[ELF] Refactor merge-* tests
[llvm-project.git] / mlir / lib / Conversion / GPUToLLVMSPV / GPUToLLVMSPV.cpp
blobbb6a38c0e76edf51524d83ae62831fc629cf406b
1 //===- GPUToLLVMSPV.cpp - Convert GPU operations to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/GPUToLLVMSPV/GPUToLLVMSPVPass.h"
11 #include "../GPUCommon/GPUOpsLowering.h"
12 #include "mlir/Conversion/GPUCommon/AttrToSPIRVConverter.h"
13 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
14 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
15 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
16 #include "mlir/Conversion/LLVMCommon/Pattern.h"
17 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
18 #include "mlir/Conversion/SPIRVCommon/AttrToLLVMConverter.h"
19 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
20 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
21 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
23 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
24 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
25 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
26 #include "mlir/IR/BuiltinTypes.h"
27 #include "mlir/IR/Matchers.h"
28 #include "mlir/IR/PatternMatch.h"
29 #include "mlir/IR/SymbolTable.h"
30 #include "mlir/Pass/Pass.h"
31 #include "mlir/Support/LLVM.h"
32 #include "mlir/Transforms/DialectConversion.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/FormatVariadic.h"
37 #define DEBUG_TYPE "gpu-to-llvm-spv"
39 using namespace mlir;
41 namespace mlir {
42 #define GEN_PASS_DEF_CONVERTGPUOPSTOLLVMSPVOPS
43 #include "mlir/Conversion/Passes.h.inc"
44 } // namespace mlir
46 //===----------------------------------------------------------------------===//
47 // Helper Functions
48 //===----------------------------------------------------------------------===//
50 static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
51 StringRef name,
52 ArrayRef<Type> paramTypes,
53 Type resultType, bool isMemNone,
54 bool isConvergent) {
55 auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
56 SymbolTable::lookupSymbolIn(symbolTable, name));
57 if (!func) {
58 OpBuilder b(symbolTable->getRegion(0));
59 func = b.create<LLVM::LLVMFuncOp>(
60 symbolTable->getLoc(), name,
61 LLVM::LLVMFunctionType::get(resultType, paramTypes));
62 func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
63 func.setNoUnwind(true);
64 func.setWillReturn(true);
66 if (isMemNone) {
67 // no externally observable effects
68 constexpr auto noModRef = mlir::LLVM::ModRefInfo::NoModRef;
69 auto memAttr = b.getAttr<LLVM::MemoryEffectsAttr>(
70 /*other=*/noModRef,
71 /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
72 func.setMemoryEffectsAttr(memAttr);
75 func.setConvergent(isConvergent);
77 return func;
80 static LLVM::CallOp createSPIRVBuiltinCall(Location loc,
81 ConversionPatternRewriter &rewriter,
82 LLVM::LLVMFuncOp func,
83 ValueRange args) {
84 auto call = rewriter.create<LLVM::CallOp>(loc, func, args);
85 call.setCConv(func.getCConv());
86 call.setConvergentAttr(func.getConvergentAttr());
87 call.setNoUnwindAttr(func.getNoUnwindAttr());
88 call.setWillReturnAttr(func.getWillReturnAttr());
89 call.setMemoryEffectsAttr(func.getMemoryEffectsAttr());
90 return call;
93 namespace {
94 //===----------------------------------------------------------------------===//
95 // Barriers
96 //===----------------------------------------------------------------------===//
98 /// Replace `gpu.barrier` with an `llvm.call` to `barrier` with
99 /// `CLK_LOCAL_MEM_FENCE` argument, indicating work-group memory scope:
100 /// ```
101 /// // gpu.barrier
102 /// %c1 = llvm.mlir.constant(1: i32) : i32
103 /// llvm.call spir_funccc @_Z7barrierj(%c1) : (i32) -> ()
104 /// ```
105 struct GPUBarrierConversion final : ConvertOpToLLVMPattern<gpu::BarrierOp> {
106 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
108 LogicalResult
109 matchAndRewrite(gpu::BarrierOp op, OpAdaptor adaptor,
110 ConversionPatternRewriter &rewriter) const final {
111 constexpr StringLiteral funcName = "_Z7barrierj";
113 Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
114 assert(moduleOp && "Expecting module");
115 Type flagTy = rewriter.getI32Type();
116 Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
117 LLVM::LLVMFuncOp func =
118 lookupOrCreateSPIRVFn(moduleOp, funcName, flagTy, voidTy,
119 /*isMemNone=*/false, /*isConvergent=*/true);
121 // Value used by SPIR-V backend to represent `CLK_LOCAL_MEM_FENCE`.
122 // See `llvm/lib/Target/SPIRV/SPIRVBuiltins.td`.
123 constexpr int64_t localMemFenceFlag = 1;
124 Location loc = op->getLoc();
125 Value flag =
126 rewriter.create<LLVM::ConstantOp>(loc, flagTy, localMemFenceFlag);
127 rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, flag));
128 return success();
132 //===----------------------------------------------------------------------===//
133 // SPIR-V Builtins
134 //===----------------------------------------------------------------------===//
136 /// Replace `gpu.*` with an `llvm.call` to the corresponding SPIR-V builtin with
137 /// a constant argument for the `dimension` attribute. Return type will depend
138 /// on index width option:
139 /// ```
140 /// // %thread_id_y = gpu.thread_id y
141 /// %c1 = llvm.mlir.constant(1: i32) : i32
142 /// %0 = llvm.call spir_funccc @_Z12get_local_idj(%c1) : (i32) -> i64
143 /// ```
144 struct LaunchConfigConversion : ConvertToLLVMPattern {
145 LaunchConfigConversion(StringRef funcName, StringRef rootOpName,
146 MLIRContext *context,
147 const LLVMTypeConverter &typeConverter,
148 PatternBenefit benefit)
149 : ConvertToLLVMPattern(rootOpName, context, typeConverter, benefit),
150 funcName(funcName) {}
152 virtual gpu::Dimension getDimension(Operation *op) const = 0;
154 LogicalResult
155 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
156 ConversionPatternRewriter &rewriter) const final {
157 Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
158 assert(moduleOp && "Expecting module");
159 Type dimTy = rewriter.getI32Type();
160 Type indexTy = getTypeConverter()->getIndexType();
161 LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(moduleOp, funcName, dimTy,
162 indexTy, /*isMemNone=*/true,
163 /*isConvergent=*/false);
165 Location loc = op->getLoc();
166 gpu::Dimension dim = getDimension(op);
167 Value dimVal = rewriter.create<LLVM::ConstantOp>(loc, dimTy,
168 static_cast<int64_t>(dim));
169 rewriter.replaceOp(op, createSPIRVBuiltinCall(loc, rewriter, func, dimVal));
170 return success();
173 StringRef funcName;
176 template <typename SourceOp>
177 struct LaunchConfigOpConversion final : LaunchConfigConversion {
178 static StringRef getFuncName();
180 explicit LaunchConfigOpConversion(const LLVMTypeConverter &typeConverter,
181 PatternBenefit benefit = 1)
182 : LaunchConfigConversion(getFuncName(), SourceOp::getOperationName(),
183 &typeConverter.getContext(), typeConverter,
184 benefit) {}
186 gpu::Dimension getDimension(Operation *op) const final {
187 return cast<SourceOp>(op).getDimension();
191 template <>
192 StringRef LaunchConfigOpConversion<gpu::BlockIdOp>::getFuncName() {
193 return "_Z12get_group_idj";
196 template <>
197 StringRef LaunchConfigOpConversion<gpu::GridDimOp>::getFuncName() {
198 return "_Z14get_num_groupsj";
201 template <>
202 StringRef LaunchConfigOpConversion<gpu::BlockDimOp>::getFuncName() {
203 return "_Z14get_local_sizej";
206 template <>
207 StringRef LaunchConfigOpConversion<gpu::ThreadIdOp>::getFuncName() {
208 return "_Z12get_local_idj";
211 template <>
212 StringRef LaunchConfigOpConversion<gpu::GlobalIdOp>::getFuncName() {
213 return "_Z13get_global_idj";
216 //===----------------------------------------------------------------------===//
217 // Shuffles
218 //===----------------------------------------------------------------------===//
220 /// Replace `gpu.shuffle` with an `llvm.call` to the corresponding SPIR-V
221 /// builtin for `shuffleResult`, keeping `value` and `offset` arguments, and a
222 /// `true` constant for the `valid` result type. Conversion will only take place
223 /// if `width` is constant and equal to the `subgroup` pass option:
224 /// ```
225 /// // %0 = gpu.shuffle idx %value, %offset, %width : f64
226 /// %0 = llvm.call spir_funccc @_Z17sub_group_shuffledj(%value, %offset)
227 /// : (f64, i32) -> f64
228 /// ```
229 struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
230 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
232 static StringRef getBaseName(gpu::ShuffleMode mode) {
233 switch (mode) {
234 case gpu::ShuffleMode::IDX:
235 return "sub_group_shuffle";
236 case gpu::ShuffleMode::XOR:
237 return "sub_group_shuffle_xor";
238 case gpu::ShuffleMode::UP:
239 return "sub_group_shuffle_up";
240 case gpu::ShuffleMode::DOWN:
241 return "sub_group_shuffle_down";
243 llvm_unreachable("Unhandled shuffle mode");
246 static std::optional<StringRef> getTypeMangling(Type type) {
247 return TypeSwitch<Type, std::optional<StringRef>>(type)
248 .Case<Float16Type>([](auto) { return "Dhj"; })
249 .Case<Float32Type>([](auto) { return "fj"; })
250 .Case<Float64Type>([](auto) { return "dj"; })
251 .Case<IntegerType>([](auto intTy) -> std::optional<StringRef> {
252 switch (intTy.getWidth()) {
253 case 8:
254 return "cj";
255 case 16:
256 return "sj";
257 case 32:
258 return "ij";
259 case 64:
260 return "lj";
262 return std::nullopt;
264 .Default([](auto) { return std::nullopt; });
267 static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
268 StringRef baseName = getBaseName(op.getMode());
269 std::optional<StringRef> typeMangling = getTypeMangling(op.getType(0));
270 if (!typeMangling)
271 return std::nullopt;
272 return llvm::formatv("_Z{0}{1}{2}", baseName.size(), baseName,
273 typeMangling.value());
276 /// Get the subgroup size from the target or return a default.
277 static int getSubgroupSize(Operation *op) {
278 return spirv::lookupTargetEnvOrDefault(op)
279 .getResourceLimits()
280 .getSubgroupSize();
283 static bool hasValidWidth(gpu::ShuffleOp op) {
284 llvm::APInt val;
285 Value width = op.getWidth();
286 return matchPattern(width, m_ConstantInt(&val)) &&
287 val == getSubgroupSize(op);
290 LogicalResult
291 matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
292 ConversionPatternRewriter &rewriter) const final {
293 if (!hasValidWidth(op))
294 return rewriter.notifyMatchFailure(
295 op, "shuffle width and subgroup size mismatch");
297 std::optional<std::string> funcName = getFuncName(op);
298 if (!funcName)
299 return rewriter.notifyMatchFailure(op, "unsupported value type");
301 Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
302 assert(moduleOp && "Expecting module");
303 Type valueType = adaptor.getValue().getType();
304 Type offsetType = adaptor.getOffset().getType();
305 Type resultType = valueType;
306 LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
307 moduleOp, funcName.value(), {valueType, offsetType}, resultType,
308 /*isMemNone=*/false, /*isConvergent=*/true);
310 Location loc = op->getLoc();
311 std::array<Value, 2> args{adaptor.getValue(), adaptor.getOffset()};
312 Value result =
313 createSPIRVBuiltinCall(loc, rewriter, func, args).getResult();
314 Value trueVal =
315 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true);
316 rewriter.replaceOp(op, {result, trueVal});
317 return success();
321 class MemorySpaceToOpenCLMemorySpaceConverter final : public TypeConverter {
322 public:
323 MemorySpaceToOpenCLMemorySpaceConverter(MLIRContext *ctx) {
324 addConversion([](Type t) { return t; });
325 addConversion([ctx](BaseMemRefType memRefType) -> std::optional<Type> {
326 // Attach global addr space attribute to memrefs with no addr space attr
327 Attribute memSpaceAttr = memRefType.getMemorySpace();
328 if (memSpaceAttr)
329 return std::nullopt;
331 unsigned globalAddrspace = storageClassToAddressSpace(
332 spirv::ClientAPI::OpenCL, spirv::StorageClass::CrossWorkgroup);
333 Attribute addrSpaceAttr =
334 IntegerAttr::get(IntegerType::get(ctx, 64), globalAddrspace);
335 if (auto rankedType = dyn_cast<MemRefType>(memRefType)) {
336 return MemRefType::get(memRefType.getShape(),
337 memRefType.getElementType(),
338 rankedType.getLayout(), addrSpaceAttr);
340 return UnrankedMemRefType::get(memRefType.getElementType(),
341 addrSpaceAttr);
343 addConversion([this](FunctionType type) {
344 auto inputs = llvm::map_to_vector(
345 type.getInputs(), [this](Type ty) { return convertType(ty); });
346 auto results = llvm::map_to_vector(
347 type.getResults(), [this](Type ty) { return convertType(ty); });
348 return FunctionType::get(type.getContext(), inputs, results);
353 //===----------------------------------------------------------------------===//
354 // Subgroup query ops.
355 //===----------------------------------------------------------------------===//
357 template <typename SubgroupOp>
358 struct GPUSubgroupOpConversion final : ConvertOpToLLVMPattern<SubgroupOp> {
359 using ConvertOpToLLVMPattern<SubgroupOp>::ConvertOpToLLVMPattern;
360 using ConvertToLLVMPattern::getTypeConverter;
362 LogicalResult
363 matchAndRewrite(SubgroupOp op, typename SubgroupOp::Adaptor adaptor,
364 ConversionPatternRewriter &rewriter) const final {
365 constexpr StringRef funcName = [] {
366 if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupIdOp>) {
367 return "_Z16get_sub_group_id";
368 } else if constexpr (std::is_same_v<SubgroupOp, gpu::LaneIdOp>) {
369 return "_Z22get_sub_group_local_id";
370 } else if constexpr (std::is_same_v<SubgroupOp, gpu::NumSubgroupsOp>) {
371 return "_Z18get_num_sub_groups";
372 } else if constexpr (std::is_same_v<SubgroupOp, gpu::SubgroupSizeOp>) {
373 return "_Z18get_sub_group_size";
375 }();
377 Operation *moduleOp =
378 op->template getParentWithTrait<OpTrait::SymbolTable>();
379 Type resultTy = rewriter.getI32Type();
380 LLVM::LLVMFuncOp func =
381 lookupOrCreateSPIRVFn(moduleOp, funcName, {}, resultTy,
382 /*isMemNone=*/false, /*isConvergent=*/false);
384 Location loc = op->getLoc();
385 Value result = createSPIRVBuiltinCall(loc, rewriter, func, {}).getResult();
387 Type indexTy = getTypeConverter()->getIndexType();
388 if (resultTy != indexTy) {
389 if (indexTy.getIntOrFloatBitWidth() < resultTy.getIntOrFloatBitWidth()) {
390 return failure();
392 result = rewriter.create<LLVM::ZExtOp>(loc, indexTy, result);
395 rewriter.replaceOp(op, result);
396 return success();
400 //===----------------------------------------------------------------------===//
401 // GPU To LLVM-SPV Pass.
402 //===----------------------------------------------------------------------===//
404 struct GPUToLLVMSPVConversionPass final
405 : impl::ConvertGpuOpsToLLVMSPVOpsBase<GPUToLLVMSPVConversionPass> {
406 using Base::Base;
408 void runOnOperation() final {
409 MLIRContext *context = &getContext();
410 RewritePatternSet patterns(context);
412 LowerToLLVMOptions options(context);
413 if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
414 options.overrideIndexBitwidth(indexBitwidth);
416 LLVMTypeConverter converter(context, options);
417 LLVMConversionTarget target(*context);
419 // Force OpenCL address spaces when they are not present
421 MemorySpaceToOpenCLMemorySpaceConverter converter(context);
422 AttrTypeReplacer replacer;
423 replacer.addReplacement([&converter](BaseMemRefType origType)
424 -> std::optional<BaseMemRefType> {
425 return converter.convertType<BaseMemRefType>(origType);
428 replacer.recursivelyReplaceElementsIn(getOperation(),
429 /*replaceAttrs=*/true,
430 /*replaceLocs=*/false,
431 /*replaceTypes=*/true);
434 target.addIllegalOp<gpu::BarrierOp, gpu::BlockDimOp, gpu::BlockIdOp,
435 gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp,
436 gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp,
437 gpu::ShuffleOp, gpu::SubgroupIdOp, gpu::SubgroupSizeOp,
438 gpu::ThreadIdOp>();
440 populateGpuToLLVMSPVConversionPatterns(converter, patterns);
441 populateGpuMemorySpaceAttributeConversions(converter);
443 if (failed(applyPartialConversion(getOperation(), target,
444 std::move(patterns))))
445 signalPassFailure();
448 } // namespace
450 //===----------------------------------------------------------------------===//
451 // GPU To LLVM-SPV Patterns.
452 //===----------------------------------------------------------------------===//
454 namespace mlir {
455 namespace {
456 static unsigned
457 gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace addressSpace) {
458 constexpr spirv::ClientAPI clientAPI = spirv::ClientAPI::OpenCL;
459 return storageClassToAddressSpace(clientAPI,
460 addressSpaceToStorageClass(addressSpace));
462 } // namespace
464 void populateGpuToLLVMSPVConversionPatterns(
465 const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
466 patterns.add<GPUBarrierConversion, GPUReturnOpLowering, GPUShuffleConversion,
467 GPUSubgroupOpConversion<gpu::LaneIdOp>,
468 GPUSubgroupOpConversion<gpu::NumSubgroupsOp>,
469 GPUSubgroupOpConversion<gpu::SubgroupIdOp>,
470 GPUSubgroupOpConversion<gpu::SubgroupSizeOp>,
471 LaunchConfigOpConversion<gpu::BlockDimOp>,
472 LaunchConfigOpConversion<gpu::BlockIdOp>,
473 LaunchConfigOpConversion<gpu::GlobalIdOp>,
474 LaunchConfigOpConversion<gpu::GridDimOp>,
475 LaunchConfigOpConversion<gpu::ThreadIdOp>>(typeConverter);
476 MLIRContext *context = &typeConverter.getContext();
477 unsigned privateAddressSpace =
478 gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Private);
479 unsigned localAddressSpace =
480 gpuAddressSpaceToOCLAddressSpace(gpu::AddressSpace::Workgroup);
481 OperationName llvmFuncOpName(LLVM::LLVMFuncOp::getOperationName(), context);
482 StringAttr kernelBlockSizeAttributeName =
483 LLVM::LLVMFuncOp::getReqdWorkGroupSizeAttrName(llvmFuncOpName);
484 patterns.add<GPUFuncOpLowering>(
485 typeConverter,
486 GPUFuncOpLoweringOptions{
487 privateAddressSpace, localAddressSpace,
488 /*kernelAttributeName=*/{}, kernelBlockSizeAttributeName,
489 LLVM::CConv::SPIR_KERNEL, LLVM::CConv::SPIR_FUNC,
490 /*encodeWorkgroupAttributionsAsArguments=*/true});
493 void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter) {
494 populateGpuMemorySpaceAttributeConversions(typeConverter,
495 gpuAddressSpaceToOCLAddressSpace);
497 } // namespace mlir