[RISCV][NFC] precommit for D159399
[llvm-project.git] / mlir / tools / mlir-vulkan-runner / VulkanRuntime.h
blob9fa52b00a0ac917a37bfc5fc81130acc0fbd771b
1 //===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares Vulkan runtime API.
11 //===----------------------------------------------------------------------===//
13 #ifndef VULKAN_RUNTIME_H
14 #define VULKAN_RUNTIME_H
16 #include "mlir/Support/LogicalResult.h"
18 #include <unordered_map>
19 #include <vector>
20 #include <vulkan/vulkan.h>
22 using namespace mlir;
24 using DescriptorSetIndex = uint32_t;
25 using BindingIndex = uint32_t;
27 /// Struct containing information regarding to a device memory buffer.
28 struct VulkanDeviceMemoryBuffer {
29 BindingIndex bindingIndex{0};
30 VkDescriptorType descriptorType{VK_DESCRIPTOR_TYPE_MAX_ENUM};
31 VkDescriptorBufferInfo bufferInfo{};
32 VkBuffer hostBuffer{VK_NULL_HANDLE};
33 VkDeviceMemory hostMemory{VK_NULL_HANDLE};
34 VkBuffer deviceBuffer{VK_NULL_HANDLE};
35 VkDeviceMemory deviceMemory{VK_NULL_HANDLE};
36 uint32_t bufferSize{0};
39 /// Struct containing information regarding to a host memory buffer.
40 struct VulkanHostMemoryBuffer {
41 /// Pointer to a host memory.
42 void *ptr{nullptr};
43 /// Size of a host memory in bytes.
44 uint32_t size{0};
47 /// Struct containing the number of local workgroups to dispatch for each
48 /// dimension.
49 struct NumWorkGroups {
50 uint32_t x{1};
51 uint32_t y{1};
52 uint32_t z{1};
55 /// Struct containing information regarding a descriptor set.
56 struct DescriptorSetInfo {
57 /// Index of a descriptor set in descriptor sets.
58 DescriptorSetIndex descriptorSet{0};
59 /// Number of descriptors in a set.
60 uint32_t descriptorSize{0};
61 /// Type of a descriptor set.
62 VkDescriptorType descriptorType{VK_DESCRIPTOR_TYPE_MAX_ENUM};
65 /// VulkanHostMemoryBuffer mapped into a descriptor set and a binding.
66 using ResourceData = std::unordered_map<
67 DescriptorSetIndex,
68 std::unordered_map<BindingIndex, VulkanHostMemoryBuffer>>;
70 /// SPIR-V storage classes.
71 /// Note that this duplicates spirv::StorageClass but it keeps the Vulkan
72 /// runtime library detached from SPIR-V dialect, so we can avoid pick up lots
73 /// of dependencies.
74 enum class SPIRVStorageClass {
75 Uniform = 2,
76 StorageBuffer = 12,
79 /// StorageClass mapped into a descriptor set and a binding.
80 using ResourceStorageClassBindingMap =
81 std::unordered_map<DescriptorSetIndex,
82 std::unordered_map<BindingIndex, SPIRVStorageClass>>;
84 /// Vulkan runtime.
85 /// The purpose of this class is to run SPIR-V compute shader on Vulkan
86 /// device.
87 /// Before the run, user must provide and set resource data with descriptors,
88 /// SPIR-V shader, number of work groups and entry point. After the creation of
89 /// VulkanRuntime, special methods must be called in the following
90 /// sequence: initRuntime(), run(), updateHostMemoryBuffers(), destroy();
91 /// each method in the sequence returns success or failure depends on the Vulkan
92 /// result code.
93 class VulkanRuntime {
94 public:
95 explicit VulkanRuntime() = default;
96 VulkanRuntime(const VulkanRuntime &) = delete;
97 VulkanRuntime &operator=(const VulkanRuntime &) = delete;
99 /// Sets needed data for Vulkan runtime.
100 void setResourceData(const ResourceData &resData);
101 void setResourceData(const DescriptorSetIndex desIndex,
102 const BindingIndex bindIndex,
103 const VulkanHostMemoryBuffer &hostMemBuffer);
104 void setShaderModule(uint8_t *shader, uint32_t size);
105 void setNumWorkGroups(const NumWorkGroups &numberWorkGroups);
106 void setResourceStorageClassBindingMap(
107 const ResourceStorageClassBindingMap &stClassData);
108 void setEntryPoint(const char *entryPointName);
110 /// Runtime initialization.
111 LogicalResult initRuntime();
113 /// Runs runtime.
114 LogicalResult run();
116 /// Updates host memory buffers.
117 LogicalResult updateHostMemoryBuffers();
119 /// Destroys all created vulkan objects and resources.
120 LogicalResult destroy();
122 private:
123 //===--------------------------------------------------------------------===//
124 // Pipeline creation methods.
125 //===--------------------------------------------------------------------===//
127 LogicalResult createInstance();
128 LogicalResult createDevice();
129 LogicalResult getBestComputeQueue();
130 LogicalResult createMemoryBuffers();
131 LogicalResult createShaderModule();
132 void initDescriptorSetLayoutBindingMap();
133 LogicalResult createDescriptorSetLayout();
134 LogicalResult createPipelineLayout();
135 LogicalResult createComputePipeline();
136 LogicalResult createDescriptorPool();
137 LogicalResult allocateDescriptorSets();
138 LogicalResult setWriteDescriptors();
139 LogicalResult createCommandPool();
140 LogicalResult createQueryPool();
141 LogicalResult createComputeCommandBuffer();
142 LogicalResult submitCommandBuffersToQueue();
143 // Copy resources from host (staging buffer) to device buffer or from device
144 // buffer to host buffer.
145 LogicalResult copyResource(bool deviceToHost);
147 //===--------------------------------------------------------------------===//
148 // Helper methods.
149 //===--------------------------------------------------------------------===//
151 /// Maps storage class to a descriptor type.
152 LogicalResult
153 mapStorageClassToDescriptorType(SPIRVStorageClass storageClass,
154 VkDescriptorType &descriptorType);
156 /// Maps storage class to buffer usage flags.
157 LogicalResult
158 mapStorageClassToBufferUsageFlag(SPIRVStorageClass storageClass,
159 VkBufferUsageFlagBits &bufferUsage);
161 LogicalResult countDeviceMemorySize();
163 //===--------------------------------------------------------------------===//
164 // Vulkan objects.
165 //===--------------------------------------------------------------------===//
167 VkInstance instance{VK_NULL_HANDLE};
168 VkPhysicalDevice physicalDevice{VK_NULL_HANDLE};
169 VkDevice device{VK_NULL_HANDLE};
170 VkQueue queue{VK_NULL_HANDLE};
172 /// Specifies VulkanDeviceMemoryBuffers divided into sets.
173 std::unordered_map<DescriptorSetIndex, std::vector<VulkanDeviceMemoryBuffer>>
174 deviceMemoryBufferMap;
176 /// Specifies shader module.
177 VkShaderModule shaderModule{VK_NULL_HANDLE};
179 /// Specifies layout bindings.
180 std::unordered_map<DescriptorSetIndex,
181 std::vector<VkDescriptorSetLayoutBinding>>
182 descriptorSetLayoutBindingMap;
184 /// Specifies layouts of descriptor sets.
185 std::vector<VkDescriptorSetLayout> descriptorSetLayouts;
186 VkPipelineLayout pipelineLayout{VK_NULL_HANDLE};
188 /// Specifies descriptor sets.
189 std::vector<VkDescriptorSet> descriptorSets;
191 /// Specifies a pool of descriptor set info, each descriptor set must have
192 /// information such as type, index and amount of bindings.
193 std::vector<DescriptorSetInfo> descriptorSetInfoPool;
194 VkDescriptorPool descriptorPool{VK_NULL_HANDLE};
196 /// Timestamp query.
197 VkQueryPool queryPool{VK_NULL_HANDLE};
198 // Number of nonoseconds for timestamp to increase 1
199 float timestampPeriod{0.f};
201 /// Computation pipeline.
202 VkPipeline pipeline{VK_NULL_HANDLE};
203 VkCommandPool commandPool{VK_NULL_HANDLE};
204 std::vector<VkCommandBuffer> commandBuffers;
206 //===--------------------------------------------------------------------===//
207 // Vulkan memory context.
208 //===--------------------------------------------------------------------===//
210 uint32_t queueFamilyIndex{0};
211 VkQueueFamilyProperties queueFamilyProperties{};
212 uint32_t hostMemoryTypeIndex{VK_MAX_MEMORY_TYPES};
213 uint32_t deviceMemoryTypeIndex{VK_MAX_MEMORY_TYPES};
214 VkDeviceSize memorySize{0};
216 //===--------------------------------------------------------------------===//
217 // Vulkan execution context.
218 //===--------------------------------------------------------------------===//
220 NumWorkGroups numWorkGroups;
221 const char *entryPoint{nullptr};
222 uint8_t *binary{nullptr};
223 uint32_t binarySize{0};
225 //===--------------------------------------------------------------------===//
226 // Vulkan resource data and storage classes.
227 //===--------------------------------------------------------------------===//
229 ResourceData resourceData;
230 ResourceStorageClassBindingMap resourceStorageClassData;
232 #endif