[rtsan] Remove mkfifoat interceptor (#116997)
[llvm-project.git] / mlir / lib / ExecutionEngine / CudaRuntimeWrappers.cpp
blob6a32309aa9e052117abef3ed6fc405071c07243c
1 //===- CudaRuntimeWrappers.cpp - MLIR CUDA API wrapper library ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implements C wrappers around the CUDA library for easy linking in ORC jit.
10 // Also adds some debugging helpers that are helpful when writing MLIR code to
11 // run on GPUs.
13 //===----------------------------------------------------------------------===//
15 #include "mlir/ExecutionEngine/CRunnerUtils.h"
17 #include <stdio.h>
19 #include "cuda.h"
20 #include "cuda_bf16.h"
21 #include "cuda_fp16.h"
23 #ifdef MLIR_ENABLE_CUDA_CUSPARSE
24 #include "cusparse.h"
25 #ifdef MLIR_ENABLE_CUDA_CUSPARSELT
26 #include "cusparseLt.h"
27 #endif // MLIR_ENABLE_CUDA_CUSPARSELT
28 #endif // MLIR_ENABLE_CUDA_CUSPARSE
30 #ifdef _WIN32
31 #include <malloc.h>
32 #define MLIR_CUDA_WRAPPERS_EXPORT __declspec(dllexport)
33 #else
34 #define MLIR_CUDA_WRAPPERS_EXPORT __attribute__((visibility("default")))
35 #endif // _WIN32
37 #define CUDA_REPORT_IF_ERROR(expr) \
38 [](CUresult result) { \
39 if (!result) \
40 return; \
41 const char *name = nullptr; \
42 cuGetErrorName(result, &name); \
43 if (!name) \
44 name = "<unknown>"; \
45 fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \
46 }(expr)
48 #define CUSPARSE_REPORT_IF_ERROR(expr) \
49 { \
50 cusparseStatus_t status = (expr); \
51 if (status != CUSPARSE_STATUS_SUCCESS) { \
52 fprintf(stderr, "cuSPARSE '%s' failed with '%s'\n", #expr, \
53 cusparseGetErrorString(status)); \
54 } \
57 thread_local static int32_t defaultDevice = 0;
59 const char *kDebugEnvironmentVariable = "MLIR_CUDA_DEBUG";
61 /// Helper method that checks environment value for debugging.
62 bool isDebugEnabled() {
63 static bool isInitialized = false;
64 static bool isEnabled = false;
65 if (!isInitialized)
66 isEnabled = getenv(kDebugEnvironmentVariable) != nullptr;
67 return isEnabled;
70 #define debug_print(fmt, ...) \
71 do { \
72 if (isDebugEnabled()) \
73 fprintf(stderr, "%s:%d:%s(): " fmt, "CudaRuntimeWrappers.cpp", __LINE__, \
74 __func__, __VA_ARGS__); \
75 } while (0)
77 // Returns default CUdevice
78 CUdevice getDefaultCuDevice() {
79 CUdevice device;
80 CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice));
81 return device;
84 // Make the primary context of the current default device current for the
85 // duration
86 // of the instance and restore the previous context on destruction.
87 class ScopedContext {
88 public:
89 ScopedContext() {
90 // Static reference to CUDA primary context for device ordinal
91 // defaultDevice.
92 static CUcontext context = [] {
93 CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0));
94 CUcontext ctx;
95 // Note: this does not affect the current context.
96 CUDA_REPORT_IF_ERROR(
97 cuDevicePrimaryCtxRetain(&ctx, getDefaultCuDevice()));
98 return ctx;
99 }();
101 CUDA_REPORT_IF_ERROR(cuCtxPushCurrent(context));
104 ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxPopCurrent(nullptr)); }
107 #ifdef MLIR_ENABLE_CUDA_CUSPARSE
108 // Note that (1) Nvidia confirms the safety to share handle across multiple
109 // instances, and streams. (2) Clients are responsible to call the @mgpu
110 // environment initialization/destruction in a thread-safe manner, e.g.,
111 // at the beginning of the program before multi-threads are created.
112 static cusparseHandle_t cusparse_env = nullptr;
114 #ifdef MLIR_ENABLE_CUDA_CUSPARSELT
115 // cusparseLtHandle_t is not a pointer type, so we need an additional flag to
116 // indicate whether it is initialized.
117 static cusparseLtHandle_t cusparseLt_env;
118 static bool cusparseLt_initiated = false;
120 #endif // MLIR_ENABLE_CUDA_CUSPARSELT
121 #endif // MLIR_ENABLE_CUDA_CUSPARSE
123 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule
124 mgpuModuleLoad(void *data, size_t /*gpuBlobSize*/) {
125 ScopedContext scopedContext;
126 CUmodule module = nullptr;
127 CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data));
128 return module;
131 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoadJIT(void *data,
132 int optLevel) {
133 ScopedContext scopedContext;
134 CUmodule module = nullptr;
135 char jitErrorBuffer[4096] = {0};
136 CUjit_option jitOptions[] = {CU_JIT_ERROR_LOG_BUFFER,
137 CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
138 CU_JIT_OPTIMIZATION_LEVEL};
139 void *jitOptionsVals[] = {jitErrorBuffer,
140 reinterpret_cast<void *>(sizeof(jitErrorBuffer)),
141 reinterpret_cast<void *>(optLevel)};
143 CUresult result =
144 cuModuleLoadDataEx(&module, data, 3, jitOptions, jitOptionsVals);
145 if (result) {
146 fprintf(stderr, "JIT compilation failed with: '%s'\n", jitErrorBuffer);
147 CUDA_REPORT_IF_ERROR(result);
149 return module;
152 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuModuleUnload(CUmodule module) {
153 CUDA_REPORT_IF_ERROR(cuModuleUnload(module));
156 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUfunction
157 mgpuModuleGetFunction(CUmodule module, const char *name) {
158 CUfunction function = nullptr;
159 CUDA_REPORT_IF_ERROR(cuModuleGetFunction(&function, module, name));
160 return function;
163 // The wrapper uses intptr_t instead of CUDA's unsigned int to match
164 // the type of MLIR's index type. This avoids the need for casts in the
165 // generated MLIR code.
166 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
167 mgpuLaunchKernel(CUfunction function, intptr_t gridX, intptr_t gridY,
168 intptr_t gridZ, intptr_t blockX, intptr_t blockY,
169 intptr_t blockZ, int32_t smem, CUstream stream, void **params,
170 void **extra, size_t /*paramsCount*/) {
171 ScopedContext scopedContext;
172 if (smem > 0) {
173 // Avoid checking driver as it's more expensive than if statement
174 int32_t maxShmem = 0;
175 CUdevice device = getDefaultCuDevice();
176 CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice));
177 CUDA_REPORT_IF_ERROR(cuDeviceGetAttribute(
178 &maxShmem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
179 device));
180 if (maxShmem < smem) {
181 fprintf(stderr,
182 "Requested shared memory (%dkb) is larger than maximum allowed "
183 "shared memory (%dkb) for this device\n",
184 smem, maxShmem);
186 CUDA_REPORT_IF_ERROR(cuFuncSetAttribute(
187 function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem));
189 debug_print("Launching kernel, grid=%ld,%ld,%ld, "
190 "threads: %ld, %ld, %ld, "
191 "smem: %dkb\n",
192 gridX, gridY, gridZ, blockX, blockY, blockZ, smem);
193 CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX,
194 blockY, blockZ, smem, stream, params,
195 extra));
198 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUstream mgpuStreamCreate() {
199 ScopedContext scopedContext;
200 CUstream stream = nullptr;
201 CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING));
202 return stream;
205 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamDestroy(CUstream stream) {
206 CUDA_REPORT_IF_ERROR(cuStreamDestroy(stream));
209 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
210 mgpuStreamSynchronize(CUstream stream) {
211 CUDA_REPORT_IF_ERROR(cuStreamSynchronize(stream));
214 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamWaitEvent(CUstream stream,
215 CUevent event) {
216 CUDA_REPORT_IF_ERROR(cuStreamWaitEvent(stream, event, /*flags=*/0));
219 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUevent mgpuEventCreate() {
220 ScopedContext scopedContext;
221 CUevent event = nullptr;
222 CUDA_REPORT_IF_ERROR(cuEventCreate(&event, CU_EVENT_DISABLE_TIMING));
223 return event;
226 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventDestroy(CUevent event) {
227 CUDA_REPORT_IF_ERROR(cuEventDestroy(event));
230 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventSynchronize(CUevent event) {
231 CUDA_REPORT_IF_ERROR(cuEventSynchronize(event));
234 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventRecord(CUevent event,
235 CUstream stream) {
236 CUDA_REPORT_IF_ERROR(cuEventRecord(event, stream));
239 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
240 mgpuMemAlloc(uint64_t sizeBytes, CUstream stream, bool isHostShared) {
241 ScopedContext scopedContext;
242 CUdeviceptr ptr = 0;
243 if (sizeBytes == 0)
244 return reinterpret_cast<void *>(ptr);
246 if (isHostShared) {
247 CUDA_REPORT_IF_ERROR(
248 cuMemAllocManaged(&ptr, sizeBytes, CU_MEM_ATTACH_GLOBAL));
249 return reinterpret_cast<void *>(ptr);
251 CUDA_REPORT_IF_ERROR(cuMemAlloc(&ptr, sizeBytes));
252 return reinterpret_cast<void *>(ptr);
255 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuMemFree(void *ptr,
256 CUstream /*stream*/) {
257 CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast<CUdeviceptr>(ptr)));
260 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
261 mgpuMemcpy(void *dst, void *src, size_t sizeBytes, CUstream stream) {
262 CUDA_REPORT_IF_ERROR(cuMemcpyAsync(reinterpret_cast<CUdeviceptr>(dst),
263 reinterpret_cast<CUdeviceptr>(src),
264 sizeBytes, stream));
267 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
268 mgpuMemset32(void *dst, unsigned int value, size_t count, CUstream stream) {
269 CUDA_REPORT_IF_ERROR(cuMemsetD32Async(reinterpret_cast<CUdeviceptr>(dst),
270 value, count, stream));
273 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
274 mgpuMemset16(void *dst, unsigned short value, size_t count, CUstream stream) {
275 CUDA_REPORT_IF_ERROR(cuMemsetD16Async(reinterpret_cast<CUdeviceptr>(dst),
276 value, count, stream));
280 /// Helper functions for writing mlir example code
283 // Allows to register byte array with the CUDA runtime. Helpful until we have
284 // transfer functions implemented.
285 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
286 mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
287 ScopedContext scopedContext;
288 CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0));
291 /// Registers a memref with the CUDA runtime. `descriptor` is a pointer to a
292 /// ranked memref descriptor struct of rank `rank`. Helpful until we have
293 /// transfer functions implemented.
294 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
295 mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor,
296 int64_t elementSizeBytes) {
297 // Only densely packed tensors are currently supported.
298 #ifdef _WIN32
299 int64_t *denseStrides = (int64_t *)_alloca(rank * sizeof(int64_t));
300 #else
301 int64_t *denseStrides = (int64_t *)alloca(rank * sizeof(int64_t));
302 #endif // _WIN32
303 int64_t *sizes = descriptor->sizes;
304 for (int64_t i = rank - 1, runningStride = 1; i >= 0; i--) {
305 denseStrides[i] = runningStride;
306 runningStride *= sizes[i];
308 uint64_t sizeBytes = sizes[0] * denseStrides[0] * elementSizeBytes;
309 int64_t *strides = &sizes[rank];
310 (void)strides;
311 for (unsigned i = 0; i < rank; ++i)
312 assert(strides[i] == denseStrides[i] &&
313 "Mismatch in computed dense strides");
315 auto *ptr = descriptor->data + descriptor->offset * elementSizeBytes;
316 mgpuMemHostRegister(ptr, sizeBytes);
319 // Allows to unregister byte array with the CUDA runtime.
320 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuMemHostUnregister(void *ptr) {
321 ScopedContext scopedContext;
322 CUDA_REPORT_IF_ERROR(cuMemHostUnregister(ptr));
325 /// Unregisters a memref with the CUDA runtime. `descriptor` is a pointer to a
326 /// ranked memref descriptor struct of rank `rank`
327 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
328 mgpuMemHostUnregisterMemRef(int64_t rank,
329 StridedMemRefType<char, 1> *descriptor,
330 int64_t elementSizeBytes) {
331 auto *ptr = descriptor->data + descriptor->offset * elementSizeBytes;
332 mgpuMemHostUnregister(ptr);
335 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) {
336 defaultDevice = device;
340 /// Runtime methods using CUDA 12.0+ driver
343 #if (CUDA_VERSION >= 12000)
345 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuLaunchClusterKernel(
346 CUfunction function, intptr_t clusterX, intptr_t clusterY,
347 intptr_t clusterZ, intptr_t gridX, intptr_t gridY, intptr_t gridZ,
348 intptr_t blockX, intptr_t blockY, intptr_t blockZ, int32_t smem,
349 CUstream stream, void **params, void **extra, size_t /*paramsCount*/) {
350 ScopedContext scopedContext;
351 if (smem > 0) {
352 // Avoid checking driver as it's more expensive than if statement
353 int32_t maxShmem = 0;
354 CUdevice device = getDefaultCuDevice();
355 CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice));
356 CUDA_REPORT_IF_ERROR(cuDeviceGetAttribute(
357 &maxShmem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
358 device));
359 if (maxShmem < smem) {
360 fprintf(stderr,
361 "Requested shared memory (%dkb) is larger than maximum allowed "
362 "shared memory (%dkb) for this device\n",
363 smem, maxShmem);
365 CUDA_REPORT_IF_ERROR(cuFuncSetAttribute(
366 function, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem));
368 CUlaunchConfig config;
369 config.gridDimX = gridX;
370 config.gridDimY = gridY;
371 config.gridDimZ = gridZ;
372 config.blockDimX = blockX;
373 config.blockDimY = blockY;
374 config.blockDimZ = blockZ;
375 config.sharedMemBytes = smem;
376 config.hStream = stream;
377 CUlaunchAttribute launchAttr[2];
378 launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
379 launchAttr[0].value.clusterDim.x = clusterX;
380 launchAttr[0].value.clusterDim.y = clusterY;
381 launchAttr[0].value.clusterDim.z = clusterZ;
382 launchAttr[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
383 launchAttr[1].value.clusterSchedulingPolicyPreference =
384 CU_CLUSTER_SCHEDULING_POLICY_SPREAD;
385 config.numAttrs = 2;
386 config.attrs = launchAttr;
388 debug_print("Launching kernel,"
389 "cluster: %ld, %ld, %ld, "
390 "grid=%ld,%ld,%ld, "
391 "threads: %ld, %ld, %ld, "
392 "smem: %dkb\n",
393 clusterX, clusterY, clusterZ, gridX, gridY, gridZ, blockX, blockY,
394 blockZ, smem);
396 CUDA_REPORT_IF_ERROR(cuLaunchKernelEx(&config, function, params, extra));
399 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuTensorMapEncodeTiled(
400 CUtensorMap *tensorMap, // Tensor map object
401 CUtensorMapDataType tensorDataType, // Tensor data type
402 cuuint32_t tensorRank, // Dimensionality of tensor
403 void *globalAddress, // Starting address
404 const cuuint64_t *globalDim, // Tensor size (number of elements)
405 const cuuint64_t *globalStrides, // Stride size (in bytes)
406 const cuuint32_t *boxDim, // Traversal box (number of elments)
407 const cuuint32_t *elementStrides, // Traversal stride
408 CUtensorMapInterleave interleave, // Type of interleaved layout
409 CUtensorMapSwizzle swizzle, // Bank swizzling pattern
410 CUtensorMapL2promotion l2Promotion, // L2 promotion size
411 CUtensorMapFloatOOBfill oobFill // Padding zfill or NaN fill
413 ScopedContext scopedContext;
414 CUDA_REPORT_IF_ERROR(cuTensorMapEncodeTiled(
415 tensorMap, tensorDataType, tensorRank, globalAddress, globalDim,
416 globalStrides, boxDim, elementStrides, interleave, swizzle, l2Promotion,
417 oobFill));
418 debug_print("Created TMA descriptor\n Addr: %p\n"
419 "data type : %d\n"
420 "rank : %d\n"
421 "globalDim[5]: %zu, %zu, %zu, %zu, %zu\n"
422 "globalStrides[5]: %zu, %zu, %zu, %zu, %zu\n"
423 "boxDim[5]: %u, %u, %u, %u, %u\n"
424 "elementStrides[5]: %u, %u, %u, %u, %u\n"
425 "interleave: %u \n"
426 "swizzle: %u \n"
427 "l2Promotion: %u \n"
428 "oobFill: %u \n",
429 (void *)&tensorMap, tensorDataType, tensorRank, globalDim[0],
430 globalDim[1], globalDim[2], globalDim[3], globalDim[4],
431 globalStrides[0], globalStrides[1], globalStrides[2],
432 globalStrides[3], globalStrides[4], boxDim[0], boxDim[1],
433 boxDim[2], boxDim[3], boxDim[4], elementStrides[0],
434 elementStrides[1], elementStrides[2], elementStrides[3],
435 elementStrides[4], interleave, swizzle, l2Promotion, oobFill);
438 template <int Rank>
439 void mgpuGetMemRefDataAndShape(void *rawDescriptor, char **addr,
440 uint64_t *globalDim, uint64_t *globalStrides,
441 const CUtensorMapDataType tensorDataType) {
442 auto descriptor =
443 reinterpret_cast<StridedMemRefType<char, Rank> *>(rawDescriptor);
444 *addr = descriptor->data;
445 for (int i = 0; i < Rank; ++i) {
446 globalDim[i] = static_cast<uint64_t>(descriptor->sizes[Rank - i - 1]);
448 static constexpr int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2,
449 4, 8, 2, 4, 4, 4};
450 for (int i = 0; i < Rank - 1; ++i) {
451 globalStrides[i] = static_cast<uint64_t>(
452 descriptor->strides[Rank - i - 2] * elementSizeInBytes[tensorDataType]);
456 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref(
457 int64_t tensorRank, // Dimensionality of tensor
458 void *rankedDescriptor, // Ranked MemRef descriptor
459 const CUtensorMapDataType tensorDataType, // Stride size (in bytes)
460 CUtensorMapInterleave interleave, // Type of interleaved layout
461 CUtensorMapSwizzle swizzle, // Bank swizzling pattern
462 CUtensorMapL2promotion l2Promotion, // L2 promotion size
463 CUtensorMapFloatOOBfill oobFill, // Padding zfill or NaN fill
464 int64_t *inputBoxDims // Tensor size (number of elements)
466 CUtensorMap tensorMap;
468 uint32_t boxDim[5] = {1, 1, 1, 1, 1}, elementStrides[5] = {1, 1, 1, 1, 1};
469 uint64_t globalDim[5] = {1, 1, 1, 1, 1}, globalStrides[5] = {0};
470 uint32_t tensorRank32 = uint32_t(tensorRank);
472 char *globalAddress = nullptr;
473 switch (tensorRank) {
474 case 1:
475 mgpuGetMemRefDataAndShape<1>(rankedDescriptor, &globalAddress, globalDim,
476 globalStrides, tensorDataType);
477 break;
478 case 2:
479 mgpuGetMemRefDataAndShape<2>(rankedDescriptor, &globalAddress, globalDim,
480 globalStrides, tensorDataType);
481 break;
482 case 3:
483 mgpuGetMemRefDataAndShape<3>(rankedDescriptor, &globalAddress, globalDim,
484 globalStrides, tensorDataType);
485 break;
486 case 4:
487 mgpuGetMemRefDataAndShape<4>(rankedDescriptor, &globalAddress, globalDim,
488 globalStrides, tensorDataType);
489 break;
490 case 5:
491 mgpuGetMemRefDataAndShape<5>(rankedDescriptor, &globalAddress, globalDim,
492 globalStrides, tensorDataType);
493 break;
494 default:
495 fprintf(
496 stderr,
497 "'mgpuTensorMapEncodeTiledMemref' failed with 'rank is too high'\n");
498 return nullptr;
501 for (int64_t r = 0; r < tensorRank; ++r) {
502 boxDim[r] = static_cast<uint32_t>(inputBoxDims[tensorRank - r - 1]);
505 ScopedContext scopedContext;
506 mgpuTensorMapEncodeTiled(&tensorMap, tensorDataType, tensorRank32,
507 globalAddress, globalDim, globalStrides, boxDim,
508 elementStrides, interleave, swizzle, l2Promotion,
509 oobFill);
510 // Copy created tensor map to device
511 CUdeviceptr dTensorMap;
512 CUDA_REPORT_IF_ERROR(cuMemAlloc(&dTensorMap, sizeof(CUtensorMap)));
513 CUDA_REPORT_IF_ERROR(cuMemcpy(dTensorMap,
514 reinterpret_cast<CUdeviceptr>(&tensorMap),
515 sizeof(CUtensorMap)));
516 return reinterpret_cast<void *>(dTensorMap);
518 #endif
520 #ifdef MLIR_ENABLE_CUDA_CUSPARSE
523 /// Wrapper methods for the cuSparse library.
526 // Some macro magic to get float/double alpha and beta on host.
527 // TODO: add support to passing alpha and beta as arguments
528 #define ALPHABETA(dtp, alpha, beta) \
529 __nv_bfloat16(alpha##16bf) = 1.0f; \
530 __nv_bfloat16(beta##16bf) = 1.0f; \
531 __half(alpha##16f) = 1.0f; \
532 __half(beta##16f) = 1.0f; \
533 float(alpha##f) = 1.0f; \
534 float(beta##f) = 1.0f; \
535 double(alpha##d) = 1.0; \
536 double(beta##d) = 1.0; \
537 const void *(alpha##p) = nullptr; \
538 const void *(beta##p) = nullptr; \
539 if (dtp == CUDA_R_16BF || dtp == CUDA_C_16BF) { \
540 (alpha##p) = reinterpret_cast<void *>(&(alpha##16bf)); \
541 (beta##p) = reinterpret_cast<void *>(&(beta##16bf)); \
542 } else if (dtp == CUDA_R_16F || dtp == CUDA_C_16F) { \
543 (alpha##p) = reinterpret_cast<void *>(&(alpha##16f)); \
544 (beta##p) = reinterpret_cast<void *>(&(beta##16f)); \
545 } else if (dtp == CUDA_R_32F || dtp == CUDA_C_32F) { \
546 (alpha##p) = reinterpret_cast<void *>(&(alpha##f)); \
547 (beta##p) = reinterpret_cast<void *>(&(beta##f)); \
548 } else { \
549 (alpha##p) = reinterpret_cast<void *>(&(alpha##d)); \
550 (beta##p) = reinterpret_cast<void *>(&(beta##d)); \
553 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuCreateSparseEnv() {
554 // ScopedContext is for cuda initialization.
555 ScopedContext scopedContext;
556 assert(!cusparse_env && "client called mgpuCreateSparseEnv() twice");
557 CUSPARSE_REPORT_IF_ERROR(cusparseCreate(&cusparse_env));
560 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuDestroySparseEnv() {
561 assert(cusparse_env && "client did not call mgpuCreateSparseEnv()");
562 CUSPARSE_REPORT_IF_ERROR(cusparseDestroy(cusparse_env));
563 cusparse_env = nullptr;
566 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
567 mgpuCreateDnVec(intptr_t size, void *values, int32_t dtp, CUstream /*stream*/) {
568 cusparseDnVecDescr_t vec = nullptr;
569 auto dTp = static_cast<cudaDataType_t>(dtp);
570 CUSPARSE_REPORT_IF_ERROR(cusparseCreateDnVec(&vec, size, values, dTp))
571 return reinterpret_cast<void *>(vec);
574 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
575 mgpuDestroyDnVec(void *v, CUstream /*stream*/) {
576 cusparseDnVecDescr_t vec = reinterpret_cast<cusparseDnVecDescr_t>(v);
577 CUSPARSE_REPORT_IF_ERROR(cusparseDestroyDnVec(vec))
580 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
581 mgpuCreateDnMat(intptr_t rows, intptr_t cols, void *values, int32_t dtp,
582 CUstream /*stream*/) {
583 cusparseDnMatDescr_t mat = nullptr;
584 auto dTp = static_cast<cudaDataType_t>(dtp);
585 CUSPARSE_REPORT_IF_ERROR(cusparseCreateDnMat(&mat, rows, cols, /*ld=*/cols,
586 values, dTp, CUSPARSE_ORDER_ROW))
587 return reinterpret_cast<void *>(mat);
590 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
591 mgpuDestroyDnMat(void *m, CUstream /*stream*/) {
592 cusparseDnMatDescr_t mat = reinterpret_cast<cusparseDnMatDescr_t>(m);
593 CUSPARSE_REPORT_IF_ERROR(cusparseDestroyDnMat(mat))
596 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
597 mgpuCreateCoo(intptr_t rows, intptr_t cols, intptr_t nnz, void *rowIdxs,
598 void *colIdxs, void *values, int32_t itp, int32_t dtp,
599 CUstream /*stream*/) {
600 cusparseSpMatDescr_t mat = nullptr;
601 auto iTp = static_cast<cusparseIndexType_t>(itp);
602 auto dTp = static_cast<cudaDataType_t>(dtp);
603 CUSPARSE_REPORT_IF_ERROR(cusparseCreateCoo(&mat, rows, cols, nnz, rowIdxs,
604 colIdxs, values, iTp,
605 CUSPARSE_INDEX_BASE_ZERO, dTp))
606 return reinterpret_cast<void *>(mat);
609 #ifdef CUSPARSE_COO_AOS // deprecated in cuSPARSE 11.2
610 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
611 mgpuCreateCooAoS(intptr_t rows, intptr_t cols, intptr_t nnz, void *idxs,
612 void *values, int32_t itp, int32_t dtp, CUstream /*stream*/) {
613 cusparseSpMatDescr_t mat = nullptr;
614 auto iTp = static_cast<cusparseIndexType_t>(itp);
615 auto dTp = static_cast<cudaDataType_t>(dtp);
616 CUSPARSE_REPORT_IF_ERROR(cusparseCreateCooAoS(
617 &mat, rows, cols, nnz, idxs, values, iTp, CUSPARSE_INDEX_BASE_ZERO, dTp))
618 return reinterpret_cast<void *>(mat);
620 #endif // CUSPARSE_COO_AOS
622 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
623 mgpuCreateCsr(intptr_t rows, intptr_t cols, intptr_t nnz, void *rowPos,
624 void *colIdxs, void *values, int32_t ptp, int32_t itp,
625 int32_t dtp, CUstream /*stream*/) {
626 cusparseSpMatDescr_t mat = nullptr;
627 auto pTp = static_cast<cusparseIndexType_t>(ptp);
628 auto iTp = static_cast<cusparseIndexType_t>(itp);
629 auto dTp = static_cast<cudaDataType_t>(dtp);
630 CUSPARSE_REPORT_IF_ERROR(cusparseCreateCsr(&mat, rows, cols, nnz, rowPos,
631 colIdxs, values, pTp, iTp,
632 CUSPARSE_INDEX_BASE_ZERO, dTp))
633 return reinterpret_cast<void *>(mat);
636 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
637 mgpuCreateCsc(intptr_t rows, intptr_t cols, intptr_t nnz, void *colPos,
638 void *rowIdxs, void *values, int32_t ptp, int32_t itp,
639 int32_t dtp, CUstream /*stream*/) {
640 cusparseSpMatDescr_t mat = nullptr;
641 auto pTp = static_cast<cusparseIndexType_t>(ptp);
642 auto iTp = static_cast<cusparseIndexType_t>(itp);
643 auto dTp = static_cast<cudaDataType_t>(dtp);
644 CUSPARSE_REPORT_IF_ERROR(cusparseCreateCsc(&mat, rows, cols, nnz, colPos,
645 rowIdxs, values, pTp, iTp,
646 CUSPARSE_INDEX_BASE_ZERO, dTp))
647 return reinterpret_cast<void *>(mat);
650 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
651 mgpuCreateBsr(intptr_t brows, intptr_t bcols, intptr_t bnnz, intptr_t rBsz,
652 intptr_t cBsz, void *rowPos, void *colIdxs, void *values,
653 int32_t ptp, int32_t itp, int32_t dtp, CUstream /*stream*/) {
654 cusparseSpMatDescr_t mat = nullptr;
655 #if CUSPARSE_VERSION >= 12100
656 auto pTp = static_cast<cusparseIndexType_t>(ptp);
657 auto iTp = static_cast<cusparseIndexType_t>(itp);
658 auto dTp = static_cast<cudaDataType_t>(dtp);
659 CUSPARSE_REPORT_IF_ERROR(cusparseCreateBsr(
660 &mat, brows, bcols, bnnz, rBsz, cBsz, rowPos, colIdxs, values, pTp, iTp,
661 CUSPARSE_INDEX_BASE_ZERO, dTp, CUSPARSE_ORDER_ROW))
662 #endif
663 return reinterpret_cast<void *>(mat);
666 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
667 mgpuDestroySpMat(void *m, CUstream /*stream*/) {
668 cusparseSpMatDescr_t mat = reinterpret_cast<cusparseSpMatDescr_t>(m);
669 CUSPARSE_REPORT_IF_ERROR(cusparseDestroySpMat(mat))
672 extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpMVBufferSize(
673 int32_t ma, void *a, void *x, void *y, int32_t ctp, CUstream /*stream*/) {
674 assert(cusparse_env && "client did not call mgpuCreateSparseEnv()");
675 cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
676 cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
677 cusparseDnVecDescr_t vecX = reinterpret_cast<cusparseDnVecDescr_t>(x);
678 cusparseDnVecDescr_t vecY = reinterpret_cast<cusparseDnVecDescr_t>(y);
679 cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp);
680 ALPHABETA(cTp, alpha, beta)
681 size_t bufferSize = 0;
682 CUSPARSE_REPORT_IF_ERROR(cusparseSpMV_bufferSize(
683 cusparse_env, modeA, alphap, matA, vecX, betap, vecY, cTp,
684 CUSPARSE_SPMV_ALG_DEFAULT, &bufferSize))
685 return bufferSize;
688 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMV(int32_t ma, void *a, void *x,
689 void *y, int32_t ctp,
690 void *buf,
691 CUstream /*stream*/) {
692 assert(cusparse_env && "client did not call mgpuCreateSparseEnv()");
693 cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
694 cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
695 cusparseDnVecDescr_t vecX = reinterpret_cast<cusparseDnVecDescr_t>(x);
696 cusparseDnVecDescr_t vecY = reinterpret_cast<cusparseDnVecDescr_t>(y);
697 cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp);
698 ALPHABETA(cTp, alpha, beta)
699 CUSPARSE_REPORT_IF_ERROR(cusparseSpMV(cusparse_env, modeA, alphap, matA, vecX,
700 betap, vecY, cTp,
701 CUSPARSE_SPMV_ALG_DEFAULT, buf))
704 extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
705 mgpuSpMMBufferSize(int32_t ma, int32_t mb, void *a, void *b, void *c,
706 int32_t ctp, CUstream /*stream*/) {
707 assert(cusparse_env && "client did not call mgpuCreateSparseEnv()");
708 cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
709 cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
710 cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
711 cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
712 cusparseDnMatDescr_t matC = reinterpret_cast<cusparseDnMatDescr_t>(c);
713 cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp);
714 ALPHABETA(cTp, alpha, beta)
715 size_t bufferSize = 0;
716 CUSPARSE_REPORT_IF_ERROR(cusparseSpMM_bufferSize(
717 cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
718 CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize))
719 return bufferSize;
722 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMM(int32_t ma, int32_t mb,
723 void *a, void *b, void *c,
724 int32_t ctp, void *buf,
725 CUstream /*stream*/) {
726 assert(cusparse_env && "client did not call mgpuCreateSparseEnv()");
727 cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
728 cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
729 cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
730 cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
731 cusparseDnMatDescr_t matC = reinterpret_cast<cusparseDnMatDescr_t>(c);
732 cudaDataType_t cTp = static_cast<cudaDataType_t>(ctp);
733 ALPHABETA(cTp, alpha, beta)
734 CUSPARSE_REPORT_IF_ERROR(cusparseSpMM(cusparse_env, modeA, modeB, alphap,
735 matA, matB, betap, matC, cTp,
736 CUSPARSE_SPMM_ALG_DEFAULT, buf))
739 extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
740 mgpuSDDMMBufferSize(int32_t ma, int32_t mb, void *a, void *b, void *c,
741 int32_t ctp, CUstream /*stream*/) {
742 assert(cusparse_env && "client did not call mgpuCreateSparseEnv()");
743 cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
744 cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
745 cusparseDnMatDescr_t matA = reinterpret_cast<cusparseDnMatDescr_t>(a);
746 cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
747 cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
748 auto cTp = static_cast<cudaDataType_t>(ctp);
749 ALPHABETA(cTp, alpha, beta)
750 size_t bufferSize = 0;
751 CUSPARSE_REPORT_IF_ERROR(cusparseSDDMM_bufferSize(
752 cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
753 CUSPARSE_SDDMM_ALG_DEFAULT, &bufferSize))
754 return bufferSize;
757 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSDDMM(int32_t ma, int32_t mb,
758 void *a, void *b, void *c,
759 int32_t ctp, void *buf,
760 CUstream /*stream*/) {
761 assert(cusparse_env && "client did not call mgpuCreateSparseEnv()");
762 cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
763 cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
764 cusparseDnMatDescr_t matA = reinterpret_cast<cusparseDnMatDescr_t>(a);
765 cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
766 cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
767 auto cTp = static_cast<cudaDataType_t>(ctp);
768 ALPHABETA(cTp, alpha, beta)
769 CUSPARSE_REPORT_IF_ERROR(cusparseSDDMM(cusparse_env, modeA, modeB, alphap,
770 matA, matB, betap, matC, cTp,
771 CUSPARSE_SDDMM_ALG_DEFAULT, buf))
774 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
775 mgpuSpGEMMCreateDescr(CUstream /*stream*/) {
776 cusparseSpGEMMDescr_t spgemmDesc = nullptr;
777 CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_createDescr(&spgemmDesc))
778 return reinterpret_cast<void *>(spgemmDesc);
781 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
782 mgpuSpGEMMDestroyDescr(void *s, CUstream /*stream*/) {
783 cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s);
784 CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_destroyDescr(spgemmDesc))
787 extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpGEMMWorkEstimation(
788 void *s, int32_t ma, int32_t mb, void *a, void *b, void *c, int32_t ctp,
789 intptr_t bs, void *buf, CUstream /*stream*/) {
790 cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s);
791 cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
792 cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
793 cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
794 cusparseSpMatDescr_t matB = reinterpret_cast<cusparseSpMatDescr_t>(b);
795 cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
796 auto cTp = static_cast<cudaDataType_t>(ctp);
797 ALPHABETA(cTp, alpha, beta)
798 size_t newBufferSize = bs;
799 CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_workEstimation(
800 cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
801 CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &newBufferSize, buf))
802 return newBufferSize;
805 extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
806 mgpuSpGEMMCompute(void *s, int32_t ma, int32_t mb, void *a, void *b, void *c,
807 int32_t ctp, intptr_t bsz2, void *buf2, CUstream /*stream*/) {
808 cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s);
809 cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
810 cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
811 cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
812 cusparseSpMatDescr_t matB = reinterpret_cast<cusparseSpMatDescr_t>(b);
813 cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
814 auto cTp = static_cast<cudaDataType_t>(ctp);
815 ALPHABETA(cTp, alpha, beta)
816 size_t newBufferSize2 = bsz2;
817 CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_compute(
818 cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
819 CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &newBufferSize2, buf2))
820 return newBufferSize2;
823 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
824 mgpuSpGEMMCopy(void *s, int32_t ma, int32_t mb, void *a, void *b, void *c,
825 int32_t ctp, CUstream /*stream*/) {
826 cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s);
827 cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
828 cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
829 cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
830 cusparseSpMatDescr_t matB = reinterpret_cast<cusparseSpMatDescr_t>(b);
831 cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
832 auto cTp = static_cast<cudaDataType_t>(ctp);
833 ALPHABETA(cTp, alpha, beta)
834 CUSPARSE_REPORT_IF_ERROR(
835 cusparseSpGEMM_copy(cusparse_env, modeA, modeB, alphap, matA, matB, betap,
836 matC, cTp, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc))
839 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
840 mgpuSpMatGetSize(void *m, void *r, void *c, void *n, CUstream /*stream*/) {
841 cusparseConstSpMatDescr_t matDescr =
842 reinterpret_cast<cusparseConstSpMatDescr_t>(m);
843 int64_t *rows = reinterpret_cast<int64_t *>(r);
844 int64_t *cols = reinterpret_cast<int64_t *>(c);
845 int64_t *nnz = reinterpret_cast<int64_t *>(n);
846 CUSPARSE_REPORT_IF_ERROR(cusparseSpMatGetSize(matDescr, rows, cols, nnz));
849 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
850 mgpuSetCsrPointers(void *m, void *p, void *c, void *v, CUstream /*stream*/) {
851 cusparseSpMatDescr_t matDescr = reinterpret_cast<cusparseSpMatDescr_t>(m);
852 CUSPARSE_REPORT_IF_ERROR(cusparseCsrSetPointers(matDescr, p, c, v));
855 #ifdef MLIR_ENABLE_CUDA_CUSPARSELT
858 /// Wrapper methods for the cuSparseLt library.
861 struct cusparseLtSpMatHandleAndData {
862 cusparseLtMatDescriptor_t mat;
863 // TODO: the following three are associated with the SpMM operator rather than
864 // the sparse matrix. Create workspace buffers and pass them to the SpMM
865 // execution.
866 cusparseLtMatmulAlgSelection_t alg_sel;
867 cusparseLtMatmulPlan_t plan;
868 cusparseLtMatmulDescriptor_t matmul;
869 void *values{nullptr};
872 struct cusparseLtDnMatHandleAndData {
873 cusparseLtMatDescriptor_t mat;
874 void *values{nullptr};
877 static_assert(sizeof(cusparseLtHandle_t) == 11024,
878 "Unexpected cusparseLt handle size");
879 static_assert(sizeof(cusparseLtSpMatHandleAndData) == 44104,
880 "Unexpected cusparseLt sparse matrix handle size");
881 static_assert(sizeof(cusparseLtDnMatHandleAndData) == 11032,
882 "Unexpected cusparseLt dense matrix handle size");
884 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuCreateSparseLtEnv() {
885 // ScopedContext is for cuda initialization.
886 ScopedContext scopedContext;
887 assert(!cusparseLt_initiated &&
888 "client called mgpuCreateSparseLtEnv() twice");
889 // Note that cuSparseLt still uses cusparseStatus_t.
890 CUSPARSE_REPORT_IF_ERROR(cusparseLtInit(&cusparseLt_env));
891 cusparseLt_initiated = true;
894 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuDestroySparseLtEnv() {
895 assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()");
896 CUSPARSE_REPORT_IF_ERROR(cusparseLtDestroy(&cusparseLt_env));
897 cusparseLt_initiated = false;
900 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
901 mgpuCreateCuSparseLtDnMat(void *dh, intptr_t rows, intptr_t cols, void *values,
902 int32_t dtp, CUstream /*stream*/) {
903 assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()");
904 auto dnmat_handle = reinterpret_cast<cusparseLtDnMatHandleAndData *>(dh);
905 dnmat_handle->values = values;
906 auto dTp = static_cast<cudaDataType_t>(dtp);
907 // Assume row-major when deciding lda.
908 const uint32_t alignment = 16;
909 CUSPARSE_REPORT_IF_ERROR(cusparseLtDenseDescriptorInit(
910 &cusparseLt_env, &(dnmat_handle->mat), rows, cols, /*lda=*/cols,
911 alignment, dTp, CUSPARSE_ORDER_ROW))
914 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
915 mgpuDestroyCuSparseLtDnMat(void *dh, CUstream /*stream*/) {
916 auto dnmat_handle = reinterpret_cast<cusparseLtDnMatHandleAndData *>(dh);
917 CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(dnmat_handle->mat)))
920 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
921 mgpuCusparseLtCreate2To4SpMat(void *sh, intptr_t rows, intptr_t cols,
922 void *values, int32_t dtp, CUstream /*stream*/) {
923 assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()");
924 auto spmat_handle = reinterpret_cast<cusparseLtSpMatHandleAndData *>(sh);
925 spmat_handle->values = values;
926 auto dTp = static_cast<cudaDataType_t>(dtp);
927 // Assume row-major when deciding lda.
928 const uint32_t alignment = 16;
929 CUSPARSE_REPORT_IF_ERROR(cusparseLtStructuredDescriptorInit(
930 &cusparseLt_env, &(spmat_handle->mat), rows, cols, /*ld=*/cols, alignment,
931 dTp, CUSPARSE_ORDER_ROW, CUSPARSELT_SPARSITY_50_PERCENT))
934 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
935 mgpuDestroyCuSparseLtSpMat(void *sh, CUstream /*stream*/) {
936 auto spmat_handle = reinterpret_cast<cusparseLtSpMatHandleAndData *>(sh);
937 CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(spmat_handle->mat)))
940 // Several things are being done in this stage, algorithm selection, planning,
941 // and returning workspace and compressed matrices data buffer sizes.
942 // The parameter prune_flag is used to indicate whether pruning and pruning
943 // check will happen 0 means not prune or prune check, 1 means prune, 2 means
944 // prune & prune check
945 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
946 mgpuCuSparseLtSpMMBufferSize(void *bs, int32_t ma, int32_t mb, void *a, void *b,
947 void *c, int32_t ctp, int32_t prune_flag,
948 CUstream stream) {
949 assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()");
950 // TODO: support more advanced settings, e.g., the input right operand is a
951 // sparse matrix assuming matA is the sparse matrix
952 auto matA = reinterpret_cast<cusparseLtSpMatHandleAndData *>(a);
953 auto matB = reinterpret_cast<cusparseLtDnMatHandleAndData *>(b);
954 auto matC = reinterpret_cast<cusparseLtDnMatHandleAndData *>(c);
955 auto workspace_size = reinterpret_cast<size_t *>(bs);
956 auto compressed_size = &(reinterpret_cast<size_t *>(bs)[1]);
957 auto compressed_buffer_size = &(reinterpret_cast<size_t *>(bs)[2]);
958 auto cTp = static_cast<cusparseComputeType>(ctp);
960 cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
961 cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
962 CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulDescriptorInit(
963 &cusparseLt_env, &(matA->matmul), modeA, modeB, &(matA->mat),
964 &(matB->mat), &(matC->mat), &(matC->mat), cTp))
965 CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulAlgSelectionInit(
966 &cusparseLt_env, &(matA->alg_sel), &(matA->matmul),
967 CUSPARSELT_MATMUL_ALG_DEFAULT))
968 int alg = 0;
969 CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulAlgSetAttribute(
970 &cusparseLt_env, &(matA->alg_sel), CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg,
971 sizeof(alg)))
973 CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulPlanInit(
974 &cusparseLt_env, &(matA->plan), &(matA->matmul), &(matA->alg_sel)))
976 // Pruning step (in-place).
977 if (prune_flag > 0)
978 CUSPARSE_REPORT_IF_ERROR(cusparseLtSpMMAPrune(
979 &cusparseLt_env, &(matA->matmul), matA->values, matA->values,
980 CUSPARSELT_PRUNE_SPMMA_STRIP, stream))
982 // Check structure of A.
983 // Note that this adds a synchronization on the stream.
984 // TODO: Do we want that?
985 if (prune_flag == 2) {
986 int *dvalid = (int *)mgpuMemAlloc(sizeof(int), stream, false);
987 CUSPARSE_REPORT_IF_ERROR(cusparseLtSpMMAPruneCheck(
988 &cusparseLt_env, &(matA->matmul), matA->values, dvalid, stream))
989 int valid = 0;
990 mgpuMemcpy(&valid, dvalid, sizeof(int), stream);
991 mgpuStreamSynchronize(stream);
992 mgpuMemFree(dvalid, stream);
993 if (valid != 0)
994 fprintf(stderr, "CUPARSE-LT: sparse matrix is not 2:4; computed results "
995 "will be invalid\n");
998 CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulGetWorkspace(
999 &cusparseLt_env, &(matA->plan), workspace_size))
1000 CUSPARSE_REPORT_IF_ERROR(cusparseLtSpMMACompressedSize(
1001 &cusparseLt_env, &(matA->plan), compressed_size, compressed_buffer_size))
1004 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
1005 mgpuCuSparseLtSpMM(void *a, void *b, void *c, void *d_workspace,
1006 void *dA_compressed, void *dA_compressedBuffer,
1007 CUstream stream) {
1008 assert(cusparseLt_initiated && "client did not call mgpuCreateSparseLtEnv()");
1009 auto matA = reinterpret_cast<cusparseLtSpMatHandleAndData *>(a);
1010 auto matB = reinterpret_cast<cusparseLtDnMatHandleAndData *>(b);
1011 auto matC = reinterpret_cast<cusparseLtDnMatHandleAndData *>(c);
1013 ALPHABETA(CUDA_R_32F, alpha, beta)
1014 CUSPARSE_REPORT_IF_ERROR(
1015 cusparseLtSpMMACompress(&cusparseLt_env, &(matA->plan), (matA->values),
1016 dA_compressed, dA_compressedBuffer, stream))
1018 // TODO: add support to multi-stream execution
1019 // Perform the matrix multiplication. D = A*B+C using C==D for now
1020 CUSPARSE_REPORT_IF_ERROR(
1021 cusparseLtMatmul(&cusparseLt_env, &(matA->plan), alphap, dA_compressed,
1022 matB->values, betap, matC->values,
1023 /*dD*/ matC->values, d_workspace, nullptr, 0))
1025 CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(matA->mat)))
1026 // destroy the plan associated with the sparse matrix
1027 CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulPlanDestroy(&(matA->plan)))
1030 #endif // MLIR_ENABLE_CUDA_CUSPARSELT
1031 #endif // MLIR_ENABLE_CUDA_CUSPARSE