1 //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15 #include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Pass/Pass.h"
22 #include "llvm/ADT/STLExtras.h"
26 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDL
27 #include "mlir/Conversion/Passes.h.inc"
31 using namespace mlir::amdgpu
;
33 /// Convert an unsigned number `val` to i32.
34 static Value
convertUnsignedToI32(ConversionPatternRewriter
&rewriter
,
35 Location loc
, Value val
) {
36 IntegerType i32
= rewriter
.getI32Type();
37 // Force check that `val` is of int type.
38 auto valTy
= cast
<IntegerType
>(val
.getType());
41 return valTy
.getWidth() > 32
42 ? Value(rewriter
.create
<LLVM::TruncOp
>(loc
, i32
, val
))
43 : Value(rewriter
.create
<LLVM::ZExtOp
>(loc
, i32
, val
));
46 static Value
createI32Constant(ConversionPatternRewriter
&rewriter
,
47 Location loc
, int32_t value
) {
48 Type i32
= rewriter
.getI32Type();
49 return rewriter
.create
<LLVM::ConstantOp
>(loc
, i32
, value
);
52 static Value
createI1Constant(ConversionPatternRewriter
&rewriter
, Location loc
,
54 Type llvmI1
= rewriter
.getI1Type();
55 return rewriter
.create
<LLVM::ConstantOp
>(loc
, llvmI1
, value
);
58 /// Returns the linear index used to access an element in the memref.
59 static Value
getLinearIndexI32(ConversionPatternRewriter
&rewriter
,
60 Location loc
, MemRefDescriptor
&memRefDescriptor
,
61 ValueRange indices
, ArrayRef
<int64_t> strides
) {
62 IntegerType i32
= rewriter
.getI32Type();
64 for (auto [i
, increment
, stride
] : llvm::enumerate(indices
, strides
)) {
65 if (stride
!= 1) { // Skip if stride is 1.
67 ShapedType::isDynamic(stride
)
68 ? convertUnsignedToI32(rewriter
, loc
,
69 memRefDescriptor
.stride(rewriter
, loc
, i
))
70 : rewriter
.create
<LLVM::ConstantOp
>(loc
, i32
, stride
);
71 increment
= rewriter
.create
<LLVM::MulOp
>(loc
, increment
, strideValue
);
74 index
? rewriter
.create
<LLVM::AddOp
>(loc
, index
, increment
) : increment
;
76 return index
? index
: createI32Constant(rewriter
, loc
, 0);
80 // Define commonly used chipsets versions for convenience.
81 constexpr Chipset kGfx908
= Chipset(9, 0, 8);
82 constexpr Chipset kGfx90a
= Chipset(9, 0, 0xa);
83 constexpr Chipset kGfx940
= Chipset(9, 4, 0);
85 /// Define lowering patterns for raw buffer ops
86 template <typename GpuOp
, typename Intrinsic
>
87 struct RawBufferOpLowering
: public ConvertOpToLLVMPattern
<GpuOp
> {
88 RawBufferOpLowering(const LLVMTypeConverter
&converter
, Chipset chipset
)
89 : ConvertOpToLLVMPattern
<GpuOp
>(converter
), chipset(chipset
) {}
92 static constexpr uint32_t maxVectorOpWidth
= 128;
95 matchAndRewrite(GpuOp gpuOp
, typename
GpuOp::Adaptor adaptor
,
96 ConversionPatternRewriter
&rewriter
) const override
{
97 Location loc
= gpuOp
.getLoc();
98 Value memref
= adaptor
.getMemref();
99 Value unconvertedMemref
= gpuOp
.getMemref();
100 MemRefType memrefType
= cast
<MemRefType
>(unconvertedMemref
.getType());
102 if (chipset
.majorVersion
< 9)
103 return gpuOp
.emitOpError("raw buffer ops require GCN or higher");
105 Value storeData
= adaptor
.getODSOperands(0)[0];
106 if (storeData
== memref
) // no write component to this op
110 wantedDataType
= storeData
.getType();
112 wantedDataType
= gpuOp
.getODSResults(0)[0].getType();
114 Value atomicCmpData
= Value();
115 // Operand index 1 of a load is the indices, trying to read them can crash.
117 Value maybeCmpData
= adaptor
.getODSOperands(1)[0];
118 if (maybeCmpData
!= memref
)
119 atomicCmpData
= maybeCmpData
;
122 Type llvmWantedDataType
= this->typeConverter
->convertType(wantedDataType
);
124 Type i32
= rewriter
.getI32Type();
125 Type i16
= rewriter
.getI16Type();
127 // Get the type size in bytes.
128 DataLayout dataLayout
= DataLayout::closest(gpuOp
);
129 int64_t elementByteWidth
=
130 dataLayout
.getTypeSizeInBits(memrefType
.getElementType()) / 8;
131 Value byteWidthConst
= createI32Constant(rewriter
, loc
, elementByteWidth
);
133 // If we want to load a vector<NxT> with total size <= 32
134 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
135 // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
136 // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
137 // so bitcast any floats to integers.
138 Type llvmBufferValType
= llvmWantedDataType
;
140 if (auto floatType
= dyn_cast
<FloatType
>(wantedDataType
))
141 llvmBufferValType
= this->getTypeConverter()->convertType(
142 rewriter
.getIntegerType(floatType
.getWidth()));
144 if (auto dataVector
= dyn_cast
<VectorType
>(wantedDataType
)) {
145 uint32_t vecLen
= dataVector
.getNumElements();
147 dataLayout
.getTypeSizeInBits(dataVector
.getElementType());
148 uint32_t totalBits
= elemBits
* vecLen
;
150 isa_and_present
<RawBufferAtomicFaddOp
>(*gpuOp
) && vecLen
== 2;
151 if (totalBits
> maxVectorOpWidth
)
152 return gpuOp
.emitOpError(
153 "Total width of loads or stores must be no more than " +
154 Twine(maxVectorOpWidth
) + " bits, but we call for " +
156 " bits. This should've been caught in validation");
157 if (!usePackedFp16
&& elemBits
< 32) {
158 if (totalBits
> 32) {
159 if (totalBits
% 32 != 0)
160 return gpuOp
.emitOpError("Load or store of more than 32-bits that "
161 "doesn't fit into words. Can't happen\n");
162 llvmBufferValType
= this->typeConverter
->convertType(
163 VectorType::get(totalBits
/ 32, i32
));
165 llvmBufferValType
= this->typeConverter
->convertType(
166 rewriter
.getIntegerType(totalBits
));
171 SmallVector
<Value
, 6> args
;
173 if (llvmBufferValType
!= llvmWantedDataType
) {
175 rewriter
.create
<LLVM::BitcastOp
>(loc
, llvmBufferValType
, storeData
);
176 args
.push_back(castForStore
);
178 args
.push_back(storeData
);
183 if (llvmBufferValType
!= llvmWantedDataType
) {
184 Value castForCmp
= rewriter
.create
<LLVM::BitcastOp
>(
185 loc
, llvmBufferValType
, atomicCmpData
);
186 args
.push_back(castForCmp
);
188 args
.push_back(atomicCmpData
);
192 // Construct buffer descriptor from memref, attributes
194 SmallVector
<int64_t, 5> strides
;
195 if (failed(getStridesAndOffset(memrefType
, strides
, offset
)))
196 return gpuOp
.emitOpError("Can't lower non-stride-offset memrefs");
198 MemRefDescriptor
memrefDescriptor(memref
);
200 Value ptr
= memrefDescriptor
.bufferPtr(
201 rewriter
, loc
, *this->getTypeConverter(), memrefType
);
202 // The stride value is always 0 for raw buffers. This also disables
204 Value stride
= rewriter
.create
<LLVM::ConstantOp
>(
205 loc
, i16
, rewriter
.getI16IntegerAttr(0));
206 // Get the number of elements.
208 if (memrefType
.hasStaticShape() &&
209 !llvm::any_of(strides
, ShapedType::isDynamic
)) {
210 int64_t size
= memrefType
.getRank() == 0 ? 1 : 0;
211 ArrayRef
<int64_t> shape
= memrefType
.getShape();
212 for (uint32_t i
= 0, e
= memrefType
.getRank(); i
< e
; ++i
)
213 size
= std::max(shape
[i
] * strides
[i
], size
);
214 size
= size
* elementByteWidth
;
215 assert(size
< std::numeric_limits
<uint32_t>::max() &&
216 "the memref buffer is too large");
217 numRecords
= createI32Constant(rewriter
, loc
, static_cast<int32_t>(size
));
220 for (uint32_t i
= 0, e
= memrefType
.getRank(); i
< e
; ++i
) {
221 Value size
= memrefDescriptor
.size(rewriter
, loc
, i
);
222 Value stride
= memrefDescriptor
.stride(rewriter
, loc
, i
);
223 Value maxThisDim
= rewriter
.create
<LLVM::MulOp
>(loc
, size
, stride
);
225 maxIndex
? rewriter
.create
<LLVM::UMaxOp
>(loc
, maxIndex
, maxThisDim
)
228 numRecords
= rewriter
.create
<LLVM::MulOp
>(
229 loc
, convertUnsignedToI32(rewriter
, loc
, maxIndex
), byteWidthConst
);
233 // bits 0-11: dst sel, ignored by these intrinsics
234 // bits 12-14: data format (ignored, must be nonzero, 7=float)
235 // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
236 // bit 19: In nested heap (0 here)
237 // bit 20: Behavior on unmap (0 means "return 0 / ignore")
238 // bits 21-22: Index stride for swizzles (N/A)
239 // bit 23: Add thread ID (0)
240 // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
241 // bits 25-26: Reserved (0)
242 // bit 27: Buffer is non-volatile (CDNA only)
243 // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
244 // none, 3 = either swizzles or testing against offset field) RDNA only
245 // bits 30-31: Type (must be 0)
246 uint32_t flags
= (7 << 12) | (4 << 15);
247 if (chipset
.majorVersion
>= 10) {
249 uint32_t oob
= adaptor
.getBoundsCheck() ? 3 : 2;
250 flags
|= (oob
<< 28);
252 Value flagsConst
= createI32Constant(rewriter
, loc
, flags
);
253 Type rsrcType
= LLVM::LLVMPointerType::get(rewriter
.getContext(), 8);
254 Value resource
= rewriter
.createOrFold
<ROCDL::MakeBufferRsrcOp
>(
255 loc
, rsrcType
, ptr
, stride
, numRecords
, flagsConst
);
256 args
.push_back(resource
);
258 // Indexing (voffset)
259 Value voffset
= getLinearIndexI32(rewriter
, loc
, memrefDescriptor
,
260 adaptor
.getIndices(), strides
);
261 if (std::optional
<int32_t> indexOffset
= adaptor
.getIndexOffset();
262 indexOffset
&& *indexOffset
> 0) {
263 Value extraOffsetConst
= createI32Constant(rewriter
, loc
, *indexOffset
);
265 voffset
? rewriter
.create
<LLVM::AddOp
>(loc
, voffset
, extraOffsetConst
)
268 voffset
= rewriter
.create
<LLVM::MulOp
>(loc
, voffset
, byteWidthConst
);
269 args
.push_back(voffset
);
272 Value sgprOffset
= adaptor
.getSgprOffset();
274 sgprOffset
= createI32Constant(rewriter
, loc
, 0);
275 sgprOffset
= rewriter
.create
<LLVM::MulOp
>(loc
, sgprOffset
, byteWidthConst
);
276 args
.push_back(sgprOffset
);
278 // bit 0: GLC = 0 (atomics drop value, less coherency)
279 // bits 1-2: SLC, DLC = 0 (similarly)
280 // bit 3: swizzled (0 for raw)
281 args
.push_back(createI32Constant(rewriter
, loc
, 0));
283 llvm::SmallVector
<Type
, 1> resultTypes(gpuOp
->getNumResults(),
285 Operation
*lowered
= rewriter
.create
<Intrinsic
>(loc
, resultTypes
, args
,
286 ArrayRef
<NamedAttribute
>());
287 if (lowered
->getNumResults() == 1) {
288 Value replacement
= lowered
->getResult(0);
289 if (llvmBufferValType
!= llvmWantedDataType
) {
290 replacement
= rewriter
.create
<LLVM::BitcastOp
>(loc
, llvmWantedDataType
,
293 rewriter
.replaceOp(gpuOp
, replacement
);
295 rewriter
.eraseOp(gpuOp
);
301 struct LDSBarrierOpLowering
: public ConvertOpToLLVMPattern
<LDSBarrierOp
> {
302 LDSBarrierOpLowering(const LLVMTypeConverter
&converter
, Chipset chipset
)
303 : ConvertOpToLLVMPattern
<LDSBarrierOp
>(converter
), chipset(chipset
) {}
308 matchAndRewrite(LDSBarrierOp op
, LDSBarrierOp::Adaptor adaptor
,
309 ConversionPatternRewriter
&rewriter
) const override
{
310 bool requiresInlineAsm
= chipset
< kGfx90a
|| chipset
.majorVersion
== 11;
312 if (requiresInlineAsm
) {
313 auto asmDialectAttr
= LLVM::AsmDialectAttr::get(rewriter
.getContext(),
314 LLVM::AsmDialect::AD_ATT
);
316 ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
317 const char *constraints
= "";
318 rewriter
.replaceOpWithNewOp
<LLVM::InlineAsmOp
>(
320 /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
321 /*asm_string=*/asmStr
, constraints
, /*has_side_effects=*/true,
322 /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr
,
323 /*operand_attrs=*/ArrayAttr());
326 if (chipset
.majorVersion
< 12) {
327 constexpr int32_t ldsOnlyBitsGfx6789
= ~(0x1f << 8);
328 constexpr int32_t ldsOnlyBitsGfx10
= ~(0x3f << 8);
329 // Left in place in case someone disables the inline ASM path or future
330 // chipsets use the same bit pattern.
331 constexpr int32_t ldsOnlyBitsGfx11
= ~(0x3f << 4);
334 if (chipset
.majorVersion
== 11)
335 ldsOnlyBits
= ldsOnlyBitsGfx11
;
336 else if (chipset
.majorVersion
== 10)
337 ldsOnlyBits
= ldsOnlyBitsGfx10
;
338 else if (chipset
.majorVersion
<= 9)
339 ldsOnlyBits
= ldsOnlyBitsGfx6789
;
341 return op
.emitOpError(
342 "don't know how to lower this for chipset major version")
343 << chipset
.majorVersion
;
345 Location loc
= op
->getLoc();
346 rewriter
.create
<ROCDL::WaitcntOp
>(loc
, ldsOnlyBits
);
347 rewriter
.replaceOpWithNewOp
<ROCDL::SBarrierOp
>(op
);
349 Location loc
= op
->getLoc();
350 rewriter
.create
<ROCDL::WaitDscntOp
>(loc
, 0);
351 rewriter
.create
<ROCDL::BarrierSignalOp
>(loc
, -1);
352 rewriter
.replaceOpWithNewOp
<ROCDL::BarrierWaitOp
>(op
, -1);
359 struct SchedBarrierOpLowering
: public ConvertOpToLLVMPattern
<SchedBarrierOp
> {
360 SchedBarrierOpLowering(const LLVMTypeConverter
&converter
, Chipset chipset
)
361 : ConvertOpToLLVMPattern
<SchedBarrierOp
>(converter
), chipset(chipset
) {}
366 matchAndRewrite(SchedBarrierOp op
, SchedBarrierOp::Adaptor adaptor
,
367 ConversionPatternRewriter
&rewriter
) const override
{
368 rewriter
.replaceOpWithNewOp
<ROCDL::SchedBarrier
>(op
,
369 (uint32_t)op
.getOpts());
376 /// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
377 /// and LLVM AMDGPU intrinsics convention.
380 /// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
381 /// 2. If the element type is bfloat16, bitcast it to i16.
382 static Value
convertMFMAVectorOperand(ConversionPatternRewriter
&rewriter
,
383 Location loc
, Value input
) {
384 Type inputType
= input
.getType();
385 if (auto vectorType
= dyn_cast
<VectorType
>(inputType
)) {
386 if (vectorType
.getElementType().isBF16())
387 return rewriter
.create
<LLVM::BitcastOp
>(
388 loc
, vectorType
.clone(rewriter
.getI16Type()), input
);
389 if (vectorType
.getElementType().isInteger(8)) {
390 return rewriter
.create
<LLVM::BitcastOp
>(
391 loc
, rewriter
.getIntegerType(vectorType
.getNumElements() * 8), input
);
397 /// Push an input operand. If it is a float type, nothing to do. If it is
398 /// an integer type, then we need to also push its signdness (1 for signed, 0
399 /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
400 /// vector. We also need to convert bfloat inputs to i16 to account for the lack
401 /// of bfloat support in the WMMA intrinsics themselves.
402 static void wmmaPushInputOperand(ConversionPatternRewriter
&rewriter
,
404 const TypeConverter
*typeConverter
,
405 bool isUnsigned
, Value llvmInput
,
407 SmallVector
<Value
, 4> &operands
) {
408 Type inputType
= llvmInput
.getType();
409 auto vectorType
= dyn_cast
<VectorType
>(inputType
);
410 Type elemType
= vectorType
.getElementType();
412 if (elemType
.isBF16())
413 llvmInput
= rewriter
.create
<LLVM::BitcastOp
>(
414 loc
, vectorType
.clone(rewriter
.getI16Type()), llvmInput
);
415 if (!elemType
.isInteger(8)) {
416 operands
.push_back(llvmInput
);
420 // We need to check the type of the input before conversion to properly test
421 // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
422 // fp8/int8 information is lost during the conversion process.
423 auto mlirInputType
= cast
<VectorType
>(mlirInput
.getType());
424 bool isInputInt8
= mlirInputType
.getElementType().isInteger(8);
426 // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
427 bool localIsUnsigned
= isUnsigned
;
428 if (elemType
.isUnsignedInteger(8)) {
429 localIsUnsigned
= true;
430 } else if (elemType
.isSignedInteger(8)) {
431 localIsUnsigned
= false;
433 Value sign
= createI1Constant(rewriter
, loc
, !localIsUnsigned
);
434 operands
.push_back(sign
);
437 int64_t numBytes
= vectorType
.getNumElements();
438 Type i32
= rewriter
.getI32Type();
439 VectorType vectorType32bits
= VectorType::get(numBytes
* 8 / 32, i32
);
440 auto llvmVectorType32bits
= typeConverter
->convertType(vectorType32bits
);
441 Value result
= rewriter
.createOrFold
<LLVM::BitcastOp
>(
442 loc
, llvmVectorType32bits
, llvmInput
);
443 operands
.push_back(result
);
446 /// Push the output operand. For many cases this is only pushing the output in
447 /// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
448 /// since the same numbers of VGPRs is used, we need to decide if to store the
449 /// result in the upper 16 bits of the VGPRs or in the lower part. To store the
450 /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
451 /// be stored it in the upper part
452 static void wmmaPushOutputOperand(ConversionPatternRewriter
&rewriter
,
454 const TypeConverter
*typeConverter
,
455 Value output
, int32_t subwordOffset
,
456 bool clamp
, SmallVector
<Value
, 4> &operands
) {
457 Type inputType
= output
.getType();
458 auto vectorType
= dyn_cast
<VectorType
>(inputType
);
459 Type elemType
= vectorType
.getElementType();
460 if (elemType
.isBF16())
461 output
= rewriter
.create
<LLVM::BitcastOp
>(
462 loc
, vectorType
.clone(rewriter
.getI16Type()), output
);
463 operands
.push_back(output
);
464 if (elemType
.isF16() || elemType
.isBF16() || elemType
.isInteger(16)) {
465 operands
.push_back(createI1Constant(rewriter
, loc
, subwordOffset
));
466 } else if (elemType
.isInteger(32)) {
467 operands
.push_back(createI1Constant(rewriter
, loc
, clamp
));
471 /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
472 /// if one exists. This includes checking to ensure the intrinsic is supported
473 /// on the architecture you are compiling for.
474 static std::optional
<StringRef
> mfmaOpToIntrinsic(MFMAOp mfma
,
476 uint32_t m
= mfma
.getM(), n
= mfma
.getN(), k
= mfma
.getK(),
477 b
= mfma
.getBlocks();
478 Type sourceElem
= mfma
.getSourceA().getType();
479 if (auto sourceType
= dyn_cast
<VectorType
>(sourceElem
))
480 sourceElem
= sourceType
.getElementType();
481 Type destElem
= mfma
.getDestC().getType();
482 if (auto destType
= dyn_cast
<VectorType
>(destElem
))
483 destElem
= destType
.getElementType();
485 if (sourceElem
.isF32() && destElem
.isF32()) {
486 if (mfma
.getReducePrecision() && chipset
>= kGfx940
) {
487 if (m
== 32 && n
== 32 && k
== 4 && b
== 1)
488 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
489 if (m
== 16 && n
== 16 && k
== 8 && b
== 1)
490 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
492 if (m
== 32 && n
== 32 && k
== 1 && b
== 2)
493 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
494 if (m
== 16 && n
== 16 && k
== 1 && b
== 4)
495 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
496 if (m
== 4 && n
== 4 && k
== 1 && b
== 16)
497 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
498 if (m
== 32 && n
== 32 && k
== 2 && b
== 1)
499 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
500 if (m
== 16 && n
== 16 && k
== 4 && b
== 1)
501 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
504 if (sourceElem
.isF16() && destElem
.isF32()) {
505 if (m
== 32 && n
== 32 && k
== 4 && b
== 2)
506 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
507 if (m
== 16 && n
== 16 && k
== 4 && b
== 4)
508 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
509 if (m
== 4 && n
== 4 && k
== 4 && b
== 16)
510 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
511 if (m
== 32 && n
== 32 && k
== 8 && b
== 1)
512 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
513 if (m
== 16 && n
== 16 && k
== 16 && b
== 1)
514 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
517 if (sourceElem
.isBF16() && destElem
.isF32() && chipset
>= kGfx90a
) {
518 if (m
== 32 && n
== 32 && k
== 4 && b
== 2)
519 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
520 if (m
== 16 && n
== 16 && k
== 4 && b
== 4)
521 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
522 if (m
== 4 && n
== 4 && k
== 4 && b
== 16)
523 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
524 if (m
== 32 && n
== 32 && k
== 8 && b
== 1)
525 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
526 if (m
== 16 && n
== 16 && k
== 16 && b
== 1)
527 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
530 if (sourceElem
.isBF16() && destElem
.isF32()) {
531 if (m
== 32 && n
== 32 && k
== 2 && b
== 2)
532 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
533 if (m
== 16 && n
== 16 && k
== 2 && b
== 4)
534 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
535 if (m
== 4 && n
== 4 && k
== 2 && b
== 16)
536 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
537 if (m
== 32 && n
== 32 && k
== 4 && b
== 1)
538 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
539 if (m
== 16 && n
== 16 && k
== 8 && b
== 1)
540 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
543 if (isa
<IntegerType
>(sourceElem
) && destElem
.isInteger(32)) {
544 if (m
== 32 && n
== 32 && k
== 4 && b
== 2)
545 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
546 if (m
== 16 && n
== 16 && k
== 4 && b
== 4)
547 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
548 if (m
== 4 && n
== 4 && k
== 4 && b
== 16)
549 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
550 if (m
== 32 && n
== 32 && k
== 8 && b
== 1)
551 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
552 if (m
== 16 && n
== 16 && k
== 16 && b
== 1)
553 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
554 if (m
== 32 && n
== 32 && k
== 16 && b
== 1 && chipset
>= kGfx940
)
555 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
556 if (m
== 16 && n
== 16 && k
== 32 && b
== 1 && chipset
>= kGfx940
)
557 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
560 if (sourceElem
.isF64() && destElem
.isF64() && chipset
>= kGfx90a
) {
561 if (m
== 16 && n
== 16 && k
== 4 && b
== 1)
562 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
563 if (m
== 4 && n
== 4 && k
== 4 && b
== 4)
564 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
567 if (sourceElem
.isFloat8E5M2FNUZ() && destElem
.isF32() && chipset
>= kGfx940
) {
568 // Known to be correct because there are no scalar f8 instructions and
569 // because a length mismatch will have been caught by the verifier.
571 cast
<VectorType
>(mfma
.getSourceB().getType()).getElementType();
572 if (m
== 16 && n
== 16 && k
== 32 && b
== 1) {
573 if (sourceBElem
.isFloat8E5M2FNUZ())
574 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
575 if (sourceBElem
.isFloat8E4M3FNUZ())
576 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
578 if (m
== 32 && n
== 32 && k
== 16 && b
== 1) {
579 if (sourceBElem
.isFloat8E5M2FNUZ())
580 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
581 if (sourceBElem
.isFloat8E4M3FNUZ())
582 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
586 if (sourceElem
.isFloat8E4M3FNUZ() && destElem
.isF32() && chipset
>= kGfx940
) {
588 cast
<VectorType
>(mfma
.getSourceB().getType()).getElementType();
589 if (m
== 16 && n
== 16 && k
== 32 && b
== 1) {
590 if (sourceBElem
.isFloat8E5M2FNUZ())
591 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
592 if (sourceBElem
.isFloat8E4M3FNUZ())
593 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
595 if (m
== 32 && n
== 32 && k
== 16 && b
== 1) {
596 if (sourceBElem
.isFloat8E5M2FNUZ())
597 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
598 if (sourceBElem
.isFloat8E4M3FNUZ())
599 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
606 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
607 /// if one exists. This includes checking to ensure the intrinsic is supported
608 /// on the architecture you are compiling for.
609 static std::optional
<StringRef
> wmmaOpToIntrinsic(WMMAOp wmma
,
611 auto sourceVectorType
= dyn_cast
<VectorType
>(wmma
.getSourceA().getType());
612 auto destVectorType
= dyn_cast
<VectorType
>(wmma
.getDestC().getType());
613 auto elemSourceType
= sourceVectorType
.getElementType();
614 auto elemDestType
= destVectorType
.getElementType();
616 if (elemSourceType
.isF16() && elemDestType
.isF32())
617 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
618 if (elemSourceType
.isBF16() && elemDestType
.isF32())
619 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
620 if (elemSourceType
.isF16() && elemDestType
.isF16())
621 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
622 if (elemSourceType
.isBF16() && elemDestType
.isBF16())
623 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
624 if (elemSourceType
.isInteger(8) && elemDestType
.isInteger(32))
625 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
626 if (elemSourceType
.isFloat8E4M3FN() && elemDestType
.isF32())
627 return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
628 if (elemSourceType
.isFloat8E5M2() && elemDestType
.isF32())
629 return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
634 struct MFMAOpLowering
: public ConvertOpToLLVMPattern
<MFMAOp
> {
635 MFMAOpLowering(const LLVMTypeConverter
&converter
, Chipset chipset
)
636 : ConvertOpToLLVMPattern
<MFMAOp
>(converter
), chipset(chipset
) {}
641 matchAndRewrite(MFMAOp op
, MFMAOpAdaptor adaptor
,
642 ConversionPatternRewriter
&rewriter
) const override
{
643 Location loc
= op
.getLoc();
644 Type outType
= typeConverter
->convertType(op
.getDestD().getType());
645 Type intrinsicOutType
= outType
;
646 if (auto outVecType
= dyn_cast
<VectorType
>(outType
))
647 if (outVecType
.getElementType().isBF16())
648 intrinsicOutType
= outVecType
.clone(rewriter
.getI16Type());
650 if (chipset
.majorVersion
!= 9 || chipset
< kGfx908
)
651 return op
->emitOpError("MFMA only supported on gfx908+");
652 uint32_t getBlgpField
= static_cast<uint32_t>(op
.getBlgp());
653 if (op
.getNegateA() || op
.getNegateB() || op
.getNegateC()) {
654 if (chipset
< kGfx940
)
655 return op
.emitOpError("negation unsupported on older than gfx940");
657 op
.getNegateA() | (op
.getNegateB() << 1) | (op
.getNegateC() << 2);
659 std::optional
<StringRef
> maybeIntrinsic
= mfmaOpToIntrinsic(op
, chipset
);
660 if (!maybeIntrinsic
.has_value())
661 return op
.emitOpError("no intrinsic matching MFMA size on given chipset");
662 OperationState
loweredOp(loc
, *maybeIntrinsic
);
663 loweredOp
.addTypes(intrinsicOutType
);
664 loweredOp
.addOperands(
665 {convertMFMAVectorOperand(rewriter
, loc
, adaptor
.getSourceA()),
666 convertMFMAVectorOperand(rewriter
, loc
, adaptor
.getSourceB()),
667 adaptor
.getDestC(), createI32Constant(rewriter
, loc
, op
.getCbsz()),
668 createI32Constant(rewriter
, loc
, op
.getAbid()),
669 createI32Constant(rewriter
, loc
, getBlgpField
)});
670 Value lowered
= rewriter
.create(loweredOp
)->getResult(0);
671 if (outType
!= intrinsicOutType
)
672 lowered
= rewriter
.create
<LLVM::BitcastOp
>(loc
, outType
, lowered
);
673 rewriter
.replaceOp(op
, lowered
);
678 struct WMMAOpLowering
: public ConvertOpToLLVMPattern
<WMMAOp
> {
679 WMMAOpLowering(const LLVMTypeConverter
&converter
, Chipset chipset
)
680 : ConvertOpToLLVMPattern
<WMMAOp
>(converter
), chipset(chipset
) {}
685 matchAndRewrite(WMMAOp op
, WMMAOpAdaptor adaptor
,
686 ConversionPatternRewriter
&rewriter
) const override
{
687 Location loc
= op
.getLoc();
689 typeConverter
->convertType
<VectorType
>(op
.getDestD().getType());
691 return rewriter
.notifyMatchFailure(op
, "type conversion failed");
693 if (chipset
.majorVersion
!= 11 && chipset
.majorVersion
!= 12)
694 return op
->emitOpError("WMMA only supported on gfx11 and gfx12");
696 // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
697 // need to bitcast bfloats to i16 and then bitcast them back.
698 VectorType rawOutType
= outType
;
699 if (outType
.getElementType().isBF16())
700 rawOutType
= outType
.clone(rewriter
.getI16Type());
702 std::optional
<StringRef
> maybeIntrinsic
= wmmaOpToIntrinsic(op
, chipset
);
704 if (!maybeIntrinsic
.has_value())
705 return op
.emitOpError("no intrinsic matching WMMA on the given chipset");
707 OperationState
loweredOp(loc
, *maybeIntrinsic
);
708 loweredOp
.addTypes(rawOutType
);
710 SmallVector
<Value
, 4> operands
;
711 wmmaPushInputOperand(rewriter
, loc
, typeConverter
, op
.getUnsignedA(),
712 adaptor
.getSourceA(), op
.getSourceA(), operands
);
713 wmmaPushInputOperand(rewriter
, loc
, typeConverter
, op
.getUnsignedB(),
714 adaptor
.getSourceB(), op
.getSourceB(), operands
);
715 wmmaPushOutputOperand(rewriter
, loc
, typeConverter
, adaptor
.getDestC(),
716 op
.getSubwordOffset(), op
.getClamp(), operands
);
718 loweredOp
.addOperands(operands
);
719 Operation
*lowered
= rewriter
.create(loweredOp
);
721 Operation
*maybeCastBack
= lowered
;
722 if (rawOutType
!= outType
)
724 rewriter
.create
<LLVM::BitcastOp
>(loc
, outType
, lowered
->getResult(0));
725 rewriter
.replaceOp(op
, maybeCastBack
->getResults());
732 struct ExtPackedFp8OpLowering final
733 : public ConvertOpToLLVMPattern
<ExtPackedFp8Op
> {
734 ExtPackedFp8OpLowering(const LLVMTypeConverter
&converter
, Chipset chipset
)
735 : ConvertOpToLLVMPattern
<amdgpu::ExtPackedFp8Op
>(converter
),
740 matchAndRewrite(ExtPackedFp8Op op
, ExtPackedFp8OpAdaptor adaptor
,
741 ConversionPatternRewriter
&rewriter
) const override
;
744 struct PackedTrunc2xFp8OpLowering final
745 : public ConvertOpToLLVMPattern
<PackedTrunc2xFp8Op
> {
746 PackedTrunc2xFp8OpLowering(const LLVMTypeConverter
&converter
,
748 : ConvertOpToLLVMPattern
<amdgpu::PackedTrunc2xFp8Op
>(converter
),
753 matchAndRewrite(PackedTrunc2xFp8Op op
, PackedTrunc2xFp8OpAdaptor adaptor
,
754 ConversionPatternRewriter
&rewriter
) const override
;
757 struct PackedStochRoundFp8OpLowering final
758 : public ConvertOpToLLVMPattern
<PackedStochRoundFp8Op
> {
759 PackedStochRoundFp8OpLowering(const LLVMTypeConverter
&converter
,
761 : ConvertOpToLLVMPattern
<amdgpu::PackedStochRoundFp8Op
>(converter
),
766 matchAndRewrite(PackedStochRoundFp8Op op
,
767 PackedStochRoundFp8OpAdaptor adaptor
,
768 ConversionPatternRewriter
&rewriter
) const override
;
772 LogicalResult
ExtPackedFp8OpLowering::matchAndRewrite(
773 ExtPackedFp8Op op
, ExtPackedFp8OpAdaptor adaptor
,
774 ConversionPatternRewriter
&rewriter
) const {
775 Location loc
= op
.getLoc();
776 if (chipset
.majorVersion
!= 9 || chipset
< kGfx940
)
777 return rewriter
.notifyMatchFailure(
778 loc
, "Fp8 conversion instructions are not available on target "
779 "architecture and their emulation is not implemented");
781 getTypeConverter()->convertType(VectorType::get(4, rewriter
.getI8Type()));
782 Type i32
= getTypeConverter()->convertType(rewriter
.getI32Type());
783 Type f32
= getTypeConverter()->convertType(op
.getResult().getType());
785 Value source
= adaptor
.getSource();
786 auto sourceVecType
= dyn_cast
<VectorType
>(op
.getSource().getType());
787 Type sourceElemType
= getElementTypeOrSelf(op
.getSource());
789 if (!sourceVecType
|| sourceVecType
.getNumElements() < 4) {
790 Value longVec
= rewriter
.create
<LLVM::UndefOp
>(loc
, v4i8
);
791 if (!sourceVecType
) {
792 longVec
= rewriter
.create
<LLVM::InsertElementOp
>(
793 loc
, longVec
, source
, createI32Constant(rewriter
, loc
, 0));
795 for (int32_t i
= 0, e
= sourceVecType
.getNumElements(); i
< e
; ++i
) {
796 Value idx
= createI32Constant(rewriter
, loc
, i
);
797 Value elem
= rewriter
.create
<LLVM::ExtractElementOp
>(loc
, source
, idx
);
799 rewriter
.create
<LLVM::InsertElementOp
>(loc
, longVec
, elem
, idx
);
804 Value i32Source
= rewriter
.create
<LLVM::BitcastOp
>(loc
, i32
, source
);
805 Value wordSel
= createI32Constant(rewriter
, loc
, op
.getIndex());
806 if (sourceElemType
.isFloat8E5M2FNUZ()) {
807 rewriter
.replaceOpWithNewOp
<ROCDL::CvtF32Bf8Op
>(op
, f32
, i32Source
,
809 } else if (sourceElemType
.isFloat8E4M3FNUZ()) {
810 rewriter
.replaceOpWithNewOp
<ROCDL::CvtF32Fp8Op
>(op
, f32
, i32Source
,
816 LogicalResult
PackedTrunc2xFp8OpLowering::matchAndRewrite(
817 PackedTrunc2xFp8Op op
, PackedTrunc2xFp8OpAdaptor adaptor
,
818 ConversionPatternRewriter
&rewriter
) const {
819 Location loc
= op
.getLoc();
820 if (chipset
.majorVersion
!= 9 || chipset
< kGfx940
)
821 return rewriter
.notifyMatchFailure(
822 loc
, "Fp8 conversion instructions are not available on target "
823 "architecture and their emulation is not implemented");
824 Type i32
= getTypeConverter()->convertType(rewriter
.getI32Type());
826 Type resultType
= op
.getResult().getType();
827 Type resultElemType
= getElementTypeOrSelf(resultType
);
829 Value sourceA
= adaptor
.getSourceA();
830 Value sourceB
= adaptor
.getSourceB();
832 sourceB
= rewriter
.create
<LLVM::UndefOp
>(loc
, sourceA
.getType());
833 Value existing
= adaptor
.getExisting();
835 existing
= rewriter
.create
<LLVM::BitcastOp
>(loc
, i32
, existing
);
837 existing
= rewriter
.create
<LLVM::UndefOp
>(loc
, i32
);
838 Value wordSel
= createI1Constant(rewriter
, loc
, op
.getWordIndex());
841 if (resultElemType
.isFloat8E5M2FNUZ())
842 result
= rewriter
.create
<ROCDL::CvtPkBf8F32Op
>(loc
, i32
, sourceA
, sourceB
,
844 else if (resultElemType
.isFloat8E4M3FNUZ())
845 result
= rewriter
.create
<ROCDL::CvtPkFp8F32Op
>(loc
, i32
, sourceA
, sourceB
,
848 result
= rewriter
.replaceOpWithNewOp
<LLVM::BitcastOp
>(
849 op
, getTypeConverter()->convertType(resultType
), result
);
853 LogicalResult
PackedStochRoundFp8OpLowering::matchAndRewrite(
854 PackedStochRoundFp8Op op
, PackedStochRoundFp8OpAdaptor adaptor
,
855 ConversionPatternRewriter
&rewriter
) const {
856 Location loc
= op
.getLoc();
857 if (chipset
.majorVersion
!= 9 || chipset
< kGfx940
)
858 return rewriter
.notifyMatchFailure(
859 loc
, "Fp8 conversion instructions are not available on target "
860 "architecture and their emulation is not implemented");
861 Type i32
= getTypeConverter()->convertType(rewriter
.getI32Type());
863 Type resultType
= op
.getResult().getType();
864 Type resultElemType
= getElementTypeOrSelf(resultType
);
866 Value source
= adaptor
.getSource();
867 Value stoch
= adaptor
.getStochiasticParam();
868 Value existing
= adaptor
.getExisting();
870 existing
= rewriter
.create
<LLVM::BitcastOp
>(loc
, i32
, existing
);
872 existing
= rewriter
.create
<LLVM::UndefOp
>(loc
, i32
);
873 Value byteSel
= createI32Constant(rewriter
, loc
, op
.getStoreIndex());
876 if (resultElemType
.isFloat8E5M2FNUZ())
877 result
= rewriter
.create
<ROCDL::CvtSrBf8F32Op
>(loc
, i32
, source
, stoch
,
879 else if (resultElemType
.isFloat8E4M3FNUZ())
880 result
= rewriter
.create
<ROCDL::CvtSrFp8F32Op
>(loc
, i32
, source
, stoch
,
883 result
= rewriter
.replaceOpWithNewOp
<LLVM::BitcastOp
>(
884 op
, getTypeConverter()->convertType(resultType
), result
);
888 // Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
889 // operation into the corresponding ROCDL instructions.
890 struct AMDGPUDPPLowering
: public ConvertOpToLLVMPattern
<DPPOp
> {
891 AMDGPUDPPLowering(const LLVMTypeConverter
&converter
, Chipset chipset
)
892 : ConvertOpToLLVMPattern
<DPPOp
>(converter
), chipset(chipset
) {}
896 matchAndRewrite(DPPOp DppOp
, DPPOp::Adaptor adaptor
,
897 ConversionPatternRewriter
&rewriter
) const override
{
899 // Convert the source operand to the corresponding LLVM type
900 Location loc
= DppOp
.getLoc();
901 Value src
= adaptor
.getSrc();
902 Value old
= adaptor
.getOld();
903 Type srcType
= src
.getType();
904 Type oldType
= old
.getType();
905 Type llvmType
= nullptr;
906 if (srcType
.getIntOrFloatBitWidth() < 32) {
907 llvmType
= rewriter
.getI32Type();
908 } else if (isa
<FloatType
>(srcType
)) {
909 llvmType
= (srcType
.getIntOrFloatBitWidth() == 32)
910 ? rewriter
.getF32Type()
911 : rewriter
.getF64Type();
912 } else if (isa
<IntegerType
>(srcType
)) {
913 llvmType
= (srcType
.getIntOrFloatBitWidth() == 32)
914 ? rewriter
.getI32Type()
915 : rewriter
.getI64Type();
917 auto llvmSrcIntType
= typeConverter
->convertType(
918 rewriter
.getIntegerType(srcType
.getIntOrFloatBitWidth()));
920 // If the source type is less of 32, use bitcast to convert it to i32.
921 auto convertOperand
= [&](Value operand
, Type operandType
) {
922 if (operandType
.getIntOrFloatBitWidth() <= 16) {
923 if (llvm::isa
<FloatType
>(operandType
)) {
925 rewriter
.create
<LLVM::BitcastOp
>(loc
, llvmSrcIntType
, operand
);
927 auto llvmVecType
= typeConverter
->convertType(mlir::VectorType::get(
928 32 / operandType
.getIntOrFloatBitWidth(), llvmSrcIntType
));
929 Value undefVec
= rewriter
.create
<LLVM::UndefOp
>(loc
, llvmVecType
);
930 operand
= rewriter
.create
<LLVM::InsertElementOp
>(
931 loc
, undefVec
, operand
, createI32Constant(rewriter
, loc
, 0));
932 operand
= rewriter
.create
<LLVM::BitcastOp
>(loc
, llvmType
, operand
);
937 src
= convertOperand(src
, srcType
);
938 old
= convertOperand(old
, oldType
);
940 // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
941 enum DppCtrl
: unsigned {
950 ROW_HALF_MIRROR
= 0x141,
955 auto kind
= DppOp
.getKind();
956 auto permArgument
= DppOp
.getPermArgument();
957 uint32_t DppCtrl
= 0;
961 case DPPPerm::quad_perm
:
962 if (auto quadPermAttr
= cast
<ArrayAttr
>(*permArgument
)) {
964 for (auto elem
: quadPermAttr
.getAsRange
<IntegerAttr
>()) {
965 uint32_t num
= elem
.getInt();
966 DppCtrl
|= num
<< (i
* 2);
971 case DPPPerm::row_shl
:
972 if (auto intAttr
= cast
<IntegerAttr
>(*permArgument
)) {
973 DppCtrl
= intAttr
.getInt() + DppCtrl::ROW_SHL0
;
976 case DPPPerm::row_shr
:
977 if (auto intAttr
= cast
<IntegerAttr
>(*permArgument
)) {
978 DppCtrl
= intAttr
.getInt() + DppCtrl::ROW_SHR0
;
981 case DPPPerm::row_ror
:
982 if (auto intAttr
= cast
<IntegerAttr
>(*permArgument
)) {
983 DppCtrl
= intAttr
.getInt() + DppCtrl::ROW_ROR0
;
986 case DPPPerm::wave_shl
:
987 DppCtrl
= DppCtrl::WAVE_SHL1
;
989 case DPPPerm::wave_shr
:
990 DppCtrl
= DppCtrl::WAVE_SHR1
;
992 case DPPPerm::wave_rol
:
993 DppCtrl
= DppCtrl::WAVE_ROL1
;
995 case DPPPerm::wave_ror
:
996 DppCtrl
= DppCtrl::WAVE_ROR1
;
998 case DPPPerm::row_mirror
:
999 DppCtrl
= DppCtrl::ROW_MIRROR
;
1001 case DPPPerm::row_half_mirror
:
1002 DppCtrl
= DppCtrl::ROW_HALF_MIRROR
;
1004 case DPPPerm::row_bcast_15
:
1005 DppCtrl
= DppCtrl::BCAST15
;
1007 case DPPPerm::row_bcast_31
:
1008 DppCtrl
= DppCtrl::BCAST31
;
1012 // Check for row_mask, bank_mask, bound_ctrl if they exist and create
1014 auto rowMask
= DppOp
->getAttrOfType
<IntegerAttr
>("row_mask").getInt();
1015 auto bankMask
= DppOp
->getAttrOfType
<IntegerAttr
>("bank_mask").getInt();
1016 bool boundCtrl
= DppOp
->getAttrOfType
<BoolAttr
>("bound_ctrl").getValue();
1018 // create a ROCDL_DPPMovOp instruction with the appropriate attributes
1019 auto dppMovOp
= rewriter
.create
<ROCDL::DPPUpdateOp
>(
1020 loc
, llvmType
, old
, src
, DppCtrl
, rowMask
, bankMask
, boundCtrl
);
1022 Value result
= dppMovOp
.getRes();
1023 if (srcType
.getIntOrFloatBitWidth() < 32) {
1024 result
= rewriter
.create
<LLVM::TruncOp
>(loc
, llvmSrcIntType
, result
);
1025 if (!llvm::isa
<IntegerType
>(srcType
)) {
1026 result
= rewriter
.create
<LLVM::BitcastOp
>(loc
, srcType
, result
);
1030 // We are replacing the AMDGPU_DPPOp instruction with the new
1031 // ROCDL_DPPMovOp instruction
1032 rewriter
.replaceOp(DppOp
, ValueRange(result
));
1037 struct ConvertAMDGPUToROCDLPass
1038 : public impl::ConvertAMDGPUToROCDLBase
<ConvertAMDGPUToROCDLPass
> {
1039 ConvertAMDGPUToROCDLPass() = default;
1041 void runOnOperation() override
{
1042 MLIRContext
*ctx
= &getContext();
1043 FailureOr
<Chipset
> maybeChipset
= Chipset::parse(chipset
);
1044 if (failed(maybeChipset
)) {
1045 emitError(UnknownLoc::get(ctx
), "Invalid chipset name: " + chipset
);
1046 return signalPassFailure();
1049 RewritePatternSet
patterns(ctx
);
1050 LLVMTypeConverter
converter(ctx
);
1051 populateAMDGPUToROCDLConversionPatterns(converter
, patterns
, *maybeChipset
);
1052 LLVMConversionTarget
target(getContext());
1053 target
.addIllegalDialect
<::mlir::amdgpu::AMDGPUDialect
>();
1054 target
.addLegalDialect
<::mlir::LLVM::LLVMDialect
>();
1055 target
.addLegalDialect
<::mlir::ROCDL::ROCDLDialect
>();
1056 if (failed(applyPartialConversion(getOperation(), target
,
1057 std::move(patterns
))))
1058 signalPassFailure();
1063 void mlir::populateAMDGPUToROCDLConversionPatterns(
1064 const LLVMTypeConverter
&converter
, RewritePatternSet
&patterns
,
1067 .add
<RawBufferOpLowering
<RawBufferLoadOp
, ROCDL::RawPtrBufferLoadOp
>,
1068 RawBufferOpLowering
<RawBufferStoreOp
, ROCDL::RawPtrBufferStoreOp
>,
1069 RawBufferOpLowering
<RawBufferAtomicFaddOp
,
1070 ROCDL::RawPtrBufferAtomicFaddOp
>,
1071 RawBufferOpLowering
<RawBufferAtomicFmaxOp
,
1072 ROCDL::RawPtrBufferAtomicFmaxOp
>,
1073 RawBufferOpLowering
<RawBufferAtomicSmaxOp
,
1074 ROCDL::RawPtrBufferAtomicSmaxOp
>,
1075 RawBufferOpLowering
<RawBufferAtomicUminOp
,
1076 ROCDL::RawPtrBufferAtomicUminOp
>,
1077 RawBufferOpLowering
<RawBufferAtomicCmpswapOp
,
1078 ROCDL::RawPtrBufferAtomicCmpSwap
>,
1079 AMDGPUDPPLowering
, LDSBarrierOpLowering
, SchedBarrierOpLowering
,
1080 MFMAOpLowering
, WMMAOpLowering
, ExtPackedFp8OpLowering
,
1081 PackedTrunc2xFp8OpLowering
, PackedStochRoundFp8OpLowering
>(converter
,
1085 std::unique_ptr
<Pass
> mlir::createConvertAMDGPUToROCDLPass() {
1086 return std::make_unique
<ConvertAMDGPUToROCDLPass
>();