[SampleProfileLoader] Fix integer overflow in generateMDProfMetadata (#90217)
[llvm-project.git] / llvm / lib / Target / AArch64 / AArch64Arm64ECCallLowering.cpp
blobdddc181b031444f24c038f777866d2b0bbe8199d
1 //===-- AArch64Arm64ECCallLowering.cpp - Lower Arm64EC calls ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file contains the IR transform to lower external or indirect calls for
11 /// the ARM64EC calling convention. Such calls must go through the runtime, so
12 /// we can translate the calling convention for calls into the emulator.
13 ///
14 /// This subsumes Control Flow Guard handling.
15 ///
16 //===----------------------------------------------------------------------===//
18 #include "AArch64.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallString.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/Statistic.h"
23 #include "llvm/IR/CallingConv.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/Instruction.h"
26 #include "llvm/IR/Mangler.h"
27 #include "llvm/InitializePasses.h"
28 #include "llvm/Object/COFF.h"
29 #include "llvm/Pass.h"
30 #include "llvm/Support/CommandLine.h"
31 #include "llvm/TargetParser/Triple.h"
33 using namespace llvm;
34 using namespace llvm::COFF;
36 using OperandBundleDef = OperandBundleDefT<Value *>;
38 #define DEBUG_TYPE "arm64eccalllowering"
40 STATISTIC(Arm64ECCallsLowered, "Number of Arm64EC calls lowered");
42 static cl::opt<bool> LowerDirectToIndirect("arm64ec-lower-direct-to-indirect",
43 cl::Hidden, cl::init(true));
44 static cl::opt<bool> GenerateThunks("arm64ec-generate-thunks", cl::Hidden,
45 cl::init(true));
47 namespace {
49 class AArch64Arm64ECCallLowering : public ModulePass {
50 public:
51 static char ID;
52 AArch64Arm64ECCallLowering() : ModulePass(ID) {
53 initializeAArch64Arm64ECCallLoweringPass(*PassRegistry::getPassRegistry());
56 Function *buildExitThunk(FunctionType *FnTy, AttributeList Attrs);
57 Function *buildEntryThunk(Function *F);
58 void lowerCall(CallBase *CB);
59 Function *buildGuestExitThunk(Function *F);
60 bool processFunction(Function &F, SetVector<Function *> &DirectCalledFns);
61 bool runOnModule(Module &M) override;
63 private:
64 int cfguard_module_flag = 0;
65 FunctionType *GuardFnType = nullptr;
66 PointerType *GuardFnPtrType = nullptr;
67 Constant *GuardFnCFGlobal = nullptr;
68 Constant *GuardFnGlobal = nullptr;
69 Module *M = nullptr;
71 Type *PtrTy;
72 Type *I64Ty;
73 Type *VoidTy;
75 void getThunkType(FunctionType *FT, AttributeList AttrList,
76 Arm64ECThunkType TT, raw_ostream &Out,
77 FunctionType *&Arm64Ty, FunctionType *&X64Ty);
78 void getThunkRetType(FunctionType *FT, AttributeList AttrList,
79 raw_ostream &Out, Type *&Arm64RetTy, Type *&X64RetTy,
80 SmallVectorImpl<Type *> &Arm64ArgTypes,
81 SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr);
82 void getThunkArgTypes(FunctionType *FT, AttributeList AttrList,
83 Arm64ECThunkType TT, raw_ostream &Out,
84 SmallVectorImpl<Type *> &Arm64ArgTypes,
85 SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr);
86 void canonicalizeThunkType(Type *T, Align Alignment, bool Ret,
87 uint64_t ArgSizeBytes, raw_ostream &Out,
88 Type *&Arm64Ty, Type *&X64Ty);
91 } // end anonymous namespace
93 void AArch64Arm64ECCallLowering::getThunkType(
94 FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT,
95 raw_ostream &Out, FunctionType *&Arm64Ty, FunctionType *&X64Ty) {
96 Out << (TT == Arm64ECThunkType::Entry ? "$ientry_thunk$cdecl$"
97 : "$iexit_thunk$cdecl$");
99 Type *Arm64RetTy;
100 Type *X64RetTy;
102 SmallVector<Type *> Arm64ArgTypes;
103 SmallVector<Type *> X64ArgTypes;
105 // The first argument to a thunk is the called function, stored in x9.
106 // For exit thunks, we pass the called function down to the emulator;
107 // for entry/guest exit thunks, we just call the Arm64 function directly.
108 if (TT == Arm64ECThunkType::Exit)
109 Arm64ArgTypes.push_back(PtrTy);
110 X64ArgTypes.push_back(PtrTy);
112 bool HasSretPtr = false;
113 getThunkRetType(FT, AttrList, Out, Arm64RetTy, X64RetTy, Arm64ArgTypes,
114 X64ArgTypes, HasSretPtr);
116 getThunkArgTypes(FT, AttrList, TT, Out, Arm64ArgTypes, X64ArgTypes,
117 HasSretPtr);
119 Arm64Ty = FunctionType::get(Arm64RetTy, Arm64ArgTypes, false);
121 X64Ty = FunctionType::get(X64RetTy, X64ArgTypes, false);
124 void AArch64Arm64ECCallLowering::getThunkArgTypes(
125 FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT,
126 raw_ostream &Out, SmallVectorImpl<Type *> &Arm64ArgTypes,
127 SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr) {
129 Out << "$";
130 if (FT->isVarArg()) {
131 // We treat the variadic function's thunk as a normal function
132 // with the following type on the ARM side:
133 // rettype exitthunk(
134 // ptr x9, ptr x0, i64 x1, i64 x2, i64 x3, ptr x4, i64 x5)
136 // that can coverage all types of variadic function.
137 // x9 is similar to normal exit thunk, store the called function.
138 // x0-x3 is the arguments be stored in registers.
139 // x4 is the address of the arguments on the stack.
140 // x5 is the size of the arguments on the stack.
142 // On the x64 side, it's the same except that x5 isn't set.
144 // If both the ARM and X64 sides are sret, there are only three
145 // arguments in registers.
147 // If the X64 side is sret, but the ARM side isn't, we pass an extra value
148 // to/from the X64 side, and let SelectionDAG transform it into a memory
149 // location.
150 Out << "varargs";
152 // x0-x3
153 for (int i = HasSretPtr ? 1 : 0; i < 4; i++) {
154 Arm64ArgTypes.push_back(I64Ty);
155 X64ArgTypes.push_back(I64Ty);
158 // x4
159 Arm64ArgTypes.push_back(PtrTy);
160 X64ArgTypes.push_back(PtrTy);
161 // x5
162 Arm64ArgTypes.push_back(I64Ty);
163 if (TT != Arm64ECThunkType::Entry) {
164 // FIXME: x5 isn't actually used by the x64 side; revisit once we
165 // have proper isel for varargs
166 X64ArgTypes.push_back(I64Ty);
168 return;
171 unsigned I = 0;
172 if (HasSretPtr)
173 I++;
175 if (I == FT->getNumParams()) {
176 Out << "v";
177 return;
180 for (unsigned E = FT->getNumParams(); I != E; ++I) {
181 #if 0
182 // FIXME: Need more information about argument size; see
183 // https://reviews.llvm.org/D132926
184 uint64_t ArgSizeBytes = AttrList.getParamArm64ECArgSizeBytes(I);
185 Align ParamAlign = AttrList.getParamAlignment(I).valueOrOne();
186 #else
187 uint64_t ArgSizeBytes = 0;
188 Align ParamAlign = Align();
189 #endif
190 Type *Arm64Ty, *X64Ty;
191 canonicalizeThunkType(FT->getParamType(I), ParamAlign,
192 /*Ret*/ false, ArgSizeBytes, Out, Arm64Ty, X64Ty);
193 Arm64ArgTypes.push_back(Arm64Ty);
194 X64ArgTypes.push_back(X64Ty);
198 void AArch64Arm64ECCallLowering::getThunkRetType(
199 FunctionType *FT, AttributeList AttrList, raw_ostream &Out,
200 Type *&Arm64RetTy, Type *&X64RetTy, SmallVectorImpl<Type *> &Arm64ArgTypes,
201 SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr) {
202 Type *T = FT->getReturnType();
203 #if 0
204 // FIXME: Need more information about argument size; see
205 // https://reviews.llvm.org/D132926
206 uint64_t ArgSizeBytes = AttrList.getRetArm64ECArgSizeBytes();
207 #else
208 int64_t ArgSizeBytes = 0;
209 #endif
210 if (T->isVoidTy()) {
211 if (FT->getNumParams()) {
212 auto SRetAttr = AttrList.getParamAttr(0, Attribute::StructRet);
213 auto InRegAttr = AttrList.getParamAttr(0, Attribute::InReg);
214 if (SRetAttr.isValid() && InRegAttr.isValid()) {
215 // sret+inreg indicates a call that returns a C++ class value. This is
216 // actually equivalent to just passing and returning a void* pointer
217 // as the first argument. Translate it that way, instead of trying
218 // to model "inreg" in the thunk's calling convention, to simplify
219 // the rest of the code.
220 Out << "i8";
221 Arm64RetTy = I64Ty;
222 X64RetTy = I64Ty;
223 return;
225 if (SRetAttr.isValid()) {
226 // FIXME: Sanity-check the sret type; if it's an integer or pointer,
227 // we'll get screwy mangling/codegen.
228 // FIXME: For large struct types, mangle as an integer argument and
229 // integer return, so we can reuse more thunks, instead of "m" syntax.
230 // (MSVC mangles this case as an integer return with no argument, but
231 // that's a miscompile.)
232 Type *SRetType = SRetAttr.getValueAsType();
233 Align SRetAlign = AttrList.getParamAlignment(0).valueOrOne();
234 Type *Arm64Ty, *X64Ty;
235 canonicalizeThunkType(SRetType, SRetAlign, /*Ret*/ true, ArgSizeBytes,
236 Out, Arm64Ty, X64Ty);
237 Arm64RetTy = VoidTy;
238 X64RetTy = VoidTy;
239 Arm64ArgTypes.push_back(FT->getParamType(0));
240 X64ArgTypes.push_back(FT->getParamType(0));
241 HasSretPtr = true;
242 return;
246 Out << "v";
247 Arm64RetTy = VoidTy;
248 X64RetTy = VoidTy;
249 return;
252 canonicalizeThunkType(T, Align(), /*Ret*/ true, ArgSizeBytes, Out, Arm64RetTy,
253 X64RetTy);
254 if (X64RetTy->isPointerTy()) {
255 // If the X64 type is canonicalized to a pointer, that means it's
256 // passed/returned indirectly. For a return value, that means it's an
257 // sret pointer.
258 X64ArgTypes.push_back(X64RetTy);
259 X64RetTy = VoidTy;
263 void AArch64Arm64ECCallLowering::canonicalizeThunkType(
264 Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes, raw_ostream &Out,
265 Type *&Arm64Ty, Type *&X64Ty) {
266 if (T->isFloatTy()) {
267 Out << "f";
268 Arm64Ty = T;
269 X64Ty = T;
270 return;
273 if (T->isDoubleTy()) {
274 Out << "d";
275 Arm64Ty = T;
276 X64Ty = T;
277 return;
280 if (T->isFloatingPointTy()) {
281 report_fatal_error(
282 "Only 32 and 64 bit floating points are supported for ARM64EC thunks");
285 auto &DL = M->getDataLayout();
287 if (auto *StructTy = dyn_cast<StructType>(T))
288 if (StructTy->getNumElements() == 1)
289 T = StructTy->getElementType(0);
291 if (T->isArrayTy()) {
292 Type *ElementTy = T->getArrayElementType();
293 uint64_t ElementCnt = T->getArrayNumElements();
294 uint64_t ElementSizePerBytes = DL.getTypeSizeInBits(ElementTy) / 8;
295 uint64_t TotalSizeBytes = ElementCnt * ElementSizePerBytes;
296 if (ElementTy->isFloatTy() || ElementTy->isDoubleTy()) {
297 Out << (ElementTy->isFloatTy() ? "F" : "D") << TotalSizeBytes;
298 if (Alignment.value() >= 16 && !Ret)
299 Out << "a" << Alignment.value();
300 Arm64Ty = T;
301 if (TotalSizeBytes <= 8) {
302 // Arm64 returns small structs of float/double in float registers;
303 // X64 uses RAX.
304 X64Ty = llvm::Type::getIntNTy(M->getContext(), TotalSizeBytes * 8);
305 } else {
306 // Struct is passed directly on Arm64, but indirectly on X64.
307 X64Ty = PtrTy;
309 return;
310 } else if (T->isFloatingPointTy()) {
311 report_fatal_error("Only 32 and 64 bit floating points are supported for "
312 "ARM64EC thunks");
316 if ((T->isIntegerTy() || T->isPointerTy()) && DL.getTypeSizeInBits(T) <= 64) {
317 Out << "i8";
318 Arm64Ty = I64Ty;
319 X64Ty = I64Ty;
320 return;
323 unsigned TypeSize = ArgSizeBytes;
324 if (TypeSize == 0)
325 TypeSize = DL.getTypeSizeInBits(T) / 8;
326 Out << "m";
327 if (TypeSize != 4)
328 Out << TypeSize;
329 if (Alignment.value() >= 16 && !Ret)
330 Out << "a" << Alignment.value();
331 // FIXME: Try to canonicalize Arm64Ty more thoroughly?
332 Arm64Ty = T;
333 if (TypeSize == 1 || TypeSize == 2 || TypeSize == 4 || TypeSize == 8) {
334 // Pass directly in an integer register
335 X64Ty = llvm::Type::getIntNTy(M->getContext(), TypeSize * 8);
336 } else {
337 // Passed directly on Arm64, but indirectly on X64.
338 X64Ty = PtrTy;
342 // This function builds the "exit thunk", a function which translates
343 // arguments and return values when calling x64 code from AArch64 code.
344 Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
345 AttributeList Attrs) {
346 SmallString<256> ExitThunkName;
347 llvm::raw_svector_ostream ExitThunkStream(ExitThunkName);
348 FunctionType *Arm64Ty, *X64Ty;
349 getThunkType(FT, Attrs, Arm64ECThunkType::Exit, ExitThunkStream, Arm64Ty,
350 X64Ty);
351 if (Function *F = M->getFunction(ExitThunkName))
352 return F;
354 Function *F = Function::Create(Arm64Ty, GlobalValue::LinkOnceODRLinkage, 0,
355 ExitThunkName, M);
356 F->setCallingConv(CallingConv::ARM64EC_Thunk_Native);
357 F->setSection(".wowthk$aa");
358 F->setComdat(M->getOrInsertComdat(ExitThunkName));
359 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
360 F->addFnAttr("frame-pointer", "all");
361 // Only copy sret from the first argument. For C++ instance methods, clang can
362 // stick an sret marking on a later argument, but it doesn't actually affect
363 // the ABI, so we can omit it. This avoids triggering a verifier assertion.
364 if (FT->getNumParams()) {
365 auto SRet = Attrs.getParamAttr(0, Attribute::StructRet);
366 auto InReg = Attrs.getParamAttr(0, Attribute::InReg);
367 if (SRet.isValid() && !InReg.isValid())
368 F->addParamAttr(1, SRet);
370 // FIXME: Copy anything other than sret? Shouldn't be necessary for normal
371 // C ABI, but might show up in other cases.
372 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", F);
373 IRBuilder<> IRB(BB);
374 Value *CalleePtr =
375 M->getOrInsertGlobal("__os_arm64x_dispatch_call_no_redirect", PtrTy);
376 Value *Callee = IRB.CreateLoad(PtrTy, CalleePtr);
377 auto &DL = M->getDataLayout();
378 SmallVector<Value *> Args;
380 // Pass the called function in x9.
381 Args.push_back(F->arg_begin());
383 Type *RetTy = Arm64Ty->getReturnType();
384 if (RetTy != X64Ty->getReturnType()) {
385 // If the return type is an array or struct, translate it. Values of size
386 // 8 or less go into RAX; bigger values go into memory, and we pass a
387 // pointer.
388 if (DL.getTypeStoreSize(RetTy) > 8) {
389 Args.push_back(IRB.CreateAlloca(RetTy));
393 for (auto &Arg : make_range(F->arg_begin() + 1, F->arg_end())) {
394 // Translate arguments from AArch64 calling convention to x86 calling
395 // convention.
397 // For simple types, we don't need to do any translation: they're
398 // represented the same way. (Implicit sign extension is not part of
399 // either convention.)
401 // The big thing we have to worry about is struct types... but
402 // fortunately AArch64 clang is pretty friendly here: the cases that need
403 // translation are always passed as a struct or array. (If we run into
404 // some cases where this doesn't work, we can teach clang to mark it up
405 // with an attribute.)
407 // The first argument is the called function, stored in x9.
408 if (Arg.getType()->isArrayTy() || Arg.getType()->isStructTy() ||
409 DL.getTypeStoreSize(Arg.getType()) > 8) {
410 Value *Mem = IRB.CreateAlloca(Arg.getType());
411 IRB.CreateStore(&Arg, Mem);
412 if (DL.getTypeStoreSize(Arg.getType()) <= 8) {
413 Type *IntTy = IRB.getIntNTy(DL.getTypeStoreSizeInBits(Arg.getType()));
414 Args.push_back(IRB.CreateLoad(IntTy, IRB.CreateBitCast(Mem, PtrTy)));
415 } else
416 Args.push_back(Mem);
417 } else {
418 Args.push_back(&Arg);
421 // FIXME: Transfer necessary attributes? sret? anything else?
423 Callee = IRB.CreateBitCast(Callee, PtrTy);
424 CallInst *Call = IRB.CreateCall(X64Ty, Callee, Args);
425 Call->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
427 Value *RetVal = Call;
428 if (RetTy != X64Ty->getReturnType()) {
429 // If we rewrote the return type earlier, convert the return value to
430 // the proper type.
431 if (DL.getTypeStoreSize(RetTy) > 8) {
432 RetVal = IRB.CreateLoad(RetTy, Args[1]);
433 } else {
434 Value *CastAlloca = IRB.CreateAlloca(RetTy);
435 IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
436 RetVal = IRB.CreateLoad(RetTy, CastAlloca);
440 if (RetTy->isVoidTy())
441 IRB.CreateRetVoid();
442 else
443 IRB.CreateRet(RetVal);
444 return F;
447 // This function builds the "entry thunk", a function which translates
448 // arguments and return values when calling AArch64 code from x64 code.
449 Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
450 SmallString<256> EntryThunkName;
451 llvm::raw_svector_ostream EntryThunkStream(EntryThunkName);
452 FunctionType *Arm64Ty, *X64Ty;
453 getThunkType(F->getFunctionType(), F->getAttributes(),
454 Arm64ECThunkType::Entry, EntryThunkStream, Arm64Ty, X64Ty);
455 if (Function *F = M->getFunction(EntryThunkName))
456 return F;
458 Function *Thunk = Function::Create(X64Ty, GlobalValue::LinkOnceODRLinkage, 0,
459 EntryThunkName, M);
460 Thunk->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
461 Thunk->setSection(".wowthk$aa");
462 Thunk->setComdat(M->getOrInsertComdat(EntryThunkName));
463 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
464 Thunk->addFnAttr("frame-pointer", "all");
466 auto &DL = M->getDataLayout();
467 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", Thunk);
468 IRBuilder<> IRB(BB);
470 Type *RetTy = Arm64Ty->getReturnType();
471 Type *X64RetType = X64Ty->getReturnType();
473 bool TransformDirectToSRet = X64RetType->isVoidTy() && !RetTy->isVoidTy();
474 unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1;
475 unsigned PassthroughArgSize = F->isVarArg() ? 5 : Thunk->arg_size();
477 // Translate arguments to call.
478 SmallVector<Value *> Args;
479 for (unsigned i = ThunkArgOffset, e = PassthroughArgSize; i != e; ++i) {
480 Value *Arg = Thunk->getArg(i);
481 Type *ArgTy = Arm64Ty->getParamType(i - ThunkArgOffset);
482 if (ArgTy->isArrayTy() || ArgTy->isStructTy() ||
483 DL.getTypeStoreSize(ArgTy) > 8) {
484 // Translate array/struct arguments to the expected type.
485 if (DL.getTypeStoreSize(ArgTy) <= 8) {
486 Value *CastAlloca = IRB.CreateAlloca(ArgTy);
487 IRB.CreateStore(Arg, IRB.CreateBitCast(CastAlloca, PtrTy));
488 Arg = IRB.CreateLoad(ArgTy, CastAlloca);
489 } else {
490 Arg = IRB.CreateLoad(ArgTy, IRB.CreateBitCast(Arg, PtrTy));
493 Args.push_back(Arg);
496 if (F->isVarArg()) {
497 // The 5th argument to variadic entry thunks is used to model the x64 sp
498 // which is passed to the thunk in x4, this can be passed to the callee as
499 // the variadic argument start address after skipping over the 32 byte
500 // shadow store.
502 // The EC thunk CC will assign any argument marked as InReg to x4.
503 Thunk->addParamAttr(5, Attribute::InReg);
504 Value *Arg = Thunk->getArg(5);
505 Arg = IRB.CreatePtrAdd(Arg, IRB.getInt64(0x20));
506 Args.push_back(Arg);
508 // Pass in a zero variadic argument size (in x5).
509 Args.push_back(IRB.getInt64(0));
512 // Call the function passed to the thunk.
513 Value *Callee = Thunk->getArg(0);
514 Callee = IRB.CreateBitCast(Callee, PtrTy);
515 Value *Call = IRB.CreateCall(Arm64Ty, Callee, Args);
517 Value *RetVal = Call;
518 if (TransformDirectToSRet) {
519 IRB.CreateStore(RetVal, IRB.CreateBitCast(Thunk->getArg(1), PtrTy));
520 } else if (X64RetType != RetTy) {
521 Value *CastAlloca = IRB.CreateAlloca(X64RetType);
522 IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
523 RetVal = IRB.CreateLoad(X64RetType, CastAlloca);
526 // Return to the caller. Note that the isel has code to translate this
527 // "ret" to a tail call to __os_arm64x_dispatch_ret. (Alternatively, we
528 // could emit a tail call here, but that would require a dedicated calling
529 // convention, which seems more complicated overall.)
530 if (X64RetType->isVoidTy())
531 IRB.CreateRetVoid();
532 else
533 IRB.CreateRet(RetVal);
535 return Thunk;
538 // Builds the "guest exit thunk", a helper to call a function which may or may
539 // not be an exit thunk. (We optimistically assume non-dllimport function
540 // declarations refer to functions defined in AArch64 code; if the linker
541 // can't prove that, we use this routine instead.)
542 Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
543 llvm::raw_null_ostream NullThunkName;
544 FunctionType *Arm64Ty, *X64Ty;
545 getThunkType(F->getFunctionType(), F->getAttributes(),
546 Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty);
547 auto MangledName = getArm64ECMangledFunctionName(F->getName().str());
548 assert(MangledName && "Can't guest exit to function that's already native");
549 std::string ThunkName = *MangledName;
550 if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
551 ThunkName.insert(ThunkName.find("@"), "$exit_thunk");
552 } else {
553 ThunkName.append("$exit_thunk");
555 Function *GuestExit =
556 Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);
557 GuestExit->setComdat(M->getOrInsertComdat(ThunkName));
558 GuestExit->setSection(".wowthk$aa");
559 GuestExit->setMetadata(
560 "arm64ec_unmangled_name",
561 MDNode::get(M->getContext(),
562 MDString::get(M->getContext(), F->getName())));
563 GuestExit->setMetadata(
564 "arm64ec_ecmangled_name",
565 MDNode::get(M->getContext(),
566 MDString::get(M->getContext(), *MangledName)));
567 F->setMetadata("arm64ec_hasguestexit", MDNode::get(M->getContext(), {}));
568 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);
569 IRBuilder<> B(BB);
571 // Load the global symbol as a pointer to the check function.
572 Value *GuardFn;
573 if (cfguard_module_flag == 2 && !F->hasFnAttribute("guard_nocf"))
574 GuardFn = GuardFnCFGlobal;
575 else
576 GuardFn = GuardFnGlobal;
577 LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn);
579 // Create new call instruction. The CFGuard check should always be a call,
580 // even if the original CallBase is an Invoke or CallBr instruction.
581 Function *Thunk = buildExitThunk(F->getFunctionType(), F->getAttributes());
582 CallInst *GuardCheck = B.CreateCall(
583 GuardFnType, GuardCheckLoad,
584 {B.CreateBitCast(F, B.getPtrTy()), B.CreateBitCast(Thunk, B.getPtrTy())});
586 // Ensure that the first argument is passed in the correct register.
587 GuardCheck->setCallingConv(CallingConv::CFGuard_Check);
589 Value *GuardRetVal = B.CreateBitCast(GuardCheck, PtrTy);
590 SmallVector<Value *> Args;
591 for (Argument &Arg : GuestExit->args())
592 Args.push_back(&Arg);
593 CallInst *Call = B.CreateCall(Arm64Ty, GuardRetVal, Args);
594 Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
596 if (Call->getType()->isVoidTy())
597 B.CreateRetVoid();
598 else
599 B.CreateRet(Call);
601 auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
602 auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
603 if (SRetAttr.isValid() && !InRegAttr.isValid()) {
604 GuestExit->addParamAttr(0, SRetAttr);
605 Call->addParamAttr(0, SRetAttr);
608 return GuestExit;
611 // Lower an indirect call with inline code.
612 void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) {
613 assert(Triple(CB->getModule()->getTargetTriple()).isOSWindows() &&
614 "Only applicable for Windows targets");
616 IRBuilder<> B(CB);
617 Value *CalledOperand = CB->getCalledOperand();
619 // If the indirect call is called within catchpad or cleanuppad,
620 // we need to copy "funclet" bundle of the call.
621 SmallVector<llvm::OperandBundleDef, 1> Bundles;
622 if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet))
623 Bundles.push_back(OperandBundleDef(*Bundle));
625 // Load the global symbol as a pointer to the check function.
626 Value *GuardFn;
627 if (cfguard_module_flag == 2 && !CB->hasFnAttr("guard_nocf"))
628 GuardFn = GuardFnCFGlobal;
629 else
630 GuardFn = GuardFnGlobal;
631 LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn);
633 // Create new call instruction. The CFGuard check should always be a call,
634 // even if the original CallBase is an Invoke or CallBr instruction.
635 Function *Thunk = buildExitThunk(CB->getFunctionType(), CB->getAttributes());
636 CallInst *GuardCheck =
637 B.CreateCall(GuardFnType, GuardCheckLoad,
638 {B.CreateBitCast(CalledOperand, B.getPtrTy()),
639 B.CreateBitCast(Thunk, B.getPtrTy())},
640 Bundles);
642 // Ensure that the first argument is passed in the correct register.
643 GuardCheck->setCallingConv(CallingConv::CFGuard_Check);
645 Value *GuardRetVal = B.CreateBitCast(GuardCheck, CalledOperand->getType());
646 CB->setCalledOperand(GuardRetVal);
649 bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
650 if (!GenerateThunks)
651 return false;
653 M = &Mod;
655 // Check if this module has the cfguard flag and read its value.
656 if (auto *MD =
657 mdconst::extract_or_null<ConstantInt>(M->getModuleFlag("cfguard")))
658 cfguard_module_flag = MD->getZExtValue();
660 PtrTy = PointerType::getUnqual(M->getContext());
661 I64Ty = Type::getInt64Ty(M->getContext());
662 VoidTy = Type::getVoidTy(M->getContext());
664 GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false);
665 GuardFnPtrType = PointerType::get(GuardFnType, 0);
666 GuardFnCFGlobal =
667 M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", GuardFnPtrType);
668 GuardFnGlobal =
669 M->getOrInsertGlobal("__os_arm64x_check_icall", GuardFnPtrType);
671 SetVector<Function *> DirectCalledFns;
672 for (Function &F : Mod)
673 if (!F.isDeclaration() &&
674 F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
675 F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64)
676 processFunction(F, DirectCalledFns);
678 struct ThunkInfo {
679 Constant *Src;
680 Constant *Dst;
681 Arm64ECThunkType Kind;
683 SmallVector<ThunkInfo> ThunkMapping;
684 for (Function &F : Mod) {
685 if (!F.isDeclaration() && (!F.hasLocalLinkage() || F.hasAddressTaken()) &&
686 F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
687 F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) {
688 if (!F.hasComdat())
689 F.setComdat(Mod.getOrInsertComdat(F.getName()));
690 ThunkMapping.push_back(
691 {&F, buildEntryThunk(&F), Arm64ECThunkType::Entry});
694 for (Function *F : DirectCalledFns) {
695 ThunkMapping.push_back(
696 {F, buildExitThunk(F->getFunctionType(), F->getAttributes()),
697 Arm64ECThunkType::Exit});
698 if (!F->hasDLLImportStorageClass())
699 ThunkMapping.push_back(
700 {buildGuestExitThunk(F), F, Arm64ECThunkType::GuestExit});
703 if (!ThunkMapping.empty()) {
704 SmallVector<Constant *> ThunkMappingArrayElems;
705 for (ThunkInfo &Thunk : ThunkMapping) {
706 ThunkMappingArrayElems.push_back(ConstantStruct::getAnon(
707 {ConstantExpr::getBitCast(Thunk.Src, PtrTy),
708 ConstantExpr::getBitCast(Thunk.Dst, PtrTy),
709 ConstantInt::get(M->getContext(), APInt(32, uint8_t(Thunk.Kind)))}));
711 Constant *ThunkMappingArray = ConstantArray::get(
712 llvm::ArrayType::get(ThunkMappingArrayElems[0]->getType(),
713 ThunkMappingArrayElems.size()),
714 ThunkMappingArrayElems);
715 new GlobalVariable(Mod, ThunkMappingArray->getType(), /*isConstant*/ false,
716 GlobalValue::ExternalLinkage, ThunkMappingArray,
717 "llvm.arm64ec.symbolmap");
720 return true;
723 bool AArch64Arm64ECCallLowering::processFunction(
724 Function &F, SetVector<Function *> &DirectCalledFns) {
725 SmallVector<CallBase *, 8> IndirectCalls;
727 // For ARM64EC targets, a function definition's name is mangled differently
728 // from the normal symbol. We currently have no representation of this sort
729 // of symbol in IR, so we change the name to the mangled name, then store
730 // the unmangled name as metadata. Later passes that need the unmangled
731 // name (emitting the definition) can grab it from the metadata.
733 // FIXME: Handle functions with weak linkage?
734 if (!F.hasLocalLinkage() || F.hasAddressTaken()) {
735 if (std::optional<std::string> MangledName =
736 getArm64ECMangledFunctionName(F.getName().str())) {
737 F.setMetadata("arm64ec_unmangled_name",
738 MDNode::get(M->getContext(),
739 MDString::get(M->getContext(), F.getName())));
740 if (F.hasComdat() && F.getComdat()->getName() == F.getName()) {
741 Comdat *MangledComdat = M->getOrInsertComdat(MangledName.value());
742 SmallVector<GlobalObject *> ComdatUsers =
743 to_vector(F.getComdat()->getUsers());
744 for (GlobalObject *User : ComdatUsers)
745 User->setComdat(MangledComdat);
747 F.setName(MangledName.value());
751 // Iterate over the instructions to find all indirect call/invoke/callbr
752 // instructions. Make a separate list of pointers to indirect
753 // call/invoke/callbr instructions because the original instructions will be
754 // deleted as the checks are added.
755 for (BasicBlock &BB : F) {
756 for (Instruction &I : BB) {
757 auto *CB = dyn_cast<CallBase>(&I);
758 if (!CB || CB->getCallingConv() == CallingConv::ARM64EC_Thunk_X64 ||
759 CB->isInlineAsm())
760 continue;
762 // We need to instrument any call that isn't directly calling an
763 // ARM64 function.
765 // FIXME: getCalledFunction() fails if there's a bitcast (e.g.
766 // unprototyped functions in C)
767 if (Function *F = CB->getCalledFunction()) {
768 if (!LowerDirectToIndirect || F->hasLocalLinkage() ||
769 F->isIntrinsic() || !F->isDeclaration())
770 continue;
772 DirectCalledFns.insert(F);
773 continue;
776 IndirectCalls.push_back(CB);
777 ++Arm64ECCallsLowered;
781 if (IndirectCalls.empty())
782 return false;
784 for (CallBase *CB : IndirectCalls)
785 lowerCall(CB);
787 return true;
790 char AArch64Arm64ECCallLowering::ID = 0;
791 INITIALIZE_PASS(AArch64Arm64ECCallLowering, "Arm64ECCallLowering",
792 "AArch64Arm64ECCallLowering", false, false)
794 ModulePass *llvm::createAArch64Arm64ECCallLoweringPass() {
795 return new AArch64Arm64ECCallLowering;