AMDGPU: Make vector_shuffle legal for v2i32 with v_pk_mov_b32 (#123684)
[llvm-project.git] / polly / lib / Support / SCEVValidator.cpp
blob599d7f9d60802e48a705660a914be7290ba213b9
2 #include "polly/Support/SCEVValidator.h"
3 #include "polly/ScopDetection.h"
4 #include "llvm/Analysis/RegionInfo.h"
5 #include "llvm/Analysis/ScalarEvolution.h"
6 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
7 #include "llvm/Support/Debug.h"
9 using namespace llvm;
10 using namespace polly;
12 #include "polly/Support/PollyDebug.h"
13 #define DEBUG_TYPE "polly-scev-validator"
15 namespace SCEVType {
16 /// The type of a SCEV
17 ///
18 /// To check for the validity of a SCEV we assign to each SCEV a type. The
19 /// possible types are INT, PARAM, IV and INVALID. The order of the types is
20 /// important. The subexpressions of SCEV with a type X can only have a type
21 /// that is smaller or equal than X.
22 enum TYPE {
23 // An integer value.
24 INT,
26 // An expression that is constant during the execution of the Scop,
27 // but that may depend on parameters unknown at compile time.
28 PARAM,
30 // An expression that may change during the execution of the SCoP.
31 IV,
33 // An invalid expression.
34 INVALID
36 } // namespace SCEVType
38 /// The result the validator returns for a SCEV expression.
39 class ValidatorResult final {
40 /// The type of the expression
41 SCEVType::TYPE Type;
43 /// The set of Parameters in the expression.
44 ParameterSetTy Parameters;
46 public:
47 /// The copy constructor
48 ValidatorResult(const ValidatorResult &Source) {
49 Type = Source.Type;
50 Parameters = Source.Parameters;
53 /// Construct a result with a certain type and no parameters.
54 ValidatorResult(SCEVType::TYPE Type) : Type(Type) {
55 assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter");
58 /// Construct a result with a certain type and a single parameter.
59 ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) {
60 Parameters.insert(Expr);
63 /// Get the type of the ValidatorResult.
64 SCEVType::TYPE getType() { return Type; }
66 /// Is the analyzed SCEV constant during the execution of the SCoP.
67 bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; }
69 /// Is the analyzed SCEV valid.
70 bool isValid() { return Type != SCEVType::INVALID; }
72 /// Is the analyzed SCEV of Type IV.
73 bool isIV() { return Type == SCEVType::IV; }
75 /// Is the analyzed SCEV of Type INT.
76 bool isINT() { return Type == SCEVType::INT; }
78 /// Is the analyzed SCEV of Type PARAM.
79 bool isPARAM() { return Type == SCEVType::PARAM; }
81 /// Get the parameters of this validator result.
82 const ParameterSetTy &getParameters() { return Parameters; }
84 /// Add the parameters of Source to this result.
85 void addParamsFrom(const ValidatorResult &Source) {
86 Parameters.insert(Source.Parameters.begin(), Source.Parameters.end());
89 /// Merge a result.
90 ///
91 /// This means to merge the parameters and to set the Type to the most
92 /// specific Type that matches both.
93 void merge(const ValidatorResult &ToMerge) {
94 Type = std::max(Type, ToMerge.Type);
95 addParamsFrom(ToMerge);
98 void print(raw_ostream &OS) {
99 switch (Type) {
100 case SCEVType::INT:
101 OS << "SCEVType::INT";
102 break;
103 case SCEVType::PARAM:
104 OS << "SCEVType::PARAM";
105 break;
106 case SCEVType::IV:
107 OS << "SCEVType::IV";
108 break;
109 case SCEVType::INVALID:
110 OS << "SCEVType::INVALID";
111 break;
116 raw_ostream &operator<<(raw_ostream &OS, ValidatorResult &VR) {
117 VR.print(OS);
118 return OS;
121 /// Check if a SCEV is valid in a SCoP.
122 class SCEVValidator : public SCEVVisitor<SCEVValidator, ValidatorResult> {
123 private:
124 const Region *R;
125 Loop *Scope;
126 ScalarEvolution &SE;
127 InvariantLoadsSetTy *ILS;
129 public:
130 SCEVValidator(const Region *R, Loop *Scope, ScalarEvolution &SE,
131 InvariantLoadsSetTy *ILS)
132 : R(R), Scope(Scope), SE(SE), ILS(ILS) {}
134 ValidatorResult visitConstant(const SCEVConstant *Constant) {
135 return ValidatorResult(SCEVType::INT);
138 ValidatorResult visitVScale(const SCEVVScale *VScale) {
139 // We do not support VScale constants.
140 POLLY_DEBUG(dbgs() << "INVALID: VScale is not supported");
141 return ValidatorResult(SCEVType::INVALID);
144 ValidatorResult visitZeroExtendOrTruncateExpr(const SCEV *Expr,
145 const SCEV *Operand) {
146 ValidatorResult Op = visit(Operand);
147 auto Type = Op.getType();
149 // If unsigned operations are allowed return the operand, otherwise
150 // check if we can model the expression without unsigned assumptions.
151 if (PollyAllowUnsignedOperations || Type == SCEVType::INVALID)
152 return Op;
154 if (Type == SCEVType::IV)
155 return ValidatorResult(SCEVType::INVALID);
156 return ValidatorResult(SCEVType::PARAM, Expr);
159 ValidatorResult visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
160 return visit(Expr->getOperand());
163 ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) {
164 return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
167 ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
168 return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
171 ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
172 return visit(Expr->getOperand());
175 ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) {
176 ValidatorResult Return(SCEVType::INT);
178 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
179 ValidatorResult Op = visit(Expr->getOperand(i));
180 Return.merge(Op);
182 // Early exit.
183 if (!Return.isValid())
184 break;
187 return Return;
190 ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) {
191 ValidatorResult Return(SCEVType::INT);
193 bool HasMultipleParams = false;
195 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
196 ValidatorResult Op = visit(Expr->getOperand(i));
198 if (Op.isINT())
199 continue;
201 if (Op.isPARAM() && Return.isPARAM()) {
202 HasMultipleParams = true;
203 continue;
206 if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) {
207 POLLY_DEBUG(
208 dbgs() << "INVALID: More than one non-int operand in MulExpr\n"
209 << "\tExpr: " << *Expr << "\n"
210 << "\tPrevious expression type: " << Return << "\n"
211 << "\tNext operand (" << Op << "): " << *Expr->getOperand(i)
212 << "\n");
214 return ValidatorResult(SCEVType::INVALID);
217 Return.merge(Op);
220 if (HasMultipleParams && Return.isValid())
221 return ValidatorResult(SCEVType::PARAM, Expr);
223 return Return;
226 ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) {
227 if (!Expr->isAffine()) {
228 POLLY_DEBUG(dbgs() << "INVALID: AddRec is not affine");
229 return ValidatorResult(SCEVType::INVALID);
232 ValidatorResult Start = visit(Expr->getStart());
233 ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
235 if (!Start.isValid())
236 return Start;
238 if (!Recurrence.isValid())
239 return Recurrence;
241 auto *L = Expr->getLoop();
242 if (R->contains(L) && (!Scope || !L->contains(Scope))) {
243 POLLY_DEBUG(
244 dbgs() << "INVALID: Loop of AddRec expression boxed in an a "
245 "non-affine subregion or has a non-synthesizable exit "
246 "value.");
247 return ValidatorResult(SCEVType::INVALID);
250 if (R->contains(L)) {
251 if (Recurrence.isINT()) {
252 ValidatorResult Result(SCEVType::IV);
253 Result.addParamsFrom(Start);
254 return Result;
257 POLLY_DEBUG(dbgs() << "INVALID: AddRec within scop has non-int"
258 "recurrence part");
259 return ValidatorResult(SCEVType::INVALID);
262 assert(Recurrence.isConstant() && "Expected 'Recurrence' to be constant");
264 // Directly generate ValidatorResult for Expr if 'start' is zero.
265 if (Expr->getStart()->isZero())
266 return ValidatorResult(SCEVType::PARAM, Expr);
268 // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}'
269 // if 'start' is not zero.
270 const SCEV *ZeroStartExpr = SE.getAddRecExpr(
271 SE.getConstant(Expr->getStart()->getType(), 0),
272 Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags());
274 ValidatorResult ZeroStartResult =
275 ValidatorResult(SCEVType::PARAM, ZeroStartExpr);
276 ZeroStartResult.addParamsFrom(Start);
278 return ZeroStartResult;
281 ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) {
282 ValidatorResult Return(SCEVType::INT);
284 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
285 ValidatorResult Op = visit(Expr->getOperand(i));
287 if (!Op.isValid())
288 return Op;
290 Return.merge(Op);
293 return Return;
296 ValidatorResult visitSMinExpr(const SCEVSMinExpr *Expr) {
297 ValidatorResult Return(SCEVType::INT);
299 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
300 ValidatorResult Op = visit(Expr->getOperand(i));
302 if (!Op.isValid())
303 return Op;
305 Return.merge(Op);
308 return Return;
311 ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) {
312 // We do not support unsigned max operations. If 'Expr' is constant during
313 // Scop execution we treat this as a parameter, otherwise we bail out.
314 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
315 ValidatorResult Op = visit(Expr->getOperand(i));
317 if (!Op.isConstant()) {
318 POLLY_DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand");
319 return ValidatorResult(SCEVType::INVALID);
323 return ValidatorResult(SCEVType::PARAM, Expr);
326 ValidatorResult visitUMinExpr(const SCEVUMinExpr *Expr) {
327 // We do not support unsigned min operations. If 'Expr' is constant during
328 // Scop execution we treat this as a parameter, otherwise we bail out.
329 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
330 ValidatorResult Op = visit(Expr->getOperand(i));
332 if (!Op.isConstant()) {
333 POLLY_DEBUG(dbgs() << "INVALID: UMinExpr has a non-constant operand");
334 return ValidatorResult(SCEVType::INVALID);
338 return ValidatorResult(SCEVType::PARAM, Expr);
341 ValidatorResult visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
342 // We do not support unsigned min operations. If 'Expr' is constant during
343 // Scop execution we treat this as a parameter, otherwise we bail out.
344 for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
345 ValidatorResult Op = visit(Expr->getOperand(i));
347 if (!Op.isConstant()) {
348 POLLY_DEBUG(
349 dbgs()
350 << "INVALID: SCEVSequentialUMinExpr has a non-constant operand");
351 return ValidatorResult(SCEVType::INVALID);
355 return ValidatorResult(SCEVType::PARAM, Expr);
358 ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) {
359 if (R->contains(I)) {
360 POLLY_DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction "
361 "within the region\n");
362 return ValidatorResult(SCEVType::INVALID);
365 return ValidatorResult(SCEVType::PARAM, S);
368 ValidatorResult visitLoadInstruction(Instruction *I, const SCEV *S) {
369 if (R->contains(I) && ILS) {
370 ILS->insert(cast<LoadInst>(I));
371 return ValidatorResult(SCEVType::PARAM, S);
374 return visitGenericInst(I, S);
377 ValidatorResult visitDivision(const SCEV *Dividend, const SCEV *Divisor,
378 const SCEV *DivExpr,
379 Instruction *SDiv = nullptr) {
381 // First check if we might be able to model the division, thus if the
382 // divisor is constant. If so, check the dividend, otherwise check if
383 // the whole division can be seen as a parameter.
384 if (isa<SCEVConstant>(Divisor) && !Divisor->isZero())
385 return visit(Dividend);
387 // For signed divisions use the SDiv instruction to check for a parameter
388 // division, for unsigned divisions check the operands.
389 if (SDiv)
390 return visitGenericInst(SDiv, DivExpr);
392 ValidatorResult LHS = visit(Dividend);
393 ValidatorResult RHS = visit(Divisor);
394 if (LHS.isConstant() && RHS.isConstant())
395 return ValidatorResult(SCEVType::PARAM, DivExpr);
397 POLLY_DEBUG(
398 dbgs() << "INVALID: unsigned division of non-constant expressions");
399 return ValidatorResult(SCEVType::INVALID);
402 ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) {
403 if (!PollyAllowUnsignedOperations)
404 return ValidatorResult(SCEVType::INVALID);
406 const SCEV *Dividend = Expr->getLHS();
407 const SCEV *Divisor = Expr->getRHS();
408 return visitDivision(Dividend, Divisor, Expr);
411 ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *Expr) {
412 assert(SDiv->getOpcode() == Instruction::SDiv &&
413 "Assumed SDiv instruction!");
415 const SCEV *Dividend = SE.getSCEV(SDiv->getOperand(0));
416 const SCEV *Divisor = SE.getSCEV(SDiv->getOperand(1));
417 return visitDivision(Dividend, Divisor, Expr, SDiv);
420 ValidatorResult visitSRemInstruction(Instruction *SRem, const SCEV *S) {
421 assert(SRem->getOpcode() == Instruction::SRem &&
422 "Assumed SRem instruction!");
424 auto *Divisor = SRem->getOperand(1);
425 auto *CI = dyn_cast<ConstantInt>(Divisor);
426 if (!CI || CI->isZeroValue())
427 return visitGenericInst(SRem, S);
429 auto *Dividend = SRem->getOperand(0);
430 const SCEV *DividendSCEV = SE.getSCEV(Dividend);
431 return visit(DividendSCEV);
434 ValidatorResult visitUnknown(const SCEVUnknown *Expr) {
435 Value *V = Expr->getValue();
437 if (!Expr->getType()->isIntegerTy() && !Expr->getType()->isPointerTy()) {
438 POLLY_DEBUG(
439 dbgs() << "INVALID: UnknownExpr is not an integer or pointer");
440 return ValidatorResult(SCEVType::INVALID);
443 if (isa<UndefValue>(V)) {
444 POLLY_DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value");
445 return ValidatorResult(SCEVType::INVALID);
448 if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) {
449 switch (I->getOpcode()) {
450 case Instruction::IntToPtr:
451 return visit(SE.getSCEVAtScope(I->getOperand(0), Scope));
452 case Instruction::Load:
453 return visitLoadInstruction(I, Expr);
454 case Instruction::SDiv:
455 return visitSDivInstruction(I, Expr);
456 case Instruction::SRem:
457 return visitSRemInstruction(I, Expr);
458 default:
459 return visitGenericInst(I, Expr);
463 if (Expr->getType()->isPointerTy()) {
464 if (isa<ConstantPointerNull>(V))
465 return ValidatorResult(SCEVType::INT); // "int"
468 return ValidatorResult(SCEVType::PARAM, Expr);
472 /// Check whether a SCEV refers to an SSA name defined inside a region.
473 class SCEVInRegionDependences final {
474 const Region *R;
475 Loop *Scope;
476 const InvariantLoadsSetTy &ILS;
477 bool AllowLoops;
478 bool HasInRegionDeps = false;
480 public:
481 SCEVInRegionDependences(const Region *R, Loop *Scope, bool AllowLoops,
482 const InvariantLoadsSetTy &ILS)
483 : R(R), Scope(Scope), ILS(ILS), AllowLoops(AllowLoops) {}
485 bool follow(const SCEV *S) {
486 if (auto Unknown = dyn_cast<SCEVUnknown>(S)) {
487 Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
489 if (Inst) {
490 // When we invariant load hoist a load, we first make sure that there
491 // can be no dependences created by it in the Scop region. So, we should
492 // not consider scalar dependences to `LoadInst`s that are invariant
493 // load hoisted.
495 // If this check is not present, then we create data dependences which
496 // are strictly not necessary by tracking the invariant load as a
497 // scalar.
498 LoadInst *LI = dyn_cast<LoadInst>(Inst);
499 if (LI && ILS.contains(LI))
500 return false;
503 // Return true when Inst is defined inside the region R.
504 if (!Inst || !R->contains(Inst))
505 return true;
507 HasInRegionDeps = true;
508 return false;
511 if (auto AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
512 if (AllowLoops)
513 return true;
515 auto *L = AddRec->getLoop();
516 if (R->contains(L) && !L->contains(Scope)) {
517 HasInRegionDeps = true;
518 return false;
522 return true;
524 bool isDone() { return false; }
525 bool hasDependences() { return HasInRegionDeps; }
528 /// Find all loops referenced in SCEVAddRecExprs.
529 class SCEVFindLoops final {
530 SetVector<const Loop *> &Loops;
532 public:
533 SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {}
535 bool follow(const SCEV *S) {
536 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S))
537 Loops.insert(AddRec->getLoop());
538 return true;
540 bool isDone() { return false; }
543 void polly::findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) {
544 SCEVFindLoops FindLoops(Loops);
545 SCEVTraversal<SCEVFindLoops> ST(FindLoops);
546 ST.visitAll(Expr);
549 /// Find all values referenced in SCEVUnknowns.
550 class SCEVFindValues final {
551 ScalarEvolution &SE;
552 SetVector<Value *> &Values;
554 public:
555 SCEVFindValues(ScalarEvolution &SE, SetVector<Value *> &Values)
556 : SE(SE), Values(Values) {}
558 bool follow(const SCEV *S) {
559 const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S);
560 if (!Unknown)
561 return true;
563 Values.insert(Unknown->getValue());
564 Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
565 if (!Inst || (Inst->getOpcode() != Instruction::SRem &&
566 Inst->getOpcode() != Instruction::SDiv))
567 return false;
569 const SCEV *Dividend = SE.getSCEV(Inst->getOperand(1));
570 if (!isa<SCEVConstant>(Dividend))
571 return false;
573 const SCEV *Divisor = SE.getSCEV(Inst->getOperand(0));
574 SCEVFindValues FindValues(SE, Values);
575 SCEVTraversal<SCEVFindValues> ST(FindValues);
576 ST.visitAll(Dividend);
577 ST.visitAll(Divisor);
579 return false;
581 bool isDone() { return false; }
584 void polly::findValues(const SCEV *Expr, ScalarEvolution &SE,
585 SetVector<Value *> &Values) {
586 SCEVFindValues FindValues(SE, Values);
587 SCEVTraversal<SCEVFindValues> ST(FindValues);
588 ST.visitAll(Expr);
591 bool polly::hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R,
592 llvm::Loop *Scope, bool AllowLoops,
593 const InvariantLoadsSetTy &ILS) {
594 SCEVInRegionDependences InRegionDeps(R, Scope, AllowLoops, ILS);
595 SCEVTraversal<SCEVInRegionDependences> ST(InRegionDeps);
596 ST.visitAll(Expr);
597 return InRegionDeps.hasDependences();
600 bool polly::isAffineExpr(const Region *R, llvm::Loop *Scope, const SCEV *Expr,
601 ScalarEvolution &SE, InvariantLoadsSetTy *ILS) {
602 if (isa<SCEVCouldNotCompute>(Expr))
603 return false;
605 SCEVValidator Validator(R, Scope, SE, ILS);
606 POLLY_DEBUG({
607 dbgs() << "\n";
608 dbgs() << "Expr: " << *Expr << "\n";
609 dbgs() << "Region: " << R->getNameStr() << "\n";
610 dbgs() << " -> ";
613 ValidatorResult Result = Validator.visit(Expr);
615 POLLY_DEBUG({
616 if (Result.isValid())
617 dbgs() << "VALID\n";
618 dbgs() << "\n";
621 return Result.isValid();
624 static bool isAffineExpr(Value *V, const Region *R, Loop *Scope,
625 ScalarEvolution &SE, ParameterSetTy &Params) {
626 const SCEV *E = SE.getSCEV(V);
627 if (isa<SCEVCouldNotCompute>(E))
628 return false;
630 SCEVValidator Validator(R, Scope, SE, nullptr);
631 ValidatorResult Result = Validator.visit(E);
632 if (!Result.isValid())
633 return false;
635 auto ResultParams = Result.getParameters();
636 Params.insert(ResultParams.begin(), ResultParams.end());
638 return true;
641 bool polly::isAffineConstraint(Value *V, const Region *R, Loop *Scope,
642 ScalarEvolution &SE, ParameterSetTy &Params,
643 bool OrExpr) {
644 if (auto *ICmp = dyn_cast<ICmpInst>(V)) {
645 return isAffineConstraint(ICmp->getOperand(0), R, Scope, SE, Params,
646 true) &&
647 isAffineConstraint(ICmp->getOperand(1), R, Scope, SE, Params, true);
648 } else if (auto *BinOp = dyn_cast<BinaryOperator>(V)) {
649 auto Opcode = BinOp->getOpcode();
650 if (Opcode == Instruction::And || Opcode == Instruction::Or)
651 return isAffineConstraint(BinOp->getOperand(0), R, Scope, SE, Params,
652 false) &&
653 isAffineConstraint(BinOp->getOperand(1), R, Scope, SE, Params,
654 false);
655 /* Fall through */
658 if (!OrExpr)
659 return false;
661 return ::isAffineExpr(V, R, Scope, SE, Params);
664 ParameterSetTy polly::getParamsInAffineExpr(const Region *R, Loop *Scope,
665 const SCEV *Expr,
666 ScalarEvolution &SE) {
667 if (isa<SCEVCouldNotCompute>(Expr))
668 return ParameterSetTy();
670 InvariantLoadsSetTy ILS;
671 SCEVValidator Validator(R, Scope, SE, &ILS);
672 ValidatorResult Result = Validator.visit(Expr);
673 assert(Result.isValid() && "Requested parameters for an invalid SCEV!");
675 return Result.getParameters();
678 std::pair<const SCEVConstant *, const SCEV *>
679 polly::extractConstantFactor(const SCEV *S, ScalarEvolution &SE) {
680 auto *ConstPart = cast<SCEVConstant>(SE.getConstant(S->getType(), 1));
682 if (auto *Constant = dyn_cast<SCEVConstant>(S))
683 return std::make_pair(Constant, SE.getConstant(S->getType(), 1));
685 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
686 if (AddRec) {
687 const SCEV *StartExpr = AddRec->getStart();
688 if (StartExpr->isZero()) {
689 auto StepPair = extractConstantFactor(AddRec->getStepRecurrence(SE), SE);
690 const SCEV *LeftOverAddRec =
691 SE.getAddRecExpr(StartExpr, StepPair.second, AddRec->getLoop(),
692 AddRec->getNoWrapFlags());
693 return std::make_pair(StepPair.first, LeftOverAddRec);
695 return std::make_pair(ConstPart, S);
698 if (auto *Add = dyn_cast<SCEVAddExpr>(S)) {
699 SmallVector<const SCEV *, 4> LeftOvers;
700 auto Op0Pair = extractConstantFactor(Add->getOperand(0), SE);
701 auto *Factor = Op0Pair.first;
702 if (SE.isKnownNegative(Factor)) {
703 Factor = cast<SCEVConstant>(SE.getNegativeSCEV(Factor));
704 LeftOvers.push_back(SE.getNegativeSCEV(Op0Pair.second));
705 } else {
706 LeftOvers.push_back(Op0Pair.second);
709 for (unsigned u = 1, e = Add->getNumOperands(); u < e; u++) {
710 auto OpUPair = extractConstantFactor(Add->getOperand(u), SE);
711 // TODO: Use something smarter than equality here, e.g., gcd.
712 if (Factor == OpUPair.first)
713 LeftOvers.push_back(OpUPair.second);
714 else if (Factor == SE.getNegativeSCEV(OpUPair.first))
715 LeftOvers.push_back(SE.getNegativeSCEV(OpUPair.second));
716 else
717 return std::make_pair(ConstPart, S);
720 const SCEV *NewAdd = SE.getAddExpr(LeftOvers, Add->getNoWrapFlags());
721 return std::make_pair(Factor, NewAdd);
724 auto *Mul = dyn_cast<SCEVMulExpr>(S);
725 if (!Mul)
726 return std::make_pair(ConstPart, S);
728 SmallVector<const SCEV *, 4> LeftOvers;
729 for (const SCEV *Op : Mul->operands())
730 if (isa<SCEVConstant>(Op))
731 ConstPart = cast<SCEVConstant>(SE.getMulExpr(ConstPart, Op));
732 else
733 LeftOvers.push_back(Op);
735 return std::make_pair(ConstPart, SE.getMulExpr(LeftOvers));
738 const SCEV *polly::tryForwardThroughPHI(const SCEV *Expr, Region &R,
739 ScalarEvolution &SE,
740 ScopDetection *SD) {
741 if (auto *Unknown = dyn_cast<SCEVUnknown>(Expr)) {
742 Value *V = Unknown->getValue();
743 auto *PHI = dyn_cast<PHINode>(V);
744 if (!PHI)
745 return Expr;
747 Value *Final = nullptr;
749 for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) {
750 BasicBlock *Incoming = PHI->getIncomingBlock(i);
751 if (SD->isErrorBlock(*Incoming, R) && R.contains(Incoming))
752 continue;
753 if (Final)
754 return Expr;
755 Final = PHI->getIncomingValue(i);
758 if (Final)
759 return SE.getSCEV(Final);
761 return Expr;
764 Value *polly::getUniqueNonErrorValue(PHINode *PHI, Region *R,
765 ScopDetection *SD) {
766 Value *V = nullptr;
767 for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) {
768 BasicBlock *BB = PHI->getIncomingBlock(i);
769 if (!SD->isErrorBlock(*BB, *R)) {
770 if (V)
771 return nullptr;
772 V = PHI->getIncomingValue(i);
776 return V;