1 //===- TypeParser.cpp - MLIR Type Parser Implementation -------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the parser for the MLIR Types.
11 //===----------------------------------------------------------------------===//
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/BuiltinAttributeInterfaces.h"
16 #include "mlir/IR/BuiltinTypeInterfaces.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/OpDefinition.h"
19 #include "mlir/IR/TensorEncoding.h"
20 #include "mlir/IR/Types.h"
21 #include "mlir/Support/LLVM.h"
22 #include "llvm/ADT/STLExtras.h"
29 using namespace mlir::detail
;
31 /// Optionally parse a type.
32 OptionalParseResult
Parser::parseOptionalType(Type
&type
) {
33 // There are many different starting tokens for a type, check them here.
34 switch (getToken().getKind()) {
36 case Token::kw_memref
:
37 case Token::kw_tensor
:
38 case Token::kw_complex
:
40 case Token::kw_vector
:
42 case Token::kw_f4E2M1FN
:
43 case Token::kw_f6E2M3FN
:
44 case Token::kw_f6E3M2FN
:
45 case Token::kw_f8E5M2
:
46 case Token::kw_f8E4M3
:
47 case Token::kw_f8E4M3FN
:
48 case Token::kw_f8E5M2FNUZ
:
49 case Token::kw_f8E4M3FNUZ
:
50 case Token::kw_f8E4M3B11FNUZ
:
51 case Token::kw_f8E3M4
:
52 case Token::kw_f8E8M0FNU
:
62 case Token::exclamation_identifier
:
63 return failure(!(type
= parseType()));
70 /// Parse an arbitrary type.
72 /// type ::= function-type
73 /// | non-function-type
75 Type
Parser::parseType() {
76 if (getToken().is(Token::l_paren
))
77 return parseFunctionType();
78 return parseNonFunctionType();
81 /// Parse a function result type.
83 /// function-result-type ::= type-list-parens
84 /// | non-function-type
86 ParseResult
Parser::parseFunctionResultTypes(SmallVectorImpl
<Type
> &elements
) {
87 if (getToken().is(Token::l_paren
))
88 return parseTypeListParens(elements
);
90 Type t
= parseNonFunctionType();
93 elements
.push_back(t
);
97 /// Parse a list of types without an enclosing parenthesis. The list must have
98 /// at least one member.
100 /// type-list-no-parens ::= type (`,` type)*
102 ParseResult
Parser::parseTypeListNoParens(SmallVectorImpl
<Type
> &elements
) {
103 auto parseElt
= [&]() -> ParseResult
{
104 auto elt
= parseType();
105 elements
.push_back(elt
);
106 return elt
? success() : failure();
109 return parseCommaSeparatedList(parseElt
);
112 /// Parse a parenthesized list of types.
114 /// type-list-parens ::= `(` `)`
115 /// | `(` type-list-no-parens `)`
117 ParseResult
Parser::parseTypeListParens(SmallVectorImpl
<Type
> &elements
) {
118 if (parseToken(Token::l_paren
, "expected '('"))
121 // Handle empty lists.
122 if (getToken().is(Token::r_paren
))
123 return consumeToken(), success();
125 if (parseTypeListNoParens(elements
) ||
126 parseToken(Token::r_paren
, "expected ')'"))
131 /// Parse a complex type.
133 /// complex-type ::= `complex` `<` type `>`
135 Type
Parser::parseComplexType() {
136 consumeToken(Token::kw_complex
);
139 if (parseToken(Token::less
, "expected '<' in complex type"))
142 SMLoc elementTypeLoc
= getToken().getLoc();
143 auto elementType
= parseType();
145 parseToken(Token::greater
, "expected '>' in complex type"))
147 if (!isa
<FloatType
>(elementType
) && !isa
<IntegerType
>(elementType
))
148 return emitError(elementTypeLoc
, "invalid element type for complex"),
151 return ComplexType::get(elementType
);
154 /// Parse a function type.
156 /// function-type ::= type-list-parens `->` function-result-type
158 Type
Parser::parseFunctionType() {
159 assert(getToken().is(Token::l_paren
));
161 SmallVector
<Type
, 4> arguments
, results
;
162 if (parseTypeListParens(arguments
) ||
163 parseToken(Token::arrow
, "expected '->' in function type") ||
164 parseFunctionResultTypes(results
))
167 return builder
.getFunctionType(arguments
, results
);
170 /// Parse a memref type.
172 /// memref-type ::= ranked-memref-type | unranked-memref-type
174 /// ranked-memref-type ::= `memref` `<` dimension-list-ranked type
175 /// (`,` layout-specification)? (`,` memory-space)? `>`
177 /// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
179 /// stride-list ::= `[` (dimension (`,` dimension)*)? `]`
180 /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
181 /// layout-specification ::= semi-affine-map | strided-layout | attribute
182 /// memory-space ::= integer-literal | attribute
184 Type
Parser::parseMemRefType() {
185 SMLoc loc
= getToken().getLoc();
186 consumeToken(Token::kw_memref
);
188 if (parseToken(Token::less
, "expected '<' in memref type"))
192 SmallVector
<int64_t, 4> dimensions
;
194 if (consumeIf(Token::star
)) {
195 // This is an unranked memref type.
197 if (parseXInDimensionList())
202 if (parseDimensionListRanked(dimensions
))
206 // Parse the element type.
207 auto typeLoc
= getToken().getLoc();
208 auto elementType
= parseType();
212 // Check that memref is formed from allowed types.
213 if (!BaseMemRefType::isValidElementType(elementType
))
214 return emitError(typeLoc
, "invalid memref element type"), nullptr;
216 MemRefLayoutAttrInterface layout
;
217 Attribute memorySpace
;
219 auto parseElt
= [&]() -> ParseResult
{
220 // Either it is MemRefLayoutAttrInterface or memory space attribute.
221 Attribute attr
= parseAttribute();
225 if (isa
<MemRefLayoutAttrInterface
>(attr
)) {
226 layout
= cast
<MemRefLayoutAttrInterface
>(attr
);
227 } else if (memorySpace
) {
228 return emitError("multiple memory spaces specified in memref type");
235 return emitError("cannot have affine map for unranked memref type");
237 return emitError("expected memory space to be last in memref type");
242 // Parse a list of mappings and address space if present.
243 if (!consumeIf(Token::greater
)) {
244 // Parse comma separated list of affine maps, followed by memory space.
245 if (parseToken(Token::comma
, "expected ',' or '>' in memref type") ||
246 parseCommaSeparatedListUntil(Token::greater
, parseElt
,
247 /*allowEmptyList=*/false)) {
253 return getChecked
<UnrankedMemRefType
>(loc
, elementType
, memorySpace
);
255 return getChecked
<MemRefType
>(loc
, dimensions
, elementType
, layout
,
259 /// Parse any type except the function type.
261 /// non-function-type ::= integer-type
272 /// index-type ::= `index`
273 /// float-type ::= `f16` | `bf16` | `f32` | `f64` | `f80` | `f128`
274 /// none-type ::= `none`
276 Type
Parser::parseNonFunctionType() {
277 switch (getToken().getKind()) {
279 return (emitWrongTokenError("expected non-function type"), nullptr);
280 case Token::kw_memref
:
281 return parseMemRefType();
282 case Token::kw_tensor
:
283 return parseTensorType();
284 case Token::kw_complex
:
285 return parseComplexType();
286 case Token::kw_tuple
:
287 return parseTupleType();
288 case Token::kw_vector
:
289 return parseVectorType();
291 case Token::inttype
: {
292 auto width
= getToken().getIntTypeBitwidth();
293 if (!width
.has_value())
294 return (emitError("invalid integer width"), nullptr);
295 if (*width
> IntegerType::kMaxWidth
) {
296 emitError(getToken().getLoc(), "integer bitwidth is limited to ")
297 << IntegerType::kMaxWidth
<< " bits";
301 IntegerType::SignednessSemantics signSemantics
= IntegerType::Signless
;
302 if (std::optional
<bool> signedness
= getToken().getIntTypeSignedness())
303 signSemantics
= *signedness
? IntegerType::Signed
: IntegerType::Unsigned
;
305 consumeToken(Token::inttype
);
306 return IntegerType::get(getContext(), *width
, signSemantics
);
310 case Token::kw_f4E2M1FN
:
311 consumeToken(Token::kw_f4E2M1FN
);
312 return builder
.getFloat4E2M1FNType();
313 case Token::kw_f6E2M3FN
:
314 consumeToken(Token::kw_f6E2M3FN
);
315 return builder
.getFloat6E2M3FNType();
316 case Token::kw_f6E3M2FN
:
317 consumeToken(Token::kw_f6E3M2FN
);
318 return builder
.getFloat6E3M2FNType();
319 case Token::kw_f8E5M2
:
320 consumeToken(Token::kw_f8E5M2
);
321 return builder
.getFloat8E5M2Type();
322 case Token::kw_f8E4M3
:
323 consumeToken(Token::kw_f8E4M3
);
324 return builder
.getFloat8E4M3Type();
325 case Token::kw_f8E4M3FN
:
326 consumeToken(Token::kw_f8E4M3FN
);
327 return builder
.getFloat8E4M3FNType();
328 case Token::kw_f8E5M2FNUZ
:
329 consumeToken(Token::kw_f8E5M2FNUZ
);
330 return builder
.getFloat8E5M2FNUZType();
331 case Token::kw_f8E4M3FNUZ
:
332 consumeToken(Token::kw_f8E4M3FNUZ
);
333 return builder
.getFloat8E4M3FNUZType();
334 case Token::kw_f8E4M3B11FNUZ
:
335 consumeToken(Token::kw_f8E4M3B11FNUZ
);
336 return builder
.getFloat8E4M3B11FNUZType();
337 case Token::kw_f8E3M4
:
338 consumeToken(Token::kw_f8E3M4
);
339 return builder
.getFloat8E3M4Type();
340 case Token::kw_f8E8M0FNU
:
341 consumeToken(Token::kw_f8E8M0FNU
);
342 return builder
.getFloat8E8M0FNUType();
344 consumeToken(Token::kw_bf16
);
345 return builder
.getBF16Type();
347 consumeToken(Token::kw_f16
);
348 return builder
.getF16Type();
350 consumeToken(Token::kw_tf32
);
351 return builder
.getTF32Type();
353 consumeToken(Token::kw_f32
);
354 return builder
.getF32Type();
356 consumeToken(Token::kw_f64
);
357 return builder
.getF64Type();
359 consumeToken(Token::kw_f80
);
360 return builder
.getF80Type();
362 consumeToken(Token::kw_f128
);
363 return builder
.getF128Type();
366 case Token::kw_index
:
367 consumeToken(Token::kw_index
);
368 return builder
.getIndexType();
372 consumeToken(Token::kw_none
);
373 return builder
.getNoneType();
376 case Token::exclamation_identifier
:
377 return parseExtendedType();
379 // Handle completion of a dialect type.
380 case Token::code_complete
:
381 if (getToken().isCodeCompletionFor(Token::exclamation_identifier
))
382 return parseExtendedType();
383 return codeCompleteType();
387 /// Parse a tensor type.
389 /// tensor-type ::= `tensor` `<` dimension-list type `>`
390 /// dimension-list ::= dimension-list-ranked | `*x`
392 Type
Parser::parseTensorType() {
393 consumeToken(Token::kw_tensor
);
395 if (parseToken(Token::less
, "expected '<' in tensor type"))
399 SmallVector
<int64_t, 4> dimensions
;
401 if (consumeIf(Token::star
)) {
402 // This is an unranked tensor type.
405 if (parseXInDimensionList())
410 if (parseDimensionListRanked(dimensions
))
414 // Parse the element type.
415 auto elementTypeLoc
= getToken().getLoc();
416 auto elementType
= parseType();
418 // Parse an optional encoding attribute.
420 if (consumeIf(Token::comma
)) {
421 auto parseResult
= parseOptionalAttribute(encoding
);
422 if (parseResult
.has_value()) {
423 if (failed(parseResult
.value()))
425 if (auto v
= dyn_cast_or_null
<VerifiableTensorEncoding
>(encoding
)) {
426 if (failed(v
.verifyEncoding(dimensions
, elementType
,
427 [&] { return emitError(); })))
433 if (!elementType
|| parseToken(Token::greater
, "expected '>' in tensor type"))
435 if (!TensorType::isValidElementType(elementType
))
436 return emitError(elementTypeLoc
, "invalid tensor element type"), nullptr;
440 return emitError("cannot apply encoding to unranked tensor"), nullptr;
441 return UnrankedTensorType::get(elementType
);
443 return RankedTensorType::get(dimensions
, elementType
, encoding
);
446 /// Parse a tuple type.
448 /// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
450 Type
Parser::parseTupleType() {
451 consumeToken(Token::kw_tuple
);
454 if (parseToken(Token::less
, "expected '<' in tuple type"))
457 // Check for an empty tuple by directly parsing '>'.
458 if (consumeIf(Token::greater
))
459 return TupleType::get(getContext());
461 // Parse the element types and the '>'.
462 SmallVector
<Type
, 4> types
;
463 if (parseTypeListNoParens(types
) ||
464 parseToken(Token::greater
, "expected '>' in tuple type"))
467 return TupleType::get(getContext(), types
);
470 /// Parse a vector type.
472 /// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
473 /// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
474 /// static-dim-list ::= decimal-literal (`x` decimal-literal)*
476 VectorType
Parser::parseVectorType() {
477 SMLoc loc
= getToken().getLoc();
478 consumeToken(Token::kw_vector
);
480 if (parseToken(Token::less
, "expected '<' in vector type"))
483 // Parse the dimensions.
484 SmallVector
<int64_t, 4> dimensions
;
485 SmallVector
<bool, 4> scalableDims
;
486 if (parseVectorDimensionList(dimensions
, scalableDims
))
489 // Parse the element type.
490 auto elementType
= parseType();
491 if (!elementType
|| parseToken(Token::greater
, "expected '>' in vector type"))
494 return getChecked
<VectorType
>(loc
, dimensions
, elementType
, scalableDims
);
497 /// Parse a dimension list in a vector type. This populates the dimension list.
498 /// For i-th dimension, `scalableDims[i]` contains either:
499 /// * `false` for a non-scalable dimension (e.g. `4`),
500 /// * `true` for a scalable dimension (e.g. `[4]`).
502 /// vector-dim-list := (static-dim-list `x`)?
503 /// static-dim-list ::= static-dim (`x` static-dim)*
504 /// static-dim ::= (decimal-literal | `[` decimal-literal `]`)
507 Parser::parseVectorDimensionList(SmallVectorImpl
<int64_t> &dimensions
,
508 SmallVectorImpl
<bool> &scalableDims
) {
509 // If there is a set of fixed-length dimensions, consume it
510 while (getToken().is(Token::integer
) || getToken().is(Token::l_square
)) {
512 bool scalable
= consumeIf(Token::l_square
);
513 if (parseIntegerInDimensionList(value
))
515 dimensions
.push_back(value
);
517 if (!consumeIf(Token::r_square
))
518 return emitWrongTokenError("missing ']' closing scalable dimension");
520 scalableDims
.push_back(scalable
);
521 // Make sure we have an 'x' or something like 'xbf32'.
522 if (parseXInDimensionList())
529 /// Parse a dimension list of a tensor or memref type. This populates the
530 /// dimension list, using ShapedType::kDynamic for the `?` dimensions if
531 /// `allowDynamic` is set and errors out on `?` otherwise. Parsing the trailing
532 /// `x` is configurable.
534 /// dimension-list ::= eps | dimension (`x` dimension)*
535 /// dimension-list-with-trailing-x ::= (dimension `x`)*
536 /// dimension ::= `?` | decimal-literal
538 /// When `allowDynamic` is not set, this is used to parse:
540 /// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)*
541 /// static-dimension-list-with-trailing-x ::= (dimension `x`)*
543 Parser::parseDimensionListRanked(SmallVectorImpl
<int64_t> &dimensions
,
544 bool allowDynamic
, bool withTrailingX
) {
545 auto parseDim
= [&]() -> LogicalResult
{
546 auto loc
= getToken().getLoc();
547 if (consumeIf(Token::question
)) {
549 return emitError(loc
, "expected static shape");
550 dimensions
.push_back(ShapedType::kDynamic
);
553 if (failed(parseIntegerInDimensionList(value
)))
555 dimensions
.push_back(value
);
561 while (getToken().isAny(Token::integer
, Token::question
)) {
562 if (failed(parseDim()) || failed(parseXInDimensionList()))
568 if (getToken().isAny(Token::integer
, Token::question
)) {
569 if (failed(parseDim()))
571 while (getToken().is(Token::bare_identifier
) &&
572 getTokenSpelling()[0] == 'x') {
573 if (failed(parseXInDimensionList()) || failed(parseDim()))
580 ParseResult
Parser::parseIntegerInDimensionList(int64_t &value
) {
581 // Hexadecimal integer literals (starting with `0x`) are not allowed in
582 // aggregate type declarations. Therefore, `0xf32` should be processed as
583 // a sequence of separate elements `0`, `x`, `f32`.
584 if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
585 // We can get here only if the token is an integer literal. Hexadecimal
586 // integer literals can only start with `0x` (`1x` wouldn't lex as a
587 // literal, just `1` would, at which point we don't get into this
589 assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
591 state
.lex
.resetPointer(getTokenSpelling().data() + 1);
594 // Make sure this integer value is in bound and valid.
595 std::optional
<uint64_t> dimension
= getToken().getUInt64IntegerValue();
597 *dimension
> (uint64_t)std::numeric_limits
<int64_t>::max())
598 return emitError("invalid dimension");
599 value
= (int64_t)*dimension
;
600 consumeToken(Token::integer
);
605 /// Parse an 'x' token in a dimension list, handling the case where the x is
606 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next
608 ParseResult
Parser::parseXInDimensionList() {
609 if (getToken().isNot(Token::bare_identifier
) || getTokenSpelling()[0] != 'x')
610 return emitWrongTokenError("expected 'x' in dimension list");
612 // If we had a prefix of 'x', lex the next token immediately after the 'x'.
613 if (getTokenSpelling().size() != 1)
614 state
.lex
.resetPointer(getTokenSpelling().data() + 1);
617 consumeToken(Token::bare_identifier
);