1 //===-- ControlFlowConverter.cpp ------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "flang/Optimizer/Dialect/FIRDialect.h"
10 #include "flang/Optimizer/Dialect/FIROps.h"
11 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
12 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
13 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
14 #include "flang/Optimizer/Support/InternalNames.h"
15 #include "flang/Optimizer/Support/TypeCode.h"
16 #include "flang/Optimizer/Transforms/Passes.h"
17 #include "flang/Runtime/derived-api.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "llvm/ADT/SmallSet.h"
24 #include "llvm/Support/CommandLine.h"
27 #define GEN_PASS_DEF_CFGCONVERSION
28 #include "flang/Optimizer/Transforms/Passes.h.inc"
36 // Conversion of fir control ops to more primitive control-flow.
38 // FIR loops that cannot be converted to the affine dialect will remain as
39 // `fir.do_loop` operations. These can be converted to control-flow operations.
41 /// Convert `fir.do_loop` to CFG
42 class CfgLoopConv
: public mlir::OpRewritePattern
<fir::DoLoopOp
> {
44 using OpRewritePattern::OpRewritePattern
;
46 CfgLoopConv(mlir::MLIRContext
*ctx
, bool forceLoopToExecuteOnce
, bool setNSW
)
47 : mlir::OpRewritePattern
<fir::DoLoopOp
>(ctx
),
48 forceLoopToExecuteOnce(forceLoopToExecuteOnce
), setNSW(setNSW
) {}
51 matchAndRewrite(DoLoopOp loop
,
52 mlir::PatternRewriter
&rewriter
) const override
{
53 auto loc
= loop
.getLoc();
54 mlir::arith::IntegerOverflowFlags flags
{};
56 flags
= bitEnumSet(flags
, mlir::arith::IntegerOverflowFlags::nsw
);
57 auto iofAttr
= mlir::arith::IntegerOverflowFlagsAttr::get(
58 rewriter
.getContext(), flags
);
60 // Create the start and end blocks that will wrap the DoLoopOp with an
61 // initalizer and an end point
62 auto *initBlock
= rewriter
.getInsertionBlock();
63 auto initPos
= rewriter
.getInsertionPoint();
64 auto *endBlock
= rewriter
.splitBlock(initBlock
, initPos
);
66 // Split the first DoLoopOp block in two parts. The part before will be the
67 // conditional block since it already has the induction variable and
68 // loop-carried values as arguments.
69 auto *conditionalBlock
= &loop
.getRegion().front();
70 conditionalBlock
->addArgument(rewriter
.getIndexType(), loc
);
72 rewriter
.splitBlock(conditionalBlock
, conditionalBlock
->begin());
73 auto *lastBlock
= &loop
.getRegion().back();
75 // Move the blocks from the DoLoopOp between initBlock and endBlock
76 rewriter
.inlineRegionBefore(loop
.getRegion(), endBlock
);
78 // Get loop values from the DoLoopOp
79 auto low
= loop
.getLowerBound();
80 auto high
= loop
.getUpperBound();
81 assert(low
&& high
&& "must be a Value");
82 auto step
= loop
.getStep();
84 // Initalization block
85 rewriter
.setInsertionPointToEnd(initBlock
);
86 auto diff
= rewriter
.create
<mlir::arith::SubIOp
>(loc
, high
, low
);
87 auto distance
= rewriter
.create
<mlir::arith::AddIOp
>(loc
, diff
, step
);
89 rewriter
.create
<mlir::arith::DivSIOp
>(loc
, distance
, step
);
91 if (forceLoopToExecuteOnce
) {
92 auto zero
= rewriter
.create
<mlir::arith::ConstantIndexOp
>(loc
, 0);
93 auto cond
= rewriter
.create
<mlir::arith::CmpIOp
>(
94 loc
, arith::CmpIPredicate::sle
, iters
, zero
);
95 auto one
= rewriter
.create
<mlir::arith::ConstantIndexOp
>(loc
, 1);
96 iters
= rewriter
.create
<mlir::arith::SelectOp
>(loc
, cond
, one
, iters
);
99 llvm::SmallVector
<mlir::Value
> loopOperands
;
100 loopOperands
.push_back(low
);
101 auto operands
= loop
.getIterOperands();
102 loopOperands
.append(operands
.begin(), operands
.end());
103 loopOperands
.push_back(iters
);
105 rewriter
.create
<mlir::cf::BranchOp
>(loc
, conditionalBlock
, loopOperands
);
108 auto *terminator
= lastBlock
->getTerminator();
109 rewriter
.setInsertionPointToEnd(lastBlock
);
110 auto iv
= conditionalBlock
->getArgument(0);
111 mlir::Value steppedIndex
=
112 rewriter
.create
<mlir::arith::AddIOp
>(loc
, iv
, step
, iofAttr
);
113 assert(steppedIndex
&& "must be a Value");
114 auto lastArg
= conditionalBlock
->getNumArguments() - 1;
115 auto itersLeft
= conditionalBlock
->getArgument(lastArg
);
116 auto one
= rewriter
.create
<mlir::arith::ConstantIndexOp
>(loc
, 1);
117 mlir::Value itersMinusOne
=
118 rewriter
.create
<mlir::arith::SubIOp
>(loc
, itersLeft
, one
);
120 llvm::SmallVector
<mlir::Value
> loopCarried
;
121 loopCarried
.push_back(steppedIndex
);
122 auto begin
= loop
.getFinalValue() ? std::next(terminator
->operand_begin())
123 : terminator
->operand_begin();
124 loopCarried
.append(begin
, terminator
->operand_end());
125 loopCarried
.push_back(itersMinusOne
);
126 rewriter
.create
<mlir::cf::BranchOp
>(loc
, conditionalBlock
, loopCarried
);
127 rewriter
.eraseOp(terminator
);
130 rewriter
.setInsertionPointToEnd(conditionalBlock
);
131 auto zero
= rewriter
.create
<mlir::arith::ConstantIndexOp
>(loc
, 0);
132 auto comparison
= rewriter
.create
<mlir::arith::CmpIOp
>(
133 loc
, arith::CmpIPredicate::sgt
, itersLeft
, zero
);
135 auto cond
= rewriter
.create
<mlir::cf::CondBranchOp
>(
136 loc
, comparison
, firstBlock
, llvm::ArrayRef
<mlir::Value
>(), endBlock
,
137 llvm::ArrayRef
<mlir::Value
>());
139 // Copy loop annotations from the do loop to the loop entry condition.
140 if (auto ann
= loop
.getLoopAnnotation())
141 cond
->setAttr("loop_annotation", *ann
);
143 // The result of the loop operation is the values of the condition block
144 // arguments except the induction variable on the last iteration.
145 auto args
= loop
.getFinalValue()
146 ? conditionalBlock
->getArguments()
147 : conditionalBlock
->getArguments().drop_front();
148 rewriter
.replaceOp(loop
, args
.drop_back());
153 bool forceLoopToExecuteOnce
;
157 /// Convert `fir.if` to control-flow
158 class CfgIfConv
: public mlir::OpRewritePattern
<fir::IfOp
> {
160 using OpRewritePattern::OpRewritePattern
;
162 CfgIfConv(mlir::MLIRContext
*ctx
, bool forceLoopToExecuteOnce
, bool setNSW
)
163 : mlir::OpRewritePattern
<fir::IfOp
>(ctx
) {}
166 matchAndRewrite(IfOp ifOp
, mlir::PatternRewriter
&rewriter
) const override
{
167 auto loc
= ifOp
.getLoc();
169 // Split the block containing the 'fir.if' into two parts. The part before
170 // will contain the condition, the part after will be the continuation
172 auto *condBlock
= rewriter
.getInsertionBlock();
173 auto opPosition
= rewriter
.getInsertionPoint();
174 auto *remainingOpsBlock
= rewriter
.splitBlock(condBlock
, opPosition
);
175 mlir::Block
*continueBlock
;
176 if (ifOp
.getNumResults() == 0) {
177 continueBlock
= remainingOpsBlock
;
179 continueBlock
= rewriter
.createBlock(
180 remainingOpsBlock
, ifOp
.getResultTypes(),
181 llvm::SmallVector
<mlir::Location
>(ifOp
.getNumResults(), loc
));
182 rewriter
.create
<mlir::cf::BranchOp
>(loc
, remainingOpsBlock
);
185 // Move blocks from the "then" region to the region containing 'fir.if',
186 // place it before the continuation block, and branch to it.
187 auto &ifOpRegion
= ifOp
.getThenRegion();
188 auto *ifOpBlock
= &ifOpRegion
.front();
189 auto *ifOpTerminator
= ifOpRegion
.back().getTerminator();
190 auto ifOpTerminatorOperands
= ifOpTerminator
->getOperands();
191 rewriter
.setInsertionPointToEnd(&ifOpRegion
.back());
192 rewriter
.create
<mlir::cf::BranchOp
>(loc
, continueBlock
,
193 ifOpTerminatorOperands
);
194 rewriter
.eraseOp(ifOpTerminator
);
195 rewriter
.inlineRegionBefore(ifOpRegion
, continueBlock
);
197 // Move blocks from the "else" region (if present) to the region containing
198 // 'fir.if', place it before the continuation block and branch to it. It
199 // will be placed after the "then" regions.
200 auto *otherwiseBlock
= continueBlock
;
201 auto &otherwiseRegion
= ifOp
.getElseRegion();
202 if (!otherwiseRegion
.empty()) {
203 otherwiseBlock
= &otherwiseRegion
.front();
204 auto *otherwiseTerm
= otherwiseRegion
.back().getTerminator();
205 auto otherwiseTermOperands
= otherwiseTerm
->getOperands();
206 rewriter
.setInsertionPointToEnd(&otherwiseRegion
.back());
207 rewriter
.create
<mlir::cf::BranchOp
>(loc
, continueBlock
,
208 otherwiseTermOperands
);
209 rewriter
.eraseOp(otherwiseTerm
);
210 rewriter
.inlineRegionBefore(otherwiseRegion
, continueBlock
);
213 rewriter
.setInsertionPointToEnd(condBlock
);
214 rewriter
.create
<mlir::cf::CondBranchOp
>(
215 loc
, ifOp
.getCondition(), ifOpBlock
, llvm::ArrayRef
<mlir::Value
>(),
216 otherwiseBlock
, llvm::ArrayRef
<mlir::Value
>());
217 rewriter
.replaceOp(ifOp
, continueBlock
->getArguments());
222 /// Convert `fir.iter_while` to control-flow.
223 class CfgIterWhileConv
: public mlir::OpRewritePattern
<fir::IterWhileOp
> {
225 using OpRewritePattern::OpRewritePattern
;
227 CfgIterWhileConv(mlir::MLIRContext
*ctx
, bool forceLoopToExecuteOnce
,
229 : mlir::OpRewritePattern
<fir::IterWhileOp
>(ctx
), setNSW(setNSW
) {}
232 matchAndRewrite(fir::IterWhileOp whileOp
,
233 mlir::PatternRewriter
&rewriter
) const override
{
234 auto loc
= whileOp
.getLoc();
235 mlir::arith::IntegerOverflowFlags flags
{};
237 flags
= bitEnumSet(flags
, mlir::arith::IntegerOverflowFlags::nsw
);
238 auto iofAttr
= mlir::arith::IntegerOverflowFlagsAttr::get(
239 rewriter
.getContext(), flags
);
241 // Start by splitting the block containing the 'fir.do_loop' into two parts.
242 // The part before will get the init code, the part after will be the end
244 auto *initBlock
= rewriter
.getInsertionBlock();
245 auto initPosition
= rewriter
.getInsertionPoint();
246 auto *endBlock
= rewriter
.splitBlock(initBlock
, initPosition
);
248 // Use the first block of the loop body as the condition block since it is
249 // the block that has the induction variable and loop-carried values as
250 // arguments. Split out all operations from the first block into a new
251 // block. Move all body blocks from the loop body region to the region
252 // containing the loop.
253 auto *conditionBlock
= &whileOp
.getRegion().front();
254 auto *firstBodyBlock
=
255 rewriter
.splitBlock(conditionBlock
, conditionBlock
->begin());
256 auto *lastBodyBlock
= &whileOp
.getRegion().back();
257 rewriter
.inlineRegionBefore(whileOp
.getRegion(), endBlock
);
258 auto iv
= conditionBlock
->getArgument(0);
259 auto iterateVar
= conditionBlock
->getArgument(1);
261 // Append the induction variable stepping logic to the last body block and
262 // branch back to the condition block. Loop-carried values are taken from
263 // operands of the loop terminator.
264 auto *terminator
= lastBodyBlock
->getTerminator();
265 rewriter
.setInsertionPointToEnd(lastBodyBlock
);
266 auto step
= whileOp
.getStep();
267 mlir::Value stepped
=
268 rewriter
.create
<mlir::arith::AddIOp
>(loc
, iv
, step
, iofAttr
);
269 assert(stepped
&& "must be a Value");
271 llvm::SmallVector
<mlir::Value
> loopCarried
;
272 loopCarried
.push_back(stepped
);
273 auto begin
= whileOp
.getFinalValue()
274 ? std::next(terminator
->operand_begin())
275 : terminator
->operand_begin();
276 loopCarried
.append(begin
, terminator
->operand_end());
277 rewriter
.create
<mlir::cf::BranchOp
>(loc
, conditionBlock
, loopCarried
);
278 rewriter
.eraseOp(terminator
);
280 // Compute loop bounds before branching to the condition.
281 rewriter
.setInsertionPointToEnd(initBlock
);
282 auto lowerBound
= whileOp
.getLowerBound();
283 auto upperBound
= whileOp
.getUpperBound();
284 assert(lowerBound
&& upperBound
&& "must be a Value");
286 // The initial values of loop-carried values is obtained from the operands
287 // of the loop operation.
288 llvm::SmallVector
<mlir::Value
> destOperands
;
289 destOperands
.push_back(lowerBound
);
290 auto iterOperands
= whileOp
.getIterOperands();
291 destOperands
.append(iterOperands
.begin(), iterOperands
.end());
292 rewriter
.create
<mlir::cf::BranchOp
>(loc
, conditionBlock
, destOperands
);
294 // With the body block done, we can fill in the condition block.
295 rewriter
.setInsertionPointToEnd(conditionBlock
);
296 // The comparison depends on the sign of the step value. We fully expect
297 // this expression to be folded by the optimizer or LLVM. This expression
298 // is written this way so that `step == 0` always returns `false`.
299 auto zero
= rewriter
.create
<mlir::arith::ConstantIndexOp
>(loc
, 0);
300 auto compl0
= rewriter
.create
<mlir::arith::CmpIOp
>(
301 loc
, arith::CmpIPredicate::slt
, zero
, step
);
302 auto compl1
= rewriter
.create
<mlir::arith::CmpIOp
>(
303 loc
, arith::CmpIPredicate::sle
, iv
, upperBound
);
304 auto compl2
= rewriter
.create
<mlir::arith::CmpIOp
>(
305 loc
, arith::CmpIPredicate::slt
, step
, zero
);
306 auto compl3
= rewriter
.create
<mlir::arith::CmpIOp
>(
307 loc
, arith::CmpIPredicate::sle
, upperBound
, iv
);
308 auto cmp0
= rewriter
.create
<mlir::arith::AndIOp
>(loc
, compl0
, compl1
);
309 auto cmp1
= rewriter
.create
<mlir::arith::AndIOp
>(loc
, compl2
, compl3
);
310 auto cmp2
= rewriter
.create
<mlir::arith::OrIOp
>(loc
, cmp0
, cmp1
);
311 // Remember to AND in the early-exit bool.
313 rewriter
.create
<mlir::arith::AndIOp
>(loc
, iterateVar
, cmp2
);
314 rewriter
.create
<mlir::cf::CondBranchOp
>(
315 loc
, comparison
, firstBodyBlock
, llvm::ArrayRef
<mlir::Value
>(),
316 endBlock
, llvm::ArrayRef
<mlir::Value
>());
317 // The result of the loop operation is the values of the condition block
318 // arguments except the induction variable on the last iteration.
319 auto args
= whileOp
.getFinalValue()
320 ? conditionBlock
->getArguments()
321 : conditionBlock
->getArguments().drop_front();
322 rewriter
.replaceOp(whileOp
, args
);
330 /// Convert FIR structured control flow ops to CFG ops.
331 class CfgConversion
: public fir::impl::CFGConversionBase
<CfgConversion
> {
333 using CFGConversionBase
<CfgConversion
>::CFGConversionBase
;
335 void runOnOperation() override
{
336 auto *context
= &this->getContext();
337 mlir::RewritePatternSet
patterns(context
);
338 fir::populateCfgConversionRewrites(patterns
, this->forceLoopToExecuteOnce
,
340 mlir::ConversionTarget
target(*context
);
341 target
.addLegalDialect
<mlir::affine::AffineDialect
,
342 mlir::cf::ControlFlowDialect
, FIROpsDialect
,
343 mlir::func::FuncDialect
>();
345 // apply the patterns
346 target
.addIllegalOp
<ResultOp
, DoLoopOp
, IfOp
, IterWhileOp
>();
347 target
.markUnknownOpDynamicallyLegal([](Operation
*) { return true; });
348 if (mlir::failed(mlir::applyPartialConversion(this->getOperation(), target
,
349 std::move(patterns
)))) {
350 mlir::emitError(mlir::UnknownLoc::get(context
),
351 "error in converting to CFG\n");
352 this->signalPassFailure();
359 /// Expose conversion rewriters to other passes
360 void fir::populateCfgConversionRewrites(mlir::RewritePatternSet
&patterns
,
361 bool forceLoopToExecuteOnce
,
363 patterns
.insert
<CfgLoopConv
, CfgIfConv
, CfgIterWhileConv
>(
364 patterns
.getContext(), forceLoopToExecuteOnce
, setNSW
);