Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / test / lib / Dialect / SCF / TestLoopUnrolling.cpp
blob8694a7f9bbd625c4862b4eb9c4f87eccae53e381
1 //===-------- TestLoopUnrolling.cpp --- loop unrolling test pass ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to unroll loops by a specified unroll factor.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Dialect/SCF/IR/SCF.h"
15 #include "mlir/Dialect/SCF/Utils/Utils.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/Pass/Pass.h"
19 using namespace mlir;
21 namespace {
23 static unsigned getNestingDepth(Operation *op) {
24 Operation *currOp = op;
25 unsigned depth = 0;
26 while ((currOp = currOp->getParentOp())) {
27 if (isa<scf::ForOp>(currOp))
28 depth++;
30 return depth;
33 struct TestLoopUnrollingPass
34 : public PassWrapper<TestLoopUnrollingPass, OperationPass<>> {
35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopUnrollingPass)
37 StringRef getArgument() const final { return "test-loop-unrolling"; }
38 StringRef getDescription() const final {
39 return "Tests loop unrolling transformation";
41 TestLoopUnrollingPass() = default;
42 TestLoopUnrollingPass(const TestLoopUnrollingPass &) {}
43 explicit TestLoopUnrollingPass(uint64_t unrollFactorParam,
44 unsigned loopDepthParam,
45 bool annotateLoopParam) {
46 unrollFactor = unrollFactorParam;
47 loopDepth = loopDepthParam;
48 annotateLoop = annotateLoopParam;
51 void getDependentDialects(DialectRegistry &registry) const override {
52 registry.insert<arith::ArithDialect>();
55 void runOnOperation() override {
56 SmallVector<scf::ForOp, 4> loops;
57 getOperation()->walk([&](scf::ForOp forOp) {
58 if (getNestingDepth(forOp) == loopDepth)
59 loops.push_back(forOp);
60 });
61 auto annotateFn = [this](unsigned i, Operation *op, OpBuilder b) {
62 if (annotateLoop) {
63 op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i));
66 for (auto loop : loops)
67 (void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
69 Option<uint64_t> unrollFactor{*this, "unroll-factor",
70 llvm::cl::desc("Loop unroll factor."),
71 llvm::cl::init(1)};
72 Option<bool> annotateLoop{*this, "annotate",
73 llvm::cl::desc("Annotate unrolled iterations."),
74 llvm::cl::init(false)};
75 Option<bool> unrollUpToFactor{*this, "unroll-up-to-factor",
76 llvm::cl::desc("Loop unroll up to factor."),
77 llvm::cl::init(false)};
78 Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."),
79 llvm::cl::init(0)};
81 } // namespace
83 namespace mlir {
84 namespace test {
85 void registerTestLoopUnrollingPass() {
86 PassRegistration<TestLoopUnrollingPass>();
88 } // namespace test
89 } // namespace mlir