1 #include "llvm/Transforms/Utils/LoopConstrainer.h"
2 #include "llvm/Analysis/LoopInfo.h"
3 #include "llvm/Analysis/ScalarEvolution.h"
4 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
5 #include "llvm/IR/Dominators.h"
6 #include "llvm/Transforms/Utils/Cloning.h"
7 #include "llvm/Transforms/Utils/LoopSimplify.h"
8 #include "llvm/Transforms/Utils/LoopUtils.h"
9 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
13 static const char *ClonedLoopTag
= "loop_constrainer.loop.clone";
15 #define DEBUG_TYPE "loop-constrainer"
17 /// Given a loop with an deccreasing induction variable, is it possible to
18 /// safely calculate the bounds of a new loop using the given Predicate.
19 static bool isSafeDecreasingBound(const SCEV
*Start
, const SCEV
*BoundSCEV
,
20 const SCEV
*Step
, ICmpInst::Predicate Pred
,
21 unsigned LatchBrExitIdx
, Loop
*L
,
22 ScalarEvolution
&SE
) {
23 if (Pred
!= ICmpInst::ICMP_SLT
&& Pred
!= ICmpInst::ICMP_SGT
&&
24 Pred
!= ICmpInst::ICMP_ULT
&& Pred
!= ICmpInst::ICMP_UGT
)
27 if (!SE
.isAvailableAtLoopEntry(BoundSCEV
, L
))
30 assert(SE
.isKnownNegative(Step
) && "expecting negative step");
32 LLVM_DEBUG(dbgs() << "isSafeDecreasingBound with:\n");
33 LLVM_DEBUG(dbgs() << "Start: " << *Start
<< "\n");
34 LLVM_DEBUG(dbgs() << "Step: " << *Step
<< "\n");
35 LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV
<< "\n");
36 LLVM_DEBUG(dbgs() << "Pred: " << Pred
<< "\n");
37 LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx
<< "\n");
39 bool IsSigned
= ICmpInst::isSigned(Pred
);
40 // The predicate that we need to check that the induction variable lies
42 ICmpInst::Predicate BoundPred
=
43 IsSigned
? CmpInst::ICMP_SGT
: CmpInst::ICMP_UGT
;
45 if (LatchBrExitIdx
== 1)
46 return SE
.isLoopEntryGuardedByCond(L
, BoundPred
, Start
, BoundSCEV
);
48 assert(LatchBrExitIdx
== 0 && "LatchBrExitIdx should be either 0 or 1");
50 const SCEV
*StepPlusOne
= SE
.getAddExpr(Step
, SE
.getOne(Step
->getType()));
51 unsigned BitWidth
= cast
<IntegerType
>(BoundSCEV
->getType())->getBitWidth();
52 APInt Min
= IsSigned
? APInt::getSignedMinValue(BitWidth
)
53 : APInt::getMinValue(BitWidth
);
54 const SCEV
*Limit
= SE
.getMinusSCEV(SE
.getConstant(Min
), StepPlusOne
);
56 const SCEV
*MinusOne
=
57 SE
.getMinusSCEV(BoundSCEV
, SE
.getOne(BoundSCEV
->getType()));
59 return SE
.isLoopEntryGuardedByCond(L
, BoundPred
, Start
, MinusOne
) &&
60 SE
.isLoopEntryGuardedByCond(L
, BoundPred
, BoundSCEV
, Limit
);
63 /// Given a loop with an increasing induction variable, is it possible to
64 /// safely calculate the bounds of a new loop using the given Predicate.
65 static bool isSafeIncreasingBound(const SCEV
*Start
, const SCEV
*BoundSCEV
,
66 const SCEV
*Step
, ICmpInst::Predicate Pred
,
67 unsigned LatchBrExitIdx
, Loop
*L
,
68 ScalarEvolution
&SE
) {
69 if (Pred
!= ICmpInst::ICMP_SLT
&& Pred
!= ICmpInst::ICMP_SGT
&&
70 Pred
!= ICmpInst::ICMP_ULT
&& Pred
!= ICmpInst::ICMP_UGT
)
73 if (!SE
.isAvailableAtLoopEntry(BoundSCEV
, L
))
76 LLVM_DEBUG(dbgs() << "isSafeIncreasingBound with:\n");
77 LLVM_DEBUG(dbgs() << "Start: " << *Start
<< "\n");
78 LLVM_DEBUG(dbgs() << "Step: " << *Step
<< "\n");
79 LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV
<< "\n");
80 LLVM_DEBUG(dbgs() << "Pred: " << Pred
<< "\n");
81 LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx
<< "\n");
83 bool IsSigned
= ICmpInst::isSigned(Pred
);
84 // The predicate that we need to check that the induction variable lies
86 ICmpInst::Predicate BoundPred
=
87 IsSigned
? CmpInst::ICMP_SLT
: CmpInst::ICMP_ULT
;
89 if (LatchBrExitIdx
== 1)
90 return SE
.isLoopEntryGuardedByCond(L
, BoundPred
, Start
, BoundSCEV
);
92 assert(LatchBrExitIdx
== 0 && "LatchBrExitIdx should be 0 or 1");
94 const SCEV
*StepMinusOne
= SE
.getMinusSCEV(Step
, SE
.getOne(Step
->getType()));
95 unsigned BitWidth
= cast
<IntegerType
>(BoundSCEV
->getType())->getBitWidth();
96 APInt Max
= IsSigned
? APInt::getSignedMaxValue(BitWidth
)
97 : APInt::getMaxValue(BitWidth
);
98 const SCEV
*Limit
= SE
.getMinusSCEV(SE
.getConstant(Max
), StepMinusOne
);
100 return (SE
.isLoopEntryGuardedByCond(L
, BoundPred
, Start
,
101 SE
.getAddExpr(BoundSCEV
, Step
)) &&
102 SE
.isLoopEntryGuardedByCond(L
, BoundPred
, BoundSCEV
, Limit
));
105 /// Returns estimate for max latch taken count of the loop of the narrowest
106 /// available type. If the latch block has such estimate, it is returned.
107 /// Otherwise, we use max exit count of whole loop (that is potentially of wider
108 /// type than latch check itself), which is still better than no estimate.
109 static const SCEV
*getNarrowestLatchMaxTakenCountEstimate(ScalarEvolution
&SE
,
111 const SCEV
*FromBlock
=
112 SE
.getExitCount(&L
, L
.getLoopLatch(), ScalarEvolution::SymbolicMaximum
);
113 if (isa
<SCEVCouldNotCompute
>(FromBlock
))
114 return SE
.getSymbolicMaxBackedgeTakenCount(&L
);
118 std::optional
<LoopStructure
>
119 LoopStructure::parseLoopStructure(ScalarEvolution
&SE
, Loop
&L
,
120 bool AllowUnsignedLatchCond
,
121 const char *&FailureReason
) {
122 if (!L
.isLoopSimplifyForm()) {
123 FailureReason
= "loop not in LoopSimplify form";
127 BasicBlock
*Latch
= L
.getLoopLatch();
128 assert(Latch
&& "Simplified loops only have one latch!");
130 if (Latch
->getTerminator()->getMetadata(ClonedLoopTag
)) {
131 FailureReason
= "loop has already been cloned";
135 if (!L
.isLoopExiting(Latch
)) {
136 FailureReason
= "no loop latch";
140 BasicBlock
*Header
= L
.getHeader();
141 BasicBlock
*Preheader
= L
.getLoopPreheader();
143 FailureReason
= "no preheader";
147 BranchInst
*LatchBr
= dyn_cast
<BranchInst
>(Latch
->getTerminator());
148 if (!LatchBr
|| LatchBr
->isUnconditional()) {
149 FailureReason
= "latch terminator not conditional branch";
153 unsigned LatchBrExitIdx
= LatchBr
->getSuccessor(0) == Header
? 1 : 0;
155 ICmpInst
*ICI
= dyn_cast
<ICmpInst
>(LatchBr
->getCondition());
156 if (!ICI
|| !isa
<IntegerType
>(ICI
->getOperand(0)->getType())) {
157 FailureReason
= "latch terminator branch not conditional on integral icmp";
161 const SCEV
*MaxBETakenCount
= getNarrowestLatchMaxTakenCountEstimate(SE
, L
);
162 if (isa
<SCEVCouldNotCompute
>(MaxBETakenCount
)) {
163 FailureReason
= "could not compute latch count";
166 assert(SE
.getLoopDisposition(MaxBETakenCount
, &L
) ==
167 ScalarEvolution::LoopInvariant
&&
168 "loop variant exit count doesn't make sense!");
170 ICmpInst::Predicate Pred
= ICI
->getPredicate();
171 Value
*LeftValue
= ICI
->getOperand(0);
172 const SCEV
*LeftSCEV
= SE
.getSCEV(LeftValue
);
173 IntegerType
*IndVarTy
= cast
<IntegerType
>(LeftValue
->getType());
175 Value
*RightValue
= ICI
->getOperand(1);
176 const SCEV
*RightSCEV
= SE
.getSCEV(RightValue
);
178 // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence.
179 if (!isa
<SCEVAddRecExpr
>(LeftSCEV
)) {
180 if (isa
<SCEVAddRecExpr
>(RightSCEV
)) {
181 std::swap(LeftSCEV
, RightSCEV
);
182 std::swap(LeftValue
, RightValue
);
183 Pred
= ICmpInst::getSwappedPredicate(Pred
);
185 FailureReason
= "no add recurrences in the icmp";
190 auto HasNoSignedWrap
= [&](const SCEVAddRecExpr
*AR
) {
191 if (AR
->getNoWrapFlags(SCEV::FlagNSW
))
194 IntegerType
*Ty
= cast
<IntegerType
>(AR
->getType());
195 IntegerType
*WideTy
=
196 IntegerType::get(Ty
->getContext(), Ty
->getBitWidth() * 2);
198 const SCEVAddRecExpr
*ExtendAfterOp
=
199 dyn_cast
<SCEVAddRecExpr
>(SE
.getSignExtendExpr(AR
, WideTy
));
201 const SCEV
*ExtendedStart
= SE
.getSignExtendExpr(AR
->getStart(), WideTy
);
202 const SCEV
*ExtendedStep
=
203 SE
.getSignExtendExpr(AR
->getStepRecurrence(SE
), WideTy
);
205 bool NoSignedWrap
= ExtendAfterOp
->getStart() == ExtendedStart
&&
206 ExtendAfterOp
->getStepRecurrence(SE
) == ExtendedStep
;
212 // We may have proved this when computing the sign extension above.
213 return AR
->getNoWrapFlags(SCEV::FlagNSW
) != SCEV::FlagAnyWrap
;
216 // `ICI` is interpreted as taking the backedge if the *next* value of the
217 // induction variable satisfies some constraint.
219 const SCEVAddRecExpr
*IndVarBase
= cast
<SCEVAddRecExpr
>(LeftSCEV
);
220 if (IndVarBase
->getLoop() != &L
) {
221 FailureReason
= "LHS in cmp is not an AddRec for this loop";
224 if (!IndVarBase
->isAffine()) {
225 FailureReason
= "LHS in icmp not induction variable";
228 const SCEV
*StepRec
= IndVarBase
->getStepRecurrence(SE
);
229 if (!isa
<SCEVConstant
>(StepRec
)) {
230 FailureReason
= "LHS in icmp not induction variable";
233 ConstantInt
*StepCI
= cast
<SCEVConstant
>(StepRec
)->getValue();
235 if (ICI
->isEquality() && !HasNoSignedWrap(IndVarBase
)) {
236 FailureReason
= "LHS in icmp needs nsw for equality predicates";
240 assert(!StepCI
->isZero() && "Zero step?");
241 bool IsIncreasing
= !StepCI
->isNegative();
242 bool IsSignedPredicate
;
243 const SCEV
*StartNext
= IndVarBase
->getStart();
244 const SCEV
*Addend
= SE
.getNegativeSCEV(IndVarBase
->getStepRecurrence(SE
));
245 const SCEV
*IndVarStart
= SE
.getAddExpr(StartNext
, Addend
);
246 const SCEV
*Step
= SE
.getSCEV(StepCI
);
248 const SCEV
*FixedRightSCEV
= nullptr;
250 // If RightValue resides within loop (but still being loop invariant),
251 // regenerate it as preheader.
252 if (auto *I
= dyn_cast
<Instruction
>(RightValue
))
253 if (L
.contains(I
->getParent()))
254 FixedRightSCEV
= RightSCEV
;
257 bool DecreasedRightValueByOne
= false;
258 if (StepCI
->isOne()) {
259 // Try to turn eq/ne predicates to those we can work with.
260 if (Pred
== ICmpInst::ICMP_NE
&& LatchBrExitIdx
== 1)
261 // while (++i != len) { while (++i < len) {
264 // If both parts are known non-negative, it is profitable to use
265 // unsigned comparison in increasing loop. This allows us to make the
266 // comparison check against "RightSCEV + 1" more optimistic.
267 if (isKnownNonNegativeInLoop(IndVarStart
, &L
, SE
) &&
268 isKnownNonNegativeInLoop(RightSCEV
, &L
, SE
))
269 Pred
= ICmpInst::ICMP_ULT
;
271 Pred
= ICmpInst::ICMP_SLT
;
272 else if (Pred
== ICmpInst::ICMP_EQ
&& LatchBrExitIdx
== 0) {
273 // while (true) { while (true) {
274 // if (++i == len) ---> if (++i > len - 1)
278 if (IndVarBase
->getNoWrapFlags(SCEV::FlagNUW
) &&
279 cannotBeMinInLoop(RightSCEV
, &L
, SE
, /*Signed*/ false)) {
280 Pred
= ICmpInst::ICMP_UGT
;
282 SE
.getMinusSCEV(RightSCEV
, SE
.getOne(RightSCEV
->getType()));
283 DecreasedRightValueByOne
= true;
284 } else if (cannotBeMinInLoop(RightSCEV
, &L
, SE
, /*Signed*/ true)) {
285 Pred
= ICmpInst::ICMP_SGT
;
287 SE
.getMinusSCEV(RightSCEV
, SE
.getOne(RightSCEV
->getType()));
288 DecreasedRightValueByOne
= true;
293 bool LTPred
= (Pred
== ICmpInst::ICMP_SLT
|| Pred
== ICmpInst::ICMP_ULT
);
294 bool GTPred
= (Pred
== ICmpInst::ICMP_SGT
|| Pred
== ICmpInst::ICMP_UGT
);
295 bool FoundExpectedPred
=
296 (LTPred
&& LatchBrExitIdx
== 1) || (GTPred
&& LatchBrExitIdx
== 0);
298 if (!FoundExpectedPred
) {
299 FailureReason
= "expected icmp slt semantically, found something else";
303 IsSignedPredicate
= ICmpInst::isSigned(Pred
);
304 if (!IsSignedPredicate
&& !AllowUnsignedLatchCond
) {
305 FailureReason
= "unsigned latch conditions are explicitly prohibited";
309 if (!isSafeIncreasingBound(IndVarStart
, RightSCEV
, Step
, Pred
,
310 LatchBrExitIdx
, &L
, SE
)) {
311 FailureReason
= "Unsafe loop bounds";
314 if (LatchBrExitIdx
== 0) {
315 // We need to increase the right value unless we have already decreased
316 // it virtually when we replaced EQ with SGT.
317 if (!DecreasedRightValueByOne
)
319 SE
.getAddExpr(RightSCEV
, SE
.getOne(RightSCEV
->getType()));
321 assert(!DecreasedRightValueByOne
&&
322 "Right value can be decreased only for LatchBrExitIdx == 0!");
325 bool IncreasedRightValueByOne
= false;
326 if (StepCI
->isMinusOne()) {
327 // Try to turn eq/ne predicates to those we can work with.
328 if (Pred
== ICmpInst::ICMP_NE
&& LatchBrExitIdx
== 1)
329 // while (--i != len) { while (--i > len) {
332 // We intentionally don't turn the predicate into UGT even if we know
333 // that both operands are non-negative, because it will only pessimize
334 // our check against "RightSCEV - 1".
335 Pred
= ICmpInst::ICMP_SGT
;
336 else if (Pred
== ICmpInst::ICMP_EQ
&& LatchBrExitIdx
== 0) {
337 // while (true) { while (true) {
338 // if (--i == len) ---> if (--i < len + 1)
342 if (IndVarBase
->getNoWrapFlags(SCEV::FlagNUW
) &&
343 cannotBeMaxInLoop(RightSCEV
, &L
, SE
, /* Signed */ false)) {
344 Pred
= ICmpInst::ICMP_ULT
;
345 RightSCEV
= SE
.getAddExpr(RightSCEV
, SE
.getOne(RightSCEV
->getType()));
346 IncreasedRightValueByOne
= true;
347 } else if (cannotBeMaxInLoop(RightSCEV
, &L
, SE
, /* Signed */ true)) {
348 Pred
= ICmpInst::ICMP_SLT
;
349 RightSCEV
= SE
.getAddExpr(RightSCEV
, SE
.getOne(RightSCEV
->getType()));
350 IncreasedRightValueByOne
= true;
355 bool LTPred
= (Pred
== ICmpInst::ICMP_SLT
|| Pred
== ICmpInst::ICMP_ULT
);
356 bool GTPred
= (Pred
== ICmpInst::ICMP_SGT
|| Pred
== ICmpInst::ICMP_UGT
);
358 bool FoundExpectedPred
=
359 (GTPred
&& LatchBrExitIdx
== 1) || (LTPred
&& LatchBrExitIdx
== 0);
361 if (!FoundExpectedPred
) {
362 FailureReason
= "expected icmp sgt semantically, found something else";
367 Pred
== ICmpInst::ICMP_SLT
|| Pred
== ICmpInst::ICMP_SGT
;
369 if (!IsSignedPredicate
&& !AllowUnsignedLatchCond
) {
370 FailureReason
= "unsigned latch conditions are explicitly prohibited";
374 if (!isSafeDecreasingBound(IndVarStart
, RightSCEV
, Step
, Pred
,
375 LatchBrExitIdx
, &L
, SE
)) {
376 FailureReason
= "Unsafe bounds";
380 if (LatchBrExitIdx
== 0) {
381 // We need to decrease the right value unless we have already increased
382 // it virtually when we replaced EQ with SLT.
383 if (!IncreasedRightValueByOne
)
385 SE
.getMinusSCEV(RightSCEV
, SE
.getOne(RightSCEV
->getType()));
387 assert(!IncreasedRightValueByOne
&&
388 "Right value can be increased only for LatchBrExitIdx == 0!");
391 BasicBlock
*LatchExit
= LatchBr
->getSuccessor(LatchBrExitIdx
);
393 assert(!L
.contains(LatchExit
) && "expected an exit block!");
394 const DataLayout
&DL
= Preheader
->getModule()->getDataLayout();
395 SCEVExpander
Expander(SE
, DL
, "loop-constrainer");
396 Instruction
*Ins
= Preheader
->getTerminator();
400 Expander
.expandCodeFor(FixedRightSCEV
, FixedRightSCEV
->getType(), Ins
);
402 Value
*IndVarStartV
= Expander
.expandCodeFor(IndVarStart
, IndVarTy
, Ins
);
403 IndVarStartV
->setName("indvar.start");
405 LoopStructure Result
;
408 Result
.Header
= Header
;
409 Result
.Latch
= Latch
;
410 Result
.LatchBr
= LatchBr
;
411 Result
.LatchExit
= LatchExit
;
412 Result
.LatchBrExitIdx
= LatchBrExitIdx
;
413 Result
.IndVarStart
= IndVarStartV
;
414 Result
.IndVarStep
= StepCI
;
415 Result
.IndVarBase
= LeftValue
;
416 Result
.IndVarIncreasing
= IsIncreasing
;
417 Result
.LoopExitAt
= RightValue
;
418 Result
.IsSignedPredicate
= IsSignedPredicate
;
419 Result
.ExitCountTy
= cast
<IntegerType
>(MaxBETakenCount
->getType());
421 FailureReason
= nullptr;
426 // Add metadata to the loop L to disable loop optimizations. Callers need to
427 // confirm that optimizing loop L is not beneficial.
428 static void DisableAllLoopOptsOnLoop(Loop
&L
) {
429 // We do not care about any existing loopID related metadata for L, since we
430 // are setting all loop metadata to false.
431 LLVMContext
&Context
= L
.getHeader()->getContext();
432 // Reserve first location for self reference to the LoopID metadata node.
433 MDNode
*Dummy
= MDNode::get(Context
, {});
434 MDNode
*DisableUnroll
= MDNode::get(
435 Context
, {MDString::get(Context
, "llvm.loop.unroll.disable")});
437 ConstantAsMetadata::get(ConstantInt::get(Type::getInt1Ty(Context
), 0));
438 MDNode
*DisableVectorize
= MDNode::get(
440 {MDString::get(Context
, "llvm.loop.vectorize.enable"), FalseVal
});
441 MDNode
*DisableLICMVersioning
= MDNode::get(
442 Context
, {MDString::get(Context
, "llvm.loop.licm_versioning.disable")});
443 MDNode
*DisableDistribution
= MDNode::get(
445 {MDString::get(Context
, "llvm.loop.distribute.enable"), FalseVal
});
447 MDNode::get(Context
, {Dummy
, DisableUnroll
, DisableVectorize
,
448 DisableLICMVersioning
, DisableDistribution
});
449 // Set operand 0 to refer to the loop id itself.
450 NewLoopID
->replaceOperandWith(0, NewLoopID
);
451 L
.setLoopID(NewLoopID
);
454 LoopConstrainer::LoopConstrainer(Loop
&L
, LoopInfo
&LI
,
455 function_ref
<void(Loop
*, bool)> LPMAddNewLoop
,
456 const LoopStructure
&LS
, ScalarEvolution
&SE
,
457 DominatorTree
&DT
, Type
*T
, SubRanges SR
)
458 : F(*L
.getHeader()->getParent()), Ctx(L
.getHeader()->getContext()), SE(SE
),
459 DT(DT
), LI(LI
), LPMAddNewLoop(LPMAddNewLoop
), OriginalLoop(L
), RangeTy(T
),
460 MainLoopStructure(LS
), SR(SR
) {}
462 void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop
&Result
,
463 const char *Tag
) const {
464 for (BasicBlock
*BB
: OriginalLoop
.getBlocks()) {
465 BasicBlock
*Clone
= CloneBasicBlock(BB
, Result
.Map
, Twine(".") + Tag
, &F
);
466 Result
.Blocks
.push_back(Clone
);
467 Result
.Map
[BB
] = Clone
;
470 auto GetClonedValue
= [&Result
](Value
*V
) {
471 assert(V
&& "null values not in domain!");
472 auto It
= Result
.Map
.find(V
);
473 if (It
== Result
.Map
.end())
475 return static_cast<Value
*>(It
->second
);
479 cast
<BasicBlock
>(GetClonedValue(OriginalLoop
.getLoopLatch()));
480 ClonedLatch
->getTerminator()->setMetadata(ClonedLoopTag
,
481 MDNode::get(Ctx
, {}));
483 Result
.Structure
= MainLoopStructure
.map(GetClonedValue
);
484 Result
.Structure
.Tag
= Tag
;
486 for (unsigned i
= 0, e
= Result
.Blocks
.size(); i
!= e
; ++i
) {
487 BasicBlock
*ClonedBB
= Result
.Blocks
[i
];
488 BasicBlock
*OriginalBB
= OriginalLoop
.getBlocks()[i
];
490 assert(Result
.Map
[OriginalBB
] == ClonedBB
&& "invariant!");
492 for (Instruction
&I
: *ClonedBB
)
493 RemapInstruction(&I
, Result
.Map
,
494 RF_NoModuleLevelChanges
| RF_IgnoreMissingLocals
);
496 // Exit blocks will now have one more predecessor and their PHI nodes need
497 // to be edited to reflect that. No phi nodes need to be introduced because
498 // the loop is in LCSSA.
500 for (auto *SBB
: successors(OriginalBB
)) {
501 if (OriginalLoop
.contains(SBB
))
502 continue; // not an exit block
504 for (PHINode
&PN
: SBB
->phis()) {
505 Value
*OldIncoming
= PN
.getIncomingValueForBlock(OriginalBB
);
506 PN
.addIncoming(GetClonedValue(OldIncoming
), ClonedBB
);
513 LoopConstrainer::RewrittenRangeInfo
LoopConstrainer::changeIterationSpaceEnd(
514 const LoopStructure
&LS
, BasicBlock
*Preheader
, Value
*ExitSubloopAt
,
515 BasicBlock
*ContinuationBlock
) const {
516 // We start with a loop with a single latch:
518 // +--------------------+
522 // +--------+-----------+
523 // | ----------------\
525 // +--------v----v------+ |
529 // +--------------------+ |
533 // +--------------------+ |
535 // | latch >----------/
537 // +-------v------------+
540 // | +--------------------+
542 // +---> original exit |
544 // +--------------------+
546 // We change the control flow to look like
549 // +--------------------+
551 // | preheader >-------------------------+
553 // +--------v-----------+ |
554 // | /-------------+ |
556 // +--------v--v--------+ | |
558 // | header | | +--------+ |
560 // +--------------------+ | | +-----v-----v-----------+
562 // | | | .pseudo.exit |
564 // | | +-----------v-----------+
567 // | | +--------v-------------+
568 // +--------------------+ | | | |
569 // | | | | | ContinuationBlock |
570 // | latch >------+ | | |
571 // | | | +----------------------+
572 // +---------v----------+ |
575 // | +---------------^-----+
577 // +-----> .exit.selector |
579 // +----------v----------+
581 // +--------------------+ |
583 // | original exit <----+
585 // +--------------------+
587 RewrittenRangeInfo RRI
;
589 BasicBlock
*BBInsertLocation
= LS
.Latch
->getNextNode();
590 RRI
.ExitSelector
= BasicBlock::Create(Ctx
, Twine(LS
.Tag
) + ".exit.selector",
591 &F
, BBInsertLocation
);
592 RRI
.PseudoExit
= BasicBlock::Create(Ctx
, Twine(LS
.Tag
) + ".pseudo.exit", &F
,
595 BranchInst
*PreheaderJump
= cast
<BranchInst
>(Preheader
->getTerminator());
596 bool Increasing
= LS
.IndVarIncreasing
;
597 bool IsSignedPredicate
= LS
.IsSignedPredicate
;
599 IRBuilder
<> B(PreheaderJump
);
600 auto NoopOrExt
= [&](Value
*V
) {
601 if (V
->getType() == RangeTy
)
603 return IsSignedPredicate
? B
.CreateSExt(V
, RangeTy
, "wide." + V
->getName())
604 : B
.CreateZExt(V
, RangeTy
, "wide." + V
->getName());
607 // EnterLoopCond - is it okay to start executing this `LS'?
608 Value
*EnterLoopCond
= nullptr;
611 ? (IsSignedPredicate
? ICmpInst::ICMP_SLT
: ICmpInst::ICMP_ULT
)
612 : (IsSignedPredicate
? ICmpInst::ICMP_SGT
: ICmpInst::ICMP_UGT
);
613 Value
*IndVarStart
= NoopOrExt(LS
.IndVarStart
);
614 EnterLoopCond
= B
.CreateICmp(Pred
, IndVarStart
, ExitSubloopAt
);
616 B
.CreateCondBr(EnterLoopCond
, LS
.Header
, RRI
.PseudoExit
);
617 PreheaderJump
->eraseFromParent();
619 LS
.LatchBr
->setSuccessor(LS
.LatchBrExitIdx
, RRI
.ExitSelector
);
620 B
.SetInsertPoint(LS
.LatchBr
);
621 Value
*IndVarBase
= NoopOrExt(LS
.IndVarBase
);
622 Value
*TakeBackedgeLoopCond
= B
.CreateICmp(Pred
, IndVarBase
, ExitSubloopAt
);
624 Value
*CondForBranch
= LS
.LatchBrExitIdx
== 1
625 ? TakeBackedgeLoopCond
626 : B
.CreateNot(TakeBackedgeLoopCond
);
628 LS
.LatchBr
->setCondition(CondForBranch
);
630 B
.SetInsertPoint(RRI
.ExitSelector
);
632 // IterationsLeft - are there any more iterations left, given the original
633 // upper bound on the induction variable? If not, we branch to the "real"
635 Value
*LoopExitAt
= NoopOrExt(LS
.LoopExitAt
);
636 Value
*IterationsLeft
= B
.CreateICmp(Pred
, IndVarBase
, LoopExitAt
);
637 B
.CreateCondBr(IterationsLeft
, RRI
.PseudoExit
, LS
.LatchExit
);
639 BranchInst
*BranchToContinuation
=
640 BranchInst::Create(ContinuationBlock
, RRI
.PseudoExit
);
642 // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of
643 // each of the PHI nodes in the loop header. This feeds into the initial
644 // value of the same PHI nodes if/when we continue execution.
645 for (PHINode
&PN
: LS
.Header
->phis()) {
646 PHINode
*NewPHI
= PHINode::Create(PN
.getType(), 2, PN
.getName() + ".copy",
647 BranchToContinuation
);
649 NewPHI
->addIncoming(PN
.getIncomingValueForBlock(Preheader
), Preheader
);
650 NewPHI
->addIncoming(PN
.getIncomingValueForBlock(LS
.Latch
),
652 RRI
.PHIValuesAtPseudoExit
.push_back(NewPHI
);
655 RRI
.IndVarEnd
= PHINode::Create(IndVarBase
->getType(), 2, "indvar.end",
656 BranchToContinuation
);
657 RRI
.IndVarEnd
->addIncoming(IndVarStart
, Preheader
);
658 RRI
.IndVarEnd
->addIncoming(IndVarBase
, RRI
.ExitSelector
);
660 // The latch exit now has a branch from `RRI.ExitSelector' instead of
661 // `LS.Latch'. The PHI nodes need to be updated to reflect that.
662 LS
.LatchExit
->replacePhiUsesWith(LS
.Latch
, RRI
.ExitSelector
);
667 void LoopConstrainer::rewriteIncomingValuesForPHIs(
668 LoopStructure
&LS
, BasicBlock
*ContinuationBlock
,
669 const LoopConstrainer::RewrittenRangeInfo
&RRI
) const {
670 unsigned PHIIndex
= 0;
671 for (PHINode
&PN
: LS
.Header
->phis())
672 PN
.setIncomingValueForBlock(ContinuationBlock
,
673 RRI
.PHIValuesAtPseudoExit
[PHIIndex
++]);
675 LS
.IndVarStart
= RRI
.IndVarEnd
;
678 BasicBlock
*LoopConstrainer::createPreheader(const LoopStructure
&LS
,
679 BasicBlock
*OldPreheader
,
680 const char *Tag
) const {
681 BasicBlock
*Preheader
= BasicBlock::Create(Ctx
, Tag
, &F
, LS
.Header
);
682 BranchInst::Create(LS
.Header
, Preheader
);
684 LS
.Header
->replacePhiUsesWith(OldPreheader
, Preheader
);
689 void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef
<BasicBlock
*> BBs
) {
690 Loop
*ParentLoop
= OriginalLoop
.getParentLoop();
694 for (BasicBlock
*BB
: BBs
)
695 ParentLoop
->addBasicBlockToLoop(BB
, LI
);
698 Loop
*LoopConstrainer::createClonedLoopStructure(Loop
*Original
, Loop
*Parent
,
699 ValueToValueMapTy
&VM
,
701 Loop
&New
= *LI
.AllocateLoop();
703 Parent
->addChildLoop(&New
);
705 LI
.addTopLevelLoop(&New
);
706 LPMAddNewLoop(&New
, IsSubloop
);
708 // Add all of the blocks in Original to the new loop.
709 for (auto *BB
: Original
->blocks())
710 if (LI
.getLoopFor(BB
) == Original
)
711 New
.addBasicBlockToLoop(cast
<BasicBlock
>(VM
[BB
]), LI
);
713 // Add all of the subloops to the new loop.
714 for (Loop
*SubLoop
: *Original
)
715 createClonedLoopStructure(SubLoop
, &New
, VM
, /* IsSubloop */ true);
720 bool LoopConstrainer::run() {
721 BasicBlock
*Preheader
= OriginalLoop
.getLoopPreheader();
722 assert(Preheader
!= nullptr && "precondition!");
724 OriginalPreheader
= Preheader
;
725 MainLoopPreheader
= Preheader
;
726 bool IsSignedPredicate
= MainLoopStructure
.IsSignedPredicate
;
727 bool Increasing
= MainLoopStructure
.IndVarIncreasing
;
728 IntegerType
*IVTy
= cast
<IntegerType
>(RangeTy
);
730 SCEVExpander
Expander(SE
, F
.getParent()->getDataLayout(), "loop-constrainer");
731 Instruction
*InsertPt
= OriginalPreheader
->getTerminator();
733 // It would have been better to make `PreLoop' and `PostLoop'
734 // `std::optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy
736 ClonedLoop PreLoop
, PostLoop
;
738 Increasing
? SR
.LowLimit
.has_value() : SR
.HighLimit
.has_value();
740 Increasing
? SR
.HighLimit
.has_value() : SR
.LowLimit
.has_value();
742 Value
*ExitPreLoopAt
= nullptr;
743 Value
*ExitMainLoopAt
= nullptr;
744 const SCEVConstant
*MinusOneS
=
745 cast
<SCEVConstant
>(SE
.getConstant(IVTy
, -1, true /* isSigned */));
748 const SCEV
*ExitPreLoopAtSCEV
= nullptr;
751 ExitPreLoopAtSCEV
= *SR
.LowLimit
;
752 else if (cannotBeMinInLoop(*SR
.HighLimit
, &OriginalLoop
, SE
,
754 ExitPreLoopAtSCEV
= SE
.getAddExpr(*SR
.HighLimit
, MinusOneS
);
756 LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing "
757 << "preloop exit limit. HighLimit = "
758 << *(*SR
.HighLimit
) << "\n");
762 if (!Expander
.isSafeToExpandAt(ExitPreLoopAtSCEV
, InsertPt
)) {
763 LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the"
764 << " preloop exit limit " << *ExitPreLoopAtSCEV
765 << " at block " << InsertPt
->getParent()->getName()
770 ExitPreLoopAt
= Expander
.expandCodeFor(ExitPreLoopAtSCEV
, IVTy
, InsertPt
);
771 ExitPreLoopAt
->setName("exit.preloop.at");
775 const SCEV
*ExitMainLoopAtSCEV
= nullptr;
778 ExitMainLoopAtSCEV
= *SR
.HighLimit
;
779 else if (cannotBeMinInLoop(*SR
.LowLimit
, &OriginalLoop
, SE
,
781 ExitMainLoopAtSCEV
= SE
.getAddExpr(*SR
.LowLimit
, MinusOneS
);
783 LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing "
784 << "mainloop exit limit. LowLimit = "
785 << *(*SR
.LowLimit
) << "\n");
789 if (!Expander
.isSafeToExpandAt(ExitMainLoopAtSCEV
, InsertPt
)) {
790 LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the"
791 << " main loop exit limit " << *ExitMainLoopAtSCEV
792 << " at block " << InsertPt
->getParent()->getName()
797 ExitMainLoopAt
= Expander
.expandCodeFor(ExitMainLoopAtSCEV
, IVTy
, InsertPt
);
798 ExitMainLoopAt
->setName("exit.mainloop.at");
801 // We clone these ahead of time so that we don't have to deal with changing
802 // and temporarily invalid IR as we transform the loops.
804 cloneLoop(PreLoop
, "preloop");
806 cloneLoop(PostLoop
, "postloop");
808 RewrittenRangeInfo PreLoopRRI
;
811 Preheader
->getTerminator()->replaceUsesOfWith(MainLoopStructure
.Header
,
812 PreLoop
.Structure
.Header
);
815 createPreheader(MainLoopStructure
, Preheader
, "mainloop");
816 PreLoopRRI
= changeIterationSpaceEnd(PreLoop
.Structure
, Preheader
,
817 ExitPreLoopAt
, MainLoopPreheader
);
818 rewriteIncomingValuesForPHIs(MainLoopStructure
, MainLoopPreheader
,
822 BasicBlock
*PostLoopPreheader
= nullptr;
823 RewrittenRangeInfo PostLoopRRI
;
827 createPreheader(PostLoop
.Structure
, Preheader
, "postloop");
828 PostLoopRRI
= changeIterationSpaceEnd(MainLoopStructure
, MainLoopPreheader
,
829 ExitMainLoopAt
, PostLoopPreheader
);
830 rewriteIncomingValuesForPHIs(PostLoop
.Structure
, PostLoopPreheader
,
834 BasicBlock
*NewMainLoopPreheader
=
835 MainLoopPreheader
!= Preheader
? MainLoopPreheader
: nullptr;
836 BasicBlock
*NewBlocks
[] = {PostLoopPreheader
, PreLoopRRI
.PseudoExit
,
837 PreLoopRRI
.ExitSelector
, PostLoopRRI
.PseudoExit
,
838 PostLoopRRI
.ExitSelector
, NewMainLoopPreheader
};
840 // Some of the above may be nullptr, filter them out before passing to
841 // addToParentLoopIfNeeded.
843 std::remove(std::begin(NewBlocks
), std::end(NewBlocks
), nullptr);
845 addToParentLoopIfNeeded(ArrayRef(std::begin(NewBlocks
), NewBlocksEnd
));
849 // We need to first add all the pre and post loop blocks into the loop
850 // structures (as part of createClonedLoopStructure), and then update the
851 // LCSSA form and LoopSimplifyForm. This is necessary for correctly updating
852 // LI when LoopSimplifyForm is generated.
853 Loop
*PreL
= nullptr, *PostL
= nullptr;
854 if (!PreLoop
.Blocks
.empty()) {
855 PreL
= createClonedLoopStructure(&OriginalLoop
,
856 OriginalLoop
.getParentLoop(), PreLoop
.Map
,
857 /* IsSubLoop */ false);
860 if (!PostLoop
.Blocks
.empty()) {
862 createClonedLoopStructure(&OriginalLoop
, OriginalLoop
.getParentLoop(),
863 PostLoop
.Map
, /* IsSubLoop */ false);
866 // This function canonicalizes the loop into Loop-Simplify and LCSSA forms.
867 auto CanonicalizeLoop
= [&](Loop
*L
, bool IsOriginalLoop
) {
868 formLCSSARecursively(*L
, DT
, &LI
, &SE
);
869 simplifyLoop(L
, &DT
, &LI
, &SE
, nullptr, nullptr, true);
870 // Pre/post loops are slow paths, we do not need to perform any loop
871 // optimizations on them.
873 DisableAllLoopOptsOnLoop(*L
);
876 CanonicalizeLoop(PreL
, false);
878 CanonicalizeLoop(PostL
, false);
879 CanonicalizeLoop(&OriginalLoop
, true);
882 /// - We've broken a "main loop" out of the loop in a way that the "main loop"
883 /// runs with the induction variable in a subset of [Begin, End).
884 /// - There is no overflow when computing "main loop" exit limit.
885 /// - Max latch taken count of the loop is limited.
886 /// It guarantees that induction variable will not overflow iterating in the
888 if (isa
<OverflowingBinaryOperator
>(MainLoopStructure
.IndVarBase
))
889 if (IsSignedPredicate
)
890 cast
<BinaryOperator
>(MainLoopStructure
.IndVarBase
)
891 ->setHasNoSignedWrap(true);
892 /// TODO: support unsigned predicate.
893 /// To add NUW flag we need to prove that both operands of BO are
894 /// non-negative. E.g:
896 /// %iv.next = add nsw i32 %iv, -1
897 /// %cmp = icmp ult i32 %iv.next, %n
898 /// br i1 %cmp, label %loopexit, label %loop
900 /// -1 is MAX_UINT in terms of unsigned int. Adding anything but zero will
901 /// overflow, therefore NUW flag is not legal here.