1 //===- LoopInvariantCodeMotionUtils.cpp - LICM Utils ------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains the implementation of the core LICM algorithm.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/Interfaces/LoopLikeInterface.h"
18 #include "mlir/Interfaces/SideEffectInterfaces.h"
19 #include "mlir/Interfaces/SubsetOpInterface.h"
20 #include "llvm/Support/Debug.h"
23 #define DEBUG_TYPE "licm"
27 /// Checks whether the given op can be hoisted by checking that
28 /// - the op and none of its contained operations depend on values inside of the
29 /// loop (by means of calling definedOutside).
30 /// - the op has no side-effects.
31 static bool canBeHoisted(Operation
*op
,
32 function_ref
<bool(OpOperand
&)> condition
) {
33 // Do not move terminators.
34 if (op
->hasTrait
<OpTrait::IsTerminator
>())
37 // Walk the nested operations and check that all used values are either
38 // defined outside of the loop or in a nested region, but not at the level of
40 auto walkFn
= [&](Operation
*child
) {
41 for (OpOperand
&operand
: child
->getOpOperands()) {
42 // Ignore values defined in a nested region.
43 if (op
->isAncestor(operand
.get().getParentRegion()->getParentOp()))
45 if (!condition(operand
))
46 return WalkResult::interrupt();
48 return WalkResult::advance();
50 return !op
->walk(walkFn
).wasInterrupted();
53 static bool canBeHoisted(Operation
*op
,
54 function_ref
<bool(Value
)> definedOutside
) {
56 op
, [&](OpOperand
&operand
) { return definedOutside(operand
.get()); });
59 size_t mlir::moveLoopInvariantCode(
60 ArrayRef
<Region
*> regions
,
61 function_ref
<bool(Value
, Region
*)> isDefinedOutsideRegion
,
62 function_ref
<bool(Operation
*, Region
*)> shouldMoveOutOfRegion
,
63 function_ref
<void(Operation
*, Region
*)> moveOutOfRegion
) {
66 for (Region
*region
: regions
) {
67 LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
68 << *region
->getParentOp() << "\n");
70 std::queue
<Operation
*> worklist
;
71 // Add top-level operations in the loop body to the worklist.
72 for (Operation
&op
: region
->getOps())
75 auto definedOutside
= [&](Value value
) {
76 return isDefinedOutsideRegion(value
, region
);
79 while (!worklist
.empty()) {
80 Operation
*op
= worklist
.front();
82 // Skip ops that have already been moved. Check if the op can be hoisted.
83 if (op
->getParentRegion() != region
)
86 LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op
<< "\n");
87 if (!shouldMoveOutOfRegion(op
, region
) ||
88 !canBeHoisted(op
, definedOutside
))
91 LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op
<< "\n");
92 moveOutOfRegion(op
, region
);
95 // Since the op has been moved, we need to check its users within the
96 // top-level of the loop body.
97 for (Operation
*user
: op
->getUsers())
98 if (user
->getParentRegion() == region
)
106 size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike
) {
107 return moveLoopInvariantCode(
108 loopLike
.getLoopRegions(),
109 [&](Value value
, Region
*) {
110 return loopLike
.isDefinedOutsideOfLoop(value
);
112 [&](Operation
*op
, Region
*) {
113 return isMemoryEffectFree(op
) && isSpeculatable(op
);
115 [&](Operation
*op
, Region
*) { loopLike
.moveOutOfLoop(op
); });
119 /// Helper data structure that keeps track of equivalent/disjoint subset ops.
120 class MatchingSubsets
{
122 /// Insert a subset op.
123 void insert(SubsetOpInterface op
, bool collectHoistableOps
= true) {
124 allSubsetOps
.push_back(op
);
125 if (!collectHoistableOps
)
127 if (auto extractionOp
=
128 dyn_cast
<SubsetExtractionOpInterface
>(op
.getOperation()))
129 insertExtractionOp(extractionOp
);
130 if (auto insertionOp
=
131 dyn_cast
<SubsetInsertionOpInterface
>(op
.getOperation()))
132 insertInsertionOp(insertionOp
);
135 /// Return a range of matching extraction-insertion subset ops. If there is no
136 /// matching extraction/insertion op, the respective value is empty. Ops are
137 /// skipped if there are other subset ops that are not guaranteed to operate
138 /// on disjoint subsets.
139 auto getHoistableSubsetOps() {
140 return llvm::make_filter_range(
141 llvm::zip(extractions
, insertions
), [&](auto pair
) {
142 auto [extractionOp
, insertionOp
] = pair
;
143 // Hoist only if the extracted and inserted values have the same type.
144 if (extractionOp
&& insertionOp
&&
145 extractionOp
->getResult(0).getType() !=
146 insertionOp
.getSourceOperand().get().getType())
148 // Hoist only if there are no conflicting subset ops.
149 return allDisjoint(extractionOp
, insertionOp
);
153 /// Populate subset ops starting from the given region iter_arg. Return
154 /// "failure" if non-subset ops are found along the path to the loop yielding
155 /// op or if there is no single path to the tied yielded operand. If
156 /// `collectHoistableOps` is set to "false", subset ops are gathered
157 /// throughout the traversal, but not enumerated by `getHoistableSubsetOps`.
158 LogicalResult
populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike
,
159 BlockArgument iterArg
,
160 bool collectHoistableOps
= true);
163 /// Helper function for equivalence of tensor values. Since only insertion
164 /// subset ops (that are also destination style ops) are followed when
165 /// traversing the SSA use-def chain, all tensor values are equivalent.
166 static bool isEquivalent(Value v1
, Value v2
) { return true; }
168 /// Return "true" if the subsets of the given extraction and insertion ops
169 /// are operating disjoint from the subsets that all other known subset ops
170 /// are operating on.
171 bool allDisjoint(SubsetExtractionOpInterface extractionOp
,
172 SubsetInsertionOpInterface insertionOp
) const {
173 for (SubsetOpInterface other
: allSubsetOps
) {
174 if (other
== extractionOp
|| other
== insertionOp
)
177 !other
.operatesOnDisjointSubset(extractionOp
, isEquivalent
))
180 !other
.operatesOnDisjointSubset(insertionOp
, isEquivalent
))
186 /// Insert a subset extraction op. If the subset is equivalent to an existing
187 /// subset insertion op, pair them up. (If there is already a paired up subset
188 /// extraction op, overwrite the subset extraction op.)
189 void insertExtractionOp(SubsetExtractionOpInterface extractionOp
) {
190 for (auto it
: llvm::enumerate(insertions
)) {
193 auto other
= cast
<SubsetOpInterface
>(it
.value().getOperation());
194 if (other
.operatesOnEquivalentSubset(extractionOp
, isEquivalent
)) {
195 extractions
[it
.index()] = extractionOp
;
199 // There is no known equivalent insertion op. Create a new entry.
200 extractions
.push_back(extractionOp
);
201 insertions
.push_back({});
204 /// Insert a subset insertion op. If the subset is equivalent to an existing
205 /// subset extraction op, pair them up. (If there is already a paired up
206 /// subset insertion op, overwrite the subset insertion op.)
207 void insertInsertionOp(SubsetInsertionOpInterface insertionOp
) {
208 for (auto it
: llvm::enumerate(extractions
)) {
211 auto other
= cast
<SubsetOpInterface
>(it
.value().getOperation());
212 if (other
.operatesOnEquivalentSubset(insertionOp
, isEquivalent
)) {
213 insertions
[it
.index()] = insertionOp
;
217 // There is no known equivalent extraction op. Create a new entry.
218 extractions
.push_back({});
219 insertions
.push_back(insertionOp
);
222 SmallVector
<SubsetExtractionOpInterface
> extractions
;
223 SmallVector
<SubsetInsertionOpInterface
> insertions
;
224 SmallVector
<SubsetOpInterface
> allSubsetOps
;
228 /// If the given value has a single use by an op that is a terminator, return
229 /// that use. Otherwise, return nullptr.
230 static OpOperand
*getSingleTerminatorUse(Value value
) {
231 if (!value
.hasOneUse())
233 OpOperand
&use
= *value
.getUses().begin();
234 if (use
.getOwner()->hasTrait
<OpTrait::IsTerminator
>())
240 MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike
,
241 BlockArgument iterArg
,
242 bool collectHoistableOps
) {
243 assert(iterArg
.getOwner()->getParentOp() == loopLike
&& "invalid iter_arg");
244 Value value
= iterArg
;
246 // Traverse use-def chain. Subset ops can be hoisted only if all ops along the
247 // use-def chain starting from the region iter_arg are subset extraction or
248 // subset insertion ops. The chain must terminate at the corresponding yield
249 // operand (e.g., no swapping of iter_args).
250 OpOperand
*yieldedOperand
= nullptr;
251 // Iterate until the single use of the current SSA value is a terminator,
252 // which is expected to be the yielding operation of the loop.
253 while (!(yieldedOperand
= getSingleTerminatorUse(value
))) {
254 Value nextValue
= {};
256 for (OpOperand
&use
: value
.getUses()) {
257 if (auto nestedLoop
= dyn_cast
<LoopLikeOpInterface
>(use
.getOwner())) {
258 // Subset ops in nested loops are collected to check if there are only
259 // disjoint subset ops, but such subset ops are not subject to hoisting.
260 // To hoist subset ops from nested loops, the hoisting transformation
261 // should be run on the nested loop.
262 auto nestedIterArg
= nestedLoop
.getTiedLoopRegionIterArg(&use
);
265 // Note: `populateSubsetOpsAtIterArg` fails if there is no single SSA
266 // use-def chain starting at `nestedIterArg` and terminating in the
267 // tied, yielding operand.
268 if (failed(populateSubsetOpsAtIterArg(nestedLoop
, nestedIterArg
,
269 /*collectHoistableOps=*/false)))
271 nextValue
= nestedLoop
.getTiedLoopResult(&use
);
275 auto subsetOp
= dyn_cast
<SubsetOpInterface
>(use
.getOwner());
280 if (auto insertionOp
=
281 dyn_cast
<SubsetInsertionOpInterface
>(use
.getOwner())) {
282 // Current implementation expects that the insertionOp implement
283 // the destinationStyleOpInterface as well. Abort if that tha is not
285 if (!isa
<DestinationStyleOpInterface
>(use
.getOwner())) {
289 // The value must be used as a destination. (In case of a source, the
290 // entire tensor would be read, which would prevent any hoisting.)
291 if (&use
!= &insertionOp
.getDestinationOperand())
293 // There must be a single use-def chain from the region iter_arg to the
294 // terminator. I.e., only one insertion op. Branches are not supported.
297 nextValue
= insertionOp
.getUpdatedDestination();
301 // Nothing can be hoisted if the chain does not continue with loop yielding
302 // op or a subset insertion op.
308 // Hoist only if the SSA use-def chain ends in the yielding terminator of the
309 // loop and the yielded value is the `idx`-th operand. (I.e., there is no
311 if (loopLike
.getTiedLoopYieldedValue(iterArg
) != yieldedOperand
)
317 /// Hoist all subset ops that operate on the idx-th region iter_arg of the given
318 /// loop-like op and index into loop-invariant subset locations. Return the
319 /// newly created loop op (that has extra iter_args) or the original loop op if
320 /// nothing was hoisted.
321 static LoopLikeOpInterface
hoistSubsetAtIterArg(RewriterBase
&rewriter
,
322 LoopLikeOpInterface loopLike
,
323 BlockArgument iterArg
) {
324 assert(iterArg
.getOwner()->getParentOp() == loopLike
&& "invalid iter_arg");
325 auto it
= llvm::find(loopLike
.getRegionIterArgs(), iterArg
);
326 int64_t iterArgIdx
= std::distance(loopLike
.getRegionIterArgs().begin(), it
);
327 MatchingSubsets subsets
;
328 if (failed(subsets
.populateSubsetOpsAtIterArg(loopLike
, iterArg
)))
331 // Hoist all matching extraction-insertion pairs one-by-one.
332 for (auto it
: subsets
.getHoistableSubsetOps()) {
333 auto extractionOp
= std::get
<0>(it
);
334 auto insertionOp
= std::get
<1>(it
);
336 // Ops cannot be hoisted if they depend on loop-variant values.
338 if (!canBeHoisted(extractionOp
, [&](OpOperand
&operand
) {
339 return loopLike
.isDefinedOutsideOfLoop(operand
.get()) ||
340 &operand
== &extractionOp
.getSourceOperand();
345 if (!canBeHoisted(insertionOp
, [&](OpOperand
&operand
) {
346 return loopLike
.isDefinedOutsideOfLoop(operand
.get()) ||
347 &operand
== &insertionOp
.getSourceOperand() ||
348 &operand
== &insertionOp
.getDestinationOperand();
353 // Only hoist extraction-insertion pairs for now. Standalone extractions/
354 // insertions that are loop-invariant could be hoisted, but there may be
355 // easier ways to canonicalize the IR.
356 if (extractionOp
&& insertionOp
) {
357 // Create a new loop with an additional iter_arg.
358 NewYieldValuesFn newYieldValuesFn
=
359 [&](OpBuilder
&b
, Location loc
,
360 ArrayRef
<BlockArgument
> innerNewBBArgs
) -> SmallVector
<Value
> {
361 return {insertionOp
.getSourceOperand().get()};
363 FailureOr
<LoopLikeOpInterface
> newLoop
=
364 loopLike
.replaceWithAdditionalYields(
365 rewriter
, extractionOp
.getResult(),
366 /*replaceInitOperandUsesInLoop=*/true, newYieldValuesFn
);
371 // Hoist the extraction/insertion ops.
372 iterArg
= loopLike
.getRegionIterArgs()[iterArgIdx
];
373 OpResult loopResult
= loopLike
.getTiedLoopResult(iterArg
);
374 OpResult newLoopResult
= loopLike
.getLoopResults()->back();
375 rewriter
.moveOpBefore(extractionOp
, loopLike
);
376 rewriter
.moveOpAfter(insertionOp
, loopLike
);
377 rewriter
.replaceAllUsesWith(insertionOp
.getUpdatedDestination(),
378 insertionOp
.getDestinationOperand().get());
379 extractionOp
.getSourceOperand().set(
380 loopLike
.getTiedLoopInit(iterArg
)->get());
381 rewriter
.replaceAllUsesWith(loopResult
,
382 insertionOp
.getUpdatedDestination());
383 insertionOp
.getSourceOperand().set(newLoopResult
);
384 insertionOp
.getDestinationOperand().set(loopResult
);
392 mlir::hoistLoopInvariantSubsets(RewriterBase
&rewriter
,
393 LoopLikeOpInterface loopLike
) {
394 // Note: As subset ops are getting hoisted, the number of region iter_args
395 // increases. This can enable further hoisting opportunities on the new
398 i
< static_cast<int64_t>(loopLike
.getRegionIterArgs().size()); ++i
) {
399 loopLike
= hoistSubsetAtIterArg(rewriter
, loopLike
,
400 loopLike
.getRegionIterArgs()[i
]);