1 //===- PointerTracking.cpp - Pointer Bounds Tracking ------------*- C++ -*-===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // This file implements tracking of pointer bounds.
12 //===----------------------------------------------------------------------===//
13 #include "llvm/Analysis/ConstantFolding.h"
14 #include "llvm/Analysis/Dominators.h"
15 #include "llvm/Analysis/LoopInfo.h"
16 #include "llvm/Analysis/PointerTracking.h"
17 #include "llvm/Analysis/ScalarEvolution.h"
18 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
19 #include "llvm/Constants.h"
20 #include "llvm/Module.h"
21 #include "llvm/Value.h"
22 #include "llvm/Support/CallSite.h"
23 #include "llvm/Support/InstIterator.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include "llvm/Target/TargetData.h"
28 char PointerTracking::ID
= 0;
29 PointerTracking::PointerTracking() : FunctionPass(&ID
) {}
31 bool PointerTracking::runOnFunction(Function
&F
) {
33 assert(analyzing
.empty());
35 TD
= getAnalysisIfAvailable
<TargetData
>();
36 SE
= &getAnalysis
<ScalarEvolution
>();
37 LI
= &getAnalysis
<LoopInfo
>();
38 DT
= &getAnalysis
<DominatorTree
>();
42 void PointerTracking::getAnalysisUsage(AnalysisUsage
&AU
) const {
43 AU
.addRequiredTransitive
<DominatorTree
>();
44 AU
.addRequiredTransitive
<LoopInfo
>();
45 AU
.addRequiredTransitive
<ScalarEvolution
>();
49 bool PointerTracking::doInitialization(Module
&M
) {
50 const Type
*PTy
= PointerType::getUnqual(Type::getInt8Ty(M
.getContext()));
52 // Find calloc(i64, i64) or calloc(i32, i32).
53 callocFunc
= M
.getFunction("calloc");
55 const FunctionType
*Ty
= callocFunc
->getFunctionType();
57 std::vector
<const Type
*> args
, args2
;
58 args
.push_back(Type::getInt64Ty(M
.getContext()));
59 args
.push_back(Type::getInt64Ty(M
.getContext()));
60 args2
.push_back(Type::getInt32Ty(M
.getContext()));
61 args2
.push_back(Type::getInt32Ty(M
.getContext()));
62 const FunctionType
*Calloc1Type
=
63 FunctionType::get(PTy
, args
, false);
64 const FunctionType
*Calloc2Type
=
65 FunctionType::get(PTy
, args2
, false);
66 if (Ty
!= Calloc1Type
&& Ty
!= Calloc2Type
)
67 callocFunc
= 0; // Give up
70 // Find realloc(i8*, i64) or realloc(i8*, i32).
71 reallocFunc
= M
.getFunction("realloc");
73 const FunctionType
*Ty
= reallocFunc
->getFunctionType();
74 std::vector
<const Type
*> args
, args2
;
76 args
.push_back(Type::getInt64Ty(M
.getContext()));
78 args2
.push_back(Type::getInt32Ty(M
.getContext()));
80 const FunctionType
*Realloc1Type
=
81 FunctionType::get(PTy
, args
, false);
82 const FunctionType
*Realloc2Type
=
83 FunctionType::get(PTy
, args2
, false);
84 if (Ty
!= Realloc1Type
&& Ty
!= Realloc2Type
)
85 reallocFunc
= 0; // Give up
90 // Calculates the number of elements allocated for pointer P,
91 // the type of the element is stored in Ty.
92 const SCEV
*PointerTracking::computeAllocationCount(Value
*P
,
93 const Type
*&Ty
) const {
94 Value
*V
= P
->stripPointerCasts();
95 if (AllocationInst
*AI
= dyn_cast
<AllocationInst
>(V
)) {
96 Value
*arraySize
= AI
->getArraySize();
97 Ty
= AI
->getAllocatedType();
98 // arraySize elements of type Ty.
99 return SE
->getSCEV(arraySize
);
102 if (GlobalVariable
*GV
= dyn_cast
<GlobalVariable
>(V
)) {
103 if (GV
->hasDefinitiveInitializer()) {
104 Constant
*C
= GV
->getInitializer();
105 if (const ArrayType
*ATy
= dyn_cast
<ArrayType
>(C
->getType())) {
106 Ty
= ATy
->getElementType();
107 return SE
->getConstant(Type::getInt32Ty(P
->getContext()),
108 ATy
->getNumElements());
112 return SE
->getConstant(Type::getInt32Ty(P
->getContext()), 1);
113 //TODO: implement more tracking for globals
116 if (CallInst
*CI
= dyn_cast
<CallInst
>(V
)) {
118 Function
*F
= dyn_cast
<Function
>(CS
.getCalledValue()->stripPointerCasts());
119 const Loop
*L
= LI
->getLoopFor(CI
->getParent());
120 if (F
== callocFunc
) {
121 Ty
= Type::getInt8Ty(P
->getContext());
122 // calloc allocates arg0*arg1 bytes.
123 return SE
->getSCEVAtScope(SE
->getMulExpr(SE
->getSCEV(CS
.getArgument(0)),
124 SE
->getSCEV(CS
.getArgument(1))),
126 } else if (F
== reallocFunc
) {
127 Ty
= Type::getInt8Ty(P
->getContext());
128 // realloc allocates arg1 bytes.
129 return SE
->getSCEVAtScope(CS
.getArgument(1), L
);
133 return SE
->getCouldNotCompute();
136 // Calculates the number of elements of type Ty allocated for P.
137 const SCEV
*PointerTracking::computeAllocationCountForType(Value
*P
,
140 const Type
*elementTy
;
141 const SCEV
*Count
= computeAllocationCount(P
, elementTy
);
142 if (isa
<SCEVCouldNotCompute
>(Count
))
147 if (!TD
) // need TargetData from this point forward
148 return SE
->getCouldNotCompute();
150 uint64_t elementSize
= TD
->getTypeAllocSize(elementTy
);
151 uint64_t wantSize
= TD
->getTypeAllocSize(Ty
);
152 if (elementSize
== wantSize
)
154 if (elementSize
% wantSize
) //fractional counts not possible
155 return SE
->getCouldNotCompute();
156 return SE
->getMulExpr(Count
, SE
->getConstant(Count
->getType(),
157 elementSize
/wantSize
));
160 const SCEV
*PointerTracking::getAllocationElementCount(Value
*V
) const {
161 // We only deal with pointers.
162 const PointerType
*PTy
= cast
<PointerType
>(V
->getType());
163 return computeAllocationCountForType(V
, PTy
->getElementType());
166 const SCEV
*PointerTracking::getAllocationSizeInBytes(Value
*V
) const {
167 return computeAllocationCountForType(V
, Type::getInt8Ty(V
->getContext()));
170 // Helper for isLoopGuardedBy that checks the swapped and inverted predicate too
171 enum SolverResult
PointerTracking::isLoopGuardedBy(const Loop
*L
,
174 const SCEV
*B
) const {
175 if (SE
->isLoopGuardedByCond(L
, Pred
, A
, B
))
177 Pred
= ICmpInst::getSwappedPredicate(Pred
);
178 if (SE
->isLoopGuardedByCond(L
, Pred
, B
, A
))
181 Pred
= ICmpInst::getInversePredicate(Pred
);
182 if (SE
->isLoopGuardedByCond(L
, Pred
, B
, A
))
184 Pred
= ICmpInst::getSwappedPredicate(Pred
);
185 if (SE
->isLoopGuardedByCond(L
, Pred
, A
, B
))
190 enum SolverResult
PointerTracking::checkLimits(const SCEV
*Offset
,
194 //FIXME: merge implementation
198 void PointerTracking::getPointerOffset(Value
*Pointer
, Value
*&Base
,
200 const SCEV
*&Offset
) const
202 Pointer
= Pointer
->stripPointerCasts();
203 Base
= Pointer
->getUnderlyingObject();
204 Limit
= getAllocationSizeInBytes(Base
);
205 if (isa
<SCEVCouldNotCompute
>(Limit
)) {
211 Offset
= SE
->getMinusSCEV(SE
->getSCEV(Pointer
), SE
->getSCEV(Base
));
212 if (isa
<SCEVCouldNotCompute
>(Offset
)) {
218 void PointerTracking::print(raw_ostream
&OS
, const Module
* M
) const {
219 // Calling some PT methods may cause caches to be updated, however
220 // this should be safe for the same reason its safe for SCEV.
221 PointerTracking
&PT
= *const_cast<PointerTracking
*>(this);
222 for (inst_iterator I
=inst_begin(*FF
), E
=inst_end(*FF
); I
!= E
; ++I
) {
223 if (!isa
<PointerType
>(I
->getType()))
226 const SCEV
*Limit
, *Offset
;
227 getPointerOffset(&*I
, Base
, Limit
, Offset
);
232 const SCEV
*S
= getAllocationElementCount(Base
);
233 OS
<< *Base
<< " ==> " << *S
<< " elements, ";
234 OS
<< *Limit
<< " bytes allocated\n";
237 OS
<< &*I
<< " -- base: " << *Base
;
238 OS
<< " offset: " << *Offset
;
240 enum SolverResult res
= PT
.checkLimits(Offset
, Limit
, I
->getParent());
243 OS
<< " always safe\n";
246 OS
<< " always unsafe\n";
249 OS
<< " <<unknown>>\n";
255 static RegisterPass
<PointerTracking
> X("pointertracking",
256 "Track pointer bounds", false, true);