Ensure SplitEdge to return the new block between the two given blocks
[llvm-project.git] / mlir / lib / Dialect / SPIRV / SPIRVTypes.cpp
blob15fafddf9f247f828101da3f1526b2a9343cf60b
1 //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types in the SPIR-V dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
14 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Identifier.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/ADT/StringSwitch.h"
22 #include "llvm/ADT/TypeSwitch.h"
24 using namespace mlir;
25 using namespace mlir::spirv;
27 // Pull in all enum utility function definitions
28 #include "mlir/Dialect/SPIRV/SPIRVEnums.cpp.inc"
29 // Pull in all enum type availability query function definitions
30 #include "mlir/Dialect/SPIRV/SPIRVEnumAvailability.cpp.inc"
32 //===----------------------------------------------------------------------===//
33 // Availability relationship
34 //===----------------------------------------------------------------------===//
36 ArrayRef<Extension> spirv::getImpliedExtensions(Version version) {
37 // Note: the following lists are from "Appendix A: Changes" of the spec.
39 #define V_1_3_IMPLIED_EXTS \
40 Extension::SPV_KHR_shader_draw_parameters, Extension::SPV_KHR_16bit_storage, \
41 Extension::SPV_KHR_device_group, Extension::SPV_KHR_multiview, \
42 Extension::SPV_KHR_storage_buffer_storage_class, \
43 Extension::SPV_KHR_variable_pointers
45 #define V_1_4_IMPLIED_EXTS \
46 Extension::SPV_KHR_no_integer_wrap_decoration, \
47 Extension::SPV_GOOGLE_decorate_string, \
48 Extension::SPV_GOOGLE_hlsl_functionality1, \
49 Extension::SPV_KHR_float_controls
51 #define V_1_5_IMPLIED_EXTS \
52 Extension::SPV_KHR_8bit_storage, Extension::SPV_EXT_descriptor_indexing, \
53 Extension::SPV_EXT_shader_viewport_index_layer, \
54 Extension::SPV_EXT_physical_storage_buffer, \
55 Extension::SPV_KHR_physical_storage_buffer, \
56 Extension::SPV_KHR_vulkan_memory_model
58 switch (version) {
59 default:
60 return {};
61 case Version::V_1_3: {
62 // The following manual ArrayRef constructor call is to satisfy GCC 5.
63 static const Extension exts[] = {V_1_3_IMPLIED_EXTS};
64 return ArrayRef<Extension>(exts, llvm::array_lengthof(exts));
66 case Version::V_1_4: {
67 static const Extension exts[] = {V_1_3_IMPLIED_EXTS, V_1_4_IMPLIED_EXTS};
68 return ArrayRef<Extension>(exts, llvm::array_lengthof(exts));
70 case Version::V_1_5: {
71 static const Extension exts[] = {V_1_3_IMPLIED_EXTS, V_1_4_IMPLIED_EXTS,
72 V_1_5_IMPLIED_EXTS};
73 return ArrayRef<Extension>(exts, llvm::array_lengthof(exts));
77 #undef V_1_5_IMPLIED_EXTS
78 #undef V_1_4_IMPLIED_EXTS
79 #undef V_1_3_IMPLIED_EXTS
82 // Pull in utility function definition for implied capabilities
83 #include "mlir/Dialect/SPIRV/SPIRVCapabilityImplication.inc"
85 SmallVector<Capability, 0>
86 spirv::getRecursiveImpliedCapabilities(Capability cap) {
87 ArrayRef<Capability> directCaps = getDirectImpliedCapabilities(cap);
88 llvm::SetVector<Capability, SmallVector<Capability, 0>> allCaps(
89 directCaps.begin(), directCaps.end());
91 // TODO: This is insufficient; find a better way to handle this
92 // (e.g., using static lists) if this turns out to be a bottleneck.
93 for (unsigned i = 0; i < allCaps.size(); ++i)
94 for (Capability c : getDirectImpliedCapabilities(allCaps[i]))
95 allCaps.insert(c);
97 return allCaps.takeVector();
100 //===----------------------------------------------------------------------===//
101 // ArrayType
102 //===----------------------------------------------------------------------===//
104 struct spirv::detail::ArrayTypeStorage : public TypeStorage {
105 using KeyTy = std::tuple<Type, unsigned, unsigned>;
107 static ArrayTypeStorage *construct(TypeStorageAllocator &allocator,
108 const KeyTy &key) {
109 return new (allocator.allocate<ArrayTypeStorage>()) ArrayTypeStorage(key);
112 bool operator==(const KeyTy &key) const {
113 return key == KeyTy(elementType, elementCount, stride);
116 ArrayTypeStorage(const KeyTy &key)
117 : elementType(std::get<0>(key)), elementCount(std::get<1>(key)),
118 stride(std::get<2>(key)) {}
120 Type elementType;
121 unsigned elementCount;
122 unsigned stride;
125 ArrayType ArrayType::get(Type elementType, unsigned elementCount) {
126 assert(elementCount && "ArrayType needs at least one element");
127 return Base::get(elementType.getContext(), elementType, elementCount,
128 /*stride=*/0);
131 ArrayType ArrayType::get(Type elementType, unsigned elementCount,
132 unsigned stride) {
133 assert(elementCount && "ArrayType needs at least one element");
134 return Base::get(elementType.getContext(), elementType, elementCount, stride);
137 unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; }
139 Type ArrayType::getElementType() const { return getImpl()->elementType; }
141 unsigned ArrayType::getArrayStride() const { return getImpl()->stride; }
143 void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
144 Optional<StorageClass> storage) {
145 getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
148 void ArrayType::getCapabilities(
149 SPIRVType::CapabilityArrayRefVector &capabilities,
150 Optional<StorageClass> storage) {
151 getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
154 Optional<int64_t> ArrayType::getSizeInBytes() {
155 auto elementType = getElementType().cast<SPIRVType>();
156 Optional<int64_t> size = elementType.getSizeInBytes();
157 if (!size)
158 return llvm::None;
159 return (*size + getArrayStride()) * getNumElements();
162 //===----------------------------------------------------------------------===//
163 // CompositeType
164 //===----------------------------------------------------------------------===//
166 bool CompositeType::classof(Type type) {
167 if (auto vectorType = type.dyn_cast<VectorType>())
168 return isValid(vectorType);
169 return type
170 .isa<spirv::ArrayType, spirv::CooperativeMatrixNVType, spirv::MatrixType,
171 spirv::RuntimeArrayType, spirv::StructType>();
174 bool CompositeType::isValid(VectorType type) {
175 switch (type.getNumElements()) {
176 case 2:
177 case 3:
178 case 4:
179 case 8:
180 case 16:
181 break;
182 default:
183 return false;
185 return type.getRank() == 1 && type.getElementType().isa<ScalarType>();
188 Type CompositeType::getElementType(unsigned index) const {
189 return TypeSwitch<Type, Type>(*this)
190 .Case<ArrayType, CooperativeMatrixNVType, RuntimeArrayType, VectorType>(
191 [](auto type) { return type.getElementType(); })
192 .Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
193 .Case<StructType>(
194 [index](StructType type) { return type.getElementType(index); })
195 .Default(
196 [](Type) -> Type { llvm_unreachable("invalid composite type"); });
199 unsigned CompositeType::getNumElements() const {
200 if (auto arrayType = dyn_cast<ArrayType>())
201 return arrayType.getNumElements();
202 if (auto matrixType = dyn_cast<MatrixType>())
203 return matrixType.getNumColumns();
204 if (auto structType = dyn_cast<StructType>())
205 return structType.getNumElements();
206 if (auto vectorType = dyn_cast<VectorType>())
207 return vectorType.getNumElements();
208 if (isa<CooperativeMatrixNVType>()) {
209 llvm_unreachable(
210 "invalid to query number of elements of spirv::CooperativeMatrix type");
212 if (isa<RuntimeArrayType>()) {
213 llvm_unreachable(
214 "invalid to query number of elements of spirv::RuntimeArray type");
216 llvm_unreachable("invalid composite type");
219 bool CompositeType::hasCompileTimeKnownNumElements() const {
220 return !isa<CooperativeMatrixNVType, RuntimeArrayType>();
223 void CompositeType::getExtensions(
224 SPIRVType::ExtensionArrayRefVector &extensions,
225 Optional<StorageClass> storage) {
226 TypeSwitch<Type>(*this)
227 .Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
228 StructType>(
229 [&](auto type) { type.getExtensions(extensions, storage); })
230 .Case<VectorType>([&](VectorType type) {
231 return type.getElementType().cast<ScalarType>().getExtensions(
232 extensions, storage);
234 .Default([](Type) { llvm_unreachable("invalid composite type"); });
237 void CompositeType::getCapabilities(
238 SPIRVType::CapabilityArrayRefVector &capabilities,
239 Optional<StorageClass> storage) {
240 TypeSwitch<Type>(*this)
241 .Case<ArrayType, CooperativeMatrixNVType, MatrixType, RuntimeArrayType,
242 StructType>(
243 [&](auto type) { type.getCapabilities(capabilities, storage); })
244 .Case<VectorType>([&](VectorType type) {
245 auto vecSize = getNumElements();
246 if (vecSize == 8 || vecSize == 16) {
247 static const Capability caps[] = {Capability::Vector16};
248 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
249 capabilities.push_back(ref);
251 return type.getElementType().cast<ScalarType>().getCapabilities(
252 capabilities, storage);
254 .Default([](Type) { llvm_unreachable("invalid composite type"); });
257 Optional<int64_t> CompositeType::getSizeInBytes() {
258 if (auto arrayType = dyn_cast<ArrayType>())
259 return arrayType.getSizeInBytes();
260 if (auto structType = dyn_cast<StructType>())
261 return structType.getSizeInBytes();
262 if (auto vectorType = dyn_cast<VectorType>()) {
263 Optional<int64_t> elementSize =
264 vectorType.getElementType().cast<ScalarType>().getSizeInBytes();
265 if (!elementSize)
266 return llvm::None;
267 return *elementSize * vectorType.getNumElements();
269 return llvm::None;
272 //===----------------------------------------------------------------------===//
273 // CooperativeMatrixType
274 //===----------------------------------------------------------------------===//
276 struct spirv::detail::CooperativeMatrixTypeStorage : public TypeStorage {
277 using KeyTy = std::tuple<Type, Scope, unsigned, unsigned>;
279 static CooperativeMatrixTypeStorage *
280 construct(TypeStorageAllocator &allocator, const KeyTy &key) {
281 return new (allocator.allocate<CooperativeMatrixTypeStorage>())
282 CooperativeMatrixTypeStorage(key);
285 bool operator==(const KeyTy &key) const {
286 return key == KeyTy(elementType, scope, rows, columns);
289 CooperativeMatrixTypeStorage(const KeyTy &key)
290 : elementType(std::get<0>(key)), rows(std::get<2>(key)),
291 columns(std::get<3>(key)), scope(std::get<1>(key)) {}
293 Type elementType;
294 unsigned rows;
295 unsigned columns;
296 Scope scope;
299 CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType,
300 Scope scope, unsigned rows,
301 unsigned columns) {
302 return Base::get(elementType.getContext(), elementType, scope, rows, columns);
305 Type CooperativeMatrixNVType::getElementType() const {
306 return getImpl()->elementType;
309 Scope CooperativeMatrixNVType::getScope() const { return getImpl()->scope; }
311 unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows; }
313 unsigned CooperativeMatrixNVType::getColumns() const {
314 return getImpl()->columns;
317 void CooperativeMatrixNVType::getExtensions(
318 SPIRVType::ExtensionArrayRefVector &extensions,
319 Optional<StorageClass> storage) {
320 getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
321 static const Extension exts[] = {Extension::SPV_NV_cooperative_matrix};
322 ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
323 extensions.push_back(ref);
326 void CooperativeMatrixNVType::getCapabilities(
327 SPIRVType::CapabilityArrayRefVector &capabilities,
328 Optional<StorageClass> storage) {
329 getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
330 static const Capability caps[] = {Capability::CooperativeMatrixNV};
331 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
332 capabilities.push_back(ref);
335 //===----------------------------------------------------------------------===//
336 // ImageType
337 //===----------------------------------------------------------------------===//
339 template <typename T> static constexpr unsigned getNumBits() { return 0; }
340 template <> constexpr unsigned getNumBits<Dim>() {
341 static_assert((1 << 3) > getMaxEnumValForDim(),
342 "Not enough bits to encode Dim value");
343 return 3;
345 template <> constexpr unsigned getNumBits<ImageDepthInfo>() {
346 static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
347 "Not enough bits to encode ImageDepthInfo value");
348 return 2;
350 template <> constexpr unsigned getNumBits<ImageArrayedInfo>() {
351 static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
352 "Not enough bits to encode ImageArrayedInfo value");
353 return 1;
355 template <> constexpr unsigned getNumBits<ImageSamplingInfo>() {
356 static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
357 "Not enough bits to encode ImageSamplingInfo value");
358 return 1;
360 template <> constexpr unsigned getNumBits<ImageSamplerUseInfo>() {
361 static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
362 "Not enough bits to encode ImageSamplerUseInfo value");
363 return 2;
365 template <> constexpr unsigned getNumBits<ImageFormat>() {
366 static_assert((1 << 6) > getMaxEnumValForImageFormat(),
367 "Not enough bits to encode ImageFormat value");
368 return 6;
371 struct spirv::detail::ImageTypeStorage : public TypeStorage {
372 public:
373 using KeyTy = std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
374 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>;
376 static ImageTypeStorage *construct(TypeStorageAllocator &allocator,
377 const KeyTy &key) {
378 return new (allocator.allocate<ImageTypeStorage>()) ImageTypeStorage(key);
381 bool operator==(const KeyTy &key) const {
382 return key == KeyTy(elementType, dim, depthInfo, arrayedInfo, samplingInfo,
383 samplerUseInfo, format);
386 ImageTypeStorage(const KeyTy &key)
387 : elementType(std::get<0>(key)), dim(std::get<1>(key)),
388 depthInfo(std::get<2>(key)), arrayedInfo(std::get<3>(key)),
389 samplingInfo(std::get<4>(key)), samplerUseInfo(std::get<5>(key)),
390 format(std::get<6>(key)) {}
392 Type elementType;
393 Dim dim : getNumBits<Dim>();
394 ImageDepthInfo depthInfo : getNumBits<ImageDepthInfo>();
395 ImageArrayedInfo arrayedInfo : getNumBits<ImageArrayedInfo>();
396 ImageSamplingInfo samplingInfo : getNumBits<ImageSamplingInfo>();
397 ImageSamplerUseInfo samplerUseInfo : getNumBits<ImageSamplerUseInfo>();
398 ImageFormat format : getNumBits<ImageFormat>();
401 ImageType
402 ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
403 ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
404 value) {
405 return Base::get(std::get<0>(value).getContext(), value);
408 Type ImageType::getElementType() const { return getImpl()->elementType; }
410 Dim ImageType::getDim() const { return getImpl()->dim; }
412 ImageDepthInfo ImageType::getDepthInfo() const { return getImpl()->depthInfo; }
414 ImageArrayedInfo ImageType::getArrayedInfo() const {
415 return getImpl()->arrayedInfo;
418 ImageSamplingInfo ImageType::getSamplingInfo() const {
419 return getImpl()->samplingInfo;
422 ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
423 return getImpl()->samplerUseInfo;
426 ImageFormat ImageType::getImageFormat() const { return getImpl()->format; }
428 void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector &,
429 Optional<StorageClass>) {
430 // Image types do not require extra extensions thus far.
433 void ImageType::getCapabilities(
434 SPIRVType::CapabilityArrayRefVector &capabilities, Optional<StorageClass>) {
435 if (auto dimCaps = spirv::getCapabilities(getDim()))
436 capabilities.push_back(*dimCaps);
438 if (auto fmtCaps = spirv::getCapabilities(getImageFormat()))
439 capabilities.push_back(*fmtCaps);
442 //===----------------------------------------------------------------------===//
443 // PointerType
444 //===----------------------------------------------------------------------===//
446 struct spirv::detail::PointerTypeStorage : public TypeStorage {
447 // (Type, StorageClass) as the key: Type stored in this struct, and
448 // StorageClass stored as TypeStorage's subclass data.
449 using KeyTy = std::pair<Type, StorageClass>;
451 static PointerTypeStorage *construct(TypeStorageAllocator &allocator,
452 const KeyTy &key) {
453 return new (allocator.allocate<PointerTypeStorage>())
454 PointerTypeStorage(key);
457 bool operator==(const KeyTy &key) const {
458 return key == KeyTy(pointeeType, storageClass);
461 PointerTypeStorage(const KeyTy &key)
462 : pointeeType(key.first), storageClass(key.second) {}
464 Type pointeeType;
465 StorageClass storageClass;
468 PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
469 return Base::get(pointeeType.getContext(), pointeeType, storageClass);
472 Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
474 StorageClass PointerType::getStorageClass() const {
475 return getImpl()->storageClass;
478 void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
479 Optional<StorageClass> storage) {
480 // Use this pointer type's storage class because this pointer indicates we are
481 // using the pointee type in that specific storage class.
482 getPointeeType().cast<SPIRVType>().getExtensions(extensions,
483 getStorageClass());
485 if (auto scExts = spirv::getExtensions(getStorageClass()))
486 extensions.push_back(*scExts);
489 void PointerType::getCapabilities(
490 SPIRVType::CapabilityArrayRefVector &capabilities,
491 Optional<StorageClass> storage) {
492 // Use this pointer type's storage class because this pointer indicates we are
493 // using the pointee type in that specific storage class.
494 getPointeeType().cast<SPIRVType>().getCapabilities(capabilities,
495 getStorageClass());
497 if (auto scCaps = spirv::getCapabilities(getStorageClass()))
498 capabilities.push_back(*scCaps);
501 //===----------------------------------------------------------------------===//
502 // RuntimeArrayType
503 //===----------------------------------------------------------------------===//
505 struct spirv::detail::RuntimeArrayTypeStorage : public TypeStorage {
506 using KeyTy = std::pair<Type, unsigned>;
508 static RuntimeArrayTypeStorage *construct(TypeStorageAllocator &allocator,
509 const KeyTy &key) {
510 return new (allocator.allocate<RuntimeArrayTypeStorage>())
511 RuntimeArrayTypeStorage(key);
514 bool operator==(const KeyTy &key) const {
515 return key == KeyTy(elementType, stride);
518 RuntimeArrayTypeStorage(const KeyTy &key)
519 : elementType(key.first), stride(key.second) {}
521 Type elementType;
522 unsigned stride;
525 RuntimeArrayType RuntimeArrayType::get(Type elementType) {
526 return Base::get(elementType.getContext(), elementType, /*stride=*/0);
529 RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) {
530 return Base::get(elementType.getContext(), elementType, stride);
533 Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
535 unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride; }
537 void RuntimeArrayType::getExtensions(
538 SPIRVType::ExtensionArrayRefVector &extensions,
539 Optional<StorageClass> storage) {
540 getElementType().cast<SPIRVType>().getExtensions(extensions, storage);
543 void RuntimeArrayType::getCapabilities(
544 SPIRVType::CapabilityArrayRefVector &capabilities,
545 Optional<StorageClass> storage) {
547 static const Capability caps[] = {Capability::Shader};
548 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
549 capabilities.push_back(ref);
551 getElementType().cast<SPIRVType>().getCapabilities(capabilities, storage);
554 //===----------------------------------------------------------------------===//
555 // ScalarType
556 //===----------------------------------------------------------------------===//
558 bool ScalarType::classof(Type type) {
559 if (auto floatType = type.dyn_cast<FloatType>()) {
560 return isValid(floatType);
562 if (auto intType = type.dyn_cast<IntegerType>()) {
563 return isValid(intType);
565 return false;
568 bool ScalarType::isValid(FloatType type) { return !type.isBF16(); }
570 bool ScalarType::isValid(IntegerType type) {
571 switch (type.getWidth()) {
572 case 1:
573 case 8:
574 case 16:
575 case 32:
576 case 64:
577 return true;
578 default:
579 return false;
583 void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
584 Optional<StorageClass> storage) {
585 // 8- or 16-bit integer/floating-point numbers will require extra extensions
586 // to appear in interface storage classes. See SPV_KHR_16bit_storage and
587 // SPV_KHR_8bit_storage for more details.
588 if (!storage)
589 return;
591 switch (*storage) {
592 case StorageClass::PushConstant:
593 case StorageClass::StorageBuffer:
594 case StorageClass::Uniform:
595 if (getIntOrFloatBitWidth() == 8) {
596 static const Extension exts[] = {Extension::SPV_KHR_8bit_storage};
597 ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
598 extensions.push_back(ref);
600 LLVM_FALLTHROUGH;
601 case StorageClass::Input:
602 case StorageClass::Output:
603 if (getIntOrFloatBitWidth() == 16) {
604 static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
605 ArrayRef<Extension> ref(exts, llvm::array_lengthof(exts));
606 extensions.push_back(ref);
608 break;
609 default:
610 break;
614 void ScalarType::getCapabilities(
615 SPIRVType::CapabilityArrayRefVector &capabilities,
616 Optional<StorageClass> storage) {
617 unsigned bitwidth = getIntOrFloatBitWidth();
619 // 8- or 16-bit integer/floating-point numbers will require extra capabilities
620 // to appear in interface storage classes. See SPV_KHR_16bit_storage and
621 // SPV_KHR_8bit_storage for more details.
623 #define STORAGE_CASE(storage, cap8, cap16) \
624 case StorageClass::storage: { \
625 if (bitwidth == 8) { \
626 static const Capability caps[] = {Capability::cap8}; \
627 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
628 capabilities.push_back(ref); \
629 } else if (bitwidth == 16) { \
630 static const Capability caps[] = {Capability::cap16}; \
631 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
632 capabilities.push_back(ref); \
634 /* No requirements for other bitwidths */ \
635 return; \
638 // This part only handles the cases where special bitwidths appearing in
639 // interface storage classes.
640 if (storage) {
641 switch (*storage) {
642 STORAGE_CASE(PushConstant, StoragePushConstant8, StoragePushConstant16);
643 STORAGE_CASE(StorageBuffer, StorageBuffer8BitAccess,
644 StorageBuffer16BitAccess);
645 STORAGE_CASE(Uniform, UniformAndStorageBuffer8BitAccess,
646 StorageUniform16);
647 case StorageClass::Input:
648 case StorageClass::Output: {
649 if (bitwidth == 16) {
650 static const Capability caps[] = {Capability::StorageInputOutput16};
651 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
652 capabilities.push_back(ref);
654 return;
656 default:
657 break;
660 #undef STORAGE_CASE
662 // For other non-interface storage classes, require a different set of
663 // capabilities for special bitwidths.
665 #define WIDTH_CASE(type, width) \
666 case width: { \
667 static const Capability caps[] = {Capability::type##width}; \
668 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
669 capabilities.push_back(ref); \
670 } break
672 if (auto intType = dyn_cast<IntegerType>()) {
673 switch (bitwidth) {
674 case 32:
675 case 1:
676 break;
677 WIDTH_CASE(Int, 8);
678 WIDTH_CASE(Int, 16);
679 WIDTH_CASE(Int, 64);
680 default:
681 llvm_unreachable("invalid bitwidth to getCapabilities");
683 } else {
684 assert(isa<FloatType>());
685 switch (bitwidth) {
686 case 32:
687 break;
688 WIDTH_CASE(Float, 16);
689 WIDTH_CASE(Float, 64);
690 default:
691 llvm_unreachable("invalid bitwidth to getCapabilities");
695 #undef WIDTH_CASE
698 Optional<int64_t> ScalarType::getSizeInBytes() {
699 auto bitWidth = getIntOrFloatBitWidth();
700 // According to the SPIR-V spec:
701 // "There is no physical size or bit pattern defined for values with boolean
702 // type. If they are stored (in conjunction with OpVariable), they can only
703 // be used with logical addressing operations, not physical, and only with
704 // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
705 // Private, Function, Input, and Output."
706 if (bitWidth == 1)
707 return llvm::None;
708 return bitWidth / 8;
711 //===----------------------------------------------------------------------===//
712 // SPIRVType
713 //===----------------------------------------------------------------------===//
715 bool SPIRVType::classof(Type type) {
716 // Allow SPIR-V dialect types
717 if (llvm::isa<SPIRVDialect>(type.getDialect()))
718 return true;
719 if (type.isa<ScalarType>())
720 return true;
721 if (auto vectorType = type.dyn_cast<VectorType>())
722 return CompositeType::isValid(vectorType);
723 return false;
726 bool SPIRVType::isScalarOrVector() {
727 return isIntOrFloat() || isa<VectorType>();
730 void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
731 Optional<StorageClass> storage) {
732 if (auto scalarType = dyn_cast<ScalarType>()) {
733 scalarType.getExtensions(extensions, storage);
734 } else if (auto compositeType = dyn_cast<CompositeType>()) {
735 compositeType.getExtensions(extensions, storage);
736 } else if (auto imageType = dyn_cast<ImageType>()) {
737 imageType.getExtensions(extensions, storage);
738 } else if (auto matrixType = dyn_cast<MatrixType>()) {
739 matrixType.getExtensions(extensions, storage);
740 } else if (auto ptrType = dyn_cast<PointerType>()) {
741 ptrType.getExtensions(extensions, storage);
742 } else {
743 llvm_unreachable("invalid SPIR-V Type to getExtensions");
747 void SPIRVType::getCapabilities(
748 SPIRVType::CapabilityArrayRefVector &capabilities,
749 Optional<StorageClass> storage) {
750 if (auto scalarType = dyn_cast<ScalarType>()) {
751 scalarType.getCapabilities(capabilities, storage);
752 } else if (auto compositeType = dyn_cast<CompositeType>()) {
753 compositeType.getCapabilities(capabilities, storage);
754 } else if (auto imageType = dyn_cast<ImageType>()) {
755 imageType.getCapabilities(capabilities, storage);
756 } else if (auto matrixType = dyn_cast<MatrixType>()) {
757 matrixType.getCapabilities(capabilities, storage);
758 } else if (auto ptrType = dyn_cast<PointerType>()) {
759 ptrType.getCapabilities(capabilities, storage);
760 } else {
761 llvm_unreachable("invalid SPIR-V Type to getCapabilities");
765 Optional<int64_t> SPIRVType::getSizeInBytes() {
766 if (auto scalarType = dyn_cast<ScalarType>())
767 return scalarType.getSizeInBytes();
768 if (auto compositeType = dyn_cast<CompositeType>())
769 return compositeType.getSizeInBytes();
770 return llvm::None;
773 //===----------------------------------------------------------------------===//
774 // StructType
775 //===----------------------------------------------------------------------===//
777 /// Type storage for SPIR-V structure types:
779 /// Structures are uniqued using:
780 /// - for identified structs:
781 /// - a string identifier;
782 /// - for literal structs:
783 /// - a list of member types;
784 /// - a list of member offset info;
785 /// - a list of member decoration info.
787 /// Identified structures only have a mutable component consisting of:
788 /// - a list of member types;
789 /// - a list of member offset info;
790 /// - a list of member decoration info.
791 struct spirv::detail::StructTypeStorage : public TypeStorage {
792 /// Construct a storage object for an identified struct type. A struct type
793 /// associated with such storage must call StructType::trySetBody(...) later
794 /// in order to mutate the storage object providing the actual content.
795 StructTypeStorage(StringRef identifier)
796 : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
797 numMemberDecorations(0), memberDecorationsInfo(nullptr),
798 identifier(identifier) {}
800 /// Construct a storage object for a literal struct type. A struct type
801 /// associated with such storage is immutable.
802 StructTypeStorage(
803 unsigned numMembers, Type const *memberTypes,
804 StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
805 StructType::MemberDecorationInfo const *memberDecorationsInfo)
806 : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
807 numMembers(numMembers), numMemberDecorations(numMemberDecorations),
808 memberDecorationsInfo(memberDecorationsInfo), identifier(StringRef()) {}
810 /// A storage key is divided into 2 parts:
811 /// - for identified structs:
812 /// - a StringRef representing the struct identifier;
813 /// - for literal structs:
814 /// - an ArrayRef<Type> for member types;
815 /// - an ArrayRef<StructType::OffsetInfo> for member offset info;
816 /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
817 /// info.
819 /// An identified struct type is uniqued only by the first part (field 0)
820 /// of the key.
822 /// A literal struct type is unqiued only by the second part (fields 1, 2, and
823 /// 3) of the key. The identifier field (field 0) must be empty.
824 using KeyTy =
825 std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
826 ArrayRef<StructType::MemberDecorationInfo>>;
828 /// For identified structs, return true if the given key contains the same
829 /// identifier.
831 /// For literal structs, return true if the given key contains a matching list
832 /// of member types + offset info + decoration info.
833 bool operator==(const KeyTy &key) const {
834 if (isIdentified()) {
835 // Identified types are uniqued by their identifier.
836 return getIdentifier() == std::get<0>(key);
839 return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
840 getMemberDecorationsInfo());
843 /// If the given key contains a non-empty identifier, this method constructs
844 /// an identified struct and leaves the rest of the struct type data to be set
845 /// through a later call to StructType::trySetBody(...).
847 /// If, on the other hand, the key contains an empty identifier, a literal
848 /// struct is constructed using the other fields of the key.
849 static StructTypeStorage *construct(TypeStorageAllocator &allocator,
850 const KeyTy &key) {
851 StringRef keyIdentifier = std::get<0>(key);
853 if (!keyIdentifier.empty()) {
854 StringRef identifier = allocator.copyInto(keyIdentifier);
856 // Identified StructType body/members will be set through trySetBody(...)
857 // later.
858 return new (allocator.allocate<StructTypeStorage>())
859 StructTypeStorage(identifier);
862 ArrayRef<Type> keyTypes = std::get<1>(key);
864 // Copy the member type and layout information into the bump pointer
865 const Type *typesList = nullptr;
866 if (!keyTypes.empty()) {
867 typesList = allocator.copyInto(keyTypes).data();
870 const StructType::OffsetInfo *offsetInfoList = nullptr;
871 if (!std::get<2>(key).empty()) {
872 ArrayRef<StructType::OffsetInfo> keyOffsetInfo = std::get<2>(key);
873 assert(keyOffsetInfo.size() == keyTypes.size() &&
874 "size of offset information must be same as the size of number of "
875 "elements");
876 offsetInfoList = allocator.copyInto(keyOffsetInfo).data();
879 const StructType::MemberDecorationInfo *memberDecorationList = nullptr;
880 unsigned numMemberDecorations = 0;
881 if (!std::get<3>(key).empty()) {
882 auto keyMemberDecorations = std::get<3>(key);
883 numMemberDecorations = keyMemberDecorations.size();
884 memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
887 return new (allocator.allocate<StructTypeStorage>())
888 StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
889 numMemberDecorations, memberDecorationList);
892 ArrayRef<Type> getMemberTypes() const {
893 return ArrayRef<Type>(memberTypesAndIsBodySet.getPointer(), numMembers);
896 ArrayRef<StructType::OffsetInfo> getOffsetInfo() const {
897 if (offsetInfo) {
898 return ArrayRef<StructType::OffsetInfo>(offsetInfo, numMembers);
900 return {};
903 ArrayRef<StructType::MemberDecorationInfo> getMemberDecorationsInfo() const {
904 if (memberDecorationsInfo) {
905 return ArrayRef<StructType::MemberDecorationInfo>(memberDecorationsInfo,
906 numMemberDecorations);
908 return {};
911 StringRef getIdentifier() const { return identifier; }
913 bool isIdentified() const { return !identifier.empty(); }
915 /// Sets the struct type content for identified structs. Calling this method
916 /// is only valid for identified structs.
918 /// Fails under the following conditions:
919 /// - If called for a literal struct;
920 /// - If called for an identified struct whose body was set before (through a
921 /// call to this method) but with different contents from the passed
922 /// arguments.
923 LogicalResult mutate(
924 TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
925 ArrayRef<StructType::OffsetInfo> structOffsetInfo,
926 ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
927 if (!isIdentified())
928 return failure();
930 if (memberTypesAndIsBodySet.getInt() &&
931 (getMemberTypes() != structMemberTypes ||
932 getOffsetInfo() != structOffsetInfo ||
933 getMemberDecorationsInfo() != structMemberDecorationInfo))
934 return failure();
936 memberTypesAndIsBodySet.setInt(true);
937 numMembers = structMemberTypes.size();
939 // Copy the member type and layout information into the bump pointer.
940 if (!structMemberTypes.empty())
941 memberTypesAndIsBodySet.setPointer(
942 allocator.copyInto(structMemberTypes).data());
944 if (!structOffsetInfo.empty()) {
945 assert(structOffsetInfo.size() == structMemberTypes.size() &&
946 "size of offset information must be same as the size of number of "
947 "elements");
948 offsetInfo = allocator.copyInto(structOffsetInfo).data();
951 if (!structMemberDecorationInfo.empty()) {
952 numMemberDecorations = structMemberDecorationInfo.size();
953 memberDecorationsInfo =
954 allocator.copyInto(structMemberDecorationInfo).data();
957 return success();
960 llvm::PointerIntPair<Type const *, 1, bool> memberTypesAndIsBodySet;
961 StructType::OffsetInfo const *offsetInfo;
962 unsigned numMembers;
963 unsigned numMemberDecorations;
964 StructType::MemberDecorationInfo const *memberDecorationsInfo;
965 StringRef identifier;
968 StructType
969 StructType::get(ArrayRef<Type> memberTypes,
970 ArrayRef<StructType::OffsetInfo> offsetInfo,
971 ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
972 assert(!memberTypes.empty() && "Struct needs at least one member type");
973 // Sort the decorations.
974 SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
975 memberDecorations.begin(), memberDecorations.end());
976 llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
977 return Base::get(memberTypes.vec().front().getContext(),
978 /*identifier=*/StringRef(), memberTypes, offsetInfo,
979 sortedDecorations);
982 StructType StructType::getIdentified(MLIRContext *context,
983 StringRef identifier) {
984 assert(!identifier.empty() &&
985 "StructType identifier must be non-empty string");
987 return Base::get(context, identifier, ArrayRef<Type>(),
988 ArrayRef<StructType::OffsetInfo>(),
989 ArrayRef<StructType::MemberDecorationInfo>());
992 StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
993 StructType newStructType = Base::get(
994 context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
995 ArrayRef<StructType::MemberDecorationInfo>());
996 // Set an empty body in case this is a identified struct.
997 if (newStructType.isIdentified() &&
998 failed(newStructType.trySetBody(
999 ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
1000 ArrayRef<StructType::MemberDecorationInfo>())))
1001 return StructType();
1003 return newStructType;
1006 StringRef StructType::getIdentifier() const { return getImpl()->identifier; }
1008 bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
1010 unsigned StructType::getNumElements() const { return getImpl()->numMembers; }
1012 Type StructType::getElementType(unsigned index) const {
1013 assert(getNumElements() > index && "member index out of range");
1014 return getImpl()->memberTypesAndIsBodySet.getPointer()[index];
1017 StructType::ElementTypeRange StructType::getElementTypes() const {
1018 return ElementTypeRange(getImpl()->memberTypesAndIsBodySet.getPointer(),
1019 getNumElements());
1022 bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
1024 uint64_t StructType::getMemberOffset(unsigned index) const {
1025 assert(getNumElements() > index && "member index out of range");
1026 return getImpl()->offsetInfo[index];
1029 void StructType::getMemberDecorations(
1030 SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorations)
1031 const {
1032 memberDecorations.clear();
1033 auto implMemberDecorations = getImpl()->getMemberDecorationsInfo();
1034 memberDecorations.append(implMemberDecorations.begin(),
1035 implMemberDecorations.end());
1038 void StructType::getMemberDecorations(
1039 unsigned index,
1040 SmallVectorImpl<StructType::MemberDecorationInfo> &decorationsInfo) const {
1041 assert(getNumElements() > index && "member index out of range");
1042 auto memberDecorations = getImpl()->getMemberDecorationsInfo();
1043 decorationsInfo.clear();
1044 for (const auto &memberDecoration : memberDecorations) {
1045 if (memberDecoration.memberIndex == index) {
1046 decorationsInfo.push_back(memberDecoration);
1048 if (memberDecoration.memberIndex > index) {
1049 // Early exit since the decorations are stored sorted.
1050 return;
1055 LogicalResult
1056 StructType::trySetBody(ArrayRef<Type> memberTypes,
1057 ArrayRef<OffsetInfo> offsetInfo,
1058 ArrayRef<MemberDecorationInfo> memberDecorations) {
1059 return Base::mutate(memberTypes, offsetInfo, memberDecorations);
1062 void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
1063 Optional<StorageClass> storage) {
1064 for (Type elementType : getElementTypes())
1065 elementType.cast<SPIRVType>().getExtensions(extensions, storage);
1068 void StructType::getCapabilities(
1069 SPIRVType::CapabilityArrayRefVector &capabilities,
1070 Optional<StorageClass> storage) {
1071 for (Type elementType : getElementTypes())
1072 elementType.cast<SPIRVType>().getCapabilities(capabilities, storage);
1075 llvm::hash_code spirv::hash_value(
1076 const StructType::MemberDecorationInfo &memberDecorationInfo) {
1077 return llvm::hash_combine(memberDecorationInfo.memberIndex,
1078 memberDecorationInfo.decoration);
1081 //===----------------------------------------------------------------------===//
1082 // MatrixType
1083 //===----------------------------------------------------------------------===//
1085 struct spirv::detail::MatrixTypeStorage : public TypeStorage {
1086 MatrixTypeStorage(Type columnType, uint32_t columnCount)
1087 : TypeStorage(), columnType(columnType), columnCount(columnCount) {}
1089 using KeyTy = std::tuple<Type, uint32_t>;
1091 static MatrixTypeStorage *construct(TypeStorageAllocator &allocator,
1092 const KeyTy &key) {
1094 // Initialize the memory using placement new.
1095 return new (allocator.allocate<MatrixTypeStorage>())
1096 MatrixTypeStorage(std::get<0>(key), std::get<1>(key));
1099 bool operator==(const KeyTy &key) const {
1100 return key == KeyTy(columnType, columnCount);
1103 Type columnType;
1104 const uint32_t columnCount;
1107 MatrixType MatrixType::get(Type columnType, uint32_t columnCount) {
1108 return Base::get(columnType.getContext(), columnType, columnCount);
1111 MatrixType MatrixType::getChecked(Type columnType, uint32_t columnCount,
1112 Location location) {
1113 return Base::getChecked(location, columnType, columnCount);
1116 LogicalResult MatrixType::verifyConstructionInvariants(Location loc,
1117 Type columnType,
1118 uint32_t columnCount) {
1119 if (columnCount < 2 || columnCount > 4)
1120 return emitError(loc, "matrix can have 2, 3, or 4 columns only");
1122 if (!isValidColumnType(columnType))
1123 return emitError(loc, "matrix columns must be vectors of floats");
1125 /// The underlying vectors (columns) must be of size 2, 3, or 4
1126 ArrayRef<int64_t> columnShape = columnType.cast<VectorType>().getShape();
1127 if (columnShape.size() != 1)
1128 return emitError(loc, "matrix columns must be 1D vectors");
1130 if (columnShape[0] < 2 || columnShape[0] > 4)
1131 return emitError(loc, "matrix columns must be of size 2, 3, or 4");
1133 return success();
1136 /// Returns true if the matrix elements are vectors of float elements
1137 bool MatrixType::isValidColumnType(Type columnType) {
1138 if (auto vectorType = columnType.dyn_cast<VectorType>()) {
1139 if (vectorType.getElementType().isa<FloatType>())
1140 return true;
1142 return false;
1145 Type MatrixType::getColumnType() const { return getImpl()->columnType; }
1147 Type MatrixType::getElementType() const {
1148 return getImpl()->columnType.cast<VectorType>().getElementType();
1151 unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; }
1153 unsigned MatrixType::getNumRows() const {
1154 return getImpl()->columnType.cast<VectorType>().getShape()[0];
1157 unsigned MatrixType::getNumElements() const {
1158 return (getImpl()->columnCount) * getNumRows();
1161 void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
1162 Optional<StorageClass> storage) {
1163 getColumnType().cast<SPIRVType>().getExtensions(extensions, storage);
1166 void MatrixType::getCapabilities(
1167 SPIRVType::CapabilityArrayRefVector &capabilities,
1168 Optional<StorageClass> storage) {
1170 static const Capability caps[] = {Capability::Matrix};
1171 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps));
1172 capabilities.push_back(ref);
1174 // Add any capabilities associated with the underlying vectors (i.e., columns)
1175 getColumnType().cast<SPIRVType>().getCapabilities(capabilities, storage);