1 //===-- AArch64Arm64ECCallLowering.cpp - Lower Arm64EC calls ----*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 /// This file contains the IR transform to lower external or indirect calls for
11 /// the ARM64EC calling convention. Such calls must go through the runtime, so
12 /// we can translate the calling convention for calls into the emulator.
14 /// This subsumes Control Flow Guard handling.
16 //===----------------------------------------------------------------------===//
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallString.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/Statistic.h"
23 #include "llvm/IR/CallingConv.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/Instruction.h"
26 #include "llvm/IR/Mangler.h"
27 #include "llvm/InitializePasses.h"
28 #include "llvm/Object/COFF.h"
29 #include "llvm/Pass.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/TargetParser/Triple.h"
34 using namespace llvm::COFF
;
36 using OperandBundleDef
= OperandBundleDefT
<Value
*>;
38 #define DEBUG_TYPE "arm64eccalllowering"
40 STATISTIC(Arm64ECCallsLowered
, "Number of Arm64EC calls lowered");
42 static cl::opt
<bool> LowerDirectToIndirect("arm64ec-lower-direct-to-indirect",
43 cl::Hidden
, cl::init(true));
44 static cl::opt
<bool> GenerateThunks("arm64ec-generate-thunks", cl::Hidden
,
49 class AArch64Arm64ECCallLowering
: public ModulePass
{
52 AArch64Arm64ECCallLowering() : ModulePass(ID
) {
53 initializeAArch64Arm64ECCallLoweringPass(*PassRegistry::getPassRegistry());
56 Function
*buildExitThunk(FunctionType
*FnTy
, AttributeList Attrs
);
57 Function
*buildEntryThunk(Function
*F
);
58 void lowerCall(CallBase
*CB
);
59 Function
*buildGuestExitThunk(Function
*F
);
60 bool processFunction(Function
&F
, SetVector
<Function
*> &DirectCalledFns
);
61 bool runOnModule(Module
&M
) override
;
64 int cfguard_module_flag
= 0;
65 FunctionType
*GuardFnType
= nullptr;
66 PointerType
*GuardFnPtrType
= nullptr;
67 Constant
*GuardFnCFGlobal
= nullptr;
68 Constant
*GuardFnGlobal
= nullptr;
75 void getThunkType(FunctionType
*FT
, AttributeList AttrList
,
76 Arm64ECThunkType TT
, raw_ostream
&Out
,
77 FunctionType
*&Arm64Ty
, FunctionType
*&X64Ty
);
78 void getThunkRetType(FunctionType
*FT
, AttributeList AttrList
,
79 raw_ostream
&Out
, Type
*&Arm64RetTy
, Type
*&X64RetTy
,
80 SmallVectorImpl
<Type
*> &Arm64ArgTypes
,
81 SmallVectorImpl
<Type
*> &X64ArgTypes
, bool &HasSretPtr
);
82 void getThunkArgTypes(FunctionType
*FT
, AttributeList AttrList
,
83 Arm64ECThunkType TT
, raw_ostream
&Out
,
84 SmallVectorImpl
<Type
*> &Arm64ArgTypes
,
85 SmallVectorImpl
<Type
*> &X64ArgTypes
, bool HasSretPtr
);
86 void canonicalizeThunkType(Type
*T
, Align Alignment
, bool Ret
,
87 uint64_t ArgSizeBytes
, raw_ostream
&Out
,
88 Type
*&Arm64Ty
, Type
*&X64Ty
);
91 } // end anonymous namespace
93 void AArch64Arm64ECCallLowering::getThunkType(
94 FunctionType
*FT
, AttributeList AttrList
, Arm64ECThunkType TT
,
95 raw_ostream
&Out
, FunctionType
*&Arm64Ty
, FunctionType
*&X64Ty
) {
96 Out
<< (TT
== Arm64ECThunkType::Entry
? "$ientry_thunk$cdecl$"
97 : "$iexit_thunk$cdecl$");
102 SmallVector
<Type
*> Arm64ArgTypes
;
103 SmallVector
<Type
*> X64ArgTypes
;
105 // The first argument to a thunk is the called function, stored in x9.
106 // For exit thunks, we pass the called function down to the emulator;
107 // for entry/guest exit thunks, we just call the Arm64 function directly.
108 if (TT
== Arm64ECThunkType::Exit
)
109 Arm64ArgTypes
.push_back(PtrTy
);
110 X64ArgTypes
.push_back(PtrTy
);
112 bool HasSretPtr
= false;
113 getThunkRetType(FT
, AttrList
, Out
, Arm64RetTy
, X64RetTy
, Arm64ArgTypes
,
114 X64ArgTypes
, HasSretPtr
);
116 getThunkArgTypes(FT
, AttrList
, TT
, Out
, Arm64ArgTypes
, X64ArgTypes
,
119 Arm64Ty
= FunctionType::get(Arm64RetTy
, Arm64ArgTypes
, false);
121 X64Ty
= FunctionType::get(X64RetTy
, X64ArgTypes
, false);
124 void AArch64Arm64ECCallLowering::getThunkArgTypes(
125 FunctionType
*FT
, AttributeList AttrList
, Arm64ECThunkType TT
,
126 raw_ostream
&Out
, SmallVectorImpl
<Type
*> &Arm64ArgTypes
,
127 SmallVectorImpl
<Type
*> &X64ArgTypes
, bool HasSretPtr
) {
130 if (FT
->isVarArg()) {
131 // We treat the variadic function's thunk as a normal function
132 // with the following type on the ARM side:
133 // rettype exitthunk(
134 // ptr x9, ptr x0, i64 x1, i64 x2, i64 x3, ptr x4, i64 x5)
136 // that can coverage all types of variadic function.
137 // x9 is similar to normal exit thunk, store the called function.
138 // x0-x3 is the arguments be stored in registers.
139 // x4 is the address of the arguments on the stack.
140 // x5 is the size of the arguments on the stack.
142 // On the x64 side, it's the same except that x5 isn't set.
144 // If both the ARM and X64 sides are sret, there are only three
145 // arguments in registers.
147 // If the X64 side is sret, but the ARM side isn't, we pass an extra value
148 // to/from the X64 side, and let SelectionDAG transform it into a memory
153 for (int i
= HasSretPtr
? 1 : 0; i
< 4; i
++) {
154 Arm64ArgTypes
.push_back(I64Ty
);
155 X64ArgTypes
.push_back(I64Ty
);
159 Arm64ArgTypes
.push_back(PtrTy
);
160 X64ArgTypes
.push_back(PtrTy
);
162 Arm64ArgTypes
.push_back(I64Ty
);
163 if (TT
!= Arm64ECThunkType::Entry
) {
164 // FIXME: x5 isn't actually used by the x64 side; revisit once we
165 // have proper isel for varargs
166 X64ArgTypes
.push_back(I64Ty
);
175 if (I
== FT
->getNumParams()) {
180 for (unsigned E
= FT
->getNumParams(); I
!= E
; ++I
) {
182 // FIXME: Need more information about argument size; see
183 // https://reviews.llvm.org/D132926
184 uint64_t ArgSizeBytes
= AttrList
.getParamArm64ECArgSizeBytes(I
);
185 Align ParamAlign
= AttrList
.getParamAlignment(I
).valueOrOne();
187 uint64_t ArgSizeBytes
= 0;
188 Align ParamAlign
= Align();
190 Type
*Arm64Ty
, *X64Ty
;
191 canonicalizeThunkType(FT
->getParamType(I
), ParamAlign
,
192 /*Ret*/ false, ArgSizeBytes
, Out
, Arm64Ty
, X64Ty
);
193 Arm64ArgTypes
.push_back(Arm64Ty
);
194 X64ArgTypes
.push_back(X64Ty
);
198 void AArch64Arm64ECCallLowering::getThunkRetType(
199 FunctionType
*FT
, AttributeList AttrList
, raw_ostream
&Out
,
200 Type
*&Arm64RetTy
, Type
*&X64RetTy
, SmallVectorImpl
<Type
*> &Arm64ArgTypes
,
201 SmallVectorImpl
<Type
*> &X64ArgTypes
, bool &HasSretPtr
) {
202 Type
*T
= FT
->getReturnType();
204 // FIXME: Need more information about argument size; see
205 // https://reviews.llvm.org/D132926
206 uint64_t ArgSizeBytes
= AttrList
.getRetArm64ECArgSizeBytes();
208 int64_t ArgSizeBytes
= 0;
211 if (FT
->getNumParams()) {
212 auto SRetAttr
= AttrList
.getParamAttr(0, Attribute::StructRet
);
213 auto InRegAttr
= AttrList
.getParamAttr(0, Attribute::InReg
);
214 if (SRetAttr
.isValid() && InRegAttr
.isValid()) {
215 // sret+inreg indicates a call that returns a C++ class value. This is
216 // actually equivalent to just passing and returning a void* pointer
217 // as the first argument. Translate it that way, instead of trying
218 // to model "inreg" in the thunk's calling convention, to simplify
219 // the rest of the code.
225 if (SRetAttr
.isValid()) {
226 // FIXME: Sanity-check the sret type; if it's an integer or pointer,
227 // we'll get screwy mangling/codegen.
228 // FIXME: For large struct types, mangle as an integer argument and
229 // integer return, so we can reuse more thunks, instead of "m" syntax.
230 // (MSVC mangles this case as an integer return with no argument, but
231 // that's a miscompile.)
232 Type
*SRetType
= SRetAttr
.getValueAsType();
233 Align SRetAlign
= AttrList
.getParamAlignment(0).valueOrOne();
234 Type
*Arm64Ty
, *X64Ty
;
235 canonicalizeThunkType(SRetType
, SRetAlign
, /*Ret*/ true, ArgSizeBytes
,
236 Out
, Arm64Ty
, X64Ty
);
239 Arm64ArgTypes
.push_back(FT
->getParamType(0));
240 X64ArgTypes
.push_back(FT
->getParamType(0));
252 canonicalizeThunkType(T
, Align(), /*Ret*/ true, ArgSizeBytes
, Out
, Arm64RetTy
,
254 if (X64RetTy
->isPointerTy()) {
255 // If the X64 type is canonicalized to a pointer, that means it's
256 // passed/returned indirectly. For a return value, that means it's an
258 X64ArgTypes
.push_back(X64RetTy
);
263 void AArch64Arm64ECCallLowering::canonicalizeThunkType(
264 Type
*T
, Align Alignment
, bool Ret
, uint64_t ArgSizeBytes
, raw_ostream
&Out
,
265 Type
*&Arm64Ty
, Type
*&X64Ty
) {
266 if (T
->isFloatTy()) {
273 if (T
->isDoubleTy()) {
280 if (T
->isFloatingPointTy()) {
282 "Only 32 and 64 bit floating points are supported for ARM64EC thunks");
285 auto &DL
= M
->getDataLayout();
287 if (auto *StructTy
= dyn_cast
<StructType
>(T
))
288 if (StructTy
->getNumElements() == 1)
289 T
= StructTy
->getElementType(0);
291 if (T
->isArrayTy()) {
292 Type
*ElementTy
= T
->getArrayElementType();
293 uint64_t ElementCnt
= T
->getArrayNumElements();
294 uint64_t ElementSizePerBytes
= DL
.getTypeSizeInBits(ElementTy
) / 8;
295 uint64_t TotalSizeBytes
= ElementCnt
* ElementSizePerBytes
;
296 if (ElementTy
->isFloatTy() || ElementTy
->isDoubleTy()) {
297 Out
<< (ElementTy
->isFloatTy() ? "F" : "D") << TotalSizeBytes
;
298 if (Alignment
.value() >= 16 && !Ret
)
299 Out
<< "a" << Alignment
.value();
301 if (TotalSizeBytes
<= 8) {
302 // Arm64 returns small structs of float/double in float registers;
304 X64Ty
= llvm::Type::getIntNTy(M
->getContext(), TotalSizeBytes
* 8);
306 // Struct is passed directly on Arm64, but indirectly on X64.
310 } else if (T
->isFloatingPointTy()) {
311 report_fatal_error("Only 32 and 64 bit floating points are supported for "
316 if ((T
->isIntegerTy() || T
->isPointerTy()) && DL
.getTypeSizeInBits(T
) <= 64) {
323 unsigned TypeSize
= ArgSizeBytes
;
325 TypeSize
= DL
.getTypeSizeInBits(T
) / 8;
329 if (Alignment
.value() >= 16 && !Ret
)
330 Out
<< "a" << Alignment
.value();
331 // FIXME: Try to canonicalize Arm64Ty more thoroughly?
333 if (TypeSize
== 1 || TypeSize
== 2 || TypeSize
== 4 || TypeSize
== 8) {
334 // Pass directly in an integer register
335 X64Ty
= llvm::Type::getIntNTy(M
->getContext(), TypeSize
* 8);
337 // Passed directly on Arm64, but indirectly on X64.
342 // This function builds the "exit thunk", a function which translates
343 // arguments and return values when calling x64 code from AArch64 code.
344 Function
*AArch64Arm64ECCallLowering::buildExitThunk(FunctionType
*FT
,
345 AttributeList Attrs
) {
346 SmallString
<256> ExitThunkName
;
347 llvm::raw_svector_ostream
ExitThunkStream(ExitThunkName
);
348 FunctionType
*Arm64Ty
, *X64Ty
;
349 getThunkType(FT
, Attrs
, Arm64ECThunkType::Exit
, ExitThunkStream
, Arm64Ty
,
351 if (Function
*F
= M
->getFunction(ExitThunkName
))
354 Function
*F
= Function::Create(Arm64Ty
, GlobalValue::LinkOnceODRLinkage
, 0,
356 F
->setCallingConv(CallingConv::ARM64EC_Thunk_Native
);
357 F
->setSection(".wowthk$aa");
358 F
->setComdat(M
->getOrInsertComdat(ExitThunkName
));
359 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
360 F
->addFnAttr("frame-pointer", "all");
361 // Only copy sret from the first argument. For C++ instance methods, clang can
362 // stick an sret marking on a later argument, but it doesn't actually affect
363 // the ABI, so we can omit it. This avoids triggering a verifier assertion.
364 if (FT
->getNumParams()) {
365 auto SRet
= Attrs
.getParamAttr(0, Attribute::StructRet
);
366 auto InReg
= Attrs
.getParamAttr(0, Attribute::InReg
);
367 if (SRet
.isValid() && !InReg
.isValid())
368 F
->addParamAttr(1, SRet
);
370 // FIXME: Copy anything other than sret? Shouldn't be necessary for normal
371 // C ABI, but might show up in other cases.
372 BasicBlock
*BB
= BasicBlock::Create(M
->getContext(), "", F
);
375 M
->getOrInsertGlobal("__os_arm64x_dispatch_call_no_redirect", PtrTy
);
376 Value
*Callee
= IRB
.CreateLoad(PtrTy
, CalleePtr
);
377 auto &DL
= M
->getDataLayout();
378 SmallVector
<Value
*> Args
;
380 // Pass the called function in x9.
381 Args
.push_back(F
->arg_begin());
383 Type
*RetTy
= Arm64Ty
->getReturnType();
384 if (RetTy
!= X64Ty
->getReturnType()) {
385 // If the return type is an array or struct, translate it. Values of size
386 // 8 or less go into RAX; bigger values go into memory, and we pass a
388 if (DL
.getTypeStoreSize(RetTy
) > 8) {
389 Args
.push_back(IRB
.CreateAlloca(RetTy
));
393 for (auto &Arg
: make_range(F
->arg_begin() + 1, F
->arg_end())) {
394 // Translate arguments from AArch64 calling convention to x86 calling
397 // For simple types, we don't need to do any translation: they're
398 // represented the same way. (Implicit sign extension is not part of
399 // either convention.)
401 // The big thing we have to worry about is struct types... but
402 // fortunately AArch64 clang is pretty friendly here: the cases that need
403 // translation are always passed as a struct or array. (If we run into
404 // some cases where this doesn't work, we can teach clang to mark it up
405 // with an attribute.)
407 // The first argument is the called function, stored in x9.
408 if (Arg
.getType()->isArrayTy() || Arg
.getType()->isStructTy() ||
409 DL
.getTypeStoreSize(Arg
.getType()) > 8) {
410 Value
*Mem
= IRB
.CreateAlloca(Arg
.getType());
411 IRB
.CreateStore(&Arg
, Mem
);
412 if (DL
.getTypeStoreSize(Arg
.getType()) <= 8) {
413 Type
*IntTy
= IRB
.getIntNTy(DL
.getTypeStoreSizeInBits(Arg
.getType()));
414 Args
.push_back(IRB
.CreateLoad(IntTy
, IRB
.CreateBitCast(Mem
, PtrTy
)));
418 Args
.push_back(&Arg
);
421 // FIXME: Transfer necessary attributes? sret? anything else?
423 Callee
= IRB
.CreateBitCast(Callee
, PtrTy
);
424 CallInst
*Call
= IRB
.CreateCall(X64Ty
, Callee
, Args
);
425 Call
->setCallingConv(CallingConv::ARM64EC_Thunk_X64
);
427 Value
*RetVal
= Call
;
428 if (RetTy
!= X64Ty
->getReturnType()) {
429 // If we rewrote the return type earlier, convert the return value to
431 if (DL
.getTypeStoreSize(RetTy
) > 8) {
432 RetVal
= IRB
.CreateLoad(RetTy
, Args
[1]);
434 Value
*CastAlloca
= IRB
.CreateAlloca(RetTy
);
435 IRB
.CreateStore(Call
, IRB
.CreateBitCast(CastAlloca
, PtrTy
));
436 RetVal
= IRB
.CreateLoad(RetTy
, CastAlloca
);
440 if (RetTy
->isVoidTy())
443 IRB
.CreateRet(RetVal
);
447 // This function builds the "entry thunk", a function which translates
448 // arguments and return values when calling AArch64 code from x64 code.
449 Function
*AArch64Arm64ECCallLowering::buildEntryThunk(Function
*F
) {
450 SmallString
<256> EntryThunkName
;
451 llvm::raw_svector_ostream
EntryThunkStream(EntryThunkName
);
452 FunctionType
*Arm64Ty
, *X64Ty
;
453 getThunkType(F
->getFunctionType(), F
->getAttributes(),
454 Arm64ECThunkType::Entry
, EntryThunkStream
, Arm64Ty
, X64Ty
);
455 if (Function
*F
= M
->getFunction(EntryThunkName
))
458 Function
*Thunk
= Function::Create(X64Ty
, GlobalValue::LinkOnceODRLinkage
, 0,
460 Thunk
->setCallingConv(CallingConv::ARM64EC_Thunk_X64
);
461 Thunk
->setSection(".wowthk$aa");
462 Thunk
->setComdat(M
->getOrInsertComdat(EntryThunkName
));
463 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
464 Thunk
->addFnAttr("frame-pointer", "all");
466 auto &DL
= M
->getDataLayout();
467 BasicBlock
*BB
= BasicBlock::Create(M
->getContext(), "", Thunk
);
470 Type
*RetTy
= Arm64Ty
->getReturnType();
471 Type
*X64RetType
= X64Ty
->getReturnType();
473 bool TransformDirectToSRet
= X64RetType
->isVoidTy() && !RetTy
->isVoidTy();
474 unsigned ThunkArgOffset
= TransformDirectToSRet
? 2 : 1;
475 unsigned PassthroughArgSize
= F
->isVarArg() ? 5 : Thunk
->arg_size();
477 // Translate arguments to call.
478 SmallVector
<Value
*> Args
;
479 for (unsigned i
= ThunkArgOffset
, e
= PassthroughArgSize
; i
!= e
; ++i
) {
480 Value
*Arg
= Thunk
->getArg(i
);
481 Type
*ArgTy
= Arm64Ty
->getParamType(i
- ThunkArgOffset
);
482 if (ArgTy
->isArrayTy() || ArgTy
->isStructTy() ||
483 DL
.getTypeStoreSize(ArgTy
) > 8) {
484 // Translate array/struct arguments to the expected type.
485 if (DL
.getTypeStoreSize(ArgTy
) <= 8) {
486 Value
*CastAlloca
= IRB
.CreateAlloca(ArgTy
);
487 IRB
.CreateStore(Arg
, IRB
.CreateBitCast(CastAlloca
, PtrTy
));
488 Arg
= IRB
.CreateLoad(ArgTy
, CastAlloca
);
490 Arg
= IRB
.CreateLoad(ArgTy
, IRB
.CreateBitCast(Arg
, PtrTy
));
497 // The 5th argument to variadic entry thunks is used to model the x64 sp
498 // which is passed to the thunk in x4, this can be passed to the callee as
499 // the variadic argument start address after skipping over the 32 byte
502 // The EC thunk CC will assign any argument marked as InReg to x4.
503 Thunk
->addParamAttr(5, Attribute::InReg
);
504 Value
*Arg
= Thunk
->getArg(5);
505 Arg
= IRB
.CreatePtrAdd(Arg
, IRB
.getInt64(0x20));
508 // Pass in a zero variadic argument size (in x5).
509 Args
.push_back(IRB
.getInt64(0));
512 // Call the function passed to the thunk.
513 Value
*Callee
= Thunk
->getArg(0);
514 Callee
= IRB
.CreateBitCast(Callee
, PtrTy
);
515 Value
*Call
= IRB
.CreateCall(Arm64Ty
, Callee
, Args
);
517 Value
*RetVal
= Call
;
518 if (TransformDirectToSRet
) {
519 IRB
.CreateStore(RetVal
, IRB
.CreateBitCast(Thunk
->getArg(1), PtrTy
));
520 } else if (X64RetType
!= RetTy
) {
521 Value
*CastAlloca
= IRB
.CreateAlloca(X64RetType
);
522 IRB
.CreateStore(Call
, IRB
.CreateBitCast(CastAlloca
, PtrTy
));
523 RetVal
= IRB
.CreateLoad(X64RetType
, CastAlloca
);
526 // Return to the caller. Note that the isel has code to translate this
527 // "ret" to a tail call to __os_arm64x_dispatch_ret. (Alternatively, we
528 // could emit a tail call here, but that would require a dedicated calling
529 // convention, which seems more complicated overall.)
530 if (X64RetType
->isVoidTy())
533 IRB
.CreateRet(RetVal
);
538 // Builds the "guest exit thunk", a helper to call a function which may or may
539 // not be an exit thunk. (We optimistically assume non-dllimport function
540 // declarations refer to functions defined in AArch64 code; if the linker
541 // can't prove that, we use this routine instead.)
542 Function
*AArch64Arm64ECCallLowering::buildGuestExitThunk(Function
*F
) {
543 llvm::raw_null_ostream NullThunkName
;
544 FunctionType
*Arm64Ty
, *X64Ty
;
545 getThunkType(F
->getFunctionType(), F
->getAttributes(),
546 Arm64ECThunkType::GuestExit
, NullThunkName
, Arm64Ty
, X64Ty
);
547 auto MangledName
= getArm64ECMangledFunctionName(F
->getName().str());
548 assert(MangledName
&& "Can't guest exit to function that's already native");
549 std::string ThunkName
= *MangledName
;
550 if (ThunkName
[0] == '?' && ThunkName
.find("@") != std::string::npos
) {
551 ThunkName
.insert(ThunkName
.find("@"), "$exit_thunk");
553 ThunkName
.append("$exit_thunk");
555 Function
*GuestExit
=
556 Function::Create(Arm64Ty
, GlobalValue::WeakODRLinkage
, 0, ThunkName
, M
);
557 GuestExit
->setComdat(M
->getOrInsertComdat(ThunkName
));
558 GuestExit
->setSection(".wowthk$aa");
559 GuestExit
->setMetadata(
560 "arm64ec_unmangled_name",
561 MDNode::get(M
->getContext(),
562 MDString::get(M
->getContext(), F
->getName())));
563 GuestExit
->setMetadata(
564 "arm64ec_ecmangled_name",
565 MDNode::get(M
->getContext(),
566 MDString::get(M
->getContext(), *MangledName
)));
567 F
->setMetadata("arm64ec_hasguestexit", MDNode::get(M
->getContext(), {}));
568 BasicBlock
*BB
= BasicBlock::Create(M
->getContext(), "", GuestExit
);
571 // Load the global symbol as a pointer to the check function.
573 if (cfguard_module_flag
== 2 && !F
->hasFnAttribute("guard_nocf"))
574 GuardFn
= GuardFnCFGlobal
;
576 GuardFn
= GuardFnGlobal
;
577 LoadInst
*GuardCheckLoad
= B
.CreateLoad(GuardFnPtrType
, GuardFn
);
579 // Create new call instruction. The CFGuard check should always be a call,
580 // even if the original CallBase is an Invoke or CallBr instruction.
581 Function
*Thunk
= buildExitThunk(F
->getFunctionType(), F
->getAttributes());
582 CallInst
*GuardCheck
= B
.CreateCall(
583 GuardFnType
, GuardCheckLoad
,
584 {B
.CreateBitCast(F
, B
.getPtrTy()), B
.CreateBitCast(Thunk
, B
.getPtrTy())});
586 // Ensure that the first argument is passed in the correct register.
587 GuardCheck
->setCallingConv(CallingConv::CFGuard_Check
);
589 Value
*GuardRetVal
= B
.CreateBitCast(GuardCheck
, PtrTy
);
590 SmallVector
<Value
*> Args
;
591 for (Argument
&Arg
: GuestExit
->args())
592 Args
.push_back(&Arg
);
593 CallInst
*Call
= B
.CreateCall(Arm64Ty
, GuardRetVal
, Args
);
594 Call
->setTailCallKind(llvm::CallInst::TCK_MustTail
);
596 if (Call
->getType()->isVoidTy())
601 auto SRetAttr
= F
->getAttributes().getParamAttr(0, Attribute::StructRet
);
602 auto InRegAttr
= F
->getAttributes().getParamAttr(0, Attribute::InReg
);
603 if (SRetAttr
.isValid() && !InRegAttr
.isValid()) {
604 GuestExit
->addParamAttr(0, SRetAttr
);
605 Call
->addParamAttr(0, SRetAttr
);
611 // Lower an indirect call with inline code.
612 void AArch64Arm64ECCallLowering::lowerCall(CallBase
*CB
) {
613 assert(Triple(CB
->getModule()->getTargetTriple()).isOSWindows() &&
614 "Only applicable for Windows targets");
617 Value
*CalledOperand
= CB
->getCalledOperand();
619 // If the indirect call is called within catchpad or cleanuppad,
620 // we need to copy "funclet" bundle of the call.
621 SmallVector
<llvm::OperandBundleDef
, 1> Bundles
;
622 if (auto Bundle
= CB
->getOperandBundle(LLVMContext::OB_funclet
))
623 Bundles
.push_back(OperandBundleDef(*Bundle
));
625 // Load the global symbol as a pointer to the check function.
627 if (cfguard_module_flag
== 2 && !CB
->hasFnAttr("guard_nocf"))
628 GuardFn
= GuardFnCFGlobal
;
630 GuardFn
= GuardFnGlobal
;
631 LoadInst
*GuardCheckLoad
= B
.CreateLoad(GuardFnPtrType
, GuardFn
);
633 // Create new call instruction. The CFGuard check should always be a call,
634 // even if the original CallBase is an Invoke or CallBr instruction.
635 Function
*Thunk
= buildExitThunk(CB
->getFunctionType(), CB
->getAttributes());
636 CallInst
*GuardCheck
=
637 B
.CreateCall(GuardFnType
, GuardCheckLoad
,
638 {B
.CreateBitCast(CalledOperand
, B
.getPtrTy()),
639 B
.CreateBitCast(Thunk
, B
.getPtrTy())},
642 // Ensure that the first argument is passed in the correct register.
643 GuardCheck
->setCallingConv(CallingConv::CFGuard_Check
);
645 Value
*GuardRetVal
= B
.CreateBitCast(GuardCheck
, CalledOperand
->getType());
646 CB
->setCalledOperand(GuardRetVal
);
649 bool AArch64Arm64ECCallLowering::runOnModule(Module
&Mod
) {
655 // Check if this module has the cfguard flag and read its value.
657 mdconst::extract_or_null
<ConstantInt
>(M
->getModuleFlag("cfguard")))
658 cfguard_module_flag
= MD
->getZExtValue();
660 PtrTy
= PointerType::getUnqual(M
->getContext());
661 I64Ty
= Type::getInt64Ty(M
->getContext());
662 VoidTy
= Type::getVoidTy(M
->getContext());
664 GuardFnType
= FunctionType::get(PtrTy
, {PtrTy
, PtrTy
}, false);
665 GuardFnPtrType
= PointerType::get(GuardFnType
, 0);
667 M
->getOrInsertGlobal("__os_arm64x_check_icall_cfg", GuardFnPtrType
);
669 M
->getOrInsertGlobal("__os_arm64x_check_icall", GuardFnPtrType
);
671 SetVector
<Function
*> DirectCalledFns
;
672 for (Function
&F
: Mod
)
673 if (!F
.isDeclaration() &&
674 F
.getCallingConv() != CallingConv::ARM64EC_Thunk_Native
&&
675 F
.getCallingConv() != CallingConv::ARM64EC_Thunk_X64
)
676 processFunction(F
, DirectCalledFns
);
681 Arm64ECThunkType Kind
;
683 SmallVector
<ThunkInfo
> ThunkMapping
;
684 for (Function
&F
: Mod
) {
685 if (!F
.isDeclaration() && (!F
.hasLocalLinkage() || F
.hasAddressTaken()) &&
686 F
.getCallingConv() != CallingConv::ARM64EC_Thunk_Native
&&
687 F
.getCallingConv() != CallingConv::ARM64EC_Thunk_X64
) {
689 F
.setComdat(Mod
.getOrInsertComdat(F
.getName()));
690 ThunkMapping
.push_back(
691 {&F
, buildEntryThunk(&F
), Arm64ECThunkType::Entry
});
694 for (Function
*F
: DirectCalledFns
) {
695 ThunkMapping
.push_back(
696 {F
, buildExitThunk(F
->getFunctionType(), F
->getAttributes()),
697 Arm64ECThunkType::Exit
});
698 if (!F
->hasDLLImportStorageClass())
699 ThunkMapping
.push_back(
700 {buildGuestExitThunk(F
), F
, Arm64ECThunkType::GuestExit
});
703 if (!ThunkMapping
.empty()) {
704 SmallVector
<Constant
*> ThunkMappingArrayElems
;
705 for (ThunkInfo
&Thunk
: ThunkMapping
) {
706 ThunkMappingArrayElems
.push_back(ConstantStruct::getAnon(
707 {ConstantExpr::getBitCast(Thunk
.Src
, PtrTy
),
708 ConstantExpr::getBitCast(Thunk
.Dst
, PtrTy
),
709 ConstantInt::get(M
->getContext(), APInt(32, uint8_t(Thunk
.Kind
)))}));
711 Constant
*ThunkMappingArray
= ConstantArray::get(
712 llvm::ArrayType::get(ThunkMappingArrayElems
[0]->getType(),
713 ThunkMappingArrayElems
.size()),
714 ThunkMappingArrayElems
);
715 new GlobalVariable(Mod
, ThunkMappingArray
->getType(), /*isConstant*/ false,
716 GlobalValue::ExternalLinkage
, ThunkMappingArray
,
717 "llvm.arm64ec.symbolmap");
723 bool AArch64Arm64ECCallLowering::processFunction(
724 Function
&F
, SetVector
<Function
*> &DirectCalledFns
) {
725 SmallVector
<CallBase
*, 8> IndirectCalls
;
727 // For ARM64EC targets, a function definition's name is mangled differently
728 // from the normal symbol. We currently have no representation of this sort
729 // of symbol in IR, so we change the name to the mangled name, then store
730 // the unmangled name as metadata. Later passes that need the unmangled
731 // name (emitting the definition) can grab it from the metadata.
733 // FIXME: Handle functions with weak linkage?
734 if (!F
.hasLocalLinkage() || F
.hasAddressTaken()) {
735 if (std::optional
<std::string
> MangledName
=
736 getArm64ECMangledFunctionName(F
.getName().str())) {
737 F
.setMetadata("arm64ec_unmangled_name",
738 MDNode::get(M
->getContext(),
739 MDString::get(M
->getContext(), F
.getName())));
740 if (F
.hasComdat() && F
.getComdat()->getName() == F
.getName()) {
741 Comdat
*MangledComdat
= M
->getOrInsertComdat(MangledName
.value());
742 SmallVector
<GlobalObject
*> ComdatUsers
=
743 to_vector(F
.getComdat()->getUsers());
744 for (GlobalObject
*User
: ComdatUsers
)
745 User
->setComdat(MangledComdat
);
747 F
.setName(MangledName
.value());
751 // Iterate over the instructions to find all indirect call/invoke/callbr
752 // instructions. Make a separate list of pointers to indirect
753 // call/invoke/callbr instructions because the original instructions will be
754 // deleted as the checks are added.
755 for (BasicBlock
&BB
: F
) {
756 for (Instruction
&I
: BB
) {
757 auto *CB
= dyn_cast
<CallBase
>(&I
);
758 if (!CB
|| CB
->getCallingConv() == CallingConv::ARM64EC_Thunk_X64
||
762 // We need to instrument any call that isn't directly calling an
765 // FIXME: getCalledFunction() fails if there's a bitcast (e.g.
766 // unprototyped functions in C)
767 if (Function
*F
= CB
->getCalledFunction()) {
768 if (!LowerDirectToIndirect
|| F
->hasLocalLinkage() ||
769 F
->isIntrinsic() || !F
->isDeclaration())
772 DirectCalledFns
.insert(F
);
776 IndirectCalls
.push_back(CB
);
777 ++Arm64ECCallsLowered
;
781 if (IndirectCalls
.empty())
784 for (CallBase
*CB
: IndirectCalls
)
790 char AArch64Arm64ECCallLowering::ID
= 0;
791 INITIALIZE_PASS(AArch64Arm64ECCallLowering
, "Arm64ECCallLowering",
792 "AArch64Arm64ECCallLowering", false, false)
794 ModulePass
*llvm::createAArch64Arm64ECCallLoweringPass() {
795 return new AArch64Arm64ECCallLowering
;