1 //===- RegionUtils.cpp - Region-related transformation utilities ----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Transforms/RegionUtils.h"
10 #include "mlir/Analysis/TopologicalSortUtils.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/IRMapping.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/RegionGraphTraits.h"
17 #include "mlir/IR/Value.h"
18 #include "mlir/Interfaces/ControlFlowInterfaces.h"
19 #include "mlir/Interfaces/SideEffectInterfaces.h"
20 #include "mlir/Support/LogicalResult.h"
22 #include "llvm/ADT/DepthFirstIterator.h"
23 #include "llvm/ADT/PostOrderIterator.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallSet.h"
32 void mlir::replaceAllUsesInRegionWith(Value orig
, Value replacement
,
34 for (auto &use
: llvm::make_early_inc_range(orig
.getUses())) {
35 if (region
.isAncestor(use
.getOwner()->getParentRegion()))
40 void mlir::visitUsedValuesDefinedAbove(
41 Region
®ion
, Region
&limit
, function_ref
<void(OpOperand
*)> callback
) {
42 assert(limit
.isAncestor(®ion
) &&
43 "expected isolation limit to be an ancestor of the given region");
45 // Collect proper ancestors of `limit` upfront to avoid traversing the region
46 // tree for every value.
47 SmallPtrSet
<Region
*, 4> properAncestors
;
48 for (auto *reg
= limit
.getParentRegion(); reg
!= nullptr;
49 reg
= reg
->getParentRegion()) {
50 properAncestors
.insert(reg
);
53 region
.walk([callback
, &properAncestors
](Operation
*op
) {
54 for (OpOperand
&operand
: op
->getOpOperands())
55 // Callback on values defined in a proper ancestor of region.
56 if (properAncestors
.count(operand
.get().getParentRegion()))
61 void mlir::visitUsedValuesDefinedAbove(
62 MutableArrayRef
<Region
> regions
, function_ref
<void(OpOperand
*)> callback
) {
63 for (Region
®ion
: regions
)
64 visitUsedValuesDefinedAbove(region
, region
, callback
);
67 void mlir::getUsedValuesDefinedAbove(Region
®ion
, Region
&limit
,
68 SetVector
<Value
> &values
) {
69 visitUsedValuesDefinedAbove(region
, limit
, [&](OpOperand
*operand
) {
70 values
.insert(operand
->get());
74 void mlir::getUsedValuesDefinedAbove(MutableArrayRef
<Region
> regions
,
75 SetVector
<Value
> &values
) {
76 for (Region
®ion
: regions
)
77 getUsedValuesDefinedAbove(region
, region
, values
);
80 //===----------------------------------------------------------------------===//
81 // Make block isolated from above.
82 //===----------------------------------------------------------------------===//
84 SmallVector
<Value
> mlir::makeRegionIsolatedFromAbove(
85 RewriterBase
&rewriter
, Region
®ion
,
86 llvm::function_ref
<bool(Operation
*)> cloneOperationIntoRegion
) {
88 // Get initial list of values used within region but defined above.
89 llvm::SetVector
<Value
> initialCapturedValues
;
90 mlir::getUsedValuesDefinedAbove(region
, initialCapturedValues
);
92 std::deque
<Value
> worklist(initialCapturedValues
.begin(),
93 initialCapturedValues
.end());
94 llvm::DenseSet
<Value
> visited
;
95 llvm::DenseSet
<Operation
*> visitedOps
;
97 llvm::SetVector
<Value
> finalCapturedValues
;
98 SmallVector
<Operation
*> clonedOperations
;
99 while (!worklist
.empty()) {
100 Value currValue
= worklist
.front();
101 worklist
.pop_front();
102 if (visited
.count(currValue
))
104 visited
.insert(currValue
);
106 Operation
*definingOp
= currValue
.getDefiningOp();
107 if (!definingOp
|| visitedOps
.count(definingOp
)) {
108 finalCapturedValues
.insert(currValue
);
111 visitedOps
.insert(definingOp
);
113 if (!cloneOperationIntoRegion(definingOp
)) {
114 // Defining operation isnt cloned, so add the current value to final
115 // captured values list.
116 finalCapturedValues
.insert(currValue
);
120 // Add all operands of the operation to the worklist and mark the op as to
122 for (Value operand
: definingOp
->getOperands()) {
123 if (visited
.count(operand
))
125 worklist
.push_back(operand
);
127 clonedOperations
.push_back(definingOp
);
130 // The operations to be cloned need to be ordered in topological order
131 // so that they can be cloned into the region without violating use-def
133 mlir::computeTopologicalSorting(clonedOperations
);
135 OpBuilder::InsertionGuard
g(rewriter
);
136 // Collect types of existing block
137 Block
*entryBlock
= ®ion
.front();
138 SmallVector
<Type
> newArgTypes
=
139 llvm::to_vector(entryBlock
->getArgumentTypes());
140 SmallVector
<Location
> newArgLocs
= llvm::to_vector(llvm::map_range(
141 entryBlock
->getArguments(), [](BlockArgument b
) { return b
.getLoc(); }));
143 // Append the types of the captured values.
144 for (auto value
: finalCapturedValues
) {
145 newArgTypes
.push_back(value
.getType());
146 newArgLocs
.push_back(value
.getLoc());
149 // Create a new entry block.
150 Block
*newEntryBlock
=
151 rewriter
.createBlock(®ion
, region
.begin(), newArgTypes
, newArgLocs
);
152 auto newEntryBlockArgs
= newEntryBlock
->getArguments();
154 // Create a mapping between the captured values and the new arguments added.
156 auto replaceIfFn
= [&](OpOperand
&use
) {
157 return use
.getOwner()->getBlock()->getParent() == ®ion
;
159 for (auto [arg
, capturedVal
] :
160 llvm::zip(newEntryBlockArgs
.take_back(finalCapturedValues
.size()),
161 finalCapturedValues
)) {
162 map
.map(capturedVal
, arg
);
163 rewriter
.replaceUsesWithIf(capturedVal
, arg
, replaceIfFn
);
165 rewriter
.setInsertionPointToStart(newEntryBlock
);
166 for (auto *clonedOp
: clonedOperations
) {
167 Operation
*newOp
= rewriter
.clone(*clonedOp
, map
);
168 rewriter
.replaceOpUsesWithIf(clonedOp
, newOp
->getResults(), replaceIfFn
);
170 rewriter
.mergeBlocks(
171 entryBlock
, newEntryBlock
,
172 newEntryBlock
->getArguments().take_front(entryBlock
->getNumArguments()));
173 return llvm::to_vector(finalCapturedValues
);
176 //===----------------------------------------------------------------------===//
177 // Unreachable Block Elimination
178 //===----------------------------------------------------------------------===//
180 /// Erase the unreachable blocks within the provided regions. Returns success
181 /// if any blocks were erased, failure otherwise.
182 // TODO: We could likely merge this with the DCE algorithm below.
183 LogicalResult
mlir::eraseUnreachableBlocks(RewriterBase
&rewriter
,
184 MutableArrayRef
<Region
> regions
) {
185 // Set of blocks found to be reachable within a given region.
186 llvm::df_iterator_default_set
<Block
*, 16> reachable
;
187 // If any blocks were found to be dead.
188 bool erasedDeadBlocks
= false;
190 SmallVector
<Region
*, 1> worklist
;
191 worklist
.reserve(regions
.size());
192 for (Region
®ion
: regions
)
193 worklist
.push_back(®ion
);
194 while (!worklist
.empty()) {
195 Region
*region
= worklist
.pop_back_val();
199 // If this is a single block region, just collect the nested regions.
200 if (std::next(region
->begin()) == region
->end()) {
201 for (Operation
&op
: region
->front())
202 for (Region
®ion
: op
.getRegions())
203 worklist
.push_back(®ion
);
207 // Mark all reachable blocks.
209 for (Block
*block
: depth_first_ext(®ion
->front(), reachable
))
210 (void)block
/* Mark all reachable blocks */;
212 // Collect all of the dead blocks and push the live regions onto the
214 for (Block
&block
: llvm::make_early_inc_range(*region
)) {
215 if (!reachable
.count(&block
)) {
216 block
.dropAllDefinedValueUses();
217 rewriter
.eraseBlock(&block
);
218 erasedDeadBlocks
= true;
222 // Walk any regions within this block.
223 for (Operation
&op
: block
)
224 for (Region
®ion
: op
.getRegions())
225 worklist
.push_back(®ion
);
229 return success(erasedDeadBlocks
);
232 //===----------------------------------------------------------------------===//
233 // Dead Code Elimination
234 //===----------------------------------------------------------------------===//
237 /// Data structure used to track which values have already been proved live.
239 /// Because Operation's can have multiple results, this data structure tracks
240 /// liveness for both Value's and Operation's to avoid having to look through
241 /// all Operation results when analyzing a use.
243 /// This data structure essentially tracks the dataflow lattice.
244 /// The set of values/ops proved live increases monotonically to a fixed-point.
248 bool wasProvenLive(Value value
) {
249 // TODO: For results that are removable, e.g. for region based control flow,
250 // we could allow for these values to be tracked independently.
251 if (OpResult result
= dyn_cast
<OpResult
>(value
))
252 return wasProvenLive(result
.getOwner());
253 return wasProvenLive(cast
<BlockArgument
>(value
));
255 bool wasProvenLive(BlockArgument arg
) { return liveValues
.count(arg
); }
256 void setProvedLive(Value value
) {
257 // TODO: For results that are removable, e.g. for region based control flow,
258 // we could allow for these values to be tracked independently.
259 if (OpResult result
= dyn_cast
<OpResult
>(value
))
260 return setProvedLive(result
.getOwner());
261 setProvedLive(cast
<BlockArgument
>(value
));
263 void setProvedLive(BlockArgument arg
) {
264 changed
|= liveValues
.insert(arg
).second
;
267 /// Operation methods.
268 bool wasProvenLive(Operation
*op
) { return liveOps
.count(op
); }
269 void setProvedLive(Operation
*op
) { changed
|= liveOps
.insert(op
).second
; }
271 /// Methods for tracking if we have reached a fixed-point.
272 void resetChanged() { changed
= false; }
273 bool hasChanged() { return changed
; }
276 bool changed
= false;
277 DenseSet
<Value
> liveValues
;
278 DenseSet
<Operation
*> liveOps
;
282 static bool isUseSpeciallyKnownDead(OpOperand
&use
, LiveMap
&liveMap
) {
283 Operation
*owner
= use
.getOwner();
284 unsigned operandIndex
= use
.getOperandNumber();
285 // This pass generally treats all uses of an op as live if the op itself is
286 // considered live. However, for successor operands to terminators we need a
287 // finer-grained notion where we deduce liveness for operands individually.
288 // The reason for this is easiest to think about in terms of a classical phi
289 // node based SSA IR, where each successor operand is really an operand to a
290 // *separate* phi node, rather than all operands to the branch itself as with
291 // the block argument representation that MLIR uses.
293 // And similarly, because each successor operand is really an operand to a phi
294 // node, rather than to the terminator op itself, a terminator op can't e.g.
295 // "print" the value of a successor operand.
296 if (owner
->hasTrait
<OpTrait::IsTerminator
>()) {
297 if (BranchOpInterface branchInterface
= dyn_cast
<BranchOpInterface
>(owner
))
298 if (auto arg
= branchInterface
.getSuccessorBlockArgument(operandIndex
))
299 return !liveMap
.wasProvenLive(*arg
);
305 static void processValue(Value value
, LiveMap
&liveMap
) {
306 bool provedLive
= llvm::any_of(value
.getUses(), [&](OpOperand
&use
) {
307 if (isUseSpeciallyKnownDead(use
, liveMap
))
309 return liveMap
.wasProvenLive(use
.getOwner());
312 liveMap
.setProvedLive(value
);
315 static void propagateLiveness(Region
®ion
, LiveMap
&liveMap
);
317 static void propagateTerminatorLiveness(Operation
*op
, LiveMap
&liveMap
) {
318 // Terminators are always live.
319 liveMap
.setProvedLive(op
);
321 // Check to see if we can reason about the successor operands and mutate them.
322 BranchOpInterface branchInterface
= dyn_cast
<BranchOpInterface
>(op
);
323 if (!branchInterface
) {
324 for (Block
*successor
: op
->getSuccessors())
325 for (BlockArgument arg
: successor
->getArguments())
326 liveMap
.setProvedLive(arg
);
330 // If we can't reason about the operand to a successor, conservatively mark
332 for (unsigned i
= 0, e
= op
->getNumSuccessors(); i
!= e
; ++i
) {
333 SuccessorOperands successorOperands
=
334 branchInterface
.getSuccessorOperands(i
);
335 for (unsigned opI
= 0, opE
= successorOperands
.getProducedOperandCount();
337 liveMap
.setProvedLive(op
->getSuccessor(i
)->getArgument(opI
));
341 static void propagateLiveness(Operation
*op
, LiveMap
&liveMap
) {
342 // Recurse on any regions the op has.
343 for (Region
®ion
: op
->getRegions())
344 propagateLiveness(region
, liveMap
);
346 // Process terminator operations.
347 if (op
->hasTrait
<OpTrait::IsTerminator
>())
348 return propagateTerminatorLiveness(op
, liveMap
);
350 // Don't reprocess live operations.
351 if (liveMap
.wasProvenLive(op
))
354 // Process the op itself.
355 if (!wouldOpBeTriviallyDead(op
))
356 return liveMap
.setProvedLive(op
);
358 // If the op isn't intrinsically alive, check it's results.
359 for (Value value
: op
->getResults())
360 processValue(value
, liveMap
);
363 static void propagateLiveness(Region
®ion
, LiveMap
&liveMap
) {
367 for (Block
*block
: llvm::post_order(®ion
.front())) {
368 // We process block arguments after the ops in the block, to promote
369 // faster convergence to a fixed point (we try to visit uses before defs).
370 for (Operation
&op
: llvm::reverse(block
->getOperations()))
371 propagateLiveness(&op
, liveMap
);
373 // We currently do not remove entry block arguments, so there is no need to
374 // track their liveness.
375 // TODO: We could track these and enable removing dead operands/arguments
376 // from region control flow operations.
377 if (block
->isEntryBlock())
380 for (Value value
: block
->getArguments()) {
381 if (!liveMap
.wasProvenLive(value
))
382 processValue(value
, liveMap
);
387 static void eraseTerminatorSuccessorOperands(Operation
*terminator
,
389 BranchOpInterface branchOp
= dyn_cast
<BranchOpInterface
>(terminator
);
393 for (unsigned succI
= 0, succE
= terminator
->getNumSuccessors();
394 succI
< succE
; succI
++) {
395 // Iterating successors in reverse is not strictly needed, since we
396 // aren't erasing any successors. But it is slightly more efficient
397 // since it will promote later operands of the terminator being erased
398 // first, reducing the quadratic-ness.
399 unsigned succ
= succE
- succI
- 1;
400 SuccessorOperands succOperands
= branchOp
.getSuccessorOperands(succ
);
401 Block
*successor
= terminator
->getSuccessor(succ
);
403 for (unsigned argI
= 0, argE
= succOperands
.size(); argI
< argE
; ++argI
) {
404 // Iterating args in reverse is needed for correctness, to avoid
405 // shifting later args when earlier args are erased.
406 unsigned arg
= argE
- argI
- 1;
407 if (!liveMap
.wasProvenLive(successor
->getArgument(arg
)))
408 succOperands
.erase(arg
);
413 static LogicalResult
deleteDeadness(RewriterBase
&rewriter
,
414 MutableArrayRef
<Region
> regions
,
416 bool erasedAnything
= false;
417 for (Region
®ion
: regions
) {
420 bool hasSingleBlock
= llvm::hasSingleElement(region
);
422 // Delete every operation that is not live. Graph regions may have cycles
423 // in the use-def graph, so we must explicitly dropAllUses() from each
424 // operation as we erase it. Visiting the operations in post-order
425 // guarantees that in SSA CFG regions value uses are removed before defs,
426 // which makes dropAllUses() a no-op.
427 for (Block
*block
: llvm::post_order(®ion
.front())) {
429 eraseTerminatorSuccessorOperands(block
->getTerminator(), liveMap
);
430 for (Operation
&childOp
:
431 llvm::make_early_inc_range(llvm::reverse(block
->getOperations()))) {
432 if (!liveMap
.wasProvenLive(&childOp
)) {
433 erasedAnything
= true;
434 childOp
.dropAllUses();
435 rewriter
.eraseOp(&childOp
);
437 erasedAnything
|= succeeded(
438 deleteDeadness(rewriter
, childOp
.getRegions(), liveMap
));
442 // Delete block arguments.
443 // The entry block has an unknown contract with their enclosing block, so
445 for (Block
&block
: llvm::drop_begin(region
.getBlocks(), 1)) {
446 block
.eraseArguments(
447 [&](BlockArgument arg
) { return !liveMap
.wasProvenLive(arg
); });
450 return success(erasedAnything
);
453 // This function performs a simple dead code elimination algorithm over the
456 // The overall goal is to prove that Values are dead, which allows deleting ops
457 // and block arguments.
459 // This uses an optimistic algorithm that assumes everything is dead until
460 // proved otherwise, allowing it to delete recursively dead cycles.
462 // This is a simple fixed-point dataflow analysis algorithm on a lattice
463 // {Dead,Alive}. Because liveness flows backward, we generally try to
464 // iterate everything backward to speed up convergence to the fixed-point. This
465 // allows for being able to delete recursively dead cycles of the use-def graph,
466 // including block arguments.
468 // This function returns success if any operations or arguments were deleted,
469 // failure otherwise.
470 LogicalResult
mlir::runRegionDCE(RewriterBase
&rewriter
,
471 MutableArrayRef
<Region
> regions
) {
474 liveMap
.resetChanged();
476 for (Region
®ion
: regions
)
477 propagateLiveness(region
, liveMap
);
478 } while (liveMap
.hasChanged());
480 return deleteDeadness(rewriter
, regions
, liveMap
);
483 //===----------------------------------------------------------------------===//
485 //===----------------------------------------------------------------------===//
487 //===----------------------------------------------------------------------===//
488 // BlockEquivalenceData
491 /// This class contains the information for comparing the equivalencies of two
492 /// blocks. Blocks are considered equivalent if they contain the same operations
493 /// in the same order. The only allowed divergence is for operands that come
494 /// from sources outside of the parent block, i.e. the uses of values produced
495 /// within the block must be equivalent.
499 /// return %arg0, %foo : i32, i32
501 /// return %arg1, %bar : i32, i32
504 /// return %foo, %arg0 : i32, i32
506 /// return %arg1, %bar : i32, i32
507 struct BlockEquivalenceData
{
508 BlockEquivalenceData(Block
*block
);
510 /// Return the order index for the given value that is within the block of
512 unsigned getOrderOf(Value value
) const;
514 /// The block this data refers to.
516 /// A hash value for this block.
517 llvm::hash_code hash
;
518 /// A map of result producing operations to their relative orders within this
519 /// block. The order of an operation is the number of defined values that are
520 /// produced within the block before this operation.
521 DenseMap
<Operation
*, unsigned> opOrderIndex
;
525 BlockEquivalenceData::BlockEquivalenceData(Block
*block
)
526 : block(block
), hash(0) {
527 unsigned orderIt
= block
->getNumArguments();
528 for (Operation
&op
: *block
) {
529 if (unsigned numResults
= op
.getNumResults()) {
530 opOrderIndex
.try_emplace(&op
, orderIt
);
531 orderIt
+= numResults
;
533 auto opHash
= OperationEquivalence::computeHash(
534 &op
, OperationEquivalence::ignoreHashValue
,
535 OperationEquivalence::ignoreHashValue
,
536 OperationEquivalence::IgnoreLocations
);
537 hash
= llvm::hash_combine(hash
, opHash
);
541 unsigned BlockEquivalenceData::getOrderOf(Value value
) const {
542 assert(value
.getParentBlock() == block
&& "expected value of this block");
544 // Arguments use the argument number as the order index.
545 if (BlockArgument arg
= dyn_cast
<BlockArgument
>(value
))
546 return arg
.getArgNumber();
548 // Otherwise, the result order is offset from the parent op's order.
549 OpResult result
= cast
<OpResult
>(value
);
550 auto opOrderIt
= opOrderIndex
.find(result
.getDefiningOp());
551 assert(opOrderIt
!= opOrderIndex
.end() && "expected op to have an order");
552 return opOrderIt
->second
+ result
.getResultNumber();
555 //===----------------------------------------------------------------------===//
559 /// This class represents a cluster of blocks to be merged together.
560 class BlockMergeCluster
{
562 BlockMergeCluster(BlockEquivalenceData
&&leaderData
)
563 : leaderData(std::move(leaderData
)) {}
565 /// Attempt to add the given block to this cluster. Returns success if the
566 /// block was merged, failure otherwise.
567 LogicalResult
addToCluster(BlockEquivalenceData
&blockData
);
569 /// Try to merge all of the blocks within this cluster into the leader block.
570 LogicalResult
merge(RewriterBase
&rewriter
);
573 /// The equivalence data for the leader of the cluster.
574 BlockEquivalenceData leaderData
;
576 /// The set of blocks that can be merged into the leader.
577 llvm::SmallSetVector
<Block
*, 1> blocksToMerge
;
579 /// A set of operand+index pairs that correspond to operands that need to be
580 /// replaced by arguments when the cluster gets merged.
581 std::set
<std::pair
<int, int>> operandsToMerge
;
585 LogicalResult
BlockMergeCluster::addToCluster(BlockEquivalenceData
&blockData
) {
586 if (leaderData
.hash
!= blockData
.hash
)
588 Block
*leaderBlock
= leaderData
.block
, *mergeBlock
= blockData
.block
;
589 if (leaderBlock
->getArgumentTypes() != mergeBlock
->getArgumentTypes())
592 // A set of operands that mismatch between the leader and the new block.
593 SmallVector
<std::pair
<int, int>, 8> mismatchedOperands
;
594 auto lhsIt
= leaderBlock
->begin(), lhsE
= leaderBlock
->end();
595 auto rhsIt
= blockData
.block
->begin(), rhsE
= blockData
.block
->end();
596 for (int opI
= 0; lhsIt
!= lhsE
&& rhsIt
!= rhsE
; ++lhsIt
, ++rhsIt
, ++opI
) {
597 // Check that the operations are equivalent.
598 if (!OperationEquivalence::isEquivalentTo(
599 &*lhsIt
, &*rhsIt
, OperationEquivalence::ignoreValueEquivalence
,
600 /*markEquivalent=*/nullptr,
601 OperationEquivalence::Flags::IgnoreLocations
))
604 // Compare the operands of the two operations. If the operand is within
605 // the block, it must refer to the same operation.
606 auto lhsOperands
= lhsIt
->getOperands(), rhsOperands
= rhsIt
->getOperands();
607 for (int operand
: llvm::seq
<int>(0, lhsIt
->getNumOperands())) {
608 Value lhsOperand
= lhsOperands
[operand
];
609 Value rhsOperand
= rhsOperands
[operand
];
610 if (lhsOperand
== rhsOperand
)
612 // Check that the types of the operands match.
613 if (lhsOperand
.getType() != rhsOperand
.getType())
616 // Check that these uses are both external, or both internal.
617 bool lhsIsInBlock
= lhsOperand
.getParentBlock() == leaderBlock
;
618 bool rhsIsInBlock
= rhsOperand
.getParentBlock() == mergeBlock
;
619 if (lhsIsInBlock
!= rhsIsInBlock
)
621 // Let the operands differ if they are defined in a different block. These
622 // will become new arguments if the blocks get merged.
625 // Check whether the operands aren't the result of an immediate
626 // predecessors terminator. In that case we are not able to use it as a
627 // successor operand when branching to the merged block as it does not
628 // dominate its producing operation.
629 auto isValidSuccessorArg
= [](Block
*block
, Value operand
) {
630 if (operand
.getDefiningOp() !=
631 operand
.getParentBlock()->getTerminator())
633 return !llvm::is_contained(block
->getPredecessors(),
634 operand
.getParentBlock());
637 if (!isValidSuccessorArg(leaderBlock
, lhsOperand
) ||
638 !isValidSuccessorArg(mergeBlock
, rhsOperand
))
641 mismatchedOperands
.emplace_back(opI
, operand
);
645 // Otherwise, these operands must have the same logical order within the
647 if (leaderData
.getOrderOf(lhsOperand
) != blockData
.getOrderOf(rhsOperand
))
651 // If the lhs or rhs has external uses, the blocks cannot be merged as the
652 // merged version of this operation will not be either the lhs or rhs
653 // alone (thus semantically incorrect), but some mix dependending on which
654 // block preceeded this.
655 // TODO allow merging of operations when one block does not dominate the
657 if (rhsIt
->isUsedOutsideOfBlock(mergeBlock
) ||
658 lhsIt
->isUsedOutsideOfBlock(leaderBlock
)) {
662 // Make sure that the block sizes are equivalent.
663 if (lhsIt
!= lhsE
|| rhsIt
!= rhsE
)
666 // If we get here, the blocks are equivalent and can be merged.
667 operandsToMerge
.insert(mismatchedOperands
.begin(), mismatchedOperands
.end());
668 blocksToMerge
.insert(blockData
.block
);
672 /// Returns true if the predecessor terminators of the given block can not have
673 /// their operands updated.
674 static bool ableToUpdatePredOperands(Block
*block
) {
675 for (auto it
= block
->pred_begin(), e
= block
->pred_end(); it
!= e
; ++it
) {
676 if (!isa
<BranchOpInterface
>((*it
)->getTerminator()))
682 /// Prunes the redundant list of new arguments. E.g., if we are passing an
683 /// argument list like [x, y, z, x] this would return [x, y, z] and it would
684 /// update the `block` (to whom the argument are passed to) accordingly. The new
685 /// arguments are passed as arguments at the back of the block, hence we need to
686 /// know how many `numOldArguments` were before, in order to correctly replace
687 /// the new arguments in the block
688 static SmallVector
<SmallVector
<Value
, 8>, 2> pruneRedundantArguments(
689 const SmallVector
<SmallVector
<Value
, 8>, 2> &newArguments
,
690 RewriterBase
&rewriter
, unsigned numOldArguments
, Block
*block
) {
692 SmallVector
<SmallVector
<Value
, 8>, 2> newArgumentsPruned(
693 newArguments
.size(), SmallVector
<Value
, 8>());
695 if (newArguments
.empty())
698 // `newArguments` is a 2D array of size `numLists` x `numArgs`
699 unsigned numLists
= newArguments
.size();
700 unsigned numArgs
= newArguments
[0].size();
702 // Map that for each arg index contains the index that we can use in place of
703 // the original index. E.g., if we have newArgs = [x, y, z, x], we will have
704 // idxToReplacement[3] = 0
705 llvm::DenseMap
<unsigned, unsigned> idxToReplacement
;
707 // This is a useful data structure to track the first appearance of a Value
708 // on a given list of arguments
709 DenseMap
<Value
, unsigned> firstValueToIdx
;
710 for (unsigned j
= 0; j
< numArgs
; ++j
) {
711 Value newArg
= newArguments
[0][j
];
712 firstValueToIdx
.try_emplace(newArg
, j
);
715 // Go through the first list of arguments (list 0).
716 for (unsigned j
= 0; j
< numArgs
; ++j
) {
717 // Look back to see if there are possible redundancies in list 0. Please
718 // note that we are using a map to annotate when an argument was seen first
719 // to avoid a O(N^2) algorithm. This has the drawback that if we have two
721 // list0: [%a, %a, %a]
722 // list1: [%c, %b, %b]
723 // We cannot simplify it, because firstValueToIdx[%a] = 0, but we cannot
724 // point list1[1](==%b) or list1[2](==%b) to list1[0](==%c). However, since
725 // the number of arguments can be potentially unbounded we cannot afford a
726 // O(N^2) algorithm (to search to all the possible pairs) and we need to
727 // accept the trade-off.
728 unsigned k
= firstValueToIdx
[newArguments
[0][j
]];
732 bool shouldReplaceJ
= true;
733 unsigned replacement
= k
;
734 // If a possible redundancy is found, then scan the other lists: we
735 // can prune the arguments if and only if they are redundant in every
737 for (unsigned i
= 1; i
< numLists
; ++i
)
739 shouldReplaceJ
&& (newArguments
[i
][k
] == newArguments
[i
][j
]);
740 // Save the replacement.
742 idxToReplacement
[j
] = replacement
;
745 // Populate the pruned argument list.
746 for (unsigned i
= 0; i
< numLists
; ++i
)
747 for (unsigned j
= 0; j
< numArgs
; ++j
)
748 if (!idxToReplacement
.contains(j
))
749 newArgumentsPruned
[i
].push_back(newArguments
[i
][j
]);
751 // Replace the block's redundant arguments.
752 SmallVector
<unsigned> toErase
;
753 for (auto [idx
, arg
] : llvm::enumerate(block
->getArguments())) {
754 if (idxToReplacement
.contains(idx
)) {
755 Value oldArg
= block
->getArgument(numOldArguments
+ idx
);
757 block
->getArgument(numOldArguments
+ idxToReplacement
[idx
]);
758 rewriter
.replaceAllUsesWith(oldArg
, newArg
);
759 toErase
.push_back(numOldArguments
+ idx
);
763 // Erase the block's redundant arguments.
764 for (unsigned idxToErase
: llvm::reverse(toErase
))
765 block
->eraseArgument(idxToErase
);
766 return newArgumentsPruned
;
769 LogicalResult
BlockMergeCluster::merge(RewriterBase
&rewriter
) {
770 // Don't consider clusters that don't have blocks to merge.
771 if (blocksToMerge
.empty())
774 Block
*leaderBlock
= leaderData
.block
;
775 if (!operandsToMerge
.empty()) {
776 // If the cluster has operands to merge, verify that the predecessor
777 // terminators of each of the blocks can have their successor operands
779 // TODO: We could try and sub-partition this cluster if only some blocks
780 // cause the mismatch.
781 if (!ableToUpdatePredOperands(leaderBlock
) ||
782 !llvm::all_of(blocksToMerge
, ableToUpdatePredOperands
))
785 // Collect the iterators for each of the blocks to merge. We will walk all
786 // of the iterators at once to avoid operand index invalidation.
787 SmallVector
<Block::iterator
, 2> blockIterators
;
788 blockIterators
.reserve(blocksToMerge
.size() + 1);
789 blockIterators
.push_back(leaderBlock
->begin());
790 for (Block
*mergeBlock
: blocksToMerge
)
791 blockIterators
.push_back(mergeBlock
->begin());
793 // Update each of the predecessor terminators with the new arguments.
794 SmallVector
<SmallVector
<Value
, 8>, 2> newArguments(
795 1 + blocksToMerge
.size(),
796 SmallVector
<Value
, 8>(operandsToMerge
.size()));
797 unsigned curOpIndex
= 0;
798 unsigned numOldArguments
= leaderBlock
->getNumArguments();
799 for (const auto &it
: llvm::enumerate(operandsToMerge
)) {
800 unsigned nextOpOffset
= it
.value().first
- curOpIndex
;
801 curOpIndex
= it
.value().first
;
803 // Process the operand for each of the block iterators.
804 for (unsigned i
= 0, e
= blockIterators
.size(); i
!= e
; ++i
) {
805 Block::iterator
&blockIter
= blockIterators
[i
];
806 std::advance(blockIter
, nextOpOffset
);
807 auto &operand
= blockIter
->getOpOperand(it
.value().second
);
808 newArguments
[i
][it
.index()] = operand
.get();
810 // Update the operand and insert an argument if this is the leader.
812 Value operandVal
= operand
.get();
813 operand
.set(leaderBlock
->addArgument(operandVal
.getType(),
814 operandVal
.getLoc()));
819 // Prune redundant arguments and update the leader block argument list
820 newArguments
= pruneRedundantArguments(newArguments
, rewriter
,
821 numOldArguments
, leaderBlock
);
823 // Update the predecessors for each of the blocks.
824 auto updatePredecessors
= [&](Block
*block
, unsigned clusterIndex
) {
825 for (auto predIt
= block
->pred_begin(), predE
= block
->pred_end();
826 predIt
!= predE
; ++predIt
) {
827 auto branch
= cast
<BranchOpInterface
>((*predIt
)->getTerminator());
828 unsigned succIndex
= predIt
.getSuccessorIndex();
829 branch
.getSuccessorOperands(succIndex
).append(
830 newArguments
[clusterIndex
]);
833 updatePredecessors(leaderBlock
, /*clusterIndex=*/0);
834 for (unsigned i
= 0, e
= blocksToMerge
.size(); i
!= e
; ++i
)
835 updatePredecessors(blocksToMerge
[i
], /*clusterIndex=*/i
+ 1);
838 // Replace all uses of the merged blocks with the leader and erase them.
839 for (Block
*block
: blocksToMerge
) {
840 block
->replaceAllUsesWith(leaderBlock
);
841 rewriter
.eraseBlock(block
);
846 /// Identify identical blocks within the given region and merge them, inserting
847 /// new block arguments as necessary. Returns success if any blocks were merged,
848 /// failure otherwise.
849 static LogicalResult
mergeIdenticalBlocks(RewriterBase
&rewriter
,
851 if (region
.empty() || llvm::hasSingleElement(region
))
854 // Identify sets of blocks, other than the entry block, that branch to the
855 // same successors. We will use these groups to create clusters of equivalent
857 DenseMap
<SuccessorRange
, SmallVector
<Block
*, 1>> matchingSuccessors
;
858 for (Block
&block
: llvm::drop_begin(region
, 1))
859 matchingSuccessors
[block
.getSuccessors()].push_back(&block
);
861 bool mergedAnyBlocks
= false;
862 for (ArrayRef
<Block
*> blocks
: llvm::make_second_range(matchingSuccessors
)) {
863 if (blocks
.size() == 1)
866 SmallVector
<BlockMergeCluster
, 1> clusters
;
867 for (Block
*block
: blocks
) {
868 BlockEquivalenceData
data(block
);
870 // Don't allow merging if this block has any regions.
871 // TODO: Add support for regions if necessary.
872 bool hasNonEmptyRegion
= llvm::any_of(*block
, [](Operation
&op
) {
873 return llvm::any_of(op
.getRegions(),
874 [](Region
®ion
) { return !region
.empty(); });
876 if (hasNonEmptyRegion
)
879 // Don't allow merging if this block's arguments are used outside of the
881 bool argHasExternalUsers
= llvm::any_of(
882 block
->getArguments(), [block
](mlir::BlockArgument
&arg
) {
883 return arg
.isUsedOutsideOfBlock(block
);
885 if (argHasExternalUsers
)
888 // Try to add this block to an existing cluster.
889 bool addedToCluster
= false;
890 for (auto &cluster
: clusters
)
891 if ((addedToCluster
= succeeded(cluster
.addToCluster(data
))))
894 clusters
.emplace_back(std::move(data
));
896 for (auto &cluster
: clusters
)
897 mergedAnyBlocks
|= succeeded(cluster
.merge(rewriter
));
900 return success(mergedAnyBlocks
);
903 /// Identify identical blocks within the given regions and merge them, inserting
904 /// new block arguments as necessary.
905 static LogicalResult
mergeIdenticalBlocks(RewriterBase
&rewriter
,
906 MutableArrayRef
<Region
> regions
) {
907 llvm::SmallSetVector
<Region
*, 1> worklist
;
908 for (auto ®ion
: regions
)
909 worklist
.insert(®ion
);
910 bool anyChanged
= false;
911 while (!worklist
.empty()) {
912 Region
*region
= worklist
.pop_back_val();
913 if (succeeded(mergeIdenticalBlocks(rewriter
, *region
))) {
914 worklist
.insert(region
);
918 // Add any nested regions to the worklist.
919 for (Block
&block
: *region
)
920 for (auto &op
: block
)
921 for (auto &nestedRegion
: op
.getRegions())
922 worklist
.insert(&nestedRegion
);
925 return success(anyChanged
);
928 /// If a block's argument is always the same across different invocations, then
929 /// drop the argument and use the value directly inside the block
930 static LogicalResult
dropRedundantArguments(RewriterBase
&rewriter
,
932 SmallVector
<size_t> argsToErase
;
934 // Go through the arguments of the block.
935 for (auto [argIdx
, blockOperand
] : llvm::enumerate(block
.getArguments())) {
939 // Go through the block predecessor and flag if they pass to the block
940 // different values for the same argument.
941 for (Block::pred_iterator predIt
= block
.pred_begin(),
942 predE
= block
.pred_end();
943 predIt
!= predE
; ++predIt
) {
944 auto branch
= dyn_cast
<BranchOpInterface
>((*predIt
)->getTerminator());
949 unsigned succIndex
= predIt
.getSuccessorIndex();
950 SuccessorOperands succOperands
= branch
.getSuccessorOperands(succIndex
);
951 auto branchOperands
= succOperands
.getForwardedOperands();
953 commonValue
= branchOperands
[argIdx
];
956 if (branchOperands
[argIdx
] != commonValue
) {
962 // If they are passing the same value, drop the argument.
963 if (commonValue
&& sameArg
) {
964 argsToErase
.push_back(argIdx
);
966 // Remove the argument from the block.
967 rewriter
.replaceAllUsesWith(blockOperand
, commonValue
);
971 // Remove the arguments.
972 for (size_t argIdx
: llvm::reverse(argsToErase
)) {
973 block
.eraseArgument(argIdx
);
975 // Remove the argument from the branch ops.
976 for (auto predIt
= block
.pred_begin(), predE
= block
.pred_end();
977 predIt
!= predE
; ++predIt
) {
978 auto branch
= cast
<BranchOpInterface
>((*predIt
)->getTerminator());
979 unsigned succIndex
= predIt
.getSuccessorIndex();
980 SuccessorOperands succOperands
= branch
.getSuccessorOperands(succIndex
);
981 succOperands
.erase(argIdx
);
984 return success(!argsToErase
.empty());
987 /// This optimization drops redundant argument to blocks. I.e., if a given
988 /// argument to a block receives the same value from each of the block
989 /// predecessors, we can remove the argument from the block and use directly the
990 /// original value. This is a simple example:
992 /// %cond = llvm.call @rand() : () -> i1
993 /// %val0 = llvm.mlir.constant(1 : i64) : i64
994 /// %val1 = llvm.mlir.constant(2 : i64) : i64
995 /// %val2 = llvm.mlir.constant(3 : i64) : i64
996 /// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
999 /// ^bb1(%arg0 : i64, %arg1 : i64):
1000 /// llvm.call @foo(%arg0, %arg1)
1002 /// The previous IR can be rewritten as:
1003 /// %cond = llvm.call @rand() : () -> i1
1004 /// %val0 = llvm.mlir.constant(1 : i64) : i64
1005 /// %val1 = llvm.mlir.constant(2 : i64) : i64
1006 /// %val2 = llvm.mlir.constant(3 : i64) : i64
1007 /// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
1009 /// ^bb1(%arg0 : i64):
1010 /// llvm.call @foo(%val0, %arg0)
1012 static LogicalResult
dropRedundantArguments(RewriterBase
&rewriter
,
1013 MutableArrayRef
<Region
> regions
) {
1014 llvm::SmallSetVector
<Region
*, 1> worklist
;
1015 for (Region
®ion
: regions
)
1016 worklist
.insert(®ion
);
1017 bool anyChanged
= false;
1018 while (!worklist
.empty()) {
1019 Region
*region
= worklist
.pop_back_val();
1021 // Add any nested regions to the worklist.
1022 for (Block
&block
: *region
) {
1024 succeeded(dropRedundantArguments(rewriter
, block
)) || anyChanged
;
1026 for (Operation
&op
: block
)
1027 for (Region
&nestedRegion
: op
.getRegions())
1028 worklist
.insert(&nestedRegion
);
1031 return success(anyChanged
);
1034 //===----------------------------------------------------------------------===//
1035 // Region Simplification
1036 //===----------------------------------------------------------------------===//
1038 /// Run a set of structural simplifications over the given regions. This
1039 /// includes transformations like unreachable block elimination, dead argument
1040 /// elimination, as well as some other DCE. This function returns success if any
1041 /// of the regions were simplified, failure otherwise.
1042 LogicalResult
mlir::simplifyRegions(RewriterBase
&rewriter
,
1043 MutableArrayRef
<Region
> regions
,
1045 bool eliminatedBlocks
= succeeded(eraseUnreachableBlocks(rewriter
, regions
));
1046 bool eliminatedOpsOrArgs
= succeeded(runRegionDCE(rewriter
, regions
));
1047 bool mergedIdenticalBlocks
= false;
1048 bool droppedRedundantArguments
= false;
1050 mergedIdenticalBlocks
= succeeded(mergeIdenticalBlocks(rewriter
, regions
));
1051 droppedRedundantArguments
=
1052 succeeded(dropRedundantArguments(rewriter
, regions
));
1054 return success(eliminatedBlocks
|| eliminatedOpsOrArgs
||
1055 mergedIdenticalBlocks
|| droppedRedundantArguments
);