1 //===- SCFToControlFlow.cpp - SCF to CF conversion ------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a pass to convert scf.for, scf.if and loop.terminator
10 // ops into standard CFG ops.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
16 #include "mlir/Dialect/Arith/IR/Arith.h"
17 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19 #include "mlir/Dialect/SCF/IR/SCF.h"
20 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/IRMapping.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "mlir/Transforms/Passes.h"
30 #define GEN_PASS_DEF_SCFTOCONTROLFLOW
31 #include "mlir/Conversion/Passes.h.inc"
35 using namespace mlir::scf
;
39 struct SCFToControlFlowPass
40 : public impl::SCFToControlFlowBase
<SCFToControlFlowPass
> {
41 void runOnOperation() override
;
44 // Create a CFG subgraph for the loop around its body blocks (if the body
45 // contained other loops, they have been already lowered to a flow of blocks).
46 // Maintain the invariants that a CFG subgraph created for any loop has a single
47 // entry and a single exit, and that the entry/exit blocks are respectively
48 // first/last blocks in the parent region. The original loop operation is
49 // replaced by the initialization operations that set up the initial value of
50 // the loop induction variable (%iv) and computes the loop bounds that are loop-
51 // invariant for affine loops. The operations following the original scf.for
52 // are split out into a separate continuation (exit) block. A condition block is
53 // created before the continuation block. It checks the exit condition of the
54 // loop and branches either to the continuation block, or to the first block of
55 // the body. The condition block takes as arguments the values of the induction
56 // variable followed by loop-carried values. Since it dominates both the body
57 // blocks and the continuation block, loop-carried values are visible in all of
58 // those blocks. Induction variable modification is appended to the last block
59 // of the body (which is the exit block from the body subgraph thanks to the
60 // invariant we maintain) along with a branch that loops back to the condition
61 // block. Loop-carried values are the loop terminator operands, which are
62 // forwarded to the branch.
64 // +---------------------------------+
65 // | <code before the ForOp> |
66 // | <definitions of %init...> |
67 // | <compute initial %iv value> |
68 // | cf.br cond(%iv, %init...) |
69 // +---------------------------------+
73 // | +--------------------------------+
74 // | | cond(%iv, %init...): |
75 // | | <compare %iv to upper bound> |
76 // | | cf.cond_br %r, body, end |
77 // | +--------------------------------+
81 // | +--------------------------------+ |
82 // | | body-first: | |
83 // | | <%init visible by dominance> | |
84 // | | <body contents> | |
85 // | +--------------------------------+ |
89 // | +--------------------------------+ |
91 // | | <body contents> | |
92 // | | <operands of yield = %yields>| |
93 // | | %new_iv =<add step to %iv> | |
94 // | | cf.br cond(%new_iv, %yields) | |
95 // | +--------------------------------+ |
97 // |----------- |--------------------
99 // +--------------------------------+
101 // | <code after the ForOp> |
102 // | <%init visible by dominance> |
103 // +--------------------------------+
105 struct ForLowering
: public OpRewritePattern
<ForOp
> {
106 using OpRewritePattern
<ForOp
>::OpRewritePattern
;
108 LogicalResult
matchAndRewrite(ForOp forOp
,
109 PatternRewriter
&rewriter
) const override
;
112 // Create a CFG subgraph for the scf.if operation (including its "then" and
113 // optional "else" operation blocks). We maintain the invariants that the
114 // subgraph has a single entry and a single exit point, and that the entry/exit
115 // blocks are respectively the first/last block of the enclosing region. The
116 // operations following the scf.if are split into a continuation (subgraph
117 // exit) block. The condition is lowered to a chain of blocks that implement the
118 // short-circuit scheme. The "scf.if" operation is replaced with a conditional
119 // branch to either the first block of the "then" region, or to the first block
120 // of the "else" region. In these blocks, "scf.yield" is unconditional branches
121 // to the post-dominating block. When the "scf.if" does not return values, the
122 // post-dominating block is the same as the continuation block. When it returns
123 // values, the post-dominating block is a new block with arguments that
124 // correspond to the values returned by the "scf.if" that unconditionally
125 // branches to the continuation block. This allows block arguments to dominate
126 // any uses of the hitherto "scf.if" results that they replaced. (Inserting a
127 // new block allows us to avoid modifying the argument list of an existing
128 // block, which is illegal in a conversion pattern). When the "else" region is
129 // empty, which is only allowed for "scf.if"s that don't return values, the
130 // condition branches directly to the continuation block.
132 // CFG for a scf.if with else and without results.
134 // +--------------------------------+
135 // | <code before the IfOp> |
136 // | cf.cond_br %cond, %then, %else |
137 // +--------------------------------+
141 // +--------------------------------+ |
143 // | <then contents> | |
144 // | cf.br continue | |
145 // +--------------------------------+ |
147 // |---------- |-------------
149 // | +--------------------------------+
151 // | | <else contents> |
152 // | | cf.br continue |
153 // | +--------------------------------+
157 // +--------------------------------+
159 // | <code after the IfOp> |
160 // +--------------------------------+
162 // CFG for a scf.if with results.
164 // +--------------------------------+
165 // | <code before the IfOp> |
166 // | cf.cond_br %cond, %then, %else |
167 // +--------------------------------+
171 // +--------------------------------+ |
173 // | <then contents> | |
174 // | cf.br dom(%args...) | |
175 // +--------------------------------+ |
177 // |---------- |-------------
179 // | +--------------------------------+
181 // | | <else contents> |
182 // | | cf.br dom(%args...) |
183 // | +--------------------------------+
187 // +--------------------------------+
188 // | dom(%args...): |
189 // | cf.br continue |
190 // +--------------------------------+
193 // +--------------------------------+
195 // | <code after the IfOp> |
196 // +--------------------------------+
198 struct IfLowering
: public OpRewritePattern
<IfOp
> {
199 using OpRewritePattern
<IfOp
>::OpRewritePattern
;
201 LogicalResult
matchAndRewrite(IfOp ifOp
,
202 PatternRewriter
&rewriter
) const override
;
205 struct ExecuteRegionLowering
: public OpRewritePattern
<ExecuteRegionOp
> {
206 using OpRewritePattern
<ExecuteRegionOp
>::OpRewritePattern
;
208 LogicalResult
matchAndRewrite(ExecuteRegionOp op
,
209 PatternRewriter
&rewriter
) const override
;
212 struct ParallelLowering
: public OpRewritePattern
<mlir::scf::ParallelOp
> {
213 using OpRewritePattern
<mlir::scf::ParallelOp
>::OpRewritePattern
;
215 LogicalResult
matchAndRewrite(mlir::scf::ParallelOp parallelOp
,
216 PatternRewriter
&rewriter
) const override
;
219 /// Create a CFG subgraph for this loop construct. The regions of the loop need
220 /// not be a single block anymore (for example, if other SCF constructs that
221 /// they contain have been already converted to CFG), but need to be single-exit
222 /// from the last block of each region. The operations following the original
223 /// WhileOp are split into a new continuation block. Both regions of the WhileOp
224 /// are inlined, and their terminators are rewritten to organize the control
225 /// flow implementing the loop as follows.
227 /// +---------------------------------+
228 /// | <code before the WhileOp> |
229 /// | cf.br ^before(%operands...) |
230 /// +---------------------------------+
234 /// | +--------------------------------+
235 /// | | ^before(%bargs...): |
236 /// | | %vals... = <some payload> |
237 /// | +--------------------------------+
241 /// | +--------------------------------+
242 /// | | ^before-last:
243 /// | | %cond = <compute condition> |
244 /// | | cf.cond_br %cond, |
245 /// | | ^after(%vals...), ^cont |
246 /// | +--------------------------------+
248 /// | | -------------|
250 /// | +--------------------------------+ |
251 /// | | ^after(%aargs...): | |
252 /// | | <body contents> | |
253 /// | +--------------------------------+ |
257 /// | +--------------------------------+ |
258 /// | | ^after-last: | |
259 /// | | %yields... = <some payload> | |
260 /// | | cf.br ^before(%yields...) | |
261 /// | +--------------------------------+ |
263 /// |----------- |--------------------
265 /// +--------------------------------+
267 /// | <code after the WhileOp> |
268 /// | <%vals from 'before' region |
269 /// | visible by dominance> |
270 /// +--------------------------------+
272 /// Values are communicated between ex-regions (the groups of blocks that used
273 /// to form a region before inlining) through block arguments of their
274 /// entry blocks, which are visible in all other dominated blocks. Similarly,
275 /// the results of the WhileOp are defined in the 'before' region, which is
276 /// required to have a single existing block, and are therefore accessible in
277 /// the continuation block due to dominance.
278 struct WhileLowering
: public OpRewritePattern
<WhileOp
> {
279 using OpRewritePattern
<WhileOp
>::OpRewritePattern
;
281 LogicalResult
matchAndRewrite(WhileOp whileOp
,
282 PatternRewriter
&rewriter
) const override
;
285 /// Optimized version of the above for the case of the "after" region merely
286 /// forwarding its arguments back to the "before" region (i.e., a "do-while"
287 /// loop). This avoid inlining the "after" region completely and branches back
288 /// to the "before" entry instead.
289 struct DoWhileLowering
: public OpRewritePattern
<WhileOp
> {
290 using OpRewritePattern
<WhileOp
>::OpRewritePattern
;
292 LogicalResult
matchAndRewrite(WhileOp whileOp
,
293 PatternRewriter
&rewriter
) const override
;
296 /// Lower an `scf.index_switch` operation to a `cf.switch` operation.
297 struct IndexSwitchLowering
: public OpRewritePattern
<IndexSwitchOp
> {
298 using OpRewritePattern::OpRewritePattern
;
300 LogicalResult
matchAndRewrite(IndexSwitchOp op
,
301 PatternRewriter
&rewriter
) const override
;
304 /// Lower an `scf.forall` operation to an `scf.parallel` op, assuming that it
305 /// has no shared outputs. Ops with shared outputs should be bufferized first.
306 /// Specialized lowerings for `scf.forall` (e.g., for GPUs) exist in other
308 struct ForallLowering
: public OpRewritePattern
<mlir::scf::ForallOp
> {
309 using OpRewritePattern
<mlir::scf::ForallOp
>::OpRewritePattern
;
311 LogicalResult
matchAndRewrite(mlir::scf::ForallOp forallOp
,
312 PatternRewriter
&rewriter
) const override
;
317 LogicalResult
ForLowering::matchAndRewrite(ForOp forOp
,
318 PatternRewriter
&rewriter
) const {
319 Location loc
= forOp
.getLoc();
321 // Start by splitting the block containing the 'scf.for' into two parts.
322 // The part before will get the init code, the part after will be the end
324 auto *initBlock
= rewriter
.getInsertionBlock();
325 auto initPosition
= rewriter
.getInsertionPoint();
326 auto *endBlock
= rewriter
.splitBlock(initBlock
, initPosition
);
328 // Use the first block of the loop body as the condition block since it is the
329 // block that has the induction variable and loop-carried values as arguments.
330 // Split out all operations from the first block into a new block. Move all
331 // body blocks from the loop body region to the region containing the loop.
332 auto *conditionBlock
= &forOp
.getRegion().front();
333 auto *firstBodyBlock
=
334 rewriter
.splitBlock(conditionBlock
, conditionBlock
->begin());
335 auto *lastBodyBlock
= &forOp
.getRegion().back();
336 rewriter
.inlineRegionBefore(forOp
.getRegion(), endBlock
);
337 auto iv
= conditionBlock
->getArgument(0);
339 // Append the induction variable stepping logic to the last body block and
340 // branch back to the condition block. Loop-carried values are taken from
341 // operands of the loop terminator.
342 Operation
*terminator
= lastBodyBlock
->getTerminator();
343 rewriter
.setInsertionPointToEnd(lastBodyBlock
);
344 auto step
= forOp
.getStep();
345 auto stepped
= rewriter
.create
<arith::AddIOp
>(loc
, iv
, step
).getResult();
349 SmallVector
<Value
, 8> loopCarried
;
350 loopCarried
.push_back(stepped
);
351 loopCarried
.append(terminator
->operand_begin(), terminator
->operand_end());
352 rewriter
.create
<cf::BranchOp
>(loc
, conditionBlock
, loopCarried
);
353 rewriter
.eraseOp(terminator
);
355 // Compute loop bounds before branching to the condition.
356 rewriter
.setInsertionPointToEnd(initBlock
);
357 Value lowerBound
= forOp
.getLowerBound();
358 Value upperBound
= forOp
.getUpperBound();
359 if (!lowerBound
|| !upperBound
)
362 // The initial values of loop-carried values is obtained from the operands
363 // of the loop operation.
364 SmallVector
<Value
, 8> destOperands
;
365 destOperands
.push_back(lowerBound
);
366 llvm::append_range(destOperands
, forOp
.getInitArgs());
367 rewriter
.create
<cf::BranchOp
>(loc
, conditionBlock
, destOperands
);
369 // With the body block done, we can fill in the condition block.
370 rewriter
.setInsertionPointToEnd(conditionBlock
);
371 auto comparison
= rewriter
.create
<arith::CmpIOp
>(
372 loc
, arith::CmpIPredicate::slt
, iv
, upperBound
);
374 auto condBranchOp
= rewriter
.create
<cf::CondBranchOp
>(
375 loc
, comparison
, firstBodyBlock
, ArrayRef
<Value
>(), endBlock
,
378 // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
379 // llvm.loop_annotation attribute.
380 SmallVector
<NamedAttribute
> llvmAttrs
;
381 llvm::copy_if(forOp
->getAttrs(), std::back_inserter(llvmAttrs
),
383 return isa
<LLVM::LLVMDialect
>(attr
.getValue().getDialect());
385 condBranchOp
->setDiscardableAttrs(llvmAttrs
);
386 // The result of the loop operation is the values of the condition block
387 // arguments except the induction variable on the last iteration.
388 rewriter
.replaceOp(forOp
, conditionBlock
->getArguments().drop_front());
392 LogicalResult
IfLowering::matchAndRewrite(IfOp ifOp
,
393 PatternRewriter
&rewriter
) const {
394 auto loc
= ifOp
.getLoc();
396 // Start by splitting the block containing the 'scf.if' into two parts.
397 // The part before will contain the condition, the part after will be the
398 // continuation point.
399 auto *condBlock
= rewriter
.getInsertionBlock();
400 auto opPosition
= rewriter
.getInsertionPoint();
401 auto *remainingOpsBlock
= rewriter
.splitBlock(condBlock
, opPosition
);
402 Block
*continueBlock
;
403 if (ifOp
.getNumResults() == 0) {
404 continueBlock
= remainingOpsBlock
;
407 rewriter
.createBlock(remainingOpsBlock
, ifOp
.getResultTypes(),
408 SmallVector
<Location
>(ifOp
.getNumResults(), loc
));
409 rewriter
.create
<cf::BranchOp
>(loc
, remainingOpsBlock
);
412 // Move blocks from the "then" region to the region containing 'scf.if',
413 // place it before the continuation block, and branch to it.
414 auto &thenRegion
= ifOp
.getThenRegion();
415 auto *thenBlock
= &thenRegion
.front();
416 Operation
*thenTerminator
= thenRegion
.back().getTerminator();
417 ValueRange thenTerminatorOperands
= thenTerminator
->getOperands();
418 rewriter
.setInsertionPointToEnd(&thenRegion
.back());
419 rewriter
.create
<cf::BranchOp
>(loc
, continueBlock
, thenTerminatorOperands
);
420 rewriter
.eraseOp(thenTerminator
);
421 rewriter
.inlineRegionBefore(thenRegion
, continueBlock
);
423 // Move blocks from the "else" region (if present) to the region containing
424 // 'scf.if', place it before the continuation block and branch to it. It
425 // will be placed after the "then" regions.
426 auto *elseBlock
= continueBlock
;
427 auto &elseRegion
= ifOp
.getElseRegion();
428 if (!elseRegion
.empty()) {
429 elseBlock
= &elseRegion
.front();
430 Operation
*elseTerminator
= elseRegion
.back().getTerminator();
431 ValueRange elseTerminatorOperands
= elseTerminator
->getOperands();
432 rewriter
.setInsertionPointToEnd(&elseRegion
.back());
433 rewriter
.create
<cf::BranchOp
>(loc
, continueBlock
, elseTerminatorOperands
);
434 rewriter
.eraseOp(elseTerminator
);
435 rewriter
.inlineRegionBefore(elseRegion
, continueBlock
);
438 rewriter
.setInsertionPointToEnd(condBlock
);
439 rewriter
.create
<cf::CondBranchOp
>(loc
, ifOp
.getCondition(), thenBlock
,
440 /*trueArgs=*/ArrayRef
<Value
>(), elseBlock
,
441 /*falseArgs=*/ArrayRef
<Value
>());
444 rewriter
.replaceOp(ifOp
, continueBlock
->getArguments());
449 ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op
,
450 PatternRewriter
&rewriter
) const {
451 auto loc
= op
.getLoc();
453 auto *condBlock
= rewriter
.getInsertionBlock();
454 auto opPosition
= rewriter
.getInsertionPoint();
455 auto *remainingOpsBlock
= rewriter
.splitBlock(condBlock
, opPosition
);
457 auto ®ion
= op
.getRegion();
458 rewriter
.setInsertionPointToEnd(condBlock
);
459 rewriter
.create
<cf::BranchOp
>(loc
, ®ion
.front());
461 for (Block
&block
: region
) {
462 if (auto terminator
= dyn_cast
<scf::YieldOp
>(block
.getTerminator())) {
463 ValueRange terminatorOperands
= terminator
->getOperands();
464 rewriter
.setInsertionPointToEnd(&block
);
465 rewriter
.create
<cf::BranchOp
>(loc
, remainingOpsBlock
, terminatorOperands
);
466 rewriter
.eraseOp(terminator
);
470 rewriter
.inlineRegionBefore(region
, remainingOpsBlock
);
472 SmallVector
<Value
> vals
;
473 SmallVector
<Location
> argLocs(op
.getNumResults(), op
->getLoc());
475 remainingOpsBlock
->addArguments(op
->getResultTypes(), argLocs
))
477 rewriter
.replaceOp(op
, vals
);
482 ParallelLowering::matchAndRewrite(ParallelOp parallelOp
,
483 PatternRewriter
&rewriter
) const {
484 Location loc
= parallelOp
.getLoc();
485 auto reductionOp
= dyn_cast
<ReduceOp
>(parallelOp
.getBody()->getTerminator());
490 // For a parallel loop, we essentially need to create an n-dimensional loop
491 // nest. We do this by translating to scf.for ops and have those lowered in
492 // a further rewrite. If a parallel loop contains reductions (and thus returns
493 // values), forward the initial values for the reductions down the loop
494 // hierarchy and bubble up the results by modifying the "yield" terminator.
495 SmallVector
<Value
, 4> iterArgs
= llvm::to_vector
<4>(parallelOp
.getInitVals());
496 SmallVector
<Value
, 4> ivs
;
497 ivs
.reserve(parallelOp
.getNumLoops());
499 SmallVector
<Value
, 4> loopResults(iterArgs
);
500 for (auto [iv
, lower
, upper
, step
] :
501 llvm::zip(parallelOp
.getInductionVars(), parallelOp
.getLowerBound(),
502 parallelOp
.getUpperBound(), parallelOp
.getStep())) {
503 ForOp forOp
= rewriter
.create
<ForOp
>(loc
, lower
, upper
, step
, iterArgs
);
504 ivs
.push_back(forOp
.getInductionVar());
505 auto iterRange
= forOp
.getRegionIterArgs();
506 iterArgs
.assign(iterRange
.begin(), iterRange
.end());
509 // Store the results of the outermost loop that will be used to replace
510 // the results of the parallel loop when it is fully rewritten.
511 loopResults
.assign(forOp
.result_begin(), forOp
.result_end());
513 } else if (!forOp
.getResults().empty()) {
514 // A loop is constructed with an empty "yield" terminator if there are
516 rewriter
.setInsertionPointToEnd(rewriter
.getInsertionBlock());
517 rewriter
.create
<scf::YieldOp
>(loc
, forOp
.getResults());
520 rewriter
.setInsertionPointToStart(forOp
.getBody());
523 // First, merge reduction blocks into the main region.
524 SmallVector
<Value
> yieldOperands
;
525 yieldOperands
.reserve(parallelOp
.getNumResults());
526 for (int64_t i
= 0, e
= parallelOp
.getNumResults(); i
< e
; ++i
) {
527 Block
&reductionBody
= reductionOp
.getReductions()[i
].front();
528 Value arg
= iterArgs
[yieldOperands
.size()];
529 yieldOperands
.push_back(
530 cast
<ReduceReturnOp
>(reductionBody
.getTerminator()).getResult());
531 rewriter
.eraseOp(reductionBody
.getTerminator());
532 rewriter
.inlineBlockBefore(&reductionBody
, reductionOp
,
533 {arg
, reductionOp
.getOperands()[i
]});
535 rewriter
.eraseOp(reductionOp
);
537 // Then merge the loop body without the terminator.
538 Block
*newBody
= rewriter
.getInsertionBlock();
539 if (newBody
->empty())
540 rewriter
.mergeBlocks(parallelOp
.getBody(), newBody
, ivs
);
542 rewriter
.inlineBlockBefore(parallelOp
.getBody(), newBody
->getTerminator(),
545 // Finally, create the terminator if required (for loops with no results, it
546 // has been already created in loop construction).
547 if (!yieldOperands
.empty()) {
548 rewriter
.setInsertionPointToEnd(rewriter
.getInsertionBlock());
549 rewriter
.create
<scf::YieldOp
>(loc
, yieldOperands
);
552 rewriter
.replaceOp(parallelOp
, loopResults
);
557 LogicalResult
WhileLowering::matchAndRewrite(WhileOp whileOp
,
558 PatternRewriter
&rewriter
) const {
559 OpBuilder::InsertionGuard
guard(rewriter
);
560 Location loc
= whileOp
.getLoc();
562 // Split the current block before the WhileOp to create the inlining point.
563 Block
*currentBlock
= rewriter
.getInsertionBlock();
564 Block
*continuation
=
565 rewriter
.splitBlock(currentBlock
, rewriter
.getInsertionPoint());
567 // Inline both regions.
568 Block
*after
= whileOp
.getAfterBody();
569 Block
*before
= whileOp
.getBeforeBody();
570 rewriter
.inlineRegionBefore(whileOp
.getAfter(), continuation
);
571 rewriter
.inlineRegionBefore(whileOp
.getBefore(), after
);
573 // Branch to the "before" region.
574 rewriter
.setInsertionPointToEnd(currentBlock
);
575 rewriter
.create
<cf::BranchOp
>(loc
, before
, whileOp
.getInits());
577 // Replace terminators with branches. Assuming bodies are SESE, which holds
578 // given only the patterns from this file, we only need to look at the last
579 // block. This should be reconsidered if we allow break/continue in SCF.
580 rewriter
.setInsertionPointToEnd(before
);
581 auto condOp
= cast
<ConditionOp
>(before
->getTerminator());
582 rewriter
.replaceOpWithNewOp
<cf::CondBranchOp
>(condOp
, condOp
.getCondition(),
583 after
, condOp
.getArgs(),
584 continuation
, ValueRange());
586 rewriter
.setInsertionPointToEnd(after
);
587 auto yieldOp
= cast
<scf::YieldOp
>(after
->getTerminator());
588 rewriter
.replaceOpWithNewOp
<cf::BranchOp
>(yieldOp
, before
,
589 yieldOp
.getResults());
591 // Replace the op with values "yielded" from the "before" region, which are
592 // visible by dominance.
593 rewriter
.replaceOp(whileOp
, condOp
.getArgs());
599 DoWhileLowering::matchAndRewrite(WhileOp whileOp
,
600 PatternRewriter
&rewriter
) const {
601 Block
&afterBlock
= *whileOp
.getAfterBody();
602 if (!llvm::hasSingleElement(afterBlock
))
603 return rewriter
.notifyMatchFailure(whileOp
,
604 "do-while simplification applicable "
605 "only if 'after' region has no payload");
607 auto yield
= dyn_cast
<scf::YieldOp
>(&afterBlock
.front());
608 if (!yield
|| yield
.getResults() != afterBlock
.getArguments())
609 return rewriter
.notifyMatchFailure(whileOp
,
610 "do-while simplification applicable "
611 "only to forwarding 'after' regions");
613 // Split the current block before the WhileOp to create the inlining point.
614 OpBuilder::InsertionGuard
guard(rewriter
);
615 Block
*currentBlock
= rewriter
.getInsertionBlock();
616 Block
*continuation
=
617 rewriter
.splitBlock(currentBlock
, rewriter
.getInsertionPoint());
619 // Only the "before" region should be inlined.
620 Block
*before
= whileOp
.getBeforeBody();
621 rewriter
.inlineRegionBefore(whileOp
.getBefore(), continuation
);
623 // Branch to the "before" region.
624 rewriter
.setInsertionPointToEnd(currentBlock
);
625 rewriter
.create
<cf::BranchOp
>(whileOp
.getLoc(), before
, whileOp
.getInits());
627 // Loop around the "before" region based on condition.
628 rewriter
.setInsertionPointToEnd(before
);
629 auto condOp
= cast
<ConditionOp
>(before
->getTerminator());
630 rewriter
.replaceOpWithNewOp
<cf::CondBranchOp
>(condOp
, condOp
.getCondition(),
631 before
, condOp
.getArgs(),
632 continuation
, ValueRange());
634 // Replace the op with values "yielded" from the "before" region, which are
635 // visible by dominance.
636 rewriter
.replaceOp(whileOp
, condOp
.getArgs());
642 IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op
,
643 PatternRewriter
&rewriter
) const {
644 // Split the block at the op.
645 Block
*condBlock
= rewriter
.getInsertionBlock();
646 Block
*continueBlock
= rewriter
.splitBlock(condBlock
, Block::iterator(op
));
648 // Create the arguments on the continue block with which to replace the
649 // results of the op.
650 SmallVector
<Value
> results
;
651 results
.reserve(op
.getNumResults());
652 for (Type resultType
: op
.getResultTypes())
653 results
.push_back(continueBlock
->addArgument(resultType
, op
.getLoc()));
655 // Handle the regions.
656 auto convertRegion
= [&](Region
®ion
) -> FailureOr
<Block
*> {
657 Block
*block
= ®ion
.front();
659 // Convert the yield terminator to a branch to the continue block.
660 auto yield
= cast
<scf::YieldOp
>(block
->getTerminator());
661 rewriter
.setInsertionPoint(yield
);
662 rewriter
.replaceOpWithNewOp
<cf::BranchOp
>(yield
, continueBlock
,
663 yield
.getOperands());
665 // Inline the region.
666 rewriter
.inlineRegionBefore(region
, continueBlock
);
670 // Convert the case regions.
671 SmallVector
<Block
*> caseSuccessors
;
672 SmallVector
<int32_t> caseValues
;
673 caseSuccessors
.reserve(op
.getCases().size());
674 caseValues
.reserve(op
.getCases().size());
675 for (auto [region
, value
] : llvm::zip(op
.getCaseRegions(), op
.getCases())) {
676 FailureOr
<Block
*> block
= convertRegion(region
);
679 caseSuccessors
.push_back(*block
);
680 caseValues
.push_back(value
);
683 // Convert the default region.
684 FailureOr
<Block
*> defaultBlock
= convertRegion(op
.getDefaultRegion());
685 if (failed(defaultBlock
))
688 // Create the switch.
689 rewriter
.setInsertionPointToEnd(condBlock
);
690 SmallVector
<ValueRange
> caseOperands(caseSuccessors
.size(), {});
692 // Cast switch index to integer case value.
693 Value caseValue
= rewriter
.create
<arith::IndexCastOp
>(
694 op
.getLoc(), rewriter
.getI32Type(), op
.getArg());
696 rewriter
.create
<cf::SwitchOp
>(
697 op
.getLoc(), caseValue
, *defaultBlock
, ValueRange(),
698 rewriter
.getDenseI32ArrayAttr(caseValues
), caseSuccessors
, caseOperands
);
699 rewriter
.replaceOp(op
, continueBlock
->getArguments());
703 LogicalResult
ForallLowering::matchAndRewrite(ForallOp forallOp
,
704 PatternRewriter
&rewriter
) const {
705 return scf::forallToParallelLoop(rewriter
, forallOp
);
708 void mlir::populateSCFToControlFlowConversionPatterns(
709 RewritePatternSet
&patterns
) {
710 patterns
.add
<ForallLowering
, ForLowering
, IfLowering
, ParallelLowering
,
711 WhileLowering
, ExecuteRegionLowering
, IndexSwitchLowering
>(
712 patterns
.getContext());
713 patterns
.add
<DoWhileLowering
>(patterns
.getContext(), /*benefit=*/2);
716 void SCFToControlFlowPass::runOnOperation() {
717 RewritePatternSet
patterns(&getContext());
718 populateSCFToControlFlowConversionPatterns(patterns
);
720 // Configure conversion to lower out SCF operations.
721 ConversionTarget
target(getContext());
722 target
.addIllegalOp
<scf::ForallOp
, scf::ForOp
, scf::IfOp
, scf::IndexSwitchOp
,
723 scf::ParallelOp
, scf::WhileOp
, scf::ExecuteRegionOp
>();
724 target
.markUnknownOpDynamicallyLegal([](Operation
*) { return true; });
726 applyPartialConversion(getOperation(), target
, std::move(patterns
))))
730 std::unique_ptr
<Pass
> mlir::createConvertSCFToCFPass() {
731 return std::make_unique
<SCFToControlFlowPass
>();