[Infra] Fix version-check workflow (#100090)
[llvm-project.git] / mlir / tools / mlir-tblgen / FormatGen.cpp
blob7540e584b8fac5d4b4a4bda7b85ee89da63dfec5
1 //===- FormatGen.cpp - Utilities for custom assembly formats ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "FormatGen.h"
10 #include "llvm/ADT/StringSwitch.h"
11 #include "llvm/Support/SourceMgr.h"
12 #include "llvm/TableGen/Error.h"
14 using namespace mlir;
15 using namespace mlir::tblgen;
17 //===----------------------------------------------------------------------===//
18 // FormatToken
19 //===----------------------------------------------------------------------===//
21 SMLoc FormatToken::getLoc() const {
22 return SMLoc::getFromPointer(spelling.data());
25 //===----------------------------------------------------------------------===//
26 // FormatLexer
27 //===----------------------------------------------------------------------===//
29 FormatLexer::FormatLexer(llvm::SourceMgr &mgr, SMLoc loc)
30 : mgr(mgr), loc(loc),
31 curBuffer(mgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer()),
32 curPtr(curBuffer.begin()) {}
34 FormatToken FormatLexer::emitError(SMLoc loc, const Twine &msg) {
35 mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
36 llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note,
37 "in custom assembly format for this operation");
38 return formToken(FormatToken::error, loc.getPointer());
41 FormatToken FormatLexer::emitError(const char *loc, const Twine &msg) {
42 return emitError(SMLoc::getFromPointer(loc), msg);
45 FormatToken FormatLexer::emitErrorAndNote(SMLoc loc, const Twine &msg,
46 const Twine &note) {
47 mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
48 llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note,
49 "in custom assembly format for this operation");
50 mgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note);
51 return formToken(FormatToken::error, loc.getPointer());
54 int FormatLexer::getNextChar() {
55 char curChar = *curPtr++;
56 switch (curChar) {
57 default:
58 return (unsigned char)curChar;
59 case 0: {
60 // A nul character in the stream is either the end of the current buffer or
61 // a random nul in the file. Disambiguate that here.
62 if (curPtr - 1 != curBuffer.end())
63 return 0;
65 // Otherwise, return end of file.
66 --curPtr;
67 return EOF;
69 case '\n':
70 case '\r':
71 // Handle the newline character by ignoring it and incrementing the line
72 // count. However, be careful about 'dos style' files with \n\r in them.
73 // Only treat a \n\r or \r\n as a single line.
74 if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
75 ++curPtr;
76 return '\n';
80 FormatToken FormatLexer::lexToken() {
81 const char *tokStart = curPtr;
83 // This always consumes at least one character.
84 int curChar = getNextChar();
85 switch (curChar) {
86 default:
87 // Handle identifiers: [a-zA-Z_]
88 if (isalpha(curChar) || curChar == '_')
89 return lexIdentifier(tokStart);
91 // Unknown character, emit an error.
92 return emitError(tokStart, "unexpected character");
93 case EOF:
94 // Return EOF denoting the end of lexing.
95 return formToken(FormatToken::eof, tokStart);
97 // Lex punctuation.
98 case '^':
99 return formToken(FormatToken::caret, tokStart);
100 case ':':
101 return formToken(FormatToken::colon, tokStart);
102 case ',':
103 return formToken(FormatToken::comma, tokStart);
104 case '=':
105 return formToken(FormatToken::equal, tokStart);
106 case '<':
107 return formToken(FormatToken::less, tokStart);
108 case '>':
109 return formToken(FormatToken::greater, tokStart);
110 case '?':
111 return formToken(FormatToken::question, tokStart);
112 case '(':
113 return formToken(FormatToken::l_paren, tokStart);
114 case ')':
115 return formToken(FormatToken::r_paren, tokStart);
116 case '*':
117 return formToken(FormatToken::star, tokStart);
118 case '|':
119 return formToken(FormatToken::pipe, tokStart);
121 // Ignore whitespace characters.
122 case 0:
123 case ' ':
124 case '\t':
125 case '\n':
126 return lexToken();
128 case '`':
129 return lexLiteral(tokStart);
130 case '$':
131 return lexVariable(tokStart);
132 case '"':
133 return lexString(tokStart);
137 FormatToken FormatLexer::lexLiteral(const char *tokStart) {
138 assert(curPtr[-1] == '`');
140 // Lex a literal surrounded by ``.
141 while (const char curChar = *curPtr++) {
142 if (curChar == '`')
143 return formToken(FormatToken::literal, tokStart);
145 return emitError(curPtr - 1, "unexpected end of file in literal");
148 FormatToken FormatLexer::lexVariable(const char *tokStart) {
149 if (!isalpha(curPtr[0]) && curPtr[0] != '_')
150 return emitError(curPtr - 1, "expected variable name");
152 // Otherwise, consume the rest of the characters.
153 while (isalnum(*curPtr) || *curPtr == '_')
154 ++curPtr;
155 return formToken(FormatToken::variable, tokStart);
158 FormatToken FormatLexer::lexString(const char *tokStart) {
159 // Lex until another quote, respecting escapes.
160 bool escape = false;
161 while (const char curChar = *curPtr++) {
162 if (!escape && curChar == '"')
163 return formToken(FormatToken::string, tokStart);
164 escape = curChar == '\\';
166 return emitError(curPtr - 1, "unexpected end of file in string");
169 FormatToken FormatLexer::lexIdentifier(const char *tokStart) {
170 // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
171 while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
172 ++curPtr;
174 // Check to see if this identifier is a keyword.
175 StringRef str(tokStart, curPtr - tokStart);
176 auto kind =
177 StringSwitch<FormatToken::Kind>(str)
178 .Case("attr-dict", FormatToken::kw_attr_dict)
179 .Case("attr-dict-with-keyword", FormatToken::kw_attr_dict_w_keyword)
180 .Case("prop-dict", FormatToken::kw_prop_dict)
181 .Case("custom", FormatToken::kw_custom)
182 .Case("functional-type", FormatToken::kw_functional_type)
183 .Case("oilist", FormatToken::kw_oilist)
184 .Case("operands", FormatToken::kw_operands)
185 .Case("params", FormatToken::kw_params)
186 .Case("ref", FormatToken::kw_ref)
187 .Case("regions", FormatToken::kw_regions)
188 .Case("results", FormatToken::kw_results)
189 .Case("struct", FormatToken::kw_struct)
190 .Case("successors", FormatToken::kw_successors)
191 .Case("type", FormatToken::kw_type)
192 .Case("qualified", FormatToken::kw_qualified)
193 .Default(FormatToken::identifier);
194 return FormatToken(kind, str);
197 //===----------------------------------------------------------------------===//
198 // FormatParser
199 //===----------------------------------------------------------------------===//
201 FormatElement::~FormatElement() = default;
203 FormatParser::~FormatParser() = default;
205 FailureOr<std::vector<FormatElement *>> FormatParser::parse() {
206 SMLoc loc = curToken.getLoc();
208 // Parse each of the format elements into the main format.
209 std::vector<FormatElement *> elements;
210 while (curToken.getKind() != FormatToken::eof) {
211 FailureOr<FormatElement *> element = parseElement(TopLevelContext);
212 if (failed(element))
213 return failure();
214 elements.push_back(*element);
217 // Verify the format.
218 if (failed(verify(loc, elements)))
219 return failure();
220 return elements;
223 //===----------------------------------------------------------------------===//
224 // Element Parsing
226 FailureOr<FormatElement *> FormatParser::parseElement(Context ctx) {
227 if (curToken.is(FormatToken::literal))
228 return parseLiteral(ctx);
229 if (curToken.is(FormatToken::string))
230 return parseString(ctx);
231 if (curToken.is(FormatToken::variable))
232 return parseVariable(ctx);
233 if (curToken.isKeyword())
234 return parseDirective(ctx);
235 if (curToken.is(FormatToken::l_paren))
236 return parseOptionalGroup(ctx);
237 return emitError(curToken.getLoc(),
238 "expected literal, variable, directive, or optional group");
241 FailureOr<FormatElement *> FormatParser::parseLiteral(Context ctx) {
242 FormatToken tok = curToken;
243 SMLoc loc = tok.getLoc();
244 consumeToken();
246 if (ctx != TopLevelContext) {
247 return emitError(
248 loc,
249 "literals may only be used in the top-level section of the format");
251 // Get the spelling without the surrounding backticks.
252 StringRef value = tok.getSpelling();
253 // Prevents things like `$arg0` or empty literals (when a literal is expected
254 // but not found) from getting segmentation faults.
255 if (value.size() < 2 || value[0] != '`' || value[value.size() - 1] != '`')
256 return emitError(tok.getLoc(), "expected literal, but got '" + value + "'");
257 value = value.drop_front().drop_back();
259 // The parsed literal is a space element (`` or ` `) or a newline.
260 if (value.empty() || value == " " || value == "\\n")
261 return create<WhitespaceElement>(value);
263 // Check that the parsed literal is valid.
264 if (!isValidLiteral(value, [&](Twine msg) {
265 (void)emitError(loc, "expected valid literal but got '" + value +
266 "': " + msg);
268 return failure();
269 return create<LiteralElement>(value);
272 FailureOr<FormatElement *> FormatParser::parseString(Context ctx) {
273 FormatToken tok = curToken;
274 SMLoc loc = tok.getLoc();
275 consumeToken();
277 if (ctx != CustomDirectiveContext) {
278 return emitError(
279 loc, "strings may only be used as 'custom' directive arguments");
281 // Escape the string.
282 std::string value;
283 StringRef contents = tok.getSpelling().drop_front().drop_back();
284 value.reserve(contents.size());
285 bool escape = false;
286 for (char c : contents) {
287 escape = c == '\\';
288 if (!escape)
289 value.push_back(c);
291 return create<StringElement>(std::move(value));
294 FailureOr<FormatElement *> FormatParser::parseVariable(Context ctx) {
295 FormatToken tok = curToken;
296 SMLoc loc = tok.getLoc();
297 consumeToken();
299 // Get the name of the variable without the leading `$`.
300 StringRef name = tok.getSpelling().drop_front();
301 return parseVariableImpl(loc, name, ctx);
304 FailureOr<FormatElement *> FormatParser::parseDirective(Context ctx) {
305 FormatToken tok = curToken;
306 SMLoc loc = tok.getLoc();
307 consumeToken();
309 if (tok.is(FormatToken::kw_custom))
310 return parseCustomDirective(loc, ctx);
311 if (tok.is(FormatToken::kw_ref))
312 return parseRefDirective(loc, ctx);
313 if (tok.is(FormatToken::kw_qualified))
314 return parseQualifiedDirective(loc, ctx);
315 return parseDirectiveImpl(loc, tok.getKind(), ctx);
318 FailureOr<FormatElement *> FormatParser::parseOptionalGroup(Context ctx) {
319 SMLoc loc = curToken.getLoc();
320 consumeToken();
321 if (ctx != TopLevelContext) {
322 return emitError(loc,
323 "optional groups can only be used as top-level elements");
326 // Parse the child elements for this optional group.
327 std::vector<FormatElement *> thenElements, elseElements;
328 FormatElement *anchor = nullptr;
329 auto parseChildElements =
330 [this, &anchor](std::vector<FormatElement *> &elements) -> LogicalResult {
331 do {
332 FailureOr<FormatElement *> element = parseElement(TopLevelContext);
333 if (failed(element))
334 return failure();
335 // Check for an anchor.
336 if (curToken.is(FormatToken::caret)) {
337 if (anchor) {
338 return emitError(curToken.getLoc(),
339 "only one element can be marked as the anchor of an "
340 "optional group");
342 anchor = *element;
343 consumeToken();
345 elements.push_back(*element);
346 } while (!curToken.is(FormatToken::r_paren));
347 return success();
350 // Parse the 'then' elements. If the anchor was found in this group, then the
351 // optional is not inverted.
352 if (failed(parseChildElements(thenElements)))
353 return failure();
354 consumeToken();
355 bool inverted = !anchor;
357 // Parse the `else` elements of this optional group.
358 if (curToken.is(FormatToken::colon)) {
359 consumeToken();
360 if (failed(parseToken(
361 FormatToken::l_paren,
362 "expected '(' to start else branch of optional group")) ||
363 failed(parseChildElements(elseElements)))
364 return failure();
365 consumeToken();
367 if (failed(parseToken(FormatToken::question,
368 "expected '?' after optional group")))
369 return failure();
371 // The optional group is required to have an anchor.
372 if (!anchor)
373 return emitError(loc, "optional group has no anchor element");
375 // Verify the child elements.
376 if (failed(verifyOptionalGroupElements(loc, thenElements, anchor)) ||
377 failed(verifyOptionalGroupElements(loc, elseElements, nullptr)))
378 return failure();
380 // Get the first parsable element. It must be an element that can be
381 // optionally-parsed.
382 auto isWhitespace = [](FormatElement *element) {
383 return isa<WhitespaceElement>(element);
385 auto thenParseBegin = llvm::find_if_not(thenElements, isWhitespace);
386 auto elseParseBegin = llvm::find_if_not(elseElements, isWhitespace);
387 unsigned thenParseStart = std::distance(thenElements.begin(), thenParseBegin);
388 unsigned elseParseStart = std::distance(elseElements.begin(), elseParseBegin);
390 if (!isa<LiteralElement, VariableElement, CustomDirective>(*thenParseBegin)) {
391 return emitError(loc, "first parsable element of an optional group must be "
392 "a literal, variable, or custom directive");
394 return create<OptionalElement>(std::move(thenElements),
395 std::move(elseElements), thenParseStart,
396 elseParseStart, anchor, inverted);
399 FailureOr<FormatElement *> FormatParser::parseCustomDirective(SMLoc loc,
400 Context ctx) {
401 if (ctx != TopLevelContext)
402 return emitError(loc, "'custom' is only valid as a top-level directive");
404 FailureOr<FormatToken> nameTok;
405 if (failed(parseToken(FormatToken::less,
406 "expected '<' before custom directive name")) ||
407 failed(nameTok =
408 parseToken(FormatToken::identifier,
409 "expected custom directive name identifier")) ||
410 failed(parseToken(FormatToken::greater,
411 "expected '>' after custom directive name")) ||
412 failed(parseToken(FormatToken::l_paren,
413 "expected '(' before custom directive parameters")))
414 return failure();
416 // Parse the arguments.
417 std::vector<FormatElement *> arguments;
418 while (true) {
419 FailureOr<FormatElement *> argument = parseElement(CustomDirectiveContext);
420 if (failed(argument))
421 return failure();
422 arguments.push_back(*argument);
423 if (!curToken.is(FormatToken::comma))
424 break;
425 consumeToken();
428 if (failed(parseToken(FormatToken::r_paren,
429 "expected ')' after custom directive parameters")))
430 return failure();
432 if (failed(verifyCustomDirectiveArguments(loc, arguments)))
433 return failure();
434 return create<CustomDirective>(nameTok->getSpelling(), std::move(arguments));
437 FailureOr<FormatElement *> FormatParser::parseRefDirective(SMLoc loc,
438 Context context) {
439 if (context != CustomDirectiveContext)
440 return emitError(loc, "'ref' is only valid within a `custom` directive");
442 FailureOr<FormatElement *> arg;
443 if (failed(parseToken(FormatToken::l_paren,
444 "expected '(' before argument list")) ||
445 failed(arg = parseElement(RefDirectiveContext)) ||
446 failed(
447 parseToken(FormatToken::r_paren, "expected ')' after argument list")))
448 return failure();
450 return create<RefDirective>(*arg);
453 FailureOr<FormatElement *> FormatParser::parseQualifiedDirective(SMLoc loc,
454 Context ctx) {
455 if (failed(parseToken(FormatToken::l_paren,
456 "expected '(' before argument list")))
457 return failure();
458 FailureOr<FormatElement *> var = parseElement(ctx);
459 if (failed(var))
460 return var;
461 if (failed(markQualified(loc, *var)))
462 return failure();
463 if (failed(
464 parseToken(FormatToken::r_paren, "expected ')' after argument list")))
465 return failure();
466 return var;
469 //===----------------------------------------------------------------------===//
470 // Utility Functions
471 //===----------------------------------------------------------------------===//
473 bool mlir::tblgen::shouldEmitSpaceBefore(StringRef value,
474 bool lastWasPunctuation) {
475 if (value.size() != 1 && value != "->")
476 return true;
477 if (lastWasPunctuation)
478 return !StringRef(">)}],").contains(value.front());
479 return !StringRef("<>(){}[],").contains(value.front());
482 bool mlir::tblgen::canFormatStringAsKeyword(
483 StringRef value, function_ref<void(Twine)> emitError) {
484 if (value.empty()) {
485 if (emitError)
486 emitError("keywords cannot be empty");
487 return false;
489 if (!isalpha(value.front()) && value.front() != '_') {
490 if (emitError)
491 emitError("valid keyword starts with a letter or '_'");
492 return false;
494 if (!llvm::all_of(value.drop_front(), [](char c) {
495 return isalnum(c) || c == '_' || c == '$' || c == '.';
496 })) {
497 if (emitError)
498 emitError(
499 "keywords should contain only alphanum, '_', '$', or '.' characters");
500 return false;
502 return true;
505 bool mlir::tblgen::isValidLiteral(StringRef value,
506 function_ref<void(Twine)> emitError) {
507 if (value.empty()) {
508 if (emitError)
509 emitError("literal can't be empty");
510 return false;
512 char front = value.front();
514 // If there is only one character, this must either be punctuation or a
515 // single character bare identifier.
516 if (value.size() == 1) {
517 StringRef bare = "_:,=<>()[]{}?+*";
518 if (isalpha(front) || bare.contains(front))
519 return true;
520 if (emitError)
521 emitError("single character literal must be a letter or one of '" + bare +
522 "'");
523 return false;
525 // Check the punctuation that are larger than a single character.
526 if (value == "->")
527 return true;
528 if (value == "...")
529 return true;
531 // Otherwise, this must be an identifier.
532 return canFormatStringAsKeyword(value, emitError);
535 //===----------------------------------------------------------------------===//
536 // Commandline Options
537 //===----------------------------------------------------------------------===//
539 llvm::cl::opt<bool> mlir::tblgen::formatErrorIsFatal(
540 "asmformat-error-is-fatal",
541 llvm::cl::desc("Emit a fatal error if format parsing fails"),
542 llvm::cl::init(true));