1 //===- LoopVR.cpp - Value Range analysis driven by loop information -------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // FIXME: What does this do?
12 //===----------------------------------------------------------------------===//
14 #define DEBUG_TYPE "loopvr"
15 #include "llvm/Analysis/LoopVR.h"
16 #include "llvm/Constants.h"
17 #include "llvm/Instructions.h"
18 #include "llvm/LLVMContext.h"
19 #include "llvm/Analysis/LoopInfo.h"
20 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
21 #include "llvm/Assembly/Writer.h"
22 #include "llvm/Support/CFG.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/raw_ostream.h"
28 static RegisterPass
<LoopVR
> X("loopvr", "Loop Value Ranges", false, true);
30 /// getRange - determine the range for a particular SCEV within a given Loop
31 ConstantRange
LoopVR::getRange(const SCEV
*S
, Loop
*L
, ScalarEvolution
&SE
) {
32 const SCEV
*T
= SE
.getBackedgeTakenCount(L
);
33 if (isa
<SCEVCouldNotCompute
>(T
))
34 return ConstantRange(cast
<IntegerType
>(S
->getType())->getBitWidth(), true);
36 T
= SE
.getTruncateOrZeroExtend(T
, S
->getType());
37 return getRange(S
, T
, SE
);
40 /// getRange - determine the range for a particular SCEV with a given trip count
41 ConstantRange
LoopVR::getRange(const SCEV
*S
, const SCEV
*T
, ScalarEvolution
&SE
){
43 if (const SCEVConstant
*C
= dyn_cast
<SCEVConstant
>(S
))
44 return ConstantRange(C
->getValue()->getValue());
46 ConstantRange
FullSet(cast
<IntegerType
>(S
->getType())->getBitWidth(), true);
48 // {x,+,y,+,...z}. We detect overflow by checking the size of the set after
49 // summing the upper and lower.
50 if (const SCEVAddExpr
*Add
= dyn_cast
<SCEVAddExpr
>(S
)) {
51 ConstantRange X
= getRange(Add
->getOperand(0), T
, SE
);
52 if (X
.isFullSet()) return FullSet
;
53 for (unsigned i
= 1, e
= Add
->getNumOperands(); i
!= e
; ++i
) {
54 ConstantRange Y
= getRange(Add
->getOperand(i
), T
, SE
);
55 if (Y
.isFullSet()) return FullSet
;
57 APInt Spread_X
= X
.getSetSize(), Spread_Y
= Y
.getSetSize();
58 APInt NewLower
= X
.getLower() + Y
.getLower();
59 APInt NewUpper
= X
.getUpper() + Y
.getUpper() - 1;
60 if (NewLower
== NewUpper
)
63 X
= ConstantRange(NewLower
, NewUpper
);
64 if (X
.getSetSize().ult(Spread_X
) || X
.getSetSize().ult(Spread_Y
))
65 return FullSet
; // we've wrapped, therefore, full set.
70 // {x,*,y,*,...,z}. In order to detect overflow, we use k*bitwidth where
71 // k is the number of terms being multiplied.
72 if (const SCEVMulExpr
*Mul
= dyn_cast
<SCEVMulExpr
>(S
)) {
73 ConstantRange X
= getRange(Mul
->getOperand(0), T
, SE
);
74 if (X
.isFullSet()) return FullSet
;
76 const IntegerType
*Ty
= IntegerType::get(SE
.getContext(), X
.getBitWidth());
77 const IntegerType
*ExTy
= IntegerType::get(SE
.getContext(),
78 X
.getBitWidth() * Mul
->getNumOperands());
79 ConstantRange XExt
= X
.zeroExtend(ExTy
->getBitWidth());
81 for (unsigned i
= 1, e
= Mul
->getNumOperands(); i
!= e
; ++i
) {
82 ConstantRange Y
= getRange(Mul
->getOperand(i
), T
, SE
);
83 if (Y
.isFullSet()) return FullSet
;
85 ConstantRange YExt
= Y
.zeroExtend(ExTy
->getBitWidth());
86 XExt
= ConstantRange(XExt
.getLower() * YExt
.getLower(),
87 ((XExt
.getUpper()-1) * (YExt
.getUpper()-1)) + 1);
89 return XExt
.truncate(Ty
->getBitWidth());
92 // X smax Y smax ... Z is: range(smax(X_smin, Y_smin, ..., Z_smin),
93 // smax(X_smax, Y_smax, ..., Z_smax))
94 // It doesn't matter if one of the SCEVs has FullSet because we're taking
95 // a maximum of the minimums across all of them.
96 if (const SCEVSMaxExpr
*SMax
= dyn_cast
<SCEVSMaxExpr
>(S
)) {
97 ConstantRange X
= getRange(SMax
->getOperand(0), T
, SE
);
98 if (X
.isFullSet()) return FullSet
;
100 APInt smin
= X
.getSignedMin(), smax
= X
.getSignedMax();
101 for (unsigned i
= 1, e
= SMax
->getNumOperands(); i
!= e
; ++i
) {
102 ConstantRange Y
= getRange(SMax
->getOperand(i
), T
, SE
);
103 smin
= APIntOps::smax(smin
, Y
.getSignedMin());
104 smax
= APIntOps::smax(smax
, Y
.getSignedMax());
106 if (smax
+ 1 == smin
) return FullSet
;
107 return ConstantRange(smin
, smax
+ 1);
110 // X umax Y umax ... Z is: range(umax(X_umin, Y_umin, ..., Z_umin),
111 // umax(X_umax, Y_umax, ..., Z_umax))
112 // It doesn't matter if one of the SCEVs has FullSet because we're taking
113 // a maximum of the minimums across all of them.
114 if (const SCEVUMaxExpr
*UMax
= dyn_cast
<SCEVUMaxExpr
>(S
)) {
115 ConstantRange X
= getRange(UMax
->getOperand(0), T
, SE
);
116 if (X
.isFullSet()) return FullSet
;
118 APInt umin
= X
.getUnsignedMin(), umax
= X
.getUnsignedMax();
119 for (unsigned i
= 1, e
= UMax
->getNumOperands(); i
!= e
; ++i
) {
120 ConstantRange Y
= getRange(UMax
->getOperand(i
), T
, SE
);
121 umin
= APIntOps::umax(umin
, Y
.getUnsignedMin());
122 umax
= APIntOps::umax(umax
, Y
.getUnsignedMax());
124 if (umax
+ 1 == umin
) return FullSet
;
125 return ConstantRange(umin
, umax
+ 1);
128 // L udiv R. Luckily, there's only ever 2 sides to a udiv.
129 if (const SCEVUDivExpr
*UDiv
= dyn_cast
<SCEVUDivExpr
>(S
)) {
130 ConstantRange L
= getRange(UDiv
->getLHS(), T
, SE
);
131 ConstantRange R
= getRange(UDiv
->getRHS(), T
, SE
);
132 if (L
.isFullSet() && R
.isFullSet()) return FullSet
;
134 if (R
.getUnsignedMax() == 0) {
135 // RHS must be single-element zero. Return an empty set.
136 return ConstantRange(R
.getBitWidth(), false);
139 APInt Lower
= L
.getUnsignedMin().udiv(R
.getUnsignedMax());
143 if (R
.getUnsignedMin() == 0) {
144 // Just because it contains zero, doesn't mean it will also contain one.
145 ConstantRange
NotZero(APInt(L
.getBitWidth(), 1),
146 APInt::getNullValue(L
.getBitWidth()));
147 R
= R
.intersectWith(NotZero
);
150 // But, the intersection might still include zero. If it does, then we know
151 // it also included one.
152 if (R
.contains(APInt::getNullValue(L
.getBitWidth())))
153 Upper
= L
.getUnsignedMax();
155 Upper
= L
.getUnsignedMax().udiv(R
.getUnsignedMin());
157 return ConstantRange(Lower
, Upper
);
160 // ConstantRange already implements the cast operators.
162 if (const SCEVZeroExtendExpr
*ZExt
= dyn_cast
<SCEVZeroExtendExpr
>(S
)) {
163 T
= SE
.getTruncateOrZeroExtend(T
, ZExt
->getOperand()->getType());
164 ConstantRange X
= getRange(ZExt
->getOperand(), T
, SE
);
165 return X
.zeroExtend(cast
<IntegerType
>(ZExt
->getType())->getBitWidth());
168 if (const SCEVSignExtendExpr
*SExt
= dyn_cast
<SCEVSignExtendExpr
>(S
)) {
169 T
= SE
.getTruncateOrZeroExtend(T
, SExt
->getOperand()->getType());
170 ConstantRange X
= getRange(SExt
->getOperand(), T
, SE
);
171 return X
.signExtend(cast
<IntegerType
>(SExt
->getType())->getBitWidth());
174 if (const SCEVTruncateExpr
*Trunc
= dyn_cast
<SCEVTruncateExpr
>(S
)) {
175 T
= SE
.getTruncateOrZeroExtend(T
, Trunc
->getOperand()->getType());
176 ConstantRange X
= getRange(Trunc
->getOperand(), T
, SE
);
177 if (X
.isFullSet()) return FullSet
;
178 return X
.truncate(cast
<IntegerType
>(Trunc
->getType())->getBitWidth());
181 if (const SCEVAddRecExpr
*AddRec
= dyn_cast
<SCEVAddRecExpr
>(S
)) {
182 const SCEVConstant
*Trip
= dyn_cast
<SCEVConstant
>(T
);
183 if (!Trip
) return FullSet
;
185 if (AddRec
->isAffine()) {
186 const SCEV
*StartHandle
= AddRec
->getStart();
187 const SCEV
*StepHandle
= AddRec
->getOperand(1);
189 const SCEVConstant
*Step
= dyn_cast
<SCEVConstant
>(StepHandle
);
190 if (!Step
) return FullSet
;
192 uint32_t ExWidth
= 2 * Trip
->getValue()->getBitWidth();
193 APInt TripExt
= Trip
->getValue()->getValue(); TripExt
.zext(ExWidth
);
194 APInt StepExt
= Step
->getValue()->getValue(); StepExt
.zext(ExWidth
);
195 if ((TripExt
* StepExt
).ugt(APInt::getLowBitsSet(ExWidth
, ExWidth
>> 1)))
198 const SCEV
*EndHandle
= SE
.getAddExpr(StartHandle
,
199 SE
.getMulExpr(T
, StepHandle
));
200 const SCEVConstant
*Start
= dyn_cast
<SCEVConstant
>(StartHandle
);
201 const SCEVConstant
*End
= dyn_cast
<SCEVConstant
>(EndHandle
);
202 if (!Start
|| !End
) return FullSet
;
204 const APInt
&StartInt
= Start
->getValue()->getValue();
205 const APInt
&EndInt
= End
->getValue()->getValue();
206 const APInt
&StepInt
= Step
->getValue()->getValue();
208 if (StepInt
.isNegative()) {
209 if (EndInt
== StartInt
+ 1) return FullSet
;
210 return ConstantRange(EndInt
, StartInt
+ 1);
212 if (StartInt
== EndInt
+ 1) return FullSet
;
213 return ConstantRange(StartInt
, EndInt
+ 1);
218 // TODO: non-affine addrec, udiv, SCEVUnknown (narrowed from elsewhere)?
223 void LoopVR::getAnalysisUsage(AnalysisUsage
&AU
) const {
224 AU
.addRequiredTransitive
<LoopInfo
>();
225 AU
.addRequiredTransitive
<ScalarEvolution
>();
226 AU
.setPreservesAll();
229 bool LoopVR::runOnFunction(Function
&F
) { Map
.clear(); return false; }
231 void LoopVR::print(raw_ostream
&OS
, const Module
*) const {
232 for (std::map
<Value
*, ConstantRange
*>::const_iterator I
= Map
.begin(),
233 E
= Map
.end(); I
!= E
; ++I
) {
234 OS
<< *I
->first
<< ": " << *I
->second
<< '\n';
238 void LoopVR::releaseMemory() {
239 for (std::map
<Value
*, ConstantRange
*>::iterator I
= Map
.begin(),
240 E
= Map
.end(); I
!= E
; ++I
) {
247 ConstantRange
LoopVR::compute(Value
*V
) {
248 if (ConstantInt
*CI
= dyn_cast
<ConstantInt
>(V
))
249 return ConstantRange(CI
->getValue());
251 Instruction
*I
= dyn_cast
<Instruction
>(V
);
253 return ConstantRange(cast
<IntegerType
>(V
->getType())->getBitWidth(), false);
255 LoopInfo
&LI
= getAnalysis
<LoopInfo
>();
257 Loop
*L
= LI
.getLoopFor(I
->getParent());
258 if (!L
|| L
->isLoopInvariant(I
))
259 return ConstantRange(cast
<IntegerType
>(V
->getType())->getBitWidth(), false);
261 ScalarEvolution
&SE
= getAnalysis
<ScalarEvolution
>();
263 const SCEV
*S
= SE
.getSCEV(I
);
264 if (isa
<SCEVUnknown
>(S
) || isa
<SCEVCouldNotCompute
>(S
))
265 return ConstantRange(cast
<IntegerType
>(V
->getType())->getBitWidth(), false);
267 return ConstantRange(getRange(S
, L
, SE
));
270 ConstantRange
LoopVR::get(Value
*V
) {
271 std::map
<Value
*, ConstantRange
*>::iterator I
= Map
.find(V
);
272 if (I
== Map
.end()) {
273 ConstantRange
*CR
= new ConstantRange(compute(V
));
281 void LoopVR::remove(Value
*V
) {
282 std::map
<Value
*, ConstantRange
*>::iterator I
= Map
.find(V
);
283 if (I
!= Map
.end()) {
289 void LoopVR::narrow(Value
*V
, const ConstantRange
&CR
) {
290 if (CR
.isFullSet()) return;
292 std::map
<Value
*, ConstantRange
*>::iterator I
= Map
.find(V
);
294 Map
[V
] = new ConstantRange(CR
);
296 Map
[V
] = new ConstantRange(Map
[V
]->intersectWith(CR
));