1 //===- OffloadWrapper.cpp ---------------------------------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "OffloadWrapper.h"
10 #include "llvm/ADT/ArrayRef.h"
11 #include "llvm/IR/Constants.h"
12 #include "llvm/IR/GlobalVariable.h"
13 #include "llvm/IR/IRBuilder.h"
14 #include "llvm/IR/LLVMContext.h"
15 #include "llvm/IR/Module.h"
16 #include "llvm/Object/OffloadBinary.h"
17 #include "llvm/Support/Error.h"
18 #include "llvm/TargetParser/Triple.h"
19 #include "llvm/Transforms/Utils/ModuleUtils.h"
24 /// Magic number that begins the section containing the CUDA fatbinary.
25 constexpr unsigned CudaFatMagic
= 0x466243b1;
26 constexpr unsigned HIPFatMagic
= 0x48495046;
28 /// Copied from clang/CGCudaRuntime.h.
29 enum OffloadEntryKindFlag
: uint32_t {
30 /// Mark the entry as a global entry. This indicates the presense of a
31 /// kernel if the size size field is zero and a variable otherwise.
32 OffloadGlobalEntry
= 0x0,
33 /// Mark the entry as a managed global variable.
34 OffloadGlobalManagedEntry
= 0x1,
35 /// Mark the entry as a surface variable.
36 OffloadGlobalSurfaceEntry
= 0x2,
37 /// Mark the entry as a texture variable.
38 OffloadGlobalTextureEntry
= 0x3,
41 IntegerType
*getSizeTTy(Module
&M
) {
42 LLVMContext
&C
= M
.getContext();
43 switch (M
.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C
))) {
45 return Type::getInt32Ty(C
);
47 return Type::getInt64Ty(C
);
49 llvm_unreachable("unsupported pointer type size");
52 // struct __tgt_offload_entry {
59 StructType
*getEntryTy(Module
&M
) {
60 LLVMContext
&C
= M
.getContext();
61 StructType
*EntryTy
= StructType::getTypeByName(C
, "__tgt_offload_entry");
63 EntryTy
= StructType::create("__tgt_offload_entry", Type::getInt8PtrTy(C
),
64 Type::getInt8PtrTy(C
), getSizeTTy(M
),
65 Type::getInt32Ty(C
), Type::getInt32Ty(C
));
69 PointerType
*getEntryPtrTy(Module
&M
) {
70 return PointerType::getUnqual(getEntryTy(M
));
73 // struct __tgt_device_image {
76 // __tgt_offload_entry *EntriesBegin;
77 // __tgt_offload_entry *EntriesEnd;
79 StructType
*getDeviceImageTy(Module
&M
) {
80 LLVMContext
&C
= M
.getContext();
81 StructType
*ImageTy
= StructType::getTypeByName(C
, "__tgt_device_image");
83 ImageTy
= StructType::create("__tgt_device_image", Type::getInt8PtrTy(C
),
84 Type::getInt8PtrTy(C
), getEntryPtrTy(M
),
89 PointerType
*getDeviceImagePtrTy(Module
&M
) {
90 return PointerType::getUnqual(getDeviceImageTy(M
));
93 // struct __tgt_bin_desc {
94 // int32_t NumDeviceImages;
95 // __tgt_device_image *DeviceImages;
96 // __tgt_offload_entry *HostEntriesBegin;
97 // __tgt_offload_entry *HostEntriesEnd;
99 StructType
*getBinDescTy(Module
&M
) {
100 LLVMContext
&C
= M
.getContext();
101 StructType
*DescTy
= StructType::getTypeByName(C
, "__tgt_bin_desc");
103 DescTy
= StructType::create("__tgt_bin_desc", Type::getInt32Ty(C
),
104 getDeviceImagePtrTy(M
), getEntryPtrTy(M
),
109 PointerType
*getBinDescPtrTy(Module
&M
) {
110 return PointerType::getUnqual(getBinDescTy(M
));
113 /// Creates binary descriptor for the given device images. Binary descriptor
114 /// is an object that is passed to the offloading runtime at program startup
115 /// and it describes all device images available in the executable or shared
116 /// library. It is defined as follows
118 /// __attribute__((visibility("hidden")))
119 /// extern __tgt_offload_entry *__start_omp_offloading_entries;
120 /// __attribute__((visibility("hidden")))
121 /// extern __tgt_offload_entry *__stop_omp_offloading_entries;
123 /// static const char Image0[] = { <Bufs.front() contents> };
125 /// static const char ImageN[] = { <Bufs.back() contents> };
127 /// static const __tgt_device_image Images[] = {
129 /// Image0, /*ImageStart*/
130 /// Image0 + sizeof(Image0), /*ImageEnd*/
131 /// __start_omp_offloading_entries, /*EntriesBegin*/
132 /// __stop_omp_offloading_entries /*EntriesEnd*/
136 /// ImageN, /*ImageStart*/
137 /// ImageN + sizeof(ImageN), /*ImageEnd*/
138 /// __start_omp_offloading_entries, /*EntriesBegin*/
139 /// __stop_omp_offloading_entries /*EntriesEnd*/
143 /// static const __tgt_bin_desc BinDesc = {
144 /// sizeof(Images) / sizeof(Images[0]), /*NumDeviceImages*/
145 /// Images, /*DeviceImages*/
146 /// __start_omp_offloading_entries, /*HostEntriesBegin*/
147 /// __stop_omp_offloading_entries /*HostEntriesEnd*/
150 /// Global variable that represents BinDesc is returned.
151 GlobalVariable
*createBinDesc(Module
&M
, ArrayRef
<ArrayRef
<char>> Bufs
) {
152 LLVMContext
&C
= M
.getContext();
153 // Create external begin/end symbols for the offload entries table.
154 auto *EntriesB
= new GlobalVariable(
155 M
, getEntryTy(M
), /*isConstant*/ true, GlobalValue::ExternalLinkage
,
156 /*Initializer*/ nullptr, "__start_omp_offloading_entries");
157 EntriesB
->setVisibility(GlobalValue::HiddenVisibility
);
158 auto *EntriesE
= new GlobalVariable(
159 M
, getEntryTy(M
), /*isConstant*/ true, GlobalValue::ExternalLinkage
,
160 /*Initializer*/ nullptr, "__stop_omp_offloading_entries");
161 EntriesE
->setVisibility(GlobalValue::HiddenVisibility
);
163 // We assume that external begin/end symbols that we have created above will
164 // be defined by the linker. But linker will do that only if linker inputs
165 // have section with "omp_offloading_entries" name which is not guaranteed.
166 // So, we just create dummy zero sized object in the offload entries section
167 // to force linker to define those symbols.
169 ConstantAggregateZero::get(ArrayType::get(getEntryTy(M
), 0u));
170 auto *DummyEntry
= new GlobalVariable(
171 M
, DummyInit
->getType(), true, GlobalVariable::ExternalLinkage
, DummyInit
,
172 "__dummy.omp_offloading.entry");
173 DummyEntry
->setSection("omp_offloading_entries");
174 DummyEntry
->setVisibility(GlobalValue::HiddenVisibility
);
176 auto *Zero
= ConstantInt::get(getSizeTTy(M
), 0u);
177 Constant
*ZeroZero
[] = {Zero
, Zero
};
179 // Create initializer for the images array.
180 SmallVector
<Constant
*, 4u> ImagesInits
;
181 ImagesInits
.reserve(Bufs
.size());
182 for (ArrayRef
<char> Buf
: Bufs
) {
183 auto *Data
= ConstantDataArray::get(C
, Buf
);
184 auto *Image
= new GlobalVariable(M
, Data
->getType(), /*isConstant*/ true,
185 GlobalVariable::InternalLinkage
, Data
,
186 ".omp_offloading.device_image");
187 Image
->setUnnamedAddr(GlobalValue::UnnamedAddr::Global
);
188 Image
->setSection(".llvm.offloading");
189 Image
->setAlignment(Align(object::OffloadBinary::getAlignment()));
191 auto *Size
= ConstantInt::get(getSizeTTy(M
), Buf
.size());
192 Constant
*ZeroSize
[] = {Zero
, Size
};
195 ConstantExpr::getGetElementPtr(Image
->getValueType(), Image
, ZeroZero
);
197 ConstantExpr::getGetElementPtr(Image
->getValueType(), Image
, ZeroSize
);
199 ImagesInits
.push_back(ConstantStruct::get(getDeviceImageTy(M
), ImageB
,
200 ImageE
, EntriesB
, EntriesE
));
203 // Then create images array.
204 auto *ImagesData
= ConstantArray::get(
205 ArrayType::get(getDeviceImageTy(M
), ImagesInits
.size()), ImagesInits
);
208 new GlobalVariable(M
, ImagesData
->getType(), /*isConstant*/ true,
209 GlobalValue::InternalLinkage
, ImagesData
,
210 ".omp_offloading.device_images");
211 Images
->setUnnamedAddr(GlobalValue::UnnamedAddr::Global
);
214 ConstantExpr::getGetElementPtr(Images
->getValueType(), Images
, ZeroZero
);
216 // And finally create the binary descriptor object.
217 auto *DescInit
= ConstantStruct::get(
219 ConstantInt::get(Type::getInt32Ty(C
), ImagesInits
.size()), ImagesB
,
222 return new GlobalVariable(M
, DescInit
->getType(), /*isConstant*/ true,
223 GlobalValue::InternalLinkage
, DescInit
,
224 ".omp_offloading.descriptor");
227 void createRegisterFunction(Module
&M
, GlobalVariable
*BinDesc
) {
228 LLVMContext
&C
= M
.getContext();
229 auto *FuncTy
= FunctionType::get(Type::getVoidTy(C
), /*isVarArg*/ false);
230 auto *Func
= Function::Create(FuncTy
, GlobalValue::InternalLinkage
,
231 ".omp_offloading.descriptor_reg", &M
);
232 Func
->setSection(".text.startup");
234 // Get __tgt_register_lib function declaration.
235 auto *RegFuncTy
= FunctionType::get(Type::getVoidTy(C
), getBinDescPtrTy(M
),
237 FunctionCallee RegFuncC
=
238 M
.getOrInsertFunction("__tgt_register_lib", RegFuncTy
);
240 // Construct function body
241 IRBuilder
<> Builder(BasicBlock::Create(C
, "entry", Func
));
242 Builder
.CreateCall(RegFuncC
, BinDesc
);
243 Builder
.CreateRetVoid();
245 // Add this function to constructors.
246 // Set priority to 1 so that __tgt_register_lib is executed AFTER
247 // __tgt_register_requires (we want to know what requirements have been
248 // asked for before we load a libomptarget plugin so that by the time the
249 // plugin is loaded it can report how many devices there are which can
250 // satisfy these requirements).
251 appendToGlobalCtors(M
, Func
, /*Priority*/ 1);
254 void createUnregisterFunction(Module
&M
, GlobalVariable
*BinDesc
) {
255 LLVMContext
&C
= M
.getContext();
256 auto *FuncTy
= FunctionType::get(Type::getVoidTy(C
), /*isVarArg*/ false);
257 auto *Func
= Function::Create(FuncTy
, GlobalValue::InternalLinkage
,
258 ".omp_offloading.descriptor_unreg", &M
);
259 Func
->setSection(".text.startup");
261 // Get __tgt_unregister_lib function declaration.
262 auto *UnRegFuncTy
= FunctionType::get(Type::getVoidTy(C
), getBinDescPtrTy(M
),
264 FunctionCallee UnRegFuncC
=
265 M
.getOrInsertFunction("__tgt_unregister_lib", UnRegFuncTy
);
267 // Construct function body
268 IRBuilder
<> Builder(BasicBlock::Create(C
, "entry", Func
));
269 Builder
.CreateCall(UnRegFuncC
, BinDesc
);
270 Builder
.CreateRetVoid();
272 // Add this function to global destructors.
273 // Match priority of __tgt_register_lib
274 appendToGlobalDtors(M
, Func
, /*Priority*/ 1);
277 // struct fatbin_wrapper {
283 StructType
*getFatbinWrapperTy(Module
&M
) {
284 LLVMContext
&C
= M
.getContext();
285 StructType
*FatbinTy
= StructType::getTypeByName(C
, "fatbin_wrapper");
287 FatbinTy
= StructType::create("fatbin_wrapper", Type::getInt32Ty(C
),
288 Type::getInt32Ty(C
), Type::getInt8PtrTy(C
),
289 Type::getInt8PtrTy(C
));
293 /// Embed the image \p Image into the module \p M so it can be found by the
295 GlobalVariable
*createFatbinDesc(Module
&M
, ArrayRef
<char> Image
, bool IsHIP
) {
296 LLVMContext
&C
= M
.getContext();
297 llvm::Type
*Int8PtrTy
= Type::getInt8PtrTy(C
);
298 llvm::Triple Triple
= llvm::Triple(M
.getTargetTriple());
300 // Create the global string containing the fatbinary.
301 StringRef FatbinConstantSection
=
302 IsHIP
? ".hip_fatbin"
303 : (Triple
.isMacOSX() ? "__NV_CUDA,__nv_fatbin" : ".nv_fatbin");
304 auto *Data
= ConstantDataArray::get(C
, Image
);
305 auto *Fatbin
= new GlobalVariable(M
, Data
->getType(), /*isConstant*/ true,
306 GlobalVariable::InternalLinkage
, Data
,
308 Fatbin
->setSection(FatbinConstantSection
);
310 // Create the fatbinary wrapper
311 StringRef FatbinWrapperSection
= IsHIP
? ".hipFatBinSegment"
312 : Triple
.isMacOSX() ? "__NV_CUDA,__fatbin"
313 : ".nvFatBinSegment";
314 Constant
*FatbinWrapper
[] = {
315 ConstantInt::get(Type::getInt32Ty(C
), IsHIP
? HIPFatMagic
: CudaFatMagic
),
316 ConstantInt::get(Type::getInt32Ty(C
), 1),
317 ConstantExpr::getPointerBitCastOrAddrSpaceCast(Fatbin
, Int8PtrTy
),
318 ConstantPointerNull::get(Type::getInt8PtrTy(C
))};
320 Constant
*FatbinInitializer
=
321 ConstantStruct::get(getFatbinWrapperTy(M
), FatbinWrapper
);
324 new GlobalVariable(M
, getFatbinWrapperTy(M
),
325 /*isConstant*/ true, GlobalValue::InternalLinkage
,
326 FatbinInitializer
, ".fatbin_wrapper");
327 FatbinDesc
->setSection(FatbinWrapperSection
);
328 FatbinDesc
->setAlignment(Align(8));
330 // We create a dummy entry to ensure the linker will define the begin / end
331 // symbols. The CUDA runtime should ignore the null address if we attempt to
334 ConstantAggregateZero::get(ArrayType::get(getEntryTy(M
), 0u));
335 auto *DummyEntry
= new GlobalVariable(
336 M
, DummyInit
->getType(), true, GlobalVariable::ExternalLinkage
, DummyInit
,
337 IsHIP
? "__dummy.hip_offloading.entry" : "__dummy.cuda_offloading.entry");
338 DummyEntry
->setVisibility(GlobalValue::HiddenVisibility
);
339 DummyEntry
->setSection(IsHIP
? "hip_offloading_entries"
340 : "cuda_offloading_entries");
345 /// Create the register globals function. We will iterate all of the offloading
346 /// entries stored at the begin / end symbols and register them according to
347 /// their type. This creates the following function in IR:
349 /// extern struct __tgt_offload_entry __start_cuda_offloading_entries;
350 /// extern struct __tgt_offload_entry __stop_cuda_offloading_entries;
352 /// extern void __cudaRegisterFunction(void **, void *, void *, void *, int,
353 /// void *, void *, void *, void *, int *);
354 /// extern void __cudaRegisterVar(void **, void *, void *, void *, int32_t,
355 /// int64_t, int32_t, int32_t);
357 /// void __cudaRegisterTest(void **fatbinHandle) {
358 /// for (struct __tgt_offload_entry *entry = &__start_cuda_offloading_entries;
359 /// entry != &__stop_cuda_offloading_entries; ++entry) {
360 /// if (!entry->size)
361 /// __cudaRegisterFunction(fatbinHandle, entry->addr, entry->name,
362 /// entry->name, -1, 0, 0, 0, 0, 0);
364 /// __cudaRegisterVar(fatbinHandle, entry->addr, entry->name, entry->name,
365 /// 0, entry->size, 0, 0);
368 Function
*createRegisterGlobalsFunction(Module
&M
, bool IsHIP
) {
369 LLVMContext
&C
= M
.getContext();
370 // Get the __cudaRegisterFunction function declaration.
371 PointerType
*Int8PtrTy
= PointerType::get(C
, 0);
372 PointerType
*Int8PtrPtrTy
= PointerType::get(C
, 0);
373 PointerType
*Int32PtrTy
= PointerType::get(C
, 0);
374 auto *RegFuncTy
= FunctionType::get(
376 {Int8PtrPtrTy
, Int8PtrTy
, Int8PtrTy
, Int8PtrTy
, Type::getInt32Ty(C
),
377 Int8PtrTy
, Int8PtrTy
, Int8PtrTy
, Int8PtrTy
, Int32PtrTy
},
379 FunctionCallee RegFunc
= M
.getOrInsertFunction(
380 IsHIP
? "__hipRegisterFunction" : "__cudaRegisterFunction", RegFuncTy
);
382 // Get the __cudaRegisterVar function declaration.
383 auto *RegVarTy
= FunctionType::get(
385 {Int8PtrPtrTy
, Int8PtrTy
, Int8PtrTy
, Int8PtrTy
, Type::getInt32Ty(C
),
386 getSizeTTy(M
), Type::getInt32Ty(C
), Type::getInt32Ty(C
)},
388 FunctionCallee RegVar
= M
.getOrInsertFunction(
389 IsHIP
? "__hipRegisterVar" : "__cudaRegisterVar", RegVarTy
);
391 // Create the references to the start / stop symbols defined by the linker.
393 new GlobalVariable(M
, ArrayType::get(getEntryTy(M
), 0),
394 /*isConstant*/ true, GlobalValue::ExternalLinkage
,
395 /*Initializer*/ nullptr,
396 IsHIP
? "__start_hip_offloading_entries"
397 : "__start_cuda_offloading_entries");
398 EntriesB
->setVisibility(GlobalValue::HiddenVisibility
);
400 new GlobalVariable(M
, ArrayType::get(getEntryTy(M
), 0),
401 /*isConstant*/ true, GlobalValue::ExternalLinkage
,
402 /*Initializer*/ nullptr,
403 IsHIP
? "__stop_hip_offloading_entries"
404 : "__stop_cuda_offloading_entries");
405 EntriesE
->setVisibility(GlobalValue::HiddenVisibility
);
407 auto *RegGlobalsTy
= FunctionType::get(Type::getVoidTy(C
), Int8PtrPtrTy
,
410 Function::Create(RegGlobalsTy
, GlobalValue::InternalLinkage
,
411 IsHIP
? ".hip.globals_reg" : ".cuda.globals_reg", &M
);
412 RegGlobalsFn
->setSection(".text.startup");
414 // Create the loop to register all the entries.
415 IRBuilder
<> Builder(BasicBlock::Create(C
, "entry", RegGlobalsFn
));
416 auto *EntryBB
= BasicBlock::Create(C
, "while.entry", RegGlobalsFn
);
417 auto *IfThenBB
= BasicBlock::Create(C
, "if.then", RegGlobalsFn
);
418 auto *IfElseBB
= BasicBlock::Create(C
, "if.else", RegGlobalsFn
);
419 auto *SwGlobalBB
= BasicBlock::Create(C
, "sw.global", RegGlobalsFn
);
420 auto *SwManagedBB
= BasicBlock::Create(C
, "sw.managed", RegGlobalsFn
);
421 auto *SwSurfaceBB
= BasicBlock::Create(C
, "sw.surface", RegGlobalsFn
);
422 auto *SwTextureBB
= BasicBlock::Create(C
, "sw.texture", RegGlobalsFn
);
423 auto *IfEndBB
= BasicBlock::Create(C
, "if.end", RegGlobalsFn
);
424 auto *ExitBB
= BasicBlock::Create(C
, "while.end", RegGlobalsFn
);
426 auto *EntryCmp
= Builder
.CreateICmpNE(EntriesB
, EntriesE
);
427 Builder
.CreateCondBr(EntryCmp
, EntryBB
, ExitBB
);
428 Builder
.SetInsertPoint(EntryBB
);
429 auto *Entry
= Builder
.CreatePHI(getEntryPtrTy(M
), 2, "entry");
431 Builder
.CreateInBoundsGEP(getEntryTy(M
), Entry
,
432 {ConstantInt::get(getSizeTTy(M
), 0),
433 ConstantInt::get(Type::getInt32Ty(C
), 0)});
434 auto *Addr
= Builder
.CreateLoad(Int8PtrTy
, AddrPtr
, "addr");
436 Builder
.CreateInBoundsGEP(getEntryTy(M
), Entry
,
437 {ConstantInt::get(getSizeTTy(M
), 0),
438 ConstantInt::get(Type::getInt32Ty(C
), 1)});
439 auto *Name
= Builder
.CreateLoad(Int8PtrTy
, NamePtr
, "name");
441 Builder
.CreateInBoundsGEP(getEntryTy(M
), Entry
,
442 {ConstantInt::get(getSizeTTy(M
), 0),
443 ConstantInt::get(Type::getInt32Ty(C
), 2)});
444 auto *Size
= Builder
.CreateLoad(getSizeTTy(M
), SizePtr
, "size");
446 Builder
.CreateInBoundsGEP(getEntryTy(M
), Entry
,
447 {ConstantInt::get(getSizeTTy(M
), 0),
448 ConstantInt::get(Type::getInt32Ty(C
), 3)});
449 auto *Flags
= Builder
.CreateLoad(Type::getInt32Ty(C
), FlagsPtr
, "flag");
451 Builder
.CreateICmpEQ(Size
, ConstantInt::getNullValue(getSizeTTy(M
)));
452 Builder
.CreateCondBr(FnCond
, IfThenBB
, IfElseBB
);
454 // Create kernel registration code.
455 Builder
.SetInsertPoint(IfThenBB
);
456 Builder
.CreateCall(RegFunc
, {RegGlobalsFn
->arg_begin(), Addr
, Name
, Name
,
457 ConstantInt::get(Type::getInt32Ty(C
), -1),
458 ConstantPointerNull::get(Int8PtrTy
),
459 ConstantPointerNull::get(Int8PtrTy
),
460 ConstantPointerNull::get(Int8PtrTy
),
461 ConstantPointerNull::get(Int8PtrTy
),
462 ConstantPointerNull::get(Int32PtrTy
)});
463 Builder
.CreateBr(IfEndBB
);
464 Builder
.SetInsertPoint(IfElseBB
);
466 auto *Switch
= Builder
.CreateSwitch(Flags
, IfEndBB
);
467 // Create global variable registration code.
468 Builder
.SetInsertPoint(SwGlobalBB
);
469 Builder
.CreateCall(RegVar
, {RegGlobalsFn
->arg_begin(), Addr
, Name
, Name
,
470 ConstantInt::get(Type::getInt32Ty(C
), 0), Size
,
471 ConstantInt::get(Type::getInt32Ty(C
), 0),
472 ConstantInt::get(Type::getInt32Ty(C
), 0)});
473 Builder
.CreateBr(IfEndBB
);
474 Switch
->addCase(Builder
.getInt32(OffloadGlobalEntry
), SwGlobalBB
);
476 // Create managed variable registration code.
477 Builder
.SetInsertPoint(SwManagedBB
);
478 Builder
.CreateBr(IfEndBB
);
479 Switch
->addCase(Builder
.getInt32(OffloadGlobalManagedEntry
), SwManagedBB
);
481 // Create surface variable registration code.
482 Builder
.SetInsertPoint(SwSurfaceBB
);
483 Builder
.CreateBr(IfEndBB
);
484 Switch
->addCase(Builder
.getInt32(OffloadGlobalSurfaceEntry
), SwSurfaceBB
);
486 // Create texture variable registration code.
487 Builder
.SetInsertPoint(SwTextureBB
);
488 Builder
.CreateBr(IfEndBB
);
489 Switch
->addCase(Builder
.getInt32(OffloadGlobalTextureEntry
), SwTextureBB
);
491 Builder
.SetInsertPoint(IfEndBB
);
492 auto *NewEntry
= Builder
.CreateInBoundsGEP(
493 getEntryTy(M
), Entry
, ConstantInt::get(getSizeTTy(M
), 1));
494 auto *Cmp
= Builder
.CreateICmpEQ(
496 ConstantExpr::getInBoundsGetElementPtr(
497 ArrayType::get(getEntryTy(M
), 0), EntriesE
,
498 ArrayRef
<Constant
*>({ConstantInt::get(getSizeTTy(M
), 0),
499 ConstantInt::get(getSizeTTy(M
), 0)})));
501 ConstantExpr::getInBoundsGetElementPtr(
502 ArrayType::get(getEntryTy(M
), 0), EntriesB
,
503 ArrayRef
<Constant
*>({ConstantInt::get(getSizeTTy(M
), 0),
504 ConstantInt::get(getSizeTTy(M
), 0)})),
505 &RegGlobalsFn
->getEntryBlock());
506 Entry
->addIncoming(NewEntry
, IfEndBB
);
507 Builder
.CreateCondBr(Cmp
, ExitBB
, EntryBB
);
508 Builder
.SetInsertPoint(ExitBB
);
509 Builder
.CreateRetVoid();
514 // Create the constructor and destructor to register the fatbinary with the CUDA
516 void createRegisterFatbinFunction(Module
&M
, GlobalVariable
*FatbinDesc
,
518 LLVMContext
&C
= M
.getContext();
519 auto *CtorFuncTy
= FunctionType::get(Type::getVoidTy(C
), /*isVarArg*/ false);
521 Function::Create(CtorFuncTy
, GlobalValue::InternalLinkage
,
522 IsHIP
? ".hip.fatbin_reg" : ".cuda.fatbin_reg", &M
);
523 CtorFunc
->setSection(".text.startup");
525 auto *DtorFuncTy
= FunctionType::get(Type::getVoidTy(C
), /*isVarArg*/ false);
527 Function::Create(DtorFuncTy
, GlobalValue::InternalLinkage
,
528 IsHIP
? ".hip.fatbin_unreg" : ".cuda.fatbin_unreg", &M
);
529 DtorFunc
->setSection(".text.startup");
531 // Get the __cudaRegisterFatBinary function declaration.
532 auto *RegFatTy
= FunctionType::get(Type::getInt8PtrTy(C
)->getPointerTo(),
533 Type::getInt8PtrTy(C
),
535 FunctionCallee RegFatbin
= M
.getOrInsertFunction(
536 IsHIP
? "__hipRegisterFatBinary" : "__cudaRegisterFatBinary", RegFatTy
);
537 // Get the __cudaRegisterFatBinaryEnd function declaration.
538 auto *RegFatEndTy
= FunctionType::get(Type::getVoidTy(C
),
539 Type::getInt8PtrTy(C
)->getPointerTo(),
541 FunctionCallee RegFatbinEnd
=
542 M
.getOrInsertFunction("__cudaRegisterFatBinaryEnd", RegFatEndTy
);
543 // Get the __cudaUnregisterFatBinary function declaration.
544 auto *UnregFatTy
= FunctionType::get(Type::getVoidTy(C
),
545 Type::getInt8PtrTy(C
)->getPointerTo(),
547 FunctionCallee UnregFatbin
= M
.getOrInsertFunction(
548 IsHIP
? "__hipUnregisterFatBinary" : "__cudaUnregisterFatBinary",
552 FunctionType::get(Type::getInt32Ty(C
), DtorFuncTy
->getPointerTo(),
554 FunctionCallee AtExit
= M
.getOrInsertFunction("atexit", AtExitTy
);
556 auto *BinaryHandleGlobal
= new llvm::GlobalVariable(
557 M
, Type::getInt8PtrTy(C
)->getPointerTo(), false,
558 llvm::GlobalValue::InternalLinkage
,
559 llvm::ConstantPointerNull::get(Type::getInt8PtrTy(C
)->getPointerTo()),
560 IsHIP
? ".hip.binary_handle" : ".cuda.binary_handle");
562 // Create the constructor to register this image with the runtime.
563 IRBuilder
<> CtorBuilder(BasicBlock::Create(C
, "entry", CtorFunc
));
564 CallInst
*Handle
= CtorBuilder
.CreateCall(
565 RegFatbin
, ConstantExpr::getPointerBitCastOrAddrSpaceCast(
566 FatbinDesc
, Type::getInt8PtrTy(C
)));
567 CtorBuilder
.CreateAlignedStore(
568 Handle
, BinaryHandleGlobal
,
569 Align(M
.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C
))));
570 CtorBuilder
.CreateCall(createRegisterGlobalsFunction(M
, IsHIP
), Handle
);
572 CtorBuilder
.CreateCall(RegFatbinEnd
, Handle
);
573 CtorBuilder
.CreateCall(AtExit
, DtorFunc
);
574 CtorBuilder
.CreateRetVoid();
576 // Create the destructor to unregister the image with the runtime. We cannot
577 // use a standard global destructor after CUDA 9.2 so this must be called by
578 // `atexit()` intead.
579 IRBuilder
<> DtorBuilder(BasicBlock::Create(C
, "entry", DtorFunc
));
580 LoadInst
*BinaryHandle
= DtorBuilder
.CreateAlignedLoad(
581 Type::getInt8PtrTy(C
)->getPointerTo(), BinaryHandleGlobal
,
582 Align(M
.getDataLayout().getPointerTypeSize(Type::getInt8PtrTy(C
))));
583 DtorBuilder
.CreateCall(UnregFatbin
, BinaryHandle
);
584 DtorBuilder
.CreateRetVoid();
586 // Add this function to constructors.
587 appendToGlobalCtors(M
, CtorFunc
, /*Priority*/ 1);
592 Error
wrapOpenMPBinaries(Module
&M
, ArrayRef
<ArrayRef
<char>> Images
) {
593 GlobalVariable
*Desc
= createBinDesc(M
, Images
);
595 return createStringError(inconvertibleErrorCode(),
596 "No binary descriptors created.");
597 createRegisterFunction(M
, Desc
);
598 createUnregisterFunction(M
, Desc
);
599 return Error::success();
602 Error
wrapCudaBinary(Module
&M
, ArrayRef
<char> Image
) {
603 GlobalVariable
*Desc
= createFatbinDesc(M
, Image
, /* IsHIP */ false);
605 return createStringError(inconvertibleErrorCode(),
606 "No fatinbary section created.");
608 createRegisterFatbinFunction(M
, Desc
, /* IsHIP */ false);
609 return Error::success();
612 Error
wrapHIPBinary(Module
&M
, ArrayRef
<char> Image
) {
613 GlobalVariable
*Desc
= createFatbinDesc(M
, Image
, /* IsHIP */ true);
615 return createStringError(inconvertibleErrorCode(),
616 "No fatinbary section created.");
618 createRegisterFatbinFunction(M
, Desc
, /* IsHIP */ true);
619 return Error::success();