1 //===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file declares Vulkan runtime API.
11 //===----------------------------------------------------------------------===//
13 #ifndef VULKAN_RUNTIME_H
14 #define VULKAN_RUNTIME_H
16 #include "mlir/Support/LogicalResult.h"
18 #include <unordered_map>
20 #include <vulkan/vulkan.h>
24 using DescriptorSetIndex
= uint32_t;
25 using BindingIndex
= uint32_t;
27 /// Struct containing information regarding to a device memory buffer.
28 struct VulkanDeviceMemoryBuffer
{
29 BindingIndex bindingIndex
{0};
30 VkDescriptorType descriptorType
{VK_DESCRIPTOR_TYPE_MAX_ENUM
};
31 VkDescriptorBufferInfo bufferInfo
{};
32 VkBuffer hostBuffer
{VK_NULL_HANDLE
};
33 VkDeviceMemory hostMemory
{VK_NULL_HANDLE
};
34 VkBuffer deviceBuffer
{VK_NULL_HANDLE
};
35 VkDeviceMemory deviceMemory
{VK_NULL_HANDLE
};
36 uint32_t bufferSize
{0};
39 /// Struct containing information regarding to a host memory buffer.
40 struct VulkanHostMemoryBuffer
{
41 /// Pointer to a host memory.
43 /// Size of a host memory in bytes.
47 /// Struct containing the number of local workgroups to dispatch for each
49 struct NumWorkGroups
{
55 /// Struct containing information regarding a descriptor set.
56 struct DescriptorSetInfo
{
57 /// Index of a descriptor set in descriptor sets.
58 DescriptorSetIndex descriptorSet
{0};
59 /// Number of descriptors in a set.
60 uint32_t descriptorSize
{0};
61 /// Type of a descriptor set.
62 VkDescriptorType descriptorType
{VK_DESCRIPTOR_TYPE_MAX_ENUM
};
65 /// VulkanHostMemoryBuffer mapped into a descriptor set and a binding.
66 using ResourceData
= std::unordered_map
<
68 std::unordered_map
<BindingIndex
, VulkanHostMemoryBuffer
>>;
70 /// SPIR-V storage classes.
71 /// Note that this duplicates spirv::StorageClass but it keeps the Vulkan
72 /// runtime library detached from SPIR-V dialect, so we can avoid pick up lots
74 enum class SPIRVStorageClass
{
79 /// StorageClass mapped into a descriptor set and a binding.
80 using ResourceStorageClassBindingMap
=
81 std::unordered_map
<DescriptorSetIndex
,
82 std::unordered_map
<BindingIndex
, SPIRVStorageClass
>>;
85 /// The purpose of this class is to run SPIR-V compute shader on Vulkan
87 /// Before the run, user must provide and set resource data with descriptors,
88 /// SPIR-V shader, number of work groups and entry point. After the creation of
89 /// VulkanRuntime, special methods must be called in the following
90 /// sequence: initRuntime(), run(), updateHostMemoryBuffers(), destroy();
91 /// each method in the sequence returns success or failure depends on the Vulkan
95 explicit VulkanRuntime() = default;
96 VulkanRuntime(const VulkanRuntime
&) = delete;
97 VulkanRuntime
&operator=(const VulkanRuntime
&) = delete;
99 /// Sets needed data for Vulkan runtime.
100 void setResourceData(const ResourceData
&resData
);
101 void setResourceData(const DescriptorSetIndex desIndex
,
102 const BindingIndex bindIndex
,
103 const VulkanHostMemoryBuffer
&hostMemBuffer
);
104 void setShaderModule(uint8_t *shader
, uint32_t size
);
105 void setNumWorkGroups(const NumWorkGroups
&numberWorkGroups
);
106 void setResourceStorageClassBindingMap(
107 const ResourceStorageClassBindingMap
&stClassData
);
108 void setEntryPoint(const char *entryPointName
);
110 /// Runtime initialization.
111 LogicalResult
initRuntime();
116 /// Updates host memory buffers.
117 LogicalResult
updateHostMemoryBuffers();
119 /// Destroys all created vulkan objects and resources.
120 LogicalResult
destroy();
123 //===--------------------------------------------------------------------===//
124 // Pipeline creation methods.
125 //===--------------------------------------------------------------------===//
127 LogicalResult
createInstance();
128 LogicalResult
createDevice();
129 LogicalResult
getBestComputeQueue();
130 LogicalResult
createMemoryBuffers();
131 LogicalResult
createShaderModule();
132 void initDescriptorSetLayoutBindingMap();
133 LogicalResult
createDescriptorSetLayout();
134 LogicalResult
createPipelineLayout();
135 LogicalResult
createComputePipeline();
136 LogicalResult
createDescriptorPool();
137 LogicalResult
allocateDescriptorSets();
138 LogicalResult
setWriteDescriptors();
139 LogicalResult
createCommandPool();
140 LogicalResult
createQueryPool();
141 LogicalResult
createComputeCommandBuffer();
142 LogicalResult
submitCommandBuffersToQueue();
143 // Copy resources from host (staging buffer) to device buffer or from device
144 // buffer to host buffer.
145 LogicalResult
copyResource(bool deviceToHost
);
147 //===--------------------------------------------------------------------===//
149 //===--------------------------------------------------------------------===//
151 /// Maps storage class to a descriptor type.
153 mapStorageClassToDescriptorType(SPIRVStorageClass storageClass
,
154 VkDescriptorType
&descriptorType
);
156 /// Maps storage class to buffer usage flags.
158 mapStorageClassToBufferUsageFlag(SPIRVStorageClass storageClass
,
159 VkBufferUsageFlagBits
&bufferUsage
);
161 LogicalResult
countDeviceMemorySize();
163 //===--------------------------------------------------------------------===//
165 //===--------------------------------------------------------------------===//
167 VkInstance instance
{VK_NULL_HANDLE
};
168 VkPhysicalDevice physicalDevice
{VK_NULL_HANDLE
};
169 VkDevice device
{VK_NULL_HANDLE
};
170 VkQueue queue
{VK_NULL_HANDLE
};
172 /// Specifies VulkanDeviceMemoryBuffers divided into sets.
173 std::unordered_map
<DescriptorSetIndex
, std::vector
<VulkanDeviceMemoryBuffer
>>
174 deviceMemoryBufferMap
;
176 /// Specifies shader module.
177 VkShaderModule shaderModule
{VK_NULL_HANDLE
};
179 /// Specifies layout bindings.
180 std::unordered_map
<DescriptorSetIndex
,
181 std::vector
<VkDescriptorSetLayoutBinding
>>
182 descriptorSetLayoutBindingMap
;
184 /// Specifies layouts of descriptor sets.
185 std::vector
<VkDescriptorSetLayout
> descriptorSetLayouts
;
186 VkPipelineLayout pipelineLayout
{VK_NULL_HANDLE
};
188 /// Specifies descriptor sets.
189 std::vector
<VkDescriptorSet
> descriptorSets
;
191 /// Specifies a pool of descriptor set info, each descriptor set must have
192 /// information such as type, index and amount of bindings.
193 std::vector
<DescriptorSetInfo
> descriptorSetInfoPool
;
194 VkDescriptorPool descriptorPool
{VK_NULL_HANDLE
};
197 VkQueryPool queryPool
{VK_NULL_HANDLE
};
198 // Number of nonoseconds for timestamp to increase 1
199 float timestampPeriod
{0.f
};
201 /// Computation pipeline.
202 VkPipeline pipeline
{VK_NULL_HANDLE
};
203 VkCommandPool commandPool
{VK_NULL_HANDLE
};
204 std::vector
<VkCommandBuffer
> commandBuffers
;
206 //===--------------------------------------------------------------------===//
207 // Vulkan memory context.
208 //===--------------------------------------------------------------------===//
210 uint32_t queueFamilyIndex
{0};
211 VkQueueFamilyProperties queueFamilyProperties
{};
212 uint32_t hostMemoryTypeIndex
{VK_MAX_MEMORY_TYPES
};
213 uint32_t deviceMemoryTypeIndex
{VK_MAX_MEMORY_TYPES
};
214 VkDeviceSize memorySize
{0};
216 //===--------------------------------------------------------------------===//
217 // Vulkan execution context.
218 //===--------------------------------------------------------------------===//
220 NumWorkGroups numWorkGroups
;
221 const char *entryPoint
{nullptr};
222 uint8_t *binary
{nullptr};
223 uint32_t binarySize
{0};
225 //===--------------------------------------------------------------------===//
226 // Vulkan resource data and storage classes.
227 //===--------------------------------------------------------------------===//
229 ResourceData resourceData
;
230 ResourceStorageClassBindingMap resourceStorageClassData
;