1 //===- ObjectHandler.cpp - Implements base ObjectManager attributes -------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the `OffloadingLLVMTranslationAttrInterface` for the
10 // `SelectObject` attribute.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
15 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17 #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
18 #include "mlir/Target/LLVMIR/Export.h"
19 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/LLVMContext.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/Support/FormatVariadic.h"
30 // Implementation of the `OffloadingLLVMTranslationAttrInterface` model.
31 class SelectObjectAttrImpl
32 : public gpu::OffloadingLLVMTranslationAttrInterface::FallbackModel
<
33 SelectObjectAttrImpl
> {
35 // Translates a `gpu.binary`, embedding the binary into a host LLVM module as
36 // global binary string.
37 LogicalResult
embedBinary(Attribute attribute
, Operation
*operation
,
38 llvm::IRBuilderBase
&builder
,
39 LLVM::ModuleTranslation
&moduleTranslation
) const;
41 // Translates a `gpu.launch_func` to a sequence of LLVM instructions resulting
42 // in a kernel launch call.
43 LogicalResult
launchKernel(Attribute attribute
,
44 Operation
*launchFuncOperation
,
45 Operation
*binaryOperation
,
46 llvm::IRBuilderBase
&builder
,
47 LLVM::ModuleTranslation
&moduleTranslation
) const;
49 // Returns the selected object for embedding.
50 gpu::ObjectAttr
getSelectedObject(gpu::BinaryOp op
) const;
52 // Returns an identifier for the global string holding the binary.
53 std::string
getBinaryIdentifier(StringRef binaryName
) {
54 return binaryName
.str() + "_bin_cst";
58 void mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(
59 DialectRegistry
®istry
) {
60 registry
.addExtension(+[](MLIRContext
*ctx
, gpu::GPUDialect
*dialect
) {
61 SelectObjectAttr::attachInterface
<SelectObjectAttrImpl
>(*ctx
);
66 SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op
) const {
67 ArrayRef
<Attribute
> objects
= op
.getObjectsAttr().getValue();
69 // Obtain the index of the object to select.
71 if (Attribute target
=
72 cast
<gpu::SelectObjectAttr
>(op
.getOffloadingHandlerAttr())
74 // If the target attribute is a number it is the index. Otherwise compare
75 // the attribute to every target inside the object array to find the index.
76 if (auto indexAttr
= mlir::dyn_cast
<IntegerAttr
>(target
)) {
77 index
= indexAttr
.getInt();
79 for (auto [i
, attr
] : llvm::enumerate(objects
)) {
80 auto obj
= mlir::dyn_cast
<gpu::ObjectAttr
>(attr
);
81 if (obj
.getTarget() == target
) {
87 // If the target attribute is null then it's selecting the first object in
92 if (index
< 0 || index
>= static_cast<int64_t>(objects
.size())) {
93 op
->emitError("the requested target object couldn't be found");
96 return mlir::dyn_cast
<gpu::ObjectAttr
>(objects
[index
]);
99 LogicalResult
SelectObjectAttrImpl::embedBinary(
100 Attribute attribute
, Operation
*operation
, llvm::IRBuilderBase
&builder
,
101 LLVM::ModuleTranslation
&moduleTranslation
) const {
102 assert(operation
&& "The binary operation must be non null.");
106 auto op
= mlir::dyn_cast
<gpu::BinaryOp
>(operation
);
108 operation
->emitError("operation must be a GPU binary");
112 gpu::ObjectAttr object
= getSelectedObject(op
);
116 llvm::Module
*module
= moduleTranslation
.getLLVMModule();
118 // Embed the object as a global string.
119 llvm::Constant
*binary
= llvm::ConstantDataArray::getString(
120 builder
.getContext(), object
.getObject().getValue(), false);
121 llvm::GlobalVariable
*serializedObj
=
122 new llvm::GlobalVariable(*module
, binary
->getType(), true,
123 llvm::GlobalValue::LinkageTypes::InternalLinkage
,
124 binary
, getBinaryIdentifier(op
.getName()));
126 if (object
.getProperties()) {
127 if (auto section
= mlir::dyn_cast_or_null
<mlir::StringAttr
>(
128 object
.getProperties().get(gpu::elfSectionName
))) {
129 serializedObj
->setSection(section
.getValue());
132 serializedObj
->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage
);
133 serializedObj
->setAlignment(llvm::MaybeAlign(8));
134 serializedObj
->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::None
);
142 LaunchKernel(Module
&module
, IRBuilderBase
&builder
,
143 mlir::LLVM::ModuleTranslation
&moduleTranslation
);
144 // Get the kernel launch callee.
145 FunctionCallee
getKernelLaunchFn();
147 // Get the kernel launch callee.
148 FunctionCallee
getClusterKernelLaunchFn();
150 // Get the module function callee.
151 FunctionCallee
getModuleFunctionFn();
153 // Get the module load callee.
154 FunctionCallee
getModuleLoadFn();
156 // Get the module load JIT callee.
157 FunctionCallee
getModuleLoadJITFn();
159 // Get the module unload callee.
160 FunctionCallee
getModuleUnloadFn();
162 // Get the stream create callee.
163 FunctionCallee
getStreamCreateFn();
165 // Get the stream destroy callee.
166 FunctionCallee
getStreamDestroyFn();
168 // Get the stream sync callee.
169 FunctionCallee
getStreamSyncFn();
171 // Ger or create the function name global string.
172 Value
*getOrCreateFunctionName(StringRef moduleName
, StringRef kernelName
);
174 // Create the void* kernel array for passing the arguments.
175 Value
*createKernelArgArray(mlir::gpu::LaunchFuncOp op
);
177 // Create the full kernel launch.
178 llvm::LogicalResult
createKernelLaunch(mlir::gpu::LaunchFuncOp op
,
179 mlir::gpu::ObjectAttr object
);
183 IRBuilderBase
&builder
;
184 mlir::LLVM::ModuleTranslation
&moduleTranslation
;
189 PointerType
*ptrTy
{};
194 LogicalResult
SelectObjectAttrImpl::launchKernel(
195 Attribute attribute
, Operation
*launchFuncOperation
,
196 Operation
*binaryOperation
, llvm::IRBuilderBase
&builder
,
197 LLVM::ModuleTranslation
&moduleTranslation
) const {
199 assert(launchFuncOperation
&& "The launch func operation must be non null.");
200 if (!launchFuncOperation
)
203 auto launchFuncOp
= mlir::dyn_cast
<gpu::LaunchFuncOp
>(launchFuncOperation
);
205 launchFuncOperation
->emitError("operation must be a GPU launch func Op.");
209 auto binOp
= mlir::dyn_cast
<gpu::BinaryOp
>(binaryOperation
);
211 binaryOperation
->emitError("operation must be a GPU binary.");
214 gpu::ObjectAttr object
= getSelectedObject(binOp
);
218 return llvm::LaunchKernel(*moduleTranslation
.getLLVMModule(), builder
,
220 .createKernelLaunch(launchFuncOp
, object
);
223 llvm::LaunchKernel::LaunchKernel(
224 Module
&module
, IRBuilderBase
&builder
,
225 mlir::LLVM::ModuleTranslation
&moduleTranslation
)
226 : module(module
), builder(builder
), moduleTranslation(moduleTranslation
) {
227 i32Ty
= builder
.getInt32Ty();
228 i64Ty
= builder
.getInt64Ty();
229 ptrTy
= builder
.getPtrTy(0);
230 voidTy
= builder
.getVoidTy();
231 intPtrTy
= builder
.getIntPtrTy(module
.getDataLayout());
234 llvm::FunctionCallee
llvm::LaunchKernel::getKernelLaunchFn() {
235 return module
.getOrInsertFunction(
237 FunctionType::get(voidTy
,
238 ArrayRef
<Type
*>({ptrTy
, intPtrTy
, intPtrTy
, intPtrTy
,
239 intPtrTy
, intPtrTy
, intPtrTy
, i32Ty
,
240 ptrTy
, ptrTy
, ptrTy
, i64Ty
}),
244 llvm::FunctionCallee
llvm::LaunchKernel::getClusterKernelLaunchFn() {
245 return module
.getOrInsertFunction(
246 "mgpuLaunchClusterKernel",
249 ArrayRef
<Type
*>({ptrTy
, intPtrTy
, intPtrTy
, intPtrTy
, intPtrTy
,
250 intPtrTy
, intPtrTy
, intPtrTy
, intPtrTy
, intPtrTy
,
251 i32Ty
, ptrTy
, ptrTy
, ptrTy
}),
255 llvm::FunctionCallee
llvm::LaunchKernel::getModuleFunctionFn() {
256 return module
.getOrInsertFunction(
257 "mgpuModuleGetFunction",
258 FunctionType::get(ptrTy
, ArrayRef
<Type
*>({ptrTy
, ptrTy
}), false));
261 llvm::FunctionCallee
llvm::LaunchKernel::getModuleLoadFn() {
262 return module
.getOrInsertFunction(
264 FunctionType::get(ptrTy
, ArrayRef
<Type
*>({ptrTy
, i64Ty
}), false));
267 llvm::FunctionCallee
llvm::LaunchKernel::getModuleLoadJITFn() {
268 return module
.getOrInsertFunction(
270 FunctionType::get(ptrTy
, ArrayRef
<Type
*>({ptrTy
, i32Ty
}), false));
273 llvm::FunctionCallee
llvm::LaunchKernel::getModuleUnloadFn() {
274 return module
.getOrInsertFunction(
276 FunctionType::get(voidTy
, ArrayRef
<Type
*>({ptrTy
}), false));
279 llvm::FunctionCallee
llvm::LaunchKernel::getStreamCreateFn() {
280 return module
.getOrInsertFunction("mgpuStreamCreate",
281 FunctionType::get(ptrTy
, false));
284 llvm::FunctionCallee
llvm::LaunchKernel::getStreamDestroyFn() {
285 return module
.getOrInsertFunction(
287 FunctionType::get(voidTy
, ArrayRef
<Type
*>({ptrTy
}), false));
290 llvm::FunctionCallee
llvm::LaunchKernel::getStreamSyncFn() {
291 return module
.getOrInsertFunction(
292 "mgpuStreamSynchronize",
293 FunctionType::get(voidTy
, ArrayRef
<Type
*>({ptrTy
}), false));
296 // Generates an LLVM IR dialect global that contains the name of the given
297 // kernel function as a C string, and returns a pointer to its beginning.
298 llvm::Value
*llvm::LaunchKernel::getOrCreateFunctionName(StringRef moduleName
,
299 StringRef kernelName
) {
300 std::string globalName
=
301 std::string(formatv("{0}_{1}_kernel_name", moduleName
, kernelName
));
303 if (GlobalVariable
*gv
= module
.getGlobalVariable(globalName
))
306 return builder
.CreateGlobalString(kernelName
, globalName
);
309 // Creates a struct containing all kernel parameters on the stack and returns
310 // an array of type-erased pointers to the fields of the struct. The array can
311 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
312 // The generated code is essentially as follows:
314 // %struct = alloca(sizeof(struct { Parameters... }))
315 // %array = alloca(NumParameters * sizeof(void *))
316 // for (i : [0, NumParameters))
317 // %fieldPtr = llvm.getelementptr %struct[0, i]
318 // llvm.store parameters[i], %fieldPtr
319 // %elementPtr = llvm.getelementptr %array[i]
320 // llvm.store %fieldPtr, %elementPtr
323 llvm::LaunchKernel::createKernelArgArray(mlir::gpu::LaunchFuncOp op
) {
324 SmallVector
<Value
*> args
=
325 moduleTranslation
.lookupValues(op
.getKernelOperands());
326 SmallVector
<Type
*> structTypes(args
.size(), nullptr);
328 for (auto [i
, arg
] : llvm::enumerate(args
))
329 structTypes
[i
] = arg
->getType();
331 Type
*structTy
= StructType::create(module
.getContext(), structTypes
);
332 Value
*argStruct
= builder
.CreateAlloca(structTy
, 0u);
333 Value
*argArray
= builder
.CreateAlloca(
334 ptrTy
, ConstantInt::get(intPtrTy
, structTypes
.size()));
336 for (auto [i
, arg
] : enumerate(args
)) {
337 Value
*structMember
= builder
.CreateStructGEP(structTy
, argStruct
, i
);
338 builder
.CreateStore(arg
, structMember
);
339 Value
*arrayMember
= builder
.CreateConstGEP1_32(ptrTy
, argArray
, i
);
340 builder
.CreateStore(structMember
, arrayMember
);
345 // Emits LLVM IR to launch a kernel function:
346 // %0 = call %binarygetter
347 // %1 = call %moduleLoad(%0)
348 // %2 = <see generateKernelNameConstant>
349 // %3 = call %moduleGetFunction(%1, %2)
350 // %4 = call %streamCreate()
351 // %5 = <see generateParamsArray>
352 // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
353 // call %streamSynchronize(%4)
354 // call %streamDestroy(%4)
355 // call %moduleUnload(%1)
357 llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op
,
358 mlir::gpu::ObjectAttr object
) {
359 auto llvmValue
= [&](mlir::Value value
) -> Value
* {
360 Value
*v
= moduleTranslation
.lookupValue(value
);
361 assert(v
&& "Value has not been translated.");
365 // Get grid dimensions.
366 mlir::gpu::KernelDim3 grid
= op
.getGridSizeOperandValues();
367 Value
*gx
= llvmValue(grid
.x
), *gy
= llvmValue(grid
.y
),
368 *gz
= llvmValue(grid
.z
);
370 // Get block dimensions.
371 mlir::gpu::KernelDim3 block
= op
.getBlockSizeOperandValues();
372 Value
*bx
= llvmValue(block
.x
), *by
= llvmValue(block
.y
),
373 *bz
= llvmValue(block
.z
);
375 // Get dynamic shared memory size.
376 Value
*dynamicMemorySize
= nullptr;
377 if (mlir::Value dynSz
= op
.getDynamicSharedMemorySize())
378 dynamicMemorySize
= llvmValue(dynSz
);
380 dynamicMemorySize
= ConstantInt::get(i32Ty
, 0);
382 // Create the argument array.
383 Value
*argArray
= createKernelArgArray(op
);
385 // Default JIT optimization level.
386 llvm::Constant
*optV
= llvm::ConstantInt::get(i32Ty
, 0);
387 // Check if there's an optimization level embedded in the object.
388 DictionaryAttr objectProps
= object
.getProperties();
389 mlir::Attribute optAttr
;
390 if (objectProps
&& (optAttr
= objectProps
.get("O"))) {
391 auto optLevel
= dyn_cast
<IntegerAttr
>(optAttr
);
393 return op
.emitError("the optimization level must be an integer");
394 optV
= llvm::ConstantInt::get(i32Ty
, optLevel
.getValue());
397 // Load the kernel module.
398 StringRef moduleName
= op
.getKernelModuleName().getValue();
399 std::string binaryIdentifier
= getBinaryIdentifier(moduleName
);
400 Value
*binary
= module
.getGlobalVariable(binaryIdentifier
, true);
402 return op
.emitError() << "Couldn't find the binary: " << binaryIdentifier
;
404 auto binaryVar
= dyn_cast
<llvm::GlobalVariable
>(binary
);
406 return op
.emitError() << "Binary is not a global variable: "
408 llvm::Constant
*binaryInit
= binaryVar
->getInitializer();
410 dyn_cast_if_present
<llvm::ConstantDataSequential
>(binaryInit
);
412 return op
.emitError() << "Couldn't find binary data array: "
414 llvm::Constant
*binarySize
=
415 llvm::ConstantInt::get(i64Ty
, binaryDataSeq
->getNumElements() *
416 binaryDataSeq
->getElementByteSize());
418 Value
*moduleObject
=
419 object
.getFormat() == gpu::CompilationTarget::Assembly
420 ? builder
.CreateCall(getModuleLoadJITFn(), {binary
, optV
})
421 : builder
.CreateCall(getModuleLoadFn(), {binary
, binarySize
});
423 // Load the kernel function.
424 Value
*moduleFunction
= builder
.CreateCall(
425 getModuleFunctionFn(),
427 getOrCreateFunctionName(moduleName
, op
.getKernelName().getValue())});
429 // Get the stream to use for execution. If there's no async object then create
430 // a stream to make a synchronous kernel launch.
431 Value
*stream
= nullptr;
432 bool handleStream
= false;
433 if (mlir::Value asyncObject
= op
.getAsyncObject()) {
434 stream
= llvmValue(asyncObject
);
437 stream
= builder
.CreateCall(getStreamCreateFn(), {});
440 llvm::Constant
*paramsCount
=
441 llvm::ConstantInt::get(i64Ty
, op
.getNumKernelOperands());
443 // Create the launch call.
444 Value
*nullPtr
= ConstantPointerNull::get(ptrTy
);
446 // Launch kernel with clusters if cluster size is specified.
447 if (op
.hasClusterSize()) {
448 mlir::gpu::KernelDim3 cluster
= op
.getClusterSizeOperandValues();
449 Value
*cx
= llvmValue(cluster
.x
), *cy
= llvmValue(cluster
.y
),
450 *cz
= llvmValue(cluster
.z
);
452 getClusterKernelLaunchFn(),
453 ArrayRef
<Value
*>({moduleFunction
, cx
, cy
, cz
, gx
, gy
, gz
, bx
, by
, bz
,
454 dynamicMemorySize
, stream
, argArray
, nullPtr
}));
456 builder
.CreateCall(getKernelLaunchFn(),
457 ArrayRef
<Value
*>({moduleFunction
, gx
, gy
, gz
, bx
, by
,
458 bz
, dynamicMemorySize
, stream
,
459 argArray
, nullPtr
, paramsCount
}));
462 // Sync & destroy the stream, for synchronous launches.
464 builder
.CreateCall(getStreamSyncFn(), {stream
});
465 builder
.CreateCall(getStreamDestroyFn(), {stream
});
468 // Unload the kernel module.
469 builder
.CreateCall(getModuleUnloadFn(), {moduleObject
});