1 //===-- CUFGPUToLLVMConversion.cpp ----------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "flang/Optimizer/Transforms/CUFGPUToLLVMConversion.h"
10 #include "flang/Common/Fortran.h"
11 #include "flang/Optimizer/CodeGen/TypeConverter.h"
12 #include "flang/Optimizer/Support/DataLayout.h"
13 #include "flang/Runtime/CUDA/common.h"
14 #include "mlir/Conversion/LLVMCommon/Pattern.h"
15 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 #include "llvm/Support/FormatVariadic.h"
22 #define GEN_PASS_DEF_CUFGPUTOLLVMCONVERSION
23 #include "flang/Optimizer/Transforms/Passes.h.inc"
28 using namespace Fortran::runtime
;
32 static mlir::Value
createKernelArgArray(mlir::Location loc
,
33 mlir::ValueRange operands
,
34 mlir::PatternRewriter
&rewriter
) {
36 auto *ctx
= rewriter
.getContext();
37 llvm::SmallVector
<mlir::Type
> structTypes(operands
.size(), nullptr);
39 for (auto [i
, arg
] : llvm::enumerate(operands
))
40 structTypes
[i
] = arg
.getType();
42 auto structTy
= mlir::LLVM::LLVMStructType::getLiteral(ctx
, structTypes
);
43 auto ptrTy
= mlir::LLVM::LLVMPointerType::get(rewriter
.getContext());
44 mlir::Type i32Ty
= rewriter
.getI32Type();
45 auto zero
= rewriter
.create
<mlir::LLVM::ConstantOp
>(
46 loc
, i32Ty
, rewriter
.getIntegerAttr(i32Ty
, 0));
47 auto one
= rewriter
.create
<mlir::LLVM::ConstantOp
>(
48 loc
, i32Ty
, rewriter
.getIntegerAttr(i32Ty
, 1));
49 mlir::Value argStruct
=
50 rewriter
.create
<mlir::LLVM::AllocaOp
>(loc
, ptrTy
, structTy
, one
);
51 auto size
= rewriter
.create
<mlir::LLVM::ConstantOp
>(
52 loc
, i32Ty
, rewriter
.getIntegerAttr(i32Ty
, structTypes
.size()));
53 mlir::Value argArray
=
54 rewriter
.create
<mlir::LLVM::AllocaOp
>(loc
, ptrTy
, ptrTy
, size
);
56 for (auto [i
, arg
] : llvm::enumerate(operands
)) {
57 auto indice
= rewriter
.create
<mlir::LLVM::ConstantOp
>(
58 loc
, i32Ty
, rewriter
.getIntegerAttr(i32Ty
, i
));
59 mlir::Value structMember
= rewriter
.create
<LLVM::GEPOp
>(
60 loc
, ptrTy
, structTy
, argStruct
,
61 mlir::ArrayRef
<mlir::Value
>({zero
, indice
}));
62 rewriter
.create
<LLVM::StoreOp
>(loc
, arg
, structMember
);
63 mlir::Value arrayMember
= rewriter
.create
<LLVM::GEPOp
>(
64 loc
, ptrTy
, ptrTy
, argArray
, mlir::ArrayRef
<mlir::Value
>({indice
}));
65 rewriter
.create
<LLVM::StoreOp
>(loc
, structMember
, arrayMember
);
70 struct GPULaunchKernelConversion
71 : public mlir::ConvertOpToLLVMPattern
<mlir::gpu::LaunchFuncOp
> {
72 explicit GPULaunchKernelConversion(
73 const fir::LLVMTypeConverter
&typeConverter
, mlir::PatternBenefit benefit
)
74 : mlir::ConvertOpToLLVMPattern
<mlir::gpu::LaunchFuncOp
>(typeConverter
,
77 using OpAdaptor
= typename
mlir::gpu::LaunchFuncOp::Adaptor
;
80 matchAndRewrite(mlir::gpu::LaunchFuncOp op
, OpAdaptor adaptor
,
81 mlir::ConversionPatternRewriter
&rewriter
) const override
{
82 mlir::Location loc
= op
.getLoc();
83 auto *ctx
= rewriter
.getContext();
84 mlir::ModuleOp mod
= op
->getParentOfType
<mlir::ModuleOp
>();
85 mlir::Value dynamicMemorySize
= op
.getDynamicSharedMemorySize();
86 mlir::Type i32Ty
= rewriter
.getI32Type();
87 if (!dynamicMemorySize
)
88 dynamicMemorySize
= rewriter
.create
<mlir::LLVM::ConstantOp
>(
89 loc
, i32Ty
, rewriter
.getIntegerAttr(i32Ty
, 0));
91 mlir::Value kernelArgs
=
92 createKernelArgArray(loc
, adaptor
.getKernelOperands(), rewriter
);
94 auto ptrTy
= mlir::LLVM::LLVMPointerType::get(rewriter
.getContext());
95 auto kernel
= mod
.lookupSymbol
<mlir::LLVM::LLVMFuncOp
>(op
.getKernelName());
96 mlir::Value kernelPtr
;
98 auto funcOp
= mod
.lookupSymbol
<mlir::func::FuncOp
>(op
.getKernelName());
100 return mlir::failure();
102 rewriter
.create
<LLVM::AddressOfOp
>(loc
, ptrTy
, funcOp
.getName());
105 rewriter
.create
<LLVM::AddressOfOp
>(loc
, ptrTy
, kernel
.getName());
108 auto llvmIntPtrType
= mlir::IntegerType::get(
109 ctx
, this->getTypeConverter()->getPointerBitwidth(0));
110 auto voidTy
= mlir::LLVM::LLVMVoidType::get(ctx
);
112 mlir::Value nullPtr
= rewriter
.create
<LLVM::ZeroOp
>(loc
, ptrTy
);
114 if (op
.hasClusterSize()) {
115 auto funcOp
= mod
.lookupSymbol
<mlir::LLVM::LLVMFuncOp
>(
116 RTNAME_STRING(CUFLaunchClusterKernel
));
117 auto funcTy
= mlir::LLVM::LLVMFunctionType::get(
119 {ptrTy
, llvmIntPtrType
, llvmIntPtrType
, llvmIntPtrType
,
120 llvmIntPtrType
, llvmIntPtrType
, llvmIntPtrType
, llvmIntPtrType
,
121 llvmIntPtrType
, llvmIntPtrType
, i32Ty
, ptrTy
, ptrTy
},
123 auto cufLaunchClusterKernel
= mlir::SymbolRefAttr::get(
124 mod
.getContext(), RTNAME_STRING(CUFLaunchClusterKernel
));
126 mlir::OpBuilder::InsertionGuard
insertGuard(rewriter
);
127 rewriter
.setInsertionPointToStart(mod
.getBody());
128 auto launchKernelFuncOp
= rewriter
.create
<mlir::LLVM::LLVMFuncOp
>(
129 loc
, RTNAME_STRING(CUFLaunchClusterKernel
), funcTy
);
130 launchKernelFuncOp
.setVisibility(
131 mlir::SymbolTable::Visibility::Private
);
133 rewriter
.replaceOpWithNewOp
<mlir::LLVM::CallOp
>(
134 op
, funcTy
, cufLaunchClusterKernel
,
135 mlir::ValueRange
{kernelPtr
, adaptor
.getClusterSizeX(),
136 adaptor
.getClusterSizeY(), adaptor
.getClusterSizeZ(),
137 adaptor
.getGridSizeX(), adaptor
.getGridSizeY(),
138 adaptor
.getGridSizeZ(), adaptor
.getBlockSizeX(),
139 adaptor
.getBlockSizeY(), adaptor
.getBlockSizeZ(),
140 dynamicMemorySize
, kernelArgs
, nullPtr
});
142 auto funcOp
= mod
.lookupSymbol
<mlir::LLVM::LLVMFuncOp
>(
143 RTNAME_STRING(CUFLaunchKernel
));
144 auto funcTy
= mlir::LLVM::LLVMFunctionType::get(
146 {ptrTy
, llvmIntPtrType
, llvmIntPtrType
, llvmIntPtrType
,
147 llvmIntPtrType
, llvmIntPtrType
, llvmIntPtrType
, i32Ty
, ptrTy
, ptrTy
},
149 auto cufLaunchKernel
= mlir::SymbolRefAttr::get(
150 mod
.getContext(), RTNAME_STRING(CUFLaunchKernel
));
152 mlir::OpBuilder::InsertionGuard
insertGuard(rewriter
);
153 rewriter
.setInsertionPointToStart(mod
.getBody());
154 auto launchKernelFuncOp
= rewriter
.create
<mlir::LLVM::LLVMFuncOp
>(
155 loc
, RTNAME_STRING(CUFLaunchKernel
), funcTy
);
156 launchKernelFuncOp
.setVisibility(
157 mlir::SymbolTable::Visibility::Private
);
159 rewriter
.replaceOpWithNewOp
<mlir::LLVM::CallOp
>(
160 op
, funcTy
, cufLaunchKernel
,
161 mlir::ValueRange
{kernelPtr
, adaptor
.getGridSizeX(),
162 adaptor
.getGridSizeY(), adaptor
.getGridSizeZ(),
163 adaptor
.getBlockSizeX(), adaptor
.getBlockSizeY(),
164 adaptor
.getBlockSizeZ(), dynamicMemorySize
,
165 kernelArgs
, nullPtr
});
168 return mlir::success();
172 class CUFGPUToLLVMConversion
173 : public fir::impl::CUFGPUToLLVMConversionBase
<CUFGPUToLLVMConversion
> {
175 void runOnOperation() override
{
176 auto *ctx
= &getContext();
177 mlir::RewritePatternSet
patterns(ctx
);
178 mlir::ConversionTarget
target(*ctx
);
180 mlir::Operation
*op
= getOperation();
181 mlir::ModuleOp module
= mlir::dyn_cast
<mlir::ModuleOp
>(op
);
183 return signalPassFailure();
185 std::optional
<mlir::DataLayout
> dl
=
186 fir::support::getOrSetDataLayout(module
, /*allowDefaultLayout=*/false);
187 fir::LLVMTypeConverter
typeConverter(module
, /*applyTBAA=*/false,
188 /*forceUnifiedTBAATree=*/false, *dl
);
189 cuf::populateCUFGPUToLLVMConversionPatterns(typeConverter
, patterns
);
190 target
.addIllegalOp
<mlir::gpu::LaunchFuncOp
>();
191 target
.addLegalDialect
<mlir::LLVM::LLVMDialect
>();
192 if (mlir::failed(mlir::applyPartialConversion(getOperation(), target
,
193 std::move(patterns
)))) {
194 mlir::emitError(mlir::UnknownLoc::get(ctx
),
195 "error in CUF GPU op conversion\n");
202 void cuf::populateCUFGPUToLLVMConversionPatterns(
203 const fir::LLVMTypeConverter
&converter
, mlir::RewritePatternSet
&patterns
,
204 mlir::PatternBenefit benefit
) {
205 patterns
.add
<GPULaunchKernelConversion
>(converter
, benefit
);