[SLP] Make getSameOpcode support different instructions if they have same semantics...
[llvm-project.git] / mlir / lib / Transforms / Utils / RegionUtils.cpp
blobe55ef6eb66b9c7d6751dec90c7e8407c5581c159
1 //===- RegionUtils.cpp - Region-related transformation utilities ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Transforms/RegionUtils.h"
10 #include "mlir/Analysis/TopologicalSortUtils.h"
11 #include "mlir/IR/Block.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/IRMapping.h"
14 #include "mlir/IR/Operation.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/RegionGraphTraits.h"
17 #include "mlir/IR/Value.h"
18 #include "mlir/Interfaces/ControlFlowInterfaces.h"
19 #include "mlir/Interfaces/SideEffectInterfaces.h"
20 #include "mlir/Support/LogicalResult.h"
22 #include "llvm/ADT/DepthFirstIterator.h"
23 #include "llvm/ADT/PostOrderIterator.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallSet.h"
27 #include <deque>
28 #include <iterator>
30 using namespace mlir;
32 void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement,
33 Region &region) {
34 for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
35 if (region.isAncestor(use.getOwner()->getParentRegion()))
36 use.set(replacement);
40 void mlir::visitUsedValuesDefinedAbove(
41 Region &region, Region &limit, function_ref<void(OpOperand *)> callback) {
42 assert(limit.isAncestor(&region) &&
43 "expected isolation limit to be an ancestor of the given region");
45 // Collect proper ancestors of `limit` upfront to avoid traversing the region
46 // tree for every value.
47 SmallPtrSet<Region *, 4> properAncestors;
48 for (auto *reg = limit.getParentRegion(); reg != nullptr;
49 reg = reg->getParentRegion()) {
50 properAncestors.insert(reg);
53 region.walk([callback, &properAncestors](Operation *op) {
54 for (OpOperand &operand : op->getOpOperands())
55 // Callback on values defined in a proper ancestor of region.
56 if (properAncestors.count(operand.get().getParentRegion()))
57 callback(&operand);
58 });
61 void mlir::visitUsedValuesDefinedAbove(
62 MutableArrayRef<Region> regions, function_ref<void(OpOperand *)> callback) {
63 for (Region &region : regions)
64 visitUsedValuesDefinedAbove(region, region, callback);
67 void mlir::getUsedValuesDefinedAbove(Region &region, Region &limit,
68 SetVector<Value> &values) {
69 visitUsedValuesDefinedAbove(region, limit, [&](OpOperand *operand) {
70 values.insert(operand->get());
71 });
74 void mlir::getUsedValuesDefinedAbove(MutableArrayRef<Region> regions,
75 SetVector<Value> &values) {
76 for (Region &region : regions)
77 getUsedValuesDefinedAbove(region, region, values);
80 //===----------------------------------------------------------------------===//
81 // Make block isolated from above.
82 //===----------------------------------------------------------------------===//
84 SmallVector<Value> mlir::makeRegionIsolatedFromAbove(
85 RewriterBase &rewriter, Region &region,
86 llvm::function_ref<bool(Operation *)> cloneOperationIntoRegion) {
88 // Get initial list of values used within region but defined above.
89 llvm::SetVector<Value> initialCapturedValues;
90 mlir::getUsedValuesDefinedAbove(region, initialCapturedValues);
92 std::deque<Value> worklist(initialCapturedValues.begin(),
93 initialCapturedValues.end());
94 llvm::DenseSet<Value> visited;
95 llvm::DenseSet<Operation *> visitedOps;
97 llvm::SetVector<Value> finalCapturedValues;
98 SmallVector<Operation *> clonedOperations;
99 while (!worklist.empty()) {
100 Value currValue = worklist.front();
101 worklist.pop_front();
102 if (visited.count(currValue))
103 continue;
104 visited.insert(currValue);
106 Operation *definingOp = currValue.getDefiningOp();
107 if (!definingOp || visitedOps.count(definingOp)) {
108 finalCapturedValues.insert(currValue);
109 continue;
111 visitedOps.insert(definingOp);
113 if (!cloneOperationIntoRegion(definingOp)) {
114 // Defining operation isnt cloned, so add the current value to final
115 // captured values list.
116 finalCapturedValues.insert(currValue);
117 continue;
120 // Add all operands of the operation to the worklist and mark the op as to
121 // be cloned.
122 for (Value operand : definingOp->getOperands()) {
123 if (visited.count(operand))
124 continue;
125 worklist.push_back(operand);
127 clonedOperations.push_back(definingOp);
130 // The operations to be cloned need to be ordered in topological order
131 // so that they can be cloned into the region without violating use-def
132 // chains.
133 mlir::computeTopologicalSorting(clonedOperations);
135 OpBuilder::InsertionGuard g(rewriter);
136 // Collect types of existing block
137 Block *entryBlock = &region.front();
138 SmallVector<Type> newArgTypes =
139 llvm::to_vector(entryBlock->getArgumentTypes());
140 SmallVector<Location> newArgLocs = llvm::to_vector(llvm::map_range(
141 entryBlock->getArguments(), [](BlockArgument b) { return b.getLoc(); }));
143 // Append the types of the captured values.
144 for (auto value : finalCapturedValues) {
145 newArgTypes.push_back(value.getType());
146 newArgLocs.push_back(value.getLoc());
149 // Create a new entry block.
150 Block *newEntryBlock =
151 rewriter.createBlock(&region, region.begin(), newArgTypes, newArgLocs);
152 auto newEntryBlockArgs = newEntryBlock->getArguments();
154 // Create a mapping between the captured values and the new arguments added.
155 IRMapping map;
156 auto replaceIfFn = [&](OpOperand &use) {
157 return use.getOwner()->getBlock()->getParent() == &region;
159 for (auto [arg, capturedVal] :
160 llvm::zip(newEntryBlockArgs.take_back(finalCapturedValues.size()),
161 finalCapturedValues)) {
162 map.map(capturedVal, arg);
163 rewriter.replaceUsesWithIf(capturedVal, arg, replaceIfFn);
165 rewriter.setInsertionPointToStart(newEntryBlock);
166 for (auto *clonedOp : clonedOperations) {
167 Operation *newOp = rewriter.clone(*clonedOp, map);
168 rewriter.replaceOpUsesWithIf(clonedOp, newOp->getResults(), replaceIfFn);
170 rewriter.mergeBlocks(
171 entryBlock, newEntryBlock,
172 newEntryBlock->getArguments().take_front(entryBlock->getNumArguments()));
173 return llvm::to_vector(finalCapturedValues);
176 //===----------------------------------------------------------------------===//
177 // Unreachable Block Elimination
178 //===----------------------------------------------------------------------===//
180 /// Erase the unreachable blocks within the provided regions. Returns success
181 /// if any blocks were erased, failure otherwise.
182 // TODO: We could likely merge this with the DCE algorithm below.
183 LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter,
184 MutableArrayRef<Region> regions) {
185 // Set of blocks found to be reachable within a given region.
186 llvm::df_iterator_default_set<Block *, 16> reachable;
187 // If any blocks were found to be dead.
188 bool erasedDeadBlocks = false;
190 SmallVector<Region *, 1> worklist;
191 worklist.reserve(regions.size());
192 for (Region &region : regions)
193 worklist.push_back(&region);
194 while (!worklist.empty()) {
195 Region *region = worklist.pop_back_val();
196 if (region->empty())
197 continue;
199 // If this is a single block region, just collect the nested regions.
200 if (std::next(region->begin()) == region->end()) {
201 for (Operation &op : region->front())
202 for (Region &region : op.getRegions())
203 worklist.push_back(&region);
204 continue;
207 // Mark all reachable blocks.
208 reachable.clear();
209 for (Block *block : depth_first_ext(&region->front(), reachable))
210 (void)block /* Mark all reachable blocks */;
212 // Collect all of the dead blocks and push the live regions onto the
213 // worklist.
214 for (Block &block : llvm::make_early_inc_range(*region)) {
215 if (!reachable.count(&block)) {
216 block.dropAllDefinedValueUses();
217 rewriter.eraseBlock(&block);
218 erasedDeadBlocks = true;
219 continue;
222 // Walk any regions within this block.
223 for (Operation &op : block)
224 for (Region &region : op.getRegions())
225 worklist.push_back(&region);
229 return success(erasedDeadBlocks);
232 //===----------------------------------------------------------------------===//
233 // Dead Code Elimination
234 //===----------------------------------------------------------------------===//
236 namespace {
237 /// Data structure used to track which values have already been proved live.
239 /// Because Operation's can have multiple results, this data structure tracks
240 /// liveness for both Value's and Operation's to avoid having to look through
241 /// all Operation results when analyzing a use.
243 /// This data structure essentially tracks the dataflow lattice.
244 /// The set of values/ops proved live increases monotonically to a fixed-point.
245 class LiveMap {
246 public:
247 /// Value methods.
248 bool wasProvenLive(Value value) {
249 // TODO: For results that are removable, e.g. for region based control flow,
250 // we could allow for these values to be tracked independently.
251 if (OpResult result = dyn_cast<OpResult>(value))
252 return wasProvenLive(result.getOwner());
253 return wasProvenLive(cast<BlockArgument>(value));
255 bool wasProvenLive(BlockArgument arg) { return liveValues.count(arg); }
256 void setProvedLive(Value value) {
257 // TODO: For results that are removable, e.g. for region based control flow,
258 // we could allow for these values to be tracked independently.
259 if (OpResult result = dyn_cast<OpResult>(value))
260 return setProvedLive(result.getOwner());
261 setProvedLive(cast<BlockArgument>(value));
263 void setProvedLive(BlockArgument arg) {
264 changed |= liveValues.insert(arg).second;
267 /// Operation methods.
268 bool wasProvenLive(Operation *op) { return liveOps.count(op); }
269 void setProvedLive(Operation *op) { changed |= liveOps.insert(op).second; }
271 /// Methods for tracking if we have reached a fixed-point.
272 void resetChanged() { changed = false; }
273 bool hasChanged() { return changed; }
275 private:
276 bool changed = false;
277 DenseSet<Value> liveValues;
278 DenseSet<Operation *> liveOps;
280 } // namespace
282 static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) {
283 Operation *owner = use.getOwner();
284 unsigned operandIndex = use.getOperandNumber();
285 // This pass generally treats all uses of an op as live if the op itself is
286 // considered live. However, for successor operands to terminators we need a
287 // finer-grained notion where we deduce liveness for operands individually.
288 // The reason for this is easiest to think about in terms of a classical phi
289 // node based SSA IR, where each successor operand is really an operand to a
290 // *separate* phi node, rather than all operands to the branch itself as with
291 // the block argument representation that MLIR uses.
293 // And similarly, because each successor operand is really an operand to a phi
294 // node, rather than to the terminator op itself, a terminator op can't e.g.
295 // "print" the value of a successor operand.
296 if (owner->hasTrait<OpTrait::IsTerminator>()) {
297 if (BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(owner))
298 if (auto arg = branchInterface.getSuccessorBlockArgument(operandIndex))
299 return !liveMap.wasProvenLive(*arg);
300 return false;
302 return false;
305 static void processValue(Value value, LiveMap &liveMap) {
306 bool provedLive = llvm::any_of(value.getUses(), [&](OpOperand &use) {
307 if (isUseSpeciallyKnownDead(use, liveMap))
308 return false;
309 return liveMap.wasProvenLive(use.getOwner());
311 if (provedLive)
312 liveMap.setProvedLive(value);
315 static void propagateLiveness(Region &region, LiveMap &liveMap);
317 static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
318 // Terminators are always live.
319 liveMap.setProvedLive(op);
321 // Check to see if we can reason about the successor operands and mutate them.
322 BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op);
323 if (!branchInterface) {
324 for (Block *successor : op->getSuccessors())
325 for (BlockArgument arg : successor->getArguments())
326 liveMap.setProvedLive(arg);
327 return;
330 // If we can't reason about the operand to a successor, conservatively mark
331 // it as live.
332 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
333 SuccessorOperands successorOperands =
334 branchInterface.getSuccessorOperands(i);
335 for (unsigned opI = 0, opE = successorOperands.getProducedOperandCount();
336 opI != opE; ++opI)
337 liveMap.setProvedLive(op->getSuccessor(i)->getArgument(opI));
341 static void propagateLiveness(Operation *op, LiveMap &liveMap) {
342 // Recurse on any regions the op has.
343 for (Region &region : op->getRegions())
344 propagateLiveness(region, liveMap);
346 // Process terminator operations.
347 if (op->hasTrait<OpTrait::IsTerminator>())
348 return propagateTerminatorLiveness(op, liveMap);
350 // Don't reprocess live operations.
351 if (liveMap.wasProvenLive(op))
352 return;
354 // Process the op itself.
355 if (!wouldOpBeTriviallyDead(op))
356 return liveMap.setProvedLive(op);
358 // If the op isn't intrinsically alive, check it's results.
359 for (Value value : op->getResults())
360 processValue(value, liveMap);
363 static void propagateLiveness(Region &region, LiveMap &liveMap) {
364 if (region.empty())
365 return;
367 for (Block *block : llvm::post_order(&region.front())) {
368 // We process block arguments after the ops in the block, to promote
369 // faster convergence to a fixed point (we try to visit uses before defs).
370 for (Operation &op : llvm::reverse(block->getOperations()))
371 propagateLiveness(&op, liveMap);
373 // We currently do not remove entry block arguments, so there is no need to
374 // track their liveness.
375 // TODO: We could track these and enable removing dead operands/arguments
376 // from region control flow operations.
377 if (block->isEntryBlock())
378 continue;
380 for (Value value : block->getArguments()) {
381 if (!liveMap.wasProvenLive(value))
382 processValue(value, liveMap);
387 static void eraseTerminatorSuccessorOperands(Operation *terminator,
388 LiveMap &liveMap) {
389 BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(terminator);
390 if (!branchOp)
391 return;
393 for (unsigned succI = 0, succE = terminator->getNumSuccessors();
394 succI < succE; succI++) {
395 // Iterating successors in reverse is not strictly needed, since we
396 // aren't erasing any successors. But it is slightly more efficient
397 // since it will promote later operands of the terminator being erased
398 // first, reducing the quadratic-ness.
399 unsigned succ = succE - succI - 1;
400 SuccessorOperands succOperands = branchOp.getSuccessorOperands(succ);
401 Block *successor = terminator->getSuccessor(succ);
403 for (unsigned argI = 0, argE = succOperands.size(); argI < argE; ++argI) {
404 // Iterating args in reverse is needed for correctness, to avoid
405 // shifting later args when earlier args are erased.
406 unsigned arg = argE - argI - 1;
407 if (!liveMap.wasProvenLive(successor->getArgument(arg)))
408 succOperands.erase(arg);
413 static LogicalResult deleteDeadness(RewriterBase &rewriter,
414 MutableArrayRef<Region> regions,
415 LiveMap &liveMap) {
416 bool erasedAnything = false;
417 for (Region &region : regions) {
418 if (region.empty())
419 continue;
420 bool hasSingleBlock = llvm::hasSingleElement(region);
422 // Delete every operation that is not live. Graph regions may have cycles
423 // in the use-def graph, so we must explicitly dropAllUses() from each
424 // operation as we erase it. Visiting the operations in post-order
425 // guarantees that in SSA CFG regions value uses are removed before defs,
426 // which makes dropAllUses() a no-op.
427 for (Block *block : llvm::post_order(&region.front())) {
428 if (!hasSingleBlock)
429 eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap);
430 for (Operation &childOp :
431 llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) {
432 if (!liveMap.wasProvenLive(&childOp)) {
433 erasedAnything = true;
434 childOp.dropAllUses();
435 rewriter.eraseOp(&childOp);
436 } else {
437 erasedAnything |= succeeded(
438 deleteDeadness(rewriter, childOp.getRegions(), liveMap));
442 // Delete block arguments.
443 // The entry block has an unknown contract with their enclosing block, so
444 // skip it.
445 for (Block &block : llvm::drop_begin(region.getBlocks(), 1)) {
446 block.eraseArguments(
447 [&](BlockArgument arg) { return !liveMap.wasProvenLive(arg); });
450 return success(erasedAnything);
453 // This function performs a simple dead code elimination algorithm over the
454 // given regions.
456 // The overall goal is to prove that Values are dead, which allows deleting ops
457 // and block arguments.
459 // This uses an optimistic algorithm that assumes everything is dead until
460 // proved otherwise, allowing it to delete recursively dead cycles.
462 // This is a simple fixed-point dataflow analysis algorithm on a lattice
463 // {Dead,Alive}. Because liveness flows backward, we generally try to
464 // iterate everything backward to speed up convergence to the fixed-point. This
465 // allows for being able to delete recursively dead cycles of the use-def graph,
466 // including block arguments.
468 // This function returns success if any operations or arguments were deleted,
469 // failure otherwise.
470 LogicalResult mlir::runRegionDCE(RewriterBase &rewriter,
471 MutableArrayRef<Region> regions) {
472 LiveMap liveMap;
473 do {
474 liveMap.resetChanged();
476 for (Region &region : regions)
477 propagateLiveness(region, liveMap);
478 } while (liveMap.hasChanged());
480 return deleteDeadness(rewriter, regions, liveMap);
483 //===----------------------------------------------------------------------===//
484 // Block Merging
485 //===----------------------------------------------------------------------===//
487 //===----------------------------------------------------------------------===//
488 // BlockEquivalenceData
490 namespace {
491 /// This class contains the information for comparing the equivalencies of two
492 /// blocks. Blocks are considered equivalent if they contain the same operations
493 /// in the same order. The only allowed divergence is for operands that come
494 /// from sources outside of the parent block, i.e. the uses of values produced
495 /// within the block must be equivalent.
496 /// e.g.,
497 /// Equivalent:
498 /// ^bb1(%arg0: i32)
499 /// return %arg0, %foo : i32, i32
500 /// ^bb2(%arg1: i32)
501 /// return %arg1, %bar : i32, i32
502 /// Not Equivalent:
503 /// ^bb1(%arg0: i32)
504 /// return %foo, %arg0 : i32, i32
505 /// ^bb2(%arg1: i32)
506 /// return %arg1, %bar : i32, i32
507 struct BlockEquivalenceData {
508 BlockEquivalenceData(Block *block);
510 /// Return the order index for the given value that is within the block of
511 /// this data.
512 unsigned getOrderOf(Value value) const;
514 /// The block this data refers to.
515 Block *block;
516 /// A hash value for this block.
517 llvm::hash_code hash;
518 /// A map of result producing operations to their relative orders within this
519 /// block. The order of an operation is the number of defined values that are
520 /// produced within the block before this operation.
521 DenseMap<Operation *, unsigned> opOrderIndex;
523 } // namespace
525 BlockEquivalenceData::BlockEquivalenceData(Block *block)
526 : block(block), hash(0) {
527 unsigned orderIt = block->getNumArguments();
528 for (Operation &op : *block) {
529 if (unsigned numResults = op.getNumResults()) {
530 opOrderIndex.try_emplace(&op, orderIt);
531 orderIt += numResults;
533 auto opHash = OperationEquivalence::computeHash(
534 &op, OperationEquivalence::ignoreHashValue,
535 OperationEquivalence::ignoreHashValue,
536 OperationEquivalence::IgnoreLocations);
537 hash = llvm::hash_combine(hash, opHash);
541 unsigned BlockEquivalenceData::getOrderOf(Value value) const {
542 assert(value.getParentBlock() == block && "expected value of this block");
544 // Arguments use the argument number as the order index.
545 if (BlockArgument arg = dyn_cast<BlockArgument>(value))
546 return arg.getArgNumber();
548 // Otherwise, the result order is offset from the parent op's order.
549 OpResult result = cast<OpResult>(value);
550 auto opOrderIt = opOrderIndex.find(result.getDefiningOp());
551 assert(opOrderIt != opOrderIndex.end() && "expected op to have an order");
552 return opOrderIt->second + result.getResultNumber();
555 //===----------------------------------------------------------------------===//
556 // BlockMergeCluster
558 namespace {
559 /// This class represents a cluster of blocks to be merged together.
560 class BlockMergeCluster {
561 public:
562 BlockMergeCluster(BlockEquivalenceData &&leaderData)
563 : leaderData(std::move(leaderData)) {}
565 /// Attempt to add the given block to this cluster. Returns success if the
566 /// block was merged, failure otherwise.
567 LogicalResult addToCluster(BlockEquivalenceData &blockData);
569 /// Try to merge all of the blocks within this cluster into the leader block.
570 LogicalResult merge(RewriterBase &rewriter);
572 private:
573 /// The equivalence data for the leader of the cluster.
574 BlockEquivalenceData leaderData;
576 /// The set of blocks that can be merged into the leader.
577 llvm::SmallSetVector<Block *, 1> blocksToMerge;
579 /// A set of operand+index pairs that correspond to operands that need to be
580 /// replaced by arguments when the cluster gets merged.
581 std::set<std::pair<int, int>> operandsToMerge;
583 } // namespace
585 LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
586 if (leaderData.hash != blockData.hash)
587 return failure();
588 Block *leaderBlock = leaderData.block, *mergeBlock = blockData.block;
589 if (leaderBlock->getArgumentTypes() != mergeBlock->getArgumentTypes())
590 return failure();
592 // A set of operands that mismatch between the leader and the new block.
593 SmallVector<std::pair<int, int>, 8> mismatchedOperands;
594 auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end();
595 auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end();
596 for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) {
597 // Check that the operations are equivalent.
598 if (!OperationEquivalence::isEquivalentTo(
599 &*lhsIt, &*rhsIt, OperationEquivalence::ignoreValueEquivalence,
600 /*markEquivalent=*/nullptr,
601 OperationEquivalence::Flags::IgnoreLocations))
602 return failure();
604 // Compare the operands of the two operations. If the operand is within
605 // the block, it must refer to the same operation.
606 auto lhsOperands = lhsIt->getOperands(), rhsOperands = rhsIt->getOperands();
607 for (int operand : llvm::seq<int>(0, lhsIt->getNumOperands())) {
608 Value lhsOperand = lhsOperands[operand];
609 Value rhsOperand = rhsOperands[operand];
610 if (lhsOperand == rhsOperand)
611 continue;
612 // Check that the types of the operands match.
613 if (lhsOperand.getType() != rhsOperand.getType())
614 return failure();
616 // Check that these uses are both external, or both internal.
617 bool lhsIsInBlock = lhsOperand.getParentBlock() == leaderBlock;
618 bool rhsIsInBlock = rhsOperand.getParentBlock() == mergeBlock;
619 if (lhsIsInBlock != rhsIsInBlock)
620 return failure();
621 // Let the operands differ if they are defined in a different block. These
622 // will become new arguments if the blocks get merged.
623 if (!lhsIsInBlock) {
625 // Check whether the operands aren't the result of an immediate
626 // predecessors terminator. In that case we are not able to use it as a
627 // successor operand when branching to the merged block as it does not
628 // dominate its producing operation.
629 auto isValidSuccessorArg = [](Block *block, Value operand) {
630 if (operand.getDefiningOp() !=
631 operand.getParentBlock()->getTerminator())
632 return true;
633 return !llvm::is_contained(block->getPredecessors(),
634 operand.getParentBlock());
637 if (!isValidSuccessorArg(leaderBlock, lhsOperand) ||
638 !isValidSuccessorArg(mergeBlock, rhsOperand))
639 return failure();
641 mismatchedOperands.emplace_back(opI, operand);
642 continue;
645 // Otherwise, these operands must have the same logical order within the
646 // parent block.
647 if (leaderData.getOrderOf(lhsOperand) != blockData.getOrderOf(rhsOperand))
648 return failure();
651 // If the lhs or rhs has external uses, the blocks cannot be merged as the
652 // merged version of this operation will not be either the lhs or rhs
653 // alone (thus semantically incorrect), but some mix dependending on which
654 // block preceeded this.
655 // TODO allow merging of operations when one block does not dominate the
656 // other
657 if (rhsIt->isUsedOutsideOfBlock(mergeBlock) ||
658 lhsIt->isUsedOutsideOfBlock(leaderBlock)) {
659 return failure();
662 // Make sure that the block sizes are equivalent.
663 if (lhsIt != lhsE || rhsIt != rhsE)
664 return failure();
666 // If we get here, the blocks are equivalent and can be merged.
667 operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end());
668 blocksToMerge.insert(blockData.block);
669 return success();
672 /// Returns true if the predecessor terminators of the given block can not have
673 /// their operands updated.
674 static bool ableToUpdatePredOperands(Block *block) {
675 for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
676 if (!isa<BranchOpInterface>((*it)->getTerminator()))
677 return false;
679 return true;
682 /// Prunes the redundant list of new arguments. E.g., if we are passing an
683 /// argument list like [x, y, z, x] this would return [x, y, z] and it would
684 /// update the `block` (to whom the argument are passed to) accordingly. The new
685 /// arguments are passed as arguments at the back of the block, hence we need to
686 /// know how many `numOldArguments` were before, in order to correctly replace
687 /// the new arguments in the block
688 static SmallVector<SmallVector<Value, 8>, 2> pruneRedundantArguments(
689 const SmallVector<SmallVector<Value, 8>, 2> &newArguments,
690 RewriterBase &rewriter, unsigned numOldArguments, Block *block) {
692 SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned(
693 newArguments.size(), SmallVector<Value, 8>());
695 if (newArguments.empty())
696 return newArguments;
698 // `newArguments` is a 2D array of size `numLists` x `numArgs`
699 unsigned numLists = newArguments.size();
700 unsigned numArgs = newArguments[0].size();
702 // Map that for each arg index contains the index that we can use in place of
703 // the original index. E.g., if we have newArgs = [x, y, z, x], we will have
704 // idxToReplacement[3] = 0
705 llvm::DenseMap<unsigned, unsigned> idxToReplacement;
707 // This is a useful data structure to track the first appearance of a Value
708 // on a given list of arguments
709 DenseMap<Value, unsigned> firstValueToIdx;
710 for (unsigned j = 0; j < numArgs; ++j) {
711 Value newArg = newArguments[0][j];
712 firstValueToIdx.try_emplace(newArg, j);
715 // Go through the first list of arguments (list 0).
716 for (unsigned j = 0; j < numArgs; ++j) {
717 // Look back to see if there are possible redundancies in list 0. Please
718 // note that we are using a map to annotate when an argument was seen first
719 // to avoid a O(N^2) algorithm. This has the drawback that if we have two
720 // lists like:
721 // list0: [%a, %a, %a]
722 // list1: [%c, %b, %b]
723 // We cannot simplify it, because firstValueToIdx[%a] = 0, but we cannot
724 // point list1[1](==%b) or list1[2](==%b) to list1[0](==%c). However, since
725 // the number of arguments can be potentially unbounded we cannot afford a
726 // O(N^2) algorithm (to search to all the possible pairs) and we need to
727 // accept the trade-off.
728 unsigned k = firstValueToIdx[newArguments[0][j]];
729 if (k == j)
730 continue;
732 bool shouldReplaceJ = true;
733 unsigned replacement = k;
734 // If a possible redundancy is found, then scan the other lists: we
735 // can prune the arguments if and only if they are redundant in every
736 // list.
737 for (unsigned i = 1; i < numLists; ++i)
738 shouldReplaceJ =
739 shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
740 // Save the replacement.
741 if (shouldReplaceJ)
742 idxToReplacement[j] = replacement;
745 // Populate the pruned argument list.
746 for (unsigned i = 0; i < numLists; ++i)
747 for (unsigned j = 0; j < numArgs; ++j)
748 if (!idxToReplacement.contains(j))
749 newArgumentsPruned[i].push_back(newArguments[i][j]);
751 // Replace the block's redundant arguments.
752 SmallVector<unsigned> toErase;
753 for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
754 if (idxToReplacement.contains(idx)) {
755 Value oldArg = block->getArgument(numOldArguments + idx);
756 Value newArg =
757 block->getArgument(numOldArguments + idxToReplacement[idx]);
758 rewriter.replaceAllUsesWith(oldArg, newArg);
759 toErase.push_back(numOldArguments + idx);
763 // Erase the block's redundant arguments.
764 for (unsigned idxToErase : llvm::reverse(toErase))
765 block->eraseArgument(idxToErase);
766 return newArgumentsPruned;
769 LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
770 // Don't consider clusters that don't have blocks to merge.
771 if (blocksToMerge.empty())
772 return failure();
774 Block *leaderBlock = leaderData.block;
775 if (!operandsToMerge.empty()) {
776 // If the cluster has operands to merge, verify that the predecessor
777 // terminators of each of the blocks can have their successor operands
778 // updated.
779 // TODO: We could try and sub-partition this cluster if only some blocks
780 // cause the mismatch.
781 if (!ableToUpdatePredOperands(leaderBlock) ||
782 !llvm::all_of(blocksToMerge, ableToUpdatePredOperands))
783 return failure();
785 // Collect the iterators for each of the blocks to merge. We will walk all
786 // of the iterators at once to avoid operand index invalidation.
787 SmallVector<Block::iterator, 2> blockIterators;
788 blockIterators.reserve(blocksToMerge.size() + 1);
789 blockIterators.push_back(leaderBlock->begin());
790 for (Block *mergeBlock : blocksToMerge)
791 blockIterators.push_back(mergeBlock->begin());
793 // Update each of the predecessor terminators with the new arguments.
794 SmallVector<SmallVector<Value, 8>, 2> newArguments(
795 1 + blocksToMerge.size(),
796 SmallVector<Value, 8>(operandsToMerge.size()));
797 unsigned curOpIndex = 0;
798 unsigned numOldArguments = leaderBlock->getNumArguments();
799 for (const auto &it : llvm::enumerate(operandsToMerge)) {
800 unsigned nextOpOffset = it.value().first - curOpIndex;
801 curOpIndex = it.value().first;
803 // Process the operand for each of the block iterators.
804 for (unsigned i = 0, e = blockIterators.size(); i != e; ++i) {
805 Block::iterator &blockIter = blockIterators[i];
806 std::advance(blockIter, nextOpOffset);
807 auto &operand = blockIter->getOpOperand(it.value().second);
808 newArguments[i][it.index()] = operand.get();
810 // Update the operand and insert an argument if this is the leader.
811 if (i == 0) {
812 Value operandVal = operand.get();
813 operand.set(leaderBlock->addArgument(operandVal.getType(),
814 operandVal.getLoc()));
819 // Prune redundant arguments and update the leader block argument list
820 newArguments = pruneRedundantArguments(newArguments, rewriter,
821 numOldArguments, leaderBlock);
823 // Update the predecessors for each of the blocks.
824 auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
825 for (auto predIt = block->pred_begin(), predE = block->pred_end();
826 predIt != predE; ++predIt) {
827 auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
828 unsigned succIndex = predIt.getSuccessorIndex();
829 branch.getSuccessorOperands(succIndex).append(
830 newArguments[clusterIndex]);
833 updatePredecessors(leaderBlock, /*clusterIndex=*/0);
834 for (unsigned i = 0, e = blocksToMerge.size(); i != e; ++i)
835 updatePredecessors(blocksToMerge[i], /*clusterIndex=*/i + 1);
838 // Replace all uses of the merged blocks with the leader and erase them.
839 for (Block *block : blocksToMerge) {
840 block->replaceAllUsesWith(leaderBlock);
841 rewriter.eraseBlock(block);
843 return success();
846 /// Identify identical blocks within the given region and merge them, inserting
847 /// new block arguments as necessary. Returns success if any blocks were merged,
848 /// failure otherwise.
849 static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
850 Region &region) {
851 if (region.empty() || llvm::hasSingleElement(region))
852 return failure();
854 // Identify sets of blocks, other than the entry block, that branch to the
855 // same successors. We will use these groups to create clusters of equivalent
856 // blocks.
857 DenseMap<SuccessorRange, SmallVector<Block *, 1>> matchingSuccessors;
858 for (Block &block : llvm::drop_begin(region, 1))
859 matchingSuccessors[block.getSuccessors()].push_back(&block);
861 bool mergedAnyBlocks = false;
862 for (ArrayRef<Block *> blocks : llvm::make_second_range(matchingSuccessors)) {
863 if (blocks.size() == 1)
864 continue;
866 SmallVector<BlockMergeCluster, 1> clusters;
867 for (Block *block : blocks) {
868 BlockEquivalenceData data(block);
870 // Don't allow merging if this block has any regions.
871 // TODO: Add support for regions if necessary.
872 bool hasNonEmptyRegion = llvm::any_of(*block, [](Operation &op) {
873 return llvm::any_of(op.getRegions(),
874 [](Region &region) { return !region.empty(); });
876 if (hasNonEmptyRegion)
877 continue;
879 // Don't allow merging if this block's arguments are used outside of the
880 // original block.
881 bool argHasExternalUsers = llvm::any_of(
882 block->getArguments(), [block](mlir::BlockArgument &arg) {
883 return arg.isUsedOutsideOfBlock(block);
885 if (argHasExternalUsers)
886 continue;
888 // Try to add this block to an existing cluster.
889 bool addedToCluster = false;
890 for (auto &cluster : clusters)
891 if ((addedToCluster = succeeded(cluster.addToCluster(data))))
892 break;
893 if (!addedToCluster)
894 clusters.emplace_back(std::move(data));
896 for (auto &cluster : clusters)
897 mergedAnyBlocks |= succeeded(cluster.merge(rewriter));
900 return success(mergedAnyBlocks);
903 /// Identify identical blocks within the given regions and merge them, inserting
904 /// new block arguments as necessary.
905 static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
906 MutableArrayRef<Region> regions) {
907 llvm::SmallSetVector<Region *, 1> worklist;
908 for (auto &region : regions)
909 worklist.insert(&region);
910 bool anyChanged = false;
911 while (!worklist.empty()) {
912 Region *region = worklist.pop_back_val();
913 if (succeeded(mergeIdenticalBlocks(rewriter, *region))) {
914 worklist.insert(region);
915 anyChanged = true;
918 // Add any nested regions to the worklist.
919 for (Block &block : *region)
920 for (auto &op : block)
921 for (auto &nestedRegion : op.getRegions())
922 worklist.insert(&nestedRegion);
925 return success(anyChanged);
928 /// If a block's argument is always the same across different invocations, then
929 /// drop the argument and use the value directly inside the block
930 static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
931 Block &block) {
932 SmallVector<size_t> argsToErase;
934 // Go through the arguments of the block.
935 for (auto [argIdx, blockOperand] : llvm::enumerate(block.getArguments())) {
936 bool sameArg = true;
937 Value commonValue;
939 // Go through the block predecessor and flag if they pass to the block
940 // different values for the same argument.
941 for (Block::pred_iterator predIt = block.pred_begin(),
942 predE = block.pred_end();
943 predIt != predE; ++predIt) {
944 auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
945 if (!branch) {
946 sameArg = false;
947 break;
949 unsigned succIndex = predIt.getSuccessorIndex();
950 SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
951 auto branchOperands = succOperands.getForwardedOperands();
952 if (!commonValue) {
953 commonValue = branchOperands[argIdx];
954 continue;
956 if (branchOperands[argIdx] != commonValue) {
957 sameArg = false;
958 break;
962 // If they are passing the same value, drop the argument.
963 if (commonValue && sameArg) {
964 argsToErase.push_back(argIdx);
966 // Remove the argument from the block.
967 rewriter.replaceAllUsesWith(blockOperand, commonValue);
971 // Remove the arguments.
972 for (size_t argIdx : llvm::reverse(argsToErase)) {
973 block.eraseArgument(argIdx);
975 // Remove the argument from the branch ops.
976 for (auto predIt = block.pred_begin(), predE = block.pred_end();
977 predIt != predE; ++predIt) {
978 auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
979 unsigned succIndex = predIt.getSuccessorIndex();
980 SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
981 succOperands.erase(argIdx);
984 return success(!argsToErase.empty());
987 /// This optimization drops redundant argument to blocks. I.e., if a given
988 /// argument to a block receives the same value from each of the block
989 /// predecessors, we can remove the argument from the block and use directly the
990 /// original value. This is a simple example:
992 /// %cond = llvm.call @rand() : () -> i1
993 /// %val0 = llvm.mlir.constant(1 : i64) : i64
994 /// %val1 = llvm.mlir.constant(2 : i64) : i64
995 /// %val2 = llvm.mlir.constant(3 : i64) : i64
996 /// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
997 /// : i64)
999 /// ^bb1(%arg0 : i64, %arg1 : i64):
1000 /// llvm.call @foo(%arg0, %arg1)
1002 /// The previous IR can be rewritten as:
1003 /// %cond = llvm.call @rand() : () -> i1
1004 /// %val0 = llvm.mlir.constant(1 : i64) : i64
1005 /// %val1 = llvm.mlir.constant(2 : i64) : i64
1006 /// %val2 = llvm.mlir.constant(3 : i64) : i64
1007 /// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
1009 /// ^bb1(%arg0 : i64):
1010 /// llvm.call @foo(%val0, %arg0)
1012 static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
1013 MutableArrayRef<Region> regions) {
1014 llvm::SmallSetVector<Region *, 1> worklist;
1015 for (Region &region : regions)
1016 worklist.insert(&region);
1017 bool anyChanged = false;
1018 while (!worklist.empty()) {
1019 Region *region = worklist.pop_back_val();
1021 // Add any nested regions to the worklist.
1022 for (Block &block : *region) {
1023 anyChanged =
1024 succeeded(dropRedundantArguments(rewriter, block)) || anyChanged;
1026 for (Operation &op : block)
1027 for (Region &nestedRegion : op.getRegions())
1028 worklist.insert(&nestedRegion);
1031 return success(anyChanged);
1034 //===----------------------------------------------------------------------===//
1035 // Region Simplification
1036 //===----------------------------------------------------------------------===//
1038 /// Run a set of structural simplifications over the given regions. This
1039 /// includes transformations like unreachable block elimination, dead argument
1040 /// elimination, as well as some other DCE. This function returns success if any
1041 /// of the regions were simplified, failure otherwise.
1042 LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
1043 MutableArrayRef<Region> regions,
1044 bool mergeBlocks) {
1045 bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
1046 bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
1047 bool mergedIdenticalBlocks = false;
1048 bool droppedRedundantArguments = false;
1049 if (mergeBlocks) {
1050 mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
1051 droppedRedundantArguments =
1052 succeeded(dropRedundantArguments(rewriter, regions));
1054 return success(eliminatedBlocks || eliminatedOpsOrArgs ||
1055 mergedIdenticalBlocks || droppedRedundantArguments);