1 //===------------ BPFCheckAndAdjustIR.cpp - Check and Adjust IR -----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Check IR and adjust IR for verifier friendly codes.
10 // The following are done for IR checking:
11 // - no relocation globals in PHI node.
12 // The following are done for IR adjustment:
13 // - remove __builtin_bpf_passthrough builtins. Target independent IR
14 // optimizations are done and those builtins can be removed.
15 // - remove llvm.bpf.getelementptr.and.load builtins.
16 // - remove llvm.bpf.getelementptr.and.store builtins.
17 // - for loads and stores with base addresses from non-zero address space
18 // cast base address to zero address space (support for BPF address spaces).
20 //===----------------------------------------------------------------------===//
24 #include "llvm/Analysis/LoopInfo.h"
25 #include "llvm/IR/GlobalVariable.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/Instruction.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/IntrinsicsBPF.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/IR/Type.h"
32 #include "llvm/IR/Value.h"
33 #include "llvm/Pass.h"
34 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
36 #define DEBUG_TYPE "bpf-check-and-opt-ir"
42 class BPFCheckAndAdjustIR final
: public ModulePass
{
43 bool runOnModule(Module
&F
) override
;
47 BPFCheckAndAdjustIR() : ModulePass(ID
) {}
48 virtual void getAnalysisUsage(AnalysisUsage
&AU
) const override
;
51 void checkIR(Module
&M
);
52 bool adjustIR(Module
&M
);
53 bool removePassThroughBuiltin(Module
&M
);
54 bool removeCompareBuiltin(Module
&M
);
55 bool sinkMinMax(Module
&M
);
56 bool removeGEPBuiltins(Module
&M
);
57 bool insertASpaceCasts(Module
&M
);
59 } // End anonymous namespace
61 char BPFCheckAndAdjustIR::ID
= 0;
62 INITIALIZE_PASS(BPFCheckAndAdjustIR
, DEBUG_TYPE
, "BPF Check And Adjust IR",
65 ModulePass
*llvm::createBPFCheckAndAdjustIR() {
66 return new BPFCheckAndAdjustIR();
69 void BPFCheckAndAdjustIR::checkIR(Module
&M
) {
70 // Ensure relocation global won't appear in PHI node
71 // This may happen if the compiler generated the following code:
73 // g1 = @llvm.skb_buff:0:1...
77 // g2 = @llvm.skb_buff:0:2...
84 // If anything likes the above "g = PHI(g1, g2)", issue a fatal error.
88 PHINode
*PN
= dyn_cast
<PHINode
>(&I
);
89 if (!PN
|| PN
->use_empty())
91 for (int i
= 0, e
= PN
->getNumIncomingValues(); i
< e
; ++i
) {
92 auto *GV
= dyn_cast
<GlobalVariable
>(PN
->getIncomingValue(i
));
95 if (GV
->hasAttribute(BPFCoreSharedInfo::AmaAttr
) ||
96 GV
->hasAttribute(BPFCoreSharedInfo::TypeIdAttr
))
97 report_fatal_error("relocation global in PHI node");
102 bool BPFCheckAndAdjustIR::removePassThroughBuiltin(Module
&M
) {
103 // Remove __builtin_bpf_passthrough()'s which are used to prevent
104 // certain IR optimizations. Now major IR optimizations are done,
106 bool Changed
= false;
107 CallInst
*ToBeDeleted
= nullptr;
108 for (Function
&F
: M
)
112 ToBeDeleted
->eraseFromParent();
113 ToBeDeleted
= nullptr;
116 auto *Call
= dyn_cast
<CallInst
>(&I
);
119 auto *GV
= dyn_cast
<GlobalValue
>(Call
->getCalledOperand());
122 if (!GV
->getName().starts_with("llvm.bpf.passthrough"))
125 Value
*Arg
= Call
->getArgOperand(1);
126 Call
->replaceAllUsesWith(Arg
);
132 bool BPFCheckAndAdjustIR::removeCompareBuiltin(Module
&M
) {
133 // Remove __builtin_bpf_compare()'s which are used to prevent
134 // certain IR optimizations. Now major IR optimizations are done,
136 bool Changed
= false;
137 CallInst
*ToBeDeleted
= nullptr;
138 for (Function
&F
: M
)
142 ToBeDeleted
->eraseFromParent();
143 ToBeDeleted
= nullptr;
146 auto *Call
= dyn_cast
<CallInst
>(&I
);
149 auto *GV
= dyn_cast
<GlobalValue
>(Call
->getCalledOperand());
152 if (!GV
->getName().starts_with("llvm.bpf.compare"))
156 Value
*Arg0
= Call
->getArgOperand(0);
157 Value
*Arg1
= Call
->getArgOperand(1);
158 Value
*Arg2
= Call
->getArgOperand(2);
160 auto OpVal
= cast
<ConstantInt
>(Arg0
)->getValue().getZExtValue();
161 CmpInst::Predicate Opcode
= (CmpInst::Predicate
)OpVal
;
163 auto *ICmp
= new ICmpInst(Opcode
, Arg1
, Arg2
);
164 ICmp
->insertBefore(Call
);
166 Call
->replaceAllUsesWith(ICmp
);
172 struct MinMaxSinkInfo
{
175 ICmpInst::Predicate Predicate
;
180 MinMaxSinkInfo(ICmpInst
*ICmp
, Value
*Other
, ICmpInst::Predicate Predicate
)
181 : ICmp(ICmp
), Other(Other
), Predicate(Predicate
), MinMax(nullptr),
182 ZExt(nullptr), SExt(nullptr) {}
185 static bool sinkMinMaxInBB(BasicBlock
&BB
,
186 const std::function
<bool(Instruction
*)> &Filter
) {
188 // (fn %a %b) or (ext (fn %a %b))
190 // ext := sext | zext
191 // fn := smin | umin | smax | umax
192 auto IsMinMaxCall
= [=](Value
*V
, MinMaxSinkInfo
&Info
) {
193 if (auto *ZExt
= dyn_cast
<ZExtInst
>(V
)) {
194 V
= ZExt
->getOperand(0);
196 } else if (auto *SExt
= dyn_cast
<SExtInst
>(V
)) {
197 V
= SExt
->getOperand(0);
201 auto *Call
= dyn_cast
<CallInst
>(V
);
205 auto *Called
= dyn_cast
<Function
>(Call
->getCalledOperand());
209 switch (Called
->getIntrinsicID()) {
210 case Intrinsic::smin
:
211 case Intrinsic::umin
:
212 case Intrinsic::smax
:
213 case Intrinsic::umax
:
227 auto ZeroOrSignExtend
= [](IRBuilder
<> &Builder
, Value
*V
,
228 MinMaxSinkInfo
&Info
) {
230 if (Info
.SExt
->getType() == V
->getType())
232 return Builder
.CreateSExt(V
, Info
.SExt
->getType());
235 if (Info
.ZExt
->getType() == V
->getType())
237 return Builder
.CreateZExt(V
, Info
.ZExt
->getType());
242 bool Changed
= false;
243 SmallVector
<MinMaxSinkInfo
, 2> SinkList
;
245 // Check BB for instructions like:
246 // insn := (icmp %a (fn ...)) | (icmp (fn ...) %a)
249 // fn := min | max | (sext (min ...)) | (sext (max ...))
251 // Put such instructions to SinkList.
252 for (Instruction
&I
: BB
) {
253 ICmpInst
*ICmp
= dyn_cast
<ICmpInst
>(&I
);
256 if (!ICmp
->isRelational())
258 MinMaxSinkInfo
First(ICmp
, ICmp
->getOperand(1),
259 ICmpInst::getSwappedPredicate(ICmp
->getPredicate()));
260 MinMaxSinkInfo
Second(ICmp
, ICmp
->getOperand(0), ICmp
->getPredicate());
261 bool FirstMinMax
= IsMinMaxCall(ICmp
->getOperand(0), First
);
262 bool SecondMinMax
= IsMinMaxCall(ICmp
->getOperand(1), Second
);
263 if (!(FirstMinMax
^ SecondMinMax
))
265 SinkList
.push_back(FirstMinMax
? First
: Second
);
268 // Iterate SinkList and replace each (icmp ...) with corresponding
269 // `x < a && x < b` or similar expression.
270 for (auto &Info
: SinkList
) {
271 ICmpInst
*ICmp
= Info
.ICmp
;
272 CallInst
*MinMax
= Info
.MinMax
;
273 Intrinsic::ID IID
= MinMax
->getCalledFunction()->getIntrinsicID();
274 ICmpInst::Predicate P
= Info
.Predicate
;
275 if (ICmpInst::isSigned(P
) && IID
!= Intrinsic::smin
&&
276 IID
!= Intrinsic::smax
)
279 IRBuilder
<> Builder(ICmp
);
280 Value
*X
= Info
.Other
;
281 Value
*A
= ZeroOrSignExtend(Builder
, MinMax
->getArgOperand(0), Info
);
282 Value
*B
= ZeroOrSignExtend(Builder
, MinMax
->getArgOperand(1), Info
);
283 bool IsMin
= IID
== Intrinsic::smin
|| IID
== Intrinsic::umin
;
284 bool IsMax
= IID
== Intrinsic::smax
|| IID
== Intrinsic::umax
;
285 bool IsLess
= ICmpInst::isLE(P
) || ICmpInst::isLT(P
);
286 bool IsGreater
= ICmpInst::isGE(P
) || ICmpInst::isGT(P
);
287 assert(IsMin
^ IsMax
);
288 assert(IsLess
^ IsGreater
);
291 Value
*LHS
= Builder
.CreateICmp(P
, X
, A
);
292 Value
*RHS
= Builder
.CreateICmp(P
, X
, B
);
293 if ((IsLess
&& IsMin
) || (IsGreater
&& IsMax
))
294 // x < min(a, b) -> x < a && x < b
295 // x > max(a, b) -> x > a && x > b
296 Replacement
= Builder
.CreateLogicalAnd(LHS
, RHS
);
298 // x > min(a, b) -> x > a || x > b
299 // x < max(a, b) -> x < a || x < b
300 Replacement
= Builder
.CreateLogicalOr(LHS
, RHS
);
302 ICmp
->replaceAllUsesWith(Replacement
);
304 Instruction
*ToRemove
[] = {ICmp
, Info
.ZExt
, Info
.SExt
, MinMax
};
305 for (Instruction
*I
: ToRemove
)
306 if (I
&& I
->use_empty())
307 I
->eraseFromParent();
315 // Do the following transformation:
317 // x < min(a, b) -> x < a && x < b
318 // x > min(a, b) -> x > a || x > b
319 // x < max(a, b) -> x < a || x < b
320 // x > max(a, b) -> x > a && x > b
322 // Such patterns are introduced by LICM.cpp:hoistMinMax()
323 // transformation and might lead to BPF verification failures for
326 // To minimize "collateral" changes only do it for icmp + min/max
327 // calls when icmp is inside a loop and min/max is outside of that
330 // Verification failure happens when:
331 // - RHS operand of some `icmp LHS, RHS` is replaced by some RHS1;
332 // - verifier can recognize RHS as a constant scalar in some context;
333 // - verifier can't recognize RHS1 as a constant scalar in the same
336 // The "constant scalar" is not a compile time constant, but a register
337 // that holds a scalar value known to verifier at some point in time
338 // during abstract interpretation.
341 // https://lore.kernel.org/bpf/20230406164505.1046801-1-yhs@fb.com/
342 bool BPFCheckAndAdjustIR::sinkMinMax(Module
&M
) {
343 bool Changed
= false;
345 for (Function
&F
: M
) {
346 if (F
.isDeclaration())
349 LoopInfo
&LI
= getAnalysis
<LoopInfoWrapperPass
>(F
).getLoopInfo();
351 for (BasicBlock
*BB
: L
->blocks()) {
352 // Filter out instructions coming from the same loop
353 Loop
*BBLoop
= LI
.getLoopFor(BB
);
354 auto OtherLoopFilter
= [&](Instruction
*I
) {
355 return LI
.getLoopFor(I
->getParent()) != BBLoop
;
357 Changed
|= sinkMinMaxInBB(*BB
, OtherLoopFilter
);
364 void BPFCheckAndAdjustIR::getAnalysisUsage(AnalysisUsage
&AU
) const {
365 AU
.addRequired
<LoopInfoWrapperPass
>();
368 static void unrollGEPLoad(CallInst
*Call
) {
369 auto [GEP
, Load
] = BPFPreserveStaticOffsetPass::reconstructLoad(Call
);
370 GEP
->insertBefore(Call
);
371 Load
->insertBefore(Call
);
372 Call
->replaceAllUsesWith(Load
);
373 Call
->eraseFromParent();
376 static void unrollGEPStore(CallInst
*Call
) {
377 auto [GEP
, Store
] = BPFPreserveStaticOffsetPass::reconstructStore(Call
);
378 GEP
->insertBefore(Call
);
379 Store
->insertBefore(Call
);
380 Call
->eraseFromParent();
383 static bool removeGEPBuiltinsInFunc(Function
&F
) {
384 SmallVector
<CallInst
*> GEPLoads
;
385 SmallVector
<CallInst
*> GEPStores
;
387 for (auto &Insn
: BB
)
388 if (auto *Call
= dyn_cast
<CallInst
>(&Insn
))
389 if (auto *Called
= Call
->getCalledFunction())
390 switch (Called
->getIntrinsicID()) {
391 case Intrinsic::bpf_getelementptr_and_load
:
392 GEPLoads
.push_back(Call
);
394 case Intrinsic::bpf_getelementptr_and_store
:
395 GEPStores
.push_back(Call
);
399 if (GEPLoads
.empty() && GEPStores
.empty())
402 for_each(GEPLoads
, unrollGEPLoad
);
403 for_each(GEPStores
, unrollGEPStore
);
408 // Rewrites the following builtins:
409 // - llvm.bpf.getelementptr.and.load
410 // - llvm.bpf.getelementptr.and.store
411 // As (load (getelementptr ...)) or (store (getelementptr ...)).
412 bool BPFCheckAndAdjustIR::removeGEPBuiltins(Module
&M
) {
413 bool Changed
= false;
415 Changed
= removeGEPBuiltinsInFunc(F
) || Changed
;
419 // Wrap ToWrap with cast to address space zero:
420 // - if ToWrap is a getelementptr,
421 // wrap it's base pointer instead and return a copy;
422 // - if ToWrap is Instruction, insert address space cast
423 // immediately after ToWrap;
424 // - if ToWrap is not an Instruction (function parameter
425 // or a global value), insert address space cast at the
426 // beginning of the Function F;
427 // - use Cache to avoid inserting too many casts;
428 static Value
*aspaceWrapValue(DenseMap
<Value
*, Value
*> &Cache
, Function
*F
,
430 auto It
= Cache
.find(ToWrap
);
431 if (It
!= Cache
.end())
432 return It
->getSecond();
434 if (auto *GEP
= dyn_cast
<GetElementPtrInst
>(ToWrap
)) {
435 Value
*Ptr
= GEP
->getPointerOperand();
436 Value
*WrappedPtr
= aspaceWrapValue(Cache
, F
, Ptr
);
437 auto *GEPTy
= cast
<PointerType
>(GEP
->getType());
438 auto *NewGEP
= GEP
->clone();
439 NewGEP
->insertAfter(GEP
);
440 NewGEP
->mutateType(PointerType::getUnqual(GEPTy
->getContext()));
441 NewGEP
->setOperand(GEP
->getPointerOperandIndex(), WrappedPtr
);
442 NewGEP
->setName(GEP
->getName());
443 Cache
[ToWrap
] = NewGEP
;
447 IRBuilder
IB(F
->getContext());
448 if (Instruction
*InsnPtr
= dyn_cast
<Instruction
>(ToWrap
))
449 IB
.SetInsertPoint(*InsnPtr
->getInsertionPointAfterDef());
451 IB
.SetInsertPoint(F
->getEntryBlock().getFirstInsertionPt());
452 auto *ASZeroPtrTy
= IB
.getPtrTy(0);
453 auto *ACast
= IB
.CreateAddrSpaceCast(ToWrap
, ASZeroPtrTy
, ToWrap
->getName());
454 Cache
[ToWrap
] = ACast
;
458 // Wrap a pointer operand OpNum of instruction I
459 // with cast to address space zero
460 static void aspaceWrapOperand(DenseMap
<Value
*, Value
*> &Cache
, Instruction
*I
,
462 Value
*OldOp
= I
->getOperand(OpNum
);
463 if (OldOp
->getType()->getPointerAddressSpace() == 0)
466 Value
*NewOp
= aspaceWrapValue(Cache
, I
->getFunction(), OldOp
);
467 I
->setOperand(OpNum
, NewOp
);
468 // Check if there are any remaining users of old GEP,
469 // delete those w/o users
471 auto *OldGEP
= dyn_cast
<GetElementPtrInst
>(OldOp
);
474 if (!OldGEP
->use_empty())
476 OldOp
= OldGEP
->getPointerOperand();
477 OldGEP
->eraseFromParent();
481 // Support for BPF address spaces:
482 // - for each function in the module M, update pointer operand of
483 // each memory access instruction (load/store/cmpxchg/atomicrmw)
484 // by casting it from non-zero address space to zero address space, e.g:
486 // (load (ptr addrspace (N) %p) ...)
487 // -> (load (addrspacecast ptr addrspace (N) %p to ptr))
489 // - assign section with name .addr_space.N for globals defined in
490 // non-zero address space N
491 bool BPFCheckAndAdjustIR::insertASpaceCasts(Module
&M
) {
492 bool Changed
= false;
493 for (Function
&F
: M
) {
494 DenseMap
<Value
*, Value
*> CastsCache
;
495 for (BasicBlock
&BB
: F
) {
496 for (Instruction
&I
: BB
) {
499 if (auto *LD
= dyn_cast
<LoadInst
>(&I
))
500 PtrOpNum
= LD
->getPointerOperandIndex();
501 else if (auto *ST
= dyn_cast
<StoreInst
>(&I
))
502 PtrOpNum
= ST
->getPointerOperandIndex();
503 else if (auto *CmpXchg
= dyn_cast
<AtomicCmpXchgInst
>(&I
))
504 PtrOpNum
= CmpXchg
->getPointerOperandIndex();
505 else if (auto *RMW
= dyn_cast
<AtomicRMWInst
>(&I
))
506 PtrOpNum
= RMW
->getPointerOperandIndex();
510 aspaceWrapOperand(CastsCache
, &I
, PtrOpNum
);
513 Changed
|= !CastsCache
.empty();
515 // Merge all globals within same address space into single
516 // .addr_space.<addr space no> section
517 for (GlobalVariable
&G
: M
.globals()) {
518 if (G
.getAddressSpace() == 0 || G
.hasSection())
520 SmallString
<16> SecName
;
521 raw_svector_ostream
OS(SecName
);
522 OS
<< ".addr_space." << G
.getAddressSpace();
523 G
.setSection(SecName
);
524 // Prevent having separate section for constants
525 G
.setConstant(false);
530 bool BPFCheckAndAdjustIR::adjustIR(Module
&M
) {
531 bool Changed
= removePassThroughBuiltin(M
);
532 Changed
= removeCompareBuiltin(M
) || Changed
;
533 Changed
= sinkMinMax(M
) || Changed
;
534 Changed
= removeGEPBuiltins(M
) || Changed
;
535 Changed
= insertASpaceCasts(M
) || Changed
;
539 bool BPFCheckAndAdjustIR::runOnModule(Module
&M
) {