[mlir][nvvm] Add attributes for cluster dimension PTX directives (#116973)
[llvm-project.git] / mlir / lib / Target / LLVMIR / Dialect / NVVM / NVVMToLLVMIRTranslation.cpp
blobcf58bc5d8f475a73c9bdd0f6f3ecb27fc86086d8
1 //===- NVVMToLLVMIRTranslation.cpp - Translate NVVM to LLVM IR ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR NVVM dialect and
10 // LLVM IR.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
15 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
16 #include "mlir/Dialect/Utils/StaticValueUtils.h"
17 #include "mlir/IR/Operation.h"
18 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
20 #include "llvm/IR/IRBuilder.h"
21 #include "llvm/IR/IntrinsicsNVPTX.h"
23 using namespace mlir;
24 using namespace mlir::LLVM;
25 using mlir::LLVM::detail::createIntrinsicCall;
27 static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType,
28 NVVM::ReduxKind kind) {
29 if (!resultType->isIntegerTy(32))
30 llvm_unreachable("unsupported data type for redux");
32 switch (kind) {
33 case NVVM::ReduxKind::ADD:
34 return llvm::Intrinsic::nvvm_redux_sync_add;
35 case NVVM::ReduxKind::UMAX:
36 return llvm::Intrinsic::nvvm_redux_sync_umax;
37 case NVVM::ReduxKind::UMIN:
38 return llvm::Intrinsic::nvvm_redux_sync_umin;
39 case NVVM::ReduxKind::AND:
40 return llvm::Intrinsic::nvvm_redux_sync_and;
41 case NVVM::ReduxKind::OR:
42 return llvm::Intrinsic::nvvm_redux_sync_or;
43 case NVVM::ReduxKind::XOR:
44 return llvm::Intrinsic::nvvm_redux_sync_xor;
45 case NVVM::ReduxKind::MAX:
46 return llvm::Intrinsic::nvvm_redux_sync_max;
47 case NVVM::ReduxKind::MIN:
48 return llvm::Intrinsic::nvvm_redux_sync_min;
50 llvm_unreachable("unknown redux kind");
53 static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
54 NVVM::ShflKind kind,
55 bool withPredicate) {
57 if (withPredicate) {
58 resultType = cast<llvm::StructType>(resultType)->getElementType(0);
59 switch (kind) {
60 case NVVM::ShflKind::bfly:
61 return resultType->isFloatTy()
62 ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32p
63 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32p;
64 case NVVM::ShflKind::up:
65 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32p
66 : llvm::Intrinsic::nvvm_shfl_sync_up_i32p;
67 case NVVM::ShflKind::down:
68 return resultType->isFloatTy()
69 ? llvm::Intrinsic::nvvm_shfl_sync_down_f32p
70 : llvm::Intrinsic::nvvm_shfl_sync_down_i32p;
71 case NVVM::ShflKind::idx:
72 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32p
73 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32p;
75 } else {
76 switch (kind) {
77 case NVVM::ShflKind::bfly:
78 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_bfly_f32
79 : llvm::Intrinsic::nvvm_shfl_sync_bfly_i32;
80 case NVVM::ShflKind::up:
81 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_up_f32
82 : llvm::Intrinsic::nvvm_shfl_sync_up_i32;
83 case NVVM::ShflKind::down:
84 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_down_f32
85 : llvm::Intrinsic::nvvm_shfl_sync_down_i32;
86 case NVVM::ShflKind::idx:
87 return resultType->isFloatTy() ? llvm::Intrinsic::nvvm_shfl_sync_idx_f32
88 : llvm::Intrinsic::nvvm_shfl_sync_idx_i32;
91 llvm_unreachable("unknown shuffle kind");
94 /// Return the intrinsic ID associated with ldmatrix for the given paramters.
95 static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
96 int32_t num) {
97 if (layout == NVVM::MMALayout::row) {
98 switch (num) {
99 case 1:
100 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
101 case 2:
102 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
103 case 4:
104 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
105 default:
106 llvm_unreachable("unsupported number of matrix");
109 } else {
110 switch (num) {
111 case 1:
112 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
113 case 2:
114 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
115 case 4:
116 return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
117 default:
118 llvm_unreachable("unsupported number of matrix");
123 static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy,
124 NVVM::ProxyKind toProxy,
125 NVVM::MemScopeKind scope,
126 bool isRelease) {
127 if (fromProxy == NVVM::ProxyKind::GENERIC &&
128 toProxy == NVVM::ProxyKind::TENSORMAP) {
129 switch (scope) {
130 case NVVM::MemScopeKind::CTA: {
131 if (isRelease)
132 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_cta;
133 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_cta;
135 case NVVM::MemScopeKind::CLUSTER: {
136 if (isRelease)
137 return llvm::Intrinsic::
138 nvvm_fence_proxy_tensormap_generic_release_cluster;
139 return llvm::Intrinsic::
140 nvvm_fence_proxy_tensormap_generic_acquire_cluster;
142 case NVVM::MemScopeKind::GPU: {
143 if (isRelease)
144 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_gpu;
145 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_gpu;
147 case NVVM::MemScopeKind::SYS: {
148 if (isRelease)
149 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_release_sys;
150 return llvm::Intrinsic::nvvm_fence_proxy_tensormap_generic_acquire_sys;
153 llvm_unreachable("Unknown scope for uni-directional fence.proxy operation");
155 llvm_unreachable("Unsupported proxy kinds");
158 namespace {
159 /// Implementation of the dialect interface that converts operations belonging
160 /// to the NVVM dialect to LLVM IR.
161 class NVVMDialectLLVMIRTranslationInterface
162 : public LLVMTranslationDialectInterface {
163 public:
164 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
166 /// Translates the given operation to LLVM IR using the provided IR builder
167 /// and saving the state in `moduleTranslation`.
168 LogicalResult
169 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
170 LLVM::ModuleTranslation &moduleTranslation) const final {
171 Operation &opInst = *op;
172 #include "mlir/Dialect/LLVMIR/NVVMConversions.inc"
174 return failure();
177 /// Attaches module-level metadata for functions marked as kernels.
178 LogicalResult
179 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
180 NamedAttribute attribute,
181 LLVM::ModuleTranslation &moduleTranslation) const final {
182 auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
183 if (!func)
184 return failure();
185 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
186 llvm::Function *llvmFunc = moduleTranslation.lookupFunction(func.getName());
188 auto generateMetadata = [&](int dim, StringRef name) {
189 llvm::Metadata *llvmMetadata[] = {
190 llvm::ValueAsMetadata::get(llvmFunc),
191 llvm::MDString::get(llvmContext, name),
192 llvm::ValueAsMetadata::get(llvm::ConstantInt::get(
193 llvm::Type::getInt32Ty(llvmContext), dim))};
194 llvm::MDNode *llvmMetadataNode =
195 llvm::MDNode::get(llvmContext, llvmMetadata);
196 moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations")
197 ->addOperand(llvmMetadataNode);
199 if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
200 if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
201 return failure();
202 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
203 generateMetadata(values[0], NVVM::NVVMDialect::getMaxntidXName());
204 if (values.size() > 1)
205 generateMetadata(values[1], NVVM::NVVMDialect::getMaxntidYName());
206 if (values.size() > 2)
207 generateMetadata(values[2], NVVM::NVVMDialect::getMaxntidZName());
208 } else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
209 if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
210 return failure();
211 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
212 generateMetadata(values[0], NVVM::NVVMDialect::getReqntidXName());
213 if (values.size() > 1)
214 generateMetadata(values[1], NVVM::NVVMDialect::getReqntidYName());
215 if (values.size() > 2)
216 generateMetadata(values[2], NVVM::NVVMDialect::getReqntidZName());
217 } else if (attribute.getName() ==
218 NVVM::NVVMDialect::getClusterDimAttrName()) {
219 if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
220 return failure();
221 auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
222 generateMetadata(values[0], NVVM::NVVMDialect::getClusterDimXName());
223 if (values.size() > 1)
224 generateMetadata(values[1], NVVM::NVVMDialect::getClusterDimYName());
225 if (values.size() > 2)
226 generateMetadata(values[2], NVVM::NVVMDialect::getClusterDimZName());
227 } else if (attribute.getName() ==
228 NVVM::NVVMDialect::getClusterMaxBlocksAttrName()) {
229 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
230 generateMetadata(value.getInt(), "cluster_max_blocks");
231 } else if (attribute.getName() ==
232 NVVM::NVVMDialect::getMinctasmAttrName()) {
233 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
234 generateMetadata(value.getInt(), "minctasm");
235 } else if (attribute.getName() == NVVM::NVVMDialect::getMaxnregAttrName()) {
236 auto value = dyn_cast<IntegerAttr>(attribute.getValue());
237 generateMetadata(value.getInt(), "maxnreg");
238 } else if (attribute.getName() ==
239 NVVM::NVVMDialect::getKernelFuncAttrName()) {
240 llvm::Metadata *llvmMetadataKernel[] = {
241 llvm::ValueAsMetadata::get(llvmFunc),
242 llvm::MDString::get(llvmContext, "kernel"),
243 llvm::ValueAsMetadata::get(
244 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 1))};
245 llvm::MDNode *llvmMetadataNode =
246 llvm::MDNode::get(llvmContext, llvmMetadataKernel);
247 moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations")
248 ->addOperand(llvmMetadataNode);
250 return success();
253 LogicalResult
254 convertParameterAttr(LLVMFuncOp funcOp, int argIdx, NamedAttribute attribute,
255 LLVM::ModuleTranslation &moduleTranslation) const final {
257 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
258 llvm::Function *llvmFunc =
259 moduleTranslation.lookupFunction(funcOp.getName());
260 llvm::NamedMDNode *nvvmAnnotations =
261 moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations");
263 if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) {
264 llvm::MDNode *gridConstantMetaData = nullptr;
266 // Check if a 'grid_constant' metadata node exists for the given function
267 for (llvm::MDNode *opnd : llvm::reverse(nvvmAnnotations->operands())) {
268 if (opnd->getNumOperands() == 3 &&
269 opnd->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc) &&
270 opnd->getOperand(1) ==
271 llvm::MDString::get(llvmContext, "grid_constant")) {
272 gridConstantMetaData = opnd;
273 break;
277 // 'grid_constant' is a function-level meta data node with a list of
278 // integers, where each integer n denotes that the nth parameter has the
279 // grid_constant annotation (numbering from 1). This requires aggregating
280 // the indices of the individual parameters that have this attribute.
281 llvm::Type *i32 = llvm::IntegerType::get(llvmContext, 32);
282 if (gridConstantMetaData == nullptr) {
283 // Create a new 'grid_constant' metadata node
284 SmallVector<llvm::Metadata *> gridConstMetadata = {
285 llvm::ValueAsMetadata::getConstant(
286 llvm::ConstantInt::get(i32, argIdx + 1))};
287 llvm::Metadata *llvmMetadata[] = {
288 llvm::ValueAsMetadata::get(llvmFunc),
289 llvm::MDString::get(llvmContext, "grid_constant"),
290 llvm::MDNode::get(llvmContext, gridConstMetadata)};
291 llvm::MDNode *llvmMetadataNode =
292 llvm::MDNode::get(llvmContext, llvmMetadata);
293 nvvmAnnotations->addOperand(llvmMetadataNode);
294 } else {
295 // Append argIdx + 1 to the 'grid_constant' argument list
296 if (auto argList =
297 dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) {
298 llvm::TempMDTuple clonedArgList = argList->clone();
299 clonedArgList->push_back((llvm::ValueAsMetadata::getConstant(
300 llvm::ConstantInt::get(i32, argIdx + 1))));
301 gridConstantMetaData->replaceOperandWith(
302 2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList)));
306 return success();
309 } // namespace
311 void mlir::registerNVVMDialectTranslation(DialectRegistry &registry) {
312 registry.insert<NVVM::NVVMDialect>();
313 registry.addExtension(+[](MLIRContext *ctx, NVVM::NVVMDialect *dialect) {
314 dialect->addInterfaces<NVVMDialectLLVMIRTranslationInterface>();
318 void mlir::registerNVVMDialectTranslation(MLIRContext &context) {
319 DialectRegistry registry;
320 registerNVVMDialectTranslation(registry);
321 context.appendDialectRegistry(registry);