Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / test / lib / Dialect / Affine / TestLoopFusion.cpp
blob19011803a793ac1f6731a51710b80329a26b91a6
1 //===- TestLoopFusion.cpp - Test loop fusion ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to test various loop fusion utility functions.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/Affine/Analysis/Utils.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Affine/LoopFusionUtils.h"
16 #include "mlir/Dialect/Affine/LoopUtils.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Pass/Pass.h"
20 #define DEBUG_TYPE "test-loop-fusion"
22 using namespace mlir;
23 using namespace mlir::affine;
25 namespace {
27 struct TestLoopFusion
28 : public PassWrapper<TestLoopFusion, OperationPass<func::FuncOp>> {
29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopFusion)
31 StringRef getArgument() const final { return "test-loop-fusion"; }
32 StringRef getDescription() const final {
33 return "Tests loop fusion utility functions.";
35 void runOnOperation() override;
37 TestLoopFusion() = default;
38 TestLoopFusion(const TestLoopFusion &pass) : PassWrapper(pass){};
40 Option<bool> clTestDependenceCheck{
41 *this, "test-loop-fusion-dependence-check",
42 llvm::cl::desc("Enable testing of loop fusion dependence check"),
43 llvm::cl::init(false)};
45 Option<bool> clTestSliceComputation{
46 *this, "test-loop-fusion-slice-computation",
47 llvm::cl::desc("Enable testing of loop fusion slice computation"),
48 llvm::cl::init(false)};
50 Option<bool> clTestLoopFusionTransformation{
51 *this, "test-loop-fusion-transformation",
52 llvm::cl::desc("Enable testing of loop fusion transformation"),
53 llvm::cl::init(false)};
56 } // namespace
58 // Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths
59 // in range ['loopDepth' + 1, 'maxLoopDepth'].
60 // Emits a remark on 'loops[i]' if a fusion-preventing dependence exists.
61 // Returns false as IR is not transformed.
62 static bool testDependenceCheck(AffineForOp srcForOp, AffineForOp dstForOp,
63 unsigned i, unsigned j, unsigned loopDepth,
64 unsigned maxLoopDepth) {
65 affine::ComputationSliceState sliceUnion;
66 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
67 FusionResult result =
68 affine::canFuseLoops(srcForOp, dstForOp, d, &sliceUnion);
69 if (result.value == FusionResult::FailBlockDependence) {
70 srcForOp->emitRemark("block-level dependence preventing"
71 " fusion of loop nest ")
72 << i << " into loop nest " << j << " at depth " << loopDepth;
75 return false;
78 // Returns the index of 'op' in its block.
79 static unsigned getBlockIndex(Operation &op) {
80 unsigned index = 0;
81 for (auto &opX : *op.getBlock()) {
82 if (&op == &opX)
83 break;
84 ++index;
86 return index;
89 // Returns a string representation of 'sliceUnion'.
90 static std::string
91 getSliceStr(const affine::ComputationSliceState &sliceUnion) {
92 std::string result;
93 llvm::raw_string_ostream os(result);
94 // Slice insertion point format [loop-depth, operation-block-index]
95 unsigned ipd = getNestingDepth(&*sliceUnion.insertPoint);
96 unsigned ipb = getBlockIndex(*sliceUnion.insertPoint);
97 os << "insert point: (" << std::to_string(ipd) << ", " << std::to_string(ipb)
98 << ")";
99 assert(sliceUnion.lbs.size() == sliceUnion.ubs.size());
100 os << " loop bounds: ";
101 for (unsigned k = 0, e = sliceUnion.lbs.size(); k < e; ++k) {
102 os << '[';
103 sliceUnion.lbs[k].print(os);
104 os << ", ";
105 sliceUnion.ubs[k].print(os);
106 os << "] ";
108 return os.str();
111 /// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths
112 /// in range ['loopDepth' + 1, 'maxLoopDepth'].
113 /// Emits a string representation of the slice union as a remark on 'loops[j]'
114 /// and marks this as incorrect slice if the slice is invalid. Returns false as
115 /// IR is not transformed.
116 static bool testSliceComputation(AffineForOp forOpA, AffineForOp forOpB,
117 unsigned i, unsigned j, unsigned loopDepth,
118 unsigned maxLoopDepth) {
119 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
120 affine::ComputationSliceState sliceUnion;
121 FusionResult result = affine::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
122 if (result.value == FusionResult::Success) {
123 forOpB->emitRemark("slice (")
124 << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
125 << " : " << getSliceStr(sliceUnion) << ")";
126 } else if (result.value == FusionResult::FailIncorrectSlice) {
127 forOpB->emitRemark("Incorrect slice (")
128 << " src loop: " << i << ", dst loop: " << j << ", depth: " << d
129 << " : " << getSliceStr(sliceUnion) << ")";
132 return false;
135 // Attempts to fuse 'forOpA' into 'forOpB' at loop depths in range
136 // ['loopDepth' + 1, 'maxLoopDepth'].
137 // Returns true if loops were successfully fused, false otherwise.
138 static bool testLoopFusionTransformation(AffineForOp forOpA, AffineForOp forOpB,
139 unsigned i, unsigned j,
140 unsigned loopDepth,
141 unsigned maxLoopDepth) {
142 for (unsigned d = loopDepth + 1; d <= maxLoopDepth; ++d) {
143 affine::ComputationSliceState sliceUnion;
144 FusionResult result = affine::canFuseLoops(forOpA, forOpB, d, &sliceUnion);
145 if (result.value == FusionResult::Success) {
146 affine::fuseLoops(forOpA, forOpB, sliceUnion);
147 // Note: 'forOpA' is removed to simplify test output. A proper loop
148 // fusion pass should check the data dependence graph and run memref
149 // region analysis to ensure removing 'forOpA' is safe.
150 forOpA.erase();
151 return true;
154 return false;
157 using LoopFunc = function_ref<bool(AffineForOp, AffineForOp, unsigned, unsigned,
158 unsigned, unsigned)>;
160 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
161 // If 'return_on_change' is true, returns on first invocation of 'fn' which
162 // returns true.
163 static bool iterateLoops(ArrayRef<SmallVector<AffineForOp, 2>> depthToLoops,
164 LoopFunc fn, bool returnOnChange = false) {
165 bool changed = false;
166 for (unsigned loopDepth = 0, end = depthToLoops.size(); loopDepth < end;
167 ++loopDepth) {
168 auto &loops = depthToLoops[loopDepth];
169 unsigned numLoops = loops.size();
170 for (unsigned j = 0; j < numLoops; ++j) {
171 for (unsigned k = 0; k < numLoops; ++k) {
172 if (j != k)
173 changed |=
174 fn(loops[j], loops[k], j, k, loopDepth, depthToLoops.size());
175 if (changed && returnOnChange)
176 return true;
180 return changed;
183 void TestLoopFusion::runOnOperation() {
184 std::vector<SmallVector<AffineForOp, 2>> depthToLoops;
185 if (clTestLoopFusionTransformation) {
186 // Run loop fusion until a fixed point is reached.
187 do {
188 depthToLoops.clear();
189 // Gather all AffineForOps by loop depth.
190 gatherLoops(getOperation(), depthToLoops);
192 // Try to fuse all combinations of src/dst loop nests in 'depthToLoops'.
193 } while (iterateLoops(depthToLoops, testLoopFusionTransformation,
194 /*returnOnChange=*/true));
195 return;
198 // Gather all AffineForOps by loop depth.
199 gatherLoops(getOperation(), depthToLoops);
201 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
202 if (clTestDependenceCheck)
203 iterateLoops(depthToLoops, testDependenceCheck);
204 if (clTestSliceComputation)
205 iterateLoops(depthToLoops, testSliceComputation);
208 namespace mlir {
209 namespace test {
210 void registerTestLoopFusion() { PassRegistration<TestLoopFusion>(); }
211 } // namespace test
212 } // namespace mlir