1 //===- vulkan-runtime-wrappers.cpp - MLIR Vulkan runner wrapper library ---===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Implements C runtime wrappers around the VulkanRuntime.
11 //===----------------------------------------------------------------------===//
17 #include "VulkanRuntime.h"
19 // Explicitly export entry points to the vulkan-runtime-wrapper.
22 #define VULKAN_WRAPPER_SYMBOL_EXPORT __declspec(dllexport)
24 #define VULKAN_WRAPPER_SYMBOL_EXPORT __attribute__((visibility("default")))
29 class VulkanRuntimeManager
{
31 VulkanRuntimeManager() = default;
32 VulkanRuntimeManager(const VulkanRuntimeManager
&) = delete;
33 VulkanRuntimeManager
operator=(const VulkanRuntimeManager
&) = delete;
34 ~VulkanRuntimeManager() = default;
36 void setResourceData(DescriptorSetIndex setIndex
, BindingIndex bindIndex
,
37 const VulkanHostMemoryBuffer
&memBuffer
) {
38 std::lock_guard
<std::mutex
> lock(mutex
);
39 vulkanRuntime
.setResourceData(setIndex
, bindIndex
, memBuffer
);
42 void setEntryPoint(const char *entryPoint
) {
43 std::lock_guard
<std::mutex
> lock(mutex
);
44 vulkanRuntime
.setEntryPoint(entryPoint
);
47 void setNumWorkGroups(NumWorkGroups numWorkGroups
) {
48 std::lock_guard
<std::mutex
> lock(mutex
);
49 vulkanRuntime
.setNumWorkGroups(numWorkGroups
);
52 void setShaderModule(uint8_t *shader
, uint32_t size
) {
53 std::lock_guard
<std::mutex
> lock(mutex
);
54 vulkanRuntime
.setShaderModule(shader
, size
);
58 std::lock_guard
<std::mutex
> lock(mutex
);
59 if (failed(vulkanRuntime
.initRuntime()) || failed(vulkanRuntime
.run()) ||
60 failed(vulkanRuntime
.updateHostMemoryBuffers()) ||
61 failed(vulkanRuntime
.destroy())) {
62 std::cerr
<< "runOnVulkan failed";
67 VulkanRuntime vulkanRuntime
;
73 template <typename T
, int N
>
74 struct MemRefDescriptor
{
82 template <typename T
, uint32_t S
>
83 void bindMemRef(void *vkRuntimeManager
, DescriptorSetIndex setIndex
,
84 BindingIndex bindIndex
, MemRefDescriptor
<T
, S
> *ptr
) {
85 uint32_t size
= sizeof(T
);
86 for (unsigned i
= 0; i
< S
; i
++)
87 size
*= ptr
->sizes
[i
];
88 VulkanHostMemoryBuffer memBuffer
{ptr
->aligned
, size
};
89 reinterpret_cast<VulkanRuntimeManager
*>(vkRuntimeManager
)
90 ->setResourceData(setIndex
, bindIndex
, memBuffer
);
94 /// Initializes `VulkanRuntimeManager` and returns a pointer to it.
95 VULKAN_WRAPPER_SYMBOL_EXPORT
void *initVulkan() {
96 return new VulkanRuntimeManager();
99 /// Deinitializes `VulkanRuntimeManager` by the given pointer.
100 VULKAN_WRAPPER_SYMBOL_EXPORT
void deinitVulkan(void *vkRuntimeManager
) {
101 delete reinterpret_cast<VulkanRuntimeManager
*>(vkRuntimeManager
);
104 VULKAN_WRAPPER_SYMBOL_EXPORT
void runOnVulkan(void *vkRuntimeManager
) {
105 reinterpret_cast<VulkanRuntimeManager
*>(vkRuntimeManager
)->runOnVulkan();
108 VULKAN_WRAPPER_SYMBOL_EXPORT
void setEntryPoint(void *vkRuntimeManager
,
109 const char *entryPoint
) {
110 reinterpret_cast<VulkanRuntimeManager
*>(vkRuntimeManager
)
111 ->setEntryPoint(entryPoint
);
114 VULKAN_WRAPPER_SYMBOL_EXPORT
void
115 setNumWorkGroups(void *vkRuntimeManager
, uint32_t x
, uint32_t y
, uint32_t z
) {
116 reinterpret_cast<VulkanRuntimeManager
*>(vkRuntimeManager
)
117 ->setNumWorkGroups({x
, y
, z
});
120 VULKAN_WRAPPER_SYMBOL_EXPORT
void
121 setBinaryShader(void *vkRuntimeManager
, uint8_t *shader
, uint32_t size
) {
122 reinterpret_cast<VulkanRuntimeManager
*>(vkRuntimeManager
)
123 ->setShaderModule(shader
, size
);
126 /// Binds the given memref to the given descriptor set and descriptor
128 #define DECLARE_BIND_MEMREF(size, type, typeName) \
129 VULKAN_WRAPPER_SYMBOL_EXPORT void bindMemRef##size##D##typeName( \
130 void *vkRuntimeManager, DescriptorSetIndex setIndex, \
131 BindingIndex bindIndex, MemRefDescriptor<type, size> *ptr) { \
132 bindMemRef<type, size>(vkRuntimeManager, setIndex, bindIndex, ptr); \
135 DECLARE_BIND_MEMREF(1, float, Float
)
136 DECLARE_BIND_MEMREF(2, float, Float
)
137 DECLARE_BIND_MEMREF(3, float, Float
)
138 DECLARE_BIND_MEMREF(1, int32_t, Int32
)
139 DECLARE_BIND_MEMREF(2, int32_t, Int32
)
140 DECLARE_BIND_MEMREF(3, int32_t, Int32
)
141 DECLARE_BIND_MEMREF(1, int16_t, Int16
)
142 DECLARE_BIND_MEMREF(2, int16_t, Int16
)
143 DECLARE_BIND_MEMREF(3, int16_t, Int16
)
144 DECLARE_BIND_MEMREF(1, int8_t, Int8
)
145 DECLARE_BIND_MEMREF(2, int8_t, Int8
)
146 DECLARE_BIND_MEMREF(3, int8_t, Int8
)
147 DECLARE_BIND_MEMREF(1, int16_t, Half
)
148 DECLARE_BIND_MEMREF(2, int16_t, Half
)
149 DECLARE_BIND_MEMREF(3, int16_t, Half
)
151 /// Fills the given 1D float memref with the given float value.
152 VULKAN_WRAPPER_SYMBOL_EXPORT
void
153 _mlir_ciface_fillResource1DFloat(MemRefDescriptor
<float, 1> *ptr
, // NOLINT
155 std::fill_n(ptr
->allocated
, ptr
->sizes
[0], value
);
158 /// Fills the given 2D float memref with the given float value.
159 VULKAN_WRAPPER_SYMBOL_EXPORT
void
160 _mlir_ciface_fillResource2DFloat(MemRefDescriptor
<float, 2> *ptr
, // NOLINT
162 std::fill_n(ptr
->allocated
, ptr
->sizes
[0] * ptr
->sizes
[1], value
);
165 /// Fills the given 3D float memref with the given float value.
166 VULKAN_WRAPPER_SYMBOL_EXPORT
void
167 _mlir_ciface_fillResource3DFloat(MemRefDescriptor
<float, 3> *ptr
, // NOLINT
169 std::fill_n(ptr
->allocated
, ptr
->sizes
[0] * ptr
->sizes
[1] * ptr
->sizes
[2],
173 /// Fills the given 1D int memref with the given int value.
174 VULKAN_WRAPPER_SYMBOL_EXPORT
void
175 _mlir_ciface_fillResource1DInt(MemRefDescriptor
<int32_t, 1> *ptr
, // NOLINT
177 std::fill_n(ptr
->allocated
, ptr
->sizes
[0], value
);
180 /// Fills the given 2D int memref with the given int value.
181 VULKAN_WRAPPER_SYMBOL_EXPORT
void
182 _mlir_ciface_fillResource2DInt(MemRefDescriptor
<int32_t, 2> *ptr
, // NOLINT
184 std::fill_n(ptr
->allocated
, ptr
->sizes
[0] * ptr
->sizes
[1], value
);
187 /// Fills the given 3D int memref with the given int value.
188 VULKAN_WRAPPER_SYMBOL_EXPORT
void
189 _mlir_ciface_fillResource3DInt(MemRefDescriptor
<int32_t, 3> *ptr
, // NOLINT
191 std::fill_n(ptr
->allocated
, ptr
->sizes
[0] * ptr
->sizes
[1] * ptr
->sizes
[2],
195 /// Fills the given 1D int memref with the given int8 value.
196 VULKAN_WRAPPER_SYMBOL_EXPORT
void
197 _mlir_ciface_fillResource1DInt8(MemRefDescriptor
<int8_t, 1> *ptr
, // NOLINT
199 std::fill_n(ptr
->allocated
, ptr
->sizes
[0], value
);
202 /// Fills the given 2D int memref with the given int8 value.
203 VULKAN_WRAPPER_SYMBOL_EXPORT
void
204 _mlir_ciface_fillResource2DInt8(MemRefDescriptor
<int8_t, 2> *ptr
, // NOLINT
206 std::fill_n(ptr
->allocated
, ptr
->sizes
[0] * ptr
->sizes
[1], value
);
209 /// Fills the given 3D int memref with the given int8 value.
210 VULKAN_WRAPPER_SYMBOL_EXPORT
void
211 _mlir_ciface_fillResource3DInt8(MemRefDescriptor
<int8_t, 3> *ptr
, // NOLINT
213 std::fill_n(ptr
->allocated
, ptr
->sizes
[0] * ptr
->sizes
[1] * ptr
->sizes
[2],