Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / test / lib / Dialect / SCF / TestSCFUtils.cpp
blob3ff7f9966e93da6280ed1e65a035e8aa4f2160f5
1 //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test SCF dialect utils.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/SCF/IR/SCF.h"
17 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
18 #include "mlir/Dialect/SCF/Utils/Utils.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 using namespace mlir;
26 namespace {
27 struct TestSCFForUtilsPass
28 : public PassWrapper<TestSCFForUtilsPass, OperationPass<func::FuncOp>> {
29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFForUtilsPass)
31 StringRef getArgument() const final { return "test-scf-for-utils"; }
32 StringRef getDescription() const final { return "test scf.for utils"; }
33 explicit TestSCFForUtilsPass() = default;
34 TestSCFForUtilsPass(const TestSCFForUtilsPass &pass) : PassWrapper(pass) {}
36 Option<bool> testReplaceWithNewYields{
37 *this, "test-replace-with-new-yields",
38 llvm::cl::desc("Test replacing a loop with a new loop that returns new "
39 "additional yield values"),
40 llvm::cl::init(false)};
42 void runOnOperation() override {
43 func::FuncOp func = getOperation();
44 SmallVector<scf::ForOp, 4> toErase;
46 if (testReplaceWithNewYields) {
47 func.walk([&](scf::ForOp forOp) {
48 if (forOp.getNumResults() == 0)
49 return;
50 auto newInitValues = forOp.getInitArgs();
51 if (newInitValues.empty())
52 return;
53 SmallVector<Value> oldYieldValues =
54 llvm::to_vector(forOp.getYieldedValues());
55 NewYieldValuesFn fn = [&](OpBuilder &b, Location loc,
56 ArrayRef<BlockArgument> newBBArgs) {
57 SmallVector<Value> newYieldValues;
58 for (auto yieldVal : oldYieldValues) {
59 newYieldValues.push_back(
60 b.create<arith::AddFOp>(loc, yieldVal, yieldVal));
62 return newYieldValues;
64 IRRewriter rewriter(forOp.getContext());
65 if (failed(forOp.replaceWithAdditionalYields(
66 rewriter, newInitValues, /*replaceInitOperandUsesInLoop=*/true,
67 fn)))
68 signalPassFailure();
69 });
74 struct TestSCFIfUtilsPass
75 : public PassWrapper<TestSCFIfUtilsPass, OperationPass<ModuleOp>> {
76 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFIfUtilsPass)
78 StringRef getArgument() const final { return "test-scf-if-utils"; }
79 StringRef getDescription() const final { return "test scf.if utils"; }
80 explicit TestSCFIfUtilsPass() = default;
82 void runOnOperation() override {
83 int count = 0;
84 getOperation().walk([&](scf::IfOp ifOp) {
85 auto strCount = std::to_string(count++);
86 func::FuncOp thenFn, elseFn;
87 OpBuilder b(ifOp);
88 IRRewriter rewriter(b);
89 if (failed(outlineIfOp(rewriter, ifOp, &thenFn,
90 std::string("outlined_then") + strCount, &elseFn,
91 std::string("outlined_else") + strCount))) {
92 this->signalPassFailure();
93 return WalkResult::interrupt();
95 return WalkResult::advance();
96 });
100 static const StringLiteral kTestPipeliningLoopMarker =
101 "__test_pipelining_loop__";
102 static const StringLiteral kTestPipeliningStageMarker =
103 "__test_pipelining_stage__";
104 /// Marker to express the order in which operations should be after
105 /// pipelining.
106 static const StringLiteral kTestPipeliningOpOrderMarker =
107 "__test_pipelining_op_order__";
109 static const StringLiteral kTestPipeliningAnnotationPart =
110 "__test_pipelining_part";
111 static const StringLiteral kTestPipeliningAnnotationIteration =
112 "__test_pipelining_iteration";
114 struct TestSCFPipeliningPass
115 : public PassWrapper<TestSCFPipeliningPass, OperationPass<func::FuncOp>> {
116 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFPipeliningPass)
118 TestSCFPipeliningPass() = default;
119 TestSCFPipeliningPass(const TestSCFPipeliningPass &) {}
120 StringRef getArgument() const final { return "test-scf-pipelining"; }
121 StringRef getDescription() const final { return "test scf.forOp pipelining"; }
123 Option<bool> annotatePipeline{
124 *this, "annotate",
125 llvm::cl::desc("Annote operations during loop pipelining transformation"),
126 llvm::cl::init(false)};
128 Option<bool> noEpiloguePeeling{
129 *this, "no-epilogue-peeling",
130 llvm::cl::desc("Use predicates instead of peeling the epilogue."),
131 llvm::cl::init(false)};
133 static void
134 getSchedule(scf::ForOp forOp,
135 std::vector<std::pair<Operation *, unsigned>> &schedule) {
136 if (!forOp->hasAttr(kTestPipeliningLoopMarker))
137 return;
139 schedule.resize(forOp.getBody()->getOperations().size() - 1);
140 forOp.walk([&schedule](Operation *op) {
141 auto attrStage =
142 op->getAttrOfType<IntegerAttr>(kTestPipeliningStageMarker);
143 auto attrCycle =
144 op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker);
145 if (attrCycle && attrStage) {
146 // TODO: Index can be out-of-bounds if ops of the loop body disappear
147 // due to folding.
148 schedule[attrCycle.getInt()] =
149 std::make_pair(op, unsigned(attrStage.getInt()));
154 /// Helper to generate "predicated" version of `op`. For simplicity we just
155 /// wrap the operation in a scf.ifOp operation.
156 static Operation *predicateOp(RewriterBase &rewriter, Operation *op,
157 Value pred) {
158 Location loc = op->getLoc();
159 auto ifOp =
160 rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true);
161 // True branch.
162 rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(),
163 ifOp.getThenRegion().front().begin());
164 rewriter.setInsertionPointAfter(op);
165 if (op->getNumResults() > 0)
166 rewriter.create<scf::YieldOp>(loc, op->getResults());
167 // False branch.
168 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
169 SmallVector<Value> elseYieldOperands;
170 elseYieldOperands.reserve(ifOp.getNumResults());
171 if (auto viewOp = dyn_cast<memref::SubViewOp>(op)) {
172 // For sub-views, just clone the op.
173 // NOTE: This is okay in the test because we use dynamic memref sizes, so
174 // the verifier will not complain. Otherwise, we may create a logically
175 // out-of-bounds view and a different technique should be used.
176 Operation *opClone = rewriter.clone(*op);
177 elseYieldOperands.append(opClone->result_begin(), opClone->result_end());
178 } else {
179 // Default to assuming constant numeric values.
180 for (Type type : op->getResultTypes()) {
181 elseYieldOperands.push_back(rewriter.create<arith::ConstantOp>(
182 loc, rewriter.getZeroAttr(type)));
185 if (op->getNumResults() > 0)
186 rewriter.create<scf::YieldOp>(loc, elseYieldOperands);
187 return ifOp.getOperation();
190 static void annotate(Operation *op,
191 mlir::scf::PipeliningOption::PipelinerPart part,
192 unsigned iteration) {
193 OpBuilder b(op);
194 switch (part) {
195 case mlir::scf::PipeliningOption::PipelinerPart::Prologue:
196 op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("prologue"));
197 break;
198 case mlir::scf::PipeliningOption::PipelinerPart::Kernel:
199 op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("kernel"));
200 break;
201 case mlir::scf::PipeliningOption::PipelinerPart::Epilogue:
202 op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("epilogue"));
203 break;
205 op->setAttr(kTestPipeliningAnnotationIteration,
206 b.getI32IntegerAttr(iteration));
209 void getDependentDialects(DialectRegistry &registry) const override {
210 registry.insert<arith::ArithDialect, memref::MemRefDialect>();
213 void runOnOperation() override {
214 RewritePatternSet patterns(&getContext());
215 mlir::scf::PipeliningOption options;
216 options.getScheduleFn = getSchedule;
217 options.supportDynamicLoops = true;
218 options.predicateFn = predicateOp;
219 if (annotatePipeline)
220 options.annotateFn = annotate;
221 if (noEpiloguePeeling) {
222 options.peelEpilogue = false;
224 scf::populateSCFLoopPipeliningPatterns(patterns, options);
225 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
226 getOperation().walk([](Operation *op) {
227 // Clean up the markers.
228 op->removeAttr(kTestPipeliningStageMarker);
229 op->removeAttr(kTestPipeliningOpOrderMarker);
233 } // namespace
235 namespace mlir {
236 namespace test {
237 void registerTestSCFUtilsPass() {
238 PassRegistration<TestSCFForUtilsPass>();
239 PassRegistration<TestSCFIfUtilsPass>();
240 PassRegistration<TestSCFPipeliningPass>();
242 } // namespace test
243 } // namespace mlir