1 //===- CorrelatedValuePropagation.cpp - Propagate CFG-derived info --------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the Correlated Value Propagation pass.
11 //===----------------------------------------------------------------------===//
13 #include "llvm/Transforms/Scalar/CorrelatedValuePropagation.h"
14 #include "llvm/ADT/DepthFirstIterator.h"
15 #include "llvm/ADT/Optional.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/Analysis/DomTreeUpdater.h"
19 #include "llvm/Analysis/GlobalsModRef.h"
20 #include "llvm/Analysis/InstructionSimplify.h"
21 #include "llvm/Analysis/LazyValueInfo.h"
22 #include "llvm/Analysis/ValueTracking.h"
23 #include "llvm/IR/Attributes.h"
24 #include "llvm/IR/BasicBlock.h"
25 #include "llvm/IR/CFG.h"
26 #include "llvm/IR/Constant.h"
27 #include "llvm/IR/ConstantRange.h"
28 #include "llvm/IR/Constants.h"
29 #include "llvm/IR/DerivedTypes.h"
30 #include "llvm/IR/Function.h"
31 #include "llvm/IR/IRBuilder.h"
32 #include "llvm/IR/InstrTypes.h"
33 #include "llvm/IR/Instruction.h"
34 #include "llvm/IR/Instructions.h"
35 #include "llvm/IR/IntrinsicInst.h"
36 #include "llvm/IR/Operator.h"
37 #include "llvm/IR/PassManager.h"
38 #include "llvm/IR/Type.h"
39 #include "llvm/IR/Value.h"
40 #include "llvm/InitializePasses.h"
41 #include "llvm/Pass.h"
42 #include "llvm/Support/Casting.h"
43 #include "llvm/Support/CommandLine.h"
44 #include "llvm/Support/Debug.h"
45 #include "llvm/Support/raw_ostream.h"
46 #include "llvm/Transforms/Scalar.h"
47 #include "llvm/Transforms/Utils/Local.h"
53 #define DEBUG_TYPE "correlated-value-propagation"
55 STATISTIC(NumPhis
, "Number of phis propagated");
56 STATISTIC(NumPhiCommon
, "Number of phis deleted via common incoming value");
57 STATISTIC(NumSelects
, "Number of selects propagated");
58 STATISTIC(NumMemAccess
, "Number of memory access targets propagated");
59 STATISTIC(NumCmps
, "Number of comparisons propagated");
60 STATISTIC(NumReturns
, "Number of return values propagated");
61 STATISTIC(NumDeadCases
, "Number of switch cases removed");
62 STATISTIC(NumSDivSRemsNarrowed
,
63 "Number of sdivs/srems whose width was decreased");
64 STATISTIC(NumSDivs
, "Number of sdiv converted to udiv");
65 STATISTIC(NumUDivURemsNarrowed
,
66 "Number of udivs/urems whose width was decreased");
67 STATISTIC(NumAShrs
, "Number of ashr converted to lshr");
68 STATISTIC(NumSRems
, "Number of srem converted to urem");
69 STATISTIC(NumSExt
, "Number of sext converted to zext");
70 STATISTIC(NumAnd
, "Number of ands removed");
71 STATISTIC(NumNW
, "Number of no-wrap deductions");
72 STATISTIC(NumNSW
, "Number of no-signed-wrap deductions");
73 STATISTIC(NumNUW
, "Number of no-unsigned-wrap deductions");
74 STATISTIC(NumAddNW
, "Number of no-wrap deductions for add");
75 STATISTIC(NumAddNSW
, "Number of no-signed-wrap deductions for add");
76 STATISTIC(NumAddNUW
, "Number of no-unsigned-wrap deductions for add");
77 STATISTIC(NumSubNW
, "Number of no-wrap deductions for sub");
78 STATISTIC(NumSubNSW
, "Number of no-signed-wrap deductions for sub");
79 STATISTIC(NumSubNUW
, "Number of no-unsigned-wrap deductions for sub");
80 STATISTIC(NumMulNW
, "Number of no-wrap deductions for mul");
81 STATISTIC(NumMulNSW
, "Number of no-signed-wrap deductions for mul");
82 STATISTIC(NumMulNUW
, "Number of no-unsigned-wrap deductions for mul");
83 STATISTIC(NumShlNW
, "Number of no-wrap deductions for shl");
84 STATISTIC(NumShlNSW
, "Number of no-signed-wrap deductions for shl");
85 STATISTIC(NumShlNUW
, "Number of no-unsigned-wrap deductions for shl");
86 STATISTIC(NumAbs
, "Number of llvm.abs intrinsics removed");
87 STATISTIC(NumOverflows
, "Number of overflow checks removed");
88 STATISTIC(NumSaturating
,
89 "Number of saturating arithmetics converted to normal arithmetics");
90 STATISTIC(NumNonNull
, "Number of function pointer arguments marked non-null");
91 STATISTIC(NumMinMax
, "Number of llvm.[us]{min,max} intrinsics removed");
95 class CorrelatedValuePropagation
: public FunctionPass
{
99 CorrelatedValuePropagation(): FunctionPass(ID
) {
100 initializeCorrelatedValuePropagationPass(*PassRegistry::getPassRegistry());
103 bool runOnFunction(Function
&F
) override
;
105 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
106 AU
.addRequired
<DominatorTreeWrapperPass
>();
107 AU
.addRequired
<LazyValueInfoWrapperPass
>();
108 AU
.addPreserved
<GlobalsAAWrapperPass
>();
109 AU
.addPreserved
<DominatorTreeWrapperPass
>();
110 AU
.addPreserved
<LazyValueInfoWrapperPass
>();
114 } // end anonymous namespace
116 char CorrelatedValuePropagation::ID
= 0;
118 INITIALIZE_PASS_BEGIN(CorrelatedValuePropagation
, "correlated-propagation",
119 "Value Propagation", false, false)
120 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass
)
121 INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass
)
122 INITIALIZE_PASS_END(CorrelatedValuePropagation
, "correlated-propagation",
123 "Value Propagation", false, false)
125 // Public interface to the Value Propagation pass
126 Pass
*llvm::createCorrelatedValuePropagationPass() {
127 return new CorrelatedValuePropagation();
130 static bool processSelect(SelectInst
*S
, LazyValueInfo
*LVI
) {
131 if (S
->getType()->isVectorTy()) return false;
132 if (isa
<Constant
>(S
->getCondition())) return false;
134 Constant
*C
= LVI
->getConstant(S
->getCondition(), S
);
135 if (!C
) return false;
137 ConstantInt
*CI
= dyn_cast
<ConstantInt
>(C
);
138 if (!CI
) return false;
140 Value
*ReplaceWith
= CI
->isOne() ? S
->getTrueValue() : S
->getFalseValue();
141 S
->replaceAllUsesWith(ReplaceWith
);
142 S
->eraseFromParent();
149 /// Try to simplify a phi with constant incoming values that match the edge
150 /// values of a non-constant value on all other edges:
152 /// %isnull = icmp eq i8* %x, null
153 /// br i1 %isnull, label %bb2, label %bb1
157 /// %r = phi i8* [ %x, %bb1 ], [ null, %bb0 ]
160 static bool simplifyCommonValuePhi(PHINode
*P
, LazyValueInfo
*LVI
,
162 // Collect incoming constants and initialize possible common value.
163 SmallVector
<std::pair
<Constant
*, unsigned>, 4> IncomingConstants
;
164 Value
*CommonValue
= nullptr;
165 for (unsigned i
= 0, e
= P
->getNumIncomingValues(); i
!= e
; ++i
) {
166 Value
*Incoming
= P
->getIncomingValue(i
);
167 if (auto *IncomingConstant
= dyn_cast
<Constant
>(Incoming
)) {
168 IncomingConstants
.push_back(std::make_pair(IncomingConstant
, i
));
169 } else if (!CommonValue
) {
170 // The potential common value is initialized to the first non-constant.
171 CommonValue
= Incoming
;
172 } else if (Incoming
!= CommonValue
) {
173 // There can be only one non-constant common value.
178 if (!CommonValue
|| IncomingConstants
.empty())
181 // The common value must be valid in all incoming blocks.
182 BasicBlock
*ToBB
= P
->getParent();
183 if (auto *CommonInst
= dyn_cast
<Instruction
>(CommonValue
))
184 if (!DT
->dominates(CommonInst
, ToBB
))
187 // We have a phi with exactly 1 variable incoming value and 1 or more constant
188 // incoming values. See if all constant incoming values can be mapped back to
189 // the same incoming variable value.
190 for (auto &IncomingConstant
: IncomingConstants
) {
191 Constant
*C
= IncomingConstant
.first
;
192 BasicBlock
*IncomingBB
= P
->getIncomingBlock(IncomingConstant
.second
);
193 if (C
!= LVI
->getConstantOnEdge(CommonValue
, IncomingBB
, ToBB
, P
))
197 // LVI only guarantees that the value matches a certain constant if the value
198 // is not poison. Make sure we don't replace a well-defined value with poison.
199 // This is usually satisfied due to a prior branch on the value.
200 if (!isGuaranteedNotToBePoison(CommonValue
, nullptr, P
, DT
))
203 // All constant incoming values map to the same variable along the incoming
204 // edges of the phi. The phi is unnecessary.
205 P
->replaceAllUsesWith(CommonValue
);
206 P
->eraseFromParent();
211 static bool processPHI(PHINode
*P
, LazyValueInfo
*LVI
, DominatorTree
*DT
,
212 const SimplifyQuery
&SQ
) {
213 bool Changed
= false;
215 BasicBlock
*BB
= P
->getParent();
216 for (unsigned i
= 0, e
= P
->getNumIncomingValues(); i
< e
; ++i
) {
217 Value
*Incoming
= P
->getIncomingValue(i
);
218 if (isa
<Constant
>(Incoming
)) continue;
220 Value
*V
= LVI
->getConstantOnEdge(Incoming
, P
->getIncomingBlock(i
), BB
, P
);
222 // Look if the incoming value is a select with a scalar condition for which
223 // LVI can tells us the value. In that case replace the incoming value with
224 // the appropriate value of the select. This often allows us to remove the
227 SelectInst
*SI
= dyn_cast
<SelectInst
>(Incoming
);
230 Value
*Condition
= SI
->getCondition();
231 if (!Condition
->getType()->isVectorTy()) {
232 if (Constant
*C
= LVI
->getConstantOnEdge(
233 Condition
, P
->getIncomingBlock(i
), BB
, P
)) {
234 if (C
->isOneValue()) {
235 V
= SI
->getTrueValue();
236 } else if (C
->isZeroValue()) {
237 V
= SI
->getFalseValue();
239 // Once LVI learns to handle vector types, we could also add support
240 // for vector type constants that are not all zeroes or all ones.
244 // Look if the select has a constant but LVI tells us that the incoming
245 // value can never be that constant. In that case replace the incoming
246 // value with the other value of the select. This often allows us to
247 // remove the select later.
249 Constant
*C
= dyn_cast
<Constant
>(SI
->getFalseValue());
252 if (LVI
->getPredicateOnEdge(ICmpInst::ICMP_EQ
, SI
, C
,
253 P
->getIncomingBlock(i
), BB
, P
) !=
254 LazyValueInfo::False
)
256 V
= SI
->getTrueValue();
259 LLVM_DEBUG(dbgs() << "CVP: Threading PHI over " << *SI
<< '\n');
262 P
->setIncomingValue(i
, V
);
266 if (Value
*V
= SimplifyInstruction(P
, SQ
)) {
267 P
->replaceAllUsesWith(V
);
268 P
->eraseFromParent();
273 Changed
= simplifyCommonValuePhi(P
, LVI
, DT
);
281 static bool processMemAccess(Instruction
*I
, LazyValueInfo
*LVI
) {
282 Value
*Pointer
= nullptr;
283 if (LoadInst
*L
= dyn_cast
<LoadInst
>(I
))
284 Pointer
= L
->getPointerOperand();
286 Pointer
= cast
<StoreInst
>(I
)->getPointerOperand();
288 if (isa
<Constant
>(Pointer
)) return false;
290 Constant
*C
= LVI
->getConstant(Pointer
, I
);
291 if (!C
) return false;
294 I
->replaceUsesOfWith(Pointer
, C
);
298 /// See if LazyValueInfo's ability to exploit edge conditions or range
299 /// information is sufficient to prove this comparison. Even for local
300 /// conditions, this can sometimes prove conditions instcombine can't by
301 /// exploiting range information.
302 static bool processCmp(CmpInst
*Cmp
, LazyValueInfo
*LVI
) {
303 Value
*Op0
= Cmp
->getOperand(0);
304 auto *C
= dyn_cast
<Constant
>(Cmp
->getOperand(1));
308 LazyValueInfo::Tristate Result
=
309 LVI
->getPredicateAt(Cmp
->getPredicate(), Op0
, C
, Cmp
,
310 /*UseBlockValue=*/true);
311 if (Result
== LazyValueInfo::Unknown
)
315 Constant
*TorF
= ConstantInt::get(Type::getInt1Ty(Cmp
->getContext()), Result
);
316 Cmp
->replaceAllUsesWith(TorF
);
317 Cmp
->eraseFromParent();
321 /// Simplify a switch instruction by removing cases which can never fire. If the
322 /// uselessness of a case could be determined locally then constant propagation
323 /// would already have figured it out. Instead, walk the predecessors and
324 /// statically evaluate cases based on information available on that edge. Cases
325 /// that cannot fire no matter what the incoming edge can safely be removed. If
326 /// a case fires on every incoming edge then the entire switch can be removed
327 /// and replaced with a branch to the case destination.
328 static bool processSwitch(SwitchInst
*I
, LazyValueInfo
*LVI
,
330 DomTreeUpdater
DTU(*DT
, DomTreeUpdater::UpdateStrategy::Lazy
);
331 Value
*Cond
= I
->getCondition();
332 BasicBlock
*BB
= I
->getParent();
334 // Analyse each switch case in turn.
335 bool Changed
= false;
336 DenseMap
<BasicBlock
*, int> SuccessorsCount
;
337 for (auto *Succ
: successors(BB
))
338 SuccessorsCount
[Succ
]++;
340 { // Scope for SwitchInstProfUpdateWrapper. It must not live during
341 // ConstantFoldTerminator() as the underlying SwitchInst can be changed.
342 SwitchInstProfUpdateWrapper
SI(*I
);
344 for (auto CI
= SI
->case_begin(), CE
= SI
->case_end(); CI
!= CE
;) {
345 ConstantInt
*Case
= CI
->getCaseValue();
346 LazyValueInfo::Tristate State
=
347 LVI
->getPredicateAt(CmpInst::ICMP_EQ
, Cond
, Case
, I
,
348 /* UseBlockValue */ true);
350 if (State
== LazyValueInfo::False
) {
351 // This case never fires - remove it.
352 BasicBlock
*Succ
= CI
->getCaseSuccessor();
353 Succ
->removePredecessor(BB
);
354 CI
= SI
.removeCase(CI
);
357 // The condition can be modified by removePredecessor's PHI simplification
359 Cond
= SI
->getCondition();
363 if (--SuccessorsCount
[Succ
] == 0)
364 DTU
.applyUpdatesPermissive({{DominatorTree::Delete
, BB
, Succ
}});
367 if (State
== LazyValueInfo::True
) {
368 // This case always fires. Arrange for the switch to be turned into an
369 // unconditional branch by replacing the switch condition with the case
371 SI
->setCondition(Case
);
372 NumDeadCases
+= SI
->getNumCases();
377 // Increment the case iterator since we didn't delete it.
383 // If the switch has been simplified to the point where it can be replaced
384 // by a branch then do so now.
385 ConstantFoldTerminator(BB
, /*DeleteDeadConditions = */ false,
386 /*TLI = */ nullptr, &DTU
);
390 // See if we can prove that the given binary op intrinsic will not overflow.
391 static bool willNotOverflow(BinaryOpIntrinsic
*BO
, LazyValueInfo
*LVI
) {
392 ConstantRange LRange
= LVI
->getConstantRange(BO
->getLHS(), BO
);
393 ConstantRange RRange
= LVI
->getConstantRange(BO
->getRHS(), BO
);
394 ConstantRange NWRegion
= ConstantRange::makeGuaranteedNoWrapRegion(
395 BO
->getBinaryOp(), RRange
, BO
->getNoWrapKind());
396 return NWRegion
.contains(LRange
);
399 static void setDeducedOverflowingFlags(Value
*V
, Instruction::BinaryOps Opcode
,
400 bool NewNSW
, bool NewNUW
) {
401 Statistic
*OpcNW
, *OpcNSW
, *OpcNUW
;
403 case Instruction::Add
:
408 case Instruction::Sub
:
413 case Instruction::Mul
:
418 case Instruction::Shl
:
424 llvm_unreachable("Will not be called with other binops");
427 auto *Inst
= dyn_cast
<Instruction
>(V
);
434 Inst
->setHasNoSignedWrap();
442 Inst
->setHasNoUnsignedWrap();
446 static bool processBinOp(BinaryOperator
*BinOp
, LazyValueInfo
*LVI
);
448 // See if @llvm.abs argument is alays positive/negative, and simplify.
449 // Notably, INT_MIN can belong to either range, regardless of the NSW,
450 // because it is negation-invariant.
451 static bool processAbsIntrinsic(IntrinsicInst
*II
, LazyValueInfo
*LVI
) {
452 Value
*X
= II
->getArgOperand(0);
453 bool IsIntMinPoison
= cast
<ConstantInt
>(II
->getArgOperand(1))->isOne();
455 Type
*Ty
= X
->getType();
457 ConstantInt::get(Ty
, APInt::getSignedMinValue(Ty
->getScalarSizeInBits()));
458 LazyValueInfo::Tristate Result
;
460 // Is X in [0, IntMin]? NOTE: INT_MIN is fine!
461 Result
= LVI
->getPredicateAt(CmpInst::Predicate::ICMP_ULE
, X
, IntMin
, II
,
462 /*UseBlockValue=*/true);
463 if (Result
== LazyValueInfo::True
) {
465 II
->replaceAllUsesWith(X
);
466 II
->eraseFromParent();
470 // Is X in [IntMin, 0]? NOTE: INT_MIN is fine!
471 Constant
*Zero
= ConstantInt::getNullValue(Ty
);
472 Result
= LVI
->getPredicateAt(CmpInst::Predicate::ICMP_SLE
, X
, Zero
, II
,
473 /*UseBlockValue=*/true);
474 assert(Result
!= LazyValueInfo::False
&& "Should have been handled already.");
476 if (Result
== LazyValueInfo::Unknown
) {
477 // Argument's range crosses zero.
478 bool Changed
= false;
479 if (!IsIntMinPoison
) {
480 // Can we at least tell that the argument is never INT_MIN?
481 Result
= LVI
->getPredicateAt(CmpInst::Predicate::ICMP_NE
, X
, IntMin
, II
,
482 /*UseBlockValue=*/true);
483 if (Result
== LazyValueInfo::True
) {
486 II
->setArgOperand(1, ConstantInt::getTrue(II
->getContext()));
494 Value
*NegX
= B
.CreateNeg(X
, II
->getName(), /*HasNUW=*/false,
495 /*HasNSW=*/IsIntMinPoison
);
497 II
->replaceAllUsesWith(NegX
);
498 II
->eraseFromParent();
500 // See if we can infer some no-wrap flags.
501 if (auto *BO
= dyn_cast
<BinaryOperator
>(NegX
))
502 processBinOp(BO
, LVI
);
507 // See if this min/max intrinsic always picks it's one specific operand.
508 static bool processMinMaxIntrinsic(MinMaxIntrinsic
*MM
, LazyValueInfo
*LVI
) {
509 CmpInst::Predicate Pred
= CmpInst::getNonStrictPredicate(MM
->getPredicate());
510 LazyValueInfo::Tristate Result
= LVI
->getPredicateAt(
511 Pred
, MM
->getLHS(), MM
->getRHS(), MM
, /*UseBlockValue=*/true);
512 if (Result
== LazyValueInfo::Unknown
)
516 MM
->replaceAllUsesWith(MM
->getOperand(!Result
));
517 MM
->eraseFromParent();
521 // Rewrite this with.overflow intrinsic as non-overflowing.
522 static bool processOverflowIntrinsic(WithOverflowInst
*WO
, LazyValueInfo
*LVI
) {
524 Instruction::BinaryOps Opcode
= WO
->getBinaryOp();
525 bool NSW
= WO
->isSigned();
526 bool NUW
= !WO
->isSigned();
529 B
.CreateBinOp(Opcode
, WO
->getLHS(), WO
->getRHS(), WO
->getName());
530 setDeducedOverflowingFlags(NewOp
, Opcode
, NSW
, NUW
);
532 StructType
*ST
= cast
<StructType
>(WO
->getType());
533 Constant
*Struct
= ConstantStruct::get(ST
,
534 { UndefValue::get(ST
->getElementType(0)),
535 ConstantInt::getFalse(ST
->getElementType(1)) });
536 Value
*NewI
= B
.CreateInsertValue(Struct
, NewOp
, 0);
537 WO
->replaceAllUsesWith(NewI
);
538 WO
->eraseFromParent();
541 // See if we can infer the other no-wrap too.
542 if (auto *BO
= dyn_cast
<BinaryOperator
>(NewOp
))
543 processBinOp(BO
, LVI
);
548 static bool processSaturatingInst(SaturatingInst
*SI
, LazyValueInfo
*LVI
) {
549 Instruction::BinaryOps Opcode
= SI
->getBinaryOp();
550 bool NSW
= SI
->isSigned();
551 bool NUW
= !SI
->isSigned();
552 BinaryOperator
*BinOp
= BinaryOperator::Create(
553 Opcode
, SI
->getLHS(), SI
->getRHS(), SI
->getName(), SI
);
554 BinOp
->setDebugLoc(SI
->getDebugLoc());
555 setDeducedOverflowingFlags(BinOp
, Opcode
, NSW
, NUW
);
557 SI
->replaceAllUsesWith(BinOp
);
558 SI
->eraseFromParent();
561 // See if we can infer the other no-wrap too.
562 if (auto *BO
= dyn_cast
<BinaryOperator
>(BinOp
))
563 processBinOp(BO
, LVI
);
568 /// Infer nonnull attributes for the arguments at the specified callsite.
569 static bool processCallSite(CallBase
&CB
, LazyValueInfo
*LVI
) {
571 if (CB
.getIntrinsicID() == Intrinsic::abs
) {
572 return processAbsIntrinsic(&cast
<IntrinsicInst
>(CB
), LVI
);
575 if (auto *MM
= dyn_cast
<MinMaxIntrinsic
>(&CB
)) {
576 return processMinMaxIntrinsic(MM
, LVI
);
579 if (auto *WO
= dyn_cast
<WithOverflowInst
>(&CB
)) {
580 if (WO
->getLHS()->getType()->isIntegerTy() && willNotOverflow(WO
, LVI
)) {
581 return processOverflowIntrinsic(WO
, LVI
);
585 if (auto *SI
= dyn_cast
<SaturatingInst
>(&CB
)) {
586 if (SI
->getType()->isIntegerTy() && willNotOverflow(SI
, LVI
)) {
587 return processSaturatingInst(SI
, LVI
);
591 bool Changed
= false;
593 // Deopt bundle operands are intended to capture state with minimal
594 // perturbance of the code otherwise. If we can find a constant value for
595 // any such operand and remove a use of the original value, that's
596 // desireable since it may allow further optimization of that value (e.g. via
597 // single use rules in instcombine). Since deopt uses tend to,
598 // idiomatically, appear along rare conditional paths, it's reasonable likely
599 // we may have a conditional fact with which LVI can fold.
600 if (auto DeoptBundle
= CB
.getOperandBundle(LLVMContext::OB_deopt
)) {
601 for (const Use
&ConstU
: DeoptBundle
->Inputs
) {
602 Use
&U
= const_cast<Use
&>(ConstU
);
604 if (V
->getType()->isVectorTy()) continue;
605 if (isa
<Constant
>(V
)) continue;
607 Constant
*C
= LVI
->getConstant(V
, &CB
);
614 SmallVector
<unsigned, 4> ArgNos
;
617 for (Value
*V
: CB
.args()) {
618 PointerType
*Type
= dyn_cast
<PointerType
>(V
->getType());
619 // Try to mark pointer typed parameters as non-null. We skip the
620 // relatively expensive analysis for constants which are obviously either
621 // null or non-null to start with.
622 if (Type
&& !CB
.paramHasAttr(ArgNo
, Attribute::NonNull
) &&
624 LVI
->getPredicateAt(ICmpInst::ICMP_EQ
, V
,
625 ConstantPointerNull::get(Type
), &CB
,
626 /*UseBlockValue=*/false) == LazyValueInfo::False
)
627 ArgNos
.push_back(ArgNo
);
631 assert(ArgNo
== CB
.arg_size() && "sanity check");
636 NumNonNull
+= ArgNos
.size();
637 AttributeList AS
= CB
.getAttributes();
638 LLVMContext
&Ctx
= CB
.getContext();
639 AS
= AS
.addParamAttribute(Ctx
, ArgNos
,
640 Attribute::get(Ctx
, Attribute::NonNull
));
641 CB
.setAttributes(AS
);
646 static bool isNonNegative(Value
*V
, LazyValueInfo
*LVI
, Instruction
*CxtI
) {
647 Constant
*Zero
= ConstantInt::get(V
->getType(), 0);
648 auto Result
= LVI
->getPredicateAt(ICmpInst::ICMP_SGE
, V
, Zero
, CxtI
,
649 /*UseBlockValue=*/true);
650 return Result
== LazyValueInfo::True
;
653 static bool isNonPositive(Value
*V
, LazyValueInfo
*LVI
, Instruction
*CxtI
) {
654 Constant
*Zero
= ConstantInt::get(V
->getType(), 0);
655 auto Result
= LVI
->getPredicateAt(ICmpInst::ICMP_SLE
, V
, Zero
, CxtI
,
656 /*UseBlockValue=*/true);
657 return Result
== LazyValueInfo::True
;
660 enum class Domain
{ NonNegative
, NonPositive
, Unknown
};
662 Domain
getDomain(Value
*V
, LazyValueInfo
*LVI
, Instruction
*CxtI
) {
663 if (isNonNegative(V
, LVI
, CxtI
))
664 return Domain::NonNegative
;
665 if (isNonPositive(V
, LVI
, CxtI
))
666 return Domain::NonPositive
;
667 return Domain::Unknown
;
670 /// Try to shrink a sdiv/srem's width down to the smallest power of two that's
671 /// sufficient to contain its operands.
672 static bool narrowSDivOrSRem(BinaryOperator
*Instr
, LazyValueInfo
*LVI
) {
673 assert(Instr
->getOpcode() == Instruction::SDiv
||
674 Instr
->getOpcode() == Instruction::SRem
);
675 if (Instr
->getType()->isVectorTy())
678 // Find the smallest power of two bitwidth that's sufficient to hold Instr's
680 unsigned OrigWidth
= Instr
->getType()->getIntegerBitWidth();
682 // What is the smallest bit width that can accomodate the entire value ranges
683 // of both of the operands?
684 std::array
<Optional
<ConstantRange
>, 2> CRs
;
685 unsigned MinSignedBits
= 0;
686 for (auto I
: zip(Instr
->operands(), CRs
)) {
687 std::get
<1>(I
) = LVI
->getConstantRange(std::get
<0>(I
), Instr
);
688 MinSignedBits
= std::max(std::get
<1>(I
)->getMinSignedBits(), MinSignedBits
);
691 // sdiv/srem is UB if divisor is -1 and divident is INT_MIN, so unless we can
692 // prove that such a combination is impossible, we need to bump the bitwidth.
693 if (CRs
[1]->contains(APInt::getAllOnesValue(OrigWidth
)) &&
695 APInt::getSignedMinValue(MinSignedBits
).sextOrSelf(OrigWidth
)))
698 // Don't shrink below 8 bits wide.
699 unsigned NewWidth
= std::max
<unsigned>(PowerOf2Ceil(MinSignedBits
), 8);
701 // NewWidth might be greater than OrigWidth if OrigWidth is not a power of
703 if (NewWidth
>= OrigWidth
)
706 ++NumSDivSRemsNarrowed
;
707 IRBuilder
<> B
{Instr
};
708 auto *TruncTy
= Type::getIntNTy(Instr
->getContext(), NewWidth
);
709 auto *LHS
= B
.CreateTruncOrBitCast(Instr
->getOperand(0), TruncTy
,
710 Instr
->getName() + ".lhs.trunc");
711 auto *RHS
= B
.CreateTruncOrBitCast(Instr
->getOperand(1), TruncTy
,
712 Instr
->getName() + ".rhs.trunc");
713 auto *BO
= B
.CreateBinOp(Instr
->getOpcode(), LHS
, RHS
, Instr
->getName());
714 auto *Sext
= B
.CreateSExt(BO
, Instr
->getType(), Instr
->getName() + ".sext");
715 if (auto *BinOp
= dyn_cast
<BinaryOperator
>(BO
))
716 if (BinOp
->getOpcode() == Instruction::SDiv
)
717 BinOp
->setIsExact(Instr
->isExact());
719 Instr
->replaceAllUsesWith(Sext
);
720 Instr
->eraseFromParent();
724 /// Try to shrink a udiv/urem's width down to the smallest power of two that's
725 /// sufficient to contain its operands.
726 static bool processUDivOrURem(BinaryOperator
*Instr
, LazyValueInfo
*LVI
) {
727 assert(Instr
->getOpcode() == Instruction::UDiv
||
728 Instr
->getOpcode() == Instruction::URem
);
729 if (Instr
->getType()->isVectorTy())
732 // Find the smallest power of two bitwidth that's sufficient to hold Instr's
735 // What is the smallest bit width that can accomodate the entire value ranges
736 // of both of the operands?
737 unsigned MaxActiveBits
= 0;
738 for (Value
*Operand
: Instr
->operands()) {
739 ConstantRange CR
= LVI
->getConstantRange(Operand
, Instr
);
740 MaxActiveBits
= std::max(CR
.getActiveBits(), MaxActiveBits
);
742 // Don't shrink below 8 bits wide.
743 unsigned NewWidth
= std::max
<unsigned>(PowerOf2Ceil(MaxActiveBits
), 8);
745 // NewWidth might be greater than OrigWidth if OrigWidth is not a power of
747 if (NewWidth
>= Instr
->getType()->getIntegerBitWidth())
750 ++NumUDivURemsNarrowed
;
751 IRBuilder
<> B
{Instr
};
752 auto *TruncTy
= Type::getIntNTy(Instr
->getContext(), NewWidth
);
753 auto *LHS
= B
.CreateTruncOrBitCast(Instr
->getOperand(0), TruncTy
,
754 Instr
->getName() + ".lhs.trunc");
755 auto *RHS
= B
.CreateTruncOrBitCast(Instr
->getOperand(1), TruncTy
,
756 Instr
->getName() + ".rhs.trunc");
757 auto *BO
= B
.CreateBinOp(Instr
->getOpcode(), LHS
, RHS
, Instr
->getName());
758 auto *Zext
= B
.CreateZExt(BO
, Instr
->getType(), Instr
->getName() + ".zext");
759 if (auto *BinOp
= dyn_cast
<BinaryOperator
>(BO
))
760 if (BinOp
->getOpcode() == Instruction::UDiv
)
761 BinOp
->setIsExact(Instr
->isExact());
763 Instr
->replaceAllUsesWith(Zext
);
764 Instr
->eraseFromParent();
768 static bool processSRem(BinaryOperator
*SDI
, LazyValueInfo
*LVI
) {
769 assert(SDI
->getOpcode() == Instruction::SRem
);
770 if (SDI
->getType()->isVectorTy())
777 std::array
<Operand
, 2> Ops
;
779 for (const auto I
: zip(Ops
, SDI
->operands())) {
780 Operand
&Op
= std::get
<0>(I
);
781 Op
.V
= std::get
<1>(I
);
782 Op
.D
= getDomain(Op
.V
, LVI
, SDI
);
783 if (Op
.D
== Domain::Unknown
)
787 // We know domains of both of the operands!
790 // We need operands to be non-negative, so negate each one that isn't.
791 for (Operand
&Op
: Ops
) {
792 if (Op
.D
== Domain::NonNegative
)
795 BinaryOperator::CreateNeg(Op
.V
, Op
.V
->getName() + ".nonneg", SDI
);
796 BO
->setDebugLoc(SDI
->getDebugLoc());
801 BinaryOperator::CreateURem(Ops
[0].V
, Ops
[1].V
, SDI
->getName(), SDI
);
802 URem
->setDebugLoc(SDI
->getDebugLoc());
806 // If the divident was non-positive, we need to negate the result.
807 if (Ops
[0].D
== Domain::NonPositive
)
808 Res
= BinaryOperator::CreateNeg(Res
, Res
->getName() + ".neg", SDI
);
810 SDI
->replaceAllUsesWith(Res
);
811 SDI
->eraseFromParent();
813 // Try to simplify our new urem.
814 processUDivOrURem(URem
, LVI
);
819 /// See if LazyValueInfo's ability to exploit edge conditions or range
820 /// information is sufficient to prove the signs of both operands of this SDiv.
821 /// If this is the case, replace the SDiv with a UDiv. Even for local
822 /// conditions, this can sometimes prove conditions instcombine can't by
823 /// exploiting range information.
824 static bool processSDiv(BinaryOperator
*SDI
, LazyValueInfo
*LVI
) {
825 assert(SDI
->getOpcode() == Instruction::SDiv
);
826 if (SDI
->getType()->isVectorTy())
833 std::array
<Operand
, 2> Ops
;
835 for (const auto I
: zip(Ops
, SDI
->operands())) {
836 Operand
&Op
= std::get
<0>(I
);
837 Op
.V
= std::get
<1>(I
);
838 Op
.D
= getDomain(Op
.V
, LVI
, SDI
);
839 if (Op
.D
== Domain::Unknown
)
843 // We know domains of both of the operands!
846 // We need operands to be non-negative, so negate each one that isn't.
847 for (Operand
&Op
: Ops
) {
848 if (Op
.D
== Domain::NonNegative
)
851 BinaryOperator::CreateNeg(Op
.V
, Op
.V
->getName() + ".nonneg", SDI
);
852 BO
->setDebugLoc(SDI
->getDebugLoc());
857 BinaryOperator::CreateUDiv(Ops
[0].V
, Ops
[1].V
, SDI
->getName(), SDI
);
858 UDiv
->setDebugLoc(SDI
->getDebugLoc());
859 UDiv
->setIsExact(SDI
->isExact());
863 // If the operands had two different domains, we need to negate the result.
864 if (Ops
[0].D
!= Ops
[1].D
)
865 Res
= BinaryOperator::CreateNeg(Res
, Res
->getName() + ".neg", SDI
);
867 SDI
->replaceAllUsesWith(Res
);
868 SDI
->eraseFromParent();
870 // Try to simplify our new udiv.
871 processUDivOrURem(UDiv
, LVI
);
876 static bool processSDivOrSRem(BinaryOperator
*Instr
, LazyValueInfo
*LVI
) {
877 assert(Instr
->getOpcode() == Instruction::SDiv
||
878 Instr
->getOpcode() == Instruction::SRem
);
879 if (Instr
->getType()->isVectorTy())
882 if (Instr
->getOpcode() == Instruction::SDiv
)
883 if (processSDiv(Instr
, LVI
))
886 if (Instr
->getOpcode() == Instruction::SRem
)
887 if (processSRem(Instr
, LVI
))
890 return narrowSDivOrSRem(Instr
, LVI
);
893 static bool processAShr(BinaryOperator
*SDI
, LazyValueInfo
*LVI
) {
894 if (SDI
->getType()->isVectorTy())
897 if (!isNonNegative(SDI
->getOperand(0), LVI
, SDI
))
901 auto *BO
= BinaryOperator::CreateLShr(SDI
->getOperand(0), SDI
->getOperand(1),
902 SDI
->getName(), SDI
);
903 BO
->setDebugLoc(SDI
->getDebugLoc());
904 BO
->setIsExact(SDI
->isExact());
905 SDI
->replaceAllUsesWith(BO
);
906 SDI
->eraseFromParent();
911 static bool processSExt(SExtInst
*SDI
, LazyValueInfo
*LVI
) {
912 if (SDI
->getType()->isVectorTy())
915 Value
*Base
= SDI
->getOperand(0);
917 if (!isNonNegative(Base
, LVI
, SDI
))
922 CastInst::CreateZExtOrBitCast(Base
, SDI
->getType(), SDI
->getName(), SDI
);
923 ZExt
->setDebugLoc(SDI
->getDebugLoc());
924 SDI
->replaceAllUsesWith(ZExt
);
925 SDI
->eraseFromParent();
930 static bool processBinOp(BinaryOperator
*BinOp
, LazyValueInfo
*LVI
) {
931 using OBO
= OverflowingBinaryOperator
;
933 if (BinOp
->getType()->isVectorTy())
936 bool NSW
= BinOp
->hasNoSignedWrap();
937 bool NUW
= BinOp
->hasNoUnsignedWrap();
941 Instruction::BinaryOps Opcode
= BinOp
->getOpcode();
942 Value
*LHS
= BinOp
->getOperand(0);
943 Value
*RHS
= BinOp
->getOperand(1);
945 ConstantRange LRange
= LVI
->getConstantRange(LHS
, BinOp
);
946 ConstantRange RRange
= LVI
->getConstantRange(RHS
, BinOp
);
948 bool Changed
= false;
949 bool NewNUW
= false, NewNSW
= false;
951 ConstantRange NUWRange
= ConstantRange::makeGuaranteedNoWrapRegion(
952 Opcode
, RRange
, OBO::NoUnsignedWrap
);
953 NewNUW
= NUWRange
.contains(LRange
);
957 ConstantRange NSWRange
= ConstantRange::makeGuaranteedNoWrapRegion(
958 Opcode
, RRange
, OBO::NoSignedWrap
);
959 NewNSW
= NSWRange
.contains(LRange
);
963 setDeducedOverflowingFlags(BinOp
, Opcode
, NewNSW
, NewNUW
);
968 static bool processAnd(BinaryOperator
*BinOp
, LazyValueInfo
*LVI
) {
969 if (BinOp
->getType()->isVectorTy())
972 // Pattern match (and lhs, C) where C includes a superset of bits which might
973 // be set in lhs. This is a common truncation idiom created by instcombine.
974 Value
*LHS
= BinOp
->getOperand(0);
975 ConstantInt
*RHS
= dyn_cast
<ConstantInt
>(BinOp
->getOperand(1));
976 if (!RHS
|| !RHS
->getValue().isMask())
979 // We can only replace the AND with LHS based on range info if the range does
980 // not include undef.
981 ConstantRange LRange
=
982 LVI
->getConstantRange(LHS
, BinOp
, /*UndefAllowed=*/false);
983 if (!LRange
.getUnsignedMax().ule(RHS
->getValue()))
986 BinOp
->replaceAllUsesWith(LHS
);
987 BinOp
->eraseFromParent();
993 static Constant
*getConstantAt(Value
*V
, Instruction
*At
, LazyValueInfo
*LVI
) {
994 if (Constant
*C
= LVI
->getConstant(V
, At
))
997 // TODO: The following really should be sunk inside LVI's core algorithm, or
998 // at least the outer shims around such.
999 auto *C
= dyn_cast
<CmpInst
>(V
);
1000 if (!C
) return nullptr;
1002 Value
*Op0
= C
->getOperand(0);
1003 Constant
*Op1
= dyn_cast
<Constant
>(C
->getOperand(1));
1004 if (!Op1
) return nullptr;
1006 LazyValueInfo::Tristate Result
= LVI
->getPredicateAt(
1007 C
->getPredicate(), Op0
, Op1
, At
, /*UseBlockValue=*/false);
1008 if (Result
== LazyValueInfo::Unknown
)
1011 return (Result
== LazyValueInfo::True
) ?
1012 ConstantInt::getTrue(C
->getContext()) :
1013 ConstantInt::getFalse(C
->getContext());
1016 static bool runImpl(Function
&F
, LazyValueInfo
*LVI
, DominatorTree
*DT
,
1017 const SimplifyQuery
&SQ
) {
1018 bool FnChanged
= false;
1019 // Visiting in a pre-order depth-first traversal causes us to simplify early
1020 // blocks before querying later blocks (which require us to analyze early
1021 // blocks). Eagerly simplifying shallow blocks means there is strictly less
1022 // work to do for deep blocks. This also means we don't visit unreachable
1024 for (BasicBlock
*BB
: depth_first(&F
.getEntryBlock())) {
1025 bool BBChanged
= false;
1026 for (BasicBlock::iterator BI
= BB
->begin(), BE
= BB
->end(); BI
!= BE
;) {
1027 Instruction
*II
= &*BI
++;
1028 switch (II
->getOpcode()) {
1029 case Instruction::Select
:
1030 BBChanged
|= processSelect(cast
<SelectInst
>(II
), LVI
);
1032 case Instruction::PHI
:
1033 BBChanged
|= processPHI(cast
<PHINode
>(II
), LVI
, DT
, SQ
);
1035 case Instruction::ICmp
:
1036 case Instruction::FCmp
:
1037 BBChanged
|= processCmp(cast
<CmpInst
>(II
), LVI
);
1039 case Instruction::Load
:
1040 case Instruction::Store
:
1041 BBChanged
|= processMemAccess(II
, LVI
);
1043 case Instruction::Call
:
1044 case Instruction::Invoke
:
1045 BBChanged
|= processCallSite(cast
<CallBase
>(*II
), LVI
);
1047 case Instruction::SRem
:
1048 case Instruction::SDiv
:
1049 BBChanged
|= processSDivOrSRem(cast
<BinaryOperator
>(II
), LVI
);
1051 case Instruction::UDiv
:
1052 case Instruction::URem
:
1053 BBChanged
|= processUDivOrURem(cast
<BinaryOperator
>(II
), LVI
);
1055 case Instruction::AShr
:
1056 BBChanged
|= processAShr(cast
<BinaryOperator
>(II
), LVI
);
1058 case Instruction::SExt
:
1059 BBChanged
|= processSExt(cast
<SExtInst
>(II
), LVI
);
1061 case Instruction::Add
:
1062 case Instruction::Sub
:
1063 case Instruction::Mul
:
1064 case Instruction::Shl
:
1065 BBChanged
|= processBinOp(cast
<BinaryOperator
>(II
), LVI
);
1067 case Instruction::And
:
1068 BBChanged
|= processAnd(cast
<BinaryOperator
>(II
), LVI
);
1073 Instruction
*Term
= BB
->getTerminator();
1074 switch (Term
->getOpcode()) {
1075 case Instruction::Switch
:
1076 BBChanged
|= processSwitch(cast
<SwitchInst
>(Term
), LVI
, DT
);
1078 case Instruction::Ret
: {
1079 auto *RI
= cast
<ReturnInst
>(Term
);
1080 // Try to determine the return value if we can. This is mainly here to
1081 // simplify the writing of unit tests, but also helps to enable IPO by
1082 // constant folding the return values of callees.
1083 auto *RetVal
= RI
->getReturnValue();
1084 if (!RetVal
) break; // handle "ret void"
1085 if (isa
<Constant
>(RetVal
)) break; // nothing to do
1086 if (auto *C
= getConstantAt(RetVal
, RI
, LVI
)) {
1088 RI
->replaceUsesOfWith(RetVal
, C
);
1094 FnChanged
|= BBChanged
;
1100 bool CorrelatedValuePropagation::runOnFunction(Function
&F
) {
1101 if (skipFunction(F
))
1104 LazyValueInfo
*LVI
= &getAnalysis
<LazyValueInfoWrapperPass
>().getLVI();
1105 DominatorTree
*DT
= &getAnalysis
<DominatorTreeWrapperPass
>().getDomTree();
1107 return runImpl(F
, LVI
, DT
, getBestSimplifyQuery(*this, F
));
1111 CorrelatedValuePropagationPass::run(Function
&F
, FunctionAnalysisManager
&AM
) {
1112 LazyValueInfo
*LVI
= &AM
.getResult
<LazyValueAnalysis
>(F
);
1113 DominatorTree
*DT
= &AM
.getResult
<DominatorTreeAnalysis
>(F
);
1115 bool Changed
= runImpl(F
, LVI
, DT
, getBestSimplifyQuery(AM
, F
));
1117 PreservedAnalyses PA
;
1119 PA
= PreservedAnalyses::all();
1121 PA
.preserve
<DominatorTreeAnalysis
>();
1122 PA
.preserve
<LazyValueAnalysis
>();
1125 // Keeping LVI alive is expensive, both because it uses a lot of memory, and
1126 // because invalidating values in LVI is expensive. While CVP does preserve
1127 // LVI, we know that passes after JumpThreading+CVP will not need the result
1128 // of this analysis, so we forcefully discard it early.
1129 PA
.abandon
<LazyValueAnalysis
>();