1 //===- LoopVR.cpp - Value Range analysis driven by loop information -------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // FIXME: What does this do?
12 //===----------------------------------------------------------------------===//
14 #define DEBUG_TYPE "loopvr"
15 #include "llvm/Analysis/LoopVR.h"
16 #include "llvm/Constants.h"
17 #include "llvm/Instructions.h"
18 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
19 #include "llvm/Assembly/Writer.h"
20 #include "llvm/Support/CFG.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/raw_ostream.h"
26 static RegisterPass
<LoopVR
> X("loopvr", "Loop Value Ranges", false, true);
28 /// getRange - determine the range for a particular SCEV within a given Loop
29 ConstantRange
LoopVR::getRange(SCEVHandle S
, Loop
*L
, ScalarEvolution
&SE
) {
30 SCEVHandle T
= SE
.getBackedgeTakenCount(L
);
31 if (isa
<SCEVCouldNotCompute
>(T
))
32 return ConstantRange(cast
<IntegerType
>(S
->getType())->getBitWidth(), true);
34 T
= SE
.getTruncateOrZeroExtend(T
, S
->getType());
35 return getRange(S
, T
, SE
);
38 /// getRange - determine the range for a particular SCEV with a given trip count
39 ConstantRange
LoopVR::getRange(SCEVHandle S
, SCEVHandle T
, ScalarEvolution
&SE
){
41 if (const SCEVConstant
*C
= dyn_cast
<SCEVConstant
>(S
))
42 return ConstantRange(C
->getValue()->getValue());
44 ConstantRange
FullSet(cast
<IntegerType
>(S
->getType())->getBitWidth(), true);
46 // {x,+,y,+,...z}. We detect overflow by checking the size of the set after
47 // summing the upper and lower.
48 if (const SCEVAddExpr
*Add
= dyn_cast
<SCEVAddExpr
>(S
)) {
49 ConstantRange X
= getRange(Add
->getOperand(0), T
, SE
);
50 if (X
.isFullSet()) return FullSet
;
51 for (unsigned i
= 1, e
= Add
->getNumOperands(); i
!= e
; ++i
) {
52 ConstantRange Y
= getRange(Add
->getOperand(i
), T
, SE
);
53 if (Y
.isFullSet()) return FullSet
;
55 APInt Spread_X
= X
.getSetSize(), Spread_Y
= Y
.getSetSize();
56 APInt NewLower
= X
.getLower() + Y
.getLower();
57 APInt NewUpper
= X
.getUpper() + Y
.getUpper() - 1;
58 if (NewLower
== NewUpper
)
61 X
= ConstantRange(NewLower
, NewUpper
);
62 if (X
.getSetSize().ult(Spread_X
) || X
.getSetSize().ult(Spread_Y
))
63 return FullSet
; // we've wrapped, therefore, full set.
68 // {x,*,y,*,...,z}. In order to detect overflow, we use k*bitwidth where
69 // k is the number of terms being multiplied.
70 if (const SCEVMulExpr
*Mul
= dyn_cast
<SCEVMulExpr
>(S
)) {
71 ConstantRange X
= getRange(Mul
->getOperand(0), T
, SE
);
72 if (X
.isFullSet()) return FullSet
;
74 const IntegerType
*Ty
= IntegerType::get(X
.getBitWidth());
75 const IntegerType
*ExTy
= IntegerType::get(X
.getBitWidth() *
76 Mul
->getNumOperands());
77 ConstantRange XExt
= X
.zeroExtend(ExTy
->getBitWidth());
79 for (unsigned i
= 1, e
= Mul
->getNumOperands(); i
!= e
; ++i
) {
80 ConstantRange Y
= getRange(Mul
->getOperand(i
), T
, SE
);
81 if (Y
.isFullSet()) return FullSet
;
83 ConstantRange YExt
= Y
.zeroExtend(ExTy
->getBitWidth());
84 XExt
= ConstantRange(XExt
.getLower() * YExt
.getLower(),
85 ((XExt
.getUpper()-1) * (YExt
.getUpper()-1)) + 1);
87 return XExt
.truncate(Ty
->getBitWidth());
90 // X smax Y smax ... Z is: range(smax(X_smin, Y_smin, ..., Z_smin),
91 // smax(X_smax, Y_smax, ..., Z_smax))
92 // It doesn't matter if one of the SCEVs has FullSet because we're taking
93 // a maximum of the minimums across all of them.
94 if (const SCEVSMaxExpr
*SMax
= dyn_cast
<SCEVSMaxExpr
>(S
)) {
95 ConstantRange X
= getRange(SMax
->getOperand(0), T
, SE
);
96 if (X
.isFullSet()) return FullSet
;
98 APInt smin
= X
.getSignedMin(), smax
= X
.getSignedMax();
99 for (unsigned i
= 1, e
= SMax
->getNumOperands(); i
!= e
; ++i
) {
100 ConstantRange Y
= getRange(SMax
->getOperand(i
), T
, SE
);
101 smin
= APIntOps::smax(smin
, Y
.getSignedMin());
102 smax
= APIntOps::smax(smax
, Y
.getSignedMax());
104 if (smax
+ 1 == smin
) return FullSet
;
105 return ConstantRange(smin
, smax
+ 1);
108 // X umax Y umax ... Z is: range(umax(X_umin, Y_umin, ..., Z_umin),
109 // umax(X_umax, Y_umax, ..., Z_umax))
110 // It doesn't matter if one of the SCEVs has FullSet because we're taking
111 // a maximum of the minimums across all of them.
112 if (const SCEVUMaxExpr
*UMax
= dyn_cast
<SCEVUMaxExpr
>(S
)) {
113 ConstantRange X
= getRange(UMax
->getOperand(0), T
, SE
);
114 if (X
.isFullSet()) return FullSet
;
116 APInt umin
= X
.getUnsignedMin(), umax
= X
.getUnsignedMax();
117 for (unsigned i
= 1, e
= UMax
->getNumOperands(); i
!= e
; ++i
) {
118 ConstantRange Y
= getRange(UMax
->getOperand(i
), T
, SE
);
119 umin
= APIntOps::umax(umin
, Y
.getUnsignedMin());
120 umax
= APIntOps::umax(umax
, Y
.getUnsignedMax());
122 if (umax
+ 1 == umin
) return FullSet
;
123 return ConstantRange(umin
, umax
+ 1);
126 // L udiv R. Luckily, there's only ever 2 sides to a udiv.
127 if (const SCEVUDivExpr
*UDiv
= dyn_cast
<SCEVUDivExpr
>(S
)) {
128 ConstantRange L
= getRange(UDiv
->getLHS(), T
, SE
);
129 ConstantRange R
= getRange(UDiv
->getRHS(), T
, SE
);
130 if (L
.isFullSet() && R
.isFullSet()) return FullSet
;
132 if (R
.getUnsignedMax() == 0) {
133 // RHS must be single-element zero. Return an empty set.
134 return ConstantRange(R
.getBitWidth(), false);
137 APInt Lower
= L
.getUnsignedMin().udiv(R
.getUnsignedMax());
141 if (R
.getUnsignedMin() == 0) {
142 // Just because it contains zero, doesn't mean it will also contain one.
143 // Use maximalIntersectWith to get the right behaviour.
144 ConstantRange
NotZero(APInt(L
.getBitWidth(), 1),
145 APInt::getNullValue(L
.getBitWidth()));
146 R
= R
.maximalIntersectWith(NotZero
);
149 // But, the maximal intersection might still include zero. If it does, then
150 // we know it also included one.
151 if (R
.contains(APInt::getNullValue(L
.getBitWidth())))
152 Upper
= L
.getUnsignedMax();
154 Upper
= L
.getUnsignedMax().udiv(R
.getUnsignedMin());
156 return ConstantRange(Lower
, Upper
);
159 // ConstantRange already implements the cast operators.
161 if (const SCEVZeroExtendExpr
*ZExt
= dyn_cast
<SCEVZeroExtendExpr
>(S
)) {
162 T
= SE
.getTruncateOrZeroExtend(T
, ZExt
->getOperand()->getType());
163 ConstantRange X
= getRange(ZExt
->getOperand(), T
, SE
);
164 return X
.zeroExtend(cast
<IntegerType
>(ZExt
->getType())->getBitWidth());
167 if (const SCEVSignExtendExpr
*SExt
= dyn_cast
<SCEVSignExtendExpr
>(S
)) {
168 T
= SE
.getTruncateOrZeroExtend(T
, SExt
->getOperand()->getType());
169 ConstantRange X
= getRange(SExt
->getOperand(), T
, SE
);
170 return X
.signExtend(cast
<IntegerType
>(SExt
->getType())->getBitWidth());
173 if (const SCEVTruncateExpr
*Trunc
= dyn_cast
<SCEVTruncateExpr
>(S
)) {
174 T
= SE
.getTruncateOrZeroExtend(T
, Trunc
->getOperand()->getType());
175 ConstantRange X
= getRange(Trunc
->getOperand(), T
, SE
);
176 if (X
.isFullSet()) return FullSet
;
177 return X
.truncate(cast
<IntegerType
>(Trunc
->getType())->getBitWidth());
180 if (const SCEVAddRecExpr
*AddRec
= dyn_cast
<SCEVAddRecExpr
>(S
)) {
181 const SCEVConstant
*Trip
= dyn_cast
<SCEVConstant
>(T
);
182 if (!Trip
) return FullSet
;
184 if (AddRec
->isAffine()) {
185 SCEVHandle StartHandle
= AddRec
->getStart();
186 SCEVHandle StepHandle
= AddRec
->getOperand(1);
188 const SCEVConstant
*Step
= dyn_cast
<SCEVConstant
>(StepHandle
);
189 if (!Step
) return FullSet
;
191 uint32_t ExWidth
= 2 * Trip
->getValue()->getBitWidth();
192 APInt TripExt
= Trip
->getValue()->getValue(); TripExt
.zext(ExWidth
);
193 APInt StepExt
= Step
->getValue()->getValue(); StepExt
.zext(ExWidth
);
194 if ((TripExt
* StepExt
).ugt(APInt::getLowBitsSet(ExWidth
, ExWidth
>> 1)))
197 SCEVHandle EndHandle
= SE
.getAddExpr(StartHandle
,
198 SE
.getMulExpr(T
, StepHandle
));
199 const SCEVConstant
*Start
= dyn_cast
<SCEVConstant
>(StartHandle
);
200 const SCEVConstant
*End
= dyn_cast
<SCEVConstant
>(EndHandle
);
201 if (!Start
|| !End
) return FullSet
;
203 const APInt
&StartInt
= Start
->getValue()->getValue();
204 const APInt
&EndInt
= End
->getValue()->getValue();
205 const APInt
&StepInt
= Step
->getValue()->getValue();
207 if (StepInt
.isNegative()) {
208 if (EndInt
== StartInt
+ 1) return FullSet
;
209 return ConstantRange(EndInt
, StartInt
+ 1);
211 if (StartInt
== EndInt
+ 1) return FullSet
;
212 return ConstantRange(StartInt
, EndInt
+ 1);
217 // TODO: non-affine addrec, udiv, SCEVUnknown (narrowed from elsewhere)?
222 bool LoopVR::runOnFunction(Function
&F
) { Map
.clear(); return false; }
224 void LoopVR::print(std::ostream
&os
, const Module
*) const {
225 raw_os_ostream
OS(os
);
226 for (std::map
<Value
*, ConstantRange
*>::const_iterator I
= Map
.begin(),
227 E
= Map
.end(); I
!= E
; ++I
) {
228 OS
<< *I
->first
<< ": " << *I
->second
<< '\n';
232 void LoopVR::releaseMemory() {
233 for (std::map
<Value
*, ConstantRange
*>::iterator I
= Map
.begin(),
234 E
= Map
.end(); I
!= E
; ++I
) {
241 ConstantRange
LoopVR::compute(Value
*V
) {
242 if (ConstantInt
*CI
= dyn_cast
<ConstantInt
>(V
))
243 return ConstantRange(CI
->getValue());
245 Instruction
*I
= dyn_cast
<Instruction
>(V
);
247 return ConstantRange(cast
<IntegerType
>(V
->getType())->getBitWidth(), false);
249 LoopInfo
&LI
= getAnalysis
<LoopInfo
>();
251 Loop
*L
= LI
.getLoopFor(I
->getParent());
252 if (!L
|| L
->isLoopInvariant(I
))
253 return ConstantRange(cast
<IntegerType
>(V
->getType())->getBitWidth(), false);
255 ScalarEvolution
&SE
= getAnalysis
<ScalarEvolution
>();
257 SCEVHandle S
= SE
.getSCEV(I
);
258 if (isa
<SCEVUnknown
>(S
) || isa
<SCEVCouldNotCompute
>(S
))
259 return ConstantRange(cast
<IntegerType
>(V
->getType())->getBitWidth(), false);
261 return ConstantRange(getRange(S
, L
, SE
));
264 ConstantRange
LoopVR::get(Value
*V
) {
265 std::map
<Value
*, ConstantRange
*>::iterator I
= Map
.find(V
);
266 if (I
== Map
.end()) {
267 ConstantRange
*CR
= new ConstantRange(compute(V
));
275 void LoopVR::remove(Value
*V
) {
276 std::map
<Value
*, ConstantRange
*>::iterator I
= Map
.find(V
);
277 if (I
!= Map
.end()) {
283 void LoopVR::narrow(Value
*V
, const ConstantRange
&CR
) {
284 if (CR
.isFullSet()) return;
286 std::map
<Value
*, ConstantRange
*>::iterator I
= Map
.find(V
);
288 Map
[V
] = new ConstantRange(CR
);
290 Map
[V
] = new ConstantRange(Map
[V
]->maximalIntersectWith(CR
));