[Infra] Fix version-check workflow (#100090)
[llvm-project.git] / mlir / tools / mlir-tblgen / AttrOrTypeFormatGen.cpp
blobdacc20b6ba2086619612a4582e168fb0eb9e7c88
1 //===- AttrOrTypeFormatGen.cpp - MLIR attribute and type format generator -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "AttrOrTypeFormatGen.h"
10 #include "FormatGen.h"
11 #include "mlir/Support/LLVM.h"
12 #include "mlir/TableGen/AttrOrTypeDef.h"
13 #include "mlir/TableGen/Format.h"
14 #include "mlir/TableGen/GenInfo.h"
15 #include "llvm/ADT/BitVector.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/ADT/StringSwitch.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/MemoryBuffer.h"
20 #include "llvm/Support/SaveAndRestore.h"
21 #include "llvm/Support/SourceMgr.h"
22 #include "llvm/TableGen/Error.h"
23 #include "llvm/TableGen/TableGenBackend.h"
25 using namespace mlir;
26 using namespace mlir::tblgen;
28 using llvm::formatv;
30 //===----------------------------------------------------------------------===//
31 // Element
32 //===----------------------------------------------------------------------===//
34 namespace {
35 /// This class represents an instance of a variable element. A variable refers
36 /// to an attribute or type parameter.
37 class ParameterElement
38 : public VariableElementBase<VariableElement::Parameter> {
39 public:
40 ParameterElement(AttrOrTypeParameter param) : param(param) {}
42 /// Get the parameter in the element.
43 const AttrOrTypeParameter &getParam() const { return param; }
45 /// Indicate if this variable is printed "qualified" (that is it is
46 /// prefixed with the `#dialect.mnemonic`).
47 bool shouldBeQualified() { return shouldBeQualifiedFlag; }
48 void setShouldBeQualified(bool qualified = true) {
49 shouldBeQualifiedFlag = qualified;
52 /// Returns true if the element contains an optional parameter.
53 bool isOptional() const { return param.isOptional(); }
55 /// Returns the name of the parameter.
56 StringRef getName() const { return param.getName(); }
58 /// Return the code to check whether the parameter is present.
59 auto genIsPresent(FmtContext &ctx, const Twine &self) const {
60 assert(isOptional() && "cannot guard on a mandatory parameter");
61 std::string valueStr = tgfmt(*param.getDefaultValue(), &ctx).str();
62 ctx.addSubst("_lhs", self).addSubst("_rhs", valueStr);
63 return tgfmt(getParam().getComparator(), &ctx);
66 /// Generate the code to check whether the parameter should be printed.
67 MethodBody &genPrintGuard(FmtContext &ctx, MethodBody &os) const {
68 assert(isOptional() && "cannot guard on a mandatory parameter");
69 std::string self = param.getAccessorName() + "()";
70 return os << "!(" << genIsPresent(ctx, self) << ")";
73 private:
74 bool shouldBeQualifiedFlag = false;
75 AttrOrTypeParameter param;
78 /// Shorthand functions that can be used with ranged-based conditions.
79 static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); }
80 static bool paramNotOptional(ParameterElement *el) { return !el->isOptional(); }
82 /// Base class for a directive that contains references to multiple variables.
83 template <DirectiveElement::Kind DirectiveKind>
84 class ParamsDirectiveBase : public DirectiveElementBase<DirectiveKind> {
85 public:
86 using Base = ParamsDirectiveBase<DirectiveKind>;
88 ParamsDirectiveBase(std::vector<ParameterElement *> &&params)
89 : params(std::move(params)) {}
91 /// Get the parameters contained in this directive.
92 ArrayRef<ParameterElement *> getParams() const { return params; }
94 /// Get the number of parameters.
95 unsigned getNumParams() const { return params.size(); }
97 /// Take all of the parameters from this directive.
98 std::vector<ParameterElement *> takeParams() { return std::move(params); }
100 /// Returns true if there are optional parameters present.
101 bool hasOptionalParams() const {
102 return llvm::any_of(getParams(), paramIsOptional);
105 private:
106 /// The parameters captured by this directive.
107 std::vector<ParameterElement *> params;
110 /// This class represents a `params` directive that refers to all parameters
111 /// of an attribute or type. When used as a top-level directive, it generates
112 /// a format of the form:
114 /// (param-value (`,` param-value)*)?
116 /// When used as an argument to another directive that accepts variables,
117 /// `params` can be used in place of manually listing all parameters of an
118 /// attribute or type.
119 class ParamsDirective : public ParamsDirectiveBase<DirectiveElement::Params> {
120 public:
121 using Base::Base;
124 /// This class represents a `struct` directive that generates a struct format
125 /// of the form:
127 /// `{` param-name `=` param-value (`,` param-name `=` param-value)* `}`
129 class StructDirective : public ParamsDirectiveBase<DirectiveElement::Struct> {
130 public:
131 using Base::Base;
134 } // namespace
136 //===----------------------------------------------------------------------===//
137 // Format Strings
138 //===----------------------------------------------------------------------===//
140 /// Default parser for attribute or type parameters.
141 static const char *const defaultParameterParser =
142 "::mlir::FieldParser<$0>::parse($_parser)";
144 /// Default printer for attribute or type parameters.
145 static const char *const defaultParameterPrinter =
146 "$_printer.printStrippedAttrOrType($_self)";
148 /// Qualified printer for attribute or type parameters: it does not elide
149 /// dialect and mnemonic.
150 static const char *const qualifiedParameterPrinter = "$_printer << $_self";
152 /// Print an error when failing to parse an element.
154 /// $0: The parameter C++ class name.
155 static const char *const parserErrorStr =
156 "$_parser.emitError($_parser.getCurrentLocation(), ";
158 /// Code format to parse a variable. Separate by lines because variable parsers
159 /// may be generated inside other directives, which requires indentation.
161 /// {0}: The parameter name.
162 /// {1}: The parse code for the parameter.
163 /// {2}: Code template for printing an error.
164 /// {3}: Name of the attribute or type.
165 /// {4}: C++ class of the parameter.
166 /// {5}: Optional code to preload the dialect for this variable.
167 static const char *const variableParser = R"(
168 // Parse variable '{0}'{5}
169 _result_{0} = {1};
170 if (::mlir::failed(_result_{0})) {{
171 {2}"failed to parse {3} parameter '{0}' which is to be a `{4}`");
172 return {{};
176 //===----------------------------------------------------------------------===//
177 // DefFormat
178 //===----------------------------------------------------------------------===//
180 namespace {
181 class DefFormat {
182 public:
183 DefFormat(const AttrOrTypeDef &def, std::vector<FormatElement *> &&elements)
184 : def(def), elements(std::move(elements)) {}
186 /// Generate the attribute or type parser.
187 void genParser(MethodBody &os);
188 /// Generate the attribute or type printer.
189 void genPrinter(MethodBody &os);
191 private:
192 /// Generate the parser code for a specific format element.
193 void genElementParser(FormatElement *el, FmtContext &ctx, MethodBody &os);
194 /// Generate the parser code for a literal.
195 void genLiteralParser(StringRef value, FmtContext &ctx, MethodBody &os,
196 bool isOptional = false);
197 /// Generate the parser code for a variable.
198 void genVariableParser(ParameterElement *el, FmtContext &ctx, MethodBody &os);
199 /// Generate the parser code for a `params` directive.
200 void genParamsParser(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
201 /// Generate the parser code for a `struct` directive.
202 void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os);
203 /// Generate the parser code for a `custom` directive.
204 void genCustomParser(CustomDirective *el, FmtContext &ctx, MethodBody &os,
205 bool isOptional = false);
206 /// Generate the parser code for an optional group.
207 void genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
208 MethodBody &os);
210 /// Generate the printer code for a specific format element.
211 void genElementPrinter(FormatElement *el, FmtContext &ctx, MethodBody &os);
212 /// Generate the printer code for a literal.
213 void genLiteralPrinter(StringRef value, FmtContext &ctx, MethodBody &os);
214 /// Generate the printer code for a variable.
215 void genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os,
216 bool skipGuard = false);
217 /// Generate a printer for comma-separated parameters.
218 void genCommaSeparatedPrinter(ArrayRef<ParameterElement *> params,
219 FmtContext &ctx, MethodBody &os,
220 function_ref<void(ParameterElement *)> extra);
221 /// Generate the printer code for a `params` directive.
222 void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
223 /// Generate the printer code for a `struct` directive.
224 void genStructPrinter(StructDirective *el, FmtContext &ctx, MethodBody &os);
225 /// Generate the printer code for a `custom` directive.
226 void genCustomPrinter(CustomDirective *el, FmtContext &ctx, MethodBody &os);
227 /// Generate the printer code for an optional group.
228 void genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
229 MethodBody &os);
230 /// Generate a printer (or space eraser) for a whitespace element.
231 void genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx,
232 MethodBody &os);
234 /// The ODS definition of the attribute or type whose format is being used to
235 /// generate a parser and printer.
236 const AttrOrTypeDef &def;
237 /// The list of top-level format elements returned by the assembly format
238 /// parser.
239 std::vector<FormatElement *> elements;
241 /// Flags for printing spaces.
242 bool shouldEmitSpace = false;
243 bool lastWasPunctuation = false;
245 } // namespace
247 //===----------------------------------------------------------------------===//
248 // ParserGen
249 //===----------------------------------------------------------------------===//
251 /// Generate a special-case "parser" for an attribute's self type parameter. The
252 /// self type parameter has special handling in the assembly format in that it
253 /// is derived from the optional trailing colon type after the attribute.
254 static void genAttrSelfTypeParser(MethodBody &os, const FmtContext &ctx,
255 const AttributeSelfTypeParameter &param) {
256 // "Parser" for an attribute self type parameter that checks the
257 // optionally-parsed trailing colon type.
259 // $0: The C++ storage class of the type parameter.
260 // $1: The self type parameter name.
261 const char *const selfTypeParser = R"(
262 if ($_type) {
263 if (auto reqType = ::llvm::dyn_cast<$0>($_type)) {
264 _result_$1 = reqType;
265 } else {
266 $_parser.emitError($_loc, "invalid kind of type specified");
267 return {};
269 })";
271 // If the attribute self type parameter is required, emit code that emits an
272 // error if the trailing type was not parsed.
273 const char *const selfTypeRequired = R"( else {
274 $_parser.emitError($_loc, "expected a trailing type");
275 return {};
276 })";
278 os << tgfmt(selfTypeParser, &ctx, param.getCppStorageType(), param.getName());
279 if (!param.isOptional())
280 os << tgfmt(selfTypeRequired, &ctx);
281 os << "\n";
284 void DefFormat::genParser(MethodBody &os) {
285 FmtContext ctx;
286 ctx.addSubst("_parser", "odsParser");
287 ctx.addSubst("_ctxt", "odsParser.getContext()");
288 ctx.withBuilder("odsBuilder");
289 if (isa<AttrDef>(def))
290 ctx.addSubst("_type", "odsType");
291 os.indent();
292 os << "::mlir::Builder odsBuilder(odsParser.getContext());\n";
294 // Store the initial location of the parser.
295 ctx.addSubst("_loc", "odsLoc");
296 os << tgfmt("::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n"
297 "(void) $_loc;\n",
298 &ctx);
300 // Declare variables to store all of the parameters. Allocated parameters
301 // such as `ArrayRef` and `StringRef` must provide a `storageType`. Store
302 // FailureOr<T> to defer type construction for parameters that are parsed in
303 // a loop (parsers return FailureOr anyways).
304 ArrayRef<AttrOrTypeParameter> params = def.getParameters();
305 for (const AttrOrTypeParameter &param : params) {
306 os << formatv("::mlir::FailureOr<{0}> _result_{1};\n",
307 param.getCppStorageType(), param.getName());
308 if (auto *selfTypeParam = dyn_cast<AttributeSelfTypeParameter>(&param))
309 genAttrSelfTypeParser(os, ctx, *selfTypeParam);
312 // Generate call to each parameter parser.
313 for (FormatElement *el : elements)
314 genElementParser(el, ctx, os);
316 // Emit an assert for each mandatory parameter. Triggering an assert means
317 // the generated parser is incorrect (i.e. there is a bug in this code).
318 for (const AttrOrTypeParameter &param : params) {
319 if (param.isOptional())
320 continue;
321 os << formatv("assert(::mlir::succeeded(_result_{0}));\n", param.getName());
324 // Generate call to the attribute or type builder. Use the checked getter
325 // if one was generated.
326 if (def.genVerifyDecl()) {
327 os << tgfmt("return $_parser.getChecked<$0>($_loc, $_parser.getContext()",
328 &ctx, def.getCppClassName());
329 } else {
330 os << tgfmt("return $0::get($_parser.getContext()", &ctx,
331 def.getCppClassName());
333 for (const AttrOrTypeParameter &param : params) {
334 os << ",\n ";
335 std::string paramSelfStr;
336 llvm::raw_string_ostream selfOs(paramSelfStr);
337 if (std::optional<StringRef> defaultValue = param.getDefaultValue()) {
338 selfOs << formatv("(_result_{0}.value_or(", param.getName())
339 << tgfmt(*defaultValue, &ctx) << "))";
340 } else {
341 selfOs << formatv("(*_result_{0})", param.getName());
343 ctx.addSubst(param.getName(), selfOs.str());
344 os << param.getCppType() << "("
345 << tgfmt(param.getConvertFromStorage(), &ctx.withSelf(selfOs.str()))
346 << ")";
348 os << ");";
351 void DefFormat::genElementParser(FormatElement *el, FmtContext &ctx,
352 MethodBody &os) {
353 if (auto *literal = dyn_cast<LiteralElement>(el))
354 return genLiteralParser(literal->getSpelling(), ctx, os);
355 if (auto *var = dyn_cast<ParameterElement>(el))
356 return genVariableParser(var, ctx, os);
357 if (auto *params = dyn_cast<ParamsDirective>(el))
358 return genParamsParser(params, ctx, os);
359 if (auto *strct = dyn_cast<StructDirective>(el))
360 return genStructParser(strct, ctx, os);
361 if (auto *custom = dyn_cast<CustomDirective>(el))
362 return genCustomParser(custom, ctx, os);
363 if (auto *optional = dyn_cast<OptionalElement>(el))
364 return genOptionalGroupParser(optional, ctx, os);
365 if (isa<WhitespaceElement>(el))
366 return;
368 llvm_unreachable("unknown format element");
371 void DefFormat::genLiteralParser(StringRef value, FmtContext &ctx,
372 MethodBody &os, bool isOptional) {
373 os << "// Parse literal '" << value << "'\n";
374 os << tgfmt("if ($_parser.parse", &ctx);
375 if (isOptional)
376 os << "Optional";
377 if (value.front() == '_' || isalpha(value.front())) {
378 os << "Keyword(\"" << value << "\")";
379 } else {
380 os << StringSwitch<StringRef>(value)
381 .Case("->", "Arrow")
382 .Case(":", "Colon")
383 .Case(",", "Comma")
384 .Case("=", "Equal")
385 .Case("<", "Less")
386 .Case(">", "Greater")
387 .Case("{", "LBrace")
388 .Case("}", "RBrace")
389 .Case("(", "LParen")
390 .Case(")", "RParen")
391 .Case("[", "LSquare")
392 .Case("]", "RSquare")
393 .Case("?", "Question")
394 .Case("+", "Plus")
395 .Case("*", "Star")
396 .Case("...", "Ellipsis")
397 << "()";
399 if (isOptional) {
400 // Leave the `if` unclosed to guard optional groups.
401 return;
403 // Parser will emit an error
404 os << ") return {};\n";
407 void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx,
408 MethodBody &os) {
409 // Check for a custom parser. Use the default attribute parser otherwise.
410 const AttrOrTypeParameter &param = el->getParam();
411 auto customParser = param.getParser();
412 auto parser =
413 customParser ? *customParser : StringRef(defaultParameterParser);
415 // If the variable points to a dialect specific entity (type of attribute),
416 // we force load the dialect now before trying to parse it.
417 std::string dialectLoading;
418 if (auto *defInit = dyn_cast<llvm::DefInit>(param.getDef())) {
419 auto *dialectValue = defInit->getDef()->getValue("dialect");
420 if (dialectValue) {
421 if (auto *dialectInit =
422 dyn_cast<llvm::DefInit>(dialectValue->getValue())) {
423 Dialect dialect(dialectInit->getDef());
424 auto cppNamespace = dialect.getCppNamespace();
425 std::string name = dialect.getCppClassName();
426 if (name != "BuiltinDialect" || cppNamespace != "::mlir") {
427 dialectLoading = ("\nodsParser.getContext()->getOrLoadDialect<" +
428 cppNamespace + "::" + name + ">();")
429 .str();
434 os << formatv(variableParser, param.getName(),
435 tgfmt(parser, &ctx, param.getCppStorageType()),
436 tgfmt(parserErrorStr, &ctx), def.getName(), param.getCppType(),
437 dialectLoading);
440 void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
441 MethodBody &os) {
442 os << "// Parse parameter list\n";
444 // If there are optional parameters, we need to switch to `parseOptionalComma`
445 // if there are no more required parameters after a certain point.
446 bool hasOptional = el->hasOptionalParams();
447 if (hasOptional) {
448 // Wrap everything in a do-while so that we can `break`.
449 os << "do {\n";
450 os.indent();
453 ArrayRef<ParameterElement *> params = el->getParams();
454 using IteratorT = ParameterElement *const *;
455 IteratorT it = params.begin();
457 // Find the last required parameter. Commas become optional aftewards.
458 // Note: IteratorT's copy assignment is deleted.
459 ParameterElement *lastReq = nullptr;
460 for (ParameterElement *param : params)
461 if (!param->isOptional())
462 lastReq = param;
463 IteratorT lastReqIt = lastReq ? llvm::find(params, lastReq) : params.begin();
465 auto eachFn = [&](ParameterElement *el) { genVariableParser(el, ctx, os); };
466 auto betweenFn = [&](IteratorT it) {
467 ParameterElement *el = *std::prev(it);
468 // Parse a comma if the last optional parameter had a value.
469 if (el->isOptional()) {
470 os << formatv("if (::mlir::succeeded(_result_{0}) && !({1})) {{\n",
471 el->getName(),
472 el->genIsPresent(ctx, "(*_result_" + el->getName() + ")"));
473 os.indent();
475 if (it <= lastReqIt) {
476 genLiteralParser(",", ctx, os);
477 } else {
478 genLiteralParser(",", ctx, os, /*isOptional=*/true);
479 os << ") break;\n";
481 if (el->isOptional())
482 os.unindent() << "}\n";
485 // llvm::interleave
486 if (it != params.end()) {
487 eachFn(*it++);
488 for (IteratorT e = params.end(); it != e; ++it) {
489 betweenFn(it);
490 eachFn(*it);
494 if (hasOptional)
495 os.unindent() << "} while(false);\n";
498 void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
499 MethodBody &os) {
500 // Loop declaration for struct parser with only required parameters.
502 // $0: Number of expected parameters.
503 const char *const loopHeader = R"(
504 for (unsigned odsStructIndex = 0; odsStructIndex < $0; ++odsStructIndex) {
507 // Loop body start for struct parser.
508 const char *const loopStart = R"(
509 ::llvm::StringRef _paramKey;
510 if ($_parser.parseKeyword(&_paramKey)) {
511 $_parser.emitError($_parser.getCurrentLocation(),
512 "expected a parameter name in struct");
513 return {};
515 if (!_loop_body(_paramKey)) return {};
518 // Struct parser loop end. Check for duplicate or unknown struct parameters.
520 // {0}: Code template for printing an error.
521 const char *const loopEnd = R"({{
522 {0}"duplicate or unknown struct parameter name: ") << _paramKey;
523 return {{};
527 // Struct parser loop terminator. Parse a comma except on the last element.
529 // {0}: Number of elements in the struct.
530 const char *const loopTerminator = R"(
531 if ((odsStructIndex != {0} - 1) && odsParser.parseComma())
532 return {{};
536 // Check that a mandatory parameter was parse.
538 // {0}: Name of the parameter.
539 const char *const checkParam = R"(
540 if (!_seen_{0}) {
541 {1}"struct is missing required parameter: ") << "{0}";
542 return {{};
546 // First iteration of the loop parsing an optional struct.
547 const char *const optionalStructFirst = R"(
548 ::llvm::StringRef _paramKey;
549 if (!$_parser.parseOptionalKeyword(&_paramKey)) {
550 if (!_loop_body(_paramKey)) return {};
551 while (!$_parser.parseOptionalComma()) {
554 os << "// Parse parameter struct\n";
556 // Declare a "seen" variable for each key.
557 for (ParameterElement *param : el->getParams())
558 os << formatv("bool _seen_{0} = false;\n", param->getName());
560 // Generate the body of the parsing loop inside a lambda.
561 os << "{\n";
562 os.indent()
563 << "const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {\n";
564 genLiteralParser("=", ctx, os.indent());
565 for (ParameterElement *param : el->getParams()) {
566 os << formatv("if (!_seen_{0} && _paramKey == \"{0}\") {\n"
567 " _seen_{0} = true;\n",
568 param->getName());
569 genVariableParser(param, ctx, os.indent());
570 os.unindent() << "} else ";
571 // Print the check for duplicate or unknown parameter.
573 os.getStream().printReindented(strfmt(loopEnd, tgfmt(parserErrorStr, &ctx)));
574 os << "return true;\n";
575 os.unindent() << "};\n";
577 // Generate the parsing loop. If optional parameters are present, then the
578 // parse loop is guarded by commas.
579 unsigned numOptional = llvm::count_if(el->getParams(), paramIsOptional);
580 if (numOptional) {
581 // If the struct itself is optional, pull out the first iteration.
582 if (numOptional == el->getNumParams()) {
583 os.getStream().printReindented(tgfmt(optionalStructFirst, &ctx).str());
584 os.indent();
585 } else {
586 os << "do {\n";
588 } else {
589 os.getStream().printReindented(
590 tgfmt(loopHeader, &ctx, el->getNumParams()).str());
592 os.indent();
593 os.getStream().printReindented(tgfmt(loopStart, &ctx).str());
594 os.unindent();
596 // Print the loop terminator. For optional parameters, we have to check that
597 // all mandatory parameters have been parsed.
598 // The whole struct is optional if all its parameters are optional.
599 if (numOptional) {
600 if (numOptional == el->getNumParams()) {
601 os << "}\n";
602 os.unindent() << "}\n";
603 } else {
604 os << tgfmt("} while(!$_parser.parseOptionalComma());\n", &ctx);
605 for (ParameterElement *param : el->getParams()) {
606 if (param->isOptional())
607 continue;
608 os.getStream().printReindented(
609 strfmt(checkParam, param->getName(), tgfmt(parserErrorStr, &ctx)));
612 } else {
613 // Because the loop loops N times and each non-failing iteration sets 1 of
614 // N flags, successfully exiting the loop means that all parameters have
615 // been seen. `parseOptionalComma` would cause issues with any formats that
616 // use "struct(...) `,`" beacuse structs aren't sounded by braces.
617 os.getStream().printReindented(strfmt(loopTerminator, el->getNumParams()));
619 os.unindent() << "}\n";
622 void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
623 MethodBody &os, bool isOptional) {
624 os << "{\n";
625 os.indent();
627 // Bound variables are passed directly to the parser as `FailureOr<T> &`.
628 // Referenced variables are passed as `T`. The custom parser fails if it
629 // returns failure or if any of the required parameters failed.
630 os << tgfmt("auto odsCustomLoc = $_parser.getCurrentLocation();\n", &ctx);
631 os << "(void)odsCustomLoc;\n";
632 os << tgfmt("auto odsCustomResult = parse$0($_parser", &ctx, el->getName());
633 os.indent();
634 for (FormatElement *arg : el->getArguments()) {
635 os << ",\n";
636 if (auto *param = dyn_cast<ParameterElement>(arg))
637 os << "::mlir::detail::unwrapForCustomParse(_result_" << param->getName()
638 << ")";
639 else if (auto *ref = dyn_cast<RefDirective>(arg))
640 os << "*_result_" << cast<ParameterElement>(ref->getArg())->getName();
641 else
642 os << tgfmt(cast<StringElement>(arg)->getValue(), &ctx);
644 os.unindent() << ");\n";
645 if (isOptional) {
646 os << "if (!odsCustomResult.has_value()) return {};\n";
647 os << "if (::mlir::failed(*odsCustomResult)) return ::mlir::failure();\n";
648 } else {
649 os << "if (::mlir::failed(odsCustomResult)) return {};\n";
651 for (FormatElement *arg : el->getArguments()) {
652 if (auto *param = dyn_cast<ParameterElement>(arg)) {
653 if (param->isOptional())
654 continue;
655 os << formatv("if (::mlir::failed(_result_{0})) {{\n", param->getName());
656 os.indent() << tgfmt("$_parser.emitError(odsCustomLoc, ", &ctx)
657 << "\"custom parser failed to parse parameter '"
658 << param->getName() << "'\");\n";
659 os << "return " << (isOptional ? "::mlir::failure()" : "{}") << ";\n";
660 os.unindent() << "}\n";
664 os.unindent() << "}\n";
667 void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
668 MethodBody &os) {
669 ArrayRef<FormatElement *> thenElements =
670 el->getThenElements(/*parseable=*/true);
672 FormatElement *first = thenElements.front();
673 const auto guardOn = [&](auto params) {
674 os << "if (!(";
675 llvm::interleave(
676 params, os,
677 [&](ParameterElement *el) {
678 os << formatv("(::mlir::succeeded(_result_{0}) && *_result_{0})",
679 el->getName());
681 " || ");
682 os << ")) {\n";
684 if (auto *literal = dyn_cast<LiteralElement>(first)) {
685 genLiteralParser(literal->getSpelling(), ctx, os, /*isOptional=*/true);
686 os << ") {\n";
687 } else if (auto *param = dyn_cast<ParameterElement>(first)) {
688 genVariableParser(param, ctx, os);
689 guardOn(llvm::ArrayRef(param));
690 } else if (auto *params = dyn_cast<ParamsDirective>(first)) {
691 genParamsParser(params, ctx, os);
692 guardOn(params->getParams());
693 } else if (auto *custom = dyn_cast<CustomDirective>(first)) {
694 os << "if (auto result = [&]() -> ::mlir::OptionalParseResult {\n";
695 os.indent();
696 genCustomParser(custom, ctx, os, /*isOptional=*/true);
697 os << "return ::mlir::success();\n";
698 os.unindent();
699 os << "}(); result.has_value() && ::mlir::failed(*result)) {\n";
700 os.indent();
701 os << "return {};\n";
702 os.unindent();
703 os << "} else if (result.has_value()) {\n";
704 } else {
705 auto *strct = cast<StructDirective>(first);
706 genStructParser(strct, ctx, os);
707 guardOn(params->getParams());
709 os.indent();
711 // Generate the parsers for the rest of the thenElements.
712 for (FormatElement *element : el->getElseElements(/*parseable=*/true))
713 genElementParser(element, ctx, os);
714 os.unindent() << "} else {\n";
715 os.indent();
716 for (FormatElement *element : thenElements.drop_front())
717 genElementParser(element, ctx, os);
718 os.unindent() << "}\n";
721 //===----------------------------------------------------------------------===//
722 // PrinterGen
723 //===----------------------------------------------------------------------===//
725 void DefFormat::genPrinter(MethodBody &os) {
726 FmtContext ctx;
727 ctx.addSubst("_printer", "odsPrinter");
728 ctx.addSubst("_ctxt", "getContext()");
729 ctx.withBuilder("odsBuilder");
730 os.indent();
731 os << "::mlir::Builder odsBuilder(getContext());\n";
733 // Generate printers.
734 shouldEmitSpace = true;
735 lastWasPunctuation = false;
736 for (FormatElement *el : elements)
737 genElementPrinter(el, ctx, os);
740 void DefFormat::genElementPrinter(FormatElement *el, FmtContext &ctx,
741 MethodBody &os) {
742 if (auto *literal = dyn_cast<LiteralElement>(el))
743 return genLiteralPrinter(literal->getSpelling(), ctx, os);
744 if (auto *params = dyn_cast<ParamsDirective>(el))
745 return genParamsPrinter(params, ctx, os);
746 if (auto *strct = dyn_cast<StructDirective>(el))
747 return genStructPrinter(strct, ctx, os);
748 if (auto *custom = dyn_cast<CustomDirective>(el))
749 return genCustomPrinter(custom, ctx, os);
750 if (auto *var = dyn_cast<ParameterElement>(el))
751 return genVariablePrinter(var, ctx, os);
752 if (auto *optional = dyn_cast<OptionalElement>(el))
753 return genOptionalGroupPrinter(optional, ctx, os);
754 if (auto *whitespace = dyn_cast<WhitespaceElement>(el))
755 return genWhitespacePrinter(whitespace, ctx, os);
757 llvm::PrintFatalError("unsupported format element");
760 void DefFormat::genLiteralPrinter(StringRef value, FmtContext &ctx,
761 MethodBody &os) {
762 // Don't insert a space before certain punctuation.
763 bool needSpace =
764 shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation);
765 os << tgfmt("$_printer$0 << \"$1\";\n", &ctx, needSpace ? " << ' '" : "",
766 value);
768 // Update the flags.
769 shouldEmitSpace =
770 value.size() != 1 || !StringRef("<({[").contains(value.front());
771 lastWasPunctuation = value.front() != '_' && !isalpha(value.front());
774 void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx,
775 MethodBody &os, bool skipGuard) {
776 const AttrOrTypeParameter &param = el->getParam();
777 ctx.withSelf(param.getAccessorName() + "()");
779 // Guard the printer on the presence of optional parameters and that they
780 // aren't equal to their default values (if they have one).
781 if (el->isOptional() && !skipGuard) {
782 el->genPrintGuard(ctx, os << "if (") << ") {\n";
783 os.indent();
786 // Insert a space before the next parameter, if necessary.
787 if (shouldEmitSpace || !lastWasPunctuation)
788 os << tgfmt("$_printer << ' ';\n", &ctx);
789 shouldEmitSpace = true;
790 lastWasPunctuation = false;
792 if (el->shouldBeQualified())
793 os << tgfmt(qualifiedParameterPrinter, &ctx) << ";\n";
794 else if (auto printer = param.getPrinter())
795 os << tgfmt(*printer, &ctx) << ";\n";
796 else
797 os << tgfmt(defaultParameterPrinter, &ctx) << ";\n";
799 if (el->isOptional() && !skipGuard)
800 os.unindent() << "}\n";
803 /// Generate code to guard printing on the presence of any optional parameters.
804 template <typename ParameterRange>
805 static void guardOnAny(FmtContext &ctx, MethodBody &os, ParameterRange &&params,
806 bool inverted = false) {
807 os << "if (";
808 if (inverted)
809 os << "!(";
810 llvm::interleave(
811 params, os,
812 [&](ParameterElement *param) { param->genPrintGuard(ctx, os); }, " || ");
813 if (inverted)
814 os << ")";
815 os << ") {\n";
816 os.indent();
819 void DefFormat::genCommaSeparatedPrinter(
820 ArrayRef<ParameterElement *> params, FmtContext &ctx, MethodBody &os,
821 function_ref<void(ParameterElement *)> extra) {
822 // Emit a space if necessary, but only if the struct is present.
823 if (shouldEmitSpace || !lastWasPunctuation) {
824 bool allOptional = llvm::all_of(params, paramIsOptional);
825 if (allOptional)
826 guardOnAny(ctx, os, params);
827 os << tgfmt("$_printer << ' ';\n", &ctx);
828 if (allOptional)
829 os.unindent() << "}\n";
832 // The first printed element does not need to emit a comma.
833 os << "{\n";
834 os.indent() << "bool _firstPrinted = true;\n";
835 for (ParameterElement *param : params) {
836 if (param->isOptional()) {
837 param->genPrintGuard(ctx, os << "if (") << ") {\n";
838 os.indent();
840 os << tgfmt("if (!_firstPrinted) $_printer << \", \";\n", &ctx);
841 os << "_firstPrinted = false;\n";
842 extra(param);
843 shouldEmitSpace = false;
844 lastWasPunctuation = true;
845 genVariablePrinter(param, ctx, os);
846 if (param->isOptional())
847 os.unindent() << "}\n";
849 os.unindent() << "}\n";
852 void DefFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
853 MethodBody &os) {
854 genCommaSeparatedPrinter(llvm::to_vector(el->getParams()), ctx, os,
855 [&](ParameterElement *param) {});
858 void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
859 MethodBody &os) {
860 genCommaSeparatedPrinter(
861 llvm::to_vector(el->getParams()), ctx, os, [&](ParameterElement *param) {
862 os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName());
866 void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
867 MethodBody &os) {
868 // Insert a space before the custom directive, if necessary.
869 if (shouldEmitSpace || !lastWasPunctuation)
870 os << tgfmt("$_printer << ' ';\n", &ctx);
871 shouldEmitSpace = true;
872 lastWasPunctuation = false;
874 os << tgfmt("print$0($_printer", &ctx, el->getName());
875 os.indent();
876 for (FormatElement *arg : el->getArguments()) {
877 os << ",\n";
878 if (auto *param = dyn_cast<ParameterElement>(arg)) {
879 os << param->getParam().getAccessorName() << "()";
880 } else if (auto *ref = dyn_cast<RefDirective>(arg)) {
881 os << cast<ParameterElement>(ref->getArg())->getParam().getAccessorName()
882 << "()";
883 } else {
884 os << tgfmt(cast<StringElement>(arg)->getValue(), &ctx);
887 os.unindent() << ");\n";
890 void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
891 MethodBody &os) {
892 FormatElement *anchor = el->getAnchor();
893 if (auto *param = dyn_cast<ParameterElement>(anchor)) {
894 guardOnAny(ctx, os, llvm::ArrayRef(param), el->isInverted());
895 } else if (auto *params = dyn_cast<ParamsDirective>(anchor)) {
896 guardOnAny(ctx, os, params->getParams(), el->isInverted());
897 } else if (auto *strct = dyn_cast<StructDirective>(anchor)) {
898 guardOnAny(ctx, os, strct->getParams(), el->isInverted());
899 } else {
900 auto *custom = cast<CustomDirective>(anchor);
901 guardOnAny(ctx, os,
902 llvm::make_filter_range(
903 llvm::map_range(custom->getArguments(),
904 [](FormatElement *el) {
905 return dyn_cast<ParameterElement>(el);
907 [](ParameterElement *param) { return !!param; }),
908 el->isInverted());
910 // Generate the printer for the contained elements.
912 llvm::SaveAndRestore shouldEmitSpaceFlag(shouldEmitSpace);
913 llvm::SaveAndRestore lastWasPunctuationFlag(lastWasPunctuation);
914 for (FormatElement *element : el->getThenElements())
915 genElementPrinter(element, ctx, os);
917 os.unindent() << "} else {\n";
918 os.indent();
919 for (FormatElement *element : el->getElseElements())
920 genElementPrinter(element, ctx, os);
921 os.unindent() << "}\n";
924 void DefFormat::genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx,
925 MethodBody &os) {
926 if (el->getValue() == "\\n") {
927 // FIXME: The newline should be `printer.printNewLine()`, i.e., handled by
928 // the printer.
929 os << tgfmt("$_printer << '\\n';\n", &ctx);
930 } else if (!el->getValue().empty()) {
931 os << tgfmt("$_printer << \"$0\";\n", &ctx, el->getValue());
932 } else {
933 lastWasPunctuation = true;
935 shouldEmitSpace = false;
938 //===----------------------------------------------------------------------===//
939 // DefFormatParser
940 //===----------------------------------------------------------------------===//
942 namespace {
943 class DefFormatParser : public FormatParser {
944 public:
945 DefFormatParser(llvm::SourceMgr &mgr, const AttrOrTypeDef &def)
946 : FormatParser(mgr, def.getLoc()[0]), def(def),
947 seenParams(def.getNumParameters()) {}
949 /// Parse the attribute or type format and create the format elements.
950 FailureOr<DefFormat> parse();
952 protected:
953 /// Verify the parsed elements.
954 LogicalResult verify(SMLoc loc, ArrayRef<FormatElement *> elements) override;
955 /// Verify the elements of a custom directive.
956 LogicalResult
957 verifyCustomDirectiveArguments(SMLoc loc,
958 ArrayRef<FormatElement *> arguments) override;
959 /// Verify the elements of an optional group.
960 LogicalResult verifyOptionalGroupElements(SMLoc loc,
961 ArrayRef<FormatElement *> elements,
962 FormatElement *anchor) override;
964 LogicalResult markQualified(SMLoc loc, FormatElement *element) override;
966 /// Parse an attribute or type variable.
967 FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
968 Context ctx) override;
969 /// Parse an attribute or type format directive.
970 FailureOr<FormatElement *>
971 parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override;
973 private:
974 /// Parse a `params` directive.
975 FailureOr<FormatElement *> parseParamsDirective(SMLoc loc, Context ctx);
976 /// Parse a `struct` directive.
977 FailureOr<FormatElement *> parseStructDirective(SMLoc loc, Context ctx);
979 /// Attribute or type tablegen def.
980 const AttrOrTypeDef &def;
982 /// Seen attribute or type parameters.
983 BitVector seenParams;
985 } // namespace
987 LogicalResult DefFormatParser::verify(SMLoc loc,
988 ArrayRef<FormatElement *> elements) {
989 // Check that all parameters are referenced in the format.
990 for (auto [index, param] : llvm::enumerate(def.getParameters())) {
991 if (param.isOptional())
992 continue;
993 if (!seenParams.test(index)) {
994 if (isa<AttributeSelfTypeParameter>(param))
995 continue;
996 return emitError(loc, "format is missing reference to parameter: " +
997 param.getName());
999 if (isa<AttributeSelfTypeParameter>(param)) {
1000 return emitError(loc,
1001 "unexpected self type parameter in assembly format");
1004 if (elements.empty())
1005 return success();
1006 // A `struct` directive that contains optional parameters cannot be followed
1007 // by a comma literal, which is ambiguous.
1008 for (auto it : llvm::zip(elements.drop_back(), elements.drop_front())) {
1009 auto *structEl = dyn_cast<StructDirective>(std::get<0>(it));
1010 auto *literalEl = dyn_cast<LiteralElement>(std::get<1>(it));
1011 if (!structEl || !literalEl)
1012 continue;
1013 if (literalEl->getSpelling() == "," && structEl->hasOptionalParams()) {
1014 return emitError(loc, "`struct` directive with optional parameters "
1015 "cannot be followed by a comma literal");
1018 return success();
1021 LogicalResult DefFormatParser::verifyCustomDirectiveArguments(
1022 SMLoc loc, ArrayRef<FormatElement *> arguments) {
1023 // Arguments are fully verified by the parser context.
1024 return success();
1027 LogicalResult
1028 DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
1029 ArrayRef<FormatElement *> elements,
1030 FormatElement *anchor) {
1031 // `params` and `struct` directives are allowed only if all the contained
1032 // parameters are optional.
1033 for (FormatElement *el : elements) {
1034 if (auto *param = dyn_cast<ParameterElement>(el)) {
1035 if (!param->isOptional()) {
1036 return emitError(loc,
1037 "parameters in an optional group must be optional");
1039 } else if (auto *params = dyn_cast<ParamsDirective>(el)) {
1040 if (llvm::any_of(params->getParams(), paramNotOptional)) {
1041 return emitError(loc, "`params` directive allowed in optional group "
1042 "only if all parameters are optional");
1044 } else if (auto *strct = dyn_cast<StructDirective>(el)) {
1045 if (llvm::any_of(strct->getParams(), paramNotOptional)) {
1046 return emitError(loc, "`struct` is only allowed in an optional group "
1047 "if all captured parameters are optional");
1049 } else if (auto *custom = dyn_cast<CustomDirective>(el)) {
1050 for (FormatElement *el : custom->getArguments()) {
1051 // If the custom argument is a variable, then it must be optional.
1052 if (auto *param = dyn_cast<ParameterElement>(el))
1053 if (!param->isOptional())
1054 return emitError(loc,
1055 "`custom` is only allowed in an optional group if "
1056 "all captured parameters are optional");
1060 // The anchor must be a parameter or one of the aforementioned directives.
1061 if (anchor) {
1062 if (!isa<ParameterElement, ParamsDirective, StructDirective,
1063 CustomDirective>(anchor)) {
1064 return emitError(
1065 loc, "optional group anchor must be a parameter or directive");
1067 // If the anchor is a custom directive, make sure at least one of its
1068 // arguments is a bound parameter.
1069 if (auto *custom = dyn_cast<CustomDirective>(anchor)) {
1070 const auto *bound =
1071 llvm::find_if(custom->getArguments(), [](FormatElement *el) {
1072 return isa<ParameterElement>(el);
1074 if (bound == custom->getArguments().end())
1075 return emitError(loc, "`custom` directive with no bound parameters "
1076 "cannot be used as optional group anchor");
1079 return success();
1082 LogicalResult DefFormatParser::markQualified(SMLoc loc,
1083 FormatElement *element) {
1084 if (!isa<ParameterElement>(element))
1085 return emitError(loc, "`qualified` argument list expected a variable");
1086 cast<ParameterElement>(element)->setShouldBeQualified();
1087 return success();
1090 FailureOr<DefFormat> DefFormatParser::parse() {
1091 FailureOr<std::vector<FormatElement *>> elements = FormatParser::parse();
1092 if (failed(elements))
1093 return failure();
1094 return DefFormat(def, std::move(*elements));
1097 FailureOr<FormatElement *>
1098 DefFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
1099 // Lookup the parameter.
1100 ArrayRef<AttrOrTypeParameter> params = def.getParameters();
1101 auto *it = llvm::find_if(
1102 params, [&](auto &param) { return param.getName() == name; });
1104 // Check that the parameter reference is valid.
1105 if (it == params.end()) {
1106 return emitError(loc,
1107 def.getName() + " has no parameter named '" + name + "'");
1109 auto idx = std::distance(params.begin(), it);
1111 if (ctx != RefDirectiveContext) {
1112 // Check that the variable has not already been bound.
1113 if (seenParams.test(idx))
1114 return emitError(loc, "duplicate parameter '" + name + "'");
1115 seenParams.set(idx);
1117 // Otherwise, to be referenced, a variable must have been bound.
1118 } else if (!seenParams.test(idx) && !isa<AttributeSelfTypeParameter>(*it)) {
1119 return emitError(loc, "parameter '" + name +
1120 "' must be bound before it is referenced");
1123 return create<ParameterElement>(*it);
1126 FailureOr<FormatElement *>
1127 DefFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
1128 Context ctx) {
1130 switch (kind) {
1131 case FormatToken::kw_qualified:
1132 return parseQualifiedDirective(loc, ctx);
1133 case FormatToken::kw_params:
1134 return parseParamsDirective(loc, ctx);
1135 case FormatToken::kw_struct:
1136 return parseStructDirective(loc, ctx);
1137 default:
1138 return emitError(loc, "unsupported directive kind");
1142 FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc,
1143 Context ctx) {
1144 // It doesn't make sense to allow references to all parameters in a custom
1145 // directive because parameters are the only things that can be bound.
1146 if (ctx != TopLevelContext && ctx != StructDirectiveContext) {
1147 return emitError(loc, "`params` can only be used at the top-level context "
1148 "or within a `struct` directive");
1151 // Collect all of the attribute's or type's parameters and ensure that none of
1152 // the parameters have already been captured.
1153 std::vector<ParameterElement *> vars;
1154 for (const auto &it : llvm::enumerate(def.getParameters())) {
1155 if (seenParams.test(it.index())) {
1156 return emitError(loc, "`params` captures duplicate parameter: " +
1157 it.value().getName());
1159 // Self-type parameters are handled separately from the rest of the
1160 // parameters.
1161 if (isa<AttributeSelfTypeParameter>(it.value()))
1162 continue;
1163 seenParams.set(it.index());
1164 vars.push_back(create<ParameterElement>(it.value()));
1166 return create<ParamsDirective>(std::move(vars));
1169 FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc,
1170 Context ctx) {
1171 if (ctx != TopLevelContext)
1172 return emitError(loc, "`struct` can only be used at the top-level context");
1174 if (failed(parseToken(FormatToken::l_paren,
1175 "expected '(' before `struct` argument list")))
1176 return failure();
1178 // Parse variables captured by `struct`.
1179 std::vector<ParameterElement *> vars;
1181 // Parse first captured parameter or a `params` directive.
1182 FailureOr<FormatElement *> var = parseElement(StructDirectiveContext);
1183 if (failed(var) || !isa<VariableElement, ParamsDirective>(*var)) {
1184 return emitError(loc,
1185 "`struct` argument list expected a variable or directive");
1187 if (isa<VariableElement>(*var)) {
1188 // Parse any other parameters.
1189 vars.push_back(cast<ParameterElement>(*var));
1190 while (peekToken().is(FormatToken::comma)) {
1191 consumeToken();
1192 var = parseElement(StructDirectiveContext);
1193 if (failed(var) || !isa<VariableElement>(*var))
1194 return emitError(loc, "expected a variable in `struct` argument list");
1195 vars.push_back(cast<ParameterElement>(*var));
1197 } else {
1198 // `struct(params)` captures all parameters in the attribute or type.
1199 vars = cast<ParamsDirective>(*var)->takeParams();
1202 if (failed(parseToken(FormatToken::r_paren,
1203 "expected ')' at the end of an argument list")))
1204 return failure();
1206 return create<StructDirective>(std::move(vars));
1209 //===----------------------------------------------------------------------===//
1210 // Interface
1211 //===----------------------------------------------------------------------===//
1213 void mlir::tblgen::generateAttrOrTypeFormat(const AttrOrTypeDef &def,
1214 MethodBody &parser,
1215 MethodBody &printer) {
1216 llvm::SourceMgr mgr;
1217 mgr.AddNewSourceBuffer(
1218 llvm::MemoryBuffer::getMemBuffer(*def.getAssemblyFormat()), SMLoc());
1220 // Parse the custom assembly format>
1221 DefFormatParser fmtParser(mgr, def);
1222 FailureOr<DefFormat> format = fmtParser.parse();
1223 if (failed(format)) {
1224 if (formatErrorIsFatal)
1225 PrintFatalError(def.getLoc(), "failed to parse assembly format");
1226 return;
1229 // Generate the parser and printer.
1230 format->genParser(parser);
1231 format->genPrinter(printer);