1 //===------------ BPFCheckAndAdjustIR.cpp - Check and Adjust IR -----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Check IR and adjust IR for verifier friendly codes.
10 // The following are done for IR checking:
11 // - no relocation globals in PHI node.
12 // The following are done for IR adjustment:
13 // - remove __builtin_bpf_passthrough builtins. Target independent IR
14 // optimizations are done and those builtins can be removed.
15 // - remove llvm.bpf.getelementptr.and.load builtins.
16 // - remove llvm.bpf.getelementptr.and.store builtins.
17 // - for loads and stores with base addresses from non-zero address space
18 // cast base address to zero address space (support for BPF address spaces).
20 //===----------------------------------------------------------------------===//
24 #include "BPFTargetMachine.h"
25 #include "llvm/Analysis/LoopInfo.h"
26 #include "llvm/IR/DebugInfoMetadata.h"
27 #include "llvm/IR/GlobalVariable.h"
28 #include "llvm/IR/IRBuilder.h"
29 #include "llvm/IR/Instruction.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/IntrinsicsBPF.h"
32 #include "llvm/IR/Module.h"
33 #include "llvm/IR/Type.h"
34 #include "llvm/IR/User.h"
35 #include "llvm/IR/Value.h"
36 #include "llvm/Pass.h"
37 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
39 #define DEBUG_TYPE "bpf-check-and-opt-ir"
45 class BPFCheckAndAdjustIR final
: public ModulePass
{
46 bool runOnModule(Module
&F
) override
;
50 BPFCheckAndAdjustIR() : ModulePass(ID
) {}
51 virtual void getAnalysisUsage(AnalysisUsage
&AU
) const override
;
54 void checkIR(Module
&M
);
55 bool adjustIR(Module
&M
);
56 bool removePassThroughBuiltin(Module
&M
);
57 bool removeCompareBuiltin(Module
&M
);
58 bool sinkMinMax(Module
&M
);
59 bool removeGEPBuiltins(Module
&M
);
60 bool insertASpaceCasts(Module
&M
);
62 } // End anonymous namespace
64 char BPFCheckAndAdjustIR::ID
= 0;
65 INITIALIZE_PASS(BPFCheckAndAdjustIR
, DEBUG_TYPE
, "BPF Check And Adjust IR",
68 ModulePass
*llvm::createBPFCheckAndAdjustIR() {
69 return new BPFCheckAndAdjustIR();
72 void BPFCheckAndAdjustIR::checkIR(Module
&M
) {
73 // Ensure relocation global won't appear in PHI node
74 // This may happen if the compiler generated the following code:
76 // g1 = @llvm.skb_buff:0:1...
80 // g2 = @llvm.skb_buff:0:2...
87 // If anything likes the above "g = PHI(g1, g2)", issue a fatal error.
91 PHINode
*PN
= dyn_cast
<PHINode
>(&I
);
92 if (!PN
|| PN
->use_empty())
94 for (int i
= 0, e
= PN
->getNumIncomingValues(); i
< e
; ++i
) {
95 auto *GV
= dyn_cast
<GlobalVariable
>(PN
->getIncomingValue(i
));
98 if (GV
->hasAttribute(BPFCoreSharedInfo::AmaAttr
) ||
99 GV
->hasAttribute(BPFCoreSharedInfo::TypeIdAttr
))
100 report_fatal_error("relocation global in PHI node");
105 bool BPFCheckAndAdjustIR::removePassThroughBuiltin(Module
&M
) {
106 // Remove __builtin_bpf_passthrough()'s which are used to prevent
107 // certain IR optimizations. Now major IR optimizations are done,
109 bool Changed
= false;
110 CallInst
*ToBeDeleted
= nullptr;
111 for (Function
&F
: M
)
115 ToBeDeleted
->eraseFromParent();
116 ToBeDeleted
= nullptr;
119 auto *Call
= dyn_cast
<CallInst
>(&I
);
122 auto *GV
= dyn_cast
<GlobalValue
>(Call
->getCalledOperand());
125 if (!GV
->getName().starts_with("llvm.bpf.passthrough"))
128 Value
*Arg
= Call
->getArgOperand(1);
129 Call
->replaceAllUsesWith(Arg
);
135 bool BPFCheckAndAdjustIR::removeCompareBuiltin(Module
&M
) {
136 // Remove __builtin_bpf_compare()'s which are used to prevent
137 // certain IR optimizations. Now major IR optimizations are done,
139 bool Changed
= false;
140 CallInst
*ToBeDeleted
= nullptr;
141 for (Function
&F
: M
)
145 ToBeDeleted
->eraseFromParent();
146 ToBeDeleted
= nullptr;
149 auto *Call
= dyn_cast
<CallInst
>(&I
);
152 auto *GV
= dyn_cast
<GlobalValue
>(Call
->getCalledOperand());
155 if (!GV
->getName().starts_with("llvm.bpf.compare"))
159 Value
*Arg0
= Call
->getArgOperand(0);
160 Value
*Arg1
= Call
->getArgOperand(1);
161 Value
*Arg2
= Call
->getArgOperand(2);
163 auto OpVal
= cast
<ConstantInt
>(Arg0
)->getValue().getZExtValue();
164 CmpInst::Predicate Opcode
= (CmpInst::Predicate
)OpVal
;
166 auto *ICmp
= new ICmpInst(Opcode
, Arg1
, Arg2
);
167 ICmp
->insertBefore(Call
);
169 Call
->replaceAllUsesWith(ICmp
);
175 struct MinMaxSinkInfo
{
178 ICmpInst::Predicate Predicate
;
183 MinMaxSinkInfo(ICmpInst
*ICmp
, Value
*Other
, ICmpInst::Predicate Predicate
)
184 : ICmp(ICmp
), Other(Other
), Predicate(Predicate
), MinMax(nullptr),
185 ZExt(nullptr), SExt(nullptr) {}
188 static bool sinkMinMaxInBB(BasicBlock
&BB
,
189 const std::function
<bool(Instruction
*)> &Filter
) {
191 // (fn %a %b) or (ext (fn %a %b))
193 // ext := sext | zext
194 // fn := smin | umin | smax | umax
195 auto IsMinMaxCall
= [=](Value
*V
, MinMaxSinkInfo
&Info
) {
196 if (auto *ZExt
= dyn_cast
<ZExtInst
>(V
)) {
197 V
= ZExt
->getOperand(0);
199 } else if (auto *SExt
= dyn_cast
<SExtInst
>(V
)) {
200 V
= SExt
->getOperand(0);
204 auto *Call
= dyn_cast
<CallInst
>(V
);
208 auto *Called
= dyn_cast
<Function
>(Call
->getCalledOperand());
212 switch (Called
->getIntrinsicID()) {
213 case Intrinsic::smin
:
214 case Intrinsic::umin
:
215 case Intrinsic::smax
:
216 case Intrinsic::umax
:
230 auto ZeroOrSignExtend
= [](IRBuilder
<> &Builder
, Value
*V
,
231 MinMaxSinkInfo
&Info
) {
233 if (Info
.SExt
->getType() == V
->getType())
235 return Builder
.CreateSExt(V
, Info
.SExt
->getType());
238 if (Info
.ZExt
->getType() == V
->getType())
240 return Builder
.CreateZExt(V
, Info
.ZExt
->getType());
245 bool Changed
= false;
246 SmallVector
<MinMaxSinkInfo
, 2> SinkList
;
248 // Check BB for instructions like:
249 // insn := (icmp %a (fn ...)) | (icmp (fn ...) %a)
252 // fn := min | max | (sext (min ...)) | (sext (max ...))
254 // Put such instructions to SinkList.
255 for (Instruction
&I
: BB
) {
256 ICmpInst
*ICmp
= dyn_cast
<ICmpInst
>(&I
);
259 if (!ICmp
->isRelational())
261 MinMaxSinkInfo
First(ICmp
, ICmp
->getOperand(1),
262 ICmpInst::getSwappedPredicate(ICmp
->getPredicate()));
263 MinMaxSinkInfo
Second(ICmp
, ICmp
->getOperand(0), ICmp
->getPredicate());
264 bool FirstMinMax
= IsMinMaxCall(ICmp
->getOperand(0), First
);
265 bool SecondMinMax
= IsMinMaxCall(ICmp
->getOperand(1), Second
);
266 if (!(FirstMinMax
^ SecondMinMax
))
268 SinkList
.push_back(FirstMinMax
? First
: Second
);
271 // Iterate SinkList and replace each (icmp ...) with corresponding
272 // `x < a && x < b` or similar expression.
273 for (auto &Info
: SinkList
) {
274 ICmpInst
*ICmp
= Info
.ICmp
;
275 CallInst
*MinMax
= Info
.MinMax
;
276 Intrinsic::ID IID
= MinMax
->getCalledFunction()->getIntrinsicID();
277 ICmpInst::Predicate P
= Info
.Predicate
;
278 if (ICmpInst::isSigned(P
) && IID
!= Intrinsic::smin
&&
279 IID
!= Intrinsic::smax
)
282 IRBuilder
<> Builder(ICmp
);
283 Value
*X
= Info
.Other
;
284 Value
*A
= ZeroOrSignExtend(Builder
, MinMax
->getArgOperand(0), Info
);
285 Value
*B
= ZeroOrSignExtend(Builder
, MinMax
->getArgOperand(1), Info
);
286 bool IsMin
= IID
== Intrinsic::smin
|| IID
== Intrinsic::umin
;
287 bool IsMax
= IID
== Intrinsic::smax
|| IID
== Intrinsic::umax
;
288 bool IsLess
= ICmpInst::isLE(P
) || ICmpInst::isLT(P
);
289 bool IsGreater
= ICmpInst::isGE(P
) || ICmpInst::isGT(P
);
290 assert(IsMin
^ IsMax
);
291 assert(IsLess
^ IsGreater
);
294 Value
*LHS
= Builder
.CreateICmp(P
, X
, A
);
295 Value
*RHS
= Builder
.CreateICmp(P
, X
, B
);
296 if ((IsLess
&& IsMin
) || (IsGreater
&& IsMax
))
297 // x < min(a, b) -> x < a && x < b
298 // x > max(a, b) -> x > a && x > b
299 Replacement
= Builder
.CreateLogicalAnd(LHS
, RHS
);
301 // x > min(a, b) -> x > a || x > b
302 // x < max(a, b) -> x < a || x < b
303 Replacement
= Builder
.CreateLogicalOr(LHS
, RHS
);
305 ICmp
->replaceAllUsesWith(Replacement
);
307 Instruction
*ToRemove
[] = {ICmp
, Info
.ZExt
, Info
.SExt
, MinMax
};
308 for (Instruction
*I
: ToRemove
)
309 if (I
&& I
->use_empty())
310 I
->eraseFromParent();
318 // Do the following transformation:
320 // x < min(a, b) -> x < a && x < b
321 // x > min(a, b) -> x > a || x > b
322 // x < max(a, b) -> x < a || x < b
323 // x > max(a, b) -> x > a && x > b
325 // Such patterns are introduced by LICM.cpp:hoistMinMax()
326 // transformation and might lead to BPF verification failures for
329 // To minimize "collateral" changes only do it for icmp + min/max
330 // calls when icmp is inside a loop and min/max is outside of that
333 // Verification failure happens when:
334 // - RHS operand of some `icmp LHS, RHS` is replaced by some RHS1;
335 // - verifier can recognize RHS as a constant scalar in some context;
336 // - verifier can't recognize RHS1 as a constant scalar in the same
339 // The "constant scalar" is not a compile time constant, but a register
340 // that holds a scalar value known to verifier at some point in time
341 // during abstract interpretation.
344 // https://lore.kernel.org/bpf/20230406164505.1046801-1-yhs@fb.com/
345 bool BPFCheckAndAdjustIR::sinkMinMax(Module
&M
) {
346 bool Changed
= false;
348 for (Function
&F
: M
) {
349 if (F
.isDeclaration())
352 LoopInfo
&LI
= getAnalysis
<LoopInfoWrapperPass
>(F
).getLoopInfo();
354 for (BasicBlock
*BB
: L
->blocks()) {
355 // Filter out instructions coming from the same loop
356 Loop
*BBLoop
= LI
.getLoopFor(BB
);
357 auto OtherLoopFilter
= [&](Instruction
*I
) {
358 return LI
.getLoopFor(I
->getParent()) != BBLoop
;
360 Changed
|= sinkMinMaxInBB(*BB
, OtherLoopFilter
);
367 void BPFCheckAndAdjustIR::getAnalysisUsage(AnalysisUsage
&AU
) const {
368 AU
.addRequired
<LoopInfoWrapperPass
>();
371 static void unrollGEPLoad(CallInst
*Call
) {
372 auto [GEP
, Load
] = BPFPreserveStaticOffsetPass::reconstructLoad(Call
);
373 GEP
->insertBefore(Call
);
374 Load
->insertBefore(Call
);
375 Call
->replaceAllUsesWith(Load
);
376 Call
->eraseFromParent();
379 static void unrollGEPStore(CallInst
*Call
) {
380 auto [GEP
, Store
] = BPFPreserveStaticOffsetPass::reconstructStore(Call
);
381 GEP
->insertBefore(Call
);
382 Store
->insertBefore(Call
);
383 Call
->eraseFromParent();
386 static bool removeGEPBuiltinsInFunc(Function
&F
) {
387 SmallVector
<CallInst
*> GEPLoads
;
388 SmallVector
<CallInst
*> GEPStores
;
390 for (auto &Insn
: BB
)
391 if (auto *Call
= dyn_cast
<CallInst
>(&Insn
))
392 if (auto *Called
= Call
->getCalledFunction())
393 switch (Called
->getIntrinsicID()) {
394 case Intrinsic::bpf_getelementptr_and_load
:
395 GEPLoads
.push_back(Call
);
397 case Intrinsic::bpf_getelementptr_and_store
:
398 GEPStores
.push_back(Call
);
402 if (GEPLoads
.empty() && GEPStores
.empty())
405 for_each(GEPLoads
, unrollGEPLoad
);
406 for_each(GEPStores
, unrollGEPStore
);
411 // Rewrites the following builtins:
412 // - llvm.bpf.getelementptr.and.load
413 // - llvm.bpf.getelementptr.and.store
414 // As (load (getelementptr ...)) or (store (getelementptr ...)).
415 bool BPFCheckAndAdjustIR::removeGEPBuiltins(Module
&M
) {
416 bool Changed
= false;
418 Changed
= removeGEPBuiltinsInFunc(F
) || Changed
;
422 // Wrap ToWrap with cast to address space zero:
423 // - if ToWrap is a getelementptr,
424 // wrap it's base pointer instead and return a copy;
425 // - if ToWrap is Instruction, insert address space cast
426 // immediately after ToWrap;
427 // - if ToWrap is not an Instruction (function parameter
428 // or a global value), insert address space cast at the
429 // beginning of the Function F;
430 // - use Cache to avoid inserting too many casts;
431 static Value
*aspaceWrapValue(DenseMap
<Value
*, Value
*> &Cache
, Function
*F
,
433 auto It
= Cache
.find(ToWrap
);
434 if (It
!= Cache
.end())
435 return It
->getSecond();
437 if (auto *GEP
= dyn_cast
<GetElementPtrInst
>(ToWrap
)) {
438 Value
*Ptr
= GEP
->getPointerOperand();
439 Value
*WrappedPtr
= aspaceWrapValue(Cache
, F
, Ptr
);
440 auto *GEPTy
= cast
<PointerType
>(GEP
->getType());
441 auto *NewGEP
= GEP
->clone();
442 NewGEP
->insertAfter(GEP
);
443 NewGEP
->mutateType(GEPTy
->getPointerTo(0));
444 NewGEP
->setOperand(GEP
->getPointerOperandIndex(), WrappedPtr
);
445 NewGEP
->setName(GEP
->getName());
446 Cache
[ToWrap
] = NewGEP
;
450 IRBuilder
IB(F
->getContext());
451 if (Instruction
*InsnPtr
= dyn_cast
<Instruction
>(ToWrap
))
452 IB
.SetInsertPoint(*InsnPtr
->getInsertionPointAfterDef());
454 IB
.SetInsertPoint(F
->getEntryBlock().getFirstInsertionPt());
455 auto *PtrTy
= cast
<PointerType
>(ToWrap
->getType());
456 auto *ASZeroPtrTy
= PtrTy
->getPointerTo(0);
457 auto *ACast
= IB
.CreateAddrSpaceCast(ToWrap
, ASZeroPtrTy
, ToWrap
->getName());
458 Cache
[ToWrap
] = ACast
;
462 // Wrap a pointer operand OpNum of instruction I
463 // with cast to address space zero
464 static void aspaceWrapOperand(DenseMap
<Value
*, Value
*> &Cache
, Instruction
*I
,
466 Value
*OldOp
= I
->getOperand(OpNum
);
467 if (OldOp
->getType()->getPointerAddressSpace() == 0)
470 Value
*NewOp
= aspaceWrapValue(Cache
, I
->getFunction(), OldOp
);
471 I
->setOperand(OpNum
, NewOp
);
472 // Check if there are any remaining users of old GEP,
473 // delete those w/o users
475 auto *OldGEP
= dyn_cast
<GetElementPtrInst
>(OldOp
);
478 if (!OldGEP
->use_empty())
480 OldOp
= OldGEP
->getPointerOperand();
481 OldGEP
->eraseFromParent();
485 // Support for BPF address spaces:
486 // - for each function in the module M, update pointer operand of
487 // each memory access instruction (load/store/cmpxchg/atomicrmw)
488 // by casting it from non-zero address space to zero address space, e.g:
490 // (load (ptr addrspace (N) %p) ...)
491 // -> (load (addrspacecast ptr addrspace (N) %p to ptr))
493 // - assign section with name .addr_space.N for globals defined in
494 // non-zero address space N
495 bool BPFCheckAndAdjustIR::insertASpaceCasts(Module
&M
) {
496 bool Changed
= false;
497 for (Function
&F
: M
) {
498 DenseMap
<Value
*, Value
*> CastsCache
;
499 for (BasicBlock
&BB
: F
) {
500 for (Instruction
&I
: BB
) {
503 if (auto *LD
= dyn_cast
<LoadInst
>(&I
))
504 PtrOpNum
= LD
->getPointerOperandIndex();
505 else if (auto *ST
= dyn_cast
<StoreInst
>(&I
))
506 PtrOpNum
= ST
->getPointerOperandIndex();
507 else if (auto *CmpXchg
= dyn_cast
<AtomicCmpXchgInst
>(&I
))
508 PtrOpNum
= CmpXchg
->getPointerOperandIndex();
509 else if (auto *RMW
= dyn_cast
<AtomicRMWInst
>(&I
))
510 PtrOpNum
= RMW
->getPointerOperandIndex();
514 aspaceWrapOperand(CastsCache
, &I
, PtrOpNum
);
517 Changed
|= !CastsCache
.empty();
519 // Merge all globals within same address space into single
520 // .addr_space.<addr space no> section
521 for (GlobalVariable
&G
: M
.globals()) {
522 if (G
.getAddressSpace() == 0 || G
.hasSection())
524 SmallString
<16> SecName
;
525 raw_svector_ostream
OS(SecName
);
526 OS
<< ".addr_space." << G
.getAddressSpace();
527 G
.setSection(SecName
);
528 // Prevent having separate section for constants
529 G
.setConstant(false);
534 bool BPFCheckAndAdjustIR::adjustIR(Module
&M
) {
535 bool Changed
= removePassThroughBuiltin(M
);
536 Changed
= removeCompareBuiltin(M
) || Changed
;
537 Changed
= sinkMinMax(M
) || Changed
;
538 Changed
= removeGEPBuiltins(M
) || Changed
;
539 Changed
= insertASpaceCasts(M
) || Changed
;
543 bool BPFCheckAndAdjustIR::runOnModule(Module
&M
) {