1 //===- BoundsChecking.cpp - Instrumentation for run-time bounds checking --===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Transforms/Instrumentation/BoundsChecking.h"
10 #include "llvm/ADT/Statistic.h"
11 #include "llvm/ADT/Twine.h"
12 #include "llvm/Analysis/MemoryBuiltins.h"
13 #include "llvm/Analysis/ScalarEvolution.h"
14 #include "llvm/Analysis/TargetFolder.h"
15 #include "llvm/Analysis/TargetLibraryInfo.h"
16 #include "llvm/IR/BasicBlock.h"
17 #include "llvm/IR/Constants.h"
18 #include "llvm/IR/DataLayout.h"
19 #include "llvm/IR/Function.h"
20 #include "llvm/IR/IRBuilder.h"
21 #include "llvm/IR/InstIterator.h"
22 #include "llvm/IR/InstrTypes.h"
23 #include "llvm/IR/Instruction.h"
24 #include "llvm/IR/Instructions.h"
25 #include "llvm/IR/Intrinsics.h"
26 #include "llvm/IR/Value.h"
27 #include "llvm/Pass.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/ErrorHandling.h"
32 #include "llvm/Support/raw_ostream.h"
38 #define DEBUG_TYPE "bounds-checking"
40 static cl::opt
<bool> SingleTrapBB("bounds-checking-single-trap",
41 cl::desc("Use one trap block per function"));
43 STATISTIC(ChecksAdded
, "Bounds checks added");
44 STATISTIC(ChecksSkipped
, "Bounds checks skipped");
45 STATISTIC(ChecksUnable
, "Bounds checks unable to add");
47 using BuilderTy
= IRBuilder
<TargetFolder
>;
49 /// Gets the conditions under which memory accessing instructions will overflow.
51 /// \p Ptr is the pointer that will be read/written, and \p InstVal is either
52 /// the result from the load or the value being stored. It is used to determine
53 /// the size of memory block that is touched.
55 /// Returns the condition under which the access will overflow.
56 static Value
*getBoundsCheckCond(Value
*Ptr
, Value
*InstVal
,
57 const DataLayout
&DL
, TargetLibraryInfo
&TLI
,
58 ObjectSizeOffsetEvaluator
&ObjSizeEval
,
59 BuilderTy
&IRB
, ScalarEvolution
&SE
) {
60 uint64_t NeededSize
= DL
.getTypeStoreSize(InstVal
->getType());
61 LLVM_DEBUG(dbgs() << "Instrument " << *Ptr
<< " for " << Twine(NeededSize
)
64 SizeOffsetEvalType SizeOffset
= ObjSizeEval
.compute(Ptr
);
66 if (!ObjSizeEval
.bothKnown(SizeOffset
)) {
71 Value
*Size
= SizeOffset
.first
;
72 Value
*Offset
= SizeOffset
.second
;
73 ConstantInt
*SizeCI
= dyn_cast
<ConstantInt
>(Size
);
75 Type
*IntTy
= DL
.getIntPtrType(Ptr
->getType());
76 Value
*NeededSizeVal
= ConstantInt::get(IntTy
, NeededSize
);
78 auto SizeRange
= SE
.getUnsignedRange(SE
.getSCEV(Size
));
79 auto OffsetRange
= SE
.getUnsignedRange(SE
.getSCEV(Offset
));
80 auto NeededSizeRange
= SE
.getUnsignedRange(SE
.getSCEV(NeededSizeVal
));
82 // three checks are required to ensure safety:
83 // . Offset >= 0 (since the offset is given from the base ptr)
84 // . Size >= Offset (unsigned)
85 // . Size - Offset >= NeededSize (unsigned)
87 // optimization: if Size >= 0 (signed), skip 1st check
88 // FIXME: add NSW/NUW here? -- we dont care if the subtraction overflows
89 Value
*ObjSize
= IRB
.CreateSub(Size
, Offset
);
90 Value
*Cmp2
= SizeRange
.getUnsignedMin().uge(OffsetRange
.getUnsignedMax())
91 ? ConstantInt::getFalse(Ptr
->getContext())
92 : IRB
.CreateICmpULT(Size
, Offset
);
93 Value
*Cmp3
= SizeRange
.sub(OffsetRange
)
95 .uge(NeededSizeRange
.getUnsignedMax())
96 ? ConstantInt::getFalse(Ptr
->getContext())
97 : IRB
.CreateICmpULT(ObjSize
, NeededSizeVal
);
98 Value
*Or
= IRB
.CreateOr(Cmp2
, Cmp3
);
99 if ((!SizeCI
|| SizeCI
->getValue().slt(0)) &&
100 !SizeRange
.getSignedMin().isNonNegative()) {
101 Value
*Cmp1
= IRB
.CreateICmpSLT(Offset
, ConstantInt::get(IntTy
, 0));
102 Or
= IRB
.CreateOr(Cmp1
, Or
);
108 /// Adds run-time bounds checks to memory accessing instructions.
110 /// \p Or is the condition that should guard the trap.
112 /// \p GetTrapBB is a callable that returns the trap BB to use on failure.
113 template <typename GetTrapBBT
>
114 static void insertBoundsCheck(Value
*Or
, BuilderTy IRB
, GetTrapBBT GetTrapBB
) {
115 // check if the comparison is always false
116 ConstantInt
*C
= dyn_cast_or_null
<ConstantInt
>(Or
);
119 // If non-zero, nothing to do.
120 if (!C
->getZExtValue())
125 BasicBlock::iterator SplitI
= IRB
.GetInsertPoint();
126 BasicBlock
*OldBB
= SplitI
->getParent();
127 BasicBlock
*Cont
= OldBB
->splitBasicBlock(SplitI
);
128 OldBB
->getTerminator()->eraseFromParent();
131 // If we have a constant zero, unconditionally branch.
132 // FIXME: We should really handle this differently to bypass the splitting
134 BranchInst::Create(GetTrapBB(IRB
), OldBB
);
138 // Create the conditional branch.
139 BranchInst::Create(GetTrapBB(IRB
), Cont
, Or
, OldBB
);
142 static bool addBoundsChecking(Function
&F
, TargetLibraryInfo
&TLI
,
143 ScalarEvolution
&SE
) {
144 const DataLayout
&DL
= F
.getParent()->getDataLayout();
145 ObjectSizeOpts EvalOpts
;
146 EvalOpts
.RoundToAlign
= true;
147 ObjectSizeOffsetEvaluator
ObjSizeEval(DL
, &TLI
, F
.getContext(), EvalOpts
);
149 // check HANDLE_MEMORY_INST in include/llvm/Instruction.def for memory
150 // touching instructions
151 SmallVector
<std::pair
<Instruction
*, Value
*>, 4> TrapInfo
;
152 for (Instruction
&I
: instructions(F
)) {
154 BuilderTy
IRB(I
.getParent(), BasicBlock::iterator(&I
), TargetFolder(DL
));
155 if (LoadInst
*LI
= dyn_cast
<LoadInst
>(&I
)) {
156 Or
= getBoundsCheckCond(LI
->getPointerOperand(), LI
, DL
, TLI
,
157 ObjSizeEval
, IRB
, SE
);
158 } else if (StoreInst
*SI
= dyn_cast
<StoreInst
>(&I
)) {
159 Or
= getBoundsCheckCond(SI
->getPointerOperand(), SI
->getValueOperand(),
160 DL
, TLI
, ObjSizeEval
, IRB
, SE
);
161 } else if (AtomicCmpXchgInst
*AI
= dyn_cast
<AtomicCmpXchgInst
>(&I
)) {
162 Or
= getBoundsCheckCond(AI
->getPointerOperand(), AI
->getCompareOperand(),
163 DL
, TLI
, ObjSizeEval
, IRB
, SE
);
164 } else if (AtomicRMWInst
*AI
= dyn_cast
<AtomicRMWInst
>(&I
)) {
165 Or
= getBoundsCheckCond(AI
->getPointerOperand(), AI
->getValOperand(), DL
,
166 TLI
, ObjSizeEval
, IRB
, SE
);
169 TrapInfo
.push_back(std::make_pair(&I
, Or
));
172 // Create a trapping basic block on demand using a callback. Depending on
173 // flags, this will either create a single block for the entire function or
174 // will create a fresh block every time it is called.
175 BasicBlock
*TrapBB
= nullptr;
176 auto GetTrapBB
= [&TrapBB
](BuilderTy
&IRB
) {
177 if (TrapBB
&& SingleTrapBB
)
180 Function
*Fn
= IRB
.GetInsertBlock()->getParent();
181 // FIXME: This debug location doesn't make a lot of sense in the
182 // `SingleTrapBB` case.
183 auto DebugLoc
= IRB
.getCurrentDebugLocation();
184 IRBuilder
<>::InsertPointGuard
Guard(IRB
);
185 TrapBB
= BasicBlock::Create(Fn
->getContext(), "trap", Fn
);
186 IRB
.SetInsertPoint(TrapBB
);
188 auto *F
= Intrinsic::getDeclaration(Fn
->getParent(), Intrinsic::trap
);
189 CallInst
*TrapCall
= IRB
.CreateCall(F
, {});
190 TrapCall
->setDoesNotReturn();
191 TrapCall
->setDoesNotThrow();
192 TrapCall
->setDebugLoc(DebugLoc
);
193 IRB
.CreateUnreachable();
199 for (const auto &Entry
: TrapInfo
) {
200 Instruction
*Inst
= Entry
.first
;
201 BuilderTy
IRB(Inst
->getParent(), BasicBlock::iterator(Inst
), TargetFolder(DL
));
202 insertBoundsCheck(Entry
.second
, IRB
, GetTrapBB
);
205 return !TrapInfo
.empty();
208 PreservedAnalyses
BoundsCheckingPass::run(Function
&F
, FunctionAnalysisManager
&AM
) {
209 auto &TLI
= AM
.getResult
<TargetLibraryAnalysis
>(F
);
210 auto &SE
= AM
.getResult
<ScalarEvolutionAnalysis
>(F
);
212 if (!addBoundsChecking(F
, TLI
, SE
))
213 return PreservedAnalyses::all();
215 return PreservedAnalyses::none();
219 struct BoundsCheckingLegacyPass
: public FunctionPass
{
222 BoundsCheckingLegacyPass() : FunctionPass(ID
) {
223 initializeBoundsCheckingLegacyPassPass(*PassRegistry::getPassRegistry());
226 bool runOnFunction(Function
&F
) override
{
227 auto &TLI
= getAnalysis
<TargetLibraryInfoWrapperPass
>().getTLI(F
);
228 auto &SE
= getAnalysis
<ScalarEvolutionWrapperPass
>().getSE();
229 return addBoundsChecking(F
, TLI
, SE
);
232 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
233 AU
.addRequired
<TargetLibraryInfoWrapperPass
>();
234 AU
.addRequired
<ScalarEvolutionWrapperPass
>();
239 char BoundsCheckingLegacyPass::ID
= 0;
240 INITIALIZE_PASS_BEGIN(BoundsCheckingLegacyPass
, "bounds-checking",
241 "Run-time bounds checking", false, false)
242 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass
)
243 INITIALIZE_PASS_END(BoundsCheckingLegacyPass
, "bounds-checking",
244 "Run-time bounds checking", false, false)
246 FunctionPass
*llvm::createBoundsCheckingLegacyPass() {
247 return new BoundsCheckingLegacyPass();