[rtsan] Remove mkfifoat interceptor (#116997)
[llvm-project.git] / mlir / lib / Conversion / AMDGPUToROCDL / AMDGPUToROCDL.cpp
blob5a7897f233eaa8edc39bd9926584e73e76b95231
1 //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15 #include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Pass/Pass.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include <optional>
25 namespace mlir {
26 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDL
27 #include "mlir/Conversion/Passes.h.inc"
28 } // namespace mlir
30 using namespace mlir;
31 using namespace mlir::amdgpu;
33 static Value createI32Constant(ConversionPatternRewriter &rewriter,
34 Location loc, int32_t value) {
35 Type llvmI32 = rewriter.getI32Type();
36 return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, value);
39 static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
40 bool value) {
41 Type llvmI1 = rewriter.getI1Type();
42 return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
45 namespace {
46 // Define commonly used chipsets versions for convenience.
47 constexpr Chipset kGfx908 = Chipset(9, 0, 8);
48 constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
49 constexpr Chipset kGfx940 = Chipset(9, 4, 0);
51 /// Define lowering patterns for raw buffer ops
52 template <typename GpuOp, typename Intrinsic>
53 struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
54 RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
55 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
57 Chipset chipset;
58 static constexpr uint32_t maxVectorOpWidth = 128;
60 LogicalResult
61 matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
62 ConversionPatternRewriter &rewriter) const override {
63 Location loc = gpuOp.getLoc();
64 Value memref = adaptor.getMemref();
65 Value unconvertedMemref = gpuOp.getMemref();
66 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
68 if (chipset.majorVersion < 9)
69 return gpuOp.emitOpError("raw buffer ops require GCN or higher");
71 Value storeData = adaptor.getODSOperands(0)[0];
72 if (storeData == memref) // no write component to this op
73 storeData = Value();
74 Type wantedDataType;
75 if (storeData)
76 wantedDataType = storeData.getType();
77 else
78 wantedDataType = gpuOp.getODSResults(0)[0].getType();
80 Value atomicCmpData = Value();
81 // Operand index 1 of a load is the indices, trying to read them can crash.
82 if (storeData) {
83 Value maybeCmpData = adaptor.getODSOperands(1)[0];
84 if (maybeCmpData != memref)
85 atomicCmpData = maybeCmpData;
88 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
90 Type i32 = rewriter.getI32Type();
91 Type llvmI32 = this->typeConverter->convertType(i32);
92 Type llvmI16 = this->typeConverter->convertType(rewriter.getI16Type());
94 int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8;
95 Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
97 // If we want to load a vector<NxT> with total size <= 32
98 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
99 // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
100 // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
101 // so bitcast any floats to integers.
102 Type llvmBufferValType = llvmWantedDataType;
103 if (atomicCmpData) {
104 if (auto floatType = dyn_cast<FloatType>(wantedDataType))
105 llvmBufferValType = this->getTypeConverter()->convertType(
106 rewriter.getIntegerType(floatType.getWidth()));
108 if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
109 uint32_t vecLen = dataVector.getNumElements();
110 uint32_t elemBits = dataVector.getElementTypeBitWidth();
111 uint32_t totalBits = elemBits * vecLen;
112 bool usePackedFp16 =
113 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
114 if (totalBits > maxVectorOpWidth)
115 return gpuOp.emitOpError(
116 "Total width of loads or stores must be no more than " +
117 Twine(maxVectorOpWidth) + " bits, but we call for " +
118 Twine(totalBits) +
119 " bits. This should've been caught in validation");
120 if (!usePackedFp16 && elemBits < 32) {
121 if (totalBits > 32) {
122 if (totalBits % 32 != 0)
123 return gpuOp.emitOpError("Load or store of more than 32-bits that "
124 "doesn't fit into words. Can't happen\n");
125 llvmBufferValType = this->typeConverter->convertType(
126 VectorType::get(totalBits / 32, i32));
127 } else {
128 llvmBufferValType = this->typeConverter->convertType(
129 rewriter.getIntegerType(totalBits));
134 SmallVector<Value, 6> args;
135 if (storeData) {
136 if (llvmBufferValType != llvmWantedDataType) {
137 Value castForStore =
138 rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
139 args.push_back(castForStore);
140 } else {
141 args.push_back(storeData);
145 if (atomicCmpData) {
146 if (llvmBufferValType != llvmWantedDataType) {
147 Value castForCmp = rewriter.create<LLVM::BitcastOp>(
148 loc, llvmBufferValType, atomicCmpData);
149 args.push_back(castForCmp);
150 } else {
151 args.push_back(atomicCmpData);
155 // Construct buffer descriptor from memref, attributes
156 int64_t offset = 0;
157 SmallVector<int64_t, 5> strides;
158 if (failed(getStridesAndOffset(memrefType, strides, offset)))
159 return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
161 MemRefDescriptor memrefDescriptor(memref);
163 Value ptr = memrefDescriptor.alignedPtr(rewriter, loc);
164 // The stride value is always 0 for raw buffers. This also disables
165 // swizling.
166 Value stride = rewriter.create<LLVM::ConstantOp>(
167 loc, llvmI16, rewriter.getI16IntegerAttr(0));
168 Value numRecords;
169 if (memrefType.hasStaticShape()) {
170 numRecords = createI32Constant(
171 rewriter, loc,
172 static_cast<int32_t>(memrefType.getNumElements() * elementByteWidth));
173 } else {
174 Value maxIndex;
175 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
176 Value size = memrefDescriptor.size(rewriter, loc, i);
177 Value stride = memrefDescriptor.stride(rewriter, loc, i);
178 stride = rewriter.create<LLVM::MulOp>(loc, stride, byteWidthConst);
179 Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
180 maxIndex = maxIndex ? rewriter.create<LLVM::MaximumOp>(loc, maxIndex,
181 maxThisDim)
182 : maxThisDim;
184 numRecords = rewriter.create<LLVM::TruncOp>(loc, llvmI32, maxIndex);
187 // Flag word:
188 // bits 0-11: dst sel, ignored by these intrinsics
189 // bits 12-14: data format (ignored, must be nonzero, 7=float)
190 // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
191 // bit 19: In nested heap (0 here)
192 // bit 20: Behavior on unmap (0 means "return 0 / ignore")
193 // bits 21-22: Index stride for swizzles (N/A)
194 // bit 23: Add thread ID (0)
195 // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
196 // bits 25-26: Reserved (0)
197 // bit 27: Buffer is non-volatile (CDNA only)
198 // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
199 // none, 3 = either swizzles or testing against offset field) RDNA only
200 // bits 30-31: Type (must be 0)
201 uint32_t flags = (7 << 12) | (4 << 15);
202 if (chipset.majorVersion >= 10) {
203 flags |= (1 << 24);
204 uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
205 flags |= (oob << 28);
207 Value flagsConst = createI32Constant(rewriter, loc, flags);
208 Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
209 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
210 loc, rsrcType, ptr, stride, numRecords, flagsConst);
211 args.push_back(resource);
213 // Indexing (voffset)
214 Value voffset = createI32Constant(rewriter, loc, 0);
215 for (auto pair : llvm::enumerate(adaptor.getIndices())) {
216 size_t i = pair.index();
217 Value index = pair.value();
218 Value strideOp;
219 if (ShapedType::isDynamic(strides[i])) {
220 strideOp = rewriter.create<LLVM::MulOp>(
221 loc, memrefDescriptor.stride(rewriter, loc, i), byteWidthConst);
222 } else {
223 strideOp =
224 createI32Constant(rewriter, loc, strides[i] * elementByteWidth);
226 index = rewriter.create<LLVM::MulOp>(loc, index, strideOp);
227 voffset = rewriter.create<LLVM::AddOp>(loc, voffset, index);
229 if (adaptor.getIndexOffset()) {
230 int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth;
231 Value extraOffsetConst = createI32Constant(rewriter, loc, indexOffset);
232 voffset =
233 voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
234 : extraOffsetConst;
236 args.push_back(voffset);
238 Value sgprOffset = adaptor.getSgprOffset();
239 if (!sgprOffset)
240 sgprOffset = createI32Constant(rewriter, loc, 0);
241 if (ShapedType::isDynamic(offset))
242 sgprOffset = rewriter.create<LLVM::AddOp>(
243 loc, memrefDescriptor.offset(rewriter, loc), sgprOffset);
244 else if (offset > 0)
245 sgprOffset = rewriter.create<LLVM::AddOp>(
246 loc, sgprOffset, createI32Constant(rewriter, loc, offset));
247 args.push_back(sgprOffset);
249 // bit 0: GLC = 0 (atomics drop value, less coherency)
250 // bits 1-2: SLC, DLC = 0 (similarly)
251 // bit 3: swizzled (0 for raw)
252 args.push_back(createI32Constant(rewriter, loc, 0));
254 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
255 llvmBufferValType);
256 Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args,
257 ArrayRef<NamedAttribute>());
258 if (lowered->getNumResults() == 1) {
259 Value replacement = lowered->getResult(0);
260 if (llvmBufferValType != llvmWantedDataType) {
261 replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType,
262 replacement);
264 rewriter.replaceOp(gpuOp, replacement);
265 } else {
266 rewriter.eraseOp(gpuOp);
268 return success();
272 struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
273 LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
274 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
276 Chipset chipset;
278 LogicalResult
279 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
280 ConversionPatternRewriter &rewriter) const override {
281 bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11;
283 if (requiresInlineAsm) {
284 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
285 LLVM::AsmDialect::AD_ATT);
286 const char *asmStr =
287 ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
288 const char *constraints = "";
289 rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>(
291 /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
292 /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
293 /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
294 /*operand_attrs=*/ArrayAttr());
295 return success();
297 if (chipset.majorVersion < 12) {
298 constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
299 constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
300 // Left in place in case someone disables the inline ASM path or future
301 // chipsets use the same bit pattern.
302 constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
304 int32_t ldsOnlyBits;
305 if (chipset.majorVersion == 11)
306 ldsOnlyBits = ldsOnlyBitsGfx11;
307 else if (chipset.majorVersion == 10)
308 ldsOnlyBits = ldsOnlyBitsGfx10;
309 else if (chipset.majorVersion <= 9)
310 ldsOnlyBits = ldsOnlyBitsGfx6789;
311 else
312 return op.emitOpError(
313 "don't know how to lower this for chipset major version")
314 << chipset.majorVersion;
316 Location loc = op->getLoc();
317 rewriter.create<ROCDL::WaitcntOp>(loc, ldsOnlyBits);
318 rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op);
319 } else {
320 Location loc = op->getLoc();
321 rewriter.create<ROCDL::WaitDscntOp>(loc, 0);
322 rewriter.create<ROCDL::BarrierSignalOp>(loc, -1);
323 rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, -1);
326 return success();
330 struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
331 SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
332 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
334 Chipset chipset;
336 LogicalResult
337 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
338 ConversionPatternRewriter &rewriter) const override {
339 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
340 (uint32_t)op.getOpts());
341 return success();
345 } // namespace
347 /// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
348 /// and LLVM AMDGPU intrinsics convention.
350 /// Specifically:
351 /// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
352 /// 2. If the element type is bfloat16, bitcast it to i16.
353 static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
354 Location loc, Value input) {
355 Type inputType = input.getType();
356 if (auto vectorType = dyn_cast<VectorType>(inputType)) {
357 if (vectorType.getElementType().isBF16())
358 return rewriter.create<LLVM::BitcastOp>(
359 loc, vectorType.clone(rewriter.getI16Type()), input);
360 if (vectorType.getElementType().isInteger(8)) {
361 return rewriter.create<LLVM::BitcastOp>(
362 loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
365 return input;
368 /// Push an input operand. If it is a float type, nothing to do. If it is
369 /// an integer type, then we need to also push its signdness (1 for signed, 0
370 /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
371 /// vector. We also need to convert bfloat inputs to i16 to account for the lack
372 /// of bfloat support in the WMMA intrinsics themselves.
373 static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
374 Location loc,
375 const TypeConverter *typeConverter,
376 bool isUnsigned, Value llvmInput,
377 Value mlirInput,
378 SmallVector<Value, 4> &operands) {
379 Type inputType = llvmInput.getType();
380 auto vectorType = dyn_cast<VectorType>(inputType);
381 Type elemType = vectorType.getElementType();
383 if (elemType.isBF16())
384 llvmInput = rewriter.create<LLVM::BitcastOp>(
385 loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
386 if (!elemType.isInteger(8)) {
387 operands.push_back(llvmInput);
388 return;
391 // We need to check the type of the input before conversion to properly test
392 // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
393 // fp8/int8 information is lost during the conversion process.
394 auto mlirInputType = cast<VectorType>(mlirInput.getType());
395 bool isInputInt8 = mlirInputType.getElementType().isInteger(8);
396 if (isInputInt8) {
397 // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
398 bool localIsUnsigned = isUnsigned;
399 if (elemType.isUnsignedInteger(8)) {
400 localIsUnsigned = true;
401 } else if (elemType.isSignedInteger(8)) {
402 localIsUnsigned = false;
404 Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
405 operands.push_back(sign);
408 int64_t numBytes = vectorType.getNumElements();
409 Type i32 = rewriter.getI32Type();
410 VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
411 auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);
412 Value result = rewriter.createOrFold<LLVM::BitcastOp>(
413 loc, llvmVectorType32bits, llvmInput);
414 operands.push_back(result);
417 /// Push the output operand. For many cases this is only pushing the output in
418 /// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
419 /// since the same numbers of VGPRs is used, we need to decide if to store the
420 /// result in the upper 16 bits of the VGPRs or in the lower part. To store the
421 /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
422 /// be stored it in the upper part
423 static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
424 Location loc,
425 const TypeConverter *typeConverter,
426 Value output, int32_t subwordOffset,
427 bool clamp, SmallVector<Value, 4> &operands) {
428 Type inputType = output.getType();
429 auto vectorType = dyn_cast<VectorType>(inputType);
430 Type elemType = vectorType.getElementType();
431 if (elemType.isBF16())
432 output = rewriter.create<LLVM::BitcastOp>(
433 loc, vectorType.clone(rewriter.getI16Type()), output);
434 operands.push_back(output);
435 if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
436 operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
437 } else if (elemType.isInteger(32)) {
438 operands.push_back(createI1Constant(rewriter, loc, clamp));
442 /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
443 /// if one exists. This includes checking to ensure the intrinsic is supported
444 /// on the architecture you are compiling for.
445 static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
446 Chipset chipset) {
447 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
448 b = mfma.getBlocks();
449 Type sourceElem = mfma.getSourceA().getType();
450 if (auto sourceType = dyn_cast<VectorType>(sourceElem))
451 sourceElem = sourceType.getElementType();
452 Type destElem = mfma.getDestC().getType();
453 if (auto destType = dyn_cast<VectorType>(destElem))
454 destElem = destType.getElementType();
456 if (sourceElem.isF32() && destElem.isF32()) {
457 if (mfma.getReducePrecision() && chipset >= kGfx940) {
458 if (m == 32 && n == 32 && k == 4 && b == 1)
459 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
460 if (m == 16 && n == 16 && k == 8 && b == 1)
461 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
463 if (m == 32 && n == 32 && k == 1 && b == 2)
464 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
465 if (m == 16 && n == 16 && k == 1 && b == 4)
466 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
467 if (m == 4 && n == 4 && k == 1 && b == 16)
468 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
469 if (m == 32 && n == 32 && k == 2 && b == 1)
470 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
471 if (m == 16 && n == 16 && k == 4 && b == 1)
472 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
475 if (sourceElem.isF16() && destElem.isF32()) {
476 if (m == 32 && n == 32 && k == 4 && b == 2)
477 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
478 if (m == 16 && n == 16 && k == 4 && b == 4)
479 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
480 if (m == 4 && n == 4 && k == 4 && b == 16)
481 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
482 if (m == 32 && n == 32 && k == 8 && b == 1)
483 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
484 if (m == 16 && n == 16 && k == 16 && b == 1)
485 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
488 if (sourceElem.isBF16() && destElem.isF32() && chipset >= kGfx90a) {
489 if (m == 32 && n == 32 && k == 4 && b == 2)
490 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
491 if (m == 16 && n == 16 && k == 4 && b == 4)
492 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
493 if (m == 4 && n == 4 && k == 4 && b == 16)
494 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
495 if (m == 32 && n == 32 && k == 8 && b == 1)
496 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
497 if (m == 16 && n == 16 && k == 16 && b == 1)
498 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
501 if (sourceElem.isBF16() && destElem.isF32()) {
502 if (m == 32 && n == 32 && k == 2 && b == 2)
503 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
504 if (m == 16 && n == 16 && k == 2 && b == 4)
505 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
506 if (m == 4 && n == 4 && k == 2 && b == 16)
507 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
508 if (m == 32 && n == 32 && k == 4 && b == 1)
509 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
510 if (m == 16 && n == 16 && k == 8 && b == 1)
511 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
514 if (isa<IntegerType>(sourceElem) && destElem.isInteger(32)) {
515 if (m == 32 && n == 32 && k == 4 && b == 2)
516 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
517 if (m == 16 && n == 16 && k == 4 && b == 4)
518 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
519 if (m == 4 && n == 4 && k == 4 && b == 16)
520 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
521 if (m == 32 && n == 32 && k == 8 && b == 1)
522 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
523 if (m == 16 && n == 16 && k == 16 && b == 1)
524 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
525 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx940)
526 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
527 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx940)
528 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
531 if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
532 if (m == 16 && n == 16 && k == 4 && b == 1)
533 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
534 if (m == 4 && n == 4 && k == 4 && b == 4)
535 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
538 if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) {
539 // Known to be correct because there are no scalar f8 instructions and
540 // because a length mismatch will have been caught by the verifier.
541 Type sourceBElem =
542 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
543 if (m == 16 && n == 16 && k == 32 && b == 1) {
544 if (sourceBElem.isFloat8E5M2FNUZ())
545 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
546 if (sourceBElem.isFloat8E4M3FNUZ())
547 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
549 if (m == 32 && n == 32 && k == 16 && b == 1) {
550 if (sourceBElem.isFloat8E5M2FNUZ())
551 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
552 if (sourceBElem.isFloat8E4M3FNUZ())
553 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
557 if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) {
558 Type sourceBElem =
559 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
560 if (m == 16 && n == 16 && k == 32 && b == 1) {
561 if (sourceBElem.isFloat8E5M2FNUZ())
562 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
563 if (sourceBElem.isFloat8E4M3FNUZ())
564 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
566 if (m == 32 && n == 32 && k == 16 && b == 1) {
567 if (sourceBElem.isFloat8E5M2FNUZ())
568 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
569 if (sourceBElem.isFloat8E4M3FNUZ())
570 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
574 return std::nullopt;
577 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
578 /// if one exists. This includes checking to ensure the intrinsic is supported
579 /// on the architecture you are compiling for.
580 static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
581 Chipset chipset) {
582 auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
583 auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
584 auto elemSourceType = sourceVectorType.getElementType();
585 auto elemDestType = destVectorType.getElementType();
587 if (elemSourceType.isF16() && elemDestType.isF32())
588 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
589 if (elemSourceType.isBF16() && elemDestType.isF32())
590 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
591 if (elemSourceType.isF16() && elemDestType.isF16())
592 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
593 if (elemSourceType.isBF16() && elemDestType.isBF16())
594 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
595 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
596 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
597 if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32())
598 return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
599 if (elemSourceType.isFloat8E5M2() && elemDestType.isF32())
600 return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
601 return std::nullopt;
604 namespace {
605 struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
606 MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
607 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
609 Chipset chipset;
611 LogicalResult
612 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
613 ConversionPatternRewriter &rewriter) const override {
614 Location loc = op.getLoc();
615 Type outType = typeConverter->convertType(op.getDestD().getType());
616 Type intrinsicOutType = outType;
617 if (auto outVecType = dyn_cast<VectorType>(outType))
618 if (outVecType.getElementType().isBF16())
619 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
621 if (chipset.majorVersion != 9 || chipset < kGfx908)
622 return op->emitOpError("MFMA only supported on gfx908+");
623 uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
624 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
625 if (chipset < kGfx940)
626 return op.emitOpError("negation unsupported on older than gfx940");
627 getBlgpField |=
628 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
630 std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
631 if (!maybeIntrinsic.has_value())
632 return op.emitOpError("no intrinsic matching MFMA size on given chipset");
633 OperationState loweredOp(loc, *maybeIntrinsic);
634 loweredOp.addTypes(intrinsicOutType);
635 loweredOp.addOperands(
636 {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
637 convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
638 adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
639 createI32Constant(rewriter, loc, op.getAbid()),
640 createI32Constant(rewriter, loc, getBlgpField)});
641 Value lowered = rewriter.create(loweredOp)->getResult(0);
642 if (outType != intrinsicOutType)
643 lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
644 rewriter.replaceOp(op, lowered);
645 return success();
649 struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
650 WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
651 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
653 Chipset chipset;
655 LogicalResult
656 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
657 ConversionPatternRewriter &rewriter) const override {
658 Location loc = op.getLoc();
659 auto outType =
660 typeConverter->convertType<VectorType>(op.getDestD().getType());
661 if (!outType)
662 return rewriter.notifyMatchFailure(op, "type conversion failed");
664 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
665 return op->emitOpError("WMMA only supported on gfx11 and gfx12");
667 // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
668 // need to bitcast bfloats to i16 and then bitcast them back.
669 VectorType rawOutType = outType;
670 if (outType.getElementType().isBF16())
671 rawOutType = outType.clone(rewriter.getI16Type());
673 std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
675 if (!maybeIntrinsic.has_value())
676 return op.emitOpError("no intrinsic matching WMMA on the given chipset");
678 OperationState loweredOp(loc, *maybeIntrinsic);
679 loweredOp.addTypes(rawOutType);
681 SmallVector<Value, 4> operands;
682 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
683 adaptor.getSourceA(), op.getSourceA(), operands);
684 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
685 adaptor.getSourceB(), op.getSourceB(), operands);
686 wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
687 op.getSubwordOffset(), op.getClamp(), operands);
689 loweredOp.addOperands(operands);
690 Operation *lowered = rewriter.create(loweredOp);
692 Operation *maybeCastBack = lowered;
693 if (rawOutType != outType)
694 maybeCastBack =
695 rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0));
696 rewriter.replaceOp(op, maybeCastBack->getResults());
698 return success();
702 namespace {
703 struct ExtPackedFp8OpLowering final
704 : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
705 ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
706 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
707 chipset(chipset) {}
708 Chipset chipset;
710 LogicalResult
711 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
712 ConversionPatternRewriter &rewriter) const override;
715 struct PackedTrunc2xFp8OpLowering final
716 : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
717 PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
718 Chipset chipset)
719 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
720 chipset(chipset) {}
721 Chipset chipset;
723 LogicalResult
724 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
725 ConversionPatternRewriter &rewriter) const override;
728 struct PackedStochRoundFp8OpLowering final
729 : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
730 PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
731 Chipset chipset)
732 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
733 chipset(chipset) {}
734 Chipset chipset;
736 LogicalResult
737 matchAndRewrite(PackedStochRoundFp8Op op,
738 PackedStochRoundFp8OpAdaptor adaptor,
739 ConversionPatternRewriter &rewriter) const override;
741 } // end namespace
743 LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
744 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
745 ConversionPatternRewriter &rewriter) const {
746 Location loc = op.getLoc();
747 if (chipset.majorVersion != 9 || chipset < kGfx940)
748 return rewriter.notifyMatchFailure(
749 loc, "Fp8 conversion instructions are not available on target "
750 "architecture and their emulation is not implemented");
751 Type v4i8 =
752 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
753 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
754 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
756 Value source = adaptor.getSource();
757 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
758 Type sourceElemType = getElementTypeOrSelf(op.getSource());
759 // Extend to a v4i8
760 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
761 Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
762 if (!sourceVecType) {
763 longVec = rewriter.create<LLVM::InsertElementOp>(
764 loc, longVec, source, createI32Constant(rewriter, loc, 0));
765 } else {
766 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
767 Value idx = createI32Constant(rewriter, loc, i);
768 Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
769 longVec =
770 rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
773 source = longVec;
775 Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
776 Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
777 if (sourceElemType.isFloat8E5M2FNUZ()) {
778 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
779 wordSel);
780 } else if (sourceElemType.isFloat8E4M3FNUZ()) {
781 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
782 wordSel);
784 return success();
787 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
788 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
789 ConversionPatternRewriter &rewriter) const {
790 Location loc = op.getLoc();
791 if (chipset.majorVersion != 9 || chipset < kGfx940)
792 return rewriter.notifyMatchFailure(
793 loc, "Fp8 conversion instructions are not available on target "
794 "architecture and their emulation is not implemented");
795 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
797 Type resultType = op.getResult().getType();
798 Type resultElemType = getElementTypeOrSelf(resultType);
800 Value sourceA = adaptor.getSourceA();
801 Value sourceB = adaptor.getSourceB();
802 if (!sourceB)
803 sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
804 Value existing = adaptor.getExisting();
805 if (existing)
806 existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
807 else
808 existing = rewriter.create<LLVM::UndefOp>(loc, i32);
809 Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
811 Value result;
812 if (resultElemType.isFloat8E5M2FNUZ())
813 result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
814 existing, wordSel);
815 else if (resultElemType.isFloat8E4M3FNUZ())
816 result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
817 existing, wordSel);
819 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
820 op, getTypeConverter()->convertType(resultType), result);
821 return success();
824 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
825 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
826 ConversionPatternRewriter &rewriter) const {
827 Location loc = op.getLoc();
828 if (chipset.majorVersion != 9 || chipset < kGfx940)
829 return rewriter.notifyMatchFailure(
830 loc, "Fp8 conversion instructions are not available on target "
831 "architecture and their emulation is not implemented");
832 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
834 Type resultType = op.getResult().getType();
835 Type resultElemType = getElementTypeOrSelf(resultType);
837 Value source = adaptor.getSource();
838 Value stoch = adaptor.getStochiasticParam();
839 Value existing = adaptor.getExisting();
840 if (existing)
841 existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
842 else
843 existing = rewriter.create<LLVM::UndefOp>(loc, i32);
844 Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
846 Value result;
847 if (resultElemType.isFloat8E5M2FNUZ())
848 result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
849 existing, byteSel);
850 else if (resultElemType.isFloat8E4M3FNUZ())
851 result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
852 existing, byteSel);
854 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
855 op, getTypeConverter()->convertType(resultType), result);
856 return success();
859 // Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
860 // operation into the corresponding ROCDL instructions.
861 struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
862 AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
863 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
864 Chipset chipset;
866 LogicalResult
867 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
868 ConversionPatternRewriter &rewriter) const override {
870 // Convert the source operand to the corresponding LLVM type
871 Location loc = DppOp.getLoc();
872 Value src = adaptor.getSrc();
873 Value old = adaptor.getOld();
874 Type srcType = src.getType();
875 Type oldType = old.getType();
876 Type llvmType = nullptr;
877 if (srcType.getIntOrFloatBitWidth() < 32) {
878 llvmType = rewriter.getI32Type();
879 } else if (isa<FloatType>(srcType)) {
880 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
881 ? rewriter.getF32Type()
882 : rewriter.getF64Type();
883 } else if (isa<IntegerType>(srcType)) {
884 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
885 ? rewriter.getI32Type()
886 : rewriter.getI64Type();
888 auto llvmSrcIntType = typeConverter->convertType(
889 rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
891 // If the source type is less of 32, use bitcast to convert it to i32.
892 auto convertOperand = [&](Value operand, Type operandType) {
893 if (operandType.getIntOrFloatBitWidth() <= 16) {
894 if (llvm::isa<FloatType>(operandType)) {
895 operand =
896 rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
898 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
899 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
900 Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
901 operand = rewriter.create<LLVM::InsertElementOp>(
902 loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
903 operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand);
905 return operand;
908 src = convertOperand(src, srcType);
909 old = convertOperand(old, oldType);
911 // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
912 enum DppCtrl : unsigned {
913 ROW_SHL0 = 0x100,
914 ROW_SHR0 = 0x110,
915 ROW_ROR0 = 0x120,
916 WAVE_SHL1 = 0x130,
917 WAVE_ROL1 = 0x134,
918 WAVE_SHR1 = 0x138,
919 WAVE_ROR1 = 0x13C,
920 ROW_MIRROR = 0x140,
921 ROW_HALF_MIRROR = 0x141,
922 BCAST15 = 0x142,
923 BCAST31 = 0x143,
926 auto kind = DppOp.getKind();
927 auto permArgument = DppOp.getPermArgument();
928 uint32_t DppCtrl = 0;
930 switch (kind) {
932 case DPPPerm::quad_perm:
933 if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
934 int32_t i = 0;
935 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
936 uint32_t num = elem.getInt();
937 DppCtrl |= num << (i * 2);
938 i++;
941 break;
942 case DPPPerm::row_shl:
943 if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
944 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
946 break;
947 case DPPPerm::row_shr:
948 if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
949 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
951 break;
952 case DPPPerm::row_ror:
953 if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
954 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
956 break;
957 case DPPPerm::wave_shl:
958 DppCtrl = DppCtrl::WAVE_SHL1;
959 break;
960 case DPPPerm::wave_shr:
961 DppCtrl = DppCtrl::WAVE_SHR1;
962 break;
963 case DPPPerm::wave_rol:
964 DppCtrl = DppCtrl::WAVE_ROL1;
965 break;
966 case DPPPerm::wave_ror:
967 DppCtrl = DppCtrl::WAVE_ROR1;
968 break;
969 case DPPPerm::row_mirror:
970 DppCtrl = DppCtrl::ROW_MIRROR;
971 break;
972 case DPPPerm::row_half_mirror:
973 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
974 break;
975 case DPPPerm::row_bcast_15:
976 DppCtrl = DppCtrl::BCAST15;
977 break;
978 case DPPPerm::row_bcast_31:
979 DppCtrl = DppCtrl::BCAST31;
980 break;
983 // Check for row_mask, bank_mask, bound_ctrl if they exist and create
984 // constants
985 auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
986 auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
987 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
989 // create a ROCDL_DPPMovOp instruction with the appropriate attributes
990 auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
991 loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
993 Value result = dppMovOp.getRes();
994 if (srcType.getIntOrFloatBitWidth() < 32) {
995 result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
996 if (!llvm::isa<IntegerType>(srcType)) {
997 result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
1001 // We are replacing the AMDGPU_DPPOp instruction with the new
1002 // ROCDL_DPPMovOp instruction
1003 rewriter.replaceOp(DppOp, ValueRange(result));
1004 return success();
1008 struct ConvertAMDGPUToROCDLPass
1009 : public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
1010 ConvertAMDGPUToROCDLPass() = default;
1012 void runOnOperation() override {
1013 MLIRContext *ctx = &getContext();
1014 FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
1015 if (failed(maybeChipset)) {
1016 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
1017 return signalPassFailure();
1020 RewritePatternSet patterns(ctx);
1021 LLVMTypeConverter converter(ctx);
1022 populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
1023 LLVMConversionTarget target(getContext());
1024 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
1025 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
1026 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
1027 if (failed(applyPartialConversion(getOperation(), target,
1028 std::move(patterns))))
1029 signalPassFailure();
1032 } // namespace
1034 void mlir::populateAMDGPUToROCDLConversionPatterns(
1035 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
1036 Chipset chipset) {
1037 patterns
1038 .add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
1039 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
1040 RawBufferOpLowering<RawBufferAtomicFaddOp,
1041 ROCDL::RawPtrBufferAtomicFaddOp>,
1042 RawBufferOpLowering<RawBufferAtomicFmaxOp,
1043 ROCDL::RawPtrBufferAtomicFmaxOp>,
1044 RawBufferOpLowering<RawBufferAtomicSmaxOp,
1045 ROCDL::RawPtrBufferAtomicSmaxOp>,
1046 RawBufferOpLowering<RawBufferAtomicUminOp,
1047 ROCDL::RawPtrBufferAtomicUminOp>,
1048 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
1049 ROCDL::RawPtrBufferAtomicCmpSwap>,
1050 AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1051 MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
1052 PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
1053 chipset);
1056 std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {
1057 return std::make_unique<ConvertAMDGPUToROCDLPass>();