[RISCV][VLOPT] Add vector narrowing integer right shift instructions to isSupportedIn...
[llvm-project.git] / flang / lib / Optimizer / Transforms / ControlFlowConverter.cpp
blobb09bbf6106dbbb138d50faabf8a746469d476700
1 //===-- ControlFlowConverter.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "flang/Optimizer/Dialect/FIRDialect.h"
10 #include "flang/Optimizer/Dialect/FIROps.h"
11 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
12 #include "flang/Optimizer/Dialect/Support/FIRContext.h"
13 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
14 #include "flang/Optimizer/Support/InternalNames.h"
15 #include "flang/Optimizer/Support/TypeCode.h"
16 #include "flang/Optimizer/Transforms/Passes.h"
17 #include "flang/Runtime/derived-api.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "llvm/ADT/SmallSet.h"
24 #include "llvm/Support/CommandLine.h"
26 namespace fir {
27 #define GEN_PASS_DEF_CFGCONVERSION
28 #include "flang/Optimizer/Transforms/Passes.h.inc"
29 } // namespace fir
31 using namespace fir;
32 using namespace mlir;
34 namespace {
36 // Conversion of fir control ops to more primitive control-flow.
38 // FIR loops that cannot be converted to the affine dialect will remain as
39 // `fir.do_loop` operations. These can be converted to control-flow operations.
41 /// Convert `fir.do_loop` to CFG
42 class CfgLoopConv : public mlir::OpRewritePattern<fir::DoLoopOp> {
43 public:
44 using OpRewritePattern::OpRewritePattern;
46 CfgLoopConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce, bool setNSW)
47 : mlir::OpRewritePattern<fir::DoLoopOp>(ctx),
48 forceLoopToExecuteOnce(forceLoopToExecuteOnce), setNSW(setNSW) {}
50 llvm::LogicalResult
51 matchAndRewrite(DoLoopOp loop,
52 mlir::PatternRewriter &rewriter) const override {
53 auto loc = loop.getLoc();
54 mlir::arith::IntegerOverflowFlags flags{};
55 if (setNSW)
56 flags = bitEnumSet(flags, mlir::arith::IntegerOverflowFlags::nsw);
57 auto iofAttr = mlir::arith::IntegerOverflowFlagsAttr::get(
58 rewriter.getContext(), flags);
60 // Create the start and end blocks that will wrap the DoLoopOp with an
61 // initalizer and an end point
62 auto *initBlock = rewriter.getInsertionBlock();
63 auto initPos = rewriter.getInsertionPoint();
64 auto *endBlock = rewriter.splitBlock(initBlock, initPos);
66 // Split the first DoLoopOp block in two parts. The part before will be the
67 // conditional block since it already has the induction variable and
68 // loop-carried values as arguments.
69 auto *conditionalBlock = &loop.getRegion().front();
70 conditionalBlock->addArgument(rewriter.getIndexType(), loc);
71 auto *firstBlock =
72 rewriter.splitBlock(conditionalBlock, conditionalBlock->begin());
73 auto *lastBlock = &loop.getRegion().back();
75 // Move the blocks from the DoLoopOp between initBlock and endBlock
76 rewriter.inlineRegionBefore(loop.getRegion(), endBlock);
78 // Get loop values from the DoLoopOp
79 auto low = loop.getLowerBound();
80 auto high = loop.getUpperBound();
81 assert(low && high && "must be a Value");
82 auto step = loop.getStep();
84 // Initalization block
85 rewriter.setInsertionPointToEnd(initBlock);
86 auto diff = rewriter.create<mlir::arith::SubIOp>(loc, high, low);
87 auto distance = rewriter.create<mlir::arith::AddIOp>(loc, diff, step);
88 mlir::Value iters =
89 rewriter.create<mlir::arith::DivSIOp>(loc, distance, step);
91 if (forceLoopToExecuteOnce) {
92 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
93 auto cond = rewriter.create<mlir::arith::CmpIOp>(
94 loc, arith::CmpIPredicate::sle, iters, zero);
95 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
96 iters = rewriter.create<mlir::arith::SelectOp>(loc, cond, one, iters);
99 llvm::SmallVector<mlir::Value> loopOperands;
100 loopOperands.push_back(low);
101 auto operands = loop.getIterOperands();
102 loopOperands.append(operands.begin(), operands.end());
103 loopOperands.push_back(iters);
105 rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopOperands);
107 // Last loop block
108 auto *terminator = lastBlock->getTerminator();
109 rewriter.setInsertionPointToEnd(lastBlock);
110 auto iv = conditionalBlock->getArgument(0);
111 mlir::Value steppedIndex =
112 rewriter.create<mlir::arith::AddIOp>(loc, iv, step, iofAttr);
113 assert(steppedIndex && "must be a Value");
114 auto lastArg = conditionalBlock->getNumArguments() - 1;
115 auto itersLeft = conditionalBlock->getArgument(lastArg);
116 auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
117 mlir::Value itersMinusOne =
118 rewriter.create<mlir::arith::SubIOp>(loc, itersLeft, one);
120 llvm::SmallVector<mlir::Value> loopCarried;
121 loopCarried.push_back(steppedIndex);
122 auto begin = loop.getFinalValue() ? std::next(terminator->operand_begin())
123 : terminator->operand_begin();
124 loopCarried.append(begin, terminator->operand_end());
125 loopCarried.push_back(itersMinusOne);
126 rewriter.create<mlir::cf::BranchOp>(loc, conditionalBlock, loopCarried);
127 rewriter.eraseOp(terminator);
129 // Conditional block
130 rewriter.setInsertionPointToEnd(conditionalBlock);
131 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
132 auto comparison = rewriter.create<mlir::arith::CmpIOp>(
133 loc, arith::CmpIPredicate::sgt, itersLeft, zero);
135 auto cond = rewriter.create<mlir::cf::CondBranchOp>(
136 loc, comparison, firstBlock, llvm::ArrayRef<mlir::Value>(), endBlock,
137 llvm::ArrayRef<mlir::Value>());
139 // Copy loop annotations from the do loop to the loop entry condition.
140 if (auto ann = loop.getLoopAnnotation())
141 cond->setAttr("loop_annotation", *ann);
143 // The result of the loop operation is the values of the condition block
144 // arguments except the induction variable on the last iteration.
145 auto args = loop.getFinalValue()
146 ? conditionalBlock->getArguments()
147 : conditionalBlock->getArguments().drop_front();
148 rewriter.replaceOp(loop, args.drop_back());
149 return success();
152 private:
153 bool forceLoopToExecuteOnce;
154 bool setNSW;
157 /// Convert `fir.if` to control-flow
158 class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> {
159 public:
160 using OpRewritePattern::OpRewritePattern;
162 CfgIfConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce, bool setNSW)
163 : mlir::OpRewritePattern<fir::IfOp>(ctx) {}
165 llvm::LogicalResult
166 matchAndRewrite(IfOp ifOp, mlir::PatternRewriter &rewriter) const override {
167 auto loc = ifOp.getLoc();
169 // Split the block containing the 'fir.if' into two parts. The part before
170 // will contain the condition, the part after will be the continuation
171 // point.
172 auto *condBlock = rewriter.getInsertionBlock();
173 auto opPosition = rewriter.getInsertionPoint();
174 auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
175 mlir::Block *continueBlock;
176 if (ifOp.getNumResults() == 0) {
177 continueBlock = remainingOpsBlock;
178 } else {
179 continueBlock = rewriter.createBlock(
180 remainingOpsBlock, ifOp.getResultTypes(),
181 llvm::SmallVector<mlir::Location>(ifOp.getNumResults(), loc));
182 rewriter.create<mlir::cf::BranchOp>(loc, remainingOpsBlock);
185 // Move blocks from the "then" region to the region containing 'fir.if',
186 // place it before the continuation block, and branch to it.
187 auto &ifOpRegion = ifOp.getThenRegion();
188 auto *ifOpBlock = &ifOpRegion.front();
189 auto *ifOpTerminator = ifOpRegion.back().getTerminator();
190 auto ifOpTerminatorOperands = ifOpTerminator->getOperands();
191 rewriter.setInsertionPointToEnd(&ifOpRegion.back());
192 rewriter.create<mlir::cf::BranchOp>(loc, continueBlock,
193 ifOpTerminatorOperands);
194 rewriter.eraseOp(ifOpTerminator);
195 rewriter.inlineRegionBefore(ifOpRegion, continueBlock);
197 // Move blocks from the "else" region (if present) to the region containing
198 // 'fir.if', place it before the continuation block and branch to it. It
199 // will be placed after the "then" regions.
200 auto *otherwiseBlock = continueBlock;
201 auto &otherwiseRegion = ifOp.getElseRegion();
202 if (!otherwiseRegion.empty()) {
203 otherwiseBlock = &otherwiseRegion.front();
204 auto *otherwiseTerm = otherwiseRegion.back().getTerminator();
205 auto otherwiseTermOperands = otherwiseTerm->getOperands();
206 rewriter.setInsertionPointToEnd(&otherwiseRegion.back());
207 rewriter.create<mlir::cf::BranchOp>(loc, continueBlock,
208 otherwiseTermOperands);
209 rewriter.eraseOp(otherwiseTerm);
210 rewriter.inlineRegionBefore(otherwiseRegion, continueBlock);
213 rewriter.setInsertionPointToEnd(condBlock);
214 rewriter.create<mlir::cf::CondBranchOp>(
215 loc, ifOp.getCondition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
216 otherwiseBlock, llvm::ArrayRef<mlir::Value>());
217 rewriter.replaceOp(ifOp, continueBlock->getArguments());
218 return success();
222 /// Convert `fir.iter_while` to control-flow.
223 class CfgIterWhileConv : public mlir::OpRewritePattern<fir::IterWhileOp> {
224 public:
225 using OpRewritePattern::OpRewritePattern;
227 CfgIterWhileConv(mlir::MLIRContext *ctx, bool forceLoopToExecuteOnce,
228 bool setNSW)
229 : mlir::OpRewritePattern<fir::IterWhileOp>(ctx), setNSW(setNSW) {}
231 llvm::LogicalResult
232 matchAndRewrite(fir::IterWhileOp whileOp,
233 mlir::PatternRewriter &rewriter) const override {
234 auto loc = whileOp.getLoc();
235 mlir::arith::IntegerOverflowFlags flags{};
236 if (setNSW)
237 flags = bitEnumSet(flags, mlir::arith::IntegerOverflowFlags::nsw);
238 auto iofAttr = mlir::arith::IntegerOverflowFlagsAttr::get(
239 rewriter.getContext(), flags);
241 // Start by splitting the block containing the 'fir.do_loop' into two parts.
242 // The part before will get the init code, the part after will be the end
243 // point.
244 auto *initBlock = rewriter.getInsertionBlock();
245 auto initPosition = rewriter.getInsertionPoint();
246 auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
248 // Use the first block of the loop body as the condition block since it is
249 // the block that has the induction variable and loop-carried values as
250 // arguments. Split out all operations from the first block into a new
251 // block. Move all body blocks from the loop body region to the region
252 // containing the loop.
253 auto *conditionBlock = &whileOp.getRegion().front();
254 auto *firstBodyBlock =
255 rewriter.splitBlock(conditionBlock, conditionBlock->begin());
256 auto *lastBodyBlock = &whileOp.getRegion().back();
257 rewriter.inlineRegionBefore(whileOp.getRegion(), endBlock);
258 auto iv = conditionBlock->getArgument(0);
259 auto iterateVar = conditionBlock->getArgument(1);
261 // Append the induction variable stepping logic to the last body block and
262 // branch back to the condition block. Loop-carried values are taken from
263 // operands of the loop terminator.
264 auto *terminator = lastBodyBlock->getTerminator();
265 rewriter.setInsertionPointToEnd(lastBodyBlock);
266 auto step = whileOp.getStep();
267 mlir::Value stepped =
268 rewriter.create<mlir::arith::AddIOp>(loc, iv, step, iofAttr);
269 assert(stepped && "must be a Value");
271 llvm::SmallVector<mlir::Value> loopCarried;
272 loopCarried.push_back(stepped);
273 auto begin = whileOp.getFinalValue()
274 ? std::next(terminator->operand_begin())
275 : terminator->operand_begin();
276 loopCarried.append(begin, terminator->operand_end());
277 rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, loopCarried);
278 rewriter.eraseOp(terminator);
280 // Compute loop bounds before branching to the condition.
281 rewriter.setInsertionPointToEnd(initBlock);
282 auto lowerBound = whileOp.getLowerBound();
283 auto upperBound = whileOp.getUpperBound();
284 assert(lowerBound && upperBound && "must be a Value");
286 // The initial values of loop-carried values is obtained from the operands
287 // of the loop operation.
288 llvm::SmallVector<mlir::Value> destOperands;
289 destOperands.push_back(lowerBound);
290 auto iterOperands = whileOp.getIterOperands();
291 destOperands.append(iterOperands.begin(), iterOperands.end());
292 rewriter.create<mlir::cf::BranchOp>(loc, conditionBlock, destOperands);
294 // With the body block done, we can fill in the condition block.
295 rewriter.setInsertionPointToEnd(conditionBlock);
296 // The comparison depends on the sign of the step value. We fully expect
297 // this expression to be folded by the optimizer or LLVM. This expression
298 // is written this way so that `step == 0` always returns `false`.
299 auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0);
300 auto compl0 = rewriter.create<mlir::arith::CmpIOp>(
301 loc, arith::CmpIPredicate::slt, zero, step);
302 auto compl1 = rewriter.create<mlir::arith::CmpIOp>(
303 loc, arith::CmpIPredicate::sle, iv, upperBound);
304 auto compl2 = rewriter.create<mlir::arith::CmpIOp>(
305 loc, arith::CmpIPredicate::slt, step, zero);
306 auto compl3 = rewriter.create<mlir::arith::CmpIOp>(
307 loc, arith::CmpIPredicate::sle, upperBound, iv);
308 auto cmp0 = rewriter.create<mlir::arith::AndIOp>(loc, compl0, compl1);
309 auto cmp1 = rewriter.create<mlir::arith::AndIOp>(loc, compl2, compl3);
310 auto cmp2 = rewriter.create<mlir::arith::OrIOp>(loc, cmp0, cmp1);
311 // Remember to AND in the early-exit bool.
312 auto comparison =
313 rewriter.create<mlir::arith::AndIOp>(loc, iterateVar, cmp2);
314 rewriter.create<mlir::cf::CondBranchOp>(
315 loc, comparison, firstBodyBlock, llvm::ArrayRef<mlir::Value>(),
316 endBlock, llvm::ArrayRef<mlir::Value>());
317 // The result of the loop operation is the values of the condition block
318 // arguments except the induction variable on the last iteration.
319 auto args = whileOp.getFinalValue()
320 ? conditionBlock->getArguments()
321 : conditionBlock->getArguments().drop_front();
322 rewriter.replaceOp(whileOp, args);
323 return success();
326 private:
327 bool setNSW;
330 /// Convert FIR structured control flow ops to CFG ops.
331 class CfgConversion : public fir::impl::CFGConversionBase<CfgConversion> {
332 public:
333 using CFGConversionBase<CfgConversion>::CFGConversionBase;
335 void runOnOperation() override {
336 auto *context = &this->getContext();
337 mlir::RewritePatternSet patterns(context);
338 fir::populateCfgConversionRewrites(patterns, this->forceLoopToExecuteOnce,
339 this->setNSW);
340 mlir::ConversionTarget target(*context);
341 target.addLegalDialect<mlir::affine::AffineDialect,
342 mlir::cf::ControlFlowDialect, FIROpsDialect,
343 mlir::func::FuncDialect>();
345 // apply the patterns
346 target.addIllegalOp<ResultOp, DoLoopOp, IfOp, IterWhileOp>();
347 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
348 if (mlir::failed(mlir::applyPartialConversion(this->getOperation(), target,
349 std::move(patterns)))) {
350 mlir::emitError(mlir::UnknownLoc::get(context),
351 "error in converting to CFG\n");
352 this->signalPassFailure();
357 } // namespace
359 /// Expose conversion rewriters to other passes
360 void fir::populateCfgConversionRewrites(mlir::RewritePatternSet &patterns,
361 bool forceLoopToExecuteOnce,
362 bool setNSW) {
363 patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>(
364 patterns.getContext(), forceLoopToExecuteOnce, setNSW);