[RISCV] Fix mgather -> riscv.masked.strided.load combine not extending indices (...
[llvm-project.git] / llvm / lib / Transforms / Utils / LoopConstrainer.cpp
blobea6d952cfa7d4f38c9cc5201b0dc352471bc45e4
1 #include "llvm/Transforms/Utils/LoopConstrainer.h"
2 #include "llvm/Analysis/LoopInfo.h"
3 #include "llvm/Analysis/ScalarEvolution.h"
4 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
5 #include "llvm/IR/Dominators.h"
6 #include "llvm/Transforms/Utils/Cloning.h"
7 #include "llvm/Transforms/Utils/LoopSimplify.h"
8 #include "llvm/Transforms/Utils/LoopUtils.h"
9 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
11 using namespace llvm;
13 static const char *ClonedLoopTag = "loop_constrainer.loop.clone";
15 #define DEBUG_TYPE "loop-constrainer"
17 /// Given a loop with an deccreasing induction variable, is it possible to
18 /// safely calculate the bounds of a new loop using the given Predicate.
19 static bool isSafeDecreasingBound(const SCEV *Start, const SCEV *BoundSCEV,
20 const SCEV *Step, ICmpInst::Predicate Pred,
21 unsigned LatchBrExitIdx, Loop *L,
22 ScalarEvolution &SE) {
23 if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
24 Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
25 return false;
27 if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
28 return false;
30 assert(SE.isKnownNegative(Step) && "expecting negative step");
32 LLVM_DEBUG(dbgs() << "isSafeDecreasingBound with:\n");
33 LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n");
34 LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n");
35 LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n");
36 LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n");
37 LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n");
39 bool IsSigned = ICmpInst::isSigned(Pred);
40 // The predicate that we need to check that the induction variable lies
41 // within bounds.
42 ICmpInst::Predicate BoundPred =
43 IsSigned ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT;
45 if (LatchBrExitIdx == 1)
46 return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV);
48 assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be either 0 or 1");
50 const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType()));
51 unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
52 APInt Min = IsSigned ? APInt::getSignedMinValue(BitWidth)
53 : APInt::getMinValue(BitWidth);
54 const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Min), StepPlusOne);
56 const SCEV *MinusOne =
57 SE.getMinusSCEV(BoundSCEV, SE.getOne(BoundSCEV->getType()));
59 return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, MinusOne) &&
60 SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit);
63 /// Given a loop with an increasing induction variable, is it possible to
64 /// safely calculate the bounds of a new loop using the given Predicate.
65 static bool isSafeIncreasingBound(const SCEV *Start, const SCEV *BoundSCEV,
66 const SCEV *Step, ICmpInst::Predicate Pred,
67 unsigned LatchBrExitIdx, Loop *L,
68 ScalarEvolution &SE) {
69 if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
70 Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
71 return false;
73 if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
74 return false;
76 LLVM_DEBUG(dbgs() << "isSafeIncreasingBound with:\n");
77 LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n");
78 LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n");
79 LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n");
80 LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n");
81 LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n");
83 bool IsSigned = ICmpInst::isSigned(Pred);
84 // The predicate that we need to check that the induction variable lies
85 // within bounds.
86 ICmpInst::Predicate BoundPred =
87 IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT;
89 if (LatchBrExitIdx == 1)
90 return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV);
92 assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1");
94 const SCEV *StepMinusOne = SE.getMinusSCEV(Step, SE.getOne(Step->getType()));
95 unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
96 APInt Max = IsSigned ? APInt::getSignedMaxValue(BitWidth)
97 : APInt::getMaxValue(BitWidth);
98 const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Max), StepMinusOne);
100 return (SE.isLoopEntryGuardedByCond(L, BoundPred, Start,
101 SE.getAddExpr(BoundSCEV, Step)) &&
102 SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit));
105 /// Returns estimate for max latch taken count of the loop of the narrowest
106 /// available type. If the latch block has such estimate, it is returned.
107 /// Otherwise, we use max exit count of whole loop (that is potentially of wider
108 /// type than latch check itself), which is still better than no estimate.
109 static const SCEV *getNarrowestLatchMaxTakenCountEstimate(ScalarEvolution &SE,
110 const Loop &L) {
111 const SCEV *FromBlock =
112 SE.getExitCount(&L, L.getLoopLatch(), ScalarEvolution::SymbolicMaximum);
113 if (isa<SCEVCouldNotCompute>(FromBlock))
114 return SE.getSymbolicMaxBackedgeTakenCount(&L);
115 return FromBlock;
118 std::optional<LoopStructure>
119 LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L,
120 bool AllowUnsignedLatchCond,
121 const char *&FailureReason) {
122 if (!L.isLoopSimplifyForm()) {
123 FailureReason = "loop not in LoopSimplify form";
124 return std::nullopt;
127 BasicBlock *Latch = L.getLoopLatch();
128 assert(Latch && "Simplified loops only have one latch!");
130 if (Latch->getTerminator()->getMetadata(ClonedLoopTag)) {
131 FailureReason = "loop has already been cloned";
132 return std::nullopt;
135 if (!L.isLoopExiting(Latch)) {
136 FailureReason = "no loop latch";
137 return std::nullopt;
140 BasicBlock *Header = L.getHeader();
141 BasicBlock *Preheader = L.getLoopPreheader();
142 if (!Preheader) {
143 FailureReason = "no preheader";
144 return std::nullopt;
147 BranchInst *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator());
148 if (!LatchBr || LatchBr->isUnconditional()) {
149 FailureReason = "latch terminator not conditional branch";
150 return std::nullopt;
153 unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0;
155 ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition());
156 if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) {
157 FailureReason = "latch terminator branch not conditional on integral icmp";
158 return std::nullopt;
161 const SCEV *MaxBETakenCount = getNarrowestLatchMaxTakenCountEstimate(SE, L);
162 if (isa<SCEVCouldNotCompute>(MaxBETakenCount)) {
163 FailureReason = "could not compute latch count";
164 return std::nullopt;
166 assert(SE.getLoopDisposition(MaxBETakenCount, &L) ==
167 ScalarEvolution::LoopInvariant &&
168 "loop variant exit count doesn't make sense!");
170 ICmpInst::Predicate Pred = ICI->getPredicate();
171 Value *LeftValue = ICI->getOperand(0);
172 const SCEV *LeftSCEV = SE.getSCEV(LeftValue);
173 IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType());
175 Value *RightValue = ICI->getOperand(1);
176 const SCEV *RightSCEV = SE.getSCEV(RightValue);
178 // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence.
179 if (!isa<SCEVAddRecExpr>(LeftSCEV)) {
180 if (isa<SCEVAddRecExpr>(RightSCEV)) {
181 std::swap(LeftSCEV, RightSCEV);
182 std::swap(LeftValue, RightValue);
183 Pred = ICmpInst::getSwappedPredicate(Pred);
184 } else {
185 FailureReason = "no add recurrences in the icmp";
186 return std::nullopt;
190 auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) {
191 if (AR->getNoWrapFlags(SCEV::FlagNSW))
192 return true;
194 IntegerType *Ty = cast<IntegerType>(AR->getType());
195 IntegerType *WideTy =
196 IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2);
198 const SCEVAddRecExpr *ExtendAfterOp =
199 dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy));
200 if (ExtendAfterOp) {
201 const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy);
202 const SCEV *ExtendedStep =
203 SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy);
205 bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart &&
206 ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep;
208 if (NoSignedWrap)
209 return true;
212 // We may have proved this when computing the sign extension above.
213 return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap;
216 // `ICI` is interpreted as taking the backedge if the *next* value of the
217 // induction variable satisfies some constraint.
219 const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV);
220 if (IndVarBase->getLoop() != &L) {
221 FailureReason = "LHS in cmp is not an AddRec for this loop";
222 return std::nullopt;
224 if (!IndVarBase->isAffine()) {
225 FailureReason = "LHS in icmp not induction variable";
226 return std::nullopt;
228 const SCEV *StepRec = IndVarBase->getStepRecurrence(SE);
229 if (!isa<SCEVConstant>(StepRec)) {
230 FailureReason = "LHS in icmp not induction variable";
231 return std::nullopt;
233 ConstantInt *StepCI = cast<SCEVConstant>(StepRec)->getValue();
235 if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)) {
236 FailureReason = "LHS in icmp needs nsw for equality predicates";
237 return std::nullopt;
240 assert(!StepCI->isZero() && "Zero step?");
241 bool IsIncreasing = !StepCI->isNegative();
242 bool IsSignedPredicate;
243 const SCEV *StartNext = IndVarBase->getStart();
244 const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE));
245 const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend);
246 const SCEV *Step = SE.getSCEV(StepCI);
248 const SCEV *FixedRightSCEV = nullptr;
250 // If RightValue resides within loop (but still being loop invariant),
251 // regenerate it as preheader.
252 if (auto *I = dyn_cast<Instruction>(RightValue))
253 if (L.contains(I->getParent()))
254 FixedRightSCEV = RightSCEV;
256 if (IsIncreasing) {
257 bool DecreasedRightValueByOne = false;
258 if (StepCI->isOne()) {
259 // Try to turn eq/ne predicates to those we can work with.
260 if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
261 // while (++i != len) { while (++i < len) {
262 // ... ---> ...
263 // } }
264 // If both parts are known non-negative, it is profitable to use
265 // unsigned comparison in increasing loop. This allows us to make the
266 // comparison check against "RightSCEV + 1" more optimistic.
267 if (isKnownNonNegativeInLoop(IndVarStart, &L, SE) &&
268 isKnownNonNegativeInLoop(RightSCEV, &L, SE))
269 Pred = ICmpInst::ICMP_ULT;
270 else
271 Pred = ICmpInst::ICMP_SLT;
272 else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
273 // while (true) { while (true) {
274 // if (++i == len) ---> if (++i > len - 1)
275 // break; break;
276 // ... ...
277 // } }
278 if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
279 cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ false)) {
280 Pred = ICmpInst::ICMP_UGT;
281 RightSCEV =
282 SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
283 DecreasedRightValueByOne = true;
284 } else if (cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ true)) {
285 Pred = ICmpInst::ICMP_SGT;
286 RightSCEV =
287 SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
288 DecreasedRightValueByOne = true;
293 bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
294 bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
295 bool FoundExpectedPred =
296 (LTPred && LatchBrExitIdx == 1) || (GTPred && LatchBrExitIdx == 0);
298 if (!FoundExpectedPred) {
299 FailureReason = "expected icmp slt semantically, found something else";
300 return std::nullopt;
303 IsSignedPredicate = ICmpInst::isSigned(Pred);
304 if (!IsSignedPredicate && !AllowUnsignedLatchCond) {
305 FailureReason = "unsigned latch conditions are explicitly prohibited";
306 return std::nullopt;
309 if (!isSafeIncreasingBound(IndVarStart, RightSCEV, Step, Pred,
310 LatchBrExitIdx, &L, SE)) {
311 FailureReason = "Unsafe loop bounds";
312 return std::nullopt;
314 if (LatchBrExitIdx == 0) {
315 // We need to increase the right value unless we have already decreased
316 // it virtually when we replaced EQ with SGT.
317 if (!DecreasedRightValueByOne)
318 FixedRightSCEV =
319 SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
320 } else {
321 assert(!DecreasedRightValueByOne &&
322 "Right value can be decreased only for LatchBrExitIdx == 0!");
324 } else {
325 bool IncreasedRightValueByOne = false;
326 if (StepCI->isMinusOne()) {
327 // Try to turn eq/ne predicates to those we can work with.
328 if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
329 // while (--i != len) { while (--i > len) {
330 // ... ---> ...
331 // } }
332 // We intentionally don't turn the predicate into UGT even if we know
333 // that both operands are non-negative, because it will only pessimize
334 // our check against "RightSCEV - 1".
335 Pred = ICmpInst::ICMP_SGT;
336 else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
337 // while (true) { while (true) {
338 // if (--i == len) ---> if (--i < len + 1)
339 // break; break;
340 // ... ...
341 // } }
342 if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
343 cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)) {
344 Pred = ICmpInst::ICMP_ULT;
345 RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
346 IncreasedRightValueByOne = true;
347 } else if (cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) {
348 Pred = ICmpInst::ICMP_SLT;
349 RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
350 IncreasedRightValueByOne = true;
355 bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
356 bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
358 bool FoundExpectedPred =
359 (GTPred && LatchBrExitIdx == 1) || (LTPred && LatchBrExitIdx == 0);
361 if (!FoundExpectedPred) {
362 FailureReason = "expected icmp sgt semantically, found something else";
363 return std::nullopt;
366 IsSignedPredicate =
367 Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT;
369 if (!IsSignedPredicate && !AllowUnsignedLatchCond) {
370 FailureReason = "unsigned latch conditions are explicitly prohibited";
371 return std::nullopt;
374 if (!isSafeDecreasingBound(IndVarStart, RightSCEV, Step, Pred,
375 LatchBrExitIdx, &L, SE)) {
376 FailureReason = "Unsafe bounds";
377 return std::nullopt;
380 if (LatchBrExitIdx == 0) {
381 // We need to decrease the right value unless we have already increased
382 // it virtually when we replaced EQ with SLT.
383 if (!IncreasedRightValueByOne)
384 FixedRightSCEV =
385 SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
386 } else {
387 assert(!IncreasedRightValueByOne &&
388 "Right value can be increased only for LatchBrExitIdx == 0!");
391 BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx);
393 assert(!L.contains(LatchExit) && "expected an exit block!");
394 const DataLayout &DL = Preheader->getModule()->getDataLayout();
395 SCEVExpander Expander(SE, DL, "loop-constrainer");
396 Instruction *Ins = Preheader->getTerminator();
398 if (FixedRightSCEV)
399 RightValue =
400 Expander.expandCodeFor(FixedRightSCEV, FixedRightSCEV->getType(), Ins);
402 Value *IndVarStartV = Expander.expandCodeFor(IndVarStart, IndVarTy, Ins);
403 IndVarStartV->setName("indvar.start");
405 LoopStructure Result;
407 Result.Tag = "main";
408 Result.Header = Header;
409 Result.Latch = Latch;
410 Result.LatchBr = LatchBr;
411 Result.LatchExit = LatchExit;
412 Result.LatchBrExitIdx = LatchBrExitIdx;
413 Result.IndVarStart = IndVarStartV;
414 Result.IndVarStep = StepCI;
415 Result.IndVarBase = LeftValue;
416 Result.IndVarIncreasing = IsIncreasing;
417 Result.LoopExitAt = RightValue;
418 Result.IsSignedPredicate = IsSignedPredicate;
419 Result.ExitCountTy = cast<IntegerType>(MaxBETakenCount->getType());
421 FailureReason = nullptr;
423 return Result;
426 // Add metadata to the loop L to disable loop optimizations. Callers need to
427 // confirm that optimizing loop L is not beneficial.
428 static void DisableAllLoopOptsOnLoop(Loop &L) {
429 // We do not care about any existing loopID related metadata for L, since we
430 // are setting all loop metadata to false.
431 LLVMContext &Context = L.getHeader()->getContext();
432 // Reserve first location for self reference to the LoopID metadata node.
433 MDNode *Dummy = MDNode::get(Context, {});
434 MDNode *DisableUnroll = MDNode::get(
435 Context, {MDString::get(Context, "llvm.loop.unroll.disable")});
436 Metadata *FalseVal =
437 ConstantAsMetadata::get(ConstantInt::get(Type::getInt1Ty(Context), 0));
438 MDNode *DisableVectorize = MDNode::get(
439 Context,
440 {MDString::get(Context, "llvm.loop.vectorize.enable"), FalseVal});
441 MDNode *DisableLICMVersioning = MDNode::get(
442 Context, {MDString::get(Context, "llvm.loop.licm_versioning.disable")});
443 MDNode *DisableDistribution = MDNode::get(
444 Context,
445 {MDString::get(Context, "llvm.loop.distribute.enable"), FalseVal});
446 MDNode *NewLoopID =
447 MDNode::get(Context, {Dummy, DisableUnroll, DisableVectorize,
448 DisableLICMVersioning, DisableDistribution});
449 // Set operand 0 to refer to the loop id itself.
450 NewLoopID->replaceOperandWith(0, NewLoopID);
451 L.setLoopID(NewLoopID);
454 LoopConstrainer::LoopConstrainer(Loop &L, LoopInfo &LI,
455 function_ref<void(Loop *, bool)> LPMAddNewLoop,
456 const LoopStructure &LS, ScalarEvolution &SE,
457 DominatorTree &DT, Type *T, SubRanges SR)
458 : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), SE(SE),
459 DT(DT), LI(LI), LPMAddNewLoop(LPMAddNewLoop), OriginalLoop(L), RangeTy(T),
460 MainLoopStructure(LS), SR(SR) {}
462 void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result,
463 const char *Tag) const {
464 for (BasicBlock *BB : OriginalLoop.getBlocks()) {
465 BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F);
466 Result.Blocks.push_back(Clone);
467 Result.Map[BB] = Clone;
470 auto GetClonedValue = [&Result](Value *V) {
471 assert(V && "null values not in domain!");
472 auto It = Result.Map.find(V);
473 if (It == Result.Map.end())
474 return V;
475 return static_cast<Value *>(It->second);
478 auto *ClonedLatch =
479 cast<BasicBlock>(GetClonedValue(OriginalLoop.getLoopLatch()));
480 ClonedLatch->getTerminator()->setMetadata(ClonedLoopTag,
481 MDNode::get(Ctx, {}));
483 Result.Structure = MainLoopStructure.map(GetClonedValue);
484 Result.Structure.Tag = Tag;
486 for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) {
487 BasicBlock *ClonedBB = Result.Blocks[i];
488 BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i];
490 assert(Result.Map[OriginalBB] == ClonedBB && "invariant!");
492 for (Instruction &I : *ClonedBB)
493 RemapInstruction(&I, Result.Map,
494 RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
496 // Exit blocks will now have one more predecessor and their PHI nodes need
497 // to be edited to reflect that. No phi nodes need to be introduced because
498 // the loop is in LCSSA.
500 for (auto *SBB : successors(OriginalBB)) {
501 if (OriginalLoop.contains(SBB))
502 continue; // not an exit block
504 for (PHINode &PN : SBB->phis()) {
505 Value *OldIncoming = PN.getIncomingValueForBlock(OriginalBB);
506 PN.addIncoming(GetClonedValue(OldIncoming), ClonedBB);
507 SE.forgetValue(&PN);
513 LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(
514 const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt,
515 BasicBlock *ContinuationBlock) const {
516 // We start with a loop with a single latch:
518 // +--------------------+
519 // | |
520 // | preheader |
521 // | |
522 // +--------+-----------+
523 // | ----------------\
524 // | / |
525 // +--------v----v------+ |
526 // | | |
527 // | header | |
528 // | | |
529 // +--------------------+ |
530 // |
531 // ..... |
532 // |
533 // +--------------------+ |
534 // | | |
535 // | latch >----------/
536 // | |
537 // +-------v------------+
538 // |
539 // |
540 // | +--------------------+
541 // | | |
542 // +---> original exit |
543 // | |
544 // +--------------------+
546 // We change the control flow to look like
549 // +--------------------+
550 // | |
551 // | preheader >-------------------------+
552 // | | |
553 // +--------v-----------+ |
554 // | /-------------+ |
555 // | / | |
556 // +--------v--v--------+ | |
557 // | | | |
558 // | header | | +--------+ |
559 // | | | | | |
560 // +--------------------+ | | +-----v-----v-----------+
561 // | | | |
562 // | | | .pseudo.exit |
563 // | | | |
564 // | | +-----------v-----------+
565 // | | |
566 // ..... | | |
567 // | | +--------v-------------+
568 // +--------------------+ | | | |
569 // | | | | | ContinuationBlock |
570 // | latch >------+ | | |
571 // | | | +----------------------+
572 // +---------v----------+ |
573 // | |
574 // | |
575 // | +---------------^-----+
576 // | | |
577 // +-----> .exit.selector |
578 // | |
579 // +----------v----------+
580 // |
581 // +--------------------+ |
582 // | | |
583 // | original exit <----+
584 // | |
585 // +--------------------+
587 RewrittenRangeInfo RRI;
589 BasicBlock *BBInsertLocation = LS.Latch->getNextNode();
590 RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector",
591 &F, BBInsertLocation);
592 RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F,
593 BBInsertLocation);
595 BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator());
596 bool Increasing = LS.IndVarIncreasing;
597 bool IsSignedPredicate = LS.IsSignedPredicate;
599 IRBuilder<> B(PreheaderJump);
600 auto NoopOrExt = [&](Value *V) {
601 if (V->getType() == RangeTy)
602 return V;
603 return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName())
604 : B.CreateZExt(V, RangeTy, "wide." + V->getName());
607 // EnterLoopCond - is it okay to start executing this `LS'?
608 Value *EnterLoopCond = nullptr;
609 auto Pred =
610 Increasing
611 ? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT)
612 : (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT);
613 Value *IndVarStart = NoopOrExt(LS.IndVarStart);
614 EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt);
616 B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit);
617 PreheaderJump->eraseFromParent();
619 LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector);
620 B.SetInsertPoint(LS.LatchBr);
621 Value *IndVarBase = NoopOrExt(LS.IndVarBase);
622 Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt);
624 Value *CondForBranch = LS.LatchBrExitIdx == 1
625 ? TakeBackedgeLoopCond
626 : B.CreateNot(TakeBackedgeLoopCond);
628 LS.LatchBr->setCondition(CondForBranch);
630 B.SetInsertPoint(RRI.ExitSelector);
632 // IterationsLeft - are there any more iterations left, given the original
633 // upper bound on the induction variable? If not, we branch to the "real"
634 // exit.
635 Value *LoopExitAt = NoopOrExt(LS.LoopExitAt);
636 Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt);
637 B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit);
639 BranchInst *BranchToContinuation =
640 BranchInst::Create(ContinuationBlock, RRI.PseudoExit);
642 // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of
643 // each of the PHI nodes in the loop header. This feeds into the initial
644 // value of the same PHI nodes if/when we continue execution.
645 for (PHINode &PN : LS.Header->phis()) {
646 PHINode *NewPHI = PHINode::Create(PN.getType(), 2, PN.getName() + ".copy",
647 BranchToContinuation);
649 NewPHI->addIncoming(PN.getIncomingValueForBlock(Preheader), Preheader);
650 NewPHI->addIncoming(PN.getIncomingValueForBlock(LS.Latch),
651 RRI.ExitSelector);
652 RRI.PHIValuesAtPseudoExit.push_back(NewPHI);
655 RRI.IndVarEnd = PHINode::Create(IndVarBase->getType(), 2, "indvar.end",
656 BranchToContinuation);
657 RRI.IndVarEnd->addIncoming(IndVarStart, Preheader);
658 RRI.IndVarEnd->addIncoming(IndVarBase, RRI.ExitSelector);
660 // The latch exit now has a branch from `RRI.ExitSelector' instead of
661 // `LS.Latch'. The PHI nodes need to be updated to reflect that.
662 LS.LatchExit->replacePhiUsesWith(LS.Latch, RRI.ExitSelector);
664 return RRI;
667 void LoopConstrainer::rewriteIncomingValuesForPHIs(
668 LoopStructure &LS, BasicBlock *ContinuationBlock,
669 const LoopConstrainer::RewrittenRangeInfo &RRI) const {
670 unsigned PHIIndex = 0;
671 for (PHINode &PN : LS.Header->phis())
672 PN.setIncomingValueForBlock(ContinuationBlock,
673 RRI.PHIValuesAtPseudoExit[PHIIndex++]);
675 LS.IndVarStart = RRI.IndVarEnd;
678 BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS,
679 BasicBlock *OldPreheader,
680 const char *Tag) const {
681 BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header);
682 BranchInst::Create(LS.Header, Preheader);
684 LS.Header->replacePhiUsesWith(OldPreheader, Preheader);
686 return Preheader;
689 void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) {
690 Loop *ParentLoop = OriginalLoop.getParentLoop();
691 if (!ParentLoop)
692 return;
694 for (BasicBlock *BB : BBs)
695 ParentLoop->addBasicBlockToLoop(BB, LI);
698 Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent,
699 ValueToValueMapTy &VM,
700 bool IsSubloop) {
701 Loop &New = *LI.AllocateLoop();
702 if (Parent)
703 Parent->addChildLoop(&New);
704 else
705 LI.addTopLevelLoop(&New);
706 LPMAddNewLoop(&New, IsSubloop);
708 // Add all of the blocks in Original to the new loop.
709 for (auto *BB : Original->blocks())
710 if (LI.getLoopFor(BB) == Original)
711 New.addBasicBlockToLoop(cast<BasicBlock>(VM[BB]), LI);
713 // Add all of the subloops to the new loop.
714 for (Loop *SubLoop : *Original)
715 createClonedLoopStructure(SubLoop, &New, VM, /* IsSubloop */ true);
717 return &New;
720 bool LoopConstrainer::run() {
721 BasicBlock *Preheader = OriginalLoop.getLoopPreheader();
722 assert(Preheader != nullptr && "precondition!");
724 OriginalPreheader = Preheader;
725 MainLoopPreheader = Preheader;
726 bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate;
727 bool Increasing = MainLoopStructure.IndVarIncreasing;
728 IntegerType *IVTy = cast<IntegerType>(RangeTy);
730 SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "loop-constrainer");
731 Instruction *InsertPt = OriginalPreheader->getTerminator();
733 // It would have been better to make `PreLoop' and `PostLoop'
734 // `std::optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy
735 // constructor.
736 ClonedLoop PreLoop, PostLoop;
737 bool NeedsPreLoop =
738 Increasing ? SR.LowLimit.has_value() : SR.HighLimit.has_value();
739 bool NeedsPostLoop =
740 Increasing ? SR.HighLimit.has_value() : SR.LowLimit.has_value();
742 Value *ExitPreLoopAt = nullptr;
743 Value *ExitMainLoopAt = nullptr;
744 const SCEVConstant *MinusOneS =
745 cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */));
747 if (NeedsPreLoop) {
748 const SCEV *ExitPreLoopAtSCEV = nullptr;
750 if (Increasing)
751 ExitPreLoopAtSCEV = *SR.LowLimit;
752 else if (cannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE,
753 IsSignedPredicate))
754 ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS);
755 else {
756 LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing "
757 << "preloop exit limit. HighLimit = "
758 << *(*SR.HighLimit) << "\n");
759 return false;
762 if (!Expander.isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt)) {
763 LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the"
764 << " preloop exit limit " << *ExitPreLoopAtSCEV
765 << " at block " << InsertPt->getParent()->getName()
766 << "\n");
767 return false;
770 ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt);
771 ExitPreLoopAt->setName("exit.preloop.at");
774 if (NeedsPostLoop) {
775 const SCEV *ExitMainLoopAtSCEV = nullptr;
777 if (Increasing)
778 ExitMainLoopAtSCEV = *SR.HighLimit;
779 else if (cannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE,
780 IsSignedPredicate))
781 ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS);
782 else {
783 LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing "
784 << "mainloop exit limit. LowLimit = "
785 << *(*SR.LowLimit) << "\n");
786 return false;
789 if (!Expander.isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt)) {
790 LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the"
791 << " main loop exit limit " << *ExitMainLoopAtSCEV
792 << " at block " << InsertPt->getParent()->getName()
793 << "\n");
794 return false;
797 ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt);
798 ExitMainLoopAt->setName("exit.mainloop.at");
801 // We clone these ahead of time so that we don't have to deal with changing
802 // and temporarily invalid IR as we transform the loops.
803 if (NeedsPreLoop)
804 cloneLoop(PreLoop, "preloop");
805 if (NeedsPostLoop)
806 cloneLoop(PostLoop, "postloop");
808 RewrittenRangeInfo PreLoopRRI;
810 if (NeedsPreLoop) {
811 Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header,
812 PreLoop.Structure.Header);
814 MainLoopPreheader =
815 createPreheader(MainLoopStructure, Preheader, "mainloop");
816 PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader,
817 ExitPreLoopAt, MainLoopPreheader);
818 rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader,
819 PreLoopRRI);
822 BasicBlock *PostLoopPreheader = nullptr;
823 RewrittenRangeInfo PostLoopRRI;
825 if (NeedsPostLoop) {
826 PostLoopPreheader =
827 createPreheader(PostLoop.Structure, Preheader, "postloop");
828 PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader,
829 ExitMainLoopAt, PostLoopPreheader);
830 rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader,
831 PostLoopRRI);
834 BasicBlock *NewMainLoopPreheader =
835 MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr;
836 BasicBlock *NewBlocks[] = {PostLoopPreheader, PreLoopRRI.PseudoExit,
837 PreLoopRRI.ExitSelector, PostLoopRRI.PseudoExit,
838 PostLoopRRI.ExitSelector, NewMainLoopPreheader};
840 // Some of the above may be nullptr, filter them out before passing to
841 // addToParentLoopIfNeeded.
842 auto NewBlocksEnd =
843 std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr);
845 addToParentLoopIfNeeded(ArrayRef(std::begin(NewBlocks), NewBlocksEnd));
847 DT.recalculate(F);
849 // We need to first add all the pre and post loop blocks into the loop
850 // structures (as part of createClonedLoopStructure), and then update the
851 // LCSSA form and LoopSimplifyForm. This is necessary for correctly updating
852 // LI when LoopSimplifyForm is generated.
853 Loop *PreL = nullptr, *PostL = nullptr;
854 if (!PreLoop.Blocks.empty()) {
855 PreL = createClonedLoopStructure(&OriginalLoop,
856 OriginalLoop.getParentLoop(), PreLoop.Map,
857 /* IsSubLoop */ false);
860 if (!PostLoop.Blocks.empty()) {
861 PostL =
862 createClonedLoopStructure(&OriginalLoop, OriginalLoop.getParentLoop(),
863 PostLoop.Map, /* IsSubLoop */ false);
866 // This function canonicalizes the loop into Loop-Simplify and LCSSA forms.
867 auto CanonicalizeLoop = [&](Loop *L, bool IsOriginalLoop) {
868 formLCSSARecursively(*L, DT, &LI, &SE);
869 simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, true);
870 // Pre/post loops are slow paths, we do not need to perform any loop
871 // optimizations on them.
872 if (!IsOriginalLoop)
873 DisableAllLoopOptsOnLoop(*L);
875 if (PreL)
876 CanonicalizeLoop(PreL, false);
877 if (PostL)
878 CanonicalizeLoop(PostL, false);
879 CanonicalizeLoop(&OriginalLoop, true);
881 /// At this point:
882 /// - We've broken a "main loop" out of the loop in a way that the "main loop"
883 /// runs with the induction variable in a subset of [Begin, End).
884 /// - There is no overflow when computing "main loop" exit limit.
885 /// - Max latch taken count of the loop is limited.
886 /// It guarantees that induction variable will not overflow iterating in the
887 /// "main loop".
888 if (isa<OverflowingBinaryOperator>(MainLoopStructure.IndVarBase))
889 if (IsSignedPredicate)
890 cast<BinaryOperator>(MainLoopStructure.IndVarBase)
891 ->setHasNoSignedWrap(true);
892 /// TODO: support unsigned predicate.
893 /// To add NUW flag we need to prove that both operands of BO are
894 /// non-negative. E.g:
895 /// ...
896 /// %iv.next = add nsw i32 %iv, -1
897 /// %cmp = icmp ult i32 %iv.next, %n
898 /// br i1 %cmp, label %loopexit, label %loop
900 /// -1 is MAX_UINT in terms of unsigned int. Adding anything but zero will
901 /// overflow, therefore NUW flag is not legal here.
903 return true;