[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / mlir / tools / mlir-tblgen / AttrOrTypeFormatGen.cpp
blobf8e0c83da3c8a66f2a580fae7a74cfdabaf21fe0
1 //===- AttrOrTypeFormatGen.cpp - MLIR attribute and type format generator -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "AttrOrTypeFormatGen.h"
10 #include "FormatGen.h"
11 #include "mlir/Support/LLVM.h"
12 #include "mlir/Support/LogicalResult.h"
13 #include "mlir/TableGen/AttrOrTypeDef.h"
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/GenInfo.h"
16 #include "llvm/ADT/BitVector.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/StringSwitch.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 #include "llvm/Support/MemoryBuffer.h"
21 #include "llvm/Support/SaveAndRestore.h"
22 #include "llvm/Support/SourceMgr.h"
23 #include "llvm/TableGen/Error.h"
24 #include "llvm/TableGen/TableGenBackend.h"
26 using namespace mlir;
27 using namespace mlir::tblgen;
29 using llvm::formatv;
31 //===----------------------------------------------------------------------===//
32 // Element
33 //===----------------------------------------------------------------------===//
35 namespace {
36 /// This class represents an instance of a variable element. A variable refers
37 /// to an attribute or type parameter.
38 class ParameterElement
39 : public VariableElementBase<VariableElement::Parameter> {
40 public:
41 ParameterElement(AttrOrTypeParameter param) : param(param) {}
43 /// Get the parameter in the element.
44 const AttrOrTypeParameter &getParam() const { return param; }
46 /// Indicate if this variable is printed "qualified" (that is it is
47 /// prefixed with the `#dialect.mnemonic`).
48 bool shouldBeQualified() { return shouldBeQualifiedFlag; }
49 void setShouldBeQualified(bool qualified = true) {
50 shouldBeQualifiedFlag = qualified;
53 /// Returns true if the element contains an optional parameter.
54 bool isOptional() const { return param.isOptional(); }
56 /// Returns the name of the parameter.
57 StringRef getName() const { return param.getName(); }
59 /// Return the code to check whether the parameter is present.
60 auto genIsPresent(FmtContext &ctx, const Twine &self) const {
61 assert(isOptional() && "cannot guard on a mandatory parameter");
62 std::string valueStr = tgfmt(*param.getDefaultValue(), &ctx).str();
63 ctx.addSubst("_lhs", self).addSubst("_rhs", valueStr);
64 return tgfmt(getParam().getComparator(), &ctx);
67 /// Generate the code to check whether the parameter should be printed.
68 MethodBody &genPrintGuard(FmtContext &ctx, MethodBody &os) const {
69 assert(isOptional() && "cannot guard on a mandatory parameter");
70 std::string self = param.getAccessorName() + "()";
71 return os << "!(" << genIsPresent(ctx, self) << ")";
74 private:
75 bool shouldBeQualifiedFlag = false;
76 AttrOrTypeParameter param;
79 /// Shorthand functions that can be used with ranged-based conditions.
80 static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); }
81 static bool paramNotOptional(ParameterElement *el) { return !el->isOptional(); }
83 /// Base class for a directive that contains references to multiple variables.
84 template <DirectiveElement::Kind DirectiveKind>
85 class ParamsDirectiveBase : public DirectiveElementBase<DirectiveKind> {
86 public:
87 using Base = ParamsDirectiveBase<DirectiveKind>;
89 ParamsDirectiveBase(std::vector<ParameterElement *> &&params)
90 : params(std::move(params)) {}
92 /// Get the parameters contained in this directive.
93 ArrayRef<ParameterElement *> getParams() const { return params; }
95 /// Get the number of parameters.
96 unsigned getNumParams() const { return params.size(); }
98 /// Take all of the parameters from this directive.
99 std::vector<ParameterElement *> takeParams() { return std::move(params); }
101 /// Returns true if there are optional parameters present.
102 bool hasOptionalParams() const {
103 return llvm::any_of(getParams(), paramIsOptional);
106 private:
107 /// The parameters captured by this directive.
108 std::vector<ParameterElement *> params;
111 /// This class represents a `params` directive that refers to all parameters
112 /// of an attribute or type. When used as a top-level directive, it generates
113 /// a format of the form:
115 /// (param-value (`,` param-value)*)?
117 /// When used as an argument to another directive that accepts variables,
118 /// `params` can be used in place of manually listing all parameters of an
119 /// attribute or type.
120 class ParamsDirective : public ParamsDirectiveBase<DirectiveElement::Params> {
121 public:
122 using Base::Base;
125 /// This class represents a `struct` directive that generates a struct format
126 /// of the form:
128 /// `{` param-name `=` param-value (`,` param-name `=` param-value)* `}`
130 class StructDirective : public ParamsDirectiveBase<DirectiveElement::Struct> {
131 public:
132 using Base::Base;
135 } // namespace
137 //===----------------------------------------------------------------------===//
138 // Format Strings
139 //===----------------------------------------------------------------------===//
141 /// Default parser for attribute or type parameters.
142 static const char *const defaultParameterParser =
143 "::mlir::FieldParser<$0>::parse($_parser)";
145 /// Default printer for attribute or type parameters.
146 static const char *const defaultParameterPrinter =
147 "$_printer.printStrippedAttrOrType($_self)";
149 /// Qualified printer for attribute or type parameters: it does not elide
150 /// dialect and mnemonic.
151 static const char *const qualifiedParameterPrinter = "$_printer << $_self";
153 /// Print an error when failing to parse an element.
155 /// $0: The parameter C++ class name.
156 static const char *const parserErrorStr =
157 "$_parser.emitError($_parser.getCurrentLocation(), ";
159 /// Code format to parse a variable. Separate by lines because variable parsers
160 /// may be generated inside other directives, which requires indentation.
162 /// {0}: The parameter name.
163 /// {1}: The parse code for the parameter.
164 /// {2}: Code template for printing an error.
165 /// {3}: Name of the attribute or type.
166 /// {4}: C++ class of the parameter.
167 static const char *const variableParser = R"(
168 // Parse variable '{0}'
169 _result_{0} = {1};
170 if (::mlir::failed(_result_{0})) {{
171 {2}"failed to parse {3} parameter '{0}' which is to be a `{4}`");
172 return {{};
176 //===----------------------------------------------------------------------===//
177 // DefFormat
178 //===----------------------------------------------------------------------===//
180 namespace {
181 class DefFormat {
182 public:
183 DefFormat(const AttrOrTypeDef &def, std::vector<FormatElement *> &&elements)
184 : def(def), elements(std::move(elements)) {}
186 /// Generate the attribute or type parser.
187 void genParser(MethodBody &os);
188 /// Generate the attribute or type printer.
189 void genPrinter(MethodBody &os);
191 private:
192 /// Generate the parser code for a specific format element.
193 void genElementParser(FormatElement *el, FmtContext &ctx, MethodBody &os);
194 /// Generate the parser code for a literal.
195 void genLiteralParser(StringRef value, FmtContext &ctx, MethodBody &os,
196 bool isOptional = false);
197 /// Generate the parser code for a variable.
198 void genVariableParser(ParameterElement *el, FmtContext &ctx, MethodBody &os);
199 /// Generate the parser code for a `params` directive.
200 void genParamsParser(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
201 /// Generate the parser code for a `struct` directive.
202 void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os);
203 /// Generate the parser code for a `custom` directive.
204 void genCustomParser(CustomDirective *el, FmtContext &ctx, MethodBody &os,
205 bool isOptional = false);
206 /// Generate the parser code for an optional group.
207 void genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
208 MethodBody &os);
210 /// Generate the printer code for a specific format element.
211 void genElementPrinter(FormatElement *el, FmtContext &ctx, MethodBody &os);
212 /// Generate the printer code for a literal.
213 void genLiteralPrinter(StringRef value, FmtContext &ctx, MethodBody &os);
214 /// Generate the printer code for a variable.
215 void genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os,
216 bool skipGuard = false);
217 /// Generate a printer for comma-separated parameters.
218 void genCommaSeparatedPrinter(ArrayRef<ParameterElement *> params,
219 FmtContext &ctx, MethodBody &os,
220 function_ref<void(ParameterElement *)> extra);
221 /// Generate the printer code for a `params` directive.
222 void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
223 /// Generate the printer code for a `struct` directive.
224 void genStructPrinter(StructDirective *el, FmtContext &ctx, MethodBody &os);
225 /// Generate the printer code for a `custom` directive.
226 void genCustomPrinter(CustomDirective *el, FmtContext &ctx, MethodBody &os);
227 /// Generate the printer code for an optional group.
228 void genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
229 MethodBody &os);
230 /// Generate a printer (or space eraser) for a whitespace element.
231 void genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx,
232 MethodBody &os);
234 /// The ODS definition of the attribute or type whose format is being used to
235 /// generate a parser and printer.
236 const AttrOrTypeDef &def;
237 /// The list of top-level format elements returned by the assembly format
238 /// parser.
239 std::vector<FormatElement *> elements;
241 /// Flags for printing spaces.
242 bool shouldEmitSpace = false;
243 bool lastWasPunctuation = false;
245 } // namespace
247 //===----------------------------------------------------------------------===//
248 // ParserGen
249 //===----------------------------------------------------------------------===//
251 /// Generate a special-case "parser" for an attribute's self type parameter. The
252 /// self type parameter has special handling in the assembly format in that it
253 /// is derived from the optional trailing colon type after the attribute.
254 static void genAttrSelfTypeParser(MethodBody &os, const FmtContext &ctx,
255 const AttributeSelfTypeParameter &param) {
256 // "Parser" for an attribute self type parameter that checks the
257 // optionally-parsed trailing colon type.
259 // $0: The C++ storage class of the type parameter.
260 // $1: The self type parameter name.
261 const char *const selfTypeParser = R"(
262 if ($_type) {
263 if (auto reqType = ::llvm::dyn_cast<$0>($_type)) {
264 _result_$1 = reqType;
265 } else {
266 $_parser.emitError($_loc, "invalid kind of type specified");
267 return {};
269 })";
271 // If the attribute self type parameter is required, emit code that emits an
272 // error if the trailing type was not parsed.
273 const char *const selfTypeRequired = R"( else {
274 $_parser.emitError($_loc, "expected a trailing type");
275 return {};
276 })";
278 os << tgfmt(selfTypeParser, &ctx, param.getCppStorageType(), param.getName());
279 if (!param.isOptional())
280 os << tgfmt(selfTypeRequired, &ctx);
281 os << "\n";
284 void DefFormat::genParser(MethodBody &os) {
285 FmtContext ctx;
286 ctx.addSubst("_parser", "odsParser");
287 ctx.addSubst("_ctxt", "odsParser.getContext()");
288 ctx.withBuilder("odsBuilder");
289 if (isa<AttrDef>(def))
290 ctx.addSubst("_type", "odsType");
291 os.indent();
292 os << "::mlir::Builder odsBuilder(odsParser.getContext());\n";
294 // Store the initial location of the parser.
295 ctx.addSubst("_loc", "odsLoc");
296 os << tgfmt("::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n"
297 "(void) $_loc;\n",
298 &ctx);
300 // Declare variables to store all of the parameters. Allocated parameters
301 // such as `ArrayRef` and `StringRef` must provide a `storageType`. Store
302 // FailureOr<T> to defer type construction for parameters that are parsed in
303 // a loop (parsers return FailureOr anyways).
304 ArrayRef<AttrOrTypeParameter> params = def.getParameters();
305 for (const AttrOrTypeParameter &param : params) {
306 os << formatv("::mlir::FailureOr<{0}> _result_{1};\n",
307 param.getCppStorageType(), param.getName());
308 if (auto *selfTypeParam = dyn_cast<AttributeSelfTypeParameter>(&param))
309 genAttrSelfTypeParser(os, ctx, *selfTypeParam);
312 // Generate call to each parameter parser.
313 for (FormatElement *el : elements)
314 genElementParser(el, ctx, os);
316 // Emit an assert for each mandatory parameter. Triggering an assert means
317 // the generated parser is incorrect (i.e. there is a bug in this code).
318 for (const AttrOrTypeParameter &param : params) {
319 if (param.isOptional())
320 continue;
321 os << formatv("assert(::mlir::succeeded(_result_{0}));\n", param.getName());
324 // Generate call to the attribute or type builder. Use the checked getter
325 // if one was generated.
326 if (def.genVerifyDecl()) {
327 os << tgfmt("return $_parser.getChecked<$0>($_loc, $_parser.getContext()",
328 &ctx, def.getCppClassName());
329 } else {
330 os << tgfmt("return $0::get($_parser.getContext()", &ctx,
331 def.getCppClassName());
333 for (const AttrOrTypeParameter &param : params) {
334 os << ",\n ";
335 std::string paramSelfStr;
336 llvm::raw_string_ostream selfOs(paramSelfStr);
337 if (std::optional<StringRef> defaultValue = param.getDefaultValue()) {
338 selfOs << formatv("(_result_{0}.value_or(", param.getName())
339 << tgfmt(*defaultValue, &ctx) << "))";
340 } else {
341 selfOs << formatv("(*_result_{0})", param.getName());
343 ctx.addSubst(param.getName(), selfOs.str());
344 os << param.getCppType() << "("
345 << tgfmt(param.getConvertFromStorage(), &ctx.withSelf(selfOs.str()))
346 << ")";
348 os << ");";
351 void DefFormat::genElementParser(FormatElement *el, FmtContext &ctx,
352 MethodBody &os) {
353 if (auto *literal = dyn_cast<LiteralElement>(el))
354 return genLiteralParser(literal->getSpelling(), ctx, os);
355 if (auto *var = dyn_cast<ParameterElement>(el))
356 return genVariableParser(var, ctx, os);
357 if (auto *params = dyn_cast<ParamsDirective>(el))
358 return genParamsParser(params, ctx, os);
359 if (auto *strct = dyn_cast<StructDirective>(el))
360 return genStructParser(strct, ctx, os);
361 if (auto *custom = dyn_cast<CustomDirective>(el))
362 return genCustomParser(custom, ctx, os);
363 if (auto *optional = dyn_cast<OptionalElement>(el))
364 return genOptionalGroupParser(optional, ctx, os);
365 if (isa<WhitespaceElement>(el))
366 return;
368 llvm_unreachable("unknown format element");
371 void DefFormat::genLiteralParser(StringRef value, FmtContext &ctx,
372 MethodBody &os, bool isOptional) {
373 os << "// Parse literal '" << value << "'\n";
374 os << tgfmt("if ($_parser.parse", &ctx);
375 if (isOptional)
376 os << "Optional";
377 if (value.front() == '_' || isalpha(value.front())) {
378 os << "Keyword(\"" << value << "\")";
379 } else {
380 os << StringSwitch<StringRef>(value)
381 .Case("->", "Arrow")
382 .Case(":", "Colon")
383 .Case(",", "Comma")
384 .Case("=", "Equal")
385 .Case("<", "Less")
386 .Case(">", "Greater")
387 .Case("{", "LBrace")
388 .Case("}", "RBrace")
389 .Case("(", "LParen")
390 .Case(")", "RParen")
391 .Case("[", "LSquare")
392 .Case("]", "RSquare")
393 .Case("?", "Question")
394 .Case("+", "Plus")
395 .Case("*", "Star")
396 .Case("...", "Ellipsis")
397 << "()";
399 if (isOptional) {
400 // Leave the `if` unclosed to guard optional groups.
401 return;
403 // Parser will emit an error
404 os << ") return {};\n";
407 void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx,
408 MethodBody &os) {
409 // Check for a custom parser. Use the default attribute parser otherwise.
410 const AttrOrTypeParameter &param = el->getParam();
411 auto customParser = param.getParser();
412 auto parser =
413 customParser ? *customParser : StringRef(defaultParameterParser);
414 os << formatv(variableParser, param.getName(),
415 tgfmt(parser, &ctx, param.getCppStorageType()),
416 tgfmt(parserErrorStr, &ctx), def.getName(), param.getCppType());
419 void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
420 MethodBody &os) {
421 os << "// Parse parameter list\n";
423 // If there are optional parameters, we need to switch to `parseOptionalComma`
424 // if there are no more required parameters after a certain point.
425 bool hasOptional = el->hasOptionalParams();
426 if (hasOptional) {
427 // Wrap everything in a do-while so that we can `break`.
428 os << "do {\n";
429 os.indent();
432 ArrayRef<ParameterElement *> params = el->getParams();
433 using IteratorT = ParameterElement *const *;
434 IteratorT it = params.begin();
436 // Find the last required parameter. Commas become optional aftewards.
437 // Note: IteratorT's copy assignment is deleted.
438 ParameterElement *lastReq = nullptr;
439 for (ParameterElement *param : params)
440 if (!param->isOptional())
441 lastReq = param;
442 IteratorT lastReqIt = lastReq ? llvm::find(params, lastReq) : params.begin();
444 auto eachFn = [&](ParameterElement *el) { genVariableParser(el, ctx, os); };
445 auto betweenFn = [&](IteratorT it) {
446 ParameterElement *el = *std::prev(it);
447 // Parse a comma if the last optional parameter had a value.
448 if (el->isOptional()) {
449 os << formatv("if (::mlir::succeeded(_result_{0}) && !({1})) {{\n",
450 el->getName(),
451 el->genIsPresent(ctx, "(*_result_" + el->getName() + ")"));
452 os.indent();
454 if (it <= lastReqIt) {
455 genLiteralParser(",", ctx, os);
456 } else {
457 genLiteralParser(",", ctx, os, /*isOptional=*/true);
458 os << ") break;\n";
460 if (el->isOptional())
461 os.unindent() << "}\n";
464 // llvm::interleave
465 if (it != params.end()) {
466 eachFn(*it++);
467 for (IteratorT e = params.end(); it != e; ++it) {
468 betweenFn(it);
469 eachFn(*it);
473 if (hasOptional)
474 os.unindent() << "} while(false);\n";
477 void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
478 MethodBody &os) {
479 // Loop declaration for struct parser with only required parameters.
481 // $0: Number of expected parameters.
482 const char *const loopHeader = R"(
483 for (unsigned odsStructIndex = 0; odsStructIndex < $0; ++odsStructIndex) {
486 // Loop body start for struct parser.
487 const char *const loopStart = R"(
488 ::llvm::StringRef _paramKey;
489 if ($_parser.parseKeyword(&_paramKey)) {
490 $_parser.emitError($_parser.getCurrentLocation(),
491 "expected a parameter name in struct");
492 return {};
494 if (!_loop_body(_paramKey)) return {};
497 // Struct parser loop end. Check for duplicate or unknown struct parameters.
499 // {0}: Code template for printing an error.
500 const char *const loopEnd = R"({{
501 {0}"duplicate or unknown struct parameter name: ") << _paramKey;
502 return {{};
506 // Struct parser loop terminator. Parse a comma except on the last element.
508 // {0}: Number of elements in the struct.
509 const char *const loopTerminator = R"(
510 if ((odsStructIndex != {0} - 1) && odsParser.parseComma())
511 return {{};
515 // Check that a mandatory parameter was parse.
517 // {0}: Name of the parameter.
518 const char *const checkParam = R"(
519 if (!_seen_{0}) {
520 {1}"struct is missing required parameter: ") << "{0}";
521 return {{};
525 // First iteration of the loop parsing an optional struct.
526 const char *const optionalStructFirst = R"(
527 ::llvm::StringRef _paramKey;
528 if (!$_parser.parseOptionalKeyword(&_paramKey)) {
529 if (!_loop_body(_paramKey)) return {};
530 while (!$_parser.parseOptionalComma()) {
533 os << "// Parse parameter struct\n";
535 // Declare a "seen" variable for each key.
536 for (ParameterElement *param : el->getParams())
537 os << formatv("bool _seen_{0} = false;\n", param->getName());
539 // Generate the body of the parsing loop inside a lambda.
540 os << "{\n";
541 os.indent()
542 << "const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {\n";
543 genLiteralParser("=", ctx, os.indent());
544 for (ParameterElement *param : el->getParams()) {
545 os << formatv("if (!_seen_{0} && _paramKey == \"{0}\") {\n"
546 " _seen_{0} = true;\n",
547 param->getName());
548 genVariableParser(param, ctx, os.indent());
549 os.unindent() << "} else ";
550 // Print the check for duplicate or unknown parameter.
552 os.getStream().printReindented(strfmt(loopEnd, tgfmt(parserErrorStr, &ctx)));
553 os << "return true;\n";
554 os.unindent() << "};\n";
556 // Generate the parsing loop. If optional parameters are present, then the
557 // parse loop is guarded by commas.
558 unsigned numOptional = llvm::count_if(el->getParams(), paramIsOptional);
559 if (numOptional) {
560 // If the struct itself is optional, pull out the first iteration.
561 if (numOptional == el->getNumParams()) {
562 os.getStream().printReindented(tgfmt(optionalStructFirst, &ctx).str());
563 os.indent();
564 } else {
565 os << "do {\n";
567 } else {
568 os.getStream().printReindented(
569 tgfmt(loopHeader, &ctx, el->getNumParams()).str());
571 os.indent();
572 os.getStream().printReindented(tgfmt(loopStart, &ctx).str());
573 os.unindent();
575 // Print the loop terminator. For optional parameters, we have to check that
576 // all mandatory parameters have been parsed.
577 // The whole struct is optional if all its parameters are optional.
578 if (numOptional) {
579 if (numOptional == el->getNumParams()) {
580 os << "}\n";
581 os.unindent() << "}\n";
582 } else {
583 os << tgfmt("} while(!$_parser.parseOptionalComma());\n", &ctx);
584 for (ParameterElement *param : el->getParams()) {
585 if (param->isOptional())
586 continue;
587 os.getStream().printReindented(
588 strfmt(checkParam, param->getName(), tgfmt(parserErrorStr, &ctx)));
591 } else {
592 // Because the loop loops N times and each non-failing iteration sets 1 of
593 // N flags, successfully exiting the loop means that all parameters have
594 // been seen. `parseOptionalComma` would cause issues with any formats that
595 // use "struct(...) `,`" beacuse structs aren't sounded by braces.
596 os.getStream().printReindented(strfmt(loopTerminator, el->getNumParams()));
598 os.unindent() << "}\n";
601 void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
602 MethodBody &os, bool isOptional) {
603 os << "{\n";
604 os.indent();
606 // Bound variables are passed directly to the parser as `FailureOr<T> &`.
607 // Referenced variables are passed as `T`. The custom parser fails if it
608 // returns failure or if any of the required parameters failed.
609 os << tgfmt("auto odsCustomLoc = $_parser.getCurrentLocation();\n", &ctx);
610 os << "(void)odsCustomLoc;\n";
611 os << tgfmt("auto odsCustomResult = parse$0($_parser", &ctx, el->getName());
612 os.indent();
613 for (FormatElement *arg : el->getArguments()) {
614 os << ",\n";
615 if (auto *param = dyn_cast<ParameterElement>(arg))
616 os << "::mlir::detail::unwrapForCustomParse(_result_" << param->getName()
617 << ")";
618 else if (auto *ref = dyn_cast<RefDirective>(arg))
619 os << "*_result_" << cast<ParameterElement>(ref->getArg())->getName();
620 else
621 os << tgfmt(cast<StringElement>(arg)->getValue(), &ctx);
623 os.unindent() << ");\n";
624 if (isOptional) {
625 os << "if (!odsCustomResult) return {};\n";
626 os << "if (::mlir::failed(*odsCustomResult)) return ::mlir::failure();\n";
627 } else {
628 os << "if (::mlir::failed(odsCustomResult)) return {};\n";
630 for (FormatElement *arg : el->getArguments()) {
631 if (auto *param = dyn_cast<ParameterElement>(arg)) {
632 if (param->isOptional())
633 continue;
634 os << formatv("if (::mlir::failed(_result_{0})) {{\n", param->getName());
635 os.indent() << tgfmt("$_parser.emitError(odsCustomLoc, ", &ctx)
636 << "\"custom parser failed to parse parameter '"
637 << param->getName() << "'\");\n";
638 os << "return " << (isOptional ? "::mlir::failure()" : "{}") << ";\n";
639 os.unindent() << "}\n";
643 os.unindent() << "}\n";
646 void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
647 MethodBody &os) {
648 ArrayRef<FormatElement *> thenElements =
649 el->getThenElements(/*parseable=*/true);
651 FormatElement *first = thenElements.front();
652 const auto guardOn = [&](auto params) {
653 os << "if (!(";
654 llvm::interleave(
655 params, os,
656 [&](ParameterElement *el) {
657 os << formatv("(::mlir::succeeded(_result_{0}) && *_result_{0})",
658 el->getName());
660 " || ");
661 os << ")) {\n";
663 if (auto *literal = dyn_cast<LiteralElement>(first)) {
664 genLiteralParser(literal->getSpelling(), ctx, os, /*isOptional=*/true);
665 os << ") {\n";
666 } else if (auto *param = dyn_cast<ParameterElement>(first)) {
667 genVariableParser(param, ctx, os);
668 guardOn(llvm::ArrayRef(param));
669 } else if (auto *params = dyn_cast<ParamsDirective>(first)) {
670 genParamsParser(params, ctx, os);
671 guardOn(params->getParams());
672 } else if (auto *custom = dyn_cast<CustomDirective>(first)) {
673 os << "if (auto result = [&]() -> ::mlir::OptionalParseResult {\n";
674 os.indent();
675 genCustomParser(custom, ctx, os, /*isOptional=*/true);
676 os << "return ::mlir::success();\n";
677 os.unindent();
678 os << "}(); result.has_value() && ::mlir::failed(*result)) {\n";
679 os.indent();
680 os << "return {};\n";
681 os.unindent();
682 os << "} else if (result.has_value()) {\n";
683 } else {
684 auto *strct = cast<StructDirective>(first);
685 genStructParser(strct, ctx, os);
686 guardOn(params->getParams());
688 os.indent();
690 // Generate the parsers for the rest of the thenElements.
691 for (FormatElement *element : el->getElseElements(/*parseable=*/true))
692 genElementParser(element, ctx, os);
693 os.unindent() << "} else {\n";
694 os.indent();
695 for (FormatElement *element : thenElements.drop_front())
696 genElementParser(element, ctx, os);
697 os.unindent() << "}\n";
700 //===----------------------------------------------------------------------===//
701 // PrinterGen
702 //===----------------------------------------------------------------------===//
704 void DefFormat::genPrinter(MethodBody &os) {
705 FmtContext ctx;
706 ctx.addSubst("_printer", "odsPrinter");
707 ctx.addSubst("_ctxt", "getContext()");
708 ctx.withBuilder("odsBuilder");
709 os.indent();
710 os << "::mlir::Builder odsBuilder(getContext());\n";
712 // Generate printers.
713 shouldEmitSpace = true;
714 lastWasPunctuation = false;
715 for (FormatElement *el : elements)
716 genElementPrinter(el, ctx, os);
719 void DefFormat::genElementPrinter(FormatElement *el, FmtContext &ctx,
720 MethodBody &os) {
721 if (auto *literal = dyn_cast<LiteralElement>(el))
722 return genLiteralPrinter(literal->getSpelling(), ctx, os);
723 if (auto *params = dyn_cast<ParamsDirective>(el))
724 return genParamsPrinter(params, ctx, os);
725 if (auto *strct = dyn_cast<StructDirective>(el))
726 return genStructPrinter(strct, ctx, os);
727 if (auto *custom = dyn_cast<CustomDirective>(el))
728 return genCustomPrinter(custom, ctx, os);
729 if (auto *var = dyn_cast<ParameterElement>(el))
730 return genVariablePrinter(var, ctx, os);
731 if (auto *optional = dyn_cast<OptionalElement>(el))
732 return genOptionalGroupPrinter(optional, ctx, os);
733 if (auto *whitespace = dyn_cast<WhitespaceElement>(el))
734 return genWhitespacePrinter(whitespace, ctx, os);
736 llvm::PrintFatalError("unsupported format element");
739 void DefFormat::genLiteralPrinter(StringRef value, FmtContext &ctx,
740 MethodBody &os) {
741 // Don't insert a space before certain punctuation.
742 bool needSpace =
743 shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation);
744 os << tgfmt("$_printer$0 << \"$1\";\n", &ctx, needSpace ? " << ' '" : "",
745 value);
747 // Update the flags.
748 shouldEmitSpace =
749 value.size() != 1 || !StringRef("<({[").contains(value.front());
750 lastWasPunctuation = value.front() != '_' && !isalpha(value.front());
753 void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx,
754 MethodBody &os, bool skipGuard) {
755 const AttrOrTypeParameter &param = el->getParam();
756 ctx.withSelf(param.getAccessorName() + "()");
758 // Guard the printer on the presence of optional parameters and that they
759 // aren't equal to their default values (if they have one).
760 if (el->isOptional() && !skipGuard) {
761 el->genPrintGuard(ctx, os << "if (") << ") {\n";
762 os.indent();
765 // Insert a space before the next parameter, if necessary.
766 if (shouldEmitSpace || !lastWasPunctuation)
767 os << tgfmt("$_printer << ' ';\n", &ctx);
768 shouldEmitSpace = true;
769 lastWasPunctuation = false;
771 if (el->shouldBeQualified())
772 os << tgfmt(qualifiedParameterPrinter, &ctx) << ";\n";
773 else if (auto printer = param.getPrinter())
774 os << tgfmt(*printer, &ctx) << ";\n";
775 else
776 os << tgfmt(defaultParameterPrinter, &ctx) << ";\n";
778 if (el->isOptional() && !skipGuard)
779 os.unindent() << "}\n";
782 /// Generate code to guard printing on the presence of any optional parameters.
783 template <typename ParameterRange>
784 static void guardOnAny(FmtContext &ctx, MethodBody &os, ParameterRange &&params,
785 bool inverted = false) {
786 os << "if (";
787 if (inverted)
788 os << "!(";
789 llvm::interleave(
790 params, os,
791 [&](ParameterElement *param) { param->genPrintGuard(ctx, os); }, " || ");
792 if (inverted)
793 os << ")";
794 os << ") {\n";
795 os.indent();
798 void DefFormat::genCommaSeparatedPrinter(
799 ArrayRef<ParameterElement *> params, FmtContext &ctx, MethodBody &os,
800 function_ref<void(ParameterElement *)> extra) {
801 // Emit a space if necessary, but only if the struct is present.
802 if (shouldEmitSpace || !lastWasPunctuation) {
803 bool allOptional = llvm::all_of(params, paramIsOptional);
804 if (allOptional)
805 guardOnAny(ctx, os, params);
806 os << tgfmt("$_printer << ' ';\n", &ctx);
807 if (allOptional)
808 os.unindent() << "}\n";
811 // The first printed element does not need to emit a comma.
812 os << "{\n";
813 os.indent() << "bool _firstPrinted = true;\n";
814 for (ParameterElement *param : params) {
815 if (param->isOptional()) {
816 param->genPrintGuard(ctx, os << "if (") << ") {\n";
817 os.indent();
819 os << tgfmt("if (!_firstPrinted) $_printer << \", \";\n", &ctx);
820 os << "_firstPrinted = false;\n";
821 extra(param);
822 shouldEmitSpace = false;
823 lastWasPunctuation = true;
824 genVariablePrinter(param, ctx, os);
825 if (param->isOptional())
826 os.unindent() << "}\n";
828 os.unindent() << "}\n";
831 void DefFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
832 MethodBody &os) {
833 genCommaSeparatedPrinter(llvm::to_vector(el->getParams()), ctx, os,
834 [&](ParameterElement *param) {});
837 void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
838 MethodBody &os) {
839 genCommaSeparatedPrinter(
840 llvm::to_vector(el->getParams()), ctx, os, [&](ParameterElement *param) {
841 os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName());
845 void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
846 MethodBody &os) {
847 // Insert a space before the custom directive, if necessary.
848 if (shouldEmitSpace || !lastWasPunctuation)
849 os << tgfmt("$_printer << ' ';\n", &ctx);
850 shouldEmitSpace = true;
851 lastWasPunctuation = false;
853 os << tgfmt("print$0($_printer", &ctx, el->getName());
854 os.indent();
855 for (FormatElement *arg : el->getArguments()) {
856 os << ",\n";
857 if (auto *param = dyn_cast<ParameterElement>(arg)) {
858 os << param->getParam().getAccessorName() << "()";
859 } else if (auto *ref = dyn_cast<RefDirective>(arg)) {
860 os << cast<ParameterElement>(ref->getArg())->getParam().getAccessorName()
861 << "()";
862 } else {
863 os << tgfmt(cast<StringElement>(arg)->getValue(), &ctx);
866 os.unindent() << ");\n";
869 void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
870 MethodBody &os) {
871 FormatElement *anchor = el->getAnchor();
872 if (auto *param = dyn_cast<ParameterElement>(anchor)) {
873 guardOnAny(ctx, os, llvm::ArrayRef(param), el->isInverted());
874 } else if (auto *params = dyn_cast<ParamsDirective>(anchor)) {
875 guardOnAny(ctx, os, params->getParams(), el->isInverted());
876 } else if (auto *strct = dyn_cast<StructDirective>(anchor)) {
877 guardOnAny(ctx, os, strct->getParams(), el->isInverted());
878 } else {
879 auto *custom = cast<CustomDirective>(anchor);
880 guardOnAny(ctx, os,
881 llvm::make_filter_range(
882 llvm::map_range(custom->getArguments(),
883 [](FormatElement *el) {
884 return dyn_cast<ParameterElement>(el);
886 [](ParameterElement *param) { return !!param; }),
887 el->isInverted());
889 // Generate the printer for the contained elements.
891 llvm::SaveAndRestore shouldEmitSpaceFlag(shouldEmitSpace);
892 llvm::SaveAndRestore lastWasPunctuationFlag(lastWasPunctuation);
893 for (FormatElement *element : el->getThenElements())
894 genElementPrinter(element, ctx, os);
896 os.unindent() << "} else {\n";
897 os.indent();
898 for (FormatElement *element : el->getElseElements())
899 genElementPrinter(element, ctx, os);
900 os.unindent() << "}\n";
903 void DefFormat::genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx,
904 MethodBody &os) {
905 if (el->getValue() == "\\n") {
906 // FIXME: The newline should be `printer.printNewLine()`, i.e., handled by
907 // the printer.
908 os << tgfmt("$_printer << '\\n';\n", &ctx);
909 } else if (!el->getValue().empty()) {
910 os << tgfmt("$_printer << \"$0\";\n", &ctx, el->getValue());
911 } else {
912 lastWasPunctuation = true;
914 shouldEmitSpace = false;
917 //===----------------------------------------------------------------------===//
918 // DefFormatParser
919 //===----------------------------------------------------------------------===//
921 namespace {
922 class DefFormatParser : public FormatParser {
923 public:
924 DefFormatParser(llvm::SourceMgr &mgr, const AttrOrTypeDef &def)
925 : FormatParser(mgr, def.getLoc()[0]), def(def),
926 seenParams(def.getNumParameters()) {}
928 /// Parse the attribute or type format and create the format elements.
929 FailureOr<DefFormat> parse();
931 protected:
932 /// Verify the parsed elements.
933 LogicalResult verify(SMLoc loc, ArrayRef<FormatElement *> elements) override;
934 /// Verify the elements of a custom directive.
935 LogicalResult
936 verifyCustomDirectiveArguments(SMLoc loc,
937 ArrayRef<FormatElement *> arguments) override;
938 /// Verify the elements of an optional group.
939 LogicalResult verifyOptionalGroupElements(SMLoc loc,
940 ArrayRef<FormatElement *> elements,
941 FormatElement *anchor) override;
943 /// Parse an attribute or type variable.
944 FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
945 Context ctx) override;
946 /// Parse an attribute or type format directive.
947 FailureOr<FormatElement *>
948 parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override;
950 private:
951 /// Parse a `params` directive.
952 FailureOr<FormatElement *> parseParamsDirective(SMLoc loc, Context ctx);
953 /// Parse a `qualified` directive.
954 FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc, Context ctx);
955 /// Parse a `struct` directive.
956 FailureOr<FormatElement *> parseStructDirective(SMLoc loc, Context ctx);
957 /// Parse a `ref` directive.
958 FailureOr<FormatElement *> parseRefDirective(SMLoc loc, Context ctx);
960 /// Attribute or type tablegen def.
961 const AttrOrTypeDef &def;
963 /// Seen attribute or type parameters.
964 BitVector seenParams;
966 } // namespace
968 LogicalResult DefFormatParser::verify(SMLoc loc,
969 ArrayRef<FormatElement *> elements) {
970 // Check that all parameters are referenced in the format.
971 for (auto [index, param] : llvm::enumerate(def.getParameters())) {
972 if (param.isOptional())
973 continue;
974 if (!seenParams.test(index)) {
975 if (isa<AttributeSelfTypeParameter>(param))
976 continue;
977 return emitError(loc, "format is missing reference to parameter: " +
978 param.getName());
980 if (isa<AttributeSelfTypeParameter>(param)) {
981 return emitError(loc,
982 "unexpected self type parameter in assembly format");
985 if (elements.empty())
986 return success();
987 // A `struct` directive that contains optional parameters cannot be followed
988 // by a comma literal, which is ambiguous.
989 for (auto it : llvm::zip(elements.drop_back(), elements.drop_front())) {
990 auto *structEl = dyn_cast<StructDirective>(std::get<0>(it));
991 auto *literalEl = dyn_cast<LiteralElement>(std::get<1>(it));
992 if (!structEl || !literalEl)
993 continue;
994 if (literalEl->getSpelling() == "," && structEl->hasOptionalParams()) {
995 return emitError(loc, "`struct` directive with optional parameters "
996 "cannot be followed by a comma literal");
999 return success();
1002 LogicalResult DefFormatParser::verifyCustomDirectiveArguments(
1003 SMLoc loc, ArrayRef<FormatElement *> arguments) {
1004 // Arguments are fully verified by the parser context.
1005 return success();
1008 LogicalResult
1009 DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
1010 ArrayRef<FormatElement *> elements,
1011 FormatElement *anchor) {
1012 // `params` and `struct` directives are allowed only if all the contained
1013 // parameters are optional.
1014 for (FormatElement *el : elements) {
1015 if (auto *param = dyn_cast<ParameterElement>(el)) {
1016 if (!param->isOptional()) {
1017 return emitError(loc,
1018 "parameters in an optional group must be optional");
1020 } else if (auto *params = dyn_cast<ParamsDirective>(el)) {
1021 if (llvm::any_of(params->getParams(), paramNotOptional)) {
1022 return emitError(loc, "`params` directive allowed in optional group "
1023 "only if all parameters are optional");
1025 } else if (auto *strct = dyn_cast<StructDirective>(el)) {
1026 if (llvm::any_of(strct->getParams(), paramNotOptional)) {
1027 return emitError(loc, "`struct` is only allowed in an optional group "
1028 "if all captured parameters are optional");
1030 } else if (auto *custom = dyn_cast<CustomDirective>(el)) {
1031 for (FormatElement *el : custom->getArguments()) {
1032 // If the custom argument is a variable, then it must be optional.
1033 if (auto *param = dyn_cast<ParameterElement>(el))
1034 if (!param->isOptional())
1035 return emitError(loc,
1036 "`custom` is only allowed in an optional group if "
1037 "all captured parameters are optional");
1041 // The anchor must be a parameter or one of the aforementioned directives.
1042 if (anchor) {
1043 if (!isa<ParameterElement, ParamsDirective, StructDirective,
1044 CustomDirective>(anchor)) {
1045 return emitError(
1046 loc, "optional group anchor must be a parameter or directive");
1048 // If the anchor is a custom directive, make sure at least one of its
1049 // arguments is a bound parameter.
1050 if (auto *custom = dyn_cast<CustomDirective>(anchor)) {
1051 const auto *bound =
1052 llvm::find_if(custom->getArguments(), [](FormatElement *el) {
1053 return isa<ParameterElement>(el);
1055 if (bound == custom->getArguments().end())
1056 return emitError(loc, "`custom` directive with no bound parameters "
1057 "cannot be used as optional group anchor");
1060 return success();
1063 FailureOr<DefFormat> DefFormatParser::parse() {
1064 FailureOr<std::vector<FormatElement *>> elements = FormatParser::parse();
1065 if (failed(elements))
1066 return failure();
1067 return DefFormat(def, std::move(*elements));
1070 FailureOr<FormatElement *>
1071 DefFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
1072 // Lookup the parameter.
1073 ArrayRef<AttrOrTypeParameter> params = def.getParameters();
1074 auto *it = llvm::find_if(
1075 params, [&](auto &param) { return param.getName() == name; });
1077 // Check that the parameter reference is valid.
1078 if (it == params.end()) {
1079 return emitError(loc,
1080 def.getName() + " has no parameter named '" + name + "'");
1082 auto idx = std::distance(params.begin(), it);
1084 if (ctx != RefDirectiveContext) {
1085 // Check that the variable has not already been bound.
1086 if (seenParams.test(idx))
1087 return emitError(loc, "duplicate parameter '" + name + "'");
1088 seenParams.set(idx);
1090 // Otherwise, to be referenced, a variable must have been bound.
1091 } else if (!seenParams.test(idx) && !isa<AttributeSelfTypeParameter>(*it)) {
1092 return emitError(loc, "parameter '" + name +
1093 "' must be bound before it is referenced");
1096 return create<ParameterElement>(*it);
1099 FailureOr<FormatElement *>
1100 DefFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
1101 Context ctx) {
1103 switch (kind) {
1104 case FormatToken::kw_qualified:
1105 return parseQualifiedDirective(loc, ctx);
1106 case FormatToken::kw_params:
1107 return parseParamsDirective(loc, ctx);
1108 case FormatToken::kw_struct:
1109 return parseStructDirective(loc, ctx);
1110 case FormatToken::kw_ref:
1111 return parseRefDirective(loc, ctx);
1112 case FormatToken::kw_custom:
1113 return parseCustomDirective(loc, ctx);
1115 default:
1116 return emitError(loc, "unsupported directive kind");
1120 FailureOr<FormatElement *>
1121 DefFormatParser::parseQualifiedDirective(SMLoc loc, Context ctx) {
1122 if (failed(parseToken(FormatToken::l_paren,
1123 "expected '(' before argument list")))
1124 return failure();
1125 FailureOr<FormatElement *> var = parseElement(ctx);
1126 if (failed(var))
1127 return var;
1128 if (!isa<ParameterElement>(*var))
1129 return emitError(loc, "`qualified` argument list expected a variable");
1130 cast<ParameterElement>(*var)->setShouldBeQualified();
1131 if (failed(
1132 parseToken(FormatToken::r_paren, "expected ')' after argument list")))
1133 return failure();
1134 return var;
1137 FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc,
1138 Context ctx) {
1139 // It doesn't make sense to allow references to all parameters in a custom
1140 // directive because parameters are the only things that can be bound.
1141 if (ctx != TopLevelContext && ctx != StructDirectiveContext) {
1142 return emitError(loc, "`params` can only be used at the top-level context "
1143 "or within a `struct` directive");
1146 // Collect all of the attribute's or type's parameters and ensure that none of
1147 // the parameters have already been captured.
1148 std::vector<ParameterElement *> vars;
1149 for (const auto &it : llvm::enumerate(def.getParameters())) {
1150 if (seenParams.test(it.index())) {
1151 return emitError(loc, "`params` captures duplicate parameter: " +
1152 it.value().getName());
1154 // Self-type parameters are handled separately from the rest of the
1155 // parameters.
1156 if (isa<AttributeSelfTypeParameter>(it.value()))
1157 continue;
1158 seenParams.set(it.index());
1159 vars.push_back(create<ParameterElement>(it.value()));
1161 return create<ParamsDirective>(std::move(vars));
1164 FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc,
1165 Context ctx) {
1166 if (ctx != TopLevelContext)
1167 return emitError(loc, "`struct` can only be used at the top-level context");
1169 if (failed(parseToken(FormatToken::l_paren,
1170 "expected '(' before `struct` argument list")))
1171 return failure();
1173 // Parse variables captured by `struct`.
1174 std::vector<ParameterElement *> vars;
1176 // Parse first captured parameter or a `params` directive.
1177 FailureOr<FormatElement *> var = parseElement(StructDirectiveContext);
1178 if (failed(var) || !isa<VariableElement, ParamsDirective>(*var)) {
1179 return emitError(loc,
1180 "`struct` argument list expected a variable or directive");
1182 if (isa<VariableElement>(*var)) {
1183 // Parse any other parameters.
1184 vars.push_back(cast<ParameterElement>(*var));
1185 while (peekToken().is(FormatToken::comma)) {
1186 consumeToken();
1187 var = parseElement(StructDirectiveContext);
1188 if (failed(var) || !isa<VariableElement>(*var))
1189 return emitError(loc, "expected a variable in `struct` argument list");
1190 vars.push_back(cast<ParameterElement>(*var));
1192 } else {
1193 // `struct(params)` captures all parameters in the attribute or type.
1194 vars = cast<ParamsDirective>(*var)->takeParams();
1197 if (failed(parseToken(FormatToken::r_paren,
1198 "expected ')' at the end of an argument list")))
1199 return failure();
1201 return create<StructDirective>(std::move(vars));
1204 FailureOr<FormatElement *> DefFormatParser::parseRefDirective(SMLoc loc,
1205 Context ctx) {
1206 if (ctx != CustomDirectiveContext)
1207 return emitError(loc, "`ref` is only allowed inside custom directives");
1209 // Parse the child parameter element.
1210 FailureOr<FormatElement *> child;
1211 if (failed(parseToken(FormatToken::l_paren, "expected '('")) ||
1212 failed(child = parseElement(RefDirectiveContext)) ||
1213 failed(parseToken(FormatToken::r_paren, "expeced ')'")))
1214 return failure();
1216 // Only parameter elements are allowed to be parsed under a `ref` directive.
1217 return create<RefDirective>(*child);
1220 //===----------------------------------------------------------------------===//
1221 // Interface
1222 //===----------------------------------------------------------------------===//
1224 void mlir::tblgen::generateAttrOrTypeFormat(const AttrOrTypeDef &def,
1225 MethodBody &parser,
1226 MethodBody &printer) {
1227 llvm::SourceMgr mgr;
1228 mgr.AddNewSourceBuffer(
1229 llvm::MemoryBuffer::getMemBuffer(*def.getAssemblyFormat()), SMLoc());
1231 // Parse the custom assembly format>
1232 DefFormatParser fmtParser(mgr, def);
1233 FailureOr<DefFormat> format = fmtParser.parse();
1234 if (failed(format)) {
1235 if (formatErrorIsFatal)
1236 PrintFatalError(def.getLoc(), "failed to parse assembly format");
1237 return;
1240 // Generate the parser and printer.
1241 format->genParser(parser);
1242 format->genPrinter(printer);