1 //===- ControlFlowToSCF.h - ControlFlow to SCF -------------*- C++ ------*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Define conversions from the ControlFlow dialect to the SCF dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Conversion/ControlFlowToSCF/ControlFlowToSCF.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
17 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/SCF/IR/SCF.h"
20 #include "mlir/Dialect/UB/IR/UBOps.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/CFGToSCF.h"
25 #define GEN_PASS_DEF_LIFTCONTROLFLOWTOSCFPASS
26 #include "mlir/Conversion/Passes.h.inc"
31 FailureOr
<Operation
*>
32 ControlFlowToSCFTransformation::createStructuredBranchRegionOp(
33 OpBuilder
&builder
, Operation
*controlFlowCondOp
, TypeRange resultTypes
,
34 MutableArrayRef
<Region
> regions
) {
35 if (auto condBrOp
= dyn_cast
<cf::CondBranchOp
>(controlFlowCondOp
)) {
36 assert(regions
.size() == 2);
37 auto ifOp
= builder
.create
<scf::IfOp
>(controlFlowCondOp
->getLoc(),
38 resultTypes
, condBrOp
.getCondition());
39 ifOp
.getThenRegion().takeBody(regions
[0]);
40 ifOp
.getElseRegion().takeBody(regions
[1]);
41 return ifOp
.getOperation();
44 if (auto switchOp
= dyn_cast
<cf::SwitchOp
>(controlFlowCondOp
)) {
45 // `getCFGSwitchValue` returns an i32 that we need to convert to index
47 auto cast
= builder
.create
<arith::IndexCastUIOp
>(
48 controlFlowCondOp
->getLoc(), builder
.getIndexType(),
50 SmallVector
<int64_t> cases
;
51 if (auto caseValues
= switchOp
.getCaseValues())
53 cases
, llvm::map_range(*caseValues
, [](const llvm::APInt
&apInt
) {
54 return apInt
.getZExtValue();
57 assert(regions
.size() == cases
.size() + 1);
59 auto indexSwitchOp
= builder
.create
<scf::IndexSwitchOp
>(
60 controlFlowCondOp
->getLoc(), resultTypes
, cast
, cases
, cases
.size());
62 indexSwitchOp
.getDefaultRegion().takeBody(regions
[0]);
63 for (auto &&[targetRegion
, sourceRegion
] :
64 llvm::zip(indexSwitchOp
.getCaseRegions(), llvm::drop_begin(regions
)))
65 targetRegion
.takeBody(sourceRegion
);
67 return indexSwitchOp
.getOperation();
70 controlFlowCondOp
->emitOpError(
71 "Cannot convert unknown control flow op to structured control flow");
76 ControlFlowToSCFTransformation::createStructuredBranchRegionTerminatorOp(
77 Location loc
, OpBuilder
&builder
, Operation
*branchRegionOp
,
78 Operation
*replacedControlFlowOp
, ValueRange results
) {
79 builder
.create
<scf::YieldOp
>(loc
, results
);
83 FailureOr
<Operation
*>
84 ControlFlowToSCFTransformation::createStructuredDoWhileLoopOp(
85 OpBuilder
&builder
, Operation
*replacedOp
, ValueRange loopVariablesInit
,
86 Value condition
, ValueRange loopVariablesNextIter
, Region
&&loopBody
) {
87 Location loc
= replacedOp
->getLoc();
88 auto whileOp
= builder
.create
<scf::WhileOp
>(loc
, loopVariablesInit
.getTypes(),
91 whileOp
.getBefore().takeBody(loopBody
);
93 builder
.setInsertionPointToEnd(&whileOp
.getBefore().back());
94 // `getCFGSwitchValue` returns a i32. We therefore need to truncate the
95 // condition to i1 first. It is guaranteed to be either 0 or 1 already.
96 builder
.create
<scf::ConditionOp
>(
97 loc
, builder
.create
<arith::TruncIOp
>(loc
, builder
.getI1Type(), condition
),
98 loopVariablesNextIter
);
100 Block
*afterBlock
= builder
.createBlock(&whileOp
.getAfter());
101 afterBlock
->addArguments(
102 loopVariablesInit
.getTypes(),
103 SmallVector
<Location
>(loopVariablesInit
.size(), loc
));
104 builder
.create
<scf::YieldOp
>(loc
, afterBlock
->getArguments());
106 return whileOp
.getOperation();
109 Value
ControlFlowToSCFTransformation::getCFGSwitchValue(Location loc
,
111 unsigned int value
) {
112 return builder
.create
<arith::ConstantOp
>(loc
,
113 builder
.getI32IntegerAttr(value
));
116 void ControlFlowToSCFTransformation::createCFGSwitchOp(
117 Location loc
, OpBuilder
&builder
, Value flag
,
118 ArrayRef
<unsigned int> caseValues
, BlockRange caseDestinations
,
119 ArrayRef
<ValueRange
> caseArguments
, Block
*defaultDest
,
120 ValueRange defaultArgs
) {
121 builder
.create
<cf::SwitchOp
>(loc
, flag
, defaultDest
, defaultArgs
,
122 llvm::to_vector_of
<int32_t>(caseValues
),
123 caseDestinations
, caseArguments
);
126 Value
ControlFlowToSCFTransformation::getUndefValue(Location loc
,
129 return builder
.create
<ub::PoisonOp
>(loc
, type
, nullptr);
132 FailureOr
<Operation
*>
133 ControlFlowToSCFTransformation::createUnreachableTerminator(Location loc
,
137 // TODO: This should create a `ub.unreachable` op. Once such an operation
138 // exists to make the pass independent of the func dialect. For now just
139 // return poison values.
140 Operation
*parentOp
= region
.getParentOp();
141 auto funcOp
= dyn_cast
<func::FuncOp
>(parentOp
);
143 return emitError(loc
, "Cannot create unreachable terminator for '")
144 << parentOp
->getName() << "'";
147 .create
<func::ReturnOp
>(
148 loc
, llvm::map_to_vector(funcOp
.getResultTypes(),
150 return getUndefValue(loc
, builder
, type
);
157 struct LiftControlFlowToSCF
158 : public impl::LiftControlFlowToSCFPassBase
<LiftControlFlowToSCF
> {
162 void runOnOperation() override
{
163 ControlFlowToSCFTransformation transformation
;
165 bool changed
= false;
166 Operation
*op
= getOperation();
167 WalkResult result
= op
->walk([&](func::FuncOp funcOp
) {
168 if (funcOp
.getBody().empty())
169 return WalkResult::advance();
171 auto &domInfo
= funcOp
!= op
? getChildAnalysis
<DominanceInfo
>(funcOp
)
172 : getAnalysis
<DominanceInfo
>();
174 auto visitor
= [&](Operation
*innerOp
) -> WalkResult
{
175 for (Region
®
: innerOp
->getRegions()) {
176 FailureOr
<bool> changedFunc
=
177 transformCFGToSCF(reg
, transformation
, domInfo
);
178 if (failed(changedFunc
))
179 return WalkResult::interrupt();
181 changed
|= *changedFunc
;
183 return WalkResult::advance();
186 if (funcOp
->walk
<WalkOrder::PostOrder
>(visitor
).wasInterrupted())
187 return WalkResult::interrupt();
189 return WalkResult::advance();
191 if (result
.wasInterrupted())
192 return signalPassFailure();
195 markAllAnalysesPreserved();