1 //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a pass to test SCF dialect utils.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/SCF/IR/SCF.h"
17 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
18 #include "mlir/Dialect/SCF/Utils/Utils.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 struct TestSCFForUtilsPass
28 : public PassWrapper
<TestSCFForUtilsPass
, OperationPass
<func::FuncOp
>> {
29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFForUtilsPass
)
31 StringRef
getArgument() const final
{ return "test-scf-for-utils"; }
32 StringRef
getDescription() const final
{ return "test scf.for utils"; }
33 explicit TestSCFForUtilsPass() = default;
34 TestSCFForUtilsPass(const TestSCFForUtilsPass
&pass
) : PassWrapper(pass
) {}
36 Option
<bool> testReplaceWithNewYields
{
37 *this, "test-replace-with-new-yields",
38 llvm::cl::desc("Test replacing a loop with a new loop that returns new "
39 "additional yield values"),
40 llvm::cl::init(false)};
42 void runOnOperation() override
{
43 func::FuncOp func
= getOperation();
44 SmallVector
<scf::ForOp
, 4> toErase
;
46 if (testReplaceWithNewYields
) {
47 func
.walk([&](scf::ForOp forOp
) {
48 if (forOp
.getNumResults() == 0)
50 auto newInitValues
= forOp
.getInitArgs();
51 if (newInitValues
.empty())
53 SmallVector
<Value
> oldYieldValues
=
54 llvm::to_vector(forOp
.getYieldedValues());
55 NewYieldValuesFn fn
= [&](OpBuilder
&b
, Location loc
,
56 ArrayRef
<BlockArgument
> newBBArgs
) {
57 SmallVector
<Value
> newYieldValues
;
58 for (auto yieldVal
: oldYieldValues
) {
59 newYieldValues
.push_back(
60 b
.create
<arith::AddFOp
>(loc
, yieldVal
, yieldVal
));
62 return newYieldValues
;
64 IRRewriter
rewriter(forOp
.getContext());
65 if (failed(forOp
.replaceWithAdditionalYields(
66 rewriter
, newInitValues
, /*replaceInitOperandUsesInLoop=*/true,
74 struct TestSCFIfUtilsPass
75 : public PassWrapper
<TestSCFIfUtilsPass
, OperationPass
<ModuleOp
>> {
76 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFIfUtilsPass
)
78 StringRef
getArgument() const final
{ return "test-scf-if-utils"; }
79 StringRef
getDescription() const final
{ return "test scf.if utils"; }
80 explicit TestSCFIfUtilsPass() = default;
82 void runOnOperation() override
{
84 getOperation().walk([&](scf::IfOp ifOp
) {
85 auto strCount
= std::to_string(count
++);
86 func::FuncOp thenFn
, elseFn
;
88 IRRewriter
rewriter(b
);
89 if (failed(outlineIfOp(rewriter
, ifOp
, &thenFn
,
90 std::string("outlined_then") + strCount
, &elseFn
,
91 std::string("outlined_else") + strCount
))) {
92 this->signalPassFailure();
93 return WalkResult::interrupt();
95 return WalkResult::advance();
100 static const StringLiteral kTestPipeliningLoopMarker
=
101 "__test_pipelining_loop__";
102 static const StringLiteral kTestPipeliningStageMarker
=
103 "__test_pipelining_stage__";
104 /// Marker to express the order in which operations should be after
106 static const StringLiteral kTestPipeliningOpOrderMarker
=
107 "__test_pipelining_op_order__";
109 static const StringLiteral kTestPipeliningAnnotationPart
=
110 "__test_pipelining_part";
111 static const StringLiteral kTestPipeliningAnnotationIteration
=
112 "__test_pipelining_iteration";
114 struct TestSCFPipeliningPass
115 : public PassWrapper
<TestSCFPipeliningPass
, OperationPass
<func::FuncOp
>> {
116 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFPipeliningPass
)
118 TestSCFPipeliningPass() = default;
119 TestSCFPipeliningPass(const TestSCFPipeliningPass
&) {}
120 StringRef
getArgument() const final
{ return "test-scf-pipelining"; }
121 StringRef
getDescription() const final
{ return "test scf.forOp pipelining"; }
123 Option
<bool> annotatePipeline
{
125 llvm::cl::desc("Annote operations during loop pipelining transformation"),
126 llvm::cl::init(false)};
128 Option
<bool> noEpiloguePeeling
{
129 *this, "no-epilogue-peeling",
130 llvm::cl::desc("Use predicates instead of peeling the epilogue."),
131 llvm::cl::init(false)};
134 getSchedule(scf::ForOp forOp
,
135 std::vector
<std::pair
<Operation
*, unsigned>> &schedule
) {
136 if (!forOp
->hasAttr(kTestPipeliningLoopMarker
))
139 schedule
.resize(forOp
.getBody()->getOperations().size() - 1);
140 forOp
.walk([&schedule
](Operation
*op
) {
142 op
->getAttrOfType
<IntegerAttr
>(kTestPipeliningStageMarker
);
144 op
->getAttrOfType
<IntegerAttr
>(kTestPipeliningOpOrderMarker
);
145 if (attrCycle
&& attrStage
) {
146 // TODO: Index can be out-of-bounds if ops of the loop body disappear
148 schedule
[attrCycle
.getInt()] =
149 std::make_pair(op
, unsigned(attrStage
.getInt()));
154 /// Helper to generate "predicated" version of `op`. For simplicity we just
155 /// wrap the operation in a scf.ifOp operation.
156 static Operation
*predicateOp(RewriterBase
&rewriter
, Operation
*op
,
158 Location loc
= op
->getLoc();
160 rewriter
.create
<scf::IfOp
>(loc
, op
->getResultTypes(), pred
, true);
162 rewriter
.moveOpBefore(op
, &ifOp
.getThenRegion().front(),
163 ifOp
.getThenRegion().front().begin());
164 rewriter
.setInsertionPointAfter(op
);
165 if (op
->getNumResults() > 0)
166 rewriter
.create
<scf::YieldOp
>(loc
, op
->getResults());
168 rewriter
.setInsertionPointToStart(&ifOp
.getElseRegion().front());
169 SmallVector
<Value
> elseYieldOperands
;
170 elseYieldOperands
.reserve(ifOp
.getNumResults());
171 if (auto viewOp
= dyn_cast
<memref::SubViewOp
>(op
)) {
172 // For sub-views, just clone the op.
173 // NOTE: This is okay in the test because we use dynamic memref sizes, so
174 // the verifier will not complain. Otherwise, we may create a logically
175 // out-of-bounds view and a different technique should be used.
176 Operation
*opClone
= rewriter
.clone(*op
);
177 elseYieldOperands
.append(opClone
->result_begin(), opClone
->result_end());
179 // Default to assuming constant numeric values.
180 for (Type type
: op
->getResultTypes()) {
181 elseYieldOperands
.push_back(rewriter
.create
<arith::ConstantOp
>(
182 loc
, rewriter
.getZeroAttr(type
)));
185 if (op
->getNumResults() > 0)
186 rewriter
.create
<scf::YieldOp
>(loc
, elseYieldOperands
);
187 return ifOp
.getOperation();
190 static void annotate(Operation
*op
,
191 mlir::scf::PipeliningOption::PipelinerPart part
,
192 unsigned iteration
) {
195 case mlir::scf::PipeliningOption::PipelinerPart::Prologue
:
196 op
->setAttr(kTestPipeliningAnnotationPart
, b
.getStringAttr("prologue"));
198 case mlir::scf::PipeliningOption::PipelinerPart::Kernel
:
199 op
->setAttr(kTestPipeliningAnnotationPart
, b
.getStringAttr("kernel"));
201 case mlir::scf::PipeliningOption::PipelinerPart::Epilogue
:
202 op
->setAttr(kTestPipeliningAnnotationPart
, b
.getStringAttr("epilogue"));
205 op
->setAttr(kTestPipeliningAnnotationIteration
,
206 b
.getI32IntegerAttr(iteration
));
209 void getDependentDialects(DialectRegistry
®istry
) const override
{
210 registry
.insert
<arith::ArithDialect
, memref::MemRefDialect
>();
213 void runOnOperation() override
{
214 RewritePatternSet
patterns(&getContext());
215 mlir::scf::PipeliningOption options
;
216 options
.getScheduleFn
= getSchedule
;
217 options
.supportDynamicLoops
= true;
218 options
.predicateFn
= predicateOp
;
219 if (annotatePipeline
)
220 options
.annotateFn
= annotate
;
221 if (noEpiloguePeeling
) {
222 options
.peelEpilogue
= false;
224 scf::populateSCFLoopPipeliningPatterns(patterns
, options
);
225 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
));
226 getOperation().walk([](Operation
*op
) {
227 // Clean up the markers.
228 op
->removeAttr(kTestPipeliningStageMarker
);
229 op
->removeAttr(kTestPipeliningOpOrderMarker
);
237 void registerTestSCFUtilsPass() {
238 PassRegistration
<TestSCFForUtilsPass
>();
239 PassRegistration
<TestSCFIfUtilsPass
>();
240 PassRegistration
<TestSCFPipeliningPass
>();