1 //===- OffloadWrapper.cpp ---------------------------------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "OffloadWrapper.h"
10 #include "llvm/ADT/ArrayRef.h"
11 #include "llvm/Frontend/Offloading/Utility.h"
12 #include "llvm/IR/Constants.h"
13 #include "llvm/IR/GlobalVariable.h"
14 #include "llvm/IR/IRBuilder.h"
15 #include "llvm/IR/LLVMContext.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/Object/OffloadBinary.h"
18 #include "llvm/Support/Error.h"
19 #include "llvm/TargetParser/Triple.h"
20 #include "llvm/Transforms/Utils/ModuleUtils.h"
25 /// Magic number that begins the section containing the CUDA fatbinary.
26 constexpr unsigned CudaFatMagic
= 0x466243b1;
27 constexpr unsigned HIPFatMagic
= 0x48495046;
29 IntegerType
*getSizeTTy(Module
&M
) {
30 return M
.getDataLayout().getIntPtrType(M
.getContext());
33 // struct __tgt_device_image {
36 // __tgt_offload_entry *EntriesBegin;
37 // __tgt_offload_entry *EntriesEnd;
39 StructType
*getDeviceImageTy(Module
&M
) {
40 LLVMContext
&C
= M
.getContext();
41 StructType
*ImageTy
= StructType::getTypeByName(C
, "__tgt_device_image");
44 StructType::create("__tgt_device_image", PointerType::getUnqual(C
),
45 PointerType::getUnqual(C
), PointerType::getUnqual(C
),
46 PointerType::getUnqual(C
));
50 PointerType
*getDeviceImagePtrTy(Module
&M
) {
51 return PointerType::getUnqual(getDeviceImageTy(M
));
54 // struct __tgt_bin_desc {
55 // int32_t NumDeviceImages;
56 // __tgt_device_image *DeviceImages;
57 // __tgt_offload_entry *HostEntriesBegin;
58 // __tgt_offload_entry *HostEntriesEnd;
60 StructType
*getBinDescTy(Module
&M
) {
61 LLVMContext
&C
= M
.getContext();
62 StructType
*DescTy
= StructType::getTypeByName(C
, "__tgt_bin_desc");
64 DescTy
= StructType::create(
65 "__tgt_bin_desc", Type::getInt32Ty(C
), getDeviceImagePtrTy(M
),
66 PointerType::getUnqual(C
), PointerType::getUnqual(C
));
70 PointerType
*getBinDescPtrTy(Module
&M
) {
71 return PointerType::getUnqual(getBinDescTy(M
));
74 /// Creates binary descriptor for the given device images. Binary descriptor
75 /// is an object that is passed to the offloading runtime at program startup
76 /// and it describes all device images available in the executable or shared
77 /// library. It is defined as follows
79 /// __attribute__((visibility("hidden")))
80 /// extern __tgt_offload_entry *__start_omp_offloading_entries;
81 /// __attribute__((visibility("hidden")))
82 /// extern __tgt_offload_entry *__stop_omp_offloading_entries;
84 /// static const char Image0[] = { <Bufs.front() contents> };
86 /// static const char ImageN[] = { <Bufs.back() contents> };
88 /// static const __tgt_device_image Images[] = {
90 /// Image0, /*ImageStart*/
91 /// Image0 + sizeof(Image0), /*ImageEnd*/
92 /// __start_omp_offloading_entries, /*EntriesBegin*/
93 /// __stop_omp_offloading_entries /*EntriesEnd*/
97 /// ImageN, /*ImageStart*/
98 /// ImageN + sizeof(ImageN), /*ImageEnd*/
99 /// __start_omp_offloading_entries, /*EntriesBegin*/
100 /// __stop_omp_offloading_entries /*EntriesEnd*/
104 /// static const __tgt_bin_desc BinDesc = {
105 /// sizeof(Images) / sizeof(Images[0]), /*NumDeviceImages*/
106 /// Images, /*DeviceImages*/
107 /// __start_omp_offloading_entries, /*HostEntriesBegin*/
108 /// __stop_omp_offloading_entries /*HostEntriesEnd*/
111 /// Global variable that represents BinDesc is returned.
112 GlobalVariable
*createBinDesc(Module
&M
, ArrayRef
<ArrayRef
<char>> Bufs
) {
113 LLVMContext
&C
= M
.getContext();
114 auto [EntriesB
, EntriesE
] =
115 offloading::getOffloadEntryArray(M
, "omp_offloading_entries");
117 auto *Zero
= ConstantInt::get(getSizeTTy(M
), 0u);
118 Constant
*ZeroZero
[] = {Zero
, Zero
};
120 // Create initializer for the images array.
121 SmallVector
<Constant
*, 4u> ImagesInits
;
122 ImagesInits
.reserve(Bufs
.size());
123 for (ArrayRef
<char> Buf
: Bufs
) {
124 auto *Data
= ConstantDataArray::get(C
, Buf
);
125 auto *Image
= new GlobalVariable(M
, Data
->getType(), /*isConstant*/ true,
126 GlobalVariable::InternalLinkage
, Data
,
127 ".omp_offloading.device_image");
128 Image
->setUnnamedAddr(GlobalValue::UnnamedAddr::Global
);
129 Image
->setSection(".llvm.offloading");
130 Image
->setAlignment(Align(object::OffloadBinary::getAlignment()));
132 auto *Size
= ConstantInt::get(getSizeTTy(M
), Buf
.size());
133 Constant
*ZeroSize
[] = {Zero
, Size
};
136 ConstantExpr::getGetElementPtr(Image
->getValueType(), Image
, ZeroZero
);
138 ConstantExpr::getGetElementPtr(Image
->getValueType(), Image
, ZeroSize
);
140 ImagesInits
.push_back(ConstantStruct::get(getDeviceImageTy(M
), ImageB
,
141 ImageE
, EntriesB
, EntriesE
));
144 // Then create images array.
145 auto *ImagesData
= ConstantArray::get(
146 ArrayType::get(getDeviceImageTy(M
), ImagesInits
.size()), ImagesInits
);
149 new GlobalVariable(M
, ImagesData
->getType(), /*isConstant*/ true,
150 GlobalValue::InternalLinkage
, ImagesData
,
151 ".omp_offloading.device_images");
152 Images
->setUnnamedAddr(GlobalValue::UnnamedAddr::Global
);
155 ConstantExpr::getGetElementPtr(Images
->getValueType(), Images
, ZeroZero
);
157 // And finally create the binary descriptor object.
158 auto *DescInit
= ConstantStruct::get(
160 ConstantInt::get(Type::getInt32Ty(C
), ImagesInits
.size()), ImagesB
,
163 return new GlobalVariable(M
, DescInit
->getType(), /*isConstant*/ true,
164 GlobalValue::InternalLinkage
, DescInit
,
165 ".omp_offloading.descriptor");
168 void createRegisterFunction(Module
&M
, GlobalVariable
*BinDesc
) {
169 LLVMContext
&C
= M
.getContext();
170 auto *FuncTy
= FunctionType::get(Type::getVoidTy(C
), /*isVarArg*/ false);
171 auto *Func
= Function::Create(FuncTy
, GlobalValue::InternalLinkage
,
172 ".omp_offloading.descriptor_reg", &M
);
173 Func
->setSection(".text.startup");
175 // Get __tgt_register_lib function declaration.
176 auto *RegFuncTy
= FunctionType::get(Type::getVoidTy(C
), getBinDescPtrTy(M
),
178 FunctionCallee RegFuncC
=
179 M
.getOrInsertFunction("__tgt_register_lib", RegFuncTy
);
181 // Construct function body
182 IRBuilder
<> Builder(BasicBlock::Create(C
, "entry", Func
));
183 Builder
.CreateCall(RegFuncC
, BinDesc
);
184 Builder
.CreateRetVoid();
186 // Add this function to constructors.
187 // Set priority to 1 so that __tgt_register_lib is executed AFTER
188 // __tgt_register_requires (we want to know what requirements have been
189 // asked for before we load a libomptarget plugin so that by the time the
190 // plugin is loaded it can report how many devices there are which can
191 // satisfy these requirements).
192 appendToGlobalCtors(M
, Func
, /*Priority*/ 1);
195 void createUnregisterFunction(Module
&M
, GlobalVariable
*BinDesc
) {
196 LLVMContext
&C
= M
.getContext();
197 auto *FuncTy
= FunctionType::get(Type::getVoidTy(C
), /*isVarArg*/ false);
198 auto *Func
= Function::Create(FuncTy
, GlobalValue::InternalLinkage
,
199 ".omp_offloading.descriptor_unreg", &M
);
200 Func
->setSection(".text.startup");
202 // Get __tgt_unregister_lib function declaration.
203 auto *UnRegFuncTy
= FunctionType::get(Type::getVoidTy(C
), getBinDescPtrTy(M
),
205 FunctionCallee UnRegFuncC
=
206 M
.getOrInsertFunction("__tgt_unregister_lib", UnRegFuncTy
);
208 // Construct function body
209 IRBuilder
<> Builder(BasicBlock::Create(C
, "entry", Func
));
210 Builder
.CreateCall(UnRegFuncC
, BinDesc
);
211 Builder
.CreateRetVoid();
213 // Add this function to global destructors.
214 // Match priority of __tgt_register_lib
215 appendToGlobalDtors(M
, Func
, /*Priority*/ 1);
218 // struct fatbin_wrapper {
224 StructType
*getFatbinWrapperTy(Module
&M
) {
225 LLVMContext
&C
= M
.getContext();
226 StructType
*FatbinTy
= StructType::getTypeByName(C
, "fatbin_wrapper");
228 FatbinTy
= StructType::create(
229 "fatbin_wrapper", Type::getInt32Ty(C
), Type::getInt32Ty(C
),
230 PointerType::getUnqual(C
), PointerType::getUnqual(C
));
234 /// Embed the image \p Image into the module \p M so it can be found by the
236 GlobalVariable
*createFatbinDesc(Module
&M
, ArrayRef
<char> Image
, bool IsHIP
) {
237 LLVMContext
&C
= M
.getContext();
238 llvm::Type
*Int8PtrTy
= PointerType::getUnqual(C
);
239 llvm::Triple Triple
= llvm::Triple(M
.getTargetTriple());
241 // Create the global string containing the fatbinary.
242 StringRef FatbinConstantSection
=
243 IsHIP
? ".hip_fatbin"
244 : (Triple
.isMacOSX() ? "__NV_CUDA,__nv_fatbin" : ".nv_fatbin");
245 auto *Data
= ConstantDataArray::get(C
, Image
);
246 auto *Fatbin
= new GlobalVariable(M
, Data
->getType(), /*isConstant*/ true,
247 GlobalVariable::InternalLinkage
, Data
,
249 Fatbin
->setSection(FatbinConstantSection
);
251 // Create the fatbinary wrapper
252 StringRef FatbinWrapperSection
= IsHIP
? ".hipFatBinSegment"
253 : Triple
.isMacOSX() ? "__NV_CUDA,__fatbin"
254 : ".nvFatBinSegment";
255 Constant
*FatbinWrapper
[] = {
256 ConstantInt::get(Type::getInt32Ty(C
), IsHIP
? HIPFatMagic
: CudaFatMagic
),
257 ConstantInt::get(Type::getInt32Ty(C
), 1),
258 ConstantExpr::getPointerBitCastOrAddrSpaceCast(Fatbin
, Int8PtrTy
),
259 ConstantPointerNull::get(PointerType::getUnqual(C
))};
261 Constant
*FatbinInitializer
=
262 ConstantStruct::get(getFatbinWrapperTy(M
), FatbinWrapper
);
265 new GlobalVariable(M
, getFatbinWrapperTy(M
),
266 /*isConstant*/ true, GlobalValue::InternalLinkage
,
267 FatbinInitializer
, ".fatbin_wrapper");
268 FatbinDesc
->setSection(FatbinWrapperSection
);
269 FatbinDesc
->setAlignment(Align(8));
274 /// Create the register globals function. We will iterate all of the offloading
275 /// entries stored at the begin / end symbols and register them according to
276 /// their type. This creates the following function in IR:
278 /// extern struct __tgt_offload_entry __start_cuda_offloading_entries;
279 /// extern struct __tgt_offload_entry __stop_cuda_offloading_entries;
281 /// extern void __cudaRegisterFunction(void **, void *, void *, void *, int,
282 /// void *, void *, void *, void *, int *);
283 /// extern void __cudaRegisterVar(void **, void *, void *, void *, int32_t,
284 /// int64_t, int32_t, int32_t);
286 /// void __cudaRegisterTest(void **fatbinHandle) {
287 /// for (struct __tgt_offload_entry *entry = &__start_cuda_offloading_entries;
288 /// entry != &__stop_cuda_offloading_entries; ++entry) {
289 /// if (!entry->size)
290 /// __cudaRegisterFunction(fatbinHandle, entry->addr, entry->name,
291 /// entry->name, -1, 0, 0, 0, 0, 0);
293 /// __cudaRegisterVar(fatbinHandle, entry->addr, entry->name, entry->name,
294 /// 0, entry->size, 0, 0);
297 Function
*createRegisterGlobalsFunction(Module
&M
, bool IsHIP
) {
298 LLVMContext
&C
= M
.getContext();
299 auto [EntriesB
, EntriesE
] = offloading::getOffloadEntryArray(
300 M
, IsHIP
? "hip_offloading_entries" : "cuda_offloading_entries");
302 // Get the __cudaRegisterFunction function declaration.
303 PointerType
*Int8PtrTy
= PointerType::get(C
, 0);
304 PointerType
*Int8PtrPtrTy
= PointerType::get(C
, 0);
305 PointerType
*Int32PtrTy
= PointerType::get(C
, 0);
306 auto *RegFuncTy
= FunctionType::get(
308 {Int8PtrPtrTy
, Int8PtrTy
, Int8PtrTy
, Int8PtrTy
, Type::getInt32Ty(C
),
309 Int8PtrTy
, Int8PtrTy
, Int8PtrTy
, Int8PtrTy
, Int32PtrTy
},
311 FunctionCallee RegFunc
= M
.getOrInsertFunction(
312 IsHIP
? "__hipRegisterFunction" : "__cudaRegisterFunction", RegFuncTy
);
314 // Get the __cudaRegisterVar function declaration.
315 auto *RegVarTy
= FunctionType::get(
317 {Int8PtrPtrTy
, Int8PtrTy
, Int8PtrTy
, Int8PtrTy
, Type::getInt32Ty(C
),
318 getSizeTTy(M
), Type::getInt32Ty(C
), Type::getInt32Ty(C
)},
320 FunctionCallee RegVar
= M
.getOrInsertFunction(
321 IsHIP
? "__hipRegisterVar" : "__cudaRegisterVar", RegVarTy
);
323 // Get the __cudaRegisterSurface function declaration.
325 FunctionType::get(Type::getVoidTy(C
),
326 {Int8PtrPtrTy
, Int8PtrTy
, Int8PtrTy
, Int8PtrTy
,
327 Type::getInt32Ty(C
), Type::getInt32Ty(C
)},
329 FunctionCallee RegSurface
= M
.getOrInsertFunction(
330 IsHIP
? "__hipRegisterSurface" : "__cudaRegisterSurface", RegSurfaceTy
);
332 // Get the __cudaRegisterTexture function declaration.
333 auto *RegTextureTy
= FunctionType::get(
335 {Int8PtrPtrTy
, Int8PtrTy
, Int8PtrTy
, Int8PtrTy
, Type::getInt32Ty(C
),
336 Type::getInt32Ty(C
), Type::getInt32Ty(C
)},
338 FunctionCallee RegTexture
= M
.getOrInsertFunction(
339 IsHIP
? "__hipRegisterTexture" : "__cudaRegisterTexture", RegTextureTy
);
341 auto *RegGlobalsTy
= FunctionType::get(Type::getVoidTy(C
), Int8PtrPtrTy
,
344 Function::Create(RegGlobalsTy
, GlobalValue::InternalLinkage
,
345 IsHIP
? ".hip.globals_reg" : ".cuda.globals_reg", &M
);
346 RegGlobalsFn
->setSection(".text.startup");
348 // Create the loop to register all the entries.
349 IRBuilder
<> Builder(BasicBlock::Create(C
, "entry", RegGlobalsFn
));
350 auto *EntryBB
= BasicBlock::Create(C
, "while.entry", RegGlobalsFn
);
351 auto *IfThenBB
= BasicBlock::Create(C
, "if.then", RegGlobalsFn
);
352 auto *IfElseBB
= BasicBlock::Create(C
, "if.else", RegGlobalsFn
);
353 auto *SwGlobalBB
= BasicBlock::Create(C
, "sw.global", RegGlobalsFn
);
354 auto *SwManagedBB
= BasicBlock::Create(C
, "sw.managed", RegGlobalsFn
);
355 auto *SwSurfaceBB
= BasicBlock::Create(C
, "sw.surface", RegGlobalsFn
);
356 auto *SwTextureBB
= BasicBlock::Create(C
, "sw.texture", RegGlobalsFn
);
357 auto *IfEndBB
= BasicBlock::Create(C
, "if.end", RegGlobalsFn
);
358 auto *ExitBB
= BasicBlock::Create(C
, "while.end", RegGlobalsFn
);
360 auto *EntryCmp
= Builder
.CreateICmpNE(EntriesB
, EntriesE
);
361 Builder
.CreateCondBr(EntryCmp
, EntryBB
, ExitBB
);
362 Builder
.SetInsertPoint(EntryBB
);
363 auto *Entry
= Builder
.CreatePHI(PointerType::getUnqual(C
), 2, "entry");
365 Builder
.CreateInBoundsGEP(offloading::getEntryTy(M
), Entry
,
366 {ConstantInt::get(getSizeTTy(M
), 0),
367 ConstantInt::get(Type::getInt32Ty(C
), 0)});
368 auto *Addr
= Builder
.CreateLoad(Int8PtrTy
, AddrPtr
, "addr");
370 Builder
.CreateInBoundsGEP(offloading::getEntryTy(M
), Entry
,
371 {ConstantInt::get(getSizeTTy(M
), 0),
372 ConstantInt::get(Type::getInt32Ty(C
), 1)});
373 auto *Name
= Builder
.CreateLoad(Int8PtrTy
, NamePtr
, "name");
375 Builder
.CreateInBoundsGEP(offloading::getEntryTy(M
), Entry
,
376 {ConstantInt::get(getSizeTTy(M
), 0),
377 ConstantInt::get(Type::getInt32Ty(C
), 2)});
378 auto *Size
= Builder
.CreateLoad(getSizeTTy(M
), SizePtr
, "size");
380 Builder
.CreateInBoundsGEP(offloading::getEntryTy(M
), Entry
,
381 {ConstantInt::get(getSizeTTy(M
), 0),
382 ConstantInt::get(Type::getInt32Ty(C
), 3)});
383 auto *Flags
= Builder
.CreateLoad(Type::getInt32Ty(C
), FlagsPtr
, "flags");
385 Builder
.CreateInBoundsGEP(offloading::getEntryTy(M
), Entry
,
386 {ConstantInt::get(getSizeTTy(M
), 0),
387 ConstantInt::get(Type::getInt32Ty(C
), 4)});
388 auto *Data
= Builder
.CreateLoad(Type::getInt32Ty(C
), DataPtr
, "textype");
389 auto *Kind
= Builder
.CreateAnd(
390 Flags
, ConstantInt::get(Type::getInt32Ty(C
), 0x7), "type");
392 // Extract the flags stored in the bit-field and convert them to C booleans.
393 auto *ExternBit
= Builder
.CreateAnd(
394 Flags
, ConstantInt::get(Type::getInt32Ty(C
),
395 llvm::offloading::OffloadGlobalExtern
));
396 auto *Extern
= Builder
.CreateLShr(
397 ExternBit
, ConstantInt::get(Type::getInt32Ty(C
), 3), "extern");
398 auto *ConstantBit
= Builder
.CreateAnd(
399 Flags
, ConstantInt::get(Type::getInt32Ty(C
),
400 llvm::offloading::OffloadGlobalConstant
));
401 auto *Const
= Builder
.CreateLShr(
402 ConstantBit
, ConstantInt::get(Type::getInt32Ty(C
), 4), "constant");
403 auto *NormalizedBit
= Builder
.CreateAnd(
404 Flags
, ConstantInt::get(Type::getInt32Ty(C
),
405 llvm::offloading::OffloadGlobalNormalized
));
406 auto *Normalized
= Builder
.CreateLShr(
407 NormalizedBit
, ConstantInt::get(Type::getInt32Ty(C
), 5), "normalized");
409 Builder
.CreateICmpEQ(Size
, ConstantInt::getNullValue(getSizeTTy(M
)));
410 Builder
.CreateCondBr(FnCond
, IfThenBB
, IfElseBB
);
412 // Create kernel registration code.
413 Builder
.SetInsertPoint(IfThenBB
);
414 Builder
.CreateCall(RegFunc
, {RegGlobalsFn
->arg_begin(), Addr
, Name
, Name
,
415 ConstantInt::get(Type::getInt32Ty(C
), -1),
416 ConstantPointerNull::get(Int8PtrTy
),
417 ConstantPointerNull::get(Int8PtrTy
),
418 ConstantPointerNull::get(Int8PtrTy
),
419 ConstantPointerNull::get(Int8PtrTy
),
420 ConstantPointerNull::get(Int32PtrTy
)});
421 Builder
.CreateBr(IfEndBB
);
422 Builder
.SetInsertPoint(IfElseBB
);
424 auto *Switch
= Builder
.CreateSwitch(Kind
, IfEndBB
);
425 // Create global variable registration code.
426 Builder
.SetInsertPoint(SwGlobalBB
);
427 Builder
.CreateCall(RegVar
,
428 {RegGlobalsFn
->arg_begin(), Addr
, Name
, Name
, Extern
, Size
,
429 Const
, ConstantInt::get(Type::getInt32Ty(C
), 0)});
430 Builder
.CreateBr(IfEndBB
);
431 Switch
->addCase(Builder
.getInt32(llvm::offloading::OffloadGlobalEntry
),
434 // Create managed variable registration code.
435 Builder
.SetInsertPoint(SwManagedBB
);
436 Builder
.CreateBr(IfEndBB
);
437 Switch
->addCase(Builder
.getInt32(llvm::offloading::OffloadGlobalManagedEntry
),
440 // Create surface variable registration code.
441 Builder
.SetInsertPoint(SwSurfaceBB
);
443 RegSurface
, {RegGlobalsFn
->arg_begin(), Addr
, Name
, Name
, Data
, Extern
});
444 Builder
.CreateBr(IfEndBB
);
445 Switch
->addCase(Builder
.getInt32(llvm::offloading::OffloadGlobalSurfaceEntry
),
448 // Create texture variable registration code.
449 Builder
.SetInsertPoint(SwTextureBB
);
450 Builder
.CreateCall(RegTexture
, {RegGlobalsFn
->arg_begin(), Addr
, Name
, Name
,
451 Data
, Normalized
, Extern
});
452 Builder
.CreateBr(IfEndBB
);
453 Switch
->addCase(Builder
.getInt32(llvm::offloading::OffloadGlobalTextureEntry
),
456 Builder
.SetInsertPoint(IfEndBB
);
457 auto *NewEntry
= Builder
.CreateInBoundsGEP(
458 offloading::getEntryTy(M
), Entry
, ConstantInt::get(getSizeTTy(M
), 1));
459 auto *Cmp
= Builder
.CreateICmpEQ(
461 ConstantExpr::getInBoundsGetElementPtr(
462 ArrayType::get(offloading::getEntryTy(M
), 0), EntriesE
,
463 ArrayRef
<Constant
*>({ConstantInt::get(getSizeTTy(M
), 0),
464 ConstantInt::get(getSizeTTy(M
), 0)})));
466 ConstantExpr::getInBoundsGetElementPtr(
467 ArrayType::get(offloading::getEntryTy(M
), 0), EntriesB
,
468 ArrayRef
<Constant
*>({ConstantInt::get(getSizeTTy(M
), 0),
469 ConstantInt::get(getSizeTTy(M
), 0)})),
470 &RegGlobalsFn
->getEntryBlock());
471 Entry
->addIncoming(NewEntry
, IfEndBB
);
472 Builder
.CreateCondBr(Cmp
, ExitBB
, EntryBB
);
473 Builder
.SetInsertPoint(ExitBB
);
474 Builder
.CreateRetVoid();
479 // Create the constructor and destructor to register the fatbinary with the CUDA
481 void createRegisterFatbinFunction(Module
&M
, GlobalVariable
*FatbinDesc
,
483 LLVMContext
&C
= M
.getContext();
484 auto *CtorFuncTy
= FunctionType::get(Type::getVoidTy(C
), /*isVarArg*/ false);
486 Function::Create(CtorFuncTy
, GlobalValue::InternalLinkage
,
487 IsHIP
? ".hip.fatbin_reg" : ".cuda.fatbin_reg", &M
);
488 CtorFunc
->setSection(".text.startup");
490 auto *DtorFuncTy
= FunctionType::get(Type::getVoidTy(C
), /*isVarArg*/ false);
492 Function::Create(DtorFuncTy
, GlobalValue::InternalLinkage
,
493 IsHIP
? ".hip.fatbin_unreg" : ".cuda.fatbin_unreg", &M
);
494 DtorFunc
->setSection(".text.startup");
496 auto *PtrTy
= PointerType::getUnqual(C
);
498 // Get the __cudaRegisterFatBinary function declaration.
499 auto *RegFatTy
= FunctionType::get(PtrTy
, PtrTy
, /*isVarArg=*/false);
500 FunctionCallee RegFatbin
= M
.getOrInsertFunction(
501 IsHIP
? "__hipRegisterFatBinary" : "__cudaRegisterFatBinary", RegFatTy
);
502 // Get the __cudaRegisterFatBinaryEnd function declaration.
504 FunctionType::get(Type::getVoidTy(C
), PtrTy
, /*isVarArg=*/false);
505 FunctionCallee RegFatbinEnd
=
506 M
.getOrInsertFunction("__cudaRegisterFatBinaryEnd", RegFatEndTy
);
507 // Get the __cudaUnregisterFatBinary function declaration.
509 FunctionType::get(Type::getVoidTy(C
), PtrTy
, /*isVarArg=*/false);
510 FunctionCallee UnregFatbin
= M
.getOrInsertFunction(
511 IsHIP
? "__hipUnregisterFatBinary" : "__cudaUnregisterFatBinary",
515 FunctionType::get(Type::getInt32Ty(C
), PtrTy
, /*isVarArg=*/false);
516 FunctionCallee AtExit
= M
.getOrInsertFunction("atexit", AtExitTy
);
518 auto *BinaryHandleGlobal
= new llvm::GlobalVariable(
519 M
, PtrTy
, false, llvm::GlobalValue::InternalLinkage
,
520 llvm::ConstantPointerNull::get(PtrTy
),
521 IsHIP
? ".hip.binary_handle" : ".cuda.binary_handle");
523 // Create the constructor to register this image with the runtime.
524 IRBuilder
<> CtorBuilder(BasicBlock::Create(C
, "entry", CtorFunc
));
525 CallInst
*Handle
= CtorBuilder
.CreateCall(
527 ConstantExpr::getPointerBitCastOrAddrSpaceCast(FatbinDesc
, PtrTy
));
528 CtorBuilder
.CreateAlignedStore(
529 Handle
, BinaryHandleGlobal
,
530 Align(M
.getDataLayout().getPointerTypeSize(PtrTy
)));
531 CtorBuilder
.CreateCall(createRegisterGlobalsFunction(M
, IsHIP
), Handle
);
533 CtorBuilder
.CreateCall(RegFatbinEnd
, Handle
);
534 CtorBuilder
.CreateCall(AtExit
, DtorFunc
);
535 CtorBuilder
.CreateRetVoid();
537 // Create the destructor to unregister the image with the runtime. We cannot
538 // use a standard global destructor after CUDA 9.2 so this must be called by
539 // `atexit()` intead.
540 IRBuilder
<> DtorBuilder(BasicBlock::Create(C
, "entry", DtorFunc
));
541 LoadInst
*BinaryHandle
= DtorBuilder
.CreateAlignedLoad(
542 PtrTy
, BinaryHandleGlobal
,
543 Align(M
.getDataLayout().getPointerTypeSize(PtrTy
)));
544 DtorBuilder
.CreateCall(UnregFatbin
, BinaryHandle
);
545 DtorBuilder
.CreateRetVoid();
547 // Add this function to constructors.
548 appendToGlobalCtors(M
, CtorFunc
, /*Priority*/ 1);
553 Error
wrapOpenMPBinaries(Module
&M
, ArrayRef
<ArrayRef
<char>> Images
) {
554 GlobalVariable
*Desc
= createBinDesc(M
, Images
);
556 return createStringError(inconvertibleErrorCode(),
557 "No binary descriptors created.");
558 createRegisterFunction(M
, Desc
);
559 createUnregisterFunction(M
, Desc
);
560 return Error::success();
563 Error
wrapCudaBinary(Module
&M
, ArrayRef
<char> Image
) {
564 GlobalVariable
*Desc
= createFatbinDesc(M
, Image
, /* IsHIP */ false);
566 return createStringError(inconvertibleErrorCode(),
567 "No fatinbary section created.");
569 createRegisterFatbinFunction(M
, Desc
, /* IsHIP */ false);
570 return Error::success();
573 Error
wrapHIPBinary(Module
&M
, ArrayRef
<char> Image
) {
574 GlobalVariable
*Desc
= createFatbinDesc(M
, Image
, /* IsHIP */ true);
576 return createStringError(inconvertibleErrorCode(),
577 "No fatinbary section created.");
579 createRegisterFatbinFunction(M
, Desc
, /* IsHIP */ true);
580 return Error::success();