1 //===-- CUFGPUToLLVMConversion.cpp ----------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "flang/Optimizer/Transforms/CUFGPUToLLVMConversion.h"
10 #include "flang/Common/Fortran.h"
11 #include "flang/Optimizer/CodeGen/TypeConverter.h"
12 #include "flang/Optimizer/Support/DataLayout.h"
13 #include "flang/Runtime/CUDA/common.h"
14 #include "mlir/Conversion/LLVMCommon/Pattern.h"
15 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 #include "llvm/Support/FormatVariadic.h"
22 #define GEN_PASS_DEF_CUFGPUTOLLVMCONVERSION
23 #include "flang/Optimizer/Transforms/Passes.h.inc"
28 using namespace Fortran::runtime
;
32 static mlir::Value
createKernelArgArray(mlir::Location loc
,
33 mlir::ValueRange operands
,
34 mlir::PatternRewriter
&rewriter
) {
36 auto *ctx
= rewriter
.getContext();
37 llvm::SmallVector
<mlir::Type
> structTypes(operands
.size(), nullptr);
39 for (auto [i
, arg
] : llvm::enumerate(operands
))
40 structTypes
[i
] = arg
.getType();
42 auto structTy
= mlir::LLVM::LLVMStructType::getLiteral(ctx
, structTypes
);
43 auto ptrTy
= mlir::LLVM::LLVMPointerType::get(rewriter
.getContext());
44 mlir::Type i32Ty
= rewriter
.getI32Type();
45 auto zero
= rewriter
.create
<mlir::LLVM::ConstantOp
>(
46 loc
, i32Ty
, rewriter
.getIntegerAttr(i32Ty
, 0));
47 auto one
= rewriter
.create
<mlir::LLVM::ConstantOp
>(
48 loc
, i32Ty
, rewriter
.getIntegerAttr(i32Ty
, 1));
49 mlir::Value argStruct
=
50 rewriter
.create
<mlir::LLVM::AllocaOp
>(loc
, ptrTy
, structTy
, one
);
51 auto size
= rewriter
.create
<mlir::LLVM::ConstantOp
>(
52 loc
, i32Ty
, rewriter
.getIntegerAttr(i32Ty
, structTypes
.size()));
53 mlir::Value argArray
=
54 rewriter
.create
<mlir::LLVM::AllocaOp
>(loc
, ptrTy
, ptrTy
, size
);
56 for (auto [i
, arg
] : llvm::enumerate(operands
)) {
57 auto indice
= rewriter
.create
<mlir::LLVM::ConstantOp
>(
58 loc
, i32Ty
, rewriter
.getIntegerAttr(i32Ty
, i
));
59 mlir::Value structMember
= rewriter
.create
<LLVM::GEPOp
>(
60 loc
, ptrTy
, structTy
, argStruct
,
61 mlir::ArrayRef
<mlir::Value
>({zero
, indice
}));
62 rewriter
.create
<LLVM::StoreOp
>(loc
, arg
, structMember
);
63 mlir::Value arrayMember
= rewriter
.create
<LLVM::GEPOp
>(
64 loc
, ptrTy
, ptrTy
, argArray
, mlir::ArrayRef
<mlir::Value
>({indice
}));
65 rewriter
.create
<LLVM::StoreOp
>(loc
, structMember
, arrayMember
);
70 struct GPULaunchKernelConversion
71 : public mlir::ConvertOpToLLVMPattern
<mlir::gpu::LaunchFuncOp
> {
72 explicit GPULaunchKernelConversion(
73 const fir::LLVMTypeConverter
&typeConverter
, mlir::PatternBenefit benefit
)
74 : mlir::ConvertOpToLLVMPattern
<mlir::gpu::LaunchFuncOp
>(typeConverter
,
77 using OpAdaptor
= typename
mlir::gpu::LaunchFuncOp::Adaptor
;
80 matchAndRewrite(mlir::gpu::LaunchFuncOp op
, OpAdaptor adaptor
,
81 mlir::ConversionPatternRewriter
&rewriter
) const override
{
82 mlir::Location loc
= op
.getLoc();
83 auto *ctx
= rewriter
.getContext();
84 mlir::ModuleOp mod
= op
->getParentOfType
<mlir::ModuleOp
>();
85 mlir::Value dynamicMemorySize
= op
.getDynamicSharedMemorySize();
86 mlir::Type i32Ty
= rewriter
.getI32Type();
87 if (!dynamicMemorySize
)
88 dynamicMemorySize
= rewriter
.create
<mlir::LLVM::ConstantOp
>(
89 loc
, i32Ty
, rewriter
.getIntegerAttr(i32Ty
, 0));
91 mlir::Value kernelArgs
=
92 createKernelArgArray(loc
, adaptor
.getKernelOperands(), rewriter
);
94 auto ptrTy
= mlir::LLVM::LLVMPointerType::get(rewriter
.getContext());
95 auto kernel
= mod
.lookupSymbol
<mlir::LLVM::LLVMFuncOp
>(op
.getKernelName());
96 mlir::Value kernelPtr
;
98 auto funcOp
= mod
.lookupSymbol
<mlir::func::FuncOp
>(op
.getKernelName());
100 return mlir::failure();
102 rewriter
.create
<LLVM::AddressOfOp
>(loc
, ptrTy
, funcOp
.getName());
105 rewriter
.create
<LLVM::AddressOfOp
>(loc
, ptrTy
, kernel
.getName());
108 auto llvmIntPtrType
= mlir::IntegerType::get(
109 ctx
, this->getTypeConverter()->getPointerBitwidth(0));
110 auto voidTy
= mlir::LLVM::LLVMVoidType::get(ctx
);
112 mlir::Value nullPtr
= rewriter
.create
<LLVM::ZeroOp
>(loc
, ptrTy
);
114 if (op
.hasClusterSize()) {
115 auto funcOp
= mod
.lookupSymbol
<mlir::LLVM::LLVMFuncOp
>(
116 RTNAME_STRING(CUFLaunchClusterKernel
));
117 auto funcTy
= mlir::LLVM::LLVMFunctionType::get(
119 {ptrTy
, llvmIntPtrType
, llvmIntPtrType
, llvmIntPtrType
,
120 llvmIntPtrType
, llvmIntPtrType
, llvmIntPtrType
, llvmIntPtrType
,
121 llvmIntPtrType
, llvmIntPtrType
, i32Ty
, ptrTy
, ptrTy
},
123 auto cufLaunchClusterKernel
= mlir::SymbolRefAttr::get(
124 mod
.getContext(), RTNAME_STRING(CUFLaunchClusterKernel
));
126 mlir::OpBuilder::InsertionGuard
insertGuard(rewriter
);
127 rewriter
.setInsertionPointToStart(mod
.getBody());
128 auto launchKernelFuncOp
= rewriter
.create
<mlir::LLVM::LLVMFuncOp
>(
129 loc
, RTNAME_STRING(CUFLaunchClusterKernel
), funcTy
);
130 launchKernelFuncOp
.setVisibility(
131 mlir::SymbolTable::Visibility::Private
);
133 rewriter
.replaceOpWithNewOp
<mlir::LLVM::CallOp
>(
134 op
, funcTy
, cufLaunchClusterKernel
,
135 mlir::ValueRange
{kernelPtr
, adaptor
.getClusterSizeX(),
136 adaptor
.getClusterSizeY(), adaptor
.getClusterSizeZ(),
137 adaptor
.getGridSizeX(), adaptor
.getGridSizeY(),
138 adaptor
.getGridSizeZ(), adaptor
.getBlockSizeX(),
139 adaptor
.getBlockSizeY(), adaptor
.getBlockSizeZ(),
140 dynamicMemorySize
, kernelArgs
, nullPtr
});
143 op
->getAttrOfType
<cuf::ProcAttributeAttr
>(cuf::getProcAttrName());
145 procAttr
&& procAttr
.getValue() == cuf::ProcAttribute::GridGlobal
;
146 llvm::StringRef fctName
= isGridGlobal
147 ? RTNAME_STRING(CUFLaunchCooperativeKernel
)
148 : RTNAME_STRING(CUFLaunchKernel
);
149 auto funcOp
= mod
.lookupSymbol
<mlir::LLVM::LLVMFuncOp
>(fctName
);
150 auto funcTy
= mlir::LLVM::LLVMFunctionType::get(
152 {ptrTy
, llvmIntPtrType
, llvmIntPtrType
, llvmIntPtrType
,
153 llvmIntPtrType
, llvmIntPtrType
, llvmIntPtrType
, i32Ty
, ptrTy
, ptrTy
},
155 auto cufLaunchKernel
=
156 mlir::SymbolRefAttr::get(mod
.getContext(), fctName
);
158 mlir::OpBuilder::InsertionGuard
insertGuard(rewriter
);
159 rewriter
.setInsertionPointToStart(mod
.getBody());
160 auto launchKernelFuncOp
=
161 rewriter
.create
<mlir::LLVM::LLVMFuncOp
>(loc
, fctName
, funcTy
);
162 launchKernelFuncOp
.setVisibility(
163 mlir::SymbolTable::Visibility::Private
);
165 rewriter
.replaceOpWithNewOp
<mlir::LLVM::CallOp
>(
166 op
, funcTy
, cufLaunchKernel
,
167 mlir::ValueRange
{kernelPtr
, adaptor
.getGridSizeX(),
168 adaptor
.getGridSizeY(), adaptor
.getGridSizeZ(),
169 adaptor
.getBlockSizeX(), adaptor
.getBlockSizeY(),
170 adaptor
.getBlockSizeZ(), dynamicMemorySize
,
171 kernelArgs
, nullPtr
});
174 return mlir::success();
178 class CUFGPUToLLVMConversion
179 : public fir::impl::CUFGPUToLLVMConversionBase
<CUFGPUToLLVMConversion
> {
181 void runOnOperation() override
{
182 auto *ctx
= &getContext();
183 mlir::RewritePatternSet
patterns(ctx
);
184 mlir::ConversionTarget
target(*ctx
);
186 mlir::Operation
*op
= getOperation();
187 mlir::ModuleOp module
= mlir::dyn_cast
<mlir::ModuleOp
>(op
);
189 return signalPassFailure();
191 std::optional
<mlir::DataLayout
> dl
= fir::support::getOrSetMLIRDataLayout(
192 module
, /*allowDefaultLayout=*/false);
193 fir::LLVMTypeConverter
typeConverter(module
, /*applyTBAA=*/false,
194 /*forceUnifiedTBAATree=*/false, *dl
);
195 cuf::populateCUFGPUToLLVMConversionPatterns(typeConverter
, patterns
);
196 target
.addIllegalOp
<mlir::gpu::LaunchFuncOp
>();
197 target
.addLegalDialect
<mlir::LLVM::LLVMDialect
>();
198 if (mlir::failed(mlir::applyPartialConversion(getOperation(), target
,
199 std::move(patterns
)))) {
200 mlir::emitError(mlir::UnknownLoc::get(ctx
),
201 "error in CUF GPU op conversion\n");
208 void cuf::populateCUFGPUToLLVMConversionPatterns(
209 const fir::LLVMTypeConverter
&converter
, mlir::RewritePatternSet
&patterns
,
210 mlir::PatternBenefit benefit
) {
211 patterns
.add
<GPULaunchKernelConversion
>(converter
, benefit
);