1 //===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
11 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
12 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
13 #include "mlir/Conversion/LLVMCommon/Pattern.h"
14 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
19 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
22 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/ImplicitLocOpBuilder.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
27 #include "mlir/IR/Value.h"
28 #include "mlir/Pass/Pass.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/raw_ostream.h"
34 #define DEBUG_TYPE "nvgpu-to-nvvm"
35 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
36 #define DBGSE() (llvm::dbgs())
39 #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
40 #include "mlir/Conversion/Passes.h.inc"
45 /// Number of bits that needs to be excluded when building matrix descriptor for
47 constexpr int exclude4LSB
= 4;
49 /// GPU has 32 bit registers, this function truncates values when larger width
51 static Value
truncToI32(ImplicitLocOpBuilder
&b
, Value value
) {
52 Type type
= value
.getType();
53 assert(llvm::isa
<IntegerType
>(type
) && "expected an integer Value");
54 if (type
.getIntOrFloatBitWidth() <= 32)
56 return b
.create
<LLVM::TruncOp
>(b
.getI32Type(), value
);
59 /// Returns the type for the intrinsic given the vectorResultType of the
60 /// `gpu.mma.sync` operation.
61 static Type
inferIntrinsicResultType(Type vectorResultType
) {
62 MLIRContext
*ctx
= vectorResultType
.getContext();
63 auto a
= cast
<LLVM::LLVMArrayType
>(vectorResultType
);
64 auto f16x2Ty
= LLVM::getFixedVectorType(Float16Type::get(ctx
), 2);
65 auto i32Ty
= IntegerType::get(ctx
, 32);
66 auto i32x2Ty
= LLVM::getFixedVectorType(i32Ty
, 2);
67 Type f64Ty
= Float64Type::get(ctx
);
68 Type f64x2Ty
= LLVM::getFixedVectorType(f64Ty
, 2);
69 Type f32Ty
= Float32Type::get(ctx
);
70 Type f32x2Ty
= LLVM::getFixedVectorType(f32Ty
, 2);
71 if (a
.getElementType() == f16x2Ty
) {
72 return LLVM::LLVMStructType::getLiteral(
73 ctx
, SmallVector
<Type
>(a
.getNumElements(), f16x2Ty
));
75 if (a
.getElementType() == i32x2Ty
) {
76 return LLVM::LLVMStructType::getLiteral(
78 SmallVector
<Type
>(static_cast<size_t>(a
.getNumElements()) * 2, i32Ty
));
80 if (a
.getElementType() == f64x2Ty
) {
81 return LLVM::LLVMStructType::getLiteral(ctx
, {f64Ty
, f64Ty
});
83 if (a
.getElementType() == f32x2Ty
) {
84 return LLVM::LLVMStructType::getLiteral(
86 SmallVector
<Type
>(static_cast<size_t>(a
.getNumElements()) * 2, f32Ty
));
88 if (a
.getElementType() == LLVM::getFixedVectorType(f32Ty
, 1)) {
89 return LLVM::LLVMStructType::getLiteral(
90 ctx
, SmallVector
<Type
>(static_cast<size_t>(a
.getNumElements()), f32Ty
));
92 return vectorResultType
;
95 /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is
96 /// always an LLVM struct) into a fragment that is compatible with the vector
97 /// type of this operation. This involves extracting elements from the struct
98 /// and inserting them into an LLVM array. These extra data-movement
99 /// operations should be canonicalized away by the LLVM backend.
100 static Value
convertIntrinsicResult(Location loc
, Type intrinsicResultType
,
101 Type resultType
, Value intrinsicResult
,
102 RewriterBase
&rewriter
) {
103 MLIRContext
*ctx
= rewriter
.getContext();
104 auto structType
= dyn_cast
<LLVM::LLVMStructType
>(intrinsicResultType
);
105 auto arrayType
= dyn_cast
<LLVM::LLVMArrayType
>(resultType
);
106 Type i32Ty
= rewriter
.getI32Type();
107 Type f32Ty
= rewriter
.getF32Type();
108 Type f64Ty
= rewriter
.getF64Type();
109 Type f16x2Ty
= LLVM::getFixedVectorType(rewriter
.getF16Type(), 2);
110 Type i32x2Ty
= LLVM::getFixedVectorType(i32Ty
, 2);
111 Type f64x2Ty
= LLVM::getFixedVectorType(f64Ty
, 2);
112 Type f32x2Ty
= LLVM::getFixedVectorType(f32Ty
, 2);
113 Type f32x1Ty
= LLVM::getFixedVectorType(f32Ty
, 1);
115 auto makeConst
= [&](int32_t index
) -> Value
{
116 return rewriter
.create
<LLVM::ConstantOp
>(loc
, IntegerType::get(ctx
, 32),
117 rewriter
.getI32IntegerAttr(index
));
121 SmallVector
<Value
, 4> elements
;
123 // The intrinsic returns 32-bit wide elements in a form which can be
124 // directly bitcasted and inserted into the result vector.
125 if (arrayType
.getElementType() == f16x2Ty
||
126 arrayType
.getElementType() == f32x1Ty
) {
127 for (unsigned i
= 0; i
< structType
.getBody().size(); i
++) {
129 rewriter
.create
<LLVM::ExtractValueOp
>(loc
, intrinsicResult
, i
);
130 el
= rewriter
.createOrFold
<LLVM::BitcastOp
>(
131 loc
, arrayType
.getElementType(), el
);
132 elements
.push_back(el
);
136 // The intrinsic returns i32, f64, and f32 values as individual scalars,
137 // even when the result is notionally a 64-bit wide element (e.g. f32x2). We
138 // need to extract them from the struct and pack them into the 64-bit wide
139 // rows of the vector result.
140 if (arrayType
.getElementType() == i32x2Ty
||
141 arrayType
.getElementType() == f64x2Ty
||
142 arrayType
.getElementType() == f32x2Ty
) {
144 for (unsigned i
= 0, e
= structType
.getBody().size() / 2; i
< e
; i
++) {
146 rewriter
.create
<LLVM::UndefOp
>(loc
, arrayType
.getElementType());
148 rewriter
.create
<LLVM::ExtractValueOp
>(loc
, intrinsicResult
, i
* 2);
149 Value x2
= rewriter
.create
<LLVM::ExtractValueOp
>(loc
, intrinsicResult
,
151 vec
= rewriter
.create
<LLVM::InsertElementOp
>(loc
, vec
.getType(), vec
,
153 vec
= rewriter
.create
<LLVM::InsertElementOp
>(loc
, vec
.getType(), vec
,
155 elements
.push_back(vec
);
159 // Create the final vectorized result.
160 Value result
= rewriter
.create
<LLVM::UndefOp
>(loc
, arrayType
);
161 for (const auto &el
: llvm::enumerate(elements
)) {
162 result
= rewriter
.create
<LLVM::InsertValueOp
>(loc
, result
, el
.value(),
168 return intrinsicResult
;
171 /// The `gpu.mma.sync` converter below expects matrix fragment operands to be
172 /// given as 2D `vectors` where the rows are 32b or 64b wide. The
173 /// `nvvm.mma.sync` op expects these argments to be a given in a long list of
174 /// scalars of certain types. This function helps unpack the `vector` arguments
175 /// and cast them to the types expected by `nvvm.mma.sync`.
176 static SmallVector
<Value
> unpackOperandVector(ImplicitLocOpBuilder
&b
,
178 NVVM::MMATypes operandPtxType
) {
179 SmallVector
<Value
> result
;
180 Type i32Ty
= b
.getI32Type();
181 Type f64Ty
= b
.getF64Type();
182 Type f32Ty
= b
.getF32Type();
183 Type i64Ty
= b
.getI64Type();
184 Type i8x4Ty
= LLVM::getFixedVectorType(b
.getI8Type(), 4);
185 Type i4x8Ty
= LLVM::getFixedVectorType(b
.getIntegerType(4), 8);
186 Type f32x1Ty
= LLVM::getFixedVectorType(f32Ty
, 1);
187 auto arrayTy
= cast
<LLVM::LLVMArrayType
>(operand
.getType());
189 for (unsigned i
= 0, e
= arrayTy
.getNumElements(); i
< e
; ++i
) {
190 Value toUse
= b
.create
<LLVM::ExtractValueOp
>(operand
, i
);
192 // For 4xi8 vectors, the intrinsic expects these to be provided as i32
194 if (arrayTy
.getElementType() == i8x4Ty
||
195 arrayTy
.getElementType() == i4x8Ty
||
196 (arrayTy
.getElementType() == f32x1Ty
&&
197 operandPtxType
== NVVM::MMATypes::tf32
)) {
198 result
.push_back(b
.create
<LLVM::BitcastOp
>(i32Ty
, toUse
));
202 // For some element types (i32, f32, f64), we need to unpack the inner
203 // vector/array type as well because the intrinsic expects individual
204 // scalars to be provided.
205 VectorType innerArrayTy
= dyn_cast
<VectorType
>(arrayTy
.getElementType());
206 if (innerArrayTy
&& (innerArrayTy
.getElementType() == i32Ty
||
207 innerArrayTy
.getElementType() == f64Ty
||
208 innerArrayTy
.getElementType() == f32Ty
)) {
209 for (unsigned idx
= 0, innerSize
= innerArrayTy
.getNumElements();
210 idx
< innerSize
; idx
++) {
211 result
.push_back(b
.create
<LLVM::ExtractElementOp
>(
213 b
.create
<LLVM::ConstantOp
>(i64Ty
, b
.getI64IntegerAttr(idx
))));
217 result
.push_back(toUse
);
222 /// Returns whether mbarrier object has shared memory address space.
223 static bool isMbarrierShared(nvgpu::MBarrierGroupType barrierType
) {
224 return (mlir::nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(
225 barrierType
.getMemorySpace()));
228 /// Returns the memory space attribute of the mbarrier object.
229 Attribute
nvgpu::getMbarrierMemorySpace(MLIRContext
*context
,
230 nvgpu::MBarrierGroupType barrierType
) {
231 Attribute memorySpace
= {};
232 if (isMbarrierShared(barrierType
)) {
234 IntegerAttr::get(IntegerType::get(context
, 64),
235 nvgpu::NVGPUDialect::kSharedMemoryAddressSpace
);
240 /// Returns memref type of the mbarrier object. The type is defined in the
241 /// MBarrierGroupType.
242 MemRefType
nvgpu::getMBarrierMemrefType(MLIRContext
*context
,
243 nvgpu::MBarrierGroupType barrierType
) {
244 Attribute memorySpace
= nvgpu::getMbarrierMemorySpace(context
, barrierType
);
245 MemRefLayoutAttrInterface layout
;
246 return MemRefType::get({barrierType
.getNumBarriers()},
247 IntegerType::get(context
, 64), layout
, memorySpace
);
252 struct MmaLdMatrixOpToNVVM
: public ConvertOpToLLVMPattern
<nvgpu::LdMatrixOp
> {
253 using ConvertOpToLLVMPattern
<nvgpu::LdMatrixOp
>::ConvertOpToLLVMPattern
;
256 matchAndRewrite(nvgpu::LdMatrixOp op
, OpAdaptor adaptor
,
257 ConversionPatternRewriter
&rewriter
) const override
{
258 MLIRContext
*ctx
= getContext();
259 ImplicitLocOpBuilder
b(op
.getLoc(), rewriter
);
261 // The result type of ldmatrix will always be a struct of 32bit integer
262 // registers if more than one 32bit value is returned. Otherwise, the result
263 // is a single i32. The result type of the GPU operation is always a vector
264 // of shape (NumRegisters, VectorRegister) where VectorRegister is the
265 // vector type of the result and always 32 bits long. We bitcast the result
266 // of the NVVM::LdMatrix to this vector type.
267 auto vectorResultType
= dyn_cast
<VectorType
>(op
->getResultTypes()[0]);
268 if (!vectorResultType
) {
271 Type innerVectorType
= LLVM::getFixedVectorType(
272 vectorResultType
.getElementType(), vectorResultType
.getDimSize(1));
274 int64_t num32BitRegs
= vectorResultType
.getDimSize(0);
276 Type ldMatrixResultType
;
277 if (num32BitRegs
> 1) {
278 ldMatrixResultType
= LLVM::LLVMStructType::getLiteral(
279 ctx
, SmallVector
<Type
>(num32BitRegs
, rewriter
.getI32Type()));
281 ldMatrixResultType
= rewriter
.getI32Type();
284 auto srcMemrefType
= cast
<MemRefType
>(op
.getSrcMemref().getType());
286 getStridedElementPtr(b
.getLoc(), srcMemrefType
, adaptor
.getSrcMemref(),
287 adaptor
.getIndices(), rewriter
);
288 Value ldMatrixResult
= b
.create
<NVVM::LdMatrixOp
>(
289 ldMatrixResultType
, srcPtr
,
290 /*num=*/op
.getNumTiles(),
291 /*layout=*/op
.getTranspose() ? NVVM::MMALayout::col
292 : NVVM::MMALayout::row
);
294 // The ldmatrix operation returns either a single i32 value or a struct of
295 // i32 values. Here we unpack those values and cast them back to their
296 // actual vector type (still of width 32b) and repack them into a result
298 Type finalResultType
= typeConverter
->convertType(vectorResultType
);
299 Value result
= b
.create
<LLVM::UndefOp
>(finalResultType
);
300 for (int64_t i
= 0, e
= vectorResultType
.getDimSize(0); i
< e
; i
++) {
302 num32BitRegs
> 1 ? b
.create
<LLVM::ExtractValueOp
>(ldMatrixResult
, i
)
304 Value casted
= b
.create
<LLVM::BitcastOp
>(innerVectorType
, i32Register
);
305 result
= b
.create
<LLVM::InsertValueOp
>(result
, casted
, i
);
308 rewriter
.replaceOp(op
, result
);
313 /// Convert the given type into the corresponding PTX type (NVVM::MMATypes
315 static FailureOr
<NVVM::MMATypes
> getNvvmMmaType(Type t
) {
316 Type elType
= getElementTypeOrSelf(t
);
317 if (elType
.isInteger(8))
318 return NVVM::MMATypes::s8
;
319 if (elType
.isInteger(4))
320 return NVVM::MMATypes::s4
;
322 return NVVM::MMATypes::f16
;
324 return NVVM::MMATypes::f64
;
326 return NVVM::MMATypes::tf32
;
330 struct MmaSyncOptoNVVM
: public ConvertOpToLLVMPattern
<nvgpu::MmaSyncOp
> {
331 using ConvertOpToLLVMPattern
<nvgpu::MmaSyncOp
>::ConvertOpToLLVMPattern
;
334 matchAndRewrite(nvgpu::MmaSyncOp op
, OpAdaptor adaptor
,
335 ConversionPatternRewriter
&rewriter
) const override
{
336 ImplicitLocOpBuilder
b(op
.getLoc(), rewriter
);
337 // Get the shapes of the MMAMatrix type being used. The shapes will
338 // choose which intrinsic this op will be lowered to.
339 VectorType aType
= op
.getMatrixA().getType();
340 VectorType bType
= op
.getMatrixA().getType();
341 VectorType cType
= op
.getMatrixC().getType();
343 std::array
<int64_t, 3> gemmShape
= op
.getMmaShapeAsArray();
345 // Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
346 bool tf32Enabled
= op
->hasAttr(op
.getTf32EnabledAttrName());
347 if (aType
.getElementType().isF32() && !tf32Enabled
)
350 FailureOr
<NVVM::MMATypes
> ptxTypeA
= getNvvmMmaType(aType
);
351 if (failed(ptxTypeA
))
352 return op
->emitOpError("failed to deduce operand PTX types");
353 FailureOr
<NVVM::MMATypes
> ptxTypeB
= getNvvmMmaType(bType
);
354 if (failed(ptxTypeB
))
355 return op
->emitOpError("failed to deduce operand PTX types");
356 std::optional
<NVVM::MMATypes
> ptxTypeC
=
357 NVVM::MmaOp::inferOperandMMAType(cType
.getElementType(),
358 /*isAccumulator=*/true);
360 return op
->emitError(
361 "could not infer the PTX type for the accumulator/result");
363 // TODO: add an attribute to the op to customize this behavior.
364 std::optional
<NVVM::MMAIntOverflow
> overflow(std::nullopt
);
365 if (isa
<IntegerType
>(aType
.getElementType()))
366 overflow
= NVVM::MMAIntOverflow::satfinite
;
368 SmallVector
<Value
> matA
=
369 unpackOperandVector(b
, adaptor
.getMatrixA(), *ptxTypeA
);
370 SmallVector
<Value
> matB
=
371 unpackOperandVector(b
, adaptor
.getMatrixB(), *ptxTypeB
);
372 SmallVector
<Value
> matC
=
373 unpackOperandVector(b
, adaptor
.getMatrixC(), *ptxTypeC
);
375 Type desiredRetTy
= typeConverter
->convertType(op
->getResultTypes()[0]);
376 Type intrinsicResTy
= inferIntrinsicResultType(
377 typeConverter
->convertType(op
->getResultTypes()[0]));
378 Value intrinsicResult
= b
.create
<NVVM::MmaOp
>(
379 intrinsicResTy
, matA
, matB
, matC
,
381 /*b1Op=*/std::nullopt
,
382 /*intOverflow=*/overflow
,
383 /*multiplicandPtxTypes=*/
384 std::array
<NVVM::MMATypes
, 2>{*ptxTypeA
, *ptxTypeB
},
385 /*multiplicandLayouts=*/
386 std::array
<NVVM::MMALayout
, 2>{NVVM::MMALayout::row
,
387 NVVM::MMALayout::col
});
388 rewriter
.replaceOp(op
, convertIntrinsicResult(op
.getLoc(), intrinsicResTy
,
389 desiredRetTy
, intrinsicResult
,
395 struct ConvertNVGPUToNVVMPass
396 : public impl::ConvertNVGPUToNVVMPassBase
<ConvertNVGPUToNVVMPass
> {
399 void getDependentDialects(DialectRegistry
®istry
) const override
{
400 registry
.insert
<memref::MemRefDialect
, LLVM::LLVMDialect
, NVVM::NVVMDialect
,
401 arith::ArithDialect
>();
404 void runOnOperation() override
{
405 LowerToLLVMOptions
options(&getContext());
406 RewritePatternSet
patterns(&getContext());
407 LLVMTypeConverter
converter(&getContext(), options
);
408 IRRewriter
rewriter(&getContext());
409 populateGpuMemorySpaceAttributeConversions(
410 converter
, [](gpu::AddressSpace space
) -> unsigned {
412 case gpu::AddressSpace::Global
:
413 return static_cast<unsigned>(
414 NVVM::NVVMMemorySpace::kGlobalMemorySpace
);
415 case gpu::AddressSpace::Workgroup
:
416 return static_cast<unsigned>(
417 NVVM::NVVMMemorySpace::kSharedMemorySpace
);
418 case gpu::AddressSpace::Private
:
421 llvm_unreachable("unknown address space enum value");
424 /// device-side async tokens cannot be materialized in nvvm. We just
425 /// convert them to a dummy i32 type in order to easily drop them during
427 converter
.addConversion([&](nvgpu::DeviceAsyncTokenType type
) -> Type
{
428 return converter
.convertType(IntegerType::get(type
.getContext(), 32));
430 converter
.addConversion([&](nvgpu::WarpgroupAccumulatorType type
) -> Type
{
431 Type elemType
= type
.getFragmented().getElementType();
432 int64_t sizeM
= type
.getFragmented().getDimSize(0);
433 int64_t sizeN
= type
.getFragmented().getDimSize(1);
436 if (elemType
.isF32() || elemType
.isInteger(32))
437 numMembers
= sizeN
/ 2;
438 else if (elemType
.isF16())
439 numMembers
= sizeN
/ 4;
441 llvm_unreachable("unsupported type for warpgroup accumulator");
443 SmallVector
<Type
> innerStructBody
;
444 for (unsigned i
= 0; i
< numMembers
; i
++)
445 innerStructBody
.push_back(elemType
);
446 auto innerStructType
=
447 LLVM::LLVMStructType::getLiteral(type
.getContext(), innerStructBody
);
449 SmallVector
<Type
> structBody
;
450 for (int i
= 0; i
< sizeM
; i
+= kWgmmaSizeM
)
451 structBody
.push_back(innerStructType
);
454 LLVM::LLVMStructType::getLiteral(type
.getContext(), structBody
);
455 return converter
.convertType(convertedType
);
457 converter
.addConversion([&](nvgpu::MBarrierTokenType type
) -> Type
{
458 return converter
.convertType(IntegerType::get(type
.getContext(), 64));
460 converter
.addConversion(
461 [&](nvgpu::WarpgroupMatrixDescriptorType type
) -> Type
{
462 return converter
.convertType(IntegerType::get(type
.getContext(), 64));
464 converter
.addConversion([&](nvgpu::MBarrierGroupType type
) -> Type
{
465 return converter
.convertType(
466 nvgpu::getMBarrierMemrefType(rewriter
.getContext(), type
));
468 converter
.addConversion([&](nvgpu::TensorMapDescriptorType type
) -> Type
{
469 return LLVM::LLVMPointerType::get(type
.getContext());
471 populateNVGPUToNVVMConversionPatterns(converter
, patterns
);
472 LLVMConversionTarget
target(getContext());
473 target
.addLegalDialect
<::mlir::LLVM::LLVMDialect
>();
474 target
.addLegalDialect
<::mlir::arith::ArithDialect
>();
475 target
.addLegalDialect
<::mlir::memref::MemRefDialect
>();
476 target
.addLegalDialect
<::mlir::NVVM::NVVMDialect
>();
477 mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
478 converter
, patterns
, target
);
479 if (failed(applyPartialConversion(getOperation(), target
,
480 std::move(patterns
))))
485 /// Returns the constraints for the sparse MMA inline assembly instruction.
486 static std::string
buildMmaSparseAsmConstraintString(unsigned matASize
,
490 llvm::raw_string_ostream
ss(str
);
491 for (unsigned i
= 0; i
< matCSize
; i
++)
493 for (unsigned i
= 0; i
< matASize
+ matBSize
+ matCSize
; i
++)
495 // The final operand is for the sparsity metadata.
496 // The sparsity selector appears as direct literal.
501 /// Returns the string for the `mma.sp.sync` instruction that corresponds to
502 /// the given parameters. Note that this function doesn't do any validation,
503 /// it's expected that the provided parameters correspond to a valid
505 static std::string
buildMmaSparseAsmString(
506 const std::array
<int64_t, 3> &shape
, unsigned matASize
, unsigned matBSize
,
507 unsigned matCSize
, NVVM::MMATypes ptxTypeA
, NVVM::MMATypes ptxTypeB
,
508 NVVM::MMATypes ptxTypeC
, NVVM::MMATypes ptxTypeD
,
509 std::optional
<NVVM::MMAIntOverflow
> overflow
, unsigned metaDataSelector
) {
510 auto ptxTypeStr
= [](NVVM::MMATypes ptxType
) {
511 return NVVM::stringifyMMATypes(ptxType
);
515 llvm::raw_string_ostream
ss(asmStr
);
516 ss
<< "mma.sp.sync.aligned.m" << shape
[0] << "n" << shape
[1] << "k"
517 << shape
[2] << ".row.col.";
520 ss
<< NVVM::stringifyMMAIntOverflow(*overflow
) << ".";
522 ss
<< ptxTypeStr(ptxTypeD
) << "." << ptxTypeStr(ptxTypeA
) << "."
523 << ptxTypeStr(ptxTypeB
) << "." << ptxTypeStr(ptxTypeC
) << " ";
524 unsigned asmArgIdx
= 0;
526 // The operand string is structured into sections `{matC elements...},
527 // {matA elements...}, {matB elements...}, {matC elements}`.
528 for (const auto arrSize
: {matCSize
, matASize
, matBSize
, matCSize
}) {
530 for (unsigned i
= 0; i
< arrSize
; i
++)
531 ss
<< "$" << asmArgIdx
++ << (i
< arrSize
- 1 ? "," : "");
534 ss
<< "$" << asmArgIdx
++ << ",";
535 assert(metaDataSelector
<= 1);
536 ss
<< "0x" << metaDataSelector
<< ";";
540 /// Builds an inline assembly operation corresponding to the specified MMA
541 /// sparse sync operation.
542 static FailureOr
<LLVM::InlineAsmOp
> emitMmaSparseSyncOpAsm(
543 ImplicitLocOpBuilder
&b
, NVVM::MMATypes ptxTypeA
, NVVM::MMATypes ptxTypeB
,
544 NVVM::MMATypes ptxTypeC
, NVVM::MMATypes ptxTypeD
,
545 std::optional
<NVVM::MMAIntOverflow
> overflow
, ArrayRef
<Value
> unpackedAData
,
546 ArrayRef
<Value
> unpackedB
, ArrayRef
<Value
> unpackedC
, Value indexData
,
547 int64_t metadataSelector
, const std::array
<int64_t, 3> &shape
,
548 Type intrinsicResultType
) {
549 auto asmDialectAttr
=
550 LLVM::AsmDialectAttr::get(b
.getContext(), LLVM::AsmDialect::AD_ATT
);
552 const unsigned matASize
= unpackedAData
.size();
553 const unsigned matBSize
= unpackedB
.size();
554 const unsigned matCSize
= unpackedC
.size();
556 std::string asmStr
= buildMmaSparseAsmString(
557 shape
, matASize
, matBSize
, matCSize
, ptxTypeA
, ptxTypeB
, ptxTypeC
,
558 ptxTypeD
, overflow
, metadataSelector
);
559 std::string constraintStr
=
560 buildMmaSparseAsmConstraintString(matASize
, matBSize
, matCSize
);
562 SmallVector
<Value
> asmVals
;
563 asmVals
.reserve(matASize
+ matBSize
+ matCSize
+ 1);
564 for (ArrayRef
<Value
> args
: {unpackedAData
, unpackedB
, unpackedC
})
565 llvm::append_range(asmVals
, args
);
566 asmVals
.push_back(indexData
);
568 return b
.create
<LLVM::InlineAsmOp
>(
569 /*resultTypes=*/intrinsicResultType
,
570 /*operands=*/asmVals
,
571 /*asm_string=*/asmStr
,
572 /*constraints=*/constraintStr
,
573 /*has_side_effects=*/true,
574 /*is_align_stack=*/false,
575 /*asm_dialect=*/asmDialectAttr
,
576 /*operand_attrs=*/ArrayAttr());
579 /// Lowers `nvgpu.mma.sp.sync` to inline assembly.
580 struct NVGPUMmaSparseSyncLowering
581 : public ConvertOpToLLVMPattern
<nvgpu::MmaSparseSyncOp
> {
582 using ConvertOpToLLVMPattern
<nvgpu::MmaSparseSyncOp
>::ConvertOpToLLVMPattern
;
585 matchAndRewrite(nvgpu::MmaSparseSyncOp op
, OpAdaptor adaptor
,
586 ConversionPatternRewriter
&rewriter
) const override
{
587 ImplicitLocOpBuilder
b(op
.getLoc(), rewriter
);
588 // Get the shapes of the MMAMatrix type being used. The shapes will
589 // choose which intrinsic this op will be lowered to.
590 VectorType aType
= op
.getMatrixA().getType();
591 VectorType bType
= op
.getMatrixB().getType();
592 VectorType cType
= op
.getMatrixC().getType();
594 FailureOr
<NVVM::MMATypes
> ptxTypeA
= getNvvmMmaType(aType
);
595 if (failed(ptxTypeA
))
596 return op
->emitOpError("failed to deduce operand PTX types");
597 FailureOr
<NVVM::MMATypes
> ptxTypeB
= getNvvmMmaType(bType
);
598 if (failed(ptxTypeB
))
599 return op
->emitOpError("failed to deduce operand PTX types");
600 std::optional
<NVVM::MMATypes
> ptxTypeC
=
601 NVVM::MmaOp::inferOperandMMAType(cType
.getElementType(),
602 /*isAccumulator=*/true);
604 return op
->emitError(
605 "could not infer the PTX type for the accumulator/result");
607 // Same as `mma.sync`, F32 works only with TensorFloat32 (TF32).
608 bool tf32Enabled
= op
->hasAttr(op
.getTf32EnabledAttrName());
609 if (aType
.getElementType().isF32() && !tf32Enabled
)
612 // TODO: add an attribute to the op to customize this behavior.
613 std::optional
<NVVM::MMAIntOverflow
> overflow(std::nullopt
);
614 if (isa
<IntegerType
>(aType
.getElementType()))
615 overflow
= NVVM::MMAIntOverflow::satfinite
;
617 SmallVector
<Value
> matA
=
618 unpackOperandVector(b
, adaptor
.getMatrixA(), *ptxTypeA
);
619 SmallVector
<Value
> matB
=
620 unpackOperandVector(b
, adaptor
.getMatrixB(), *ptxTypeB
);
621 SmallVector
<Value
> matC
=
622 unpackOperandVector(b
, adaptor
.getMatrixC(), *ptxTypeC
);
624 Type desiredRetTy
= typeConverter
->convertType(op
->getResultTypes()[0]);
625 Type intrinsicResTy
= inferIntrinsicResultType(
626 typeConverter
->convertType(op
->getResultTypes()[0]));
628 // Bitcast the sparse metadata from vector<2xf16> to an i32.
629 Value sparseMetadata
= adaptor
.getSparseMetadata();
630 if (sparseMetadata
.getType() !=
631 LLVM::getFixedVectorType(rewriter
.getI16Type(), 2))
632 return op
->emitOpError() << "Expected metadata type to be LLVM "
633 "VectorType of 2 i16 elements";
635 b
.create
<LLVM::BitcastOp
>(rewriter
.getI32Type(), sparseMetadata
);
637 FailureOr
<LLVM::InlineAsmOp
> intrinsicResult
= emitMmaSparseSyncOpAsm(
638 b
, *ptxTypeA
, *ptxTypeB
, *ptxTypeC
, *ptxTypeC
, overflow
, matA
, matB
,
639 matC
, sparseMetadata
, op
.getSparsitySelector(), op
.getMmaShapeAsArray(),
641 if (failed(intrinsicResult
))
644 assert((*intrinsicResult
).getNumResults() == 1 &&
645 "expected inline asm op returns a single LLVM struct type");
647 op
, convertIntrinsicResult(op
.getLoc(), intrinsicResTy
, desiredRetTy
,
648 (*intrinsicResult
)->getResult(0), rewriter
));
653 struct NVGPUAsyncCopyLowering
654 : public ConvertOpToLLVMPattern
<nvgpu::DeviceAsyncCopyOp
> {
655 using ConvertOpToLLVMPattern
<
656 nvgpu::DeviceAsyncCopyOp
>::ConvertOpToLLVMPattern
;
659 matchAndRewrite(nvgpu::DeviceAsyncCopyOp op
, OpAdaptor adaptor
,
660 ConversionPatternRewriter
&rewriter
) const override
{
661 ImplicitLocOpBuilder
b(op
.getLoc(), rewriter
);
662 Location loc
= op
.getLoc();
663 auto dstMemrefType
= cast
<MemRefType
>(op
.getDst().getType());
665 getStridedElementPtr(b
.getLoc(), dstMemrefType
, adaptor
.getDst(),
666 adaptor
.getDstIndices(), rewriter
);
667 FailureOr
<unsigned> dstAddressSpace
=
668 getTypeConverter()->getMemRefAddressSpace(dstMemrefType
);
669 if (failed(dstAddressSpace
))
670 return rewriter
.notifyMatchFailure(
671 loc
, "destination memref address space not convertible to integer");
673 auto srcMemrefType
= cast
<MemRefType
>(op
.getSrc().getType());
674 FailureOr
<unsigned> srcAddressSpace
=
675 getTypeConverter()->getMemRefAddressSpace(srcMemrefType
);
676 if (failed(srcAddressSpace
))
677 return rewriter
.notifyMatchFailure(
678 loc
, "source memref address space not convertible to integer");
680 Value scrPtr
= getStridedElementPtr(loc
, srcMemrefType
, adaptor
.getSrc(),
681 adaptor
.getSrcIndices(), rewriter
);
682 // Intrinsics takes a global pointer so we need an address space cast.
683 auto srcPointerGlobalType
= LLVM::LLVMPointerType::get(
684 op
->getContext(), NVVM::NVVMMemorySpace::kGlobalMemorySpace
);
685 scrPtr
= b
.create
<LLVM::AddrSpaceCastOp
>(srcPointerGlobalType
, scrPtr
);
686 int64_t dstElements
= adaptor
.getDstElements().getZExtValue();
687 int64_t sizeInBytes
=
688 (dstMemrefType
.getElementTypeBitWidth() * dstElements
) / 8;
689 // When the optional SrcElements argument is *not* present, the regular
690 // CpAsyncOp is generated. CopyAsyncOp reads bytes from source (global
691 // memory) to fill DstElements number of elements in the destination
693 Value srcBytes
= adaptor
.getSrcElements();
695 // When the optional SrcElements argument is present, the source (global
696 // memory) of CpAsyncOp is read only for SrcElements number of elements.
697 // The rest of the DstElements in the destination (shared memory) are
698 // filled with zeros.
700 b
.create
<LLVM::ConstantOp
>(b
.getI32Type(), b
.getI32IntegerAttr(3));
701 Value bitwidth
= b
.create
<LLVM::ConstantOp
>(
703 b
.getI32IntegerAttr(srcMemrefType
.getElementTypeBitWidth()));
704 Value srcElementsI32
= b
.create
<LLVM::TruncOp
>(b
.getI32Type(), srcBytes
);
705 srcBytes
= b
.create
<LLVM::LShrOp
>(
706 b
.create
<LLVM::MulOp
>(bitwidth
, srcElementsI32
), c3I32
);
708 // Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
710 NVVM::LoadCacheModifierKind cacheModifier
=
711 (op
.getBypassL1().value_or(false) && sizeInBytes
== 16)
712 ? NVVM::LoadCacheModifierKind::CG
713 : NVVM::LoadCacheModifierKind::CA
;
715 b
.create
<NVVM::CpAsyncOp
>(
716 dstPtr
, scrPtr
, rewriter
.getI32IntegerAttr(sizeInBytes
),
717 NVVM::LoadCacheModifierKindAttr::get(op
->getContext(), cacheModifier
),
720 // Drop the result token.
721 Value zero
= b
.create
<LLVM::ConstantOp
>(
722 IntegerType::get(op
.getContext(), 32), rewriter
.getI32IntegerAttr(0));
723 rewriter
.replaceOp(op
, zero
);
728 struct NVGPUAsyncCreateGroupLowering
729 : public ConvertOpToLLVMPattern
<nvgpu::DeviceAsyncCreateGroupOp
> {
730 using ConvertOpToLLVMPattern
<
731 nvgpu::DeviceAsyncCreateGroupOp
>::ConvertOpToLLVMPattern
;
734 matchAndRewrite(nvgpu::DeviceAsyncCreateGroupOp op
, OpAdaptor adaptor
,
735 ConversionPatternRewriter
&rewriter
) const override
{
736 rewriter
.create
<NVVM::CpAsyncCommitGroupOp
>(op
.getLoc());
737 // Drop the result token.
738 Value zero
= rewriter
.create
<LLVM::ConstantOp
>(
739 op
->getLoc(), IntegerType::get(op
.getContext(), 32),
740 rewriter
.getI32IntegerAttr(0));
741 rewriter
.replaceOp(op
, zero
);
746 struct NVGPUAsyncWaitLowering
747 : public ConvertOpToLLVMPattern
<nvgpu::DeviceAsyncWaitOp
> {
748 using ConvertOpToLLVMPattern
<
749 nvgpu::DeviceAsyncWaitOp
>::ConvertOpToLLVMPattern
;
752 matchAndRewrite(nvgpu::DeviceAsyncWaitOp op
, OpAdaptor adaptor
,
753 ConversionPatternRewriter
&rewriter
) const override
{
754 // If numGroup is not present pick 0 as a conservative correct value.
755 int32_t numGroups
= adaptor
.getNumGroups().value_or(0);
756 rewriter
.create
<NVVM::CpAsyncWaitGroupOp
>(op
.getLoc(), numGroups
);
757 rewriter
.eraseOp(op
);
762 /// Creates mbarrier object in shared memory
763 struct NVGPUMBarrierCreateLowering
764 : public ConvertOpToLLVMPattern
<nvgpu::MBarrierCreateOp
> {
765 using ConvertOpToLLVMPattern
<nvgpu::MBarrierCreateOp
>::ConvertOpToLLVMPattern
;
767 template <typename moduleT
>
768 memref::GlobalOp
generateGlobalBarrier(ConversionPatternRewriter
&rewriter
,
769 Operation
*funcOp
, moduleT moduleOp
,
770 MemRefType barrierType
) const {
771 SymbolTable
symbolTable(moduleOp
);
772 OpBuilder::InsertionGuard
guard(rewriter
);
773 rewriter
.setInsertionPoint(&moduleOp
.front());
774 auto global
= rewriter
.create
<memref::GlobalOp
>(
775 funcOp
->getLoc(), "__mbarrier",
776 /*sym_visibility=*/rewriter
.getStringAttr("private"),
777 /*type=*/barrierType
,
778 /*initial_value=*/ElementsAttr(),
780 /*alignment=*/rewriter
.getI64IntegerAttr(8));
781 symbolTable
.insert(global
);
786 matchAndRewrite(nvgpu::MBarrierCreateOp op
, OpAdaptor adaptor
,
787 ConversionPatternRewriter
&rewriter
) const override
{
788 Operation
*funcOp
= op
->getParentOp();
789 MemRefType barrierType
= nvgpu::getMBarrierMemrefType(
790 rewriter
.getContext(), op
.getBarriers().getType());
792 memref::GlobalOp global
;
793 if (auto moduleOp
= funcOp
->getParentOfType
<gpu::GPUModuleOp
>())
794 global
= generateGlobalBarrier(rewriter
, funcOp
, moduleOp
, barrierType
);
795 else if (auto moduleOp
= funcOp
->getParentOfType
<ModuleOp
>())
796 global
= generateGlobalBarrier(rewriter
, funcOp
, moduleOp
, barrierType
);
798 rewriter
.setInsertionPoint(op
);
799 rewriter
.replaceOpWithNewOp
<memref::GetGlobalOp
>(op
, barrierType
,
805 /// Base class for lowering mbarrier operations to nvvm intrinsics.
806 template <typename SourceOp
>
807 struct MBarrierBasePattern
: public ConvertOpToLLVMPattern
<SourceOp
> {
809 using ConvertOpToLLVMPattern
<SourceOp
>::ConvertOpToLLVMPattern
;
810 /// Returns the base pointer of the mbarrier object.
811 Value
getMbarrierPtr(ImplicitLocOpBuilder
&b
,
812 nvgpu::MBarrierGroupType mbarType
, Value memrefDesc
,
814 ConversionPatternRewriter
&rewriter
) const {
815 MemRefType mbarrierMemrefType
=
816 nvgpu::getMBarrierMemrefType(rewriter
.getContext(), mbarType
);
817 return ConvertToLLVMPattern::getStridedElementPtr(
818 b
.getLoc(), mbarrierMemrefType
, memrefDesc
, {mbarId
}, rewriter
);
822 /// Lowers `nvgpu.mbarrier.init` to `nvvm.mbarrier.init`
823 struct NVGPUMBarrierInitLowering
824 : public MBarrierBasePattern
<nvgpu::MBarrierInitOp
> {
825 using MBarrierBasePattern
<nvgpu::MBarrierInitOp
>::MBarrierBasePattern
;
828 matchAndRewrite(nvgpu::MBarrierInitOp op
, OpAdaptor adaptor
,
829 ConversionPatternRewriter
&rewriter
) const override
{
830 ImplicitLocOpBuilder
b(op
->getLoc(), rewriter
);
831 nvgpu::MBarrierGroupType mbarrierType
= op
.getBarriers().getType();
832 rewriter
.setInsertionPoint(op
);
833 Value barrier
= getMbarrierPtr(b
, mbarrierType
, adaptor
.getBarriers(),
834 adaptor
.getMbarId(), rewriter
);
835 Value count
= truncToI32(b
, adaptor
.getCount());
836 if (isMbarrierShared(mbarrierType
)) {
837 rewriter
.replaceOpWithNewOp
<NVVM::MBarrierInitSharedOp
>(
838 op
, barrier
, count
, adaptor
.getPredicate());
840 rewriter
.replaceOpWithNewOp
<NVVM::MBarrierInitOp
>(op
, barrier
, count
,
841 adaptor
.getPredicate());
847 /// Lowers `nvgpu.mbarrier.arrive` to `nvvm.mbarrier.arrive`
848 struct NVGPUMBarrierArriveLowering
849 : public MBarrierBasePattern
<nvgpu::MBarrierArriveOp
> {
850 using MBarrierBasePattern
<nvgpu::MBarrierArriveOp
>::MBarrierBasePattern
;
852 matchAndRewrite(nvgpu::MBarrierArriveOp op
, OpAdaptor adaptor
,
853 ConversionPatternRewriter
&rewriter
) const override
{
854 ImplicitLocOpBuilder
b(op
->getLoc(), rewriter
);
856 getMbarrierPtr(b
, op
.getBarriers().getType(), adaptor
.getBarriers(),
857 adaptor
.getMbarId(), rewriter
);
858 Type tokenType
= getTypeConverter()->convertType(
859 nvgpu::MBarrierTokenType::get(op
->getContext()));
860 if (isMbarrierShared(op
.getBarriers().getType())) {
861 rewriter
.replaceOpWithNewOp
<NVVM::MBarrierArriveSharedOp
>(op
, tokenType
,
864 rewriter
.replaceOpWithNewOp
<NVVM::MBarrierArriveOp
>(op
, tokenType
,
871 /// Lowers `nvgpu.mbarrier.arrive.nocomplete` to
872 /// `nvvm.mbarrier.arrive.nocomplete`
873 struct NVGPUMBarrierArriveNoCompleteLowering
874 : public MBarrierBasePattern
<nvgpu::MBarrierArriveNoCompleteOp
> {
875 using MBarrierBasePattern
<
876 nvgpu::MBarrierArriveNoCompleteOp
>::MBarrierBasePattern
;
878 matchAndRewrite(nvgpu::MBarrierArriveNoCompleteOp op
, OpAdaptor adaptor
,
879 ConversionPatternRewriter
&rewriter
) const override
{
880 ImplicitLocOpBuilder
b(op
->getLoc(), rewriter
);
882 getMbarrierPtr(b
, op
.getBarriers().getType(), adaptor
.getBarriers(),
883 adaptor
.getMbarId(), rewriter
);
884 Type tokenType
= getTypeConverter()->convertType(
885 nvgpu::MBarrierTokenType::get(op
->getContext()));
886 Value count
= truncToI32(b
, adaptor
.getCount());
887 if (isMbarrierShared(op
.getBarriers().getType())) {
888 rewriter
.replaceOpWithNewOp
<NVVM::MBarrierArriveNocompleteSharedOp
>(
889 op
, tokenType
, barrier
, count
);
891 rewriter
.replaceOpWithNewOp
<NVVM::MBarrierArriveNocompleteOp
>(
892 op
, tokenType
, barrier
, count
);
898 /// Lowers `nvgpu.mbarrier.test.wait` to `nvvm.mbarrier.test.wait`
899 struct NVGPUMBarrierTestWaitLowering
900 : public MBarrierBasePattern
<nvgpu::MBarrierTestWaitOp
> {
901 using MBarrierBasePattern
<nvgpu::MBarrierTestWaitOp
>::MBarrierBasePattern
;
903 matchAndRewrite(nvgpu::MBarrierTestWaitOp op
, OpAdaptor adaptor
,
904 ConversionPatternRewriter
&rewriter
) const override
{
905 ImplicitLocOpBuilder
b(op
->getLoc(), rewriter
);
907 getMbarrierPtr(b
, op
.getBarriers().getType(), adaptor
.getBarriers(),
908 adaptor
.getMbarId(), rewriter
);
909 Type retType
= rewriter
.getI1Type();
910 if (isMbarrierShared(op
.getBarriers().getType())) {
911 rewriter
.replaceOpWithNewOp
<NVVM::MBarrierTestWaitSharedOp
>(
912 op
, retType
, barrier
, adaptor
.getToken());
914 rewriter
.replaceOpWithNewOp
<NVVM::MBarrierTestWaitOp
>(
915 op
, retType
, barrier
, adaptor
.getToken());
921 struct NVGPUMBarrierArriveExpectTxLowering
922 : public MBarrierBasePattern
<nvgpu::MBarrierArriveExpectTxOp
> {
923 using MBarrierBasePattern
<
924 nvgpu::MBarrierArriveExpectTxOp
>::MBarrierBasePattern
;
926 matchAndRewrite(nvgpu::MBarrierArriveExpectTxOp op
, OpAdaptor adaptor
,
927 ConversionPatternRewriter
&rewriter
) const override
{
928 ImplicitLocOpBuilder
b(op
->getLoc(), rewriter
);
930 getMbarrierPtr(b
, op
.getBarriers().getType(), adaptor
.getBarriers(),
931 adaptor
.getMbarId(), rewriter
);
932 Value txcount
= truncToI32(b
, adaptor
.getTxcount());
934 if (isMbarrierShared(op
.getBarriers().getType())) {
935 rewriter
.replaceOpWithNewOp
<NVVM::MBarrierArriveExpectTxSharedOp
>(
936 op
, barrier
, txcount
, adaptor
.getPredicate());
940 rewriter
.replaceOpWithNewOp
<NVVM::MBarrierArriveExpectTxOp
>(
941 op
, barrier
, txcount
, adaptor
.getPredicate());
946 struct NVGPUMBarrierTryWaitParityLowering
947 : public MBarrierBasePattern
<nvgpu::MBarrierTryWaitParityOp
> {
948 using MBarrierBasePattern
<
949 nvgpu::MBarrierTryWaitParityOp
>::MBarrierBasePattern
;
951 matchAndRewrite(nvgpu::MBarrierTryWaitParityOp op
, OpAdaptor adaptor
,
952 ConversionPatternRewriter
&rewriter
) const override
{
953 ImplicitLocOpBuilder
b(op
->getLoc(), rewriter
);
955 getMbarrierPtr(b
, op
.getBarriers().getType(), adaptor
.getBarriers(),
956 adaptor
.getMbarId(), rewriter
);
957 Value ticks
= truncToI32(b
, adaptor
.getTicks());
959 b
.create
<LLVM::ZExtOp
>(b
.getI32Type(), adaptor
.getPhaseParity());
961 if (isMbarrierShared(op
.getBarriers().getType())) {
962 rewriter
.replaceOpWithNewOp
<NVVM::MBarrierTryWaitParitySharedOp
>(
963 op
, barrier
, phase
, ticks
);
967 rewriter
.replaceOpWithNewOp
<NVVM::MBarrierTryWaitParityOp
>(op
, barrier
,
973 struct NVGPUTmaAsyncLoadOpLowering
974 : public MBarrierBasePattern
<nvgpu::TmaAsyncLoadOp
> {
975 using MBarrierBasePattern
<nvgpu::TmaAsyncLoadOp
>::MBarrierBasePattern
;
977 matchAndRewrite(nvgpu::TmaAsyncLoadOp op
, OpAdaptor adaptor
,
978 ConversionPatternRewriter
&rewriter
) const override
{
979 ImplicitLocOpBuilder
b(op
->getLoc(), rewriter
);
980 auto srcMemrefType
= cast
<MemRefType
>(op
.getDst().getType());
981 Value dest
= getStridedElementPtr(op
->getLoc(), srcMemrefType
,
982 adaptor
.getDst(), {}, rewriter
);
984 getMbarrierPtr(b
, op
.getBarriers().getType(), adaptor
.getBarriers(),
985 adaptor
.getMbarId(), rewriter
);
987 SmallVector
<Value
> coords
= adaptor
.getCoordinates();
988 for (auto [index
, value
] : llvm::enumerate(coords
)) {
989 coords
[index
] = truncToI32(b
, value
);
991 rewriter
.replaceOpWithNewOp
<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp
>(
992 op
, dest
, adaptor
.getTensorMapDescriptor(), coords
, barrier
,
993 ValueRange
{}, adaptor
.getMulticastMask(), Value
{},
994 adaptor
.getPredicate());
999 struct NVGPUTmaAsyncStoreOpLowering
1000 : public MBarrierBasePattern
<nvgpu::TmaAsyncStoreOp
> {
1001 using MBarrierBasePattern
<nvgpu::TmaAsyncStoreOp
>::MBarrierBasePattern
;
1003 matchAndRewrite(nvgpu::TmaAsyncStoreOp op
, OpAdaptor adaptor
,
1004 ConversionPatternRewriter
&rewriter
) const override
{
1005 ImplicitLocOpBuilder
b(op
->getLoc(), rewriter
);
1006 auto srcMemrefType
= cast
<MemRefType
>(op
.getSrc().getType());
1007 Value dest
= getStridedElementPtr(op
->getLoc(), srcMemrefType
,
1008 adaptor
.getSrc(), {}, rewriter
);
1009 SmallVector
<Value
> coords
= adaptor
.getCoordinates();
1010 for (auto [index
, value
] : llvm::enumerate(coords
)) {
1011 coords
[index
] = truncToI32(b
, value
);
1014 rewriter
.replaceOpWithNewOp
<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp
>(
1015 op
, adaptor
.getTensorMapDescriptor(), dest
, coords
,
1016 adaptor
.getPredicate());
1021 struct NVGPUGenerateWarpgroupDescriptorLowering
1022 : public ConvertOpToLLVMPattern
<nvgpu::WarpgroupGenerateDescriptorOp
> {
1023 using ConvertOpToLLVMPattern
<
1024 nvgpu::WarpgroupGenerateDescriptorOp
>::ConvertOpToLLVMPattern
;
1027 matchAndRewrite(nvgpu::WarpgroupGenerateDescriptorOp op
, OpAdaptor adaptor
,
1028 ConversionPatternRewriter
&rewriter
) const override
{
1030 ImplicitLocOpBuilder
b(op
->getLoc(), rewriter
);
1032 nvgpu::TensorMapSwizzleKind swizzleKind
=
1033 op
.getTensorMap().getType().getSwizzle();
1036 (swizzleKind
== nvgpu::TensorMapSwizzleKind::SWIZZLE_128B
) ? 128
1037 : (swizzleKind
== nvgpu::TensorMapSwizzleKind::SWIZZLE_64B
) ? 64
1038 : (swizzleKind
== nvgpu::TensorMapSwizzleKind::SWIZZLE_32B
) ? 32
1041 (swizzleKind
== nvgpu::TensorMapSwizzleKind::SWIZZLE_128B
) ? 1
1042 : (swizzleKind
== nvgpu::TensorMapSwizzleKind::SWIZZLE_64B
) ? 2
1043 : (swizzleKind
== nvgpu::TensorMapSwizzleKind::SWIZZLE_32B
) ? 3
1046 auto ti64
= b
.getIntegerType(64);
1047 auto makeConst
= [&](uint64_t index
) -> Value
{
1048 return b
.create
<LLVM::ConstantOp
>(ti64
, b
.getI64IntegerAttr(index
));
1050 auto shiftLeft
= [&](Value value
, unsigned shift
) -> Value
{
1051 return b
.create
<LLVM::ShlOp
>(ti64
, value
, makeConst(shift
));
1053 auto shiftRight
= [&](Value value
, unsigned shift
) -> Value
{
1054 return b
.create
<LLVM::LShrOp
>(ti64
, value
, makeConst(shift
));
1056 auto insertBit
= [&](Value desc
, Value val
, int startBit
) {
1057 return b
.create
<LLVM::OrOp
>(ti64
, desc
, shiftLeft(val
, startBit
));
1060 int64_t sizeN
= op
.getTensorMap().getType().getTensor().getDimSize(0);
1061 uint64_t strideDimVal
= (layout
<< 3) >> exclude4LSB
;
1062 uint64_t leadDimVal
= (sizeN
* layout
) >> exclude4LSB
;
1063 uint64_t offsetVal
= 0;
1065 Value strideDim
= makeConst(strideDimVal
);
1066 Value leadDim
= makeConst(leadDimVal
);
1068 Value baseAddr
= getStridedElementPtr(
1069 op
->getLoc(), cast
<MemRefType
>(op
.getTensor().getType()),
1070 adaptor
.getTensor(), {}, rewriter
);
1071 Value basePtr
= b
.create
<LLVM::PtrToIntOp
>(ti64
, baseAddr
);
1072 // Just use 14 bits for base address
1073 Value basePtr14bit
= shiftRight(shiftLeft(basePtr
, 46), 50);
1075 int startSwizzleBit
= 62, startOffsetBit
= 49, startStrideBit
= 32,
1076 startLeadBit
= 16, startBaseAddrBit
= 0;
1077 Value dsc
= makeConst(0);
1078 // // [62,64) swizzle type
1079 dsc
= insertBit(dsc
, makeConst(swizzle
), startSwizzleBit
);
1080 // // [49,52) base_offset
1081 dsc
= insertBit(dsc
, makeConst(offsetVal
), startOffsetBit
);
1082 // // [32,46) stride
1083 dsc
= insertBit(dsc
, strideDim
, startStrideBit
);
1084 // // [16,30) leading dimension
1085 dsc
= insertBit(dsc
, leadDim
, startLeadBit
);
1086 // // [0,14) start_address
1087 dsc
= insertBit(dsc
, basePtr14bit
, startBaseAddrBit
);
1089 LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
1090 << "leading_off:" << leadDimVal
<< "\t"
1091 << "stride_off :" << strideDimVal
<< "\t"
1092 << "base_offset:" << offsetVal
<< "\t"
1093 << "layout_type:" << swizzle
<< " ("
1094 << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind
)
1095 << ")\n start_addr : " << baseAddr
<< "\n");
1097 rewriter
.replaceOp(op
, dsc
);
1102 static Value
makeI64Const(ImplicitLocOpBuilder
&b
, int32_t index
) {
1103 return b
.create
<LLVM::ConstantOp
>(b
.getIntegerType(64),
1104 b
.getI32IntegerAttr(index
));
1107 /// Returns a Value that holds data type enum that is expected by CUDA driver.
1108 static Value
elementTypeAsLLVMConstant(ImplicitLocOpBuilder
&b
, Type type
) {
1109 // Enum is from CUDA driver API
1110 // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html
1111 enum CUtensorMapDataTypeEnum
{
1112 CU_TENSOR_MAP_DATA_TYPE_UINT8
= 0,
1113 CU_TENSOR_MAP_DATA_TYPE_UINT16
,
1114 CU_TENSOR_MAP_DATA_TYPE_UINT32
,
1115 CU_TENSOR_MAP_DATA_TYPE_INT32
,
1116 CU_TENSOR_MAP_DATA_TYPE_UINT64
,
1117 CU_TENSOR_MAP_DATA_TYPE_INT64
,
1118 CU_TENSOR_MAP_DATA_TYPE_FLOAT16
,
1119 CU_TENSOR_MAP_DATA_TYPE_FLOAT32
,
1120 CU_TENSOR_MAP_DATA_TYPE_FLOAT64
,
1121 CU_TENSOR_MAP_DATA_TYPE_BFLOAT16
,
1122 CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ
,
1123 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32
,
1124 CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ
1127 if (type
.isUnsignedInteger(8))
1128 return makeI64Const(b
, CU_TENSOR_MAP_DATA_TYPE_UINT8
);
1129 if (type
.isUnsignedInteger(16))
1130 return makeI64Const(b
, CU_TENSOR_MAP_DATA_TYPE_UINT16
);
1131 if (type
.isUnsignedInteger(32))
1132 return makeI64Const(b
, CU_TENSOR_MAP_DATA_TYPE_UINT32
);
1133 if (type
.isUnsignedInteger(64))
1134 return makeI64Const(b
, CU_TENSOR_MAP_DATA_TYPE_UINT64
);
1135 if (type
.isSignlessInteger(32))
1136 return makeI64Const(b
, CU_TENSOR_MAP_DATA_TYPE_INT32
);
1137 if (type
.isSignlessInteger(64))
1138 return makeI64Const(b
, CU_TENSOR_MAP_DATA_TYPE_INT64
);
1140 return makeI64Const(b
, CU_TENSOR_MAP_DATA_TYPE_FLOAT16
);
1142 return makeI64Const(b
, CU_TENSOR_MAP_DATA_TYPE_FLOAT32
);
1144 return makeI64Const(b
, CU_TENSOR_MAP_DATA_TYPE_FLOAT64
);
1146 return makeI64Const(b
, CU_TENSOR_MAP_DATA_TYPE_BFLOAT16
);
1148 llvm_unreachable("Not supported data type");
1151 struct NVGPUTmaCreateDescriptorOpLowering
1152 : public ConvertOpToLLVMPattern
<nvgpu::TmaCreateDescriptorOp
> {
1153 using ConvertOpToLLVMPattern
<
1154 nvgpu::TmaCreateDescriptorOp
>::ConvertOpToLLVMPattern
;
1156 matchAndRewrite(nvgpu::TmaCreateDescriptorOp op
, OpAdaptor adaptor
,
1157 ConversionPatternRewriter
&rewriter
) const override
{
1158 ImplicitLocOpBuilder
b(op
->getLoc(), rewriter
);
1159 auto llvmPointerType
= LLVM::LLVMPointerType::get(op
->getContext());
1160 Type llvmInt64Type
= IntegerType::get(op
->getContext(), 64);
1162 Value tensorElementType
=
1163 elementTypeAsLLVMConstant(b
, op
.getTensor().getType().getElementType());
1164 auto promotedOperands
= getTypeConverter()->promoteOperands(
1165 b
.getLoc(), op
->getOperands(), adaptor
.getOperands(), b
);
1167 Value boxArrayPtr
= b
.create
<LLVM::AllocaOp
>(llvmPointerType
, llvmInt64Type
,
1168 makeI64Const(b
, 5));
1169 for (auto [index
, value
] : llvm::enumerate(adaptor
.getBoxDimensions())) {
1170 Value gep
= b
.create
<LLVM::GEPOp
>(llvmPointerType
, llvmPointerType
,
1171 boxArrayPtr
, makeI64Const(b
, index
));
1172 b
.create
<LLVM::StoreOp
>(value
, gep
);
1175 nvgpu::TensorMapDescriptorType desc
= op
.getTensorMap().getType();
1176 // Set Arguments for the function call
1177 SmallVector
<Value
> arguments
;
1178 arguments
.push_back(promotedOperands
[0]); // rank
1179 arguments
.push_back(promotedOperands
[1]); // descriptor
1180 arguments
.push_back(tensorElementType
); // data type
1181 arguments
.push_back(
1182 makeI64Const(b
, (int)desc
.getInterleave())); // interleave
1183 arguments
.push_back(makeI64Const(b
, (int)desc
.getSwizzle())); // swizzle
1184 arguments
.push_back(makeI64Const(b
, (int)desc
.getL2promo())); // l2promo
1185 arguments
.push_back(makeI64Const(b
, (int)desc
.getOob())); // oob
1186 arguments
.push_back(boxArrayPtr
); // box dimensions
1188 // Set data types of the arguments
1189 SmallVector
<Type
> argTypes
= {
1190 llvmInt64Type
, /* int64_t tensorRank */
1191 llvmPointerType
, /* ptr */
1192 llvmInt64Type
, /* int64_t */
1193 llvmInt64Type
, /* int64_t */
1194 llvmInt64Type
, /* int64_t */
1195 llvmInt64Type
, /* int64_t */
1196 llvmInt64Type
, /* int64_t */
1197 llvmPointerType
/* ptr */
1199 FunctionCallBuilder hostRegisterCallBuilder
= {
1200 "mgpuTensorMapEncodeTiledMemref", llvmPointerType
, argTypes
};
1202 hostRegisterCallBuilder
.create(b
.getLoc(), b
, arguments
).getResult();
1204 rewriter
.replaceOp(op
, tensorMap
);
1209 struct NVGPUWarpgroupMmaOpLowering
1210 : public ConvertOpToLLVMPattern
<nvgpu::WarpgroupMmaOp
> {
1211 using ConvertOpToLLVMPattern
<nvgpu::WarpgroupMmaOp
>::ConvertOpToLLVMPattern
;
1213 /// This is a helper class to generate required NVVM Ops for warp-group level
1214 /// matrix multiplication.
1215 /// When the given GEMM shape is larger than the shape of
1216 /// a wgmma instrution in PTX, it can generate multiple NVVM::WgmmaMmaAsyncOp
1217 /// Op(s), group and execute them asynchronously. The class also handles
1218 /// waiting for completion and iterates through WarpgroupMatrixDescriptor to
1219 /// create descriptors for each instruction.
1221 /// For example this is the case when the shape of GEMM is 128x128x128
1223 /// nvvm.wgmma.fence.aligned
1225 /// nvvm.wgmma.mma.async descA, descB
1226 /// iterate(descA, descB)
1227 /// nvvm.wgmma.mma.async descA, descB
1230 /// nvvm.wgmma.group.sync.aligned
1231 /// nvvm.wgmma.wait.group.sync [groupId]
1233 class WarpgroupGemm
{
1234 nvgpu::WarpgroupMmaOp op
;
1235 ImplicitLocOpBuilder b
;
1238 // Entire shape of the given Op
1239 int64_t totalM
, totalN
, totalK
;
1241 // Shape of one wgmma instruction
1242 int wgmmaM
= 0, wgmmaN
= 0, wgmmaK
= 0;
1244 // Iteration counts for GEMM
1245 int iterationM
= 0, iterationN
= 0, iterationK
= 0;
1247 /// The function returns the shape of wgmma instruction that is defined in
1248 /// PTX programming guide.
1249 /// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape
1250 void findWgmmaShape(int64_t sizeM
, int64_t sizeN
, Type inputElemType
) {
1253 if (inputElemType
.isTF32()) {
1255 } else if (inputElemType
.isF16() || inputElemType
.isBF16()) {
1257 } else if (inputElemType
.isFloat8E4M3FN() ||
1258 inputElemType
.isFloat8E5M2() || inputElemType
.isInteger(16)) {
1260 } else if (inputElemType
.isInteger(1)) {
1263 llvm_unreachable("msg: not supported K shape");
1265 LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
1266 << ", n = " << wgmmaN
<< ", k = " << wgmmaK
<< "]\n");
1269 /// Generates WGMMATypesAttr from MLIR Type
1270 NVVM::WGMMATypesAttr
generateWgmmaType(Type type
,
1271 bool useF32
= false) const {
1272 auto getWgmmaType
= [=](Type elemType
) {
1273 if (elemType
.isF32() || elemType
.isTF32())
1274 return useF32
? NVVM::WGMMATypes::f32
: NVVM::WGMMATypes::tf32
;
1275 if (elemType
.isF16())
1276 return NVVM::WGMMATypes::f16
;
1277 if (elemType
.isBF16())
1278 return NVVM::WGMMATypes::bf16
;
1279 if (elemType
.isFloat8E4M3FN())
1280 return NVVM::WGMMATypes::e4m3
;
1281 if (elemType
.isFloat8E5M2())
1282 return NVVM::WGMMATypes::e5m2
;
1283 if (elemType
.isInteger(1))
1284 return NVVM::WGMMATypes::b1
;
1285 if (elemType
.isInteger(8))
1286 return NVVM::WGMMATypes::s8
;
1287 if (elemType
.isUnsignedInteger(8))
1288 return NVVM::WGMMATypes::u8
;
1289 if (elemType
.isInteger(32))
1290 return NVVM::WGMMATypes::s32
;
1291 llvm_unreachable("unsupported type");
1293 return NVVM::WGMMATypesAttr::get(op
->getContext(), getWgmmaType(type
));
1296 /// Generates layout attribute for the input matrix for wgmma instruction
1298 generateWgmmaLayout(std::optional
<bool> transpose
) const {
1299 if (transpose
.value_or(false))
1300 return NVVM::MMALayoutAttr::get(op
->getContext(), NVVM::MMALayout::col
);
1301 return NVVM::MMALayoutAttr::get(op
->getContext(), NVVM::MMALayout::row
);
1304 /// Generates shape attribute for wgmma instruction
1305 NVVM::MMAShapeAttr
generateWgmmaShape() const {
1306 return NVVM::MMAShapeAttr::get(op
->getContext(), wgmmaM
, wgmmaN
, wgmmaK
);
1309 /// Generates scale attributes of output matrix for wgmma instruction
1310 NVVM::WGMMAScaleOutAttr
generateScaleOut() const {
1311 return NVVM::WGMMAScaleOutAttr::get(op
->getContext(),
1312 NVVM::WGMMAScaleOut::one
);
1314 /// Generates scale attributes of input matrix for wgmma instruction
1315 NVVM::WGMMAScaleInAttr
generateScaleIn() const {
1316 return NVVM::WGMMAScaleInAttr::get(op
->getContext(),
1317 NVVM::WGMMAScaleIn::one
);
1320 /// Basic function to generate Add
1321 Value
makeAdd(Value lhs
, Value rhs
) {
1322 return b
.create
<LLVM::AddOp
>(lhs
.getType(), lhs
, rhs
);
1325 /// Moves the descriptor pointer of matrix-A for the next wgmma instruction.
1326 /// Currently, it only handles row-major.
1328 /// It moves the pointer like below for [128][64] size:
1331 /// descA ---> +--+--+--+--+
1336 /// descA+512---> +-----------+
1343 Value
iterateDescriptorA(Value desc
, int i
, int j
, int k
) {
1344 MemRefType matrixTypeA
= op
.getDescriptorA().getType().getTensor();
1345 Type elemA
= matrixTypeA
.getElementType();
1346 int byte
= elemA
.getIntOrFloatBitWidth() / 8;
1347 int tileShapeA
= matrixTypeA
.getDimSize(1);
1348 int incrementVal
= ((wgmmaK
* k
) + (totalK
* tileShapeA
* i
)) * byte
;
1349 incrementVal
= incrementVal
>> exclude4LSB
;
1350 LLVM_DEBUG(DBGS() << "\t\t[m: " << i
<< " n: " << j
<< " k: " << k
1351 << "] [wgmma descriptors] Descriptor A + "
1352 << incrementVal
<< " | \t ");
1355 return makeAdd(desc
, makeI64Const(b
, incrementVal
));
1358 /// Moves the descriptor pointer of matrix-B for the next wgmma instruction.
1359 /// Currently, it only handles column-major.
1361 /// It moves the pointer like below for [128][64] size:
1362 /// descB ---> +--+--+--+--+--+--+--+--+
1363 /// |↓ | | | | | | | |
1364 /// |↓ | | | | | | | |
1365 /// |↓ | | | | | | | |
1366 /// |↓ | | | | | | | |
1367 /// +--+--+--+--+--+--+--+--+
1369 Value
iterateDescriptorB(Value desc
, int i
, int j
, int k
) {
1370 MemRefType matrixTypeB
= op
.getDescriptorB().getType().getTensor();
1371 Type elemB
= matrixTypeB
.getElementType();
1372 int byte
= elemB
.getIntOrFloatBitWidth() / 8;
1373 int incrementVal
= matrixTypeB
.getDimSize(0) * wgmmaK
* k
* byte
;
1374 incrementVal
= incrementVal
>> exclude4LSB
;
1375 LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal
<< "\n");
1378 return makeAdd(desc
, makeI64Const(b
, incrementVal
));
1381 /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
1382 /// descriptors and arranges them based on induction variables: i, j, and k.
1383 Value
generateWgmma(int i
, int j
, int k
, Value matrixC
) {
1384 LLVM_DEBUG(DBGS() << "\t wgmma."
1385 << "m" << wgmmaM
<< "n" << wgmmaN
<< "k" << wgmmaK
1386 << "(A[" << (iterationM
* wgmmaM
) << ":"
1387 << (iterationM
* wgmmaM
) + wgmmaM
<< "]["
1388 << (iterationK
* wgmmaK
) << ":"
1389 << (iterationK
* wgmmaK
+ wgmmaK
) << "] * "
1390 << " B[" << (iterationK
* wgmmaK
) << ":"
1391 << (iterationK
* wgmmaK
+ wgmmaK
) << "][" << 0 << ":"
1392 << wgmmaN
<< "])\n");
1394 Value descriptorA
= iterateDescriptorA(adaptor
.getDescriptorA(), i
, j
, k
);
1395 Value descriptorB
= iterateDescriptorB(adaptor
.getDescriptorB(), i
, j
, k
);
1397 Type elemA
= op
.getDescriptorA().getType().getTensor().getElementType();
1398 NVVM::WGMMATypesAttr itypeA
= generateWgmmaType(elemA
);
1400 Type elemB
= op
.getDescriptorB().getType().getTensor().getElementType();
1401 NVVM::WGMMATypesAttr itypeB
= generateWgmmaType(elemB
);
1403 Type elemD
= op
.getMatrixC().getType().getFragmented().getElementType();
1404 NVVM::WGMMATypesAttr itypeD
= generateWgmmaType(elemD
, true);
1406 NVVM::MMAShapeAttr shape
= generateWgmmaShape();
1407 NVVM::WGMMAScaleOutAttr scaleOut
= generateScaleOut();
1408 NVVM::WGMMAScaleInAttr scaleIn
= generateScaleIn();
1409 NVVM::MMALayoutAttr layoutA
= generateWgmmaLayout(op
.getTransposeA());
1410 NVVM::MMALayoutAttr layoutB
= generateWgmmaLayout(!op
.getTransposeB());
1412 auto overflow
= NVVM::MMAIntOverflowAttr::get(
1413 op
->getContext(), NVVM::MMAIntOverflow::wrapped
);
1415 return b
.create
<NVVM::WgmmaMmaAsyncOp
>(
1416 matrixC
.getType(), matrixC
, descriptorA
, descriptorB
, shape
, itypeA
,
1417 itypeB
, itypeD
, scaleOut
, scaleIn
, scaleIn
, layoutA
, layoutB
,
1421 /// Generates multiple wgmma instructions to complete the given GEMM shape
1422 Value
generateWgmmaGroup() {
1424 b
.create
<LLVM::UndefOp
>(adaptor
.getMatrixC().getType());
1427 SmallVector
<Value
> wgmmaResults
;
1428 for (int i
= 0; i
< iterationM
; ++i
) {
1429 Value matrixC
= b
.create
<LLVM::ExtractValueOp
>(adaptor
.getMatrixC(), i
);
1430 for (int j
= 0; j
< iterationN
; ++j
)
1431 for (int k
= 0; k
< iterationK
; ++k
)
1432 matrixC
= generateWgmma(i
, j
, k
, matrixC
);
1433 wgmmaResults
.push_back(matrixC
);
1435 for (auto [idx
, matrix
] : llvm::enumerate(wgmmaResults
)) {
1436 wgmmaResult
= b
.create
<LLVM::InsertValueOp
>(wgmmaResult
.getType(),
1437 wgmmaResult
, matrix
, idx
);
1443 WarpgroupGemm(nvgpu::WarpgroupMmaOp op
, ImplicitLocOpBuilder
&b
,
1445 : op(op
), b(b
), adaptor(adaptor
) {
1446 // Find the entire GEMM Shape
1447 totalM
= op
.getDescriptorA().getType().getTensor().getDimSize(0);
1448 totalN
= op
.getDescriptorB().getType().getTensor().getDimSize(1);
1449 totalK
= op
.getDescriptorA().getType().getTensor().getDimSize(1);
1450 LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM
<< "][" << totalN
1451 << "] += A[" << totalM
<< "][" << totalK
<< "] * B["
1452 << totalK
<< "][" << totalN
<< "] ---===\n");
1454 // Find the shape for one wgmma instruction
1457 op
.getDescriptorA().getType().getTensor().getElementType());
1459 // Iterations counts to complete the given shape with wgmma shape
1460 iterationM
= totalM
/ wgmmaM
;
1461 iterationN
= totalN
/ wgmmaN
;
1462 iterationK
= totalK
/ wgmmaK
;
1465 /// Generates WgmmaMmaAsync Ops to complete the specified GEMM shape. It
1466 /// includes generating a fence Op (WgmmaFenceAlignedOp) before the
1467 /// instructions and group synchronization, as well as waiting
1468 /// (WgmmaGroupSyncAlignedOp) for group synchronization
1469 /// (WgmmaWaitGroupSyncOp) after the instructions.
1470 Value
generateWarpgroupMma() {
1471 b
.create
<NVVM::WgmmaFenceAlignedOp
>();
1472 Value wgmmaResult
= generateWgmmaGroup();
1473 b
.create
<NVVM::WgmmaGroupSyncAlignedOp
>();
1474 b
.create
<NVVM::WgmmaWaitGroupSyncOp
>(op
.getWaitGroup());
1479 matchAndRewrite(nvgpu::WarpgroupMmaOp op
, OpAdaptor adaptor
,
1480 ConversionPatternRewriter
&rewriter
) const override
{
1481 ImplicitLocOpBuilder
b(op
->getLoc(), rewriter
);
1483 // Step 1. Build a helper class
1484 WarpgroupGemm
warpgroupGemm(op
, b
, adaptor
);
1486 // Step 2. Get the entire GEMM Shape
1487 Value wgmmaResult
= warpgroupGemm
.generateWarpgroupMma();
1489 // Step 3. Replace fragmented result struct with the op results
1490 rewriter
.replaceOp(op
, wgmmaResult
);
1495 struct NVGPUWarpgroupMmaStoreOpLowering
1496 : public ConvertOpToLLVMPattern
<nvgpu::WarpgroupMmaStoreOp
> {
1497 using ConvertOpToLLVMPattern
<
1498 nvgpu::WarpgroupMmaStoreOp
>::ConvertOpToLLVMPattern
;
1500 /// This function stores a fragmented register matrix owned by a warp group
1501 /// (128 threads) into a memref. Each thread has 64 registers, each the size
1503 /// Here is what each threads (T) holds, each `d` is struct value with a
1506 /// Threads in warp-group (128 threads) and what they owns in the matrixD:
1507 /// 0-31 Warp-0 -> MatrixD[0:15 ][0:N]
1508 /// 32-63 Warp-1 -> MatrixD[16:31][0:N]
1509 /// 64-95 Warp-2 -> MatrixD[32:47][0:N]
1510 /// 96-127 Warp-3 -> MatrixD[48:64][0:N]
1513 /// +______________________________________________________________________+
1514 /// | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 |
1515 /// 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
1516 /// 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
1517 /// ..| .........|.........|.........|.........|........|...........|........|
1518 /// 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
1519 /// 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
1520 /// ..| .........|.........|.........|.........|........|...........|........|
1521 /// 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
1522 /// 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
1523 /// ..| .........|.........|.........|.........|........|...........|........|
1524 /// 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
1525 /// ..| .........|.........|.........|.........|........|...........|........|
1526 /// 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
1527 /// ..| .........|.........|.........|.........|........|...........|........|
1528 /// +______________________________________________________________________+
1530 /// \param rewriter: The pattern rewriter.
1531 /// \param matrixD: Result of the warp-group MMA operation (fragmented
1532 /// matrix). It is holded by a thread and a struct with 64 elements.
1533 /// \param dstMemref: The memref where the registers will be stored.
1534 /// \param offset: the offset within the memref where the registers will be
1536 void storeFragmentedMatrix(ImplicitLocOpBuilder
&b
, Value matrixD
,
1537 TypedValue
<MemRefType
> dstMemref
,
1539 Type i32
= b
.getI32Type();
1541 auto makeConst
= [&](int32_t index
) -> Value
{
1542 return b
.create
<LLVM::ConstantOp
>(i32
, b
.getI32IntegerAttr(index
));
1544 Value c1
= makeConst(1);
1545 Value c2
= makeConst(2);
1546 Value c4
= makeConst(4);
1547 Value c8
= makeConst(8);
1548 Value c16
= makeConst(16);
1549 Value warpSize
= makeConst(kWarpSize
);
1551 auto makeMul
= [&](Value lhs
, Value rhs
) -> Value
{
1552 return b
.create
<LLVM::MulOp
>(lhs
.getType(), lhs
, rhs
);
1554 auto makeAdd
= [&](Value lhs
, Value rhs
) -> Value
{
1555 return b
.create
<LLVM::AddOp
>(lhs
.getType(), lhs
, rhs
);
1558 auto makeExtractAndStore
= [&](int i
, Value wgmmaResult
, Value x
, Value y
,
1559 TypedValue
<::mlir::MemRefType
> memref
) {
1560 Type it
= b
.getIndexType();
1561 Value idx
= b
.create
<arith::IndexCastOp
>(it
, x
);
1562 Value idy0
= b
.create
<arith::IndexCastOp
>(it
, y
);
1563 Value idy1
= b
.create
<arith::IndexCastOp
>(it
, makeAdd(y
, c1
));
1564 Value d0
= b
.create
<LLVM::ExtractValueOp
>(wgmmaResult
, i
);
1565 Value d1
= b
.create
<LLVM::ExtractValueOp
>(wgmmaResult
, i
+ 1);
1566 b
.create
<memref::StoreOp
>(d0
, memref
, ValueRange
{idx
, idy0
});
1567 b
.create
<memref::StoreOp
>(d1
, memref
, ValueRange
{idx
, idy1
});
1570 Value tidx
= b
.create
<NVVM::ThreadIdXOp
>(i32
);
1571 Value laneId
= b
.create
<LLVM::URemOp
>(i32
, tidx
, warpSize
);
1572 Value warpId
= b
.create
<LLVM::UDivOp
>(i32
, tidx
, warpSize
);
1573 Value lane4Id
= b
.create
<LLVM::UDivOp
>(i32
, laneId
, c4
);
1574 Value lane4modId
= b
.create
<LLVM::URemOp
>(i32
, laneId
, c4
);
1576 Value tj
= makeMul(lane4modId
, c2
);
1577 Value ti
= makeAdd(lane4Id
, makeMul(warpId
, c16
));
1579 ti
= makeAdd(ti
, makeConst(offset
));
1581 auto structType
= cast
<LLVM::LLVMStructType
>(matrixD
.getType());
1583 // Number of 32-bit registers owns per thread
1584 constexpr unsigned numAdjacentRegisters
= 2;
1585 // Number of 8x8 matrices one below another per warp
1586 constexpr unsigned numStackedMatrices
= 2;
1588 size_t storeCount
= (structType
.getBody().size() /
1589 (numStackedMatrices
* numAdjacentRegisters
));
1591 for (size_t i
= 0; i
< numStackedMatrices
; ++i
) {
1592 Value idx
= makeAdd(ti
, makeMul(makeConst(i
), c8
));
1593 for (size_t j
= 0; j
< storeCount
; ++j
) {
1594 Value idy
= makeAdd(tj
, makeMul(makeConst(j
), c8
));
1595 size_t structIndex
= (i
* numAdjacentRegisters
) +
1596 (j
* (numStackedMatrices
* numAdjacentRegisters
));
1597 makeExtractAndStore(structIndex
, matrixD
, idx
, idy
, dstMemref
);
1603 matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op
, OpAdaptor adaptor
,
1604 ConversionPatternRewriter
&rewriter
) const override
{
1606 ImplicitLocOpBuilder
b(op
->getLoc(), rewriter
);
1607 Value matriDValue
= adaptor
.getMatrixD();
1608 auto stype
= cast
<LLVM::LLVMStructType
>(matriDValue
.getType());
1609 for (auto [idx
, matrixD
] : llvm::enumerate(stype
.getBody())) {
1610 auto structType
= cast
<LLVM::LLVMStructType
>(matrixD
);
1611 Value innerStructValue
= b
.create
<LLVM::ExtractValueOp
>(matriDValue
, idx
);
1612 storeFragmentedMatrix(b
, innerStructValue
, op
.getDstMemref(), offset
);
1613 offset
+= structType
.getBody().size();
1615 rewriter
.eraseOp(op
);
1620 struct NVGPUWarpgroupMmaInitAccumulatorOpLowering
1621 : public ConvertOpToLLVMPattern
<nvgpu::WarpgroupMmaInitAccumulatorOp
> {
1622 using ConvertOpToLLVMPattern
<
1623 nvgpu::WarpgroupMmaInitAccumulatorOp
>::ConvertOpToLLVMPattern
;
1625 matchAndRewrite(nvgpu::WarpgroupMmaInitAccumulatorOp op
, OpAdaptor adaptor
,
1626 ConversionPatternRewriter
&rewriter
) const override
{
1627 ImplicitLocOpBuilder
b(op
->getLoc(), rewriter
);
1628 LLVM::LLVMStructType packStructType
= cast
<LLVM::LLVMStructType
>(
1629 getTypeConverter()->convertType(op
.getMatrixC().getType()));
1630 Type elemType
= cast
<LLVM::LLVMStructType
>(packStructType
.getBody().front())
1633 Value zero
= b
.create
<LLVM::ConstantOp
>(elemType
, b
.getZeroAttr(elemType
));
1634 Value packStruct
= b
.create
<LLVM::UndefOp
>(packStructType
);
1635 SmallVector
<Value
> innerStructs
;
1636 // Unpack the structs and set all values to zero
1637 for (auto [idx
, s
] : llvm::enumerate(packStructType
.getBody())) {
1638 auto structType
= cast
<LLVM::LLVMStructType
>(s
);
1639 Value structValue
= b
.create
<LLVM::ExtractValueOp
>(packStruct
, idx
);
1640 for (unsigned i
= 0; i
< structType
.getBody().size(); ++i
) {
1641 structValue
= b
.create
<LLVM::InsertValueOp
>(
1642 structType
, structValue
, zero
, ArrayRef
<int64_t>({i
}));
1644 innerStructs
.push_back(structValue
);
1646 // Pack the inner structs into a single struct
1647 for (auto [idx
, matrix
] : llvm::enumerate(innerStructs
)) {
1648 packStruct
= b
.create
<LLVM::InsertValueOp
>(packStruct
.getType(),
1649 packStruct
, matrix
, idx
);
1651 rewriter
.replaceOp(op
, packStruct
);
1656 struct NVGPUTmaPrefetchOpLowering
1657 : public ConvertOpToLLVMPattern
<nvgpu::TmaPrefetchOp
> {
1658 using ConvertOpToLLVMPattern
<nvgpu::TmaPrefetchOp
>::ConvertOpToLLVMPattern
;
1660 matchAndRewrite(nvgpu::TmaPrefetchOp op
, OpAdaptor adaptor
,
1661 ConversionPatternRewriter
&rewriter
) const override
{
1662 rewriter
.replaceOpWithNewOp
<NVVM::PrefetchTensorMapOp
>(
1663 op
, adaptor
.getTensorMapDescriptor(), adaptor
.getPredicate());
1668 struct NVGPURcpOpLowering
: public ConvertOpToLLVMPattern
<nvgpu::RcpOp
> {
1669 using ConvertOpToLLVMPattern
<nvgpu::RcpOp
>::ConvertOpToLLVMPattern
;
1671 matchAndRewrite(nvgpu::RcpOp op
, OpAdaptor adaptor
,
1672 ConversionPatternRewriter
&rewriter
) const override
{
1673 ImplicitLocOpBuilder
b(op
->getLoc(), rewriter
);
1674 auto i64Ty
= b
.getI64Type();
1675 auto f32Ty
= b
.getF32Type();
1676 VectorType inTy
= op
.getIn().getType();
1677 // apply rcp.approx.ftz.f on each element in vector.
1678 auto convert1DVec
= [&](Type llvm1DVectorTy
, Value inVec
) {
1679 Value ret1DVec
= b
.create
<LLVM::UndefOp
>(llvm1DVectorTy
);
1680 int numElems
= llvm::cast
<VectorType
>(llvm1DVectorTy
).getNumElements();
1681 for (int i
= 0; i
< numElems
; i
++) {
1682 Value idx
= b
.create
<LLVM::ConstantOp
>(i64Ty
, b
.getI64IntegerAttr(i
));
1683 Value elem
= b
.create
<LLVM::ExtractElementOp
>(inVec
, idx
);
1684 Value dst
= b
.create
<NVVM::RcpApproxFtzF32Op
>(f32Ty
, elem
);
1685 ret1DVec
= b
.create
<LLVM::InsertElementOp
>(ret1DVec
, dst
, idx
);
1689 if (inTy
.getRank() == 1) {
1690 rewriter
.replaceOp(op
, convert1DVec(inTy
, adaptor
.getIn()));
1693 return LLVM::detail::handleMultidimensionalVectors(
1694 op
.getOperation(), adaptor
.getOperands(), *(this->getTypeConverter()),
1695 [&](Type llvm1DVectorTy
, ValueRange operands
) -> Value
{
1696 OpAdaptor
adaptor(operands
);
1697 return convert1DVec(llvm1DVectorTy
, adaptor
.getIn());
1704 void mlir::populateNVGPUToNVVMConversionPatterns(
1705 const LLVMTypeConverter
&converter
, RewritePatternSet
&patterns
) {
1707 NVGPUMBarrierCreateLowering
, // nvgpu.mbarrier.create
1708 NVGPUMBarrierInitLowering
, // nvgpu.mbarrier.init
1709 NVGPUMBarrierArriveLowering
, // nvgpu.mbarrier.arrive
1710 NVGPUMBarrierArriveNoCompleteLowering
, // nvgpu.mbarrier.arrive.no_complete
1711 NVGPUMBarrierTestWaitLowering
, // nvgpu.mbarrier.test_wait_parity
1712 NVGPUMBarrierTryWaitParityLowering
, // nvgpu.mbarrier.try_wait_parity
1713 NVGPUTmaAsyncLoadOpLowering
, // nvgpu.tma.async.load
1714 NVGPUTmaAsyncStoreOpLowering
, // nvgpu.tma.async.store
1715 NVGPUTmaCreateDescriptorOpLowering
, // nvgpu.tma.create.descriptor
1716 NVGPUTmaPrefetchOpLowering
, // nvgpu.tma.prefetch.descriptor
1717 NVGPUMBarrierArriveExpectTxLowering
, // nvgpu.mbarrier.arrive.expect_tx
1718 NVGPUGenerateWarpgroupDescriptorLowering
, // nvgpu.warpgroup.generate.descriptor
1719 NVGPUWarpgroupMmaOpLowering
, // nvgpu.warpgroup.mma
1720 NVGPUWarpgroupMmaStoreOpLowering
, // nvgpu.warpgroup.mma.store
1721 NVGPUWarpgroupMmaInitAccumulatorOpLowering
, // nvgpu.warpgroup.mma.init.accumulator
1722 MmaSyncOptoNVVM
, MmaLdMatrixOpToNVVM
, NVGPUAsyncCopyLowering
,
1723 NVGPUAsyncCreateGroupLowering
, NVGPUAsyncWaitLowering
,
1724 NVGPUMmaSparseSyncLowering
, NVGPURcpOpLowering
>(converter
);