1 //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines the SPIR-V dialect in MLIR.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "SPIRVParsingUtils.h"
17 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
20 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21 #include "mlir/Dialect/UB/IR/UBOps.h"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/DialectImplementation.h"
25 #include "mlir/IR/MLIRContext.h"
26 #include "mlir/Parser/Parser.h"
27 #include "mlir/Transforms/InliningUtils.h"
28 #include "llvm/ADT/DenseMap.h"
29 #include "llvm/ADT/Sequence.h"
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/ADT/StringExtras.h"
32 #include "llvm/ADT/StringMap.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/Support/raw_ostream.h"
37 using namespace mlir::spirv
;
39 #include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
41 //===----------------------------------------------------------------------===//
43 //===----------------------------------------------------------------------===//
45 /// Returns true if the given region contains spirv.Return or spirv.ReturnValue
47 static inline bool containsReturn(Region
®ion
) {
48 return llvm::any_of(region
, [](Block
&block
) {
49 Operation
*terminator
= block
.getTerminator();
50 return isa
<spirv::ReturnOp
, spirv::ReturnValueOp
>(terminator
);
55 /// This class defines the interface for inlining within the SPIR-V dialect.
56 struct SPIRVInlinerInterface
: public DialectInlinerInterface
{
57 using DialectInlinerInterface::DialectInlinerInterface
;
59 /// All call operations within SPIRV can be inlined.
60 bool isLegalToInline(Operation
*call
, Operation
*callable
,
61 bool wouldBeCloned
) const final
{
65 /// Returns true if the given region 'src' can be inlined into the region
66 /// 'dest' that is attached to an operation registered to the current dialect.
67 bool isLegalToInline(Region
*dest
, Region
*src
, bool wouldBeCloned
,
68 IRMapping
&) const final
{
69 // Return true here when inlining into spirv.func, spirv.mlir.selection, and
70 // spirv.mlir.loop operations.
71 auto *op
= dest
->getParentOp();
72 return isa
<spirv::FuncOp
, spirv::SelectionOp
, spirv::LoopOp
>(op
);
75 /// Returns true if the given operation 'op', that is registered to this
76 /// dialect, can be inlined into the region 'dest' that is attached to an
77 /// operation registered to the current dialect.
78 bool isLegalToInline(Operation
*op
, Region
*dest
, bool wouldBeCloned
,
79 IRMapping
&) const final
{
80 // TODO: Enable inlining structured control flows with return.
81 if ((isa
<spirv::SelectionOp
, spirv::LoopOp
>(op
)) &&
82 containsReturn(op
->getRegion(0)))
84 // TODO: we need to filter OpKill here to avoid inlining it to
85 // a loop continue construct:
86 // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
87 // However OpKill is fragment shader specific and we don't support it yet.
91 /// Handle the given inlined terminator by replacing it with a new operation
93 void handleTerminator(Operation
*op
, Block
*newDest
) const final
{
94 if (auto returnOp
= dyn_cast
<spirv::ReturnOp
>(op
)) {
95 OpBuilder(op
).create
<spirv::BranchOp
>(op
->getLoc(), newDest
);
97 } else if (auto retValOp
= dyn_cast
<spirv::ReturnValueOp
>(op
)) {
98 OpBuilder(op
).create
<spirv::BranchOp
>(retValOp
->getLoc(), newDest
,
99 retValOp
->getOperands());
104 /// Handle the given inlined terminator by replacing it with a new operation
106 void handleTerminator(Operation
*op
, ValueRange valuesToRepl
) const final
{
107 // Only spirv.ReturnValue needs to be handled here.
108 auto retValOp
= dyn_cast
<spirv::ReturnValueOp
>(op
);
112 // Replace the values directly with the return operands.
113 assert(valuesToRepl
.size() == 1 &&
114 "spirv.ReturnValue expected to only handle one result");
115 valuesToRepl
.front().replaceAllUsesWith(retValOp
.getValue());
120 //===----------------------------------------------------------------------===//
122 //===----------------------------------------------------------------------===//
124 void SPIRVDialect::initialize() {
125 registerAttributes();
131 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
134 addInterfaces
<SPIRVInlinerInterface
>();
136 // Allow unknown operations because SPIR-V is extensible.
137 allowUnknownOperations();
138 declarePromisedInterface
<gpu::TargetAttrInterface
, TargetEnvAttr
>();
141 std::string
SPIRVDialect::getAttributeName(Decoration decoration
) {
142 return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration
));
145 //===----------------------------------------------------------------------===//
147 //===----------------------------------------------------------------------===//
149 // Forward declarations.
150 template <typename ValTy
>
151 static std::optional
<ValTy
> parseAndVerify(SPIRVDialect
const &dialect
,
152 DialectAsmParser
&parser
);
154 std::optional
<Type
> parseAndVerify
<Type
>(SPIRVDialect
const &dialect
,
155 DialectAsmParser
&parser
);
158 std::optional
<unsigned> parseAndVerify
<unsigned>(SPIRVDialect
const &dialect
,
159 DialectAsmParser
&parser
);
161 static Type
parseAndVerifyType(SPIRVDialect
const &dialect
,
162 DialectAsmParser
&parser
) {
164 SMLoc typeLoc
= parser
.getCurrentLocation();
165 if (parser
.parseType(type
))
168 // Allow SPIR-V dialect types
169 if (&type
.getDialect() == &dialect
)
172 // Check other allowed types
173 if (auto t
= llvm::dyn_cast
<FloatType
>(type
)) {
175 parser
.emitError(typeLoc
, "cannot use 'bf16' to compose SPIR-V types");
178 } else if (auto t
= llvm::dyn_cast
<IntegerType
>(type
)) {
179 if (!ScalarType::isValid(t
)) {
180 parser
.emitError(typeLoc
,
181 "only 1/8/16/32/64-bit integer type allowed but found ")
185 } else if (auto t
= llvm::dyn_cast
<VectorType
>(type
)) {
186 if (t
.getRank() != 1) {
187 parser
.emitError(typeLoc
, "only 1-D vector allowed but found ") << t
;
190 if (t
.getNumElements() > 4) {
192 typeLoc
, "vector length has to be less than or equal to 4 but found ")
193 << t
.getNumElements();
197 parser
.emitError(typeLoc
, "cannot use ")
198 << type
<< " to compose SPIR-V types";
205 static Type
parseAndVerifyMatrixType(SPIRVDialect
const &dialect
,
206 DialectAsmParser
&parser
) {
208 SMLoc typeLoc
= parser
.getCurrentLocation();
209 if (parser
.parseType(type
))
212 if (auto t
= llvm::dyn_cast
<VectorType
>(type
)) {
213 if (t
.getRank() != 1) {
214 parser
.emitError(typeLoc
, "only 1-D vector allowed but found ") << t
;
217 if (t
.getNumElements() > 4 || t
.getNumElements() < 2) {
218 parser
.emitError(typeLoc
,
219 "matrix columns size has to be less than or equal "
220 "to 4 and greater than or equal 2, but found ")
221 << t
.getNumElements();
225 if (!llvm::isa
<FloatType
>(t
.getElementType())) {
226 parser
.emitError(typeLoc
, "matrix columns' elements must be of "
228 << t
.getElementType();
232 parser
.emitError(typeLoc
, "matrix must be composed using vector "
241 static Type
parseAndVerifySampledImageType(SPIRVDialect
const &dialect
,
242 DialectAsmParser
&parser
) {
244 SMLoc typeLoc
= parser
.getCurrentLocation();
245 if (parser
.parseType(type
))
248 if (!llvm::isa
<ImageType
>(type
)) {
249 parser
.emitError(typeLoc
,
250 "sampled image must be composed using image type, got ")
258 /// Parses an optional `, stride = N` assembly segment. If no parsing failure
259 /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
261 static LogicalResult
parseOptionalArrayStride(const SPIRVDialect
&dialect
,
262 DialectAsmParser
&parser
,
264 if (failed(parser
.parseOptionalComma())) {
269 if (parser
.parseKeyword("stride") || parser
.parseEqual())
272 SMLoc strideLoc
= parser
.getCurrentLocation();
273 std::optional
<unsigned> optStride
= parseAndVerify
<unsigned>(dialect
, parser
);
277 if (!(stride
= *optStride
)) {
278 parser
.emitError(strideLoc
, "ArrayStride must be greater than zero");
284 // element-type ::= integer-type
285 // | floating-point-type
289 // array-type ::= `!spirv.array` `<` integer-literal `x` element-type
290 // (`,` `stride` `=` integer-literal)? `>`
291 static Type
parseArrayType(SPIRVDialect
const &dialect
,
292 DialectAsmParser
&parser
) {
293 if (parser
.parseLess())
296 SmallVector
<int64_t, 1> countDims
;
297 SMLoc countLoc
= parser
.getCurrentLocation();
298 if (parser
.parseDimensionList(countDims
, /*allowDynamic=*/false))
300 if (countDims
.size() != 1) {
301 parser
.emitError(countLoc
,
302 "expected single integer for array element count");
306 // According to the SPIR-V spec:
307 // "Length is the number of elements in the array. It must be at least 1."
308 int64_t count
= countDims
[0];
310 parser
.emitError(countLoc
, "expected array length greater than 0");
314 Type elementType
= parseAndVerifyType(dialect
, parser
);
319 if (failed(parseOptionalArrayStride(dialect
, parser
, stride
)))
322 if (parser
.parseGreater())
324 return ArrayType::get(elementType
, count
, stride
);
327 // cooperative-matrix-type ::=
328 // `!spirv.coopmatrix` `<` rows `x` columns `x` element-type `,`
330 static Type
parseCooperativeMatrixType(SPIRVDialect
const &dialect
,
331 DialectAsmParser
&parser
) {
332 if (parser
.parseLess())
335 SmallVector
<int64_t, 2> dims
;
336 SMLoc countLoc
= parser
.getCurrentLocation();
337 if (parser
.parseDimensionList(dims
, /*allowDynamic=*/false))
340 if (dims
.size() != 2) {
341 parser
.emitError(countLoc
, "expected row and column count");
345 auto elementTy
= parseAndVerifyType(dialect
, parser
);
350 if (parser
.parseComma() ||
351 spirv::parseEnumKeywordAttr(scope
, parser
, "scope <id>"))
354 CooperativeMatrixUseKHR use
;
355 if (parser
.parseComma() ||
356 spirv::parseEnumKeywordAttr(use
, parser
, "use <id>"))
359 if (parser
.parseGreater())
362 return CooperativeMatrixType::get(elementTy
, dims
[0], dims
[1], scope
, use
);
365 // TODO: Reorder methods to be utilities first and parse*Type
366 // methods in alphabetical order
368 // storage-class ::= `UniformConstant`
371 // | <and other storage classes...>
373 // pointer-type ::= `!spirv.ptr<` element-type `,` storage-class `>`
374 static Type
parsePointerType(SPIRVDialect
const &dialect
,
375 DialectAsmParser
&parser
) {
376 if (parser
.parseLess())
379 auto pointeeType
= parseAndVerifyType(dialect
, parser
);
383 StringRef storageClassSpec
;
384 SMLoc storageClassLoc
= parser
.getCurrentLocation();
385 if (parser
.parseComma() || parser
.parseKeyword(&storageClassSpec
))
388 auto storageClass
= symbolizeStorageClass(storageClassSpec
);
390 parser
.emitError(storageClassLoc
, "unknown storage class: ")
394 if (parser
.parseGreater())
396 return PointerType::get(pointeeType
, *storageClass
);
399 // runtime-array-type ::= `!spirv.rtarray` `<` element-type
400 // (`,` `stride` `=` integer-literal)? `>`
401 static Type
parseRuntimeArrayType(SPIRVDialect
const &dialect
,
402 DialectAsmParser
&parser
) {
403 if (parser
.parseLess())
406 Type elementType
= parseAndVerifyType(dialect
, parser
);
411 if (failed(parseOptionalArrayStride(dialect
, parser
, stride
)))
414 if (parser
.parseGreater())
416 return RuntimeArrayType::get(elementType
, stride
);
419 // matrix-type ::= `!spirv.matrix` `<` integer-literal `x` element-type `>`
420 static Type
parseMatrixType(SPIRVDialect
const &dialect
,
421 DialectAsmParser
&parser
) {
422 if (parser
.parseLess())
425 SmallVector
<int64_t, 1> countDims
;
426 SMLoc countLoc
= parser
.getCurrentLocation();
427 if (parser
.parseDimensionList(countDims
, /*allowDynamic=*/false))
429 if (countDims
.size() != 1) {
430 parser
.emitError(countLoc
, "expected single unsigned "
431 "integer for number of columns");
435 int64_t columnCount
= countDims
[0];
436 // According to the specification, Matrices can have 2, 3, or 4 columns
437 if (columnCount
< 2 || columnCount
> 4) {
438 parser
.emitError(countLoc
, "matrix is expected to have 2, 3, or 4 "
443 Type columnType
= parseAndVerifyMatrixType(dialect
, parser
);
447 if (parser
.parseGreater())
450 return MatrixType::get(columnType
, columnCount
);
453 // Specialize this function to parse each of the parameters that define an
454 // ImageType. By default it assumes this is an enum type.
455 template <typename ValTy
>
456 static std::optional
<ValTy
> parseAndVerify(SPIRVDialect
const &dialect
,
457 DialectAsmParser
&parser
) {
459 SMLoc enumLoc
= parser
.getCurrentLocation();
460 if (parser
.parseKeyword(&enumSpec
)) {
464 auto val
= spirv::symbolizeEnum
<ValTy
>(enumSpec
);
466 parser
.emitError(enumLoc
, "unknown attribute: '") << enumSpec
<< "'";
471 std::optional
<Type
> parseAndVerify
<Type
>(SPIRVDialect
const &dialect
,
472 DialectAsmParser
&parser
) {
473 // TODO: Further verify that the element type can be sampled
474 auto ty
= parseAndVerifyType(dialect
, parser
);
480 template <typename IntTy
>
481 static std::optional
<IntTy
> parseAndVerifyInteger(SPIRVDialect
const &dialect
,
482 DialectAsmParser
&parser
) {
483 IntTy offsetVal
= std::numeric_limits
<IntTy
>::max();
484 if (parser
.parseInteger(offsetVal
))
490 std::optional
<unsigned> parseAndVerify
<unsigned>(SPIRVDialect
const &dialect
,
491 DialectAsmParser
&parser
) {
492 return parseAndVerifyInteger
<unsigned>(dialect
, parser
);
496 // Functor object to parse a comma separated list of specs. The function
497 // parseAndVerify does the actual parsing and verification of individual
498 // elements. This is a functor since parsing the last element of the list
499 // (termination condition) needs partial specialization.
500 template <typename ParseType
, typename
... Args
>
501 struct ParseCommaSeparatedList
{
502 std::optional
<std::tuple
<ParseType
, Args
...>>
503 operator()(SPIRVDialect
const &dialect
, DialectAsmParser
&parser
) const {
504 auto parseVal
= parseAndVerify
<ParseType
>(dialect
, parser
);
508 auto numArgs
= std::tuple_size
<std::tuple
<Args
...>>::value
;
509 if (numArgs
!= 0 && failed(parser
.parseComma()))
511 auto remainingValues
= ParseCommaSeparatedList
<Args
...>{}(dialect
, parser
);
512 if (!remainingValues
)
514 return std::tuple_cat(std::tuple
<ParseType
>(parseVal
.value()),
515 remainingValues
.value());
519 // Partial specialization of the function to parse a comma separated list of
520 // specs to parse the last element of the list.
521 template <typename ParseType
>
522 struct ParseCommaSeparatedList
<ParseType
> {
523 std::optional
<std::tuple
<ParseType
>>
524 operator()(SPIRVDialect
const &dialect
, DialectAsmParser
&parser
) const {
525 if (auto value
= parseAndVerify
<ParseType
>(dialect
, parser
))
526 return std::tuple
<ParseType
>(*value
);
532 // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
534 // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
536 // arrayed-info ::= `NonArrayed` | `Arrayed`
538 // sampling-info ::= `SingleSampled` | `MultiSampled`
540 // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler`
542 // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
544 // image-type ::= `!spirv.image<` element-type `,` dim `,` depth-info `,`
545 // arrayed-info `,` sampling-info `,`
546 // sampler-use-info `,` format `>`
547 static Type
parseImageType(SPIRVDialect
const &dialect
,
548 DialectAsmParser
&parser
) {
549 if (parser
.parseLess())
553 ParseCommaSeparatedList
<Type
, Dim
, ImageDepthInfo
, ImageArrayedInfo
,
554 ImageSamplingInfo
, ImageSamplerUseInfo
,
555 ImageFormat
>{}(dialect
, parser
);
559 if (parser
.parseGreater())
561 return ImageType::get(*value
);
564 // sampledImage-type :: = `!spirv.sampledImage<` image-type `>`
565 static Type
parseSampledImageType(SPIRVDialect
const &dialect
,
566 DialectAsmParser
&parser
) {
567 if (parser
.parseLess())
570 Type parsedType
= parseAndVerifySampledImageType(dialect
, parser
);
574 if (parser
.parseGreater())
576 return SampledImageType::get(parsedType
);
579 // Parse decorations associated with a member.
580 static ParseResult
parseStructMemberDecorations(
581 SPIRVDialect
const &dialect
, DialectAsmParser
&parser
,
582 ArrayRef
<Type
> memberTypes
,
583 SmallVectorImpl
<StructType::OffsetInfo
> &offsetInfo
,
584 SmallVectorImpl
<StructType::MemberDecorationInfo
> &memberDecorationInfo
) {
586 // Check if the first element is offset.
587 SMLoc offsetLoc
= parser
.getCurrentLocation();
588 StructType::OffsetInfo offset
= 0;
589 OptionalParseResult offsetParseResult
= parser
.parseOptionalInteger(offset
);
590 if (offsetParseResult
.has_value()) {
591 if (failed(*offsetParseResult
))
594 if (offsetInfo
.size() != memberTypes
.size() - 1) {
595 return parser
.emitError(offsetLoc
,
596 "offset specification must be given for "
599 offsetInfo
.push_back(offset
);
602 // Check for no spirv::Decorations.
603 if (succeeded(parser
.parseOptionalRSquare()))
606 // If there was an offset, make sure to parse the comma.
607 if (offsetParseResult
.has_value() && parser
.parseComma())
610 // Check for spirv::Decorations.
611 auto parseDecorations
= [&]() {
612 auto memberDecoration
= parseAndVerify
<spirv::Decoration
>(dialect
, parser
);
613 if (!memberDecoration
)
616 // Parse member decoration value if it exists.
617 if (succeeded(parser
.parseOptionalEqual())) {
618 auto memberDecorationValue
=
619 parseAndVerifyInteger
<uint32_t>(dialect
, parser
);
621 if (!memberDecorationValue
)
624 memberDecorationInfo
.emplace_back(
625 static_cast<uint32_t>(memberTypes
.size() - 1), 1,
626 memberDecoration
.value(), memberDecorationValue
.value());
628 memberDecorationInfo
.emplace_back(
629 static_cast<uint32_t>(memberTypes
.size() - 1), 0,
630 memberDecoration
.value(), 0);
634 if (failed(parser
.parseCommaSeparatedList(parseDecorations
)) ||
635 failed(parser
.parseRSquare()))
641 // struct-member-decoration ::= integer-literal? spirv-decoration*
643 // `!spirv.struct<` (id `,`)?
645 // (spirv-type (`[` struct-member-decoration `]`)?)*
647 static Type
parseStructType(SPIRVDialect
const &dialect
,
648 DialectAsmParser
&parser
) {
649 // TODO: This function is quite lengthy. Break it down into smaller chunks.
651 if (parser
.parseLess())
654 StringRef identifier
;
655 FailureOr
<DialectAsmParser::CyclicParseReset
> cyclicParse
;
657 // Check if this is an identified struct type.
658 if (succeeded(parser
.parseOptionalKeyword(&identifier
))) {
659 // Check if this is a possible recursive reference.
661 StructType::getIdentified(dialect
.getContext(), identifier
);
662 cyclicParse
= parser
.tryStartCyclicParse(structType
);
663 if (succeeded(parser
.parseOptionalGreater())) {
664 if (succeeded(cyclicParse
)) {
667 "recursive struct reference not nested in struct definition");
675 if (failed(parser
.parseComma()))
678 if (failed(cyclicParse
)) {
679 parser
.emitError(parser
.getNameLoc(),
680 "identifier already used for an enclosing struct");
685 if (failed(parser
.parseLParen()))
688 if (succeeded(parser
.parseOptionalRParen()) &&
689 succeeded(parser
.parseOptionalGreater())) {
690 return StructType::getEmpty(dialect
.getContext(), identifier
);
693 StructType idStructTy
;
695 if (!identifier
.empty())
696 idStructTy
= StructType::getIdentified(dialect
.getContext(), identifier
);
698 SmallVector
<Type
, 4> memberTypes
;
699 SmallVector
<StructType::OffsetInfo
, 4> offsetInfo
;
700 SmallVector
<StructType::MemberDecorationInfo
, 4> memberDecorationInfo
;
704 if (parser
.parseType(memberType
))
706 memberTypes
.push_back(memberType
);
708 if (succeeded(parser
.parseOptionalLSquare()))
709 if (parseStructMemberDecorations(dialect
, parser
, memberTypes
, offsetInfo
,
710 memberDecorationInfo
))
712 } while (succeeded(parser
.parseOptionalComma()));
714 if (!offsetInfo
.empty() && memberTypes
.size() != offsetInfo
.size()) {
715 parser
.emitError(parser
.getNameLoc(),
716 "offset specification must be given for all members");
720 if (failed(parser
.parseRParen()) || failed(parser
.parseGreater()))
723 if (!identifier
.empty()) {
724 if (failed(idStructTy
.trySetBody(memberTypes
, offsetInfo
,
725 memberDecorationInfo
)))
730 return StructType::get(memberTypes
, offsetInfo
, memberDecorationInfo
);
733 // spirv-type ::= array-type
737 // | runtime-array-type
738 // | sampled-image-type
740 Type
SPIRVDialect::parseType(DialectAsmParser
&parser
) const {
742 if (parser
.parseKeyword(&keyword
))
745 if (keyword
== "array")
746 return parseArrayType(*this, parser
);
747 if (keyword
== "coopmatrix")
748 return parseCooperativeMatrixType(*this, parser
);
749 if (keyword
== "image")
750 return parseImageType(*this, parser
);
751 if (keyword
== "ptr")
752 return parsePointerType(*this, parser
);
753 if (keyword
== "rtarray")
754 return parseRuntimeArrayType(*this, parser
);
755 if (keyword
== "sampled_image")
756 return parseSampledImageType(*this, parser
);
757 if (keyword
== "struct")
758 return parseStructType(*this, parser
);
759 if (keyword
== "matrix")
760 return parseMatrixType(*this, parser
);
761 parser
.emitError(parser
.getNameLoc(), "unknown SPIR-V type: ") << keyword
;
765 //===----------------------------------------------------------------------===//
767 //===----------------------------------------------------------------------===//
769 static void print(ArrayType type
, DialectAsmPrinter
&os
) {
770 os
<< "array<" << type
.getNumElements() << " x " << type
.getElementType();
771 if (unsigned stride
= type
.getArrayStride())
772 os
<< ", stride=" << stride
;
776 static void print(RuntimeArrayType type
, DialectAsmPrinter
&os
) {
777 os
<< "rtarray<" << type
.getElementType();
778 if (unsigned stride
= type
.getArrayStride())
779 os
<< ", stride=" << stride
;
783 static void print(PointerType type
, DialectAsmPrinter
&os
) {
784 os
<< "ptr<" << type
.getPointeeType() << ", "
785 << stringifyStorageClass(type
.getStorageClass()) << ">";
788 static void print(ImageType type
, DialectAsmPrinter
&os
) {
789 os
<< "image<" << type
.getElementType() << ", " << stringifyDim(type
.getDim())
790 << ", " << stringifyImageDepthInfo(type
.getDepthInfo()) << ", "
791 << stringifyImageArrayedInfo(type
.getArrayedInfo()) << ", "
792 << stringifyImageSamplingInfo(type
.getSamplingInfo()) << ", "
793 << stringifyImageSamplerUseInfo(type
.getSamplerUseInfo()) << ", "
794 << stringifyImageFormat(type
.getImageFormat()) << ">";
797 static void print(SampledImageType type
, DialectAsmPrinter
&os
) {
798 os
<< "sampled_image<" << type
.getImageType() << ">";
801 static void print(StructType type
, DialectAsmPrinter
&os
) {
802 FailureOr
<AsmPrinter::CyclicPrintReset
> cyclicPrint
;
806 if (type
.isIdentified()) {
807 os
<< type
.getIdentifier();
809 cyclicPrint
= os
.tryStartCyclicPrint(type
);
810 if (failed(cyclicPrint
)) {
820 auto printMember
= [&](unsigned i
) {
821 os
<< type
.getElementType(i
);
822 SmallVector
<spirv::StructType::MemberDecorationInfo
, 0> decorations
;
823 type
.getMemberDecorations(i
, decorations
);
824 if (type
.hasOffset() || !decorations
.empty()) {
826 if (type
.hasOffset()) {
827 os
<< type
.getMemberOffset(i
);
828 if (!decorations
.empty())
831 auto eachFn
= [&os
](spirv::StructType::MemberDecorationInfo decoration
) {
832 os
<< stringifyDecoration(decoration
.decoration
);
833 if (decoration
.hasValue
) {
834 os
<< "=" << decoration
.decorationValue
;
837 llvm::interleaveComma(decorations
, os
, eachFn
);
841 llvm::interleaveComma(llvm::seq
<unsigned>(0, type
.getNumElements()), os
,
846 static void print(CooperativeMatrixType type
, DialectAsmPrinter
&os
) {
847 os
<< "coopmatrix<" << type
.getRows() << "x" << type
.getColumns() << "x"
848 << type
.getElementType() << ", " << type
.getScope() << ", "
849 << type
.getUse() << ">";
852 static void print(MatrixType type
, DialectAsmPrinter
&os
) {
853 os
<< "matrix<" << type
.getNumColumns() << " x " << type
.getColumnType();
857 void SPIRVDialect::printType(Type type
, DialectAsmPrinter
&os
) const {
858 TypeSwitch
<Type
>(type
)
859 .Case
<ArrayType
, CooperativeMatrixType
, PointerType
, RuntimeArrayType
,
860 ImageType
, SampledImageType
, StructType
, MatrixType
>(
861 [&](auto type
) { print(type
, os
); })
862 .Default([](Type
) { llvm_unreachable("unhandled SPIR-V type"); });
865 //===----------------------------------------------------------------------===//
867 //===----------------------------------------------------------------------===//
869 Operation
*SPIRVDialect::materializeConstant(OpBuilder
&builder
,
870 Attribute value
, Type type
,
872 if (auto poison
= dyn_cast
<ub::PoisonAttr
>(value
))
873 return builder
.create
<ub::PoisonOp
>(loc
, type
, poison
);
875 if (!spirv::ConstantOp::isBuildableWith(type
))
878 return builder
.create
<spirv::ConstantOp
>(loc
, type
, value
);
881 //===----------------------------------------------------------------------===//
882 // Shader Interface ABI
883 //===----------------------------------------------------------------------===//
885 LogicalResult
SPIRVDialect::verifyOperationAttribute(Operation
*op
,
886 NamedAttribute attribute
) {
887 StringRef symbol
= attribute
.getName().strref();
888 Attribute attr
= attribute
.getValue();
890 if (symbol
== spirv::getEntryPointABIAttrName()) {
891 if (!llvm::isa
<spirv::EntryPointABIAttr
>(attr
)) {
892 return op
->emitError("'")
893 << symbol
<< "' attribute must be an entry point ABI attribute";
895 } else if (symbol
== spirv::getTargetEnvAttrName()) {
896 if (!llvm::isa
<spirv::TargetEnvAttr
>(attr
))
897 return op
->emitError("'") << symbol
<< "' must be a spirv::TargetEnvAttr";
899 return op
->emitError("found unsupported '")
900 << symbol
<< "' attribute on operation";
906 /// Verifies the given SPIR-V `attribute` attached to a value of the given
907 /// `valueType` is valid.
908 static LogicalResult
verifyRegionAttribute(Location loc
, Type valueType
,
909 NamedAttribute attribute
) {
910 StringRef symbol
= attribute
.getName().strref();
911 Attribute attr
= attribute
.getValue();
913 if (symbol
== spirv::getInterfaceVarABIAttrName()) {
914 auto varABIAttr
= llvm::dyn_cast
<spirv::InterfaceVarABIAttr
>(attr
);
916 return emitError(loc
, "'")
917 << symbol
<< "' must be a spirv::InterfaceVarABIAttr";
919 if (varABIAttr
.getStorageClass() && !valueType
.isIntOrIndexOrFloat())
920 return emitError(loc
, "'") << symbol
921 << "' attribute cannot specify storage class "
922 "when attaching to a non-scalar value";
925 if (symbol
== spirv::DecorationAttr::name
) {
926 if (!isa
<spirv::DecorationAttr
>(attr
))
927 return emitError(loc
, "'")
928 << symbol
<< "' must be a spirv::DecorationAttr";
932 return emitError(loc
, "found unsupported '")
933 << symbol
<< "' attribute on region argument";
936 LogicalResult
SPIRVDialect::verifyRegionArgAttribute(Operation
*op
,
937 unsigned regionIndex
,
939 NamedAttribute attribute
) {
940 auto funcOp
= dyn_cast
<FunctionOpInterface
>(op
);
943 Type argType
= funcOp
.getArgumentTypes()[argIndex
];
945 return verifyRegionAttribute(op
->getLoc(), argType
, attribute
);
948 LogicalResult
SPIRVDialect::verifyRegionResultAttribute(
949 Operation
*op
, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
950 NamedAttribute attribute
) {
951 return op
->emitError("cannot attach SPIR-V attributes to region result");