1 //===- NVVMToLLVMIRTranslation.cpp - Translate NVVM to LLVM IR ------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a translation between the MLIR NVVM dialect and
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
15 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
16 #include "mlir/Dialect/Utils/StaticValueUtils.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
20 #include "llvm/IR/IRBuilder.h"
21 #include "llvm/IR/IntrinsicsNVPTX.h"
24 using namespace mlir::LLVM
;
25 using mlir::LLVM::detail::createIntrinsicCall
;
27 static llvm::Intrinsic::ID
getReduxIntrinsicId(llvm::Type
*resultType
,
28 NVVM::ReduxKind kind
) {
29 if (!resultType
->isIntegerTy(32))
30 llvm_unreachable("unsupported data type for redux");
33 case NVVM::ReduxKind::ADD
:
34 return llvm::Intrinsic::nvvm_redux_sync_add
;
35 case NVVM::ReduxKind::UMAX
:
36 return llvm::Intrinsic::nvvm_redux_sync_umax
;
37 case NVVM::ReduxKind::UMIN
:
38 return llvm::Intrinsic::nvvm_redux_sync_umin
;
39 case NVVM::ReduxKind::AND
:
40 return llvm::Intrinsic::nvvm_redux_sync_and
;
41 case NVVM::ReduxKind::OR
:
42 return llvm::Intrinsic::nvvm_redux_sync_or
;
43 case NVVM::ReduxKind::XOR
:
44 return llvm::Intrinsic::nvvm_redux_sync_xor
;
45 case NVVM::ReduxKind::MAX
:
46 return llvm::Intrinsic::nvvm_redux_sync_max
;
47 case NVVM::ReduxKind::MIN
:
48 return llvm::Intrinsic::nvvm_redux_sync_min
;
50 llvm_unreachable("unknown redux kind");
53 static llvm::Intrinsic::ID
getShflIntrinsicId(llvm::Type
*resultType
,
58 resultType
= cast
<llvm::StructType
>(resultType
)->getElementType(0);
60 case NVVM::ShflKind::bfly
:
61 return resultType
->isFloatTy()
62 ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
63 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p
;
64 case NVVM::ShflKind::up
:
65 return resultType
->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
66 : llvm::Intrinsic::nvvm_shfl_sync_up_i32p
;
67 case NVVM::ShflKind::down
:
68 return resultType
->isFloatTy()
69 ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
70 : llvm::Intrinsic::nvvm_shfl_sync_down_i32p
;
71 case NVVM::ShflKind::idx
:
72 return resultType
->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
73 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p
;
77 case NVVM::ShflKind::bfly
:
78 return resultType
->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
79 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32
;
80 case NVVM::ShflKind::up
:
81 return resultType
->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
82 : llvm::Intrinsic::nvvm_shfl_sync_up_i32
;
83 case NVVM::ShflKind::down
:
84 return resultType
->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
85 : llvm::Intrinsic::nvvm_shfl_sync_down_i32
;
86 case NVVM::ShflKind::idx
:
87 return resultType
->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
88 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32
;
91 llvm_unreachable("unknown shuffle kind");
94 /// Return the intrinsic ID associated with ldmatrix for the given paramters.
95 static llvm::Intrinsic::ID
getLdMatrixIntrinsicId(NVVM::MMALayout layout
,
97 if (layout
== NVVM::MMALayout::row
) {
100 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16
;
102 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16
;
104 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16
;
106 llvm_unreachable("unsupported number of matrix");
112 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16
;
114 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16
;
116 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16
;
118 llvm_unreachable("unsupported number of matrix");
123 static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy
,
124 NVVM::ProxyKind toProxy
,
125 NVVM::MemScopeKind scope
,
127 if (fromProxy
== NVVM::ProxyKind::GENERIC
&&
128 toProxy
== NVVM::ProxyKind::TENSORMAP
) {
130 case NVVM::MemScopeKind::CTA
: {
132 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_cta
;
133 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_cta
;
135 case NVVM::MemScopeKind::CLUSTER
: {
137 return llvm::Intrinsic::
138 nvvm_fence_proxy_tensormap_generic_release_cluster
;
139 return llvm::Intrinsic::
140 nvvm_fence_proxy_tensormap_generic_acquire_cluster
;
142 case NVVM::MemScopeKind::GPU
: {
144 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_gpu
;
145 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_gpu
;
147 case NVVM::MemScopeKind::SYS
: {
149 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_sys
;
150 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_sys
;
153 llvm_unreachable("Unknown scope for uni-directional fence.proxy operation");
155 llvm_unreachable("Unsupported proxy kinds");
159 /// Implementation of the dialect interface that converts operations belonging
160 /// to the NVVM dialect to LLVM IR.
161 class NVVMDialectLLVMIRTranslationInterface
162 : public LLVMTranslationDialectInterface
{
164 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface
;
166 /// Translates the given operation to LLVM IR using the provided IR builder
167 /// and saving the state in `moduleTranslation`.
169 convertOperation(Operation
*op
, llvm::IRBuilderBase
&builder
,
170 LLVM::ModuleTranslation
&moduleTranslation
) const final
{
171 Operation
&opInst
= *op
;
172 #include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
177 /// Attaches module-level metadata for functions marked as kernels.
179 amendOperation(Operation
*op
, ArrayRef
<llvm::Instruction
*> instructions
,
180 NamedAttribute attribute
,
181 LLVM::ModuleTranslation
&moduleTranslation
) const final
{
182 auto func
= dyn_cast
<LLVM::LLVMFuncOp
>(op
);
185 llvm::LLVMContext
&llvmContext
= moduleTranslation
.getLLVMContext();
186 llvm::Function
*llvmFunc
= moduleTranslation
.lookupFunction(func
.getName());
188 auto generateMetadata
= [&](int dim
, StringRef name
) {
189 llvm::Metadata
*llvmMetadata
[] = {
190 llvm::ValueAsMetadata::get(llvmFunc
),
191 llvm::MDString::get(llvmContext
, name
),
192 llvm::ValueAsMetadata::get(llvm::ConstantInt::get(
193 llvm::Type::getInt32Ty(llvmContext
), dim
))};
194 llvm::MDNode
*llvmMetadataNode
=
195 llvm::MDNode::get(llvmContext
, llvmMetadata
);
196 moduleTranslation
.getOrInsertNamedModuleMetadata("nvvm.annotations")
197 ->addOperand(llvmMetadataNode
);
199 if (attribute
.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
200 if (!dyn_cast
<DenseI32ArrayAttr
>(attribute
.getValue()))
202 auto values
= cast
<DenseI32ArrayAttr
>(attribute
.getValue());
203 generateMetadata(values
[0], NVVM::NVVMDialect::getMaxntidXName());
204 if (values
.size() > 1)
205 generateMetadata(values
[1], NVVM::NVVMDialect::getMaxntidYName());
206 if (values
.size() > 2)
207 generateMetadata(values
[2], NVVM::NVVMDialect::getMaxntidZName());
208 } else if (attribute
.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
209 if (!dyn_cast
<DenseI32ArrayAttr
>(attribute
.getValue()))
211 auto values
= cast
<DenseI32ArrayAttr
>(attribute
.getValue());
212 generateMetadata(values
[0], NVVM::NVVMDialect::getReqntidXName());
213 if (values
.size() > 1)
214 generateMetadata(values
[1], NVVM::NVVMDialect::getReqntidYName());
215 if (values
.size() > 2)
216 generateMetadata(values
[2], NVVM::NVVMDialect::getReqntidZName());
217 } else if (attribute
.getName() ==
218 NVVM::NVVMDialect::getClusterDimAttrName()) {
219 if (!dyn_cast
<DenseI32ArrayAttr
>(attribute
.getValue()))
221 auto values
= cast
<DenseI32ArrayAttr
>(attribute
.getValue());
222 generateMetadata(values
[0], NVVM::NVVMDialect::getClusterDimXName());
223 if (values
.size() > 1)
224 generateMetadata(values
[1], NVVM::NVVMDialect::getClusterDimYName());
225 if (values
.size() > 2)
226 generateMetadata(values
[2], NVVM::NVVMDialect::getClusterDimZName());
227 } else if (attribute
.getName() ==
228 NVVM::NVVMDialect::getClusterMaxBlocksAttrName()) {
229 auto value
= dyn_cast
<IntegerAttr
>(attribute
.getValue());
230 generateMetadata(value
.getInt(), "cluster_max_blocks");
231 } else if (attribute
.getName() ==
232 NVVM::NVVMDialect::getMinctasmAttrName()) {
233 auto value
= dyn_cast
<IntegerAttr
>(attribute
.getValue());
234 generateMetadata(value
.getInt(), "minctasm");
235 } else if (attribute
.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) {
236 auto value
= dyn_cast
<IntegerAttr
>(attribute
.getValue());
237 generateMetadata(value
.getInt(), "maxnreg");
238 } else if (attribute
.getName() ==
239 NVVM::NVVMDialect::getKernelFuncAttrName()) {
240 llvm::Metadata
*llvmMetadataKernel
[] = {
241 llvm::ValueAsMetadata::get(llvmFunc
),
242 llvm::MDString::get(llvmContext
, "kernel"),
243 llvm::ValueAsMetadata::get(
244 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext
), 1))};
245 llvm::MDNode
*llvmMetadataNode
=
246 llvm::MDNode::get(llvmContext
, llvmMetadataKernel
);
247 moduleTranslation
.getOrInsertNamedModuleMetadata("nvvm.annotations")
248 ->addOperand(llvmMetadataNode
);
254 convertParameterAttr(LLVMFuncOp funcOp
, int argIdx
, NamedAttribute attribute
,
255 LLVM::ModuleTranslation
&moduleTranslation
) const final
{
257 llvm::LLVMContext
&llvmContext
= moduleTranslation
.getLLVMContext();
258 llvm::Function
*llvmFunc
=
259 moduleTranslation
.lookupFunction(funcOp
.getName());
260 llvm::NamedMDNode
*nvvmAnnotations
=
261 moduleTranslation
.getOrInsertNamedModuleMetadata("nvvm.annotations");
263 if (attribute
.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
264 llvm::MDNode
*gridConstantMetaData
= nullptr;
266 // Check if a 'grid_constant' metadata node exists for the given function
267 for (llvm::MDNode
*opnd
: llvm::reverse(nvvmAnnotations
->operands())) {
268 if (opnd
->getNumOperands() == 3 &&
269 opnd
->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc
) &&
270 opnd
->getOperand(1) ==
271 llvm::MDString::get(llvmContext
, "grid_constant")) {
272 gridConstantMetaData
= opnd
;
277 // 'grid_constant' is a function-level meta data node with a list of
278 // integers, where each integer n denotes that the nth parameter has the
279 // grid_constant annotation (numbering from 1). This requires aggregating
280 // the indices of the individual parameters that have this attribute.
281 llvm::Type
*i32
= llvm::IntegerType::get(llvmContext
, 32);
282 if (gridConstantMetaData
== nullptr) {
283 // Create a new 'grid_constant' metadata node
284 SmallVector
<llvm::Metadata
*> gridConstMetadata
= {
285 llvm::ValueAsMetadata::getConstant(
286 llvm::ConstantInt::get(i32
, argIdx
+ 1))};
287 llvm::Metadata
*llvmMetadata
[] = {
288 llvm::ValueAsMetadata::get(llvmFunc
),
289 llvm::MDString::get(llvmContext
, "grid_constant"),
290 llvm::MDNode::get(llvmContext
, gridConstMetadata
)};
291 llvm::MDNode
*llvmMetadataNode
=
292 llvm::MDNode::get(llvmContext
, llvmMetadata
);
293 nvvmAnnotations
->addOperand(llvmMetadataNode
);
295 // Append argIdx + 1 to the 'grid_constant' argument list
297 dyn_cast
<llvm::MDTuple
>(gridConstantMetaData
->getOperand(2))) {
298 llvm::TempMDTuple clonedArgList
= argList
->clone();
299 clonedArgList
->push_back((llvm::ValueAsMetadata::getConstant(
300 llvm::ConstantInt::get(i32
, argIdx
+ 1))));
301 gridConstantMetaData
->replaceOperandWith(
302 2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList
)));
311 void mlir::registerNVVMDialectTranslation(DialectRegistry
®istry
) {
312 registry
.insert
<NVVM::NVVMDialect
>();
313 registry
.addExtension(+[](MLIRContext
*ctx
, NVVM::NVVMDialect
*dialect
) {
314 dialect
->addInterfaces
<NVVMDialectLLVMIRTranslationInterface
>();
318 void mlir::registerNVVMDialectTranslation(MLIRContext
&context
) {
319 DialectRegistry registry
;
320 registerNVVMDialectTranslation(registry
);
321 context
.appendDialectRegistry(registry
);