1 //===- ControlFlowToSPIRV.cpp - ControlFlow to SPIR-V Patterns ------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements patterns to convert standard dialect to SPIR-V dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
14 #include "../SPIRVCommon/Pattern.h"
15 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
19 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
20 #include "mlir/IR/AffineMap.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/FormatVariadic.h"
26 #define DEBUG_TYPE "cf-to-spirv-pattern"
30 /// Legailze target block arguments.
31 static LogicalResult
legalizeBlockArguments(Block
&block
, Operation
*op
,
32 PatternRewriter
&rewriter
,
33 const TypeConverter
&converter
) {
34 auto builder
= OpBuilder::atBlockBegin(&block
);
35 for (unsigned i
= 0; i
< block
.getNumArguments(); ++i
) {
36 BlockArgument arg
= block
.getArgument(i
);
37 if (converter
.isLegal(arg
.getType()))
39 Type ty
= arg
.getType();
40 Type newTy
= converter
.convertType(ty
);
42 return rewriter
.notifyMatchFailure(
43 op
, llvm::formatv("failed to legalize type for argument {0})", arg
));
45 unsigned argNum
= arg
.getArgNumber();
46 Location loc
= arg
.getLoc();
47 Value newArg
= block
.insertArgument(argNum
, newTy
, loc
);
48 Value convertedValue
= converter
.materializeSourceConversion(
49 builder
, op
->getLoc(), ty
, newArg
);
50 if (!convertedValue
) {
51 return rewriter
.notifyMatchFailure(
52 op
, llvm::formatv("failed to cast new argument {0} to type {1})",
55 arg
.replaceAllUsesWith(convertedValue
);
56 block
.eraseArgument(argNum
+ 1);
61 //===----------------------------------------------------------------------===//
62 // Operation conversion
63 //===----------------------------------------------------------------------===//
66 /// Converts cf.br to spirv.Branch.
67 struct BranchOpPattern final
: OpConversionPattern
<cf::BranchOp
> {
68 using OpConversionPattern::OpConversionPattern
;
71 matchAndRewrite(cf::BranchOp op
, OpAdaptor adaptor
,
72 ConversionPatternRewriter
&rewriter
) const override
{
73 if (failed(legalizeBlockArguments(*op
.getDest(), op
, rewriter
,
74 *getTypeConverter())))
77 rewriter
.replaceOpWithNewOp
<spirv::BranchOp
>(op
, op
.getDest(),
78 adaptor
.getDestOperands());
83 /// Converts cf.cond_br to spirv.BranchConditional.
84 struct CondBranchOpPattern final
: OpConversionPattern
<cf::CondBranchOp
> {
85 using OpConversionPattern::OpConversionPattern
;
88 matchAndRewrite(cf::CondBranchOp op
, OpAdaptor adaptor
,
89 ConversionPatternRewriter
&rewriter
) const override
{
90 if (failed(legalizeBlockArguments(*op
.getTrueDest(), op
, rewriter
,
91 *getTypeConverter())))
94 if (failed(legalizeBlockArguments(*op
.getFalseDest(), op
, rewriter
,
95 *getTypeConverter())))
98 rewriter
.replaceOpWithNewOp
<spirv::BranchConditionalOp
>(
99 op
, adaptor
.getCondition(), op
.getTrueDest(),
100 adaptor
.getTrueDestOperands(), op
.getFalseDest(),
101 adaptor
.getFalseDestOperands());
107 //===----------------------------------------------------------------------===//
108 // Pattern population
109 //===----------------------------------------------------------------------===//
111 void mlir::cf::populateControlFlowToSPIRVPatterns(
112 const SPIRVTypeConverter
&typeConverter
, RewritePatternSet
&patterns
) {
113 MLIRContext
*context
= patterns
.getContext();
115 patterns
.add
<BranchOpPattern
, CondBranchOpPattern
>(typeConverter
, context
);