[rtsan] Remove mkfifoat interceptor (#116997)
[llvm-project.git] / mlir / lib / AsmParser / DialectSymbolParser.cpp
blob9f4a87a6a02de5c083dbb487512abcb97b49329c
1 //===- DialectSymbolParser.cpp - MLIR Dialect Symbol Parser --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the dialect symbols, such as extended
10 // attributes and types.
12 //===----------------------------------------------------------------------===//
14 #include "AsmParserImpl.h"
15 #include "Parser.h"
16 #include "mlir/AsmParser/AsmParserState.h"
17 #include "mlir/IR/AsmState.h"
18 #include "mlir/IR/Attributes.h"
19 #include "mlir/IR/BuiltinAttributeInterfaces.h"
20 #include "mlir/IR/BuiltinAttributes.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/Dialect.h"
23 #include "mlir/IR/DialectImplementation.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/Support/LLVM.h"
26 #include "llvm/Support/MemoryBuffer.h"
27 #include "llvm/Support/SourceMgr.h"
28 #include <cassert>
29 #include <cstddef>
30 #include <utility>
32 using namespace mlir;
33 using namespace mlir::detail;
34 using llvm::MemoryBuffer;
35 using llvm::SourceMgr;
37 namespace {
38 /// This class provides the main implementation of the DialectAsmParser that
39 /// allows for dialects to parse attributes and types. This allows for dialect
40 /// hooking into the main MLIR parsing logic.
41 class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> {
42 public:
43 CustomDialectAsmParser(StringRef fullSpec, Parser &parser)
44 : AsmParserImpl<DialectAsmParser>(parser.getToken().getLoc(), parser),
45 fullSpec(fullSpec) {}
46 ~CustomDialectAsmParser() override = default;
48 /// Returns the full specification of the symbol being parsed. This allows
49 /// for using a separate parser if necessary.
50 StringRef getFullSymbolSpec() const override { return fullSpec; }
52 private:
53 /// The full symbol specification.
54 StringRef fullSpec;
56 } // namespace
58 ///
59 /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
60 /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body
61 /// | '(' pretty-dialect-sym-contents+ ')'
62 /// | '[' pretty-dialect-sym-contents+ ']'
63 /// | '{' pretty-dialect-sym-contents+ '}'
64 /// | '[^[<({>\])}\0]+'
65 ///
66 ParseResult Parser::parseDialectSymbolBody(StringRef &body,
67 bool &isCodeCompletion) {
68 // Symbol bodies are a relatively unstructured format that contains a series
69 // of properly nested punctuation, with anything else in the middle. Scan
70 // ahead to find it and consume it if successful, otherwise emit an error.
71 const char *curPtr = getTokenSpelling().data();
73 // Scan over the nested punctuation, bailing out on error and consuming until
74 // we find the end. We know that we're currently looking at the '<', so we can
75 // go until we find the matching '>' character.
76 assert(*curPtr == '<');
77 SmallVector<char, 8> nestedPunctuation;
78 const char *codeCompleteLoc = state.lex.getCodeCompleteLoc();
80 // Functor used to emit an unbalanced punctuation error.
81 auto emitPunctError = [&] {
82 return emitError() << "unbalanced '" << nestedPunctuation.back()
83 << "' character in pretty dialect name";
85 // Functor used to check for unbalanced punctuation.
86 auto checkNestedPunctuation = [&](char expectedToken) -> ParseResult {
87 if (nestedPunctuation.back() != expectedToken)
88 return emitPunctError();
89 nestedPunctuation.pop_back();
90 return success();
92 do {
93 // Handle code completions, which may appear in the middle of the symbol
94 // body.
95 if (curPtr == codeCompleteLoc) {
96 isCodeCompletion = true;
97 nestedPunctuation.clear();
98 break;
101 char c = *curPtr++;
102 switch (c) {
103 case '\0':
104 // This also handles the EOF case.
105 if (!nestedPunctuation.empty())
106 return emitPunctError();
107 return emitError("unexpected nul or EOF in pretty dialect name");
108 case '<':
109 case '[':
110 case '(':
111 case '{':
112 nestedPunctuation.push_back(c);
113 continue;
115 case '-':
116 // The sequence `->` is treated as special token.
117 if (*curPtr == '>')
118 ++curPtr;
119 continue;
121 case '>':
122 if (failed(checkNestedPunctuation('<')))
123 return failure();
124 break;
125 case ']':
126 if (failed(checkNestedPunctuation('[')))
127 return failure();
128 break;
129 case ')':
130 if (failed(checkNestedPunctuation('(')))
131 return failure();
132 break;
133 case '}':
134 if (failed(checkNestedPunctuation('{')))
135 return failure();
136 break;
137 case '"': {
138 // Dispatch to the lexer to lex past strings.
139 resetToken(curPtr - 1);
140 curPtr = state.curToken.getEndLoc().getPointer();
142 // Handle code completions, which may appear in the middle of the symbol
143 // body.
144 if (state.curToken.isCodeCompletion()) {
145 isCodeCompletion = true;
146 nestedPunctuation.clear();
147 break;
150 // Otherwise, ensure this token was actually a string.
151 if (state.curToken.isNot(Token::string))
152 return failure();
153 break;
156 default:
157 continue;
159 } while (!nestedPunctuation.empty());
161 // Ok, we succeeded, remember where we stopped, reset the lexer to know it is
162 // consuming all this stuff, and return.
163 resetToken(curPtr);
165 unsigned length = curPtr - body.begin();
166 body = StringRef(body.data(), length);
167 return success();
170 /// Parse an extended dialect symbol.
171 template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
172 static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
173 SymbolAliasMap &aliases,
174 CreateFn &&createSymbol) {
175 Token tok = p.getToken();
177 // Handle code completion of the extended symbol.
178 StringRef identifier = tok.getSpelling().drop_front();
179 if (tok.isCodeCompletion() && identifier.empty())
180 return p.codeCompleteDialectSymbol(aliases);
182 // Parse the dialect namespace.
183 SMRange range = p.getToken().getLocRange();
184 SMLoc loc = p.getToken().getLoc();
185 p.consumeToken();
187 // Check to see if this is a pretty name.
188 auto [dialectName, symbolData] = identifier.split('.');
189 bool isPrettyName = !symbolData.empty() || identifier.back() == '.';
191 // Check to see if the symbol has trailing data, i.e. has an immediately
192 // following '<'.
193 bool hasTrailingData =
194 p.getToken().is(Token::less) &&
195 identifier.bytes_end() == p.getTokenSpelling().bytes_begin();
197 // If there is no '<' token following this, and if the typename contains no
198 // dot, then we are parsing a symbol alias.
199 if (!hasTrailingData && !isPrettyName) {
200 // Check for an alias for this type.
201 auto aliasIt = aliases.find(identifier);
202 if (aliasIt == aliases.end())
203 return (p.emitWrongTokenError("undefined symbol alias id '" + identifier +
204 "'"),
205 nullptr);
206 if (asmState) {
207 if constexpr (std::is_same_v<Symbol, Type>)
208 asmState->addTypeAliasUses(identifier, range);
209 else
210 asmState->addAttrAliasUses(identifier, range);
212 return aliasIt->second;
215 // If this isn't an alias, we are parsing a dialect-specific symbol. If the
216 // name contains a dot, then this is the "pretty" form. If not, it is the
217 // verbose form that looks like <...>.
218 if (!isPrettyName) {
219 // Point the symbol data to the end of the dialect name to start.
220 symbolData = StringRef(dialectName.end(), 0);
222 // Parse the body of the symbol.
223 bool isCodeCompletion = false;
224 if (p.parseDialectSymbolBody(symbolData, isCodeCompletion))
225 return nullptr;
226 symbolData = symbolData.drop_front();
228 // If the body contained a code completion it won't have the trailing `>`
229 // token, so don't drop it.
230 if (!isCodeCompletion)
231 symbolData = symbolData.drop_back();
232 } else {
233 loc = SMLoc::getFromPointer(symbolData.data());
235 // If the dialect's symbol is followed immediately by a <, then lex the body
236 // of it into prettyName.
237 if (hasTrailingData && p.parseDialectSymbolBody(symbolData))
238 return nullptr;
241 return createSymbol(dialectName, symbolData, loc);
244 /// Parse an extended attribute.
246 /// extended-attribute ::= (dialect-attribute | attribute-alias)
247 /// dialect-attribute ::= `#` dialect-namespace `<` attr-data `>`
248 /// (`:` type)?
249 /// | `#` alias-name pretty-dialect-sym-body? (`:` type)?
250 /// attribute-alias ::= `#` alias-name
252 Attribute Parser::parseExtendedAttr(Type type) {
253 MLIRContext *ctx = getContext();
254 Attribute attr = parseExtendedSymbol<Attribute>(
255 *this, state.asmState, state.symbols.attributeAliasDefinitions,
256 [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute {
257 // Parse an optional trailing colon type.
258 Type attrType = type;
259 if (consumeIf(Token::colon) && !(attrType = parseType()))
260 return Attribute();
262 // If we found a registered dialect, then ask it to parse the attribute.
263 if (Dialect *dialect =
264 builder.getContext()->getOrLoadDialect(dialectName)) {
265 // Temporarily reset the lexer to let the dialect parse the attribute.
266 const char *curLexerPos = getToken().getLoc().getPointer();
267 resetToken(symbolData.data());
269 // Parse the attribute.
270 CustomDialectAsmParser customParser(symbolData, *this);
271 Attribute attr = dialect->parseAttribute(customParser, attrType);
272 resetToken(curLexerPos);
273 return attr;
276 // Otherwise, form a new opaque attribute.
277 return OpaqueAttr::getChecked(
278 [&] { return emitError(loc); }, StringAttr::get(ctx, dialectName),
279 symbolData, attrType ? attrType : NoneType::get(ctx));
282 // Ensure that the attribute has the same type as requested.
283 auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
284 if (type && typedAttr && typedAttr.getType() != type) {
285 emitError("attribute type different than expected: expected ")
286 << type << ", but got " << typedAttr.getType();
287 return nullptr;
289 return attr;
292 /// Parse an extended type.
294 /// extended-type ::= (dialect-type | type-alias)
295 /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>`
296 /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body?
297 /// type-alias ::= `!` alias-name
299 Type Parser::parseExtendedType() {
300 MLIRContext *ctx = getContext();
301 return parseExtendedSymbol<Type>(
302 *this, state.asmState, state.symbols.typeAliasDefinitions,
303 [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type {
304 // If we found a registered dialect, then ask it to parse the type.
305 if (auto *dialect = ctx->getOrLoadDialect(dialectName)) {
306 // Temporarily reset the lexer to let the dialect parse the type.
307 const char *curLexerPos = getToken().getLoc().getPointer();
308 resetToken(symbolData.data());
310 // Parse the type.
311 CustomDialectAsmParser customParser(symbolData, *this);
312 Type type = dialect->parseType(customParser);
313 resetToken(curLexerPos);
314 return type;
317 // Otherwise, form a new opaque type.
318 return OpaqueType::getChecked([&] { return emitError(loc); },
319 StringAttr::get(ctx, dialectName),
320 symbolData);
324 //===----------------------------------------------------------------------===//
325 // mlir::parseAttribute/parseType
326 //===----------------------------------------------------------------------===//
328 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If
329 /// parsing failed, nullptr is returned.
330 template <typename T, typename ParserFn>
331 static T parseSymbol(StringRef inputStr, MLIRContext *context,
332 size_t *numReadOut, bool isKnownNullTerminated,
333 ParserFn &&parserFn) {
334 // Set the buffer name to the string being parsed, so that it appears in error
335 // diagnostics.
336 auto memBuffer =
337 isKnownNullTerminated
338 ? MemoryBuffer::getMemBuffer(inputStr,
339 /*BufferName=*/inputStr)
340 : MemoryBuffer::getMemBufferCopy(inputStr, /*BufferName=*/inputStr);
341 SourceMgr sourceMgr;
342 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
343 SymbolState aliasState;
344 ParserConfig config(context);
345 ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr,
346 /*codeCompleteContext=*/nullptr);
347 Parser parser(state);
349 Token startTok = parser.getToken();
350 T symbol = parserFn(parser);
351 if (!symbol)
352 return T();
354 // Provide the number of bytes that were read.
355 Token endTok = parser.getToken();
356 size_t numRead =
357 endTok.getLoc().getPointer() - startTok.getLoc().getPointer();
358 if (numReadOut) {
359 *numReadOut = numRead;
360 } else if (numRead != inputStr.size()) {
361 parser.emitError(endTok.getLoc()) << "found trailing characters: '"
362 << inputStr.drop_front(numRead) << "'";
363 return T();
365 return symbol;
368 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
369 Type type, size_t *numRead,
370 bool isKnownNullTerminated) {
371 return parseSymbol<Attribute>(
372 attrStr, context, numRead, isKnownNullTerminated,
373 [type](Parser &parser) { return parser.parseAttribute(type); });
375 Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead,
376 bool isKnownNullTerminated) {
377 return parseSymbol<Type>(typeStr, context, numRead, isKnownNullTerminated,
378 [](Parser &parser) { return parser.parseType(); });