1 //===- AttrOrTypeFormatGen.cpp - MLIR attribute and type format generator -===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "AttrOrTypeFormatGen.h"
10 #include "FormatGen.h"
11 #include "mlir/Support/LLVM.h"
12 #include "mlir/Support/LogicalResult.h"
13 #include "mlir/TableGen/AttrOrTypeDef.h"
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/GenInfo.h"
16 #include "llvm/ADT/BitVector.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/ADT/StringSwitch.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 #include "llvm/Support/MemoryBuffer.h"
21 #include "llvm/Support/SaveAndRestore.h"
22 #include "llvm/Support/SourceMgr.h"
23 #include "llvm/TableGen/Error.h"
24 #include "llvm/TableGen/TableGenBackend.h"
27 using namespace mlir::tblgen
;
31 //===----------------------------------------------------------------------===//
33 //===----------------------------------------------------------------------===//
36 /// This class represents an instance of a variable element. A variable refers
37 /// to an attribute or type parameter.
38 class ParameterElement
39 : public VariableElementBase
<VariableElement::Parameter
> {
41 ParameterElement(AttrOrTypeParameter param
) : param(param
) {}
43 /// Get the parameter in the element.
44 const AttrOrTypeParameter
&getParam() const { return param
; }
46 /// Indicate if this variable is printed "qualified" (that is it is
47 /// prefixed with the `#dialect.mnemonic`).
48 bool shouldBeQualified() { return shouldBeQualifiedFlag
; }
49 void setShouldBeQualified(bool qualified
= true) {
50 shouldBeQualifiedFlag
= qualified
;
53 /// Returns true if the element contains an optional parameter.
54 bool isOptional() const { return param
.isOptional(); }
56 /// Returns the name of the parameter.
57 StringRef
getName() const { return param
.getName(); }
59 /// Return the code to check whether the parameter is present.
60 auto genIsPresent(FmtContext
&ctx
, const Twine
&self
) const {
61 assert(isOptional() && "cannot guard on a mandatory parameter");
62 std::string valueStr
= tgfmt(*param
.getDefaultValue(), &ctx
).str();
63 ctx
.addSubst("_lhs", self
).addSubst("_rhs", valueStr
);
64 return tgfmt(getParam().getComparator(), &ctx
);
67 /// Generate the code to check whether the parameter should be printed.
68 MethodBody
&genPrintGuard(FmtContext
&ctx
, MethodBody
&os
) const {
69 assert(isOptional() && "cannot guard on a mandatory parameter");
70 std::string self
= param
.getAccessorName() + "()";
71 return os
<< "!(" << genIsPresent(ctx
, self
) << ")";
75 bool shouldBeQualifiedFlag
= false;
76 AttrOrTypeParameter param
;
79 /// Shorthand functions that can be used with ranged-based conditions.
80 static bool paramIsOptional(ParameterElement
*el
) { return el
->isOptional(); }
81 static bool paramNotOptional(ParameterElement
*el
) { return !el
->isOptional(); }
83 /// Base class for a directive that contains references to multiple variables.
84 template <DirectiveElement::Kind DirectiveKind
>
85 class ParamsDirectiveBase
: public DirectiveElementBase
<DirectiveKind
> {
87 using Base
= ParamsDirectiveBase
<DirectiveKind
>;
89 ParamsDirectiveBase(std::vector
<ParameterElement
*> &¶ms
)
90 : params(std::move(params
)) {}
92 /// Get the parameters contained in this directive.
93 ArrayRef
<ParameterElement
*> getParams() const { return params
; }
95 /// Get the number of parameters.
96 unsigned getNumParams() const { return params
.size(); }
98 /// Take all of the parameters from this directive.
99 std::vector
<ParameterElement
*> takeParams() { return std::move(params
); }
101 /// Returns true if there are optional parameters present.
102 bool hasOptionalParams() const {
103 return llvm::any_of(getParams(), paramIsOptional
);
107 /// The parameters captured by this directive.
108 std::vector
<ParameterElement
*> params
;
111 /// This class represents a `params` directive that refers to all parameters
112 /// of an attribute or type. When used as a top-level directive, it generates
113 /// a format of the form:
115 /// (param-value (`,` param-value)*)?
117 /// When used as an argument to another directive that accepts variables,
118 /// `params` can be used in place of manually listing all parameters of an
119 /// attribute or type.
120 class ParamsDirective
: public ParamsDirectiveBase
<DirectiveElement::Params
> {
125 /// This class represents a `struct` directive that generates a struct format
128 /// `{` param-name `=` param-value (`,` param-name `=` param-value)* `}`
130 class StructDirective
: public ParamsDirectiveBase
<DirectiveElement::Struct
> {
137 //===----------------------------------------------------------------------===//
139 //===----------------------------------------------------------------------===//
141 /// Default parser for attribute or type parameters.
142 static const char *const defaultParameterParser
=
143 "::mlir::FieldParser<$0>::parse($_parser)";
145 /// Default printer for attribute or type parameters.
146 static const char *const defaultParameterPrinter
=
147 "$_printer.printStrippedAttrOrType($_self)";
149 /// Qualified printer for attribute or type parameters: it does not elide
150 /// dialect and mnemonic.
151 static const char *const qualifiedParameterPrinter
= "$_printer << $_self";
153 /// Print an error when failing to parse an element.
155 /// $0: The parameter C++ class name.
156 static const char *const parserErrorStr
=
157 "$_parser.emitError($_parser.getCurrentLocation(), ";
159 /// Code format to parse a variable. Separate by lines because variable parsers
160 /// may be generated inside other directives, which requires indentation.
162 /// {0}: The parameter name.
163 /// {1}: The parse code for the parameter.
164 /// {2}: Code template for printing an error.
165 /// {3}: Name of the attribute or type.
166 /// {4}: C++ class of the parameter.
167 static const char *const variableParser
= R
"(
168 // Parse variable '{0}'
170 if (::mlir::failed(_result_{0})) {{
171 {2}"failed to parse
{3} parameter
'{0}' which is to be a `
{4}`
");
176 //===----------------------------------------------------------------------===//
178 //===----------------------------------------------------------------------===//
183 DefFormat(const AttrOrTypeDef
&def
, std::vector
<FormatElement
*> &&elements
)
184 : def(def
), elements(std::move(elements
)) {}
186 /// Generate the attribute or type parser.
187 void genParser(MethodBody
&os
);
188 /// Generate the attribute or type printer.
189 void genPrinter(MethodBody
&os
);
192 /// Generate the parser code for a specific format element.
193 void genElementParser(FormatElement
*el
, FmtContext
&ctx
, MethodBody
&os
);
194 /// Generate the parser code for a literal.
195 void genLiteralParser(StringRef value
, FmtContext
&ctx
, MethodBody
&os
,
196 bool isOptional
= false);
197 /// Generate the parser code for a variable.
198 void genVariableParser(ParameterElement
*el
, FmtContext
&ctx
, MethodBody
&os
);
199 /// Generate the parser code for a `params` directive.
200 void genParamsParser(ParamsDirective
*el
, FmtContext
&ctx
, MethodBody
&os
);
201 /// Generate the parser code for a `struct` directive.
202 void genStructParser(StructDirective
*el
, FmtContext
&ctx
, MethodBody
&os
);
203 /// Generate the parser code for a `custom` directive.
204 void genCustomParser(CustomDirective
*el
, FmtContext
&ctx
, MethodBody
&os
,
205 bool isOptional
= false);
206 /// Generate the parser code for an optional group.
207 void genOptionalGroupParser(OptionalElement
*el
, FmtContext
&ctx
,
210 /// Generate the printer code for a specific format element.
211 void genElementPrinter(FormatElement
*el
, FmtContext
&ctx
, MethodBody
&os
);
212 /// Generate the printer code for a literal.
213 void genLiteralPrinter(StringRef value
, FmtContext
&ctx
, MethodBody
&os
);
214 /// Generate the printer code for a variable.
215 void genVariablePrinter(ParameterElement
*el
, FmtContext
&ctx
, MethodBody
&os
,
216 bool skipGuard
= false);
217 /// Generate a printer for comma-separated parameters.
218 void genCommaSeparatedPrinter(ArrayRef
<ParameterElement
*> params
,
219 FmtContext
&ctx
, MethodBody
&os
,
220 function_ref
<void(ParameterElement
*)> extra
);
221 /// Generate the printer code for a `params` directive.
222 void genParamsPrinter(ParamsDirective
*el
, FmtContext
&ctx
, MethodBody
&os
);
223 /// Generate the printer code for a `struct` directive.
224 void genStructPrinter(StructDirective
*el
, FmtContext
&ctx
, MethodBody
&os
);
225 /// Generate the printer code for a `custom` directive.
226 void genCustomPrinter(CustomDirective
*el
, FmtContext
&ctx
, MethodBody
&os
);
227 /// Generate the printer code for an optional group.
228 void genOptionalGroupPrinter(OptionalElement
*el
, FmtContext
&ctx
,
230 /// Generate a printer (or space eraser) for a whitespace element.
231 void genWhitespacePrinter(WhitespaceElement
*el
, FmtContext
&ctx
,
234 /// The ODS definition of the attribute or type whose format is being used to
235 /// generate a parser and printer.
236 const AttrOrTypeDef
&def
;
237 /// The list of top-level format elements returned by the assembly format
239 std::vector
<FormatElement
*> elements
;
241 /// Flags for printing spaces.
242 bool shouldEmitSpace
= false;
243 bool lastWasPunctuation
= false;
247 //===----------------------------------------------------------------------===//
249 //===----------------------------------------------------------------------===//
251 /// Generate a special-case "parser" for an attribute's self type parameter. The
252 /// self type parameter has special handling in the assembly format in that it
253 /// is derived from the optional trailing colon type after the attribute.
254 static void genAttrSelfTypeParser(MethodBody
&os
, const FmtContext
&ctx
,
255 const AttributeSelfTypeParameter
¶m
) {
256 // "Parser" for an attribute self type parameter that checks the
257 // optionally-parsed trailing colon type.
259 // $0: The C++ storage class of the type parameter.
260 // $1: The self type parameter name.
261 const char *const selfTypeParser
= R
"(
263 if (auto reqType = ::llvm::dyn_cast<$0>($_type)) {
264 _result_$1 = reqType;
266 $_parser.emitError($_loc, "invalid kind of type specified
");
271 // If the attribute self type parameter is required, emit code that emits an
272 // error if the trailing type was not parsed.
273 const char *const selfTypeRequired
= R
"( else {
274 $_parser.emitError($_loc, "expected a trailing type
");
278 os
<< tgfmt(selfTypeParser
, &ctx
, param
.getCppStorageType(), param
.getName());
279 if (!param
.isOptional())
280 os
<< tgfmt(selfTypeRequired
, &ctx
);
284 void DefFormat::genParser(MethodBody
&os
) {
286 ctx
.addSubst("_parser", "odsParser");
287 ctx
.addSubst("_ctxt", "odsParser.getContext()");
288 ctx
.withBuilder("odsBuilder");
289 if (isa
<AttrDef
>(def
))
290 ctx
.addSubst("_type", "odsType");
292 os
<< "::mlir::Builder odsBuilder(odsParser.getContext());\n";
294 // Store the initial location of the parser.
295 ctx
.addSubst("_loc", "odsLoc");
296 os
<< tgfmt("::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n"
300 // Declare variables to store all of the parameters. Allocated parameters
301 // such as `ArrayRef` and `StringRef` must provide a `storageType`. Store
302 // FailureOr<T> to defer type construction for parameters that are parsed in
303 // a loop (parsers return FailureOr anyways).
304 ArrayRef
<AttrOrTypeParameter
> params
= def
.getParameters();
305 for (const AttrOrTypeParameter
¶m
: params
) {
306 os
<< formatv("::mlir::FailureOr<{0}> _result_{1};\n",
307 param
.getCppStorageType(), param
.getName());
308 if (auto *selfTypeParam
= dyn_cast
<AttributeSelfTypeParameter
>(¶m
))
309 genAttrSelfTypeParser(os
, ctx
, *selfTypeParam
);
312 // Generate call to each parameter parser.
313 for (FormatElement
*el
: elements
)
314 genElementParser(el
, ctx
, os
);
316 // Emit an assert for each mandatory parameter. Triggering an assert means
317 // the generated parser is incorrect (i.e. there is a bug in this code).
318 for (const AttrOrTypeParameter
¶m
: params
) {
319 if (param
.isOptional())
321 os
<< formatv("assert(::mlir::succeeded(_result_{0}));\n", param
.getName());
324 // Generate call to the attribute or type builder. Use the checked getter
325 // if one was generated.
326 if (def
.genVerifyDecl()) {
327 os
<< tgfmt("return $_parser.getChecked<$0>($_loc, $_parser.getContext()",
328 &ctx
, def
.getCppClassName());
330 os
<< tgfmt("return $0::get($_parser.getContext()", &ctx
,
331 def
.getCppClassName());
333 for (const AttrOrTypeParameter
¶m
: params
) {
335 std::string paramSelfStr
;
336 llvm::raw_string_ostream
selfOs(paramSelfStr
);
337 if (std::optional
<StringRef
> defaultValue
= param
.getDefaultValue()) {
338 selfOs
<< formatv("(_result_{0}.value_or(", param
.getName())
339 << tgfmt(*defaultValue
, &ctx
) << "))";
341 selfOs
<< formatv("(*_result_{0})", param
.getName());
343 ctx
.addSubst(param
.getName(), selfOs
.str());
344 os
<< param
.getCppType() << "("
345 << tgfmt(param
.getConvertFromStorage(), &ctx
.withSelf(selfOs
.str()))
351 void DefFormat::genElementParser(FormatElement
*el
, FmtContext
&ctx
,
353 if (auto *literal
= dyn_cast
<LiteralElement
>(el
))
354 return genLiteralParser(literal
->getSpelling(), ctx
, os
);
355 if (auto *var
= dyn_cast
<ParameterElement
>(el
))
356 return genVariableParser(var
, ctx
, os
);
357 if (auto *params
= dyn_cast
<ParamsDirective
>(el
))
358 return genParamsParser(params
, ctx
, os
);
359 if (auto *strct
= dyn_cast
<StructDirective
>(el
))
360 return genStructParser(strct
, ctx
, os
);
361 if (auto *custom
= dyn_cast
<CustomDirective
>(el
))
362 return genCustomParser(custom
, ctx
, os
);
363 if (auto *optional
= dyn_cast
<OptionalElement
>(el
))
364 return genOptionalGroupParser(optional
, ctx
, os
);
365 if (isa
<WhitespaceElement
>(el
))
368 llvm_unreachable("unknown format element");
371 void DefFormat::genLiteralParser(StringRef value
, FmtContext
&ctx
,
372 MethodBody
&os
, bool isOptional
) {
373 os
<< "// Parse literal '" << value
<< "'\n";
374 os
<< tgfmt("if ($_parser.parse", &ctx
);
377 if (value
.front() == '_' || isalpha(value
.front())) {
378 os
<< "Keyword(\"" << value
<< "\")";
380 os
<< StringSwitch
<StringRef
>(value
)
386 .Case(">", "Greater")
391 .Case("[", "LSquare")
392 .Case("]", "RSquare")
393 .Case("?", "Question")
396 .Case("...", "Ellipsis")
400 // Leave the `if` unclosed to guard optional groups.
403 // Parser will emit an error
404 os
<< ") return {};\n";
407 void DefFormat::genVariableParser(ParameterElement
*el
, FmtContext
&ctx
,
409 // Check for a custom parser. Use the default attribute parser otherwise.
410 const AttrOrTypeParameter
¶m
= el
->getParam();
411 auto customParser
= param
.getParser();
413 customParser
? *customParser
: StringRef(defaultParameterParser
);
414 os
<< formatv(variableParser
, param
.getName(),
415 tgfmt(parser
, &ctx
, param
.getCppStorageType()),
416 tgfmt(parserErrorStr
, &ctx
), def
.getName(), param
.getCppType());
419 void DefFormat::genParamsParser(ParamsDirective
*el
, FmtContext
&ctx
,
421 os
<< "// Parse parameter list\n";
423 // If there are optional parameters, we need to switch to `parseOptionalComma`
424 // if there are no more required parameters after a certain point.
425 bool hasOptional
= el
->hasOptionalParams();
427 // Wrap everything in a do-while so that we can `break`.
432 ArrayRef
<ParameterElement
*> params
= el
->getParams();
433 using IteratorT
= ParameterElement
*const *;
434 IteratorT it
= params
.begin();
436 // Find the last required parameter. Commas become optional aftewards.
437 // Note: IteratorT's copy assignment is deleted.
438 ParameterElement
*lastReq
= nullptr;
439 for (ParameterElement
*param
: params
)
440 if (!param
->isOptional())
442 IteratorT lastReqIt
= lastReq
? llvm::find(params
, lastReq
) : params
.begin();
444 auto eachFn
= [&](ParameterElement
*el
) { genVariableParser(el
, ctx
, os
); };
445 auto betweenFn
= [&](IteratorT it
) {
446 ParameterElement
*el
= *std::prev(it
);
447 // Parse a comma if the last optional parameter had a value.
448 if (el
->isOptional()) {
449 os
<< formatv("if (::mlir::succeeded(_result_{0}) && !({1})) {{\n",
451 el
->genIsPresent(ctx
, "(*_result_" + el
->getName() + ")"));
454 if (it
<= lastReqIt
) {
455 genLiteralParser(",", ctx
, os
);
457 genLiteralParser(",", ctx
, os
, /*isOptional=*/true);
460 if (el
->isOptional())
461 os
.unindent() << "}\n";
465 if (it
!= params
.end()) {
467 for (IteratorT e
= params
.end(); it
!= e
; ++it
) {
474 os
.unindent() << "} while(false);\n";
477 void DefFormat::genStructParser(StructDirective
*el
, FmtContext
&ctx
,
479 // Loop declaration for struct parser with only required parameters.
481 // $0: Number of expected parameters.
482 const char *const loopHeader
= R
"(
483 for (unsigned odsStructIndex = 0; odsStructIndex < $0; ++odsStructIndex) {
486 // Loop body start for struct parser.
487 const char *const loopStart
= R
"(
488 ::llvm::StringRef _paramKey;
489 if ($_parser.parseKeyword(&_paramKey)) {
490 $_parser.emitError($_parser.getCurrentLocation(),
491 "expected a parameter name in
struct");
494 if (!_loop_body(_paramKey)) return {};
497 // Struct parser loop end. Check for duplicate or unknown struct parameters.
499 // {0}: Code template for printing an error.
500 const char *const loopEnd
= R
"({{
501 {0}"duplicate
or unknown
struct parameter name
: ") << _paramKey;
506 // Struct parser loop terminator. Parse a comma except on the last element.
508 // {0}: Number of elements in the struct.
509 const char *const loopTerminator
= R
"(
510 if ((odsStructIndex != {0} - 1) && odsParser.parseComma())
515 // Check that a mandatory parameter was parse.
517 // {0}: Name of the parameter.
518 const char *const checkParam
= R
"(
520 {1}"struct is missing required parameter
: ") << "{0}";
525 // First iteration of the loop parsing an optional struct.
526 const char *const optionalStructFirst
= R
"(
527 ::llvm::StringRef _paramKey;
528 if (!$_parser.parseOptionalKeyword(&_paramKey)) {
529 if (!_loop_body(_paramKey)) return {};
530 while (!$_parser.parseOptionalComma()) {
533 os
<< "// Parse parameter struct\n";
535 // Declare a "seen" variable for each key.
536 for (ParameterElement
*param
: el
->getParams())
537 os
<< formatv("bool _seen_{0} = false;\n", param
->getName());
539 // Generate the body of the parsing loop inside a lambda.
542 << "const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {\n";
543 genLiteralParser("=", ctx
, os
.indent());
544 for (ParameterElement
*param
: el
->getParams()) {
545 os
<< formatv("if (!_seen_{0} && _paramKey == \"{0}\") {\n"
546 " _seen_{0} = true;\n",
548 genVariableParser(param
, ctx
, os
.indent());
549 os
.unindent() << "} else ";
550 // Print the check for duplicate or unknown parameter.
552 os
.getStream().printReindented(strfmt(loopEnd
, tgfmt(parserErrorStr
, &ctx
)));
553 os
<< "return true;\n";
554 os
.unindent() << "};\n";
556 // Generate the parsing loop. If optional parameters are present, then the
557 // parse loop is guarded by commas.
558 unsigned numOptional
= llvm::count_if(el
->getParams(), paramIsOptional
);
560 // If the struct itself is optional, pull out the first iteration.
561 if (numOptional
== el
->getNumParams()) {
562 os
.getStream().printReindented(tgfmt(optionalStructFirst
, &ctx
).str());
568 os
.getStream().printReindented(
569 tgfmt(loopHeader
, &ctx
, el
->getNumParams()).str());
572 os
.getStream().printReindented(tgfmt(loopStart
, &ctx
).str());
575 // Print the loop terminator. For optional parameters, we have to check that
576 // all mandatory parameters have been parsed.
577 // The whole struct is optional if all its parameters are optional.
579 if (numOptional
== el
->getNumParams()) {
581 os
.unindent() << "}\n";
583 os
<< tgfmt("} while(!$_parser.parseOptionalComma());\n", &ctx
);
584 for (ParameterElement
*param
: el
->getParams()) {
585 if (param
->isOptional())
587 os
.getStream().printReindented(
588 strfmt(checkParam
, param
->getName(), tgfmt(parserErrorStr
, &ctx
)));
592 // Because the loop loops N times and each non-failing iteration sets 1 of
593 // N flags, successfully exiting the loop means that all parameters have
594 // been seen. `parseOptionalComma` would cause issues with any formats that
595 // use "struct(...) `,`" beacuse structs aren't sounded by braces.
596 os
.getStream().printReindented(strfmt(loopTerminator
, el
->getNumParams()));
598 os
.unindent() << "}\n";
601 void DefFormat::genCustomParser(CustomDirective
*el
, FmtContext
&ctx
,
602 MethodBody
&os
, bool isOptional
) {
606 // Bound variables are passed directly to the parser as `FailureOr<T> &`.
607 // Referenced variables are passed as `T`. The custom parser fails if it
608 // returns failure or if any of the required parameters failed.
609 os
<< tgfmt("auto odsCustomLoc = $_parser.getCurrentLocation();\n", &ctx
);
610 os
<< "(void)odsCustomLoc;\n";
611 os
<< tgfmt("auto odsCustomResult = parse$0($_parser", &ctx
, el
->getName());
613 for (FormatElement
*arg
: el
->getArguments()) {
615 if (auto *param
= dyn_cast
<ParameterElement
>(arg
))
616 os
<< "::mlir::detail::unwrapForCustomParse(_result_" << param
->getName()
618 else if (auto *ref
= dyn_cast
<RefDirective
>(arg
))
619 os
<< "*_result_" << cast
<ParameterElement
>(ref
->getArg())->getName();
621 os
<< tgfmt(cast
<StringElement
>(arg
)->getValue(), &ctx
);
623 os
.unindent() << ");\n";
625 os
<< "if (!odsCustomResult) return {};\n";
626 os
<< "if (::mlir::failed(*odsCustomResult)) return ::mlir::failure();\n";
628 os
<< "if (::mlir::failed(odsCustomResult)) return {};\n";
630 for (FormatElement
*arg
: el
->getArguments()) {
631 if (auto *param
= dyn_cast
<ParameterElement
>(arg
)) {
632 if (param
->isOptional())
634 os
<< formatv("if (::mlir::failed(_result_{0})) {{\n", param
->getName());
635 os
.indent() << tgfmt("$_parser.emitError(odsCustomLoc, ", &ctx
)
636 << "\"custom parser failed to parse parameter '"
637 << param
->getName() << "'\");\n";
638 os
<< "return " << (isOptional
? "::mlir::failure()" : "{}") << ";\n";
639 os
.unindent() << "}\n";
643 os
.unindent() << "}\n";
646 void DefFormat::genOptionalGroupParser(OptionalElement
*el
, FmtContext
&ctx
,
648 ArrayRef
<FormatElement
*> thenElements
=
649 el
->getThenElements(/*parseable=*/true);
651 FormatElement
*first
= thenElements
.front();
652 const auto guardOn
= [&](auto params
) {
656 [&](ParameterElement
*el
) {
657 os
<< formatv("(::mlir::succeeded(_result_{0}) && *_result_{0})",
663 if (auto *literal
= dyn_cast
<LiteralElement
>(first
)) {
664 genLiteralParser(literal
->getSpelling(), ctx
, os
, /*isOptional=*/true);
666 } else if (auto *param
= dyn_cast
<ParameterElement
>(first
)) {
667 genVariableParser(param
, ctx
, os
);
668 guardOn(llvm::ArrayRef(param
));
669 } else if (auto *params
= dyn_cast
<ParamsDirective
>(first
)) {
670 genParamsParser(params
, ctx
, os
);
671 guardOn(params
->getParams());
672 } else if (auto *custom
= dyn_cast
<CustomDirective
>(first
)) {
673 os
<< "if (auto result = [&]() -> ::mlir::OptionalParseResult {\n";
675 genCustomParser(custom
, ctx
, os
, /*isOptional=*/true);
676 os
<< "return ::mlir::success();\n";
678 os
<< "}(); result.has_value() && ::mlir::failed(*result)) {\n";
680 os
<< "return {};\n";
682 os
<< "} else if (result.has_value()) {\n";
684 auto *strct
= cast
<StructDirective
>(first
);
685 genStructParser(strct
, ctx
, os
);
686 guardOn(params
->getParams());
690 // Generate the parsers for the rest of the thenElements.
691 for (FormatElement
*element
: el
->getElseElements(/*parseable=*/true))
692 genElementParser(element
, ctx
, os
);
693 os
.unindent() << "} else {\n";
695 for (FormatElement
*element
: thenElements
.drop_front())
696 genElementParser(element
, ctx
, os
);
697 os
.unindent() << "}\n";
700 //===----------------------------------------------------------------------===//
702 //===----------------------------------------------------------------------===//
704 void DefFormat::genPrinter(MethodBody
&os
) {
706 ctx
.addSubst("_printer", "odsPrinter");
707 ctx
.addSubst("_ctxt", "getContext()");
708 ctx
.withBuilder("odsBuilder");
710 os
<< "::mlir::Builder odsBuilder(getContext());\n";
712 // Generate printers.
713 shouldEmitSpace
= true;
714 lastWasPunctuation
= false;
715 for (FormatElement
*el
: elements
)
716 genElementPrinter(el
, ctx
, os
);
719 void DefFormat::genElementPrinter(FormatElement
*el
, FmtContext
&ctx
,
721 if (auto *literal
= dyn_cast
<LiteralElement
>(el
))
722 return genLiteralPrinter(literal
->getSpelling(), ctx
, os
);
723 if (auto *params
= dyn_cast
<ParamsDirective
>(el
))
724 return genParamsPrinter(params
, ctx
, os
);
725 if (auto *strct
= dyn_cast
<StructDirective
>(el
))
726 return genStructPrinter(strct
, ctx
, os
);
727 if (auto *custom
= dyn_cast
<CustomDirective
>(el
))
728 return genCustomPrinter(custom
, ctx
, os
);
729 if (auto *var
= dyn_cast
<ParameterElement
>(el
))
730 return genVariablePrinter(var
, ctx
, os
);
731 if (auto *optional
= dyn_cast
<OptionalElement
>(el
))
732 return genOptionalGroupPrinter(optional
, ctx
, os
);
733 if (auto *whitespace
= dyn_cast
<WhitespaceElement
>(el
))
734 return genWhitespacePrinter(whitespace
, ctx
, os
);
736 llvm::PrintFatalError("unsupported format element");
739 void DefFormat::genLiteralPrinter(StringRef value
, FmtContext
&ctx
,
741 // Don't insert a space before certain punctuation.
743 shouldEmitSpace
&& shouldEmitSpaceBefore(value
, lastWasPunctuation
);
744 os
<< tgfmt("$_printer$0 << \"$1\";\n", &ctx
, needSpace
? " << ' '" : "",
749 value
.size() != 1 || !StringRef("<({[").contains(value
.front());
750 lastWasPunctuation
= value
.front() != '_' && !isalpha(value
.front());
753 void DefFormat::genVariablePrinter(ParameterElement
*el
, FmtContext
&ctx
,
754 MethodBody
&os
, bool skipGuard
) {
755 const AttrOrTypeParameter
¶m
= el
->getParam();
756 ctx
.withSelf(param
.getAccessorName() + "()");
758 // Guard the printer on the presence of optional parameters and that they
759 // aren't equal to their default values (if they have one).
760 if (el
->isOptional() && !skipGuard
) {
761 el
->genPrintGuard(ctx
, os
<< "if (") << ") {\n";
765 // Insert a space before the next parameter, if necessary.
766 if (shouldEmitSpace
|| !lastWasPunctuation
)
767 os
<< tgfmt("$_printer << ' ';\n", &ctx
);
768 shouldEmitSpace
= true;
769 lastWasPunctuation
= false;
771 if (el
->shouldBeQualified())
772 os
<< tgfmt(qualifiedParameterPrinter
, &ctx
) << ";\n";
773 else if (auto printer
= param
.getPrinter())
774 os
<< tgfmt(*printer
, &ctx
) << ";\n";
776 os
<< tgfmt(defaultParameterPrinter
, &ctx
) << ";\n";
778 if (el
->isOptional() && !skipGuard
)
779 os
.unindent() << "}\n";
782 /// Generate code to guard printing on the presence of any optional parameters.
783 template <typename ParameterRange
>
784 static void guardOnAny(FmtContext
&ctx
, MethodBody
&os
, ParameterRange
&¶ms
,
785 bool inverted
= false) {
791 [&](ParameterElement
*param
) { param
->genPrintGuard(ctx
, os
); }, " || ");
798 void DefFormat::genCommaSeparatedPrinter(
799 ArrayRef
<ParameterElement
*> params
, FmtContext
&ctx
, MethodBody
&os
,
800 function_ref
<void(ParameterElement
*)> extra
) {
801 // Emit a space if necessary, but only if the struct is present.
802 if (shouldEmitSpace
|| !lastWasPunctuation
) {
803 bool allOptional
= llvm::all_of(params
, paramIsOptional
);
805 guardOnAny(ctx
, os
, params
);
806 os
<< tgfmt("$_printer << ' ';\n", &ctx
);
808 os
.unindent() << "}\n";
811 // The first printed element does not need to emit a comma.
813 os
.indent() << "bool _firstPrinted = true;\n";
814 for (ParameterElement
*param
: params
) {
815 if (param
->isOptional()) {
816 param
->genPrintGuard(ctx
, os
<< "if (") << ") {\n";
819 os
<< tgfmt("if (!_firstPrinted) $_printer << \", \";\n", &ctx
);
820 os
<< "_firstPrinted = false;\n";
822 shouldEmitSpace
= false;
823 lastWasPunctuation
= true;
824 genVariablePrinter(param
, ctx
, os
);
825 if (param
->isOptional())
826 os
.unindent() << "}\n";
828 os
.unindent() << "}\n";
831 void DefFormat::genParamsPrinter(ParamsDirective
*el
, FmtContext
&ctx
,
833 genCommaSeparatedPrinter(llvm::to_vector(el
->getParams()), ctx
, os
,
834 [&](ParameterElement
*param
) {});
837 void DefFormat::genStructPrinter(StructDirective
*el
, FmtContext
&ctx
,
839 genCommaSeparatedPrinter(
840 llvm::to_vector(el
->getParams()), ctx
, os
, [&](ParameterElement
*param
) {
841 os
<< tgfmt("$_printer << \"$0 = \";\n", &ctx
, param
->getName());
845 void DefFormat::genCustomPrinter(CustomDirective
*el
, FmtContext
&ctx
,
847 // Insert a space before the custom directive, if necessary.
848 if (shouldEmitSpace
|| !lastWasPunctuation
)
849 os
<< tgfmt("$_printer << ' ';\n", &ctx
);
850 shouldEmitSpace
= true;
851 lastWasPunctuation
= false;
853 os
<< tgfmt("print$0($_printer", &ctx
, el
->getName());
855 for (FormatElement
*arg
: el
->getArguments()) {
857 if (auto *param
= dyn_cast
<ParameterElement
>(arg
)) {
858 os
<< param
->getParam().getAccessorName() << "()";
859 } else if (auto *ref
= dyn_cast
<RefDirective
>(arg
)) {
860 os
<< cast
<ParameterElement
>(ref
->getArg())->getParam().getAccessorName()
863 os
<< tgfmt(cast
<StringElement
>(arg
)->getValue(), &ctx
);
866 os
.unindent() << ");\n";
869 void DefFormat::genOptionalGroupPrinter(OptionalElement
*el
, FmtContext
&ctx
,
871 FormatElement
*anchor
= el
->getAnchor();
872 if (auto *param
= dyn_cast
<ParameterElement
>(anchor
)) {
873 guardOnAny(ctx
, os
, llvm::ArrayRef(param
), el
->isInverted());
874 } else if (auto *params
= dyn_cast
<ParamsDirective
>(anchor
)) {
875 guardOnAny(ctx
, os
, params
->getParams(), el
->isInverted());
876 } else if (auto *strct
= dyn_cast
<StructDirective
>(anchor
)) {
877 guardOnAny(ctx
, os
, strct
->getParams(), el
->isInverted());
879 auto *custom
= cast
<CustomDirective
>(anchor
);
881 llvm::make_filter_range(
882 llvm::map_range(custom
->getArguments(),
883 [](FormatElement
*el
) {
884 return dyn_cast
<ParameterElement
>(el
);
886 [](ParameterElement
*param
) { return !!param
; }),
889 // Generate the printer for the contained elements.
891 llvm::SaveAndRestore
shouldEmitSpaceFlag(shouldEmitSpace
);
892 llvm::SaveAndRestore
lastWasPunctuationFlag(lastWasPunctuation
);
893 for (FormatElement
*element
: el
->getThenElements())
894 genElementPrinter(element
, ctx
, os
);
896 os
.unindent() << "} else {\n";
898 for (FormatElement
*element
: el
->getElseElements())
899 genElementPrinter(element
, ctx
, os
);
900 os
.unindent() << "}\n";
903 void DefFormat::genWhitespacePrinter(WhitespaceElement
*el
, FmtContext
&ctx
,
905 if (el
->getValue() == "\\n") {
906 // FIXME: The newline should be `printer.printNewLine()`, i.e., handled by
908 os
<< tgfmt("$_printer << '\\n';\n", &ctx
);
909 } else if (!el
->getValue().empty()) {
910 os
<< tgfmt("$_printer << \"$0\";\n", &ctx
, el
->getValue());
912 lastWasPunctuation
= true;
914 shouldEmitSpace
= false;
917 //===----------------------------------------------------------------------===//
919 //===----------------------------------------------------------------------===//
922 class DefFormatParser
: public FormatParser
{
924 DefFormatParser(llvm::SourceMgr
&mgr
, const AttrOrTypeDef
&def
)
925 : FormatParser(mgr
, def
.getLoc()[0]), def(def
),
926 seenParams(def
.getNumParameters()) {}
928 /// Parse the attribute or type format and create the format elements.
929 FailureOr
<DefFormat
> parse();
932 /// Verify the parsed elements.
933 LogicalResult
verify(SMLoc loc
, ArrayRef
<FormatElement
*> elements
) override
;
934 /// Verify the elements of a custom directive.
936 verifyCustomDirectiveArguments(SMLoc loc
,
937 ArrayRef
<FormatElement
*> arguments
) override
;
938 /// Verify the elements of an optional group.
939 LogicalResult
verifyOptionalGroupElements(SMLoc loc
,
940 ArrayRef
<FormatElement
*> elements
,
941 FormatElement
*anchor
) override
;
943 /// Parse an attribute or type variable.
944 FailureOr
<FormatElement
*> parseVariableImpl(SMLoc loc
, StringRef name
,
945 Context ctx
) override
;
946 /// Parse an attribute or type format directive.
947 FailureOr
<FormatElement
*>
948 parseDirectiveImpl(SMLoc loc
, FormatToken::Kind kind
, Context ctx
) override
;
951 /// Parse a `params` directive.
952 FailureOr
<FormatElement
*> parseParamsDirective(SMLoc loc
, Context ctx
);
953 /// Parse a `qualified` directive.
954 FailureOr
<FormatElement
*> parseQualifiedDirective(SMLoc loc
, Context ctx
);
955 /// Parse a `struct` directive.
956 FailureOr
<FormatElement
*> parseStructDirective(SMLoc loc
, Context ctx
);
957 /// Parse a `ref` directive.
958 FailureOr
<FormatElement
*> parseRefDirective(SMLoc loc
, Context ctx
);
960 /// Attribute or type tablegen def.
961 const AttrOrTypeDef
&def
;
963 /// Seen attribute or type parameters.
964 BitVector seenParams
;
968 LogicalResult
DefFormatParser::verify(SMLoc loc
,
969 ArrayRef
<FormatElement
*> elements
) {
970 // Check that all parameters are referenced in the format.
971 for (auto [index
, param
] : llvm::enumerate(def
.getParameters())) {
972 if (param
.isOptional())
974 if (!seenParams
.test(index
)) {
975 if (isa
<AttributeSelfTypeParameter
>(param
))
977 return emitError(loc
, "format is missing reference to parameter: " +
980 if (isa
<AttributeSelfTypeParameter
>(param
)) {
981 return emitError(loc
,
982 "unexpected self type parameter in assembly format");
985 if (elements
.empty())
987 // A `struct` directive that contains optional parameters cannot be followed
988 // by a comma literal, which is ambiguous.
989 for (auto it
: llvm::zip(elements
.drop_back(), elements
.drop_front())) {
990 auto *structEl
= dyn_cast
<StructDirective
>(std::get
<0>(it
));
991 auto *literalEl
= dyn_cast
<LiteralElement
>(std::get
<1>(it
));
992 if (!structEl
|| !literalEl
)
994 if (literalEl
->getSpelling() == "," && structEl
->hasOptionalParams()) {
995 return emitError(loc
, "`struct` directive with optional parameters "
996 "cannot be followed by a comma literal");
1002 LogicalResult
DefFormatParser::verifyCustomDirectiveArguments(
1003 SMLoc loc
, ArrayRef
<FormatElement
*> arguments
) {
1004 // Arguments are fully verified by the parser context.
1009 DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc
,
1010 ArrayRef
<FormatElement
*> elements
,
1011 FormatElement
*anchor
) {
1012 // `params` and `struct` directives are allowed only if all the contained
1013 // parameters are optional.
1014 for (FormatElement
*el
: elements
) {
1015 if (auto *param
= dyn_cast
<ParameterElement
>(el
)) {
1016 if (!param
->isOptional()) {
1017 return emitError(loc
,
1018 "parameters in an optional group must be optional");
1020 } else if (auto *params
= dyn_cast
<ParamsDirective
>(el
)) {
1021 if (llvm::any_of(params
->getParams(), paramNotOptional
)) {
1022 return emitError(loc
, "`params` directive allowed in optional group "
1023 "only if all parameters are optional");
1025 } else if (auto *strct
= dyn_cast
<StructDirective
>(el
)) {
1026 if (llvm::any_of(strct
->getParams(), paramNotOptional
)) {
1027 return emitError(loc
, "`struct` is only allowed in an optional group "
1028 "if all captured parameters are optional");
1030 } else if (auto *custom
= dyn_cast
<CustomDirective
>(el
)) {
1031 for (FormatElement
*el
: custom
->getArguments()) {
1032 // If the custom argument is a variable, then it must be optional.
1033 if (auto *param
= dyn_cast
<ParameterElement
>(el
))
1034 if (!param
->isOptional())
1035 return emitError(loc
,
1036 "`custom` is only allowed in an optional group if "
1037 "all captured parameters are optional");
1041 // The anchor must be a parameter or one of the aforementioned directives.
1043 if (!isa
<ParameterElement
, ParamsDirective
, StructDirective
,
1044 CustomDirective
>(anchor
)) {
1046 loc
, "optional group anchor must be a parameter or directive");
1048 // If the anchor is a custom directive, make sure at least one of its
1049 // arguments is a bound parameter.
1050 if (auto *custom
= dyn_cast
<CustomDirective
>(anchor
)) {
1052 llvm::find_if(custom
->getArguments(), [](FormatElement
*el
) {
1053 return isa
<ParameterElement
>(el
);
1055 if (bound
== custom
->getArguments().end())
1056 return emitError(loc
, "`custom` directive with no bound parameters "
1057 "cannot be used as optional group anchor");
1063 FailureOr
<DefFormat
> DefFormatParser::parse() {
1064 FailureOr
<std::vector
<FormatElement
*>> elements
= FormatParser::parse();
1065 if (failed(elements
))
1067 return DefFormat(def
, std::move(*elements
));
1070 FailureOr
<FormatElement
*>
1071 DefFormatParser::parseVariableImpl(SMLoc loc
, StringRef name
, Context ctx
) {
1072 // Lookup the parameter.
1073 ArrayRef
<AttrOrTypeParameter
> params
= def
.getParameters();
1074 auto *it
= llvm::find_if(
1075 params
, [&](auto ¶m
) { return param
.getName() == name
; });
1077 // Check that the parameter reference is valid.
1078 if (it
== params
.end()) {
1079 return emitError(loc
,
1080 def
.getName() + " has no parameter named '" + name
+ "'");
1082 auto idx
= std::distance(params
.begin(), it
);
1084 if (ctx
!= RefDirectiveContext
) {
1085 // Check that the variable has not already been bound.
1086 if (seenParams
.test(idx
))
1087 return emitError(loc
, "duplicate parameter '" + name
+ "'");
1088 seenParams
.set(idx
);
1090 // Otherwise, to be referenced, a variable must have been bound.
1091 } else if (!seenParams
.test(idx
) && !isa
<AttributeSelfTypeParameter
>(*it
)) {
1092 return emitError(loc
, "parameter '" + name
+
1093 "' must be bound before it is referenced");
1096 return create
<ParameterElement
>(*it
);
1099 FailureOr
<FormatElement
*>
1100 DefFormatParser::parseDirectiveImpl(SMLoc loc
, FormatToken::Kind kind
,
1104 case FormatToken::kw_qualified
:
1105 return parseQualifiedDirective(loc
, ctx
);
1106 case FormatToken::kw_params
:
1107 return parseParamsDirective(loc
, ctx
);
1108 case FormatToken::kw_struct
:
1109 return parseStructDirective(loc
, ctx
);
1110 case FormatToken::kw_ref
:
1111 return parseRefDirective(loc
, ctx
);
1112 case FormatToken::kw_custom
:
1113 return parseCustomDirective(loc
, ctx
);
1116 return emitError(loc
, "unsupported directive kind");
1120 FailureOr
<FormatElement
*>
1121 DefFormatParser::parseQualifiedDirective(SMLoc loc
, Context ctx
) {
1122 if (failed(parseToken(FormatToken::l_paren
,
1123 "expected '(' before argument list")))
1125 FailureOr
<FormatElement
*> var
= parseElement(ctx
);
1128 if (!isa
<ParameterElement
>(*var
))
1129 return emitError(loc
, "`qualified` argument list expected a variable");
1130 cast
<ParameterElement
>(*var
)->setShouldBeQualified();
1132 parseToken(FormatToken::r_paren
, "expected ')' after argument list")))
1137 FailureOr
<FormatElement
*> DefFormatParser::parseParamsDirective(SMLoc loc
,
1139 // It doesn't make sense to allow references to all parameters in a custom
1140 // directive because parameters are the only things that can be bound.
1141 if (ctx
!= TopLevelContext
&& ctx
!= StructDirectiveContext
) {
1142 return emitError(loc
, "`params` can only be used at the top-level context "
1143 "or within a `struct` directive");
1146 // Collect all of the attribute's or type's parameters and ensure that none of
1147 // the parameters have already been captured.
1148 std::vector
<ParameterElement
*> vars
;
1149 for (const auto &it
: llvm::enumerate(def
.getParameters())) {
1150 if (seenParams
.test(it
.index())) {
1151 return emitError(loc
, "`params` captures duplicate parameter: " +
1152 it
.value().getName());
1154 // Self-type parameters are handled separately from the rest of the
1156 if (isa
<AttributeSelfTypeParameter
>(it
.value()))
1158 seenParams
.set(it
.index());
1159 vars
.push_back(create
<ParameterElement
>(it
.value()));
1161 return create
<ParamsDirective
>(std::move(vars
));
1164 FailureOr
<FormatElement
*> DefFormatParser::parseStructDirective(SMLoc loc
,
1166 if (ctx
!= TopLevelContext
)
1167 return emitError(loc
, "`struct` can only be used at the top-level context");
1169 if (failed(parseToken(FormatToken::l_paren
,
1170 "expected '(' before `struct` argument list")))
1173 // Parse variables captured by `struct`.
1174 std::vector
<ParameterElement
*> vars
;
1176 // Parse first captured parameter or a `params` directive.
1177 FailureOr
<FormatElement
*> var
= parseElement(StructDirectiveContext
);
1178 if (failed(var
) || !isa
<VariableElement
, ParamsDirective
>(*var
)) {
1179 return emitError(loc
,
1180 "`struct` argument list expected a variable or directive");
1182 if (isa
<VariableElement
>(*var
)) {
1183 // Parse any other parameters.
1184 vars
.push_back(cast
<ParameterElement
>(*var
));
1185 while (peekToken().is(FormatToken::comma
)) {
1187 var
= parseElement(StructDirectiveContext
);
1188 if (failed(var
) || !isa
<VariableElement
>(*var
))
1189 return emitError(loc
, "expected a variable in `struct` argument list");
1190 vars
.push_back(cast
<ParameterElement
>(*var
));
1193 // `struct(params)` captures all parameters in the attribute or type.
1194 vars
= cast
<ParamsDirective
>(*var
)->takeParams();
1197 if (failed(parseToken(FormatToken::r_paren
,
1198 "expected ')' at the end of an argument list")))
1201 return create
<StructDirective
>(std::move(vars
));
1204 FailureOr
<FormatElement
*> DefFormatParser::parseRefDirective(SMLoc loc
,
1206 if (ctx
!= CustomDirectiveContext
)
1207 return emitError(loc
, "`ref` is only allowed inside custom directives");
1209 // Parse the child parameter element.
1210 FailureOr
<FormatElement
*> child
;
1211 if (failed(parseToken(FormatToken::l_paren
, "expected '('")) ||
1212 failed(child
= parseElement(RefDirectiveContext
)) ||
1213 failed(parseToken(FormatToken::r_paren
, "expeced ')'")))
1216 // Only parameter elements are allowed to be parsed under a `ref` directive.
1217 return create
<RefDirective
>(*child
);
1220 //===----------------------------------------------------------------------===//
1222 //===----------------------------------------------------------------------===//
1224 void mlir::tblgen::generateAttrOrTypeFormat(const AttrOrTypeDef
&def
,
1226 MethodBody
&printer
) {
1227 llvm::SourceMgr mgr
;
1228 mgr
.AddNewSourceBuffer(
1229 llvm::MemoryBuffer::getMemBuffer(*def
.getAssemblyFormat()), SMLoc());
1231 // Parse the custom assembly format>
1232 DefFormatParser
fmtParser(mgr
, def
);
1233 FailureOr
<DefFormat
> format
= fmtParser
.parse();
1234 if (failed(format
)) {
1235 if (formatErrorIsFatal
)
1236 PrintFatalError(def
.getLoc(), "failed to parse assembly format");
1240 // Generate the parser and printer.
1241 format
->genParser(parser
);
1242 format
->genPrinter(printer
);