1 //===-------- TestLoopUnrolling.cpp --- loop unrolling test pass ----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a pass to unroll loops by a specified unroll factor.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/SCF/IR/SCF.h"
15 #include "mlir/Dialect/SCF/Utils/Utils.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/Pass/Pass.h"
23 static unsigned getNestingDepth(Operation
*op
) {
24 Operation
*currOp
= op
;
26 while ((currOp
= currOp
->getParentOp())) {
27 if (isa
<scf::ForOp
>(currOp
))
33 struct TestLoopUnrollingPass
34 : public PassWrapper
<TestLoopUnrollingPass
, OperationPass
<>> {
35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopUnrollingPass
)
37 StringRef
getArgument() const final
{ return "test-loop-unrolling"; }
38 StringRef
getDescription() const final
{
39 return "Tests loop unrolling transformation";
41 TestLoopUnrollingPass() = default;
42 TestLoopUnrollingPass(const TestLoopUnrollingPass
&) {}
43 explicit TestLoopUnrollingPass(uint64_t unrollFactorParam
,
44 unsigned loopDepthParam
,
45 bool annotateLoopParam
) {
46 unrollFactor
= unrollFactorParam
;
47 loopDepth
= loopDepthParam
;
48 annotateLoop
= annotateLoopParam
;
51 void getDependentDialects(DialectRegistry
®istry
) const override
{
52 registry
.insert
<arith::ArithDialect
>();
55 void runOnOperation() override
{
56 SmallVector
<scf::ForOp
, 4> loops
;
57 getOperation()->walk([&](scf::ForOp forOp
) {
58 if (getNestingDepth(forOp
) == loopDepth
)
59 loops
.push_back(forOp
);
61 auto annotateFn
= [this](unsigned i
, Operation
*op
, OpBuilder b
) {
63 op
->setAttr("unrolled_iteration", b
.getUI32IntegerAttr(i
));
66 for (auto loop
: loops
)
67 (void)loopUnrollByFactor(loop
, unrollFactor
, annotateFn
);
69 Option
<uint64_t> unrollFactor
{*this, "unroll-factor",
70 llvm::cl::desc("Loop unroll factor."),
72 Option
<bool> annotateLoop
{*this, "annotate",
73 llvm::cl::desc("Annotate unrolled iterations."),
74 llvm::cl::init(false)};
75 Option
<bool> unrollUpToFactor
{*this, "unroll-up-to-factor",
76 llvm::cl::desc("Loop unroll up to factor."),
77 llvm::cl::init(false)};
78 Option
<unsigned> loopDepth
{*this, "loop-depth", llvm::cl::desc("Loop depth."),
85 void registerTestLoopUnrollingPass() {
86 PassRegistration
<TestLoopUnrollingPass
>();