1 //===- GPUToLLVMSPV.cpp - Convert GPU operations to LLVM dialect ----------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/GPUToLLVMSPV/GPUToLLVMSPVPass.h"
11 #include "../GPUCommon/GPUOpsLowering.h"
12 #include "mlir/Conversion/GPUCommon/AttrToSPIRVConverter.h"
13 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
14 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
15 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
16 #include "mlir/Conversion/LLVMCommon/Pattern.h"
17 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
18 #include "mlir/Conversion/SPIRVCommon/AttrToLLVMConverter.h"
19 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
20 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
23 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
24 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
25 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/Matchers.h"
28 #include "mlir/IR/PatternMatch.h"
29 #include "mlir/IR/SymbolTable.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Support/LLVM.h"
32 #include "mlir/Transforms/DialectConversion.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/FormatVariadic.h"
37 #define DEBUG_TYPE "gpu-to-llvm-spv"
42 #define GEN_PASS_DEF_CONVERTGPUOPSTOLLVMSPVOPS
43 #include "mlir/Conversion/Passes.h.inc"
46 //===----------------------------------------------------------------------===//
48 //===----------------------------------------------------------------------===//
50 static LLVM::LLVMFuncOp
lookupOrCreateSPIRVFn(Operation
*symbolTable
,
52 ArrayRef
<Type
> paramTypes
,
53 Type resultType
, bool isMemNone
,
55 auto func
= dyn_cast_or_null
<LLVM::LLVMFuncOp
>(
56 SymbolTable::lookupSymbolIn(symbolTable
, name
));
58 OpBuilder
b(symbolTable
->getRegion(0));
59 func
= b
.create
<LLVM::LLVMFuncOp
>(
60 symbolTable
->getLoc(), name
,
61 LLVM::LLVMFunctionType::get(resultType
, paramTypes
));
62 func
.setCConv(LLVM::cconv::CConv::SPIR_FUNC
);
63 func
.setNoUnwind(true);
64 func
.setWillReturn(true);
67 // no externally observable effects
68 constexpr auto noModRef
= mlir::LLVM::ModRefInfo::NoModRef
;
69 auto memAttr
= b
.getAttr
<LLVM::MemoryEffectsAttr
>(
71 /*argMem=*/noModRef
, /*inaccessibleMem=*/noModRef
);
72 func
.setMemoryEffectsAttr(memAttr
);
75 func
.setConvergent(isConvergent
);
80 static LLVM::CallOp
createSPIRVBuiltinCall(Location loc
,
81 ConversionPatternRewriter
&rewriter
,
82 LLVM::LLVMFuncOp func
,
84 auto call
= rewriter
.create
<LLVM::CallOp
>(loc
, func
, args
);
85 call
.setCConv(func
.getCConv());
86 call
.setConvergentAttr(func
.getConvergentAttr());
87 call
.setNoUnwindAttr(func
.getNoUnwindAttr());
88 call
.setWillReturnAttr(func
.getWillReturnAttr());
89 call
.setMemoryEffectsAttr(func
.getMemoryEffectsAttr());
94 //===----------------------------------------------------------------------===//
96 //===----------------------------------------------------------------------===//
98 /// Replace `gpu.barrier` with an `llvm.call` to `barrier` with
99 /// `CLK_LOCAL_MEM_FENCE` argument, indicating work-group memory scope:
102 /// %c1 = llvm.mlir.constant(1: i32) : i32
103 /// llvm.call spir_funccc @_Z7barrierj(%c1) : (i32) -> ()
105 struct GPUBarrierConversion final
: ConvertOpToLLVMPattern
<gpu::BarrierOp
> {
106 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern
;
109 matchAndRewrite(gpu::BarrierOp op
, OpAdaptor adaptor
,
110 ConversionPatternRewriter
&rewriter
) const final
{
111 constexpr StringLiteral funcName
= "_Z7barrierj";
113 Operation
*moduleOp
= op
->getParentWithTrait
<OpTrait::SymbolTable
>();
114 assert(moduleOp
&& "Expecting module");
115 Type flagTy
= rewriter
.getI32Type();
116 Type voidTy
= rewriter
.getType
<LLVM::LLVMVoidType
>();
117 LLVM::LLVMFuncOp func
=
118 lookupOrCreateSPIRVFn(moduleOp
, funcName
, flagTy
, voidTy
,
119 /*isMemNone=*/false, /*isConvergent=*/true);
121 // Value used by SPIR-V backend to represent `CLK_LOCAL_MEM_FENCE`.
122 // See `llvm/lib/Target/SPIRV/SPIRVBuiltins.td`.
123 constexpr int64_t localMemFenceFlag
= 1;
124 Location loc
= op
->getLoc();
126 rewriter
.create
<LLVM::ConstantOp
>(loc
, flagTy
, localMemFenceFlag
);
127 rewriter
.replaceOp(op
, createSPIRVBuiltinCall(loc
, rewriter
, func
, flag
));
132 //===----------------------------------------------------------------------===//
134 //===----------------------------------------------------------------------===//
136 /// Replace `gpu.*` with an `llvm.call` to the corresponding SPIR-V builtin with
137 /// a constant argument for the `dimension` attribute. Return type will depend
138 /// on index width option:
140 /// // %thread_id_y = gpu.thread_id y
141 /// %c1 = llvm.mlir.constant(1: i32) : i32
142 /// %0 = llvm.call spir_funccc @_Z12get_local_idj(%c1) : (i32) -> i64
144 struct LaunchConfigConversion
: ConvertToLLVMPattern
{
145 LaunchConfigConversion(StringRef funcName
, StringRef rootOpName
,
146 MLIRContext
*context
,
147 const LLVMTypeConverter
&typeConverter
,
148 PatternBenefit benefit
)
149 : ConvertToLLVMPattern(rootOpName
, context
, typeConverter
, benefit
),
150 funcName(funcName
) {}
152 virtual gpu::Dimension
getDimension(Operation
*op
) const = 0;
155 matchAndRewrite(Operation
*op
, ArrayRef
<Value
> operands
,
156 ConversionPatternRewriter
&rewriter
) const final
{
157 Operation
*moduleOp
= op
->getParentWithTrait
<OpTrait::SymbolTable
>();
158 assert(moduleOp
&& "Expecting module");
159 Type dimTy
= rewriter
.getI32Type();
160 Type indexTy
= getTypeConverter()->getIndexType();
161 LLVM::LLVMFuncOp func
= lookupOrCreateSPIRVFn(moduleOp
, funcName
, dimTy
,
162 indexTy
, /*isMemNone=*/true,
163 /*isConvergent=*/false);
165 Location loc
= op
->getLoc();
166 gpu::Dimension dim
= getDimension(op
);
167 Value dimVal
= rewriter
.create
<LLVM::ConstantOp
>(loc
, dimTy
,
168 static_cast<int64_t>(dim
));
169 rewriter
.replaceOp(op
, createSPIRVBuiltinCall(loc
, rewriter
, func
, dimVal
));
176 template <typename SourceOp
>
177 struct LaunchConfigOpConversion final
: LaunchConfigConversion
{
178 static StringRef
getFuncName();
180 explicit LaunchConfigOpConversion(const LLVMTypeConverter
&typeConverter
,
181 PatternBenefit benefit
= 1)
182 : LaunchConfigConversion(getFuncName(), SourceOp::getOperationName(),
183 &typeConverter
.getContext(), typeConverter
,
186 gpu::Dimension
getDimension(Operation
*op
) const final
{
187 return cast
<SourceOp
>(op
).getDimension();
192 StringRef LaunchConfigOpConversion
<gpu::BlockIdOp
>::getFuncName() {
193 return "_Z12get_group_idj";
197 StringRef LaunchConfigOpConversion
<gpu::GridDimOp
>::getFuncName() {
198 return "_Z14get_num_groupsj";
202 StringRef LaunchConfigOpConversion
<gpu::BlockDimOp
>::getFuncName() {
203 return "_Z14get_local_sizej";
207 StringRef LaunchConfigOpConversion
<gpu::ThreadIdOp
>::getFuncName() {
208 return "_Z12get_local_idj";
212 StringRef LaunchConfigOpConversion
<gpu::GlobalIdOp
>::getFuncName() {
213 return "_Z13get_global_idj";
216 //===----------------------------------------------------------------------===//
218 //===----------------------------------------------------------------------===//
220 /// Replace `gpu.shuffle` with an `llvm.call` to the corresponding SPIR-V
221 /// builtin for `shuffleResult`, keeping `value` and `offset` arguments, and a
222 /// `true` constant for the `valid` result type. Conversion will only take place
223 /// if `width` is constant and equal to the `subgroup` pass option:
225 /// // %0 = gpu.shuffle idx %value, %offset, %width : f64
226 /// %0 = llvm.call spir_funccc @_Z17sub_group_shuffledj(%value, %offset)
227 /// : (f64, i32) -> f64
229 struct GPUShuffleConversion final
: ConvertOpToLLVMPattern
<gpu::ShuffleOp
> {
230 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern
;
232 static StringRef
getBaseName(gpu::ShuffleMode mode
) {
234 case gpu::ShuffleMode::IDX
:
235 return "sub_group_shuffle";
236 case gpu::ShuffleMode::XOR
:
237 return "sub_group_shuffle_xor";
238 case gpu::ShuffleMode::UP
:
239 return "sub_group_shuffle_up";
240 case gpu::ShuffleMode::DOWN
:
241 return "sub_group_shuffle_down";
243 llvm_unreachable("Unhandled shuffle mode");
246 static std::optional
<StringRef
> getTypeMangling(Type type
) {
247 return TypeSwitch
<Type
, std::optional
<StringRef
>>(type
)
248 .Case
<Float16Type
>([](auto) { return "Dhj"; })
249 .Case
<Float32Type
>([](auto) { return "fj"; })
250 .Case
<Float64Type
>([](auto) { return "dj"; })
251 .Case
<IntegerType
>([](auto intTy
) -> std::optional
<StringRef
> {
252 switch (intTy
.getWidth()) {
264 .Default([](auto) { return std::nullopt
; });
267 static std::optional
<std::string
> getFuncName(gpu::ShuffleOp op
) {
268 StringRef baseName
= getBaseName(op
.getMode());
269 std::optional
<StringRef
> typeMangling
= getTypeMangling(op
.getType(0));
272 return llvm::formatv("_Z{0}{1}{2}", baseName
.size(), baseName
,
273 typeMangling
.value());
276 /// Get the subgroup size from the target or return a default.
277 static int getSubgroupSize(Operation
*op
) {
278 return spirv::lookupTargetEnvOrDefault(op
)
283 static bool hasValidWidth(gpu::ShuffleOp op
) {
285 Value width
= op
.getWidth();
286 return matchPattern(width
, m_ConstantInt(&val
)) &&
287 val
== getSubgroupSize(op
);
291 matchAndRewrite(gpu::ShuffleOp op
, OpAdaptor adaptor
,
292 ConversionPatternRewriter
&rewriter
) const final
{
293 if (!hasValidWidth(op
))
294 return rewriter
.notifyMatchFailure(
295 op
, "shuffle width and subgroup size mismatch");
297 std::optional
<std::string
> funcName
= getFuncName(op
);
299 return rewriter
.notifyMatchFailure(op
, "unsupported value type");
301 Operation
*moduleOp
= op
->getParentWithTrait
<OpTrait::SymbolTable
>();
302 assert(moduleOp
&& "Expecting module");
303 Type valueType
= adaptor
.getValue().getType();
304 Type offsetType
= adaptor
.getOffset().getType();
305 Type resultType
= valueType
;
306 LLVM::LLVMFuncOp func
= lookupOrCreateSPIRVFn(
307 moduleOp
, funcName
.value(), {valueType
, offsetType
}, resultType
,
308 /*isMemNone=*/false, /*isConvergent=*/true);
310 Location loc
= op
->getLoc();
311 std::array
<Value
, 2> args
{adaptor
.getValue(), adaptor
.getOffset()};
313 createSPIRVBuiltinCall(loc
, rewriter
, func
, args
).getResult();
315 rewriter
.create
<LLVM::ConstantOp
>(loc
, rewriter
.getI1Type(), true);
316 rewriter
.replaceOp(op
, {result
, trueVal
});
321 class MemorySpaceToOpenCLMemorySpaceConverter final
: public TypeConverter
{
323 MemorySpaceToOpenCLMemorySpaceConverter(MLIRContext
*ctx
) {
324 addConversion([](Type t
) { return t
; });
325 addConversion([ctx
](BaseMemRefType memRefType
) -> std::optional
<Type
> {
326 // Attach global addr space attribute to memrefs with no addr space attr
327 Attribute memSpaceAttr
= memRefType
.getMemorySpace();
331 unsigned globalAddrspace
= storageClassToAddressSpace(
332 spirv::ClientAPI::OpenCL
, spirv::StorageClass::CrossWorkgroup
);
333 Attribute addrSpaceAttr
=
334 IntegerAttr::get(IntegerType::get(ctx
, 64), globalAddrspace
);
335 if (auto rankedType
= dyn_cast
<MemRefType
>(memRefType
)) {
336 return MemRefType::get(memRefType
.getShape(),
337 memRefType
.getElementType(),
338 rankedType
.getLayout(), addrSpaceAttr
);
340 return UnrankedMemRefType::get(memRefType
.getElementType(),
343 addConversion([this](FunctionType type
) {
344 auto inputs
= llvm::map_to_vector(
345 type
.getInputs(), [this](Type ty
) { return convertType(ty
); });
346 auto results
= llvm::map_to_vector(
347 type
.getResults(), [this](Type ty
) { return convertType(ty
); });
348 return FunctionType::get(type
.getContext(), inputs
, results
);
353 //===----------------------------------------------------------------------===//
354 // Subgroup query ops.
355 //===----------------------------------------------------------------------===//
357 template <typename SubgroupOp
>
358 struct GPUSubgroupOpConversion final
: ConvertOpToLLVMPattern
<SubgroupOp
> {
359 using ConvertOpToLLVMPattern
<SubgroupOp
>::ConvertOpToLLVMPattern
;
360 using ConvertToLLVMPattern::getTypeConverter
;
363 matchAndRewrite(SubgroupOp op
, typename
SubgroupOp::Adaptor adaptor
,
364 ConversionPatternRewriter
&rewriter
) const final
{
365 constexpr StringRef funcName
= [] {
366 if constexpr (std::is_same_v
<SubgroupOp
, gpu::SubgroupIdOp
>) {
367 return "_Z16get_sub_group_id";
368 } else if constexpr (std::is_same_v
<SubgroupOp
, gpu::LaneIdOp
>) {
369 return "_Z22get_sub_group_local_id";
370 } else if constexpr (std::is_same_v
<SubgroupOp
, gpu::NumSubgroupsOp
>) {
371 return "_Z18get_num_sub_groups";
372 } else if constexpr (std::is_same_v
<SubgroupOp
, gpu::SubgroupSizeOp
>) {
373 return "_Z18get_sub_group_size";
377 Operation
*moduleOp
=
378 op
->template getParentWithTrait
<OpTrait::SymbolTable
>();
379 Type resultTy
= rewriter
.getI32Type();
380 LLVM::LLVMFuncOp func
=
381 lookupOrCreateSPIRVFn(moduleOp
, funcName
, {}, resultTy
,
382 /*isMemNone=*/false, /*isConvergent=*/false);
384 Location loc
= op
->getLoc();
385 Value result
= createSPIRVBuiltinCall(loc
, rewriter
, func
, {}).getResult();
387 Type indexTy
= getTypeConverter()->getIndexType();
388 if (resultTy
!= indexTy
) {
389 if (indexTy
.getIntOrFloatBitWidth() < resultTy
.getIntOrFloatBitWidth()) {
392 result
= rewriter
.create
<LLVM::ZExtOp
>(loc
, indexTy
, result
);
395 rewriter
.replaceOp(op
, result
);
400 //===----------------------------------------------------------------------===//
401 // GPU To LLVM-SPV Pass.
402 //===----------------------------------------------------------------------===//
404 struct GPUToLLVMSPVConversionPass final
405 : impl::ConvertGpuOpsToLLVMSPVOpsBase
<GPUToLLVMSPVConversionPass
> {
408 void runOnOperation() final
{
409 MLIRContext
*context
= &getContext();
410 RewritePatternSet
patterns(context
);
412 LowerToLLVMOptions
options(context
);
413 if (indexBitwidth
!= kDeriveIndexBitwidthFromDataLayout
)
414 options
.overrideIndexBitwidth(indexBitwidth
);
416 LLVMTypeConverter
converter(context
, options
);
417 LLVMConversionTarget
target(*context
);
419 // Force OpenCL address spaces when they are not present
421 MemorySpaceToOpenCLMemorySpaceConverter
converter(context
);
422 AttrTypeReplacer replacer
;
423 replacer
.addReplacement([&converter
](BaseMemRefType origType
)
424 -> std::optional
<BaseMemRefType
> {
425 return converter
.convertType
<BaseMemRefType
>(origType
);
428 replacer
.recursivelyReplaceElementsIn(getOperation(),
429 /*replaceAttrs=*/true,
430 /*replaceLocs=*/false,
431 /*replaceTypes=*/true);
434 target
.addIllegalOp
<gpu::BarrierOp
, gpu::BlockDimOp
, gpu::BlockIdOp
,
435 gpu::GPUFuncOp
, gpu::GlobalIdOp
, gpu::GridDimOp
,
436 gpu::LaneIdOp
, gpu::NumSubgroupsOp
, gpu::ReturnOp
,
437 gpu::ShuffleOp
, gpu::SubgroupIdOp
, gpu::SubgroupSizeOp
,
440 populateGpuToLLVMSPVConversionPatterns(converter
, patterns
);
441 populateGpuMemorySpaceAttributeConversions(converter
);
443 if (failed(applyPartialConversion(getOperation(), target
,
444 std::move(patterns
))))
450 //===----------------------------------------------------------------------===//
451 // GPU To LLVM-SPV Patterns.
452 //===----------------------------------------------------------------------===//
457 gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace addressSpace
) {
458 constexpr spirv::ClientAPI clientAPI
= spirv::ClientAPI::OpenCL
;
459 return storageClassToAddressSpace(clientAPI
,
460 addressSpaceToStorageClass(addressSpace
));
464 void populateGpuToLLVMSPVConversionPatterns(
465 const LLVMTypeConverter
&typeConverter
, RewritePatternSet
&patterns
) {
466 patterns
.add
<GPUBarrierConversion
, GPUReturnOpLowering
, GPUShuffleConversion
,
467 GPUSubgroupOpConversion
<gpu::LaneIdOp
>,
468 GPUSubgroupOpConversion
<gpu::NumSubgroupsOp
>,
469 GPUSubgroupOpConversion
<gpu::SubgroupIdOp
>,
470 GPUSubgroupOpConversion
<gpu::SubgroupSizeOp
>,
471 LaunchConfigOpConversion
<gpu::BlockDimOp
>,
472 LaunchConfigOpConversion
<gpu::BlockIdOp
>,
473 LaunchConfigOpConversion
<gpu::GlobalIdOp
>,
474 LaunchConfigOpConversion
<gpu::GridDimOp
>,
475 LaunchConfigOpConversion
<gpu::ThreadIdOp
>>(typeConverter
);
476 MLIRContext
*context
= &typeConverter
.getContext();
477 unsigned privateAddressSpace
=
478 gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Private
);
479 unsigned localAddressSpace
=
480 gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Workgroup
);
481 OperationName
llvmFuncOpName(LLVM::LLVMFuncOp::getOperationName(), context
);
482 StringAttr kernelBlockSizeAttributeName
=
483 LLVM::LLVMFuncOp::getReqdWorkGroupSizeAttrName(llvmFuncOpName
);
484 patterns
.add
<GPUFuncOpLowering
>(
486 GPUFuncOpLoweringOptions
{
487 privateAddressSpace
, localAddressSpace
,
488 /*kernelAttributeName=*/{}, kernelBlockSizeAttributeName
,
489 LLVM::CConv::SPIR_KERNEL
, LLVM::CConv::SPIR_FUNC
,
490 /*encodeWorkgroupAttributionsAsArguments=*/true});
493 void populateGpuMemorySpaceAttributeConversions(TypeConverter
&typeConverter
) {
494 populateGpuMemorySpaceAttributeConversions(typeConverter
,
495 gpuAddressSpaceToOCLAddressSpace
);