1 //===- TypeParser.cpp - MLIR Type Parser Implementation -------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the parser for the MLIR Types.
11 //===----------------------------------------------------------------------===//
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/BuiltinAttributeInterfaces.h"
16 #include "mlir/IR/BuiltinTypeInterfaces.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/OpDefinition.h"
19 #include "mlir/IR/TensorEncoding.h"
20 #include "mlir/IR/Types.h"
21 #include "mlir/Support/LLVM.h"
22 #include "mlir/Support/LogicalResult.h"
23 #include "llvm/ADT/STLExtras.h"
30 using namespace mlir::detail
;
32 /// Optionally parse a type.
33 OptionalParseResult
Parser::parseOptionalType(Type
&type
) {
34 // There are many different starting tokens for a type, check them here.
35 switch (getToken().getKind()) {
37 case Token::kw_memref
:
38 case Token::kw_tensor
:
39 case Token::kw_complex
:
41 case Token::kw_vector
:
43 case Token::kw_f8E5M2
:
44 case Token::kw_f8E4M3FN
:
45 case Token::kw_f8E5M2FNUZ
:
46 case Token::kw_f8E4M3FNUZ
:
47 case Token::kw_f8E4M3B11FNUZ
:
57 case Token::exclamation_identifier
:
58 return failure(!(type
= parseType()));
65 /// Parse an arbitrary type.
67 /// type ::= function-type
68 /// | non-function-type
70 Type
Parser::parseType() {
71 if (getToken().is(Token::l_paren
))
72 return parseFunctionType();
73 return parseNonFunctionType();
76 /// Parse a function result type.
78 /// function-result-type ::= type-list-parens
79 /// | non-function-type
81 ParseResult
Parser::parseFunctionResultTypes(SmallVectorImpl
<Type
> &elements
) {
82 if (getToken().is(Token::l_paren
))
83 return parseTypeListParens(elements
);
85 Type t
= parseNonFunctionType();
88 elements
.push_back(t
);
92 /// Parse a list of types without an enclosing parenthesis. The list must have
93 /// at least one member.
95 /// type-list-no-parens ::= type (`,` type)*
97 ParseResult
Parser::parseTypeListNoParens(SmallVectorImpl
<Type
> &elements
) {
98 auto parseElt
= [&]() -> ParseResult
{
99 auto elt
= parseType();
100 elements
.push_back(elt
);
101 return elt
? success() : failure();
104 return parseCommaSeparatedList(parseElt
);
107 /// Parse a parenthesized list of types.
109 /// type-list-parens ::= `(` `)`
110 /// | `(` type-list-no-parens `)`
112 ParseResult
Parser::parseTypeListParens(SmallVectorImpl
<Type
> &elements
) {
113 if (parseToken(Token::l_paren
, "expected '('"))
116 // Handle empty lists.
117 if (getToken().is(Token::r_paren
))
118 return consumeToken(), success();
120 if (parseTypeListNoParens(elements
) ||
121 parseToken(Token::r_paren
, "expected ')'"))
126 /// Parse a complex type.
128 /// complex-type ::= `complex` `<` type `>`
130 Type
Parser::parseComplexType() {
131 consumeToken(Token::kw_complex
);
134 if (parseToken(Token::less
, "expected '<' in complex type"))
137 SMLoc elementTypeLoc
= getToken().getLoc();
138 auto elementType
= parseType();
140 parseToken(Token::greater
, "expected '>' in complex type"))
142 if (!isa
<FloatType
>(elementType
) && !isa
<IntegerType
>(elementType
))
143 return emitError(elementTypeLoc
, "invalid element type for complex"),
146 return ComplexType::get(elementType
);
149 /// Parse a function type.
151 /// function-type ::= type-list-parens `->` function-result-type
153 Type
Parser::parseFunctionType() {
154 assert(getToken().is(Token::l_paren
));
156 SmallVector
<Type
, 4> arguments
, results
;
157 if (parseTypeListParens(arguments
) ||
158 parseToken(Token::arrow
, "expected '->' in function type") ||
159 parseFunctionResultTypes(results
))
162 return builder
.getFunctionType(arguments
, results
);
165 /// Parse a memref type.
167 /// memref-type ::= ranked-memref-type | unranked-memref-type
169 /// ranked-memref-type ::= `memref` `<` dimension-list-ranked type
170 /// (`,` layout-specification)? (`,` memory-space)? `>`
172 /// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
174 /// stride-list ::= `[` (dimension (`,` dimension)*)? `]`
175 /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
176 /// layout-specification ::= semi-affine-map | strided-layout | attribute
177 /// memory-space ::= integer-literal | attribute
179 Type
Parser::parseMemRefType() {
180 SMLoc loc
= getToken().getLoc();
181 consumeToken(Token::kw_memref
);
183 if (parseToken(Token::less
, "expected '<' in memref type"))
187 SmallVector
<int64_t, 4> dimensions
;
189 if (consumeIf(Token::star
)) {
190 // This is an unranked memref type.
192 if (parseXInDimensionList())
197 if (parseDimensionListRanked(dimensions
))
201 // Parse the element type.
202 auto typeLoc
= getToken().getLoc();
203 auto elementType
= parseType();
207 // Check that memref is formed from allowed types.
208 if (!BaseMemRefType::isValidElementType(elementType
))
209 return emitError(typeLoc
, "invalid memref element type"), nullptr;
211 MemRefLayoutAttrInterface layout
;
212 Attribute memorySpace
;
214 auto parseElt
= [&]() -> ParseResult
{
215 // Either it is MemRefLayoutAttrInterface or memory space attribute.
216 Attribute attr
= parseAttribute();
220 if (isa
<MemRefLayoutAttrInterface
>(attr
)) {
221 layout
= cast
<MemRefLayoutAttrInterface
>(attr
);
222 } else if (memorySpace
) {
223 return emitError("multiple memory spaces specified in memref type");
230 return emitError("cannot have affine map for unranked memref type");
232 return emitError("expected memory space to be last in memref type");
237 // Parse a list of mappings and address space if present.
238 if (!consumeIf(Token::greater
)) {
239 // Parse comma separated list of affine maps, followed by memory space.
240 if (parseToken(Token::comma
, "expected ',' or '>' in memref type") ||
241 parseCommaSeparatedListUntil(Token::greater
, parseElt
,
242 /*allowEmptyList=*/false)) {
248 return getChecked
<UnrankedMemRefType
>(loc
, elementType
, memorySpace
);
250 return getChecked
<MemRefType
>(loc
, dimensions
, elementType
, layout
,
254 /// Parse any type except the function type.
256 /// non-function-type ::= integer-type
267 /// index-type ::= `index`
268 /// float-type ::= `f16` | `bf16` | `f32` | `f64` | `f80` | `f128`
269 /// none-type ::= `none`
271 Type
Parser::parseNonFunctionType() {
272 switch (getToken().getKind()) {
274 return (emitWrongTokenError("expected non-function type"), nullptr);
275 case Token::kw_memref
:
276 return parseMemRefType();
277 case Token::kw_tensor
:
278 return parseTensorType();
279 case Token::kw_complex
:
280 return parseComplexType();
281 case Token::kw_tuple
:
282 return parseTupleType();
283 case Token::kw_vector
:
284 return parseVectorType();
286 case Token::inttype
: {
287 auto width
= getToken().getIntTypeBitwidth();
288 if (!width
.has_value())
289 return (emitError("invalid integer width"), nullptr);
290 if (*width
> IntegerType::kMaxWidth
) {
291 emitError(getToken().getLoc(), "integer bitwidth is limited to ")
292 << IntegerType::kMaxWidth
<< " bits";
296 IntegerType::SignednessSemantics signSemantics
= IntegerType::Signless
;
297 if (std::optional
<bool> signedness
= getToken().getIntTypeSignedness())
298 signSemantics
= *signedness
? IntegerType::Signed
: IntegerType::Unsigned
;
300 consumeToken(Token::inttype
);
301 return IntegerType::get(getContext(), *width
, signSemantics
);
305 case Token::kw_f8E5M2
:
306 consumeToken(Token::kw_f8E5M2
);
307 return builder
.getFloat8E5M2Type();
308 case Token::kw_f8E4M3FN
:
309 consumeToken(Token::kw_f8E4M3FN
);
310 return builder
.getFloat8E4M3FNType();
311 case Token::kw_f8E5M2FNUZ
:
312 consumeToken(Token::kw_f8E5M2FNUZ
);
313 return builder
.getFloat8E5M2FNUZType();
314 case Token::kw_f8E4M3FNUZ
:
315 consumeToken(Token::kw_f8E4M3FNUZ
);
316 return builder
.getFloat8E4M3FNUZType();
317 case Token::kw_f8E4M3B11FNUZ
:
318 consumeToken(Token::kw_f8E4M3B11FNUZ
);
319 return builder
.getFloat8E4M3B11FNUZType();
321 consumeToken(Token::kw_bf16
);
322 return builder
.getBF16Type();
324 consumeToken(Token::kw_f16
);
325 return builder
.getF16Type();
327 consumeToken(Token::kw_tf32
);
328 return builder
.getTF32Type();
330 consumeToken(Token::kw_f32
);
331 return builder
.getF32Type();
333 consumeToken(Token::kw_f64
);
334 return builder
.getF64Type();
336 consumeToken(Token::kw_f80
);
337 return builder
.getF80Type();
339 consumeToken(Token::kw_f128
);
340 return builder
.getF128Type();
343 case Token::kw_index
:
344 consumeToken(Token::kw_index
);
345 return builder
.getIndexType();
349 consumeToken(Token::kw_none
);
350 return builder
.getNoneType();
353 case Token::exclamation_identifier
:
354 return parseExtendedType();
356 // Handle completion of a dialect type.
357 case Token::code_complete
:
358 if (getToken().isCodeCompletionFor(Token::exclamation_identifier
))
359 return parseExtendedType();
360 return codeCompleteType();
364 /// Parse a tensor type.
366 /// tensor-type ::= `tensor` `<` dimension-list type `>`
367 /// dimension-list ::= dimension-list-ranked | `*x`
369 Type
Parser::parseTensorType() {
370 consumeToken(Token::kw_tensor
);
372 if (parseToken(Token::less
, "expected '<' in tensor type"))
376 SmallVector
<int64_t, 4> dimensions
;
378 if (consumeIf(Token::star
)) {
379 // This is an unranked tensor type.
382 if (parseXInDimensionList())
387 if (parseDimensionListRanked(dimensions
))
391 // Parse the element type.
392 auto elementTypeLoc
= getToken().getLoc();
393 auto elementType
= parseType();
395 // Parse an optional encoding attribute.
397 if (consumeIf(Token::comma
)) {
398 auto parseResult
= parseOptionalAttribute(encoding
);
399 if (parseResult
.has_value()) {
400 if (failed(parseResult
.value()))
402 if (auto v
= dyn_cast_or_null
<VerifiableTensorEncoding
>(encoding
)) {
403 if (failed(v
.verifyEncoding(dimensions
, elementType
,
404 [&] { return emitError(); })))
410 if (!elementType
|| parseToken(Token::greater
, "expected '>' in tensor type"))
412 if (!TensorType::isValidElementType(elementType
))
413 return emitError(elementTypeLoc
, "invalid tensor element type"), nullptr;
417 return emitError("cannot apply encoding to unranked tensor"), nullptr;
418 return UnrankedTensorType::get(elementType
);
420 return RankedTensorType::get(dimensions
, elementType
, encoding
);
423 /// Parse a tuple type.
425 /// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
427 Type
Parser::parseTupleType() {
428 consumeToken(Token::kw_tuple
);
431 if (parseToken(Token::less
, "expected '<' in tuple type"))
434 // Check for an empty tuple by directly parsing '>'.
435 if (consumeIf(Token::greater
))
436 return TupleType::get(getContext());
438 // Parse the element types and the '>'.
439 SmallVector
<Type
, 4> types
;
440 if (parseTypeListNoParens(types
) ||
441 parseToken(Token::greater
, "expected '>' in tuple type"))
444 return TupleType::get(getContext(), types
);
447 /// Parse a vector type.
449 /// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
450 /// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
451 /// static-dim-list ::= decimal-literal (`x` decimal-literal)*
453 VectorType
Parser::parseVectorType() {
454 consumeToken(Token::kw_vector
);
456 if (parseToken(Token::less
, "expected '<' in vector type"))
459 SmallVector
<int64_t, 4> dimensions
;
460 SmallVector
<bool, 4> scalableDims
;
461 if (parseVectorDimensionList(dimensions
, scalableDims
))
463 if (any_of(dimensions
, [](int64_t i
) { return i
<= 0; }))
464 return emitError(getToken().getLoc(),
465 "vector types must have positive constant sizes"),
468 // Parse the element type.
469 auto typeLoc
= getToken().getLoc();
470 auto elementType
= parseType();
471 if (!elementType
|| parseToken(Token::greater
, "expected '>' in vector type"))
474 if (!VectorType::isValidElementType(elementType
))
475 return emitError(typeLoc
, "vector elements must be int/index/float type"),
478 return VectorType::get(dimensions
, elementType
, scalableDims
);
481 /// Parse a dimension list in a vector type. This populates the dimension list.
482 /// For i-th dimension, `scalableDims[i]` contains either:
483 /// * `false` for a non-scalable dimension (e.g. `4`),
484 /// * `true` for a scalable dimension (e.g. `[4]`).
486 /// vector-dim-list := (static-dim-list `x`)?
487 /// static-dim-list ::= static-dim (`x` static-dim)*
488 /// static-dim ::= (decimal-literal | `[` decimal-literal `]`)
491 Parser::parseVectorDimensionList(SmallVectorImpl
<int64_t> &dimensions
,
492 SmallVectorImpl
<bool> &scalableDims
) {
493 // If there is a set of fixed-length dimensions, consume it
494 while (getToken().is(Token::integer
) || getToken().is(Token::l_square
)) {
496 bool scalable
= consumeIf(Token::l_square
);
497 if (parseIntegerInDimensionList(value
))
499 dimensions
.push_back(value
);
501 if (!consumeIf(Token::r_square
))
502 return emitWrongTokenError("missing ']' closing scalable dimension");
504 scalableDims
.push_back(scalable
);
505 // Make sure we have an 'x' or something like 'xbf32'.
506 if (parseXInDimensionList())
513 /// Parse a dimension list of a tensor or memref type. This populates the
514 /// dimension list, using ShapedType::kDynamic for the `?` dimensions if
515 /// `allowDynamic` is set and errors out on `?` otherwise. Parsing the trailing
516 /// `x` is configurable.
518 /// dimension-list ::= eps | dimension (`x` dimension)*
519 /// dimension-list-with-trailing-x ::= (dimension `x`)*
520 /// dimension ::= `?` | decimal-literal
522 /// When `allowDynamic` is not set, this is used to parse:
524 /// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)*
525 /// static-dimension-list-with-trailing-x ::= (dimension `x`)*
527 Parser::parseDimensionListRanked(SmallVectorImpl
<int64_t> &dimensions
,
528 bool allowDynamic
, bool withTrailingX
) {
529 auto parseDim
= [&]() -> LogicalResult
{
530 auto loc
= getToken().getLoc();
531 if (consumeIf(Token::question
)) {
533 return emitError(loc
, "expected static shape");
534 dimensions
.push_back(ShapedType::kDynamic
);
537 if (failed(parseIntegerInDimensionList(value
)))
539 dimensions
.push_back(value
);
545 while (getToken().isAny(Token::integer
, Token::question
)) {
546 if (failed(parseDim()) || failed(parseXInDimensionList()))
552 if (getToken().isAny(Token::integer
, Token::question
)) {
553 if (failed(parseDim()))
555 while (getToken().is(Token::bare_identifier
) &&
556 getTokenSpelling()[0] == 'x') {
557 if (failed(parseXInDimensionList()) || failed(parseDim()))
564 ParseResult
Parser::parseIntegerInDimensionList(int64_t &value
) {
565 // Hexadecimal integer literals (starting with `0x`) are not allowed in
566 // aggregate type declarations. Therefore, `0xf32` should be processed as
567 // a sequence of separate elements `0`, `x`, `f32`.
568 if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
569 // We can get here only if the token is an integer literal. Hexadecimal
570 // integer literals can only start with `0x` (`1x` wouldn't lex as a
571 // literal, just `1` would, at which point we don't get into this
573 assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
575 state
.lex
.resetPointer(getTokenSpelling().data() + 1);
578 // Make sure this integer value is in bound and valid.
579 std::optional
<uint64_t> dimension
= getToken().getUInt64IntegerValue();
581 *dimension
> (uint64_t)std::numeric_limits
<int64_t>::max())
582 return emitError("invalid dimension");
583 value
= (int64_t)*dimension
;
584 consumeToken(Token::integer
);
589 /// Parse an 'x' token in a dimension list, handling the case where the x is
590 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next
592 ParseResult
Parser::parseXInDimensionList() {
593 if (getToken().isNot(Token::bare_identifier
) || getTokenSpelling()[0] != 'x')
594 return emitWrongTokenError("expected 'x' in dimension list");
596 // If we had a prefix of 'x', lex the next token immediately after the 'x'.
597 if (getTokenSpelling().size() != 1)
598 state
.lex
.resetPointer(getTokenSpelling().data() + 1);
601 consumeToken(Token::bare_identifier
);