1 //===- DialectSymbolParser.cpp - MLIR Dialect Symbol Parser --------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the parser for the dialect symbols, such as extended
10 // attributes and types.
12 //===----------------------------------------------------------------------===//
14 #include "AsmParserImpl.h"
16 #include "mlir/AsmParser/AsmParserState.h"
17 #include "mlir/IR/AsmState.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/BuiltinAttributeInterfaces.h"
20 #include "mlir/IR/BuiltinAttributes.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Dialect.h"
23 #include "mlir/IR/DialectImplementation.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Support/LogicalResult.h"
27 #include "llvm/Support/MemoryBuffer.h"
28 #include "llvm/Support/SourceMgr.h"
34 using namespace mlir::detail
;
35 using llvm::MemoryBuffer
;
36 using llvm::SourceMgr
;
39 /// This class provides the main implementation of the DialectAsmParser that
40 /// allows for dialects to parse attributes and types. This allows for dialect
41 /// hooking into the main MLIR parsing logic.
42 class CustomDialectAsmParser
: public AsmParserImpl
<DialectAsmParser
> {
44 CustomDialectAsmParser(StringRef fullSpec
, Parser
&parser
)
45 : AsmParserImpl
<DialectAsmParser
>(parser
.getToken().getLoc(), parser
),
47 ~CustomDialectAsmParser() override
= default;
49 /// Returns the full specification of the symbol being parsed. This allows
50 /// for using a separate parser if necessary.
51 StringRef
getFullSymbolSpec() const override
{ return fullSpec
; }
54 /// The full symbol specification.
60 /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
61 /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body
62 /// | '(' pretty-dialect-sym-contents+ ')'
63 /// | '[' pretty-dialect-sym-contents+ ']'
64 /// | '{' pretty-dialect-sym-contents+ '}'
65 /// | '[^[<({>\])}\0]+'
67 ParseResult
Parser::parseDialectSymbolBody(StringRef
&body
,
68 bool &isCodeCompletion
) {
69 // Symbol bodies are a relatively unstructured format that contains a series
70 // of properly nested punctuation, with anything else in the middle. Scan
71 // ahead to find it and consume it if successful, otherwise emit an error.
72 const char *curPtr
= getTokenSpelling().data();
74 // Scan over the nested punctuation, bailing out on error and consuming until
75 // we find the end. We know that we're currently looking at the '<', so we can
76 // go until we find the matching '>' character.
77 assert(*curPtr
== '<');
78 SmallVector
<char, 8> nestedPunctuation
;
79 const char *codeCompleteLoc
= state
.lex
.getCodeCompleteLoc();
81 // Functor used to emit an unbalanced punctuation error.
82 auto emitPunctError
= [&] {
83 return emitError() << "unbalanced '" << nestedPunctuation
.back()
84 << "' character in pretty dialect name";
86 // Functor used to check for unbalanced punctuation.
87 auto checkNestedPunctuation
= [&](char expectedToken
) -> ParseResult
{
88 if (nestedPunctuation
.back() != expectedToken
)
89 return emitPunctError();
90 nestedPunctuation
.pop_back();
94 // Handle code completions, which may appear in the middle of the symbol
96 if (curPtr
== codeCompleteLoc
) {
97 isCodeCompletion
= true;
98 nestedPunctuation
.clear();
105 // This also handles the EOF case.
106 if (!nestedPunctuation
.empty())
107 return emitPunctError();
108 return emitError("unexpected nul or EOF in pretty dialect name");
113 nestedPunctuation
.push_back(c
);
117 // The sequence `->` is treated as special token.
123 if (failed(checkNestedPunctuation('<')))
127 if (failed(checkNestedPunctuation('[')))
131 if (failed(checkNestedPunctuation('(')))
135 if (failed(checkNestedPunctuation('{')))
139 // Dispatch to the lexer to lex past strings.
140 resetToken(curPtr
- 1);
141 curPtr
= state
.curToken
.getEndLoc().getPointer();
143 // Handle code completions, which may appear in the middle of the symbol
145 if (state
.curToken
.isCodeCompletion()) {
146 isCodeCompletion
= true;
147 nestedPunctuation
.clear();
151 // Otherwise, ensure this token was actually a string.
152 if (state
.curToken
.isNot(Token::string
))
160 } while (!nestedPunctuation
.empty());
162 // Ok, we succeeded, remember where we stopped, reset the lexer to know it is
163 // consuming all this stuff, and return.
166 unsigned length
= curPtr
- body
.begin();
167 body
= StringRef(body
.data(), length
);
171 /// Parse an extended dialect symbol.
172 template <typename Symbol
, typename SymbolAliasMap
, typename CreateFn
>
173 static Symbol
parseExtendedSymbol(Parser
&p
, AsmParserState
*asmState
,
174 SymbolAliasMap
&aliases
,
175 CreateFn
&&createSymbol
) {
176 Token tok
= p
.getToken();
178 // Handle code completion of the extended symbol.
179 StringRef identifier
= tok
.getSpelling().drop_front();
180 if (tok
.isCodeCompletion() && identifier
.empty())
181 return p
.codeCompleteDialectSymbol(aliases
);
183 // Parse the dialect namespace.
184 SMRange range
= p
.getToken().getLocRange();
185 SMLoc loc
= p
.getToken().getLoc();
188 // Check to see if this is a pretty name.
189 auto [dialectName
, symbolData
] = identifier
.split('.');
190 bool isPrettyName
= !symbolData
.empty() || identifier
.back() == '.';
192 // Check to see if the symbol has trailing data, i.e. has an immediately
194 bool hasTrailingData
=
195 p
.getToken().is(Token::less
) &&
196 identifier
.bytes_end() == p
.getTokenSpelling().bytes_begin();
198 // If there is no '<' token following this, and if the typename contains no
199 // dot, then we are parsing a symbol alias.
200 if (!hasTrailingData
&& !isPrettyName
) {
201 // Check for an alias for this type.
202 auto aliasIt
= aliases
.find(identifier
);
203 if (aliasIt
== aliases
.end())
204 return (p
.emitWrongTokenError("undefined symbol alias id '" + identifier
+
208 if constexpr (std::is_same_v
<Symbol
, Type
>)
209 asmState
->addTypeAliasUses(identifier
, range
);
211 asmState
->addAttrAliasUses(identifier
, range
);
213 return aliasIt
->second
;
216 // If this isn't an alias, we are parsing a dialect-specific symbol. If the
217 // name contains a dot, then this is the "pretty" form. If not, it is the
218 // verbose form that looks like <...>.
220 // Point the symbol data to the end of the dialect name to start.
221 symbolData
= StringRef(dialectName
.end(), 0);
223 // Parse the body of the symbol.
224 bool isCodeCompletion
= false;
225 if (p
.parseDialectSymbolBody(symbolData
, isCodeCompletion
))
227 symbolData
= symbolData
.drop_front();
229 // If the body contained a code completion it won't have the trailing `>`
230 // token, so don't drop it.
231 if (!isCodeCompletion
)
232 symbolData
= symbolData
.drop_back();
234 loc
= SMLoc::getFromPointer(symbolData
.data());
236 // If the dialect's symbol is followed immediately by a <, then lex the body
237 // of it into prettyName.
238 if (hasTrailingData
&& p
.parseDialectSymbolBody(symbolData
))
242 return createSymbol(dialectName
, symbolData
, loc
);
245 /// Parse an extended attribute.
247 /// extended-attribute ::= (dialect-attribute | attribute-alias)
248 /// dialect-attribute ::= `#` dialect-namespace `<` attr-data `>`
250 /// | `#` alias-name pretty-dialect-sym-body? (`:` type)?
251 /// attribute-alias ::= `#` alias-name
253 Attribute
Parser::parseExtendedAttr(Type type
) {
254 MLIRContext
*ctx
= getContext();
255 Attribute attr
= parseExtendedSymbol
<Attribute
>(
256 *this, state
.asmState
, state
.symbols
.attributeAliasDefinitions
,
257 [&](StringRef dialectName
, StringRef symbolData
, SMLoc loc
) -> Attribute
{
258 // Parse an optional trailing colon type.
259 Type attrType
= type
;
260 if (consumeIf(Token::colon
) && !(attrType
= parseType()))
263 // If we found a registered dialect, then ask it to parse the attribute.
264 if (Dialect
*dialect
=
265 builder
.getContext()->getOrLoadDialect(dialectName
)) {
266 // Temporarily reset the lexer to let the dialect parse the attribute.
267 const char *curLexerPos
= getToken().getLoc().getPointer();
268 resetToken(symbolData
.data());
270 // Parse the attribute.
271 CustomDialectAsmParser
customParser(symbolData
, *this);
272 Attribute attr
= dialect
->parseAttribute(customParser
, attrType
);
273 resetToken(curLexerPos
);
277 // Otherwise, form a new opaque attribute.
278 return OpaqueAttr::getChecked(
279 [&] { return emitError(loc
); }, StringAttr::get(ctx
, dialectName
),
280 symbolData
, attrType
? attrType
: NoneType::get(ctx
));
283 // Ensure that the attribute has the same type as requested.
284 auto typedAttr
= dyn_cast_or_null
<TypedAttr
>(attr
);
285 if (type
&& typedAttr
&& typedAttr
.getType() != type
) {
286 emitError("attribute type different than expected: expected ")
287 << type
<< ", but got " << typedAttr
.getType();
293 /// Parse an extended type.
295 /// extended-type ::= (dialect-type | type-alias)
296 /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>`
297 /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body?
298 /// type-alias ::= `!` alias-name
300 Type
Parser::parseExtendedType() {
301 MLIRContext
*ctx
= getContext();
302 return parseExtendedSymbol
<Type
>(
303 *this, state
.asmState
, state
.symbols
.typeAliasDefinitions
,
304 [&](StringRef dialectName
, StringRef symbolData
, SMLoc loc
) -> Type
{
305 // If we found a registered dialect, then ask it to parse the type.
306 if (auto *dialect
= ctx
->getOrLoadDialect(dialectName
)) {
307 // Temporarily reset the lexer to let the dialect parse the type.
308 const char *curLexerPos
= getToken().getLoc().getPointer();
309 resetToken(symbolData
.data());
312 CustomDialectAsmParser
customParser(symbolData
, *this);
313 Type type
= dialect
->parseType(customParser
);
314 resetToken(curLexerPos
);
318 // Otherwise, form a new opaque type.
319 return OpaqueType::getChecked([&] { return emitError(loc
); },
320 StringAttr::get(ctx
, dialectName
),
325 //===----------------------------------------------------------------------===//
326 // mlir::parseAttribute/parseType
327 //===----------------------------------------------------------------------===//
329 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If
330 /// parsing failed, nullptr is returned.
331 template <typename T
, typename ParserFn
>
332 static T
parseSymbol(StringRef inputStr
, MLIRContext
*context
,
333 size_t *numReadOut
, bool isKnownNullTerminated
,
334 ParserFn
&&parserFn
) {
335 // Set the buffer name to the string being parsed, so that it appears in error
338 isKnownNullTerminated
339 ? MemoryBuffer::getMemBuffer(inputStr
,
340 /*BufferName=*/inputStr
)
341 : MemoryBuffer::getMemBufferCopy(inputStr
, /*BufferName=*/inputStr
);
343 sourceMgr
.AddNewSourceBuffer(std::move(memBuffer
), SMLoc());
344 SymbolState aliasState
;
345 ParserConfig
config(context
);
346 ParserState
state(sourceMgr
, config
, aliasState
, /*asmState=*/nullptr,
347 /*codeCompleteContext=*/nullptr);
348 Parser
parser(state
);
350 Token startTok
= parser
.getToken();
351 T symbol
= parserFn(parser
);
355 // Provide the number of bytes that were read.
356 Token endTok
= parser
.getToken();
358 endTok
.getLoc().getPointer() - startTok
.getLoc().getPointer();
360 *numReadOut
= numRead
;
361 } else if (numRead
!= inputStr
.size()) {
362 parser
.emitError(endTok
.getLoc()) << "found trailing characters: '"
363 << inputStr
.drop_front(numRead
) << "'";
369 Attribute
mlir::parseAttribute(StringRef attrStr
, MLIRContext
*context
,
370 Type type
, size_t *numRead
,
371 bool isKnownNullTerminated
) {
372 return parseSymbol
<Attribute
>(
373 attrStr
, context
, numRead
, isKnownNullTerminated
,
374 [type
](Parser
&parser
) { return parser
.parseAttribute(type
); });
376 Type
mlir::parseType(StringRef typeStr
, MLIRContext
*context
, size_t *numRead
,
377 bool isKnownNullTerminated
) {
378 return parseSymbol
<Type
>(typeStr
, context
, numRead
, isKnownNullTerminated
,
379 [](Parser
&parser
) { return parser
.parseType(); });