[libcxx][test] Fix a test for the range of file offsets on ARMv7 Linux targets. ...
[llvm-project.git] / mlir / lib / Conversion / AMDGPUToROCDL / AMDGPUToROCDL.cpp
blob1564e417a7a48e7c43025bcf83d692d76eb82d48
1 //===- AMDGPUToROCDL.cpp - AMDGPU to ROCDL dialect conversion -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
11 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
14 #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
15 #include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Pass/Pass.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include <optional>
25 namespace mlir {
26 #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDL
27 #include "mlir/Conversion/Passes.h.inc"
28 } // namespace mlir
30 using namespace mlir;
31 using namespace mlir::amdgpu;
33 /// Convert an unsigned number `val` to i32.
34 static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
35 Location loc, Value val) {
36 IntegerType i32 = rewriter.getI32Type();
37 // Force check that `val` is of int type.
38 auto valTy = cast<IntegerType>(val.getType());
39 if (i32 == valTy)
40 return val;
41 return valTy.getWidth() > 32
42 ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val))
43 : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val));
46 static Value createI32Constant(ConversionPatternRewriter &rewriter,
47 Location loc, int32_t value) {
48 Type i32 = rewriter.getI32Type();
49 return rewriter.create<LLVM::ConstantOp>(loc, i32, value);
52 static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
53 bool value) {
54 Type llvmI1 = rewriter.getI1Type();
55 return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
58 /// Returns the linear index used to access an element in the memref.
59 static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
60 Location loc, MemRefDescriptor &memRefDescriptor,
61 ValueRange indices, ArrayRef<int64_t> strides) {
62 IntegerType i32 = rewriter.getI32Type();
63 Value index;
64 for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
65 if (stride != 1) { // Skip if stride is 1.
66 Value strideValue =
67 ShapedType::isDynamic(stride)
68 ? convertUnsignedToI32(rewriter, loc,
69 memRefDescriptor.stride(rewriter, loc, i))
70 : rewriter.create<LLVM::ConstantOp>(loc, i32, stride);
71 increment = rewriter.create<LLVM::MulOp>(loc, increment, strideValue);
73 index =
74 index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
76 return index ? index : createI32Constant(rewriter, loc, 0);
79 namespace {
80 // Define commonly used chipsets versions for convenience.
81 constexpr Chipset kGfx908 = Chipset(9, 0, 8);
82 constexpr Chipset kGfx90a = Chipset(9, 0, 0xa);
83 constexpr Chipset kGfx940 = Chipset(9, 4, 0);
85 /// Define lowering patterns for raw buffer ops
86 template <typename GpuOp, typename Intrinsic>
87 struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
88 RawBufferOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
89 : ConvertOpToLLVMPattern<GpuOp>(converter), chipset(chipset) {}
91 Chipset chipset;
92 static constexpr uint32_t maxVectorOpWidth = 128;
94 LogicalResult
95 matchAndRewrite(GpuOp gpuOp, typename GpuOp::Adaptor adaptor,
96 ConversionPatternRewriter &rewriter) const override {
97 Location loc = gpuOp.getLoc();
98 Value memref = adaptor.getMemref();
99 Value unconvertedMemref = gpuOp.getMemref();
100 MemRefType memrefType = cast<MemRefType>(unconvertedMemref.getType());
102 if (chipset.majorVersion < 9)
103 return gpuOp.emitOpError("raw buffer ops require GCN or higher");
105 Value storeData = adaptor.getODSOperands(0)[0];
106 if (storeData == memref) // no write component to this op
107 storeData = Value();
108 Type wantedDataType;
109 if (storeData)
110 wantedDataType = storeData.getType();
111 else
112 wantedDataType = gpuOp.getODSResults(0)[0].getType();
114 Value atomicCmpData = Value();
115 // Operand index 1 of a load is the indices, trying to read them can crash.
116 if (storeData) {
117 Value maybeCmpData = adaptor.getODSOperands(1)[0];
118 if (maybeCmpData != memref)
119 atomicCmpData = maybeCmpData;
122 Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
124 Type i32 = rewriter.getI32Type();
125 Type i16 = rewriter.getI16Type();
127 // Get the type size in bytes.
128 DataLayout dataLayout = DataLayout::closest(gpuOp);
129 int64_t elementByteWidth =
130 dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
131 Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
133 // If we want to load a vector<NxT> with total size <= 32
134 // bits, use a scalar load and bitcast it. Similarly, if bitsize(T) < 32
135 // and the total load size is >= 32, use a vector load of N / (bitsize(T) /
136 // 32) x i32 and bitcast. Also, the CAS intrinsic requires integer operands,
137 // so bitcast any floats to integers.
138 Type llvmBufferValType = llvmWantedDataType;
139 if (atomicCmpData) {
140 if (auto floatType = dyn_cast<FloatType>(wantedDataType))
141 llvmBufferValType = this->getTypeConverter()->convertType(
142 rewriter.getIntegerType(floatType.getWidth()));
144 if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
145 uint32_t vecLen = dataVector.getNumElements();
146 uint32_t elemBits =
147 dataLayout.getTypeSizeInBits(dataVector.getElementType());
148 uint32_t totalBits = elemBits * vecLen;
149 bool usePackedFp16 =
150 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
151 if (totalBits > maxVectorOpWidth)
152 return gpuOp.emitOpError(
153 "Total width of loads or stores must be no more than " +
154 Twine(maxVectorOpWidth) + " bits, but we call for " +
155 Twine(totalBits) +
156 " bits. This should've been caught in validation");
157 if (!usePackedFp16 && elemBits < 32) {
158 if (totalBits > 32) {
159 if (totalBits % 32 != 0)
160 return gpuOp.emitOpError("Load or store of more than 32-bits that "
161 "doesn't fit into words. Can't happen\n");
162 llvmBufferValType = this->typeConverter->convertType(
163 VectorType::get(totalBits / 32, i32));
164 } else {
165 llvmBufferValType = this->typeConverter->convertType(
166 rewriter.getIntegerType(totalBits));
171 SmallVector<Value, 6> args;
172 if (storeData) {
173 if (llvmBufferValType != llvmWantedDataType) {
174 Value castForStore =
175 rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
176 args.push_back(castForStore);
177 } else {
178 args.push_back(storeData);
182 if (atomicCmpData) {
183 if (llvmBufferValType != llvmWantedDataType) {
184 Value castForCmp = rewriter.create<LLVM::BitcastOp>(
185 loc, llvmBufferValType, atomicCmpData);
186 args.push_back(castForCmp);
187 } else {
188 args.push_back(atomicCmpData);
192 // Construct buffer descriptor from memref, attributes
193 int64_t offset = 0;
194 SmallVector<int64_t, 5> strides;
195 if (failed(getStridesAndOffset(memrefType, strides, offset)))
196 return gpuOp.emitOpError("Can't lower non-stride-offset memrefs");
198 MemRefDescriptor memrefDescriptor(memref);
200 Value ptr = memrefDescriptor.bufferPtr(
201 rewriter, loc, *this->getTypeConverter(), memrefType);
202 // The stride value is always 0 for raw buffers. This also disables
203 // swizling.
204 Value stride = rewriter.create<LLVM::ConstantOp>(
205 loc, i16, rewriter.getI16IntegerAttr(0));
206 // Get the number of elements.
207 Value numRecords;
208 if (memrefType.hasStaticShape() &&
209 !llvm::any_of(strides, ShapedType::isDynamic)) {
210 int64_t size = memrefType.getRank() == 0 ? 1 : 0;
211 ArrayRef<int64_t> shape = memrefType.getShape();
212 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
213 size = std::max(shape[i] * strides[i], size);
214 size = size * elementByteWidth;
215 assert(size < std::numeric_limits<uint32_t>::max() &&
216 "the memref buffer is too large");
217 numRecords = createI32Constant(rewriter, loc, static_cast<int32_t>(size));
218 } else {
219 Value maxIndex;
220 for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
221 Value size = memrefDescriptor.size(rewriter, loc, i);
222 Value stride = memrefDescriptor.stride(rewriter, loc, i);
223 Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
224 maxIndex =
225 maxIndex ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
226 : maxThisDim;
228 numRecords = rewriter.create<LLVM::MulOp>(
229 loc, convertUnsignedToI32(rewriter, loc, maxIndex), byteWidthConst);
232 // Flag word:
233 // bits 0-11: dst sel, ignored by these intrinsics
234 // bits 12-14: data format (ignored, must be nonzero, 7=float)
235 // bits 15-18: data format (ignored, must be nonzero, 4=32bit)
236 // bit 19: In nested heap (0 here)
237 // bit 20: Behavior on unmap (0 means "return 0 / ignore")
238 // bits 21-22: Index stride for swizzles (N/A)
239 // bit 23: Add thread ID (0)
240 // bit 24: Reserved to 1 (RDNA) or 0 (CDNA)
241 // bits 25-26: Reserved (0)
242 // bit 27: Buffer is non-volatile (CDNA only)
243 // bits 28-29: Out of bounds select (0 = structured, 1 = check index, 2 =
244 // none, 3 = either swizzles or testing against offset field) RDNA only
245 // bits 30-31: Type (must be 0)
246 uint32_t flags = (7 << 12) | (4 << 15);
247 if (chipset.majorVersion >= 10) {
248 flags |= (1 << 24);
249 uint32_t oob = adaptor.getBoundsCheck() ? 3 : 2;
250 flags |= (oob << 28);
252 Value flagsConst = createI32Constant(rewriter, loc, flags);
253 Type rsrcType = LLVM::LLVMPointerType::get(rewriter.getContext(), 8);
254 Value resource = rewriter.createOrFold<ROCDL::MakeBufferRsrcOp>(
255 loc, rsrcType, ptr, stride, numRecords, flagsConst);
256 args.push_back(resource);
258 // Indexing (voffset)
259 Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
260 adaptor.getIndices(), strides);
261 if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
262 indexOffset && *indexOffset > 0) {
263 Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
264 voffset =
265 voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
266 : extraOffsetConst;
268 voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst);
269 args.push_back(voffset);
271 // SGPR offset.
272 Value sgprOffset = adaptor.getSgprOffset();
273 if (!sgprOffset)
274 sgprOffset = createI32Constant(rewriter, loc, 0);
275 sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
276 args.push_back(sgprOffset);
278 // bit 0: GLC = 0 (atomics drop value, less coherency)
279 // bits 1-2: SLC, DLC = 0 (similarly)
280 // bit 3: swizzled (0 for raw)
281 args.push_back(createI32Constant(rewriter, loc, 0));
283 llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
284 llvmBufferValType);
285 Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args,
286 ArrayRef<NamedAttribute>());
287 if (lowered->getNumResults() == 1) {
288 Value replacement = lowered->getResult(0);
289 if (llvmBufferValType != llvmWantedDataType) {
290 replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType,
291 replacement);
293 rewriter.replaceOp(gpuOp, replacement);
294 } else {
295 rewriter.eraseOp(gpuOp);
297 return success();
301 struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
302 LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
303 : ConvertOpToLLVMPattern<LDSBarrierOp>(converter), chipset(chipset) {}
305 Chipset chipset;
307 LogicalResult
308 matchAndRewrite(LDSBarrierOp op, LDSBarrierOp::Adaptor adaptor,
309 ConversionPatternRewriter &rewriter) const override {
310 bool requiresInlineAsm = chipset < kGfx90a || chipset.majorVersion == 11;
312 if (requiresInlineAsm) {
313 auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
314 LLVM::AsmDialect::AD_ATT);
315 const char *asmStr =
316 ";;;WARNING: BREAKS DEBUG WATCHES\ns_waitcnt lgkmcnt(0)\ns_barrier";
317 const char *constraints = "";
318 rewriter.replaceOpWithNewOp<LLVM::InlineAsmOp>(
320 /*resultTypes=*/TypeRange(), /*operands=*/ValueRange(),
321 /*asm_string=*/asmStr, constraints, /*has_side_effects=*/true,
322 /*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
323 /*operand_attrs=*/ArrayAttr());
324 return success();
326 if (chipset.majorVersion < 12) {
327 constexpr int32_t ldsOnlyBitsGfx6789 = ~(0x1f << 8);
328 constexpr int32_t ldsOnlyBitsGfx10 = ~(0x3f << 8);
329 // Left in place in case someone disables the inline ASM path or future
330 // chipsets use the same bit pattern.
331 constexpr int32_t ldsOnlyBitsGfx11 = ~(0x3f << 4);
333 int32_t ldsOnlyBits;
334 if (chipset.majorVersion == 11)
335 ldsOnlyBits = ldsOnlyBitsGfx11;
336 else if (chipset.majorVersion == 10)
337 ldsOnlyBits = ldsOnlyBitsGfx10;
338 else if (chipset.majorVersion <= 9)
339 ldsOnlyBits = ldsOnlyBitsGfx6789;
340 else
341 return op.emitOpError(
342 "don't know how to lower this for chipset major version")
343 << chipset.majorVersion;
345 Location loc = op->getLoc();
346 rewriter.create<ROCDL::WaitcntOp>(loc, ldsOnlyBits);
347 rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op);
348 } else {
349 Location loc = op->getLoc();
350 rewriter.create<ROCDL::WaitDscntOp>(loc, 0);
351 rewriter.create<ROCDL::BarrierSignalOp>(loc, -1);
352 rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, -1);
355 return success();
359 struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
360 SchedBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
361 : ConvertOpToLLVMPattern<SchedBarrierOp>(converter), chipset(chipset) {}
363 Chipset chipset;
365 LogicalResult
366 matchAndRewrite(SchedBarrierOp op, SchedBarrierOp::Adaptor adaptor,
367 ConversionPatternRewriter &rewriter) const override {
368 rewriter.replaceOpWithNewOp<ROCDL::SchedBarrier>(op,
369 (uint32_t)op.getOpts());
370 return success();
374 } // namespace
376 /// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
377 /// and LLVM AMDGPU intrinsics convention.
379 /// Specifically:
380 /// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
381 /// 2. If the element type is bfloat16, bitcast it to i16.
382 static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
383 Location loc, Value input) {
384 Type inputType = input.getType();
385 if (auto vectorType = dyn_cast<VectorType>(inputType)) {
386 if (vectorType.getElementType().isBF16())
387 return rewriter.create<LLVM::BitcastOp>(
388 loc, vectorType.clone(rewriter.getI16Type()), input);
389 if (vectorType.getElementType().isInteger(8)) {
390 return rewriter.create<LLVM::BitcastOp>(
391 loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
394 return input;
397 /// Push an input operand. If it is a float type, nothing to do. If it is
398 /// an integer type, then we need to also push its signdness (1 for signed, 0
399 /// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
400 /// vector. We also need to convert bfloat inputs to i16 to account for the lack
401 /// of bfloat support in the WMMA intrinsics themselves.
402 static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
403 Location loc,
404 const TypeConverter *typeConverter,
405 bool isUnsigned, Value llvmInput,
406 Value mlirInput,
407 SmallVector<Value, 4> &operands) {
408 Type inputType = llvmInput.getType();
409 auto vectorType = dyn_cast<VectorType>(inputType);
410 Type elemType = vectorType.getElementType();
412 if (elemType.isBF16())
413 llvmInput = rewriter.create<LLVM::BitcastOp>(
414 loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
415 if (!elemType.isInteger(8)) {
416 operands.push_back(llvmInput);
417 return;
420 // We need to check the type of the input before conversion to properly test
421 // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
422 // fp8/int8 information is lost during the conversion process.
423 auto mlirInputType = cast<VectorType>(mlirInput.getType());
424 bool isInputInt8 = mlirInputType.getElementType().isInteger(8);
425 if (isInputInt8) {
426 // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
427 bool localIsUnsigned = isUnsigned;
428 if (elemType.isUnsignedInteger(8)) {
429 localIsUnsigned = true;
430 } else if (elemType.isSignedInteger(8)) {
431 localIsUnsigned = false;
433 Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
434 operands.push_back(sign);
437 int64_t numBytes = vectorType.getNumElements();
438 Type i32 = rewriter.getI32Type();
439 VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
440 auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);
441 Value result = rewriter.createOrFold<LLVM::BitcastOp>(
442 loc, llvmVectorType32bits, llvmInput);
443 operands.push_back(result);
446 /// Push the output operand. For many cases this is only pushing the output in
447 /// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics,
448 /// since the same numbers of VGPRs is used, we need to decide if to store the
449 /// result in the upper 16 bits of the VGPRs or in the lower part. To store the
450 /// result in the lower 16 bits, set subwordOffset to 1, otherwise result will
451 /// be stored it in the upper part
452 static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
453 Location loc,
454 const TypeConverter *typeConverter,
455 Value output, int32_t subwordOffset,
456 bool clamp, SmallVector<Value, 4> &operands) {
457 Type inputType = output.getType();
458 auto vectorType = dyn_cast<VectorType>(inputType);
459 Type elemType = vectorType.getElementType();
460 if (elemType.isBF16())
461 output = rewriter.create<LLVM::BitcastOp>(
462 loc, vectorType.clone(rewriter.getI16Type()), output);
463 operands.push_back(output);
464 if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
465 operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
466 } else if (elemType.isInteger(32)) {
467 operands.push_back(createI1Constant(rewriter, loc, clamp));
471 /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma`
472 /// if one exists. This includes checking to ensure the intrinsic is supported
473 /// on the architecture you are compiling for.
474 static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
475 Chipset chipset) {
476 uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(),
477 b = mfma.getBlocks();
478 Type sourceElem = mfma.getSourceA().getType();
479 if (auto sourceType = dyn_cast<VectorType>(sourceElem))
480 sourceElem = sourceType.getElementType();
481 Type destElem = mfma.getDestC().getType();
482 if (auto destType = dyn_cast<VectorType>(destElem))
483 destElem = destType.getElementType();
485 if (sourceElem.isF32() && destElem.isF32()) {
486 if (mfma.getReducePrecision() && chipset >= kGfx940) {
487 if (m == 32 && n == 32 && k == 4 && b == 1)
488 return ROCDL::mfma_f32_32x32x4_xf32::getOperationName();
489 if (m == 16 && n == 16 && k == 8 && b == 1)
490 return ROCDL::mfma_f32_16x16x8_xf32::getOperationName();
492 if (m == 32 && n == 32 && k == 1 && b == 2)
493 return ROCDL::mfma_f32_32x32x1f32::getOperationName();
494 if (m == 16 && n == 16 && k == 1 && b == 4)
495 return ROCDL::mfma_f32_16x16x1f32::getOperationName();
496 if (m == 4 && n == 4 && k == 1 && b == 16)
497 return ROCDL::mfma_f32_4x4x1f32::getOperationName();
498 if (m == 32 && n == 32 && k == 2 && b == 1)
499 return ROCDL::mfma_f32_32x32x2f32::getOperationName();
500 if (m == 16 && n == 16 && k == 4 && b == 1)
501 return ROCDL::mfma_f32_16x16x4f32::getOperationName();
504 if (sourceElem.isF16() && destElem.isF32()) {
505 if (m == 32 && n == 32 && k == 4 && b == 2)
506 return ROCDL::mfma_f32_32x32x4f16::getOperationName();
507 if (m == 16 && n == 16 && k == 4 && b == 4)
508 return ROCDL::mfma_f32_16x16x4f16::getOperationName();
509 if (m == 4 && n == 4 && k == 4 && b == 16)
510 return ROCDL::mfma_f32_4x4x4f16::getOperationName();
511 if (m == 32 && n == 32 && k == 8 && b == 1)
512 return ROCDL::mfma_f32_32x32x8f16::getOperationName();
513 if (m == 16 && n == 16 && k == 16 && b == 1)
514 return ROCDL::mfma_f32_16x16x16f16::getOperationName();
517 if (sourceElem.isBF16() && destElem.isF32() && chipset >= kGfx90a) {
518 if (m == 32 && n == 32 && k == 4 && b == 2)
519 return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName();
520 if (m == 16 && n == 16 && k == 4 && b == 4)
521 return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName();
522 if (m == 4 && n == 4 && k == 4 && b == 16)
523 return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName();
524 if (m == 32 && n == 32 && k == 8 && b == 1)
525 return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName();
526 if (m == 16 && n == 16 && k == 16 && b == 1)
527 return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName();
530 if (sourceElem.isBF16() && destElem.isF32()) {
531 if (m == 32 && n == 32 && k == 2 && b == 2)
532 return ROCDL::mfma_f32_32x32x2bf16::getOperationName();
533 if (m == 16 && n == 16 && k == 2 && b == 4)
534 return ROCDL::mfma_f32_16x16x2bf16::getOperationName();
535 if (m == 4 && n == 4 && k == 2 && b == 16)
536 return ROCDL::mfma_f32_4x4x2bf16::getOperationName();
537 if (m == 32 && n == 32 && k == 4 && b == 1)
538 return ROCDL::mfma_f32_32x32x4bf16::getOperationName();
539 if (m == 16 && n == 16 && k == 8 && b == 1)
540 return ROCDL::mfma_f32_16x16x8bf16::getOperationName();
543 if (isa<IntegerType>(sourceElem) && destElem.isInteger(32)) {
544 if (m == 32 && n == 32 && k == 4 && b == 2)
545 return ROCDL::mfma_i32_32x32x4i8::getOperationName();
546 if (m == 16 && n == 16 && k == 4 && b == 4)
547 return ROCDL::mfma_i32_16x16x4i8::getOperationName();
548 if (m == 4 && n == 4 && k == 4 && b == 16)
549 return ROCDL::mfma_i32_4x4x4i8::getOperationName();
550 if (m == 32 && n == 32 && k == 8 && b == 1)
551 return ROCDL::mfma_i32_32x32x8i8::getOperationName();
552 if (m == 16 && n == 16 && k == 16 && b == 1)
553 return ROCDL::mfma_i32_16x16x16i8::getOperationName();
554 if (m == 32 && n == 32 && k == 16 && b == 1 && chipset >= kGfx940)
555 return ROCDL::mfma_i32_32x32x16_i8::getOperationName();
556 if (m == 16 && n == 16 && k == 32 && b == 1 && chipset >= kGfx940)
557 return ROCDL::mfma_i32_16x16x32_i8::getOperationName();
560 if (sourceElem.isF64() && destElem.isF64() && chipset >= kGfx90a) {
561 if (m == 16 && n == 16 && k == 4 && b == 1)
562 return ROCDL::mfma_f64_16x16x4f64::getOperationName();
563 if (m == 4 && n == 4 && k == 4 && b == 4)
564 return ROCDL::mfma_f64_4x4x4f64::getOperationName();
567 if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && chipset >= kGfx940) {
568 // Known to be correct because there are no scalar f8 instructions and
569 // because a length mismatch will have been caught by the verifier.
570 Type sourceBElem =
571 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
572 if (m == 16 && n == 16 && k == 32 && b == 1) {
573 if (sourceBElem.isFloat8E5M2FNUZ())
574 return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName();
575 if (sourceBElem.isFloat8E4M3FNUZ())
576 return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName();
578 if (m == 32 && n == 32 && k == 16 && b == 1) {
579 if (sourceBElem.isFloat8E5M2FNUZ())
580 return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName();
581 if (sourceBElem.isFloat8E4M3FNUZ())
582 return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName();
586 if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && chipset >= kGfx940) {
587 Type sourceBElem =
588 cast<VectorType>(mfma.getSourceB().getType()).getElementType();
589 if (m == 16 && n == 16 && k == 32 && b == 1) {
590 if (sourceBElem.isFloat8E5M2FNUZ())
591 return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName();
592 if (sourceBElem.isFloat8E4M3FNUZ())
593 return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName();
595 if (m == 32 && n == 32 && k == 16 && b == 1) {
596 if (sourceBElem.isFloat8E5M2FNUZ())
597 return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName();
598 if (sourceBElem.isFloat8E4M3FNUZ())
599 return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName();
603 return std::nullopt;
606 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
607 /// if one exists. This includes checking to ensure the intrinsic is supported
608 /// on the architecture you are compiling for.
609 static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
610 Chipset chipset) {
611 auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
612 auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
613 auto elemSourceType = sourceVectorType.getElementType();
614 auto elemDestType = destVectorType.getElementType();
616 if (elemSourceType.isF16() && elemDestType.isF32())
617 return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
618 if (elemSourceType.isBF16() && elemDestType.isF32())
619 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
620 if (elemSourceType.isF16() && elemDestType.isF16())
621 return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
622 if (elemSourceType.isBF16() && elemDestType.isBF16())
623 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
624 if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
625 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
626 if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32())
627 return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
628 if (elemSourceType.isFloat8E5M2() && elemDestType.isF32())
629 return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
630 return std::nullopt;
633 namespace {
634 struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
635 MFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
636 : ConvertOpToLLVMPattern<MFMAOp>(converter), chipset(chipset) {}
638 Chipset chipset;
640 LogicalResult
641 matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor,
642 ConversionPatternRewriter &rewriter) const override {
643 Location loc = op.getLoc();
644 Type outType = typeConverter->convertType(op.getDestD().getType());
645 Type intrinsicOutType = outType;
646 if (auto outVecType = dyn_cast<VectorType>(outType))
647 if (outVecType.getElementType().isBF16())
648 intrinsicOutType = outVecType.clone(rewriter.getI16Type());
650 if (chipset.majorVersion != 9 || chipset < kGfx908)
651 return op->emitOpError("MFMA only supported on gfx908+");
652 uint32_t getBlgpField = static_cast<uint32_t>(op.getBlgp());
653 if (op.getNegateA() || op.getNegateB() || op.getNegateC()) {
654 if (chipset < kGfx940)
655 return op.emitOpError("negation unsupported on older than gfx940");
656 getBlgpField |=
657 op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2);
659 std::optional<StringRef> maybeIntrinsic = mfmaOpToIntrinsic(op, chipset);
660 if (!maybeIntrinsic.has_value())
661 return op.emitOpError("no intrinsic matching MFMA size on given chipset");
662 OperationState loweredOp(loc, *maybeIntrinsic);
663 loweredOp.addTypes(intrinsicOutType);
664 loweredOp.addOperands(
665 {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
666 convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
667 adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
668 createI32Constant(rewriter, loc, op.getAbid()),
669 createI32Constant(rewriter, loc, getBlgpField)});
670 Value lowered = rewriter.create(loweredOp)->getResult(0);
671 if (outType != intrinsicOutType)
672 lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
673 rewriter.replaceOp(op, lowered);
674 return success();
678 struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
679 WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
680 : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
682 Chipset chipset;
684 LogicalResult
685 matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
686 ConversionPatternRewriter &rewriter) const override {
687 Location loc = op.getLoc();
688 auto outType =
689 typeConverter->convertType<VectorType>(op.getDestD().getType());
690 if (!outType)
691 return rewriter.notifyMatchFailure(op, "type conversion failed");
693 if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
694 return op->emitOpError("WMMA only supported on gfx11 and gfx12");
696 // The WMMA operations represent vectors of bf16s as vectors of i16s, so we
697 // need to bitcast bfloats to i16 and then bitcast them back.
698 VectorType rawOutType = outType;
699 if (outType.getElementType().isBF16())
700 rawOutType = outType.clone(rewriter.getI16Type());
702 std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
704 if (!maybeIntrinsic.has_value())
705 return op.emitOpError("no intrinsic matching WMMA on the given chipset");
707 OperationState loweredOp(loc, *maybeIntrinsic);
708 loweredOp.addTypes(rawOutType);
710 SmallVector<Value, 4> operands;
711 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
712 adaptor.getSourceA(), op.getSourceA(), operands);
713 wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
714 adaptor.getSourceB(), op.getSourceB(), operands);
715 wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
716 op.getSubwordOffset(), op.getClamp(), operands);
718 loweredOp.addOperands(operands);
719 Operation *lowered = rewriter.create(loweredOp);
721 Operation *maybeCastBack = lowered;
722 if (rawOutType != outType)
723 maybeCastBack =
724 rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0));
725 rewriter.replaceOp(op, maybeCastBack->getResults());
727 return success();
731 namespace {
732 struct ExtPackedFp8OpLowering final
733 : public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
734 ExtPackedFp8OpLowering(const LLVMTypeConverter &converter, Chipset chipset)
735 : ConvertOpToLLVMPattern<amdgpu::ExtPackedFp8Op>(converter),
736 chipset(chipset) {}
737 Chipset chipset;
739 LogicalResult
740 matchAndRewrite(ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
741 ConversionPatternRewriter &rewriter) const override;
744 struct PackedTrunc2xFp8OpLowering final
745 : public ConvertOpToLLVMPattern<PackedTrunc2xFp8Op> {
746 PackedTrunc2xFp8OpLowering(const LLVMTypeConverter &converter,
747 Chipset chipset)
748 : ConvertOpToLLVMPattern<amdgpu::PackedTrunc2xFp8Op>(converter),
749 chipset(chipset) {}
750 Chipset chipset;
752 LogicalResult
753 matchAndRewrite(PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
754 ConversionPatternRewriter &rewriter) const override;
757 struct PackedStochRoundFp8OpLowering final
758 : public ConvertOpToLLVMPattern<PackedStochRoundFp8Op> {
759 PackedStochRoundFp8OpLowering(const LLVMTypeConverter &converter,
760 Chipset chipset)
761 : ConvertOpToLLVMPattern<amdgpu::PackedStochRoundFp8Op>(converter),
762 chipset(chipset) {}
763 Chipset chipset;
765 LogicalResult
766 matchAndRewrite(PackedStochRoundFp8Op op,
767 PackedStochRoundFp8OpAdaptor adaptor,
768 ConversionPatternRewriter &rewriter) const override;
770 } // end namespace
772 LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
773 ExtPackedFp8Op op, ExtPackedFp8OpAdaptor adaptor,
774 ConversionPatternRewriter &rewriter) const {
775 Location loc = op.getLoc();
776 if (chipset.majorVersion != 9 || chipset < kGfx940)
777 return rewriter.notifyMatchFailure(
778 loc, "Fp8 conversion instructions are not available on target "
779 "architecture and their emulation is not implemented");
780 Type v4i8 =
781 getTypeConverter()->convertType(VectorType::get(4, rewriter.getI8Type()));
782 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
783 Type f32 = getTypeConverter()->convertType(op.getResult().getType());
785 Value source = adaptor.getSource();
786 auto sourceVecType = dyn_cast<VectorType>(op.getSource().getType());
787 Type sourceElemType = getElementTypeOrSelf(op.getSource());
788 // Extend to a v4i8
789 if (!sourceVecType || sourceVecType.getNumElements() < 4) {
790 Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
791 if (!sourceVecType) {
792 longVec = rewriter.create<LLVM::InsertElementOp>(
793 loc, longVec, source, createI32Constant(rewriter, loc, 0));
794 } else {
795 for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
796 Value idx = createI32Constant(rewriter, loc, i);
797 Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
798 longVec =
799 rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
802 source = longVec;
804 Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
805 Value wordSel = createI32Constant(rewriter, loc, op.getIndex());
806 if (sourceElemType.isFloat8E5M2FNUZ()) {
807 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Bf8Op>(op, f32, i32Source,
808 wordSel);
809 } else if (sourceElemType.isFloat8E4M3FNUZ()) {
810 rewriter.replaceOpWithNewOp<ROCDL::CvtF32Fp8Op>(op, f32, i32Source,
811 wordSel);
813 return success();
816 LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite(
817 PackedTrunc2xFp8Op op, PackedTrunc2xFp8OpAdaptor adaptor,
818 ConversionPatternRewriter &rewriter) const {
819 Location loc = op.getLoc();
820 if (chipset.majorVersion != 9 || chipset < kGfx940)
821 return rewriter.notifyMatchFailure(
822 loc, "Fp8 conversion instructions are not available on target "
823 "architecture and their emulation is not implemented");
824 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
826 Type resultType = op.getResult().getType();
827 Type resultElemType = getElementTypeOrSelf(resultType);
829 Value sourceA = adaptor.getSourceA();
830 Value sourceB = adaptor.getSourceB();
831 if (!sourceB)
832 sourceB = rewriter.create<LLVM::UndefOp>(loc, sourceA.getType());
833 Value existing = adaptor.getExisting();
834 if (existing)
835 existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
836 else
837 existing = rewriter.create<LLVM::UndefOp>(loc, i32);
838 Value wordSel = createI1Constant(rewriter, loc, op.getWordIndex());
840 Value result;
841 if (resultElemType.isFloat8E5M2FNUZ())
842 result = rewriter.create<ROCDL::CvtPkBf8F32Op>(loc, i32, sourceA, sourceB,
843 existing, wordSel);
844 else if (resultElemType.isFloat8E4M3FNUZ())
845 result = rewriter.create<ROCDL::CvtPkFp8F32Op>(loc, i32, sourceA, sourceB,
846 existing, wordSel);
848 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
849 op, getTypeConverter()->convertType(resultType), result);
850 return success();
853 LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite(
854 PackedStochRoundFp8Op op, PackedStochRoundFp8OpAdaptor adaptor,
855 ConversionPatternRewriter &rewriter) const {
856 Location loc = op.getLoc();
857 if (chipset.majorVersion != 9 || chipset < kGfx940)
858 return rewriter.notifyMatchFailure(
859 loc, "Fp8 conversion instructions are not available on target "
860 "architecture and their emulation is not implemented");
861 Type i32 = getTypeConverter()->convertType(rewriter.getI32Type());
863 Type resultType = op.getResult().getType();
864 Type resultElemType = getElementTypeOrSelf(resultType);
866 Value source = adaptor.getSource();
867 Value stoch = adaptor.getStochiasticParam();
868 Value existing = adaptor.getExisting();
869 if (existing)
870 existing = rewriter.create<LLVM::BitcastOp>(loc, i32, existing);
871 else
872 existing = rewriter.create<LLVM::UndefOp>(loc, i32);
873 Value byteSel = createI32Constant(rewriter, loc, op.getStoreIndex());
875 Value result;
876 if (resultElemType.isFloat8E5M2FNUZ())
877 result = rewriter.create<ROCDL::CvtSrBf8F32Op>(loc, i32, source, stoch,
878 existing, byteSel);
879 else if (resultElemType.isFloat8E4M3FNUZ())
880 result = rewriter.create<ROCDL::CvtSrFp8F32Op>(loc, i32, source, stoch,
881 existing, byteSel);
883 result = rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(
884 op, getTypeConverter()->convertType(resultType), result);
885 return success();
888 // Implement the AMDGPU_DPPLowering class that will convert the amdgpu.dpp
889 // operation into the corresponding ROCDL instructions.
890 struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
891 AMDGPUDPPLowering(const LLVMTypeConverter &converter, Chipset chipset)
892 : ConvertOpToLLVMPattern<DPPOp>(converter), chipset(chipset) {}
893 Chipset chipset;
895 LogicalResult
896 matchAndRewrite(DPPOp DppOp, DPPOp::Adaptor adaptor,
897 ConversionPatternRewriter &rewriter) const override {
899 // Convert the source operand to the corresponding LLVM type
900 Location loc = DppOp.getLoc();
901 Value src = adaptor.getSrc();
902 Value old = adaptor.getOld();
903 Type srcType = src.getType();
904 Type oldType = old.getType();
905 Type llvmType = nullptr;
906 if (srcType.getIntOrFloatBitWidth() < 32) {
907 llvmType = rewriter.getI32Type();
908 } else if (isa<FloatType>(srcType)) {
909 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
910 ? rewriter.getF32Type()
911 : rewriter.getF64Type();
912 } else if (isa<IntegerType>(srcType)) {
913 llvmType = (srcType.getIntOrFloatBitWidth() == 32)
914 ? rewriter.getI32Type()
915 : rewriter.getI64Type();
917 auto llvmSrcIntType = typeConverter->convertType(
918 rewriter.getIntegerType(srcType.getIntOrFloatBitWidth()));
920 // If the source type is less of 32, use bitcast to convert it to i32.
921 auto convertOperand = [&](Value operand, Type operandType) {
922 if (operandType.getIntOrFloatBitWidth() <= 16) {
923 if (llvm::isa<FloatType>(operandType)) {
924 operand =
925 rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
927 auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
928 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
929 Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
930 operand = rewriter.create<LLVM::InsertElementOp>(
931 loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
932 operand = rewriter.create<LLVM::BitcastOp>(loc, llvmType, operand);
934 return operand;
937 src = convertOperand(src, srcType);
938 old = convertOperand(old, oldType);
940 // This is taken from the following file llvm/lib/Target/AMDGPU/SIDefines.h
941 enum DppCtrl : unsigned {
942 ROW_SHL0 = 0x100,
943 ROW_SHR0 = 0x110,
944 ROW_ROR0 = 0x120,
945 WAVE_SHL1 = 0x130,
946 WAVE_ROL1 = 0x134,
947 WAVE_SHR1 = 0x138,
948 WAVE_ROR1 = 0x13C,
949 ROW_MIRROR = 0x140,
950 ROW_HALF_MIRROR = 0x141,
951 BCAST15 = 0x142,
952 BCAST31 = 0x143,
955 auto kind = DppOp.getKind();
956 auto permArgument = DppOp.getPermArgument();
957 uint32_t DppCtrl = 0;
959 switch (kind) {
961 case DPPPerm::quad_perm:
962 if (auto quadPermAttr = cast<ArrayAttr>(*permArgument)) {
963 int32_t i = 0;
964 for (auto elem : quadPermAttr.getAsRange<IntegerAttr>()) {
965 uint32_t num = elem.getInt();
966 DppCtrl |= num << (i * 2);
967 i++;
970 break;
971 case DPPPerm::row_shl:
972 if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
973 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHL0;
975 break;
976 case DPPPerm::row_shr:
977 if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
978 DppCtrl = intAttr.getInt() + DppCtrl::ROW_SHR0;
980 break;
981 case DPPPerm::row_ror:
982 if (auto intAttr = cast<IntegerAttr>(*permArgument)) {
983 DppCtrl = intAttr.getInt() + DppCtrl::ROW_ROR0;
985 break;
986 case DPPPerm::wave_shl:
987 DppCtrl = DppCtrl::WAVE_SHL1;
988 break;
989 case DPPPerm::wave_shr:
990 DppCtrl = DppCtrl::WAVE_SHR1;
991 break;
992 case DPPPerm::wave_rol:
993 DppCtrl = DppCtrl::WAVE_ROL1;
994 break;
995 case DPPPerm::wave_ror:
996 DppCtrl = DppCtrl::WAVE_ROR1;
997 break;
998 case DPPPerm::row_mirror:
999 DppCtrl = DppCtrl::ROW_MIRROR;
1000 break;
1001 case DPPPerm::row_half_mirror:
1002 DppCtrl = DppCtrl::ROW_HALF_MIRROR;
1003 break;
1004 case DPPPerm::row_bcast_15:
1005 DppCtrl = DppCtrl::BCAST15;
1006 break;
1007 case DPPPerm::row_bcast_31:
1008 DppCtrl = DppCtrl::BCAST31;
1009 break;
1012 // Check for row_mask, bank_mask, bound_ctrl if they exist and create
1013 // constants
1014 auto rowMask = DppOp->getAttrOfType<IntegerAttr>("row_mask").getInt();
1015 auto bankMask = DppOp->getAttrOfType<IntegerAttr>("bank_mask").getInt();
1016 bool boundCtrl = DppOp->getAttrOfType<BoolAttr>("bound_ctrl").getValue();
1018 // create a ROCDL_DPPMovOp instruction with the appropriate attributes
1019 auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
1020 loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl);
1022 Value result = dppMovOp.getRes();
1023 if (srcType.getIntOrFloatBitWidth() < 32) {
1024 result = rewriter.create<LLVM::TruncOp>(loc, llvmSrcIntType, result);
1025 if (!llvm::isa<IntegerType>(srcType)) {
1026 result = rewriter.create<LLVM::BitcastOp>(loc, srcType, result);
1030 // We are replacing the AMDGPU_DPPOp instruction with the new
1031 // ROCDL_DPPMovOp instruction
1032 rewriter.replaceOp(DppOp, ValueRange(result));
1033 return success();
1037 struct ConvertAMDGPUToROCDLPass
1038 : public impl::ConvertAMDGPUToROCDLBase<ConvertAMDGPUToROCDLPass> {
1039 ConvertAMDGPUToROCDLPass() = default;
1041 void runOnOperation() override {
1042 MLIRContext *ctx = &getContext();
1043 FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
1044 if (failed(maybeChipset)) {
1045 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
1046 return signalPassFailure();
1049 RewritePatternSet patterns(ctx);
1050 LLVMTypeConverter converter(ctx);
1051 populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset);
1052 LLVMConversionTarget target(getContext());
1053 target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>();
1054 target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
1055 target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>();
1056 if (failed(applyPartialConversion(getOperation(), target,
1057 std::move(patterns))))
1058 signalPassFailure();
1061 } // namespace
1063 void mlir::populateAMDGPUToROCDLConversionPatterns(
1064 const LLVMTypeConverter &converter, RewritePatternSet &patterns,
1065 Chipset chipset) {
1066 patterns
1067 .add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
1068 RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,
1069 RawBufferOpLowering<RawBufferAtomicFaddOp,
1070 ROCDL::RawPtrBufferAtomicFaddOp>,
1071 RawBufferOpLowering<RawBufferAtomicFmaxOp,
1072 ROCDL::RawPtrBufferAtomicFmaxOp>,
1073 RawBufferOpLowering<RawBufferAtomicSmaxOp,
1074 ROCDL::RawPtrBufferAtomicSmaxOp>,
1075 RawBufferOpLowering<RawBufferAtomicUminOp,
1076 ROCDL::RawPtrBufferAtomicUminOp>,
1077 RawBufferOpLowering<RawBufferAtomicCmpswapOp,
1078 ROCDL::RawPtrBufferAtomicCmpSwap>,
1079 AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1080 MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
1081 PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
1082 chipset);
1085 std::unique_ptr<Pass> mlir::createConvertAMDGPUToROCDLPass() {
1086 return std::make_unique<ConvertAMDGPUToROCDLPass>();