1 //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines the SPIR-V dialect in MLIR.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
14 #include "mlir/Dialect/SPIRV/IR/ParserUtils.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
17 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/DialectImplementation.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/Parser.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/StringMap.h"
29 #include "llvm/ADT/StringSwitch.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/raw_ostream.h"
34 using namespace mlir::spirv
;
36 //===----------------------------------------------------------------------===//
38 //===----------------------------------------------------------------------===//
40 /// Returns true if the given region contains spv.Return or spv.ReturnValue ops.
41 static inline bool containsReturn(Region
®ion
) {
42 return llvm::any_of(region
, [](Block
&block
) {
43 Operation
*terminator
= block
.getTerminator();
44 return isa
<spirv::ReturnOp
, spirv::ReturnValueOp
>(terminator
);
49 /// This class defines the interface for inlining within the SPIR-V dialect.
50 struct SPIRVInlinerInterface
: public DialectInlinerInterface
{
51 using DialectInlinerInterface::DialectInlinerInterface
;
53 /// All call operations within SPIRV can be inlined.
54 bool isLegalToInline(Operation
*call
, Operation
*callable
,
55 bool wouldBeCloned
) const final
{
59 /// Returns true if the given region 'src' can be inlined into the region
60 /// 'dest' that is attached to an operation registered to the current dialect.
61 bool isLegalToInline(Region
*dest
, Region
*src
, bool wouldBeCloned
,
62 BlockAndValueMapping
&) const final
{
63 // Return true here when inlining into spv.func, spv.selection, and
64 // spv.loop operations.
65 auto *op
= dest
->getParentOp();
66 return isa
<spirv::FuncOp
, spirv::SelectionOp
, spirv::LoopOp
>(op
);
69 /// Returns true if the given operation 'op', that is registered to this
70 /// dialect, can be inlined into the region 'dest' that is attached to an
71 /// operation registered to the current dialect.
72 bool isLegalToInline(Operation
*op
, Region
*dest
, bool wouldBeCloned
,
73 BlockAndValueMapping
&) const final
{
74 // TODO: Enable inlining structured control flows with return.
75 if ((isa
<spirv::SelectionOp
, spirv::LoopOp
>(op
)) &&
76 containsReturn(op
->getRegion(0)))
78 // TODO: we need to filter OpKill here to avoid inlining it to
79 // a loop continue construct:
80 // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
81 // However OpKill is fragment shader specific and we don't support it yet.
85 /// Handle the given inlined terminator by replacing it with a new operation
87 void handleTerminator(Operation
*op
, Block
*newDest
) const final
{
88 if (auto returnOp
= dyn_cast
<spirv::ReturnOp
>(op
)) {
89 OpBuilder(op
).create
<spirv::BranchOp
>(op
->getLoc(), newDest
);
91 } else if (auto retValOp
= dyn_cast
<spirv::ReturnValueOp
>(op
)) {
92 llvm_unreachable("unimplemented spv.ReturnValue in inliner");
96 /// Handle the given inlined terminator by replacing it with a new operation
98 void handleTerminator(Operation
*op
,
99 ArrayRef
<Value
> valuesToRepl
) const final
{
100 // Only spv.ReturnValue needs to be handled here.
101 auto retValOp
= dyn_cast
<spirv::ReturnValueOp
>(op
);
105 // Replace the values directly with the return operands.
106 assert(valuesToRepl
.size() == 1 &&
107 "spv.ReturnValue expected to only handle one result");
108 valuesToRepl
.front().replaceAllUsesWith(retValOp
.value());
113 //===----------------------------------------------------------------------===//
115 //===----------------------------------------------------------------------===//
117 void SPIRVDialect::initialize() {
118 addTypes
<ArrayType
, CooperativeMatrixNVType
, ImageType
, MatrixType
,
119 PointerType
, RuntimeArrayType
, StructType
>();
121 addAttributes
<InterfaceVarABIAttr
, TargetEnvAttr
, VerCapExtAttr
>();
126 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
129 addInterfaces
<SPIRVInlinerInterface
>();
131 // Allow unknown operations because SPIR-V is extensible.
132 allowUnknownOperations();
135 std::string
SPIRVDialect::getAttributeName(Decoration decoration
) {
136 return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration
));
139 //===----------------------------------------------------------------------===//
141 //===----------------------------------------------------------------------===//
143 // Forward declarations.
144 template <typename ValTy
>
145 static Optional
<ValTy
> parseAndVerify(SPIRVDialect
const &dialect
,
146 DialectAsmParser
&parser
);
148 Optional
<Type
> parseAndVerify
<Type
>(SPIRVDialect
const &dialect
,
149 DialectAsmParser
&parser
);
152 Optional
<unsigned> parseAndVerify
<unsigned>(SPIRVDialect
const &dialect
,
153 DialectAsmParser
&parser
);
155 static Type
parseAndVerifyType(SPIRVDialect
const &dialect
,
156 DialectAsmParser
&parser
) {
158 llvm::SMLoc typeLoc
= parser
.getCurrentLocation();
159 if (parser
.parseType(type
))
162 // Allow SPIR-V dialect types
163 if (&type
.getDialect() == &dialect
)
166 // Check other allowed types
167 if (auto t
= type
.dyn_cast
<FloatType
>()) {
169 parser
.emitError(typeLoc
, "cannot use 'bf16' to compose SPIR-V types");
172 } else if (auto t
= type
.dyn_cast
<IntegerType
>()) {
173 if (!ScalarType::isValid(t
)) {
174 parser
.emitError(typeLoc
,
175 "only 1/8/16/32/64-bit integer type allowed but found ")
179 } else if (auto t
= type
.dyn_cast
<VectorType
>()) {
180 if (t
.getRank() != 1) {
181 parser
.emitError(typeLoc
, "only 1-D vector allowed but found ") << t
;
184 if (t
.getNumElements() > 4) {
186 typeLoc
, "vector length has to be less than or equal to 4 but found ")
187 << t
.getNumElements();
191 parser
.emitError(typeLoc
, "cannot use ")
192 << type
<< " to compose SPIR-V types";
199 static Type
parseAndVerifyMatrixType(SPIRVDialect
const &dialect
,
200 DialectAsmParser
&parser
) {
202 llvm::SMLoc typeLoc
= parser
.getCurrentLocation();
203 if (parser
.parseType(type
))
206 if (auto t
= type
.dyn_cast
<VectorType
>()) {
207 if (t
.getRank() != 1) {
208 parser
.emitError(typeLoc
, "only 1-D vector allowed but found ") << t
;
211 if (t
.getNumElements() > 4 || t
.getNumElements() < 2) {
212 parser
.emitError(typeLoc
,
213 "matrix columns size has to be less than or equal "
214 "to 4 and greater than or equal 2, but found ")
215 << t
.getNumElements();
219 if (!t
.getElementType().isa
<FloatType
>()) {
220 parser
.emitError(typeLoc
, "matrix columns' elements must be of "
222 << t
.getElementType();
226 parser
.emitError(typeLoc
, "matrix must be composed using vector "
235 /// Parses an optional `, stride = N` assembly segment. If no parsing failure
236 /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
238 static LogicalResult
parseOptionalArrayStride(const SPIRVDialect
&dialect
,
239 DialectAsmParser
&parser
,
241 if (failed(parser
.parseOptionalComma())) {
246 if (parser
.parseKeyword("stride") || parser
.parseEqual())
249 llvm::SMLoc strideLoc
= parser
.getCurrentLocation();
250 Optional
<unsigned> optStride
= parseAndVerify
<unsigned>(dialect
, parser
);
254 if (!(stride
= optStride
.getValue())) {
255 parser
.emitError(strideLoc
, "ArrayStride must be greater than zero");
261 // element-type ::= integer-type
262 // | floating-point-type
266 // array-type ::= `!spv.array` `<` integer-literal `x` element-type
267 // (`,` `stride` `=` integer-literal)? `>`
268 static Type
parseArrayType(SPIRVDialect
const &dialect
,
269 DialectAsmParser
&parser
) {
270 if (parser
.parseLess())
273 SmallVector
<int64_t, 1> countDims
;
274 llvm::SMLoc countLoc
= parser
.getCurrentLocation();
275 if (parser
.parseDimensionList(countDims
, /*allowDynamic=*/false))
277 if (countDims
.size() != 1) {
278 parser
.emitError(countLoc
,
279 "expected single integer for array element count");
283 // According to the SPIR-V spec:
284 // "Length is the number of elements in the array. It must be at least 1."
285 int64_t count
= countDims
[0];
287 parser
.emitError(countLoc
, "expected array length greater than 0");
291 Type elementType
= parseAndVerifyType(dialect
, parser
);
296 if (failed(parseOptionalArrayStride(dialect
, parser
, stride
)))
299 if (parser
.parseGreater())
301 return ArrayType::get(elementType
, count
, stride
);
304 // cooperative-matrix-type ::= `!spv.coopmatrix` `<` element-type ',' scope ','
305 // rows ',' columns>`
306 static Type
parseCooperativeMatrixType(SPIRVDialect
const &dialect
,
307 DialectAsmParser
&parser
) {
308 if (parser
.parseLess())
311 SmallVector
<int64_t, 2> dims
;
312 llvm::SMLoc countLoc
= parser
.getCurrentLocation();
313 if (parser
.parseDimensionList(dims
, /*allowDynamic=*/false))
316 if (dims
.size() != 2) {
317 parser
.emitError(countLoc
, "expected rows and columns size");
321 auto elementTy
= parseAndVerifyType(dialect
, parser
);
326 if (parser
.parseComma() || parseEnumKeywordAttr(scope
, parser
, "scope <id>"))
329 if (parser
.parseGreater())
331 return CooperativeMatrixNVType::get(elementTy
, scope
, dims
[0], dims
[1]);
334 // TODO: Reorder methods to be utilities first and parse*Type
335 // methods in alphabetical order
337 // storage-class ::= `UniformConstant`
340 // | <and other storage classes...>
342 // pointer-type ::= `!spv.ptr<` element-type `,` storage-class `>`
343 static Type
parsePointerType(SPIRVDialect
const &dialect
,
344 DialectAsmParser
&parser
) {
345 if (parser
.parseLess())
348 auto pointeeType
= parseAndVerifyType(dialect
, parser
);
352 StringRef storageClassSpec
;
353 llvm::SMLoc storageClassLoc
= parser
.getCurrentLocation();
354 if (parser
.parseComma() || parser
.parseKeyword(&storageClassSpec
))
357 auto storageClass
= symbolizeStorageClass(storageClassSpec
);
359 parser
.emitError(storageClassLoc
, "unknown storage class: ")
363 if (parser
.parseGreater())
365 return PointerType::get(pointeeType
, *storageClass
);
368 // runtime-array-type ::= `!spv.rtarray` `<` element-type
369 // (`,` `stride` `=` integer-literal)? `>`
370 static Type
parseRuntimeArrayType(SPIRVDialect
const &dialect
,
371 DialectAsmParser
&parser
) {
372 if (parser
.parseLess())
375 Type elementType
= parseAndVerifyType(dialect
, parser
);
380 if (failed(parseOptionalArrayStride(dialect
, parser
, stride
)))
383 if (parser
.parseGreater())
385 return RuntimeArrayType::get(elementType
, stride
);
388 // matrix-type ::= `!spv.matrix` `<` integer-literal `x` element-type `>`
389 static Type
parseMatrixType(SPIRVDialect
const &dialect
,
390 DialectAsmParser
&parser
) {
391 if (parser
.parseLess())
394 SmallVector
<int64_t, 1> countDims
;
395 llvm::SMLoc countLoc
= parser
.getCurrentLocation();
396 if (parser
.parseDimensionList(countDims
, /*allowDynamic=*/false))
398 if (countDims
.size() != 1) {
399 parser
.emitError(countLoc
, "expected single unsigned "
400 "integer for number of columns");
404 int64_t columnCount
= countDims
[0];
405 // According to the specification, Matrices can have 2, 3, or 4 columns
406 if (columnCount
< 2 || columnCount
> 4) {
407 parser
.emitError(countLoc
, "matrix is expected to have 2, 3, or 4 "
412 Type columnType
= parseAndVerifyMatrixType(dialect
, parser
);
416 if (parser
.parseGreater())
419 return MatrixType::get(columnType
, columnCount
);
422 // Specialize this function to parse each of the parameters that define an
423 // ImageType. By default it assumes this is an enum type.
424 template <typename ValTy
>
425 static Optional
<ValTy
> parseAndVerify(SPIRVDialect
const &dialect
,
426 DialectAsmParser
&parser
) {
428 llvm::SMLoc enumLoc
= parser
.getCurrentLocation();
429 if (parser
.parseKeyword(&enumSpec
)) {
433 auto val
= spirv::symbolizeEnum
<ValTy
>(enumSpec
);
435 parser
.emitError(enumLoc
, "unknown attribute: '") << enumSpec
<< "'";
440 Optional
<Type
> parseAndVerify
<Type
>(SPIRVDialect
const &dialect
,
441 DialectAsmParser
&parser
) {
442 // TODO: Further verify that the element type can be sampled
443 auto ty
= parseAndVerifyType(dialect
, parser
);
449 template <typename IntTy
>
450 static Optional
<IntTy
> parseAndVerifyInteger(SPIRVDialect
const &dialect
,
451 DialectAsmParser
&parser
) {
452 IntTy offsetVal
= std::numeric_limits
<IntTy
>::max();
453 if (parser
.parseInteger(offsetVal
))
459 Optional
<unsigned> parseAndVerify
<unsigned>(SPIRVDialect
const &dialect
,
460 DialectAsmParser
&parser
) {
461 return parseAndVerifyInteger
<unsigned>(dialect
, parser
);
465 // Functor object to parse a comma separated list of specs. The function
466 // parseAndVerify does the actual parsing and verification of individual
467 // elements. This is a functor since parsing the last element of the list
468 // (termination condition) needs partial specialization.
469 template <typename ParseType
, typename
... Args
>
470 struct ParseCommaSeparatedList
{
471 Optional
<std::tuple
<ParseType
, Args
...>>
472 operator()(SPIRVDialect
const &dialect
, DialectAsmParser
&parser
) const {
473 auto parseVal
= parseAndVerify
<ParseType
>(dialect
, parser
);
477 auto numArgs
= std::tuple_size
<std::tuple
<Args
...>>::value
;
478 if (numArgs
!= 0 && failed(parser
.parseComma()))
480 auto remainingValues
= ParseCommaSeparatedList
<Args
...>{}(dialect
, parser
);
481 if (!remainingValues
)
483 return std::tuple_cat(std::tuple
<ParseType
>(parseVal
.getValue()),
484 remainingValues
.getValue());
488 // Partial specialization of the function to parse a comma separated list of
489 // specs to parse the last element of the list.
490 template <typename ParseType
>
491 struct ParseCommaSeparatedList
<ParseType
> {
492 Optional
<std::tuple
<ParseType
>> operator()(SPIRVDialect
const &dialect
,
493 DialectAsmParser
&parser
) const {
494 if (auto value
= parseAndVerify
<ParseType
>(dialect
, parser
))
495 return std::tuple
<ParseType
>(value
.getValue());
501 // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
503 // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
505 // arrayed-info ::= `NonArrayed` | `Arrayed`
507 // sampling-info ::= `SingleSampled` | `MultiSampled`
509 // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler`
511 // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
513 // image-type ::= `!spv.image<` element-type `,` dim `,` depth-info `,`
514 // arrayed-info `,` sampling-info `,`
515 // sampler-use-info `,` format `>`
516 static Type
parseImageType(SPIRVDialect
const &dialect
,
517 DialectAsmParser
&parser
) {
518 if (parser
.parseLess())
522 ParseCommaSeparatedList
<Type
, Dim
, ImageDepthInfo
, ImageArrayedInfo
,
523 ImageSamplingInfo
, ImageSamplerUseInfo
,
524 ImageFormat
>{}(dialect
, parser
);
528 if (parser
.parseGreater())
530 return ImageType::get(value
.getValue());
533 // Parse decorations associated with a member.
534 static ParseResult
parseStructMemberDecorations(
535 SPIRVDialect
const &dialect
, DialectAsmParser
&parser
,
536 ArrayRef
<Type
> memberTypes
,
537 SmallVectorImpl
<StructType::OffsetInfo
> &offsetInfo
,
538 SmallVectorImpl
<StructType::MemberDecorationInfo
> &memberDecorationInfo
) {
540 // Check if the first element is offset.
541 llvm::SMLoc offsetLoc
= parser
.getCurrentLocation();
542 StructType::OffsetInfo offset
= 0;
543 OptionalParseResult offsetParseResult
= parser
.parseOptionalInteger(offset
);
544 if (offsetParseResult
.hasValue()) {
545 if (failed(*offsetParseResult
))
548 if (offsetInfo
.size() != memberTypes
.size() - 1) {
549 return parser
.emitError(offsetLoc
,
550 "offset specification must be given for "
553 offsetInfo
.push_back(offset
);
556 // Check for no spirv::Decorations.
557 if (succeeded(parser
.parseOptionalRSquare()))
560 // If there was an offset, make sure to parse the comma.
561 if (offsetParseResult
.hasValue() && parser
.parseComma())
564 // Check for spirv::Decorations.
566 auto memberDecoration
= parseAndVerify
<spirv::Decoration
>(dialect
, parser
);
567 if (!memberDecoration
)
570 // Parse member decoration value if it exists.
571 if (succeeded(parser
.parseOptionalEqual())) {
572 auto memberDecorationValue
=
573 parseAndVerifyInteger
<uint32_t>(dialect
, parser
);
575 if (!memberDecorationValue
)
578 memberDecorationInfo
.emplace_back(
579 static_cast<uint32_t>(memberTypes
.size() - 1), 1,
580 memberDecoration
.getValue(), memberDecorationValue
.getValue());
582 memberDecorationInfo
.emplace_back(
583 static_cast<uint32_t>(memberTypes
.size() - 1), 0,
584 memberDecoration
.getValue(), 0);
587 } while (succeeded(parser
.parseOptionalComma()));
589 return parser
.parseRSquare();
592 // struct-member-decoration ::= integer-literal? spirv-decoration*
594 // `!spv.struct<` (id `,`)?
596 // (spirv-type (`[` struct-member-decoration `]`)?)*
598 static Type
parseStructType(SPIRVDialect
const &dialect
,
599 DialectAsmParser
&parser
) {
600 // TODO: This function is quite lengthy. Break it down into smaller chunks.
602 // To properly resolve recursive references while parsing recursive struct
603 // types, we need to maintain a list of enclosing struct type names. This set
604 // maintains the names of struct types in which the type we are about to parse
607 // Note: This has to be thread_local to enable multiple threads to safely
608 // parse concurrently.
609 thread_local
llvm::SetVector
<StringRef
> structContext
;
611 static auto removeIdentifierAndFail
=
612 [](llvm::SetVector
<StringRef
> &structContext
, StringRef identifier
) {
613 if (!identifier
.empty())
614 structContext
.remove(identifier
);
619 if (parser
.parseLess())
622 StringRef identifier
;
624 // Check if this is an identified struct type.
625 if (succeeded(parser
.parseOptionalKeyword(&identifier
))) {
626 // Check if this is a possible recursive reference.
627 if (succeeded(parser
.parseOptionalGreater())) {
628 if (structContext
.count(identifier
) == 0) {
631 "recursive struct reference not nested in struct definition");
636 return StructType::getIdentified(dialect
.getContext(), identifier
);
639 if (failed(parser
.parseComma()))
642 if (structContext
.count(identifier
) != 0) {
643 parser
.emitError(parser
.getNameLoc(),
644 "identifier already used for an enclosing struct");
646 return removeIdentifierAndFail(structContext
, identifier
);
649 structContext
.insert(identifier
);
652 if (failed(parser
.parseLParen()))
653 return removeIdentifierAndFail(structContext
, identifier
);
655 if (succeeded(parser
.parseOptionalRParen()) &&
656 succeeded(parser
.parseOptionalGreater())) {
657 if (!identifier
.empty())
658 structContext
.remove(identifier
);
660 return StructType::getEmpty(dialect
.getContext(), identifier
);
663 StructType idStructTy
;
665 if (!identifier
.empty())
666 idStructTy
= StructType::getIdentified(dialect
.getContext(), identifier
);
668 SmallVector
<Type
, 4> memberTypes
;
669 SmallVector
<StructType::OffsetInfo
, 4> offsetInfo
;
670 SmallVector
<StructType::MemberDecorationInfo
, 4> memberDecorationInfo
;
674 if (parser
.parseType(memberType
))
675 return removeIdentifierAndFail(structContext
, identifier
);
676 memberTypes
.push_back(memberType
);
678 if (succeeded(parser
.parseOptionalLSquare()))
679 if (parseStructMemberDecorations(dialect
, parser
, memberTypes
, offsetInfo
,
680 memberDecorationInfo
))
681 return removeIdentifierAndFail(structContext
, identifier
);
682 } while (succeeded(parser
.parseOptionalComma()));
684 if (!offsetInfo
.empty() && memberTypes
.size() != offsetInfo
.size()) {
685 parser
.emitError(parser
.getNameLoc(),
686 "offset specification must be given for all members");
687 return removeIdentifierAndFail(structContext
, identifier
);
690 if (failed(parser
.parseRParen()) || failed(parser
.parseGreater()))
691 return removeIdentifierAndFail(structContext
, identifier
);
693 if (!identifier
.empty()) {
694 if (failed(idStructTy
.trySetBody(memberTypes
, offsetInfo
,
695 memberDecorationInfo
)))
698 structContext
.remove(identifier
);
702 return StructType::get(memberTypes
, offsetInfo
, memberDecorationInfo
);
705 // spirv-type ::= array-type
709 // | runtime-array-type
711 Type
SPIRVDialect::parseType(DialectAsmParser
&parser
) const {
713 if (parser
.parseKeyword(&keyword
))
716 if (keyword
== "array")
717 return parseArrayType(*this, parser
);
718 if (keyword
== "coopmatrix")
719 return parseCooperativeMatrixType(*this, parser
);
720 if (keyword
== "image")
721 return parseImageType(*this, parser
);
722 if (keyword
== "ptr")
723 return parsePointerType(*this, parser
);
724 if (keyword
== "rtarray")
725 return parseRuntimeArrayType(*this, parser
);
726 if (keyword
== "struct")
727 return parseStructType(*this, parser
);
728 if (keyword
== "matrix")
729 return parseMatrixType(*this, parser
);
730 parser
.emitError(parser
.getNameLoc(), "unknown SPIR-V type: ") << keyword
;
734 //===----------------------------------------------------------------------===//
736 //===----------------------------------------------------------------------===//
738 static void print(ArrayType type
, DialectAsmPrinter
&os
) {
739 os
<< "array<" << type
.getNumElements() << " x " << type
.getElementType();
740 if (unsigned stride
= type
.getArrayStride())
741 os
<< ", stride=" << stride
;
745 static void print(RuntimeArrayType type
, DialectAsmPrinter
&os
) {
746 os
<< "rtarray<" << type
.getElementType();
747 if (unsigned stride
= type
.getArrayStride())
748 os
<< ", stride=" << stride
;
752 static void print(PointerType type
, DialectAsmPrinter
&os
) {
753 os
<< "ptr<" << type
.getPointeeType() << ", "
754 << stringifyStorageClass(type
.getStorageClass()) << ">";
757 static void print(ImageType type
, DialectAsmPrinter
&os
) {
758 os
<< "image<" << type
.getElementType() << ", " << stringifyDim(type
.getDim())
759 << ", " << stringifyImageDepthInfo(type
.getDepthInfo()) << ", "
760 << stringifyImageArrayedInfo(type
.getArrayedInfo()) << ", "
761 << stringifyImageSamplingInfo(type
.getSamplingInfo()) << ", "
762 << stringifyImageSamplerUseInfo(type
.getSamplerUseInfo()) << ", "
763 << stringifyImageFormat(type
.getImageFormat()) << ">";
766 static void print(StructType type
, DialectAsmPrinter
&os
) {
767 thread_local
llvm::SetVector
<StringRef
> structContext
;
771 if (type
.isIdentified()) {
772 os
<< type
.getIdentifier();
774 if (structContext
.count(type
.getIdentifier())) {
780 structContext
.insert(type
.getIdentifier());
785 auto printMember
= [&](unsigned i
) {
786 os
<< type
.getElementType(i
);
787 SmallVector
<spirv::StructType::MemberDecorationInfo
, 0> decorations
;
788 type
.getMemberDecorations(i
, decorations
);
789 if (type
.hasOffset() || !decorations
.empty()) {
791 if (type
.hasOffset()) {
792 os
<< type
.getMemberOffset(i
);
793 if (!decorations
.empty())
796 auto eachFn
= [&os
](spirv::StructType::MemberDecorationInfo decoration
) {
797 os
<< stringifyDecoration(decoration
.decoration
);
798 if (decoration
.hasValue
) {
799 os
<< "=" << decoration
.decorationValue
;
802 llvm::interleaveComma(decorations
, os
, eachFn
);
806 llvm::interleaveComma(llvm::seq
<unsigned>(0, type
.getNumElements()), os
,
810 if (type
.isIdentified())
811 structContext
.remove(type
.getIdentifier());
814 static void print(CooperativeMatrixNVType type
, DialectAsmPrinter
&os
) {
815 os
<< "coopmatrix<" << type
.getRows() << "x" << type
.getColumns() << "x";
816 os
<< type
.getElementType() << ", " << stringifyScope(type
.getScope());
820 static void print(MatrixType type
, DialectAsmPrinter
&os
) {
821 os
<< "matrix<" << type
.getNumColumns() << " x " << type
.getColumnType();
825 void SPIRVDialect::printType(Type type
, DialectAsmPrinter
&os
) const {
826 TypeSwitch
<Type
>(type
)
827 .Case
<ArrayType
, CooperativeMatrixNVType
, PointerType
, RuntimeArrayType
,
828 ImageType
, StructType
, MatrixType
>(
829 [&](auto type
) { print(type
, os
); })
830 .Default([](Type
) { llvm_unreachable("unhandled SPIR-V type"); });
833 //===----------------------------------------------------------------------===//
835 //===----------------------------------------------------------------------===//
837 /// Parses a comma-separated list of keywords, invokes `processKeyword` on each
838 /// of the parsed keyword, and returns failure if any error occurs.
839 static ParseResult
parseKeywordList(
840 DialectAsmParser
&parser
,
841 function_ref
<LogicalResult(llvm::SMLoc
, StringRef
)> processKeyword
) {
842 if (parser
.parseLSquare())
845 // Special case for empty list.
846 if (succeeded(parser
.parseOptionalRSquare()))
849 // Keep parsing the keyword and an optional comma following it. If the comma
850 // is successfully parsed, then we have more keywords to parse.
852 auto loc
= parser
.getCurrentLocation();
854 if (parser
.parseKeyword(&keyword
) || failed(processKeyword(loc
, keyword
)))
856 } while (succeeded(parser
.parseOptionalComma()));
858 if (parser
.parseRSquare())
864 /// Parses a spirv::InterfaceVarABIAttr.
865 static Attribute
parseInterfaceVarABIAttr(DialectAsmParser
&parser
) {
866 if (parser
.parseLess())
869 Builder
&builder
= parser
.getBuilder();
871 if (parser
.parseLParen())
874 IntegerAttr descriptorSetAttr
;
876 auto loc
= parser
.getCurrentLocation();
877 uint32_t descriptorSet
= 0;
878 auto descriptorSetParseResult
= parser
.parseOptionalInteger(descriptorSet
);
880 if (!descriptorSetParseResult
.hasValue() ||
881 failed(*descriptorSetParseResult
)) {
882 parser
.emitError(loc
, "missing descriptor set");
885 descriptorSetAttr
= builder
.getI32IntegerAttr(descriptorSet
);
888 if (parser
.parseComma())
891 IntegerAttr bindingAttr
;
893 auto loc
= parser
.getCurrentLocation();
894 uint32_t binding
= 0;
895 auto bindingParseResult
= parser
.parseOptionalInteger(binding
);
897 if (!bindingParseResult
.hasValue() || failed(*bindingParseResult
)) {
898 parser
.emitError(loc
, "missing binding");
901 bindingAttr
= builder
.getI32IntegerAttr(binding
);
904 if (parser
.parseRParen())
907 IntegerAttr storageClassAttr
;
909 if (succeeded(parser
.parseOptionalComma())) {
910 auto loc
= parser
.getCurrentLocation();
911 StringRef storageClass
;
912 if (parser
.parseKeyword(&storageClass
))
915 if (auto storageClassSymbol
=
916 spirv::symbolizeStorageClass(storageClass
)) {
917 storageClassAttr
= builder
.getI32IntegerAttr(
918 static_cast<uint32_t>(*storageClassSymbol
));
920 parser
.emitError(loc
, "unknown storage class: ") << storageClass
;
926 if (parser
.parseGreater())
929 return spirv::InterfaceVarABIAttr::get(descriptorSetAttr
, bindingAttr
,
933 static Attribute
parseVerCapExtAttr(DialectAsmParser
&parser
) {
934 if (parser
.parseLess())
937 Builder
&builder
= parser
.getBuilder();
939 IntegerAttr versionAttr
;
941 auto loc
= parser
.getCurrentLocation();
943 if (parser
.parseKeyword(&version
) || parser
.parseComma())
946 if (auto versionSymbol
= spirv::symbolizeVersion(version
)) {
948 builder
.getI32IntegerAttr(static_cast<uint32_t>(*versionSymbol
));
950 parser
.emitError(loc
, "unknown version: ") << version
;
955 ArrayAttr capabilitiesAttr
;
957 SmallVector
<Attribute
, 4> capabilities
;
958 llvm::SMLoc errorloc
;
959 StringRef errorKeyword
;
961 auto processCapability
= [&](llvm::SMLoc loc
, StringRef capability
) {
962 if (auto capSymbol
= spirv::symbolizeCapability(capability
)) {
963 capabilities
.push_back(
964 builder
.getI32IntegerAttr(static_cast<uint32_t>(*capSymbol
)));
967 return errorloc
= loc
, errorKeyword
= capability
, failure();
969 if (parseKeywordList(parser
, processCapability
) || parser
.parseComma()) {
970 if (!errorKeyword
.empty())
971 parser
.emitError(errorloc
, "unknown capability: ") << errorKeyword
;
975 capabilitiesAttr
= builder
.getArrayAttr(capabilities
);
978 ArrayAttr extensionsAttr
;
980 SmallVector
<Attribute
, 1> extensions
;
981 llvm::SMLoc errorloc
;
982 StringRef errorKeyword
;
984 auto processExtension
= [&](llvm::SMLoc loc
, StringRef extension
) {
985 if (spirv::symbolizeExtension(extension
)) {
986 extensions
.push_back(builder
.getStringAttr(extension
));
989 return errorloc
= loc
, errorKeyword
= extension
, failure();
991 if (parseKeywordList(parser
, processExtension
)) {
992 if (!errorKeyword
.empty())
993 parser
.emitError(errorloc
, "unknown extension: ") << errorKeyword
;
997 extensionsAttr
= builder
.getArrayAttr(extensions
);
1000 if (parser
.parseGreater())
1003 return spirv::VerCapExtAttr::get(versionAttr
, capabilitiesAttr
,
1007 /// Parses a spirv::TargetEnvAttr.
1008 static Attribute
parseTargetEnvAttr(DialectAsmParser
&parser
) {
1009 if (parser
.parseLess())
1012 spirv::VerCapExtAttr tripleAttr
;
1013 if (parser
.parseAttribute(tripleAttr
) || parser
.parseComma())
1016 // Parse [vendor[:device-type[:device-id]]]
1017 Vendor vendorID
= Vendor::Unknown
;
1018 DeviceType deviceType
= DeviceType::Unknown
;
1019 uint32_t deviceID
= spirv::TargetEnvAttr::kUnknownDeviceID
;
1021 auto loc
= parser
.getCurrentLocation();
1022 StringRef vendorStr
;
1023 if (succeeded(parser
.parseOptionalKeyword(&vendorStr
))) {
1024 if (auto vendorSymbol
= spirv::symbolizeVendor(vendorStr
)) {
1025 vendorID
= *vendorSymbol
;
1027 parser
.emitError(loc
, "unknown vendor: ") << vendorStr
;
1030 if (succeeded(parser
.parseOptionalColon())) {
1031 loc
= parser
.getCurrentLocation();
1032 StringRef deviceTypeStr
;
1033 if (parser
.parseKeyword(&deviceTypeStr
))
1035 if (auto deviceTypeSymbol
= spirv::symbolizeDeviceType(deviceTypeStr
)) {
1036 deviceType
= *deviceTypeSymbol
;
1038 parser
.emitError(loc
, "unknown device type: ") << deviceTypeStr
;
1041 if (succeeded(parser
.parseOptionalColon())) {
1042 loc
= parser
.getCurrentLocation();
1043 if (parser
.parseInteger(deviceID
))
1047 if (parser
.parseComma())
1052 DictionaryAttr limitsAttr
;
1054 auto loc
= parser
.getCurrentLocation();
1055 if (parser
.parseAttribute(limitsAttr
))
1058 if (!limitsAttr
.isa
<spirv::ResourceLimitsAttr
>()) {
1061 "limits must be a dictionary attribute containing two 32-bit integer "
1062 "attributes 'max_compute_workgroup_invocations' and "
1063 "'max_compute_workgroup_size'");
1068 if (parser
.parseGreater())
1071 return spirv::TargetEnvAttr::get(tripleAttr
, vendorID
, deviceType
, deviceID
,
1075 Attribute
SPIRVDialect::parseAttribute(DialectAsmParser
&parser
,
1077 // SPIR-V attributes are dictionaries so they do not have type.
1079 parser
.emitError(parser
.getNameLoc(), "unexpected type");
1083 // Parse the kind keyword first.
1085 if (parser
.parseKeyword(&attrKind
))
1088 if (attrKind
== spirv::TargetEnvAttr::getKindName())
1089 return parseTargetEnvAttr(parser
);
1090 if (attrKind
== spirv::VerCapExtAttr::getKindName())
1091 return parseVerCapExtAttr(parser
);
1092 if (attrKind
== spirv::InterfaceVarABIAttr::getKindName())
1093 return parseInterfaceVarABIAttr(parser
);
1095 parser
.emitError(parser
.getNameLoc(), "unknown SPIR-V attribute kind: ")
1100 //===----------------------------------------------------------------------===//
1101 // Attribute Printing
1102 //===----------------------------------------------------------------------===//
1104 static void print(spirv::VerCapExtAttr triple
, DialectAsmPrinter
&printer
) {
1105 auto &os
= printer
.getStream();
1106 printer
<< spirv::VerCapExtAttr::getKindName() << "<"
1107 << spirv::stringifyVersion(triple
.getVersion()) << ", [";
1108 llvm::interleaveComma(
1109 triple
.getCapabilities(), os
,
1110 [&](spirv::Capability cap
) { os
<< spirv::stringifyCapability(cap
); });
1112 llvm::interleaveComma(triple
.getExtensionsAttr(), os
, [&](Attribute attr
) {
1113 os
<< attr
.cast
<StringAttr
>().getValue();
1118 static void print(spirv::TargetEnvAttr targetEnv
, DialectAsmPrinter
&printer
) {
1119 printer
<< spirv::TargetEnvAttr::getKindName() << "<#spv.";
1120 print(targetEnv
.getTripleAttr(), printer
);
1121 spirv::Vendor vendorID
= targetEnv
.getVendorID();
1122 spirv::DeviceType deviceType
= targetEnv
.getDeviceType();
1123 uint32_t deviceID
= targetEnv
.getDeviceID();
1124 if (vendorID
!= spirv::Vendor::Unknown
) {
1125 printer
<< ", " << spirv::stringifyVendor(vendorID
);
1126 if (deviceType
!= spirv::DeviceType::Unknown
) {
1127 printer
<< ":" << spirv::stringifyDeviceType(deviceType
);
1128 if (deviceID
!= spirv::TargetEnvAttr::kUnknownDeviceID
)
1129 printer
<< ":" << deviceID
;
1132 printer
<< ", " << targetEnv
.getResourceLimits() << ">";
1135 static void print(spirv::InterfaceVarABIAttr interfaceVarABIAttr
,
1136 DialectAsmPrinter
&printer
) {
1137 printer
<< spirv::InterfaceVarABIAttr::getKindName() << "<("
1138 << interfaceVarABIAttr
.getDescriptorSet() << ", "
1139 << interfaceVarABIAttr
.getBinding() << ")";
1140 auto storageClass
= interfaceVarABIAttr
.getStorageClass();
1142 printer
<< ", " << spirv::stringifyStorageClass(*storageClass
);
1146 void SPIRVDialect::printAttribute(Attribute attr
,
1147 DialectAsmPrinter
&printer
) const {
1148 if (auto targetEnv
= attr
.dyn_cast
<TargetEnvAttr
>())
1149 print(targetEnv
, printer
);
1150 else if (auto vceAttr
= attr
.dyn_cast
<VerCapExtAttr
>())
1151 print(vceAttr
, printer
);
1152 else if (auto interfaceVarABIAttr
= attr
.dyn_cast
<InterfaceVarABIAttr
>())
1153 print(interfaceVarABIAttr
, printer
);
1155 llvm_unreachable("unhandled SPIR-V attribute kind");
1158 //===----------------------------------------------------------------------===//
1160 //===----------------------------------------------------------------------===//
1162 Operation
*SPIRVDialect::materializeConstant(OpBuilder
&builder
,
1163 Attribute value
, Type type
,
1165 if (!spirv::ConstantOp::isBuildableWith(type
))
1168 return builder
.create
<spirv::ConstantOp
>(loc
, type
, value
);
1171 //===----------------------------------------------------------------------===//
1172 // Shader Interface ABI
1173 //===----------------------------------------------------------------------===//
1175 LogicalResult
SPIRVDialect::verifyOperationAttribute(Operation
*op
,
1176 NamedAttribute attribute
) {
1177 StringRef symbol
= attribute
.first
.strref();
1178 Attribute attr
= attribute
.second
;
1180 // TODO: figure out a way to generate the description from the
1181 // StructAttr definition.
1182 if (symbol
== spirv::getEntryPointABIAttrName()) {
1183 if (!attr
.isa
<spirv::EntryPointABIAttr
>())
1184 return op
->emitError("'")
1186 << "' attribute must be a dictionary attribute containing one "
1187 "32-bit integer elements attribute: 'local_size'";
1188 } else if (symbol
== spirv::getTargetEnvAttrName()) {
1189 if (!attr
.isa
<spirv::TargetEnvAttr
>())
1190 return op
->emitError("'") << symbol
<< "' must be a spirv::TargetEnvAttr";
1192 return op
->emitError("found unsupported '")
1193 << symbol
<< "' attribute on operation";
1199 /// Verifies the given SPIR-V `attribute` attached to a value of the given
1200 /// `valueType` is valid.
1201 static LogicalResult
verifyRegionAttribute(Location loc
, Type valueType
,
1202 NamedAttribute attribute
) {
1203 StringRef symbol
= attribute
.first
.strref();
1204 Attribute attr
= attribute
.second
;
1206 if (symbol
!= spirv::getInterfaceVarABIAttrName())
1207 return emitError(loc
, "found unsupported '")
1208 << symbol
<< "' attribute on region argument";
1210 auto varABIAttr
= attr
.dyn_cast
<spirv::InterfaceVarABIAttr
>();
1212 return emitError(loc
, "'")
1213 << symbol
<< "' must be a spirv::InterfaceVarABIAttr";
1215 if (varABIAttr
.getStorageClass() && !valueType
.isIntOrIndexOrFloat())
1216 return emitError(loc
, "'") << symbol
1217 << "' attribute cannot specify storage class "
1218 "when attaching to a non-scalar value";
1223 LogicalResult
SPIRVDialect::verifyRegionArgAttribute(Operation
*op
,
1224 unsigned regionIndex
,
1226 NamedAttribute attribute
) {
1227 return verifyRegionAttribute(
1228 op
->getLoc(), op
->getRegion(regionIndex
).getArgument(argIndex
).getType(),
1232 LogicalResult
SPIRVDialect::verifyRegionResultAttribute(
1233 Operation
*op
, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
1234 NamedAttribute attribute
) {
1235 return op
->emitError("cannot attach SPIR-V attributes to region result");