[mlir][spirv] NFC: Shuffle code around to better follow convention
[llvm-project.git] / mlir / lib / Dialect / SPIRV / IR / SPIRVDialect.cpp
blobc53048c7a5ff477d15c71b447169d49cec63e983
1 //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V dialect in MLIR.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
14 #include "mlir/Dialect/SPIRV/IR/ParserUtils.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
17 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/DialectImplementation.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/Parser.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/StringMap.h"
29 #include "llvm/ADT/StringSwitch.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/raw_ostream.h"
33 using namespace mlir;
34 using namespace mlir::spirv;
36 //===----------------------------------------------------------------------===//
37 // InlinerInterface
38 //===----------------------------------------------------------------------===//
40 /// Returns true if the given region contains spv.Return or spv.ReturnValue ops.
41 static inline bool containsReturn(Region &region) {
42 return llvm::any_of(region, [](Block &block) {
43 Operation *terminator = block.getTerminator();
44 return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
45 });
48 namespace {
49 /// This class defines the interface for inlining within the SPIR-V dialect.
50 struct SPIRVInlinerInterface : public DialectInlinerInterface {
51 using DialectInlinerInterface::DialectInlinerInterface;
53 /// All call operations within SPIRV can be inlined.
54 bool isLegalToInline(Operation *call, Operation *callable,
55 bool wouldBeCloned) const final {
56 return true;
59 /// Returns true if the given region 'src' can be inlined into the region
60 /// 'dest' that is attached to an operation registered to the current dialect.
61 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
62 BlockAndValueMapping &) const final {
63 // Return true here when inlining into spv.func, spv.selection, and
64 // spv.loop operations.
65 auto *op = dest->getParentOp();
66 return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
69 /// Returns true if the given operation 'op', that is registered to this
70 /// dialect, can be inlined into the region 'dest' that is attached to an
71 /// operation registered to the current dialect.
72 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
73 BlockAndValueMapping &) const final {
74 // TODO: Enable inlining structured control flows with return.
75 if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
76 containsReturn(op->getRegion(0)))
77 return false;
78 // TODO: we need to filter OpKill here to avoid inlining it to
79 // a loop continue construct:
80 // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
81 // However OpKill is fragment shader specific and we don't support it yet.
82 return true;
85 /// Handle the given inlined terminator by replacing it with a new operation
86 /// as necessary.
87 void handleTerminator(Operation *op, Block *newDest) const final {
88 if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
89 OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest);
90 op->erase();
91 } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
92 llvm_unreachable("unimplemented spv.ReturnValue in inliner");
96 /// Handle the given inlined terminator by replacing it with a new operation
97 /// as necessary.
98 void handleTerminator(Operation *op,
99 ArrayRef<Value> valuesToRepl) const final {
100 // Only spv.ReturnValue needs to be handled here.
101 auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
102 if (!retValOp)
103 return;
105 // Replace the values directly with the return operands.
106 assert(valuesToRepl.size() == 1 &&
107 "spv.ReturnValue expected to only handle one result");
108 valuesToRepl.front().replaceAllUsesWith(retValOp.value());
111 } // namespace
113 //===----------------------------------------------------------------------===//
114 // SPIR-V Dialect
115 //===----------------------------------------------------------------------===//
117 void SPIRVDialect::initialize() {
118 addTypes<ArrayType, CooperativeMatrixNVType, ImageType, MatrixType,
119 PointerType, RuntimeArrayType, StructType>();
121 addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
123 // Add SPIR-V ops.
124 addOperations<
125 #define GET_OP_LIST
126 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
127 >();
129 addInterfaces<SPIRVInlinerInterface>();
131 // Allow unknown operations because SPIR-V is extensible.
132 allowUnknownOperations();
135 std::string SPIRVDialect::getAttributeName(Decoration decoration) {
136 return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration));
139 //===----------------------------------------------------------------------===//
140 // Type Parsing
141 //===----------------------------------------------------------------------===//
143 // Forward declarations.
144 template <typename ValTy>
145 static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
146 DialectAsmParser &parser);
147 template <>
148 Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
149 DialectAsmParser &parser);
151 template <>
152 Optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
153 DialectAsmParser &parser);
155 static Type parseAndVerifyType(SPIRVDialect const &dialect,
156 DialectAsmParser &parser) {
157 Type type;
158 llvm::SMLoc typeLoc = parser.getCurrentLocation();
159 if (parser.parseType(type))
160 return Type();
162 // Allow SPIR-V dialect types
163 if (&type.getDialect() == &dialect)
164 return type;
166 // Check other allowed types
167 if (auto t = type.dyn_cast<FloatType>()) {
168 if (type.isBF16()) {
169 parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
170 return Type();
172 } else if (auto t = type.dyn_cast<IntegerType>()) {
173 if (!ScalarType::isValid(t)) {
174 parser.emitError(typeLoc,
175 "only 1/8/16/32/64-bit integer type allowed but found ")
176 << type;
177 return Type();
179 } else if (auto t = type.dyn_cast<VectorType>()) {
180 if (t.getRank() != 1) {
181 parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
182 return Type();
184 if (t.getNumElements() > 4) {
185 parser.emitError(
186 typeLoc, "vector length has to be less than or equal to 4 but found ")
187 << t.getNumElements();
188 return Type();
190 } else {
191 parser.emitError(typeLoc, "cannot use ")
192 << type << " to compose SPIR-V types";
193 return Type();
196 return type;
199 static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
200 DialectAsmParser &parser) {
201 Type type;
202 llvm::SMLoc typeLoc = parser.getCurrentLocation();
203 if (parser.parseType(type))
204 return Type();
206 if (auto t = type.dyn_cast<VectorType>()) {
207 if (t.getRank() != 1) {
208 parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
209 return Type();
211 if (t.getNumElements() > 4 || t.getNumElements() < 2) {
212 parser.emitError(typeLoc,
213 "matrix columns size has to be less than or equal "
214 "to 4 and greater than or equal 2, but found ")
215 << t.getNumElements();
216 return Type();
219 if (!t.getElementType().isa<FloatType>()) {
220 parser.emitError(typeLoc, "matrix columns' elements must be of "
221 "Float type, got ")
222 << t.getElementType();
223 return Type();
225 } else {
226 parser.emitError(typeLoc, "matrix must be composed using vector "
227 "type, got ")
228 << type;
229 return Type();
232 return type;
235 /// Parses an optional `, stride = N` assembly segment. If no parsing failure
236 /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
237 /// missing.
238 static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
239 DialectAsmParser &parser,
240 unsigned &stride) {
241 if (failed(parser.parseOptionalComma())) {
242 stride = 0;
243 return success();
246 if (parser.parseKeyword("stride") || parser.parseEqual())
247 return failure();
249 llvm::SMLoc strideLoc = parser.getCurrentLocation();
250 Optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
251 if (!optStride)
252 return failure();
254 if (!(stride = optStride.getValue())) {
255 parser.emitError(strideLoc, "ArrayStride must be greater than zero");
256 return failure();
258 return success();
261 // element-type ::= integer-type
262 // | floating-point-type
263 // | vector-type
264 // | spirv-type
266 // array-type ::= `!spv.array` `<` integer-literal `x` element-type
267 // (`,` `stride` `=` integer-literal)? `>`
268 static Type parseArrayType(SPIRVDialect const &dialect,
269 DialectAsmParser &parser) {
270 if (parser.parseLess())
271 return Type();
273 SmallVector<int64_t, 1> countDims;
274 llvm::SMLoc countLoc = parser.getCurrentLocation();
275 if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
276 return Type();
277 if (countDims.size() != 1) {
278 parser.emitError(countLoc,
279 "expected single integer for array element count");
280 return Type();
283 // According to the SPIR-V spec:
284 // "Length is the number of elements in the array. It must be at least 1."
285 int64_t count = countDims[0];
286 if (count == 0) {
287 parser.emitError(countLoc, "expected array length greater than 0");
288 return Type();
291 Type elementType = parseAndVerifyType(dialect, parser);
292 if (!elementType)
293 return Type();
295 unsigned stride = 0;
296 if (failed(parseOptionalArrayStride(dialect, parser, stride)))
297 return Type();
299 if (parser.parseGreater())
300 return Type();
301 return ArrayType::get(elementType, count, stride);
304 // cooperative-matrix-type ::= `!spv.coopmatrix` `<` element-type ',' scope ','
305 // rows ',' columns>`
306 static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
307 DialectAsmParser &parser) {
308 if (parser.parseLess())
309 return Type();
311 SmallVector<int64_t, 2> dims;
312 llvm::SMLoc countLoc = parser.getCurrentLocation();
313 if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
314 return Type();
316 if (dims.size() != 2) {
317 parser.emitError(countLoc, "expected rows and columns size");
318 return Type();
321 auto elementTy = parseAndVerifyType(dialect, parser);
322 if (!elementTy)
323 return Type();
325 Scope scope;
326 if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
327 return Type();
329 if (parser.parseGreater())
330 return Type();
331 return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]);
334 // TODO: Reorder methods to be utilities first and parse*Type
335 // methods in alphabetical order
337 // storage-class ::= `UniformConstant`
338 // | `Uniform`
339 // | `Workgroup`
340 // | <and other storage classes...>
342 // pointer-type ::= `!spv.ptr<` element-type `,` storage-class `>`
343 static Type parsePointerType(SPIRVDialect const &dialect,
344 DialectAsmParser &parser) {
345 if (parser.parseLess())
346 return Type();
348 auto pointeeType = parseAndVerifyType(dialect, parser);
349 if (!pointeeType)
350 return Type();
352 StringRef storageClassSpec;
353 llvm::SMLoc storageClassLoc = parser.getCurrentLocation();
354 if (parser.parseComma() || parser.parseKeyword(&storageClassSpec))
355 return Type();
357 auto storageClass = symbolizeStorageClass(storageClassSpec);
358 if (!storageClass) {
359 parser.emitError(storageClassLoc, "unknown storage class: ")
360 << storageClassSpec;
361 return Type();
363 if (parser.parseGreater())
364 return Type();
365 return PointerType::get(pointeeType, *storageClass);
368 // runtime-array-type ::= `!spv.rtarray` `<` element-type
369 // (`,` `stride` `=` integer-literal)? `>`
370 static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
371 DialectAsmParser &parser) {
372 if (parser.parseLess())
373 return Type();
375 Type elementType = parseAndVerifyType(dialect, parser);
376 if (!elementType)
377 return Type();
379 unsigned stride = 0;
380 if (failed(parseOptionalArrayStride(dialect, parser, stride)))
381 return Type();
383 if (parser.parseGreater())
384 return Type();
385 return RuntimeArrayType::get(elementType, stride);
388 // matrix-type ::= `!spv.matrix` `<` integer-literal `x` element-type `>`
389 static Type parseMatrixType(SPIRVDialect const &dialect,
390 DialectAsmParser &parser) {
391 if (parser.parseLess())
392 return Type();
394 SmallVector<int64_t, 1> countDims;
395 llvm::SMLoc countLoc = parser.getCurrentLocation();
396 if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
397 return Type();
398 if (countDims.size() != 1) {
399 parser.emitError(countLoc, "expected single unsigned "
400 "integer for number of columns");
401 return Type();
404 int64_t columnCount = countDims[0];
405 // According to the specification, Matrices can have 2, 3, or 4 columns
406 if (columnCount < 2 || columnCount > 4) {
407 parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 "
408 "columns");
409 return Type();
412 Type columnType = parseAndVerifyMatrixType(dialect, parser);
413 if (!columnType)
414 return Type();
416 if (parser.parseGreater())
417 return Type();
419 return MatrixType::get(columnType, columnCount);
422 // Specialize this function to parse each of the parameters that define an
423 // ImageType. By default it assumes this is an enum type.
424 template <typename ValTy>
425 static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
426 DialectAsmParser &parser) {
427 StringRef enumSpec;
428 llvm::SMLoc enumLoc = parser.getCurrentLocation();
429 if (parser.parseKeyword(&enumSpec)) {
430 return llvm::None;
433 auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
434 if (!val)
435 parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'";
436 return val;
439 template <>
440 Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
441 DialectAsmParser &parser) {
442 // TODO: Further verify that the element type can be sampled
443 auto ty = parseAndVerifyType(dialect, parser);
444 if (!ty)
445 return llvm::None;
446 return ty;
449 template <typename IntTy>
450 static Optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
451 DialectAsmParser &parser) {
452 IntTy offsetVal = std::numeric_limits<IntTy>::max();
453 if (parser.parseInteger(offsetVal))
454 return llvm::None;
455 return offsetVal;
458 template <>
459 Optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
460 DialectAsmParser &parser) {
461 return parseAndVerifyInteger<unsigned>(dialect, parser);
464 namespace {
465 // Functor object to parse a comma separated list of specs. The function
466 // parseAndVerify does the actual parsing and verification of individual
467 // elements. This is a functor since parsing the last element of the list
468 // (termination condition) needs partial specialization.
469 template <typename ParseType, typename... Args>
470 struct ParseCommaSeparatedList {
471 Optional<std::tuple<ParseType, Args...>>
472 operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
473 auto parseVal = parseAndVerify<ParseType>(dialect, parser);
474 if (!parseVal)
475 return llvm::None;
477 auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
478 if (numArgs != 0 && failed(parser.parseComma()))
479 return llvm::None;
480 auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
481 if (!remainingValues)
482 return llvm::None;
483 return std::tuple_cat(std::tuple<ParseType>(parseVal.getValue()),
484 remainingValues.getValue());
488 // Partial specialization of the function to parse a comma separated list of
489 // specs to parse the last element of the list.
490 template <typename ParseType>
491 struct ParseCommaSeparatedList<ParseType> {
492 Optional<std::tuple<ParseType>> operator()(SPIRVDialect const &dialect,
493 DialectAsmParser &parser) const {
494 if (auto value = parseAndVerify<ParseType>(dialect, parser))
495 return std::tuple<ParseType>(value.getValue());
496 return llvm::None;
499 } // namespace
501 // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
503 // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
505 // arrayed-info ::= `NonArrayed` | `Arrayed`
507 // sampling-info ::= `SingleSampled` | `MultiSampled`
509 // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler`
511 // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
513 // image-type ::= `!spv.image<` element-type `,` dim `,` depth-info `,`
514 // arrayed-info `,` sampling-info `,`
515 // sampler-use-info `,` format `>`
516 static Type parseImageType(SPIRVDialect const &dialect,
517 DialectAsmParser &parser) {
518 if (parser.parseLess())
519 return Type();
521 auto value =
522 ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
523 ImageSamplingInfo, ImageSamplerUseInfo,
524 ImageFormat>{}(dialect, parser);
525 if (!value)
526 return Type();
528 if (parser.parseGreater())
529 return Type();
530 return ImageType::get(value.getValue());
533 // Parse decorations associated with a member.
534 static ParseResult parseStructMemberDecorations(
535 SPIRVDialect const &dialect, DialectAsmParser &parser,
536 ArrayRef<Type> memberTypes,
537 SmallVectorImpl<StructType::OffsetInfo> &offsetInfo,
538 SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) {
540 // Check if the first element is offset.
541 llvm::SMLoc offsetLoc = parser.getCurrentLocation();
542 StructType::OffsetInfo offset = 0;
543 OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
544 if (offsetParseResult.hasValue()) {
545 if (failed(*offsetParseResult))
546 return failure();
548 if (offsetInfo.size() != memberTypes.size() - 1) {
549 return parser.emitError(offsetLoc,
550 "offset specification must be given for "
551 "all members");
553 offsetInfo.push_back(offset);
556 // Check for no spirv::Decorations.
557 if (succeeded(parser.parseOptionalRSquare()))
558 return success();
560 // If there was an offset, make sure to parse the comma.
561 if (offsetParseResult.hasValue() && parser.parseComma())
562 return failure();
564 // Check for spirv::Decorations.
565 do {
566 auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
567 if (!memberDecoration)
568 return failure();
570 // Parse member decoration value if it exists.
571 if (succeeded(parser.parseOptionalEqual())) {
572 auto memberDecorationValue =
573 parseAndVerifyInteger<uint32_t>(dialect, parser);
575 if (!memberDecorationValue)
576 return failure();
578 memberDecorationInfo.emplace_back(
579 static_cast<uint32_t>(memberTypes.size() - 1), 1,
580 memberDecoration.getValue(), memberDecorationValue.getValue());
581 } else {
582 memberDecorationInfo.emplace_back(
583 static_cast<uint32_t>(memberTypes.size() - 1), 0,
584 memberDecoration.getValue(), 0);
587 } while (succeeded(parser.parseOptionalComma()));
589 return parser.parseRSquare();
592 // struct-member-decoration ::= integer-literal? spirv-decoration*
593 // struct-type ::=
594 // `!spv.struct<` (id `,`)?
595 // `(`
596 // (spirv-type (`[` struct-member-decoration `]`)?)*
597 // `)>`
598 static Type parseStructType(SPIRVDialect const &dialect,
599 DialectAsmParser &parser) {
600 // TODO: This function is quite lengthy. Break it down into smaller chunks.
602 // To properly resolve recursive references while parsing recursive struct
603 // types, we need to maintain a list of enclosing struct type names. This set
604 // maintains the names of struct types in which the type we are about to parse
605 // is nested.
607 // Note: This has to be thread_local to enable multiple threads to safely
608 // parse concurrently.
609 thread_local llvm::SetVector<StringRef> structContext;
611 static auto removeIdentifierAndFail =
612 [](llvm::SetVector<StringRef> &structContext, StringRef identifier) {
613 if (!identifier.empty())
614 structContext.remove(identifier);
616 return Type();
619 if (parser.parseLess())
620 return Type();
622 StringRef identifier;
624 // Check if this is an identified struct type.
625 if (succeeded(parser.parseOptionalKeyword(&identifier))) {
626 // Check if this is a possible recursive reference.
627 if (succeeded(parser.parseOptionalGreater())) {
628 if (structContext.count(identifier) == 0) {
629 parser.emitError(
630 parser.getNameLoc(),
631 "recursive struct reference not nested in struct definition");
633 return Type();
636 return StructType::getIdentified(dialect.getContext(), identifier);
639 if (failed(parser.parseComma()))
640 return Type();
642 if (structContext.count(identifier) != 0) {
643 parser.emitError(parser.getNameLoc(),
644 "identifier already used for an enclosing struct");
646 return removeIdentifierAndFail(structContext, identifier);
649 structContext.insert(identifier);
652 if (failed(parser.parseLParen()))
653 return removeIdentifierAndFail(structContext, identifier);
655 if (succeeded(parser.parseOptionalRParen()) &&
656 succeeded(parser.parseOptionalGreater())) {
657 if (!identifier.empty())
658 structContext.remove(identifier);
660 return StructType::getEmpty(dialect.getContext(), identifier);
663 StructType idStructTy;
665 if (!identifier.empty())
666 idStructTy = StructType::getIdentified(dialect.getContext(), identifier);
668 SmallVector<Type, 4> memberTypes;
669 SmallVector<StructType::OffsetInfo, 4> offsetInfo;
670 SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo;
672 do {
673 Type memberType;
674 if (parser.parseType(memberType))
675 return removeIdentifierAndFail(structContext, identifier);
676 memberTypes.push_back(memberType);
678 if (succeeded(parser.parseOptionalLSquare()))
679 if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
680 memberDecorationInfo))
681 return removeIdentifierAndFail(structContext, identifier);
682 } while (succeeded(parser.parseOptionalComma()));
684 if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
685 parser.emitError(parser.getNameLoc(),
686 "offset specification must be given for all members");
687 return removeIdentifierAndFail(structContext, identifier);
690 if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
691 return removeIdentifierAndFail(structContext, identifier);
693 if (!identifier.empty()) {
694 if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
695 memberDecorationInfo)))
696 return Type();
698 structContext.remove(identifier);
699 return idStructTy;
702 return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
705 // spirv-type ::= array-type
706 // | element-type
707 // | image-type
708 // | pointer-type
709 // | runtime-array-type
710 // | struct-type
711 Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
712 StringRef keyword;
713 if (parser.parseKeyword(&keyword))
714 return Type();
716 if (keyword == "array")
717 return parseArrayType(*this, parser);
718 if (keyword == "coopmatrix")
719 return parseCooperativeMatrixType(*this, parser);
720 if (keyword == "image")
721 return parseImageType(*this, parser);
722 if (keyword == "ptr")
723 return parsePointerType(*this, parser);
724 if (keyword == "rtarray")
725 return parseRuntimeArrayType(*this, parser);
726 if (keyword == "struct")
727 return parseStructType(*this, parser);
728 if (keyword == "matrix")
729 return parseMatrixType(*this, parser);
730 parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
731 return Type();
734 //===----------------------------------------------------------------------===//
735 // Type Printing
736 //===----------------------------------------------------------------------===//
738 static void print(ArrayType type, DialectAsmPrinter &os) {
739 os << "array<" << type.getNumElements() << " x " << type.getElementType();
740 if (unsigned stride = type.getArrayStride())
741 os << ", stride=" << stride;
742 os << ">";
745 static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
746 os << "rtarray<" << type.getElementType();
747 if (unsigned stride = type.getArrayStride())
748 os << ", stride=" << stride;
749 os << ">";
752 static void print(PointerType type, DialectAsmPrinter &os) {
753 os << "ptr<" << type.getPointeeType() << ", "
754 << stringifyStorageClass(type.getStorageClass()) << ">";
757 static void print(ImageType type, DialectAsmPrinter &os) {
758 os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
759 << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
760 << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
761 << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
762 << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
763 << stringifyImageFormat(type.getImageFormat()) << ">";
766 static void print(StructType type, DialectAsmPrinter &os) {
767 thread_local llvm::SetVector<StringRef> structContext;
769 os << "struct<";
771 if (type.isIdentified()) {
772 os << type.getIdentifier();
774 if (structContext.count(type.getIdentifier())) {
775 os << ">";
776 return;
779 os << ", ";
780 structContext.insert(type.getIdentifier());
783 os << "(";
785 auto printMember = [&](unsigned i) {
786 os << type.getElementType(i);
787 SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations;
788 type.getMemberDecorations(i, decorations);
789 if (type.hasOffset() || !decorations.empty()) {
790 os << " [";
791 if (type.hasOffset()) {
792 os << type.getMemberOffset(i);
793 if (!decorations.empty())
794 os << ", ";
796 auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
797 os << stringifyDecoration(decoration.decoration);
798 if (decoration.hasValue) {
799 os << "=" << decoration.decorationValue;
802 llvm::interleaveComma(decorations, os, eachFn);
803 os << "]";
806 llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
807 printMember);
808 os << ")>";
810 if (type.isIdentified())
811 structContext.remove(type.getIdentifier());
814 static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
815 os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
816 os << type.getElementType() << ", " << stringifyScope(type.getScope());
817 os << ">";
820 static void print(MatrixType type, DialectAsmPrinter &os) {
821 os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
822 os << ">";
825 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
826 TypeSwitch<Type>(type)
827 .Case<ArrayType, CooperativeMatrixNVType, PointerType, RuntimeArrayType,
828 ImageType, StructType, MatrixType>(
829 [&](auto type) { print(type, os); })
830 .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
833 //===----------------------------------------------------------------------===//
834 // Attribute Parsing
835 //===----------------------------------------------------------------------===//
837 /// Parses a comma-separated list of keywords, invokes `processKeyword` on each
838 /// of the parsed keyword, and returns failure if any error occurs.
839 static ParseResult parseKeywordList(
840 DialectAsmParser &parser,
841 function_ref<LogicalResult(llvm::SMLoc, StringRef)> processKeyword) {
842 if (parser.parseLSquare())
843 return failure();
845 // Special case for empty list.
846 if (succeeded(parser.parseOptionalRSquare()))
847 return success();
849 // Keep parsing the keyword and an optional comma following it. If the comma
850 // is successfully parsed, then we have more keywords to parse.
851 do {
852 auto loc = parser.getCurrentLocation();
853 StringRef keyword;
854 if (parser.parseKeyword(&keyword) || failed(processKeyword(loc, keyword)))
855 return failure();
856 } while (succeeded(parser.parseOptionalComma()));
858 if (parser.parseRSquare())
859 return failure();
861 return success();
864 /// Parses a spirv::InterfaceVarABIAttr.
865 static Attribute parseInterfaceVarABIAttr(DialectAsmParser &parser) {
866 if (parser.parseLess())
867 return {};
869 Builder &builder = parser.getBuilder();
871 if (parser.parseLParen())
872 return {};
874 IntegerAttr descriptorSetAttr;
876 auto loc = parser.getCurrentLocation();
877 uint32_t descriptorSet = 0;
878 auto descriptorSetParseResult = parser.parseOptionalInteger(descriptorSet);
880 if (!descriptorSetParseResult.hasValue() ||
881 failed(*descriptorSetParseResult)) {
882 parser.emitError(loc, "missing descriptor set");
883 return {};
885 descriptorSetAttr = builder.getI32IntegerAttr(descriptorSet);
888 if (parser.parseComma())
889 return {};
891 IntegerAttr bindingAttr;
893 auto loc = parser.getCurrentLocation();
894 uint32_t binding = 0;
895 auto bindingParseResult = parser.parseOptionalInteger(binding);
897 if (!bindingParseResult.hasValue() || failed(*bindingParseResult)) {
898 parser.emitError(loc, "missing binding");
899 return {};
901 bindingAttr = builder.getI32IntegerAttr(binding);
904 if (parser.parseRParen())
905 return {};
907 IntegerAttr storageClassAttr;
909 if (succeeded(parser.parseOptionalComma())) {
910 auto loc = parser.getCurrentLocation();
911 StringRef storageClass;
912 if (parser.parseKeyword(&storageClass))
913 return {};
915 if (auto storageClassSymbol =
916 spirv::symbolizeStorageClass(storageClass)) {
917 storageClassAttr = builder.getI32IntegerAttr(
918 static_cast<uint32_t>(*storageClassSymbol));
919 } else {
920 parser.emitError(loc, "unknown storage class: ") << storageClass;
921 return {};
926 if (parser.parseGreater())
927 return {};
929 return spirv::InterfaceVarABIAttr::get(descriptorSetAttr, bindingAttr,
930 storageClassAttr);
933 static Attribute parseVerCapExtAttr(DialectAsmParser &parser) {
934 if (parser.parseLess())
935 return {};
937 Builder &builder = parser.getBuilder();
939 IntegerAttr versionAttr;
941 auto loc = parser.getCurrentLocation();
942 StringRef version;
943 if (parser.parseKeyword(&version) || parser.parseComma())
944 return {};
946 if (auto versionSymbol = spirv::symbolizeVersion(version)) {
947 versionAttr =
948 builder.getI32IntegerAttr(static_cast<uint32_t>(*versionSymbol));
949 } else {
950 parser.emitError(loc, "unknown version: ") << version;
951 return {};
955 ArrayAttr capabilitiesAttr;
957 SmallVector<Attribute, 4> capabilities;
958 llvm::SMLoc errorloc;
959 StringRef errorKeyword;
961 auto processCapability = [&](llvm::SMLoc loc, StringRef capability) {
962 if (auto capSymbol = spirv::symbolizeCapability(capability)) {
963 capabilities.push_back(
964 builder.getI32IntegerAttr(static_cast<uint32_t>(*capSymbol)));
965 return success();
967 return errorloc = loc, errorKeyword = capability, failure();
969 if (parseKeywordList(parser, processCapability) || parser.parseComma()) {
970 if (!errorKeyword.empty())
971 parser.emitError(errorloc, "unknown capability: ") << errorKeyword;
972 return {};
975 capabilitiesAttr = builder.getArrayAttr(capabilities);
978 ArrayAttr extensionsAttr;
980 SmallVector<Attribute, 1> extensions;
981 llvm::SMLoc errorloc;
982 StringRef errorKeyword;
984 auto processExtension = [&](llvm::SMLoc loc, StringRef extension) {
985 if (spirv::symbolizeExtension(extension)) {
986 extensions.push_back(builder.getStringAttr(extension));
987 return success();
989 return errorloc = loc, errorKeyword = extension, failure();
991 if (parseKeywordList(parser, processExtension)) {
992 if (!errorKeyword.empty())
993 parser.emitError(errorloc, "unknown extension: ") << errorKeyword;
994 return {};
997 extensionsAttr = builder.getArrayAttr(extensions);
1000 if (parser.parseGreater())
1001 return {};
1003 return spirv::VerCapExtAttr::get(versionAttr, capabilitiesAttr,
1004 extensionsAttr);
1007 /// Parses a spirv::TargetEnvAttr.
1008 static Attribute parseTargetEnvAttr(DialectAsmParser &parser) {
1009 if (parser.parseLess())
1010 return {};
1012 spirv::VerCapExtAttr tripleAttr;
1013 if (parser.parseAttribute(tripleAttr) || parser.parseComma())
1014 return {};
1016 // Parse [vendor[:device-type[:device-id]]]
1017 Vendor vendorID = Vendor::Unknown;
1018 DeviceType deviceType = DeviceType::Unknown;
1019 uint32_t deviceID = spirv::TargetEnvAttr::kUnknownDeviceID;
1021 auto loc = parser.getCurrentLocation();
1022 StringRef vendorStr;
1023 if (succeeded(parser.parseOptionalKeyword(&vendorStr))) {
1024 if (auto vendorSymbol = spirv::symbolizeVendor(vendorStr)) {
1025 vendorID = *vendorSymbol;
1026 } else {
1027 parser.emitError(loc, "unknown vendor: ") << vendorStr;
1030 if (succeeded(parser.parseOptionalColon())) {
1031 loc = parser.getCurrentLocation();
1032 StringRef deviceTypeStr;
1033 if (parser.parseKeyword(&deviceTypeStr))
1034 return {};
1035 if (auto deviceTypeSymbol = spirv::symbolizeDeviceType(deviceTypeStr)) {
1036 deviceType = *deviceTypeSymbol;
1037 } else {
1038 parser.emitError(loc, "unknown device type: ") << deviceTypeStr;
1041 if (succeeded(parser.parseOptionalColon())) {
1042 loc = parser.getCurrentLocation();
1043 if (parser.parseInteger(deviceID))
1044 return {};
1047 if (parser.parseComma())
1048 return {};
1052 DictionaryAttr limitsAttr;
1054 auto loc = parser.getCurrentLocation();
1055 if (parser.parseAttribute(limitsAttr))
1056 return {};
1058 if (!limitsAttr.isa<spirv::ResourceLimitsAttr>()) {
1059 parser.emitError(
1060 loc,
1061 "limits must be a dictionary attribute containing two 32-bit integer "
1062 "attributes 'max_compute_workgroup_invocations' and "
1063 "'max_compute_workgroup_size'");
1064 return {};
1068 if (parser.parseGreater())
1069 return {};
1071 return spirv::TargetEnvAttr::get(tripleAttr, vendorID, deviceType, deviceID,
1072 limitsAttr);
1075 Attribute SPIRVDialect::parseAttribute(DialectAsmParser &parser,
1076 Type type) const {
1077 // SPIR-V attributes are dictionaries so they do not have type.
1078 if (type) {
1079 parser.emitError(parser.getNameLoc(), "unexpected type");
1080 return {};
1083 // Parse the kind keyword first.
1084 StringRef attrKind;
1085 if (parser.parseKeyword(&attrKind))
1086 return {};
1088 if (attrKind == spirv::TargetEnvAttr::getKindName())
1089 return parseTargetEnvAttr(parser);
1090 if (attrKind == spirv::VerCapExtAttr::getKindName())
1091 return parseVerCapExtAttr(parser);
1092 if (attrKind == spirv::InterfaceVarABIAttr::getKindName())
1093 return parseInterfaceVarABIAttr(parser);
1095 parser.emitError(parser.getNameLoc(), "unknown SPIR-V attribute kind: ")
1096 << attrKind;
1097 return {};
1100 //===----------------------------------------------------------------------===//
1101 // Attribute Printing
1102 //===----------------------------------------------------------------------===//
1104 static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer) {
1105 auto &os = printer.getStream();
1106 printer << spirv::VerCapExtAttr::getKindName() << "<"
1107 << spirv::stringifyVersion(triple.getVersion()) << ", [";
1108 llvm::interleaveComma(
1109 triple.getCapabilities(), os,
1110 [&](spirv::Capability cap) { os << spirv::stringifyCapability(cap); });
1111 printer << "], [";
1112 llvm::interleaveComma(triple.getExtensionsAttr(), os, [&](Attribute attr) {
1113 os << attr.cast<StringAttr>().getValue();
1115 printer << "]>";
1118 static void print(spirv::TargetEnvAttr targetEnv, DialectAsmPrinter &printer) {
1119 printer << spirv::TargetEnvAttr::getKindName() << "<#spv.";
1120 print(targetEnv.getTripleAttr(), printer);
1121 spirv::Vendor vendorID = targetEnv.getVendorID();
1122 spirv::DeviceType deviceType = targetEnv.getDeviceType();
1123 uint32_t deviceID = targetEnv.getDeviceID();
1124 if (vendorID != spirv::Vendor::Unknown) {
1125 printer << ", " << spirv::stringifyVendor(vendorID);
1126 if (deviceType != spirv::DeviceType::Unknown) {
1127 printer << ":" << spirv::stringifyDeviceType(deviceType);
1128 if (deviceID != spirv::TargetEnvAttr::kUnknownDeviceID)
1129 printer << ":" << deviceID;
1132 printer << ", " << targetEnv.getResourceLimits() << ">";
1135 static void print(spirv::InterfaceVarABIAttr interfaceVarABIAttr,
1136 DialectAsmPrinter &printer) {
1137 printer << spirv::InterfaceVarABIAttr::getKindName() << "<("
1138 << interfaceVarABIAttr.getDescriptorSet() << ", "
1139 << interfaceVarABIAttr.getBinding() << ")";
1140 auto storageClass = interfaceVarABIAttr.getStorageClass();
1141 if (storageClass)
1142 printer << ", " << spirv::stringifyStorageClass(*storageClass);
1143 printer << ">";
1146 void SPIRVDialect::printAttribute(Attribute attr,
1147 DialectAsmPrinter &printer) const {
1148 if (auto targetEnv = attr.dyn_cast<TargetEnvAttr>())
1149 print(targetEnv, printer);
1150 else if (auto vceAttr = attr.dyn_cast<VerCapExtAttr>())
1151 print(vceAttr, printer);
1152 else if (auto interfaceVarABIAttr = attr.dyn_cast<InterfaceVarABIAttr>())
1153 print(interfaceVarABIAttr, printer);
1154 else
1155 llvm_unreachable("unhandled SPIR-V attribute kind");
1158 //===----------------------------------------------------------------------===//
1159 // Constant
1160 //===----------------------------------------------------------------------===//
1162 Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
1163 Attribute value, Type type,
1164 Location loc) {
1165 if (!spirv::ConstantOp::isBuildableWith(type))
1166 return nullptr;
1168 return builder.create<spirv::ConstantOp>(loc, type, value);
1171 //===----------------------------------------------------------------------===//
1172 // Shader Interface ABI
1173 //===----------------------------------------------------------------------===//
1175 LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
1176 NamedAttribute attribute) {
1177 StringRef symbol = attribute.first.strref();
1178 Attribute attr = attribute.second;
1180 // TODO: figure out a way to generate the description from the
1181 // StructAttr definition.
1182 if (symbol == spirv::getEntryPointABIAttrName()) {
1183 if (!attr.isa<spirv::EntryPointABIAttr>())
1184 return op->emitError("'")
1185 << symbol
1186 << "' attribute must be a dictionary attribute containing one "
1187 "32-bit integer elements attribute: 'local_size'";
1188 } else if (symbol == spirv::getTargetEnvAttrName()) {
1189 if (!attr.isa<spirv::TargetEnvAttr>())
1190 return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
1191 } else {
1192 return op->emitError("found unsupported '")
1193 << symbol << "' attribute on operation";
1196 return success();
1199 /// Verifies the given SPIR-V `attribute` attached to a value of the given
1200 /// `valueType` is valid.
1201 static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
1202 NamedAttribute attribute) {
1203 StringRef symbol = attribute.first.strref();
1204 Attribute attr = attribute.second;
1206 if (symbol != spirv::getInterfaceVarABIAttrName())
1207 return emitError(loc, "found unsupported '")
1208 << symbol << "' attribute on region argument";
1210 auto varABIAttr = attr.dyn_cast<spirv::InterfaceVarABIAttr>();
1211 if (!varABIAttr)
1212 return emitError(loc, "'")
1213 << symbol << "' must be a spirv::InterfaceVarABIAttr";
1215 if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
1216 return emitError(loc, "'") << symbol
1217 << "' attribute cannot specify storage class "
1218 "when attaching to a non-scalar value";
1220 return success();
1223 LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
1224 unsigned regionIndex,
1225 unsigned argIndex,
1226 NamedAttribute attribute) {
1227 return verifyRegionAttribute(
1228 op->getLoc(), op->getRegion(regionIndex).getArgument(argIndex).getType(),
1229 attribute);
1232 LogicalResult SPIRVDialect::verifyRegionResultAttribute(
1233 Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
1234 NamedAttribute attribute) {
1235 return op->emitError("cannot attach SPIR-V attributes to region result");