1 //===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the parser for the MLIR Types.
11 //===----------------------------------------------------------------------===//
15 #include "AsmParserImpl.h"
16 #include "mlir/AsmParser/AsmParserState.h"
17 #include "mlir/IR/AffineMap.h"
18 #include "mlir/IR/BuiltinAttributes.h"
19 #include "mlir/IR/BuiltinDialect.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/DialectImplementation.h"
22 #include "mlir/IR/DialectResourceBlobManager.h"
23 #include "mlir/IR/IntegerSet.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/Support/Endian.h"
29 using namespace mlir::detail
;
31 /// Parse an arbitrary attribute.
33 /// attribute-value ::= `unit`
35 /// | integer-literal (`:` (index-type | integer-type))?
36 /// | float-literal (`:` float-type)?
37 /// | string-literal (`:` type)?
39 /// | `[` `:` (integer-type | float-type) tensor-literal `]`
40 /// | `[` (attribute-value (`,` attribute-value)*)? `]`
41 /// | `{` (attribute-entry (`,` attribute-entry)*)? `}`
42 /// | symbol-ref-id (`::` symbol-ref-id)*
43 /// | `dense` `<` tensor-literal `>` `:`
44 /// (tensor-type | vector-type)
45 /// | `sparse` `<` attribute-value `,` attribute-value `>`
46 /// `:` (tensor-type | vector-type)
47 /// | `strided` `<` `[` comma-separated-int-or-question `]`
48 /// (`,` `offset` `:` integer-literal)? `>`
49 /// | distinct-attribute
50 /// | extended-attribute
52 Attribute
Parser::parseAttribute(Type type
) {
53 switch (getToken().getKind()) {
54 // Parse an AffineMap or IntegerSet attribute.
55 case Token::kw_affine_map
: {
56 consumeToken(Token::kw_affine_map
);
59 if (parseToken(Token::less
, "expected '<' in affine map") ||
60 parseAffineMapReference(map
) ||
61 parseToken(Token::greater
, "expected '>' in affine map"))
63 return AffineMapAttr::get(map
);
65 case Token::kw_affine_set
: {
66 consumeToken(Token::kw_affine_set
);
69 if (parseToken(Token::less
, "expected '<' in integer set") ||
70 parseIntegerSetReference(set
) ||
71 parseToken(Token::greater
, "expected '>' in integer set"))
73 return IntegerSetAttr::get(set
);
76 // Parse an array attribute.
77 case Token::l_square
: {
78 consumeToken(Token::l_square
);
79 SmallVector
<Attribute
, 4> elements
;
80 auto parseElt
= [&]() -> ParseResult
{
81 elements
.push_back(parseAttribute());
82 return elements
.back() ? success() : failure();
85 if (parseCommaSeparatedListUntil(Token::r_square
, parseElt
))
87 return builder
.getArrayAttr(elements
);
90 // Parse a boolean attribute.
92 consumeToken(Token::kw_false
);
93 return builder
.getBoolAttr(false);
95 consumeToken(Token::kw_true
);
96 return builder
.getBoolAttr(true);
98 // Parse a dense elements attribute.
100 return parseDenseElementsAttr(type
);
102 // Parse a dense resource elements attribute.
103 case Token::kw_dense_resource
:
104 return parseDenseResourceElementsAttr(type
);
106 // Parse a dense array attribute.
107 case Token::kw_array
:
108 return parseDenseArrayAttr(type
);
110 // Parse a dictionary attribute.
111 case Token::l_brace
: {
112 NamedAttrList elements
;
113 if (parseAttributeDict(elements
))
115 return elements
.getDictionary(getContext());
118 // Parse an extended attribute, i.e. alias or dialect attribute.
119 case Token::hash_identifier
:
120 return parseExtendedAttr(type
);
122 // Parse floating point and integer attributes.
123 case Token::floatliteral
:
124 return parseFloatAttr(type
, /*isNegative=*/false);
126 return parseDecOrHexAttr(type
, /*isNegative=*/false);
128 consumeToken(Token::minus
);
129 if (getToken().is(Token::integer
))
130 return parseDecOrHexAttr(type
, /*isNegative=*/true);
131 if (getToken().is(Token::floatliteral
))
132 return parseFloatAttr(type
, /*isNegative=*/true);
134 return (emitWrongTokenError(
135 "expected constant integer or floating point value"),
139 // Parse a location attribute.
140 case Token::kw_loc
: {
141 consumeToken(Token::kw_loc
);
143 LocationAttr locAttr
;
144 if (parseToken(Token::l_paren
, "expected '(' in inline location") ||
145 parseLocationInstance(locAttr
) ||
146 parseToken(Token::r_paren
, "expected ')' in inline location"))
151 // Parse a sparse elements attribute.
152 case Token::kw_sparse
:
153 return parseSparseElementsAttr(type
);
155 // Parse a strided layout attribute.
156 case Token::kw_strided
:
157 return parseStridedLayoutAttr();
159 // Parse a distinct attribute.
160 case Token::kw_distinct
:
161 return parseDistinctAttr(type
);
163 // Parse a string attribute.
164 case Token::string
: {
165 auto val
= getToken().getStringValue();
166 consumeToken(Token::string
);
167 // Parse the optional trailing colon type if one wasn't explicitly provided.
168 if (!type
&& consumeIf(Token::colon
) && !(type
= parseType()))
171 return type
? StringAttr::get(val
, type
)
172 : StringAttr::get(getContext(), val
);
175 // Parse a symbol reference attribute.
176 case Token::at_identifier
: {
177 // When populating the parser state, this is a list of locations for all of
178 // the nested references.
179 SmallVector
<SMRange
> referenceLocations
;
181 referenceLocations
.push_back(getToken().getLocRange());
183 // Parse the top-level reference.
184 std::string nameStr
= getToken().getSymbolReference();
185 consumeToken(Token::at_identifier
);
187 // Parse any nested references.
188 std::vector
<FlatSymbolRefAttr
> nestedRefs
;
189 while (getToken().is(Token::colon
)) {
190 // Check for the '::' prefix.
191 const char *curPointer
= getToken().getLoc().getPointer();
192 consumeToken(Token::colon
);
193 if (!consumeIf(Token::colon
)) {
194 if (getToken().isNot(Token::eof
, Token::error
)) {
195 state
.lex
.resetPointer(curPointer
);
200 // Parse the reference itself.
201 auto curLoc
= getToken().getLoc();
202 if (getToken().isNot(Token::at_identifier
)) {
203 emitError(curLoc
, "expected nested symbol reference identifier");
207 // If we are populating the assembly state, add the location for this
210 referenceLocations
.push_back(getToken().getLocRange());
212 std::string nameStr
= getToken().getSymbolReference();
213 consumeToken(Token::at_identifier
);
214 nestedRefs
.push_back(SymbolRefAttr::get(getContext(), nameStr
));
216 SymbolRefAttr symbolRefAttr
=
217 SymbolRefAttr::get(getContext(), nameStr
, nestedRefs
);
219 // If we are populating the assembly state, record this symbol reference.
221 state
.asmState
->addUses(symbolRefAttr
, referenceLocations
);
222 return symbolRefAttr
;
225 // Parse a 'unit' attribute.
227 consumeToken(Token::kw_unit
);
228 return builder
.getUnitAttr();
230 // Handle completion of an attribute.
231 case Token::code_complete
:
232 if (getToken().isCodeCompletionFor(Token::hash_identifier
))
233 return parseExtendedAttr(type
);
234 return codeCompleteAttribute();
237 // Parse a type attribute. We parse `Optional` here to allow for providing a
238 // better error message.
240 OptionalParseResult result
= parseOptionalType(type
);
241 if (!result
.has_value())
242 return emitWrongTokenError("expected attribute value"), Attribute();
243 return failed(*result
) ? Attribute() : TypeAttr::get(type
);
247 /// Parse an optional attribute with the provided type.
248 OptionalParseResult
Parser::parseOptionalAttribute(Attribute
&attribute
,
250 switch (getToken().getKind()) {
251 case Token::at_identifier
:
252 case Token::floatliteral
:
254 case Token::hash_identifier
:
255 case Token::kw_affine_map
:
256 case Token::kw_affine_set
:
257 case Token::kw_dense
:
258 case Token::kw_dense_resource
:
259 case Token::kw_false
:
261 case Token::kw_sparse
:
265 case Token::l_square
:
268 attribute
= parseAttribute(type
);
269 return success(attribute
!= nullptr);
272 // Parse an optional type attribute.
274 OptionalParseResult result
= parseOptionalType(type
);
275 if (result
.has_value() && succeeded(*result
))
276 attribute
= TypeAttr::get(type
);
280 OptionalParseResult
Parser::parseOptionalAttribute(ArrayAttr
&attribute
,
282 return parseOptionalAttributeWithToken(Token::l_square
, attribute
, type
);
284 OptionalParseResult
Parser::parseOptionalAttribute(StringAttr
&attribute
,
286 return parseOptionalAttributeWithToken(Token::string
, attribute
, type
);
288 OptionalParseResult
Parser::parseOptionalAttribute(SymbolRefAttr
&result
,
290 return parseOptionalAttributeWithToken(Token::at_identifier
, result
, type
);
293 /// Attribute dictionary.
295 /// attribute-dict ::= `{` `}`
296 /// | `{` attribute-entry (`,` attribute-entry)* `}`
297 /// attribute-entry ::= (bare-id | string-literal) `=` attribute-value
299 ParseResult
Parser::parseAttributeDict(NamedAttrList
&attributes
) {
300 llvm::SmallDenseSet
<StringAttr
> seenKeys
;
301 auto parseElt
= [&]() -> ParseResult
{
302 // The name of an attribute can either be a bare identifier, or a string.
303 std::optional
<StringAttr
> nameId
;
304 if (getToken().is(Token::string
))
305 nameId
= builder
.getStringAttr(getToken().getStringValue());
306 else if (getToken().isAny(Token::bare_identifier
, Token::inttype
) ||
307 getToken().isKeyword())
308 nameId
= builder
.getStringAttr(getTokenSpelling());
310 return emitWrongTokenError("expected attribute name");
313 return emitError("expected valid attribute name");
315 if (!seenKeys
.insert(*nameId
).second
)
316 return emitError("duplicate key '")
317 << nameId
->getValue() << "' in dictionary attribute";
320 // Lazy load a dialect in the context if there is a possible namespace.
321 auto splitName
= nameId
->strref().split('.');
322 if (!splitName
.second
.empty())
323 getContext()->getOrLoadDialect(splitName
.first
);
325 // Try to parse the '=' for the attribute value.
326 if (!consumeIf(Token::equal
)) {
327 // If there is no '=', we treat this as a unit attribute.
328 attributes
.push_back({*nameId
, builder
.getUnitAttr()});
332 auto attr
= parseAttribute();
335 attributes
.push_back({*nameId
, attr
});
339 return parseCommaSeparatedList(Delimiter::Braces
, parseElt
,
340 " in attribute dictionary");
343 /// Parse a float attribute.
344 Attribute
Parser::parseFloatAttr(Type type
, bool isNegative
) {
345 auto val
= getToken().getFloatingPointValue();
347 return (emitError("floating point value too large for attribute"), nullptr);
348 consumeToken(Token::floatliteral
);
350 // Default to F64 when no type is specified.
351 if (!consumeIf(Token::colon
))
352 type
= builder
.getF64Type();
353 else if (!(type
= parseType()))
356 if (!isa
<FloatType
>(type
))
357 return (emitError("floating point value not valid for specified type"),
359 return FloatAttr::get(type
, isNegative
? -*val
: *val
);
362 /// Construct an APint from a parsed value, a known attribute type and
364 static std::optional
<APInt
> buildAttributeAPInt(Type type
, bool isNegative
,
365 StringRef spelling
) {
366 // Parse the integer value into an APInt that is big enough to hold the value.
368 bool isHex
= spelling
.size() > 1 && spelling
[1] == 'x';
369 if (spelling
.getAsInteger(isHex
? 0 : 10, result
))
372 // Extend or truncate the bitwidth to the right size.
373 unsigned width
= type
.isIndex() ? IndexType::kInternalStorageBitWidth
374 : type
.getIntOrFloatBitWidth();
376 if (width
> result
.getBitWidth()) {
377 result
= result
.zext(width
);
378 } else if (width
< result
.getBitWidth()) {
379 // The parser can return an unnecessarily wide result with leading zeros.
380 // This isn't a problem, but truncating off bits is bad.
381 if (result
.countl_zero() < result
.getBitWidth() - width
)
384 result
= result
.trunc(width
);
388 // 0 bit integers cannot be negative and manipulation of their sign bit will
389 // assert, so short-cut validation here.
392 } else if (isNegative
) {
393 // The value is negative, we have an overflow if the sign bit is not set
394 // in the negated apInt.
396 if (!result
.isSignBitSet())
398 } else if ((type
.isSignedInteger() || type
.isIndex()) &&
399 result
.isSignBitSet()) {
400 // The value is a positive signed integer or index,
401 // we have an overflow if the sign bit is set.
408 /// Parse a decimal or a hexadecimal literal, which can be either an integer
409 /// or a float attribute.
410 Attribute
Parser::parseDecOrHexAttr(Type type
, bool isNegative
) {
411 Token tok
= getToken();
412 StringRef spelling
= tok
.getSpelling();
413 SMLoc loc
= tok
.getLoc();
415 consumeToken(Token::integer
);
417 // Default to i64 if not type is specified.
418 if (!consumeIf(Token::colon
))
419 type
= builder
.getIntegerType(64);
420 else if (!(type
= parseType()))
424 if (auto floatType
= dyn_cast
<FloatType
>(type
)) {
425 std::optional
<APFloat
> result
;
426 if (failed(parseFloatFromIntegerLiteral(result
, tok
, isNegative
,
427 floatType
.getFloatSemantics())))
429 return FloatAttr::get(floatType
, *result
);
432 if (!isa
<IntegerType
, IndexType
>(type
))
433 return emitError(loc
, "integer literal not valid for specified type"),
436 if (isNegative
&& type
.isUnsignedInteger()) {
438 "negative integer literal not valid for unsigned integer type");
442 std::optional
<APInt
> apInt
= buildAttributeAPInt(type
, isNegative
, spelling
);
444 return emitError(loc
, "integer constant out of range for attribute"),
446 return builder
.getIntegerAttr(type
, *apInt
);
449 //===----------------------------------------------------------------------===//
450 // TensorLiteralParser
451 //===----------------------------------------------------------------------===//
453 /// Parse elements values stored within a hex string. On success, the values are
454 /// stored into 'result'.
455 static ParseResult
parseElementAttrHexValues(Parser
&parser
, Token tok
,
456 std::string
&result
) {
457 if (std::optional
<std::string
> value
= tok
.getHexStringValue()) {
458 result
= std::move(*value
);
461 return parser
.emitError(
462 tok
.getLoc(), "expected string containing hex digits starting with `0x`");
466 /// This class implements a parser for TensorLiterals. A tensor literal is
467 /// either a single element (e.g, 5) or a multi-dimensional list of elements
468 /// (e.g., [[5, 5]]).
469 class TensorLiteralParser
{
471 TensorLiteralParser(Parser
&p
) : p(p
) {}
473 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
474 /// may also parse a tensor literal that is store as a hex string.
475 ParseResult
parse(bool allowHex
);
477 /// Build a dense attribute instance with the parsed elements and the given
479 DenseElementsAttr
getAttr(SMLoc loc
, ShapedType type
);
481 ArrayRef
<int64_t> getShape() const { return shape
; }
484 /// Get the parsed elements for an integer attribute.
485 ParseResult
getIntAttrElements(SMLoc loc
, Type eltTy
,
486 std::vector
<APInt
> &intValues
);
488 /// Get the parsed elements for a float attribute.
489 ParseResult
getFloatAttrElements(SMLoc loc
, FloatType eltTy
,
490 std::vector
<APFloat
> &floatValues
);
492 /// Build a Dense String attribute for the given type.
493 DenseElementsAttr
getStringAttr(SMLoc loc
, ShapedType type
, Type eltTy
);
495 /// Build a Dense attribute with hex data for the given type.
496 DenseElementsAttr
getHexAttr(SMLoc loc
, ShapedType type
);
498 /// Parse a single element, returning failure if it isn't a valid element
499 /// literal. For example:
500 /// parseElement(1) -> Success, 1
501 /// parseElement([1]) -> Failure
502 ParseResult
parseElement();
504 /// Parse a list of either lists or elements, returning the dimensions of the
505 /// parsed sub-tensors in dims. For example:
506 /// parseList([1, 2, 3]) -> Success, [3]
507 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
508 /// parseList([[1, 2], 3]) -> Failure
509 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
510 ParseResult
parseList(SmallVectorImpl
<int64_t> &dims
);
512 /// Parse a literal that was printed as a hex string.
513 ParseResult
parseHexElements();
517 /// The shape inferred from the parsed elements.
518 SmallVector
<int64_t, 4> shape
;
520 /// Storage used when parsing elements, this is a pair of <is_negated, token>.
521 std::vector
<std::pair
<bool, Token
>> storage
;
523 /// Storage used when parsing elements that were stored as hex values.
524 std::optional
<Token
> hexStorage
;
528 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
529 /// may also parse a tensor literal that is store as a hex string.
530 ParseResult
TensorLiteralParser::parse(bool allowHex
) {
531 // If hex is allowed, check for a string literal.
532 if (allowHex
&& p
.getToken().is(Token::string
)) {
533 hexStorage
= p
.getToken();
534 p
.consumeToken(Token::string
);
537 // Otherwise, parse a list or an individual element.
538 if (p
.getToken().is(Token::l_square
))
539 return parseList(shape
);
540 return parseElement();
543 /// Build a dense attribute instance with the parsed elements and the given
545 DenseElementsAttr
TensorLiteralParser::getAttr(SMLoc loc
, ShapedType type
) {
546 Type eltType
= type
.getElementType();
548 // Check to see if we parse the literal from a hex string.
550 (eltType
.isIntOrIndexOrFloat() || isa
<ComplexType
>(eltType
)))
551 return getHexAttr(loc
, type
);
553 // Check that the parsed storage size has the same number of elements to the
554 // type, or is a known splat.
555 if (!shape
.empty() && getShape() != type
.getShape()) {
556 p
.emitError(loc
) << "inferred shape of elements literal ([" << getShape()
557 << "]) does not match type ([" << type
.getShape() << "])";
561 // Handle the case where no elements were parsed.
562 if (!hexStorage
&& storage
.empty() && type
.getNumElements()) {
563 p
.emitError(loc
) << "parsed zero elements, but type (" << type
564 << ") expected at least 1";
568 // Handle complex types in the specific element type cases below.
569 bool isComplex
= false;
570 if (ComplexType complexTy
= dyn_cast
<ComplexType
>(eltType
)) {
571 eltType
= complexTy
.getElementType();
575 // Handle integer and index types.
576 if (eltType
.isIntOrIndex()) {
577 std::vector
<APInt
> intValues
;
578 if (failed(getIntAttrElements(loc
, eltType
, intValues
)))
581 // If this is a complex, treat the parsed values as complex values.
582 auto complexData
= llvm::ArrayRef(
583 reinterpret_cast<std::complex<APInt
> *>(intValues
.data()),
584 intValues
.size() / 2);
585 return DenseElementsAttr::get(type
, complexData
);
587 return DenseElementsAttr::get(type
, intValues
);
589 // Handle floating point types.
590 if (FloatType floatTy
= dyn_cast
<FloatType
>(eltType
)) {
591 std::vector
<APFloat
> floatValues
;
592 if (failed(getFloatAttrElements(loc
, floatTy
, floatValues
)))
595 // If this is a complex, treat the parsed values as complex values.
596 auto complexData
= llvm::ArrayRef(
597 reinterpret_cast<std::complex<APFloat
> *>(floatValues
.data()),
598 floatValues
.size() / 2);
599 return DenseElementsAttr::get(type
, complexData
);
601 return DenseElementsAttr::get(type
, floatValues
);
604 // Other types are assumed to be string representations.
605 return getStringAttr(loc
, type
, type
.getElementType());
608 /// Build a Dense Integer attribute for the given type.
610 TensorLiteralParser::getIntAttrElements(SMLoc loc
, Type eltTy
,
611 std::vector
<APInt
> &intValues
) {
612 intValues
.reserve(storage
.size());
613 bool isUintType
= eltTy
.isUnsignedInteger();
614 for (const auto &signAndToken
: storage
) {
615 bool isNegative
= signAndToken
.first
;
616 const Token
&token
= signAndToken
.second
;
617 auto tokenLoc
= token
.getLoc();
619 if (isNegative
&& isUintType
) {
620 return p
.emitError(tokenLoc
)
621 << "expected unsigned integer elements, but parsed negative value";
624 // Check to see if floating point values were parsed.
625 if (token
.is(Token::floatliteral
)) {
626 return p
.emitError(tokenLoc
)
627 << "expected integer elements, but parsed floating-point";
630 assert(token
.isAny(Token::integer
, Token::kw_true
, Token::kw_false
) &&
631 "unexpected token type");
632 if (token
.isAny(Token::kw_true
, Token::kw_false
)) {
633 if (!eltTy
.isInteger(1)) {
634 return p
.emitError(tokenLoc
)
635 << "expected i1 type for 'true' or 'false' values";
637 APInt
apInt(1, token
.is(Token::kw_true
), /*isSigned=*/false);
638 intValues
.push_back(apInt
);
642 // Create APInt values for each element with the correct bitwidth.
643 std::optional
<APInt
> apInt
=
644 buildAttributeAPInt(eltTy
, isNegative
, token
.getSpelling());
646 return p
.emitError(tokenLoc
, "integer constant out of range for type");
647 intValues
.push_back(*apInt
);
652 /// Build a Dense Float attribute for the given type.
654 TensorLiteralParser::getFloatAttrElements(SMLoc loc
, FloatType eltTy
,
655 std::vector
<APFloat
> &floatValues
) {
656 floatValues
.reserve(storage
.size());
657 for (const auto &signAndToken
: storage
) {
658 bool isNegative
= signAndToken
.first
;
659 const Token
&token
= signAndToken
.second
;
660 std::optional
<APFloat
> result
;
661 if (failed(p
.parseFloatFromLiteral(result
, token
, isNegative
,
662 eltTy
.getFloatSemantics())))
664 floatValues
.push_back(*result
);
669 /// Build a Dense String attribute for the given type.
670 DenseElementsAttr
TensorLiteralParser::getStringAttr(SMLoc loc
, ShapedType type
,
672 if (hexStorage
.has_value()) {
673 auto stringValue
= hexStorage
->getStringValue();
674 return DenseStringElementsAttr::get(type
, {stringValue
});
677 std::vector
<std::string
> stringValues
;
678 std::vector
<StringRef
> stringRefValues
;
679 stringValues
.reserve(storage
.size());
680 stringRefValues
.reserve(storage
.size());
682 for (auto val
: storage
) {
683 stringValues
.push_back(val
.second
.getStringValue());
684 stringRefValues
.emplace_back(stringValues
.back());
687 return DenseStringElementsAttr::get(type
, stringRefValues
);
690 /// Build a Dense attribute with hex data for the given type.
691 DenseElementsAttr
TensorLiteralParser::getHexAttr(SMLoc loc
, ShapedType type
) {
692 Type elementType
= type
.getElementType();
693 if (!elementType
.isIntOrIndexOrFloat() && !isa
<ComplexType
>(elementType
)) {
695 << "expected floating-point, integer, or complex element type, got "
701 if (parseElementAttrHexValues(p
, *hexStorage
, data
))
704 ArrayRef
<char> rawData(data
.data(), data
.size());
705 bool detectedSplat
= false;
706 if (!DenseElementsAttr::isValidRawBuffer(type
, rawData
, detectedSplat
)) {
707 p
.emitError(loc
) << "elements hex data size is invalid for provided type: "
712 if (llvm::endianness::native
== llvm::endianness::big
) {
713 // Convert endianess in big-endian(BE) machines. `rawData` is
714 // little-endian(LE) because HEX in raw data of dense element attribute
715 // is always LE format. It is converted into BE here to be used in BE
717 SmallVector
<char, 64> outDataVec(rawData
.size());
718 MutableArrayRef
<char> convRawData(outDataVec
);
719 DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
720 rawData
, convRawData
, type
);
721 return DenseElementsAttr::getFromRawBuffer(type
, convRawData
);
724 return DenseElementsAttr::getFromRawBuffer(type
, rawData
);
727 ParseResult
TensorLiteralParser::parseElement() {
728 switch (p
.getToken().getKind()) {
729 // Parse a boolean element.
731 case Token::kw_false
:
732 case Token::floatliteral
:
734 storage
.emplace_back(/*isNegative=*/false, p
.getToken());
738 // Parse a signed integer or a negative floating-point element.
740 p
.consumeToken(Token::minus
);
741 if (!p
.getToken().isAny(Token::floatliteral
, Token::integer
))
742 return p
.emitError("expected integer or floating point literal");
743 storage
.emplace_back(/*isNegative=*/true, p
.getToken());
748 storage
.emplace_back(/*isNegative=*/false, p
.getToken());
752 // Parse a complex element of the form '(' element ',' element ')'.
754 p
.consumeToken(Token::l_paren
);
755 if (parseElement() ||
756 p
.parseToken(Token::comma
, "expected ',' between complex elements") ||
758 p
.parseToken(Token::r_paren
, "expected ')' after complex elements"))
763 return p
.emitError("expected element literal of primitive type");
769 /// Parse a list of either lists or elements, returning the dimensions of the
770 /// parsed sub-tensors in dims. For example:
771 /// parseList([1, 2, 3]) -> Success, [3]
772 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
773 /// parseList([[1, 2], 3]) -> Failure
774 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
775 ParseResult
TensorLiteralParser::parseList(SmallVectorImpl
<int64_t> &dims
) {
776 auto checkDims
= [&](const SmallVectorImpl
<int64_t> &prevDims
,
777 const SmallVectorImpl
<int64_t> &newDims
) -> ParseResult
{
778 if (prevDims
== newDims
)
780 return p
.emitError("tensor literal is invalid; ranks are not consistent "
785 SmallVector
<int64_t, 4> newDims
;
787 auto parseOneElement
= [&]() -> ParseResult
{
788 SmallVector
<int64_t, 4> thisDims
;
789 if (p
.getToken().getKind() == Token::l_square
) {
790 if (parseList(thisDims
))
792 } else if (parseElement()) {
797 return checkDims(newDims
, thisDims
);
802 if (p
.parseCommaSeparatedList(Parser::Delimiter::Square
, parseOneElement
))
805 // Return the sublists' dimensions with 'size' prepended.
807 dims
.push_back(size
);
808 dims
.append(newDims
.begin(), newDims
.end());
812 //===----------------------------------------------------------------------===//
813 // DenseArrayAttr Parser
814 //===----------------------------------------------------------------------===//
817 /// A generic dense array element parser. It parsers integer and floating point
819 class DenseArrayElementParser
{
821 explicit DenseArrayElementParser(Type type
) : type(type
) {}
823 /// Parse an integer element.
824 ParseResult
parseIntegerElement(Parser
&p
);
826 /// Parse a floating point element.
827 ParseResult
parseFloatElement(Parser
&p
);
829 /// Convert the current contents to a dense array.
830 DenseArrayAttr
getAttr() { return DenseArrayAttr::get(type
, size
, rawData
); }
833 /// Append the raw data of an APInt to the result.
834 void append(const APInt
&data
);
836 /// The array element type.
838 /// The resultant byte array representing the contents of the array.
839 std::vector
<char> rawData
;
840 /// The number of elements in the array.
845 void DenseArrayElementParser::append(const APInt
&data
) {
846 if (data
.getBitWidth()) {
847 assert(data
.getBitWidth() % 8 == 0);
848 unsigned byteSize
= data
.getBitWidth() / 8;
849 size_t offset
= rawData
.size();
850 rawData
.insert(rawData
.end(), byteSize
, 0);
851 llvm::StoreIntToMemory(
852 data
, reinterpret_cast<uint8_t *>(rawData
.data() + offset
), byteSize
);
857 ParseResult
DenseArrayElementParser::parseIntegerElement(Parser
&p
) {
858 bool isNegative
= p
.consumeIf(Token::minus
);
860 // Parse an integer literal as an APInt.
861 std::optional
<APInt
> value
;
862 StringRef spelling
= p
.getToken().getSpelling();
863 if (p
.getToken().isAny(Token::kw_true
, Token::kw_false
)) {
864 if (!type
.isInteger(1))
865 return p
.emitError("expected i1 type for 'true' or 'false' values");
866 value
= APInt(/*numBits=*/8, p
.getToken().is(Token::kw_true
),
867 !type
.isUnsignedInteger());
869 } else if (p
.consumeIf(Token::integer
)) {
870 value
= buildAttributeAPInt(type
, isNegative
, spelling
);
872 return p
.emitError("integer constant out of range");
874 return p
.emitError("expected integer literal");
880 ParseResult
DenseArrayElementParser::parseFloatElement(Parser
&p
) {
881 bool isNegative
= p
.consumeIf(Token::minus
);
882 Token token
= p
.getToken();
883 std::optional
<APFloat
> fromIntLit
;
885 p
.parseFloatFromLiteral(fromIntLit
, token
, isNegative
,
886 cast
<FloatType
>(type
).getFloatSemantics())))
889 append(fromIntLit
->bitcastToAPInt());
893 /// Parse a dense array attribute.
894 Attribute
Parser::parseDenseArrayAttr(Type attrType
) {
895 consumeToken(Token::kw_array
);
896 if (parseToken(Token::less
, "expected '<' after 'array'"))
899 SMLoc typeLoc
= getToken().getLoc();
900 Type eltType
= parseType();
902 emitError(typeLoc
, "expected an integer or floating point type");
906 // Only bool or integer and floating point elements divisible by bytes are
908 if (!eltType
.isIntOrIndexOrFloat()) {
909 emitError(typeLoc
, "expected integer or float type, got: ") << eltType
;
912 if (!eltType
.isInteger(1) && eltType
.getIntOrFloatBitWidth() % 8 != 0) {
913 emitError(typeLoc
, "element type bitwidth must be a multiple of 8");
917 // Check for empty list.
918 if (consumeIf(Token::greater
))
919 return DenseArrayAttr::get(eltType
, 0, {});
921 if (parseToken(Token::colon
, "expected ':' after dense array type"))
924 DenseArrayElementParser
eltParser(eltType
);
925 if (eltType
.isIntOrIndex()) {
926 if (parseCommaSeparatedList(
927 [&] { return eltParser
.parseIntegerElement(*this); }))
930 if (parseCommaSeparatedList(
931 [&] { return eltParser
.parseFloatElement(*this); }))
934 if (parseToken(Token::greater
, "expected '>' to close an array attribute"))
936 return eltParser
.getAttr();
939 /// Parse a dense elements attribute.
940 Attribute
Parser::parseDenseElementsAttr(Type attrType
) {
941 auto attribLoc
= getToken().getLoc();
942 consumeToken(Token::kw_dense
);
943 if (parseToken(Token::less
, "expected '<' after 'dense'"))
946 // Parse the literal data if necessary.
947 TensorLiteralParser
literalParser(*this);
948 if (!consumeIf(Token::greater
)) {
949 if (literalParser
.parse(/*allowHex=*/true) ||
950 parseToken(Token::greater
, "expected '>'"))
954 // If the type is specified `parseElementsLiteralType` will not parse a type.
955 // Use the attribute location as the location for error reporting in that
957 auto loc
= attrType
? attribLoc
: getToken().getLoc();
958 auto type
= parseElementsLiteralType(attrType
);
961 return literalParser
.getAttr(loc
, type
);
964 Attribute
Parser::parseDenseResourceElementsAttr(Type attrType
) {
965 auto loc
= getToken().getLoc();
966 consumeToken(Token::kw_dense_resource
);
967 if (parseToken(Token::less
, "expected '<' after 'dense_resource'"))
970 // Parse the resource handle.
971 FailureOr
<AsmDialectResourceHandle
> rawHandle
=
972 parseResourceHandle(getContext()->getLoadedDialect
<BuiltinDialect
>());
973 if (failed(rawHandle
) || parseToken(Token::greater
, "expected '>'"))
976 auto *handle
= dyn_cast
<DenseResourceElementsHandle
>(&*rawHandle
);
978 return emitError(loc
, "invalid `dense_resource` handle type"), nullptr;
980 // Parse the type of the attribute if the user didn't provide one.
983 typeLoc
= getToken().getLoc();
984 if (parseToken(Token::colon
, "expected ':'") || !(attrType
= parseType()))
988 ShapedType shapedType
= dyn_cast
<ShapedType
>(attrType
);
990 emitError(typeLoc
, "`dense_resource` expected a shaped type");
994 return DenseResourceElementsAttr::get(shapedType
, *handle
);
997 /// Shaped type for elements attribute.
999 /// elements-literal-type ::= vector-type | ranked-tensor-type
1001 /// This method also checks the type has static shape.
1002 ShapedType
Parser::parseElementsLiteralType(Type type
) {
1003 // If the user didn't provide a type, parse the colon type for the literal.
1005 if (parseToken(Token::colon
, "expected ':'"))
1007 if (!(type
= parseType()))
1011 auto sType
= dyn_cast
<ShapedType
>(type
);
1013 emitError("elements literal must be a shaped type");
1017 if (!sType
.hasStaticShape())
1018 return (emitError("elements literal type must have static shape"), nullptr);
1023 /// Parse a sparse elements attribute.
1024 Attribute
Parser::parseSparseElementsAttr(Type attrType
) {
1025 SMLoc loc
= getToken().getLoc();
1026 consumeToken(Token::kw_sparse
);
1027 if (parseToken(Token::less
, "Expected '<' after 'sparse'"))
1030 // Check for the case where all elements are sparse. The indices are
1031 // represented by a 2-dimensional shape where the second dimension is the rank
1033 Type indiceEltType
= builder
.getIntegerType(64);
1034 if (consumeIf(Token::greater
)) {
1035 ShapedType type
= parseElementsLiteralType(attrType
);
1039 // Construct the sparse elements attr using zero element indice/value
1041 ShapedType indicesType
=
1042 RankedTensorType::get({0, type
.getRank()}, indiceEltType
);
1043 ShapedType valuesType
= RankedTensorType::get({0}, type
.getElementType());
1044 return getChecked
<SparseElementsAttr
>(
1045 loc
, type
, DenseElementsAttr::get(indicesType
, ArrayRef
<Attribute
>()),
1046 DenseElementsAttr::get(valuesType
, ArrayRef
<Attribute
>()));
1049 /// Parse the indices. We don't allow hex values here as we may need to use
1050 /// the inferred shape.
1051 auto indicesLoc
= getToken().getLoc();
1052 TensorLiteralParser
indiceParser(*this);
1053 if (indiceParser
.parse(/*allowHex=*/false))
1056 if (parseToken(Token::comma
, "expected ','"))
1059 /// Parse the values.
1060 auto valuesLoc
= getToken().getLoc();
1061 TensorLiteralParser
valuesParser(*this);
1062 if (valuesParser
.parse(/*allowHex=*/true))
1065 if (parseToken(Token::greater
, "expected '>'"))
1068 auto type
= parseElementsLiteralType(attrType
);
1072 // If the indices are a splat, i.e. the literal parser parsed an element and
1073 // not a list, we set the shape explicitly. The indices are represented by a
1074 // 2-dimensional shape where the second dimension is the rank of the type.
1075 // Given that the parsed indices is a splat, we know that we only have one
1076 // indice and thus one for the first dimension.
1077 ShapedType indicesType
;
1078 if (indiceParser
.getShape().empty()) {
1079 indicesType
= RankedTensorType::get({1, type
.getRank()}, indiceEltType
);
1081 // Otherwise, set the shape to the one parsed by the literal parser.
1082 indicesType
= RankedTensorType::get(indiceParser
.getShape(), indiceEltType
);
1084 auto indices
= indiceParser
.getAttr(indicesLoc
, indicesType
);
1086 // If the values are a splat, set the shape explicitly based on the number of
1087 // indices. The number of indices is encoded in the first dimension of the
1088 // indice shape type.
1089 auto valuesEltType
= type
.getElementType();
1090 ShapedType valuesType
=
1091 valuesParser
.getShape().empty()
1092 ? RankedTensorType::get({indicesType
.getDimSize(0)}, valuesEltType
)
1093 : RankedTensorType::get(valuesParser
.getShape(), valuesEltType
);
1094 auto values
= valuesParser
.getAttr(valuesLoc
, valuesType
);
1096 // Build the sparse elements attribute by the indices and values.
1097 return getChecked
<SparseElementsAttr
>(loc
, type
, indices
, values
);
1100 Attribute
Parser::parseStridedLayoutAttr() {
1101 // Callback for error emissing at the keyword token location.
1102 llvm::SMLoc loc
= getToken().getLoc();
1103 auto errorEmitter
= [&] { return emitError(loc
); };
1105 consumeToken(Token::kw_strided
);
1106 if (failed(parseToken(Token::less
, "expected '<' after 'strided'")) ||
1107 failed(parseToken(Token::l_square
, "expected '['")))
1110 // Parses either an integer token or a question mark token. Reports an error
1111 // and returns std::nullopt if the current token is neither. The integer token
1112 // must fit into int64_t limits.
1113 auto parseStrideOrOffset
= [&]() -> std::optional
<int64_t> {
1114 if (consumeIf(Token::question
))
1115 return ShapedType::kDynamic
;
1117 SMLoc loc
= getToken().getLoc();
1118 auto emitWrongTokenError
= [&] {
1119 emitError(loc
, "expected a 64-bit signed integer or '?'");
1120 return std::nullopt
;
1123 bool negative
= consumeIf(Token::minus
);
1125 if (getToken().is(Token::integer
)) {
1126 std::optional
<uint64_t> value
= getToken().getUInt64IntegerValue();
1128 *value
> static_cast<uint64_t>(std::numeric_limits
<int64_t>::max()))
1129 return emitWrongTokenError();
1131 auto result
= static_cast<int64_t>(*value
);
1138 return emitWrongTokenError();
1142 SmallVector
<int64_t> strides
;
1143 if (!getToken().is(Token::r_square
)) {
1145 std::optional
<int64_t> stride
= parseStrideOrOffset();
1148 strides
.push_back(*stride
);
1149 } while (consumeIf(Token::comma
));
1152 if (failed(parseToken(Token::r_square
, "expected ']'")))
1155 // Fast path in absence of offset.
1156 if (consumeIf(Token::greater
)) {
1157 if (failed(StridedLayoutAttr::verify(errorEmitter
,
1158 /*offset=*/0, strides
)))
1160 return StridedLayoutAttr::get(getContext(), /*offset=*/0, strides
);
1163 if (failed(parseToken(Token::comma
, "expected ','")) ||
1164 failed(parseToken(Token::kw_offset
, "expected 'offset' after comma")) ||
1165 failed(parseToken(Token::colon
, "expected ':' after 'offset'")))
1168 std::optional
<int64_t> offset
= parseStrideOrOffset();
1169 if (!offset
|| failed(parseToken(Token::greater
, "expected '>'")))
1172 if (failed(StridedLayoutAttr::verify(errorEmitter
, *offset
, strides
)))
1174 return StridedLayoutAttr::get(getContext(), *offset
, strides
);
1175 // return getChecked<StridedLayoutAttr>(loc,getContext(), *offset, strides);
1178 /// Parse a distinct attribute.
1180 /// distinct-attribute ::= `distinct`
1181 /// `[` integer-literal `]<` attribute-value `>`
1183 Attribute
Parser::parseDistinctAttr(Type type
) {
1184 SMLoc loc
= getToken().getLoc();
1185 consumeToken(Token::kw_distinct
);
1186 if (parseToken(Token::l_square
, "expected '[' after 'distinct'"))
1189 // Parse the distinct integer identifier.
1190 Token token
= getToken();
1191 if (parseToken(Token::integer
, "expected distinct ID"))
1193 std::optional
<uint64_t> value
= token
.getUInt64IntegerValue();
1195 emitError("expected an unsigned 64-bit integer");
1199 // Parse the referenced attribute.
1200 if (parseToken(Token::r_square
, "expected ']' to close distinct ID") ||
1201 parseToken(Token::less
, "expected '<' after distinct ID"))
1204 Attribute referencedAttr
;
1205 if (getToken().is(Token::greater
)) {
1207 referencedAttr
= builder
.getUnitAttr();
1209 referencedAttr
= parseAttribute(type
);
1210 if (!referencedAttr
) {
1211 emitError("expected attribute");
1215 if (parseToken(Token::greater
, "expected '>' to close distinct attribute"))
1219 // Add the distinct attribute to the parser state, if it has not been parsed
1220 // before. Otherwise, check if the parsed reference attribute matches the one
1221 // found in the parser state.
1222 DenseMap
<uint64_t, DistinctAttr
> &distinctAttrs
=
1223 state
.symbols
.distinctAttributes
;
1224 auto it
= distinctAttrs
.find(*value
);
1225 if (it
== distinctAttrs
.end()) {
1226 DistinctAttr distinctAttr
= DistinctAttr::create(referencedAttr
);
1227 it
= distinctAttrs
.try_emplace(*value
, distinctAttr
).first
;
1228 } else if (it
->getSecond().getReferencedAttr() != referencedAttr
) {
1229 emitError(loc
, "referenced attribute does not match previous definition: ")
1230 << it
->getSecond().getReferencedAttr();
1234 return it
->getSecond();