1 //===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements passes to convert `gpu.launch_func` op into a sequence
10 // of LLVM calls that emulate the host and device sides.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
16 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
17 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
18 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
19 #include "mlir/Conversion/LLVMCommon/Pattern.h"
20 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
21 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
22 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"
24 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
25 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
26 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
27 #include "mlir/IR/BuiltinOps.h"
28 #include "mlir/IR/SymbolTable.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Transforms/DialectConversion.h"
31 #include "llvm/ADT/DenseMap.h"
32 #include "llvm/ADT/StringExtras.h"
33 #include "llvm/Support/FormatVariadic.h"
36 #define GEN_PASS_DEF_LOWERHOSTCODETOLLVMPASS
37 #include "mlir/Conversion/Passes.h.inc"
42 static constexpr const char kSPIRVModule
[] = "__spv__";
44 //===----------------------------------------------------------------------===//
46 //===----------------------------------------------------------------------===//
48 /// Returns the string name of the `DescriptorSet` decoration.
49 static std::string
descriptorSetName() {
50 return llvm::convertToSnakeFromCamelCase(
51 stringifyDecoration(spirv::Decoration::DescriptorSet
));
54 /// Returns the string name of the `Binding` decoration.
55 static std::string
bindingName() {
56 return llvm::convertToSnakeFromCamelCase(
57 stringifyDecoration(spirv::Decoration::Binding
));
60 /// Calculates the index of the kernel's operand that is represented by the
61 /// given global variable with the `bind` attribute. We assume that the index of
62 /// each kernel's operand is mapped to (descriptorSet, binding) by the map:
64 /// which is implemented under `LowerABIAttributesPass`.
65 static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op
) {
66 IntegerAttr binding
= op
->getAttrOfType
<IntegerAttr
>(bindingName());
67 return binding
.getInt();
70 /// Copies the given number of bytes from src to dst pointers.
71 static void copy(Location loc
, Value dst
, Value src
, Value size
,
73 builder
.create
<LLVM::MemcpyOp
>(loc
, dst
, src
, size
, /*isVolatile=*/false);
76 /// Encodes the binding and descriptor set numbers into a new symbolic name.
77 /// The name is specified by
78 /// {kernel_module_name}_{variable_name}_descriptor_set{ds}_binding{b}
79 /// to avoid symbolic conflicts, where 'ds' and 'b' are descriptor set and
82 createGlobalVariableWithBindName(spirv::GlobalVariableOp op
,
83 StringRef kernelModuleName
) {
84 IntegerAttr descriptorSet
=
85 op
->getAttrOfType
<IntegerAttr
>(descriptorSetName());
86 IntegerAttr binding
= op
->getAttrOfType
<IntegerAttr
>(bindingName());
87 return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}",
88 kernelModuleName
.str(), op
.getSymName().str(),
89 std::to_string(descriptorSet
.getInt()),
90 std::to_string(binding
.getInt()));
93 /// Returns true if the given global variable has both a descriptor set number
94 /// and a binding number.
95 static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op
) {
96 IntegerAttr descriptorSet
=
97 op
->getAttrOfType
<IntegerAttr
>(descriptorSetName());
98 IntegerAttr binding
= op
->getAttrOfType
<IntegerAttr
>(bindingName());
99 return descriptorSet
&& binding
;
102 /// Fills `globalVariableMap` with SPIR-V global variables that represent kernel
103 /// arguments from the given SPIR-V module. We assume that the module contains a
104 /// single entry point function. Hence, all `spirv.GlobalVariable`s with a bind
105 /// attribute are kernel arguments.
106 static LogicalResult
getKernelGlobalVariables(
107 spirv::ModuleOp module
,
108 DenseMap
<uint32_t, spirv::GlobalVariableOp
> &globalVariableMap
) {
109 auto entryPoints
= module
.getOps
<spirv::EntryPointOp
>();
110 if (!llvm::hasSingleElement(entryPoints
)) {
111 return module
.emitError(
112 "The module must contain exactly one entry point function");
114 auto globalVariables
= module
.getOps
<spirv::GlobalVariableOp
>();
115 for (auto globalOp
: globalVariables
) {
116 if (hasDescriptorSetAndBinding(globalOp
))
117 globalVariableMap
[calculateGlobalIndex(globalOp
)] = globalOp
;
122 /// Encodes the SPIR-V module's symbolic name into the name of the entry point
124 static LogicalResult
encodeKernelName(spirv::ModuleOp module
) {
125 StringRef spvModuleName
= module
.getSymName().value_or(kSPIRVModule
);
126 // We already know that the module contains exactly one entry point function
127 // based on `getKernelGlobalVariables()` call. Update this function's name
129 // {spv_module_name}_{function_name}
130 auto entryPoints
= module
.getOps
<spirv::EntryPointOp
>();
131 if (!llvm::hasSingleElement(entryPoints
)) {
132 return module
.emitError(
133 "The module must contain exactly one entry point function");
135 spirv::EntryPointOp entryPoint
= *entryPoints
.begin();
136 StringRef funcName
= entryPoint
.getFn();
137 auto funcOp
= module
.lookupSymbol
<spirv::FuncOp
>(entryPoint
.getFnAttr());
138 StringAttr newFuncName
=
139 StringAttr::get(module
->getContext(), spvModuleName
+ "_" + funcName
);
140 if (failed(SymbolTable::replaceAllSymbolUses(funcOp
, newFuncName
, module
)))
142 SymbolTable::setSymbolName(funcOp
, newFuncName
);
146 //===----------------------------------------------------------------------===//
147 // Conversion patterns
148 //===----------------------------------------------------------------------===//
152 /// Structure to group information about the variables being copied.
159 /// This pattern emulates a call to the kernel in LLVM dialect. For that, we
160 /// copy the data to the global variable (emulating device side), call the
161 /// kernel as a normal void LLVM function, and copy the data back (emulating the
163 class GPULaunchLowering
: public ConvertOpToLLVMPattern
<gpu::LaunchFuncOp
> {
164 using ConvertOpToLLVMPattern
<gpu::LaunchFuncOp
>::ConvertOpToLLVMPattern
;
167 matchAndRewrite(gpu::LaunchFuncOp launchOp
, OpAdaptor adaptor
,
168 ConversionPatternRewriter
&rewriter
) const override
{
169 auto *op
= launchOp
.getOperation();
170 MLIRContext
*context
= rewriter
.getContext();
171 auto module
= launchOp
->getParentOfType
<ModuleOp
>();
173 // Get the SPIR-V module that represents the gpu kernel module. The module
175 // __spv__{kernel_module_name}
176 // based on GPU to SPIR-V conversion.
177 StringRef kernelModuleName
= launchOp
.getKernelModuleName().getValue();
178 std::string spvModuleName
= kSPIRVModule
+ kernelModuleName
.str();
179 auto spvModule
= module
.lookupSymbol
<spirv::ModuleOp
>(
180 StringAttr::get(context
, spvModuleName
));
182 return launchOp
.emitOpError("SPIR-V kernel module '")
183 << spvModuleName
<< "' is not found";
186 // Declare kernel function in the main module so that it later can be linked
187 // with its definition from the kernel module. We know that the kernel
188 // function would have no arguments and the data is passed via global
189 // variables. The name of the kernel will be
190 // {spv_module_name}_{kernel_function_name}
191 // to avoid symbolic name conflicts.
192 StringRef kernelFuncName
= launchOp
.getKernelName().getValue();
193 std::string newKernelFuncName
= spvModuleName
+ "_" + kernelFuncName
.str();
194 auto kernelFunc
= module
.lookupSymbol
<LLVM::LLVMFuncOp
>(
195 StringAttr::get(context
, newKernelFuncName
));
197 OpBuilder::InsertionGuard
guard(rewriter
);
198 rewriter
.setInsertionPointToStart(module
.getBody());
199 kernelFunc
= rewriter
.create
<LLVM::LLVMFuncOp
>(
200 rewriter
.getUnknownLoc(), newKernelFuncName
,
201 LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context
),
203 rewriter
.setInsertionPoint(launchOp
);
206 // Get all global variables associated with the kernel operands.
207 DenseMap
<uint32_t, spirv::GlobalVariableOp
> globalVariableMap
;
208 if (failed(getKernelGlobalVariables(spvModule
, globalVariableMap
)))
211 // Traverse kernel operands that were converted to MemRefDescriptors. For
212 // each operand, create a global variable and copy data from operand to it.
213 Location loc
= launchOp
.getLoc();
214 SmallVector
<CopyInfo
, 4> copyInfo
;
215 auto numKernelOperands
= launchOp
.getNumKernelOperands();
216 auto kernelOperands
= adaptor
.getOperands().take_back(numKernelOperands
);
217 for (const auto &operand
: llvm::enumerate(kernelOperands
)) {
218 // Check if the kernel's operand is a ranked memref.
219 auto memRefType
= dyn_cast
<MemRefType
>(
220 launchOp
.getKernelOperand(operand
.index()).getType());
224 // Calculate the size of the memref and get the pointer to the allocated
226 SmallVector
<Value
, 4> sizes
;
227 SmallVector
<Value
, 4> strides
;
229 getMemRefDescriptorSizes(loc
, memRefType
, {}, rewriter
, sizes
, strides
,
231 MemRefDescriptor
descriptor(operand
.value());
232 Value src
= descriptor
.allocatedPtr(rewriter
, loc
);
234 // Get the global variable in the SPIR-V module that is associated with
235 // the kernel operand. Construct its new name and create a corresponding
236 // LLVM dialect global variable.
237 spirv::GlobalVariableOp spirvGlobal
= globalVariableMap
[operand
.index()];
239 cast
<spirv::PointerType
>(spirvGlobal
.getType()).getPointeeType();
240 auto dstGlobalType
= typeConverter
->convertType(pointeeType
);
244 createGlobalVariableWithBindName(spirvGlobal
, spvModuleName
);
245 // Check if this variable has already been created.
246 auto dstGlobal
= module
.lookupSymbol
<LLVM::GlobalOp
>(name
);
248 OpBuilder::InsertionGuard
guard(rewriter
);
249 rewriter
.setInsertionPointToStart(module
.getBody());
250 dstGlobal
= rewriter
.create
<LLVM::GlobalOp
>(
252 /*isConstant=*/false, LLVM::Linkage::Linkonce
, name
, Attribute(),
254 rewriter
.setInsertionPoint(launchOp
);
257 // Copy the data from src operand pointer to dst global variable. Save
258 // src, dst and size so that we can copy data back after emulating the
260 Value dst
= rewriter
.create
<LLVM::AddressOfOp
>(
261 loc
, typeConverter
->convertType(spirvGlobal
.getType()),
262 dstGlobal
.getSymName());
263 copy(loc
, dst
, src
, sizeBytes
, rewriter
);
268 info
.size
= sizeBytes
;
269 copyInfo
.push_back(info
);
271 // Create a call to the kernel and copy the data back.
272 rewriter
.replaceOpWithNewOp
<LLVM::CallOp
>(op
, kernelFunc
,
274 for (CopyInfo info
: copyInfo
)
275 copy(loc
, info
.src
, info
.dst
, info
.size
, rewriter
);
280 class LowerHostCodeToLLVM
281 : public impl::LowerHostCodeToLLVMPassBase
<LowerHostCodeToLLVM
> {
285 void runOnOperation() override
{
286 ModuleOp module
= getOperation();
288 // Erase the GPU module.
289 for (auto gpuModule
:
290 llvm::make_early_inc_range(module
.getOps
<gpu::GPUModuleOp
>()))
293 // Request C wrapper emission.
294 for (auto func
: module
.getOps
<func::FuncOp
>()) {
295 func
->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
296 UnitAttr::get(&getContext()));
299 // Specify options to lower to LLVM and pull in the conversion patterns.
300 LowerToLLVMOptions
options(module
.getContext());
302 auto *context
= module
.getContext();
303 RewritePatternSet
patterns(context
);
304 LLVMTypeConverter
typeConverter(context
, options
);
305 mlir::arith::populateArithToLLVMConversionPatterns(typeConverter
, patterns
);
306 populateFinalizeMemRefToLLVMConversionPatterns(typeConverter
, patterns
);
307 populateFuncToLLVMConversionPatterns(typeConverter
, patterns
);
308 patterns
.add
<GPULaunchLowering
>(typeConverter
);
310 // Pull in SPIR-V type conversion patterns to convert SPIR-V global
311 // variable's type to LLVM dialect type.
312 populateSPIRVToLLVMTypeConversion(typeConverter
);
314 ConversionTarget
target(*context
);
315 target
.addLegalDialect
<LLVM::LLVMDialect
>();
316 if (failed(applyPartialConversion(module
, target
, std::move(patterns
))))
319 // Finally, modify the kernel function in SPIR-V modules to avoid symbolic
321 for (auto spvModule
: module
.getOps
<spirv::ModuleOp
>()) {
322 if (failed(encodeKernelName(spvModule
))) {