1 //===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements passes to convert `gpu.launch_func` op into a sequence
10 // of LLVM calls that emulate the host and device sides.
12 //===----------------------------------------------------------------------===//
14 #include "../PassDetail.h"
15 #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h"
16 #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h"
17 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
18 #include "mlir/Dialect/GPU/GPUDialect.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/SymbolTable.h"
24 #include "mlir/Transforms/DialectConversion.h"
26 #include "llvm/ADT/DenseMap.h"
27 #include "llvm/Support/FormatVariadic.h"
31 static constexpr const char kSPIRVModule
[] = "__spv__";
33 //===----------------------------------------------------------------------===//
35 //===----------------------------------------------------------------------===//
37 /// Returns the string name of the `DescriptorSet` decoration.
38 static std::string
descriptorSetName() {
39 return llvm::convertToSnakeFromCamelCase(
40 stringifyDecoration(spirv::Decoration::DescriptorSet
));
43 /// Returns the string name of the `Binding` decoration.
44 static std::string
bindingName() {
45 return llvm::convertToSnakeFromCamelCase(
46 stringifyDecoration(spirv::Decoration::Binding
));
49 /// Calculates the index of the kernel's operand that is represented by the
50 /// given global variable with the `bind` attribute. We assume that the index of
51 /// each kernel's operand is mapped to (descriptorSet, binding) by the map:
53 /// which is implemented under `LowerABIAttributesPass`.
54 static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op
) {
55 IntegerAttr binding
= op
->getAttrOfType
<IntegerAttr
>(bindingName());
56 return binding
.getInt();
59 /// Copies the given number of bytes from src to dst pointers.
60 static void copy(Location loc
, Value dst
, Value src
, Value size
,
62 MLIRContext
*context
= builder
.getContext();
63 auto llvmI1Type
= LLVM::LLVMType::getInt1Ty(context
);
64 Value isVolatile
= builder
.create
<LLVM::ConstantOp
>(
65 loc
, llvmI1Type
, builder
.getBoolAttr(false));
66 builder
.create
<LLVM::MemcpyOp
>(loc
, dst
, src
, size
, isVolatile
);
69 /// Encodes the binding and descriptor set numbers into a new symbolic name.
70 /// The name is specified by
71 /// {kernel_module_name}_{variable_name}_descriptor_set{ds}_binding{b}
72 /// to avoid symbolic conflicts, where 'ds' and 'b' are descriptor set and
75 createGlobalVariableWithBindName(spirv::GlobalVariableOp op
,
76 StringRef kernelModuleName
) {
77 IntegerAttr descriptorSet
=
78 op
->getAttrOfType
<IntegerAttr
>(descriptorSetName());
79 IntegerAttr binding
= op
->getAttrOfType
<IntegerAttr
>(bindingName());
80 return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}",
81 kernelModuleName
.str(), op
.sym_name().str(),
82 std::to_string(descriptorSet
.getInt()),
83 std::to_string(binding
.getInt()));
86 /// Returns true if the given global variable has both a descriptor set number
87 /// and a binding number.
88 static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op
) {
89 IntegerAttr descriptorSet
=
90 op
->getAttrOfType
<IntegerAttr
>(descriptorSetName());
91 IntegerAttr binding
= op
->getAttrOfType
<IntegerAttr
>(bindingName());
92 return descriptorSet
&& binding
;
95 /// Fills `globalVariableMap` with SPIR-V global variables that represent kernel
96 /// arguments from the given SPIR-V module. We assume that the module contains a
97 /// single entry point function. Hence, all `spv.globalVariable`s with a bind
98 /// attribute are kernel arguments.
99 static LogicalResult
getKernelGlobalVariables(
100 spirv::ModuleOp module
,
101 DenseMap
<uint32_t, spirv::GlobalVariableOp
> &globalVariableMap
) {
102 auto entryPoints
= module
.getOps
<spirv::EntryPointOp
>();
103 if (!llvm::hasSingleElement(entryPoints
)) {
104 return module
.emitError(
105 "The module must contain exactly one entry point function");
107 auto globalVariables
= module
.getOps
<spirv::GlobalVariableOp
>();
108 for (auto globalOp
: globalVariables
) {
109 if (hasDescriptorSetAndBinding(globalOp
))
110 globalVariableMap
[calculateGlobalIndex(globalOp
)] = globalOp
;
115 /// Encodes the SPIR-V module's symbolic name into the name of the entry point
117 static LogicalResult
encodeKernelName(spirv::ModuleOp module
) {
118 StringRef spvModuleName
= module
.sym_name().getValue();
119 // We already know that the module contains exactly one entry point function
120 // based on `getKernelGlobalVariables()` call. Update this function's name
122 // {spv_module_name}_{function_name}
123 auto entryPoint
= *module
.getOps
<spirv::EntryPointOp
>().begin();
124 StringRef funcName
= entryPoint
.fn();
125 auto funcOp
= module
.lookupSymbol
<spirv::FuncOp
>(funcName
);
126 std::string newFuncName
= spvModuleName
.str() + "_" + funcName
.str();
127 if (failed(SymbolTable::replaceAllSymbolUses(funcOp
, newFuncName
, module
)))
129 SymbolTable::setSymbolName(funcOp
, newFuncName
);
133 //===----------------------------------------------------------------------===//
134 // Conversion patterns
135 //===----------------------------------------------------------------------===//
139 /// Structure to group information about the variables being copied.
146 /// This pattern emulates a call to the kernel in LLVM dialect. For that, we
147 /// copy the data to the global variable (emulating device side), call the
148 /// kernel as a normal void LLVM function, and copy the data back (emulating the
150 class GPULaunchLowering
: public ConvertOpToLLVMPattern
<gpu::LaunchFuncOp
> {
151 using ConvertOpToLLVMPattern
<gpu::LaunchFuncOp
>::ConvertOpToLLVMPattern
;
154 matchAndRewrite(gpu::LaunchFuncOp launchOp
, ArrayRef
<Value
> operands
,
155 ConversionPatternRewriter
&rewriter
) const override
{
156 auto *op
= launchOp
.getOperation();
157 MLIRContext
*context
= rewriter
.getContext();
158 auto module
= launchOp
->getParentOfType
<ModuleOp
>();
160 // Get the SPIR-V module that represents the gpu kernel module. The module
162 // __spv__{kernel_module_name}
163 // based on GPU to SPIR-V conversion.
164 StringRef kernelModuleName
= launchOp
.getKernelModuleName();
165 std::string spvModuleName
= kSPIRVModule
+ kernelModuleName
.str();
166 auto spvModule
= module
.lookupSymbol
<spirv::ModuleOp
>(spvModuleName
);
168 return launchOp
.emitOpError("SPIR-V kernel module '")
169 << spvModuleName
<< "' is not found";
172 // Declare kernel function in the main module so that it later can be linked
173 // with its definition from the kernel module. We know that the kernel
174 // function would have no arguments and the data is passed via global
175 // variables. The name of the kernel will be
176 // {spv_module_name}_{kernel_function_name}
177 // to avoid symbolic name conflicts.
178 StringRef kernelFuncName
= launchOp
.getKernelName();
179 std::string newKernelFuncName
= spvModuleName
+ "_" + kernelFuncName
.str();
180 auto kernelFunc
= module
.lookupSymbol
<LLVM::LLVMFuncOp
>(newKernelFuncName
);
182 OpBuilder::InsertionGuard
guard(rewriter
);
183 rewriter
.setInsertionPointToStart(module
.getBody());
184 kernelFunc
= rewriter
.create
<LLVM::LLVMFuncOp
>(
185 rewriter
.getUnknownLoc(), newKernelFuncName
,
186 LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(context
),
187 ArrayRef
<LLVM::LLVMType
>(),
188 /*isVarArg=*/false));
189 rewriter
.setInsertionPoint(launchOp
);
192 // Get all global variables associated with the kernel operands.
193 DenseMap
<uint32_t, spirv::GlobalVariableOp
> globalVariableMap
;
194 if (failed(getKernelGlobalVariables(spvModule
, globalVariableMap
)))
197 // Traverse kernel operands that were converted to MemRefDescriptors. For
198 // each operand, create a global variable and copy data from operand to it.
199 Location loc
= launchOp
.getLoc();
200 SmallVector
<CopyInfo
, 4> copyInfo
;
201 auto numKernelOperands
= launchOp
.getNumKernelOperands();
202 auto kernelOperands
= operands
.take_back(numKernelOperands
);
203 for (auto operand
: llvm::enumerate(kernelOperands
)) {
204 // Check if the kernel's opernad is a ranked memref.
205 auto memRefType
= launchOp
.getKernelOperand(operand
.index())
207 .dyn_cast
<MemRefType
>();
211 // Calculate the size of the memref and get the pointer to the allocated
213 SmallVector
<Value
, 4> sizes
;
214 SmallVector
<Value
, 4> strides
;
216 getMemRefDescriptorSizes(loc
, memRefType
, operand
.value(), rewriter
,
217 sizes
, strides
, sizeBytes
);
218 MemRefDescriptor
descriptor(operand
.value());
219 Value src
= descriptor
.allocatedPtr(rewriter
, loc
);
221 // Get the global variable in the SPIR-V module that is associated with
222 // the kernel operand. Construct its new name and create a corresponding
223 // LLVM dialect global variable.
224 spirv::GlobalVariableOp spirvGlobal
= globalVariableMap
[operand
.index()];
226 spirvGlobal
.type().cast
<spirv::PointerType
>().getPointeeType();
227 auto dstGlobalType
= typeConverter
->convertType(pointeeType
);
231 createGlobalVariableWithBindName(spirvGlobal
, spvModuleName
);
232 // Check if this variable has already been created.
233 auto dstGlobal
= module
.lookupSymbol
<LLVM::GlobalOp
>(name
);
235 OpBuilder::InsertionGuard
guard(rewriter
);
236 rewriter
.setInsertionPointToStart(module
.getBody());
237 dstGlobal
= rewriter
.create
<LLVM::GlobalOp
>(
238 loc
, dstGlobalType
.cast
<LLVM::LLVMType
>(),
239 /*isConstant=*/false, LLVM::Linkage::Linkonce
, name
, Attribute());
240 rewriter
.setInsertionPoint(launchOp
);
243 // Copy the data from src operand pointer to dst global variable. Save
244 // src, dst and size so that we can copy data back after emulating the
246 Value dst
= rewriter
.create
<LLVM::AddressOfOp
>(loc
, dstGlobal
);
247 copy(loc
, dst
, src
, sizeBytes
, rewriter
);
252 info
.size
= sizeBytes
;
253 copyInfo
.push_back(info
);
255 // Create a call to the kernel and copy the data back.
256 rewriter
.replaceOpWithNewOp
<LLVM::CallOp
>(op
, kernelFunc
,
258 for (CopyInfo info
: copyInfo
)
259 copy(loc
, info
.src
, info
.dst
, info
.size
, rewriter
);
264 class LowerHostCodeToLLVM
265 : public LowerHostCodeToLLVMBase
<LowerHostCodeToLLVM
> {
267 void runOnOperation() override
{
268 ModuleOp module
= getOperation();
270 // Erase the GPU module.
271 for (auto gpuModule
:
272 llvm::make_early_inc_range(module
.getOps
<gpu::GPUModuleOp
>()))
275 // Specify options to lower Standard to LLVM and pull in the conversion
277 LowerToLLVMOptions options
= {
278 /*useBarePtrCallConv=*/false,
279 /*emitCWrappers=*/true,
280 /*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout
};
281 auto *context
= module
.getContext();
282 OwningRewritePatternList patterns
;
283 LLVMTypeConverter
typeConverter(context
, options
);
284 populateStdToLLVMConversionPatterns(typeConverter
, patterns
);
285 patterns
.insert
<GPULaunchLowering
>(typeConverter
);
287 // Pull in SPIR-V type conversion patterns to convert SPIR-V global
288 // variable's type to LLVM dialect type.
289 populateSPIRVToLLVMTypeConversion(typeConverter
);
291 ConversionTarget
target(*context
);
292 target
.addLegalDialect
<LLVM::LLVMDialect
>();
293 if (failed(applyPartialConversion(module
, target
, std::move(patterns
))))
296 // Finally, modify the kernel function in SPIR-V modules to avoid symbolic
298 for (auto spvModule
: module
.getOps
<spirv::ModuleOp
>())
299 encodeKernelName(spvModule
);
304 std::unique_ptr
<mlir::OperationPass
<mlir::ModuleOp
>>
305 mlir::createLowerHostCodeToLLVMPass() {
306 return std::make_unique
<LowerHostCodeToLLVM
>();