1 //===- AttrOrTypeFormatGen.cpp - MLIR attribute and type format generator -===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "AttrOrTypeFormatGen.h"
10 #include "FormatGen.h"
11 #include "mlir/Support/LLVM.h"
12 #include "mlir/TableGen/AttrOrTypeDef.h"
13 #include "mlir/TableGen/Format.h"
14 #include "mlir/TableGen/GenInfo.h"
15 #include "llvm/ADT/BitVector.h"
16 #include "llvm/ADT/StringExtras.h"
17 #include "llvm/ADT/StringSwitch.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/MemoryBuffer.h"
20 #include "llvm/Support/SaveAndRestore.h"
21 #include "llvm/Support/SourceMgr.h"
22 #include "llvm/TableGen/Error.h"
23 #include "llvm/TableGen/TableGenBackend.h"
26 using namespace mlir::tblgen
;
30 //===----------------------------------------------------------------------===//
32 //===----------------------------------------------------------------------===//
35 /// This class represents an instance of a variable element. A variable refers
36 /// to an attribute or type parameter.
37 class ParameterElement
38 : public VariableElementBase
<VariableElement::Parameter
> {
40 ParameterElement(AttrOrTypeParameter param
) : param(param
) {}
42 /// Get the parameter in the element.
43 const AttrOrTypeParameter
&getParam() const { return param
; }
45 /// Indicate if this variable is printed "qualified" (that is it is
46 /// prefixed with the `#dialect.mnemonic`).
47 bool shouldBeQualified() { return shouldBeQualifiedFlag
; }
48 void setShouldBeQualified(bool qualified
= true) {
49 shouldBeQualifiedFlag
= qualified
;
52 /// Returns true if the element contains an optional parameter.
53 bool isOptional() const { return param
.isOptional(); }
55 /// Returns the name of the parameter.
56 StringRef
getName() const { return param
.getName(); }
58 /// Return the code to check whether the parameter is present.
59 auto genIsPresent(FmtContext
&ctx
, const Twine
&self
) const {
60 assert(isOptional() && "cannot guard on a mandatory parameter");
61 std::string valueStr
= tgfmt(*param
.getDefaultValue(), &ctx
).str();
62 ctx
.addSubst("_lhs", self
).addSubst("_rhs", valueStr
);
63 return tgfmt(getParam().getComparator(), &ctx
);
66 /// Generate the code to check whether the parameter should be printed.
67 MethodBody
&genPrintGuard(FmtContext
&ctx
, MethodBody
&os
) const {
68 assert(isOptional() && "cannot guard on a mandatory parameter");
69 std::string self
= param
.getAccessorName() + "()";
70 return os
<< "!(" << genIsPresent(ctx
, self
) << ")";
74 bool shouldBeQualifiedFlag
= false;
75 AttrOrTypeParameter param
;
78 /// Shorthand functions that can be used with ranged-based conditions.
79 static bool paramIsOptional(ParameterElement
*el
) { return el
->isOptional(); }
80 static bool paramNotOptional(ParameterElement
*el
) { return !el
->isOptional(); }
82 /// Base class for a directive that contains references to multiple variables.
83 template <DirectiveElement::Kind DirectiveKind
>
84 class ParamsDirectiveBase
: public DirectiveElementBase
<DirectiveKind
> {
86 using Base
= ParamsDirectiveBase
<DirectiveKind
>;
88 ParamsDirectiveBase(std::vector
<ParameterElement
*> &¶ms
)
89 : params(std::move(params
)) {}
91 /// Get the parameters contained in this directive.
92 ArrayRef
<ParameterElement
*> getParams() const { return params
; }
94 /// Get the number of parameters.
95 unsigned getNumParams() const { return params
.size(); }
97 /// Take all of the parameters from this directive.
98 std::vector
<ParameterElement
*> takeParams() { return std::move(params
); }
100 /// Returns true if there are optional parameters present.
101 bool hasOptionalParams() const {
102 return llvm::any_of(getParams(), paramIsOptional
);
106 /// The parameters captured by this directive.
107 std::vector
<ParameterElement
*> params
;
110 /// This class represents a `params` directive that refers to all parameters
111 /// of an attribute or type. When used as a top-level directive, it generates
112 /// a format of the form:
114 /// (param-value (`,` param-value)*)?
116 /// When used as an argument to another directive that accepts variables,
117 /// `params` can be used in place of manually listing all parameters of an
118 /// attribute or type.
119 class ParamsDirective
: public ParamsDirectiveBase
<DirectiveElement::Params
> {
124 /// This class represents a `struct` directive that generates a struct format
127 /// `{` param-name `=` param-value (`,` param-name `=` param-value)* `}`
129 class StructDirective
: public ParamsDirectiveBase
<DirectiveElement::Struct
> {
136 //===----------------------------------------------------------------------===//
138 //===----------------------------------------------------------------------===//
140 /// Default parser for attribute or type parameters.
141 static const char *const defaultParameterParser
=
142 "::mlir::FieldParser<$0>::parse($_parser)";
144 /// Default printer for attribute or type parameters.
145 static const char *const defaultParameterPrinter
=
146 "$_printer.printStrippedAttrOrType($_self)";
148 /// Qualified printer for attribute or type parameters: it does not elide
149 /// dialect and mnemonic.
150 static const char *const qualifiedParameterPrinter
= "$_printer << $_self";
152 /// Print an error when failing to parse an element.
154 /// $0: The parameter C++ class name.
155 static const char *const parserErrorStr
=
156 "$_parser.emitError($_parser.getCurrentLocation(), ";
158 /// Code format to parse a variable. Separate by lines because variable parsers
159 /// may be generated inside other directives, which requires indentation.
161 /// {0}: The parameter name.
162 /// {1}: The parse code for the parameter.
163 /// {2}: Code template for printing an error.
164 /// {3}: Name of the attribute or type.
165 /// {4}: C++ class of the parameter.
166 /// {5}: Optional code to preload the dialect for this variable.
167 static const char *const variableParser
= R
"(
168 // Parse variable '{0}'{5}
170 if (::mlir::failed(_result_{0})) {{
171 {2}"failed to parse
{3} parameter
'{0}' which is to be a `
{4}`
");
176 //===----------------------------------------------------------------------===//
178 //===----------------------------------------------------------------------===//
183 DefFormat(const AttrOrTypeDef
&def
, std::vector
<FormatElement
*> &&elements
)
184 : def(def
), elements(std::move(elements
)) {}
186 /// Generate the attribute or type parser.
187 void genParser(MethodBody
&os
);
188 /// Generate the attribute or type printer.
189 void genPrinter(MethodBody
&os
);
192 /// Generate the parser code for a specific format element.
193 void genElementParser(FormatElement
*el
, FmtContext
&ctx
, MethodBody
&os
);
194 /// Generate the parser code for a literal.
195 void genLiteralParser(StringRef value
, FmtContext
&ctx
, MethodBody
&os
,
196 bool isOptional
= false);
197 /// Generate the parser code for a variable.
198 void genVariableParser(ParameterElement
*el
, FmtContext
&ctx
, MethodBody
&os
);
199 /// Generate the parser code for a `params` directive.
200 void genParamsParser(ParamsDirective
*el
, FmtContext
&ctx
, MethodBody
&os
);
201 /// Generate the parser code for a `struct` directive.
202 void genStructParser(StructDirective
*el
, FmtContext
&ctx
, MethodBody
&os
);
203 /// Generate the parser code for a `custom` directive.
204 void genCustomParser(CustomDirective
*el
, FmtContext
&ctx
, MethodBody
&os
,
205 bool isOptional
= false);
206 /// Generate the parser code for an optional group.
207 void genOptionalGroupParser(OptionalElement
*el
, FmtContext
&ctx
,
210 /// Generate the printer code for a specific format element.
211 void genElementPrinter(FormatElement
*el
, FmtContext
&ctx
, MethodBody
&os
);
212 /// Generate the printer code for a literal.
213 void genLiteralPrinter(StringRef value
, FmtContext
&ctx
, MethodBody
&os
);
214 /// Generate the printer code for a variable.
215 void genVariablePrinter(ParameterElement
*el
, FmtContext
&ctx
, MethodBody
&os
,
216 bool skipGuard
= false);
217 /// Generate a printer for comma-separated parameters.
218 void genCommaSeparatedPrinter(ArrayRef
<ParameterElement
*> params
,
219 FmtContext
&ctx
, MethodBody
&os
,
220 function_ref
<void(ParameterElement
*)> extra
);
221 /// Generate the printer code for a `params` directive.
222 void genParamsPrinter(ParamsDirective
*el
, FmtContext
&ctx
, MethodBody
&os
);
223 /// Generate the printer code for a `struct` directive.
224 void genStructPrinter(StructDirective
*el
, FmtContext
&ctx
, MethodBody
&os
);
225 /// Generate the printer code for a `custom` directive.
226 void genCustomPrinter(CustomDirective
*el
, FmtContext
&ctx
, MethodBody
&os
);
227 /// Generate the printer code for an optional group.
228 void genOptionalGroupPrinter(OptionalElement
*el
, FmtContext
&ctx
,
230 /// Generate a printer (or space eraser) for a whitespace element.
231 void genWhitespacePrinter(WhitespaceElement
*el
, FmtContext
&ctx
,
234 /// The ODS definition of the attribute or type whose format is being used to
235 /// generate a parser and printer.
236 const AttrOrTypeDef
&def
;
237 /// The list of top-level format elements returned by the assembly format
239 std::vector
<FormatElement
*> elements
;
241 /// Flags for printing spaces.
242 bool shouldEmitSpace
= false;
243 bool lastWasPunctuation
= false;
247 //===----------------------------------------------------------------------===//
249 //===----------------------------------------------------------------------===//
251 /// Generate a special-case "parser" for an attribute's self type parameter. The
252 /// self type parameter has special handling in the assembly format in that it
253 /// is derived from the optional trailing colon type after the attribute.
254 static void genAttrSelfTypeParser(MethodBody
&os
, const FmtContext
&ctx
,
255 const AttributeSelfTypeParameter
¶m
) {
256 // "Parser" for an attribute self type parameter that checks the
257 // optionally-parsed trailing colon type.
259 // $0: The C++ storage class of the type parameter.
260 // $1: The self type parameter name.
261 const char *const selfTypeParser
= R
"(
263 if (auto reqType = ::llvm::dyn_cast<$0>($_type)) {
264 _result_$1 = reqType;
266 $_parser.emitError($_loc, "invalid kind of type specified
");
271 // If the attribute self type parameter is required, emit code that emits an
272 // error if the trailing type was not parsed.
273 const char *const selfTypeRequired
= R
"( else {
274 $_parser.emitError($_loc, "expected a trailing type
");
278 os
<< tgfmt(selfTypeParser
, &ctx
, param
.getCppStorageType(), param
.getName());
279 if (!param
.isOptional())
280 os
<< tgfmt(selfTypeRequired
, &ctx
);
284 void DefFormat::genParser(MethodBody
&os
) {
286 ctx
.addSubst("_parser", "odsParser");
287 ctx
.addSubst("_ctxt", "odsParser.getContext()");
288 ctx
.withBuilder("odsBuilder");
289 if (isa
<AttrDef
>(def
))
290 ctx
.addSubst("_type", "odsType");
292 os
<< "::mlir::Builder odsBuilder(odsParser.getContext());\n";
294 // Store the initial location of the parser.
295 ctx
.addSubst("_loc", "odsLoc");
296 os
<< tgfmt("::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n"
300 // Declare variables to store all of the parameters. Allocated parameters
301 // such as `ArrayRef` and `StringRef` must provide a `storageType`. Store
302 // FailureOr<T> to defer type construction for parameters that are parsed in
303 // a loop (parsers return FailureOr anyways).
304 ArrayRef
<AttrOrTypeParameter
> params
= def
.getParameters();
305 for (const AttrOrTypeParameter
¶m
: params
) {
306 os
<< formatv("::mlir::FailureOr<{0}> _result_{1};\n",
307 param
.getCppStorageType(), param
.getName());
308 if (auto *selfTypeParam
= dyn_cast
<AttributeSelfTypeParameter
>(¶m
))
309 genAttrSelfTypeParser(os
, ctx
, *selfTypeParam
);
312 // Generate call to each parameter parser.
313 for (FormatElement
*el
: elements
)
314 genElementParser(el
, ctx
, os
);
316 // Emit an assert for each mandatory parameter. Triggering an assert means
317 // the generated parser is incorrect (i.e. there is a bug in this code).
318 for (const AttrOrTypeParameter
¶m
: params
) {
319 if (param
.isOptional())
321 os
<< formatv("assert(::mlir::succeeded(_result_{0}));\n", param
.getName());
324 // Generate call to the attribute or type builder. Use the checked getter
325 // if one was generated.
326 if (def
.genVerifyDecl()) {
327 os
<< tgfmt("return $_parser.getChecked<$0>($_loc, $_parser.getContext()",
328 &ctx
, def
.getCppClassName());
330 os
<< tgfmt("return $0::get($_parser.getContext()", &ctx
,
331 def
.getCppClassName());
333 for (const AttrOrTypeParameter
¶m
: params
) {
335 std::string paramSelfStr
;
336 llvm::raw_string_ostream
selfOs(paramSelfStr
);
337 if (std::optional
<StringRef
> defaultValue
= param
.getDefaultValue()) {
338 selfOs
<< formatv("(_result_{0}.value_or(", param
.getName())
339 << tgfmt(*defaultValue
, &ctx
) << "))";
341 selfOs
<< formatv("(*_result_{0})", param
.getName());
343 ctx
.addSubst(param
.getName(), selfOs
.str());
344 os
<< param
.getCppType() << "("
345 << tgfmt(param
.getConvertFromStorage(), &ctx
.withSelf(selfOs
.str()))
351 void DefFormat::genElementParser(FormatElement
*el
, FmtContext
&ctx
,
353 if (auto *literal
= dyn_cast
<LiteralElement
>(el
))
354 return genLiteralParser(literal
->getSpelling(), ctx
, os
);
355 if (auto *var
= dyn_cast
<ParameterElement
>(el
))
356 return genVariableParser(var
, ctx
, os
);
357 if (auto *params
= dyn_cast
<ParamsDirective
>(el
))
358 return genParamsParser(params
, ctx
, os
);
359 if (auto *strct
= dyn_cast
<StructDirective
>(el
))
360 return genStructParser(strct
, ctx
, os
);
361 if (auto *custom
= dyn_cast
<CustomDirective
>(el
))
362 return genCustomParser(custom
, ctx
, os
);
363 if (auto *optional
= dyn_cast
<OptionalElement
>(el
))
364 return genOptionalGroupParser(optional
, ctx
, os
);
365 if (isa
<WhitespaceElement
>(el
))
368 llvm_unreachable("unknown format element");
371 void DefFormat::genLiteralParser(StringRef value
, FmtContext
&ctx
,
372 MethodBody
&os
, bool isOptional
) {
373 os
<< "// Parse literal '" << value
<< "'\n";
374 os
<< tgfmt("if ($_parser.parse", &ctx
);
377 if (value
.front() == '_' || isalpha(value
.front())) {
378 os
<< "Keyword(\"" << value
<< "\")";
380 os
<< StringSwitch
<StringRef
>(value
)
386 .Case(">", "Greater")
391 .Case("[", "LSquare")
392 .Case("]", "RSquare")
393 .Case("?", "Question")
396 .Case("...", "Ellipsis")
400 // Leave the `if` unclosed to guard optional groups.
403 // Parser will emit an error
404 os
<< ") return {};\n";
407 void DefFormat::genVariableParser(ParameterElement
*el
, FmtContext
&ctx
,
409 // Check for a custom parser. Use the default attribute parser otherwise.
410 const AttrOrTypeParameter
¶m
= el
->getParam();
411 auto customParser
= param
.getParser();
413 customParser
? *customParser
: StringRef(defaultParameterParser
);
415 // If the variable points to a dialect specific entity (type of attribute),
416 // we force load the dialect now before trying to parse it.
417 std::string dialectLoading
;
418 if (auto *defInit
= dyn_cast
<llvm::DefInit
>(param
.getDef())) {
419 auto *dialectValue
= defInit
->getDef()->getValue("dialect");
421 if (auto *dialectInit
=
422 dyn_cast
<llvm::DefInit
>(dialectValue
->getValue())) {
423 Dialect
dialect(dialectInit
->getDef());
424 auto cppNamespace
= dialect
.getCppNamespace();
425 std::string name
= dialect
.getCppClassName();
426 if (name
!= "BuiltinDialect" || cppNamespace
!= "::mlir") {
427 dialectLoading
= ("\nodsParser.getContext()->getOrLoadDialect<" +
428 cppNamespace
+ "::" + name
+ ">();")
434 os
<< formatv(variableParser
, param
.getName(),
435 tgfmt(parser
, &ctx
, param
.getCppStorageType()),
436 tgfmt(parserErrorStr
, &ctx
), def
.getName(), param
.getCppType(),
440 void DefFormat::genParamsParser(ParamsDirective
*el
, FmtContext
&ctx
,
442 os
<< "// Parse parameter list\n";
444 // If there are optional parameters, we need to switch to `parseOptionalComma`
445 // if there are no more required parameters after a certain point.
446 bool hasOptional
= el
->hasOptionalParams();
448 // Wrap everything in a do-while so that we can `break`.
453 ArrayRef
<ParameterElement
*> params
= el
->getParams();
454 using IteratorT
= ParameterElement
*const *;
455 IteratorT it
= params
.begin();
457 // Find the last required parameter. Commas become optional aftewards.
458 // Note: IteratorT's copy assignment is deleted.
459 ParameterElement
*lastReq
= nullptr;
460 for (ParameterElement
*param
: params
)
461 if (!param
->isOptional())
463 IteratorT lastReqIt
= lastReq
? llvm::find(params
, lastReq
) : params
.begin();
465 auto eachFn
= [&](ParameterElement
*el
) { genVariableParser(el
, ctx
, os
); };
466 auto betweenFn
= [&](IteratorT it
) {
467 ParameterElement
*el
= *std::prev(it
);
468 // Parse a comma if the last optional parameter had a value.
469 if (el
->isOptional()) {
470 os
<< formatv("if (::mlir::succeeded(_result_{0}) && !({1})) {{\n",
472 el
->genIsPresent(ctx
, "(*_result_" + el
->getName() + ")"));
475 if (it
<= lastReqIt
) {
476 genLiteralParser(",", ctx
, os
);
478 genLiteralParser(",", ctx
, os
, /*isOptional=*/true);
481 if (el
->isOptional())
482 os
.unindent() << "}\n";
486 if (it
!= params
.end()) {
488 for (IteratorT e
= params
.end(); it
!= e
; ++it
) {
495 os
.unindent() << "} while(false);\n";
498 void DefFormat::genStructParser(StructDirective
*el
, FmtContext
&ctx
,
500 // Loop declaration for struct parser with only required parameters.
502 // $0: Number of expected parameters.
503 const char *const loopHeader
= R
"(
504 for (unsigned odsStructIndex = 0; odsStructIndex < $0; ++odsStructIndex) {
507 // Loop body start for struct parser.
508 const char *const loopStart
= R
"(
509 ::llvm::StringRef _paramKey;
510 if ($_parser.parseKeyword(&_paramKey)) {
511 $_parser.emitError($_parser.getCurrentLocation(),
512 "expected a parameter name in
struct");
515 if (!_loop_body(_paramKey)) return {};
518 // Struct parser loop end. Check for duplicate or unknown struct parameters.
520 // {0}: Code template for printing an error.
521 const char *const loopEnd
= R
"({{
522 {0}"duplicate
or unknown
struct parameter name
: ") << _paramKey;
527 // Struct parser loop terminator. Parse a comma except on the last element.
529 // {0}: Number of elements in the struct.
530 const char *const loopTerminator
= R
"(
531 if ((odsStructIndex != {0} - 1) && odsParser.parseComma())
536 // Check that a mandatory parameter was parse.
538 // {0}: Name of the parameter.
539 const char *const checkParam
= R
"(
541 {1}"struct is missing required parameter
: ") << "{0}";
546 // First iteration of the loop parsing an optional struct.
547 const char *const optionalStructFirst
= R
"(
548 ::llvm::StringRef _paramKey;
549 if (!$_parser.parseOptionalKeyword(&_paramKey)) {
550 if (!_loop_body(_paramKey)) return {};
551 while (!$_parser.parseOptionalComma()) {
554 os
<< "// Parse parameter struct\n";
556 // Declare a "seen" variable for each key.
557 for (ParameterElement
*param
: el
->getParams())
558 os
<< formatv("bool _seen_{0} = false;\n", param
->getName());
560 // Generate the body of the parsing loop inside a lambda.
563 << "const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {\n";
564 genLiteralParser("=", ctx
, os
.indent());
565 for (ParameterElement
*param
: el
->getParams()) {
566 os
<< formatv("if (!_seen_{0} && _paramKey == \"{0}\") {\n"
567 " _seen_{0} = true;\n",
569 genVariableParser(param
, ctx
, os
.indent());
570 os
.unindent() << "} else ";
571 // Print the check for duplicate or unknown parameter.
573 os
.getStream().printReindented(strfmt(loopEnd
, tgfmt(parserErrorStr
, &ctx
)));
574 os
<< "return true;\n";
575 os
.unindent() << "};\n";
577 // Generate the parsing loop. If optional parameters are present, then the
578 // parse loop is guarded by commas.
579 unsigned numOptional
= llvm::count_if(el
->getParams(), paramIsOptional
);
581 // If the struct itself is optional, pull out the first iteration.
582 if (numOptional
== el
->getNumParams()) {
583 os
.getStream().printReindented(tgfmt(optionalStructFirst
, &ctx
).str());
589 os
.getStream().printReindented(
590 tgfmt(loopHeader
, &ctx
, el
->getNumParams()).str());
593 os
.getStream().printReindented(tgfmt(loopStart
, &ctx
).str());
596 // Print the loop terminator. For optional parameters, we have to check that
597 // all mandatory parameters have been parsed.
598 // The whole struct is optional if all its parameters are optional.
600 if (numOptional
== el
->getNumParams()) {
602 os
.unindent() << "}\n";
604 os
<< tgfmt("} while(!$_parser.parseOptionalComma());\n", &ctx
);
605 for (ParameterElement
*param
: el
->getParams()) {
606 if (param
->isOptional())
608 os
.getStream().printReindented(
609 strfmt(checkParam
, param
->getName(), tgfmt(parserErrorStr
, &ctx
)));
613 // Because the loop loops N times and each non-failing iteration sets 1 of
614 // N flags, successfully exiting the loop means that all parameters have
615 // been seen. `parseOptionalComma` would cause issues with any formats that
616 // use "struct(...) `,`" beacuse structs aren't sounded by braces.
617 os
.getStream().printReindented(strfmt(loopTerminator
, el
->getNumParams()));
619 os
.unindent() << "}\n";
622 void DefFormat::genCustomParser(CustomDirective
*el
, FmtContext
&ctx
,
623 MethodBody
&os
, bool isOptional
) {
627 // Bound variables are passed directly to the parser as `FailureOr<T> &`.
628 // Referenced variables are passed as `T`. The custom parser fails if it
629 // returns failure or if any of the required parameters failed.
630 os
<< tgfmt("auto odsCustomLoc = $_parser.getCurrentLocation();\n", &ctx
);
631 os
<< "(void)odsCustomLoc;\n";
632 os
<< tgfmt("auto odsCustomResult = parse$0($_parser", &ctx
, el
->getName());
634 for (FormatElement
*arg
: el
->getArguments()) {
636 if (auto *param
= dyn_cast
<ParameterElement
>(arg
))
637 os
<< "::mlir::detail::unwrapForCustomParse(_result_" << param
->getName()
639 else if (auto *ref
= dyn_cast
<RefDirective
>(arg
))
640 os
<< "*_result_" << cast
<ParameterElement
>(ref
->getArg())->getName();
642 os
<< tgfmt(cast
<StringElement
>(arg
)->getValue(), &ctx
);
644 os
.unindent() << ");\n";
646 os
<< "if (!odsCustomResult.has_value()) return {};\n";
647 os
<< "if (::mlir::failed(*odsCustomResult)) return ::mlir::failure();\n";
649 os
<< "if (::mlir::failed(odsCustomResult)) return {};\n";
651 for (FormatElement
*arg
: el
->getArguments()) {
652 if (auto *param
= dyn_cast
<ParameterElement
>(arg
)) {
653 if (param
->isOptional())
655 os
<< formatv("if (::mlir::failed(_result_{0})) {{\n", param
->getName());
656 os
.indent() << tgfmt("$_parser.emitError(odsCustomLoc, ", &ctx
)
657 << "\"custom parser failed to parse parameter '"
658 << param
->getName() << "'\");\n";
659 os
<< "return " << (isOptional
? "::mlir::failure()" : "{}") << ";\n";
660 os
.unindent() << "}\n";
664 os
.unindent() << "}\n";
667 void DefFormat::genOptionalGroupParser(OptionalElement
*el
, FmtContext
&ctx
,
669 ArrayRef
<FormatElement
*> thenElements
=
670 el
->getThenElements(/*parseable=*/true);
672 FormatElement
*first
= thenElements
.front();
673 const auto guardOn
= [&](auto params
) {
677 [&](ParameterElement
*el
) {
678 os
<< formatv("(::mlir::succeeded(_result_{0}) && *_result_{0})",
684 if (auto *literal
= dyn_cast
<LiteralElement
>(first
)) {
685 genLiteralParser(literal
->getSpelling(), ctx
, os
, /*isOptional=*/true);
687 } else if (auto *param
= dyn_cast
<ParameterElement
>(first
)) {
688 genVariableParser(param
, ctx
, os
);
689 guardOn(llvm::ArrayRef(param
));
690 } else if (auto *params
= dyn_cast
<ParamsDirective
>(first
)) {
691 genParamsParser(params
, ctx
, os
);
692 guardOn(params
->getParams());
693 } else if (auto *custom
= dyn_cast
<CustomDirective
>(first
)) {
694 os
<< "if (auto result = [&]() -> ::mlir::OptionalParseResult {\n";
696 genCustomParser(custom
, ctx
, os
, /*isOptional=*/true);
697 os
<< "return ::mlir::success();\n";
699 os
<< "}(); result.has_value() && ::mlir::failed(*result)) {\n";
701 os
<< "return {};\n";
703 os
<< "} else if (result.has_value()) {\n";
705 auto *strct
= cast
<StructDirective
>(first
);
706 genStructParser(strct
, ctx
, os
);
707 guardOn(params
->getParams());
711 // Generate the parsers for the rest of the thenElements.
712 for (FormatElement
*element
: el
->getElseElements(/*parseable=*/true))
713 genElementParser(element
, ctx
, os
);
714 os
.unindent() << "} else {\n";
716 for (FormatElement
*element
: thenElements
.drop_front())
717 genElementParser(element
, ctx
, os
);
718 os
.unindent() << "}\n";
721 //===----------------------------------------------------------------------===//
723 //===----------------------------------------------------------------------===//
725 void DefFormat::genPrinter(MethodBody
&os
) {
727 ctx
.addSubst("_printer", "odsPrinter");
728 ctx
.addSubst("_ctxt", "getContext()");
729 ctx
.withBuilder("odsBuilder");
731 os
<< "::mlir::Builder odsBuilder(getContext());\n";
733 // Generate printers.
734 shouldEmitSpace
= true;
735 lastWasPunctuation
= false;
736 for (FormatElement
*el
: elements
)
737 genElementPrinter(el
, ctx
, os
);
740 void DefFormat::genElementPrinter(FormatElement
*el
, FmtContext
&ctx
,
742 if (auto *literal
= dyn_cast
<LiteralElement
>(el
))
743 return genLiteralPrinter(literal
->getSpelling(), ctx
, os
);
744 if (auto *params
= dyn_cast
<ParamsDirective
>(el
))
745 return genParamsPrinter(params
, ctx
, os
);
746 if (auto *strct
= dyn_cast
<StructDirective
>(el
))
747 return genStructPrinter(strct
, ctx
, os
);
748 if (auto *custom
= dyn_cast
<CustomDirective
>(el
))
749 return genCustomPrinter(custom
, ctx
, os
);
750 if (auto *var
= dyn_cast
<ParameterElement
>(el
))
751 return genVariablePrinter(var
, ctx
, os
);
752 if (auto *optional
= dyn_cast
<OptionalElement
>(el
))
753 return genOptionalGroupPrinter(optional
, ctx
, os
);
754 if (auto *whitespace
= dyn_cast
<WhitespaceElement
>(el
))
755 return genWhitespacePrinter(whitespace
, ctx
, os
);
757 llvm::PrintFatalError("unsupported format element");
760 void DefFormat::genLiteralPrinter(StringRef value
, FmtContext
&ctx
,
762 // Don't insert a space before certain punctuation.
764 shouldEmitSpace
&& shouldEmitSpaceBefore(value
, lastWasPunctuation
);
765 os
<< tgfmt("$_printer$0 << \"$1\";\n", &ctx
, needSpace
? " << ' '" : "",
770 value
.size() != 1 || !StringRef("<({[").contains(value
.front());
771 lastWasPunctuation
= value
.front() != '_' && !isalpha(value
.front());
774 void DefFormat::genVariablePrinter(ParameterElement
*el
, FmtContext
&ctx
,
775 MethodBody
&os
, bool skipGuard
) {
776 const AttrOrTypeParameter
¶m
= el
->getParam();
777 ctx
.withSelf(param
.getAccessorName() + "()");
779 // Guard the printer on the presence of optional parameters and that they
780 // aren't equal to their default values (if they have one).
781 if (el
->isOptional() && !skipGuard
) {
782 el
->genPrintGuard(ctx
, os
<< "if (") << ") {\n";
786 // Insert a space before the next parameter, if necessary.
787 if (shouldEmitSpace
|| !lastWasPunctuation
)
788 os
<< tgfmt("$_printer << ' ';\n", &ctx
);
789 shouldEmitSpace
= true;
790 lastWasPunctuation
= false;
792 if (el
->shouldBeQualified())
793 os
<< tgfmt(qualifiedParameterPrinter
, &ctx
) << ";\n";
794 else if (auto printer
= param
.getPrinter())
795 os
<< tgfmt(*printer
, &ctx
) << ";\n";
797 os
<< tgfmt(defaultParameterPrinter
, &ctx
) << ";\n";
799 if (el
->isOptional() && !skipGuard
)
800 os
.unindent() << "}\n";
803 /// Generate code to guard printing on the presence of any optional parameters.
804 template <typename ParameterRange
>
805 static void guardOnAny(FmtContext
&ctx
, MethodBody
&os
, ParameterRange
&¶ms
,
806 bool inverted
= false) {
812 [&](ParameterElement
*param
) { param
->genPrintGuard(ctx
, os
); }, " || ");
819 void DefFormat::genCommaSeparatedPrinter(
820 ArrayRef
<ParameterElement
*> params
, FmtContext
&ctx
, MethodBody
&os
,
821 function_ref
<void(ParameterElement
*)> extra
) {
822 // Emit a space if necessary, but only if the struct is present.
823 if (shouldEmitSpace
|| !lastWasPunctuation
) {
824 bool allOptional
= llvm::all_of(params
, paramIsOptional
);
826 guardOnAny(ctx
, os
, params
);
827 os
<< tgfmt("$_printer << ' ';\n", &ctx
);
829 os
.unindent() << "}\n";
832 // The first printed element does not need to emit a comma.
834 os
.indent() << "bool _firstPrinted = true;\n";
835 for (ParameterElement
*param
: params
) {
836 if (param
->isOptional()) {
837 param
->genPrintGuard(ctx
, os
<< "if (") << ") {\n";
840 os
<< tgfmt("if (!_firstPrinted) $_printer << \", \";\n", &ctx
);
841 os
<< "_firstPrinted = false;\n";
843 shouldEmitSpace
= false;
844 lastWasPunctuation
= true;
845 genVariablePrinter(param
, ctx
, os
);
846 if (param
->isOptional())
847 os
.unindent() << "}\n";
849 os
.unindent() << "}\n";
852 void DefFormat::genParamsPrinter(ParamsDirective
*el
, FmtContext
&ctx
,
854 genCommaSeparatedPrinter(llvm::to_vector(el
->getParams()), ctx
, os
,
855 [&](ParameterElement
*param
) {});
858 void DefFormat::genStructPrinter(StructDirective
*el
, FmtContext
&ctx
,
860 genCommaSeparatedPrinter(
861 llvm::to_vector(el
->getParams()), ctx
, os
, [&](ParameterElement
*param
) {
862 os
<< tgfmt("$_printer << \"$0 = \";\n", &ctx
, param
->getName());
866 void DefFormat::genCustomPrinter(CustomDirective
*el
, FmtContext
&ctx
,
868 // Insert a space before the custom directive, if necessary.
869 if (shouldEmitSpace
|| !lastWasPunctuation
)
870 os
<< tgfmt("$_printer << ' ';\n", &ctx
);
871 shouldEmitSpace
= true;
872 lastWasPunctuation
= false;
874 os
<< tgfmt("print$0($_printer", &ctx
, el
->getName());
876 for (FormatElement
*arg
: el
->getArguments()) {
878 if (auto *param
= dyn_cast
<ParameterElement
>(arg
)) {
879 os
<< param
->getParam().getAccessorName() << "()";
880 } else if (auto *ref
= dyn_cast
<RefDirective
>(arg
)) {
881 os
<< cast
<ParameterElement
>(ref
->getArg())->getParam().getAccessorName()
884 os
<< tgfmt(cast
<StringElement
>(arg
)->getValue(), &ctx
);
887 os
.unindent() << ");\n";
890 void DefFormat::genOptionalGroupPrinter(OptionalElement
*el
, FmtContext
&ctx
,
892 FormatElement
*anchor
= el
->getAnchor();
893 if (auto *param
= dyn_cast
<ParameterElement
>(anchor
)) {
894 guardOnAny(ctx
, os
, llvm::ArrayRef(param
), el
->isInverted());
895 } else if (auto *params
= dyn_cast
<ParamsDirective
>(anchor
)) {
896 guardOnAny(ctx
, os
, params
->getParams(), el
->isInverted());
897 } else if (auto *strct
= dyn_cast
<StructDirective
>(anchor
)) {
898 guardOnAny(ctx
, os
, strct
->getParams(), el
->isInverted());
900 auto *custom
= cast
<CustomDirective
>(anchor
);
902 llvm::make_filter_range(
903 llvm::map_range(custom
->getArguments(),
904 [](FormatElement
*el
) {
905 return dyn_cast
<ParameterElement
>(el
);
907 [](ParameterElement
*param
) { return !!param
; }),
910 // Generate the printer for the contained elements.
912 llvm::SaveAndRestore
shouldEmitSpaceFlag(shouldEmitSpace
);
913 llvm::SaveAndRestore
lastWasPunctuationFlag(lastWasPunctuation
);
914 for (FormatElement
*element
: el
->getThenElements())
915 genElementPrinter(element
, ctx
, os
);
917 os
.unindent() << "} else {\n";
919 for (FormatElement
*element
: el
->getElseElements())
920 genElementPrinter(element
, ctx
, os
);
921 os
.unindent() << "}\n";
924 void DefFormat::genWhitespacePrinter(WhitespaceElement
*el
, FmtContext
&ctx
,
926 if (el
->getValue() == "\\n") {
927 // FIXME: The newline should be `printer.printNewLine()`, i.e., handled by
929 os
<< tgfmt("$_printer << '\\n';\n", &ctx
);
930 } else if (!el
->getValue().empty()) {
931 os
<< tgfmt("$_printer << \"$0\";\n", &ctx
, el
->getValue());
933 lastWasPunctuation
= true;
935 shouldEmitSpace
= false;
938 //===----------------------------------------------------------------------===//
940 //===----------------------------------------------------------------------===//
943 class DefFormatParser
: public FormatParser
{
945 DefFormatParser(llvm::SourceMgr
&mgr
, const AttrOrTypeDef
&def
)
946 : FormatParser(mgr
, def
.getLoc()[0]), def(def
),
947 seenParams(def
.getNumParameters()) {}
949 /// Parse the attribute or type format and create the format elements.
950 FailureOr
<DefFormat
> parse();
953 /// Verify the parsed elements.
954 LogicalResult
verify(SMLoc loc
, ArrayRef
<FormatElement
*> elements
) override
;
955 /// Verify the elements of a custom directive.
957 verifyCustomDirectiveArguments(SMLoc loc
,
958 ArrayRef
<FormatElement
*> arguments
) override
;
959 /// Verify the elements of an optional group.
960 LogicalResult
verifyOptionalGroupElements(SMLoc loc
,
961 ArrayRef
<FormatElement
*> elements
,
962 FormatElement
*anchor
) override
;
964 LogicalResult
markQualified(SMLoc loc
, FormatElement
*element
) override
;
966 /// Parse an attribute or type variable.
967 FailureOr
<FormatElement
*> parseVariableImpl(SMLoc loc
, StringRef name
,
968 Context ctx
) override
;
969 /// Parse an attribute or type format directive.
970 FailureOr
<FormatElement
*>
971 parseDirectiveImpl(SMLoc loc
, FormatToken::Kind kind
, Context ctx
) override
;
974 /// Parse a `params` directive.
975 FailureOr
<FormatElement
*> parseParamsDirective(SMLoc loc
, Context ctx
);
976 /// Parse a `struct` directive.
977 FailureOr
<FormatElement
*> parseStructDirective(SMLoc loc
, Context ctx
);
979 /// Attribute or type tablegen def.
980 const AttrOrTypeDef
&def
;
982 /// Seen attribute or type parameters.
983 BitVector seenParams
;
987 LogicalResult
DefFormatParser::verify(SMLoc loc
,
988 ArrayRef
<FormatElement
*> elements
) {
989 // Check that all parameters are referenced in the format.
990 for (auto [index
, param
] : llvm::enumerate(def
.getParameters())) {
991 if (param
.isOptional())
993 if (!seenParams
.test(index
)) {
994 if (isa
<AttributeSelfTypeParameter
>(param
))
996 return emitError(loc
, "format is missing reference to parameter: " +
999 if (isa
<AttributeSelfTypeParameter
>(param
)) {
1000 return emitError(loc
,
1001 "unexpected self type parameter in assembly format");
1004 if (elements
.empty())
1006 // A `struct` directive that contains optional parameters cannot be followed
1007 // by a comma literal, which is ambiguous.
1008 for (auto it
: llvm::zip(elements
.drop_back(), elements
.drop_front())) {
1009 auto *structEl
= dyn_cast
<StructDirective
>(std::get
<0>(it
));
1010 auto *literalEl
= dyn_cast
<LiteralElement
>(std::get
<1>(it
));
1011 if (!structEl
|| !literalEl
)
1013 if (literalEl
->getSpelling() == "," && structEl
->hasOptionalParams()) {
1014 return emitError(loc
, "`struct` directive with optional parameters "
1015 "cannot be followed by a comma literal");
1021 LogicalResult
DefFormatParser::verifyCustomDirectiveArguments(
1022 SMLoc loc
, ArrayRef
<FormatElement
*> arguments
) {
1023 // Arguments are fully verified by the parser context.
1028 DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc
,
1029 ArrayRef
<FormatElement
*> elements
,
1030 FormatElement
*anchor
) {
1031 // `params` and `struct` directives are allowed only if all the contained
1032 // parameters are optional.
1033 for (FormatElement
*el
: elements
) {
1034 if (auto *param
= dyn_cast
<ParameterElement
>(el
)) {
1035 if (!param
->isOptional()) {
1036 return emitError(loc
,
1037 "parameters in an optional group must be optional");
1039 } else if (auto *params
= dyn_cast
<ParamsDirective
>(el
)) {
1040 if (llvm::any_of(params
->getParams(), paramNotOptional
)) {
1041 return emitError(loc
, "`params` directive allowed in optional group "
1042 "only if all parameters are optional");
1044 } else if (auto *strct
= dyn_cast
<StructDirective
>(el
)) {
1045 if (llvm::any_of(strct
->getParams(), paramNotOptional
)) {
1046 return emitError(loc
, "`struct` is only allowed in an optional group "
1047 "if all captured parameters are optional");
1049 } else if (auto *custom
= dyn_cast
<CustomDirective
>(el
)) {
1050 for (FormatElement
*el
: custom
->getArguments()) {
1051 // If the custom argument is a variable, then it must be optional.
1052 if (auto *param
= dyn_cast
<ParameterElement
>(el
))
1053 if (!param
->isOptional())
1054 return emitError(loc
,
1055 "`custom` is only allowed in an optional group if "
1056 "all captured parameters are optional");
1060 // The anchor must be a parameter or one of the aforementioned directives.
1062 if (!isa
<ParameterElement
, ParamsDirective
, StructDirective
,
1063 CustomDirective
>(anchor
)) {
1065 loc
, "optional group anchor must be a parameter or directive");
1067 // If the anchor is a custom directive, make sure at least one of its
1068 // arguments is a bound parameter.
1069 if (auto *custom
= dyn_cast
<CustomDirective
>(anchor
)) {
1071 llvm::find_if(custom
->getArguments(), [](FormatElement
*el
) {
1072 return isa
<ParameterElement
>(el
);
1074 if (bound
== custom
->getArguments().end())
1075 return emitError(loc
, "`custom` directive with no bound parameters "
1076 "cannot be used as optional group anchor");
1082 LogicalResult
DefFormatParser::markQualified(SMLoc loc
,
1083 FormatElement
*element
) {
1084 if (!isa
<ParameterElement
>(element
))
1085 return emitError(loc
, "`qualified` argument list expected a variable");
1086 cast
<ParameterElement
>(element
)->setShouldBeQualified();
1090 FailureOr
<DefFormat
> DefFormatParser::parse() {
1091 FailureOr
<std::vector
<FormatElement
*>> elements
= FormatParser::parse();
1092 if (failed(elements
))
1094 return DefFormat(def
, std::move(*elements
));
1097 FailureOr
<FormatElement
*>
1098 DefFormatParser::parseVariableImpl(SMLoc loc
, StringRef name
, Context ctx
) {
1099 // Lookup the parameter.
1100 ArrayRef
<AttrOrTypeParameter
> params
= def
.getParameters();
1101 auto *it
= llvm::find_if(
1102 params
, [&](auto ¶m
) { return param
.getName() == name
; });
1104 // Check that the parameter reference is valid.
1105 if (it
== params
.end()) {
1106 return emitError(loc
,
1107 def
.getName() + " has no parameter named '" + name
+ "'");
1109 auto idx
= std::distance(params
.begin(), it
);
1111 if (ctx
!= RefDirectiveContext
) {
1112 // Check that the variable has not already been bound.
1113 if (seenParams
.test(idx
))
1114 return emitError(loc
, "duplicate parameter '" + name
+ "'");
1115 seenParams
.set(idx
);
1117 // Otherwise, to be referenced, a variable must have been bound.
1118 } else if (!seenParams
.test(idx
) && !isa
<AttributeSelfTypeParameter
>(*it
)) {
1119 return emitError(loc
, "parameter '" + name
+
1120 "' must be bound before it is referenced");
1123 return create
<ParameterElement
>(*it
);
1126 FailureOr
<FormatElement
*>
1127 DefFormatParser::parseDirectiveImpl(SMLoc loc
, FormatToken::Kind kind
,
1131 case FormatToken::kw_qualified
:
1132 return parseQualifiedDirective(loc
, ctx
);
1133 case FormatToken::kw_params
:
1134 return parseParamsDirective(loc
, ctx
);
1135 case FormatToken::kw_struct
:
1136 return parseStructDirective(loc
, ctx
);
1138 return emitError(loc
, "unsupported directive kind");
1142 FailureOr
<FormatElement
*> DefFormatParser::parseParamsDirective(SMLoc loc
,
1144 // It doesn't make sense to allow references to all parameters in a custom
1145 // directive because parameters are the only things that can be bound.
1146 if (ctx
!= TopLevelContext
&& ctx
!= StructDirectiveContext
) {
1147 return emitError(loc
, "`params` can only be used at the top-level context "
1148 "or within a `struct` directive");
1151 // Collect all of the attribute's or type's parameters and ensure that none of
1152 // the parameters have already been captured.
1153 std::vector
<ParameterElement
*> vars
;
1154 for (const auto &it
: llvm::enumerate(def
.getParameters())) {
1155 if (seenParams
.test(it
.index())) {
1156 return emitError(loc
, "`params` captures duplicate parameter: " +
1157 it
.value().getName());
1159 // Self-type parameters are handled separately from the rest of the
1161 if (isa
<AttributeSelfTypeParameter
>(it
.value()))
1163 seenParams
.set(it
.index());
1164 vars
.push_back(create
<ParameterElement
>(it
.value()));
1166 return create
<ParamsDirective
>(std::move(vars
));
1169 FailureOr
<FormatElement
*> DefFormatParser::parseStructDirective(SMLoc loc
,
1171 if (ctx
!= TopLevelContext
)
1172 return emitError(loc
, "`struct` can only be used at the top-level context");
1174 if (failed(parseToken(FormatToken::l_paren
,
1175 "expected '(' before `struct` argument list")))
1178 // Parse variables captured by `struct`.
1179 std::vector
<ParameterElement
*> vars
;
1181 // Parse first captured parameter or a `params` directive.
1182 FailureOr
<FormatElement
*> var
= parseElement(StructDirectiveContext
);
1183 if (failed(var
) || !isa
<VariableElement
, ParamsDirective
>(*var
)) {
1184 return emitError(loc
,
1185 "`struct` argument list expected a variable or directive");
1187 if (isa
<VariableElement
>(*var
)) {
1188 // Parse any other parameters.
1189 vars
.push_back(cast
<ParameterElement
>(*var
));
1190 while (peekToken().is(FormatToken::comma
)) {
1192 var
= parseElement(StructDirectiveContext
);
1193 if (failed(var
) || !isa
<VariableElement
>(*var
))
1194 return emitError(loc
, "expected a variable in `struct` argument list");
1195 vars
.push_back(cast
<ParameterElement
>(*var
));
1198 // `struct(params)` captures all parameters in the attribute or type.
1199 vars
= cast
<ParamsDirective
>(*var
)->takeParams();
1202 if (failed(parseToken(FormatToken::r_paren
,
1203 "expected ')' at the end of an argument list")))
1206 return create
<StructDirective
>(std::move(vars
));
1209 //===----------------------------------------------------------------------===//
1211 //===----------------------------------------------------------------------===//
1213 void mlir::tblgen::generateAttrOrTypeFormat(const AttrOrTypeDef
&def
,
1215 MethodBody
&printer
) {
1216 llvm::SourceMgr mgr
;
1217 mgr
.AddNewSourceBuffer(
1218 llvm::MemoryBuffer::getMemBuffer(*def
.getAssemblyFormat()), SMLoc());
1220 // Parse the custom assembly format>
1221 DefFormatParser
fmtParser(mgr
, def
);
1222 FailureOr
<DefFormat
> format
= fmtParser
.parse();
1223 if (failed(format
)) {
1224 if (formatErrorIsFatal
)
1225 PrintFatalError(def
.getLoc(), "failed to parse assembly format");
1229 // Generate the parser and printer.
1230 format
->genParser(parser
);
1231 format
->genPrinter(printer
);