[mlir][spirv] NFC: Shuffle code around to better follow convention
[llvm-project.git] / mlir / lib / Conversion / SPIRVToLLVM / ConvertLaunchFuncToLLVMCalls.cpp
blob1724c7044339d5ebc259f93dd43b0538187c6a51
1 //===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements passes to convert `gpu.launch_func` op into a sequence
10 // of LLVM calls that emulate the host and device sides.
12 //===----------------------------------------------------------------------===//
14 #include "../PassDetail.h"
15 #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h"
16 #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h"
17 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
18 #include "mlir/Dialect/GPU/GPUDialect.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/SymbolTable.h"
24 #include "mlir/Transforms/DialectConversion.h"
26 #include "llvm/ADT/DenseMap.h"
27 #include "llvm/Support/FormatVariadic.h"
29 using namespace mlir;
31 static constexpr const char kSPIRVModule[] = "__spv__";
33 //===----------------------------------------------------------------------===//
34 // Utility functions
35 //===----------------------------------------------------------------------===//
37 /// Returns the string name of the `DescriptorSet` decoration.
38 static std::string descriptorSetName() {
39 return llvm::convertToSnakeFromCamelCase(
40 stringifyDecoration(spirv::Decoration::DescriptorSet));
43 /// Returns the string name of the `Binding` decoration.
44 static std::string bindingName() {
45 return llvm::convertToSnakeFromCamelCase(
46 stringifyDecoration(spirv::Decoration::Binding));
49 /// Calculates the index of the kernel's operand that is represented by the
50 /// given global variable with the `bind` attribute. We assume that the index of
51 /// each kernel's operand is mapped to (descriptorSet, binding) by the map:
52 /// i -> (0, i)
53 /// which is implemented under `LowerABIAttributesPass`.
54 static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) {
55 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
56 return binding.getInt();
59 /// Copies the given number of bytes from src to dst pointers.
60 static void copy(Location loc, Value dst, Value src, Value size,
61 OpBuilder &builder) {
62 MLIRContext *context = builder.getContext();
63 auto llvmI1Type = LLVM::LLVMType::getInt1Ty(context);
64 Value isVolatile = builder.create<LLVM::ConstantOp>(
65 loc, llvmI1Type, builder.getBoolAttr(false));
66 builder.create<LLVM::MemcpyOp>(loc, dst, src, size, isVolatile);
69 /// Encodes the binding and descriptor set numbers into a new symbolic name.
70 /// The name is specified by
71 /// {kernel_module_name}_{variable_name}_descriptor_set{ds}_binding{b}
72 /// to avoid symbolic conflicts, where 'ds' and 'b' are descriptor set and
73 /// binding numbers.
74 static std::string
75 createGlobalVariableWithBindName(spirv::GlobalVariableOp op,
76 StringRef kernelModuleName) {
77 IntegerAttr descriptorSet =
78 op->getAttrOfType<IntegerAttr>(descriptorSetName());
79 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
80 return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}",
81 kernelModuleName.str(), op.sym_name().str(),
82 std::to_string(descriptorSet.getInt()),
83 std::to_string(binding.getInt()));
86 /// Returns true if the given global variable has both a descriptor set number
87 /// and a binding number.
88 static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) {
89 IntegerAttr descriptorSet =
90 op->getAttrOfType<IntegerAttr>(descriptorSetName());
91 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
92 return descriptorSet && binding;
95 /// Fills `globalVariableMap` with SPIR-V global variables that represent kernel
96 /// arguments from the given SPIR-V module. We assume that the module contains a
97 /// single entry point function. Hence, all `spv.globalVariable`s with a bind
98 /// attribute are kernel arguments.
99 static LogicalResult getKernelGlobalVariables(
100 spirv::ModuleOp module,
101 DenseMap<uint32_t, spirv::GlobalVariableOp> &globalVariableMap) {
102 auto entryPoints = module.getOps<spirv::EntryPointOp>();
103 if (!llvm::hasSingleElement(entryPoints)) {
104 return module.emitError(
105 "The module must contain exactly one entry point function");
107 auto globalVariables = module.getOps<spirv::GlobalVariableOp>();
108 for (auto globalOp : globalVariables) {
109 if (hasDescriptorSetAndBinding(globalOp))
110 globalVariableMap[calculateGlobalIndex(globalOp)] = globalOp;
112 return success();
115 /// Encodes the SPIR-V module's symbolic name into the name of the entry point
116 /// function.
117 static LogicalResult encodeKernelName(spirv::ModuleOp module) {
118 StringRef spvModuleName = module.sym_name().getValue();
119 // We already know that the module contains exactly one entry point function
120 // based on `getKernelGlobalVariables()` call. Update this function's name
121 // to:
122 // {spv_module_name}_{function_name}
123 auto entryPoint = *module.getOps<spirv::EntryPointOp>().begin();
124 StringRef funcName = entryPoint.fn();
125 auto funcOp = module.lookupSymbol<spirv::FuncOp>(funcName);
126 std::string newFuncName = spvModuleName.str() + "_" + funcName.str();
127 if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module)))
128 return failure();
129 SymbolTable::setSymbolName(funcOp, newFuncName);
130 return success();
133 //===----------------------------------------------------------------------===//
134 // Conversion patterns
135 //===----------------------------------------------------------------------===//
137 namespace {
139 /// Structure to group information about the variables being copied.
140 struct CopyInfo {
141 Value dst;
142 Value src;
143 Value size;
146 /// This pattern emulates a call to the kernel in LLVM dialect. For that, we
147 /// copy the data to the global variable (emulating device side), call the
148 /// kernel as a normal void LLVM function, and copy the data back (emulating the
149 /// host side).
150 class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
151 using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern;
153 LogicalResult
154 matchAndRewrite(gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
155 ConversionPatternRewriter &rewriter) const override {
156 auto *op = launchOp.getOperation();
157 MLIRContext *context = rewriter.getContext();
158 auto module = launchOp->getParentOfType<ModuleOp>();
160 // Get the SPIR-V module that represents the gpu kernel module. The module
161 // is named:
162 // __spv__{kernel_module_name}
163 // based on GPU to SPIR-V conversion.
164 StringRef kernelModuleName = launchOp.getKernelModuleName();
165 std::string spvModuleName = kSPIRVModule + kernelModuleName.str();
166 auto spvModule = module.lookupSymbol<spirv::ModuleOp>(spvModuleName);
167 if (!spvModule) {
168 return launchOp.emitOpError("SPIR-V kernel module '")
169 << spvModuleName << "' is not found";
172 // Declare kernel function in the main module so that it later can be linked
173 // with its definition from the kernel module. We know that the kernel
174 // function would have no arguments and the data is passed via global
175 // variables. The name of the kernel will be
176 // {spv_module_name}_{kernel_function_name}
177 // to avoid symbolic name conflicts.
178 StringRef kernelFuncName = launchOp.getKernelName();
179 std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str();
180 auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(newKernelFuncName);
181 if (!kernelFunc) {
182 OpBuilder::InsertionGuard guard(rewriter);
183 rewriter.setInsertionPointToStart(module.getBody());
184 kernelFunc = rewriter.create<LLVM::LLVMFuncOp>(
185 rewriter.getUnknownLoc(), newKernelFuncName,
186 LLVM::LLVMType::getFunctionTy(LLVM::LLVMType::getVoidTy(context),
187 ArrayRef<LLVM::LLVMType>(),
188 /*isVarArg=*/false));
189 rewriter.setInsertionPoint(launchOp);
192 // Get all global variables associated with the kernel operands.
193 DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
194 if (failed(getKernelGlobalVariables(spvModule, globalVariableMap)))
195 return failure();
197 // Traverse kernel operands that were converted to MemRefDescriptors. For
198 // each operand, create a global variable and copy data from operand to it.
199 Location loc = launchOp.getLoc();
200 SmallVector<CopyInfo, 4> copyInfo;
201 auto numKernelOperands = launchOp.getNumKernelOperands();
202 auto kernelOperands = operands.take_back(numKernelOperands);
203 for (auto operand : llvm::enumerate(kernelOperands)) {
204 // Check if the kernel's opernad is a ranked memref.
205 auto memRefType = launchOp.getKernelOperand(operand.index())
206 .getType()
207 .dyn_cast<MemRefType>();
208 if (!memRefType)
209 return failure();
211 // Calculate the size of the memref and get the pointer to the allocated
212 // buffer.
213 SmallVector<Value, 4> sizes;
214 SmallVector<Value, 4> strides;
215 Value sizeBytes;
216 getMemRefDescriptorSizes(loc, memRefType, operand.value(), rewriter,
217 sizes, strides, sizeBytes);
218 MemRefDescriptor descriptor(operand.value());
219 Value src = descriptor.allocatedPtr(rewriter, loc);
221 // Get the global variable in the SPIR-V module that is associated with
222 // the kernel operand. Construct its new name and create a corresponding
223 // LLVM dialect global variable.
224 spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
225 auto pointeeType =
226 spirvGlobal.type().cast<spirv::PointerType>().getPointeeType();
227 auto dstGlobalType = typeConverter->convertType(pointeeType);
228 if (!dstGlobalType)
229 return failure();
230 std::string name =
231 createGlobalVariableWithBindName(spirvGlobal, spvModuleName);
232 // Check if this variable has already been created.
233 auto dstGlobal = module.lookupSymbol<LLVM::GlobalOp>(name);
234 if (!dstGlobal) {
235 OpBuilder::InsertionGuard guard(rewriter);
236 rewriter.setInsertionPointToStart(module.getBody());
237 dstGlobal = rewriter.create<LLVM::GlobalOp>(
238 loc, dstGlobalType.cast<LLVM::LLVMType>(),
239 /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute());
240 rewriter.setInsertionPoint(launchOp);
243 // Copy the data from src operand pointer to dst global variable. Save
244 // src, dst and size so that we can copy data back after emulating the
245 // kernel call.
246 Value dst = rewriter.create<LLVM::AddressOfOp>(loc, dstGlobal);
247 copy(loc, dst, src, sizeBytes, rewriter);
249 CopyInfo info;
250 info.dst = dst;
251 info.src = src;
252 info.size = sizeBytes;
253 copyInfo.push_back(info);
255 // Create a call to the kernel and copy the data back.
256 rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc,
257 ArrayRef<Value>());
258 for (CopyInfo info : copyInfo)
259 copy(loc, info.src, info.dst, info.size, rewriter);
260 return success();
264 class LowerHostCodeToLLVM
265 : public LowerHostCodeToLLVMBase<LowerHostCodeToLLVM> {
266 public:
267 void runOnOperation() override {
268 ModuleOp module = getOperation();
270 // Erase the GPU module.
271 for (auto gpuModule :
272 llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>()))
273 gpuModule.erase();
275 // Specify options to lower Standard to LLVM and pull in the conversion
276 // patterns.
277 LowerToLLVMOptions options = {
278 /*useBarePtrCallConv=*/false,
279 /*emitCWrappers=*/true,
280 /*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout};
281 auto *context = module.getContext();
282 OwningRewritePatternList patterns;
283 LLVMTypeConverter typeConverter(context, options);
284 populateStdToLLVMConversionPatterns(typeConverter, patterns);
285 patterns.insert<GPULaunchLowering>(typeConverter);
287 // Pull in SPIR-V type conversion patterns to convert SPIR-V global
288 // variable's type to LLVM dialect type.
289 populateSPIRVToLLVMTypeConversion(typeConverter);
291 ConversionTarget target(*context);
292 target.addLegalDialect<LLVM::LLVMDialect>();
293 if (failed(applyPartialConversion(module, target, std::move(patterns))))
294 signalPassFailure();
296 // Finally, modify the kernel function in SPIR-V modules to avoid symbolic
297 // conflicts.
298 for (auto spvModule : module.getOps<spirv::ModuleOp>())
299 encodeKernelName(spvModule);
302 } // namespace
304 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
305 mlir::createLowerHostCodeToLLVMPass() {
306 return std::make_unique<LowerHostCodeToLLVM>();