1 //===-- SPIRVMergeRegionExitTargets.cpp ----------------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Merge the multiple exit targets of a convergence region into a single block.
10 // Each exit target will be assigned a constant value, and a phi node + switch
11 // will allow the new exit target to re-route to the correct basic block.
13 //===----------------------------------------------------------------------===//
15 #include "Analysis/SPIRVConvergenceRegionAnalysis.h"
17 #include "SPIRVSubtarget.h"
18 #include "SPIRVTargetMachine.h"
19 #include "SPIRVUtils.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/SmallPtrSet.h"
22 #include "llvm/Analysis/LoopInfo.h"
23 #include "llvm/CodeGen/IntrinsicLowering.h"
24 #include "llvm/IR/CFG.h"
25 #include "llvm/IR/Dominators.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/IntrinsicInst.h"
28 #include "llvm/IR/Intrinsics.h"
29 #include "llvm/IR/IntrinsicsSPIRV.h"
30 #include "llvm/InitializePasses.h"
31 #include "llvm/Transforms/Utils/Cloning.h"
32 #include "llvm/Transforms/Utils/LoopSimplify.h"
33 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
38 void initializeSPIRVMergeRegionExitTargetsPass(PassRegistry
&);
40 class SPIRVMergeRegionExitTargets
: public FunctionPass
{
44 SPIRVMergeRegionExitTargets() : FunctionPass(ID
) {
45 initializeSPIRVMergeRegionExitTargetsPass(*PassRegistry::getPassRegistry());
48 // Gather all the successors of |BB|.
49 // This function asserts if the terminator neither a branch, switch or return.
50 std::unordered_set
<BasicBlock
*> gatherSuccessors(BasicBlock
*BB
) {
51 std::unordered_set
<BasicBlock
*> output
;
52 auto *T
= BB
->getTerminator();
54 if (auto *BI
= dyn_cast
<BranchInst
>(T
)) {
55 output
.insert(BI
->getSuccessor(0));
56 if (BI
->isConditional())
57 output
.insert(BI
->getSuccessor(1));
61 if (auto *SI
= dyn_cast
<SwitchInst
>(T
)) {
62 output
.insert(SI
->getDefaultDest());
63 for (auto &Case
: SI
->cases())
64 output
.insert(Case
.getCaseSuccessor());
68 assert(isa
<ReturnInst
>(T
) && "Unhandled terminator type.");
72 /// Create a value in BB set to the value associated with the branch the block
73 /// terminator will take.
74 llvm::Value
*createExitVariable(
76 const DenseMap
<BasicBlock
*, ConstantInt
*> &TargetToValue
) {
77 auto *T
= BB
->getTerminator();
78 if (isa
<ReturnInst
>(T
))
81 IRBuilder
<> Builder(BB
);
82 Builder
.SetInsertPoint(T
);
84 if (auto *BI
= dyn_cast
<BranchInst
>(T
)) {
86 BasicBlock
*LHSTarget
= BI
->getSuccessor(0);
87 BasicBlock
*RHSTarget
=
88 BI
->isConditional() ? BI
->getSuccessor(1) : nullptr;
90 Value
*LHS
= TargetToValue
.count(LHSTarget
) != 0
91 ? TargetToValue
.at(LHSTarget
)
93 Value
*RHS
= TargetToValue
.count(RHSTarget
) != 0
94 ? TargetToValue
.at(RHSTarget
)
97 if (LHS
== nullptr || RHS
== nullptr)
98 return LHS
== nullptr ? RHS
: LHS
;
99 return Builder
.CreateSelect(BI
->getCondition(), LHS
, RHS
);
102 // TODO: add support for switch cases.
103 llvm_unreachable("Unhandled terminator type.");
106 /// Replaces |BB|'s branch targets present in |ToReplace| with |NewTarget|.
107 void replaceBranchTargets(BasicBlock
*BB
,
108 const SmallPtrSet
<BasicBlock
*, 4> &ToReplace
,
109 BasicBlock
*NewTarget
) {
110 auto *T
= BB
->getTerminator();
111 if (isa
<ReturnInst
>(T
))
114 if (auto *BI
= dyn_cast
<BranchInst
>(T
)) {
115 for (size_t i
= 0; i
< BI
->getNumSuccessors(); i
++) {
116 if (ToReplace
.count(BI
->getSuccessor(i
)) != 0)
117 BI
->setSuccessor(i
, NewTarget
);
122 if (auto *SI
= dyn_cast
<SwitchInst
>(T
)) {
123 for (size_t i
= 0; i
< SI
->getNumSuccessors(); i
++) {
124 if (ToReplace
.count(SI
->getSuccessor(i
)) != 0)
125 SI
->setSuccessor(i
, NewTarget
);
130 assert(false && "Unhandled terminator type.");
133 // Run the pass on the given convergence region, ignoring the sub-regions.
134 // Returns true if the CFG changed, false otherwise.
135 bool runOnConvergenceRegionNoRecurse(LoopInfo
&LI
,
136 const SPIRV::ConvergenceRegion
*CR
) {
137 // Gather all the exit targets for this region.
138 SmallPtrSet
<BasicBlock
*, 4> ExitTargets
;
139 for (BasicBlock
*Exit
: CR
->Exits
) {
140 for (BasicBlock
*Target
: gatherSuccessors(Exit
)) {
141 if (CR
->Blocks
.count(Target
) == 0)
142 ExitTargets
.insert(Target
);
146 // If we have zero or one exit target, nothing do to.
147 if (ExitTargets
.size() <= 1)
150 // Create the new single exit target.
151 auto F
= CR
->Entry
->getParent();
152 auto NewExitTarget
= BasicBlock::Create(F
->getContext(), "new.exit", F
);
153 IRBuilder
<> Builder(NewExitTarget
);
155 // CodeGen output needs to be stable. Using the set as-is would order
156 // the targets differently depending on the allocation pattern.
157 // Sorting per basic-block ordering in the function.
158 std::vector
<BasicBlock
*> SortedExitTargets
;
159 std::vector
<BasicBlock
*> SortedExits
;
160 for (BasicBlock
&BB
: *F
) {
161 if (ExitTargets
.count(&BB
) != 0)
162 SortedExitTargets
.push_back(&BB
);
163 if (CR
->Exits
.count(&BB
) != 0)
164 SortedExits
.push_back(&BB
);
167 // Creating one constant per distinct exit target. This will be route to the
169 DenseMap
<BasicBlock
*, ConstantInt
*> TargetToValue
;
170 for (BasicBlock
*Target
: SortedExitTargets
)
171 TargetToValue
.insert(
172 std::make_pair(Target
, Builder
.getInt32(TargetToValue
.size())));
174 // Creating one variable per exit node, set to the constant matching the
175 // targeted external block.
176 std::vector
<std::pair
<BasicBlock
*, Value
*>> ExitToVariable
;
177 for (auto Exit
: SortedExits
) {
178 llvm::Value
*Value
= createExitVariable(Exit
, TargetToValue
);
179 ExitToVariable
.emplace_back(std::make_pair(Exit
, Value
));
182 // Gather the correct value depending on the exit we came from.
183 llvm::PHINode
*node
=
184 Builder
.CreatePHI(Builder
.getInt32Ty(), ExitToVariable
.size());
185 for (auto [BB
, Value
] : ExitToVariable
) {
186 node
->addIncoming(Value
, BB
);
189 // Creating the switch to jump to the correct exit target.
190 llvm::SwitchInst
*Sw
= Builder
.CreateSwitch(node
, SortedExitTargets
[0],
191 SortedExitTargets
.size() - 1);
192 for (size_t i
= 1; i
< SortedExitTargets
.size(); i
++) {
193 BasicBlock
*BB
= SortedExitTargets
[i
];
194 Sw
->addCase(TargetToValue
[BB
], BB
);
197 // Fix exit branches to redirect to the new exit.
198 for (auto Exit
: CR
->Exits
)
199 replaceBranchTargets(Exit
, ExitTargets
, NewExitTarget
);
204 /// Run the pass on the given convergence region and sub-regions (DFS).
205 /// Returns true if a region/sub-region was modified, false otherwise.
206 /// This returns as soon as one region/sub-region has been modified.
207 bool runOnConvergenceRegion(LoopInfo
&LI
,
208 const SPIRV::ConvergenceRegion
*CR
) {
209 for (auto *Child
: CR
->Children
)
210 if (runOnConvergenceRegion(LI
, Child
))
213 return runOnConvergenceRegionNoRecurse(LI
, CR
);
217 /// Validates each edge exiting the region has the same destination basic
219 void validateRegionExits(const SPIRV::ConvergenceRegion
*CR
) {
220 for (auto *Child
: CR
->Children
)
221 validateRegionExits(Child
);
223 std::unordered_set
<BasicBlock
*> ExitTargets
;
224 for (auto *Exit
: CR
->Exits
) {
225 auto Set
= gatherSuccessors(Exit
);
226 for (auto *BB
: Set
) {
227 if (CR
->Blocks
.count(BB
) == 0)
228 ExitTargets
.insert(BB
);
232 assert(ExitTargets
.size() <= 1);
236 virtual bool runOnFunction(Function
&F
) override
{
237 LoopInfo
&LI
= getAnalysis
<LoopInfoWrapperPass
>().getLoopInfo();
238 const auto *TopLevelRegion
=
239 getAnalysis
<SPIRVConvergenceRegionAnalysisWrapperPass
>()
241 .getTopLevelRegion();
243 // FIXME: very inefficient method: each time a region is modified, we bubble
244 // back up, and recompute the whole convergence region tree. Once the
245 // algorithm is completed and test coverage good enough, rewrite this pass
246 // to be efficient instead of simple.
247 bool modified
= false;
248 while (runOnConvergenceRegion(LI
, TopLevelRegion
)) {
249 TopLevelRegion
= getAnalysis
<SPIRVConvergenceRegionAnalysisWrapperPass
>()
251 .getTopLevelRegion();
255 #if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS)
256 validateRegionExits(TopLevelRegion
);
261 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
262 AU
.addRequired
<DominatorTreeWrapperPass
>();
263 AU
.addRequired
<LoopInfoWrapperPass
>();
264 AU
.addRequired
<SPIRVConvergenceRegionAnalysisWrapperPass
>();
265 FunctionPass::getAnalysisUsage(AU
);
270 char SPIRVMergeRegionExitTargets::ID
= 0;
272 INITIALIZE_PASS_BEGIN(SPIRVMergeRegionExitTargets
, "split-region-exit-blocks",
273 "SPIRV split region exit blocks", false, false)
274 INITIALIZE_PASS_DEPENDENCY(LoopSimplify
)
275 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass
)
276 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass
)
277 INITIALIZE_PASS_DEPENDENCY(SPIRVConvergenceRegionAnalysisWrapperPass
)
279 INITIALIZE_PASS_END(SPIRVMergeRegionExitTargets
, "split-region-exit-blocks",
280 "SPIRV split region exit blocks", false, false)
282 FunctionPass
*llvm::createSPIRVMergeRegionExitTargetsPass() {
283 return new SPIRVMergeRegionExitTargets();