1 //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15 #include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Pass/Pass.h"
22 #include "llvm/ADT/STLExtras.h"
26 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDL
27 #include "mlir/Conversion/Passes.h.inc"
31 using namespace mlir::amdgpu
;
33 static Value
createI32Constant(ConversionPatternRewriter
&rewriter
,
34 Location loc
, int32_t value
) {
35 Type llvmI32
= rewriter
.getI32Type();
36 return rewriter
.create
<LLVM::ConstantOp
>(loc
, llvmI32
, value
);
39 static Value
createI1Constant(ConversionPatternRewriter
&rewriter
, Location loc
,
41 Type llvmI1
= rewriter
.getI1Type();
42 return rewriter
.create
<LLVM::ConstantOp
>(loc
, llvmI1
, value
);
46 // Define commonly used chipsets versions for convenience.
47 constexpr Chipset kGfx908
= Chipset(9, 0, 8);
48 constexpr Chipset kGfx90a
= Chipset(9, 0, 0xa);
49 constexpr Chipset kGfx940
= Chipset(9, 4, 0);
51 /// Define lowering patterns for raw buffer ops
52 template <typename GpuOp
, typename Intrinsic
>
53 struct RawBufferOpLowering
: public ConvertOpToLLVMPattern
<GpuOp
> {
54 RawBufferOpLowering(const LLVMTypeConverter
&converter
, Chipset chipset
)
55 : ConvertOpToLLVMPattern
<GpuOp
>(converter
), chipset(chipset
) {}
58 static constexpr uint32_t maxVectorOpWidth
= 128;
61 matchAndRewrite(GpuOp gpuOp
, typename
GpuOp::Adaptor adaptor
,
62 ConversionPatternRewriter
&rewriter
) const override
{
63 Location loc
= gpuOp
.getLoc();
64 Value memref
= adaptor
.getMemref();
65 Value unconvertedMemref
= gpuOp
.getMemref();
66 MemRefType memrefType
= cast
<MemRefType
>(unconvertedMemref
.getType());
68 if (chipset
.majorVersion
< 9)
69 return gpuOp
.emitOpError("raw buffer ops require GCN or higher");
71 Value storeData
= adaptor
.getODSOperands(0)[0];
72 if (storeData
== memref
) // no write component to this op
76 wantedDataType
= storeData
.getType();
78 wantedDataType
= gpuOp
.getODSResults(0)[0].getType();
80 Value atomicCmpData
= Value();
81 // Operand index 1 of a load is the indices, trying to read them can crash.
83 Value maybeCmpData
= adaptor
.getODSOperands(1)[0];
84 if (maybeCmpData
!= memref
)
85 atomicCmpData
= maybeCmpData
;
88 Type llvmWantedDataType
= this->typeConverter
->convertType(wantedDataType
);
90 Type i32
= rewriter
.getI32Type();
91 Type llvmI32
= this->typeConverter
->convertType(i32
);
92 Type llvmI16
= this->typeConverter
->convertType(rewriter
.getI16Type());
94 int64_t elementByteWidth
= memrefType
.getElementTypeBitWidth() / 8;
95 Value byteWidthConst
= createI32Constant(rewriter
, loc
, elementByteWidth
);
97 // If we want to load a vector<NxT> with total size <= 32
98 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
99 // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
100 // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
101 // so bitcast any floats to integers.
102 Type llvmBufferValType
= llvmWantedDataType
;
104 if (auto floatType
= dyn_cast
<FloatType
>(wantedDataType
))
105 llvmBufferValType
= this->getTypeConverter()->convertType(
106 rewriter
.getIntegerType(floatType
.getWidth()));
108 if (auto dataVector
= dyn_cast
<VectorType
>(wantedDataType
)) {
109 uint32_t vecLen
= dataVector
.getNumElements();
110 uint32_t elemBits
= dataVector
.getElementTypeBitWidth();
111 uint32_t totalBits
= elemBits
* vecLen
;
113 isa_and_present
<RawBufferAtomicFaddOp
>(*gpuOp
) && vecLen
== 2;
114 if (totalBits
> maxVectorOpWidth
)
115 return gpuOp
.emitOpError(
116 "Total width of loads or stores must be no more than " +
117 Twine(maxVectorOpWidth
) + " bits, but we call for " +
119 " bits. This should've been caught in validation");
120 if (!usePackedFp16
&& elemBits
< 32) {
121 if (totalBits
> 32) {
122 if (totalBits
% 32 != 0)
123 return gpuOp
.emitOpError("Load or store of more than 32-bits that "
124 "doesn't fit into words. Can't happen\n");
125 llvmBufferValType
= this->typeConverter
->convertType(
126 VectorType::get(totalBits
/ 32, i32
));
128 llvmBufferValType
= this->typeConverter
->convertType(
129 rewriter
.getIntegerType(totalBits
));
134 SmallVector
<Value
, 6> args
;
136 if (llvmBufferValType
!= llvmWantedDataType
) {
138 rewriter
.create
<LLVM::BitcastOp
>(loc
, llvmBufferValType
, storeData
);
139 args
.push_back(castForStore
);
141 args
.push_back(storeData
);
146 if (llvmBufferValType
!= llvmWantedDataType
) {
147 Value castForCmp
= rewriter
.create
<LLVM::BitcastOp
>(
148 loc
, llvmBufferValType
, atomicCmpData
);
149 args
.push_back(castForCmp
);
151 args
.push_back(atomicCmpData
);
155 // Construct buffer descriptor from memref, attributes
157 SmallVector
<int64_t, 5> strides
;
158 if (failed(getStridesAndOffset(memrefType
, strides
, offset
)))
159 return gpuOp
.emitOpError("Can't lower non-stride-offset memrefs");
161 MemRefDescriptor
memrefDescriptor(memref
);
163 Value ptr
= memrefDescriptor
.alignedPtr(rewriter
, loc
);
164 // The stride value is always 0 for raw buffers. This also disables
166 Value stride
= rewriter
.create
<LLVM::ConstantOp
>(
167 loc
, llvmI16
, rewriter
.getI16IntegerAttr(0));
169 if (memrefType
.hasStaticShape()) {
170 numRecords
= createI32Constant(
172 static_cast<int32_t>(memrefType
.getNumElements() * elementByteWidth
));
175 for (uint32_t i
= 0, e
= memrefType
.getRank(); i
< e
; ++i
) {
176 Value size
= memrefDescriptor
.size(rewriter
, loc
, i
);
177 Value stride
= memrefDescriptor
.stride(rewriter
, loc
, i
);
178 stride
= rewriter
.create
<LLVM::MulOp
>(loc
, stride
, byteWidthConst
);
179 Value maxThisDim
= rewriter
.create
<LLVM::MulOp
>(loc
, size
, stride
);
180 maxIndex
= maxIndex
? rewriter
.create
<LLVM::MaximumOp
>(loc
, maxIndex
,
184 numRecords
= rewriter
.create
<LLVM::TruncOp
>(loc
, llvmI32
, maxIndex
);
188 // bits 0-11: dst sel, ignored by these intrinsics
189 // bits 12-14: data format (ignored, must be nonzero, 7=float)
190 // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
191 // bit 19: In nested heap (0 here)
192 // bit 20: Behavior on unmap (0 means "return 0 / ignore")
193 // bits 21-22: Index stride for swizzles (N/A)
194 // bit 23: Add thread ID (0)
195 // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
196 // bits 25-26: Reserved (0)
197 // bit 27: Buffer is non-volatile (CDNA only)
198 // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
199 // none, 3 = either swizzles or testing against offset field) RDNA only
200 // bits 30-31: Type (must be 0)
201 uint32_t flags
= (7 << 12) | (4 << 15);
202 if (chipset
.majorVersion
>= 10) {
204 uint32_t oob
= adaptor
.getBoundsCheck() ? 3 : 2;
205 flags
|= (oob
<< 28);
207 Value flagsConst
= createI32Constant(rewriter
, loc
, flags
);
208 Type rsrcType
= LLVM::LLVMPointerType::get(rewriter
.getContext(), 8);
209 Value resource
= rewriter
.createOrFold
<ROCDL::MakeBufferRsrcOp
>(
210 loc
, rsrcType
, ptr
, stride
, numRecords
, flagsConst
);
211 args
.push_back(resource
);
213 // Indexing (voffset)
214 Value voffset
= createI32Constant(rewriter
, loc
, 0);
215 for (auto pair
: llvm::enumerate(adaptor
.getIndices())) {
216 size_t i
= pair
.index();
217 Value index
= pair
.value();
219 if (ShapedType::isDynamic(strides
[i
])) {
220 strideOp
= rewriter
.create
<LLVM::MulOp
>(
221 loc
, memrefDescriptor
.stride(rewriter
, loc
, i
), byteWidthConst
);
224 createI32Constant(rewriter
, loc
, strides
[i
] * elementByteWidth
);
226 index
= rewriter
.create
<LLVM::MulOp
>(loc
, index
, strideOp
);
227 voffset
= rewriter
.create
<LLVM::AddOp
>(loc
, voffset
, index
);
229 if (adaptor
.getIndexOffset()) {
230 int32_t indexOffset
= *gpuOp
.getIndexOffset() * elementByteWidth
;
231 Value extraOffsetConst
= createI32Constant(rewriter
, loc
, indexOffset
);
233 voffset
? rewriter
.create
<LLVM::AddOp
>(loc
, voffset
, extraOffsetConst
)
236 args
.push_back(voffset
);
238 Value sgprOffset
= adaptor
.getSgprOffset();
240 sgprOffset
= createI32Constant(rewriter
, loc
, 0);
241 if (ShapedType::isDynamic(offset
))
242 sgprOffset
= rewriter
.create
<LLVM::AddOp
>(
243 loc
, memrefDescriptor
.offset(rewriter
, loc
), sgprOffset
);
245 sgprOffset
= rewriter
.create
<LLVM::AddOp
>(
246 loc
, sgprOffset
, createI32Constant(rewriter
, loc
, offset
));
247 args
.push_back(sgprOffset
);
249 // bit 0: GLC = 0 (atomics drop value, less coherency)
250 // bits 1-2: SLC, DLC = 0 (similarly)
251 // bit 3: swizzled (0 for raw)
252 args
.push_back(createI32Constant(rewriter
, loc
, 0));
254 llvm::SmallVector
<Type
, 1> resultTypes(gpuOp
->getNumResults(),
256 Operation
*lowered
= rewriter
.create
<Intrinsic
>(loc
, resultTypes
, args
,
257 ArrayRef
<NamedAttribute
>());
258 if (lowered
->getNumResults() == 1) {
259 Value replacement
= lowered
->getResult(0);
260 if (llvmBufferValType
!= llvmWantedDataType
) {
261 replacement
= rewriter
.create
<LLVM::BitcastOp
>(loc
, llvmWantedDataType
,
264 rewriter
.replaceOp(gpuOp
, replacement
);
266 rewriter
.eraseOp(gpuOp
);
272 struct LDSBarrierOpLowering
: public ConvertOpToLLVMPattern
<LDSBarrierOp
> {
273 LDSBarrierOpLowering(const LLVMTypeConverter
&converter
, Chipset chipset
)
274 : ConvertOpToLLVMPattern
<LDSBarrierOp
>(converter
), chipset(chipset
) {}
279 matchAndRewrite(LDSBarrierOp op
, LDSBarrierOp::Adaptor adaptor
,
280 ConversionPatternRewriter
&rewriter
) const override
{
281 bool requiresInlineAsm
= chipset
< kGfx90a
|| chipset
.majorVersion
== 11;
283 if (requiresInlineAsm
) {
284 auto asmDialectAttr
= LLVM::AsmDialectAttr::get(rewriter
.getContext(),
285 LLVM::AsmDialect::AD_ATT
);
287 ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
288 const char *constraints
= "";
289 rewriter
.replaceOpWithNewOp
<LLVM::InlineAsmOp
>(
291 /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
292 /*asm_string=*/asmStr
, constraints
, /*has_side_effects=*/true,
293 /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr
,
294 /*operand_attrs=*/ArrayAttr());
297 if (chipset
.majorVersion
< 12) {
298 constexpr int32_t ldsOnlyBitsGfx6789
= ~(0x1f << 8);
299 constexpr int32_t ldsOnlyBitsGfx10
= ~(0x3f << 8);
300 // Left in place in case someone disables the inline ASM path or future
301 // chipsets use the same bit pattern.
302 constexpr int32_t ldsOnlyBitsGfx11
= ~(0x3f << 4);
305 if (chipset
.majorVersion
== 11)
306 ldsOnlyBits
= ldsOnlyBitsGfx11
;
307 else if (chipset
.majorVersion
== 10)
308 ldsOnlyBits
= ldsOnlyBitsGfx10
;
309 else if (chipset
.majorVersion
<= 9)
310 ldsOnlyBits
= ldsOnlyBitsGfx6789
;
312 return op
.emitOpError(
313 "don't know how to lower this for chipset major version")
314 << chipset
.majorVersion
;
316 Location loc
= op
->getLoc();
317 rewriter
.create
<ROCDL::WaitcntOp
>(loc
, ldsOnlyBits
);
318 rewriter
.replaceOpWithNewOp
<ROCDL::SBarrierOp
>(op
);
320 Location loc
= op
->getLoc();
321 rewriter
.create
<ROCDL::WaitDscntOp
>(loc
, 0);
322 rewriter
.create
<ROCDL::BarrierSignalOp
>(loc
, -1);
323 rewriter
.replaceOpWithNewOp
<ROCDL::BarrierWaitOp
>(op
, -1);
330 struct SchedBarrierOpLowering
: public ConvertOpToLLVMPattern
<SchedBarrierOp
> {
331 SchedBarrierOpLowering(const LLVMTypeConverter
&converter
, Chipset chipset
)
332 : ConvertOpToLLVMPattern
<SchedBarrierOp
>(converter
), chipset(chipset
) {}
337 matchAndRewrite(SchedBarrierOp op
, SchedBarrierOp::Adaptor adaptor
,
338 ConversionPatternRewriter
&rewriter
) const override
{
339 rewriter
.replaceOpWithNewOp
<ROCDL::SchedBarrier
>(op
,
340 (uint32_t)op
.getOpts());
347 /// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
348 /// and LLVM AMDGPU intrinsics convention.
351 /// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
352 /// 2. If the element type is bfloat16, bitcast it to i16.
353 static Value
convertMFMAVectorOperand(ConversionPatternRewriter
&rewriter
,
354 Location loc
, Value input
) {
355 Type inputType
= input
.getType();
356 if (auto vectorType
= dyn_cast
<VectorType
>(inputType
)) {
357 if (vectorType
.getElementType().isBF16())
358 return rewriter
.create
<LLVM::BitcastOp
>(
359 loc
, vectorType
.clone(rewriter
.getI16Type()), input
);
360 if (vectorType
.getElementType().isInteger(8)) {
361 return rewriter
.create
<LLVM::BitcastOp
>(
362 loc
, rewriter
.getIntegerType(vectorType
.getNumElements() * 8), input
);
368 /// Push an input operand. If it is a float type, nothing to do. If it is
369 /// an integer type, then we need to also push its signdness (1 for signed, 0
370 /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
371 /// vector. We also need to convert bfloat inputs to i16 to account for the lack
372 /// of bfloat support in the WMMA intrinsics themselves.
373 static void wmmaPushInputOperand(ConversionPatternRewriter
&rewriter
,
375 const TypeConverter
*typeConverter
,
376 bool isUnsigned
, Value llvmInput
,
378 SmallVector
<Value
, 4> &operands
) {
379 Type inputType
= llvmInput
.getType();
380 auto vectorType
= dyn_cast
<VectorType
>(inputType
);
381 Type elemType
= vectorType
.getElementType();
383 if (elemType
.isBF16())
384 llvmInput
= rewriter
.create
<LLVM::BitcastOp
>(
385 loc
, vectorType
.clone(rewriter
.getI16Type()), llvmInput
);
386 if (!elemType
.isInteger(8)) {
387 operands
.push_back(llvmInput
);
391 // We need to check the type of the input before conversion to properly test
392 // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
393 // fp8/int8 information is lost during the conversion process.
394 auto mlirInputType
= cast
<VectorType
>(mlirInput
.getType());
395 bool isInputInt8
= mlirInputType
.getElementType().isInteger(8);
397 // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
398 bool localIsUnsigned
= isUnsigned
;
399 if (elemType
.isUnsignedInteger(8)) {
400 localIsUnsigned
= true;
401 } else if (elemType
.isSignedInteger(8)) {
402 localIsUnsigned
= false;
404 Value sign
= createI1Constant(rewriter
, loc
, !localIsUnsigned
);
405 operands
.push_back(sign
);
408 int64_t numBytes
= vectorType
.getNumElements();
409 Type i32
= rewriter
.getI32Type();
410 VectorType vectorType32bits
= VectorType::get(numBytes
* 8 / 32, i32
);
411 auto llvmVectorType32bits
= typeConverter
->convertType(vectorType32bits
);
412 Value result
= rewriter
.createOrFold
<LLVM::BitcastOp
>(
413 loc
, llvmVectorType32bits
, llvmInput
);
414 operands
.push_back(result
);
417 /// Push the output operand. For many cases this is only pushing the output in
418 /// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
419 /// since the same numbers of VGPRs is used, we need to decide if to store the
420 /// result in the upper 16 bits of the VGPRs or in the lower part. To store the
421 /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
422 /// be stored it in the upper part
423 static void wmmaPushOutputOperand(ConversionPatternRewriter
&rewriter
,
425 const TypeConverter
*typeConverter
,
426 Value output
, int32_t subwordOffset
,
427 bool clamp
, SmallVector
<Value
, 4> &operands
) {
428 Type inputType
= output
.getType();
429 auto vectorType
= dyn_cast
<VectorType
>(inputType
);
430 Type elemType
= vectorType
.getElementType();
431 if (elemType
.isBF16())
432 output
= rewriter
.create
<LLVM::BitcastOp
>(
433 loc
, vectorType
.clone(rewriter
.getI16Type()), output
);
434 operands
.push_back(output
);
435 if (elemType
.isF16() || elemType
.isBF16() || elemType
.isInteger(16)) {
436 operands
.push_back(createI1Constant(rewriter
, loc
, subwordOffset
));
437 } else if (elemType
.isInteger(32)) {
438 operands
.push_back(createI1Constant(rewriter
, loc
, clamp
));
442 /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
443 /// if one exists. This includes checking to ensure the intrinsic is supported
444 /// on the architecture you are compiling for.
445 static std::optional
<StringRef
> mfmaOpToIntrinsic(MFMAOp mfma
,
447 uint32_t m
= mfma
.getM(), n
= mfma
.getN(), k
= mfma
.getK(),
448 b
= mfma
.getBlocks();
449 Type sourceElem
= mfma
.getSourceA().getType();
450 if (auto sourceType
= dyn_cast
<VectorType
>(sourceElem
))
451 sourceElem
= sourceType
.getElementType();
452 Type destElem
= mfma
.getDestC().getType();
453 if (auto destType
= dyn_cast
<VectorType
>(destElem
))
454 destElem
= destType
.getElementType();
456 if (sourceElem
.isF32() && destElem
.isF32()) {
457 if (mfma
.getReducePrecision() && chipset
>= kGfx940
) {
458 if (m
== 32 && n
== 32 && k
== 4 && b
== 1)
459 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
460 if (m
== 16 && n
== 16 && k
== 8 && b
== 1)
461 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
463 if (m
== 32 && n
== 32 && k
== 1 && b
== 2)
464 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
465 if (m
== 16 && n
== 16 && k
== 1 && b
== 4)
466 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
467 if (m
== 4 && n
== 4 && k
== 1 && b
== 16)
468 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
469 if (m
== 32 && n
== 32 && k
== 2 && b
== 1)
470 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
471 if (m
== 16 && n
== 16 && k
== 4 && b
== 1)
472 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
475 if (sourceElem
.isF16() && destElem
.isF32()) {
476 if (m
== 32 && n
== 32 && k
== 4 && b
== 2)
477 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
478 if (m
== 16 && n
== 16 && k
== 4 && b
== 4)
479 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
480 if (m
== 4 && n
== 4 && k
== 4 && b
== 16)
481 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
482 if (m
== 32 && n
== 32 && k
== 8 && b
== 1)
483 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
484 if (m
== 16 && n
== 16 && k
== 16 && b
== 1)
485 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
488 if (sourceElem
.isBF16() && destElem
.isF32() && chipset
>= kGfx90a
) {
489 if (m
== 32 && n
== 32 && k
== 4 && b
== 2)
490 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
491 if (m
== 16 && n
== 16 && k
== 4 && b
== 4)
492 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
493 if (m
== 4 && n
== 4 && k
== 4 && b
== 16)
494 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
495 if (m
== 32 && n
== 32 && k
== 8 && b
== 1)
496 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
497 if (m
== 16 && n
== 16 && k
== 16 && b
== 1)
498 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
501 if (sourceElem
.isBF16() && destElem
.isF32()) {
502 if (m
== 32 && n
== 32 && k
== 2 && b
== 2)
503 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
504 if (m
== 16 && n
== 16 && k
== 2 && b
== 4)
505 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
506 if (m
== 4 && n
== 4 && k
== 2 && b
== 16)
507 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
508 if (m
== 32 && n
== 32 && k
== 4 && b
== 1)
509 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
510 if (m
== 16 && n
== 16 && k
== 8 && b
== 1)
511 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
514 if (isa
<IntegerType
>(sourceElem
) && destElem
.isInteger(32)) {
515 if (m
== 32 && n
== 32 && k
== 4 && b
== 2)
516 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
517 if (m
== 16 && n
== 16 && k
== 4 && b
== 4)
518 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
519 if (m
== 4 && n
== 4 && k
== 4 && b
== 16)
520 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
521 if (m
== 32 && n
== 32 && k
== 8 && b
== 1)
522 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
523 if (m
== 16 && n
== 16 && k
== 16 && b
== 1)
524 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
525 if (m
== 32 && n
== 32 && k
== 16 && b
== 1 && chipset
>= kGfx940
)
526 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
527 if (m
== 16 && n
== 16 && k
== 32 && b
== 1 && chipset
>= kGfx940
)
528 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
531 if (sourceElem
.isF64() && destElem
.isF64() && chipset
>= kGfx90a
) {
532 if (m
== 16 && n
== 16 && k
== 4 && b
== 1)
533 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
534 if (m
== 4 && n
== 4 && k
== 4 && b
== 4)
535 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
538 if (sourceElem
.isFloat8E5M2FNUZ() && destElem
.isF32() && chipset
>= kGfx940
) {
539 // Known to be correct because there are no scalar f8 instructions and
540 // because a length mismatch will have been caught by the verifier.
542 cast
<VectorType
>(mfma
.getSourceB().getType()).getElementType();
543 if (m
== 16 && n
== 16 && k
== 32 && b
== 1) {
544 if (sourceBElem
.isFloat8E5M2FNUZ())
545 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
546 if (sourceBElem
.isFloat8E4M3FNUZ())
547 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
549 if (m
== 32 && n
== 32 && k
== 16 && b
== 1) {
550 if (sourceBElem
.isFloat8E5M2FNUZ())
551 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
552 if (sourceBElem
.isFloat8E4M3FNUZ())
553 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
557 if (sourceElem
.isFloat8E4M3FNUZ() && destElem
.isF32() && chipset
>= kGfx940
) {
559 cast
<VectorType
>(mfma
.getSourceB().getType()).getElementType();
560 if (m
== 16 && n
== 16 && k
== 32 && b
== 1) {
561 if (sourceBElem
.isFloat8E5M2FNUZ())
562 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
563 if (sourceBElem
.isFloat8E4M3FNUZ())
564 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
566 if (m
== 32 && n
== 32 && k
== 16 && b
== 1) {
567 if (sourceBElem
.isFloat8E5M2FNUZ())
568 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
569 if (sourceBElem
.isFloat8E4M3FNUZ())
570 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
577 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
578 /// if one exists. This includes checking to ensure the intrinsic is supported
579 /// on the architecture you are compiling for.
580 static std::optional
<StringRef
> wmmaOpToIntrinsic(WMMAOp wmma
,
582 auto sourceVectorType
= dyn_cast
<VectorType
>(wmma
.getSourceA().getType());
583 auto destVectorType
= dyn_cast
<VectorType
>(wmma
.getDestC().getType());
584 auto elemSourceType
= sourceVectorType
.getElementType();
585 auto elemDestType
= destVectorType
.getElementType();
587 if (elemSourceType
.isF16() && elemDestType
.isF32())
588 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
589 if (elemSourceType
.isBF16() && elemDestType
.isF32())
590 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
591 if (elemSourceType
.isF16() && elemDestType
.isF16())
592 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
593 if (elemSourceType
.isBF16() && elemDestType
.isBF16())
594 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
595 if (elemSourceType
.isInteger(8) && elemDestType
.isInteger(32))
596 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
597 if (elemSourceType
.isFloat8E4M3FN() && elemDestType
.isF32())
598 return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
599 if (elemSourceType
.isFloat8E5M2() && elemDestType
.isF32())
600 return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
605 struct MFMAOpLowering
: public ConvertOpToLLVMPattern
<MFMAOp
> {
606 MFMAOpLowering(const LLVMTypeConverter
&converter
, Chipset chipset
)
607 : ConvertOpToLLVMPattern
<MFMAOp
>(converter
), chipset(chipset
) {}
612 matchAndRewrite(MFMAOp op
, MFMAOpAdaptor adaptor
,
613 ConversionPatternRewriter
&rewriter
) const override
{
614 Location loc
= op
.getLoc();
615 Type outType
= typeConverter
->convertType(op
.getDestD().getType());
616 Type intrinsicOutType
= outType
;
617 if (auto outVecType
= dyn_cast
<VectorType
>(outType
))
618 if (outVecType
.getElementType().isBF16())
619 intrinsicOutType
= outVecType
.clone(rewriter
.getI16Type());
621 if (chipset
.majorVersion
!= 9 || chipset
< kGfx908
)
622 return op
->emitOpError("MFMA only supported on gfx908+");
623 uint32_t getBlgpField
= static_cast<uint32_t>(op
.getBlgp());
624 if (op
.getNegateA() || op
.getNegateB() || op
.getNegateC()) {
625 if (chipset
< kGfx940
)
626 return op
.emitOpError("negation unsupported on older than gfx940");
628 op
.getNegateA() | (op
.getNegateB() << 1) | (op
.getNegateC() << 2);
630 std::optional
<StringRef
> maybeIntrinsic
= mfmaOpToIntrinsic(op
, chipset
);
631 if (!maybeIntrinsic
.has_value())
632 return op
.emitOpError("no intrinsic matching MFMA size on given chipset");
633 OperationState
loweredOp(loc
, *maybeIntrinsic
);
634 loweredOp
.addTypes(intrinsicOutType
);
635 loweredOp
.addOperands(
636 {convertMFMAVectorOperand(rewriter
, loc
, adaptor
.getSourceA()),
637 convertMFMAVectorOperand(rewriter
, loc
, adaptor
.getSourceB()),
638 adaptor
.getDestC(), createI32Constant(rewriter
, loc
, op
.getCbsz()),
639 createI32Constant(rewriter
, loc
, op
.getAbid()),
640 createI32Constant(rewriter
, loc
, getBlgpField
)});
641 Value lowered
= rewriter
.create(loweredOp
)->getResult(0);
642 if (outType
!= intrinsicOutType
)
643 lowered
= rewriter
.create
<LLVM::BitcastOp
>(loc
, outType
, lowered
);
644 rewriter
.replaceOp(op
, lowered
);
649 struct WMMAOpLowering
: public ConvertOpToLLVMPattern
<WMMAOp
> {
650 WMMAOpLowering(const LLVMTypeConverter
&converter
, Chipset chipset
)
651 : ConvertOpToLLVMPattern
<WMMAOp
>(converter
), chipset(chipset
) {}
656 matchAndRewrite(WMMAOp op
, WMMAOpAdaptor adaptor
,
657 ConversionPatternRewriter
&rewriter
) const override
{
658 Location loc
= op
.getLoc();
660 typeConverter
->convertType
<VectorType
>(op
.getDestD().getType());
662 return rewriter
.notifyMatchFailure(op
, "type conversion failed");
664 if (chipset
.majorVersion
!= 11 && chipset
.majorVersion
!= 12)
665 return op
->emitOpError("WMMA only supported on gfx11 and gfx12");
667 // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
668 // need to bitcast bfloats to i16 and then bitcast them back.
669 VectorType rawOutType
= outType
;
670 if (outType
.getElementType().isBF16())
671 rawOutType
= outType
.clone(rewriter
.getI16Type());
673 std::optional
<StringRef
> maybeIntrinsic
= wmmaOpToIntrinsic(op
, chipset
);
675 if (!maybeIntrinsic
.has_value())
676 return op
.emitOpError("no intrinsic matching WMMA on the given chipset");
678 OperationState
loweredOp(loc
, *maybeIntrinsic
);
679 loweredOp
.addTypes(rawOutType
);
681 SmallVector
<Value
, 4> operands
;
682 wmmaPushInputOperand(rewriter
, loc
, typeConverter
, op
.getUnsignedA(),
683 adaptor
.getSourceA(), op
.getSourceA(), operands
);
684 wmmaPushInputOperand(rewriter
, loc
, typeConverter
, op
.getUnsignedB(),
685 adaptor
.getSourceB(), op
.getSourceB(), operands
);
686 wmmaPushOutputOperand(rewriter
, loc
, typeConverter
, adaptor
.getDestC(),
687 op
.getSubwordOffset(), op
.getClamp(), operands
);
689 loweredOp
.addOperands(operands
);
690 Operation
*lowered
= rewriter
.create(loweredOp
);
692 Operation
*maybeCastBack
= lowered
;
693 if (rawOutType
!= outType
)
695 rewriter
.create
<LLVM::BitcastOp
>(loc
, outType
, lowered
->getResult(0));
696 rewriter
.replaceOp(op
, maybeCastBack
->getResults());
703 struct ExtPackedFp8OpLowering final
704 : public ConvertOpToLLVMPattern
<ExtPackedFp8Op
> {
705 ExtPackedFp8OpLowering(const LLVMTypeConverter
&converter
, Chipset chipset
)
706 : ConvertOpToLLVMPattern
<amdgpu::ExtPackedFp8Op
>(converter
),
711 matchAndRewrite(ExtPackedFp8Op op
, ExtPackedFp8OpAdaptor adaptor
,
712 ConversionPatternRewriter
&rewriter
) const override
;
715 struct PackedTrunc2xFp8OpLowering final
716 : public ConvertOpToLLVMPattern
<PackedTrunc2xFp8Op
> {
717 PackedTrunc2xFp8OpLowering(const LLVMTypeConverter
&converter
,
719 : ConvertOpToLLVMPattern
<amdgpu::PackedTrunc2xFp8Op
>(converter
),
724 matchAndRewrite(PackedTrunc2xFp8Op op
, PackedTrunc2xFp8OpAdaptor adaptor
,
725 ConversionPatternRewriter
&rewriter
) const override
;
728 struct PackedStochRoundFp8OpLowering final
729 : public ConvertOpToLLVMPattern
<PackedStochRoundFp8Op
> {
730 PackedStochRoundFp8OpLowering(const LLVMTypeConverter
&converter
,
732 : ConvertOpToLLVMPattern
<amdgpu::PackedStochRoundFp8Op
>(converter
),
737 matchAndRewrite(PackedStochRoundFp8Op op
,
738 PackedStochRoundFp8OpAdaptor adaptor
,
739 ConversionPatternRewriter
&rewriter
) const override
;
743 LogicalResult
ExtPackedFp8OpLowering::matchAndRewrite(
744 ExtPackedFp8Op op
, ExtPackedFp8OpAdaptor adaptor
,
745 ConversionPatternRewriter
&rewriter
) const {
746 Location loc
= op
.getLoc();
747 if (chipset
.majorVersion
!= 9 || chipset
< kGfx940
)
748 return rewriter
.notifyMatchFailure(
749 loc
, "Fp8 conversion instructions are not available on target "
750 "architecture and their emulation is not implemented");
752 getTypeConverter()->convertType(VectorType::get(4, rewriter
.getI8Type()));
753 Type i32
= getTypeConverter()->convertType(rewriter
.getI32Type());
754 Type f32
= getTypeConverter()->convertType(op
.getResult().getType());
756 Value source
= adaptor
.getSource();
757 auto sourceVecType
= dyn_cast
<VectorType
>(op
.getSource().getType());
758 Type sourceElemType
= getElementTypeOrSelf(op
.getSource());
760 if (!sourceVecType
|| sourceVecType
.getNumElements() < 4) {
761 Value longVec
= rewriter
.create
<LLVM::UndefOp
>(loc
, v4i8
);
762 if (!sourceVecType
) {
763 longVec
= rewriter
.create
<LLVM::InsertElementOp
>(
764 loc
, longVec
, source
, createI32Constant(rewriter
, loc
, 0));
766 for (int32_t i
= 0, e
= sourceVecType
.getNumElements(); i
< e
; ++i
) {
767 Value idx
= createI32Constant(rewriter
, loc
, i
);
768 Value elem
= rewriter
.create
<LLVM::ExtractElementOp
>(loc
, source
, idx
);
770 rewriter
.create
<LLVM::InsertElementOp
>(loc
, longVec
, elem
, idx
);
775 Value i32Source
= rewriter
.create
<LLVM::BitcastOp
>(loc
, i32
, source
);
776 Value wordSel
= createI32Constant(rewriter
, loc
, op
.getIndex());
777 if (sourceElemType
.isFloat8E5M2FNUZ()) {
778 rewriter
.replaceOpWithNewOp
<ROCDL::CvtF32Bf8Op
>(op
, f32
, i32Source
,
780 } else if (sourceElemType
.isFloat8E4M3FNUZ()) {
781 rewriter
.replaceOpWithNewOp
<ROCDL::CvtF32Fp8Op
>(op
, f32
, i32Source
,
787 LogicalResult
PackedTrunc2xFp8OpLowering::matchAndRewrite(
788 PackedTrunc2xFp8Op op
, PackedTrunc2xFp8OpAdaptor adaptor
,
789 ConversionPatternRewriter
&rewriter
) const {
790 Location loc
= op
.getLoc();
791 if (chipset
.majorVersion
!= 9 || chipset
< kGfx940
)
792 return rewriter
.notifyMatchFailure(
793 loc
, "Fp8 conversion instructions are not available on target "
794 "architecture and their emulation is not implemented");
795 Type i32
= getTypeConverter()->convertType(rewriter
.getI32Type());
797 Type resultType
= op
.getResult().getType();
798 Type resultElemType
= getElementTypeOrSelf(resultType
);
800 Value sourceA
= adaptor
.getSourceA();
801 Value sourceB
= adaptor
.getSourceB();
803 sourceB
= rewriter
.create
<LLVM::UndefOp
>(loc
, sourceA
.getType());
804 Value existing
= adaptor
.getExisting();
806 existing
= rewriter
.create
<LLVM::BitcastOp
>(loc
, i32
, existing
);
808 existing
= rewriter
.create
<LLVM::UndefOp
>(loc
, i32
);
809 Value wordSel
= createI1Constant(rewriter
, loc
, op
.getWordIndex());
812 if (resultElemType
.isFloat8E5M2FNUZ())
813 result
= rewriter
.create
<ROCDL::CvtPkBf8F32Op
>(loc
, i32
, sourceA
, sourceB
,
815 else if (resultElemType
.isFloat8E4M3FNUZ())
816 result
= rewriter
.create
<ROCDL::CvtPkFp8F32Op
>(loc
, i32
, sourceA
, sourceB
,
819 result
= rewriter
.replaceOpWithNewOp
<LLVM::BitcastOp
>(
820 op
, getTypeConverter()->convertType(resultType
), result
);
824 LogicalResult
PackedStochRoundFp8OpLowering::matchAndRewrite(
825 PackedStochRoundFp8Op op
, PackedStochRoundFp8OpAdaptor adaptor
,
826 ConversionPatternRewriter
&rewriter
) const {
827 Location loc
= op
.getLoc();
828 if (chipset
.majorVersion
!= 9 || chipset
< kGfx940
)
829 return rewriter
.notifyMatchFailure(
830 loc
, "Fp8 conversion instructions are not available on target "
831 "architecture and their emulation is not implemented");
832 Type i32
= getTypeConverter()->convertType(rewriter
.getI32Type());
834 Type resultType
= op
.getResult().getType();
835 Type resultElemType
= getElementTypeOrSelf(resultType
);
837 Value source
= adaptor
.getSource();
838 Value stoch
= adaptor
.getStochiasticParam();
839 Value existing
= adaptor
.getExisting();
841 existing
= rewriter
.create
<LLVM::BitcastOp
>(loc
, i32
, existing
);
843 existing
= rewriter
.create
<LLVM::UndefOp
>(loc
, i32
);
844 Value byteSel
= createI32Constant(rewriter
, loc
, op
.getStoreIndex());
847 if (resultElemType
.isFloat8E5M2FNUZ())
848 result
= rewriter
.create
<ROCDL::CvtSrBf8F32Op
>(loc
, i32
, source
, stoch
,
850 else if (resultElemType
.isFloat8E4M3FNUZ())
851 result
= rewriter
.create
<ROCDL::CvtSrFp8F32Op
>(loc
, i32
, source
, stoch
,
854 result
= rewriter
.replaceOpWithNewOp
<LLVM::BitcastOp
>(
855 op
, getTypeConverter()->convertType(resultType
), result
);
859 // Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
860 // operation into the corresponding ROCDL instructions.
861 struct AMDGPUDPPLowering
: public ConvertOpToLLVMPattern
<DPPOp
> {
862 AMDGPUDPPLowering(const LLVMTypeConverter
&converter
, Chipset chipset
)
863 : ConvertOpToLLVMPattern
<DPPOp
>(converter
), chipset(chipset
) {}
867 matchAndRewrite(DPPOp DppOp
, DPPOp::Adaptor adaptor
,
868 ConversionPatternRewriter
&rewriter
) const override
{
870 // Convert the source operand to the corresponding LLVM type
871 Location loc
= DppOp
.getLoc();
872 Value src
= adaptor
.getSrc();
873 Value old
= adaptor
.getOld();
874 Type srcType
= src
.getType();
875 Type oldType
= old
.getType();
876 Type llvmType
= nullptr;
877 if (srcType
.getIntOrFloatBitWidth() < 32) {
878 llvmType
= rewriter
.getI32Type();
879 } else if (isa
<FloatType
>(srcType
)) {
880 llvmType
= (srcType
.getIntOrFloatBitWidth() == 32)
881 ? rewriter
.getF32Type()
882 : rewriter
.getF64Type();
883 } else if (isa
<IntegerType
>(srcType
)) {
884 llvmType
= (srcType
.getIntOrFloatBitWidth() == 32)
885 ? rewriter
.getI32Type()
886 : rewriter
.getI64Type();
888 auto llvmSrcIntType
= typeConverter
->convertType(
889 rewriter
.getIntegerType(srcType
.getIntOrFloatBitWidth()));
891 // If the source type is less of 32, use bitcast to convert it to i32.
892 auto convertOperand
= [&](Value operand
, Type operandType
) {
893 if (operandType
.getIntOrFloatBitWidth() <= 16) {
894 if (llvm::isa
<FloatType
>(operandType
)) {
896 rewriter
.create
<LLVM::BitcastOp
>(loc
, llvmSrcIntType
, operand
);
898 auto llvmVecType
= typeConverter
->convertType(mlir::VectorType::get(
899 32 / operandType
.getIntOrFloatBitWidth(), llvmSrcIntType
));
900 Value undefVec
= rewriter
.create
<LLVM::UndefOp
>(loc
, llvmVecType
);
901 operand
= rewriter
.create
<LLVM::InsertElementOp
>(
902 loc
, undefVec
, operand
, createI32Constant(rewriter
, loc
, 0));
903 operand
= rewriter
.create
<LLVM::BitcastOp
>(loc
, llvmType
, operand
);
908 src
= convertOperand(src
, srcType
);
909 old
= convertOperand(old
, oldType
);
911 // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
912 enum DppCtrl
: unsigned {
921 ROW_HALF_MIRROR
= 0x141,
926 auto kind
= DppOp
.getKind();
927 auto permArgument
= DppOp
.getPermArgument();
928 uint32_t DppCtrl
= 0;
932 case DPPPerm::quad_perm
:
933 if (auto quadPermAttr
= cast
<ArrayAttr
>(*permArgument
)) {
935 for (auto elem
: quadPermAttr
.getAsRange
<IntegerAttr
>()) {
936 uint32_t num
= elem
.getInt();
937 DppCtrl
|= num
<< (i
* 2);
942 case DPPPerm::row_shl
:
943 if (auto intAttr
= cast
<IntegerAttr
>(*permArgument
)) {
944 DppCtrl
= intAttr
.getInt() + DppCtrl::ROW_SHL0
;
947 case DPPPerm::row_shr
:
948 if (auto intAttr
= cast
<IntegerAttr
>(*permArgument
)) {
949 DppCtrl
= intAttr
.getInt() + DppCtrl::ROW_SHR0
;
952 case DPPPerm::row_ror
:
953 if (auto intAttr
= cast
<IntegerAttr
>(*permArgument
)) {
954 DppCtrl
= intAttr
.getInt() + DppCtrl::ROW_ROR0
;
957 case DPPPerm::wave_shl
:
958 DppCtrl
= DppCtrl::WAVE_SHL1
;
960 case DPPPerm::wave_shr
:
961 DppCtrl
= DppCtrl::WAVE_SHR1
;
963 case DPPPerm::wave_rol
:
964 DppCtrl
= DppCtrl::WAVE_ROL1
;
966 case DPPPerm::wave_ror
:
967 DppCtrl
= DppCtrl::WAVE_ROR1
;
969 case DPPPerm::row_mirror
:
970 DppCtrl
= DppCtrl::ROW_MIRROR
;
972 case DPPPerm::row_half_mirror
:
973 DppCtrl
= DppCtrl::ROW_HALF_MIRROR
;
975 case DPPPerm::row_bcast_15
:
976 DppCtrl
= DppCtrl::BCAST15
;
978 case DPPPerm::row_bcast_31
:
979 DppCtrl
= DppCtrl::BCAST31
;
983 // Check for row_mask, bank_mask, bound_ctrl if they exist and create
985 auto rowMask
= DppOp
->getAttrOfType
<IntegerAttr
>("row_mask").getInt();
986 auto bankMask
= DppOp
->getAttrOfType
<IntegerAttr
>("bank_mask").getInt();
987 bool boundCtrl
= DppOp
->getAttrOfType
<BoolAttr
>("bound_ctrl").getValue();
989 // create a ROCDL_DPPMovOp instruction with the appropriate attributes
990 auto dppMovOp
= rewriter
.create
<ROCDL::DPPUpdateOp
>(
991 loc
, llvmType
, old
, src
, DppCtrl
, rowMask
, bankMask
, boundCtrl
);
993 Value result
= dppMovOp
.getRes();
994 if (srcType
.getIntOrFloatBitWidth() < 32) {
995 result
= rewriter
.create
<LLVM::TruncOp
>(loc
, llvmSrcIntType
, result
);
996 if (!llvm::isa
<IntegerType
>(srcType
)) {
997 result
= rewriter
.create
<LLVM::BitcastOp
>(loc
, srcType
, result
);
1001 // We are replacing the AMDGPU_DPPOp instruction with the new
1002 // ROCDL_DPPMovOp instruction
1003 rewriter
.replaceOp(DppOp
, ValueRange(result
));
1008 struct ConvertAMDGPUToROCDLPass
1009 : public impl::ConvertAMDGPUToROCDLBase
<ConvertAMDGPUToROCDLPass
> {
1010 ConvertAMDGPUToROCDLPass() = default;
1012 void runOnOperation() override
{
1013 MLIRContext
*ctx
= &getContext();
1014 FailureOr
<Chipset
> maybeChipset
= Chipset::parse(chipset
);
1015 if (failed(maybeChipset
)) {
1016 emitError(UnknownLoc::get(ctx
), "Invalid chipset name: " + chipset
);
1017 return signalPassFailure();
1020 RewritePatternSet
patterns(ctx
);
1021 LLVMTypeConverter
converter(ctx
);
1022 populateAMDGPUToROCDLConversionPatterns(converter
, patterns
, *maybeChipset
);
1023 LLVMConversionTarget
target(getContext());
1024 target
.addIllegalDialect
<::mlir::amdgpu::AMDGPUDialect
>();
1025 target
.addLegalDialect
<::mlir::LLVM::LLVMDialect
>();
1026 target
.addLegalDialect
<::mlir::ROCDL::ROCDLDialect
>();
1027 if (failed(applyPartialConversion(getOperation(), target
,
1028 std::move(patterns
))))
1029 signalPassFailure();
1034 void mlir::populateAMDGPUToROCDLConversionPatterns(
1035 const LLVMTypeConverter
&converter
, RewritePatternSet
&patterns
,
1038 .add
<RawBufferOpLowering
<RawBufferLoadOp
, ROCDL::RawPtrBufferLoadOp
>,
1039 RawBufferOpLowering
<RawBufferStoreOp
, ROCDL::RawPtrBufferStoreOp
>,
1040 RawBufferOpLowering
<RawBufferAtomicFaddOp
,
1041 ROCDL::RawPtrBufferAtomicFaddOp
>,
1042 RawBufferOpLowering
<RawBufferAtomicFmaxOp
,
1043 ROCDL::RawPtrBufferAtomicFmaxOp
>,
1044 RawBufferOpLowering
<RawBufferAtomicSmaxOp
,
1045 ROCDL::RawPtrBufferAtomicSmaxOp
>,
1046 RawBufferOpLowering
<RawBufferAtomicUminOp
,
1047 ROCDL::RawPtrBufferAtomicUminOp
>,
1048 RawBufferOpLowering
<RawBufferAtomicCmpswapOp
,
1049 ROCDL::RawPtrBufferAtomicCmpSwap
>,
1050 AMDGPUDPPLowering
, LDSBarrierOpLowering
, SchedBarrierOpLowering
,
1051 MFMAOpLowering
, WMMAOpLowering
, ExtPackedFp8OpLowering
,
1052 PackedTrunc2xFp8OpLowering
, PackedStochRoundFp8OpLowering
>(converter
,
1056 std::unique_ptr
<Pass
> mlir::createConvertAMDGPUToROCDLPass() {
1057 return std::make_unique
<ConvertAMDGPUToROCDLPass
>();