1 //===- BoundsChecking.cpp - Instrumentation for run-time bounds checking --===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "llvm/Transforms/Instrumentation/BoundsChecking.h"
10 #include "llvm/ADT/Statistic.h"
11 #include "llvm/ADT/Twine.h"
12 #include "llvm/Analysis/MemoryBuiltins.h"
13 #include "llvm/Analysis/ScalarEvolution.h"
14 #include "llvm/Analysis/TargetFolder.h"
15 #include "llvm/Analysis/TargetLibraryInfo.h"
16 #include "llvm/IR/BasicBlock.h"
17 #include "llvm/IR/Constants.h"
18 #include "llvm/IR/DataLayout.h"
19 #include "llvm/IR/Function.h"
20 #include "llvm/IR/IRBuilder.h"
21 #include "llvm/IR/InstIterator.h"
22 #include "llvm/IR/InstrTypes.h"
23 #include "llvm/IR/Instruction.h"
24 #include "llvm/IR/Instructions.h"
25 #include "llvm/IR/Intrinsics.h"
26 #include "llvm/IR/Value.h"
27 #include "llvm/InitializePasses.h"
28 #include "llvm/Pass.h"
29 #include "llvm/Support/Casting.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/ErrorHandling.h"
33 #include "llvm/Support/raw_ostream.h"
39 #define DEBUG_TYPE "bounds-checking"
41 static cl::opt
<bool> SingleTrapBB("bounds-checking-single-trap",
42 cl::desc("Use one trap block per function"));
44 STATISTIC(ChecksAdded
, "Bounds checks added");
45 STATISTIC(ChecksSkipped
, "Bounds checks skipped");
46 STATISTIC(ChecksUnable
, "Bounds checks unable to add");
48 using BuilderTy
= IRBuilder
<TargetFolder
>;
50 /// Gets the conditions under which memory accessing instructions will overflow.
52 /// \p Ptr is the pointer that will be read/written, and \p InstVal is either
53 /// the result from the load or the value being stored. It is used to determine
54 /// the size of memory block that is touched.
56 /// Returns the condition under which the access will overflow.
57 static Value
*getBoundsCheckCond(Value
*Ptr
, Value
*InstVal
,
58 const DataLayout
&DL
, TargetLibraryInfo
&TLI
,
59 ObjectSizeOffsetEvaluator
&ObjSizeEval
,
60 BuilderTy
&IRB
, ScalarEvolution
&SE
) {
61 uint64_t NeededSize
= DL
.getTypeStoreSize(InstVal
->getType());
62 LLVM_DEBUG(dbgs() << "Instrument " << *Ptr
<< " for " << Twine(NeededSize
)
65 SizeOffsetEvalType SizeOffset
= ObjSizeEval
.compute(Ptr
);
67 if (!ObjSizeEval
.bothKnown(SizeOffset
)) {
72 Value
*Size
= SizeOffset
.first
;
73 Value
*Offset
= SizeOffset
.second
;
74 ConstantInt
*SizeCI
= dyn_cast
<ConstantInt
>(Size
);
76 Type
*IntTy
= DL
.getIntPtrType(Ptr
->getType());
77 Value
*NeededSizeVal
= ConstantInt::get(IntTy
, NeededSize
);
79 auto SizeRange
= SE
.getUnsignedRange(SE
.getSCEV(Size
));
80 auto OffsetRange
= SE
.getUnsignedRange(SE
.getSCEV(Offset
));
81 auto NeededSizeRange
= SE
.getUnsignedRange(SE
.getSCEV(NeededSizeVal
));
83 // three checks are required to ensure safety:
84 // . Offset >= 0 (since the offset is given from the base ptr)
85 // . Size >= Offset (unsigned)
86 // . Size - Offset >= NeededSize (unsigned)
88 // optimization: if Size >= 0 (signed), skip 1st check
89 // FIXME: add NSW/NUW here? -- we dont care if the subtraction overflows
90 Value
*ObjSize
= IRB
.CreateSub(Size
, Offset
);
91 Value
*Cmp2
= SizeRange
.getUnsignedMin().uge(OffsetRange
.getUnsignedMax())
92 ? ConstantInt::getFalse(Ptr
->getContext())
93 : IRB
.CreateICmpULT(Size
, Offset
);
94 Value
*Cmp3
= SizeRange
.sub(OffsetRange
)
96 .uge(NeededSizeRange
.getUnsignedMax())
97 ? ConstantInt::getFalse(Ptr
->getContext())
98 : IRB
.CreateICmpULT(ObjSize
, NeededSizeVal
);
99 Value
*Or
= IRB
.CreateOr(Cmp2
, Cmp3
);
100 if ((!SizeCI
|| SizeCI
->getValue().slt(0)) &&
101 !SizeRange
.getSignedMin().isNonNegative()) {
102 Value
*Cmp1
= IRB
.CreateICmpSLT(Offset
, ConstantInt::get(IntTy
, 0));
103 Or
= IRB
.CreateOr(Cmp1
, Or
);
109 /// Adds run-time bounds checks to memory accessing instructions.
111 /// \p Or is the condition that should guard the trap.
113 /// \p GetTrapBB is a callable that returns the trap BB to use on failure.
114 template <typename GetTrapBBT
>
115 static void insertBoundsCheck(Value
*Or
, BuilderTy
&IRB
, GetTrapBBT GetTrapBB
) {
116 // check if the comparison is always false
117 ConstantInt
*C
= dyn_cast_or_null
<ConstantInt
>(Or
);
120 // If non-zero, nothing to do.
121 if (!C
->getZExtValue())
126 BasicBlock::iterator SplitI
= IRB
.GetInsertPoint();
127 BasicBlock
*OldBB
= SplitI
->getParent();
128 BasicBlock
*Cont
= OldBB
->splitBasicBlock(SplitI
);
129 OldBB
->getTerminator()->eraseFromParent();
132 // If we have a constant zero, unconditionally branch.
133 // FIXME: We should really handle this differently to bypass the splitting
135 BranchInst::Create(GetTrapBB(IRB
), OldBB
);
139 // Create the conditional branch.
140 BranchInst::Create(GetTrapBB(IRB
), Cont
, Or
, OldBB
);
143 static bool addBoundsChecking(Function
&F
, TargetLibraryInfo
&TLI
,
144 ScalarEvolution
&SE
) {
145 const DataLayout
&DL
= F
.getParent()->getDataLayout();
146 ObjectSizeOpts EvalOpts
;
147 EvalOpts
.RoundToAlign
= true;
148 ObjectSizeOffsetEvaluator
ObjSizeEval(DL
, &TLI
, F
.getContext(), EvalOpts
);
150 // check HANDLE_MEMORY_INST in include/llvm/Instruction.def for memory
151 // touching instructions
152 SmallVector
<std::pair
<Instruction
*, Value
*>, 4> TrapInfo
;
153 for (Instruction
&I
: instructions(F
)) {
155 BuilderTy
IRB(I
.getParent(), BasicBlock::iterator(&I
), TargetFolder(DL
));
156 if (LoadInst
*LI
= dyn_cast
<LoadInst
>(&I
)) {
157 if (!LI
->isVolatile())
158 Or
= getBoundsCheckCond(LI
->getPointerOperand(), LI
, DL
, TLI
,
159 ObjSizeEval
, IRB
, SE
);
160 } else if (StoreInst
*SI
= dyn_cast
<StoreInst
>(&I
)) {
161 if (!SI
->isVolatile())
162 Or
= getBoundsCheckCond(SI
->getPointerOperand(), SI
->getValueOperand(),
163 DL
, TLI
, ObjSizeEval
, IRB
, SE
);
164 } else if (AtomicCmpXchgInst
*AI
= dyn_cast
<AtomicCmpXchgInst
>(&I
)) {
165 if (!AI
->isVolatile())
167 getBoundsCheckCond(AI
->getPointerOperand(), AI
->getCompareOperand(),
168 DL
, TLI
, ObjSizeEval
, IRB
, SE
);
169 } else if (AtomicRMWInst
*AI
= dyn_cast
<AtomicRMWInst
>(&I
)) {
170 if (!AI
->isVolatile())
171 Or
= getBoundsCheckCond(AI
->getPointerOperand(), AI
->getValOperand(),
172 DL
, TLI
, ObjSizeEval
, IRB
, SE
);
175 TrapInfo
.push_back(std::make_pair(&I
, Or
));
178 // Create a trapping basic block on demand using a callback. Depending on
179 // flags, this will either create a single block for the entire function or
180 // will create a fresh block every time it is called.
181 BasicBlock
*TrapBB
= nullptr;
182 auto GetTrapBB
= [&TrapBB
](BuilderTy
&IRB
) {
183 if (TrapBB
&& SingleTrapBB
)
186 Function
*Fn
= IRB
.GetInsertBlock()->getParent();
187 // FIXME: This debug location doesn't make a lot of sense in the
188 // `SingleTrapBB` case.
189 auto DebugLoc
= IRB
.getCurrentDebugLocation();
190 IRBuilder
<>::InsertPointGuard
Guard(IRB
);
191 TrapBB
= BasicBlock::Create(Fn
->getContext(), "trap", Fn
);
192 IRB
.SetInsertPoint(TrapBB
);
194 auto *F
= Intrinsic::getDeclaration(Fn
->getParent(), Intrinsic::trap
);
195 CallInst
*TrapCall
= IRB
.CreateCall(F
, {});
196 TrapCall
->setDoesNotReturn();
197 TrapCall
->setDoesNotThrow();
198 TrapCall
->setDebugLoc(DebugLoc
);
199 IRB
.CreateUnreachable();
205 for (const auto &Entry
: TrapInfo
) {
206 Instruction
*Inst
= Entry
.first
;
207 BuilderTy
IRB(Inst
->getParent(), BasicBlock::iterator(Inst
), TargetFolder(DL
));
208 insertBoundsCheck(Entry
.second
, IRB
, GetTrapBB
);
211 return !TrapInfo
.empty();
214 PreservedAnalyses
BoundsCheckingPass::run(Function
&F
, FunctionAnalysisManager
&AM
) {
215 auto &TLI
= AM
.getResult
<TargetLibraryAnalysis
>(F
);
216 auto &SE
= AM
.getResult
<ScalarEvolutionAnalysis
>(F
);
218 if (!addBoundsChecking(F
, TLI
, SE
))
219 return PreservedAnalyses::all();
221 return PreservedAnalyses::none();
225 struct BoundsCheckingLegacyPass
: public FunctionPass
{
228 BoundsCheckingLegacyPass() : FunctionPass(ID
) {
229 initializeBoundsCheckingLegacyPassPass(*PassRegistry::getPassRegistry());
232 bool runOnFunction(Function
&F
) override
{
233 auto &TLI
= getAnalysis
<TargetLibraryInfoWrapperPass
>().getTLI(F
);
234 auto &SE
= getAnalysis
<ScalarEvolutionWrapperPass
>().getSE();
235 return addBoundsChecking(F
, TLI
, SE
);
238 void getAnalysisUsage(AnalysisUsage
&AU
) const override
{
239 AU
.addRequired
<TargetLibraryInfoWrapperPass
>();
240 AU
.addRequired
<ScalarEvolutionWrapperPass
>();
245 char BoundsCheckingLegacyPass::ID
= 0;
246 INITIALIZE_PASS_BEGIN(BoundsCheckingLegacyPass
, "bounds-checking",
247 "Run-time bounds checking", false, false)
248 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass
)
249 INITIALIZE_PASS_END(BoundsCheckingLegacyPass
, "bounds-checking",
250 "Run-time bounds checking", false, false)
252 FunctionPass
*llvm::createBoundsCheckingLegacyPass() {
253 return new BoundsCheckingLegacyPass();