1 //===- DialectSymbolParser.cpp - MLIR Dialect Symbol Parser --------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements the parser for the dialect symbols, such as extended
10 // attributes and types.
12 //===----------------------------------------------------------------------===//
14 #include "AsmParserImpl.h"
16 #include "mlir/AsmParser/AsmParserState.h"
17 #include "mlir/IR/AsmState.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/BuiltinAttributeInterfaces.h"
20 #include "mlir/IR/BuiltinAttributes.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Dialect.h"
23 #include "mlir/IR/DialectImplementation.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/Support/LLVM.h"
26 #include "llvm/Support/MemoryBuffer.h"
27 #include "llvm/Support/SourceMgr.h"
33 using namespace mlir::detail
;
34 using llvm::MemoryBuffer
;
35 using llvm::SourceMgr
;
38 /// This class provides the main implementation of the DialectAsmParser that
39 /// allows for dialects to parse attributes and types. This allows for dialect
40 /// hooking into the main MLIR parsing logic.
41 class CustomDialectAsmParser
: public AsmParserImpl
<DialectAsmParser
> {
43 CustomDialectAsmParser(StringRef fullSpec
, Parser
&parser
)
44 : AsmParserImpl
<DialectAsmParser
>(parser
.getToken().getLoc(), parser
),
46 ~CustomDialectAsmParser() override
= default;
48 /// Returns the full specification of the symbol being parsed. This allows
49 /// for using a separate parser if necessary.
50 StringRef
getFullSymbolSpec() const override
{ return fullSpec
; }
53 /// The full symbol specification.
59 /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
60 /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body
61 /// | '(' pretty-dialect-sym-contents+ ')'
62 /// | '[' pretty-dialect-sym-contents+ ']'
63 /// | '{' pretty-dialect-sym-contents+ '}'
64 /// | '[^[<({>\])}\0]+'
66 ParseResult
Parser::parseDialectSymbolBody(StringRef
&body
,
67 bool &isCodeCompletion
) {
68 // Symbol bodies are a relatively unstructured format that contains a series
69 // of properly nested punctuation, with anything else in the middle. Scan
70 // ahead to find it and consume it if successful, otherwise emit an error.
71 const char *curPtr
= getTokenSpelling().data();
73 // Scan over the nested punctuation, bailing out on error and consuming until
74 // we find the end. We know that we're currently looking at the '<', so we can
75 // go until we find the matching '>' character.
76 assert(*curPtr
== '<');
77 SmallVector
<char, 8> nestedPunctuation
;
78 const char *codeCompleteLoc
= state
.lex
.getCodeCompleteLoc();
80 // Functor used to emit an unbalanced punctuation error.
81 auto emitPunctError
= [&] {
82 return emitError() << "unbalanced '" << nestedPunctuation
.back()
83 << "' character in pretty dialect name";
85 // Functor used to check for unbalanced punctuation.
86 auto checkNestedPunctuation
= [&](char expectedToken
) -> ParseResult
{
87 if (nestedPunctuation
.back() != expectedToken
)
88 return emitPunctError();
89 nestedPunctuation
.pop_back();
93 // Handle code completions, which may appear in the middle of the symbol
95 if (curPtr
== codeCompleteLoc
) {
96 isCodeCompletion
= true;
97 nestedPunctuation
.clear();
104 // This also handles the EOF case.
105 if (!nestedPunctuation
.empty())
106 return emitPunctError();
107 return emitError("unexpected nul or EOF in pretty dialect name");
112 nestedPunctuation
.push_back(c
);
116 // The sequence `->` is treated as special token.
122 if (failed(checkNestedPunctuation('<')))
126 if (failed(checkNestedPunctuation('[')))
130 if (failed(checkNestedPunctuation('(')))
134 if (failed(checkNestedPunctuation('{')))
138 // Dispatch to the lexer to lex past strings.
139 resetToken(curPtr
- 1);
140 curPtr
= state
.curToken
.getEndLoc().getPointer();
142 // Handle code completions, which may appear in the middle of the symbol
144 if (state
.curToken
.isCodeCompletion()) {
145 isCodeCompletion
= true;
146 nestedPunctuation
.clear();
150 // Otherwise, ensure this token was actually a string.
151 if (state
.curToken
.isNot(Token::string
))
159 } while (!nestedPunctuation
.empty());
161 // Ok, we succeeded, remember where we stopped, reset the lexer to know it is
162 // consuming all this stuff, and return.
165 unsigned length
= curPtr
- body
.begin();
166 body
= StringRef(body
.data(), length
);
170 /// Parse an extended dialect symbol.
171 template <typename Symbol
, typename SymbolAliasMap
, typename CreateFn
>
172 static Symbol
parseExtendedSymbol(Parser
&p
, AsmParserState
*asmState
,
173 SymbolAliasMap
&aliases
,
174 CreateFn
&&createSymbol
) {
175 Token tok
= p
.getToken();
177 // Handle code completion of the extended symbol.
178 StringRef identifier
= tok
.getSpelling().drop_front();
179 if (tok
.isCodeCompletion() && identifier
.empty())
180 return p
.codeCompleteDialectSymbol(aliases
);
182 // Parse the dialect namespace.
183 SMRange range
= p
.getToken().getLocRange();
184 SMLoc loc
= p
.getToken().getLoc();
187 // Check to see if this is a pretty name.
188 auto [dialectName
, symbolData
] = identifier
.split('.');
189 bool isPrettyName
= !symbolData
.empty() || identifier
.back() == '.';
191 // Check to see if the symbol has trailing data, i.e. has an immediately
193 bool hasTrailingData
=
194 p
.getToken().is(Token::less
) &&
195 identifier
.bytes_end() == p
.getTokenSpelling().bytes_begin();
197 // If there is no '<' token following this, and if the typename contains no
198 // dot, then we are parsing a symbol alias.
199 if (!hasTrailingData
&& !isPrettyName
) {
200 // Check for an alias for this type.
201 auto aliasIt
= aliases
.find(identifier
);
202 if (aliasIt
== aliases
.end())
203 return (p
.emitWrongTokenError("undefined symbol alias id '" + identifier
+
207 if constexpr (std::is_same_v
<Symbol
, Type
>)
208 asmState
->addTypeAliasUses(identifier
, range
);
210 asmState
->addAttrAliasUses(identifier
, range
);
212 return aliasIt
->second
;
215 // If this isn't an alias, we are parsing a dialect-specific symbol. If the
216 // name contains a dot, then this is the "pretty" form. If not, it is the
217 // verbose form that looks like <...>.
219 // Point the symbol data to the end of the dialect name to start.
220 symbolData
= StringRef(dialectName
.end(), 0);
222 // Parse the body of the symbol.
223 bool isCodeCompletion
= false;
224 if (p
.parseDialectSymbolBody(symbolData
, isCodeCompletion
))
226 symbolData
= symbolData
.drop_front();
228 // If the body contained a code completion it won't have the trailing `>`
229 // token, so don't drop it.
230 if (!isCodeCompletion
)
231 symbolData
= symbolData
.drop_back();
233 loc
= SMLoc::getFromPointer(symbolData
.data());
235 // If the dialect's symbol is followed immediately by a <, then lex the body
236 // of it into prettyName.
237 if (hasTrailingData
&& p
.parseDialectSymbolBody(symbolData
))
241 return createSymbol(dialectName
, symbolData
, loc
);
244 /// Parse an extended attribute.
246 /// extended-attribute ::= (dialect-attribute | attribute-alias)
247 /// dialect-attribute ::= `#` dialect-namespace `<` attr-data `>`
249 /// | `#` alias-name pretty-dialect-sym-body? (`:` type)?
250 /// attribute-alias ::= `#` alias-name
252 Attribute
Parser::parseExtendedAttr(Type type
) {
253 MLIRContext
*ctx
= getContext();
254 Attribute attr
= parseExtendedSymbol
<Attribute
>(
255 *this, state
.asmState
, state
.symbols
.attributeAliasDefinitions
,
256 [&](StringRef dialectName
, StringRef symbolData
, SMLoc loc
) -> Attribute
{
257 // Parse an optional trailing colon type.
258 Type attrType
= type
;
259 if (consumeIf(Token::colon
) && !(attrType
= parseType()))
262 // If we found a registered dialect, then ask it to parse the attribute.
263 if (Dialect
*dialect
=
264 builder
.getContext()->getOrLoadDialect(dialectName
)) {
265 // Temporarily reset the lexer to let the dialect parse the attribute.
266 const char *curLexerPos
= getToken().getLoc().getPointer();
267 resetToken(symbolData
.data());
269 // Parse the attribute.
270 CustomDialectAsmParser
customParser(symbolData
, *this);
271 Attribute attr
= dialect
->parseAttribute(customParser
, attrType
);
272 resetToken(curLexerPos
);
276 // Otherwise, form a new opaque attribute.
277 return OpaqueAttr::getChecked(
278 [&] { return emitError(loc
); }, StringAttr::get(ctx
, dialectName
),
279 symbolData
, attrType
? attrType
: NoneType::get(ctx
));
282 // Ensure that the attribute has the same type as requested.
283 auto typedAttr
= dyn_cast_or_null
<TypedAttr
>(attr
);
284 if (type
&& typedAttr
&& typedAttr
.getType() != type
) {
285 emitError("attribute type different than expected: expected ")
286 << type
<< ", but got " << typedAttr
.getType();
292 /// Parse an extended type.
294 /// extended-type ::= (dialect-type | type-alias)
295 /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>`
296 /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body?
297 /// type-alias ::= `!` alias-name
299 Type
Parser::parseExtendedType() {
300 MLIRContext
*ctx
= getContext();
301 return parseExtendedSymbol
<Type
>(
302 *this, state
.asmState
, state
.symbols
.typeAliasDefinitions
,
303 [&](StringRef dialectName
, StringRef symbolData
, SMLoc loc
) -> Type
{
304 // If we found a registered dialect, then ask it to parse the type.
305 if (auto *dialect
= ctx
->getOrLoadDialect(dialectName
)) {
306 // Temporarily reset the lexer to let the dialect parse the type.
307 const char *curLexerPos
= getToken().getLoc().getPointer();
308 resetToken(symbolData
.data());
311 CustomDialectAsmParser
customParser(symbolData
, *this);
312 Type type
= dialect
->parseType(customParser
);
313 resetToken(curLexerPos
);
317 // Otherwise, form a new opaque type.
318 return OpaqueType::getChecked([&] { return emitError(loc
); },
319 StringAttr::get(ctx
, dialectName
),
324 //===----------------------------------------------------------------------===//
325 // mlir::parseAttribute/parseType
326 //===----------------------------------------------------------------------===//
328 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If
329 /// parsing failed, nullptr is returned.
330 template <typename T
, typename ParserFn
>
331 static T
parseSymbol(StringRef inputStr
, MLIRContext
*context
,
332 size_t *numReadOut
, bool isKnownNullTerminated
,
333 ParserFn
&&parserFn
) {
334 // Set the buffer name to the string being parsed, so that it appears in error
337 isKnownNullTerminated
338 ? MemoryBuffer::getMemBuffer(inputStr
,
339 /*BufferName=*/inputStr
)
340 : MemoryBuffer::getMemBufferCopy(inputStr
, /*BufferName=*/inputStr
);
342 sourceMgr
.AddNewSourceBuffer(std::move(memBuffer
), SMLoc());
343 SymbolState aliasState
;
344 ParserConfig
config(context
);
345 ParserState
state(sourceMgr
, config
, aliasState
, /*asmState=*/nullptr,
346 /*codeCompleteContext=*/nullptr);
347 Parser
parser(state
);
349 Token startTok
= parser
.getToken();
350 T symbol
= parserFn(parser
);
354 // Provide the number of bytes that were read.
355 Token endTok
= parser
.getToken();
357 endTok
.getLoc().getPointer() - startTok
.getLoc().getPointer();
359 *numReadOut
= numRead
;
360 } else if (numRead
!= inputStr
.size()) {
361 parser
.emitError(endTok
.getLoc()) << "found trailing characters: '"
362 << inputStr
.drop_front(numRead
) << "'";
368 Attribute
mlir::parseAttribute(StringRef attrStr
, MLIRContext
*context
,
369 Type type
, size_t *numRead
,
370 bool isKnownNullTerminated
) {
371 return parseSymbol
<Attribute
>(
372 attrStr
, context
, numRead
, isKnownNullTerminated
,
373 [type
](Parser
&parser
) { return parser
.parseAttribute(type
); });
375 Type
mlir::parseType(StringRef typeStr
, MLIRContext
*context
, size_t *numRead
,
376 bool isKnownNullTerminated
) {
377 return parseSymbol
<Type
>(typeStr
, context
, numRead
, isKnownNullTerminated
,
378 [](Parser
&parser
) { return parser
.parseType(); });