1 //===- NVVMToLLVMIRTranslation.cpp - Translate NVVM to LLVM IR ------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a translation between the MLIR NVVM dialect and
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
15 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
16 #include "mlir/Dialect/Utils/StaticValueUtils.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Support/LogicalResult.h"
19 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/IntrinsicsNVPTX.h"
25 using namespace mlir::LLVM
;
26 using mlir::LLVM::detail::createIntrinsicCall
;
28 static llvm::Intrinsic::ID
getReduxIntrinsicId(llvm::Type
*resultType
,
29 NVVM::ReduxKind kind
) {
30 if (!resultType
->isIntegerTy(32))
31 llvm_unreachable("unsupported data type for redux");
34 case NVVM::ReduxKind::ADD
:
35 return llvm::Intrinsic::nvvm_redux_sync_add
;
36 case NVVM::ReduxKind::UMAX
:
37 return llvm::Intrinsic::nvvm_redux_sync_umax
;
38 case NVVM::ReduxKind::UMIN
:
39 return llvm::Intrinsic::nvvm_redux_sync_umin
;
40 case NVVM::ReduxKind::AND
:
41 return llvm::Intrinsic::nvvm_redux_sync_and
;
42 case NVVM::ReduxKind::OR
:
43 return llvm::Intrinsic::nvvm_redux_sync_or
;
44 case NVVM::ReduxKind::XOR
:
45 return llvm::Intrinsic::nvvm_redux_sync_xor
;
46 case NVVM::ReduxKind::MAX
:
47 return llvm::Intrinsic::nvvm_redux_sync_max
;
48 case NVVM::ReduxKind::MIN
:
49 return llvm::Intrinsic::nvvm_redux_sync_min
;
51 llvm_unreachable("unknown redux kind");
54 static llvm::Intrinsic::ID
getShflIntrinsicId(llvm::Type
*resultType
,
59 resultType
= cast
<llvm::StructType
>(resultType
)->getElementType(0);
61 case NVVM::ShflKind::bfly
:
62 return resultType
->isFloatTy()
63 ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
64 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p
;
65 case NVVM::ShflKind::up
:
66 return resultType
->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
67 : llvm::Intrinsic::nvvm_shfl_sync_up_i32p
;
68 case NVVM::ShflKind::down
:
69 return resultType
->isFloatTy()
70 ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
71 : llvm::Intrinsic::nvvm_shfl_sync_down_i32p
;
72 case NVVM::ShflKind::idx
:
73 return resultType
->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
74 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p
;
78 case NVVM::ShflKind::bfly
:
79 return resultType
->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
80 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32
;
81 case NVVM::ShflKind::up
:
82 return resultType
->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
83 : llvm::Intrinsic::nvvm_shfl_sync_up_i32
;
84 case NVVM::ShflKind::down
:
85 return resultType
->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
86 : llvm::Intrinsic::nvvm_shfl_sync_down_i32
;
87 case NVVM::ShflKind::idx
:
88 return resultType
->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
89 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32
;
92 llvm_unreachable("unknown shuffle kind");
95 /// Return the intrinsic ID associated with ldmatrix for the given paramters.
96 static llvm::Intrinsic::ID
getLdMatrixIntrinsicId(NVVM::MMALayout layout
,
98 if (layout
== NVVM::MMALayout::row
) {
101 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16
;
103 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16
;
105 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16
;
107 llvm_unreachable("unsupported number of matrix");
113 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16
;
115 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16
;
117 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16
;
119 llvm_unreachable("unsupported number of matrix");
125 /// Implementation of the dialect interface that converts operations belonging
126 /// to the NVVM dialect to LLVM IR.
127 class NVVMDialectLLVMIRTranslationInterface
128 : public LLVMTranslationDialectInterface
{
130 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface
;
132 /// Translates the given operation to LLVM IR using the provided IR builder
133 /// and saving the state in `moduleTranslation`.
135 convertOperation(Operation
*op
, llvm::IRBuilderBase
&builder
,
136 LLVM::ModuleTranslation
&moduleTranslation
) const final
{
137 Operation
&opInst
= *op
;
138 #include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
143 /// Attaches module-level metadata for functions marked as kernels.
145 amendOperation(Operation
*op
, ArrayRef
<llvm::Instruction
*> instructions
,
146 NamedAttribute attribute
,
147 LLVM::ModuleTranslation
&moduleTranslation
) const final
{
148 auto func
= dyn_cast
<LLVM::LLVMFuncOp
>(op
);
151 llvm::LLVMContext
&llvmContext
= moduleTranslation
.getLLVMContext();
152 llvm::Function
*llvmFunc
= moduleTranslation
.lookupFunction(func
.getName());
154 auto generateMetadata
= [&](int dim
, StringRef name
) {
155 llvm::Metadata
*llvmMetadata
[] = {
156 llvm::ValueAsMetadata::get(llvmFunc
),
157 llvm::MDString::get(llvmContext
, name
),
158 llvm::ValueAsMetadata::get(llvm::ConstantInt::get(
159 llvm::Type::getInt32Ty(llvmContext
), dim
))};
160 llvm::MDNode
*llvmMetadataNode
=
161 llvm::MDNode::get(llvmContext
, llvmMetadata
);
162 moduleTranslation
.getOrInsertNamedModuleMetadata("nvvm.annotations")
163 ->addOperand(llvmMetadataNode
);
165 if (attribute
.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
166 if (!dyn_cast
<DenseI32ArrayAttr
>(attribute
.getValue()))
168 auto values
= cast
<DenseI32ArrayAttr
>(attribute
.getValue());
169 generateMetadata(values
[0], NVVM::NVVMDialect::getMaxntidXName());
170 if (values
.size() > 1)
171 generateMetadata(values
[1], NVVM::NVVMDialect::getMaxntidYName());
172 if (values
.size() > 2)
173 generateMetadata(values
[2], NVVM::NVVMDialect::getMaxntidZName());
174 } else if (attribute
.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
175 if (!dyn_cast
<DenseI32ArrayAttr
>(attribute
.getValue()))
177 auto values
= cast
<DenseI32ArrayAttr
>(attribute
.getValue());
178 generateMetadata(values
[0], NVVM::NVVMDialect::getReqntidXName());
179 if (values
.size() > 1)
180 generateMetadata(values
[1], NVVM::NVVMDialect::getReqntidYName());
181 if (values
.size() > 2)
182 generateMetadata(values
[2], NVVM::NVVMDialect::getReqntidZName());
183 } else if (attribute
.getName() ==
184 NVVM::NVVMDialect::getMinctasmAttrName()) {
185 auto value
= dyn_cast
<IntegerAttr
>(attribute
.getValue());
186 generateMetadata(value
.getInt(), "minctasm");
187 } else if (attribute
.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) {
188 auto value
= dyn_cast
<IntegerAttr
>(attribute
.getValue());
189 generateMetadata(value
.getInt(), "maxnreg");
190 } else if (attribute
.getName() ==
191 NVVM::NVVMDialect::getKernelFuncAttrName()) {
192 llvm::Metadata
*llvmMetadataKernel
[] = {
193 llvm::ValueAsMetadata::get(llvmFunc
),
194 llvm::MDString::get(llvmContext
, "kernel"),
195 llvm::ValueAsMetadata::get(
196 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext
), 1))};
197 llvm::MDNode
*llvmMetadataNode
=
198 llvm::MDNode::get(llvmContext
, llvmMetadataKernel
);
199 moduleTranslation
.getOrInsertNamedModuleMetadata("nvvm.annotations")
200 ->addOperand(llvmMetadataNode
);
206 convertParameterAttr(LLVMFuncOp funcOp
, int argIdx
, NamedAttribute attribute
,
207 LLVM::ModuleTranslation
&moduleTranslation
) const final
{
209 llvm::LLVMContext
&llvmContext
= moduleTranslation
.getLLVMContext();
210 llvm::Function
*llvmFunc
=
211 moduleTranslation
.lookupFunction(funcOp
.getName());
212 llvm::NamedMDNode
*nvvmAnnotations
=
213 moduleTranslation
.getOrInsertNamedModuleMetadata("nvvm.annotations");
215 if (attribute
.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
216 llvm::MDNode
*gridConstantMetaData
= nullptr;
218 // Check if a 'grid_constant' metadata node exists for the given function
219 for (llvm::MDNode
*opnd
: llvm::reverse(nvvmAnnotations
->operands())) {
220 if (opnd
->getNumOperands() == 3 &&
221 opnd
->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc
) &&
222 opnd
->getOperand(1) ==
223 llvm::MDString::get(llvmContext
, "grid_constant")) {
224 gridConstantMetaData
= opnd
;
229 // 'grid_constant' is a function-level meta data node with a list of
230 // integers, where each integer n denotes that the nth parameter has the
231 // grid_constant annotation (numbering from 1). This requires aggregating
232 // the indices of the individual parameters that have this attribute.
233 llvm::Type
*i32
= llvm::IntegerType::get(llvmContext
, 32);
234 if (gridConstantMetaData
== nullptr) {
235 // Create a new 'grid_constant' metadata node
236 SmallVector
<llvm::Metadata
*> gridConstMetadata
= {
237 llvm::ValueAsMetadata::getConstant(
238 llvm::ConstantInt::get(i32
, argIdx
+ 1))};
239 llvm::Metadata
*llvmMetadata
[] = {
240 llvm::ValueAsMetadata::get(llvmFunc
),
241 llvm::MDString::get(llvmContext
, "grid_constant"),
242 llvm::MDNode::get(llvmContext
, gridConstMetadata
)};
243 llvm::MDNode
*llvmMetadataNode
=
244 llvm::MDNode::get(llvmContext
, llvmMetadata
);
245 nvvmAnnotations
->addOperand(llvmMetadataNode
);
247 // Append argIdx + 1 to the 'grid_constant' argument list
249 dyn_cast
<llvm::MDTuple
>(gridConstantMetaData
->getOperand(2))) {
250 llvm::TempMDTuple clonedArgList
= argList
->clone();
251 clonedArgList
->push_back((llvm::ValueAsMetadata::getConstant(
252 llvm::ConstantInt::get(i32
, argIdx
+ 1))));
253 gridConstantMetaData
->replaceOperandWith(
254 2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList
)));
263 void mlir::registerNVVMDialectTranslation(DialectRegistry
®istry
) {
264 registry
.insert
<NVVM::NVVMDialect
>();
265 registry
.addExtension(+[](MLIRContext
*ctx
, NVVM::NVVMDialect
*dialect
) {
266 dialect
->addInterfaces
<NVVMDialectLLVMIRTranslationInterface
>();
270 void mlir::registerNVVMDialectTranslation(MLIRContext
&context
) {
271 DialectRegistry registry
;
272 registerNVVMDialectTranslation(registry
);
273 context
.appendDialectRegistry(registry
);