1 //===- SPIRVTypes.cpp - MLIR SPIR-V Types ---------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines the types in the SPIR-V dialect.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
14 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
15 #include "mlir/IR/Attributes.h"
16 #include "mlir/IR/BuiltinTypes.h"
17 #include "mlir/IR/Identifier.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/ADT/StringSwitch.h"
22 #include "llvm/ADT/TypeSwitch.h"
25 using namespace mlir::spirv
;
27 // Pull in all enum utility function definitions
28 #include "mlir/Dialect/SPIRV/SPIRVEnums.cpp.inc"
29 // Pull in all enum type availability query function definitions
30 #include "mlir/Dialect/SPIRV/SPIRVEnumAvailability.cpp.inc"
32 //===----------------------------------------------------------------------===//
33 // Availability relationship
34 //===----------------------------------------------------------------------===//
36 ArrayRef
<Extension
> spirv::getImpliedExtensions(Version version
) {
37 // Note: the following lists are from "Appendix A: Changes" of the spec.
39 #define V_1_3_IMPLIED_EXTS \
40 Extension::SPV_KHR_shader_draw_parameters, Extension::SPV_KHR_16bit_storage, \
41 Extension::SPV_KHR_device_group, Extension::SPV_KHR_multiview, \
42 Extension::SPV_KHR_storage_buffer_storage_class, \
43 Extension::SPV_KHR_variable_pointers
45 #define V_1_4_IMPLIED_EXTS \
46 Extension::SPV_KHR_no_integer_wrap_decoration, \
47 Extension::SPV_GOOGLE_decorate_string, \
48 Extension::SPV_GOOGLE_hlsl_functionality1, \
49 Extension::SPV_KHR_float_controls
51 #define V_1_5_IMPLIED_EXTS \
52 Extension::SPV_KHR_8bit_storage, Extension::SPV_EXT_descriptor_indexing, \
53 Extension::SPV_EXT_shader_viewport_index_layer, \
54 Extension::SPV_EXT_physical_storage_buffer, \
55 Extension::SPV_KHR_physical_storage_buffer, \
56 Extension::SPV_KHR_vulkan_memory_model
61 case Version::V_1_3
: {
62 // The following manual ArrayRef constructor call is to satisfy GCC 5.
63 static const Extension exts
[] = {V_1_3_IMPLIED_EXTS
};
64 return ArrayRef
<Extension
>(exts
, llvm::array_lengthof(exts
));
66 case Version::V_1_4
: {
67 static const Extension exts
[] = {V_1_3_IMPLIED_EXTS
, V_1_4_IMPLIED_EXTS
};
68 return ArrayRef
<Extension
>(exts
, llvm::array_lengthof(exts
));
70 case Version::V_1_5
: {
71 static const Extension exts
[] = {V_1_3_IMPLIED_EXTS
, V_1_4_IMPLIED_EXTS
,
73 return ArrayRef
<Extension
>(exts
, llvm::array_lengthof(exts
));
77 #undef V_1_5_IMPLIED_EXTS
78 #undef V_1_4_IMPLIED_EXTS
79 #undef V_1_3_IMPLIED_EXTS
82 // Pull in utility function definition for implied capabilities
83 #include "mlir/Dialect/SPIRV/SPIRVCapabilityImplication.inc"
85 SmallVector
<Capability
, 0>
86 spirv::getRecursiveImpliedCapabilities(Capability cap
) {
87 ArrayRef
<Capability
> directCaps
= getDirectImpliedCapabilities(cap
);
88 llvm::SetVector
<Capability
, SmallVector
<Capability
, 0>> allCaps(
89 directCaps
.begin(), directCaps
.end());
91 // TODO: This is insufficient; find a better way to handle this
92 // (e.g., using static lists) if this turns out to be a bottleneck.
93 for (unsigned i
= 0; i
< allCaps
.size(); ++i
)
94 for (Capability c
: getDirectImpliedCapabilities(allCaps
[i
]))
97 return allCaps
.takeVector();
100 //===----------------------------------------------------------------------===//
102 //===----------------------------------------------------------------------===//
104 struct spirv::detail::ArrayTypeStorage
: public TypeStorage
{
105 using KeyTy
= std::tuple
<Type
, unsigned, unsigned>;
107 static ArrayTypeStorage
*construct(TypeStorageAllocator
&allocator
,
109 return new (allocator
.allocate
<ArrayTypeStorage
>()) ArrayTypeStorage(key
);
112 bool operator==(const KeyTy
&key
) const {
113 return key
== KeyTy(elementType
, elementCount
, stride
);
116 ArrayTypeStorage(const KeyTy
&key
)
117 : elementType(std::get
<0>(key
)), elementCount(std::get
<1>(key
)),
118 stride(std::get
<2>(key
)) {}
121 unsigned elementCount
;
125 ArrayType
ArrayType::get(Type elementType
, unsigned elementCount
) {
126 assert(elementCount
&& "ArrayType needs at least one element");
127 return Base::get(elementType
.getContext(), elementType
, elementCount
,
131 ArrayType
ArrayType::get(Type elementType
, unsigned elementCount
,
133 assert(elementCount
&& "ArrayType needs at least one element");
134 return Base::get(elementType
.getContext(), elementType
, elementCount
, stride
);
137 unsigned ArrayType::getNumElements() const { return getImpl()->elementCount
; }
139 Type
ArrayType::getElementType() const { return getImpl()->elementType
; }
141 unsigned ArrayType::getArrayStride() const { return getImpl()->stride
; }
143 void ArrayType::getExtensions(SPIRVType::ExtensionArrayRefVector
&extensions
,
144 Optional
<StorageClass
> storage
) {
145 getElementType().cast
<SPIRVType
>().getExtensions(extensions
, storage
);
148 void ArrayType::getCapabilities(
149 SPIRVType::CapabilityArrayRefVector
&capabilities
,
150 Optional
<StorageClass
> storage
) {
151 getElementType().cast
<SPIRVType
>().getCapabilities(capabilities
, storage
);
154 Optional
<int64_t> ArrayType::getSizeInBytes() {
155 auto elementType
= getElementType().cast
<SPIRVType
>();
156 Optional
<int64_t> size
= elementType
.getSizeInBytes();
159 return (*size
+ getArrayStride()) * getNumElements();
162 //===----------------------------------------------------------------------===//
164 //===----------------------------------------------------------------------===//
166 bool CompositeType::classof(Type type
) {
167 if (auto vectorType
= type
.dyn_cast
<VectorType
>())
168 return isValid(vectorType
);
170 .isa
<spirv::ArrayType
, spirv::CooperativeMatrixNVType
, spirv::MatrixType
,
171 spirv::RuntimeArrayType
, spirv::StructType
>();
174 bool CompositeType::isValid(VectorType type
) {
175 switch (type
.getNumElements()) {
185 return type
.getRank() == 1 && type
.getElementType().isa
<ScalarType
>();
188 Type
CompositeType::getElementType(unsigned index
) const {
189 return TypeSwitch
<Type
, Type
>(*this)
190 .Case
<ArrayType
, CooperativeMatrixNVType
, RuntimeArrayType
, VectorType
>(
191 [](auto type
) { return type
.getElementType(); })
192 .Case
<MatrixType
>([](MatrixType type
) { return type
.getColumnType(); })
194 [index
](StructType type
) { return type
.getElementType(index
); })
196 [](Type
) -> Type
{ llvm_unreachable("invalid composite type"); });
199 unsigned CompositeType::getNumElements() const {
200 if (auto arrayType
= dyn_cast
<ArrayType
>())
201 return arrayType
.getNumElements();
202 if (auto matrixType
= dyn_cast
<MatrixType
>())
203 return matrixType
.getNumColumns();
204 if (auto structType
= dyn_cast
<StructType
>())
205 return structType
.getNumElements();
206 if (auto vectorType
= dyn_cast
<VectorType
>())
207 return vectorType
.getNumElements();
208 if (isa
<CooperativeMatrixNVType
>()) {
210 "invalid to query number of elements of spirv::CooperativeMatrix type");
212 if (isa
<RuntimeArrayType
>()) {
214 "invalid to query number of elements of spirv::RuntimeArray type");
216 llvm_unreachable("invalid composite type");
219 bool CompositeType::hasCompileTimeKnownNumElements() const {
220 return !isa
<CooperativeMatrixNVType
, RuntimeArrayType
>();
223 void CompositeType::getExtensions(
224 SPIRVType::ExtensionArrayRefVector
&extensions
,
225 Optional
<StorageClass
> storage
) {
226 TypeSwitch
<Type
>(*this)
227 .Case
<ArrayType
, CooperativeMatrixNVType
, MatrixType
, RuntimeArrayType
,
229 [&](auto type
) { type
.getExtensions(extensions
, storage
); })
230 .Case
<VectorType
>([&](VectorType type
) {
231 return type
.getElementType().cast
<ScalarType
>().getExtensions(
232 extensions
, storage
);
234 .Default([](Type
) { llvm_unreachable("invalid composite type"); });
237 void CompositeType::getCapabilities(
238 SPIRVType::CapabilityArrayRefVector
&capabilities
,
239 Optional
<StorageClass
> storage
) {
240 TypeSwitch
<Type
>(*this)
241 .Case
<ArrayType
, CooperativeMatrixNVType
, MatrixType
, RuntimeArrayType
,
243 [&](auto type
) { type
.getCapabilities(capabilities
, storage
); })
244 .Case
<VectorType
>([&](VectorType type
) {
245 auto vecSize
= getNumElements();
246 if (vecSize
== 8 || vecSize
== 16) {
247 static const Capability caps
[] = {Capability::Vector16
};
248 ArrayRef
<Capability
> ref(caps
, llvm::array_lengthof(caps
));
249 capabilities
.push_back(ref
);
251 return type
.getElementType().cast
<ScalarType
>().getCapabilities(
252 capabilities
, storage
);
254 .Default([](Type
) { llvm_unreachable("invalid composite type"); });
257 Optional
<int64_t> CompositeType::getSizeInBytes() {
258 if (auto arrayType
= dyn_cast
<ArrayType
>())
259 return arrayType
.getSizeInBytes();
260 if (auto structType
= dyn_cast
<StructType
>())
261 return structType
.getSizeInBytes();
262 if (auto vectorType
= dyn_cast
<VectorType
>()) {
263 Optional
<int64_t> elementSize
=
264 vectorType
.getElementType().cast
<ScalarType
>().getSizeInBytes();
267 return *elementSize
* vectorType
.getNumElements();
272 //===----------------------------------------------------------------------===//
273 // CooperativeMatrixType
274 //===----------------------------------------------------------------------===//
276 struct spirv::detail::CooperativeMatrixTypeStorage
: public TypeStorage
{
277 using KeyTy
= std::tuple
<Type
, Scope
, unsigned, unsigned>;
279 static CooperativeMatrixTypeStorage
*
280 construct(TypeStorageAllocator
&allocator
, const KeyTy
&key
) {
281 return new (allocator
.allocate
<CooperativeMatrixTypeStorage
>())
282 CooperativeMatrixTypeStorage(key
);
285 bool operator==(const KeyTy
&key
) const {
286 return key
== KeyTy(elementType
, scope
, rows
, columns
);
289 CooperativeMatrixTypeStorage(const KeyTy
&key
)
290 : elementType(std::get
<0>(key
)), rows(std::get
<2>(key
)),
291 columns(std::get
<3>(key
)), scope(std::get
<1>(key
)) {}
299 CooperativeMatrixNVType
CooperativeMatrixNVType::get(Type elementType
,
300 Scope scope
, unsigned rows
,
302 return Base::get(elementType
.getContext(), elementType
, scope
, rows
, columns
);
305 Type
CooperativeMatrixNVType::getElementType() const {
306 return getImpl()->elementType
;
309 Scope
CooperativeMatrixNVType::getScope() const { return getImpl()->scope
; }
311 unsigned CooperativeMatrixNVType::getRows() const { return getImpl()->rows
; }
313 unsigned CooperativeMatrixNVType::getColumns() const {
314 return getImpl()->columns
;
317 void CooperativeMatrixNVType::getExtensions(
318 SPIRVType::ExtensionArrayRefVector
&extensions
,
319 Optional
<StorageClass
> storage
) {
320 getElementType().cast
<SPIRVType
>().getExtensions(extensions
, storage
);
321 static const Extension exts
[] = {Extension::SPV_NV_cooperative_matrix
};
322 ArrayRef
<Extension
> ref(exts
, llvm::array_lengthof(exts
));
323 extensions
.push_back(ref
);
326 void CooperativeMatrixNVType::getCapabilities(
327 SPIRVType::CapabilityArrayRefVector
&capabilities
,
328 Optional
<StorageClass
> storage
) {
329 getElementType().cast
<SPIRVType
>().getCapabilities(capabilities
, storage
);
330 static const Capability caps
[] = {Capability::CooperativeMatrixNV
};
331 ArrayRef
<Capability
> ref(caps
, llvm::array_lengthof(caps
));
332 capabilities
.push_back(ref
);
335 //===----------------------------------------------------------------------===//
337 //===----------------------------------------------------------------------===//
339 template <typename T
> static constexpr unsigned getNumBits() { return 0; }
340 template <> constexpr unsigned getNumBits
<Dim
>() {
341 static_assert((1 << 3) > getMaxEnumValForDim(),
342 "Not enough bits to encode Dim value");
345 template <> constexpr unsigned getNumBits
<ImageDepthInfo
>() {
346 static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(),
347 "Not enough bits to encode ImageDepthInfo value");
350 template <> constexpr unsigned getNumBits
<ImageArrayedInfo
>() {
351 static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(),
352 "Not enough bits to encode ImageArrayedInfo value");
355 template <> constexpr unsigned getNumBits
<ImageSamplingInfo
>() {
356 static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(),
357 "Not enough bits to encode ImageSamplingInfo value");
360 template <> constexpr unsigned getNumBits
<ImageSamplerUseInfo
>() {
361 static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(),
362 "Not enough bits to encode ImageSamplerUseInfo value");
365 template <> constexpr unsigned getNumBits
<ImageFormat
>() {
366 static_assert((1 << 6) > getMaxEnumValForImageFormat(),
367 "Not enough bits to encode ImageFormat value");
371 struct spirv::detail::ImageTypeStorage
: public TypeStorage
{
373 using KeyTy
= std::tuple
<Type
, Dim
, ImageDepthInfo
, ImageArrayedInfo
,
374 ImageSamplingInfo
, ImageSamplerUseInfo
, ImageFormat
>;
376 static ImageTypeStorage
*construct(TypeStorageAllocator
&allocator
,
378 return new (allocator
.allocate
<ImageTypeStorage
>()) ImageTypeStorage(key
);
381 bool operator==(const KeyTy
&key
) const {
382 return key
== KeyTy(elementType
, dim
, depthInfo
, arrayedInfo
, samplingInfo
,
383 samplerUseInfo
, format
);
386 ImageTypeStorage(const KeyTy
&key
)
387 : elementType(std::get
<0>(key
)), dim(std::get
<1>(key
)),
388 depthInfo(std::get
<2>(key
)), arrayedInfo(std::get
<3>(key
)),
389 samplingInfo(std::get
<4>(key
)), samplerUseInfo(std::get
<5>(key
)),
390 format(std::get
<6>(key
)) {}
393 Dim dim
: getNumBits
<Dim
>();
394 ImageDepthInfo depthInfo
: getNumBits
<ImageDepthInfo
>();
395 ImageArrayedInfo arrayedInfo
: getNumBits
<ImageArrayedInfo
>();
396 ImageSamplingInfo samplingInfo
: getNumBits
<ImageSamplingInfo
>();
397 ImageSamplerUseInfo samplerUseInfo
: getNumBits
<ImageSamplerUseInfo
>();
398 ImageFormat format
: getNumBits
<ImageFormat
>();
402 ImageType::get(std::tuple
<Type
, Dim
, ImageDepthInfo
, ImageArrayedInfo
,
403 ImageSamplingInfo
, ImageSamplerUseInfo
, ImageFormat
>
405 return Base::get(std::get
<0>(value
).getContext(), value
);
408 Type
ImageType::getElementType() const { return getImpl()->elementType
; }
410 Dim
ImageType::getDim() const { return getImpl()->dim
; }
412 ImageDepthInfo
ImageType::getDepthInfo() const { return getImpl()->depthInfo
; }
414 ImageArrayedInfo
ImageType::getArrayedInfo() const {
415 return getImpl()->arrayedInfo
;
418 ImageSamplingInfo
ImageType::getSamplingInfo() const {
419 return getImpl()->samplingInfo
;
422 ImageSamplerUseInfo
ImageType::getSamplerUseInfo() const {
423 return getImpl()->samplerUseInfo
;
426 ImageFormat
ImageType::getImageFormat() const { return getImpl()->format
; }
428 void ImageType::getExtensions(SPIRVType::ExtensionArrayRefVector
&,
429 Optional
<StorageClass
>) {
430 // Image types do not require extra extensions thus far.
433 void ImageType::getCapabilities(
434 SPIRVType::CapabilityArrayRefVector
&capabilities
, Optional
<StorageClass
>) {
435 if (auto dimCaps
= spirv::getCapabilities(getDim()))
436 capabilities
.push_back(*dimCaps
);
438 if (auto fmtCaps
= spirv::getCapabilities(getImageFormat()))
439 capabilities
.push_back(*fmtCaps
);
442 //===----------------------------------------------------------------------===//
444 //===----------------------------------------------------------------------===//
446 struct spirv::detail::PointerTypeStorage
: public TypeStorage
{
447 // (Type, StorageClass) as the key: Type stored in this struct, and
448 // StorageClass stored as TypeStorage's subclass data.
449 using KeyTy
= std::pair
<Type
, StorageClass
>;
451 static PointerTypeStorage
*construct(TypeStorageAllocator
&allocator
,
453 return new (allocator
.allocate
<PointerTypeStorage
>())
454 PointerTypeStorage(key
);
457 bool operator==(const KeyTy
&key
) const {
458 return key
== KeyTy(pointeeType
, storageClass
);
461 PointerTypeStorage(const KeyTy
&key
)
462 : pointeeType(key
.first
), storageClass(key
.second
) {}
465 StorageClass storageClass
;
468 PointerType
PointerType::get(Type pointeeType
, StorageClass storageClass
) {
469 return Base::get(pointeeType
.getContext(), pointeeType
, storageClass
);
472 Type
PointerType::getPointeeType() const { return getImpl()->pointeeType
; }
474 StorageClass
PointerType::getStorageClass() const {
475 return getImpl()->storageClass
;
478 void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector
&extensions
,
479 Optional
<StorageClass
> storage
) {
480 // Use this pointer type's storage class because this pointer indicates we are
481 // using the pointee type in that specific storage class.
482 getPointeeType().cast
<SPIRVType
>().getExtensions(extensions
,
485 if (auto scExts
= spirv::getExtensions(getStorageClass()))
486 extensions
.push_back(*scExts
);
489 void PointerType::getCapabilities(
490 SPIRVType::CapabilityArrayRefVector
&capabilities
,
491 Optional
<StorageClass
> storage
) {
492 // Use this pointer type's storage class because this pointer indicates we are
493 // using the pointee type in that specific storage class.
494 getPointeeType().cast
<SPIRVType
>().getCapabilities(capabilities
,
497 if (auto scCaps
= spirv::getCapabilities(getStorageClass()))
498 capabilities
.push_back(*scCaps
);
501 //===----------------------------------------------------------------------===//
503 //===----------------------------------------------------------------------===//
505 struct spirv::detail::RuntimeArrayTypeStorage
: public TypeStorage
{
506 using KeyTy
= std::pair
<Type
, unsigned>;
508 static RuntimeArrayTypeStorage
*construct(TypeStorageAllocator
&allocator
,
510 return new (allocator
.allocate
<RuntimeArrayTypeStorage
>())
511 RuntimeArrayTypeStorage(key
);
514 bool operator==(const KeyTy
&key
) const {
515 return key
== KeyTy(elementType
, stride
);
518 RuntimeArrayTypeStorage(const KeyTy
&key
)
519 : elementType(key
.first
), stride(key
.second
) {}
525 RuntimeArrayType
RuntimeArrayType::get(Type elementType
) {
526 return Base::get(elementType
.getContext(), elementType
, /*stride=*/0);
529 RuntimeArrayType
RuntimeArrayType::get(Type elementType
, unsigned stride
) {
530 return Base::get(elementType
.getContext(), elementType
, stride
);
533 Type
RuntimeArrayType::getElementType() const { return getImpl()->elementType
; }
535 unsigned RuntimeArrayType::getArrayStride() const { return getImpl()->stride
; }
537 void RuntimeArrayType::getExtensions(
538 SPIRVType::ExtensionArrayRefVector
&extensions
,
539 Optional
<StorageClass
> storage
) {
540 getElementType().cast
<SPIRVType
>().getExtensions(extensions
, storage
);
543 void RuntimeArrayType::getCapabilities(
544 SPIRVType::CapabilityArrayRefVector
&capabilities
,
545 Optional
<StorageClass
> storage
) {
547 static const Capability caps
[] = {Capability::Shader
};
548 ArrayRef
<Capability
> ref(caps
, llvm::array_lengthof(caps
));
549 capabilities
.push_back(ref
);
551 getElementType().cast
<SPIRVType
>().getCapabilities(capabilities
, storage
);
554 //===----------------------------------------------------------------------===//
556 //===----------------------------------------------------------------------===//
558 bool ScalarType::classof(Type type
) {
559 if (auto floatType
= type
.dyn_cast
<FloatType
>()) {
560 return isValid(floatType
);
562 if (auto intType
= type
.dyn_cast
<IntegerType
>()) {
563 return isValid(intType
);
568 bool ScalarType::isValid(FloatType type
) { return !type
.isBF16(); }
570 bool ScalarType::isValid(IntegerType type
) {
571 switch (type
.getWidth()) {
583 void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector
&extensions
,
584 Optional
<StorageClass
> storage
) {
585 // 8- or 16-bit integer/floating-point numbers will require extra extensions
586 // to appear in interface storage classes. See SPV_KHR_16bit_storage and
587 // SPV_KHR_8bit_storage for more details.
592 case StorageClass::PushConstant
:
593 case StorageClass::StorageBuffer
:
594 case StorageClass::Uniform
:
595 if (getIntOrFloatBitWidth() == 8) {
596 static const Extension exts
[] = {Extension::SPV_KHR_8bit_storage
};
597 ArrayRef
<Extension
> ref(exts
, llvm::array_lengthof(exts
));
598 extensions
.push_back(ref
);
601 case StorageClass::Input
:
602 case StorageClass::Output
:
603 if (getIntOrFloatBitWidth() == 16) {
604 static const Extension exts
[] = {Extension::SPV_KHR_16bit_storage
};
605 ArrayRef
<Extension
> ref(exts
, llvm::array_lengthof(exts
));
606 extensions
.push_back(ref
);
614 void ScalarType::getCapabilities(
615 SPIRVType::CapabilityArrayRefVector
&capabilities
,
616 Optional
<StorageClass
> storage
) {
617 unsigned bitwidth
= getIntOrFloatBitWidth();
619 // 8- or 16-bit integer/floating-point numbers will require extra capabilities
620 // to appear in interface storage classes. See SPV_KHR_16bit_storage and
621 // SPV_KHR_8bit_storage for more details.
623 #define STORAGE_CASE(storage, cap8, cap16) \
624 case StorageClass::storage: { \
625 if (bitwidth == 8) { \
626 static const Capability caps[] = {Capability::cap8}; \
627 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
628 capabilities.push_back(ref); \
629 } else if (bitwidth == 16) { \
630 static const Capability caps[] = {Capability::cap16}; \
631 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
632 capabilities.push_back(ref); \
634 /* No requirements for other bitwidths */ \
638 // This part only handles the cases where special bitwidths appearing in
639 // interface storage classes.
642 STORAGE_CASE(PushConstant
, StoragePushConstant8
, StoragePushConstant16
);
643 STORAGE_CASE(StorageBuffer
, StorageBuffer8BitAccess
,
644 StorageBuffer16BitAccess
);
645 STORAGE_CASE(Uniform
, UniformAndStorageBuffer8BitAccess
,
647 case StorageClass::Input
:
648 case StorageClass::Output
: {
649 if (bitwidth
== 16) {
650 static const Capability caps
[] = {Capability::StorageInputOutput16
};
651 ArrayRef
<Capability
> ref(caps
, llvm::array_lengthof(caps
));
652 capabilities
.push_back(ref
);
662 // For other non-interface storage classes, require a different set of
663 // capabilities for special bitwidths.
665 #define WIDTH_CASE(type, width) \
667 static const Capability caps[] = {Capability::type##width}; \
668 ArrayRef<Capability> ref(caps, llvm::array_lengthof(caps)); \
669 capabilities.push_back(ref); \
672 if (auto intType
= dyn_cast
<IntegerType
>()) {
681 llvm_unreachable("invalid bitwidth to getCapabilities");
684 assert(isa
<FloatType
>());
688 WIDTH_CASE(Float
, 16);
689 WIDTH_CASE(Float
, 64);
691 llvm_unreachable("invalid bitwidth to getCapabilities");
698 Optional
<int64_t> ScalarType::getSizeInBytes() {
699 auto bitWidth
= getIntOrFloatBitWidth();
700 // According to the SPIR-V spec:
701 // "There is no physical size or bit pattern defined for values with boolean
702 // type. If they are stored (in conjunction with OpVariable), they can only
703 // be used with logical addressing operations, not physical, and only with
704 // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
705 // Private, Function, Input, and Output."
711 //===----------------------------------------------------------------------===//
713 //===----------------------------------------------------------------------===//
715 bool SPIRVType::classof(Type type
) {
716 // Allow SPIR-V dialect types
717 if (llvm::isa
<SPIRVDialect
>(type
.getDialect()))
719 if (type
.isa
<ScalarType
>())
721 if (auto vectorType
= type
.dyn_cast
<VectorType
>())
722 return CompositeType::isValid(vectorType
);
726 bool SPIRVType::isScalarOrVector() {
727 return isIntOrFloat() || isa
<VectorType
>();
730 void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector
&extensions
,
731 Optional
<StorageClass
> storage
) {
732 if (auto scalarType
= dyn_cast
<ScalarType
>()) {
733 scalarType
.getExtensions(extensions
, storage
);
734 } else if (auto compositeType
= dyn_cast
<CompositeType
>()) {
735 compositeType
.getExtensions(extensions
, storage
);
736 } else if (auto imageType
= dyn_cast
<ImageType
>()) {
737 imageType
.getExtensions(extensions
, storage
);
738 } else if (auto matrixType
= dyn_cast
<MatrixType
>()) {
739 matrixType
.getExtensions(extensions
, storage
);
740 } else if (auto ptrType
= dyn_cast
<PointerType
>()) {
741 ptrType
.getExtensions(extensions
, storage
);
743 llvm_unreachable("invalid SPIR-V Type to getExtensions");
747 void SPIRVType::getCapabilities(
748 SPIRVType::CapabilityArrayRefVector
&capabilities
,
749 Optional
<StorageClass
> storage
) {
750 if (auto scalarType
= dyn_cast
<ScalarType
>()) {
751 scalarType
.getCapabilities(capabilities
, storage
);
752 } else if (auto compositeType
= dyn_cast
<CompositeType
>()) {
753 compositeType
.getCapabilities(capabilities
, storage
);
754 } else if (auto imageType
= dyn_cast
<ImageType
>()) {
755 imageType
.getCapabilities(capabilities
, storage
);
756 } else if (auto matrixType
= dyn_cast
<MatrixType
>()) {
757 matrixType
.getCapabilities(capabilities
, storage
);
758 } else if (auto ptrType
= dyn_cast
<PointerType
>()) {
759 ptrType
.getCapabilities(capabilities
, storage
);
761 llvm_unreachable("invalid SPIR-V Type to getCapabilities");
765 Optional
<int64_t> SPIRVType::getSizeInBytes() {
766 if (auto scalarType
= dyn_cast
<ScalarType
>())
767 return scalarType
.getSizeInBytes();
768 if (auto compositeType
= dyn_cast
<CompositeType
>())
769 return compositeType
.getSizeInBytes();
773 //===----------------------------------------------------------------------===//
775 //===----------------------------------------------------------------------===//
777 /// Type storage for SPIR-V structure types:
779 /// Structures are uniqued using:
780 /// - for identified structs:
781 /// - a string identifier;
782 /// - for literal structs:
783 /// - a list of member types;
784 /// - a list of member offset info;
785 /// - a list of member decoration info.
787 /// Identified structures only have a mutable component consisting of:
788 /// - a list of member types;
789 /// - a list of member offset info;
790 /// - a list of member decoration info.
791 struct spirv::detail::StructTypeStorage
: public TypeStorage
{
792 /// Construct a storage object for an identified struct type. A struct type
793 /// associated with such storage must call StructType::trySetBody(...) later
794 /// in order to mutate the storage object providing the actual content.
795 StructTypeStorage(StringRef identifier
)
796 : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
797 numMemberDecorations(0), memberDecorationsInfo(nullptr),
798 identifier(identifier
) {}
800 /// Construct a storage object for a literal struct type. A struct type
801 /// associated with such storage is immutable.
803 unsigned numMembers
, Type
const *memberTypes
,
804 StructType::OffsetInfo
const *layoutInfo
, unsigned numMemberDecorations
,
805 StructType::MemberDecorationInfo
const *memberDecorationsInfo
)
806 : memberTypesAndIsBodySet(memberTypes
, false), offsetInfo(layoutInfo
),
807 numMembers(numMembers
), numMemberDecorations(numMemberDecorations
),
808 memberDecorationsInfo(memberDecorationsInfo
), identifier(StringRef()) {}
810 /// A storage key is divided into 2 parts:
811 /// - for identified structs:
812 /// - a StringRef representing the struct identifier;
813 /// - for literal structs:
814 /// - an ArrayRef<Type> for member types;
815 /// - an ArrayRef<StructType::OffsetInfo> for member offset info;
816 /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
819 /// An identified struct type is uniqued only by the first part (field 0)
822 /// A literal struct type is unqiued only by the second part (fields 1, 2, and
823 /// 3) of the key. The identifier field (field 0) must be empty.
825 std::tuple
<StringRef
, ArrayRef
<Type
>, ArrayRef
<StructType::OffsetInfo
>,
826 ArrayRef
<StructType::MemberDecorationInfo
>>;
828 /// For identified structs, return true if the given key contains the same
831 /// For literal structs, return true if the given key contains a matching list
832 /// of member types + offset info + decoration info.
833 bool operator==(const KeyTy
&key
) const {
834 if (isIdentified()) {
835 // Identified types are uniqued by their identifier.
836 return getIdentifier() == std::get
<0>(key
);
839 return key
== KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
840 getMemberDecorationsInfo());
843 /// If the given key contains a non-empty identifier, this method constructs
844 /// an identified struct and leaves the rest of the struct type data to be set
845 /// through a later call to StructType::trySetBody(...).
847 /// If, on the other hand, the key contains an empty identifier, a literal
848 /// struct is constructed using the other fields of the key.
849 static StructTypeStorage
*construct(TypeStorageAllocator
&allocator
,
851 StringRef keyIdentifier
= std::get
<0>(key
);
853 if (!keyIdentifier
.empty()) {
854 StringRef identifier
= allocator
.copyInto(keyIdentifier
);
856 // Identified StructType body/members will be set through trySetBody(...)
858 return new (allocator
.allocate
<StructTypeStorage
>())
859 StructTypeStorage(identifier
);
862 ArrayRef
<Type
> keyTypes
= std::get
<1>(key
);
864 // Copy the member type and layout information into the bump pointer
865 const Type
*typesList
= nullptr;
866 if (!keyTypes
.empty()) {
867 typesList
= allocator
.copyInto(keyTypes
).data();
870 const StructType::OffsetInfo
*offsetInfoList
= nullptr;
871 if (!std::get
<2>(key
).empty()) {
872 ArrayRef
<StructType::OffsetInfo
> keyOffsetInfo
= std::get
<2>(key
);
873 assert(keyOffsetInfo
.size() == keyTypes
.size() &&
874 "size of offset information must be same as the size of number of "
876 offsetInfoList
= allocator
.copyInto(keyOffsetInfo
).data();
879 const StructType::MemberDecorationInfo
*memberDecorationList
= nullptr;
880 unsigned numMemberDecorations
= 0;
881 if (!std::get
<3>(key
).empty()) {
882 auto keyMemberDecorations
= std::get
<3>(key
);
883 numMemberDecorations
= keyMemberDecorations
.size();
884 memberDecorationList
= allocator
.copyInto(keyMemberDecorations
).data();
887 return new (allocator
.allocate
<StructTypeStorage
>())
888 StructTypeStorage(keyTypes
.size(), typesList
, offsetInfoList
,
889 numMemberDecorations
, memberDecorationList
);
892 ArrayRef
<Type
> getMemberTypes() const {
893 return ArrayRef
<Type
>(memberTypesAndIsBodySet
.getPointer(), numMembers
);
896 ArrayRef
<StructType::OffsetInfo
> getOffsetInfo() const {
898 return ArrayRef
<StructType::OffsetInfo
>(offsetInfo
, numMembers
);
903 ArrayRef
<StructType::MemberDecorationInfo
> getMemberDecorationsInfo() const {
904 if (memberDecorationsInfo
) {
905 return ArrayRef
<StructType::MemberDecorationInfo
>(memberDecorationsInfo
,
906 numMemberDecorations
);
911 StringRef
getIdentifier() const { return identifier
; }
913 bool isIdentified() const { return !identifier
.empty(); }
915 /// Sets the struct type content for identified structs. Calling this method
916 /// is only valid for identified structs.
918 /// Fails under the following conditions:
919 /// - If called for a literal struct;
920 /// - If called for an identified struct whose body was set before (through a
921 /// call to this method) but with different contents from the passed
923 LogicalResult
mutate(
924 TypeStorageAllocator
&allocator
, ArrayRef
<Type
> structMemberTypes
,
925 ArrayRef
<StructType::OffsetInfo
> structOffsetInfo
,
926 ArrayRef
<StructType::MemberDecorationInfo
> structMemberDecorationInfo
) {
930 if (memberTypesAndIsBodySet
.getInt() &&
931 (getMemberTypes() != structMemberTypes
||
932 getOffsetInfo() != structOffsetInfo
||
933 getMemberDecorationsInfo() != structMemberDecorationInfo
))
936 memberTypesAndIsBodySet
.setInt(true);
937 numMembers
= structMemberTypes
.size();
939 // Copy the member type and layout information into the bump pointer.
940 if (!structMemberTypes
.empty())
941 memberTypesAndIsBodySet
.setPointer(
942 allocator
.copyInto(structMemberTypes
).data());
944 if (!structOffsetInfo
.empty()) {
945 assert(structOffsetInfo
.size() == structMemberTypes
.size() &&
946 "size of offset information must be same as the size of number of "
948 offsetInfo
= allocator
.copyInto(structOffsetInfo
).data();
951 if (!structMemberDecorationInfo
.empty()) {
952 numMemberDecorations
= structMemberDecorationInfo
.size();
953 memberDecorationsInfo
=
954 allocator
.copyInto(structMemberDecorationInfo
).data();
960 llvm::PointerIntPair
<Type
const *, 1, bool> memberTypesAndIsBodySet
;
961 StructType::OffsetInfo
const *offsetInfo
;
963 unsigned numMemberDecorations
;
964 StructType::MemberDecorationInfo
const *memberDecorationsInfo
;
965 StringRef identifier
;
969 StructType::get(ArrayRef
<Type
> memberTypes
,
970 ArrayRef
<StructType::OffsetInfo
> offsetInfo
,
971 ArrayRef
<StructType::MemberDecorationInfo
> memberDecorations
) {
972 assert(!memberTypes
.empty() && "Struct needs at least one member type");
973 // Sort the decorations.
974 SmallVector
<StructType::MemberDecorationInfo
, 4> sortedDecorations(
975 memberDecorations
.begin(), memberDecorations
.end());
976 llvm::array_pod_sort(sortedDecorations
.begin(), sortedDecorations
.end());
977 return Base::get(memberTypes
.vec().front().getContext(),
978 /*identifier=*/StringRef(), memberTypes
, offsetInfo
,
982 StructType
StructType::getIdentified(MLIRContext
*context
,
983 StringRef identifier
) {
984 assert(!identifier
.empty() &&
985 "StructType identifier must be non-empty string");
987 return Base::get(context
, identifier
, ArrayRef
<Type
>(),
988 ArrayRef
<StructType::OffsetInfo
>(),
989 ArrayRef
<StructType::MemberDecorationInfo
>());
992 StructType
StructType::getEmpty(MLIRContext
*context
, StringRef identifier
) {
993 StructType newStructType
= Base::get(
994 context
, identifier
, ArrayRef
<Type
>(), ArrayRef
<StructType::OffsetInfo
>(),
995 ArrayRef
<StructType::MemberDecorationInfo
>());
996 // Set an empty body in case this is a identified struct.
997 if (newStructType
.isIdentified() &&
998 failed(newStructType
.trySetBody(
999 ArrayRef
<Type
>(), ArrayRef
<StructType::OffsetInfo
>(),
1000 ArrayRef
<StructType::MemberDecorationInfo
>())))
1001 return StructType();
1003 return newStructType
;
1006 StringRef
StructType::getIdentifier() const { return getImpl()->identifier
; }
1008 bool StructType::isIdentified() const { return getImpl()->isIdentified(); }
1010 unsigned StructType::getNumElements() const { return getImpl()->numMembers
; }
1012 Type
StructType::getElementType(unsigned index
) const {
1013 assert(getNumElements() > index
&& "member index out of range");
1014 return getImpl()->memberTypesAndIsBodySet
.getPointer()[index
];
1017 StructType::ElementTypeRange
StructType::getElementTypes() const {
1018 return ElementTypeRange(getImpl()->memberTypesAndIsBodySet
.getPointer(),
1022 bool StructType::hasOffset() const { return getImpl()->offsetInfo
; }
1024 uint64_t StructType::getMemberOffset(unsigned index
) const {
1025 assert(getNumElements() > index
&& "member index out of range");
1026 return getImpl()->offsetInfo
[index
];
1029 void StructType::getMemberDecorations(
1030 SmallVectorImpl
<StructType::MemberDecorationInfo
> &memberDecorations
)
1032 memberDecorations
.clear();
1033 auto implMemberDecorations
= getImpl()->getMemberDecorationsInfo();
1034 memberDecorations
.append(implMemberDecorations
.begin(),
1035 implMemberDecorations
.end());
1038 void StructType::getMemberDecorations(
1040 SmallVectorImpl
<StructType::MemberDecorationInfo
> &decorationsInfo
) const {
1041 assert(getNumElements() > index
&& "member index out of range");
1042 auto memberDecorations
= getImpl()->getMemberDecorationsInfo();
1043 decorationsInfo
.clear();
1044 for (const auto &memberDecoration
: memberDecorations
) {
1045 if (memberDecoration
.memberIndex
== index
) {
1046 decorationsInfo
.push_back(memberDecoration
);
1048 if (memberDecoration
.memberIndex
> index
) {
1049 // Early exit since the decorations are stored sorted.
1056 StructType::trySetBody(ArrayRef
<Type
> memberTypes
,
1057 ArrayRef
<OffsetInfo
> offsetInfo
,
1058 ArrayRef
<MemberDecorationInfo
> memberDecorations
) {
1059 return Base::mutate(memberTypes
, offsetInfo
, memberDecorations
);
1062 void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector
&extensions
,
1063 Optional
<StorageClass
> storage
) {
1064 for (Type elementType
: getElementTypes())
1065 elementType
.cast
<SPIRVType
>().getExtensions(extensions
, storage
);
1068 void StructType::getCapabilities(
1069 SPIRVType::CapabilityArrayRefVector
&capabilities
,
1070 Optional
<StorageClass
> storage
) {
1071 for (Type elementType
: getElementTypes())
1072 elementType
.cast
<SPIRVType
>().getCapabilities(capabilities
, storage
);
1075 llvm::hash_code
spirv::hash_value(
1076 const StructType::MemberDecorationInfo
&memberDecorationInfo
) {
1077 return llvm::hash_combine(memberDecorationInfo
.memberIndex
,
1078 memberDecorationInfo
.decoration
);
1081 //===----------------------------------------------------------------------===//
1083 //===----------------------------------------------------------------------===//
1085 struct spirv::detail::MatrixTypeStorage
: public TypeStorage
{
1086 MatrixTypeStorage(Type columnType
, uint32_t columnCount
)
1087 : TypeStorage(), columnType(columnType
), columnCount(columnCount
) {}
1089 using KeyTy
= std::tuple
<Type
, uint32_t>;
1091 static MatrixTypeStorage
*construct(TypeStorageAllocator
&allocator
,
1094 // Initialize the memory using placement new.
1095 return new (allocator
.allocate
<MatrixTypeStorage
>())
1096 MatrixTypeStorage(std::get
<0>(key
), std::get
<1>(key
));
1099 bool operator==(const KeyTy
&key
) const {
1100 return key
== KeyTy(columnType
, columnCount
);
1104 const uint32_t columnCount
;
1107 MatrixType
MatrixType::get(Type columnType
, uint32_t columnCount
) {
1108 return Base::get(columnType
.getContext(), columnType
, columnCount
);
1111 MatrixType
MatrixType::getChecked(Type columnType
, uint32_t columnCount
,
1112 Location location
) {
1113 return Base::getChecked(location
, columnType
, columnCount
);
1116 LogicalResult
MatrixType::verifyConstructionInvariants(Location loc
,
1118 uint32_t columnCount
) {
1119 if (columnCount
< 2 || columnCount
> 4)
1120 return emitError(loc
, "matrix can have 2, 3, or 4 columns only");
1122 if (!isValidColumnType(columnType
))
1123 return emitError(loc
, "matrix columns must be vectors of floats");
1125 /// The underlying vectors (columns) must be of size 2, 3, or 4
1126 ArrayRef
<int64_t> columnShape
= columnType
.cast
<VectorType
>().getShape();
1127 if (columnShape
.size() != 1)
1128 return emitError(loc
, "matrix columns must be 1D vectors");
1130 if (columnShape
[0] < 2 || columnShape
[0] > 4)
1131 return emitError(loc
, "matrix columns must be of size 2, 3, or 4");
1136 /// Returns true if the matrix elements are vectors of float elements
1137 bool MatrixType::isValidColumnType(Type columnType
) {
1138 if (auto vectorType
= columnType
.dyn_cast
<VectorType
>()) {
1139 if (vectorType
.getElementType().isa
<FloatType
>())
1145 Type
MatrixType::getColumnType() const { return getImpl()->columnType
; }
1147 Type
MatrixType::getElementType() const {
1148 return getImpl()->columnType
.cast
<VectorType
>().getElementType();
1151 unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount
; }
1153 unsigned MatrixType::getNumRows() const {
1154 return getImpl()->columnType
.cast
<VectorType
>().getShape()[0];
1157 unsigned MatrixType::getNumElements() const {
1158 return (getImpl()->columnCount
) * getNumRows();
1161 void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector
&extensions
,
1162 Optional
<StorageClass
> storage
) {
1163 getColumnType().cast
<SPIRVType
>().getExtensions(extensions
, storage
);
1166 void MatrixType::getCapabilities(
1167 SPIRVType::CapabilityArrayRefVector
&capabilities
,
1168 Optional
<StorageClass
> storage
) {
1170 static const Capability caps
[] = {Capability::Matrix
};
1171 ArrayRef
<Capability
> ref(caps
, llvm::array_lengthof(caps
));
1172 capabilities
.push_back(ref
);
1174 // Add any capabilities associated with the underlying vectors (i.e., columns)
1175 getColumnType().cast
<SPIRVType
>().getCapabilities(capabilities
, storage
);