[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / mlir / tools / mlir-tblgen / OpFormatGen.cpp
blob18ca34379a71a0eba49f39aaab75cfb8bf2d300e
1 //===- OpFormatGen.cpp - MLIR operation asm format generator --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "OpFormatGen.h"
10 #include "FormatGen.h"
11 #include "OpClass.h"
12 #include "mlir/Support/LLVM.h"
13 #include "mlir/TableGen/Class.h"
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "mlir/TableGen/Trait.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallBitVector.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Signals.h"
24 #include "llvm/Support/SourceMgr.h"
25 #include "llvm/TableGen/Record.h"
27 #define DEBUG_TYPE "mlir-tblgen-opformatgen"
29 using namespace mlir;
30 using namespace mlir::tblgen;
32 //===----------------------------------------------------------------------===//
33 // VariableElement
35 namespace {
36 /// This class represents an instance of an op variable element. A variable
37 /// refers to something registered on the operation itself, e.g. an operand,
38 /// result, attribute, region, or successor.
39 template <typename VarT, VariableElement::Kind VariableKind>
40 class OpVariableElement : public VariableElementBase<VariableKind> {
41 public:
42 using Base = OpVariableElement<VarT, VariableKind>;
44 /// Create an op variable element with the variable value.
45 OpVariableElement(const VarT *var) : var(var) {}
47 /// Get the variable.
48 const VarT *getVar() { return var; }
50 protected:
51 /// The op variable, e.g. a type or attribute constraint.
52 const VarT *var;
55 /// This class represents a variable that refers to an attribute argument.
56 struct AttributeVariable
57 : public OpVariableElement<NamedAttribute, VariableElement::Attribute> {
58 using Base::Base;
60 /// Return the constant builder call for the type of this attribute, or
61 /// std::nullopt if it doesn't have one.
62 std::optional<StringRef> getTypeBuilder() const {
63 std::optional<Type> attrType = var->attr.getValueType();
64 return attrType ? attrType->getBuilderCall() : std::nullopt;
67 /// Return if this attribute refers to a UnitAttr.
68 bool isUnitAttr() const {
69 return var->attr.getBaseAttr().getAttrDefName() == "UnitAttr";
72 /// Indicate if this attribute is printed "qualified" (that is it is
73 /// prefixed with the `#dialect.mnemonic`).
74 bool shouldBeQualified() { return shouldBeQualifiedFlag; }
75 void setShouldBeQualified(bool qualified = true) {
76 shouldBeQualifiedFlag = qualified;
79 private:
80 bool shouldBeQualifiedFlag = false;
83 /// This class represents a variable that refers to an operand argument.
84 using OperandVariable =
85 OpVariableElement<NamedTypeConstraint, VariableElement::Operand>;
87 /// This class represents a variable that refers to a result.
88 using ResultVariable =
89 OpVariableElement<NamedTypeConstraint, VariableElement::Result>;
91 /// This class represents a variable that refers to a region.
92 using RegionVariable = OpVariableElement<NamedRegion, VariableElement::Region>;
94 /// This class represents a variable that refers to a successor.
95 using SuccessorVariable =
96 OpVariableElement<NamedSuccessor, VariableElement::Successor>;
98 /// This class represents a variable that refers to a property argument.
99 using PropertyVariable =
100 OpVariableElement<NamedProperty, VariableElement::Property>;
101 } // namespace
103 //===----------------------------------------------------------------------===//
104 // DirectiveElement
106 namespace {
107 /// This class represents the `operands` directive. This directive represents
108 /// all of the operands of an operation.
109 using OperandsDirective = DirectiveElementBase<DirectiveElement::Operands>;
111 /// This class represents the `results` directive. This directive represents
112 /// all of the results of an operation.
113 using ResultsDirective = DirectiveElementBase<DirectiveElement::Results>;
115 /// This class represents the `regions` directive. This directive represents
116 /// all of the regions of an operation.
117 using RegionsDirective = DirectiveElementBase<DirectiveElement::Regions>;
119 /// This class represents the `successors` directive. This directive represents
120 /// all of the successors of an operation.
121 using SuccessorsDirective = DirectiveElementBase<DirectiveElement::Successors>;
123 /// This class represents the `attr-dict` directive. This directive represents
124 /// the attribute dictionary of the operation.
125 class AttrDictDirective
126 : public DirectiveElementBase<DirectiveElement::AttrDict> {
127 public:
128 explicit AttrDictDirective(bool withKeyword) : withKeyword(withKeyword) {}
130 /// Return whether the dictionary should be printed with the 'attributes'
131 /// keyword.
132 bool isWithKeyword() const { return withKeyword; }
134 private:
135 /// If the dictionary should be printed with the 'attributes' keyword.
136 bool withKeyword;
139 /// This class represents the `prop-dict` directive. This directive represents
140 /// the properties of the operation, expressed as a directionary.
141 class PropDictDirective
142 : public DirectiveElementBase<DirectiveElement::PropDict> {
143 public:
144 explicit PropDictDirective() = default;
147 /// This class represents the `functional-type` directive. This directive takes
148 /// two arguments and formats them, respectively, as the inputs and results of a
149 /// FunctionType.
150 class FunctionalTypeDirective
151 : public DirectiveElementBase<DirectiveElement::FunctionalType> {
152 public:
153 FunctionalTypeDirective(FormatElement *inputs, FormatElement *results)
154 : inputs(inputs), results(results) {}
156 FormatElement *getInputs() const { return inputs; }
157 FormatElement *getResults() const { return results; }
159 private:
160 /// The input and result arguments.
161 FormatElement *inputs, *results;
164 /// This class represents the `type` directive.
165 class TypeDirective : public DirectiveElementBase<DirectiveElement::Type> {
166 public:
167 TypeDirective(FormatElement *arg) : arg(arg) {}
169 FormatElement *getArg() const { return arg; }
171 /// Indicate if this type is printed "qualified" (that is it is
172 /// prefixed with the `!dialect.mnemonic`).
173 bool shouldBeQualified() { return shouldBeQualifiedFlag; }
174 void setShouldBeQualified(bool qualified = true) {
175 shouldBeQualifiedFlag = qualified;
178 private:
179 /// The argument that is used to format the directive.
180 FormatElement *arg;
182 bool shouldBeQualifiedFlag = false;
185 /// This class represents a group of order-independent optional clauses. Each
186 /// clause starts with a literal element and has a coressponding parsing
187 /// element. A parsing element is a continous sequence of format elements.
188 /// Each clause can appear 0 or 1 time.
189 class OIListElement : public DirectiveElementBase<DirectiveElement::OIList> {
190 public:
191 OIListElement(std::vector<FormatElement *> &&literalElements,
192 std::vector<std::vector<FormatElement *>> &&parsingElements)
193 : literalElements(std::move(literalElements)),
194 parsingElements(std::move(parsingElements)) {}
196 /// Returns a range to iterate over the LiteralElements.
197 auto getLiteralElements() const {
198 function_ref<LiteralElement *(FormatElement * el)>
199 literalElementCastConverter =
200 [](FormatElement *el) { return cast<LiteralElement>(el); };
201 return llvm::map_range(literalElements, literalElementCastConverter);
204 /// Returns a range to iterate over the parsing elements corresponding to the
205 /// clauses.
206 ArrayRef<std::vector<FormatElement *>> getParsingElements() const {
207 return parsingElements;
210 /// Returns a range to iterate over tuples of parsing and literal elements.
211 auto getClauses() const {
212 return llvm::zip(getLiteralElements(), getParsingElements());
215 /// If the parsing element is a single UnitAttr element, then it returns the
216 /// attribute variable. Otherwise, returns nullptr.
217 AttributeVariable *
218 getUnitAttrParsingElement(ArrayRef<FormatElement *> pelement) {
219 if (pelement.size() == 1) {
220 auto *attrElem = dyn_cast<AttributeVariable>(pelement[0]);
221 if (attrElem && attrElem->isUnitAttr())
222 return attrElem;
224 return nullptr;
227 private:
228 /// A vector of `LiteralElement` objects. Each element stores the keyword
229 /// for one case of oilist element. For example, an oilist element along with
230 /// the `literalElements` vector:
231 /// ```
232 /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`]
233 /// literalElements = { `keyword`, `otherKeyword` }
234 /// ```
235 std::vector<FormatElement *> literalElements;
237 /// A vector of valid declarative assembly format vectors. Each object in
238 /// parsing elements is a vector of elements in assembly format syntax.
239 /// For example, an oilist element along with the parsingElements vector:
240 /// ```
241 /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`]
242 /// parsingElements = {
243 /// { `=`, `(`, $arg0, `)` },
244 /// { `<`, $arg1, `>` }
245 /// }
246 /// ```
247 std::vector<std::vector<FormatElement *>> parsingElements;
249 } // namespace
251 //===----------------------------------------------------------------------===//
252 // OperationFormat
253 //===----------------------------------------------------------------------===//
255 namespace {
257 using ConstArgument =
258 llvm::PointerUnion<const NamedAttribute *, const NamedTypeConstraint *>;
260 struct OperationFormat {
261 /// This class represents a specific resolver for an operand or result type.
262 class TypeResolution {
263 public:
264 TypeResolution() = default;
266 /// Get the index into the buildable types for this type, or std::nullopt.
267 std::optional<int> getBuilderIdx() const { return builderIdx; }
268 void setBuilderIdx(int idx) { builderIdx = idx; }
270 /// Get the variable this type is resolved to, or nullptr.
271 const NamedTypeConstraint *getVariable() const {
272 return llvm::dyn_cast_if_present<const NamedTypeConstraint *>(resolver);
274 /// Get the attribute this type is resolved to, or nullptr.
275 const NamedAttribute *getAttribute() const {
276 return llvm::dyn_cast_if_present<const NamedAttribute *>(resolver);
278 /// Get the transformer for the type of the variable, or std::nullopt.
279 std::optional<StringRef> getVarTransformer() const {
280 return variableTransformer;
282 void setResolver(ConstArgument arg, std::optional<StringRef> transformer) {
283 resolver = arg;
284 variableTransformer = transformer;
285 assert(getVariable() || getAttribute());
288 private:
289 /// If the type is resolved with a buildable type, this is the index into
290 /// 'buildableTypes' in the parent format.
291 std::optional<int> builderIdx;
292 /// If the type is resolved based upon another operand or result, this is
293 /// the variable or the attribute that this type is resolved to.
294 ConstArgument resolver;
295 /// If the type is resolved based upon another operand or result, this is
296 /// a transformer to apply to the variable when resolving.
297 std::optional<StringRef> variableTransformer;
300 /// The context in which an element is generated.
301 enum class GenContext {
302 /// The element is generated at the top-level or with the same behaviour.
303 Normal,
304 /// The element is generated inside an optional group.
305 Optional
308 OperationFormat(const Operator &op)
309 : useProperties(op.getDialect().usePropertiesForAttributes() &&
310 !op.getAttributes().empty()),
311 opCppClassName(op.getCppClassName()) {
312 operandTypes.resize(op.getNumOperands(), TypeResolution());
313 resultTypes.resize(op.getNumResults(), TypeResolution());
315 hasImplicitTermTrait = llvm::any_of(op.getTraits(), [](const Trait &trait) {
316 return trait.getDef().isSubClassOf("SingleBlockImplicitTerminatorImpl");
319 hasSingleBlockTrait = op.getTrait("::mlir::OpTrait::SingleBlock");
322 /// Generate the operation parser from this format.
323 void genParser(Operator &op, OpClass &opClass);
324 /// Generate the parser code for a specific format element.
325 void genElementParser(FormatElement *element, MethodBody &body,
326 FmtContext &attrTypeCtx,
327 GenContext genCtx = GenContext::Normal);
328 /// Generate the C++ to resolve the types of operands and results during
329 /// parsing.
330 void genParserTypeResolution(Operator &op, MethodBody &body);
331 /// Generate the C++ to resolve the types of the operands during parsing.
332 void genParserOperandTypeResolution(
333 Operator &op, MethodBody &body,
334 function_ref<void(TypeResolution &, StringRef)> emitTypeResolver);
335 /// Generate the C++ to resolve regions during parsing.
336 void genParserRegionResolution(Operator &op, MethodBody &body);
337 /// Generate the C++ to resolve successors during parsing.
338 void genParserSuccessorResolution(Operator &op, MethodBody &body);
339 /// Generate the C++ to handling variadic segment size traits.
340 void genParserVariadicSegmentResolution(Operator &op, MethodBody &body);
342 /// Generate the operation printer from this format.
343 void genPrinter(Operator &op, OpClass &opClass);
345 /// Generate the printer code for a specific format element.
346 void genElementPrinter(FormatElement *element, MethodBody &body, Operator &op,
347 bool &shouldEmitSpace, bool &lastWasPunctuation);
349 /// The various elements in this format.
350 std::vector<FormatElement *> elements;
352 /// A flag indicating if all operand/result types were seen. If the format
353 /// contains these, it can not contain individual type resolvers.
354 bool allOperands = false, allOperandTypes = false, allResultTypes = false;
356 /// A flag indicating if this operation infers its result types
357 bool infersResultTypes = false;
359 /// A flag indicating if this operation has the SingleBlockImplicitTerminator
360 /// trait.
361 bool hasImplicitTermTrait;
363 /// A flag indicating if this operation has the SingleBlock trait.
364 bool hasSingleBlockTrait;
366 /// Indicate whether attribute are stored in properties.
367 bool useProperties;
369 /// The Operation class name
370 StringRef opCppClassName;
372 /// A map of buildable types to indices.
373 llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
375 /// The index of the buildable type, if valid, for every operand and result.
376 std::vector<TypeResolution> operandTypes, resultTypes;
378 /// The set of attributes explicitly used within the format.
379 SmallVector<const NamedAttribute *, 8> usedAttributes;
380 llvm::StringSet<> inferredAttributes;
382 } // namespace
384 //===----------------------------------------------------------------------===//
385 // Parser Gen
387 /// Returns true if we can format the given attribute as an EnumAttr in the
388 /// parser format.
389 static bool canFormatEnumAttr(const NamedAttribute *attr) {
390 Attribute baseAttr = attr->attr.getBaseAttr();
391 const EnumAttr *enumAttr = dyn_cast<EnumAttr>(&baseAttr);
392 if (!enumAttr)
393 return false;
395 // The attribute must have a valid underlying type and a constant builder.
396 return !enumAttr->getUnderlyingType().empty() &&
397 !enumAttr->getConstBuilderTemplate().empty();
400 /// Returns if we should format the given attribute as an SymbolNameAttr.
401 static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) {
402 return attr->attr.getBaseAttr().getAttrDefName() == "SymbolNameAttr";
405 /// The code snippet used to generate a parser call for an attribute.
407 /// {0}: The name of the attribute.
408 /// {1}: The type for the attribute.
409 const char *const attrParserCode = R"(
410 if (parser.parseCustomAttributeWithFallback({0}Attr, {1})) {{
411 return ::mlir::failure();
415 /// The code snippet used to generate a parser call for an attribute.
417 /// {0}: The name of the attribute.
418 /// {1}: The type for the attribute.
419 const char *const genericAttrParserCode = R"(
420 if (parser.parseAttribute({0}Attr, {1}))
421 return ::mlir::failure();
424 const char *const optionalAttrParserCode = R"(
425 ::mlir::OptionalParseResult parseResult{0}Attr =
426 parser.parseOptionalAttribute({0}Attr, {1});
427 if (parseResult{0}Attr.has_value() && failed(*parseResult{0}Attr))
428 return ::mlir::failure();
429 if (parseResult{0}Attr.has_value() && succeeded(*parseResult{0}Attr))
432 /// The code snippet used to generate a parser call for a symbol name attribute.
434 /// {0}: The name of the attribute.
435 const char *const symbolNameAttrParserCode = R"(
436 if (parser.parseSymbolName({0}Attr))
437 return ::mlir::failure();
439 const char *const optionalSymbolNameAttrParserCode = R"(
440 // Parsing an optional symbol name doesn't fail, so no need to check the
441 // result.
442 (void)parser.parseOptionalSymbolName({0}Attr);
445 /// The code snippet used to generate a parser call for an enum attribute.
447 /// {0}: The name of the attribute.
448 /// {1}: The c++ namespace for the enum symbolize functions.
449 /// {2}: The function to symbolize a string of the enum.
450 /// {3}: The constant builder call to create an attribute of the enum type.
451 /// {4}: The set of allowed enum keywords.
452 /// {5}: The error message on failure when the enum isn't present.
453 /// {6}: The attribute assignment expression
454 const char *const enumAttrParserCode = R"(
456 ::llvm::StringRef attrStr;
457 ::mlir::NamedAttrList attrStorage;
458 auto loc = parser.getCurrentLocation();
459 if (parser.parseOptionalKeyword(&attrStr, {4})) {
460 ::mlir::StringAttr attrVal;
461 ::mlir::OptionalParseResult parseResult =
462 parser.parseOptionalAttribute(attrVal,
463 parser.getBuilder().getNoneType(),
464 "{0}", attrStorage);
465 if (parseResult.has_value()) {{
466 if (failed(*parseResult))
467 return ::mlir::failure();
468 attrStr = attrVal.getValue();
469 } else {
473 if (!attrStr.empty()) {
474 auto attrOptional = {1}::{2}(attrStr);
475 if (!attrOptional)
476 return parser.emitError(loc, "invalid ")
477 << "{0} attribute specification: \"" << attrStr << '"';;
479 {0}Attr = {3};
485 /// The code snippet used to generate a parser call for an operand.
487 /// {0}: The name of the operand.
488 const char *const variadicOperandParserCode = R"(
489 {0}OperandsLoc = parser.getCurrentLocation();
490 if (parser.parseOperandList({0}Operands))
491 return ::mlir::failure();
493 const char *const optionalOperandParserCode = R"(
495 {0}OperandsLoc = parser.getCurrentLocation();
496 ::mlir::OpAsmParser::UnresolvedOperand operand;
497 ::mlir::OptionalParseResult parseResult =
498 parser.parseOptionalOperand(operand);
499 if (parseResult.has_value()) {
500 if (failed(*parseResult))
501 return ::mlir::failure();
502 {0}Operands.push_back(operand);
506 const char *const operandParserCode = R"(
507 {0}OperandsLoc = parser.getCurrentLocation();
508 if (parser.parseOperand({0}RawOperands[0]))
509 return ::mlir::failure();
511 /// The code snippet used to generate a parser call for a VariadicOfVariadic
512 /// operand.
514 /// {0}: The name of the operand.
515 /// {1}: The name of segment size attribute.
516 const char *const variadicOfVariadicOperandParserCode = R"(
518 {0}OperandsLoc = parser.getCurrentLocation();
519 int32_t curSize = 0;
520 do {
521 if (parser.parseOptionalLParen())
522 break;
523 if (parser.parseOperandList({0}Operands) || parser.parseRParen())
524 return ::mlir::failure();
525 {0}OperandGroupSizes.push_back({0}Operands.size() - curSize);
526 curSize = {0}Operands.size();
527 } while (succeeded(parser.parseOptionalComma()));
531 /// The code snippet used to generate a parser call for a type list.
533 /// {0}: The name for the type list.
534 const char *const variadicOfVariadicTypeParserCode = R"(
535 do {
536 if (parser.parseOptionalLParen())
537 break;
538 if (parser.parseOptionalRParen() &&
539 (parser.parseTypeList({0}Types) || parser.parseRParen()))
540 return ::mlir::failure();
541 } while (succeeded(parser.parseOptionalComma()));
543 const char *const variadicTypeParserCode = R"(
544 if (parser.parseTypeList({0}Types))
545 return ::mlir::failure();
547 const char *const optionalTypeParserCode = R"(
549 ::mlir::Type optionalType;
550 ::mlir::OptionalParseResult parseResult =
551 parser.parseOptionalType(optionalType);
552 if (parseResult.has_value()) {
553 if (failed(*parseResult))
554 return ::mlir::failure();
555 {0}Types.push_back(optionalType);
559 const char *const typeParserCode = R"(
561 {0} type;
562 if (parser.parseCustomTypeWithFallback(type))
563 return ::mlir::failure();
564 {1}RawTypes[0] = type;
567 const char *const qualifiedTypeParserCode = R"(
568 if (parser.parseType({1}RawTypes[0]))
569 return ::mlir::failure();
572 /// The code snippet used to generate a parser call for a functional type.
574 /// {0}: The name for the input type list.
575 /// {1}: The name for the result type list.
576 const char *const functionalTypeParserCode = R"(
577 ::mlir::FunctionType {0}__{1}_functionType;
578 if (parser.parseType({0}__{1}_functionType))
579 return ::mlir::failure();
580 {0}Types = {0}__{1}_functionType.getInputs();
581 {1}Types = {0}__{1}_functionType.getResults();
584 /// The code snippet used to generate a parser call to infer return types.
586 /// {0}: The operation class name
587 const char *const inferReturnTypesParserCode = R"(
588 ::llvm::SmallVector<::mlir::Type> inferredReturnTypes;
589 if (::mlir::failed({0}::inferReturnTypes(parser.getContext(),
590 result.location, result.operands,
591 result.attributes.getDictionary(parser.getContext()),
592 result.getRawProperties(),
593 result.regions, inferredReturnTypes)))
594 return ::mlir::failure();
595 result.addTypes(inferredReturnTypes);
598 /// The code snippet used to generate a parser call for a region list.
600 /// {0}: The name for the region list.
601 const char *regionListParserCode = R"(
603 std::unique_ptr<::mlir::Region> region;
604 auto firstRegionResult = parser.parseOptionalRegion(region);
605 if (firstRegionResult.has_value()) {
606 if (failed(*firstRegionResult))
607 return ::mlir::failure();
608 {0}Regions.emplace_back(std::move(region));
610 // Parse any trailing regions.
611 while (succeeded(parser.parseOptionalComma())) {
612 region = std::make_unique<::mlir::Region>();
613 if (parser.parseRegion(*region))
614 return ::mlir::failure();
615 {0}Regions.emplace_back(std::move(region));
621 /// The code snippet used to ensure a list of regions have terminators.
623 /// {0}: The name of the region list.
624 const char *regionListEnsureTerminatorParserCode = R"(
625 for (auto &region : {0}Regions)
626 ensureTerminator(*region, parser.getBuilder(), result.location);
629 /// The code snippet used to ensure a list of regions have a block.
631 /// {0}: The name of the region list.
632 const char *regionListEnsureSingleBlockParserCode = R"(
633 for (auto &region : {0}Regions)
634 if (region->empty()) region->emplaceBlock();
637 /// The code snippet used to generate a parser call for an optional region.
639 /// {0}: The name of the region.
640 const char *optionalRegionParserCode = R"(
642 auto parseResult = parser.parseOptionalRegion(*{0}Region);
643 if (parseResult.has_value() && failed(*parseResult))
644 return ::mlir::failure();
648 /// The code snippet used to generate a parser call for a region.
650 /// {0}: The name of the region.
651 const char *regionParserCode = R"(
652 if (parser.parseRegion(*{0}Region))
653 return ::mlir::failure();
656 /// The code snippet used to ensure a region has a terminator.
658 /// {0}: The name of the region.
659 const char *regionEnsureTerminatorParserCode = R"(
660 ensureTerminator(*{0}Region, parser.getBuilder(), result.location);
663 /// The code snippet used to ensure a region has a block.
665 /// {0}: The name of the region.
666 const char *regionEnsureSingleBlockParserCode = R"(
667 if ({0}Region->empty()) {0}Region->emplaceBlock();
670 /// The code snippet used to generate a parser call for a successor list.
672 /// {0}: The name for the successor list.
673 const char *successorListParserCode = R"(
675 ::mlir::Block *succ;
676 auto firstSucc = parser.parseOptionalSuccessor(succ);
677 if (firstSucc.has_value()) {
678 if (failed(*firstSucc))
679 return ::mlir::failure();
680 {0}Successors.emplace_back(succ);
682 // Parse any trailing successors.
683 while (succeeded(parser.parseOptionalComma())) {
684 if (parser.parseSuccessor(succ))
685 return ::mlir::failure();
686 {0}Successors.emplace_back(succ);
692 /// The code snippet used to generate a parser call for a successor.
694 /// {0}: The name of the successor.
695 const char *successorParserCode = R"(
696 if (parser.parseSuccessor({0}Successor))
697 return ::mlir::failure();
700 /// The code snippet used to generate a parser for OIList
702 /// {0}: literal keyword corresponding to a case for oilist
703 const char *oilistParserCode = R"(
704 if ({0}Clause) {
705 return parser.emitError(parser.getNameLoc())
706 << "`{0}` clause can appear at most once in the expansion of the "
707 "oilist directive";
709 {0}Clause = true;
712 namespace {
713 /// The type of length for a given parse argument.
714 enum class ArgumentLengthKind {
715 /// The argument is a variadic of a variadic, and may contain 0->N range
716 /// elements.
717 VariadicOfVariadic,
718 /// The argument is variadic, and may contain 0->N elements.
719 Variadic,
720 /// The argument is optional, and may contain 0 or 1 elements.
721 Optional,
722 /// The argument is a single element, i.e. always represents 1 element.
723 Single
725 } // namespace
727 /// Get the length kind for the given constraint.
728 static ArgumentLengthKind
729 getArgumentLengthKind(const NamedTypeConstraint *var) {
730 if (var->isOptional())
731 return ArgumentLengthKind::Optional;
732 if (var->isVariadicOfVariadic())
733 return ArgumentLengthKind::VariadicOfVariadic;
734 if (var->isVariadic())
735 return ArgumentLengthKind::Variadic;
736 return ArgumentLengthKind::Single;
739 /// Get the name used for the type list for the given type directive operand.
740 /// 'lengthKind' to the corresponding kind for the given argument.
741 static StringRef getTypeListName(FormatElement *arg,
742 ArgumentLengthKind &lengthKind) {
743 if (auto *operand = dyn_cast<OperandVariable>(arg)) {
744 lengthKind = getArgumentLengthKind(operand->getVar());
745 return operand->getVar()->name;
747 if (auto *result = dyn_cast<ResultVariable>(arg)) {
748 lengthKind = getArgumentLengthKind(result->getVar());
749 return result->getVar()->name;
751 lengthKind = ArgumentLengthKind::Variadic;
752 if (isa<OperandsDirective>(arg))
753 return "allOperand";
754 if (isa<ResultsDirective>(arg))
755 return "allResult";
756 llvm_unreachable("unknown 'type' directive argument");
759 /// Generate the parser for a literal value.
760 static void genLiteralParser(StringRef value, MethodBody &body) {
761 // Handle the case of a keyword/identifier.
762 if (value.front() == '_' || isalpha(value.front())) {
763 body << "Keyword(\"" << value << "\")";
764 return;
766 body << (StringRef)StringSwitch<StringRef>(value)
767 .Case("->", "Arrow()")
768 .Case(":", "Colon()")
769 .Case(",", "Comma()")
770 .Case("=", "Equal()")
771 .Case("<", "Less()")
772 .Case(">", "Greater()")
773 .Case("{", "LBrace()")
774 .Case("}", "RBrace()")
775 .Case("(", "LParen()")
776 .Case(")", "RParen()")
777 .Case("[", "LSquare()")
778 .Case("]", "RSquare()")
779 .Case("?", "Question()")
780 .Case("+", "Plus()")
781 .Case("*", "Star()")
782 .Case("...", "Ellipsis()");
785 /// Generate the storage code required for parsing the given element.
786 static void genElementParserStorage(FormatElement *element, const Operator &op,
787 MethodBody &body) {
788 if (auto *optional = dyn_cast<OptionalElement>(element)) {
789 ArrayRef<FormatElement *> elements = optional->getThenElements();
791 // If the anchor is a unit attribute, it won't be parsed directly so elide
792 // it.
793 auto *anchor = dyn_cast<AttributeVariable>(optional->getAnchor());
794 FormatElement *elidedAnchorElement = nullptr;
795 if (anchor && anchor != elements.front() && anchor->isUnitAttr())
796 elidedAnchorElement = anchor;
797 for (FormatElement *childElement : elements)
798 if (childElement != elidedAnchorElement)
799 genElementParserStorage(childElement, op, body);
800 for (FormatElement *childElement : optional->getElseElements())
801 genElementParserStorage(childElement, op, body);
803 } else if (auto *oilist = dyn_cast<OIListElement>(element)) {
804 for (ArrayRef<FormatElement *> pelement : oilist->getParsingElements()) {
805 if (!oilist->getUnitAttrParsingElement(pelement))
806 for (FormatElement *element : pelement)
807 genElementParserStorage(element, op, body);
810 } else if (auto *custom = dyn_cast<CustomDirective>(element)) {
811 for (FormatElement *paramElement : custom->getArguments())
812 genElementParserStorage(paramElement, op, body);
814 } else if (isa<OperandsDirective>(element)) {
815 body << " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> "
816 "allOperands;\n";
818 } else if (isa<RegionsDirective>(element)) {
819 body << " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
820 "fullRegions;\n";
822 } else if (isa<SuccessorsDirective>(element)) {
823 body << " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
825 } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
826 const NamedAttribute *var = attr->getVar();
827 body << llvm::formatv(" {0} {1}Attr;\n", var->attr.getStorageType(),
828 var->name);
830 } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
831 StringRef name = operand->getVar()->name;
832 if (operand->getVar()->isVariableLength()) {
833 body
834 << " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> "
835 << name << "Operands;\n";
836 if (operand->getVar()->isVariadicOfVariadic()) {
837 body << " llvm::SmallVector<int32_t> " << name
838 << "OperandGroupSizes;\n";
840 } else {
841 body << " ::mlir::OpAsmParser::UnresolvedOperand " << name
842 << "RawOperands[1];\n"
843 << " ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> "
844 << name << "Operands(" << name << "RawOperands);";
846 body << llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
847 " (void){0}OperandsLoc;\n",
848 name);
850 } else if (auto *region = dyn_cast<RegionVariable>(element)) {
851 StringRef name = region->getVar()->name;
852 if (region->getVar()->isVariadic()) {
853 body << llvm::formatv(
854 " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
855 "{0}Regions;\n",
856 name);
857 } else {
858 body << llvm::formatv(" std::unique_ptr<::mlir::Region> {0}Region = "
859 "std::make_unique<::mlir::Region>();\n",
860 name);
863 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
864 StringRef name = successor->getVar()->name;
865 if (successor->getVar()->isVariadic()) {
866 body << llvm::formatv(" ::llvm::SmallVector<::mlir::Block *, 2> "
867 "{0}Successors;\n",
868 name);
869 } else {
870 body << llvm::formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name);
873 } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
874 ArgumentLengthKind lengthKind;
875 StringRef name = getTypeListName(dir->getArg(), lengthKind);
876 if (lengthKind != ArgumentLengthKind::Single)
877 body << " ::llvm::SmallVector<::mlir::Type, 1> " << name << "Types;\n";
878 else
879 body << llvm::formatv(" ::mlir::Type {0}RawTypes[1];\n", name)
880 << llvm::formatv(
881 " ::llvm::ArrayRef<::mlir::Type> {0}Types({0}RawTypes);\n",
882 name);
883 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
884 ArgumentLengthKind ignored;
885 body << " ::llvm::ArrayRef<::mlir::Type> "
886 << getTypeListName(dir->getInputs(), ignored) << "Types;\n";
887 body << " ::llvm::ArrayRef<::mlir::Type> "
888 << getTypeListName(dir->getResults(), ignored) << "Types;\n";
892 /// Generate the parser for a parameter to a custom directive.
893 static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
894 if (auto *attr = dyn_cast<AttributeVariable>(param)) {
895 body << attr->getVar()->name << "Attr";
896 } else if (isa<AttrDictDirective>(param)) {
897 body << "result.attributes";
898 } else if (auto *operand = dyn_cast<OperandVariable>(param)) {
899 StringRef name = operand->getVar()->name;
900 ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
901 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
902 body << llvm::formatv("{0}OperandGroups", name);
903 else if (lengthKind == ArgumentLengthKind::Variadic)
904 body << llvm::formatv("{0}Operands", name);
905 else if (lengthKind == ArgumentLengthKind::Optional)
906 body << llvm::formatv("{0}Operand", name);
907 else
908 body << formatv("{0}RawOperands[0]", name);
910 } else if (auto *region = dyn_cast<RegionVariable>(param)) {
911 StringRef name = region->getVar()->name;
912 if (region->getVar()->isVariadic())
913 body << llvm::formatv("{0}Regions", name);
914 else
915 body << llvm::formatv("*{0}Region", name);
917 } else if (auto *successor = dyn_cast<SuccessorVariable>(param)) {
918 StringRef name = successor->getVar()->name;
919 if (successor->getVar()->isVariadic())
920 body << llvm::formatv("{0}Successors", name);
921 else
922 body << llvm::formatv("{0}Successor", name);
924 } else if (auto *dir = dyn_cast<RefDirective>(param)) {
925 genCustomParameterParser(dir->getArg(), body);
927 } else if (auto *dir = dyn_cast<TypeDirective>(param)) {
928 ArgumentLengthKind lengthKind;
929 StringRef listName = getTypeListName(dir->getArg(), lengthKind);
930 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
931 body << llvm::formatv("{0}TypeGroups", listName);
932 else if (lengthKind == ArgumentLengthKind::Variadic)
933 body << llvm::formatv("{0}Types", listName);
934 else if (lengthKind == ArgumentLengthKind::Optional)
935 body << llvm::formatv("{0}Type", listName);
936 else
937 body << formatv("{0}RawTypes[0]", listName);
939 } else if (auto *string = dyn_cast<StringElement>(param)) {
940 FmtContext ctx;
941 ctx.withBuilder("parser.getBuilder()");
942 ctx.addSubst("_ctxt", "parser.getContext()");
943 body << tgfmt(string->getValue(), &ctx);
945 } else if (auto *property = dyn_cast<PropertyVariable>(param)) {
946 body << llvm::formatv("result.getOrAddProperties<Properties>().{0}",
947 property->getVar()->name);
948 } else {
949 llvm_unreachable("unknown custom directive parameter");
953 /// Generate the parser for a custom directive.
954 static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
955 bool useProperties,
956 StringRef opCppClassName,
957 bool isOptional = false) {
958 body << " {\n";
960 // Preprocess the directive variables.
961 // * Add a local variable for optional operands and types. This provides a
962 // better API to the user defined parser methods.
963 // * Set the location of operand variables.
964 for (FormatElement *param : dir->getArguments()) {
965 if (auto *operand = dyn_cast<OperandVariable>(param)) {
966 auto *var = operand->getVar();
967 body << " " << var->name
968 << "OperandsLoc = parser.getCurrentLocation();\n";
969 if (var->isOptional()) {
970 body << llvm::formatv(
971 " ::std::optional<::mlir::OpAsmParser::UnresolvedOperand> "
972 "{0}Operand;\n",
973 var->name);
974 } else if (var->isVariadicOfVariadic()) {
975 body << llvm::formatv(" "
976 "::llvm::SmallVector<::llvm::SmallVector<::mlir::"
977 "OpAsmParser::UnresolvedOperand>> "
978 "{0}OperandGroups;\n",
979 var->name);
981 } else if (auto *dir = dyn_cast<TypeDirective>(param)) {
982 ArgumentLengthKind lengthKind;
983 StringRef listName = getTypeListName(dir->getArg(), lengthKind);
984 if (lengthKind == ArgumentLengthKind::Optional) {
985 body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName);
986 } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
987 body << llvm::formatv(
988 " ::llvm::SmallVector<llvm::SmallVector<::mlir::Type>> "
989 "{0}TypeGroups;\n",
990 listName);
992 } else if (auto *dir = dyn_cast<RefDirective>(param)) {
993 FormatElement *input = dir->getArg();
994 if (auto *operand = dyn_cast<OperandVariable>(input)) {
995 if (!operand->getVar()->isOptional())
996 continue;
997 body << llvm::formatv(
998 " {0} {1}Operand = {1}Operands.empty() ? {0}() : "
999 "{1}Operands[0];\n",
1000 "::std::optional<::mlir::OpAsmParser::UnresolvedOperand>",
1001 operand->getVar()->name);
1003 } else if (auto *type = dyn_cast<TypeDirective>(input)) {
1004 ArgumentLengthKind lengthKind;
1005 StringRef listName = getTypeListName(type->getArg(), lengthKind);
1006 if (lengthKind == ArgumentLengthKind::Optional) {
1007 body << llvm::formatv(" ::mlir::Type {0}Type = {0}Types.empty() ? "
1008 "::mlir::Type() : {0}Types[0];\n",
1009 listName);
1015 body << " auto odsResult = parse" << dir->getName() << "(parser";
1016 for (FormatElement *param : dir->getArguments()) {
1017 body << ", ";
1018 genCustomParameterParser(param, body);
1020 body << ");\n";
1022 if (isOptional) {
1023 body << " if (!odsResult) return {};\n"
1024 << " if (::mlir::failed(*odsResult)) return ::mlir::failure();\n";
1025 } else {
1026 body << " if (odsResult) return ::mlir::failure();\n";
1029 // After parsing, add handling for any of the optional constructs.
1030 for (FormatElement *param : dir->getArguments()) {
1031 if (auto *attr = dyn_cast<AttributeVariable>(param)) {
1032 const NamedAttribute *var = attr->getVar();
1033 if (var->attr.isOptional() || var->attr.hasDefaultValue())
1034 body << llvm::formatv(" if ({0}Attr)\n ", var->name);
1035 if (useProperties) {
1036 body << formatv(
1037 " result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;\n",
1038 var->name, opCppClassName);
1039 } else {
1040 body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
1041 var->name);
1044 } else if (auto *operand = dyn_cast<OperandVariable>(param)) {
1045 const NamedTypeConstraint *var = operand->getVar();
1046 if (var->isOptional()) {
1047 body << llvm::formatv(" if ({0}Operand.has_value())\n"
1048 " {0}Operands.push_back(*{0}Operand);\n",
1049 var->name);
1050 } else if (var->isVariadicOfVariadic()) {
1051 body << llvm::formatv(
1052 " for (const auto &subRange : {0}OperandGroups) {{\n"
1053 " {0}Operands.append(subRange.begin(), subRange.end());\n"
1054 " {0}OperandGroupSizes.push_back(subRange.size());\n"
1055 " }\n",
1056 var->name, var->constraint.getVariadicOfVariadicSegmentSizeAttr());
1058 } else if (auto *dir = dyn_cast<TypeDirective>(param)) {
1059 ArgumentLengthKind lengthKind;
1060 StringRef listName = getTypeListName(dir->getArg(), lengthKind);
1061 if (lengthKind == ArgumentLengthKind::Optional) {
1062 body << llvm::formatv(" if ({0}Type)\n"
1063 " {0}Types.push_back({0}Type);\n",
1064 listName);
1065 } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
1066 body << llvm::formatv(
1067 " for (const auto &subRange : {0}TypeGroups)\n"
1068 " {0}Types.append(subRange.begin(), subRange.end());\n",
1069 listName);
1074 body << " }\n";
1077 /// Generate the parser for a enum attribute.
1078 static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
1079 FmtContext &attrTypeCtx, bool parseAsOptional,
1080 bool useProperties, StringRef opCppClassName) {
1081 Attribute baseAttr = var->attr.getBaseAttr();
1082 const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
1083 std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
1085 // Generate the code for building an attribute for this enum.
1086 std::string attrBuilderStr;
1088 llvm::raw_string_ostream os(attrBuilderStr);
1089 os << tgfmt(enumAttr.getConstBuilderTemplate(), &attrTypeCtx,
1090 "*attrOptional");
1093 // Build a string containing the cases that can be formatted as a keyword.
1094 std::string validCaseKeywordsStr = "{";
1095 llvm::raw_string_ostream validCaseKeywordsOS(validCaseKeywordsStr);
1096 for (const EnumAttrCase &attrCase : cases)
1097 if (canFormatStringAsKeyword(attrCase.getStr()))
1098 validCaseKeywordsOS << '"' << attrCase.getStr() << "\",";
1099 validCaseKeywordsOS.str().back() = '}';
1101 // If the attribute is not optional, build an error message for the missing
1102 // attribute.
1103 std::string errorMessage;
1104 if (!parseAsOptional) {
1105 llvm::raw_string_ostream errorMessageOS(errorMessage);
1106 errorMessageOS
1107 << "return parser.emitError(loc, \"expected string or "
1108 "keyword containing one of the following enum values for attribute '"
1109 << var->name << "' [";
1110 llvm::interleaveComma(cases, errorMessageOS, [&](const auto &attrCase) {
1111 errorMessageOS << attrCase.getStr();
1113 errorMessageOS << "]\");";
1115 std::string attrAssignment;
1116 if (useProperties) {
1117 attrAssignment =
1118 formatv(" "
1119 "result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;",
1120 var->name, opCppClassName);
1121 } else {
1122 attrAssignment =
1123 formatv("result.addAttribute(\"{0}\", {0}Attr);", var->name);
1126 body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(),
1127 enumAttr.getStringToSymbolFnName(), attrBuilderStr,
1128 validCaseKeywordsStr, errorMessage, attrAssignment);
1131 // Generate the parser for an attribute.
1132 static void genAttrParser(AttributeVariable *attr, MethodBody &body,
1133 FmtContext &attrTypeCtx, bool parseAsOptional,
1134 bool useProperties, StringRef opCppClassName) {
1135 const NamedAttribute *var = attr->getVar();
1137 // Check to see if we can parse this as an enum attribute.
1138 if (canFormatEnumAttr(var))
1139 return genEnumAttrParser(var, body, attrTypeCtx, parseAsOptional,
1140 useProperties, opCppClassName);
1142 // Check to see if we should parse this as a symbol name attribute.
1143 if (shouldFormatSymbolNameAttr(var)) {
1144 body << formatv(parseAsOptional ? optionalSymbolNameAttrParserCode
1145 : symbolNameAttrParserCode,
1146 var->name);
1147 } else {
1149 // If this attribute has a buildable type, use that when parsing the
1150 // attribute.
1151 std::string attrTypeStr;
1152 if (std::optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
1153 llvm::raw_string_ostream os(attrTypeStr);
1154 os << tgfmt(*typeBuilder, &attrTypeCtx);
1155 } else {
1156 attrTypeStr = "::mlir::Type{}";
1158 if (parseAsOptional) {
1159 body << formatv(optionalAttrParserCode, var->name, attrTypeStr);
1160 } else {
1161 if (attr->shouldBeQualified() ||
1162 var->attr.getStorageType() == "::mlir::Attribute")
1163 body << formatv(genericAttrParserCode, var->name, attrTypeStr);
1164 else
1165 body << formatv(attrParserCode, var->name, attrTypeStr);
1168 if (useProperties) {
1169 body << formatv(
1170 " if ({0}Attr) result.getOrAddProperties<{1}::Properties>().{0} = "
1171 "{0}Attr;\n",
1172 var->name, opCppClassName);
1173 } else {
1174 body << formatv(
1175 " if ({0}Attr) result.attributes.append(\"{0}\", {0}Attr);\n",
1176 var->name);
1180 void OperationFormat::genParser(Operator &op, OpClass &opClass) {
1181 SmallVector<MethodParameter> paramList;
1182 paramList.emplace_back("::mlir::OpAsmParser &", "parser");
1183 paramList.emplace_back("::mlir::OperationState &", "result");
1185 auto *method = opClass.addStaticMethod("::mlir::ParseResult", "parse",
1186 std::move(paramList));
1187 auto &body = method->body();
1189 // Generate variables to store the operands and type within the format. This
1190 // allows for referencing these variables in the presence of optional
1191 // groupings.
1192 for (FormatElement *element : elements)
1193 genElementParserStorage(element, op, body);
1195 // A format context used when parsing attributes with buildable types.
1196 FmtContext attrTypeCtx;
1197 attrTypeCtx.withBuilder("parser.getBuilder()");
1199 // Generate parsers for each of the elements.
1200 for (FormatElement *element : elements)
1201 genElementParser(element, body, attrTypeCtx);
1203 // Generate the code to resolve the operand/result types and successors now
1204 // that they have been parsed.
1205 genParserRegionResolution(op, body);
1206 genParserSuccessorResolution(op, body);
1207 genParserVariadicSegmentResolution(op, body);
1208 genParserTypeResolution(op, body);
1210 body << " return ::mlir::success();\n";
1213 void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
1214 FmtContext &attrTypeCtx,
1215 GenContext genCtx) {
1216 /// Optional Group.
1217 if (auto *optional = dyn_cast<OptionalElement>(element)) {
1218 auto genElementParsers = [&](FormatElement *firstElement,
1219 ArrayRef<FormatElement *> elements,
1220 bool thenGroup) {
1221 // If the anchor is a unit attribute, we don't need to print it. When
1222 // parsing, we will add this attribute if this group is present.
1223 FormatElement *elidedAnchorElement = nullptr;
1224 auto *anchorAttr = dyn_cast<AttributeVariable>(optional->getAnchor());
1225 if (anchorAttr && anchorAttr != firstElement &&
1226 anchorAttr->isUnitAttr()) {
1227 elidedAnchorElement = anchorAttr;
1229 if (!thenGroup == optional->isInverted()) {
1230 // Add the anchor unit attribute to the operation state.
1231 if (useProperties) {
1232 body << formatv(
1233 " result.getOrAddProperties<{1}::Properties>().{0} = "
1234 "parser.getBuilder().getUnitAttr();",
1235 anchorAttr->getVar()->name, opCppClassName);
1236 } else {
1237 body << " result.addAttribute(\"" << anchorAttr->getVar()->name
1238 << "\", parser.getBuilder().getUnitAttr());\n";
1243 // Generate the rest of the elements inside an optional group. Elements in
1244 // an optional group after the guard are parsed as required.
1245 for (FormatElement *childElement : elements)
1246 if (childElement != elidedAnchorElement)
1247 genElementParser(childElement, body, attrTypeCtx,
1248 GenContext::Optional);
1251 ArrayRef<FormatElement *> thenElements =
1252 optional->getThenElements(/*parseable=*/true);
1254 // Generate a special optional parser for the first element to gate the
1255 // parsing of the rest of the elements.
1256 FormatElement *firstElement = thenElements.front();
1257 if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) {
1258 genAttrParser(attrVar, body, attrTypeCtx, /*parseAsOptional=*/true,
1259 useProperties, opCppClassName);
1260 body << " if (" << attrVar->getVar()->name << "Attr) {\n";
1261 } else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
1262 body << " if (::mlir::succeeded(parser.parseOptional";
1263 genLiteralParser(literal->getSpelling(), body);
1264 body << ")) {\n";
1265 } else if (auto *opVar = dyn_cast<OperandVariable>(firstElement)) {
1266 genElementParser(opVar, body, attrTypeCtx);
1267 body << " if (!" << opVar->getVar()->name << "Operands.empty()) {\n";
1268 } else if (auto *regionVar = dyn_cast<RegionVariable>(firstElement)) {
1269 const NamedRegion *region = regionVar->getVar();
1270 if (region->isVariadic()) {
1271 genElementParser(regionVar, body, attrTypeCtx);
1272 body << " if (!" << region->name << "Regions.empty()) {\n";
1273 } else {
1274 body << llvm::formatv(optionalRegionParserCode, region->name);
1275 body << " if (!" << region->name << "Region->empty()) {\n ";
1276 if (hasImplicitTermTrait)
1277 body << llvm::formatv(regionEnsureTerminatorParserCode, region->name);
1278 else if (hasSingleBlockTrait)
1279 body << llvm::formatv(regionEnsureSingleBlockParserCode,
1280 region->name);
1282 } else if (auto *custom = dyn_cast<CustomDirective>(firstElement)) {
1283 body << " if (auto result = [&]() -> ::mlir::OptionalParseResult {\n";
1284 genCustomDirectiveParser(custom, body, useProperties, opCppClassName,
1285 /*isOptional=*/true);
1286 body << " return ::mlir::success();\n"
1287 << " }(); result.has_value() && ::mlir::failed(*result)) {\n"
1288 << " return ::mlir::failure();\n"
1289 << " } else if (result.has_value()) {\n";
1292 genElementParsers(firstElement, thenElements.drop_front(),
1293 /*thenGroup=*/true);
1294 body << " }";
1296 // Generate the else elements.
1297 auto elseElements = optional->getElseElements();
1298 if (!elseElements.empty()) {
1299 body << " else {\n";
1300 ArrayRef<FormatElement *> elseElements =
1301 optional->getElseElements(/*parseable=*/true);
1302 genElementParsers(elseElements.front(), elseElements,
1303 /*thenGroup=*/false);
1304 body << " }";
1306 body << "\n";
1308 /// OIList Directive
1309 } else if (OIListElement *oilist = dyn_cast<OIListElement>(element)) {
1310 for (LiteralElement *le : oilist->getLiteralElements())
1311 body << " bool " << le->getSpelling() << "Clause = false;\n";
1313 // Generate the parsing loop
1314 body << " while(true) {\n";
1315 for (auto clause : oilist->getClauses()) {
1316 LiteralElement *lelement = std::get<0>(clause);
1317 ArrayRef<FormatElement *> pelement = std::get<1>(clause);
1318 body << "if (succeeded(parser.parseOptional";
1319 genLiteralParser(lelement->getSpelling(), body);
1320 body << ")) {\n";
1321 StringRef lelementName = lelement->getSpelling();
1322 body << formatv(oilistParserCode, lelementName);
1323 if (AttributeVariable *unitAttrElem =
1324 oilist->getUnitAttrParsingElement(pelement)) {
1325 if (useProperties) {
1326 body << formatv(
1327 " result.getOrAddProperties<{1}::Properties>().{0} = "
1328 "parser.getBuilder().getUnitAttr();",
1329 unitAttrElem->getVar()->name, opCppClassName);
1330 } else {
1331 body << " result.addAttribute(\"" << unitAttrElem->getVar()->name
1332 << "\", UnitAttr::get(parser.getContext()));\n";
1334 } else {
1335 for (FormatElement *el : pelement)
1336 genElementParser(el, body, attrTypeCtx);
1338 body << " } else ";
1340 body << " {\n";
1341 body << " break;\n";
1342 body << " }\n";
1343 body << "}\n";
1345 /// Literals.
1346 } else if (LiteralElement *literal = dyn_cast<LiteralElement>(element)) {
1347 body << " if (parser.parse";
1348 genLiteralParser(literal->getSpelling(), body);
1349 body << ")\n return ::mlir::failure();\n";
1351 /// Whitespaces.
1352 } else if (isa<WhitespaceElement>(element)) {
1353 // Nothing to parse.
1355 /// Arguments.
1356 } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
1357 bool parseAsOptional =
1358 (genCtx == GenContext::Normal && attr->getVar()->attr.isOptional());
1359 genAttrParser(attr, body, attrTypeCtx, parseAsOptional, useProperties,
1360 opCppClassName);
1362 } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
1363 ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
1364 StringRef name = operand->getVar()->name;
1365 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
1366 body << llvm::formatv(
1367 variadicOfVariadicOperandParserCode, name,
1368 operand->getVar()->constraint.getVariadicOfVariadicSegmentSizeAttr());
1369 else if (lengthKind == ArgumentLengthKind::Variadic)
1370 body << llvm::formatv(variadicOperandParserCode, name);
1371 else if (lengthKind == ArgumentLengthKind::Optional)
1372 body << llvm::formatv(optionalOperandParserCode, name);
1373 else
1374 body << formatv(operandParserCode, name);
1376 } else if (auto *region = dyn_cast<RegionVariable>(element)) {
1377 bool isVariadic = region->getVar()->isVariadic();
1378 body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode,
1379 region->getVar()->name);
1380 if (hasImplicitTermTrait)
1381 body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode
1382 : regionEnsureTerminatorParserCode,
1383 region->getVar()->name);
1384 else if (hasSingleBlockTrait)
1385 body << llvm::formatv(isVariadic ? regionListEnsureSingleBlockParserCode
1386 : regionEnsureSingleBlockParserCode,
1387 region->getVar()->name);
1389 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
1390 bool isVariadic = successor->getVar()->isVariadic();
1391 body << formatv(isVariadic ? successorListParserCode : successorParserCode,
1392 successor->getVar()->name);
1394 /// Directives.
1395 } else if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
1396 body.indent() << "{\n";
1397 body.indent() << "auto loc = parser.getCurrentLocation();(void)loc;\n"
1398 << "if (parser.parseOptionalAttrDict"
1399 << (attrDict->isWithKeyword() ? "WithKeyword" : "")
1400 << "(result.attributes))\n"
1401 << " return ::mlir::failure();\n";
1402 if (useProperties) {
1403 body << "if (failed(verifyInherentAttrs(result.name, result.attributes, "
1404 "[&]() {\n"
1405 << " return parser.emitError(loc) << \"'\" << "
1406 "result.name.getStringRef() << \"' op \";\n"
1407 << " })))\n"
1408 << " return ::mlir::failure();\n";
1410 body.unindent() << "}\n";
1411 body.unindent();
1412 } else if (dyn_cast<PropDictDirective>(element)) {
1413 body << " if (parseProperties(parser, result))\n"
1414 << " return ::mlir::failure();\n";
1415 } else if (auto *customDir = dyn_cast<CustomDirective>(element)) {
1416 genCustomDirectiveParser(customDir, body, useProperties, opCppClassName);
1417 } else if (isa<OperandsDirective>(element)) {
1418 body << " ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
1419 << " if (parser.parseOperandList(allOperands))\n"
1420 << " return ::mlir::failure();\n";
1422 } else if (isa<RegionsDirective>(element)) {
1423 body << llvm::formatv(regionListParserCode, "full");
1424 if (hasImplicitTermTrait)
1425 body << llvm::formatv(regionListEnsureTerminatorParserCode, "full");
1426 else if (hasSingleBlockTrait)
1427 body << llvm::formatv(regionListEnsureSingleBlockParserCode, "full");
1429 } else if (isa<SuccessorsDirective>(element)) {
1430 body << llvm::formatv(successorListParserCode, "full");
1432 } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
1433 ArgumentLengthKind lengthKind;
1434 StringRef listName = getTypeListName(dir->getArg(), lengthKind);
1435 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
1436 body << llvm::formatv(variadicOfVariadicTypeParserCode, listName);
1437 } else if (lengthKind == ArgumentLengthKind::Variadic) {
1438 body << llvm::formatv(variadicTypeParserCode, listName);
1439 } else if (lengthKind == ArgumentLengthKind::Optional) {
1440 body << llvm::formatv(optionalTypeParserCode, listName);
1441 } else {
1442 const char *parserCode =
1443 dir->shouldBeQualified() ? qualifiedTypeParserCode : typeParserCode;
1444 TypeSwitch<FormatElement *>(dir->getArg())
1445 .Case<OperandVariable, ResultVariable>([&](auto operand) {
1446 body << formatv(parserCode,
1447 operand->getVar()->constraint.getCPPClassName(),
1448 listName);
1450 .Default([&](auto operand) {
1451 body << formatv(parserCode, "::mlir::Type", listName);
1454 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
1455 ArgumentLengthKind ignored;
1456 body << formatv(functionalTypeParserCode,
1457 getTypeListName(dir->getInputs(), ignored),
1458 getTypeListName(dir->getResults(), ignored));
1459 } else {
1460 llvm_unreachable("unknown format element");
1464 void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
1465 // If any of type resolutions use transformed variables, make sure that the
1466 // types of those variables are resolved.
1467 SmallPtrSet<const NamedTypeConstraint *, 8> verifiedVariables;
1468 FmtContext verifierFCtx;
1469 for (TypeResolution &resolver :
1470 llvm::concat<TypeResolution>(resultTypes, operandTypes)) {
1471 std::optional<StringRef> transformer = resolver.getVarTransformer();
1472 if (!transformer)
1473 continue;
1474 // Ensure that we don't verify the same variables twice.
1475 const NamedTypeConstraint *variable = resolver.getVariable();
1476 if (!variable || !verifiedVariables.insert(variable).second)
1477 continue;
1479 auto constraint = variable->constraint;
1480 body << " for (::mlir::Type type : " << variable->name << "Types) {\n"
1481 << " (void)type;\n"
1482 << " if (!("
1483 << tgfmt(constraint.getConditionTemplate(),
1484 &verifierFCtx.withSelf("type"))
1485 << ")) {\n"
1486 << formatv(" return parser.emitError(parser.getNameLoc()) << "
1487 "\"'{0}' must be {1}, but got \" << type;\n",
1488 variable->name, constraint.getSummary())
1489 << " }\n"
1490 << " }\n";
1493 // Initialize the set of buildable types.
1494 if (!buildableTypes.empty()) {
1495 FmtContext typeBuilderCtx;
1496 typeBuilderCtx.withBuilder("parser.getBuilder()");
1497 for (auto &it : buildableTypes)
1498 body << " ::mlir::Type odsBuildableType" << it.second << " = "
1499 << tgfmt(it.first, &typeBuilderCtx) << ";\n";
1502 // Emit the code necessary for a type resolver.
1503 auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) {
1504 if (std::optional<int> val = resolver.getBuilderIdx()) {
1505 body << "odsBuildableType" << *val;
1506 } else if (const NamedTypeConstraint *var = resolver.getVariable()) {
1507 if (std::optional<StringRef> tform = resolver.getVarTransformer()) {
1508 FmtContext fmtContext;
1509 fmtContext.addSubst("_ctxt", "parser.getContext()");
1510 if (var->isVariadic())
1511 fmtContext.withSelf(var->name + "Types");
1512 else
1513 fmtContext.withSelf(var->name + "Types[0]");
1514 body << tgfmt(*tform, &fmtContext);
1515 } else {
1516 body << var->name << "Types";
1517 if (!var->isVariadic())
1518 body << "[0]";
1520 } else if (const NamedAttribute *attr = resolver.getAttribute()) {
1521 if (std::optional<StringRef> tform = resolver.getVarTransformer())
1522 body << tgfmt(*tform,
1523 &FmtContext().withSelf(attr->name + "Attr.getType()"));
1524 else
1525 body << attr->name << "Attr.getType()";
1526 } else {
1527 body << curVar << "Types";
1531 // Resolve each of the result types.
1532 if (!infersResultTypes) {
1533 if (allResultTypes) {
1534 body << " result.addTypes(allResultTypes);\n";
1535 } else {
1536 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
1537 body << " result.addTypes(";
1538 emitTypeResolver(resultTypes[i], op.getResultName(i));
1539 body << ");\n";
1544 // Emit the operand type resolutions.
1545 genParserOperandTypeResolution(op, body, emitTypeResolver);
1547 // Handle return type inference once all operands have been resolved
1548 if (infersResultTypes)
1549 body << formatv(inferReturnTypesParserCode, op.getCppClassName());
1552 void OperationFormat::genParserOperandTypeResolution(
1553 Operator &op, MethodBody &body,
1554 function_ref<void(TypeResolution &, StringRef)> emitTypeResolver) {
1555 // Early exit if there are no operands.
1556 if (op.getNumOperands() == 0)
1557 return;
1559 // Handle the case where all operand types are grouped together with
1560 // "types(operands)".
1561 if (allOperandTypes) {
1562 // If `operands` was specified, use the full operand list directly.
1563 if (allOperands) {
1564 body << " if (parser.resolveOperands(allOperands, allOperandTypes, "
1565 "allOperandLoc, result.operands))\n"
1566 " return ::mlir::failure();\n";
1567 return;
1570 // Otherwise, use llvm::concat to merge the disjoint operand lists together.
1571 // llvm::concat does not allow the case of a single range, so guard it here.
1572 body << " if (parser.resolveOperands(";
1573 if (op.getNumOperands() > 1) {
1574 body << "::llvm::concat<const ::mlir::OpAsmParser::UnresolvedOperand>(";
1575 llvm::interleaveComma(op.getOperands(), body, [&](auto &operand) {
1576 body << operand.name << "Operands";
1578 body << ")";
1579 } else {
1580 body << op.operand_begin()->name << "Operands";
1582 body << ", allOperandTypes, parser.getNameLoc(), result.operands))\n"
1583 << " return ::mlir::failure();\n";
1584 return;
1587 // Handle the case where all operands are grouped together with "operands".
1588 if (allOperands) {
1589 body << " if (parser.resolveOperands(allOperands, ";
1591 // Group all of the operand types together to perform the resolution all at
1592 // once. Use llvm::concat to perform the merge. llvm::concat does not allow
1593 // the case of a single range, so guard it here.
1594 if (op.getNumOperands() > 1) {
1595 body << "::llvm::concat<const ::mlir::Type>(";
1596 llvm::interleaveComma(
1597 llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
1598 body << "::llvm::ArrayRef<::mlir::Type>(";
1599 emitTypeResolver(operandTypes[i], op.getOperand(i).name);
1600 body << ")";
1602 body << ")";
1603 } else {
1604 emitTypeResolver(operandTypes.front(), op.getOperand(0).name);
1607 body << ", allOperandLoc, result.operands))\n return "
1608 "::mlir::failure();\n";
1609 return;
1612 // The final case is the one where each of the operands types are resolved
1613 // separately.
1614 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
1615 NamedTypeConstraint &operand = op.getOperand(i);
1616 body << " if (parser.resolveOperands(" << operand.name << "Operands, ";
1618 // Resolve the type of this operand.
1619 TypeResolution &operandType = operandTypes[i];
1620 emitTypeResolver(operandType, operand.name);
1622 body << ", " << operand.name
1623 << "OperandsLoc, result.operands))\n return ::mlir::failure();\n";
1627 void OperationFormat::genParserRegionResolution(Operator &op,
1628 MethodBody &body) {
1629 // Check for the case where all regions were parsed.
1630 bool hasAllRegions = llvm::any_of(
1631 elements, [](FormatElement *elt) { return isa<RegionsDirective>(elt); });
1632 if (hasAllRegions) {
1633 body << " result.addRegions(fullRegions);\n";
1634 return;
1637 // Otherwise, handle each region individually.
1638 for (const NamedRegion &region : op.getRegions()) {
1639 if (region.isVariadic())
1640 body << " result.addRegions(" << region.name << "Regions);\n";
1641 else
1642 body << " result.addRegion(std::move(" << region.name << "Region));\n";
1646 void OperationFormat::genParserSuccessorResolution(Operator &op,
1647 MethodBody &body) {
1648 // Check for the case where all successors were parsed.
1649 bool hasAllSuccessors = llvm::any_of(elements, [](FormatElement *elt) {
1650 return isa<SuccessorsDirective>(elt);
1652 if (hasAllSuccessors) {
1653 body << " result.addSuccessors(fullSuccessors);\n";
1654 return;
1657 // Otherwise, handle each successor individually.
1658 for (const NamedSuccessor &successor : op.getSuccessors()) {
1659 if (successor.isVariadic())
1660 body << " result.addSuccessors(" << successor.name << "Successors);\n";
1661 else
1662 body << " result.addSuccessors(" << successor.name << "Successor);\n";
1666 void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
1667 MethodBody &body) {
1668 if (!allOperands) {
1669 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1670 auto interleaveFn = [&](const NamedTypeConstraint &operand) {
1671 // If the operand is variadic emit the parsed size.
1672 if (operand.isVariableLength())
1673 body << "static_cast<int32_t>(" << operand.name << "Operands.size())";
1674 else
1675 body << "1";
1677 if (op.getDialect().usePropertiesForAttributes()) {
1678 body << "::llvm::copy(::llvm::ArrayRef<int32_t>({";
1679 llvm::interleaveComma(op.getOperands(), body, interleaveFn);
1680 body << formatv("}), "
1681 "result.getOrAddProperties<{0}::Properties>()."
1682 "operandSegmentSizes.begin());\n",
1683 op.getCppClassName());
1684 } else {
1685 body << " result.addAttribute(\"operandSegmentSizes\", "
1686 << "parser.getBuilder().getDenseI32ArrayAttr({";
1687 llvm::interleaveComma(op.getOperands(), body, interleaveFn);
1688 body << "}));\n";
1691 for (const NamedTypeConstraint &operand : op.getOperands()) {
1692 if (!operand.isVariadicOfVariadic())
1693 continue;
1694 if (op.getDialect().usePropertiesForAttributes()) {
1695 body << llvm::formatv(
1696 " result.getOrAddProperties<{0}::Properties>().{1} = "
1697 "parser.getBuilder().getDenseI32ArrayAttr({2}OperandGroupSizes);\n",
1698 op.getCppClassName(),
1699 operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
1700 operand.name);
1701 } else {
1702 body << llvm::formatv(
1703 " result.addAttribute(\"{0}\", "
1704 "parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));"
1705 "\n",
1706 operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
1707 operand.name);
1712 if (!allResultTypes &&
1713 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
1714 auto interleaveFn = [&](const NamedTypeConstraint &result) {
1715 // If the result is variadic emit the parsed size.
1716 if (result.isVariableLength())
1717 body << "static_cast<int32_t>(" << result.name << "Types.size())";
1718 else
1719 body << "1";
1721 if (op.getDialect().usePropertiesForAttributes()) {
1722 body << "::llvm::copy(::llvm::ArrayRef<int32_t>({";
1723 llvm::interleaveComma(op.getResults(), body, interleaveFn);
1724 body << formatv("}), "
1725 "result.getOrAddProperties<{0}::Properties>()."
1726 "resultSegmentSizes.begin());\n",
1727 op.getCppClassName());
1728 } else {
1729 body << " result.addAttribute(\"resultSegmentSizes\", "
1730 << "parser.getBuilder().getDenseI32ArrayAttr({";
1731 llvm::interleaveComma(op.getResults(), body, interleaveFn);
1732 body << "}));\n";
1737 //===----------------------------------------------------------------------===//
1738 // PrinterGen
1740 /// The code snippet used to generate a printer call for a region of an
1741 // operation that has the SingleBlockImplicitTerminator trait.
1743 /// {0}: The name of the region.
1744 const char *regionSingleBlockImplicitTerminatorPrinterCode = R"(
1746 bool printTerminator = true;
1747 if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{
1748 printTerminator = !term->getAttrDictionary().empty() ||
1749 term->getNumOperands() != 0 ||
1750 term->getNumResults() != 0;
1752 _odsPrinter.printRegion({0}, /*printEntryBlockArgs=*/true,
1753 /*printBlockTerminators=*/printTerminator);
1757 /// The code snippet used to generate a printer call for an enum that has cases
1758 /// that can't be represented with a keyword.
1760 /// {0}: The name of the enum attribute.
1761 /// {1}: The name of the enum attributes symbolToString function.
1762 const char *enumAttrBeginPrinterCode = R"(
1764 auto caseValue = {0}();
1765 auto caseValueStr = {1}(caseValue);
1768 /// Generate the printer for the 'prop-dict' directive.
1769 static void genPropDictPrinter(OperationFormat &fmt, Operator &op,
1770 MethodBody &body) {
1771 body << " _odsPrinter << \" \";\n"
1772 << " printProperties(this->getContext(), _odsPrinter, "
1773 "getProperties());\n";
1776 /// Generate the printer for the 'attr-dict' directive.
1777 static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
1778 MethodBody &body, bool withKeyword) {
1779 body << " ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;\n";
1780 // Elide the variadic segment size attributes if necessary.
1781 if (!fmt.allOperands &&
1782 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"))
1783 body << " elidedAttrs.push_back(\"operandSegmentSizes\");\n";
1784 if (!fmt.allResultTypes &&
1785 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
1786 body << " elidedAttrs.push_back(\"resultSegmentSizes\");\n";
1787 for (const StringRef key : fmt.inferredAttributes.keys())
1788 body << " elidedAttrs.push_back(\"" << key << "\");\n";
1789 for (const NamedAttribute *attr : fmt.usedAttributes)
1790 body << " elidedAttrs.push_back(\"" << attr->name << "\");\n";
1791 // Add code to check attributes for equality with the default value
1792 // for attributes with the elidePrintingDefaultValue bit set.
1793 for (const NamedAttribute &namedAttr : op.getAttributes()) {
1794 const Attribute &attr = namedAttr.attr;
1795 if (!attr.isDerivedAttr() && attr.hasDefaultValue()) {
1796 const StringRef &name = namedAttr.name;
1797 FmtContext fctx;
1798 fctx.withBuilder("odsBuilder");
1799 std::string defaultValue = std::string(
1800 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
1801 body << " {\n";
1802 body << " ::mlir::Builder odsBuilder(getContext());\n";
1803 body << " ::mlir::Attribute attr = " << op.getGetterName(name)
1804 << "Attr();\n";
1805 body << " if(attr && (attr == " << defaultValue << "))\n";
1806 body << " elidedAttrs.push_back(\"" << name << "\");\n";
1807 body << " }\n";
1810 body << " _odsPrinter.printOptionalAttrDict"
1811 << (withKeyword ? "WithKeyword" : "")
1812 << "((*this)->getAttrs(), elidedAttrs);\n";
1815 /// Generate the printer for a literal value. `shouldEmitSpace` is true if a
1816 /// space should be emitted before this element. `lastWasPunctuation` is true if
1817 /// the previous element was a punctuation literal.
1818 static void genLiteralPrinter(StringRef value, MethodBody &body,
1819 bool &shouldEmitSpace, bool &lastWasPunctuation) {
1820 body << " _odsPrinter";
1822 // Don't insert a space for certain punctuation.
1823 if (shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation))
1824 body << " << ' '";
1825 body << " << \"" << value << "\";\n";
1827 // Insert a space after certain literals.
1828 shouldEmitSpace =
1829 value.size() != 1 || !StringRef("<({[").contains(value.front());
1830 lastWasPunctuation = value.front() != '_' && !isalpha(value.front());
1833 /// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation`
1834 /// are set to false.
1835 static void genSpacePrinter(bool value, MethodBody &body, bool &shouldEmitSpace,
1836 bool &lastWasPunctuation) {
1837 if (value) {
1838 body << " _odsPrinter << ' ';\n";
1839 lastWasPunctuation = false;
1840 } else {
1841 lastWasPunctuation = true;
1843 shouldEmitSpace = false;
1846 /// Generate the printer for a custom directive parameter.
1847 static void genCustomDirectiveParameterPrinter(FormatElement *element,
1848 const Operator &op,
1849 MethodBody &body) {
1850 if (auto *attr = dyn_cast<AttributeVariable>(element)) {
1851 body << op.getGetterName(attr->getVar()->name) << "Attr()";
1853 } else if (isa<AttrDictDirective>(element)) {
1854 body << "getOperation()->getAttrDictionary()";
1856 } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
1857 body << op.getGetterName(operand->getVar()->name) << "()";
1859 } else if (auto *region = dyn_cast<RegionVariable>(element)) {
1860 body << op.getGetterName(region->getVar()->name) << "()";
1862 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
1863 body << op.getGetterName(successor->getVar()->name) << "()";
1865 } else if (auto *dir = dyn_cast<RefDirective>(element)) {
1866 genCustomDirectiveParameterPrinter(dir->getArg(), op, body);
1868 } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
1869 auto *typeOperand = dir->getArg();
1870 auto *operand = dyn_cast<OperandVariable>(typeOperand);
1871 auto *var = operand ? operand->getVar()
1872 : cast<ResultVariable>(typeOperand)->getVar();
1873 std::string name = op.getGetterName(var->name);
1874 if (var->isVariadic())
1875 body << name << "().getTypes()";
1876 else if (var->isOptional())
1877 body << llvm::formatv("({0}() ? {0}().getType() : ::mlir::Type())", name);
1878 else
1879 body << name << "().getType()";
1881 } else if (auto *string = dyn_cast<StringElement>(element)) {
1882 FmtContext ctx;
1883 ctx.withBuilder("::mlir::Builder(getContext())");
1884 ctx.addSubst("_ctxt", "getContext()");
1885 body << tgfmt(string->getValue(), &ctx);
1887 } else if (auto *property = dyn_cast<PropertyVariable>(element)) {
1888 FmtContext ctx;
1889 ctx.addSubst("_ctxt", "getContext()");
1890 const NamedProperty *namedProperty = property->getVar();
1891 ctx.addSubst("_storage", "getProperties()." + namedProperty->name);
1892 body << tgfmt(namedProperty->prop.getConvertFromStorageCall(), &ctx);
1893 } else {
1894 llvm_unreachable("unknown custom directive parameter");
1898 /// Generate the printer for a custom directive.
1899 static void genCustomDirectivePrinter(CustomDirective *customDir,
1900 const Operator &op, MethodBody &body) {
1901 body << " print" << customDir->getName() << "(_odsPrinter, *this";
1902 for (FormatElement *param : customDir->getArguments()) {
1903 body << ", ";
1904 genCustomDirectiveParameterPrinter(param, op, body);
1906 body << ");\n";
1909 /// Generate the printer for a region with the given variable name.
1910 static void genRegionPrinter(const Twine &regionName, MethodBody &body,
1911 bool hasImplicitTermTrait) {
1912 if (hasImplicitTermTrait)
1913 body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode,
1914 regionName);
1915 else
1916 body << " _odsPrinter.printRegion(" << regionName << ");\n";
1918 static void genVariadicRegionPrinter(const Twine &regionListName,
1919 MethodBody &body,
1920 bool hasImplicitTermTrait) {
1921 body << " llvm::interleaveComma(" << regionListName
1922 << ", _odsPrinter, [&](::mlir::Region &region) {\n ";
1923 genRegionPrinter("region", body, hasImplicitTermTrait);
1924 body << " });\n";
1927 /// Generate the C++ for an operand to a (*-)type directive.
1928 static MethodBody &genTypeOperandPrinter(FormatElement *arg, const Operator &op,
1929 MethodBody &body,
1930 bool useArrayRef = true) {
1931 if (isa<OperandsDirective>(arg))
1932 return body << "getOperation()->getOperandTypes()";
1933 if (isa<ResultsDirective>(arg))
1934 return body << "getOperation()->getResultTypes()";
1935 auto *operand = dyn_cast<OperandVariable>(arg);
1936 auto *var = operand ? operand->getVar() : cast<ResultVariable>(arg)->getVar();
1937 if (var->isVariadicOfVariadic())
1938 return body << llvm::formatv("{0}().join().getTypes()",
1939 op.getGetterName(var->name));
1940 if (var->isVariadic())
1941 return body << op.getGetterName(var->name) << "().getTypes()";
1942 if (var->isOptional())
1943 return body << llvm::formatv(
1944 "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
1945 "::llvm::ArrayRef<::mlir::Type>())",
1946 op.getGetterName(var->name));
1947 if (useArrayRef)
1948 return body << "::llvm::ArrayRef<::mlir::Type>("
1949 << op.getGetterName(var->name) << "().getType())";
1950 return body << op.getGetterName(var->name) << "().getType()";
1953 /// Generate the printer for an enum attribute.
1954 static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
1955 MethodBody &body) {
1956 Attribute baseAttr = var->attr.getBaseAttr();
1957 const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
1958 std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
1960 body << llvm::formatv(enumAttrBeginPrinterCode,
1961 (var->attr.isOptional() ? "*" : "") +
1962 op.getGetterName(var->name),
1963 enumAttr.getSymbolToStringFnName());
1965 // Get a string containing all of the cases that can't be represented with a
1966 // keyword.
1967 BitVector nonKeywordCases(cases.size());
1968 for (auto it : llvm::enumerate(cases)) {
1969 if (!canFormatStringAsKeyword(it.value().getStr()))
1970 nonKeywordCases.set(it.index());
1973 // Otherwise if this is a bit enum attribute, don't allow cases that may
1974 // overlap with other cases. For simplicity sake, only allow cases with a
1975 // single bit value.
1976 if (enumAttr.isBitEnum()) {
1977 for (auto it : llvm::enumerate(cases)) {
1978 int64_t value = it.value().getValue();
1979 if (value < 0 || !llvm::isPowerOf2_64(value))
1980 nonKeywordCases.set(it.index());
1984 // If there are any cases that can't be used with a keyword, switch on the
1985 // case value to determine when to print in the string form.
1986 if (nonKeywordCases.any()) {
1987 body << " switch (caseValue) {\n";
1988 StringRef cppNamespace = enumAttr.getCppNamespace();
1989 StringRef enumName = enumAttr.getEnumClassName();
1990 for (auto it : llvm::enumerate(cases)) {
1991 if (nonKeywordCases.test(it.index()))
1992 continue;
1993 StringRef symbol = it.value().getSymbol();
1994 body << llvm::formatv(" case {0}::{1}::{2}:\n", cppNamespace, enumName,
1995 llvm::isDigit(symbol.front()) ? ("_" + symbol)
1996 : symbol);
1998 body << " _odsPrinter << caseValueStr;\n"
1999 " break;\n"
2000 " default:\n"
2001 " _odsPrinter << '\"' << caseValueStr << '\"';\n"
2002 " break;\n"
2003 " }\n"
2004 " }\n";
2005 return;
2008 body << " _odsPrinter << caseValueStr;\n"
2009 " }\n";
2012 /// Generate a check that a DefaultValuedAttr has a value that is non-default.
2013 static void genNonDefaultValueCheck(MethodBody &body, const Operator &op,
2014 AttributeVariable &attrElement) {
2015 FmtContext fctx;
2016 Attribute attr = attrElement.getVar()->attr;
2017 fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())");
2018 body << " && " << op.getGetterName(attrElement.getVar()->name) << "Attr() != "
2019 << tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue());
2022 /// Generate the check for the anchor of an optional group.
2023 static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
2024 const Operator &op,
2025 MethodBody &body) {
2026 TypeSwitch<FormatElement *>(anchor)
2027 .Case<OperandVariable, ResultVariable>([&](auto *element) {
2028 const NamedTypeConstraint *var = element->getVar();
2029 std::string name = op.getGetterName(var->name);
2030 if (var->isOptional())
2031 body << name << "()";
2032 else if (var->isVariadic())
2033 body << "!" << name << "().empty()";
2035 .Case([&](RegionVariable *element) {
2036 const NamedRegion *var = element->getVar();
2037 std::string name = op.getGetterName(var->name);
2038 // TODO: Add a check for optional regions here when ODS supports it.
2039 body << "!" << name << "().empty()";
2041 .Case([&](TypeDirective *element) {
2042 genOptionalGroupPrinterAnchor(element->getArg(), op, body);
2044 .Case([&](FunctionalTypeDirective *element) {
2045 genOptionalGroupPrinterAnchor(element->getInputs(), op, body);
2047 .Case([&](AttributeVariable *element) {
2048 Attribute attr = element->getVar()->attr;
2049 body << op.getGetterName(element->getVar()->name) << "Attr()";
2050 if (attr.isOptional())
2051 return; // done
2052 if (attr.hasDefaultValue()) {
2053 // Consider a default-valued attribute as present if it's not the
2054 // default value.
2055 genNonDefaultValueCheck(body, op, *element);
2056 return;
2058 llvm_unreachable("attribute must be optional or default-valued");
2060 .Case([&](CustomDirective *ele) {
2061 body << '(';
2062 llvm::interleave(
2063 ele->getArguments(), body,
2064 [&](FormatElement *child) {
2065 body << '(';
2066 genOptionalGroupPrinterAnchor(child, op, body);
2067 body << ')';
2069 " || ");
2070 body << ')';
2074 void collect(FormatElement *element,
2075 SmallVectorImpl<VariableElement *> &variables) {
2076 TypeSwitch<FormatElement *>(element)
2077 .Case([&](VariableElement *var) { variables.emplace_back(var); })
2078 .Case([&](CustomDirective *ele) {
2079 for (FormatElement *arg : ele->getArguments())
2080 collect(arg, variables);
2082 .Case([&](OptionalElement *ele) {
2083 for (FormatElement *arg : ele->getThenElements())
2084 collect(arg, variables);
2085 for (FormatElement *arg : ele->getElseElements())
2086 collect(arg, variables);
2088 .Case([&](FunctionalTypeDirective *funcType) {
2089 collect(funcType->getInputs(), variables);
2090 collect(funcType->getResults(), variables);
2092 .Case([&](OIListElement *oilist) {
2093 for (ArrayRef<FormatElement *> arg : oilist->getParsingElements())
2094 for (FormatElement *arg : arg)
2095 collect(arg, variables);
2099 void OperationFormat::genElementPrinter(FormatElement *element,
2100 MethodBody &body, Operator &op,
2101 bool &shouldEmitSpace,
2102 bool &lastWasPunctuation) {
2103 if (LiteralElement *literal = dyn_cast<LiteralElement>(element))
2104 return genLiteralPrinter(literal->getSpelling(), body, shouldEmitSpace,
2105 lastWasPunctuation);
2107 // Emit a whitespace element.
2108 if (auto *space = dyn_cast<WhitespaceElement>(element)) {
2109 if (space->getValue() == "\\n") {
2110 body << " _odsPrinter.printNewline();\n";
2111 } else {
2112 genSpacePrinter(!space->getValue().empty(), body, shouldEmitSpace,
2113 lastWasPunctuation);
2115 return;
2118 // Emit an optional group.
2119 if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) {
2120 // Emit the check for the presence of the anchor element.
2121 FormatElement *anchor = optional->getAnchor();
2122 body << " if (";
2123 if (optional->isInverted())
2124 body << "!";
2125 genOptionalGroupPrinterAnchor(anchor, op, body);
2126 body << ") {\n";
2127 body.indent();
2129 // If the anchor is a unit attribute, we don't need to print it. When
2130 // parsing, we will add this attribute if this group is present.
2131 ArrayRef<FormatElement *> thenElements = optional->getThenElements();
2132 ArrayRef<FormatElement *> elseElements = optional->getElseElements();
2133 FormatElement *elidedAnchorElement = nullptr;
2134 auto *anchorAttr = dyn_cast<AttributeVariable>(anchor);
2135 if (anchorAttr && anchorAttr != thenElements.front() &&
2136 (elseElements.empty() || anchorAttr != elseElements.front()) &&
2137 anchorAttr->isUnitAttr()) {
2138 elidedAnchorElement = anchorAttr;
2140 auto genElementPrinters = [&](ArrayRef<FormatElement *> elements) {
2141 for (FormatElement *childElement : elements) {
2142 if (childElement != elidedAnchorElement) {
2143 genElementPrinter(childElement, body, op, shouldEmitSpace,
2144 lastWasPunctuation);
2149 // Emit each of the elements.
2150 genElementPrinters(thenElements);
2151 body << "}";
2153 // Emit each of the else elements.
2154 if (!elseElements.empty()) {
2155 body << " else {\n";
2156 genElementPrinters(elseElements);
2157 body << "}";
2160 body.unindent() << "\n";
2161 return;
2164 // Emit the OIList
2165 if (auto *oilist = dyn_cast<OIListElement>(element)) {
2166 for (auto clause : oilist->getClauses()) {
2167 LiteralElement *lelement = std::get<0>(clause);
2168 ArrayRef<FormatElement *> pelement = std::get<1>(clause);
2170 SmallVector<VariableElement *> vars;
2171 for (FormatElement *el : pelement)
2172 collect(el, vars);
2173 body << " if (false";
2174 for (VariableElement *var : vars) {
2175 TypeSwitch<FormatElement *>(var)
2176 .Case([&](AttributeVariable *attrEle) {
2177 body << " || (" << op.getGetterName(attrEle->getVar()->name)
2178 << "Attr()";
2179 Attribute attr = attrEle->getVar()->attr;
2180 if (attr.hasDefaultValue()) {
2181 // Don't print default-valued attributes.
2182 genNonDefaultValueCheck(body, op, *attrEle);
2184 body << ")";
2186 .Case([&](OperandVariable *ele) {
2187 if (ele->getVar()->isVariadic()) {
2188 body << " || " << op.getGetterName(ele->getVar()->name)
2189 << "().size()";
2190 } else {
2191 body << " || " << op.getGetterName(ele->getVar()->name) << "()";
2194 .Case([&](ResultVariable *ele) {
2195 if (ele->getVar()->isVariadic()) {
2196 body << " || " << op.getGetterName(ele->getVar()->name)
2197 << "().size()";
2198 } else {
2199 body << " || " << op.getGetterName(ele->getVar()->name) << "()";
2202 .Case([&](RegionVariable *reg) {
2203 body << " || " << op.getGetterName(reg->getVar()->name) << "()";
2207 body << ") {\n";
2208 genLiteralPrinter(lelement->getSpelling(), body, shouldEmitSpace,
2209 lastWasPunctuation);
2210 if (oilist->getUnitAttrParsingElement(pelement) == nullptr) {
2211 for (FormatElement *element : pelement)
2212 genElementPrinter(element, body, op, shouldEmitSpace,
2213 lastWasPunctuation);
2215 body << " }\n";
2217 return;
2220 // Emit the attribute dictionary.
2221 if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
2222 genAttrDictPrinter(*this, op, body, attrDict->isWithKeyword());
2223 lastWasPunctuation = false;
2224 return;
2227 // Emit the attribute dictionary.
2228 if (dyn_cast<PropDictDirective>(element)) {
2229 genPropDictPrinter(*this, op, body);
2230 lastWasPunctuation = false;
2231 return;
2234 // Optionally insert a space before the next element. The AttrDict printer
2235 // already adds a space as necessary.
2236 if (shouldEmitSpace || !lastWasPunctuation)
2237 body << " _odsPrinter << ' ';\n";
2238 lastWasPunctuation = false;
2239 shouldEmitSpace = true;
2241 if (auto *attr = dyn_cast<AttributeVariable>(element)) {
2242 const NamedAttribute *var = attr->getVar();
2244 // If we are formatting as an enum, symbolize the attribute as a string.
2245 if (canFormatEnumAttr(var))
2246 return genEnumAttrPrinter(var, op, body);
2248 // If we are formatting as a symbol name, handle it as a symbol name.
2249 if (shouldFormatSymbolNameAttr(var)) {
2250 body << " _odsPrinter.printSymbolName(" << op.getGetterName(var->name)
2251 << "Attr().getValue());\n";
2252 return;
2255 // Elide the attribute type if it is buildable.
2256 if (attr->getTypeBuilder())
2257 body << " _odsPrinter.printAttributeWithoutType("
2258 << op.getGetterName(var->name) << "Attr());\n";
2259 else if (attr->shouldBeQualified() ||
2260 var->attr.getStorageType() == "::mlir::Attribute")
2261 body << " _odsPrinter.printAttribute(" << op.getGetterName(var->name)
2262 << "Attr());\n";
2263 else
2264 body << "_odsPrinter.printStrippedAttrOrType("
2265 << op.getGetterName(var->name) << "Attr());\n";
2266 } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
2267 if (operand->getVar()->isVariadicOfVariadic()) {
2268 body << " ::llvm::interleaveComma("
2269 << op.getGetterName(operand->getVar()->name)
2270 << "(), _odsPrinter, [&](const auto &operands) { _odsPrinter << "
2271 "\"(\" << operands << "
2272 "\")\"; });\n";
2274 } else if (operand->getVar()->isOptional()) {
2275 body << " if (::mlir::Value value = "
2276 << op.getGetterName(operand->getVar()->name) << "())\n"
2277 << " _odsPrinter << value;\n";
2278 } else {
2279 body << " _odsPrinter << " << op.getGetterName(operand->getVar()->name)
2280 << "();\n";
2282 } else if (auto *region = dyn_cast<RegionVariable>(element)) {
2283 const NamedRegion *var = region->getVar();
2284 std::string name = op.getGetterName(var->name);
2285 if (var->isVariadic()) {
2286 genVariadicRegionPrinter(name + "()", body, hasImplicitTermTrait);
2287 } else {
2288 genRegionPrinter(name + "()", body, hasImplicitTermTrait);
2290 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
2291 const NamedSuccessor *var = successor->getVar();
2292 std::string name = op.getGetterName(var->name);
2293 if (var->isVariadic())
2294 body << " ::llvm::interleaveComma(" << name << "(), _odsPrinter);\n";
2295 else
2296 body << " _odsPrinter << " << name << "();\n";
2297 } else if (auto *dir = dyn_cast<CustomDirective>(element)) {
2298 genCustomDirectivePrinter(dir, op, body);
2299 } else if (isa<OperandsDirective>(element)) {
2300 body << " _odsPrinter << getOperation()->getOperands();\n";
2301 } else if (isa<RegionsDirective>(element)) {
2302 genVariadicRegionPrinter("getOperation()->getRegions()", body,
2303 hasImplicitTermTrait);
2304 } else if (isa<SuccessorsDirective>(element)) {
2305 body << " ::llvm::interleaveComma(getOperation()->getSuccessors(), "
2306 "_odsPrinter);\n";
2307 } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
2308 if (auto *operand = dyn_cast<OperandVariable>(dir->getArg())) {
2309 if (operand->getVar()->isVariadicOfVariadic()) {
2310 body << llvm::formatv(
2311 " ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, "
2312 "[&](::mlir::TypeRange types) {{ _odsPrinter << \"(\" << "
2313 "types << \")\"; });\n",
2314 op.getGetterName(operand->getVar()->name));
2315 return;
2318 const NamedTypeConstraint *var = nullptr;
2320 if (auto *operand = dyn_cast<OperandVariable>(dir->getArg()))
2321 var = operand->getVar();
2322 else if (auto *operand = dyn_cast<ResultVariable>(dir->getArg()))
2323 var = operand->getVar();
2325 if (var && !var->isVariadicOfVariadic() && !var->isVariadic() &&
2326 !var->isOptional()) {
2327 std::string cppClass = var->constraint.getCPPClassName();
2328 if (dir->shouldBeQualified()) {
2329 body << " _odsPrinter << " << op.getGetterName(var->name)
2330 << "().getType();\n";
2331 return;
2333 body << " {\n"
2334 << " auto type = " << op.getGetterName(var->name)
2335 << "().getType();\n"
2336 << " if (auto validType = ::llvm::dyn_cast<" << cppClass
2337 << ">(type))\n"
2338 << " _odsPrinter.printStrippedAttrOrType(validType);\n"
2339 << " else\n"
2340 << " _odsPrinter << type;\n"
2341 << " }\n";
2342 return;
2344 body << " _odsPrinter << ";
2345 genTypeOperandPrinter(dir->getArg(), op, body, /*useArrayRef=*/false)
2346 << ";\n";
2347 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
2348 body << " _odsPrinter.printFunctionalType(";
2349 genTypeOperandPrinter(dir->getInputs(), op, body) << ", ";
2350 genTypeOperandPrinter(dir->getResults(), op, body) << ");\n";
2351 } else {
2352 llvm_unreachable("unknown format element");
2356 void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
2357 auto *method = opClass.addMethod(
2358 "void", "print",
2359 MethodParameter("::mlir::OpAsmPrinter &", "_odsPrinter"));
2360 auto &body = method->body();
2362 // Flags for if we should emit a space, and if the last element was
2363 // punctuation.
2364 bool shouldEmitSpace = true, lastWasPunctuation = false;
2365 for (FormatElement *element : elements)
2366 genElementPrinter(element, body, op, shouldEmitSpace, lastWasPunctuation);
2369 //===----------------------------------------------------------------------===//
2370 // OpFormatParser
2371 //===----------------------------------------------------------------------===//
2373 /// Function to find an element within the given range that has the same name as
2374 /// 'name'.
2375 template <typename RangeT>
2376 static auto findArg(RangeT &&range, StringRef name) {
2377 auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
2378 return it != range.end() ? &*it : nullptr;
2381 namespace {
2382 /// This class implements a parser for an instance of an operation assembly
2383 /// format.
2384 class OpFormatParser : public FormatParser {
2385 public:
2386 OpFormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op)
2387 : FormatParser(mgr, op.getLoc()[0]), fmt(format), op(op),
2388 seenOperandTypes(op.getNumOperands()),
2389 seenResultTypes(op.getNumResults()) {}
2391 protected:
2392 /// Verify the format elements.
2393 LogicalResult verify(SMLoc loc, ArrayRef<FormatElement *> elements) override;
2394 /// Verify the arguments to a custom directive.
2395 LogicalResult
2396 verifyCustomDirectiveArguments(SMLoc loc,
2397 ArrayRef<FormatElement *> arguments) override;
2398 /// Verify the elements of an optional group.
2399 LogicalResult verifyOptionalGroupElements(SMLoc loc,
2400 ArrayRef<FormatElement *> elements,
2401 FormatElement *anchor) override;
2402 LogicalResult verifyOptionalGroupElement(SMLoc loc, FormatElement *element,
2403 bool isAnchor);
2405 /// Parse an operation variable.
2406 FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
2407 Context ctx) override;
2408 /// Parse an operation format directive.
2409 FailureOr<FormatElement *>
2410 parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override;
2412 private:
2413 /// This struct represents a type resolution instance. It includes a specific
2414 /// type as well as an optional transformer to apply to that type in order to
2415 /// properly resolve the type of a variable.
2416 struct TypeResolutionInstance {
2417 ConstArgument resolver;
2418 std::optional<StringRef> transformer;
2421 /// Verify the state of operation attributes within the format.
2422 LogicalResult verifyAttributes(SMLoc loc, ArrayRef<FormatElement *> elements);
2424 /// Verify that attributes elements aren't followed by colon literals.
2425 LogicalResult verifyAttributeColonType(SMLoc loc,
2426 ArrayRef<FormatElement *> elements);
2427 /// Verify that the attribute dictionary directive isn't followed by a region.
2428 LogicalResult verifyAttrDictRegion(SMLoc loc,
2429 ArrayRef<FormatElement *> elements);
2431 /// Verify the state of operation operands within the format.
2432 LogicalResult
2433 verifyOperands(SMLoc loc,
2434 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2436 /// Verify the state of operation regions within the format.
2437 LogicalResult verifyRegions(SMLoc loc);
2439 /// Verify the state of operation results within the format.
2440 LogicalResult
2441 verifyResults(SMLoc loc,
2442 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2444 /// Verify the state of operation successors within the format.
2445 LogicalResult verifySuccessors(SMLoc loc);
2447 LogicalResult verifyOIListElements(SMLoc loc,
2448 ArrayRef<FormatElement *> elements);
2450 /// Given the values of an `AllTypesMatch` trait, check for inferable type
2451 /// resolution.
2452 void handleAllTypesMatchConstraint(
2453 ArrayRef<StringRef> values,
2454 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2455 /// Check for inferable type resolution given all operands, and or results,
2456 /// have the same type. If 'includeResults' is true, the results also have the
2457 /// same type as all of the operands.
2458 void handleSameTypesConstraint(
2459 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2460 bool includeResults);
2461 /// Check for inferable type resolution based on another operand, result, or
2462 /// attribute.
2463 void handleTypesMatchConstraint(
2464 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2465 const llvm::Record &def);
2467 /// Returns an argument or attribute with the given name that has been seen
2468 /// within the format.
2469 ConstArgument findSeenArg(StringRef name);
2471 /// Parse the various different directives.
2472 FailureOr<FormatElement *> parsePropDictDirective(SMLoc loc, Context context);
2473 FailureOr<FormatElement *> parseAttrDictDirective(SMLoc loc, Context context,
2474 bool withKeyword);
2475 FailureOr<FormatElement *> parseFunctionalTypeDirective(SMLoc loc,
2476 Context context);
2477 FailureOr<FormatElement *> parseOIListDirective(SMLoc loc, Context context);
2478 LogicalResult verifyOIListParsingElement(FormatElement *element, SMLoc loc);
2479 FailureOr<FormatElement *> parseOperandsDirective(SMLoc loc, Context context);
2480 FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc,
2481 Context context);
2482 FailureOr<FormatElement *> parseReferenceDirective(SMLoc loc,
2483 Context context);
2484 FailureOr<FormatElement *> parseRegionsDirective(SMLoc loc, Context context);
2485 FailureOr<FormatElement *> parseResultsDirective(SMLoc loc, Context context);
2486 FailureOr<FormatElement *> parseSuccessorsDirective(SMLoc loc,
2487 Context context);
2488 FailureOr<FormatElement *> parseTypeDirective(SMLoc loc, Context context);
2489 FailureOr<FormatElement *> parseTypeDirectiveOperand(SMLoc loc,
2490 bool isRefChild = false);
2492 //===--------------------------------------------------------------------===//
2493 // Fields
2494 //===--------------------------------------------------------------------===//
2496 OperationFormat &fmt;
2497 Operator &op;
2499 // The following are various bits of format state used for verification
2500 // during parsing.
2501 bool hasAttrDict = false;
2502 bool hasPropDict = false;
2503 bool hasAllRegions = false, hasAllSuccessors = false;
2504 bool canInferResultTypes = false;
2505 llvm::SmallBitVector seenOperandTypes, seenResultTypes;
2506 llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs;
2507 llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
2508 llvm::DenseSet<const NamedRegion *> seenRegions;
2509 llvm::DenseSet<const NamedSuccessor *> seenSuccessors;
2510 llvm::DenseSet<const NamedProperty *> seenProperties;
2512 } // namespace
2514 LogicalResult OpFormatParser::verify(SMLoc loc,
2515 ArrayRef<FormatElement *> elements) {
2516 // Check that the attribute dictionary is in the format.
2517 if (!hasAttrDict)
2518 return emitError(loc, "'attr-dict' directive not found in "
2519 "custom assembly format");
2521 // Check for any type traits that we can use for inferring types.
2522 llvm::StringMap<TypeResolutionInstance> variableTyResolver;
2523 for (const Trait &trait : op.getTraits()) {
2524 const llvm::Record &def = trait.getDef();
2525 if (def.isSubClassOf("AllTypesMatch")) {
2526 handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),
2527 variableTyResolver);
2528 } else if (def.getName() == "SameTypeOperands") {
2529 handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false);
2530 } else if (def.getName() == "SameOperandsAndResultType") {
2531 handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
2532 } else if (def.isSubClassOf("TypesMatchWith")) {
2533 handleTypesMatchConstraint(variableTyResolver, def);
2534 } else if (!op.allResultTypesKnown()) {
2535 // This doesn't check the name directly to handle
2536 // DeclareOpInterfaceMethods<InferTypeOpInterface>
2537 // and the like.
2538 // TODO: Add hasCppInterface check.
2539 if (auto name = def.getValueAsOptionalString("cppInterfaceName")) {
2540 if (*name == "InferTypeOpInterface" &&
2541 def.getValueAsString("cppNamespace") == "::mlir")
2542 canInferResultTypes = true;
2547 // Verify the state of the various operation components.
2548 if (failed(verifyAttributes(loc, elements)) ||
2549 failed(verifyResults(loc, variableTyResolver)) ||
2550 failed(verifyOperands(loc, variableTyResolver)) ||
2551 failed(verifyRegions(loc)) || failed(verifySuccessors(loc)) ||
2552 failed(verifyOIListElements(loc, elements)))
2553 return failure();
2555 // Collect the set of used attributes in the format.
2556 fmt.usedAttributes = seenAttrs.takeVector();
2557 return success();
2560 LogicalResult
2561 OpFormatParser::verifyAttributes(SMLoc loc,
2562 ArrayRef<FormatElement *> elements) {
2563 // Check that there are no `:` literals after an attribute without a constant
2564 // type. The attribute grammar contains an optional trailing colon type, which
2565 // can lead to unexpected and generally unintended behavior. Given that, it is
2566 // better to just error out here instead.
2567 if (failed(verifyAttributeColonType(loc, elements)))
2568 return failure();
2569 // Check that there are no region variables following an attribute dicitonary.
2570 // Both start with `{` and so the optional attribute dictionary can cause
2571 // format ambiguities.
2572 if (failed(verifyAttrDictRegion(loc, elements)))
2573 return failure();
2575 // Check for VariadicOfVariadic variables. The segment attribute of those
2576 // variables will be infered.
2577 for (const NamedTypeConstraint *var : seenOperands) {
2578 if (var->constraint.isVariadicOfVariadic()) {
2579 fmt.inferredAttributes.insert(
2580 var->constraint.getVariadicOfVariadicSegmentSizeAttr());
2584 return success();
2587 /// Returns whether the single format element is optionally parsed.
2588 static bool isOptionallyParsed(FormatElement *el) {
2589 if (auto *attrVar = dyn_cast<AttributeVariable>(el)) {
2590 Attribute attr = attrVar->getVar()->attr;
2591 return attr.isOptional() || attr.hasDefaultValue();
2593 if (auto *operandVar = dyn_cast<OperandVariable>(el)) {
2594 const NamedTypeConstraint *operand = operandVar->getVar();
2595 return operand->isOptional() || operand->isVariadic() ||
2596 operand->isVariadicOfVariadic();
2598 if (auto *successorVar = dyn_cast<SuccessorVariable>(el))
2599 return successorVar->getVar()->isVariadic();
2600 if (auto *regionVar = dyn_cast<RegionVariable>(el))
2601 return regionVar->getVar()->isVariadic();
2602 return isa<WhitespaceElement, AttrDictDirective>(el);
2605 /// Scan the given range of elements from the start for an invalid format
2606 /// element that satisfies `isInvalid`, skipping any optionally-parsed elements.
2607 /// If an optional group is encountered, this function recurses into the 'then'
2608 /// and 'else' elements to check if they are invalid. Returns `success` if the
2609 /// range is known to be valid or `std::nullopt` if scanning reached the end.
2611 /// Since the guard element of an optional group is required, this function
2612 /// accepts an optional element pointer to mark it as required.
2613 static std::optional<LogicalResult> checkRangeForElement(
2614 FormatElement *base,
2615 function_ref<bool(FormatElement *, FormatElement *)> isInvalid,
2616 iterator_range<ArrayRef<FormatElement *>::iterator> elementRange,
2617 FormatElement *optionalGuard = nullptr) {
2618 for (FormatElement *element : elementRange) {
2619 // If we encounter an invalid element, return an error.
2620 if (isInvalid(base, element))
2621 return failure();
2623 // Recurse on optional groups.
2624 if (auto *optional = dyn_cast<OptionalElement>(element)) {
2625 if (std::optional<LogicalResult> result = checkRangeForElement(
2626 base, isInvalid, optional->getThenElements(),
2627 // The optional group guard is required for the group.
2628 optional->getThenElements().front()))
2629 if (failed(*result))
2630 return failure();
2631 if (std::optional<LogicalResult> result = checkRangeForElement(
2632 base, isInvalid, optional->getElseElements()))
2633 if (failed(*result))
2634 return failure();
2635 // Skip the optional group.
2636 continue;
2639 // Skip optionally parsed elements.
2640 if (element != optionalGuard && isOptionallyParsed(element))
2641 continue;
2643 // We found a closing element that is valid.
2644 return success();
2646 // Return std::nullopt to indicate that we reached the end.
2647 return std::nullopt;
2650 /// For the given elements, check whether any attributes are followed by a colon
2651 /// literal, resulting in an ambiguous assembly format. Returns a non-null
2652 /// attribute if verification of said attribute reached the end of the range.
2653 /// Returns null if all attribute elements are verified.
2654 static FailureOr<FormatElement *> verifyAdjacentElements(
2655 function_ref<bool(FormatElement *)> isBase,
2656 function_ref<bool(FormatElement *, FormatElement *)> isInvalid,
2657 ArrayRef<FormatElement *> elements) {
2658 for (auto *it = elements.begin(), *e = elements.end(); it != e; ++it) {
2659 // The current attribute being verified.
2660 FormatElement *base;
2662 if (isBase(*it)) {
2663 base = *it;
2664 } else if (auto *optional = dyn_cast<OptionalElement>(*it)) {
2665 // Recurse on optional groups.
2666 FailureOr<FormatElement *> thenResult = verifyAdjacentElements(
2667 isBase, isInvalid, optional->getThenElements());
2668 if (failed(thenResult))
2669 return failure();
2670 FailureOr<FormatElement *> elseResult = verifyAdjacentElements(
2671 isBase, isInvalid, optional->getElseElements());
2672 if (failed(elseResult))
2673 return failure();
2674 // If either optional group has an unverified attribute, save it.
2675 // Otherwise, move on to the next element.
2676 if (!(base = *thenResult) && !(base = *elseResult))
2677 continue;
2678 } else {
2679 continue;
2682 // Verify subsequent elements for potential ambiguities.
2683 if (std::optional<LogicalResult> result =
2684 checkRangeForElement(base, isInvalid, {std::next(it), e})) {
2685 if (failed(*result))
2686 return failure();
2687 } else {
2688 // Since we reached the end, return the attribute as unverified.
2689 return base;
2692 // All attribute elements are known to be verified.
2693 return nullptr;
2696 LogicalResult
2697 OpFormatParser::verifyAttributeColonType(SMLoc loc,
2698 ArrayRef<FormatElement *> elements) {
2699 auto isBase = [](FormatElement *el) {
2700 auto *attr = dyn_cast<AttributeVariable>(el);
2701 if (!attr)
2702 return false;
2703 // Check only attributes without type builders or that are known to call
2704 // the generic attribute parser.
2705 return !attr->getTypeBuilder() &&
2706 (attr->shouldBeQualified() ||
2707 attr->getVar()->attr.getStorageType() == "::mlir::Attribute");
2709 auto isInvalid = [&](FormatElement *base, FormatElement *el) {
2710 auto *literal = dyn_cast<LiteralElement>(el);
2711 if (!literal || literal->getSpelling() != ":")
2712 return false;
2713 // If we encounter `:`, the range is known to be invalid.
2714 (void)emitError(
2715 loc,
2716 llvm::formatv("format ambiguity caused by `:` literal found after "
2717 "attribute `{0}` which does not have a buildable type",
2718 cast<AttributeVariable>(base)->getVar()->name));
2719 return true;
2721 return verifyAdjacentElements(isBase, isInvalid, elements);
2724 LogicalResult
2725 OpFormatParser::verifyAttrDictRegion(SMLoc loc,
2726 ArrayRef<FormatElement *> elements) {
2727 auto isBase = [](FormatElement *el) {
2728 if (auto *attrDict = dyn_cast<AttrDictDirective>(el))
2729 return !attrDict->isWithKeyword();
2730 return false;
2732 auto isInvalid = [&](FormatElement *base, FormatElement *el) {
2733 auto *region = dyn_cast<RegionVariable>(el);
2734 if (!region)
2735 return false;
2736 (void)emitErrorAndNote(
2737 loc,
2738 llvm::formatv("format ambiguity caused by `attr-dict` directive "
2739 "followed by region `{0}`",
2740 region->getVar()->name),
2741 "try using `attr-dict-with-keyword` instead");
2742 return true;
2744 return verifyAdjacentElements(isBase, isInvalid, elements);
2747 LogicalResult OpFormatParser::verifyOperands(
2748 SMLoc loc, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2749 // Check that all of the operands are within the format, and their types can
2750 // be inferred.
2751 auto &buildableTypes = fmt.buildableTypes;
2752 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
2753 NamedTypeConstraint &operand = op.getOperand(i);
2755 // Check that the operand itself is in the format.
2756 if (!fmt.allOperands && !seenOperands.count(&operand)) {
2757 return emitErrorAndNote(loc,
2758 "operand #" + Twine(i) + ", named '" +
2759 operand.name + "', not found",
2760 "suggest adding a '$" + operand.name +
2761 "' directive to the custom assembly format");
2764 // Check that the operand type is in the format, or that it can be inferred.
2765 if (fmt.allOperandTypes || seenOperandTypes.test(i))
2766 continue;
2768 // Check to see if we can infer this type from another variable.
2769 auto varResolverIt = variableTyResolver.find(op.getOperand(i).name);
2770 if (varResolverIt != variableTyResolver.end()) {
2771 TypeResolutionInstance &resolver = varResolverIt->second;
2772 fmt.operandTypes[i].setResolver(resolver.resolver, resolver.transformer);
2773 continue;
2776 // Similarly to results, allow a custom builder for resolving the type if
2777 // we aren't using the 'operands' directive.
2778 std::optional<StringRef> builder = operand.constraint.getBuilderCall();
2779 if (!builder || (fmt.allOperands && operand.isVariableLength())) {
2780 return emitErrorAndNote(
2781 loc,
2782 "type of operand #" + Twine(i) + ", named '" + operand.name +
2783 "', is not buildable and a buildable type cannot be inferred",
2784 "suggest adding a type constraint to the operation or adding a "
2785 "'type($" +
2786 operand.name + ")' directive to the " + "custom assembly format");
2788 auto it = buildableTypes.insert({*builder, buildableTypes.size()});
2789 fmt.operandTypes[i].setBuilderIdx(it.first->second);
2791 return success();
2794 LogicalResult OpFormatParser::verifyRegions(SMLoc loc) {
2795 // Check that all of the regions are within the format.
2796 if (hasAllRegions)
2797 return success();
2799 for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) {
2800 const NamedRegion &region = op.getRegion(i);
2801 if (!seenRegions.count(&region)) {
2802 return emitErrorAndNote(loc,
2803 "region #" + Twine(i) + ", named '" +
2804 region.name + "', not found",
2805 "suggest adding a '$" + region.name +
2806 "' directive to the custom assembly format");
2809 return success();
2812 LogicalResult OpFormatParser::verifyResults(
2813 SMLoc loc, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2814 // If we format all of the types together, there is nothing to check.
2815 if (fmt.allResultTypes)
2816 return success();
2818 // If no result types are specified and we can infer them, infer all result
2819 // types
2820 if (op.getNumResults() > 0 && seenResultTypes.count() == 0 &&
2821 canInferResultTypes) {
2822 fmt.infersResultTypes = true;
2823 return success();
2826 // Check that all of the result types can be inferred.
2827 auto &buildableTypes = fmt.buildableTypes;
2828 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
2829 if (seenResultTypes.test(i))
2830 continue;
2832 // Check to see if we can infer this type from another variable.
2833 auto varResolverIt = variableTyResolver.find(op.getResultName(i));
2834 if (varResolverIt != variableTyResolver.end()) {
2835 TypeResolutionInstance resolver = varResolverIt->second;
2836 fmt.resultTypes[i].setResolver(resolver.resolver, resolver.transformer);
2837 continue;
2840 // If the result is not variable length, allow for the case where the type
2841 // has a builder that we can use.
2842 NamedTypeConstraint &result = op.getResult(i);
2843 std::optional<StringRef> builder = result.constraint.getBuilderCall();
2844 if (!builder || result.isVariableLength()) {
2845 return emitErrorAndNote(
2846 loc,
2847 "type of result #" + Twine(i) + ", named '" + result.name +
2848 "', is not buildable and a buildable type cannot be inferred",
2849 "suggest adding a type constraint to the operation or adding a "
2850 "'type($" +
2851 result.name + ")' directive to the " + "custom assembly format");
2853 // Note in the format that this result uses the custom builder.
2854 auto it = buildableTypes.insert({*builder, buildableTypes.size()});
2855 fmt.resultTypes[i].setBuilderIdx(it.first->second);
2857 return success();
2860 LogicalResult OpFormatParser::verifySuccessors(SMLoc loc) {
2861 // Check that all of the successors are within the format.
2862 if (hasAllSuccessors)
2863 return success();
2865 for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) {
2866 const NamedSuccessor &successor = op.getSuccessor(i);
2867 if (!seenSuccessors.count(&successor)) {
2868 return emitErrorAndNote(loc,
2869 "successor #" + Twine(i) + ", named '" +
2870 successor.name + "', not found",
2871 "suggest adding a '$" + successor.name +
2872 "' directive to the custom assembly format");
2875 return success();
2878 LogicalResult
2879 OpFormatParser::verifyOIListElements(SMLoc loc,
2880 ArrayRef<FormatElement *> elements) {
2881 // Check that all of the successors are within the format.
2882 SmallVector<StringRef> prohibitedLiterals;
2883 for (FormatElement *it : elements) {
2884 if (auto *oilist = dyn_cast<OIListElement>(it)) {
2885 if (!prohibitedLiterals.empty()) {
2886 // We just saw an oilist element in last iteration. Literals should not
2887 // match.
2888 for (LiteralElement *literal : oilist->getLiteralElements()) {
2889 if (find(prohibitedLiterals, literal->getSpelling()) !=
2890 prohibitedLiterals.end()) {
2891 return emitError(
2892 loc, "format ambiguity because " + literal->getSpelling() +
2893 " is used in two adjacent oilist elements.");
2897 for (LiteralElement *literal : oilist->getLiteralElements())
2898 prohibitedLiterals.push_back(literal->getSpelling());
2899 } else if (auto *literal = dyn_cast<LiteralElement>(it)) {
2900 if (find(prohibitedLiterals, literal->getSpelling()) !=
2901 prohibitedLiterals.end()) {
2902 return emitError(
2903 loc,
2904 "format ambiguity because " + literal->getSpelling() +
2905 " is used both in oilist element and the adjacent literal.");
2907 prohibitedLiterals.clear();
2908 } else {
2909 prohibitedLiterals.clear();
2912 return success();
2915 void OpFormatParser::handleAllTypesMatchConstraint(
2916 ArrayRef<StringRef> values,
2917 llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2918 for (unsigned i = 0, e = values.size(); i != e; ++i) {
2919 // Check to see if this value matches a resolved operand or result type.
2920 ConstArgument arg = findSeenArg(values[i]);
2921 if (!arg)
2922 continue;
2924 // Mark this value as the type resolver for the other variables.
2925 for (unsigned j = 0; j != i; ++j)
2926 variableTyResolver[values[j]] = {arg, std::nullopt};
2927 for (unsigned j = i + 1; j != e; ++j)
2928 variableTyResolver[values[j]] = {arg, std::nullopt};
2932 void OpFormatParser::handleSameTypesConstraint(
2933 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2934 bool includeResults) {
2935 const NamedTypeConstraint *resolver = nullptr;
2936 int resolvedIt = -1;
2938 // Check to see if there is an operand or result to use for the resolution.
2939 if ((resolvedIt = seenOperandTypes.find_first()) != -1)
2940 resolver = &op.getOperand(resolvedIt);
2941 else if (includeResults && (resolvedIt = seenResultTypes.find_first()) != -1)
2942 resolver = &op.getResult(resolvedIt);
2943 else
2944 return;
2946 // Set the resolvers for each operand and result.
2947 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i)
2948 if (!seenOperandTypes.test(i))
2949 variableTyResolver[op.getOperand(i).name] = {resolver, std::nullopt};
2950 if (includeResults) {
2951 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
2952 if (!seenResultTypes.test(i))
2953 variableTyResolver[op.getResultName(i)] = {resolver, std::nullopt};
2957 void OpFormatParser::handleTypesMatchConstraint(
2958 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2959 const llvm::Record &def) {
2960 StringRef lhsName = def.getValueAsString("lhs");
2961 StringRef rhsName = def.getValueAsString("rhs");
2962 StringRef transformer = def.getValueAsString("transformer");
2963 if (ConstArgument arg = findSeenArg(lhsName))
2964 variableTyResolver[rhsName] = {arg, transformer};
2967 ConstArgument OpFormatParser::findSeenArg(StringRef name) {
2968 if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name))
2969 return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;
2970 if (const NamedTypeConstraint *arg = findArg(op.getResults(), name))
2971 return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr;
2972 if (const NamedAttribute *attr = findArg(op.getAttributes(), name))
2973 return seenAttrs.count(attr) ? attr : nullptr;
2974 return nullptr;
2977 FailureOr<FormatElement *>
2978 OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
2979 // Check that the parsed argument is something actually registered on the op.
2980 // Attributes
2981 if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) {
2982 if (ctx == TypeDirectiveContext)
2983 return emitError(
2984 loc, "attributes cannot be used as children to a `type` directive");
2985 if (ctx == RefDirectiveContext) {
2986 if (!seenAttrs.count(attr))
2987 return emitError(loc, "attribute '" + name +
2988 "' must be bound before it is referenced");
2989 } else if (!seenAttrs.insert(attr)) {
2990 return emitError(loc, "attribute '" + name + "' is already bound");
2993 return create<AttributeVariable>(attr);
2996 if (const NamedProperty *property = findArg(op.getProperties(), name)) {
2997 if (ctx != CustomDirectiveContext && ctx != RefDirectiveContext)
2998 return emitError(
2999 loc, "properties currently only supported in `custom` directive");
3001 if (ctx == RefDirectiveContext) {
3002 if (!seenProperties.count(property))
3003 return emitError(loc, "property '" + name +
3004 "' must be bound before it is referenced");
3005 } else {
3006 if (!seenProperties.insert(property).second)
3007 return emitError(loc, "property '" + name + "' is already bound");
3010 return create<PropertyVariable>(property);
3013 // Operands
3014 if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) {
3015 if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
3016 if (fmt.allOperands || !seenOperands.insert(operand).second)
3017 return emitError(loc, "operand '" + name + "' is already bound");
3018 } else if (ctx == RefDirectiveContext && !seenOperands.count(operand)) {
3019 return emitError(loc, "operand '" + name +
3020 "' must be bound before it is referenced");
3022 return create<OperandVariable>(operand);
3024 // Regions
3025 if (const NamedRegion *region = findArg(op.getRegions(), name)) {
3026 if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
3027 if (hasAllRegions || !seenRegions.insert(region).second)
3028 return emitError(loc, "region '" + name + "' is already bound");
3029 } else if (ctx == RefDirectiveContext && !seenRegions.count(region)) {
3030 return emitError(loc, "region '" + name +
3031 "' must be bound before it is referenced");
3032 } else {
3033 return emitError(loc, "regions can only be used at the top level");
3035 return create<RegionVariable>(region);
3037 // Results.
3038 if (const auto *result = findArg(op.getResults(), name)) {
3039 if (ctx != TypeDirectiveContext)
3040 return emitError(loc, "result variables can can only be used as a child "
3041 "to a 'type' directive");
3042 return create<ResultVariable>(result);
3044 // Successors.
3045 if (const auto *successor = findArg(op.getSuccessors(), name)) {
3046 if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
3047 if (hasAllSuccessors || !seenSuccessors.insert(successor).second)
3048 return emitError(loc, "successor '" + name + "' is already bound");
3049 } else if (ctx == RefDirectiveContext && !seenSuccessors.count(successor)) {
3050 return emitError(loc, "successor '" + name +
3051 "' must be bound before it is referenced");
3052 } else {
3053 return emitError(loc, "successors can only be used at the top level");
3056 return create<SuccessorVariable>(successor);
3058 return emitError(loc, "expected variable to refer to an argument, region, "
3059 "result, or successor");
3062 FailureOr<FormatElement *>
3063 OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
3064 Context ctx) {
3065 switch (kind) {
3066 case FormatToken::kw_prop_dict:
3067 return parsePropDictDirective(loc, ctx);
3068 case FormatToken::kw_attr_dict:
3069 return parseAttrDictDirective(loc, ctx,
3070 /*withKeyword=*/false);
3071 case FormatToken::kw_attr_dict_w_keyword:
3072 return parseAttrDictDirective(loc, ctx,
3073 /*withKeyword=*/true);
3074 case FormatToken::kw_functional_type:
3075 return parseFunctionalTypeDirective(loc, ctx);
3076 case FormatToken::kw_operands:
3077 return parseOperandsDirective(loc, ctx);
3078 case FormatToken::kw_qualified:
3079 return parseQualifiedDirective(loc, ctx);
3080 case FormatToken::kw_regions:
3081 return parseRegionsDirective(loc, ctx);
3082 case FormatToken::kw_results:
3083 return parseResultsDirective(loc, ctx);
3084 case FormatToken::kw_successors:
3085 return parseSuccessorsDirective(loc, ctx);
3086 case FormatToken::kw_ref:
3087 return parseReferenceDirective(loc, ctx);
3088 case FormatToken::kw_type:
3089 return parseTypeDirective(loc, ctx);
3090 case FormatToken::kw_oilist:
3091 return parseOIListDirective(loc, ctx);
3093 default:
3094 return emitError(loc, "unsupported directive kind");
3098 FailureOr<FormatElement *>
3099 OpFormatParser::parseAttrDictDirective(SMLoc loc, Context context,
3100 bool withKeyword) {
3101 if (context == TypeDirectiveContext)
3102 return emitError(loc, "'attr-dict' directive can only be used as a "
3103 "top-level directive");
3105 if (context == RefDirectiveContext) {
3106 if (!hasAttrDict)
3107 return emitError(loc, "'ref' of 'attr-dict' is not bound by a prior "
3108 "'attr-dict' directive");
3110 // Otherwise, this is a top-level context.
3111 } else {
3112 if (hasAttrDict)
3113 return emitError(loc, "'attr-dict' directive has already been seen");
3114 hasAttrDict = true;
3117 return create<AttrDictDirective>(withKeyword);
3120 FailureOr<FormatElement *>
3121 OpFormatParser::parsePropDictDirective(SMLoc loc, Context context) {
3122 if (context == TypeDirectiveContext)
3123 return emitError(loc, "'prop-dict' directive can only be used as a "
3124 "top-level directive");
3126 if (context == RefDirectiveContext)
3127 llvm::report_fatal_error("'ref' of 'prop-dict' unsupported");
3128 // Otherwise, this is a top-level context.
3130 if (hasPropDict)
3131 return emitError(loc, "'prop-dict' directive has already been seen");
3132 hasPropDict = true;
3134 return create<PropDictDirective>();
3137 LogicalResult OpFormatParser::verifyCustomDirectiveArguments(
3138 SMLoc loc, ArrayRef<FormatElement *> arguments) {
3139 for (FormatElement *argument : arguments) {
3140 if (!isa<AttrDictDirective, AttributeVariable, OperandVariable,
3141 PropertyVariable, RefDirective, RegionVariable, SuccessorVariable,
3142 StringElement, TypeDirective>(argument)) {
3143 // TODO: FormatElement should have location info attached.
3144 return emitError(loc, "only variables and types may be used as "
3145 "parameters to a custom directive");
3147 if (auto *type = dyn_cast<TypeDirective>(argument)) {
3148 if (!isa<OperandVariable, ResultVariable>(type->getArg())) {
3149 return emitError(loc, "type directives within a custom directive may "
3150 "only refer to variables");
3154 return success();
3157 FailureOr<FormatElement *>
3158 OpFormatParser::parseFunctionalTypeDirective(SMLoc loc, Context context) {
3159 if (context != TopLevelContext)
3160 return emitError(
3161 loc, "'functional-type' is only valid as a top-level directive");
3163 // Parse the main operand.
3164 FailureOr<FormatElement *> inputs, results;
3165 if (failed(parseToken(FormatToken::l_paren,
3166 "expected '(' before argument list")) ||
3167 failed(inputs = parseTypeDirectiveOperand(loc)) ||
3168 failed(parseToken(FormatToken::comma,
3169 "expected ',' after inputs argument")) ||
3170 failed(results = parseTypeDirectiveOperand(loc)) ||
3171 failed(
3172 parseToken(FormatToken::r_paren, "expected ')' after argument list")))
3173 return failure();
3174 return create<FunctionalTypeDirective>(*inputs, *results);
3177 FailureOr<FormatElement *>
3178 OpFormatParser::parseOperandsDirective(SMLoc loc, Context context) {
3179 if (context == RefDirectiveContext) {
3180 if (!fmt.allOperands)
3181 return emitError(loc, "'ref' of 'operands' is not bound by a prior "
3182 "'operands' directive");
3184 } else if (context == TopLevelContext || context == CustomDirectiveContext) {
3185 if (fmt.allOperands || !seenOperands.empty())
3186 return emitError(loc, "'operands' directive creates overlap in format");
3187 fmt.allOperands = true;
3189 return create<OperandsDirective>();
3192 FailureOr<FormatElement *>
3193 OpFormatParser::parseReferenceDirective(SMLoc loc, Context context) {
3194 if (context != CustomDirectiveContext)
3195 return emitError(loc, "'ref' is only valid within a `custom` directive");
3197 FailureOr<FormatElement *> arg;
3198 if (failed(parseToken(FormatToken::l_paren,
3199 "expected '(' before argument list")) ||
3200 failed(arg = parseElement(RefDirectiveContext)) ||
3201 failed(
3202 parseToken(FormatToken::r_paren, "expected ')' after argument list")))
3203 return failure();
3205 return create<RefDirective>(*arg);
3208 FailureOr<FormatElement *>
3209 OpFormatParser::parseRegionsDirective(SMLoc loc, Context context) {
3210 if (context == TypeDirectiveContext)
3211 return emitError(loc, "'regions' is only valid as a top-level directive");
3212 if (context == RefDirectiveContext) {
3213 if (!hasAllRegions)
3214 return emitError(loc, "'ref' of 'regions' is not bound by a prior "
3215 "'regions' directive");
3217 // Otherwise, this is a TopLevel directive.
3218 } else {
3219 if (hasAllRegions || !seenRegions.empty())
3220 return emitError(loc, "'regions' directive creates overlap in format");
3221 hasAllRegions = true;
3223 return create<RegionsDirective>();
3226 FailureOr<FormatElement *>
3227 OpFormatParser::parseResultsDirective(SMLoc loc, Context context) {
3228 if (context != TypeDirectiveContext)
3229 return emitError(loc, "'results' directive can can only be used as a child "
3230 "to a 'type' directive");
3231 return create<ResultsDirective>();
3234 FailureOr<FormatElement *>
3235 OpFormatParser::parseSuccessorsDirective(SMLoc loc, Context context) {
3236 if (context == TypeDirectiveContext)
3237 return emitError(loc,
3238 "'successors' is only valid as a top-level directive");
3239 if (context == RefDirectiveContext) {
3240 if (!hasAllSuccessors)
3241 return emitError(loc, "'ref' of 'successors' is not bound by a prior "
3242 "'successors' directive");
3244 // Otherwise, this is a TopLevel directive.
3245 } else {
3246 if (hasAllSuccessors || !seenSuccessors.empty())
3247 return emitError(loc, "'successors' directive creates overlap in format");
3248 hasAllSuccessors = true;
3250 return create<SuccessorsDirective>();
3253 FailureOr<FormatElement *>
3254 OpFormatParser::parseOIListDirective(SMLoc loc, Context context) {
3255 if (failed(parseToken(FormatToken::l_paren,
3256 "expected '(' before oilist argument list")))
3257 return failure();
3258 std::vector<FormatElement *> literalElements;
3259 std::vector<std::vector<FormatElement *>> parsingElements;
3260 do {
3261 FailureOr<FormatElement *> lelement = parseLiteral(context);
3262 if (failed(lelement))
3263 return failure();
3264 literalElements.push_back(*lelement);
3265 parsingElements.emplace_back();
3266 std::vector<FormatElement *> &currParsingElements = parsingElements.back();
3267 while (peekToken().getKind() != FormatToken::pipe &&
3268 peekToken().getKind() != FormatToken::r_paren) {
3269 FailureOr<FormatElement *> pelement = parseElement(context);
3270 if (failed(pelement) ||
3271 failed(verifyOIListParsingElement(*pelement, loc)))
3272 return failure();
3273 currParsingElements.push_back(*pelement);
3275 if (peekToken().getKind() == FormatToken::pipe) {
3276 consumeToken();
3277 continue;
3279 if (peekToken().getKind() == FormatToken::r_paren) {
3280 consumeToken();
3281 break;
3283 } while (true);
3285 return create<OIListElement>(std::move(literalElements),
3286 std::move(parsingElements));
3289 LogicalResult OpFormatParser::verifyOIListParsingElement(FormatElement *element,
3290 SMLoc loc) {
3291 SmallVector<VariableElement *> vars;
3292 collect(element, vars);
3293 for (VariableElement *elem : vars) {
3294 LogicalResult res =
3295 TypeSwitch<FormatElement *, LogicalResult>(elem)
3296 // Only optional attributes can be within an oilist parsing group.
3297 .Case([&](AttributeVariable *attrEle) {
3298 if (!attrEle->getVar()->attr.isOptional() &&
3299 !attrEle->getVar()->attr.hasDefaultValue())
3300 return emitError(loc, "only optional attributes can be used in "
3301 "an oilist parsing group");
3302 return success();
3304 // Only optional-like(i.e. variadic) operands can be within an
3305 // oilist parsing group.
3306 .Case([&](OperandVariable *ele) {
3307 if (!ele->getVar()->isVariableLength())
3308 return emitError(loc, "only variable length operands can be "
3309 "used within an oilist parsing group");
3310 return success();
3312 // Only optional-like(i.e. variadic) results can be within an oilist
3313 // parsing group.
3314 .Case([&](ResultVariable *ele) {
3315 if (!ele->getVar()->isVariableLength())
3316 return emitError(loc, "only variable length results can be "
3317 "used within an oilist parsing group");
3318 return success();
3320 .Case([&](RegionVariable *) { return success(); })
3321 .Default([&](FormatElement *) {
3322 return emitError(loc,
3323 "only literals, types, and variables can be "
3324 "used within an oilist group");
3326 if (failed(res))
3327 return failure();
3329 return success();
3332 FailureOr<FormatElement *> OpFormatParser::parseTypeDirective(SMLoc loc,
3333 Context context) {
3334 if (context == TypeDirectiveContext)
3335 return emitError(loc, "'type' cannot be used as a child of another `type`");
3337 bool isRefChild = context == RefDirectiveContext;
3338 FailureOr<FormatElement *> operand;
3339 if (failed(parseToken(FormatToken::l_paren,
3340 "expected '(' before argument list")) ||
3341 failed(operand = parseTypeDirectiveOperand(loc, isRefChild)) ||
3342 failed(
3343 parseToken(FormatToken::r_paren, "expected ')' after argument list")))
3344 return failure();
3346 return create<TypeDirective>(*operand);
3349 FailureOr<FormatElement *>
3350 OpFormatParser::parseQualifiedDirective(SMLoc loc, Context context) {
3351 FailureOr<FormatElement *> element;
3352 if (failed(parseToken(FormatToken::l_paren,
3353 "expected '(' before argument list")) ||
3354 failed(element = parseElement(context)) ||
3355 failed(
3356 parseToken(FormatToken::r_paren, "expected ')' after argument list")))
3357 return failure();
3358 return TypeSwitch<FormatElement *, FailureOr<FormatElement *>>(*element)
3359 .Case<AttributeVariable, TypeDirective>([](auto *element) {
3360 element->setShouldBeQualified();
3361 return element;
3363 .Default([&](auto *element) {
3364 return this->emitError(
3365 loc,
3366 "'qualified' directive expects an attribute or a `type` directive");
3370 FailureOr<FormatElement *>
3371 OpFormatParser::parseTypeDirectiveOperand(SMLoc loc, bool isRefChild) {
3372 FailureOr<FormatElement *> result = parseElement(TypeDirectiveContext);
3373 if (failed(result))
3374 return failure();
3376 FormatElement *element = *result;
3377 if (isa<LiteralElement>(element))
3378 return emitError(
3379 loc, "'type' directive operand expects variable or directive operand");
3381 if (auto *var = dyn_cast<OperandVariable>(element)) {
3382 unsigned opIdx = var->getVar() - op.operand_begin();
3383 if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
3384 return emitError(loc, "'type' of '" + var->getVar()->name +
3385 "' is already bound");
3386 if (isRefChild && !(fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
3387 return emitError(loc, "'ref' of 'type($" + var->getVar()->name +
3388 ")' is not bound by a prior 'type' directive");
3389 seenOperandTypes.set(opIdx);
3390 } else if (auto *var = dyn_cast<ResultVariable>(element)) {
3391 unsigned resIdx = var->getVar() - op.result_begin();
3392 if (!isRefChild && (fmt.allResultTypes || seenResultTypes.test(resIdx)))
3393 return emitError(loc, "'type' of '" + var->getVar()->name +
3394 "' is already bound");
3395 if (isRefChild && !(fmt.allResultTypes || seenResultTypes.test(resIdx)))
3396 return emitError(loc, "'ref' of 'type($" + var->getVar()->name +
3397 ")' is not bound by a prior 'type' directive");
3398 seenResultTypes.set(resIdx);
3399 } else if (isa<OperandsDirective>(&*element)) {
3400 if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.any()))
3401 return emitError(loc, "'operands' 'type' is already bound");
3402 if (isRefChild && !fmt.allOperandTypes)
3403 return emitError(loc, "'ref' of 'type(operands)' is not bound by a prior "
3404 "'type' directive");
3405 fmt.allOperandTypes = true;
3406 } else if (isa<ResultsDirective>(&*element)) {
3407 if (!isRefChild && (fmt.allResultTypes || seenResultTypes.any()))
3408 return emitError(loc, "'results' 'type' is already bound");
3409 if (isRefChild && !fmt.allResultTypes)
3410 return emitError(loc, "'ref' of 'type(results)' is not bound by a prior "
3411 "'type' directive");
3412 fmt.allResultTypes = true;
3413 } else {
3414 return emitError(loc, "invalid argument to 'type' directive");
3416 return element;
3419 LogicalResult OpFormatParser::verifyOptionalGroupElements(
3420 SMLoc loc, ArrayRef<FormatElement *> elements, FormatElement *anchor) {
3421 for (FormatElement *element : elements) {
3422 if (failed(verifyOptionalGroupElement(loc, element, element == anchor)))
3423 return failure();
3425 return success();
3428 LogicalResult OpFormatParser::verifyOptionalGroupElement(SMLoc loc,
3429 FormatElement *element,
3430 bool isAnchor) {
3431 return TypeSwitch<FormatElement *, LogicalResult>(element)
3432 // All attributes can be within the optional group, but only optional
3433 // attributes can be the anchor.
3434 .Case([&](AttributeVariable *attrEle) {
3435 Attribute attr = attrEle->getVar()->attr;
3436 if (isAnchor && !(attr.isOptional() || attr.hasDefaultValue()))
3437 return emitError(loc, "only optional or default-valued attributes "
3438 "can be used to anchor an optional group");
3439 return success();
3441 // Only optional-like(i.e. variadic) operands can be within an optional
3442 // group.
3443 .Case([&](OperandVariable *ele) {
3444 if (!ele->getVar()->isVariableLength())
3445 return emitError(loc, "only variable length operands can be used "
3446 "within an optional group");
3447 return success();
3449 // Only optional-like(i.e. variadic) results can be within an optional
3450 // group.
3451 .Case([&](ResultVariable *ele) {
3452 if (!ele->getVar()->isVariableLength())
3453 return emitError(loc, "only variable length results can be used "
3454 "within an optional group");
3455 return success();
3457 .Case([&](RegionVariable *) {
3458 // TODO: When ODS has proper support for marking "optional" regions, add
3459 // a check here.
3460 return success();
3462 .Case([&](TypeDirective *ele) {
3463 return verifyOptionalGroupElement(loc, ele->getArg(),
3464 /*isAnchor=*/false);
3466 .Case([&](FunctionalTypeDirective *ele) {
3467 if (failed(verifyOptionalGroupElement(loc, ele->getInputs(),
3468 /*isAnchor=*/false)))
3469 return failure();
3470 return verifyOptionalGroupElement(loc, ele->getResults(),
3471 /*isAnchor=*/false);
3473 .Case([&](CustomDirective *ele) {
3474 if (!isAnchor)
3475 return success();
3476 // Verify each child as being valid in an optional group. They are all
3477 // potential anchors if the custom directive was marked as one.
3478 for (FormatElement *child : ele->getArguments()) {
3479 if (isa<RefDirective>(child))
3480 continue;
3481 if (failed(verifyOptionalGroupElement(loc, child, /*isAnchor=*/true)))
3482 return failure();
3484 return success();
3486 // Literals, whitespace, and custom directives may be used, but they can't
3487 // anchor the group.
3488 .Case<LiteralElement, WhitespaceElement, OptionalElement>(
3489 [&](FormatElement *) {
3490 if (isAnchor)
3491 return emitError(loc, "only variables and types can be used "
3492 "to anchor an optional group");
3493 return success();
3495 .Default([&](FormatElement *) {
3496 return emitError(loc, "only literals, types, and variables can be "
3497 "used within an optional group");
3501 //===----------------------------------------------------------------------===//
3502 // Interface
3503 //===----------------------------------------------------------------------===//
3505 void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass) {
3506 // TODO: Operator doesn't expose all necessary functionality via
3507 // the const interface.
3508 Operator &op = const_cast<Operator &>(constOp);
3509 if (!op.hasAssemblyFormat())
3510 return;
3512 // Parse the format description.
3513 llvm::SourceMgr mgr;
3514 mgr.AddNewSourceBuffer(
3515 llvm::MemoryBuffer::getMemBuffer(op.getAssemblyFormat()), SMLoc());
3516 OperationFormat format(op);
3517 OpFormatParser parser(mgr, format, op);
3518 FailureOr<std::vector<FormatElement *>> elements = parser.parse();
3519 if (failed(elements)) {
3520 // Exit the process if format errors are treated as fatal.
3521 if (formatErrorIsFatal) {
3522 // Invoke the interrupt handlers to run the file cleanup handlers.
3523 llvm::sys::RunInterruptHandlers();
3524 std::exit(1);
3526 return;
3528 format.elements = std::move(*elements);
3530 // Generate the printer and parser based on the parsed format.
3531 format.genParser(op, opClass);
3532 format.genPrinter(op, opClass);