1 //===- SCFToEmitC.cpp - SCF to EmitC conversion ---------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a pass to convert scf.if ops into emitc ops.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/EmitC/IR/EmitC.h"
17 #include "mlir/Dialect/SCF/IR/SCF.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "mlir/Transforms/Passes.h"
27 #define GEN_PASS_DEF_SCFTOEMITC
28 #include "mlir/Conversion/Passes.h.inc"
32 using namespace mlir::scf
;
36 struct SCFToEmitCPass
: public impl::SCFToEmitCBase
<SCFToEmitCPass
> {
37 void runOnOperation() override
;
40 // Lower scf::for to emitc::for, implementing result values using
41 // emitc::variable's updated within the loop body.
42 struct ForLowering
: public OpRewritePattern
<ForOp
> {
43 using OpRewritePattern
<ForOp
>::OpRewritePattern
;
45 LogicalResult
matchAndRewrite(ForOp forOp
,
46 PatternRewriter
&rewriter
) const override
;
49 // Create an uninitialized emitc::variable op for each result of the given op.
51 static SmallVector
<Value
> createVariablesForResults(T op
,
52 PatternRewriter
&rewriter
) {
53 SmallVector
<Value
> resultVariables
;
55 if (!op
.getNumResults())
56 return resultVariables
;
58 Location loc
= op
->getLoc();
59 MLIRContext
*context
= op
.getContext();
61 OpBuilder::InsertionGuard
guard(rewriter
);
62 rewriter
.setInsertionPoint(op
);
64 for (OpResult result
: op
.getResults()) {
65 Type resultType
= result
.getType();
66 Type varType
= emitc::LValueType::get(resultType
);
67 emitc::OpaqueAttr noInit
= emitc::OpaqueAttr::get(context
, "");
68 emitc::VariableOp var
=
69 rewriter
.create
<emitc::VariableOp
>(loc
, varType
, noInit
);
70 resultVariables
.push_back(var
);
73 return resultVariables
;
76 // Create a series of assign ops assigning given values to given variables at
77 // the current insertion point of given rewriter.
78 static void assignValues(ValueRange values
, SmallVector
<Value
> &variables
,
79 PatternRewriter
&rewriter
, Location loc
) {
80 for (auto [value
, var
] : llvm::zip(values
, variables
))
81 rewriter
.create
<emitc::AssignOp
>(loc
, var
, value
);
84 SmallVector
<Value
> loadValues(const SmallVector
<Value
> &variables
,
85 PatternRewriter
&rewriter
, Location loc
) {
86 return llvm::map_to_vector
<>(variables
, [&](Value var
) {
87 Type type
= cast
<emitc::LValueType
>(var
.getType()).getValueType();
88 return rewriter
.create
<emitc::LoadOp
>(loc
, type
, var
).getResult();
92 static void lowerYield(SmallVector
<Value
> &resultVariables
,
93 PatternRewriter
&rewriter
, scf::YieldOp yield
) {
94 Location loc
= yield
.getLoc();
95 ValueRange operands
= yield
.getOperands();
97 OpBuilder::InsertionGuard
guard(rewriter
);
98 rewriter
.setInsertionPoint(yield
);
100 assignValues(operands
, resultVariables
, rewriter
, loc
);
102 rewriter
.create
<emitc::YieldOp
>(loc
);
103 rewriter
.eraseOp(yield
);
106 // Lower the contents of an scf::if/scf::index_switch regions to an
107 // emitc::if/emitc::switch region. The contents of the lowering region is
108 // moved into the respective lowered region, but the scf::yield is replaced not
109 // only with an emitc::yield, but also with a sequence of emitc::assign ops that
110 // set the yielded values into the result variables.
111 static void lowerRegion(SmallVector
<Value
> &resultVariables
,
112 PatternRewriter
&rewriter
, Region
®ion
,
113 Region
&loweredRegion
) {
114 rewriter
.inlineRegionBefore(region
, loweredRegion
, loweredRegion
.end());
115 Operation
*terminator
= loweredRegion
.back().getTerminator();
116 lowerYield(resultVariables
, rewriter
, cast
<scf::YieldOp
>(terminator
));
119 LogicalResult
ForLowering::matchAndRewrite(ForOp forOp
,
120 PatternRewriter
&rewriter
) const {
121 Location loc
= forOp
.getLoc();
123 // Create an emitc::variable op for each result. These variables will be
124 // assigned to by emitc::assign ops within the loop body.
125 SmallVector
<Value
> resultVariables
=
126 createVariablesForResults(forOp
, rewriter
);
128 assignValues(forOp
.getInits(), resultVariables
, rewriter
, loc
);
130 emitc::ForOp loweredFor
= rewriter
.create
<emitc::ForOp
>(
131 loc
, forOp
.getLowerBound(), forOp
.getUpperBound(), forOp
.getStep());
133 Block
*loweredBody
= loweredFor
.getBody();
135 // Erase the auto-generated terminator for the lowered for op.
136 rewriter
.eraseOp(loweredBody
->getTerminator());
138 IRRewriter::InsertPoint ip
= rewriter
.saveInsertionPoint();
139 rewriter
.setInsertionPointToEnd(loweredBody
);
141 SmallVector
<Value
> iterArgsValues
=
142 loadValues(resultVariables
, rewriter
, loc
);
144 rewriter
.restoreInsertionPoint(ip
);
146 SmallVector
<Value
> replacingValues
;
147 replacingValues
.push_back(loweredFor
.getInductionVar());
148 replacingValues
.append(iterArgsValues
.begin(), iterArgsValues
.end());
150 rewriter
.mergeBlocks(forOp
.getBody(), loweredBody
, replacingValues
);
151 lowerYield(resultVariables
, rewriter
,
152 cast
<scf::YieldOp
>(loweredBody
->getTerminator()));
154 // Load variables into SSA values after the for loop.
155 SmallVector
<Value
> resultValues
= loadValues(resultVariables
, rewriter
, loc
);
157 rewriter
.replaceOp(forOp
, resultValues
);
161 // Lower scf::if to emitc::if, implementing result values as emitc::variable's
162 // updated within the then and else regions.
163 struct IfLowering
: public OpRewritePattern
<IfOp
> {
164 using OpRewritePattern
<IfOp
>::OpRewritePattern
;
166 LogicalResult
matchAndRewrite(IfOp ifOp
,
167 PatternRewriter
&rewriter
) const override
;
172 LogicalResult
IfLowering::matchAndRewrite(IfOp ifOp
,
173 PatternRewriter
&rewriter
) const {
174 Location loc
= ifOp
.getLoc();
176 // Create an emitc::variable op for each result. These variables will be
177 // assigned to by emitc::assign ops within the then & else regions.
178 SmallVector
<Value
> resultVariables
=
179 createVariablesForResults(ifOp
, rewriter
);
181 Region
&thenRegion
= ifOp
.getThenRegion();
182 Region
&elseRegion
= ifOp
.getElseRegion();
184 bool hasElseBlock
= !elseRegion
.empty();
187 rewriter
.create
<emitc::IfOp
>(loc
, ifOp
.getCondition(), false, false);
189 Region
&loweredThenRegion
= loweredIf
.getThenRegion();
190 lowerRegion(resultVariables
, rewriter
, thenRegion
, loweredThenRegion
);
193 Region
&loweredElseRegion
= loweredIf
.getElseRegion();
194 lowerRegion(resultVariables
, rewriter
, elseRegion
, loweredElseRegion
);
197 rewriter
.setInsertionPointAfter(ifOp
);
198 SmallVector
<Value
> results
= loadValues(resultVariables
, rewriter
, loc
);
200 rewriter
.replaceOp(ifOp
, results
);
204 // Lower scf::index_switch to emitc::switch, implementing result values as
205 // emitc::variable's updated within the case and default regions.
206 struct IndexSwitchOpLowering
: public OpRewritePattern
<IndexSwitchOp
> {
207 using OpRewritePattern
<IndexSwitchOp
>::OpRewritePattern
;
209 LogicalResult
matchAndRewrite(IndexSwitchOp indexSwitchOp
,
210 PatternRewriter
&rewriter
) const override
;
214 IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp
,
215 PatternRewriter
&rewriter
) const {
216 Location loc
= indexSwitchOp
.getLoc();
218 // Create an emitc::variable op for each result. These variables will be
219 // assigned to by emitc::assign ops within the case and default regions.
220 SmallVector
<Value
> resultVariables
=
221 createVariablesForResults(indexSwitchOp
, rewriter
);
223 auto loweredSwitch
= rewriter
.create
<emitc::SwitchOp
>(
224 loc
, indexSwitchOp
.getArg(), indexSwitchOp
.getCases(),
225 indexSwitchOp
.getNumCases());
227 // Lowering all case regions.
228 for (auto pair
: llvm::zip(indexSwitchOp
.getCaseRegions(),
229 loweredSwitch
.getCaseRegions())) {
230 lowerRegion(resultVariables
, rewriter
, std::get
<0>(pair
),
234 // Lowering default region.
235 lowerRegion(resultVariables
, rewriter
, indexSwitchOp
.getDefaultRegion(),
236 loweredSwitch
.getDefaultRegion());
238 rewriter
.setInsertionPointAfter(indexSwitchOp
);
239 SmallVector
<Value
> results
= loadValues(resultVariables
, rewriter
, loc
);
241 rewriter
.replaceOp(indexSwitchOp
, results
);
245 void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet
&patterns
) {
246 patterns
.add
<ForLowering
>(patterns
.getContext());
247 patterns
.add
<IfLowering
>(patterns
.getContext());
248 patterns
.add
<IndexSwitchOpLowering
>(patterns
.getContext());
251 void SCFToEmitCPass::runOnOperation() {
252 RewritePatternSet
patterns(&getContext());
253 populateSCFToEmitCConversionPatterns(patterns
);
255 // Configure conversion to lower out SCF operations.
256 ConversionTarget
target(getContext());
257 target
.addIllegalOp
<scf::ForOp
, scf::IfOp
, scf::IndexSwitchOp
>();
258 target
.markUnknownOpDynamicallyLegal([](Operation
*) { return true; });
260 applyPartialConversion(getOperation(), target
, std::move(patterns
))))