1 //===- FuncToSPIRV.cpp - Func to SPIR-V Patterns ------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements patterns to convert Func dialect to SPIR-V dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
14 #include "../SPIRVCommon/Pattern.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
19 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
20 #include "mlir/IR/AffineMap.h"
21 #include "llvm/Support/Debug.h"
23 #define DEBUG_TYPE "func-to-spirv-pattern"
27 //===----------------------------------------------------------------------===//
28 // Operation conversion
29 //===----------------------------------------------------------------------===//
31 // Note that DRR cannot be used for the patterns in this file: we may need to
32 // convert type along the way, which requires ConversionPattern. DRR generates
33 // normal RewritePattern.
37 /// Converts func.return to spirv.Return.
38 class ReturnOpPattern final
: public OpConversionPattern
<func::ReturnOp
> {
40 using OpConversionPattern
<func::ReturnOp
>::OpConversionPattern
;
43 matchAndRewrite(func::ReturnOp returnOp
, OpAdaptor adaptor
,
44 ConversionPatternRewriter
&rewriter
) const override
{
45 if (returnOp
.getNumOperands() > 1)
48 if (returnOp
.getNumOperands() == 1) {
49 rewriter
.replaceOpWithNewOp
<spirv::ReturnValueOp
>(
50 returnOp
, adaptor
.getOperands()[0]);
52 rewriter
.replaceOpWithNewOp
<spirv::ReturnOp
>(returnOp
);
58 /// Converts func.call to spirv.FunctionCall.
59 class CallOpPattern final
: public OpConversionPattern
<func::CallOp
> {
61 using OpConversionPattern
<func::CallOp
>::OpConversionPattern
;
64 matchAndRewrite(func::CallOp callOp
, OpAdaptor adaptor
,
65 ConversionPatternRewriter
&rewriter
) const override
{
66 // multiple results func was not converted to spirv.func
67 if (callOp
.getNumResults() > 1)
69 if (callOp
.getNumResults() == 1) {
71 getTypeConverter()->convertType(callOp
.getResult(0).getType());
74 rewriter
.replaceOpWithNewOp
<spirv::FunctionCallOp
>(
75 callOp
, resultType
, adaptor
.getOperands(), callOp
->getAttrs());
77 rewriter
.replaceOpWithNewOp
<spirv::FunctionCallOp
>(
78 callOp
, TypeRange(), adaptor
.getOperands(), callOp
->getAttrs());
86 //===----------------------------------------------------------------------===//
88 //===----------------------------------------------------------------------===//
90 void mlir::populateFuncToSPIRVPatterns(const SPIRVTypeConverter
&typeConverter
,
91 RewritePatternSet
&patterns
) {
92 MLIRContext
*context
= patterns
.getContext();
94 patterns
.add
<ReturnOpPattern
, CallOpPattern
>(typeConverter
, context
);