[rtsan] Remove mkfifoat interceptor (#116997)
[llvm-project.git] / mlir / lib / Dialect / SPIRV / IR / SPIRVDialect.cpp
blob48be287ef833b274648f8107b15957ab705909b3
1 //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V dialect in MLIR.
11 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "SPIRVParsingUtils.h"
17 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
20 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21 #include "mlir/Dialect/UB/IR/UBOps.h"
22 #include "mlir/IR/Builders.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/DialectImplementation.h"
25 #include "mlir/IR/MLIRContext.h"
26 #include "mlir/Parser/Parser.h"
27 #include "mlir/Transforms/InliningUtils.h"
28 #include "llvm/ADT/DenseMap.h"
29 #include "llvm/ADT/Sequence.h"
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/ADT/StringExtras.h"
32 #include "llvm/ADT/StringMap.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/Support/raw_ostream.h"
36 using namespace mlir;
37 using namespace mlir::spirv;
39 #include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
41 //===----------------------------------------------------------------------===//
42 // InlinerInterface
43 //===----------------------------------------------------------------------===//
45 /// Returns true if the given region contains spirv.Return or spirv.ReturnValue
46 /// ops.
47 static inline bool containsReturn(Region &region) {
48 return llvm::any_of(region, [](Block &block) {
49 Operation *terminator = block.getTerminator();
50 return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
51 });
54 namespace {
55 /// This class defines the interface for inlining within the SPIR-V dialect.
56 struct SPIRVInlinerInterface : public DialectInlinerInterface {
57 using DialectInlinerInterface::DialectInlinerInterface;
59 /// All call operations within SPIRV can be inlined.
60 bool isLegalToInline(Operation *call, Operation *callable,
61 bool wouldBeCloned) const final {
62 return true;
65 /// Returns true if the given region 'src' can be inlined into the region
66 /// 'dest' that is attached to an operation registered to the current dialect.
67 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
68 IRMapping &) const final {
69 // Return true here when inlining into spirv.func, spirv.mlir.selection, and
70 // spirv.mlir.loop operations.
71 auto *op = dest->getParentOp();
72 return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
75 /// Returns true if the given operation 'op', that is registered to this
76 /// dialect, can be inlined into the region 'dest' that is attached to an
77 /// operation registered to the current dialect.
78 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
79 IRMapping &) const final {
80 // TODO: Enable inlining structured control flows with return.
81 if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
82 containsReturn(op->getRegion(0)))
83 return false;
84 // TODO: we need to filter OpKill here to avoid inlining it to
85 // a loop continue construct:
86 // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
87 // However OpKill is fragment shader specific and we don't support it yet.
88 return true;
91 /// Handle the given inlined terminator by replacing it with a new operation
92 /// as necessary.
93 void handleTerminator(Operation *op, Block *newDest) const final {
94 if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
95 OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest);
96 op->erase();
97 } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
98 OpBuilder(op).create<spirv::BranchOp>(retValOp->getLoc(), newDest,
99 retValOp->getOperands());
100 op->erase();
104 /// Handle the given inlined terminator by replacing it with a new operation
105 /// as necessary.
106 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
107 // Only spirv.ReturnValue needs to be handled here.
108 auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
109 if (!retValOp)
110 return;
112 // Replace the values directly with the return operands.
113 assert(valuesToRepl.size() == 1 &&
114 "spirv.ReturnValue expected to only handle one result");
115 valuesToRepl.front().replaceAllUsesWith(retValOp.getValue());
118 } // namespace
120 //===----------------------------------------------------------------------===//
121 // SPIR-V Dialect
122 //===----------------------------------------------------------------------===//
124 void SPIRVDialect::initialize() {
125 registerAttributes();
126 registerTypes();
128 // Add SPIR-V ops.
129 addOperations<
130 #define GET_OP_LIST
131 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
132 >();
134 addInterfaces<SPIRVInlinerInterface>();
136 // Allow unknown operations because SPIR-V is extensible.
137 allowUnknownOperations();
138 declarePromisedInterface<gpu::TargetAttrInterface, TargetEnvAttr>();
141 std::string SPIRVDialect::getAttributeName(Decoration decoration) {
142 return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration));
145 //===----------------------------------------------------------------------===//
146 // Type Parsing
147 //===----------------------------------------------------------------------===//
149 // Forward declarations.
150 template <typename ValTy>
151 static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
152 DialectAsmParser &parser);
153 template <>
154 std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
155 DialectAsmParser &parser);
157 template <>
158 std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
159 DialectAsmParser &parser);
161 static Type parseAndVerifyType(SPIRVDialect const &dialect,
162 DialectAsmParser &parser) {
163 Type type;
164 SMLoc typeLoc = parser.getCurrentLocation();
165 if (parser.parseType(type))
166 return Type();
168 // Allow SPIR-V dialect types
169 if (&type.getDialect() == &dialect)
170 return type;
172 // Check other allowed types
173 if (auto t = llvm::dyn_cast<FloatType>(type)) {
174 if (type.isBF16()) {
175 parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
176 return Type();
178 } else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
179 if (!ScalarType::isValid(t)) {
180 parser.emitError(typeLoc,
181 "only 1/8/16/32/64-bit integer type allowed but found ")
182 << type;
183 return Type();
185 } else if (auto t = llvm::dyn_cast<VectorType>(type)) {
186 if (t.getRank() != 1) {
187 parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
188 return Type();
190 if (t.getNumElements() > 4) {
191 parser.emitError(
192 typeLoc, "vector length has to be less than or equal to 4 but found ")
193 << t.getNumElements();
194 return Type();
196 } else {
197 parser.emitError(typeLoc, "cannot use ")
198 << type << " to compose SPIR-V types";
199 return Type();
202 return type;
205 static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
206 DialectAsmParser &parser) {
207 Type type;
208 SMLoc typeLoc = parser.getCurrentLocation();
209 if (parser.parseType(type))
210 return Type();
212 if (auto t = llvm::dyn_cast<VectorType>(type)) {
213 if (t.getRank() != 1) {
214 parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
215 return Type();
217 if (t.getNumElements() > 4 || t.getNumElements() < 2) {
218 parser.emitError(typeLoc,
219 "matrix columns size has to be less than or equal "
220 "to 4 and greater than or equal 2, but found ")
221 << t.getNumElements();
222 return Type();
225 if (!llvm::isa<FloatType>(t.getElementType())) {
226 parser.emitError(typeLoc, "matrix columns' elements must be of "
227 "Float type, got ")
228 << t.getElementType();
229 return Type();
231 } else {
232 parser.emitError(typeLoc, "matrix must be composed using vector "
233 "type, got ")
234 << type;
235 return Type();
238 return type;
241 static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect,
242 DialectAsmParser &parser) {
243 Type type;
244 SMLoc typeLoc = parser.getCurrentLocation();
245 if (parser.parseType(type))
246 return Type();
248 if (!llvm::isa<ImageType>(type)) {
249 parser.emitError(typeLoc,
250 "sampled image must be composed using image type, got ")
251 << type;
252 return Type();
255 return type;
258 /// Parses an optional `, stride = N` assembly segment. If no parsing failure
259 /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
260 /// missing.
261 static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
262 DialectAsmParser &parser,
263 unsigned &stride) {
264 if (failed(parser.parseOptionalComma())) {
265 stride = 0;
266 return success();
269 if (parser.parseKeyword("stride") || parser.parseEqual())
270 return failure();
272 SMLoc strideLoc = parser.getCurrentLocation();
273 std::optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
274 if (!optStride)
275 return failure();
277 if (!(stride = *optStride)) {
278 parser.emitError(strideLoc, "ArrayStride must be greater than zero");
279 return failure();
281 return success();
284 // element-type ::= integer-type
285 // | floating-point-type
286 // | vector-type
287 // | spirv-type
289 // array-type ::= `!spirv.array` `<` integer-literal `x` element-type
290 // (`,` `stride` `=` integer-literal)? `>`
291 static Type parseArrayType(SPIRVDialect const &dialect,
292 DialectAsmParser &parser) {
293 if (parser.parseLess())
294 return Type();
296 SmallVector<int64_t, 1> countDims;
297 SMLoc countLoc = parser.getCurrentLocation();
298 if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
299 return Type();
300 if (countDims.size() != 1) {
301 parser.emitError(countLoc,
302 "expected single integer for array element count");
303 return Type();
306 // According to the SPIR-V spec:
307 // "Length is the number of elements in the array. It must be at least 1."
308 int64_t count = countDims[0];
309 if (count == 0) {
310 parser.emitError(countLoc, "expected array length greater than 0");
311 return Type();
314 Type elementType = parseAndVerifyType(dialect, parser);
315 if (!elementType)
316 return Type();
318 unsigned stride = 0;
319 if (failed(parseOptionalArrayStride(dialect, parser, stride)))
320 return Type();
322 if (parser.parseGreater())
323 return Type();
324 return ArrayType::get(elementType, count, stride);
327 // cooperative-matrix-type ::=
328 // `!spirv.coopmatrix` `<` rows `x` columns `x` element-type `,`
329 // scope `,` use `>`
330 static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
331 DialectAsmParser &parser) {
332 if (parser.parseLess())
333 return {};
335 SmallVector<int64_t, 2> dims;
336 SMLoc countLoc = parser.getCurrentLocation();
337 if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
338 return {};
340 if (dims.size() != 2) {
341 parser.emitError(countLoc, "expected row and column count");
342 return {};
345 auto elementTy = parseAndVerifyType(dialect, parser);
346 if (!elementTy)
347 return {};
349 Scope scope;
350 if (parser.parseComma() ||
351 spirv::parseEnumKeywordAttr(scope, parser, "scope <id>"))
352 return {};
354 CooperativeMatrixUseKHR use;
355 if (parser.parseComma() ||
356 spirv::parseEnumKeywordAttr(use, parser, "use <id>"))
357 return {};
359 if (parser.parseGreater())
360 return {};
362 return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
365 // TODO: Reorder methods to be utilities first and parse*Type
366 // methods in alphabetical order
368 // storage-class ::= `UniformConstant`
369 // | `Uniform`
370 // | `Workgroup`
371 // | <and other storage classes...>
373 // pointer-type ::= `!spirv.ptr<` element-type `,` storage-class `>`
374 static Type parsePointerType(SPIRVDialect const &dialect,
375 DialectAsmParser &parser) {
376 if (parser.parseLess())
377 return Type();
379 auto pointeeType = parseAndVerifyType(dialect, parser);
380 if (!pointeeType)
381 return Type();
383 StringRef storageClassSpec;
384 SMLoc storageClassLoc = parser.getCurrentLocation();
385 if (parser.parseComma() || parser.parseKeyword(&storageClassSpec))
386 return Type();
388 auto storageClass = symbolizeStorageClass(storageClassSpec);
389 if (!storageClass) {
390 parser.emitError(storageClassLoc, "unknown storage class: ")
391 << storageClassSpec;
392 return Type();
394 if (parser.parseGreater())
395 return Type();
396 return PointerType::get(pointeeType, *storageClass);
399 // runtime-array-type ::= `!spirv.rtarray` `<` element-type
400 // (`,` `stride` `=` integer-literal)? `>`
401 static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
402 DialectAsmParser &parser) {
403 if (parser.parseLess())
404 return Type();
406 Type elementType = parseAndVerifyType(dialect, parser);
407 if (!elementType)
408 return Type();
410 unsigned stride = 0;
411 if (failed(parseOptionalArrayStride(dialect, parser, stride)))
412 return Type();
414 if (parser.parseGreater())
415 return Type();
416 return RuntimeArrayType::get(elementType, stride);
419 // matrix-type ::= `!spirv.matrix` `<` integer-literal `x` element-type `>`
420 static Type parseMatrixType(SPIRVDialect const &dialect,
421 DialectAsmParser &parser) {
422 if (parser.parseLess())
423 return Type();
425 SmallVector<int64_t, 1> countDims;
426 SMLoc countLoc = parser.getCurrentLocation();
427 if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
428 return Type();
429 if (countDims.size() != 1) {
430 parser.emitError(countLoc, "expected single unsigned "
431 "integer for number of columns");
432 return Type();
435 int64_t columnCount = countDims[0];
436 // According to the specification, Matrices can have 2, 3, or 4 columns
437 if (columnCount < 2 || columnCount > 4) {
438 parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 "
439 "columns");
440 return Type();
443 Type columnType = parseAndVerifyMatrixType(dialect, parser);
444 if (!columnType)
445 return Type();
447 if (parser.parseGreater())
448 return Type();
450 return MatrixType::get(columnType, columnCount);
453 // Specialize this function to parse each of the parameters that define an
454 // ImageType. By default it assumes this is an enum type.
455 template <typename ValTy>
456 static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
457 DialectAsmParser &parser) {
458 StringRef enumSpec;
459 SMLoc enumLoc = parser.getCurrentLocation();
460 if (parser.parseKeyword(&enumSpec)) {
461 return std::nullopt;
464 auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
465 if (!val)
466 parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'";
467 return val;
470 template <>
471 std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
472 DialectAsmParser &parser) {
473 // TODO: Further verify that the element type can be sampled
474 auto ty = parseAndVerifyType(dialect, parser);
475 if (!ty)
476 return std::nullopt;
477 return ty;
480 template <typename IntTy>
481 static std::optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
482 DialectAsmParser &parser) {
483 IntTy offsetVal = std::numeric_limits<IntTy>::max();
484 if (parser.parseInteger(offsetVal))
485 return std::nullopt;
486 return offsetVal;
489 template <>
490 std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
491 DialectAsmParser &parser) {
492 return parseAndVerifyInteger<unsigned>(dialect, parser);
495 namespace {
496 // Functor object to parse a comma separated list of specs. The function
497 // parseAndVerify does the actual parsing and verification of individual
498 // elements. This is a functor since parsing the last element of the list
499 // (termination condition) needs partial specialization.
500 template <typename ParseType, typename... Args>
501 struct ParseCommaSeparatedList {
502 std::optional<std::tuple<ParseType, Args...>>
503 operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
504 auto parseVal = parseAndVerify<ParseType>(dialect, parser);
505 if (!parseVal)
506 return std::nullopt;
508 auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
509 if (numArgs != 0 && failed(parser.parseComma()))
510 return std::nullopt;
511 auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
512 if (!remainingValues)
513 return std::nullopt;
514 return std::tuple_cat(std::tuple<ParseType>(parseVal.value()),
515 remainingValues.value());
519 // Partial specialization of the function to parse a comma separated list of
520 // specs to parse the last element of the list.
521 template <typename ParseType>
522 struct ParseCommaSeparatedList<ParseType> {
523 std::optional<std::tuple<ParseType>>
524 operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
525 if (auto value = parseAndVerify<ParseType>(dialect, parser))
526 return std::tuple<ParseType>(*value);
527 return std::nullopt;
530 } // namespace
532 // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
534 // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
536 // arrayed-info ::= `NonArrayed` | `Arrayed`
538 // sampling-info ::= `SingleSampled` | `MultiSampled`
540 // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler`
542 // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
544 // image-type ::= `!spirv.image<` element-type `,` dim `,` depth-info `,`
545 // arrayed-info `,` sampling-info `,`
546 // sampler-use-info `,` format `>`
547 static Type parseImageType(SPIRVDialect const &dialect,
548 DialectAsmParser &parser) {
549 if (parser.parseLess())
550 return Type();
552 auto value =
553 ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
554 ImageSamplingInfo, ImageSamplerUseInfo,
555 ImageFormat>{}(dialect, parser);
556 if (!value)
557 return Type();
559 if (parser.parseGreater())
560 return Type();
561 return ImageType::get(*value);
564 // sampledImage-type :: = `!spirv.sampledImage<` image-type `>`
565 static Type parseSampledImageType(SPIRVDialect const &dialect,
566 DialectAsmParser &parser) {
567 if (parser.parseLess())
568 return Type();
570 Type parsedType = parseAndVerifySampledImageType(dialect, parser);
571 if (!parsedType)
572 return Type();
574 if (parser.parseGreater())
575 return Type();
576 return SampledImageType::get(parsedType);
579 // Parse decorations associated with a member.
580 static ParseResult parseStructMemberDecorations(
581 SPIRVDialect const &dialect, DialectAsmParser &parser,
582 ArrayRef<Type> memberTypes,
583 SmallVectorImpl<StructType::OffsetInfo> &offsetInfo,
584 SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) {
586 // Check if the first element is offset.
587 SMLoc offsetLoc = parser.getCurrentLocation();
588 StructType::OffsetInfo offset = 0;
589 OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
590 if (offsetParseResult.has_value()) {
591 if (failed(*offsetParseResult))
592 return failure();
594 if (offsetInfo.size() != memberTypes.size() - 1) {
595 return parser.emitError(offsetLoc,
596 "offset specification must be given for "
597 "all members");
599 offsetInfo.push_back(offset);
602 // Check for no spirv::Decorations.
603 if (succeeded(parser.parseOptionalRSquare()))
604 return success();
606 // If there was an offset, make sure to parse the comma.
607 if (offsetParseResult.has_value() && parser.parseComma())
608 return failure();
610 // Check for spirv::Decorations.
611 auto parseDecorations = [&]() {
612 auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
613 if (!memberDecoration)
614 return failure();
616 // Parse member decoration value if it exists.
617 if (succeeded(parser.parseOptionalEqual())) {
618 auto memberDecorationValue =
619 parseAndVerifyInteger<uint32_t>(dialect, parser);
621 if (!memberDecorationValue)
622 return failure();
624 memberDecorationInfo.emplace_back(
625 static_cast<uint32_t>(memberTypes.size() - 1), 1,
626 memberDecoration.value(), memberDecorationValue.value());
627 } else {
628 memberDecorationInfo.emplace_back(
629 static_cast<uint32_t>(memberTypes.size() - 1), 0,
630 memberDecoration.value(), 0);
632 return success();
634 if (failed(parser.parseCommaSeparatedList(parseDecorations)) ||
635 failed(parser.parseRSquare()))
636 return failure();
638 return success();
641 // struct-member-decoration ::= integer-literal? spirv-decoration*
642 // struct-type ::=
643 // `!spirv.struct<` (id `,`)?
644 // `(`
645 // (spirv-type (`[` struct-member-decoration `]`)?)*
646 // `)>`
647 static Type parseStructType(SPIRVDialect const &dialect,
648 DialectAsmParser &parser) {
649 // TODO: This function is quite lengthy. Break it down into smaller chunks.
651 if (parser.parseLess())
652 return Type();
654 StringRef identifier;
655 FailureOr<DialectAsmParser::CyclicParseReset> cyclicParse;
657 // Check if this is an identified struct type.
658 if (succeeded(parser.parseOptionalKeyword(&identifier))) {
659 // Check if this is a possible recursive reference.
660 auto structType =
661 StructType::getIdentified(dialect.getContext(), identifier);
662 cyclicParse = parser.tryStartCyclicParse(structType);
663 if (succeeded(parser.parseOptionalGreater())) {
664 if (succeeded(cyclicParse)) {
665 parser.emitError(
666 parser.getNameLoc(),
667 "recursive struct reference not nested in struct definition");
669 return Type();
672 return structType;
675 if (failed(parser.parseComma()))
676 return Type();
678 if (failed(cyclicParse)) {
679 parser.emitError(parser.getNameLoc(),
680 "identifier already used for an enclosing struct");
681 return Type();
685 if (failed(parser.parseLParen()))
686 return Type();
688 if (succeeded(parser.parseOptionalRParen()) &&
689 succeeded(parser.parseOptionalGreater())) {
690 return StructType::getEmpty(dialect.getContext(), identifier);
693 StructType idStructTy;
695 if (!identifier.empty())
696 idStructTy = StructType::getIdentified(dialect.getContext(), identifier);
698 SmallVector<Type, 4> memberTypes;
699 SmallVector<StructType::OffsetInfo, 4> offsetInfo;
700 SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo;
702 do {
703 Type memberType;
704 if (parser.parseType(memberType))
705 return Type();
706 memberTypes.push_back(memberType);
708 if (succeeded(parser.parseOptionalLSquare()))
709 if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
710 memberDecorationInfo))
711 return Type();
712 } while (succeeded(parser.parseOptionalComma()));
714 if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
715 parser.emitError(parser.getNameLoc(),
716 "offset specification must be given for all members");
717 return Type();
720 if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
721 return Type();
723 if (!identifier.empty()) {
724 if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
725 memberDecorationInfo)))
726 return Type();
727 return idStructTy;
730 return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
733 // spirv-type ::= array-type
734 // | element-type
735 // | image-type
736 // | pointer-type
737 // | runtime-array-type
738 // | sampled-image-type
739 // | struct-type
740 Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
741 StringRef keyword;
742 if (parser.parseKeyword(&keyword))
743 return Type();
745 if (keyword == "array")
746 return parseArrayType(*this, parser);
747 if (keyword == "coopmatrix")
748 return parseCooperativeMatrixType(*this, parser);
749 if (keyword == "image")
750 return parseImageType(*this, parser);
751 if (keyword == "ptr")
752 return parsePointerType(*this, parser);
753 if (keyword == "rtarray")
754 return parseRuntimeArrayType(*this, parser);
755 if (keyword == "sampled_image")
756 return parseSampledImageType(*this, parser);
757 if (keyword == "struct")
758 return parseStructType(*this, parser);
759 if (keyword == "matrix")
760 return parseMatrixType(*this, parser);
761 parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
762 return Type();
765 //===----------------------------------------------------------------------===//
766 // Type Printing
767 //===----------------------------------------------------------------------===//
769 static void print(ArrayType type, DialectAsmPrinter &os) {
770 os << "array<" << type.getNumElements() << " x " << type.getElementType();
771 if (unsigned stride = type.getArrayStride())
772 os << ", stride=" << stride;
773 os << ">";
776 static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
777 os << "rtarray<" << type.getElementType();
778 if (unsigned stride = type.getArrayStride())
779 os << ", stride=" << stride;
780 os << ">";
783 static void print(PointerType type, DialectAsmPrinter &os) {
784 os << "ptr<" << type.getPointeeType() << ", "
785 << stringifyStorageClass(type.getStorageClass()) << ">";
788 static void print(ImageType type, DialectAsmPrinter &os) {
789 os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
790 << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
791 << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
792 << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
793 << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
794 << stringifyImageFormat(type.getImageFormat()) << ">";
797 static void print(SampledImageType type, DialectAsmPrinter &os) {
798 os << "sampled_image<" << type.getImageType() << ">";
801 static void print(StructType type, DialectAsmPrinter &os) {
802 FailureOr<AsmPrinter::CyclicPrintReset> cyclicPrint;
804 os << "struct<";
806 if (type.isIdentified()) {
807 os << type.getIdentifier();
809 cyclicPrint = os.tryStartCyclicPrint(type);
810 if (failed(cyclicPrint)) {
811 os << ">";
812 return;
815 os << ", ";
818 os << "(";
820 auto printMember = [&](unsigned i) {
821 os << type.getElementType(i);
822 SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations;
823 type.getMemberDecorations(i, decorations);
824 if (type.hasOffset() || !decorations.empty()) {
825 os << " [";
826 if (type.hasOffset()) {
827 os << type.getMemberOffset(i);
828 if (!decorations.empty())
829 os << ", ";
831 auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
832 os << stringifyDecoration(decoration.decoration);
833 if (decoration.hasValue) {
834 os << "=" << decoration.decorationValue;
837 llvm::interleaveComma(decorations, os, eachFn);
838 os << "]";
841 llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
842 printMember);
843 os << ")>";
846 static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {
847 os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"
848 << type.getElementType() << ", " << type.getScope() << ", "
849 << type.getUse() << ">";
852 static void print(MatrixType type, DialectAsmPrinter &os) {
853 os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
854 os << ">";
857 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
858 TypeSwitch<Type>(type)
859 .Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
860 ImageType, SampledImageType, StructType, MatrixType>(
861 [&](auto type) { print(type, os); })
862 .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
865 //===----------------------------------------------------------------------===//
866 // Constant
867 //===----------------------------------------------------------------------===//
869 Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
870 Attribute value, Type type,
871 Location loc) {
872 if (auto poison = dyn_cast<ub::PoisonAttr>(value))
873 return builder.create<ub::PoisonOp>(loc, type, poison);
875 if (!spirv::ConstantOp::isBuildableWith(type))
876 return nullptr;
878 return builder.create<spirv::ConstantOp>(loc, type, value);
881 //===----------------------------------------------------------------------===//
882 // Shader Interface ABI
883 //===----------------------------------------------------------------------===//
885 LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
886 NamedAttribute attribute) {
887 StringRef symbol = attribute.getName().strref();
888 Attribute attr = attribute.getValue();
890 if (symbol == spirv::getEntryPointABIAttrName()) {
891 if (!llvm::isa<spirv::EntryPointABIAttr>(attr)) {
892 return op->emitError("'")
893 << symbol << "' attribute must be an entry point ABI attribute";
895 } else if (symbol == spirv::getTargetEnvAttrName()) {
896 if (!llvm::isa<spirv::TargetEnvAttr>(attr))
897 return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
898 } else {
899 return op->emitError("found unsupported '")
900 << symbol << "' attribute on operation";
903 return success();
906 /// Verifies the given SPIR-V `attribute` attached to a value of the given
907 /// `valueType` is valid.
908 static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
909 NamedAttribute attribute) {
910 StringRef symbol = attribute.getName().strref();
911 Attribute attr = attribute.getValue();
913 if (symbol == spirv::getInterfaceVarABIAttrName()) {
914 auto varABIAttr = llvm::dyn_cast<spirv::InterfaceVarABIAttr>(attr);
915 if (!varABIAttr)
916 return emitError(loc, "'")
917 << symbol << "' must be a spirv::InterfaceVarABIAttr";
919 if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
920 return emitError(loc, "'") << symbol
921 << "' attribute cannot specify storage class "
922 "when attaching to a non-scalar value";
923 return success();
925 if (symbol == spirv::DecorationAttr::name) {
926 if (!isa<spirv::DecorationAttr>(attr))
927 return emitError(loc, "'")
928 << symbol << "' must be a spirv::DecorationAttr";
929 return success();
932 return emitError(loc, "found unsupported '")
933 << symbol << "' attribute on region argument";
936 LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
937 unsigned regionIndex,
938 unsigned argIndex,
939 NamedAttribute attribute) {
940 auto funcOp = dyn_cast<FunctionOpInterface>(op);
941 if (!funcOp)
942 return success();
943 Type argType = funcOp.getArgumentTypes()[argIndex];
945 return verifyRegionAttribute(op->getLoc(), argType, attribute);
948 LogicalResult SPIRVDialect::verifyRegionResultAttribute(
949 Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
950 NamedAttribute attribute) {
951 return op->emitError("cannot attach SPIR-V attributes to region result");