1 //===- TestLoopFusion.cpp - Test loop fusion ------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a pass to test various loop fusion utility functions.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/Affine/Analysis/Utils.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Affine/LoopFusionUtils.h"
16 #include "mlir/Dialect/Affine/LoopUtils.h"
17 #include "mlir/Dialect/Func/IR/FuncOps.h"
18 #include "mlir/Pass/Pass.h"
20 #define DEBUG_TYPE "test-loop-fusion"
23 using namespace mlir::affine
;
28 : public PassWrapper
<TestLoopFusion
, OperationPass
<func::FuncOp
>> {
29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopFusion
)
31 StringRef
getArgument() const final
{ return "test-loop-fusion"; }
32 StringRef
getDescription() const final
{
33 return "Tests loop fusion utility functions.";
35 void runOnOperation() override
;
37 TestLoopFusion() = default;
38 TestLoopFusion(const TestLoopFusion
&pass
) : PassWrapper(pass
){};
40 Option
<bool> clTestDependenceCheck
{
41 *this, "test-loop-fusion-dependence-check",
42 llvm::cl::desc("Enable testing of loop fusion dependence check"),
43 llvm::cl::init(false)};
45 Option
<bool> clTestSliceComputation
{
46 *this, "test-loop-fusion-slice-computation",
47 llvm::cl::desc("Enable testing of loop fusion slice computation"),
48 llvm::cl::init(false)};
50 Option
<bool> clTestLoopFusionTransformation
{
51 *this, "test-loop-fusion-transformation",
52 llvm::cl::desc("Enable testing of loop fusion transformation"),
53 llvm::cl::init(false)};
58 // Run fusion dependence check on 'loops[i]' and 'loops[j]' at loop depths
59 // in range ['loopDepth' + 1, 'maxLoopDepth'].
60 // Emits a remark on 'loops[i]' if a fusion-preventing dependence exists.
61 // Returns false as IR is not transformed.
62 static bool testDependenceCheck(AffineForOp srcForOp
, AffineForOp dstForOp
,
63 unsigned i
, unsigned j
, unsigned loopDepth
,
64 unsigned maxLoopDepth
) {
65 affine::ComputationSliceState sliceUnion
;
66 for (unsigned d
= loopDepth
+ 1; d
<= maxLoopDepth
; ++d
) {
68 affine::canFuseLoops(srcForOp
, dstForOp
, d
, &sliceUnion
);
69 if (result
.value
== FusionResult::FailBlockDependence
) {
70 srcForOp
->emitRemark("block-level dependence preventing"
71 " fusion of loop nest ")
72 << i
<< " into loop nest " << j
<< " at depth " << loopDepth
;
78 // Returns the index of 'op' in its block.
79 static unsigned getBlockIndex(Operation
&op
) {
81 for (auto &opX
: *op
.getBlock()) {
89 // Returns a string representation of 'sliceUnion'.
91 getSliceStr(const affine::ComputationSliceState
&sliceUnion
) {
93 llvm::raw_string_ostream
os(result
);
94 // Slice insertion point format [loop-depth, operation-block-index]
95 unsigned ipd
= getNestingDepth(&*sliceUnion
.insertPoint
);
96 unsigned ipb
= getBlockIndex(*sliceUnion
.insertPoint
);
97 os
<< "insert point: (" << std::to_string(ipd
) << ", " << std::to_string(ipb
)
99 assert(sliceUnion
.lbs
.size() == sliceUnion
.ubs
.size());
100 os
<< " loop bounds: ";
101 for (unsigned k
= 0, e
= sliceUnion
.lbs
.size(); k
< e
; ++k
) {
103 sliceUnion
.lbs
[k
].print(os
);
105 sliceUnion
.ubs
[k
].print(os
);
111 /// Computes fusion slice union on 'loops[i]' and 'loops[j]' at loop depths
112 /// in range ['loopDepth' + 1, 'maxLoopDepth'].
113 /// Emits a string representation of the slice union as a remark on 'loops[j]'
114 /// and marks this as incorrect slice if the slice is invalid. Returns false as
115 /// IR is not transformed.
116 static bool testSliceComputation(AffineForOp forOpA
, AffineForOp forOpB
,
117 unsigned i
, unsigned j
, unsigned loopDepth
,
118 unsigned maxLoopDepth
) {
119 for (unsigned d
= loopDepth
+ 1; d
<= maxLoopDepth
; ++d
) {
120 affine::ComputationSliceState sliceUnion
;
121 FusionResult result
= affine::canFuseLoops(forOpA
, forOpB
, d
, &sliceUnion
);
122 if (result
.value
== FusionResult::Success
) {
123 forOpB
->emitRemark("slice (")
124 << " src loop: " << i
<< ", dst loop: " << j
<< ", depth: " << d
125 << " : " << getSliceStr(sliceUnion
) << ")";
126 } else if (result
.value
== FusionResult::FailIncorrectSlice
) {
127 forOpB
->emitRemark("Incorrect slice (")
128 << " src loop: " << i
<< ", dst loop: " << j
<< ", depth: " << d
129 << " : " << getSliceStr(sliceUnion
) << ")";
135 // Attempts to fuse 'forOpA' into 'forOpB' at loop depths in range
136 // ['loopDepth' + 1, 'maxLoopDepth'].
137 // Returns true if loops were successfully fused, false otherwise.
138 static bool testLoopFusionTransformation(AffineForOp forOpA
, AffineForOp forOpB
,
139 unsigned i
, unsigned j
,
141 unsigned maxLoopDepth
) {
142 for (unsigned d
= loopDepth
+ 1; d
<= maxLoopDepth
; ++d
) {
143 affine::ComputationSliceState sliceUnion
;
144 FusionResult result
= affine::canFuseLoops(forOpA
, forOpB
, d
, &sliceUnion
);
145 if (result
.value
== FusionResult::Success
) {
146 affine::fuseLoops(forOpA
, forOpB
, sliceUnion
);
147 // Note: 'forOpA' is removed to simplify test output. A proper loop
148 // fusion pass should check the data dependence graph and run memref
149 // region analysis to ensure removing 'forOpA' is safe.
157 using LoopFunc
= function_ref
<bool(AffineForOp
, AffineForOp
, unsigned, unsigned,
158 unsigned, unsigned)>;
160 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
161 // If 'return_on_change' is true, returns on first invocation of 'fn' which
163 static bool iterateLoops(ArrayRef
<SmallVector
<AffineForOp
, 2>> depthToLoops
,
164 LoopFunc fn
, bool returnOnChange
= false) {
165 bool changed
= false;
166 for (unsigned loopDepth
= 0, end
= depthToLoops
.size(); loopDepth
< end
;
168 auto &loops
= depthToLoops
[loopDepth
];
169 unsigned numLoops
= loops
.size();
170 for (unsigned j
= 0; j
< numLoops
; ++j
) {
171 for (unsigned k
= 0; k
< numLoops
; ++k
) {
174 fn(loops
[j
], loops
[k
], j
, k
, loopDepth
, depthToLoops
.size());
175 if (changed
&& returnOnChange
)
183 void TestLoopFusion::runOnOperation() {
184 std::vector
<SmallVector
<AffineForOp
, 2>> depthToLoops
;
185 if (clTestLoopFusionTransformation
) {
186 // Run loop fusion until a fixed point is reached.
188 depthToLoops
.clear();
189 // Gather all AffineForOps by loop depth.
190 gatherLoops(getOperation(), depthToLoops
);
192 // Try to fuse all combinations of src/dst loop nests in 'depthToLoops'.
193 } while (iterateLoops(depthToLoops
, testLoopFusionTransformation
,
194 /*returnOnChange=*/true));
198 // Gather all AffineForOps by loop depth.
199 gatherLoops(getOperation(), depthToLoops
);
201 // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
202 if (clTestDependenceCheck
)
203 iterateLoops(depthToLoops
, testDependenceCheck
);
204 if (clTestSliceComputation
)
205 iterateLoops(depthToLoops
, testSliceComputation
);
210 void registerTestLoopFusion() { PassRegistration
<TestLoopFusion
>(); }