1 //===- RocmRuntimeWrappers.cpp - MLIR ROCM runtime wrapper library --------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Implements C wrappers around the ROCM library for easy linking in ORC jit.
10 // Also adds some debugging helpers that are helpful when writing MLIR code to
13 //===----------------------------------------------------------------------===//
18 #include "mlir/ExecutionEngine/CRunnerUtils.h"
19 #include "llvm/ADT/ArrayRef.h"
21 #include "hip/hip_runtime.h"
23 #define HIP_REPORT_IF_ERROR(expr) \
24 [](hipError_t result) { \
27 const char *name = hipGetErrorName(result); \
30 fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \
33 thread_local
static int32_t defaultDevice
= 0;
35 extern "C" hipModule_t
mgpuModuleLoad(void *data
, size_t /*gpuBlobSize*/) {
36 hipModule_t module
= nullptr;
37 HIP_REPORT_IF_ERROR(hipModuleLoadData(&module
, data
));
41 extern "C" hipModule_t
mgpuModuleLoadJIT(void *data
, int optLevel
) {
42 assert(false && "This function is not available in HIP.");
46 extern "C" void mgpuModuleUnload(hipModule_t module
) {
47 HIP_REPORT_IF_ERROR(hipModuleUnload(module
));
50 extern "C" hipFunction_t
mgpuModuleGetFunction(hipModule_t module
,
52 hipFunction_t function
= nullptr;
53 HIP_REPORT_IF_ERROR(hipModuleGetFunction(&function
, module
, name
));
57 // The wrapper uses intptr_t instead of ROCM's unsigned int to match
58 // the type of MLIR's index type. This avoids the need for casts in the
59 // generated MLIR code.
60 extern "C" void mgpuLaunchKernel(hipFunction_t function
, intptr_t gridX
,
61 intptr_t gridY
, intptr_t gridZ
,
62 intptr_t blockX
, intptr_t blockY
,
63 intptr_t blockZ
, int32_t smem
,
64 hipStream_t stream
, void **params
,
65 void **extra
, size_t /*paramsCount*/) {
66 HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function
, gridX
, gridY
, gridZ
,
67 blockX
, blockY
, blockZ
, smem
,
68 stream
, params
, extra
));
71 extern "C" hipStream_t
mgpuStreamCreate() {
72 hipStream_t stream
= nullptr;
73 HIP_REPORT_IF_ERROR(hipStreamCreate(&stream
));
77 extern "C" void mgpuStreamDestroy(hipStream_t stream
) {
78 HIP_REPORT_IF_ERROR(hipStreamDestroy(stream
));
81 extern "C" void mgpuStreamSynchronize(hipStream_t stream
) {
82 return HIP_REPORT_IF_ERROR(hipStreamSynchronize(stream
));
85 extern "C" void mgpuStreamWaitEvent(hipStream_t stream
, hipEvent_t event
) {
86 HIP_REPORT_IF_ERROR(hipStreamWaitEvent(stream
, event
, /*flags=*/0));
89 extern "C" hipEvent_t
mgpuEventCreate() {
90 hipEvent_t event
= nullptr;
91 HIP_REPORT_IF_ERROR(hipEventCreateWithFlags(&event
, hipEventDisableTiming
));
95 extern "C" void mgpuEventDestroy(hipEvent_t event
) {
96 HIP_REPORT_IF_ERROR(hipEventDestroy(event
));
99 extern "C" void mgpuEventSynchronize(hipEvent_t event
) {
100 HIP_REPORT_IF_ERROR(hipEventSynchronize(event
));
103 extern "C" void mgpuEventRecord(hipEvent_t event
, hipStream_t stream
) {
104 HIP_REPORT_IF_ERROR(hipEventRecord(event
, stream
));
107 extern "C" void *mgpuMemAlloc(uint64_t sizeBytes
, hipStream_t
/*stream*/,
108 bool /*isHostShared*/) {
110 HIP_REPORT_IF_ERROR(hipMalloc(&ptr
, sizeBytes
));
114 extern "C" void mgpuMemFree(void *ptr
, hipStream_t
/*stream*/) {
115 HIP_REPORT_IF_ERROR(hipFree(ptr
));
118 extern "C" void mgpuMemcpy(void *dst
, void *src
, size_t sizeBytes
,
119 hipStream_t stream
) {
121 hipMemcpyAsync(dst
, src
, sizeBytes
, hipMemcpyDefault
, stream
));
124 extern "C" void mgpuMemset32(void *dst
, int value
, size_t count
,
125 hipStream_t stream
) {
126 HIP_REPORT_IF_ERROR(hipMemsetD32Async(reinterpret_cast<hipDeviceptr_t
>(dst
),
127 value
, count
, stream
));
130 extern "C" void mgpuMemset16(void *dst
, int short value
, size_t count
,
131 hipStream_t stream
) {
132 HIP_REPORT_IF_ERROR(hipMemsetD16Async(reinterpret_cast<hipDeviceptr_t
>(dst
),
133 value
, count
, stream
));
136 /// Helper functions for writing mlir example code
138 // Allows to register byte array with the ROCM runtime. Helpful until we have
139 // transfer functions implemented.
140 extern "C" void mgpuMemHostRegister(void *ptr
, uint64_t sizeBytes
) {
141 HIP_REPORT_IF_ERROR(hipHostRegister(ptr
, sizeBytes
, /*flags=*/0));
144 // Allows to register a MemRef with the ROCm runtime. Helpful until we have
145 // transfer functions implemented.
147 mgpuMemHostRegisterMemRef(int64_t rank
, StridedMemRefType
<char, 1> *descriptor
,
148 int64_t elementSizeBytes
) {
150 llvm::SmallVector
<int64_t, 4> denseStrides(rank
);
151 llvm::ArrayRef
<int64_t> sizes(descriptor
->sizes
, rank
);
152 llvm::ArrayRef
<int64_t> strides(sizes
.end(), rank
);
154 std::partial_sum(sizes
.rbegin(), sizes
.rend(), denseStrides
.rbegin(),
155 std::multiplies
<int64_t>());
156 auto sizeBytes
= denseStrides
.front() * elementSizeBytes
;
158 // Only densely packed tensors are currently supported.
159 std::rotate(denseStrides
.begin(), denseStrides
.begin() + 1,
161 denseStrides
.back() = 1;
162 assert(strides
== llvm::ArrayRef(denseStrides
));
164 auto ptr
= descriptor
->data
+ descriptor
->offset
* elementSizeBytes
;
165 mgpuMemHostRegister(ptr
, sizeBytes
);
168 // Allows to unregister byte array with the ROCM runtime. Helpful until we have
169 // transfer functions implemented.
170 extern "C" void mgpuMemHostUnregister(void *ptr
) {
171 HIP_REPORT_IF_ERROR(hipHostUnregister(ptr
));
174 // Allows to unregister a MemRef with the ROCm runtime. Helpful until we have
175 // transfer functions implemented.
177 mgpuMemHostUnregisterMemRef(int64_t rank
,
178 StridedMemRefType
<char, 1> *descriptor
,
179 int64_t elementSizeBytes
) {
180 auto ptr
= descriptor
->data
+ descriptor
->offset
* elementSizeBytes
;
181 mgpuMemHostUnregister(ptr
);
184 template <typename T
>
185 void mgpuMemGetDevicePointer(T
*hostPtr
, T
**devicePtr
) {
186 HIP_REPORT_IF_ERROR(hipSetDevice(0));
188 hipHostGetDevicePointer((void **)devicePtr
, hostPtr
, /*flags=*/0));
191 extern "C" StridedMemRefType
<float, 1>
192 mgpuMemGetDeviceMemRef1dFloat(float *allocated
, float *aligned
, int64_t offset
,
193 int64_t size
, int64_t stride
) {
194 float *devicePtr
= nullptr;
195 mgpuMemGetDevicePointer(aligned
, &devicePtr
);
196 return {devicePtr
, devicePtr
, offset
, {size
}, {stride
}};
199 extern "C" StridedMemRefType
<int32_t, 1>
200 mgpuMemGetDeviceMemRef1dInt32(int32_t *allocated
, int32_t *aligned
,
201 int64_t offset
, int64_t size
, int64_t stride
) {
202 int32_t *devicePtr
= nullptr;
203 mgpuMemGetDevicePointer(aligned
, &devicePtr
);
204 return {devicePtr
, devicePtr
, offset
, {size
}, {stride
}};
207 extern "C" void mgpuSetDefaultDevice(int32_t device
) {
208 defaultDevice
= device
;
209 HIP_REPORT_IF_ERROR(hipSetDevice(device
));