1 //===- ControlFlowSinkUtils.cpp - Code to perform control-flow sinking ----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements utilities for control-flow sinking. Control-flow
10 // sinking moves operations whose only uses are in conditionally-executed blocks
11 // into those blocks so that they aren't executed on paths where their results
14 // Control-flow sinking is not implemented on BranchOpInterface because
15 // sinking ops into the successors of branch operations may move ops into loops.
16 // It is idiomatic MLIR to perform optimizations at IR levels that readily
17 // provide the necessary information.
19 //===----------------------------------------------------------------------===//
21 #include "mlir/Transforms/ControlFlowSinkUtils.h"
22 #include "mlir/IR/Dominance.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/Interfaces/ControlFlowInterfaces.h"
27 #define DEBUG_TYPE "cf-sink"
32 /// A helper struct for control-flow sinking.
35 /// Create an operation sinker with given dominance info.
36 Sinker(function_ref
<bool(Operation
*, Region
*)> shouldMoveIntoRegion
,
37 function_ref
<void(Operation
*, Region
*)> moveIntoRegion
,
38 DominanceInfo
&domInfo
)
39 : shouldMoveIntoRegion(shouldMoveIntoRegion
),
40 moveIntoRegion(moveIntoRegion
), domInfo(domInfo
) {}
42 /// Given a list of regions, find operations to sink and sink them. Return the
43 /// number of operations sunk.
44 size_t sinkRegions(RegionRange regions
);
47 /// Given a region and an op which dominates the region, returns true if all
48 /// users of the given op are dominated by the entry block of the region, and
49 /// thus the operation can be sunk into the region.
50 bool allUsersDominatedBy(Operation
*op
, Region
*region
);
52 /// Given a region and a top-level op (an op whose parent region is the given
53 /// region), determine whether the defining ops of the op's operands can be
54 /// sunk into the region.
56 /// Add moved ops to the work queue.
57 void tryToSinkPredecessors(Operation
*user
, Region
*region
,
58 std::vector
<Operation
*> &stack
);
60 /// Iterate over all the ops in a region and try to sink their predecessors.
61 /// Recurse on subgraphs using a work queue.
62 void sinkRegion(Region
*region
);
64 /// The callback to determine whether an op should be moved in to a region.
65 function_ref
<bool(Operation
*, Region
*)> shouldMoveIntoRegion
;
66 /// The calback to move an operation into the region.
67 function_ref
<void(Operation
*, Region
*)> moveIntoRegion
;
68 /// Dominance info to determine op user dominance with respect to regions.
69 DominanceInfo
&domInfo
;
70 /// The number of operations sunk.
73 } // end anonymous namespace
75 bool Sinker::allUsersDominatedBy(Operation
*op
, Region
*region
) {
76 assert(region
->findAncestorOpInRegion(*op
) == nullptr &&
77 "expected op to be defined outside the region");
78 return llvm::all_of(op
->getUsers(), [&](Operation
*user
) {
79 // The user is dominated by the region if its containing block is dominated
80 // by the region's entry block.
81 return domInfo
.dominates(®ion
->front(), user
->getBlock());
85 void Sinker::tryToSinkPredecessors(Operation
*user
, Region
*region
,
86 std::vector
<Operation
*> &stack
) {
87 LLVM_DEBUG(user
->print(llvm::dbgs() << "\nContained op:\n"));
88 for (Value value
: user
->getOperands()) {
89 Operation
*op
= value
.getDefiningOp();
90 // Ignore block arguments and ops that are already inside the region.
91 if (!op
|| op
->getParentRegion() == region
)
93 LLVM_DEBUG(op
->print(llvm::dbgs() << "\nTry to sink:\n"));
95 // If the op's users are all in the region and it can be moved, then do so.
96 if (allUsersDominatedBy(op
, region
) && shouldMoveIntoRegion(op
, region
)) {
97 moveIntoRegion(op
, region
);
99 // Add the op to the work queue.
105 void Sinker::sinkRegion(Region
*region
) {
106 // Initialize the work queue with all the ops in the region.
107 std::vector
<Operation
*> stack
;
108 for (Operation
&op
: region
->getOps())
109 stack
.push_back(&op
);
111 // Process all the ops depth-first. This ensures that nodes of subgraphs are
112 // sunk in the correct order.
113 while (!stack
.empty()) {
114 Operation
*op
= stack
.back();
116 tryToSinkPredecessors(op
, region
, stack
);
120 size_t Sinker::sinkRegions(RegionRange regions
) {
121 for (Region
*region
: regions
)
122 if (!region
->empty())
127 size_t mlir::controlFlowSink(
128 RegionRange regions
, DominanceInfo
&domInfo
,
129 function_ref
<bool(Operation
*, Region
*)> shouldMoveIntoRegion
,
130 function_ref
<void(Operation
*, Region
*)> moveIntoRegion
) {
131 return Sinker(shouldMoveIntoRegion
, moveIntoRegion
, domInfo
)
132 .sinkRegions(regions
);
135 void mlir::getSinglyExecutedRegionsToSink(RegionBranchOpInterface branch
,
136 SmallVectorImpl
<Region
*> ®ions
) {
137 // Collect constant operands.
138 SmallVector
<Attribute
> operands(branch
->getNumOperands(), Attribute());
139 for (auto [idx
, operand
] : llvm::enumerate(branch
->getOperands()))
140 (void)matchPattern(operand
, m_Constant(&operands
[idx
]));
142 // Get the invocation bounds.
143 SmallVector
<InvocationBounds
> bounds
;
144 branch
.getRegionInvocationBounds(operands
, bounds
);
146 // For a simple control-flow sink, only consider regions that are executed at
148 for (auto it
: llvm::zip(branch
->getRegions(), bounds
)) {
149 const InvocationBounds
&bound
= std::get
<1>(it
);
150 if (bound
.getUpperBound() && *bound
.getUpperBound() <= 1)
151 regions
.push_back(&std::get
<0>(it
));