LAA: improve code in getStrideFromPointer (NFC) (#124780)
[llvm-project.git] / flang / lib / Optimizer / Transforms / CUFGPUToLLVMConversion.cpp
blob4611a18a541bcc910bf218a7207455528d7feecb
1 //===-- CUFGPUToLLVMConversion.cpp ----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "flang/Optimizer/Transforms/CUFGPUToLLVMConversion.h"
10 #include "flang/Common/Fortran.h"
11 #include "flang/Optimizer/CodeGen/TypeConverter.h"
12 #include "flang/Optimizer/Support/DataLayout.h"
13 #include "flang/Runtime/CUDA/common.h"
14 #include "mlir/Conversion/LLVMCommon/Pattern.h"
15 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 #include "llvm/Support/FormatVariadic.h"
21 namespace fir {
22 #define GEN_PASS_DEF_CUFGPUTOLLVMCONVERSION
23 #include "flang/Optimizer/Transforms/Passes.h.inc"
24 } // namespace fir
26 using namespace fir;
27 using namespace mlir;
28 using namespace Fortran::runtime;
30 namespace {
32 static mlir::Value createKernelArgArray(mlir::Location loc,
33 mlir::ValueRange operands,
34 mlir::PatternRewriter &rewriter) {
36 auto *ctx = rewriter.getContext();
37 llvm::SmallVector<mlir::Type> structTypes(operands.size(), nullptr);
39 for (auto [i, arg] : llvm::enumerate(operands))
40 structTypes[i] = arg.getType();
42 auto structTy = mlir::LLVM::LLVMStructType::getLiteral(ctx, structTypes);
43 auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
44 mlir::Type i32Ty = rewriter.getI32Type();
45 auto zero = rewriter.create<mlir::LLVM::ConstantOp>(
46 loc, i32Ty, rewriter.getIntegerAttr(i32Ty, 0));
47 auto one = rewriter.create<mlir::LLVM::ConstantOp>(
48 loc, i32Ty, rewriter.getIntegerAttr(i32Ty, 1));
49 mlir::Value argStruct =
50 rewriter.create<mlir::LLVM::AllocaOp>(loc, ptrTy, structTy, one);
51 auto size = rewriter.create<mlir::LLVM::ConstantOp>(
52 loc, i32Ty, rewriter.getIntegerAttr(i32Ty, structTypes.size()));
53 mlir::Value argArray =
54 rewriter.create<mlir::LLVM::AllocaOp>(loc, ptrTy, ptrTy, size);
56 for (auto [i, arg] : llvm::enumerate(operands)) {
57 auto indice = rewriter.create<mlir::LLVM::ConstantOp>(
58 loc, i32Ty, rewriter.getIntegerAttr(i32Ty, i));
59 mlir::Value structMember = rewriter.create<LLVM::GEPOp>(
60 loc, ptrTy, structTy, argStruct,
61 mlir::ArrayRef<mlir::Value>({zero, indice}));
62 rewriter.create<LLVM::StoreOp>(loc, arg, structMember);
63 mlir::Value arrayMember = rewriter.create<LLVM::GEPOp>(
64 loc, ptrTy, ptrTy, argArray, mlir::ArrayRef<mlir::Value>({indice}));
65 rewriter.create<LLVM::StoreOp>(loc, structMember, arrayMember);
67 return argArray;
70 struct GPULaunchKernelConversion
71 : public mlir::ConvertOpToLLVMPattern<mlir::gpu::LaunchFuncOp> {
72 explicit GPULaunchKernelConversion(
73 const fir::LLVMTypeConverter &typeConverter, mlir::PatternBenefit benefit)
74 : mlir::ConvertOpToLLVMPattern<mlir::gpu::LaunchFuncOp>(typeConverter,
75 benefit) {}
77 using OpAdaptor = typename mlir::gpu::LaunchFuncOp::Adaptor;
79 mlir::LogicalResult
80 matchAndRewrite(mlir::gpu::LaunchFuncOp op, OpAdaptor adaptor,
81 mlir::ConversionPatternRewriter &rewriter) const override {
82 mlir::Location loc = op.getLoc();
83 auto *ctx = rewriter.getContext();
84 mlir::ModuleOp mod = op->getParentOfType<mlir::ModuleOp>();
85 mlir::Value dynamicMemorySize = op.getDynamicSharedMemorySize();
86 mlir::Type i32Ty = rewriter.getI32Type();
87 if (!dynamicMemorySize)
88 dynamicMemorySize = rewriter.create<mlir::LLVM::ConstantOp>(
89 loc, i32Ty, rewriter.getIntegerAttr(i32Ty, 0));
91 mlir::Value kernelArgs =
92 createKernelArgArray(loc, adaptor.getKernelOperands(), rewriter);
94 auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
95 auto kernel = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>(op.getKernelName());
96 mlir::Value kernelPtr;
97 if (!kernel) {
98 auto funcOp = mod.lookupSymbol<mlir::func::FuncOp>(op.getKernelName());
99 if (!funcOp)
100 return mlir::failure();
101 kernelPtr =
102 rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, funcOp.getName());
103 } else {
104 kernelPtr =
105 rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, kernel.getName());
108 auto llvmIntPtrType = mlir::IntegerType::get(
109 ctx, this->getTypeConverter()->getPointerBitwidth(0));
110 auto voidTy = mlir::LLVM::LLVMVoidType::get(ctx);
112 mlir::Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, ptrTy);
114 if (op.hasClusterSize()) {
115 auto funcOp = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>(
116 RTNAME_STRING(CUFLaunchClusterKernel));
117 auto funcTy = mlir::LLVM::LLVMFunctionType::get(
118 voidTy,
119 {ptrTy, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
120 llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
121 llvmIntPtrType, llvmIntPtrType, i32Ty, ptrTy, ptrTy},
122 /*isVarArg=*/false);
123 auto cufLaunchClusterKernel = mlir::SymbolRefAttr::get(
124 mod.getContext(), RTNAME_STRING(CUFLaunchClusterKernel));
125 if (!funcOp) {
126 mlir::OpBuilder::InsertionGuard insertGuard(rewriter);
127 rewriter.setInsertionPointToStart(mod.getBody());
128 auto launchKernelFuncOp = rewriter.create<mlir::LLVM::LLVMFuncOp>(
129 loc, RTNAME_STRING(CUFLaunchClusterKernel), funcTy);
130 launchKernelFuncOp.setVisibility(
131 mlir::SymbolTable::Visibility::Private);
133 rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
134 op, funcTy, cufLaunchClusterKernel,
135 mlir::ValueRange{kernelPtr, adaptor.getClusterSizeX(),
136 adaptor.getClusterSizeY(), adaptor.getClusterSizeZ(),
137 adaptor.getGridSizeX(), adaptor.getGridSizeY(),
138 adaptor.getGridSizeZ(), adaptor.getBlockSizeX(),
139 adaptor.getBlockSizeY(), adaptor.getBlockSizeZ(),
140 dynamicMemorySize, kernelArgs, nullPtr});
141 } else {
142 auto procAttr =
143 op->getAttrOfType<cuf::ProcAttributeAttr>(cuf::getProcAttrName());
144 bool isGridGlobal =
145 procAttr && procAttr.getValue() == cuf::ProcAttribute::GridGlobal;
146 llvm::StringRef fctName = isGridGlobal
147 ? RTNAME_STRING(CUFLaunchCooperativeKernel)
148 : RTNAME_STRING(CUFLaunchKernel);
149 auto funcOp = mod.lookupSymbol<mlir::LLVM::LLVMFuncOp>(fctName);
150 auto funcTy = mlir::LLVM::LLVMFunctionType::get(
151 voidTy,
152 {ptrTy, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
153 llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, i32Ty, ptrTy, ptrTy},
154 /*isVarArg=*/false);
155 auto cufLaunchKernel =
156 mlir::SymbolRefAttr::get(mod.getContext(), fctName);
157 if (!funcOp) {
158 mlir::OpBuilder::InsertionGuard insertGuard(rewriter);
159 rewriter.setInsertionPointToStart(mod.getBody());
160 auto launchKernelFuncOp =
161 rewriter.create<mlir::LLVM::LLVMFuncOp>(loc, fctName, funcTy);
162 launchKernelFuncOp.setVisibility(
163 mlir::SymbolTable::Visibility::Private);
165 rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
166 op, funcTy, cufLaunchKernel,
167 mlir::ValueRange{kernelPtr, adaptor.getGridSizeX(),
168 adaptor.getGridSizeY(), adaptor.getGridSizeZ(),
169 adaptor.getBlockSizeX(), adaptor.getBlockSizeY(),
170 adaptor.getBlockSizeZ(), dynamicMemorySize,
171 kernelArgs, nullPtr});
174 return mlir::success();
178 class CUFGPUToLLVMConversion
179 : public fir::impl::CUFGPUToLLVMConversionBase<CUFGPUToLLVMConversion> {
180 public:
181 void runOnOperation() override {
182 auto *ctx = &getContext();
183 mlir::RewritePatternSet patterns(ctx);
184 mlir::ConversionTarget target(*ctx);
186 mlir::Operation *op = getOperation();
187 mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op);
188 if (!module)
189 return signalPassFailure();
191 std::optional<mlir::DataLayout> dl = fir::support::getOrSetMLIRDataLayout(
192 module, /*allowDefaultLayout=*/false);
193 fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false,
194 /*forceUnifiedTBAATree=*/false, *dl);
195 cuf::populateCUFGPUToLLVMConversionPatterns(typeConverter, patterns);
196 target.addIllegalOp<mlir::gpu::LaunchFuncOp>();
197 target.addLegalDialect<mlir::LLVM::LLVMDialect>();
198 if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
199 std::move(patterns)))) {
200 mlir::emitError(mlir::UnknownLoc::get(ctx),
201 "error in CUF GPU op conversion\n");
202 signalPassFailure();
206 } // namespace
208 void cuf::populateCUFGPUToLLVMConversionPatterns(
209 const fir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns,
210 mlir::PatternBenefit benefit) {
211 patterns.add<GPULaunchKernelConversion>(converter, benefit);