[Clang] Prevent `mlink-builtin-bitcode` from internalizing the RPC client (#118661)
[llvm-project.git] / mlir / lib / Target / LLVMIR / Dialect / GPU / SelectObjectAttr.cpp
blobade239c526af864c2a46f3d34f1f931dfa748fb8
1 //===- ObjectHandler.cpp - Implements base ObjectManager attributes -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the `OffloadingLLVMTranslationAttrInterface` for the
10 // `SelectObject` attribute.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
15 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17 #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
18 #include "mlir/Target/LLVMIR/Export.h"
19 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/LLVMContext.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/Support/FormatVariadic.h"
27 using namespace mlir;
29 namespace {
30 // Implementation of the `OffloadingLLVMTranslationAttrInterface` model.
31 class SelectObjectAttrImpl
32 : public gpu::OffloadingLLVMTranslationAttrInterface::FallbackModel<
33 SelectObjectAttrImpl> {
34 public:
35 // Translates a `gpu.binary`, embedding the binary into a host LLVM module as
36 // global binary string.
37 LogicalResult embedBinary(Attribute attribute, Operation *operation,
38 llvm::IRBuilderBase &builder,
39 LLVM::ModuleTranslation &moduleTranslation) const;
41 // Translates a `gpu.launch_func` to a sequence of LLVM instructions resulting
42 // in a kernel launch call.
43 LogicalResult launchKernel(Attribute attribute,
44 Operation *launchFuncOperation,
45 Operation *binaryOperation,
46 llvm::IRBuilderBase &builder,
47 LLVM::ModuleTranslation &moduleTranslation) const;
49 // Returns the selected object for embedding.
50 gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op) const;
52 // Returns an identifier for the global string holding the binary.
53 std::string getBinaryIdentifier(StringRef binaryName) {
54 return binaryName.str() + "_bin_cst";
56 } // namespace
58 void mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(
59 DialectRegistry &registry) {
60 registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
61 SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
62 });
65 gpu::ObjectAttr
66 SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op) const {
67 ArrayRef<Attribute> objects = op.getObjectsAttr().getValue();
69 // Obtain the index of the object to select.
70 int64_t index = -1;
71 if (Attribute target =
72 cast<gpu::SelectObjectAttr>(op.getOffloadingHandlerAttr())
73 .getTarget()) {
74 // If the target attribute is a number it is the index. Otherwise compare
75 // the attribute to every target inside the object array to find the index.
76 if (auto indexAttr = mlir::dyn_cast<IntegerAttr>(target)) {
77 index = indexAttr.getInt();
78 } else {
79 for (auto [i, attr] : llvm::enumerate(objects)) {
80 auto obj = mlir::dyn_cast<gpu::ObjectAttr>(attr);
81 if (obj.getTarget() == target) {
82 index = i;
86 } else {
87 // If the target attribute is null then it's selecting the first object in
88 // the object array.
89 index = 0;
92 if (index < 0 || index >= static_cast<int64_t>(objects.size())) {
93 op->emitError("the requested target object couldn't be found");
94 return nullptr;
96 return mlir::dyn_cast<gpu::ObjectAttr>(objects[index]);
99 LogicalResult SelectObjectAttrImpl::embedBinary(
100 Attribute attribute, Operation *operation, llvm::IRBuilderBase &builder,
101 LLVM::ModuleTranslation &moduleTranslation) const {
102 assert(operation && "The binary operation must be non null.");
103 if (!operation)
104 return failure();
106 auto op = mlir::dyn_cast<gpu::BinaryOp>(operation);
107 if (!op) {
108 operation->emitError("operation must be a GPU binary");
109 return failure();
112 gpu::ObjectAttr object = getSelectedObject(op);
113 if (!object)
114 return failure();
116 llvm::Module *module = moduleTranslation.getLLVMModule();
118 // Embed the object as a global string.
119 llvm::Constant *binary = llvm::ConstantDataArray::getString(
120 builder.getContext(), object.getObject().getValue(), false);
121 llvm::GlobalVariable *serializedObj =
122 new llvm::GlobalVariable(*module, binary->getType(), true,
123 llvm::GlobalValue::LinkageTypes::InternalLinkage,
124 binary, getBinaryIdentifier(op.getName()));
126 if (object.getProperties()) {
127 if (auto section = mlir::dyn_cast_or_null<mlir::StringAttr>(
128 object.getProperties().get(gpu::elfSectionName))) {
129 serializedObj->setSection(section.getValue());
132 serializedObj->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage);
133 serializedObj->setAlignment(llvm::MaybeAlign(8));
134 serializedObj->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::None);
135 return success();
138 namespace llvm {
139 namespace {
140 class LaunchKernel {
141 public:
142 LaunchKernel(Module &module, IRBuilderBase &builder,
143 mlir::LLVM::ModuleTranslation &moduleTranslation);
144 // Get the kernel launch callee.
145 FunctionCallee getKernelLaunchFn();
147 // Get the kernel launch callee.
148 FunctionCallee getClusterKernelLaunchFn();
150 // Get the module function callee.
151 FunctionCallee getModuleFunctionFn();
153 // Get the module load callee.
154 FunctionCallee getModuleLoadFn();
156 // Get the module load JIT callee.
157 FunctionCallee getModuleLoadJITFn();
159 // Get the module unload callee.
160 FunctionCallee getModuleUnloadFn();
162 // Get the stream create callee.
163 FunctionCallee getStreamCreateFn();
165 // Get the stream destroy callee.
166 FunctionCallee getStreamDestroyFn();
168 // Get the stream sync callee.
169 FunctionCallee getStreamSyncFn();
171 // Ger or create the function name global string.
172 Value *getOrCreateFunctionName(StringRef moduleName, StringRef kernelName);
174 // Create the void* kernel array for passing the arguments.
175 Value *createKernelArgArray(mlir::gpu::LaunchFuncOp op);
177 // Create the full kernel launch.
178 llvm::LogicalResult createKernelLaunch(mlir::gpu::LaunchFuncOp op,
179 mlir::gpu::ObjectAttr object);
181 private:
182 Module &module;
183 IRBuilderBase &builder;
184 mlir::LLVM::ModuleTranslation &moduleTranslation;
185 Type *i32Ty{};
186 Type *i64Ty{};
187 Type *voidTy{};
188 Type *intPtrTy{};
189 PointerType *ptrTy{};
191 } // namespace
192 } // namespace llvm
194 LogicalResult SelectObjectAttrImpl::launchKernel(
195 Attribute attribute, Operation *launchFuncOperation,
196 Operation *binaryOperation, llvm::IRBuilderBase &builder,
197 LLVM::ModuleTranslation &moduleTranslation) const {
199 assert(launchFuncOperation && "The launch func operation must be non null.");
200 if (!launchFuncOperation)
201 return failure();
203 auto launchFuncOp = mlir::dyn_cast<gpu::LaunchFuncOp>(launchFuncOperation);
204 if (!launchFuncOp) {
205 launchFuncOperation->emitError("operation must be a GPU launch func Op.");
206 return failure();
209 auto binOp = mlir::dyn_cast<gpu::BinaryOp>(binaryOperation);
210 if (!binOp) {
211 binaryOperation->emitError("operation must be a GPU binary.");
212 return failure();
214 gpu::ObjectAttr object = getSelectedObject(binOp);
215 if (!object)
216 return failure();
218 return llvm::LaunchKernel(*moduleTranslation.getLLVMModule(), builder,
219 moduleTranslation)
220 .createKernelLaunch(launchFuncOp, object);
223 llvm::LaunchKernel::LaunchKernel(
224 Module &module, IRBuilderBase &builder,
225 mlir::LLVM::ModuleTranslation &moduleTranslation)
226 : module(module), builder(builder), moduleTranslation(moduleTranslation) {
227 i32Ty = builder.getInt32Ty();
228 i64Ty = builder.getInt64Ty();
229 ptrTy = builder.getPtrTy(0);
230 voidTy = builder.getVoidTy();
231 intPtrTy = builder.getIntPtrTy(module.getDataLayout());
234 llvm::FunctionCallee llvm::LaunchKernel::getKernelLaunchFn() {
235 return module.getOrInsertFunction(
236 "mgpuLaunchKernel",
237 FunctionType::get(voidTy,
238 ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy,
239 intPtrTy, intPtrTy, intPtrTy, i32Ty,
240 ptrTy, ptrTy, ptrTy, i64Ty}),
241 false));
244 llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn() {
245 return module.getOrInsertFunction(
246 "mgpuLaunchClusterKernel",
247 FunctionType::get(
248 voidTy,
249 ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
250 intPtrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
251 i32Ty, ptrTy, ptrTy, ptrTy}),
252 false));
255 llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
256 return module.getOrInsertFunction(
257 "mgpuModuleGetFunction",
258 FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, ptrTy}), false));
261 llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadFn() {
262 return module.getOrInsertFunction(
263 "mgpuModuleLoad",
264 FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, i64Ty}), false));
267 llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadJITFn() {
268 return module.getOrInsertFunction(
269 "mgpuModuleLoadJIT",
270 FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, i32Ty}), false));
273 llvm::FunctionCallee llvm::LaunchKernel::getModuleUnloadFn() {
274 return module.getOrInsertFunction(
275 "mgpuModuleUnload",
276 FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
279 llvm::FunctionCallee llvm::LaunchKernel::getStreamCreateFn() {
280 return module.getOrInsertFunction("mgpuStreamCreate",
281 FunctionType::get(ptrTy, false));
284 llvm::FunctionCallee llvm::LaunchKernel::getStreamDestroyFn() {
285 return module.getOrInsertFunction(
286 "mgpuStreamDestroy",
287 FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
290 llvm::FunctionCallee llvm::LaunchKernel::getStreamSyncFn() {
291 return module.getOrInsertFunction(
292 "mgpuStreamSynchronize",
293 FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
296 // Generates an LLVM IR dialect global that contains the name of the given
297 // kernel function as a C string, and returns a pointer to its beginning.
298 llvm::Value *llvm::LaunchKernel::getOrCreateFunctionName(StringRef moduleName,
299 StringRef kernelName) {
300 std::string globalName =
301 std::string(formatv("{0}_{1}_kernel_name", moduleName, kernelName));
303 if (GlobalVariable *gv = module.getGlobalVariable(globalName))
304 return gv;
306 return builder.CreateGlobalString(kernelName, globalName);
309 // Creates a struct containing all kernel parameters on the stack and returns
310 // an array of type-erased pointers to the fields of the struct. The array can
311 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
312 // The generated code is essentially as follows:
314 // %struct = alloca(sizeof(struct { Parameters... }))
315 // %array = alloca(NumParameters * sizeof(void *))
316 // for (i : [0, NumParameters))
317 // %fieldPtr = llvm.getelementptr %struct[0, i]
318 // llvm.store parameters[i], %fieldPtr
319 // %elementPtr = llvm.getelementptr %array[i]
320 // llvm.store %fieldPtr, %elementPtr
321 // return %array
322 llvm::Value *
323 llvm::LaunchKernel::createKernelArgArray(mlir::gpu::LaunchFuncOp op) {
324 SmallVector<Value *> args =
325 moduleTranslation.lookupValues(op.getKernelOperands());
326 SmallVector<Type *> structTypes(args.size(), nullptr);
328 for (auto [i, arg] : llvm::enumerate(args))
329 structTypes[i] = arg->getType();
331 Type *structTy = StructType::create(module.getContext(), structTypes);
332 Value *argStruct = builder.CreateAlloca(structTy, 0u);
333 Value *argArray = builder.CreateAlloca(
334 ptrTy, ConstantInt::get(intPtrTy, structTypes.size()));
336 for (auto [i, arg] : enumerate(args)) {
337 Value *structMember = builder.CreateStructGEP(structTy, argStruct, i);
338 builder.CreateStore(arg, structMember);
339 Value *arrayMember = builder.CreateConstGEP1_32(ptrTy, argArray, i);
340 builder.CreateStore(structMember, arrayMember);
342 return argArray;
345 // Emits LLVM IR to launch a kernel function:
346 // %0 = call %binarygetter
347 // %1 = call %moduleLoad(%0)
348 // %2 = <see generateKernelNameConstant>
349 // %3 = call %moduleGetFunction(%1, %2)
350 // %4 = call %streamCreate()
351 // %5 = <see generateParamsArray>
352 // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
353 // call %streamSynchronize(%4)
354 // call %streamDestroy(%4)
355 // call %moduleUnload(%1)
356 llvm::LogicalResult
357 llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
358 mlir::gpu::ObjectAttr object) {
359 auto llvmValue = [&](mlir::Value value) -> Value * {
360 Value *v = moduleTranslation.lookupValue(value);
361 assert(v && "Value has not been translated.");
362 return v;
365 // Get grid dimensions.
366 mlir::gpu::KernelDim3 grid = op.getGridSizeOperandValues();
367 Value *gx = llvmValue(grid.x), *gy = llvmValue(grid.y),
368 *gz = llvmValue(grid.z);
370 // Get block dimensions.
371 mlir::gpu::KernelDim3 block = op.getBlockSizeOperandValues();
372 Value *bx = llvmValue(block.x), *by = llvmValue(block.y),
373 *bz = llvmValue(block.z);
375 // Get dynamic shared memory size.
376 Value *dynamicMemorySize = nullptr;
377 if (mlir::Value dynSz = op.getDynamicSharedMemorySize())
378 dynamicMemorySize = llvmValue(dynSz);
379 else
380 dynamicMemorySize = ConstantInt::get(i32Ty, 0);
382 // Create the argument array.
383 Value *argArray = createKernelArgArray(op);
385 // Default JIT optimization level.
386 llvm::Constant *optV = llvm::ConstantInt::get(i32Ty, 0);
387 // Check if there's an optimization level embedded in the object.
388 DictionaryAttr objectProps = object.getProperties();
389 mlir::Attribute optAttr;
390 if (objectProps && (optAttr = objectProps.get("O"))) {
391 auto optLevel = dyn_cast<IntegerAttr>(optAttr);
392 if (!optLevel)
393 return op.emitError("the optimization level must be an integer");
394 optV = llvm::ConstantInt::get(i32Ty, optLevel.getValue());
397 // Load the kernel module.
398 StringRef moduleName = op.getKernelModuleName().getValue();
399 std::string binaryIdentifier = getBinaryIdentifier(moduleName);
400 Value *binary = module.getGlobalVariable(binaryIdentifier, true);
401 if (!binary)
402 return op.emitError() << "Couldn't find the binary: " << binaryIdentifier;
404 auto binaryVar = dyn_cast<llvm::GlobalVariable>(binary);
405 if (!binaryVar)
406 return op.emitError() << "Binary is not a global variable: "
407 << binaryIdentifier;
408 llvm::Constant *binaryInit = binaryVar->getInitializer();
409 auto binaryDataSeq =
410 dyn_cast_if_present<llvm::ConstantDataSequential>(binaryInit);
411 if (!binaryDataSeq)
412 return op.emitError() << "Couldn't find binary data array: "
413 << binaryIdentifier;
414 llvm::Constant *binarySize =
415 llvm::ConstantInt::get(i64Ty, binaryDataSeq->getNumElements() *
416 binaryDataSeq->getElementByteSize());
418 Value *moduleObject =
419 object.getFormat() == gpu::CompilationTarget::Assembly
420 ? builder.CreateCall(getModuleLoadJITFn(), {binary, optV})
421 : builder.CreateCall(getModuleLoadFn(), {binary, binarySize});
423 // Load the kernel function.
424 Value *moduleFunction = builder.CreateCall(
425 getModuleFunctionFn(),
426 {moduleObject,
427 getOrCreateFunctionName(moduleName, op.getKernelName().getValue())});
429 // Get the stream to use for execution. If there's no async object then create
430 // a stream to make a synchronous kernel launch.
431 Value *stream = nullptr;
432 bool handleStream = false;
433 if (mlir::Value asyncObject = op.getAsyncObject()) {
434 stream = llvmValue(asyncObject);
435 } else {
436 handleStream = true;
437 stream = builder.CreateCall(getStreamCreateFn(), {});
440 llvm::Constant *paramsCount =
441 llvm::ConstantInt::get(i64Ty, op.getNumKernelOperands());
443 // Create the launch call.
444 Value *nullPtr = ConstantPointerNull::get(ptrTy);
446 // Launch kernel with clusters if cluster size is specified.
447 if (op.hasClusterSize()) {
448 mlir::gpu::KernelDim3 cluster = op.getClusterSizeOperandValues();
449 Value *cx = llvmValue(cluster.x), *cy = llvmValue(cluster.y),
450 *cz = llvmValue(cluster.z);
451 builder.CreateCall(
452 getClusterKernelLaunchFn(),
453 ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz,
454 dynamicMemorySize, stream, argArray, nullPtr}));
455 } else {
456 builder.CreateCall(getKernelLaunchFn(),
457 ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by,
458 bz, dynamicMemorySize, stream,
459 argArray, nullPtr, paramsCount}));
462 // Sync & destroy the stream, for synchronous launches.
463 if (handleStream) {
464 builder.CreateCall(getStreamSyncFn(), {stream});
465 builder.CreateCall(getStreamDestroyFn(), {stream});
468 // Unload the kernel module.
469 builder.CreateCall(getModuleUnloadFn(), {moduleObject});
471 return success();