[TableGen] Add TreePatternNode::children and use it in for loops (NFC) (#119877)
[llvm-project.git] / mlir / lib / Transforms / Utils / LoopInvariantCodeMotionUtils.cpp
blob7460746934a78c751dbc37ca544cad5af93ee543
1 //===- LoopInvariantCodeMotionUtils.cpp - LICM Utils ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the core LICM algorithm.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/Interfaces/LoopLikeInterface.h"
18 #include "mlir/Interfaces/SideEffectInterfaces.h"
19 #include "mlir/Interfaces/SubsetOpInterface.h"
20 #include "llvm/Support/Debug.h"
21 #include <queue>
23 #define DEBUG_TYPE "licm"
25 using namespace mlir;
27 /// Checks whether the given op can be hoisted by checking that
28 /// - the op and none of its contained operations depend on values inside of the
29 /// loop (by means of calling definedOutside).
30 /// - the op has no side-effects.
31 static bool canBeHoisted(Operation *op,
32 function_ref<bool(OpOperand &)> condition) {
33 // Do not move terminators.
34 if (op->hasTrait<OpTrait::IsTerminator>())
35 return false;
37 // Walk the nested operations and check that all used values are either
38 // defined outside of the loop or in a nested region, but not at the level of
39 // the loop body.
40 auto walkFn = [&](Operation *child) {
41 for (OpOperand &operand : child->getOpOperands()) {
42 // Ignore values defined in a nested region.
43 if (op->isAncestor(operand.get().getParentRegion()->getParentOp()))
44 continue;
45 if (!condition(operand))
46 return WalkResult::interrupt();
48 return WalkResult::advance();
50 return !op->walk(walkFn).wasInterrupted();
53 static bool canBeHoisted(Operation *op,
54 function_ref<bool(Value)> definedOutside) {
55 return canBeHoisted(
56 op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
59 size_t mlir::moveLoopInvariantCode(
60 ArrayRef<Region *> regions,
61 function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
62 function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
63 function_ref<void(Operation *, Region *)> moveOutOfRegion) {
64 size_t numMoved = 0;
66 for (Region *region : regions) {
67 LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
68 << *region->getParentOp() << "\n");
70 std::queue<Operation *> worklist;
71 // Add top-level operations in the loop body to the worklist.
72 for (Operation &op : region->getOps())
73 worklist.push(&op);
75 auto definedOutside = [&](Value value) {
76 return isDefinedOutsideRegion(value, region);
79 while (!worklist.empty()) {
80 Operation *op = worklist.front();
81 worklist.pop();
82 // Skip ops that have already been moved. Check if the op can be hoisted.
83 if (op->getParentRegion() != region)
84 continue;
86 LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
87 if (!shouldMoveOutOfRegion(op, region) ||
88 !canBeHoisted(op, definedOutside))
89 continue;
91 LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
92 moveOutOfRegion(op, region);
93 ++numMoved;
95 // Since the op has been moved, we need to check its users within the
96 // top-level of the loop body.
97 for (Operation *user : op->getUsers())
98 if (user->getParentRegion() == region)
99 worklist.push(user);
103 return numMoved;
106 size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
107 return moveLoopInvariantCode(
108 loopLike.getLoopRegions(),
109 [&](Value value, Region *) {
110 return loopLike.isDefinedOutsideOfLoop(value);
112 [&](Operation *op, Region *) {
113 return isMemoryEffectFree(op) && isSpeculatable(op);
115 [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
118 namespace {
119 /// Helper data structure that keeps track of equivalent/disjoint subset ops.
120 class MatchingSubsets {
121 public:
122 /// Insert a subset op.
123 void insert(SubsetOpInterface op, bool collectHoistableOps = true) {
124 allSubsetOps.push_back(op);
125 if (!collectHoistableOps)
126 return;
127 if (auto extractionOp =
128 dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
129 insertExtractionOp(extractionOp);
130 if (auto insertionOp =
131 dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
132 insertInsertionOp(insertionOp);
135 /// Return a range of matching extraction-insertion subset ops. If there is no
136 /// matching extraction/insertion op, the respective value is empty. Ops are
137 /// skipped if there are other subset ops that are not guaranteed to operate
138 /// on disjoint subsets.
139 auto getHoistableSubsetOps() {
140 return llvm::make_filter_range(
141 llvm::zip(extractions, insertions), [&](auto pair) {
142 auto [extractionOp, insertionOp] = pair;
143 // Hoist only if the extracted and inserted values have the same type.
144 if (extractionOp && insertionOp &&
145 extractionOp->getResult(0).getType() !=
146 insertionOp.getSourceOperand().get().getType())
147 return false;
148 // Hoist only if there are no conflicting subset ops.
149 return allDisjoint(extractionOp, insertionOp);
153 /// Populate subset ops starting from the given region iter_arg. Return
154 /// "failure" if non-subset ops are found along the path to the loop yielding
155 /// op or if there is no single path to the tied yielded operand. If
156 /// `collectHoistableOps` is set to "false", subset ops are gathered
157 /// throughout the traversal, but not enumerated by `getHoistableSubsetOps`.
158 LogicalResult populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
159 BlockArgument iterArg,
160 bool collectHoistableOps = true);
162 private:
163 /// Helper function for equivalence of tensor values. Since only insertion
164 /// subset ops (that are also destination style ops) are followed when
165 /// traversing the SSA use-def chain, all tensor values are equivalent.
166 static bool isEquivalent(Value v1, Value v2) { return true; }
168 /// Return "true" if the subsets of the given extraction and insertion ops
169 /// are operating disjoint from the subsets that all other known subset ops
170 /// are operating on.
171 bool allDisjoint(SubsetExtractionOpInterface extractionOp,
172 SubsetInsertionOpInterface insertionOp) const {
173 for (SubsetOpInterface other : allSubsetOps) {
174 if (other == extractionOp || other == insertionOp)
175 continue;
176 if (extractionOp &&
177 !other.operatesOnDisjointSubset(extractionOp, isEquivalent))
178 return false;
179 if (insertionOp &&
180 !other.operatesOnDisjointSubset(insertionOp, isEquivalent))
181 return false;
183 return true;
186 /// Insert a subset extraction op. If the subset is equivalent to an existing
187 /// subset insertion op, pair them up. (If there is already a paired up subset
188 /// extraction op, overwrite the subset extraction op.)
189 void insertExtractionOp(SubsetExtractionOpInterface extractionOp) {
190 for (auto it : llvm::enumerate(insertions)) {
191 if (!it.value())
192 continue;
193 auto other = cast<SubsetOpInterface>(it.value().getOperation());
194 if (other.operatesOnEquivalentSubset(extractionOp, isEquivalent)) {
195 extractions[it.index()] = extractionOp;
196 return;
199 // There is no known equivalent insertion op. Create a new entry.
200 extractions.push_back(extractionOp);
201 insertions.push_back({});
204 /// Insert a subset insertion op. If the subset is equivalent to an existing
205 /// subset extraction op, pair them up. (If there is already a paired up
206 /// subset insertion op, overwrite the subset insertion op.)
207 void insertInsertionOp(SubsetInsertionOpInterface insertionOp) {
208 for (auto it : llvm::enumerate(extractions)) {
209 if (!it.value())
210 continue;
211 auto other = cast<SubsetOpInterface>(it.value().getOperation());
212 if (other.operatesOnEquivalentSubset(insertionOp, isEquivalent)) {
213 insertions[it.index()] = insertionOp;
214 return;
217 // There is no known equivalent extraction op. Create a new entry.
218 extractions.push_back({});
219 insertions.push_back(insertionOp);
222 SmallVector<SubsetExtractionOpInterface> extractions;
223 SmallVector<SubsetInsertionOpInterface> insertions;
224 SmallVector<SubsetOpInterface> allSubsetOps;
226 } // namespace
228 /// If the given value has a single use by an op that is a terminator, return
229 /// that use. Otherwise, return nullptr.
230 static OpOperand *getSingleTerminatorUse(Value value) {
231 if (!value.hasOneUse())
232 return nullptr;
233 OpOperand &use = *value.getUses().begin();
234 if (use.getOwner()->hasTrait<OpTrait::IsTerminator>())
235 return &use;
236 return nullptr;
239 LogicalResult
240 MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
241 BlockArgument iterArg,
242 bool collectHoistableOps) {
243 assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
244 Value value = iterArg;
246 // Traverse use-def chain. Subset ops can be hoisted only if all ops along the
247 // use-def chain starting from the region iter_arg are subset extraction or
248 // subset insertion ops. The chain must terminate at the corresponding yield
249 // operand (e.g., no swapping of iter_args).
250 OpOperand *yieldedOperand = nullptr;
251 // Iterate until the single use of the current SSA value is a terminator,
252 // which is expected to be the yielding operation of the loop.
253 while (!(yieldedOperand = getSingleTerminatorUse(value))) {
254 Value nextValue = {};
256 for (OpOperand &use : value.getUses()) {
257 if (auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
258 // Subset ops in nested loops are collected to check if there are only
259 // disjoint subset ops, but such subset ops are not subject to hoisting.
260 // To hoist subset ops from nested loops, the hoisting transformation
261 // should be run on the nested loop.
262 auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg(&use);
263 if (!nestedIterArg)
264 return failure();
265 // Note: `populateSubsetOpsAtIterArg` fails if there is no single SSA
266 // use-def chain starting at `nestedIterArg` and terminating in the
267 // tied, yielding operand.
268 if (failed(populateSubsetOpsAtIterArg(nestedLoop, nestedIterArg,
269 /*collectHoistableOps=*/false)))
270 return failure();
271 nextValue = nestedLoop.getTiedLoopResult(&use);
272 continue;
275 auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner());
276 if (!subsetOp)
277 return failure();
278 insert(subsetOp);
280 if (auto insertionOp =
281 dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) {
282 // Current implementation expects that the insertionOp implement
283 // the destinationStyleOpInterface as well. Abort if that tha is not
284 // the case
285 if (!isa<DestinationStyleOpInterface>(use.getOwner())) {
286 return failure();
289 // The value must be used as a destination. (In case of a source, the
290 // entire tensor would be read, which would prevent any hoisting.)
291 if (&use != &insertionOp.getDestinationOperand())
292 return failure();
293 // There must be a single use-def chain from the region iter_arg to the
294 // terminator. I.e., only one insertion op. Branches are not supported.
295 if (nextValue)
296 return failure();
297 nextValue = insertionOp.getUpdatedDestination();
301 // Nothing can be hoisted if the chain does not continue with loop yielding
302 // op or a subset insertion op.
303 if (!nextValue)
304 return failure();
305 value = nextValue;
308 // Hoist only if the SSA use-def chain ends in the yielding terminator of the
309 // loop and the yielded value is the `idx`-th operand. (I.e., there is no
310 // swapping yield.)
311 if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand)
312 return failure();
314 return success();
317 /// Hoist all subset ops that operate on the idx-th region iter_arg of the given
318 /// loop-like op and index into loop-invariant subset locations. Return the
319 /// newly created loop op (that has extra iter_args) or the original loop op if
320 /// nothing was hoisted.
321 static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
322 LoopLikeOpInterface loopLike,
323 BlockArgument iterArg) {
324 assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
325 auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
326 int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
327 MatchingSubsets subsets;
328 if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
329 return loopLike;
331 // Hoist all matching extraction-insertion pairs one-by-one.
332 for (auto it : subsets.getHoistableSubsetOps()) {
333 auto extractionOp = std::get<0>(it);
334 auto insertionOp = std::get<1>(it);
336 // Ops cannot be hoisted if they depend on loop-variant values.
337 if (extractionOp) {
338 if (!canBeHoisted(extractionOp, [&](OpOperand &operand) {
339 return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
340 &operand == &extractionOp.getSourceOperand();
342 extractionOp = {};
344 if (insertionOp) {
345 if (!canBeHoisted(insertionOp, [&](OpOperand &operand) {
346 return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
347 &operand == &insertionOp.getSourceOperand() ||
348 &operand == &insertionOp.getDestinationOperand();
350 insertionOp = {};
353 // Only hoist extraction-insertion pairs for now. Standalone extractions/
354 // insertions that are loop-invariant could be hoisted, but there may be
355 // easier ways to canonicalize the IR.
356 if (extractionOp && insertionOp) {
357 // Create a new loop with an additional iter_arg.
358 NewYieldValuesFn newYieldValuesFn =
359 [&](OpBuilder &b, Location loc,
360 ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
361 return {insertionOp.getSourceOperand().get()};
363 FailureOr<LoopLikeOpInterface> newLoop =
364 loopLike.replaceWithAdditionalYields(
365 rewriter, extractionOp.getResult(),
366 /*replaceInitOperandUsesInLoop=*/true, newYieldValuesFn);
367 if (failed(newLoop))
368 return loopLike;
369 loopLike = *newLoop;
371 // Hoist the extraction/insertion ops.
372 iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
373 OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
374 OpResult newLoopResult = loopLike.getLoopResults()->back();
375 rewriter.moveOpBefore(extractionOp, loopLike);
376 rewriter.moveOpAfter(insertionOp, loopLike);
377 rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
378 insertionOp.getDestinationOperand().get());
379 extractionOp.getSourceOperand().set(
380 loopLike.getTiedLoopInit(iterArg)->get());
381 rewriter.replaceAllUsesWith(loopResult,
382 insertionOp.getUpdatedDestination());
383 insertionOp.getSourceOperand().set(newLoopResult);
384 insertionOp.getDestinationOperand().set(loopResult);
388 return loopLike;
391 LoopLikeOpInterface
392 mlir::hoistLoopInvariantSubsets(RewriterBase &rewriter,
393 LoopLikeOpInterface loopLike) {
394 // Note: As subset ops are getting hoisted, the number of region iter_args
395 // increases. This can enable further hoisting opportunities on the new
396 // iter_args.
397 for (int64_t i = 0;
398 i < static_cast<int64_t>(loopLike.getRegionIterArgs().size()); ++i) {
399 loopLike = hoistSubsetAtIterArg(rewriter, loopLike,
400 loopLike.getRegionIterArgs()[i]);
402 return loopLike;