[SampleProfileLoader] Fix integer overflow in generateMDProfMetadata (#90217)
[llvm-project.git] / mlir / lib / Target / LLVMIR / Dialect / GPU / SelectObjectAttr.cpp
blob0eb33287d608bd2366ff45741475aee3aa08680e
1 //===- ObjectHandler.cpp - Implements base ObjectManager attributes -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the `OffloadingLLVMTranslationAttrInterface` for the
10 // `SelectObject` attribute.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
16 #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
17 #include "mlir/Target/LLVMIR/Export.h"
18 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/LLVMContext.h"
23 #include "llvm/IR/Module.h"
24 #include "llvm/Support/FormatVariadic.h"
26 using namespace mlir;
28 namespace {
29 // Implementation of the `OffloadingLLVMTranslationAttrInterface` model.
30 class SelectObjectAttrImpl
31 : public gpu::OffloadingLLVMTranslationAttrInterface::FallbackModel<
32 SelectObjectAttrImpl> {
33 public:
34 // Translates a `gpu.binary`, embedding the binary into a host LLVM module as
35 // global binary string.
36 LogicalResult embedBinary(Attribute attribute, Operation *operation,
37 llvm::IRBuilderBase &builder,
38 LLVM::ModuleTranslation &moduleTranslation) const;
40 // Translates a `gpu.launch_func` to a sequence of LLVM instructions resulting
41 // in a kernel launch call.
42 LogicalResult launchKernel(Attribute attribute,
43 Operation *launchFuncOperation,
44 Operation *binaryOperation,
45 llvm::IRBuilderBase &builder,
46 LLVM::ModuleTranslation &moduleTranslation) const;
48 // Returns the selected object for embedding.
49 gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op) const;
51 // Returns an identifier for the global string holding the binary.
52 std::string getBinaryIdentifier(StringRef binaryName) {
53 return binaryName.str() + "_bin_cst";
55 } // namespace
57 void mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(
58 DialectRegistry &registry) {
59 registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
60 SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
61 });
64 gpu::ObjectAttr
65 SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op) const {
66 ArrayRef<Attribute> objects = op.getObjectsAttr().getValue();
68 // Obtain the index of the object to select.
69 int64_t index = -1;
70 if (Attribute target =
71 cast<gpu::SelectObjectAttr>(op.getOffloadingHandlerAttr())
72 .getTarget()) {
73 // If the target attribute is a number it is the index. Otherwise compare
74 // the attribute to every target inside the object array to find the index.
75 if (auto indexAttr = mlir::dyn_cast<IntegerAttr>(target)) {
76 index = indexAttr.getInt();
77 } else {
78 for (auto [i, attr] : llvm::enumerate(objects)) {
79 auto obj = mlir::dyn_cast<gpu::ObjectAttr>(attr);
80 if (obj.getTarget() == target) {
81 index = i;
85 } else {
86 // If the target attribute is null then it's selecting the first object in
87 // the object array.
88 index = 0;
91 if (index < 0 || index >= static_cast<int64_t>(objects.size())) {
92 op->emitError("the requested target object couldn't be found");
93 return nullptr;
95 return mlir::dyn_cast<gpu::ObjectAttr>(objects[index]);
98 LogicalResult SelectObjectAttrImpl::embedBinary(
99 Attribute attribute, Operation *operation, llvm::IRBuilderBase &builder,
100 LLVM::ModuleTranslation &moduleTranslation) const {
101 assert(operation && "The binary operation must be non null.");
102 if (!operation)
103 return failure();
105 auto op = mlir::dyn_cast<gpu::BinaryOp>(operation);
106 if (!op) {
107 operation->emitError("operation must be a GPU binary");
108 return failure();
111 gpu::ObjectAttr object = getSelectedObject(op);
112 if (!object)
113 return failure();
115 llvm::Module *module = moduleTranslation.getLLVMModule();
117 // Embed the object as a global string.
118 llvm::Constant *binary = llvm::ConstantDataArray::getString(
119 builder.getContext(), object.getObject().getValue(), false);
120 llvm::GlobalVariable *serializedObj =
121 new llvm::GlobalVariable(*module, binary->getType(), true,
122 llvm::GlobalValue::LinkageTypes::InternalLinkage,
123 binary, getBinaryIdentifier(op.getName()));
124 serializedObj->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage);
125 serializedObj->setAlignment(llvm::MaybeAlign(8));
126 serializedObj->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::None);
127 return success();
130 namespace llvm {
131 namespace {
132 class LaunchKernel {
133 public:
134 LaunchKernel(Module &module, IRBuilderBase &builder,
135 mlir::LLVM::ModuleTranslation &moduleTranslation);
136 // Get the kernel launch callee.
137 FunctionCallee getKernelLaunchFn();
139 // Get the kernel launch callee.
140 FunctionCallee getClusterKernelLaunchFn();
142 // Get the module function callee.
143 FunctionCallee getModuleFunctionFn();
145 // Get the module load callee.
146 FunctionCallee getModuleLoadFn();
148 // Get the module load JIT callee.
149 FunctionCallee getModuleLoadJITFn();
151 // Get the module unload callee.
152 FunctionCallee getModuleUnloadFn();
154 // Get the stream create callee.
155 FunctionCallee getStreamCreateFn();
157 // Get the stream destroy callee.
158 FunctionCallee getStreamDestroyFn();
160 // Get the stream sync callee.
161 FunctionCallee getStreamSyncFn();
163 // Ger or create the function name global string.
164 Value *getOrCreateFunctionName(StringRef moduleName, StringRef kernelName);
166 // Create the void* kernel array for passing the arguments.
167 Value *createKernelArgArray(mlir::gpu::LaunchFuncOp op);
169 // Create the full kernel launch.
170 mlir::LogicalResult createKernelLaunch(mlir::gpu::LaunchFuncOp op,
171 mlir::gpu::ObjectAttr object);
173 private:
174 Module &module;
175 IRBuilderBase &builder;
176 mlir::LLVM::ModuleTranslation &moduleTranslation;
177 Type *i32Ty{};
178 Type *i64Ty{};
179 Type *voidTy{};
180 Type *intPtrTy{};
181 PointerType *ptrTy{};
183 } // namespace
184 } // namespace llvm
186 LogicalResult SelectObjectAttrImpl::launchKernel(
187 Attribute attribute, Operation *launchFuncOperation,
188 Operation *binaryOperation, llvm::IRBuilderBase &builder,
189 LLVM::ModuleTranslation &moduleTranslation) const {
191 assert(launchFuncOperation && "The launch func operation must be non null.");
192 if (!launchFuncOperation)
193 return failure();
195 auto launchFuncOp = mlir::dyn_cast<gpu::LaunchFuncOp>(launchFuncOperation);
196 if (!launchFuncOp) {
197 launchFuncOperation->emitError("operation must be a GPU launch func Op.");
198 return failure();
201 auto binOp = mlir::dyn_cast<gpu::BinaryOp>(binaryOperation);
202 if (!binOp) {
203 binaryOperation->emitError("operation must be a GPU binary.");
204 return failure();
206 gpu::ObjectAttr object = getSelectedObject(binOp);
207 if (!object)
208 return failure();
210 return llvm::LaunchKernel(*moduleTranslation.getLLVMModule(), builder,
211 moduleTranslation)
212 .createKernelLaunch(launchFuncOp, object);
215 llvm::LaunchKernel::LaunchKernel(
216 Module &module, IRBuilderBase &builder,
217 mlir::LLVM::ModuleTranslation &moduleTranslation)
218 : module(module), builder(builder), moduleTranslation(moduleTranslation) {
219 i32Ty = builder.getInt32Ty();
220 i64Ty = builder.getInt64Ty();
221 ptrTy = builder.getPtrTy(0);
222 voidTy = builder.getVoidTy();
223 intPtrTy = builder.getIntPtrTy(module.getDataLayout());
226 llvm::FunctionCallee llvm::LaunchKernel::getKernelLaunchFn() {
227 return module.getOrInsertFunction(
228 "mgpuLaunchKernel",
229 FunctionType::get(voidTy,
230 ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy,
231 intPtrTy, intPtrTy, intPtrTy, i32Ty,
232 ptrTy, ptrTy, ptrTy, i64Ty}),
233 false));
236 llvm::FunctionCallee llvm::LaunchKernel::getClusterKernelLaunchFn() {
237 return module.getOrInsertFunction(
238 "mgpuLaunchClusterKernel",
239 FunctionType::get(
240 voidTy,
241 ArrayRef<Type *>({ptrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
242 intPtrTy, intPtrTy, intPtrTy, intPtrTy, intPtrTy,
243 i32Ty, ptrTy, ptrTy, ptrTy}),
244 false));
247 llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
248 return module.getOrInsertFunction(
249 "mgpuModuleGetFunction",
250 FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, ptrTy}), false));
253 llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadFn() {
254 return module.getOrInsertFunction(
255 "mgpuModuleLoad",
256 FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, i64Ty}), false));
259 llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadJITFn() {
260 return module.getOrInsertFunction(
261 "mgpuModuleLoadJIT",
262 FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, i32Ty}), false));
265 llvm::FunctionCallee llvm::LaunchKernel::getModuleUnloadFn() {
266 return module.getOrInsertFunction(
267 "mgpuModuleUnload",
268 FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
271 llvm::FunctionCallee llvm::LaunchKernel::getStreamCreateFn() {
272 return module.getOrInsertFunction("mgpuStreamCreate",
273 FunctionType::get(ptrTy, false));
276 llvm::FunctionCallee llvm::LaunchKernel::getStreamDestroyFn() {
277 return module.getOrInsertFunction(
278 "mgpuStreamDestroy",
279 FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
282 llvm::FunctionCallee llvm::LaunchKernel::getStreamSyncFn() {
283 return module.getOrInsertFunction(
284 "mgpuStreamSynchronize",
285 FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
288 // Generates an LLVM IR dialect global that contains the name of the given
289 // kernel function as a C string, and returns a pointer to its beginning.
290 llvm::Value *llvm::LaunchKernel::getOrCreateFunctionName(StringRef moduleName,
291 StringRef kernelName) {
292 std::string globalName =
293 std::string(formatv("{0}_{1}_kernel_name", moduleName, kernelName));
295 if (GlobalVariable *gv = module.getGlobalVariable(globalName))
296 return gv;
298 return builder.CreateGlobalString(kernelName, globalName);
301 // Creates a struct containing all kernel parameters on the stack and returns
302 // an array of type-erased pointers to the fields of the struct. The array can
303 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
304 // The generated code is essentially as follows:
306 // %struct = alloca(sizeof(struct { Parameters... }))
307 // %array = alloca(NumParameters * sizeof(void *))
308 // for (i : [0, NumParameters))
309 // %fieldPtr = llvm.getelementptr %struct[0, i]
310 // llvm.store parameters[i], %fieldPtr
311 // %elementPtr = llvm.getelementptr %array[i]
312 // llvm.store %fieldPtr, %elementPtr
313 // return %array
314 llvm::Value *
315 llvm::LaunchKernel::createKernelArgArray(mlir::gpu::LaunchFuncOp op) {
316 SmallVector<Value *> args =
317 moduleTranslation.lookupValues(op.getKernelOperands());
318 SmallVector<Type *> structTypes(args.size(), nullptr);
320 for (auto [i, arg] : llvm::enumerate(args))
321 structTypes[i] = arg->getType();
323 Type *structTy = StructType::create(module.getContext(), structTypes);
324 Value *argStruct = builder.CreateAlloca(structTy, 0u);
325 Value *argArray = builder.CreateAlloca(
326 ptrTy, ConstantInt::get(intPtrTy, structTypes.size()));
328 for (auto [i, arg] : enumerate(args)) {
329 Value *structMember = builder.CreateStructGEP(structTy, argStruct, i);
330 builder.CreateStore(arg, structMember);
331 Value *arrayMember = builder.CreateConstGEP1_32(ptrTy, argArray, i);
332 builder.CreateStore(structMember, arrayMember);
334 return argArray;
337 // Emits LLVM IR to launch a kernel function:
338 // %0 = call %binarygetter
339 // %1 = call %moduleLoad(%0)
340 // %2 = <see generateKernelNameConstant>
341 // %3 = call %moduleGetFunction(%1, %2)
342 // %4 = call %streamCreate()
343 // %5 = <see generateParamsArray>
344 // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
345 // call %streamSynchronize(%4)
346 // call %streamDestroy(%4)
347 // call %moduleUnload(%1)
348 mlir::LogicalResult
349 llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
350 mlir::gpu::ObjectAttr object) {
351 auto llvmValue = [&](mlir::Value value) -> Value * {
352 Value *v = moduleTranslation.lookupValue(value);
353 assert(v && "Value has not been translated.");
354 return v;
357 // Get grid dimensions.
358 mlir::gpu::KernelDim3 grid = op.getGridSizeOperandValues();
359 Value *gx = llvmValue(grid.x), *gy = llvmValue(grid.y),
360 *gz = llvmValue(grid.z);
362 // Get block dimensions.
363 mlir::gpu::KernelDim3 block = op.getBlockSizeOperandValues();
364 Value *bx = llvmValue(block.x), *by = llvmValue(block.y),
365 *bz = llvmValue(block.z);
367 // Get dynamic shared memory size.
368 Value *dynamicMemorySize = nullptr;
369 if (mlir::Value dynSz = op.getDynamicSharedMemorySize())
370 dynamicMemorySize = llvmValue(dynSz);
371 else
372 dynamicMemorySize = ConstantInt::get(i32Ty, 0);
374 // Create the argument array.
375 Value *argArray = createKernelArgArray(op);
377 // Default JIT optimization level.
378 llvm::Constant *optV = llvm::ConstantInt::get(i32Ty, 0);
379 // Check if there's an optimization level embedded in the object.
380 DictionaryAttr objectProps = object.getProperties();
381 mlir::Attribute optAttr;
382 if (objectProps && (optAttr = objectProps.get("O"))) {
383 auto optLevel = dyn_cast<IntegerAttr>(optAttr);
384 if (!optLevel)
385 return op.emitError("the optimization level must be an integer");
386 optV = llvm::ConstantInt::get(i32Ty, optLevel.getValue());
389 // Load the kernel module.
390 StringRef moduleName = op.getKernelModuleName().getValue();
391 std::string binaryIdentifier = getBinaryIdentifier(moduleName);
392 Value *binary = module.getGlobalVariable(binaryIdentifier, true);
393 if (!binary)
394 return op.emitError() << "Couldn't find the binary: " << binaryIdentifier;
396 auto binaryVar = dyn_cast<llvm::GlobalVariable>(binary);
397 if (!binaryVar)
398 return op.emitError() << "Binary is not a global variable: "
399 << binaryIdentifier;
400 llvm::Constant *binaryInit = binaryVar->getInitializer();
401 auto binaryDataSeq =
402 dyn_cast_if_present<llvm::ConstantDataSequential>(binaryInit);
403 if (!binaryDataSeq)
404 return op.emitError() << "Couldn't find binary data array: "
405 << binaryIdentifier;
406 llvm::Constant *binarySize =
407 llvm::ConstantInt::get(i64Ty, binaryDataSeq->getNumElements() *
408 binaryDataSeq->getElementByteSize());
410 Value *moduleObject =
411 object.getFormat() == gpu::CompilationTarget::Assembly
412 ? builder.CreateCall(getModuleLoadJITFn(), {binary, optV})
413 : builder.CreateCall(getModuleLoadFn(), {binary, binarySize});
415 // Load the kernel function.
416 Value *moduleFunction = builder.CreateCall(
417 getModuleFunctionFn(),
418 {moduleObject,
419 getOrCreateFunctionName(moduleName, op.getKernelName().getValue())});
421 // Get the stream to use for execution. If there's no async object then create
422 // a stream to make a synchronous kernel launch.
423 Value *stream = nullptr;
424 bool handleStream = false;
425 if (mlir::Value asyncObject = op.getAsyncObject()) {
426 stream = llvmValue(asyncObject);
427 } else {
428 handleStream = true;
429 stream = builder.CreateCall(getStreamCreateFn(), {});
432 llvm::Constant *paramsCount =
433 llvm::ConstantInt::get(i64Ty, op.getNumKernelOperands());
435 // Create the launch call.
436 Value *nullPtr = ConstantPointerNull::get(ptrTy);
438 // Launch kernel with clusters if cluster size is specified.
439 if (op.hasClusterSize()) {
440 mlir::gpu::KernelDim3 cluster = op.getClusterSizeOperandValues();
441 Value *cx = llvmValue(cluster.x), *cy = llvmValue(cluster.y),
442 *cz = llvmValue(cluster.z);
443 builder.CreateCall(
444 getClusterKernelLaunchFn(),
445 ArrayRef<Value *>({moduleFunction, cx, cy, cz, gx, gy, gz, bx, by, bz,
446 dynamicMemorySize, stream, argArray, nullPtr}));
447 } else {
448 builder.CreateCall(getKernelLaunchFn(),
449 ArrayRef<Value *>({moduleFunction, gx, gy, gz, bx, by,
450 bz, dynamicMemorySize, stream,
451 argArray, nullPtr, paramsCount}));
454 // Sync & destroy the stream, for synchronous launches.
455 if (handleStream) {
456 builder.CreateCall(getStreamSyncFn(), {stream});
457 builder.CreateCall(getStreamDestroyFn(), {stream});
460 // Unload the kernel module.
461 builder.CreateCall(getModuleUnloadFn(), {moduleObject});
463 return success();