1 //===------- LoopBoundSplit.cpp - Split Loop Bound --------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Transforms/Scalar/LoopBoundSplit.h"
10 #include "llvm/Analysis/LoopAccessAnalysis.h"
11 #include "llvm/Analysis/LoopAnalysisManager.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/LoopIterator.h"
14 #include "llvm/Analysis/LoopPass.h"
15 #include "llvm/Analysis/MemorySSA.h"
16 #include "llvm/Analysis/MemorySSAUpdater.h"
17 #include "llvm/Analysis/ScalarEvolution.h"
18 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
19 #include "llvm/IR/PatternMatch.h"
20 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
21 #include "llvm/Transforms/Utils/Cloning.h"
22 #include "llvm/Transforms/Utils/LoopSimplify.h"
23 #include "llvm/Transforms/Utils/LoopUtils.h"
24 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
26 #define DEBUG_TYPE "loop-bound-split"
30 using namespace PatternMatch
;
33 struct ConditionInfo
{
34 /// Branch instruction with this condition
36 /// ICmp instruction with this condition
39 ICmpInst::Predicate Pred
;
45 const SCEV
*AddRecSCEV
;
47 const SCEV
*BoundSCEV
;
50 : BI(nullptr), ICmp(nullptr), Pred(ICmpInst::BAD_ICMP_PREDICATE
),
51 AddRecValue(nullptr), BoundValue(nullptr), AddRecSCEV(nullptr),
56 static void analyzeICmp(ScalarEvolution
&SE
, ICmpInst
*ICmp
,
57 ConditionInfo
&Cond
) {
59 if (match(ICmp
, m_ICmp(Cond
.Pred
, m_Value(Cond
.AddRecValue
),
60 m_Value(Cond
.BoundValue
)))) {
61 Cond
.AddRecSCEV
= SE
.getSCEV(Cond
.AddRecValue
);
62 Cond
.BoundSCEV
= SE
.getSCEV(Cond
.BoundValue
);
63 // Locate AddRec in LHSSCEV and Bound in RHSSCEV.
64 if (isa
<SCEVAddRecExpr
>(Cond
.BoundSCEV
) &&
65 !isa
<SCEVAddRecExpr
>(Cond
.AddRecSCEV
)) {
66 std::swap(Cond
.AddRecValue
, Cond
.BoundValue
);
67 std::swap(Cond
.AddRecSCEV
, Cond
.BoundSCEV
);
68 Cond
.Pred
= ICmpInst::getSwappedPredicate(Cond
.Pred
);
73 static bool calculateUpperBound(const Loop
&L
, ScalarEvolution
&SE
,
74 ConditionInfo
&Cond
, bool IsExitCond
) {
76 const SCEV
*ExitCount
= SE
.getExitCount(&L
, Cond
.ICmp
->getParent());
77 if (isa
<SCEVCouldNotCompute
>(ExitCount
))
80 Cond
.BoundSCEV
= ExitCount
;
84 // For non-exit condtion, if pred is LT, keep existing bound.
85 if (Cond
.Pred
== ICmpInst::ICMP_SLT
|| Cond
.Pred
== ICmpInst::ICMP_ULT
)
88 // For non-exit condition, if pre is LE, try to convert it to LT.
90 // AddRec <= Bound --> AddRec < Bound + 1
91 if (Cond
.Pred
!= ICmpInst::ICMP_ULE
&& Cond
.Pred
!= ICmpInst::ICMP_SLE
)
94 if (IntegerType
*BoundSCEVIntType
=
95 dyn_cast
<IntegerType
>(Cond
.BoundSCEV
->getType())) {
96 unsigned BitWidth
= BoundSCEVIntType
->getBitWidth();
97 APInt Max
= ICmpInst::isSigned(Cond
.Pred
)
98 ? APInt::getSignedMaxValue(BitWidth
)
99 : APInt::getMaxValue(BitWidth
);
100 const SCEV
*MaxSCEV
= SE
.getConstant(Max
);
101 // Check Bound < INT_MAX
102 ICmpInst::Predicate Pred
=
103 ICmpInst::isSigned(Cond
.Pred
) ? ICmpInst::ICMP_SLT
: ICmpInst::ICMP_ULT
;
104 if (SE
.isKnownPredicate(Pred
, Cond
.BoundSCEV
, MaxSCEV
)) {
105 const SCEV
*BoundPlusOneSCEV
=
106 SE
.getAddExpr(Cond
.BoundSCEV
, SE
.getOne(BoundSCEVIntType
));
107 Cond
.BoundSCEV
= BoundPlusOneSCEV
;
113 // ToDo: Support ICMP_NE/EQ.
118 static bool hasProcessableCondition(const Loop
&L
, ScalarEvolution
&SE
,
119 ICmpInst
*ICmp
, ConditionInfo
&Cond
,
121 analyzeICmp(SE
, ICmp
, Cond
);
123 // The BoundSCEV should be evaluated at loop entry.
124 if (!SE
.isAvailableAtLoopEntry(Cond
.BoundSCEV
, &L
))
127 const SCEVAddRecExpr
*AddRecSCEV
= dyn_cast
<SCEVAddRecExpr
>(Cond
.AddRecSCEV
);
128 // Allowed AddRec as induction variable.
132 if (!AddRecSCEV
->isAffine())
135 const SCEV
*StepRecSCEV
= AddRecSCEV
->getStepRecurrence(SE
);
136 // Allowed constant step.
137 if (!isa
<SCEVConstant
>(StepRecSCEV
))
140 ConstantInt
*StepCI
= cast
<SCEVConstant
>(StepRecSCEV
)->getValue();
141 // Allowed positive step for now.
142 // TODO: Support negative step.
143 if (StepCI
->isNegative() || StepCI
->isZero())
146 // Calculate upper bound.
147 if (!calculateUpperBound(L
, SE
, Cond
, IsExitCond
))
153 static bool isProcessableCondBI(const ScalarEvolution
&SE
,
154 const BranchInst
*BI
) {
155 BasicBlock
*TrueSucc
= nullptr;
156 BasicBlock
*FalseSucc
= nullptr;
157 ICmpInst::Predicate Pred
;
159 if (!match(BI
, m_Br(m_ICmp(Pred
, m_Value(LHS
), m_Value(RHS
)),
160 m_BasicBlock(TrueSucc
), m_BasicBlock(FalseSucc
))))
163 if (!SE
.isSCEVable(LHS
->getType()))
165 assert(SE
.isSCEVable(RHS
->getType()) && "Expected RHS's type is SCEVable");
167 if (TrueSucc
== FalseSucc
)
173 static bool canSplitLoopBound(const Loop
&L
, const DominatorTree
&DT
,
174 ScalarEvolution
&SE
, ConditionInfo
&Cond
) {
175 // Skip function with optsize.
176 if (L
.getHeader()->getParent()->hasOptSize())
179 // Split only innermost loop.
180 if (!L
.isInnermost())
183 // Check loop is in simplified form.
184 if (!L
.isLoopSimplifyForm())
187 // Check loop is in LCSSA form.
188 if (!L
.isLCSSAForm(DT
))
191 // Skip loop that cannot be cloned.
192 if (!L
.isSafeToClone())
195 BasicBlock
*ExitingBB
= L
.getExitingBlock();
196 // Assumed only one exiting block.
200 BranchInst
*ExitingBI
= dyn_cast
<BranchInst
>(ExitingBB
->getTerminator());
204 // Allowed only conditional branch with ICmp.
205 if (!isProcessableCondBI(SE
, ExitingBI
))
208 // Check the condition is processable.
209 ICmpInst
*ICmp
= cast
<ICmpInst
>(ExitingBI
->getCondition());
210 if (!hasProcessableCondition(L
, SE
, ICmp
, Cond
, /*IsExitCond*/ true))
217 static bool isProfitableToTransform(const Loop
&L
, const BranchInst
*BI
) {
218 // If the conditional branch splits a loop into two halves, we could
219 // generally say it is profitable.
221 // ToDo: Add more profitable cases here.
223 // Check this branch causes diamond CFG.
224 BasicBlock
*Succ0
= BI
->getSuccessor(0);
225 BasicBlock
*Succ1
= BI
->getSuccessor(1);
227 BasicBlock
*Succ0Succ
= Succ0
->getSingleSuccessor();
228 BasicBlock
*Succ1Succ
= Succ1
->getSingleSuccessor();
229 if (!Succ0Succ
|| !Succ1Succ
|| Succ0Succ
!= Succ1Succ
)
232 // ToDo: Calculate each successor's instruction cost.
237 static BranchInst
*findSplitCandidate(const Loop
&L
, ScalarEvolution
&SE
,
238 ConditionInfo
&ExitingCond
,
239 ConditionInfo
&SplitCandidateCond
) {
240 for (auto *BB
: L
.blocks()) {
241 // Skip condition of backedge.
242 if (L
.getLoopLatch() == BB
)
245 auto *BI
= dyn_cast
<BranchInst
>(BB
->getTerminator());
249 // Check conditional branch with ICmp.
250 if (!isProcessableCondBI(SE
, BI
))
253 // Skip loop invariant condition.
254 if (L
.isLoopInvariant(BI
->getCondition()))
257 // Check the condition is processable.
258 ICmpInst
*ICmp
= cast
<ICmpInst
>(BI
->getCondition());
259 if (!hasProcessableCondition(L
, SE
, ICmp
, SplitCandidateCond
,
260 /*IsExitCond*/ false))
263 if (ExitingCond
.BoundSCEV
->getType() !=
264 SplitCandidateCond
.BoundSCEV
->getType())
267 SplitCandidateCond
.BI
= BI
;
274 static bool splitLoopBound(Loop
&L
, DominatorTree
&DT
, LoopInfo
&LI
,
275 ScalarEvolution
&SE
, LPMUpdater
&U
) {
276 ConditionInfo SplitCandidateCond
;
277 ConditionInfo ExitingCond
;
279 // Check we can split this loop's bound.
280 if (!canSplitLoopBound(L
, DT
, SE
, ExitingCond
))
283 if (!findSplitCandidate(L
, SE
, ExitingCond
, SplitCandidateCond
))
286 if (!isProfitableToTransform(L
, SplitCandidateCond
.BI
))
289 // Now, we have a split candidate. Let's build a form as below.
290 // +--------------------+
292 // | set up newbound |
293 // +--------------------+
294 // | /----------------\
295 // +--------v----v------+ |
297 // | with true condition| | |
298 // +--------------------+ | |
300 // +--------v-----------+ | |
301 // | if.then.BB | | |
302 // +--------------------+ | |
304 // +--------v-----------<---/ |
305 // | latch >----------/
307 // +--------------------+
309 // +--------v-----------+
310 // | preheader2 |--------------\
311 // | if (AddRec i != | |
313 // +--------------------+ |
314 // | /----------------\ |
315 // +--------v----v------+ | |
316 // | header2 |---\ | |
317 // | conditional branch | | | |
318 // |with false condition| | | |
319 // +--------------------+ | | |
321 // +--------v-----------+ | | |
322 // | if.then.BB2 | | | |
323 // +--------------------+ | | |
325 // +--------v-----------<---/ | |
326 // | latch2 >----------/ |
327 // | with org bound | |
328 // +--------v-----------+ |
330 // | +---------------+ |
331 // +--> exit <-------/
334 // Let's create post loop.
335 SmallVector
<BasicBlock
*, 8> PostLoopBlocks
;
337 ValueToValueMapTy VMap
;
338 BasicBlock
*PreHeader
= L
.getLoopPreheader();
339 BasicBlock
*SplitLoopPH
= SplitEdge(PreHeader
, L
.getHeader(), &DT
, &LI
);
340 PostLoop
= cloneLoopWithPreheader(L
.getExitBlock(), SplitLoopPH
, &L
, VMap
,
341 ".split", &LI
, &DT
, PostLoopBlocks
);
342 remapInstructionsInBlocks(PostLoopBlocks
, VMap
);
344 // Add conditional branch to check we can skip post-loop in its preheader.
345 BasicBlock
*PostLoopPreHeader
= PostLoop
->getLoopPreheader();
346 IRBuilder
<> Builder(PostLoopPreHeader
);
347 Instruction
*OrigBI
= PostLoopPreHeader
->getTerminator();
348 ICmpInst::Predicate Pred
= ICmpInst::ICMP_NE
;
350 Builder
.CreateICmp(Pred
, ExitingCond
.AddRecValue
, ExitingCond
.BoundValue
);
351 Builder
.CreateCondBr(Cond
, PostLoop
->getHeader(), PostLoop
->getExitBlock());
352 OrigBI
->eraseFromParent();
354 // Create new loop bound and add it into preheader of pre-loop.
355 const SCEV
*NewBoundSCEV
= ExitingCond
.BoundSCEV
;
356 const SCEV
*SplitBoundSCEV
= SplitCandidateCond
.BoundSCEV
;
357 NewBoundSCEV
= ICmpInst::isSigned(ExitingCond
.Pred
)
358 ? SE
.getSMinExpr(NewBoundSCEV
, SplitBoundSCEV
)
359 : SE
.getUMinExpr(NewBoundSCEV
, SplitBoundSCEV
);
361 SCEVExpander
Expander(
362 SE
, L
.getHeader()->getParent()->getParent()->getDataLayout(), "split");
363 Instruction
*InsertPt
= SplitLoopPH
->getTerminator();
364 Value
*NewBoundValue
=
365 Expander
.expandCodeFor(NewBoundSCEV
, NewBoundSCEV
->getType(), InsertPt
);
366 NewBoundValue
->setName("new.bound");
368 // Replace exiting bound value of pre-loop NewBound.
369 ExitingCond
.ICmp
->setOperand(1, NewBoundValue
);
371 // Replace IV's start value of post-loop by NewBound.
372 for (PHINode
&PN
: L
.getHeader()->phis()) {
373 // Find PHI with exiting condition from pre-loop.
374 if (SE
.isSCEVable(PN
.getType()) && isa
<SCEVAddRecExpr
>(SE
.getSCEV(&PN
))) {
375 for (Value
*Op
: PN
.incoming_values()) {
376 if (Op
== ExitingCond
.AddRecValue
) {
377 // Find cloned PHI for post-loop.
378 PHINode
*PostLoopPN
= cast
<PHINode
>(VMap
[&PN
]);
379 PostLoopPN
->setIncomingValueForBlock(PostLoopPreHeader
,
386 // Replace SplitCandidateCond.BI's condition of pre-loop by True.
387 LLVMContext
&Context
= PreHeader
->getContext();
388 SplitCandidateCond
.BI
->setCondition(ConstantInt::getTrue(Context
));
390 // Replace cloned SplitCandidateCond.BI's condition in post-loop by False.
391 BranchInst
*ClonedSplitCandidateBI
=
392 cast
<BranchInst
>(VMap
[SplitCandidateCond
.BI
]);
393 ClonedSplitCandidateBI
->setCondition(ConstantInt::getFalse(Context
));
395 // Replace exit branch target of pre-loop by post-loop's preheader.
396 if (L
.getExitBlock() == ExitingCond
.BI
->getSuccessor(0))
397 ExitingCond
.BI
->setSuccessor(0, PostLoopPreHeader
);
399 ExitingCond
.BI
->setSuccessor(1, PostLoopPreHeader
);
401 // Update dominator tree.
402 DT
.changeImmediateDominator(PostLoopPreHeader
, L
.getExitingBlock());
403 DT
.changeImmediateDominator(PostLoop
->getExitBlock(), PostLoopPreHeader
);
405 // Invalidate cached SE information.
408 // Canonicalize loops.
409 // TODO: Try to update LCSSA information according to above change.
410 formLCSSA(L
, DT
, &LI
, &SE
);
411 simplifyLoop(&L
, &DT
, &LI
, &SE
, nullptr, nullptr, true);
412 formLCSSA(*PostLoop
, DT
, &LI
, &SE
);
413 simplifyLoop(PostLoop
, &DT
, &LI
, &SE
, nullptr, nullptr, true);
415 // Add new post-loop to loop pass manager.
416 U
.addSiblingLoops(PostLoop
);
421 PreservedAnalyses
LoopBoundSplitPass::run(Loop
&L
, LoopAnalysisManager
&AM
,
422 LoopStandardAnalysisResults
&AR
,
424 Function
&F
= *L
.getHeader()->getParent();
427 LLVM_DEBUG(dbgs() << "Spliting bound of loop in " << F
.getName() << ": " << L
430 if (!splitLoopBound(L
, AR
.DT
, AR
.LI
, AR
.SE
, U
))
431 return PreservedAnalyses::all();
433 assert(AR
.DT
.verify(DominatorTree::VerificationLevel::Fast
));
436 return getLoopPassPreservedAnalyses();
439 } // end namespace llvm