[rtsan] Remove mkfifoat interceptor (#116997)
[llvm-project.git] / mlir / lib / AsmParser / AttributeParser.cpp
blobff616dac9625b4156dcc6e659a48f600e92ed4c2
1 //===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the MLIR Types.
11 //===----------------------------------------------------------------------===//
13 #include "Parser.h"
15 #include "AsmParserImpl.h"
16 #include "mlir/AsmParser/AsmParserState.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/BuiltinAttributes.h"
19 #include "mlir/IR/BuiltinDialect.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/DialectImplementation.h"
22 #include "mlir/IR/DialectResourceBlobManager.h"
23 #include "mlir/IR/IntegerSet.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/Support/Endian.h"
26 #include <optional>
28 using namespace mlir;
29 using namespace mlir::detail;
31 /// Parse an arbitrary attribute.
32 ///
33 /// attribute-value ::= `unit`
34 /// | bool-literal
35 /// | integer-literal (`:` (index-type | integer-type))?
36 /// | float-literal (`:` float-type)?
37 /// | string-literal (`:` type)?
38 /// | type
39 /// | `[` `:` (integer-type | float-type) tensor-literal `]`
40 /// | `[` (attribute-value (`,` attribute-value)*)? `]`
41 /// | `{` (attribute-entry (`,` attribute-entry)*)? `}`
42 /// | symbol-ref-id (`::` symbol-ref-id)*
43 /// | `dense` `<` tensor-literal `>` `:`
44 /// (tensor-type | vector-type)
45 /// | `sparse` `<` attribute-value `,` attribute-value `>`
46 /// `:` (tensor-type | vector-type)
47 /// | `strided` `<` `[` comma-separated-int-or-question `]`
48 /// (`,` `offset` `:` integer-literal)? `>`
49 /// | distinct-attribute
50 /// | extended-attribute
51 ///
52 Attribute Parser::parseAttribute(Type type) {
53 switch (getToken().getKind()) {
54 // Parse an AffineMap or IntegerSet attribute.
55 case Token::kw_affine_map: {
56 consumeToken(Token::kw_affine_map);
58 AffineMap map;
59 if (parseToken(Token::less, "expected '<' in affine map") ||
60 parseAffineMapReference(map) ||
61 parseToken(Token::greater, "expected '>' in affine map"))
62 return Attribute();
63 return AffineMapAttr::get(map);
65 case Token::kw_affine_set: {
66 consumeToken(Token::kw_affine_set);
68 IntegerSet set;
69 if (parseToken(Token::less, "expected '<' in integer set") ||
70 parseIntegerSetReference(set) ||
71 parseToken(Token::greater, "expected '>' in integer set"))
72 return Attribute();
73 return IntegerSetAttr::get(set);
76 // Parse an array attribute.
77 case Token::l_square: {
78 consumeToken(Token::l_square);
79 SmallVector<Attribute, 4> elements;
80 auto parseElt = [&]() -> ParseResult {
81 elements.push_back(parseAttribute());
82 return elements.back() ? success() : failure();
85 if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
86 return nullptr;
87 return builder.getArrayAttr(elements);
90 // Parse a boolean attribute.
91 case Token::kw_false:
92 consumeToken(Token::kw_false);
93 return builder.getBoolAttr(false);
94 case Token::kw_true:
95 consumeToken(Token::kw_true);
96 return builder.getBoolAttr(true);
98 // Parse a dense elements attribute.
99 case Token::kw_dense:
100 return parseDenseElementsAttr(type);
102 // Parse a dense resource elements attribute.
103 case Token::kw_dense_resource:
104 return parseDenseResourceElementsAttr(type);
106 // Parse a dense array attribute.
107 case Token::kw_array:
108 return parseDenseArrayAttr(type);
110 // Parse a dictionary attribute.
111 case Token::l_brace: {
112 NamedAttrList elements;
113 if (parseAttributeDict(elements))
114 return nullptr;
115 return elements.getDictionary(getContext());
118 // Parse an extended attribute, i.e. alias or dialect attribute.
119 case Token::hash_identifier:
120 return parseExtendedAttr(type);
122 // Parse floating point and integer attributes.
123 case Token::floatliteral:
124 return parseFloatAttr(type, /*isNegative=*/false);
125 case Token::integer:
126 return parseDecOrHexAttr(type, /*isNegative=*/false);
127 case Token::minus: {
128 consumeToken(Token::minus);
129 if (getToken().is(Token::integer))
130 return parseDecOrHexAttr(type, /*isNegative=*/true);
131 if (getToken().is(Token::floatliteral))
132 return parseFloatAttr(type, /*isNegative=*/true);
134 return (emitWrongTokenError(
135 "expected constant integer or floating point value"),
136 nullptr);
139 // Parse a location attribute.
140 case Token::kw_loc: {
141 consumeToken(Token::kw_loc);
143 LocationAttr locAttr;
144 if (parseToken(Token::l_paren, "expected '(' in inline location") ||
145 parseLocationInstance(locAttr) ||
146 parseToken(Token::r_paren, "expected ')' in inline location"))
147 return Attribute();
148 return locAttr;
151 // Parse a sparse elements attribute.
152 case Token::kw_sparse:
153 return parseSparseElementsAttr(type);
155 // Parse a strided layout attribute.
156 case Token::kw_strided:
157 return parseStridedLayoutAttr();
159 // Parse a distinct attribute.
160 case Token::kw_distinct:
161 return parseDistinctAttr(type);
163 // Parse a string attribute.
164 case Token::string: {
165 auto val = getToken().getStringValue();
166 consumeToken(Token::string);
167 // Parse the optional trailing colon type if one wasn't explicitly provided.
168 if (!type && consumeIf(Token::colon) && !(type = parseType()))
169 return Attribute();
171 return type ? StringAttr::get(val, type)
172 : StringAttr::get(getContext(), val);
175 // Parse a symbol reference attribute.
176 case Token::at_identifier: {
177 // When populating the parser state, this is a list of locations for all of
178 // the nested references.
179 SmallVector<SMRange> referenceLocations;
180 if (state.asmState)
181 referenceLocations.push_back(getToken().getLocRange());
183 // Parse the top-level reference.
184 std::string nameStr = getToken().getSymbolReference();
185 consumeToken(Token::at_identifier);
187 // Parse any nested references.
188 std::vector<FlatSymbolRefAttr> nestedRefs;
189 while (getToken().is(Token::colon)) {
190 // Check for the '::' prefix.
191 const char *curPointer = getToken().getLoc().getPointer();
192 consumeToken(Token::colon);
193 if (!consumeIf(Token::colon)) {
194 if (getToken().isNot(Token::eof, Token::error)) {
195 state.lex.resetPointer(curPointer);
196 consumeToken();
198 break;
200 // Parse the reference itself.
201 auto curLoc = getToken().getLoc();
202 if (getToken().isNot(Token::at_identifier)) {
203 emitError(curLoc, "expected nested symbol reference identifier");
204 return Attribute();
207 // If we are populating the assembly state, add the location for this
208 // reference.
209 if (state.asmState)
210 referenceLocations.push_back(getToken().getLocRange());
212 std::string nameStr = getToken().getSymbolReference();
213 consumeToken(Token::at_identifier);
214 nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr));
216 SymbolRefAttr symbolRefAttr =
217 SymbolRefAttr::get(getContext(), nameStr, nestedRefs);
219 // If we are populating the assembly state, record this symbol reference.
220 if (state.asmState)
221 state.asmState->addUses(symbolRefAttr, referenceLocations);
222 return symbolRefAttr;
225 // Parse a 'unit' attribute.
226 case Token::kw_unit:
227 consumeToken(Token::kw_unit);
228 return builder.getUnitAttr();
230 // Handle completion of an attribute.
231 case Token::code_complete:
232 if (getToken().isCodeCompletionFor(Token::hash_identifier))
233 return parseExtendedAttr(type);
234 return codeCompleteAttribute();
236 default:
237 // Parse a type attribute. We parse `Optional` here to allow for providing a
238 // better error message.
239 Type type;
240 OptionalParseResult result = parseOptionalType(type);
241 if (!result.has_value())
242 return emitWrongTokenError("expected attribute value"), Attribute();
243 return failed(*result) ? Attribute() : TypeAttr::get(type);
247 /// Parse an optional attribute with the provided type.
248 OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
249 Type type) {
250 switch (getToken().getKind()) {
251 case Token::at_identifier:
252 case Token::floatliteral:
253 case Token::integer:
254 case Token::hash_identifier:
255 case Token::kw_affine_map:
256 case Token::kw_affine_set:
257 case Token::kw_dense:
258 case Token::kw_dense_resource:
259 case Token::kw_false:
260 case Token::kw_loc:
261 case Token::kw_sparse:
262 case Token::kw_true:
263 case Token::kw_unit:
264 case Token::l_brace:
265 case Token::l_square:
266 case Token::minus:
267 case Token::string:
268 attribute = parseAttribute(type);
269 return success(attribute != nullptr);
271 default:
272 // Parse an optional type attribute.
273 Type type;
274 OptionalParseResult result = parseOptionalType(type);
275 if (result.has_value() && succeeded(*result))
276 attribute = TypeAttr::get(type);
277 return result;
280 OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute,
281 Type type) {
282 return parseOptionalAttributeWithToken(Token::l_square, attribute, type);
284 OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute,
285 Type type) {
286 return parseOptionalAttributeWithToken(Token::string, attribute, type);
288 OptionalParseResult Parser::parseOptionalAttribute(SymbolRefAttr &result,
289 Type type) {
290 return parseOptionalAttributeWithToken(Token::at_identifier, result, type);
293 /// Attribute dictionary.
295 /// attribute-dict ::= `{` `}`
296 /// | `{` attribute-entry (`,` attribute-entry)* `}`
297 /// attribute-entry ::= (bare-id | string-literal) `=` attribute-value
299 ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
300 llvm::SmallDenseSet<StringAttr> seenKeys;
301 auto parseElt = [&]() -> ParseResult {
302 // The name of an attribute can either be a bare identifier, or a string.
303 std::optional<StringAttr> nameId;
304 if (getToken().is(Token::string))
305 nameId = builder.getStringAttr(getToken().getStringValue());
306 else if (getToken().isAny(Token::bare_identifier, Token::inttype) ||
307 getToken().isKeyword())
308 nameId = builder.getStringAttr(getTokenSpelling());
309 else
310 return emitWrongTokenError("expected attribute name");
312 if (nameId->empty())
313 return emitError("expected valid attribute name");
315 if (!seenKeys.insert(*nameId).second)
316 return emitError("duplicate key '")
317 << nameId->getValue() << "' in dictionary attribute";
318 consumeToken();
320 // Lazy load a dialect in the context if there is a possible namespace.
321 auto splitName = nameId->strref().split('.');
322 if (!splitName.second.empty())
323 getContext()->getOrLoadDialect(splitName.first);
325 // Try to parse the '=' for the attribute value.
326 if (!consumeIf(Token::equal)) {
327 // If there is no '=', we treat this as a unit attribute.
328 attributes.push_back({*nameId, builder.getUnitAttr()});
329 return success();
332 auto attr = parseAttribute();
333 if (!attr)
334 return failure();
335 attributes.push_back({*nameId, attr});
336 return success();
339 return parseCommaSeparatedList(Delimiter::Braces, parseElt,
340 " in attribute dictionary");
343 /// Parse a float attribute.
344 Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
345 auto val = getToken().getFloatingPointValue();
346 if (!val)
347 return (emitError("floating point value too large for attribute"), nullptr);
348 consumeToken(Token::floatliteral);
349 if (!type) {
350 // Default to F64 when no type is specified.
351 if (!consumeIf(Token::colon))
352 type = builder.getF64Type();
353 else if (!(type = parseType()))
354 return nullptr;
356 if (!isa<FloatType>(type))
357 return (emitError("floating point value not valid for specified type"),
358 nullptr);
359 return FloatAttr::get(type, isNegative ? -*val : *val);
362 /// Construct an APint from a parsed value, a known attribute type and
363 /// sign.
364 static std::optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
365 StringRef spelling) {
366 // Parse the integer value into an APInt that is big enough to hold the value.
367 APInt result;
368 bool isHex = spelling.size() > 1 && spelling[1] == 'x';
369 if (spelling.getAsInteger(isHex ? 0 : 10, result))
370 return std::nullopt;
372 // Extend or truncate the bitwidth to the right size.
373 unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
374 : type.getIntOrFloatBitWidth();
376 if (width > result.getBitWidth()) {
377 result = result.zext(width);
378 } else if (width < result.getBitWidth()) {
379 // The parser can return an unnecessarily wide result with leading zeros.
380 // This isn't a problem, but truncating off bits is bad.
381 if (result.countl_zero() < result.getBitWidth() - width)
382 return std::nullopt;
384 result = result.trunc(width);
387 if (width == 0) {
388 // 0 bit integers cannot be negative and manipulation of their sign bit will
389 // assert, so short-cut validation here.
390 if (isNegative)
391 return std::nullopt;
392 } else if (isNegative) {
393 // The value is negative, we have an overflow if the sign bit is not set
394 // in the negated apInt.
395 result.negate();
396 if (!result.isSignBitSet())
397 return std::nullopt;
398 } else if ((type.isSignedInteger() || type.isIndex()) &&
399 result.isSignBitSet()) {
400 // The value is a positive signed integer or index,
401 // we have an overflow if the sign bit is set.
402 return std::nullopt;
405 return result;
408 /// Parse a decimal or a hexadecimal literal, which can be either an integer
409 /// or a float attribute.
410 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
411 Token tok = getToken();
412 StringRef spelling = tok.getSpelling();
413 SMLoc loc = tok.getLoc();
415 consumeToken(Token::integer);
416 if (!type) {
417 // Default to i64 if not type is specified.
418 if (!consumeIf(Token::colon))
419 type = builder.getIntegerType(64);
420 else if (!(type = parseType()))
421 return nullptr;
424 if (auto floatType = dyn_cast<FloatType>(type)) {
425 std::optional<APFloat> result;
426 if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative,
427 floatType.getFloatSemantics())))
428 return Attribute();
429 return FloatAttr::get(floatType, *result);
432 if (!isa<IntegerType, IndexType>(type))
433 return emitError(loc, "integer literal not valid for specified type"),
434 nullptr;
436 if (isNegative && type.isUnsignedInteger()) {
437 emitError(loc,
438 "negative integer literal not valid for unsigned integer type");
439 return nullptr;
442 std::optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling);
443 if (!apInt)
444 return emitError(loc, "integer constant out of range for attribute"),
445 nullptr;
446 return builder.getIntegerAttr(type, *apInt);
449 //===----------------------------------------------------------------------===//
450 // TensorLiteralParser
451 //===----------------------------------------------------------------------===//
453 /// Parse elements values stored within a hex string. On success, the values are
454 /// stored into 'result'.
455 static ParseResult parseElementAttrHexValues(Parser &parser, Token tok,
456 std::string &result) {
457 if (std::optional<std::string> value = tok.getHexStringValue()) {
458 result = std::move(*value);
459 return success();
461 return parser.emitError(
462 tok.getLoc(), "expected string containing hex digits starting with `0x`");
465 namespace {
466 /// This class implements a parser for TensorLiterals. A tensor literal is
467 /// either a single element (e.g, 5) or a multi-dimensional list of elements
468 /// (e.g., [[5, 5]]).
469 class TensorLiteralParser {
470 public:
471 TensorLiteralParser(Parser &p) : p(p) {}
473 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
474 /// may also parse a tensor literal that is store as a hex string.
475 ParseResult parse(bool allowHex);
477 /// Build a dense attribute instance with the parsed elements and the given
478 /// shaped type.
479 DenseElementsAttr getAttr(SMLoc loc, ShapedType type);
481 ArrayRef<int64_t> getShape() const { return shape; }
483 private:
484 /// Get the parsed elements for an integer attribute.
485 ParseResult getIntAttrElements(SMLoc loc, Type eltTy,
486 std::vector<APInt> &intValues);
488 /// Get the parsed elements for a float attribute.
489 ParseResult getFloatAttrElements(SMLoc loc, FloatType eltTy,
490 std::vector<APFloat> &floatValues);
492 /// Build a Dense String attribute for the given type.
493 DenseElementsAttr getStringAttr(SMLoc loc, ShapedType type, Type eltTy);
495 /// Build a Dense attribute with hex data for the given type.
496 DenseElementsAttr getHexAttr(SMLoc loc, ShapedType type);
498 /// Parse a single element, returning failure if it isn't a valid element
499 /// literal. For example:
500 /// parseElement(1) -> Success, 1
501 /// parseElement([1]) -> Failure
502 ParseResult parseElement();
504 /// Parse a list of either lists or elements, returning the dimensions of the
505 /// parsed sub-tensors in dims. For example:
506 /// parseList([1, 2, 3]) -> Success, [3]
507 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
508 /// parseList([[1, 2], 3]) -> Failure
509 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
510 ParseResult parseList(SmallVectorImpl<int64_t> &dims);
512 /// Parse a literal that was printed as a hex string.
513 ParseResult parseHexElements();
515 Parser &p;
517 /// The shape inferred from the parsed elements.
518 SmallVector<int64_t, 4> shape;
520 /// Storage used when parsing elements, this is a pair of <is_negated, token>.
521 std::vector<std::pair<bool, Token>> storage;
523 /// Storage used when parsing elements that were stored as hex values.
524 std::optional<Token> hexStorage;
526 } // namespace
528 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
529 /// may also parse a tensor literal that is store as a hex string.
530 ParseResult TensorLiteralParser::parse(bool allowHex) {
531 // If hex is allowed, check for a string literal.
532 if (allowHex && p.getToken().is(Token::string)) {
533 hexStorage = p.getToken();
534 p.consumeToken(Token::string);
535 return success();
537 // Otherwise, parse a list or an individual element.
538 if (p.getToken().is(Token::l_square))
539 return parseList(shape);
540 return parseElement();
543 /// Build a dense attribute instance with the parsed elements and the given
544 /// shaped type.
545 DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
546 Type eltType = type.getElementType();
548 // Check to see if we parse the literal from a hex string.
549 if (hexStorage &&
550 (eltType.isIntOrIndexOrFloat() || isa<ComplexType>(eltType)))
551 return getHexAttr(loc, type);
553 // Check that the parsed storage size has the same number of elements to the
554 // type, or is a known splat.
555 if (!shape.empty() && getShape() != type.getShape()) {
556 p.emitError(loc) << "inferred shape of elements literal ([" << getShape()
557 << "]) does not match type ([" << type.getShape() << "])";
558 return nullptr;
561 // Handle the case where no elements were parsed.
562 if (!hexStorage && storage.empty() && type.getNumElements()) {
563 p.emitError(loc) << "parsed zero elements, but type (" << type
564 << ") expected at least 1";
565 return nullptr;
568 // Handle complex types in the specific element type cases below.
569 bool isComplex = false;
570 if (ComplexType complexTy = dyn_cast<ComplexType>(eltType)) {
571 eltType = complexTy.getElementType();
572 isComplex = true;
575 // Handle integer and index types.
576 if (eltType.isIntOrIndex()) {
577 std::vector<APInt> intValues;
578 if (failed(getIntAttrElements(loc, eltType, intValues)))
579 return nullptr;
580 if (isComplex) {
581 // If this is a complex, treat the parsed values as complex values.
582 auto complexData = llvm::ArrayRef(
583 reinterpret_cast<std::complex<APInt> *>(intValues.data()),
584 intValues.size() / 2);
585 return DenseElementsAttr::get(type, complexData);
587 return DenseElementsAttr::get(type, intValues);
589 // Handle floating point types.
590 if (FloatType floatTy = dyn_cast<FloatType>(eltType)) {
591 std::vector<APFloat> floatValues;
592 if (failed(getFloatAttrElements(loc, floatTy, floatValues)))
593 return nullptr;
594 if (isComplex) {
595 // If this is a complex, treat the parsed values as complex values.
596 auto complexData = llvm::ArrayRef(
597 reinterpret_cast<std::complex<APFloat> *>(floatValues.data()),
598 floatValues.size() / 2);
599 return DenseElementsAttr::get(type, complexData);
601 return DenseElementsAttr::get(type, floatValues);
604 // Other types are assumed to be string representations.
605 return getStringAttr(loc, type, type.getElementType());
608 /// Build a Dense Integer attribute for the given type.
609 ParseResult
610 TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy,
611 std::vector<APInt> &intValues) {
612 intValues.reserve(storage.size());
613 bool isUintType = eltTy.isUnsignedInteger();
614 for (const auto &signAndToken : storage) {
615 bool isNegative = signAndToken.first;
616 const Token &token = signAndToken.second;
617 auto tokenLoc = token.getLoc();
619 if (isNegative && isUintType) {
620 return p.emitError(tokenLoc)
621 << "expected unsigned integer elements, but parsed negative value";
624 // Check to see if floating point values were parsed.
625 if (token.is(Token::floatliteral)) {
626 return p.emitError(tokenLoc)
627 << "expected integer elements, but parsed floating-point";
630 assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
631 "unexpected token type");
632 if (token.isAny(Token::kw_true, Token::kw_false)) {
633 if (!eltTy.isInteger(1)) {
634 return p.emitError(tokenLoc)
635 << "expected i1 type for 'true' or 'false' values";
637 APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false);
638 intValues.push_back(apInt);
639 continue;
642 // Create APInt values for each element with the correct bitwidth.
643 std::optional<APInt> apInt =
644 buildAttributeAPInt(eltTy, isNegative, token.getSpelling());
645 if (!apInt)
646 return p.emitError(tokenLoc, "integer constant out of range for type");
647 intValues.push_back(*apInt);
649 return success();
652 /// Build a Dense Float attribute for the given type.
653 ParseResult
654 TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
655 std::vector<APFloat> &floatValues) {
656 floatValues.reserve(storage.size());
657 for (const auto &signAndToken : storage) {
658 bool isNegative = signAndToken.first;
659 const Token &token = signAndToken.second;
660 std::optional<APFloat> result;
661 if (failed(p.parseFloatFromLiteral(result, token, isNegative,
662 eltTy.getFloatSemantics())))
663 return failure();
664 floatValues.push_back(*result);
666 return success();
669 /// Build a Dense String attribute for the given type.
670 DenseElementsAttr TensorLiteralParser::getStringAttr(SMLoc loc, ShapedType type,
671 Type eltTy) {
672 if (hexStorage.has_value()) {
673 auto stringValue = hexStorage->getStringValue();
674 return DenseStringElementsAttr::get(type, {stringValue});
677 std::vector<std::string> stringValues;
678 std::vector<StringRef> stringRefValues;
679 stringValues.reserve(storage.size());
680 stringRefValues.reserve(storage.size());
682 for (auto val : storage) {
683 stringValues.push_back(val.second.getStringValue());
684 stringRefValues.emplace_back(stringValues.back());
687 return DenseStringElementsAttr::get(type, stringRefValues);
690 /// Build a Dense attribute with hex data for the given type.
691 DenseElementsAttr TensorLiteralParser::getHexAttr(SMLoc loc, ShapedType type) {
692 Type elementType = type.getElementType();
693 if (!elementType.isIntOrIndexOrFloat() && !isa<ComplexType>(elementType)) {
694 p.emitError(loc)
695 << "expected floating-point, integer, or complex element type, got "
696 << elementType;
697 return nullptr;
700 std::string data;
701 if (parseElementAttrHexValues(p, *hexStorage, data))
702 return nullptr;
704 ArrayRef<char> rawData(data.data(), data.size());
705 bool detectedSplat = false;
706 if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) {
707 p.emitError(loc) << "elements hex data size is invalid for provided type: "
708 << type;
709 return nullptr;
712 if (llvm::endianness::native == llvm::endianness::big) {
713 // Convert endianess in big-endian(BE) machines. `rawData` is
714 // little-endian(LE) because HEX in raw data of dense element attribute
715 // is always LE format. It is converted into BE here to be used in BE
716 // machines.
717 SmallVector<char, 64> outDataVec(rawData.size());
718 MutableArrayRef<char> convRawData(outDataVec);
719 DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
720 rawData, convRawData, type);
721 return DenseElementsAttr::getFromRawBuffer(type, convRawData);
724 return DenseElementsAttr::getFromRawBuffer(type, rawData);
727 ParseResult TensorLiteralParser::parseElement() {
728 switch (p.getToken().getKind()) {
729 // Parse a boolean element.
730 case Token::kw_true:
731 case Token::kw_false:
732 case Token::floatliteral:
733 case Token::integer:
734 storage.emplace_back(/*isNegative=*/false, p.getToken());
735 p.consumeToken();
736 break;
738 // Parse a signed integer or a negative floating-point element.
739 case Token::minus:
740 p.consumeToken(Token::minus);
741 if (!p.getToken().isAny(Token::floatliteral, Token::integer))
742 return p.emitError("expected integer or floating point literal");
743 storage.emplace_back(/*isNegative=*/true, p.getToken());
744 p.consumeToken();
745 break;
747 case Token::string:
748 storage.emplace_back(/*isNegative=*/false, p.getToken());
749 p.consumeToken();
750 break;
752 // Parse a complex element of the form '(' element ',' element ')'.
753 case Token::l_paren:
754 p.consumeToken(Token::l_paren);
755 if (parseElement() ||
756 p.parseToken(Token::comma, "expected ',' between complex elements") ||
757 parseElement() ||
758 p.parseToken(Token::r_paren, "expected ')' after complex elements"))
759 return failure();
760 break;
762 default:
763 return p.emitError("expected element literal of primitive type");
766 return success();
769 /// Parse a list of either lists or elements, returning the dimensions of the
770 /// parsed sub-tensors in dims. For example:
771 /// parseList([1, 2, 3]) -> Success, [3]
772 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
773 /// parseList([[1, 2], 3]) -> Failure
774 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
775 ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
776 auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
777 const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
778 if (prevDims == newDims)
779 return success();
780 return p.emitError("tensor literal is invalid; ranks are not consistent "
781 "between elements");
784 bool first = true;
785 SmallVector<int64_t, 4> newDims;
786 unsigned size = 0;
787 auto parseOneElement = [&]() -> ParseResult {
788 SmallVector<int64_t, 4> thisDims;
789 if (p.getToken().getKind() == Token::l_square) {
790 if (parseList(thisDims))
791 return failure();
792 } else if (parseElement()) {
793 return failure();
795 ++size;
796 if (!first)
797 return checkDims(newDims, thisDims);
798 newDims = thisDims;
799 first = false;
800 return success();
802 if (p.parseCommaSeparatedList(Parser::Delimiter::Square, parseOneElement))
803 return failure();
805 // Return the sublists' dimensions with 'size' prepended.
806 dims.clear();
807 dims.push_back(size);
808 dims.append(newDims.begin(), newDims.end());
809 return success();
812 //===----------------------------------------------------------------------===//
813 // DenseArrayAttr Parser
814 //===----------------------------------------------------------------------===//
816 namespace {
817 /// A generic dense array element parser. It parsers integer and floating point
818 /// elements.
819 class DenseArrayElementParser {
820 public:
821 explicit DenseArrayElementParser(Type type) : type(type) {}
823 /// Parse an integer element.
824 ParseResult parseIntegerElement(Parser &p);
826 /// Parse a floating point element.
827 ParseResult parseFloatElement(Parser &p);
829 /// Convert the current contents to a dense array.
830 DenseArrayAttr getAttr() { return DenseArrayAttr::get(type, size, rawData); }
832 private:
833 /// Append the raw data of an APInt to the result.
834 void append(const APInt &data);
836 /// The array element type.
837 Type type;
838 /// The resultant byte array representing the contents of the array.
839 std::vector<char> rawData;
840 /// The number of elements in the array.
841 int64_t size = 0;
843 } // namespace
845 void DenseArrayElementParser::append(const APInt &data) {
846 if (data.getBitWidth()) {
847 assert(data.getBitWidth() % 8 == 0);
848 unsigned byteSize = data.getBitWidth() / 8;
849 size_t offset = rawData.size();
850 rawData.insert(rawData.end(), byteSize, 0);
851 llvm::StoreIntToMemory(
852 data, reinterpret_cast<uint8_t *>(rawData.data() + offset), byteSize);
854 ++size;
857 ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) {
858 bool isNegative = p.consumeIf(Token::minus);
860 // Parse an integer literal as an APInt.
861 std::optional<APInt> value;
862 StringRef spelling = p.getToken().getSpelling();
863 if (p.getToken().isAny(Token::kw_true, Token::kw_false)) {
864 if (!type.isInteger(1))
865 return p.emitError("expected i1 type for 'true' or 'false' values");
866 value = APInt(/*numBits=*/8, p.getToken().is(Token::kw_true),
867 !type.isUnsignedInteger());
868 p.consumeToken();
869 } else if (p.consumeIf(Token::integer)) {
870 value = buildAttributeAPInt(type, isNegative, spelling);
871 if (!value)
872 return p.emitError("integer constant out of range");
873 } else {
874 return p.emitError("expected integer literal");
876 append(*value);
877 return success();
880 ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
881 bool isNegative = p.consumeIf(Token::minus);
882 Token token = p.getToken();
883 std::optional<APFloat> fromIntLit;
884 if (failed(
885 p.parseFloatFromLiteral(fromIntLit, token, isNegative,
886 cast<FloatType>(type).getFloatSemantics())))
887 return failure();
888 p.consumeToken();
889 append(fromIntLit->bitcastToAPInt());
890 return success();
893 /// Parse a dense array attribute.
894 Attribute Parser::parseDenseArrayAttr(Type attrType) {
895 consumeToken(Token::kw_array);
896 if (parseToken(Token::less, "expected '<' after 'array'"))
897 return {};
899 SMLoc typeLoc = getToken().getLoc();
900 Type eltType = parseType();
901 if (!eltType) {
902 emitError(typeLoc, "expected an integer or floating point type");
903 return {};
906 // Only bool or integer and floating point elements divisible by bytes are
907 // supported.
908 if (!eltType.isIntOrIndexOrFloat()) {
909 emitError(typeLoc, "expected integer or float type, got: ") << eltType;
910 return {};
912 if (!eltType.isInteger(1) && eltType.getIntOrFloatBitWidth() % 8 != 0) {
913 emitError(typeLoc, "element type bitwidth must be a multiple of 8");
914 return {};
917 // Check for empty list.
918 if (consumeIf(Token::greater))
919 return DenseArrayAttr::get(eltType, 0, {});
921 if (parseToken(Token::colon, "expected ':' after dense array type"))
922 return {};
924 DenseArrayElementParser eltParser(eltType);
925 if (eltType.isIntOrIndex()) {
926 if (parseCommaSeparatedList(
927 [&] { return eltParser.parseIntegerElement(*this); }))
928 return {};
929 } else {
930 if (parseCommaSeparatedList(
931 [&] { return eltParser.parseFloatElement(*this); }))
932 return {};
934 if (parseToken(Token::greater, "expected '>' to close an array attribute"))
935 return {};
936 return eltParser.getAttr();
939 /// Parse a dense elements attribute.
940 Attribute Parser::parseDenseElementsAttr(Type attrType) {
941 auto attribLoc = getToken().getLoc();
942 consumeToken(Token::kw_dense);
943 if (parseToken(Token::less, "expected '<' after 'dense'"))
944 return nullptr;
946 // Parse the literal data if necessary.
947 TensorLiteralParser literalParser(*this);
948 if (!consumeIf(Token::greater)) {
949 if (literalParser.parse(/*allowHex=*/true) ||
950 parseToken(Token::greater, "expected '>'"))
951 return nullptr;
954 // If the type is specified `parseElementsLiteralType` will not parse a type.
955 // Use the attribute location as the location for error reporting in that
956 // case.
957 auto loc = attrType ? attribLoc : getToken().getLoc();
958 auto type = parseElementsLiteralType(attrType);
959 if (!type)
960 return nullptr;
961 return literalParser.getAttr(loc, type);
964 Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
965 auto loc = getToken().getLoc();
966 consumeToken(Token::kw_dense_resource);
967 if (parseToken(Token::less, "expected '<' after 'dense_resource'"))
968 return nullptr;
970 // Parse the resource handle.
971 FailureOr<AsmDialectResourceHandle> rawHandle =
972 parseResourceHandle(getContext()->getLoadedDialect<BuiltinDialect>());
973 if (failed(rawHandle) || parseToken(Token::greater, "expected '>'"))
974 return nullptr;
976 auto *handle = dyn_cast<DenseResourceElementsHandle>(&*rawHandle);
977 if (!handle)
978 return emitError(loc, "invalid `dense_resource` handle type"), nullptr;
980 // Parse the type of the attribute if the user didn't provide one.
981 SMLoc typeLoc = loc;
982 if (!attrType) {
983 typeLoc = getToken().getLoc();
984 if (parseToken(Token::colon, "expected ':'") || !(attrType = parseType()))
985 return nullptr;
988 ShapedType shapedType = dyn_cast<ShapedType>(attrType);
989 if (!shapedType) {
990 emitError(typeLoc, "`dense_resource` expected a shaped type");
991 return nullptr;
994 return DenseResourceElementsAttr::get(shapedType, *handle);
997 /// Shaped type for elements attribute.
999 /// elements-literal-type ::= vector-type | ranked-tensor-type
1001 /// This method also checks the type has static shape.
1002 ShapedType Parser::parseElementsLiteralType(Type type) {
1003 // If the user didn't provide a type, parse the colon type for the literal.
1004 if (!type) {
1005 if (parseToken(Token::colon, "expected ':'"))
1006 return nullptr;
1007 if (!(type = parseType()))
1008 return nullptr;
1011 auto sType = dyn_cast<ShapedType>(type);
1012 if (!sType) {
1013 emitError("elements literal must be a shaped type");
1014 return nullptr;
1017 if (!sType.hasStaticShape())
1018 return (emitError("elements literal type must have static shape"), nullptr);
1020 return sType;
1023 /// Parse a sparse elements attribute.
1024 Attribute Parser::parseSparseElementsAttr(Type attrType) {
1025 SMLoc loc = getToken().getLoc();
1026 consumeToken(Token::kw_sparse);
1027 if (parseToken(Token::less, "Expected '<' after 'sparse'"))
1028 return nullptr;
1030 // Check for the case where all elements are sparse. The indices are
1031 // represented by a 2-dimensional shape where the second dimension is the rank
1032 // of the type.
1033 Type indiceEltType = builder.getIntegerType(64);
1034 if (consumeIf(Token::greater)) {
1035 ShapedType type = parseElementsLiteralType(attrType);
1036 if (!type)
1037 return nullptr;
1039 // Construct the sparse elements attr using zero element indice/value
1040 // attributes.
1041 ShapedType indicesType =
1042 RankedTensorType::get({0, type.getRank()}, indiceEltType);
1043 ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
1044 return getChecked<SparseElementsAttr>(
1045 loc, type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
1046 DenseElementsAttr::get(valuesType, ArrayRef<Attribute>()));
1049 /// Parse the indices. We don't allow hex values here as we may need to use
1050 /// the inferred shape.
1051 auto indicesLoc = getToken().getLoc();
1052 TensorLiteralParser indiceParser(*this);
1053 if (indiceParser.parse(/*allowHex=*/false))
1054 return nullptr;
1056 if (parseToken(Token::comma, "expected ','"))
1057 return nullptr;
1059 /// Parse the values.
1060 auto valuesLoc = getToken().getLoc();
1061 TensorLiteralParser valuesParser(*this);
1062 if (valuesParser.parse(/*allowHex=*/true))
1063 return nullptr;
1065 if (parseToken(Token::greater, "expected '>'"))
1066 return nullptr;
1068 auto type = parseElementsLiteralType(attrType);
1069 if (!type)
1070 return nullptr;
1072 // If the indices are a splat, i.e. the literal parser parsed an element and
1073 // not a list, we set the shape explicitly. The indices are represented by a
1074 // 2-dimensional shape where the second dimension is the rank of the type.
1075 // Given that the parsed indices is a splat, we know that we only have one
1076 // indice and thus one for the first dimension.
1077 ShapedType indicesType;
1078 if (indiceParser.getShape().empty()) {
1079 indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
1080 } else {
1081 // Otherwise, set the shape to the one parsed by the literal parser.
1082 indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
1084 auto indices = indiceParser.getAttr(indicesLoc, indicesType);
1086 // If the values are a splat, set the shape explicitly based on the number of
1087 // indices. The number of indices is encoded in the first dimension of the
1088 // indice shape type.
1089 auto valuesEltType = type.getElementType();
1090 ShapedType valuesType =
1091 valuesParser.getShape().empty()
1092 ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
1093 : RankedTensorType::get(valuesParser.getShape(), valuesEltType);
1094 auto values = valuesParser.getAttr(valuesLoc, valuesType);
1096 // Build the sparse elements attribute by the indices and values.
1097 return getChecked<SparseElementsAttr>(loc, type, indices, values);
1100 Attribute Parser::parseStridedLayoutAttr() {
1101 // Callback for error emissing at the keyword token location.
1102 llvm::SMLoc loc = getToken().getLoc();
1103 auto errorEmitter = [&] { return emitError(loc); };
1105 consumeToken(Token::kw_strided);
1106 if (failed(parseToken(Token::less, "expected '<' after 'strided'")) ||
1107 failed(parseToken(Token::l_square, "expected '['")))
1108 return nullptr;
1110 // Parses either an integer token or a question mark token. Reports an error
1111 // and returns std::nullopt if the current token is neither. The integer token
1112 // must fit into int64_t limits.
1113 auto parseStrideOrOffset = [&]() -> std::optional<int64_t> {
1114 if (consumeIf(Token::question))
1115 return ShapedType::kDynamic;
1117 SMLoc loc = getToken().getLoc();
1118 auto emitWrongTokenError = [&] {
1119 emitError(loc, "expected a 64-bit signed integer or '?'");
1120 return std::nullopt;
1123 bool negative = consumeIf(Token::minus);
1125 if (getToken().is(Token::integer)) {
1126 std::optional<uint64_t> value = getToken().getUInt64IntegerValue();
1127 if (!value ||
1128 *value > static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
1129 return emitWrongTokenError();
1130 consumeToken();
1131 auto result = static_cast<int64_t>(*value);
1132 if (negative)
1133 result = -result;
1135 return result;
1138 return emitWrongTokenError();
1141 // Parse strides.
1142 SmallVector<int64_t> strides;
1143 if (!getToken().is(Token::r_square)) {
1144 do {
1145 std::optional<int64_t> stride = parseStrideOrOffset();
1146 if (!stride)
1147 return nullptr;
1148 strides.push_back(*stride);
1149 } while (consumeIf(Token::comma));
1152 if (failed(parseToken(Token::r_square, "expected ']'")))
1153 return nullptr;
1155 // Fast path in absence of offset.
1156 if (consumeIf(Token::greater)) {
1157 if (failed(StridedLayoutAttr::verify(errorEmitter,
1158 /*offset=*/0, strides)))
1159 return nullptr;
1160 return StridedLayoutAttr::get(getContext(), /*offset=*/0, strides);
1163 if (failed(parseToken(Token::comma, "expected ','")) ||
1164 failed(parseToken(Token::kw_offset, "expected 'offset' after comma")) ||
1165 failed(parseToken(Token::colon, "expected ':' after 'offset'")))
1166 return nullptr;
1168 std::optional<int64_t> offset = parseStrideOrOffset();
1169 if (!offset || failed(parseToken(Token::greater, "expected '>'")))
1170 return nullptr;
1172 if (failed(StridedLayoutAttr::verify(errorEmitter, *offset, strides)))
1173 return nullptr;
1174 return StridedLayoutAttr::get(getContext(), *offset, strides);
1175 // return getChecked<StridedLayoutAttr>(loc,getContext(), *offset, strides);
1178 /// Parse a distinct attribute.
1180 /// distinct-attribute ::= `distinct`
1181 /// `[` integer-literal `]<` attribute-value `>`
1183 Attribute Parser::parseDistinctAttr(Type type) {
1184 SMLoc loc = getToken().getLoc();
1185 consumeToken(Token::kw_distinct);
1186 if (parseToken(Token::l_square, "expected '[' after 'distinct'"))
1187 return {};
1189 // Parse the distinct integer identifier.
1190 Token token = getToken();
1191 if (parseToken(Token::integer, "expected distinct ID"))
1192 return {};
1193 std::optional<uint64_t> value = token.getUInt64IntegerValue();
1194 if (!value) {
1195 emitError("expected an unsigned 64-bit integer");
1196 return {};
1199 // Parse the referenced attribute.
1200 if (parseToken(Token::r_square, "expected ']' to close distinct ID") ||
1201 parseToken(Token::less, "expected '<' after distinct ID"))
1202 return {};
1204 Attribute referencedAttr;
1205 if (getToken().is(Token::greater)) {
1206 consumeToken();
1207 referencedAttr = builder.getUnitAttr();
1208 } else {
1209 referencedAttr = parseAttribute(type);
1210 if (!referencedAttr) {
1211 emitError("expected attribute");
1212 return {};
1215 if (parseToken(Token::greater, "expected '>' to close distinct attribute"))
1216 return {};
1219 // Add the distinct attribute to the parser state, if it has not been parsed
1220 // before. Otherwise, check if the parsed reference attribute matches the one
1221 // found in the parser state.
1222 DenseMap<uint64_t, DistinctAttr> &distinctAttrs =
1223 state.symbols.distinctAttributes;
1224 auto it = distinctAttrs.find(*value);
1225 if (it == distinctAttrs.end()) {
1226 DistinctAttr distinctAttr = DistinctAttr::create(referencedAttr);
1227 it = distinctAttrs.try_emplace(*value, distinctAttr).first;
1228 } else if (it->getSecond().getReferencedAttr() != referencedAttr) {
1229 emitError(loc, "referenced attribute does not match previous definition: ")
1230 << it->getSecond().getReferencedAttr();
1231 return {};
1234 return it->getSecond();