1 //===- SyclRuntimeWrappers.cpp - MLIR SYCL wrapper library ------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Implements wrappers around the sycl runtime library with C linkage
11 //===----------------------------------------------------------------------===//
13 #include <CL/sycl.hpp>
14 #include <level_zero/ze_api.h>
15 #include <sycl/ext/oneapi/backend/level_zero.hpp>
18 #define SYCL_RUNTIME_EXPORT __declspec(dllexport)
20 #define SYCL_RUNTIME_EXPORT
26 auto catchAll(F
&&func
) {
29 } catch (const std::exception
&e
) {
30 fprintf(stdout
, "An exception was thrown: %s\n", e
.what());
34 fprintf(stdout
, "An unknown exception was thrown\n");
40 #define L0_SAFE_CALL(call) \
42 ze_result_t status = (call); \
43 if (status != ZE_RESULT_SUCCESS) { \
44 fprintf(stdout, "L0 error %d\n", status); \
52 static sycl::device
getDefaultDevice() {
53 static sycl::device syclDevice
;
54 static bool isDeviceInitialised
= false;
55 if (!isDeviceInitialised
) {
56 auto platformList
= sycl::platform::get_platforms();
57 for (const auto &platform
: platformList
) {
58 auto platformName
= platform
.get_info
<sycl::info::platform::name
>();
59 bool isLevelZero
= platformName
.find("Level-Zero") != std::string::npos
;
63 syclDevice
= platform
.get_devices()[0];
64 isDeviceInitialised
= true;
67 throw std::runtime_error("getDefaultDevice failed");
72 static sycl::context
getDefaultContext() {
73 static sycl::context syclContext
{getDefaultDevice()};
77 static void *allocDeviceMemory(sycl::queue
*queue
, size_t size
, bool isShared
) {
78 void *memPtr
= nullptr;
80 memPtr
= sycl::aligned_alloc_shared(64, size
, getDefaultDevice(),
83 memPtr
= sycl::aligned_alloc_device(64, size
, getDefaultDevice(),
86 if (memPtr
== nullptr) {
87 throw std::runtime_error("mem allocation failed!");
92 static void deallocDeviceMemory(sycl::queue
*queue
, void *ptr
) {
93 sycl::free(ptr
, *queue
);
96 static ze_module_handle_t
loadModule(const void *data
, size_t dataSize
) {
98 ze_module_handle_t zeModule
;
99 ze_module_desc_t desc
= {ZE_STRUCTURE_TYPE_MODULE_DESC
,
101 ZE_MODULE_FORMAT_IL_SPIRV
,
103 (const uint8_t *)data
,
106 auto zeDevice
= sycl::get_native
<sycl::backend::ext_oneapi_level_zero
>(
108 auto zeContext
= sycl::get_native
<sycl::backend::ext_oneapi_level_zero
>(
109 getDefaultContext());
110 L0_SAFE_CALL(zeModuleCreate(zeContext
, zeDevice
, &desc
, &zeModule
, nullptr));
114 static sycl::kernel
*getKernel(ze_module_handle_t zeModule
, const char *name
) {
117 ze_kernel_handle_t zeKernel
;
118 ze_kernel_desc_t desc
= {};
119 desc
.pKernelName
= name
;
121 L0_SAFE_CALL(zeKernelCreate(zeModule
, &desc
, &zeKernel
));
122 sycl::kernel_bundle
<sycl::bundle_state::executable
> kernelBundle
=
123 sycl::make_kernel_bundle
<sycl::backend::ext_oneapi_level_zero
,
124 sycl::bundle_state::executable
>(
125 {zeModule
}, getDefaultContext());
127 auto kernel
= sycl::make_kernel
<sycl::backend::ext_oneapi_level_zero
>(
128 {kernelBundle
, zeKernel
}, getDefaultContext());
129 return new sycl::kernel(kernel
);
132 static void launchKernel(sycl::queue
*queue
, sycl::kernel
*kernel
, size_t gridX
,
133 size_t gridY
, size_t gridZ
, size_t blockX
,
134 size_t blockY
, size_t blockZ
, size_t sharedMemBytes
,
135 void **params
, size_t paramsCount
) {
136 auto syclGlobalRange
=
137 sycl::range
<3>(blockZ
* gridZ
, blockY
* gridY
, blockX
* gridX
);
138 auto syclLocalRange
= sycl::range
<3>(blockZ
, blockY
, blockX
);
139 sycl::nd_range
<3> syclNdRange(syclGlobalRange
, syclLocalRange
);
141 queue
->submit([&](sycl::handler
&cgh
) {
142 for (size_t i
= 0; i
< paramsCount
; i
++) {
143 cgh
.set_arg(static_cast<uint32_t>(i
), *(static_cast<void **>(params
[i
])));
145 cgh
.parallel_for(syclNdRange
, *kernel
);
151 extern "C" SYCL_RUNTIME_EXPORT
sycl::queue
*mgpuStreamCreate() {
153 return catchAll([&]() {
155 new sycl::queue(getDefaultContext(), getDefaultDevice());
160 extern "C" SYCL_RUNTIME_EXPORT
void mgpuStreamDestroy(sycl::queue
*queue
) {
161 catchAll([&]() { delete queue
; });
164 extern "C" SYCL_RUNTIME_EXPORT
void *
165 mgpuMemAlloc(uint64_t size
, sycl::queue
*queue
, bool isShared
) {
166 return catchAll([&]() {
167 return allocDeviceMemory(queue
, static_cast<size_t>(size
), true);
171 extern "C" SYCL_RUNTIME_EXPORT
void mgpuMemFree(void *ptr
, sycl::queue
*queue
) {
174 deallocDeviceMemory(queue
, ptr
);
179 extern "C" SYCL_RUNTIME_EXPORT ze_module_handle_t
180 mgpuModuleLoad(const void *data
, size_t gpuBlobSize
) {
181 return catchAll([&]() { return loadModule(data
, gpuBlobSize
); });
184 extern "C" SYCL_RUNTIME_EXPORT
sycl::kernel
*
185 mgpuModuleGetFunction(ze_module_handle_t module
, const char *name
) {
186 return catchAll([&]() { return getKernel(module
, name
); });
189 extern "C" SYCL_RUNTIME_EXPORT
void
190 mgpuLaunchKernel(sycl::kernel
*kernel
, size_t gridX
, size_t gridY
, size_t gridZ
,
191 size_t blockX
, size_t blockY
, size_t blockZ
,
192 size_t sharedMemBytes
, sycl::queue
*queue
, void **params
,
193 void ** /*extra*/, size_t paramsCount
) {
194 return catchAll([&]() {
195 launchKernel(queue
, kernel
, gridX
, gridY
, gridZ
, blockX
, blockY
, blockZ
,
196 sharedMemBytes
, params
, paramsCount
);
200 extern "C" SYCL_RUNTIME_EXPORT
void mgpuStreamSynchronize(sycl::queue
*queue
) {
202 catchAll([&]() { queue
->wait(); });
205 extern "C" SYCL_RUNTIME_EXPORT
void
206 mgpuModuleUnload(ze_module_handle_t module
) {
208 catchAll([&]() { L0_SAFE_CALL(zeModuleDestroy(module
)); });