1 //===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file provides a library for running a module on a Vulkan device.
10 // Implements a Vulkan runtime.
12 //===----------------------------------------------------------------------===//
14 #include "VulkanRuntime.h"
18 // TODO: It's generally bad to access stdout/stderr in a library.
19 // Figure out a better way for error reporting.
23 inline void emitVulkanError(const char *api
, VkResult error
) {
24 std::cerr
<< " failed with error code " << error
<< " when executing " << api
;
27 #define RETURN_ON_VULKAN_ERROR(result, api) \
28 if ((result) != VK_SUCCESS) { \
29 emitVulkanError(api, (result)); \
35 void VulkanRuntime::setNumWorkGroups(const NumWorkGroups
&numberWorkGroups
) {
36 numWorkGroups
= numberWorkGroups
;
39 void VulkanRuntime::setResourceStorageClassBindingMap(
40 const ResourceStorageClassBindingMap
&stClassData
) {
41 resourceStorageClassData
= stClassData
;
44 void VulkanRuntime::setResourceData(
45 const DescriptorSetIndex desIndex
, const BindingIndex bindIndex
,
46 const VulkanHostMemoryBuffer
&hostMemBuffer
) {
47 resourceData
[desIndex
][bindIndex
] = hostMemBuffer
;
48 resourceStorageClassData
[desIndex
][bindIndex
] =
49 SPIRVStorageClass::StorageBuffer
;
52 void VulkanRuntime::setEntryPoint(const char *entryPointName
) {
53 entryPoint
= entryPointName
;
56 void VulkanRuntime::setResourceData(const ResourceData
&resData
) {
57 resourceData
= resData
;
60 void VulkanRuntime::setShaderModule(uint8_t *shader
, uint32_t size
) {
65 LogicalResult
VulkanRuntime::mapStorageClassToDescriptorType(
66 SPIRVStorageClass storageClass
, VkDescriptorType
&descriptorType
) {
67 switch (storageClass
) {
68 case SPIRVStorageClass::StorageBuffer
:
69 descriptorType
= VK_DESCRIPTOR_TYPE_STORAGE_BUFFER
;
71 case SPIRVStorageClass::Uniform
:
72 descriptorType
= VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER
;
78 LogicalResult
VulkanRuntime::mapStorageClassToBufferUsageFlag(
79 SPIRVStorageClass storageClass
, VkBufferUsageFlagBits
&bufferUsage
) {
80 switch (storageClass
) {
81 case SPIRVStorageClass::StorageBuffer
:
82 bufferUsage
= VK_BUFFER_USAGE_STORAGE_BUFFER_BIT
;
84 case SPIRVStorageClass::Uniform
:
85 bufferUsage
= VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT
;
91 LogicalResult
VulkanRuntime::countDeviceMemorySize() {
92 for (const auto &resourceDataMapPair
: resourceData
) {
93 const auto &resourceDataMap
= resourceDataMapPair
.second
;
94 for (const auto &resourceDataBindingPair
: resourceDataMap
) {
95 if (resourceDataBindingPair
.second
.size
) {
96 memorySize
+= resourceDataBindingPair
.second
.size
;
98 std::cerr
<< "expected buffer size greater than zero for resource data";
106 LogicalResult
VulkanRuntime::initRuntime() {
107 if (resourceData
.empty()) {
108 std::cerr
<< "Vulkan runtime needs at least one resource";
111 if (!binarySize
|| !binary
) {
112 std::cerr
<< "binary shader size must be greater than zero";
115 if (failed(countDeviceMemorySize())) {
121 LogicalResult
VulkanRuntime::destroy() {
122 // According to Vulkan spec:
123 // "To ensure that no work is active on the device, vkDeviceWaitIdle can be
124 // used to gate the destruction of the device. Prior to destroying a device,
125 // an application is responsible for destroying/freeing any Vulkan objects
126 // that were created using that device as the first parameter of the
127 // corresponding vkCreate* or vkAllocate* command."
128 RETURN_ON_VULKAN_ERROR(vkDeviceWaitIdle(device
), "vkDeviceWaitIdle");
131 vkFreeCommandBuffers(device
, commandPool
, commandBuffers
.size(),
132 commandBuffers
.data());
133 vkDestroyQueryPool(device
, queryPool
, nullptr);
134 vkDestroyCommandPool(device
, commandPool
, nullptr);
135 vkFreeDescriptorSets(device
, descriptorPool
, descriptorSets
.size(),
136 descriptorSets
.data());
137 vkDestroyDescriptorPool(device
, descriptorPool
, nullptr);
138 vkDestroyPipeline(device
, pipeline
, nullptr);
139 vkDestroyPipelineLayout(device
, pipelineLayout
, nullptr);
140 for (auto &descriptorSetLayout
: descriptorSetLayouts
) {
141 vkDestroyDescriptorSetLayout(device
, descriptorSetLayout
, nullptr);
143 vkDestroyShaderModule(device
, shaderModule
, nullptr);
145 // For each descriptor set.
146 for (auto &deviceMemoryBufferMapPair
: deviceMemoryBufferMap
) {
147 auto &deviceMemoryBuffers
= deviceMemoryBufferMapPair
.second
;
148 // For each descriptor binding.
149 for (auto &memoryBuffer
: deviceMemoryBuffers
) {
150 vkFreeMemory(device
, memoryBuffer
.deviceMemory
, nullptr);
151 vkFreeMemory(device
, memoryBuffer
.hostMemory
, nullptr);
152 vkDestroyBuffer(device
, memoryBuffer
.hostBuffer
, nullptr);
153 vkDestroyBuffer(device
, memoryBuffer
.deviceBuffer
, nullptr);
157 vkDestroyDevice(device
, nullptr);
158 vkDestroyInstance(instance
, nullptr);
162 LogicalResult
VulkanRuntime::run() {
163 // Create logical device, shader module and memory buffers.
164 if (failed(createInstance()) || failed(createDevice()) ||
165 failed(createMemoryBuffers()) || failed(createShaderModule())) {
169 // Descriptor bindings divided into sets. Each descriptor binding
170 // must have a layout binding attached into a descriptor set layout.
171 // Each layout set must be binded into a pipeline layout.
172 initDescriptorSetLayoutBindingMap();
173 if (failed(createDescriptorSetLayout()) || failed(createPipelineLayout()) ||
174 // Each descriptor set must be allocated from a descriptor pool.
175 failed(createComputePipeline()) || failed(createDescriptorPool()) ||
176 failed(allocateDescriptorSets()) || failed(setWriteDescriptors()) ||
177 // Create command buffer.
178 failed(createCommandPool()) || failed(createQueryPool()) ||
179 failed(createComputeCommandBuffer())) {
183 // Get working queue.
184 vkGetDeviceQueue(device
, queueFamilyIndex
, 0, &queue
);
186 if (failed(copyResource(/*deviceToHost=*/false)))
189 auto submitStart
= std::chrono::high_resolution_clock::now();
190 // Submit command buffer into the queue.
191 if (failed(submitCommandBuffersToQueue()))
193 auto submitEnd
= std::chrono::high_resolution_clock::now();
195 RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue
), "vkQueueWaitIdle");
196 auto execEnd
= std::chrono::high_resolution_clock::now();
198 auto submitDuration
= std::chrono::duration_cast
<std::chrono::microseconds
>(
199 submitEnd
- submitStart
);
200 auto execDuration
= std::chrono::duration_cast
<std::chrono::microseconds
>(
201 execEnd
- submitEnd
);
203 if (queryPool
!= VK_NULL_HANDLE
) {
204 uint64_t timestamps
[2];
205 RETURN_ON_VULKAN_ERROR(
206 vkGetQueryPoolResults(
207 device
, queryPool
, /*firstQuery=*/0, /*queryCount=*/2,
208 /*dataSize=*/sizeof(timestamps
),
209 /*pData=*/reinterpret_cast<void *>(timestamps
),
210 /*stride=*/sizeof(uint64_t),
211 VK_QUERY_RESULT_64_BIT
| VK_QUERY_RESULT_WAIT_BIT
),
212 "vkGetQueryPoolResults");
213 float microsec
= (timestamps
[1] - timestamps
[0]) * timestampPeriod
/ 1000;
214 std::cout
<< "Compute shader execution time: " << std::setprecision(3)
215 << microsec
<< "us\n";
218 std::cout
<< "Command buffer submit time: " << submitDuration
.count()
219 << "us\nWait idle time: " << execDuration
.count() << "us\n";
224 LogicalResult
VulkanRuntime::createInstance() {
225 VkApplicationInfo applicationInfo
= {};
226 applicationInfo
.sType
= VK_STRUCTURE_TYPE_APPLICATION_INFO
;
227 applicationInfo
.pNext
= nullptr;
228 applicationInfo
.pApplicationName
= "MLIR Vulkan runtime";
229 applicationInfo
.applicationVersion
= 0;
230 applicationInfo
.pEngineName
= "mlir";
231 applicationInfo
.engineVersion
= 0;
232 applicationInfo
.apiVersion
= VK_MAKE_VERSION(1, 0, 0);
234 VkInstanceCreateInfo instanceCreateInfo
= {};
235 instanceCreateInfo
.sType
= VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO
;
236 instanceCreateInfo
.pNext
= nullptr;
237 instanceCreateInfo
.pApplicationInfo
= &applicationInfo
;
238 instanceCreateInfo
.enabledLayerCount
= 0;
239 instanceCreateInfo
.ppEnabledLayerNames
= nullptr;
241 std::vector
<const char *> extNames
;
242 #if defined(__APPLE__)
243 // enumerate MoltenVK for Vulkan 1.0
244 instanceCreateInfo
.flags
= VK_INSTANCE_CREATE_ENUMERATE_PORTABILITY_BIT_KHR
;
245 // add KHR portability instance extensions
246 extNames
.push_back(VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME
);
247 extNames
.push_back(VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME
);
249 instanceCreateInfo
.flags
= 0;
251 instanceCreateInfo
.enabledExtensionCount
=
252 static_cast<uint32_t>(extNames
.size());
253 instanceCreateInfo
.ppEnabledExtensionNames
= extNames
.data();
255 RETURN_ON_VULKAN_ERROR(
256 vkCreateInstance(&instanceCreateInfo
, nullptr, &instance
),
261 LogicalResult
VulkanRuntime::createDevice() {
262 uint32_t physicalDeviceCount
= 0;
263 RETURN_ON_VULKAN_ERROR(
264 vkEnumeratePhysicalDevices(instance
, &physicalDeviceCount
, nullptr),
265 "vkEnumeratePhysicalDevices");
267 std::vector
<VkPhysicalDevice
> physicalDevices(physicalDeviceCount
);
268 RETURN_ON_VULKAN_ERROR(vkEnumeratePhysicalDevices(instance
,
269 &physicalDeviceCount
,
270 physicalDevices
.data()),
271 "vkEnumeratePhysicalDevices");
273 RETURN_ON_VULKAN_ERROR(physicalDeviceCount
? VK_SUCCESS
: VK_INCOMPLETE
,
274 "physicalDeviceCount");
276 // TODO: find the best device.
277 physicalDevice
= physicalDevices
.front();
278 if (failed(getBestComputeQueue()))
281 const float queuePriority
= 1.0f
;
282 VkDeviceQueueCreateInfo deviceQueueCreateInfo
= {};
283 deviceQueueCreateInfo
.sType
= VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO
;
284 deviceQueueCreateInfo
.pNext
= nullptr;
285 deviceQueueCreateInfo
.flags
= 0;
286 deviceQueueCreateInfo
.queueFamilyIndex
= queueFamilyIndex
;
287 deviceQueueCreateInfo
.queueCount
= 1;
288 deviceQueueCreateInfo
.pQueuePriorities
= &queuePriority
;
290 // Structure specifying parameters of a newly created device.
291 VkDeviceCreateInfo deviceCreateInfo
= {};
292 deviceCreateInfo
.sType
= VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO
;
293 deviceCreateInfo
.pNext
= nullptr;
294 deviceCreateInfo
.flags
= 0;
295 deviceCreateInfo
.queueCreateInfoCount
= 1;
296 deviceCreateInfo
.pQueueCreateInfos
= &deviceQueueCreateInfo
;
297 deviceCreateInfo
.enabledLayerCount
= 0;
298 deviceCreateInfo
.ppEnabledLayerNames
= nullptr;
299 deviceCreateInfo
.enabledExtensionCount
= 0;
300 deviceCreateInfo
.ppEnabledExtensionNames
= nullptr;
301 deviceCreateInfo
.pEnabledFeatures
= nullptr;
303 RETURN_ON_VULKAN_ERROR(
304 vkCreateDevice(physicalDevice
, &deviceCreateInfo
, nullptr, &device
),
307 VkPhysicalDeviceMemoryProperties properties
= {};
308 vkGetPhysicalDeviceMemoryProperties(physicalDevice
, &properties
);
310 // Try to find memory type with following properties:
311 // VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT bit specifies that memory allocated
312 // with this type can be mapped for host access using vkMapMemory;
313 // VK_MEMORY_PROPERTY_HOST_COHERENT_BIT bit specifies that the host cache
314 // management commands vkFlushMappedMemoryRanges and
315 // vkInvalidateMappedMemoryRanges are not needed to flush host writes to the
316 // device or make device writes visible to the host, respectively.
317 for (uint32_t i
= 0, e
= properties
.memoryTypeCount
; i
< e
; ++i
) {
318 if ((VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT
&
319 properties
.memoryTypes
[i
].propertyFlags
) &&
320 (VK_MEMORY_PROPERTY_HOST_COHERENT_BIT
&
321 properties
.memoryTypes
[i
].propertyFlags
) &&
323 properties
.memoryHeaps
[properties
.memoryTypes
[i
].heapIndex
].size
)) {
324 hostMemoryTypeIndex
= i
;
329 // Find memory type memory type with VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT to be
330 // used on the device. This will allow better performance access for GPU with
332 for (uint32_t i
= 0, e
= properties
.memoryTypeCount
; i
< e
; ++i
) {
333 if ((VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT
&
334 properties
.memoryTypes
[i
].propertyFlags
) &&
336 properties
.memoryHeaps
[properties
.memoryTypes
[i
].heapIndex
].size
)) {
337 deviceMemoryTypeIndex
= i
;
342 RETURN_ON_VULKAN_ERROR((hostMemoryTypeIndex
== VK_MAX_MEMORY_TYPES
||
343 deviceMemoryTypeIndex
== VK_MAX_MEMORY_TYPES
)
346 "invalid memoryTypeIndex");
350 LogicalResult
VulkanRuntime::getBestComputeQueue() {
351 uint32_t queueFamilyPropertiesCount
= 0;
352 vkGetPhysicalDeviceQueueFamilyProperties(
353 physicalDevice
, &queueFamilyPropertiesCount
, nullptr);
355 std::vector
<VkQueueFamilyProperties
> familyProperties(
356 queueFamilyPropertiesCount
);
357 vkGetPhysicalDeviceQueueFamilyProperties(
358 physicalDevice
, &queueFamilyPropertiesCount
, familyProperties
.data());
360 // VK_QUEUE_COMPUTE_BIT specifies that queues in this queue family support
361 // compute operations. Try to find a compute-only queue first if possible.
362 for (uint32_t i
= 0; i
< queueFamilyPropertiesCount
; ++i
) {
363 auto flags
= familyProperties
[i
].queueFlags
;
364 if ((flags
& VK_QUEUE_COMPUTE_BIT
) && !(flags
& VK_QUEUE_GRAPHICS_BIT
)) {
365 queueFamilyIndex
= i
;
366 queueFamilyProperties
= familyProperties
[i
];
371 // Otherwise use a queue that can also support graphics.
372 for (uint32_t i
= 0; i
< queueFamilyPropertiesCount
; ++i
) {
373 auto flags
= familyProperties
[i
].queueFlags
;
374 if ((flags
& VK_QUEUE_COMPUTE_BIT
)) {
375 queueFamilyIndex
= i
;
376 queueFamilyProperties
= familyProperties
[i
];
381 std::cerr
<< "cannot find valid queue";
385 LogicalResult
VulkanRuntime::createMemoryBuffers() {
386 // For each descriptor set.
387 for (const auto &resourceDataMapPair
: resourceData
) {
388 std::vector
<VulkanDeviceMemoryBuffer
> deviceMemoryBuffers
;
389 const auto descriptorSetIndex
= resourceDataMapPair
.first
;
390 const auto &resourceDataMap
= resourceDataMapPair
.second
;
392 // For each descriptor binding.
393 for (const auto &resourceDataBindingPair
: resourceDataMap
) {
394 // Create device memory buffer.
395 VulkanDeviceMemoryBuffer memoryBuffer
;
396 memoryBuffer
.bindingIndex
= resourceDataBindingPair
.first
;
397 VkDescriptorType descriptorType
= {};
398 VkBufferUsageFlagBits bufferUsage
= {};
400 // Check that descriptor set has storage class map.
401 const auto resourceStorageClassMapIt
=
402 resourceStorageClassData
.find(descriptorSetIndex
);
403 if (resourceStorageClassMapIt
== resourceStorageClassData
.end()) {
405 << "cannot find storage class for resource in descriptor set: "
406 << descriptorSetIndex
;
410 // Check that specific descriptor binding has storage class.
411 const auto &resourceStorageClassMap
= resourceStorageClassMapIt
->second
;
412 const auto resourceStorageClassIt
=
413 resourceStorageClassMap
.find(resourceDataBindingPair
.first
);
414 if (resourceStorageClassIt
== resourceStorageClassMap
.end()) {
416 << "cannot find storage class for resource with descriptor index: "
417 << resourceDataBindingPair
.first
;
421 const auto resourceStorageClassBinding
= resourceStorageClassIt
->second
;
422 if (failed(mapStorageClassToDescriptorType(resourceStorageClassBinding
,
424 failed(mapStorageClassToBufferUsageFlag(resourceStorageClassBinding
,
426 std::cerr
<< "storage class for resource with descriptor binding: "
427 << resourceDataBindingPair
.first
428 << " in the descriptor set: " << descriptorSetIndex
429 << " is not supported ";
433 // Set descriptor type for the specific device memory buffer.
434 memoryBuffer
.descriptorType
= descriptorType
;
435 const auto bufferSize
= resourceDataBindingPair
.second
.size
;
436 memoryBuffer
.bufferSize
= bufferSize
;
437 // Specify memory allocation info.
438 VkMemoryAllocateInfo memoryAllocateInfo
= {};
439 memoryAllocateInfo
.sType
= VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO
;
440 memoryAllocateInfo
.pNext
= nullptr;
441 memoryAllocateInfo
.allocationSize
= bufferSize
;
442 memoryAllocateInfo
.memoryTypeIndex
= hostMemoryTypeIndex
;
444 // Allocate device memory.
445 RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device
, &memoryAllocateInfo
,
447 &memoryBuffer
.hostMemory
),
449 memoryAllocateInfo
.memoryTypeIndex
= deviceMemoryTypeIndex
;
450 RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device
, &memoryAllocateInfo
,
452 &memoryBuffer
.deviceMemory
),
455 RETURN_ON_VULKAN_ERROR(vkMapMemory(device
, memoryBuffer
.hostMemory
, 0,
457 reinterpret_cast<void **>(&payload
)),
460 // Copy host memory into the mapped area.
461 std::memcpy(payload
, resourceDataBindingPair
.second
.ptr
, bufferSize
);
462 vkUnmapMemory(device
, memoryBuffer
.hostMemory
);
464 VkBufferCreateInfo bufferCreateInfo
= {};
465 bufferCreateInfo
.sType
= VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO
;
466 bufferCreateInfo
.pNext
= nullptr;
467 bufferCreateInfo
.flags
= 0;
468 bufferCreateInfo
.size
= bufferSize
;
469 bufferCreateInfo
.usage
= bufferUsage
| VK_BUFFER_USAGE_TRANSFER_DST_BIT
|
470 VK_BUFFER_USAGE_TRANSFER_SRC_BIT
;
471 bufferCreateInfo
.sharingMode
= VK_SHARING_MODE_EXCLUSIVE
;
472 bufferCreateInfo
.queueFamilyIndexCount
= 1;
473 bufferCreateInfo
.pQueueFamilyIndices
= &queueFamilyIndex
;
474 RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device
, &bufferCreateInfo
, nullptr,
475 &memoryBuffer
.hostBuffer
),
477 RETURN_ON_VULKAN_ERROR(vkCreateBuffer(device
, &bufferCreateInfo
, nullptr,
478 &memoryBuffer
.deviceBuffer
),
481 // Bind buffer and device memory.
482 RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device
, memoryBuffer
.hostBuffer
,
483 memoryBuffer
.hostMemory
, 0),
484 "vkBindBufferMemory");
485 RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device
,
486 memoryBuffer
.deviceBuffer
,
487 memoryBuffer
.deviceMemory
, 0),
488 "vkBindBufferMemory");
490 // Update buffer info.
491 memoryBuffer
.bufferInfo
.buffer
= memoryBuffer
.deviceBuffer
;
492 memoryBuffer
.bufferInfo
.offset
= 0;
493 memoryBuffer
.bufferInfo
.range
= VK_WHOLE_SIZE
;
494 deviceMemoryBuffers
.push_back(memoryBuffer
);
497 // Associate device memory buffers with a descriptor set.
498 deviceMemoryBufferMap
[descriptorSetIndex
] = deviceMemoryBuffers
;
503 LogicalResult
VulkanRuntime::copyResource(bool deviceToHost
) {
504 VkCommandBufferAllocateInfo commandBufferAllocateInfo
= {
505 VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO
,
508 VK_COMMAND_BUFFER_LEVEL_PRIMARY
,
511 VkCommandBuffer commandBuffer
;
512 RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device
,
513 &commandBufferAllocateInfo
,
515 "vkAllocateCommandBuffers");
517 VkCommandBufferBeginInfo commandBufferBeginInfo
= {
518 VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO
,
523 RETURN_ON_VULKAN_ERROR(
524 vkBeginCommandBuffer(commandBuffer
, &commandBufferBeginInfo
),
525 "vkBeginCommandBuffer");
527 for (const auto &deviceMemoryBufferMapPair
: deviceMemoryBufferMap
) {
528 std::vector
<VkDescriptorSetLayoutBinding
> descriptorSetLayoutBindings
;
529 const auto &deviceMemoryBuffers
= deviceMemoryBufferMapPair
.second
;
530 for (const auto &memBuffer
: deviceMemoryBuffers
) {
531 VkBufferCopy copy
= {0, 0, memBuffer
.bufferSize
};
533 vkCmdCopyBuffer(commandBuffer
, memBuffer
.deviceBuffer
,
534 memBuffer
.hostBuffer
, 1, ©
);
536 vkCmdCopyBuffer(commandBuffer
, memBuffer
.hostBuffer
,
537 memBuffer
.deviceBuffer
, 1, ©
);
541 RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer
),
542 "vkEndCommandBuffer");
543 VkSubmitInfo submitInfo
= {
544 VK_STRUCTURE_TYPE_SUBMIT_INFO
,
554 submitInfo
.pCommandBuffers
= &commandBuffer
;
555 RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue
, 1, &submitInfo
, VK_NULL_HANDLE
),
557 RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue
), "vkQueueWaitIdle");
559 vkFreeCommandBuffers(device
, commandPool
, 1, &commandBuffer
);
563 LogicalResult
VulkanRuntime::createShaderModule() {
564 VkShaderModuleCreateInfo shaderModuleCreateInfo
= {};
565 shaderModuleCreateInfo
.sType
= VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO
;
566 shaderModuleCreateInfo
.pNext
= nullptr;
567 shaderModuleCreateInfo
.flags
= 0;
568 // Set size in bytes.
569 shaderModuleCreateInfo
.codeSize
= binarySize
;
570 // Set pointer to the binary shader.
571 shaderModuleCreateInfo
.pCode
= reinterpret_cast<uint32_t *>(binary
);
572 RETURN_ON_VULKAN_ERROR(vkCreateShaderModule(device
, &shaderModuleCreateInfo
,
573 nullptr, &shaderModule
),
574 "vkCreateShaderModule");
578 void VulkanRuntime::initDescriptorSetLayoutBindingMap() {
579 for (const auto &deviceMemoryBufferMapPair
: deviceMemoryBufferMap
) {
580 std::vector
<VkDescriptorSetLayoutBinding
> descriptorSetLayoutBindings
;
581 const auto &deviceMemoryBuffers
= deviceMemoryBufferMapPair
.second
;
582 const auto descriptorSetIndex
= deviceMemoryBufferMapPair
.first
;
584 // Create a layout binding for each descriptor.
585 for (const auto &memBuffer
: deviceMemoryBuffers
) {
586 VkDescriptorSetLayoutBinding descriptorSetLayoutBinding
= {};
587 descriptorSetLayoutBinding
.binding
= memBuffer
.bindingIndex
;
588 descriptorSetLayoutBinding
.descriptorType
= memBuffer
.descriptorType
;
589 descriptorSetLayoutBinding
.descriptorCount
= 1;
590 descriptorSetLayoutBinding
.stageFlags
= VK_SHADER_STAGE_COMPUTE_BIT
;
591 descriptorSetLayoutBinding
.pImmutableSamplers
= nullptr;
592 descriptorSetLayoutBindings
.push_back(descriptorSetLayoutBinding
);
594 descriptorSetLayoutBindingMap
[descriptorSetIndex
] =
595 descriptorSetLayoutBindings
;
599 LogicalResult
VulkanRuntime::createDescriptorSetLayout() {
600 for (const auto &deviceMemoryBufferMapPair
: deviceMemoryBufferMap
) {
601 const auto descriptorSetIndex
= deviceMemoryBufferMapPair
.first
;
602 const auto &deviceMemoryBuffers
= deviceMemoryBufferMapPair
.second
;
603 // Each descriptor in a descriptor set must be the same type.
604 VkDescriptorType descriptorType
=
605 deviceMemoryBuffers
.front().descriptorType
;
606 const uint32_t descriptorSize
= deviceMemoryBuffers
.size();
607 const auto descriptorSetLayoutBindingIt
=
608 descriptorSetLayoutBindingMap
.find(descriptorSetIndex
);
610 if (descriptorSetLayoutBindingIt
== descriptorSetLayoutBindingMap
.end()) {
611 std::cerr
<< "cannot find layout bindings for the set with number: "
612 << descriptorSetIndex
;
616 const auto &descriptorSetLayoutBindings
=
617 descriptorSetLayoutBindingIt
->second
;
618 // Create descriptor set layout.
619 VkDescriptorSetLayout descriptorSetLayout
= {};
620 VkDescriptorSetLayoutCreateInfo descriptorSetLayoutCreateInfo
= {};
622 descriptorSetLayoutCreateInfo
.sType
=
623 VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO
;
624 descriptorSetLayoutCreateInfo
.pNext
= nullptr;
625 descriptorSetLayoutCreateInfo
.flags
= 0;
626 // Amount of descriptor bindings in a layout set.
627 descriptorSetLayoutCreateInfo
.bindingCount
=
628 descriptorSetLayoutBindings
.size();
629 descriptorSetLayoutCreateInfo
.pBindings
=
630 descriptorSetLayoutBindings
.data();
631 RETURN_ON_VULKAN_ERROR(
632 vkCreateDescriptorSetLayout(device
, &descriptorSetLayoutCreateInfo
,
633 nullptr, &descriptorSetLayout
),
634 "vkCreateDescriptorSetLayout");
636 descriptorSetLayouts
.push_back(descriptorSetLayout
);
637 descriptorSetInfoPool
.push_back(
638 {descriptorSetIndex
, descriptorSize
, descriptorType
});
643 LogicalResult
VulkanRuntime::createPipelineLayout() {
644 // Associate descriptor sets with a pipeline layout.
645 VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo
= {};
646 pipelineLayoutCreateInfo
.sType
=
647 VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO
;
648 pipelineLayoutCreateInfo
.pNext
= nullptr;
649 pipelineLayoutCreateInfo
.flags
= 0;
650 pipelineLayoutCreateInfo
.setLayoutCount
= descriptorSetLayouts
.size();
651 pipelineLayoutCreateInfo
.pSetLayouts
= descriptorSetLayouts
.data();
652 pipelineLayoutCreateInfo
.pushConstantRangeCount
= 0;
653 pipelineLayoutCreateInfo
.pPushConstantRanges
= nullptr;
654 RETURN_ON_VULKAN_ERROR(vkCreatePipelineLayout(device
,
655 &pipelineLayoutCreateInfo
,
656 nullptr, &pipelineLayout
),
657 "vkCreatePipelineLayout");
661 LogicalResult
VulkanRuntime::createComputePipeline() {
662 VkPipelineShaderStageCreateInfo stageInfo
= {};
663 stageInfo
.sType
= VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO
;
664 stageInfo
.pNext
= nullptr;
666 stageInfo
.stage
= VK_SHADER_STAGE_COMPUTE_BIT
;
667 stageInfo
.module
= shaderModule
;
669 stageInfo
.pName
= entryPoint
;
670 stageInfo
.pSpecializationInfo
= nullptr;
672 VkComputePipelineCreateInfo computePipelineCreateInfo
= {};
673 computePipelineCreateInfo
.sType
=
674 VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO
;
675 computePipelineCreateInfo
.pNext
= nullptr;
676 computePipelineCreateInfo
.flags
= 0;
677 computePipelineCreateInfo
.stage
= stageInfo
;
678 computePipelineCreateInfo
.layout
= pipelineLayout
;
679 computePipelineCreateInfo
.basePipelineHandle
= nullptr;
680 computePipelineCreateInfo
.basePipelineIndex
= 0;
681 RETURN_ON_VULKAN_ERROR(vkCreateComputePipelines(device
, nullptr, 1,
682 &computePipelineCreateInfo
,
684 "vkCreateComputePipelines");
688 LogicalResult
VulkanRuntime::createDescriptorPool() {
689 std::vector
<VkDescriptorPoolSize
> descriptorPoolSizes
;
690 for (const auto &descriptorSetInfo
: descriptorSetInfoPool
) {
691 // For each descriptor set populate descriptor pool size.
692 VkDescriptorPoolSize descriptorPoolSize
= {};
693 descriptorPoolSize
.type
= descriptorSetInfo
.descriptorType
;
694 descriptorPoolSize
.descriptorCount
= descriptorSetInfo
.descriptorSize
;
695 descriptorPoolSizes
.push_back(descriptorPoolSize
);
698 VkDescriptorPoolCreateInfo descriptorPoolCreateInfo
= {};
699 descriptorPoolCreateInfo
.sType
=
700 VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO
;
701 descriptorPoolCreateInfo
.pNext
= nullptr;
702 descriptorPoolCreateInfo
.flags
= 0;
703 descriptorPoolCreateInfo
.maxSets
= descriptorPoolSizes
.size();
704 descriptorPoolCreateInfo
.poolSizeCount
= descriptorPoolSizes
.size();
705 descriptorPoolCreateInfo
.pPoolSizes
= descriptorPoolSizes
.data();
706 RETURN_ON_VULKAN_ERROR(vkCreateDescriptorPool(device
,
707 &descriptorPoolCreateInfo
,
708 nullptr, &descriptorPool
),
709 "vkCreateDescriptorPool");
713 LogicalResult
VulkanRuntime::allocateDescriptorSets() {
714 VkDescriptorSetAllocateInfo descriptorSetAllocateInfo
= {};
715 // Size of descriptor sets and descriptor layout sets is the same.
716 descriptorSets
.resize(descriptorSetLayouts
.size());
717 descriptorSetAllocateInfo
.sType
=
718 VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO
;
719 descriptorSetAllocateInfo
.pNext
= nullptr;
720 descriptorSetAllocateInfo
.descriptorPool
= descriptorPool
;
721 descriptorSetAllocateInfo
.descriptorSetCount
= descriptorSetLayouts
.size();
722 descriptorSetAllocateInfo
.pSetLayouts
= descriptorSetLayouts
.data();
723 RETURN_ON_VULKAN_ERROR(vkAllocateDescriptorSets(device
,
724 &descriptorSetAllocateInfo
,
725 descriptorSets
.data()),
726 "vkAllocateDescriptorSets");
730 LogicalResult
VulkanRuntime::setWriteDescriptors() {
731 if (descriptorSets
.size() != descriptorSetInfoPool
.size()) {
732 std::cerr
<< "Each descriptor set must have descriptor set information";
735 // For each descriptor set.
736 auto descriptorSetIt
= descriptorSets
.begin();
737 // Each descriptor set is associated with descriptor set info.
738 for (const auto &descriptorSetInfo
: descriptorSetInfoPool
) {
739 // For each device memory buffer in the descriptor set.
740 const auto &deviceMemoryBuffers
=
741 deviceMemoryBufferMap
[descriptorSetInfo
.descriptorSet
];
742 for (const auto &memoryBuffer
: deviceMemoryBuffers
) {
743 // Structure describing descriptor sets to write to.
744 VkWriteDescriptorSet wSet
= {};
745 wSet
.sType
= VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET
;
746 wSet
.pNext
= nullptr;
748 wSet
.dstSet
= *descriptorSetIt
;
749 wSet
.dstBinding
= memoryBuffer
.bindingIndex
;
750 wSet
.dstArrayElement
= 0;
751 wSet
.descriptorCount
= 1;
752 wSet
.descriptorType
= memoryBuffer
.descriptorType
;
753 wSet
.pImageInfo
= nullptr;
754 wSet
.pBufferInfo
= &memoryBuffer
.bufferInfo
;
755 wSet
.pTexelBufferView
= nullptr;
756 vkUpdateDescriptorSets(device
, 1, &wSet
, 0, nullptr);
758 // Increment descriptor set iterator.
764 LogicalResult
VulkanRuntime::createCommandPool() {
765 VkCommandPoolCreateInfo commandPoolCreateInfo
= {};
766 commandPoolCreateInfo
.sType
= VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO
;
767 commandPoolCreateInfo
.pNext
= nullptr;
768 commandPoolCreateInfo
.flags
= 0;
769 commandPoolCreateInfo
.queueFamilyIndex
= queueFamilyIndex
;
770 RETURN_ON_VULKAN_ERROR(vkCreateCommandPool(device
, &commandPoolCreateInfo
,
771 /*pAllocator=*/nullptr,
773 "vkCreateCommandPool");
777 LogicalResult
VulkanRuntime::createQueryPool() {
778 // Return directly if timestamp query is not supported.
779 if (queueFamilyProperties
.timestampValidBits
== 0)
782 // Get timestamp period for this physical device.
783 VkPhysicalDeviceProperties deviceProperties
= {};
784 vkGetPhysicalDeviceProperties(physicalDevice
, &deviceProperties
);
785 timestampPeriod
= deviceProperties
.limits
.timestampPeriod
;
787 // Create query pool.
788 VkQueryPoolCreateInfo queryPoolCreateInfo
= {};
789 queryPoolCreateInfo
.sType
= VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO
;
790 queryPoolCreateInfo
.pNext
= nullptr;
791 queryPoolCreateInfo
.flags
= 0;
792 queryPoolCreateInfo
.queryType
= VK_QUERY_TYPE_TIMESTAMP
;
793 queryPoolCreateInfo
.queryCount
= 2;
794 queryPoolCreateInfo
.pipelineStatistics
= 0;
795 RETURN_ON_VULKAN_ERROR(vkCreateQueryPool(device
, &queryPoolCreateInfo
,
796 /*pAllocator=*/nullptr, &queryPool
),
797 "vkCreateQueryPool");
802 LogicalResult
VulkanRuntime::createComputeCommandBuffer() {
803 VkCommandBufferAllocateInfo commandBufferAllocateInfo
= {};
804 commandBufferAllocateInfo
.sType
=
805 VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO
;
806 commandBufferAllocateInfo
.pNext
= nullptr;
807 commandBufferAllocateInfo
.commandPool
= commandPool
;
808 commandBufferAllocateInfo
.level
= VK_COMMAND_BUFFER_LEVEL_PRIMARY
;
809 commandBufferAllocateInfo
.commandBufferCount
= 1;
811 VkCommandBuffer commandBuffer
;
812 RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device
,
813 &commandBufferAllocateInfo
,
815 "vkAllocateCommandBuffers");
817 VkCommandBufferBeginInfo commandBufferBeginInfo
= {};
818 commandBufferBeginInfo
.sType
= VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO
;
819 commandBufferBeginInfo
.pNext
= nullptr;
820 commandBufferBeginInfo
.flags
= VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT
;
821 commandBufferBeginInfo
.pInheritanceInfo
= nullptr;
824 RETURN_ON_VULKAN_ERROR(
825 vkBeginCommandBuffer(commandBuffer
, &commandBufferBeginInfo
),
826 "vkBeginCommandBuffer");
828 if (queryPool
!= VK_NULL_HANDLE
)
829 vkCmdResetQueryPool(commandBuffer
, queryPool
, 0, 2);
831 vkCmdBindPipeline(commandBuffer
, VK_PIPELINE_BIND_POINT_COMPUTE
, pipeline
);
832 vkCmdBindDescriptorSets(commandBuffer
, VK_PIPELINE_BIND_POINT_COMPUTE
,
833 pipelineLayout
, 0, descriptorSets
.size(),
834 descriptorSets
.data(), 0, nullptr);
835 // Get a timestamp before invoking the compute shader.
836 if (queryPool
!= VK_NULL_HANDLE
)
837 vkCmdWriteTimestamp(commandBuffer
, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT
,
839 vkCmdDispatch(commandBuffer
, numWorkGroups
.x
, numWorkGroups
.y
,
841 // Get another timestamp after invoking the compute shader.
842 if (queryPool
!= VK_NULL_HANDLE
)
843 vkCmdWriteTimestamp(commandBuffer
, VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT
,
847 RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer
),
848 "vkEndCommandBuffer");
850 commandBuffers
.push_back(commandBuffer
);
854 LogicalResult
VulkanRuntime::submitCommandBuffersToQueue() {
855 VkSubmitInfo submitInfo
= {};
856 submitInfo
.sType
= VK_STRUCTURE_TYPE_SUBMIT_INFO
;
857 submitInfo
.pNext
= nullptr;
858 submitInfo
.waitSemaphoreCount
= 0;
859 submitInfo
.pWaitSemaphores
= nullptr;
860 submitInfo
.pWaitDstStageMask
= nullptr;
861 submitInfo
.commandBufferCount
= commandBuffers
.size();
862 submitInfo
.pCommandBuffers
= commandBuffers
.data();
863 submitInfo
.signalSemaphoreCount
= 0;
864 submitInfo
.pSignalSemaphores
= nullptr;
865 RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue
, 1, &submitInfo
, nullptr),
870 LogicalResult
VulkanRuntime::updateHostMemoryBuffers() {
871 // First copy back the data to the staging buffer.
872 (void)copyResource(/*deviceToHost=*/true);
874 // For each descriptor set.
875 for (auto &resourceDataMapPair
: resourceData
) {
876 auto &resourceDataMap
= resourceDataMapPair
.second
;
877 auto &deviceMemoryBuffers
=
878 deviceMemoryBufferMap
[resourceDataMapPair
.first
];
879 // For each device memory buffer in the set.
880 for (auto &deviceMemoryBuffer
: deviceMemoryBuffers
) {
881 if (resourceDataMap
.count(deviceMemoryBuffer
.bindingIndex
)) {
883 auto &hostMemoryBuffer
=
884 resourceDataMap
[deviceMemoryBuffer
.bindingIndex
];
885 RETURN_ON_VULKAN_ERROR(vkMapMemory(device
,
886 deviceMemoryBuffer
.hostMemory
, 0,
887 hostMemoryBuffer
.size
, 0,
888 reinterpret_cast<void **>(&payload
)),
890 std::memcpy(hostMemoryBuffer
.ptr
, payload
, hostMemoryBuffer
.size
);
891 vkUnmapMemory(device
, deviceMemoryBuffer
.hostMemory
);