1 //===------ WmmaOpsToNVVM.cpp - WMMA LD/ST/Compute to NVVM lowering -------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file contains definitions of patterns to lower GPU Subgroup MMA ops to
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
15 #include "mlir/Conversion/LLVMCommon/Pattern.h"
16 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
19 #include "mlir/IR/TypeUtilities.h"
25 /// Checks if all the operands of the op being lowered are of LLVM Types. The
26 /// types are expected to be converted by the `LLVMTypeConverter` before the op
27 /// is actually lowered. If the type of an operands is not already converted it
28 /// hints a missing typeConversion and failure is returned in that case.
29 static LogicalResult
areAllLLVMTypes(Operation
*op
, ValueRange operands
,
30 ConversionPatternRewriter
&rewriter
) {
31 if (!llvm::all_of(operands
, [](Value value
) {
32 return LLVM::isCompatibleType(value
.getType());
34 return rewriter
.notifyMatchFailure(
35 op
, "cannot convert if operands aren't of LLVM type.");
41 /// Error string to emit when an unimplemented WMMA variant is encountered.
42 static constexpr StringRef kInvalidCaseStr
= "Unsupported WMMA variant.";
44 static NVVM::MMAFrag
convertOperand(StringRef operandName
) {
45 if (operandName
== "AOp")
46 return NVVM::MMAFrag::a
;
47 if (operandName
== "BOp")
48 return NVVM::MMAFrag::b
;
49 if (operandName
== "COp")
50 return NVVM::MMAFrag::c
;
51 llvm_unreachable("Unknown operand name");
54 static NVVM::MMATypes
getElementType(gpu::MMAMatrixType type
) {
55 if (type
.getElementType().isF16())
56 return NVVM::MMATypes::f16
;
57 if (type
.getElementType().isF32())
58 return type
.getOperand() == "COp" ? NVVM::MMATypes::f32
59 : NVVM::MMATypes::tf32
;
61 if (type
.getElementType().isSignedInteger(8))
62 return NVVM::MMATypes::s8
;
63 if (type
.getElementType().isUnsignedInteger(8))
64 return NVVM::MMATypes::u8
;
65 // Accumulator type is signless and implies signed.
66 if (type
.getElementType().isInteger(32))
67 return NVVM::MMATypes::s32
;
68 llvm_unreachable("Unsupported type");
71 /// This class implements the conversion of GPU MMA loadOp to wmma.load op
72 /// in the NVVM dialect. The conversion not only emits the NVVM op but also
73 /// emits code that is necessary to store the data in the destination memref
74 /// after it has been loaded.
75 struct WmmaLoadOpToNVVMLowering
76 : public ConvertOpToLLVMPattern
<gpu::SubgroupMmaLoadMatrixOp
> {
77 using ConvertOpToLLVMPattern
<
78 gpu::SubgroupMmaLoadMatrixOp
>::ConvertOpToLLVMPattern
;
81 matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp
,
83 ConversionPatternRewriter
&rewriter
) const override
{
84 Operation
*op
= subgroupMmaLoadMatrixOp
.getOperation();
85 if (failed(areAllLLVMTypes(op
, adaptor
.getOperands(), rewriter
)))
88 // Get the shape of the MMAMatrix type being returned. The shape will
89 // choose which intrinsic this op will be lowered to.
90 NVVM::MMALayout layout
= subgroupMmaLoadMatrixOp
.getTranspose()
91 ? NVVM::MMALayout::col
92 : NVVM::MMALayout::row
;
93 gpu::MMAMatrixType retType
=
94 cast
<gpu::MMAMatrixType
>(subgroupMmaLoadMatrixOp
.getRes().getType());
95 ArrayRef
<int64_t> retTypeShape
= retType
.getShape();
99 NVVM::MMATypes eltype
= getElementType(retType
);
100 // NVVM intrinsics require to give mxnxk dimensions, infer the missing
101 // dimension based on the valid intrinsics available.
102 if (retType
.getOperand() == "AOp") {
105 n
= NVVM::WMMALoadOp::inferNDimension(m
, k
, eltype
);
106 } else if (retType
.getOperand() == "BOp") {
109 m
= NVVM::WMMALoadOp::inferMDimension(k
, n
, eltype
);
110 } else if (retType
.getOperand() == "COp") {
113 k
= NVVM::WMMALoadOp::inferKDimension(m
, n
, eltype
);
115 NVVM::MMAFrag frag
= convertOperand(retType
.getOperand());
116 // Check that there is an exisiting instruction for the combination we need.
117 if (NVVM::WMMALoadOp::getIntrinsicID(m
, n
, k
, layout
, eltype
, frag
) == 0)
118 return rewriter
.notifyMatchFailure(op
, kInvalidCaseStr
);
120 Type resType
= convertMMAToLLVMType(retType
);
121 Location loc
= op
->getLoc();
123 // Create nvvm.mma_load op according to the operand types.
124 Value dataPtr
= getStridedElementPtr(
125 loc
, cast
<MemRefType
>(subgroupMmaLoadMatrixOp
.getSrcMemref().getType()),
126 adaptor
.getSrcMemref(), adaptor
.getIndices(), rewriter
);
128 Value leadingDim
= rewriter
.create
<LLVM::ConstantOp
>(
129 loc
, rewriter
.getI32Type(),
130 subgroupMmaLoadMatrixOp
.getLeadDimensionAttr());
131 rewriter
.replaceOpWithNewOp
<NVVM::WMMALoadOp
>(
132 op
, resType
, dataPtr
, leadingDim
, m
, n
, k
, layout
, eltype
, frag
);
137 /// This class implements the conversion of GPU MMA storeOp to wmma.store op
138 /// in the NVVM dialect. The conversion not only emits the NVVM op but also
139 /// emits code that is necessary to unpack the data in the source and
140 /// convert the data in the format that is needed by the NVVM op.
141 struct WmmaStoreOpToNVVMLowering
142 : public ConvertOpToLLVMPattern
<gpu::SubgroupMmaStoreMatrixOp
> {
143 using ConvertOpToLLVMPattern
<
144 gpu::SubgroupMmaStoreMatrixOp
>::ConvertOpToLLVMPattern
;
147 matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp
,
149 ConversionPatternRewriter
&rewriter
) const override
{
150 Operation
*op
= subgroupMmaStoreMatrixOp
.getOperation();
151 if (failed(areAllLLVMTypes(op
, adaptor
.getOperands(), rewriter
)))
154 Location loc
= op
->getLoc();
156 SmallVector
<Value
, 4> storeOpOperands
;
157 // Get the shape of the MMAMatrix type being stored. The shape will
158 // choose which intrinsic this op will be lowered to.
159 gpu::MMAMatrixType srcType
=
160 cast
<gpu::MMAMatrixType
>(subgroupMmaStoreMatrixOp
.getSrc().getType());
161 ArrayRef
<int64_t> srcTypeShape
= srcType
.getShape();
162 NVVM::MMALayout layout
= subgroupMmaStoreMatrixOp
.getTranspose()
163 ? NVVM::MMALayout::col
164 : NVVM::MMALayout::row
;
165 NVVM::MMATypes eltype
= getElementType(srcType
);
166 int64_t m
= srcTypeShape
[0];
167 int64_t n
= srcTypeShape
[1];
168 int64_t k
= NVVM::WMMAStoreOp::inferKDimension(m
, n
, eltype
);
169 if (NVVM::WMMAStoreOp::getIntrinsicID(m
, n
, k
, layout
, eltype
) == 0)
170 return rewriter
.notifyMatchFailure(op
, kInvalidCaseStr
);
172 auto matrixType
= cast
<LLVM::LLVMStructType
>(adaptor
.getSrc().getType());
173 for (unsigned i
= 0, e
= matrixType
.getBody().size(); i
< e
; ++i
) {
175 rewriter
.create
<LLVM::ExtractValueOp
>(loc
, adaptor
.getSrc(), i
);
176 storeOpOperands
.push_back(toUse
);
179 Value dataPtr
= getStridedElementPtr(
181 cast
<MemRefType
>(subgroupMmaStoreMatrixOp
.getDstMemref().getType()),
182 adaptor
.getDstMemref(), adaptor
.getIndices(), rewriter
);
183 Value leadingDim
= rewriter
.create
<LLVM::ConstantOp
>(
184 loc
, rewriter
.getI32Type(),
185 subgroupMmaStoreMatrixOp
.getLeadDimensionAttr());
186 rewriter
.replaceOpWithNewOp
<NVVM::WMMAStoreOp
>(
187 op
, dataPtr
, m
, n
, k
, layout
, eltype
, storeOpOperands
, leadingDim
);
192 /// This class implements the conversion of GPU MMA computeOp to wmma.mma op
193 /// in the NVVM dialect.
194 struct WmmaMmaOpToNVVMLowering
195 : public ConvertOpToLLVMPattern
<gpu::SubgroupMmaComputeOp
> {
196 using ConvertOpToLLVMPattern
<
197 gpu::SubgroupMmaComputeOp
>::ConvertOpToLLVMPattern
;
200 matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp
,
202 ConversionPatternRewriter
&rewriter
) const override
{
203 Operation
*op
= subgroupMmaComputeOp
.getOperation();
204 if (failed(areAllLLVMTypes(op
, adaptor
.getOperands(), rewriter
)))
207 Location loc
= op
->getLoc();
209 // The wmma.mma intrinsic in llvm requires the operands as individual
210 // values. So individual elements from the memrefs need to be extracted and
211 // then passed on to the intrinsic call. Emit llvm ops to extract individual
212 // values form lowered memrefs.
213 SmallVector
<Value
> unpackedOps
;
215 auto unpackOp
= [&](Value operand
) {
216 auto structType
= cast
<LLVM::LLVMStructType
>(operand
.getType());
217 for (size_t i
= 0, e
= structType
.getBody().size(); i
< e
; ++i
) {
218 Value toUse
= rewriter
.create
<LLVM::ExtractValueOp
>(loc
, operand
, i
);
219 unpackedOps
.push_back(toUse
);
223 // Get the shapes of the MMAMatrix type being used. The shapes will
224 // choose which intrinsic this op will be lowered to.
225 gpu::MMAMatrixType aType
=
226 cast
<gpu::MMAMatrixType
>(subgroupMmaComputeOp
.getOpA().getType());
227 ArrayRef
<int64_t> aTypeShape
= aType
.getShape();
228 gpu::MMAMatrixType cType
=
229 cast
<gpu::MMAMatrixType
>(subgroupMmaComputeOp
.getOpC().getType());
230 ArrayRef
<int64_t> cTypeShape
= cType
.getShape();
231 int64_t m
= cTypeShape
[0];
232 int64_t n
= cTypeShape
[1];
233 int64_t k
= aTypeShape
[1];
234 NVVM::MMALayout aLayout
= subgroupMmaComputeOp
.getATranspose()
235 ? NVVM::MMALayout::col
236 : NVVM::MMALayout::row
;
237 NVVM::MMALayout bLayout
= subgroupMmaComputeOp
.getBTranspose()
238 ? NVVM::MMALayout::col
239 : NVVM::MMALayout::row
;
240 NVVM::MMATypes sourceType
= getElementType(aType
);
241 NVVM::MMATypes destType
= getElementType(cType
);
242 if (NVVM::WMMAMmaOp::getIntrinsicID(m
, n
, k
, aLayout
, bLayout
, sourceType
,
244 return rewriter
.notifyMatchFailure(op
, kInvalidCaseStr
);
246 NVVM::MMATypes bElementType
= getElementType(
247 cast
<gpu::MMAMatrixType
>(subgroupMmaComputeOp
.getOpB().getType()));
248 if (bElementType
!= sourceType
)
249 return rewriter
.notifyMatchFailure(
250 op
, "WMMA compute op input matrix element types must match.");
252 unpackOp(adaptor
.getOpA());
253 unpackOp(adaptor
.getOpB());
254 unpackOp(adaptor
.getOpC());
256 rewriter
.replaceOpWithNewOp
<NVVM::WMMAMmaOp
>(
257 op
, adaptor
.getOpC().getType(), m
, n
, k
, aLayout
, bLayout
, sourceType
,
258 destType
, unpackedOps
);
263 /// Convert GPU MMA ConstantMatrixOp to a chain of InsertValueOp.
264 struct WmmaConstantOpToNVVMLowering
265 : public ConvertOpToLLVMPattern
<gpu::SubgroupMmaConstantMatrixOp
> {
266 using ConvertOpToLLVMPattern
<
267 gpu::SubgroupMmaConstantMatrixOp
>::ConvertOpToLLVMPattern
;
270 matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantOp
,
272 ConversionPatternRewriter
&rewriter
) const override
{
273 if (failed(areAllLLVMTypes(subgroupMmaConstantOp
.getOperation(),
274 adaptor
.getOperands(), rewriter
)))
276 Location loc
= subgroupMmaConstantOp
.getLoc();
277 Value cst
= adaptor
.getOperands()[0];
278 LLVM::LLVMStructType type
= convertMMAToLLVMType(
279 cast
<gpu::MMAMatrixType
>(subgroupMmaConstantOp
.getType()));
280 // If the element type is a vector create a vector from the operand.
281 if (auto vecType
= dyn_cast
<VectorType
>(type
.getBody()[0])) {
282 Value vecCst
= rewriter
.create
<LLVM::UndefOp
>(loc
, vecType
);
283 for (int64_t vecEl
= 0; vecEl
< vecType
.getNumElements(); vecEl
++) {
284 Value idx
= rewriter
.create
<LLVM::ConstantOp
>(
285 loc
, rewriter
.getI32Type(), vecEl
);
286 vecCst
= rewriter
.create
<LLVM::InsertElementOp
>(loc
, vecType
, vecCst
,
291 Value matrixStruct
= rewriter
.create
<LLVM::UndefOp
>(loc
, type
);
292 for (size_t i
: llvm::seq(size_t(0), type
.getBody().size())) {
294 rewriter
.create
<LLVM::InsertValueOp
>(loc
, matrixStruct
, cst
, i
);
296 rewriter
.replaceOp(subgroupMmaConstantOp
, matrixStruct
);
301 static Value
createMinMaxF(OpBuilder
&builder
, Location loc
, Value lhs
,
302 Value rhs
, bool isMin
) {
303 auto floatType
= cast
<FloatType
>(getElementTypeOrSelf(lhs
.getType()));
304 Type i1Type
= builder
.getI1Type();
305 if (auto vecType
= dyn_cast
<VectorType
>(lhs
.getType()))
306 i1Type
= VectorType::get(vecType
.getShape(), i1Type
);
307 Value cmp
= builder
.create
<LLVM::FCmpOp
>(
308 loc
, i1Type
, isMin
? LLVM::FCmpPredicate::olt
: LLVM::FCmpPredicate::ogt
,
310 Value sel
= builder
.create
<LLVM::SelectOp
>(loc
, cmp
, lhs
, rhs
);
311 Value isNan
= builder
.create
<LLVM::FCmpOp
>(
312 loc
, i1Type
, LLVM::FCmpPredicate::uno
, lhs
, rhs
);
313 Value nan
= builder
.create
<LLVM::ConstantOp
>(
315 builder
.getFloatAttr(floatType
,
316 APFloat::getQNaN(floatType
.getFloatSemantics())));
317 return builder
.create
<LLVM::SelectOp
>(loc
, isNan
, nan
, sel
);
320 static Value
createScalarOp(OpBuilder
&builder
, Location loc
,
321 gpu::MMAElementwiseOp op
,
322 ArrayRef
<Value
> operands
) {
324 case gpu::MMAElementwiseOp::ADDF
:
325 return builder
.create
<LLVM::FAddOp
>(loc
, operands
[0].getType(), operands
);
326 case gpu::MMAElementwiseOp::MULF
:
327 return builder
.create
<LLVM::FMulOp
>(loc
, operands
[0].getType(), operands
);
328 case gpu::MMAElementwiseOp::DIVF
:
329 return builder
.create
<LLVM::FDivOp
>(loc
, operands
[0].getType(), operands
);
330 case gpu::MMAElementwiseOp::MAXF
:
331 return createMinMaxF(builder
, loc
, operands
[0], operands
[1],
333 case gpu::MMAElementwiseOp::MINF
:
334 return createMinMaxF(builder
, loc
, operands
[0], operands
[1],
337 llvm_unreachable("unknown op");
341 /// Convert GPU MMA elementwise ops to extract + op + insert.
342 struct WmmaElementwiseOpToNVVMLowering
343 : public ConvertOpToLLVMPattern
<gpu::SubgroupMmaElementwiseOp
> {
344 using ConvertOpToLLVMPattern
<
345 gpu::SubgroupMmaElementwiseOp
>::ConvertOpToLLVMPattern
;
348 matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp
,
350 ConversionPatternRewriter
&rewriter
) const override
{
351 if (failed(areAllLLVMTypes(subgroupMmaElementwiseOp
.getOperation(),
352 adaptor
.getOperands(), rewriter
)))
354 Location loc
= subgroupMmaElementwiseOp
.getLoc();
355 size_t numOperands
= adaptor
.getOperands().size();
356 LLVM::LLVMStructType destType
= convertMMAToLLVMType(
357 cast
<gpu::MMAMatrixType
>(subgroupMmaElementwiseOp
.getType()));
358 Value matrixStruct
= rewriter
.create
<LLVM::UndefOp
>(loc
, destType
);
359 for (size_t i
= 0, e
= destType
.getBody().size(); i
< e
; ++i
) {
360 SmallVector
<Value
> extractedOperands
;
361 for (size_t opIdx
= 0; opIdx
< numOperands
; opIdx
++) {
362 extractedOperands
.push_back(rewriter
.create
<LLVM::ExtractValueOp
>(
363 loc
, adaptor
.getOperands()[opIdx
], i
));
366 createScalarOp(rewriter
, loc
, subgroupMmaElementwiseOp
.getOpType(),
369 rewriter
.create
<LLVM::InsertValueOp
>(loc
, matrixStruct
, element
, i
);
371 rewriter
.replaceOp(subgroupMmaElementwiseOp
, matrixStruct
);
378 /// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
379 LLVM::LLVMStructType
mlir::convertMMAToLLVMType(gpu::MMAMatrixType type
) {
380 NVVM::MMAFrag frag
= convertOperand(type
.getOperand());
381 NVVM::MMATypes eltType
= getElementType(type
);
382 auto nRow
= type
.getShape()[0];
383 auto nCol
= type
.getShape()[1];
384 std::pair
<Type
, unsigned> typeInfo
=
385 NVVM::inferMMAType(eltType
, frag
, nRow
, nCol
, type
.getContext());
386 return LLVM::LLVMStructType::getLiteral(
387 type
.getContext(), SmallVector
<Type
, 8>(typeInfo
.second
, typeInfo
.first
));
390 void mlir::populateGpuWMMAToNVVMConversionPatterns(
391 const LLVMTypeConverter
&converter
, RewritePatternSet
&patterns
) {
392 patterns
.add
<WmmaLoadOpToNVVMLowering
, WmmaMmaOpToNVVMLowering
,
393 WmmaStoreOpToNVVMLowering
, WmmaConstantOpToNVVMLowering
,
394 WmmaElementwiseOpToNVVMLowering
>(converter
);