Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / lib / Conversion / SCFToEmitC / SCFToEmitC.cpp
blob67a43c43d608b4ae7ae5fa91d800b0e95fc5218f
1 //===- SCFToEmitC.cpp - SCF to EmitC conversion ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert scf.if ops into emitc ops.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/EmitC/IR/EmitC.h"
17 #include "mlir/Dialect/SCF/IR/SCF.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "mlir/Transforms/Passes.h"
26 namespace mlir {
27 #define GEN_PASS_DEF_SCFTOEMITC
28 #include "mlir/Conversion/Passes.h.inc"
29 } // namespace mlir
31 using namespace mlir;
32 using namespace mlir::scf;
34 namespace {
36 struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
37 void runOnOperation() override;
40 // Lower scf::for to emitc::for, implementing result values using
41 // emitc::variable's updated within the loop body.
42 struct ForLowering : public OpRewritePattern<ForOp> {
43 using OpRewritePattern<ForOp>::OpRewritePattern;
45 LogicalResult matchAndRewrite(ForOp forOp,
46 PatternRewriter &rewriter) const override;
49 // Create an uninitialized emitc::variable op for each result of the given op.
50 template <typename T>
51 static SmallVector<Value> createVariablesForResults(T op,
52 PatternRewriter &rewriter) {
53 SmallVector<Value> resultVariables;
55 if (!op.getNumResults())
56 return resultVariables;
58 Location loc = op->getLoc();
59 MLIRContext *context = op.getContext();
61 OpBuilder::InsertionGuard guard(rewriter);
62 rewriter.setInsertionPoint(op);
64 for (OpResult result : op.getResults()) {
65 Type resultType = result.getType();
66 Type varType = emitc::LValueType::get(resultType);
67 emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
68 emitc::VariableOp var =
69 rewriter.create<emitc::VariableOp>(loc, varType, noInit);
70 resultVariables.push_back(var);
73 return resultVariables;
76 // Create a series of assign ops assigning given values to given variables at
77 // the current insertion point of given rewriter.
78 static void assignValues(ValueRange values, SmallVector<Value> &variables,
79 PatternRewriter &rewriter, Location loc) {
80 for (auto [value, var] : llvm::zip(values, variables))
81 rewriter.create<emitc::AssignOp>(loc, var, value);
84 SmallVector<Value> loadValues(const SmallVector<Value> &variables,
85 PatternRewriter &rewriter, Location loc) {
86 return llvm::map_to_vector<>(variables, [&](Value var) {
87 Type type = cast<emitc::LValueType>(var.getType()).getValueType();
88 return rewriter.create<emitc::LoadOp>(loc, type, var).getResult();
89 });
92 static void lowerYield(SmallVector<Value> &resultVariables,
93 PatternRewriter &rewriter, scf::YieldOp yield) {
94 Location loc = yield.getLoc();
95 ValueRange operands = yield.getOperands();
97 OpBuilder::InsertionGuard guard(rewriter);
98 rewriter.setInsertionPoint(yield);
100 assignValues(operands, resultVariables, rewriter, loc);
102 rewriter.create<emitc::YieldOp>(loc);
103 rewriter.eraseOp(yield);
106 // Lower the contents of an scf::if/scf::index_switch regions to an
107 // emitc::if/emitc::switch region. The contents of the lowering region is
108 // moved into the respective lowered region, but the scf::yield is replaced not
109 // only with an emitc::yield, but also with a sequence of emitc::assign ops that
110 // set the yielded values into the result variables.
111 static void lowerRegion(SmallVector<Value> &resultVariables,
112 PatternRewriter &rewriter, Region &region,
113 Region &loweredRegion) {
114 rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
115 Operation *terminator = loweredRegion.back().getTerminator();
116 lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
119 LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
120 PatternRewriter &rewriter) const {
121 Location loc = forOp.getLoc();
123 // Create an emitc::variable op for each result. These variables will be
124 // assigned to by emitc::assign ops within the loop body.
125 SmallVector<Value> resultVariables =
126 createVariablesForResults(forOp, rewriter);
128 assignValues(forOp.getInits(), resultVariables, rewriter, loc);
130 emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
131 loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
133 Block *loweredBody = loweredFor.getBody();
135 // Erase the auto-generated terminator for the lowered for op.
136 rewriter.eraseOp(loweredBody->getTerminator());
138 IRRewriter::InsertPoint ip = rewriter.saveInsertionPoint();
139 rewriter.setInsertionPointToEnd(loweredBody);
141 SmallVector<Value> iterArgsValues =
142 loadValues(resultVariables, rewriter, loc);
144 rewriter.restoreInsertionPoint(ip);
146 SmallVector<Value> replacingValues;
147 replacingValues.push_back(loweredFor.getInductionVar());
148 replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());
150 rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
151 lowerYield(resultVariables, rewriter,
152 cast<scf::YieldOp>(loweredBody->getTerminator()));
154 // Load variables into SSA values after the for loop.
155 SmallVector<Value> resultValues = loadValues(resultVariables, rewriter, loc);
157 rewriter.replaceOp(forOp, resultValues);
158 return success();
161 // Lower scf::if to emitc::if, implementing result values as emitc::variable's
162 // updated within the then and else regions.
163 struct IfLowering : public OpRewritePattern<IfOp> {
164 using OpRewritePattern<IfOp>::OpRewritePattern;
166 LogicalResult matchAndRewrite(IfOp ifOp,
167 PatternRewriter &rewriter) const override;
170 } // namespace
172 LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
173 PatternRewriter &rewriter) const {
174 Location loc = ifOp.getLoc();
176 // Create an emitc::variable op for each result. These variables will be
177 // assigned to by emitc::assign ops within the then & else regions.
178 SmallVector<Value> resultVariables =
179 createVariablesForResults(ifOp, rewriter);
181 Region &thenRegion = ifOp.getThenRegion();
182 Region &elseRegion = ifOp.getElseRegion();
184 bool hasElseBlock = !elseRegion.empty();
186 auto loweredIf =
187 rewriter.create<emitc::IfOp>(loc, ifOp.getCondition(), false, false);
189 Region &loweredThenRegion = loweredIf.getThenRegion();
190 lowerRegion(resultVariables, rewriter, thenRegion, loweredThenRegion);
192 if (hasElseBlock) {
193 Region &loweredElseRegion = loweredIf.getElseRegion();
194 lowerRegion(resultVariables, rewriter, elseRegion, loweredElseRegion);
197 rewriter.setInsertionPointAfter(ifOp);
198 SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
200 rewriter.replaceOp(ifOp, results);
201 return success();
204 // Lower scf::index_switch to emitc::switch, implementing result values as
205 // emitc::variable's updated within the case and default regions.
206 struct IndexSwitchOpLowering : public OpRewritePattern<IndexSwitchOp> {
207 using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
209 LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp,
210 PatternRewriter &rewriter) const override;
213 LogicalResult
214 IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp,
215 PatternRewriter &rewriter) const {
216 Location loc = indexSwitchOp.getLoc();
218 // Create an emitc::variable op for each result. These variables will be
219 // assigned to by emitc::assign ops within the case and default regions.
220 SmallVector<Value> resultVariables =
221 createVariablesForResults(indexSwitchOp, rewriter);
223 auto loweredSwitch = rewriter.create<emitc::SwitchOp>(
224 loc, indexSwitchOp.getArg(), indexSwitchOp.getCases(),
225 indexSwitchOp.getNumCases());
227 // Lowering all case regions.
228 for (auto pair : llvm::zip(indexSwitchOp.getCaseRegions(),
229 loweredSwitch.getCaseRegions())) {
230 lowerRegion(resultVariables, rewriter, std::get<0>(pair),
231 std::get<1>(pair));
234 // Lowering default region.
235 lowerRegion(resultVariables, rewriter, indexSwitchOp.getDefaultRegion(),
236 loweredSwitch.getDefaultRegion());
238 rewriter.setInsertionPointAfter(indexSwitchOp);
239 SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
241 rewriter.replaceOp(indexSwitchOp, results);
242 return success();
245 void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) {
246 patterns.add<ForLowering>(patterns.getContext());
247 patterns.add<IfLowering>(patterns.getContext());
248 patterns.add<IndexSwitchOpLowering>(patterns.getContext());
251 void SCFToEmitCPass::runOnOperation() {
252 RewritePatternSet patterns(&getContext());
253 populateSCFToEmitCConversionPatterns(patterns);
255 // Configure conversion to lower out SCF operations.
256 ConversionTarget target(getContext());
257 target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>();
258 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
259 if (failed(
260 applyPartialConversion(getOperation(), target, std::move(patterns))))
261 signalPassFailure();