[Infra] Fix version-check workflow (#100090)
[llvm-project.git] / mlir / tools / mlir-tblgen / OpFormatGen.cpp
bloba97d8760842a98aa929cd1439ae3ae1d6a4ce319
1 //===- OpFormatGen.cpp - MLIR operation asm format generator --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "OpFormatGen.h"
10 #include "FormatGen.h"
11 #include "OpClass.h"
12 #include "mlir/Support/LLVM.h"
13 #include "mlir/TableGen/Class.h"
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "mlir/TableGen/Trait.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallBitVector.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Signals.h"
24 #include "llvm/Support/SourceMgr.h"
25 #include "llvm/TableGen/Record.h"
27 #define DEBUG_TYPE "mlir-tblgen-opformatgen"
29 using namespace mlir;
30 using namespace mlir::tblgen;
32 //===----------------------------------------------------------------------===//
33 // VariableElement
35 namespace {
36 /// This class represents an instance of an op variable element. A variable
37 /// refers to something registered on the operation itself, e.g. an operand,
38 /// result, attribute, region, or successor.
39 template <typename VarT, VariableElement::Kind VariableKind>
40 class OpVariableElement : public VariableElementBase<VariableKind> {
41 public:
42 using Base = OpVariableElement<VarT, VariableKind>;
44 /// Create an op variable element with the variable value.
45 OpVariableElement(const VarT *var) : var(var) {}
47 /// Get the variable.
48 const VarT *getVar() { return var; }
50 protected:
51 /// The op variable, e.g. a type or attribute constraint.
52 const VarT *var;
55 /// This class represents a variable that refers to an attribute argument.
56 struct AttributeVariable
57 : public OpVariableElement<NamedAttribute, VariableElement::Attribute> {
58 using Base::Base;
60 /// Return the constant builder call for the type of this attribute, or
61 /// std::nullopt if it doesn't have one.
62 std::optional<StringRef> getTypeBuilder() const {
63 std::optional<Type> attrType = var->attr.getValueType();
64 return attrType ? attrType->getBuilderCall() : std::nullopt;
67 /// Return if this attribute refers to a UnitAttr.
68 bool isUnitAttr() const {
69 return var->attr.getBaseAttr().getAttrDefName() == "UnitAttr";
72 /// Indicate if this attribute is printed "qualified" (that is it is
73 /// prefixed with the `#dialect.mnemonic`).
74 bool shouldBeQualified() { return shouldBeQualifiedFlag; }
75 void setShouldBeQualified(bool qualified = true) {
76 shouldBeQualifiedFlag = qualified;
79 private:
80 bool shouldBeQualifiedFlag = false;
83 /// This class represents a variable that refers to an operand argument.
84 using OperandVariable =
85 OpVariableElement<NamedTypeConstraint, VariableElement::Operand>;
87 /// This class represents a variable that refers to a result.
88 using ResultVariable =
89 OpVariableElement<NamedTypeConstraint, VariableElement::Result>;
91 /// This class represents a variable that refers to a region.
92 using RegionVariable = OpVariableElement<NamedRegion, VariableElement::Region>;
94 /// This class represents a variable that refers to a successor.
95 using SuccessorVariable =
96 OpVariableElement<NamedSuccessor, VariableElement::Successor>;
98 /// This class represents a variable that refers to a property argument.
99 using PropertyVariable =
100 OpVariableElement<NamedProperty, VariableElement::Property>;
101 } // namespace
103 //===----------------------------------------------------------------------===//
104 // DirectiveElement
106 namespace {
107 /// This class represents the `operands` directive. This directive represents
108 /// all of the operands of an operation.
109 using OperandsDirective = DirectiveElementBase<DirectiveElement::Operands>;
111 /// This class represents the `results` directive. This directive represents
112 /// all of the results of an operation.
113 using ResultsDirective = DirectiveElementBase<DirectiveElement::Results>;
115 /// This class represents the `regions` directive. This directive represents
116 /// all of the regions of an operation.
117 using RegionsDirective = DirectiveElementBase<DirectiveElement::Regions>;
119 /// This class represents the `successors` directive. This directive represents
120 /// all of the successors of an operation.
121 using SuccessorsDirective = DirectiveElementBase<DirectiveElement::Successors>;
123 /// This class represents the `attr-dict` directive. This directive represents
124 /// the attribute dictionary of the operation.
125 class AttrDictDirective
126 : public DirectiveElementBase<DirectiveElement::AttrDict> {
127 public:
128 explicit AttrDictDirective(bool withKeyword) : withKeyword(withKeyword) {}
130 /// Return whether the dictionary should be printed with the 'attributes'
131 /// keyword.
132 bool isWithKeyword() const { return withKeyword; }
134 private:
135 /// If the dictionary should be printed with the 'attributes' keyword.
136 bool withKeyword;
139 /// This class represents the `prop-dict` directive. This directive represents
140 /// the properties of the operation, expressed as a directionary.
141 class PropDictDirective
142 : public DirectiveElementBase<DirectiveElement::PropDict> {
143 public:
144 explicit PropDictDirective() = default;
147 /// This class represents the `functional-type` directive. This directive takes
148 /// two arguments and formats them, respectively, as the inputs and results of a
149 /// FunctionType.
150 class FunctionalTypeDirective
151 : public DirectiveElementBase<DirectiveElement::FunctionalType> {
152 public:
153 FunctionalTypeDirective(FormatElement *inputs, FormatElement *results)
154 : inputs(inputs), results(results) {}
156 FormatElement *getInputs() const { return inputs; }
157 FormatElement *getResults() const { return results; }
159 private:
160 /// The input and result arguments.
161 FormatElement *inputs, *results;
164 /// This class represents the `type` directive.
165 class TypeDirective : public DirectiveElementBase<DirectiveElement::Type> {
166 public:
167 TypeDirective(FormatElement *arg) : arg(arg) {}
169 FormatElement *getArg() const { return arg; }
171 /// Indicate if this type is printed "qualified" (that is it is
172 /// prefixed with the `!dialect.mnemonic`).
173 bool shouldBeQualified() { return shouldBeQualifiedFlag; }
174 void setShouldBeQualified(bool qualified = true) {
175 shouldBeQualifiedFlag = qualified;
178 private:
179 /// The argument that is used to format the directive.
180 FormatElement *arg;
182 bool shouldBeQualifiedFlag = false;
185 /// This class represents a group of order-independent optional clauses. Each
186 /// clause starts with a literal element and has a coressponding parsing
187 /// element. A parsing element is a continous sequence of format elements.
188 /// Each clause can appear 0 or 1 time.
189 class OIListElement : public DirectiveElementBase<DirectiveElement::OIList> {
190 public:
191 OIListElement(std::vector<FormatElement *> &&literalElements,
192 std::vector<std::vector<FormatElement *>> &&parsingElements)
193 : literalElements(std::move(literalElements)),
194 parsingElements(std::move(parsingElements)) {}
196 /// Returns a range to iterate over the LiteralElements.
197 auto getLiteralElements() const {
198 function_ref<LiteralElement *(FormatElement * el)>
199 literalElementCastConverter =
200 [](FormatElement *el) { return cast<LiteralElement>(el); };
201 return llvm::map_range(literalElements, literalElementCastConverter);
204 /// Returns a range to iterate over the parsing elements corresponding to the
205 /// clauses.
206 ArrayRef<std::vector<FormatElement *>> getParsingElements() const {
207 return parsingElements;
210 /// Returns a range to iterate over tuples of parsing and literal elements.
211 auto getClauses() const {
212 return llvm::zip(getLiteralElements(), getParsingElements());
215 /// If the parsing element is a single UnitAttr element, then it returns the
216 /// attribute variable. Otherwise, returns nullptr.
217 AttributeVariable *
218 getUnitAttrParsingElement(ArrayRef<FormatElement *> pelement) {
219 if (pelement.size() == 1) {
220 auto *attrElem = dyn_cast<AttributeVariable>(pelement[0]);
221 if (attrElem && attrElem->isUnitAttr())
222 return attrElem;
224 return nullptr;
227 private:
228 /// A vector of `LiteralElement` objects. Each element stores the keyword
229 /// for one case of oilist element. For example, an oilist element along with
230 /// the `literalElements` vector:
231 /// ```
232 /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`]
233 /// literalElements = { `keyword`, `otherKeyword` }
234 /// ```
235 std::vector<FormatElement *> literalElements;
237 /// A vector of valid declarative assembly format vectors. Each object in
238 /// parsing elements is a vector of elements in assembly format syntax.
239 /// For example, an oilist element along with the parsingElements vector:
240 /// ```
241 /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`]
242 /// parsingElements = {
243 /// { `=`, `(`, $arg0, `)` },
244 /// { `<`, $arg1, `>` }
245 /// }
246 /// ```
247 std::vector<std::vector<FormatElement *>> parsingElements;
249 } // namespace
251 //===----------------------------------------------------------------------===//
252 // OperationFormat
253 //===----------------------------------------------------------------------===//
255 namespace {
257 using ConstArgument =
258 llvm::PointerUnion<const NamedAttribute *, const NamedTypeConstraint *>;
260 struct OperationFormat {
261 /// This class represents a specific resolver for an operand or result type.
262 class TypeResolution {
263 public:
264 TypeResolution() = default;
266 /// Get the index into the buildable types for this type, or std::nullopt.
267 std::optional<int> getBuilderIdx() const { return builderIdx; }
268 void setBuilderIdx(int idx) { builderIdx = idx; }
270 /// Get the variable this type is resolved to, or nullptr.
271 const NamedTypeConstraint *getVariable() const {
272 return llvm::dyn_cast_if_present<const NamedTypeConstraint *>(resolver);
274 /// Get the attribute this type is resolved to, or nullptr.
275 const NamedAttribute *getAttribute() const {
276 return llvm::dyn_cast_if_present<const NamedAttribute *>(resolver);
278 /// Get the transformer for the type of the variable, or std::nullopt.
279 std::optional<StringRef> getVarTransformer() const {
280 return variableTransformer;
282 void setResolver(ConstArgument arg, std::optional<StringRef> transformer) {
283 resolver = arg;
284 variableTransformer = transformer;
285 assert(getVariable() || getAttribute());
288 private:
289 /// If the type is resolved with a buildable type, this is the index into
290 /// 'buildableTypes' in the parent format.
291 std::optional<int> builderIdx;
292 /// If the type is resolved based upon another operand or result, this is
293 /// the variable or the attribute that this type is resolved to.
294 ConstArgument resolver;
295 /// If the type is resolved based upon another operand or result, this is
296 /// a transformer to apply to the variable when resolving.
297 std::optional<StringRef> variableTransformer;
300 /// The context in which an element is generated.
301 enum class GenContext {
302 /// The element is generated at the top-level or with the same behaviour.
303 Normal,
304 /// The element is generated inside an optional group.
305 Optional
308 OperationFormat(const Operator &op)
309 : useProperties(op.getDialect().usePropertiesForAttributes() &&
310 !op.getAttributes().empty()),
311 opCppClassName(op.getCppClassName()) {
312 operandTypes.resize(op.getNumOperands(), TypeResolution());
313 resultTypes.resize(op.getNumResults(), TypeResolution());
315 hasImplicitTermTrait = llvm::any_of(op.getTraits(), [](const Trait &trait) {
316 return trait.getDef().isSubClassOf("SingleBlockImplicitTerminatorImpl");
319 hasSingleBlockTrait = op.getTrait("::mlir::OpTrait::SingleBlock");
322 /// Generate the operation parser from this format.
323 void genParser(Operator &op, OpClass &opClass);
324 /// Generate the parser code for a specific format element.
325 void genElementParser(FormatElement *element, MethodBody &body,
326 FmtContext &attrTypeCtx,
327 GenContext genCtx = GenContext::Normal);
328 /// Generate the C++ to resolve the types of operands and results during
329 /// parsing.
330 void genParserTypeResolution(Operator &op, MethodBody &body);
331 /// Generate the C++ to resolve the types of the operands during parsing.
332 void genParserOperandTypeResolution(
333 Operator &op, MethodBody &body,
334 function_ref<void(TypeResolution &, StringRef)> emitTypeResolver);
335 /// Generate the C++ to resolve regions during parsing.
336 void genParserRegionResolution(Operator &op, MethodBody &body);
337 /// Generate the C++ to resolve successors during parsing.
338 void genParserSuccessorResolution(Operator &op, MethodBody &body);
339 /// Generate the C++ to handling variadic segment size traits.
340 void genParserVariadicSegmentResolution(Operator &op, MethodBody &body);
342 /// Generate the operation printer from this format.
343 void genPrinter(Operator &op, OpClass &opClass);
345 /// Generate the printer code for a specific format element.
346 void genElementPrinter(FormatElement *element, MethodBody &body, Operator &op,
347 bool &shouldEmitSpace, bool &lastWasPunctuation);
349 /// The various elements in this format.
350 std::vector<FormatElement *> elements;
352 /// A flag indicating if all operand/result types were seen. If the format
353 /// contains these, it can not contain individual type resolvers.
354 bool allOperands = false, allOperandTypes = false, allResultTypes = false;
356 /// A flag indicating if this operation infers its result types
357 bool infersResultTypes = false;
359 /// A flag indicating if this operation has the SingleBlockImplicitTerminator
360 /// trait.
361 bool hasImplicitTermTrait;
363 /// A flag indicating if this operation has the SingleBlock trait.
364 bool hasSingleBlockTrait;
366 /// Indicate whether attribute are stored in properties.
367 bool useProperties;
369 /// Indicate whether prop-dict is used in the format
370 bool hasPropDict;
372 /// The Operation class name
373 StringRef opCppClassName;
375 /// A map of buildable types to indices.
376 llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
378 /// The index of the buildable type, if valid, for every operand and result.
379 std::vector<TypeResolution> operandTypes, resultTypes;
381 /// The set of attributes explicitly used within the format.
382 llvm::SmallSetVector<const NamedAttribute *, 8> usedAttributes;
383 llvm::StringSet<> inferredAttributes;
385 /// The set of properties explicitly used within the format.
386 llvm::SmallSetVector<const NamedProperty *, 8> usedProperties;
388 } // namespace
390 //===----------------------------------------------------------------------===//
391 // Parser Gen
393 /// Returns true if we can format the given attribute as an EnumAttr in the
394 /// parser format.
395 static bool canFormatEnumAttr(const NamedAttribute *attr) {
396 Attribute baseAttr = attr->attr.getBaseAttr();
397 const EnumAttr *enumAttr = dyn_cast<EnumAttr>(&baseAttr);
398 if (!enumAttr)
399 return false;
401 // The attribute must have a valid underlying type and a constant builder.
402 return !enumAttr->getUnderlyingType().empty() &&
403 !enumAttr->getConstBuilderTemplate().empty();
406 /// Returns if we should format the given attribute as an SymbolNameAttr.
407 static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) {
408 return attr->attr.getBaseAttr().getAttrDefName() == "SymbolNameAttr";
411 /// The code snippet used to generate a parser call for an attribute.
413 /// {0}: The name of the attribute.
414 /// {1}: The type for the attribute.
415 const char *const attrParserCode = R"(
416 if (parser.parseCustomAttributeWithFallback({0}Attr, {1})) {{
417 return ::mlir::failure();
421 /// The code snippet used to generate a parser call for an attribute.
423 /// {0}: The name of the attribute.
424 /// {1}: The type for the attribute.
425 const char *const genericAttrParserCode = R"(
426 if (parser.parseAttribute({0}Attr, {1}))
427 return ::mlir::failure();
430 const char *const optionalAttrParserCode = R"(
431 ::mlir::OptionalParseResult parseResult{0}Attr =
432 parser.parseOptionalAttribute({0}Attr, {1});
433 if (parseResult{0}Attr.has_value() && failed(*parseResult{0}Attr))
434 return ::mlir::failure();
435 if (parseResult{0}Attr.has_value() && succeeded(*parseResult{0}Attr))
438 /// The code snippet used to generate a parser call for a symbol name attribute.
440 /// {0}: The name of the attribute.
441 const char *const symbolNameAttrParserCode = R"(
442 if (parser.parseSymbolName({0}Attr))
443 return ::mlir::failure();
445 const char *const optionalSymbolNameAttrParserCode = R"(
446 // Parsing an optional symbol name doesn't fail, so no need to check the
447 // result.
448 (void)parser.parseOptionalSymbolName({0}Attr);
451 /// The code snippet used to generate a parser call for an enum attribute.
453 /// {0}: The name of the attribute.
454 /// {1}: The c++ namespace for the enum symbolize functions.
455 /// {2}: The function to symbolize a string of the enum.
456 /// {3}: The constant builder call to create an attribute of the enum type.
457 /// {4}: The set of allowed enum keywords.
458 /// {5}: The error message on failure when the enum isn't present.
459 /// {6}: The attribute assignment expression
460 const char *const enumAttrParserCode = R"(
462 ::llvm::StringRef attrStr;
463 ::mlir::NamedAttrList attrStorage;
464 auto loc = parser.getCurrentLocation();
465 if (parser.parseOptionalKeyword(&attrStr, {4})) {
466 ::mlir::StringAttr attrVal;
467 ::mlir::OptionalParseResult parseResult =
468 parser.parseOptionalAttribute(attrVal,
469 parser.getBuilder().getNoneType(),
470 "{0}", attrStorage);
471 if (parseResult.has_value()) {{
472 if (failed(*parseResult))
473 return ::mlir::failure();
474 attrStr = attrVal.getValue();
475 } else {
479 if (!attrStr.empty()) {
480 auto attrOptional = {1}::{2}(attrStr);
481 if (!attrOptional)
482 return parser.emitError(loc, "invalid ")
483 << "{0} attribute specification: \"" << attrStr << '"';;
485 {0}Attr = {3};
491 /// The code snippet used to generate a parser call for an operand.
493 /// {0}: The name of the operand.
494 const char *const variadicOperandParserCode = R"(
495 {0}OperandsLoc = parser.getCurrentLocation();
496 if (parser.parseOperandList({0}Operands))
497 return ::mlir::failure();
499 const char *const optionalOperandParserCode = R"(
501 {0}OperandsLoc = parser.getCurrentLocation();
502 ::mlir::OpAsmParser::UnresolvedOperand operand;
503 ::mlir::OptionalParseResult parseResult =
504 parser.parseOptionalOperand(operand);
505 if (parseResult.has_value()) {
506 if (failed(*parseResult))
507 return ::mlir::failure();
508 {0}Operands.push_back(operand);
512 const char *const operandParserCode = R"(
513 {0}OperandsLoc = parser.getCurrentLocation();
514 if (parser.parseOperand({0}RawOperand))
515 return ::mlir::failure();
517 /// The code snippet used to generate a parser call for a VariadicOfVariadic
518 /// operand.
520 /// {0}: The name of the operand.
521 /// {1}: The name of segment size attribute.
522 const char *const variadicOfVariadicOperandParserCode = R"(
524 {0}OperandsLoc = parser.getCurrentLocation();
525 int32_t curSize = 0;
526 do {
527 if (parser.parseOptionalLParen())
528 break;
529 if (parser.parseOperandList({0}Operands) || parser.parseRParen())
530 return ::mlir::failure();
531 {0}OperandGroupSizes.push_back({0}Operands.size() - curSize);
532 curSize = {0}Operands.size();
533 } while (succeeded(parser.parseOptionalComma()));
537 /// The code snippet used to generate a parser call for a type list.
539 /// {0}: The name for the type list.
540 const char *const variadicOfVariadicTypeParserCode = R"(
541 do {
542 if (parser.parseOptionalLParen())
543 break;
544 if (parser.parseOptionalRParen() &&
545 (parser.parseTypeList({0}Types) || parser.parseRParen()))
546 return ::mlir::failure();
547 } while (succeeded(parser.parseOptionalComma()));
549 const char *const variadicTypeParserCode = R"(
550 if (parser.parseTypeList({0}Types))
551 return ::mlir::failure();
553 const char *const optionalTypeParserCode = R"(
555 ::mlir::Type optionalType;
556 ::mlir::OptionalParseResult parseResult =
557 parser.parseOptionalType(optionalType);
558 if (parseResult.has_value()) {
559 if (failed(*parseResult))
560 return ::mlir::failure();
561 {0}Types.push_back(optionalType);
565 const char *const typeParserCode = R"(
567 {0} type;
568 if (parser.parseCustomTypeWithFallback(type))
569 return ::mlir::failure();
570 {1}RawType = type;
573 const char *const qualifiedTypeParserCode = R"(
574 if (parser.parseType({1}RawType))
575 return ::mlir::failure();
578 /// The code snippet used to generate a parser call for a functional type.
580 /// {0}: The name for the input type list.
581 /// {1}: The name for the result type list.
582 const char *const functionalTypeParserCode = R"(
583 ::mlir::FunctionType {0}__{1}_functionType;
584 if (parser.parseType({0}__{1}_functionType))
585 return ::mlir::failure();
586 {0}Types = {0}__{1}_functionType.getInputs();
587 {1}Types = {0}__{1}_functionType.getResults();
590 /// The code snippet used to generate a parser call to infer return types.
592 /// {0}: The operation class name
593 const char *const inferReturnTypesParserCode = R"(
594 ::llvm::SmallVector<::mlir::Type> inferredReturnTypes;
595 if (::mlir::failed({0}::inferReturnTypes(parser.getContext(),
596 result.location, result.operands,
597 result.attributes.getDictionary(parser.getContext()),
598 result.getRawProperties(),
599 result.regions, inferredReturnTypes)))
600 return ::mlir::failure();
601 result.addTypes(inferredReturnTypes);
604 /// The code snippet used to generate a parser call for a region list.
606 /// {0}: The name for the region list.
607 const char *regionListParserCode = R"(
609 std::unique_ptr<::mlir::Region> region;
610 auto firstRegionResult = parser.parseOptionalRegion(region);
611 if (firstRegionResult.has_value()) {
612 if (failed(*firstRegionResult))
613 return ::mlir::failure();
614 {0}Regions.emplace_back(std::move(region));
616 // Parse any trailing regions.
617 while (succeeded(parser.parseOptionalComma())) {
618 region = std::make_unique<::mlir::Region>();
619 if (parser.parseRegion(*region))
620 return ::mlir::failure();
621 {0}Regions.emplace_back(std::move(region));
627 /// The code snippet used to ensure a list of regions have terminators.
629 /// {0}: The name of the region list.
630 const char *regionListEnsureTerminatorParserCode = R"(
631 for (auto &region : {0}Regions)
632 ensureTerminator(*region, parser.getBuilder(), result.location);
635 /// The code snippet used to ensure a list of regions have a block.
637 /// {0}: The name of the region list.
638 const char *regionListEnsureSingleBlockParserCode = R"(
639 for (auto &region : {0}Regions)
640 if (region->empty()) region->emplaceBlock();
643 /// The code snippet used to generate a parser call for an optional region.
645 /// {0}: The name of the region.
646 const char *optionalRegionParserCode = R"(
648 auto parseResult = parser.parseOptionalRegion(*{0}Region);
649 if (parseResult.has_value() && failed(*parseResult))
650 return ::mlir::failure();
654 /// The code snippet used to generate a parser call for a region.
656 /// {0}: The name of the region.
657 const char *regionParserCode = R"(
658 if (parser.parseRegion(*{0}Region))
659 return ::mlir::failure();
662 /// The code snippet used to ensure a region has a terminator.
664 /// {0}: The name of the region.
665 const char *regionEnsureTerminatorParserCode = R"(
666 ensureTerminator(*{0}Region, parser.getBuilder(), result.location);
669 /// The code snippet used to ensure a region has a block.
671 /// {0}: The name of the region.
672 const char *regionEnsureSingleBlockParserCode = R"(
673 if ({0}Region->empty()) {0}Region->emplaceBlock();
676 /// The code snippet used to generate a parser call for a successor list.
678 /// {0}: The name for the successor list.
679 const char *successorListParserCode = R"(
681 ::mlir::Block *succ;
682 auto firstSucc = parser.parseOptionalSuccessor(succ);
683 if (firstSucc.has_value()) {
684 if (failed(*firstSucc))
685 return ::mlir::failure();
686 {0}Successors.emplace_back(succ);
688 // Parse any trailing successors.
689 while (succeeded(parser.parseOptionalComma())) {
690 if (parser.parseSuccessor(succ))
691 return ::mlir::failure();
692 {0}Successors.emplace_back(succ);
698 /// The code snippet used to generate a parser call for a successor.
700 /// {0}: The name of the successor.
701 const char *successorParserCode = R"(
702 if (parser.parseSuccessor({0}Successor))
703 return ::mlir::failure();
706 /// The code snippet used to generate a parser for OIList
708 /// {0}: literal keyword corresponding to a case for oilist
709 const char *oilistParserCode = R"(
710 if ({0}Clause) {
711 return parser.emitError(parser.getNameLoc())
712 << "`{0}` clause can appear at most once in the expansion of the "
713 "oilist directive";
715 {0}Clause = true;
718 namespace {
719 /// The type of length for a given parse argument.
720 enum class ArgumentLengthKind {
721 /// The argument is a variadic of a variadic, and may contain 0->N range
722 /// elements.
723 VariadicOfVariadic,
724 /// The argument is variadic, and may contain 0->N elements.
725 Variadic,
726 /// The argument is optional, and may contain 0 or 1 elements.
727 Optional,
728 /// The argument is a single element, i.e. always represents 1 element.
729 Single
731 } // namespace
733 /// Get the length kind for the given constraint.
734 static ArgumentLengthKind
735 getArgumentLengthKind(const NamedTypeConstraint *var) {
736 if (var->isOptional())
737 return ArgumentLengthKind::Optional;
738 if (var->isVariadicOfVariadic())
739 return ArgumentLengthKind::VariadicOfVariadic;
740 if (var->isVariadic())
741 return ArgumentLengthKind::Variadic;
742 return ArgumentLengthKind::Single;
745 /// Get the name used for the type list for the given type directive operand.
746 /// 'lengthKind' to the corresponding kind for the given argument.
747 static StringRef getTypeListName(FormatElement *arg,
748 ArgumentLengthKind &lengthKind) {
749 if (auto *operand = dyn_cast<OperandVariable>(arg)) {
750 lengthKind = getArgumentLengthKind(operand->getVar());
751 return operand->getVar()->name;
753 if (auto *result = dyn_cast<ResultVariable>(arg)) {
754 lengthKind = getArgumentLengthKind(result->getVar());
755 return result->getVar()->name;
757 lengthKind = ArgumentLengthKind::Variadic;
758 if (isa<OperandsDirective>(arg))
759 return "allOperand";
760 if (isa<ResultsDirective>(arg))
761 return "allResult";
762 llvm_unreachable("unknown 'type' directive argument");
765 /// Generate the parser for a literal value.
766 static void genLiteralParser(StringRef value, MethodBody &body) {
767 // Handle the case of a keyword/identifier.
768 if (value.front() == '_' || isalpha(value.front())) {
769 body << "Keyword(\"" << value << "\")";
770 return;
772 body << (StringRef)StringSwitch<StringRef>(value)
773 .Case("->", "Arrow()")
774 .Case(":", "Colon()")
775 .Case(",", "Comma()")
776 .Case("=", "Equal()")
777 .Case("<", "Less()")
778 .Case(">", "Greater()")
779 .Case("{", "LBrace()")
780 .Case("}", "RBrace()")
781 .Case("(", "LParen()")
782 .Case(")", "RParen()")
783 .Case("[", "LSquare()")
784 .Case("]", "RSquare()")
785 .Case("?", "Question()")
786 .Case("+", "Plus()")
787 .Case("*", "Star()")
788 .Case("...", "Ellipsis()");
791 /// Generate the storage code required for parsing the given element.
792 static void genElementParserStorage(FormatElement *element, const Operator &op,
793 MethodBody &body) {
794 if (auto *optional = dyn_cast<OptionalElement>(element)) {
795 ArrayRef<FormatElement *> elements = optional->getThenElements();
797 // If the anchor is a unit attribute, it won't be parsed directly so elide
798 // it.
799 auto *anchor = dyn_cast<AttributeVariable>(optional->getAnchor());
800 FormatElement *elidedAnchorElement = nullptr;
801 if (anchor && anchor != elements.front() && anchor->isUnitAttr())
802 elidedAnchorElement = anchor;
803 for (FormatElement *childElement : elements)
804 if (childElement != elidedAnchorElement)
805 genElementParserStorage(childElement, op, body);
806 for (FormatElement *childElement : optional->getElseElements())
807 genElementParserStorage(childElement, op, body);
809 } else if (auto *oilist = dyn_cast<OIListElement>(element)) {
810 for (ArrayRef<FormatElement *> pelement : oilist->getParsingElements()) {
811 if (!oilist->getUnitAttrParsingElement(pelement))
812 for (FormatElement *element : pelement)
813 genElementParserStorage(element, op, body);
816 } else if (auto *custom = dyn_cast<CustomDirective>(element)) {
817 for (FormatElement *paramElement : custom->getArguments())
818 genElementParserStorage(paramElement, op, body);
820 } else if (isa<OperandsDirective>(element)) {
821 body << " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> "
822 "allOperands;\n";
824 } else if (isa<RegionsDirective>(element)) {
825 body << " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
826 "fullRegions;\n";
828 } else if (isa<SuccessorsDirective>(element)) {
829 body << " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
831 } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
832 const NamedAttribute *var = attr->getVar();
833 body << llvm::formatv(" {0} {1}Attr;\n", var->attr.getStorageType(),
834 var->name);
836 } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
837 StringRef name = operand->getVar()->name;
838 if (operand->getVar()->isVariableLength()) {
839 body
840 << " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> "
841 << name << "Operands;\n";
842 if (operand->getVar()->isVariadicOfVariadic()) {
843 body << " llvm::SmallVector<int32_t> " << name
844 << "OperandGroupSizes;\n";
846 } else {
847 body << " ::mlir::OpAsmParser::UnresolvedOperand " << name
848 << "RawOperand{};\n"
849 << " ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> "
850 << name << "Operands(&" << name << "RawOperand, 1);";
852 body << llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
853 " (void){0}OperandsLoc;\n",
854 name);
856 } else if (auto *region = dyn_cast<RegionVariable>(element)) {
857 StringRef name = region->getVar()->name;
858 if (region->getVar()->isVariadic()) {
859 body << llvm::formatv(
860 " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
861 "{0}Regions;\n",
862 name);
863 } else {
864 body << llvm::formatv(" std::unique_ptr<::mlir::Region> {0}Region = "
865 "std::make_unique<::mlir::Region>();\n",
866 name);
869 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
870 StringRef name = successor->getVar()->name;
871 if (successor->getVar()->isVariadic()) {
872 body << llvm::formatv(" ::llvm::SmallVector<::mlir::Block *, 2> "
873 "{0}Successors;\n",
874 name);
875 } else {
876 body << llvm::formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name);
879 } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
880 ArgumentLengthKind lengthKind;
881 StringRef name = getTypeListName(dir->getArg(), lengthKind);
882 if (lengthKind != ArgumentLengthKind::Single)
883 body << " ::llvm::SmallVector<::mlir::Type, 1> " << name << "Types;\n";
884 else
885 body
886 << llvm::formatv(" ::mlir::Type {0}RawType{{};\n", name)
887 << llvm::formatv(
888 " ::llvm::ArrayRef<::mlir::Type> {0}Types(&{0}RawType, 1);\n",
889 name);
890 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
891 ArgumentLengthKind ignored;
892 body << " ::llvm::ArrayRef<::mlir::Type> "
893 << getTypeListName(dir->getInputs(), ignored) << "Types;\n";
894 body << " ::llvm::ArrayRef<::mlir::Type> "
895 << getTypeListName(dir->getResults(), ignored) << "Types;\n";
899 /// Generate the parser for a parameter to a custom directive.
900 static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
901 if (auto *attr = dyn_cast<AttributeVariable>(param)) {
902 body << attr->getVar()->name << "Attr";
903 } else if (isa<AttrDictDirective>(param)) {
904 body << "result.attributes";
905 } else if (isa<PropDictDirective>(param)) {
906 body << "result";
907 } else if (auto *operand = dyn_cast<OperandVariable>(param)) {
908 StringRef name = operand->getVar()->name;
909 ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
910 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
911 body << llvm::formatv("{0}OperandGroups", name);
912 else if (lengthKind == ArgumentLengthKind::Variadic)
913 body << llvm::formatv("{0}Operands", name);
914 else if (lengthKind == ArgumentLengthKind::Optional)
915 body << llvm::formatv("{0}Operand", name);
916 else
917 body << formatv("{0}RawOperand", name);
919 } else if (auto *region = dyn_cast<RegionVariable>(param)) {
920 StringRef name = region->getVar()->name;
921 if (region->getVar()->isVariadic())
922 body << llvm::formatv("{0}Regions", name);
923 else
924 body << llvm::formatv("*{0}Region", name);
926 } else if (auto *successor = dyn_cast<SuccessorVariable>(param)) {
927 StringRef name = successor->getVar()->name;
928 if (successor->getVar()->isVariadic())
929 body << llvm::formatv("{0}Successors", name);
930 else
931 body << llvm::formatv("{0}Successor", name);
933 } else if (auto *dir = dyn_cast<RefDirective>(param)) {
934 genCustomParameterParser(dir->getArg(), body);
936 } else if (auto *dir = dyn_cast<TypeDirective>(param)) {
937 ArgumentLengthKind lengthKind;
938 StringRef listName = getTypeListName(dir->getArg(), lengthKind);
939 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
940 body << llvm::formatv("{0}TypeGroups", listName);
941 else if (lengthKind == ArgumentLengthKind::Variadic)
942 body << llvm::formatv("{0}Types", listName);
943 else if (lengthKind == ArgumentLengthKind::Optional)
944 body << llvm::formatv("{0}Type", listName);
945 else
946 body << formatv("{0}RawType", listName);
948 } else if (auto *string = dyn_cast<StringElement>(param)) {
949 FmtContext ctx;
950 ctx.withBuilder("parser.getBuilder()");
951 ctx.addSubst("_ctxt", "parser.getContext()");
952 body << tgfmt(string->getValue(), &ctx);
954 } else if (auto *property = dyn_cast<PropertyVariable>(param)) {
955 body << llvm::formatv("result.getOrAddProperties<Properties>().{0}",
956 property->getVar()->name);
957 } else {
958 llvm_unreachable("unknown custom directive parameter");
962 /// Generate the parser for a custom directive.
963 static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
964 bool useProperties,
965 StringRef opCppClassName,
966 bool isOptional = false) {
967 body << " {\n";
969 // Preprocess the directive variables.
970 // * Add a local variable for optional operands and types. This provides a
971 // better API to the user defined parser methods.
972 // * Set the location of operand variables.
973 for (FormatElement *param : dir->getArguments()) {
974 if (auto *operand = dyn_cast<OperandVariable>(param)) {
975 auto *var = operand->getVar();
976 body << " " << var->name
977 << "OperandsLoc = parser.getCurrentLocation();\n";
978 if (var->isOptional()) {
979 body << llvm::formatv(
980 " ::std::optional<::mlir::OpAsmParser::UnresolvedOperand> "
981 "{0}Operand;\n",
982 var->name);
983 } else if (var->isVariadicOfVariadic()) {
984 body << llvm::formatv(" "
985 "::llvm::SmallVector<::llvm::SmallVector<::mlir::"
986 "OpAsmParser::UnresolvedOperand>> "
987 "{0}OperandGroups;\n",
988 var->name);
990 } else if (auto *dir = dyn_cast<TypeDirective>(param)) {
991 ArgumentLengthKind lengthKind;
992 StringRef listName = getTypeListName(dir->getArg(), lengthKind);
993 if (lengthKind == ArgumentLengthKind::Optional) {
994 body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName);
995 } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
996 body << llvm::formatv(
997 " ::llvm::SmallVector<llvm::SmallVector<::mlir::Type>> "
998 "{0}TypeGroups;\n",
999 listName);
1001 } else if (auto *dir = dyn_cast<RefDirective>(param)) {
1002 FormatElement *input = dir->getArg();
1003 if (auto *operand = dyn_cast<OperandVariable>(input)) {
1004 if (!operand->getVar()->isOptional())
1005 continue;
1006 body << llvm::formatv(
1007 " {0} {1}Operand = {1}Operands.empty() ? {0}() : "
1008 "{1}Operands[0];\n",
1009 "::std::optional<::mlir::OpAsmParser::UnresolvedOperand>",
1010 operand->getVar()->name);
1012 } else if (auto *type = dyn_cast<TypeDirective>(input)) {
1013 ArgumentLengthKind lengthKind;
1014 StringRef listName = getTypeListName(type->getArg(), lengthKind);
1015 if (lengthKind == ArgumentLengthKind::Optional) {
1016 body << llvm::formatv(" ::mlir::Type {0}Type = {0}Types.empty() ? "
1017 "::mlir::Type() : {0}Types[0];\n",
1018 listName);
1024 body << " auto odsResult = parse" << dir->getName() << "(parser";
1025 for (FormatElement *param : dir->getArguments()) {
1026 body << ", ";
1027 genCustomParameterParser(param, body);
1029 body << ");\n";
1031 if (isOptional) {
1032 body << " if (!odsResult.has_value()) return {};\n"
1033 << " if (::mlir::failed(*odsResult)) return ::mlir::failure();\n";
1034 } else {
1035 body << " if (odsResult) return ::mlir::failure();\n";
1038 // After parsing, add handling for any of the optional constructs.
1039 for (FormatElement *param : dir->getArguments()) {
1040 if (auto *attr = dyn_cast<AttributeVariable>(param)) {
1041 const NamedAttribute *var = attr->getVar();
1042 if (var->attr.isOptional() || var->attr.hasDefaultValue())
1043 body << llvm::formatv(" if ({0}Attr)\n ", var->name);
1044 if (useProperties) {
1045 body << formatv(
1046 " result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;\n",
1047 var->name, opCppClassName);
1048 } else {
1049 body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
1050 var->name);
1053 } else if (auto *operand = dyn_cast<OperandVariable>(param)) {
1054 const NamedTypeConstraint *var = operand->getVar();
1055 if (var->isOptional()) {
1056 body << llvm::formatv(" if ({0}Operand.has_value())\n"
1057 " {0}Operands.push_back(*{0}Operand);\n",
1058 var->name);
1059 } else if (var->isVariadicOfVariadic()) {
1060 body << llvm::formatv(
1061 " for (const auto &subRange : {0}OperandGroups) {{\n"
1062 " {0}Operands.append(subRange.begin(), subRange.end());\n"
1063 " {0}OperandGroupSizes.push_back(subRange.size());\n"
1064 " }\n",
1065 var->name, var->constraint.getVariadicOfVariadicSegmentSizeAttr());
1067 } else if (auto *dir = dyn_cast<TypeDirective>(param)) {
1068 ArgumentLengthKind lengthKind;
1069 StringRef listName = getTypeListName(dir->getArg(), lengthKind);
1070 if (lengthKind == ArgumentLengthKind::Optional) {
1071 body << llvm::formatv(" if ({0}Type)\n"
1072 " {0}Types.push_back({0}Type);\n",
1073 listName);
1074 } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
1075 body << llvm::formatv(
1076 " for (const auto &subRange : {0}TypeGroups)\n"
1077 " {0}Types.append(subRange.begin(), subRange.end());\n",
1078 listName);
1083 body << " }\n";
1086 /// Generate the parser for a enum attribute.
1087 static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
1088 FmtContext &attrTypeCtx, bool parseAsOptional,
1089 bool useProperties, StringRef opCppClassName) {
1090 Attribute baseAttr = var->attr.getBaseAttr();
1091 const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
1092 std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
1094 // Generate the code for building an attribute for this enum.
1095 std::string attrBuilderStr;
1097 llvm::raw_string_ostream os(attrBuilderStr);
1098 os << tgfmt(enumAttr.getConstBuilderTemplate(), &attrTypeCtx,
1099 "*attrOptional");
1102 // Build a string containing the cases that can be formatted as a keyword.
1103 std::string validCaseKeywordsStr = "{";
1104 llvm::raw_string_ostream validCaseKeywordsOS(validCaseKeywordsStr);
1105 for (const EnumAttrCase &attrCase : cases)
1106 if (canFormatStringAsKeyword(attrCase.getStr()))
1107 validCaseKeywordsOS << '"' << attrCase.getStr() << "\",";
1108 validCaseKeywordsOS.str().back() = '}';
1110 // If the attribute is not optional, build an error message for the missing
1111 // attribute.
1112 std::string errorMessage;
1113 if (!parseAsOptional) {
1114 llvm::raw_string_ostream errorMessageOS(errorMessage);
1115 errorMessageOS
1116 << "return parser.emitError(loc, \"expected string or "
1117 "keyword containing one of the following enum values for attribute '"
1118 << var->name << "' [";
1119 llvm::interleaveComma(cases, errorMessageOS, [&](const auto &attrCase) {
1120 errorMessageOS << attrCase.getStr();
1122 errorMessageOS << "]\");";
1124 std::string attrAssignment;
1125 if (useProperties) {
1126 attrAssignment =
1127 formatv(" "
1128 "result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;",
1129 var->name, opCppClassName);
1130 } else {
1131 attrAssignment =
1132 formatv("result.addAttribute(\"{0}\", {0}Attr);", var->name);
1135 body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(),
1136 enumAttr.getStringToSymbolFnName(), attrBuilderStr,
1137 validCaseKeywordsStr, errorMessage, attrAssignment);
1140 // Generate the parser for an attribute.
1141 static void genAttrParser(AttributeVariable *attr, MethodBody &body,
1142 FmtContext &attrTypeCtx, bool parseAsOptional,
1143 bool useProperties, StringRef opCppClassName) {
1144 const NamedAttribute *var = attr->getVar();
1146 // Check to see if we can parse this as an enum attribute.
1147 if (canFormatEnumAttr(var))
1148 return genEnumAttrParser(var, body, attrTypeCtx, parseAsOptional,
1149 useProperties, opCppClassName);
1151 // Check to see if we should parse this as a symbol name attribute.
1152 if (shouldFormatSymbolNameAttr(var)) {
1153 body << formatv(parseAsOptional ? optionalSymbolNameAttrParserCode
1154 : symbolNameAttrParserCode,
1155 var->name);
1156 } else {
1158 // If this attribute has a buildable type, use that when parsing the
1159 // attribute.
1160 std::string attrTypeStr;
1161 if (std::optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
1162 llvm::raw_string_ostream os(attrTypeStr);
1163 os << tgfmt(*typeBuilder, &attrTypeCtx);
1164 } else {
1165 attrTypeStr = "::mlir::Type{}";
1167 if (parseAsOptional) {
1168 body << formatv(optionalAttrParserCode, var->name, attrTypeStr);
1169 } else {
1170 if (attr->shouldBeQualified() ||
1171 var->attr.getStorageType() == "::mlir::Attribute")
1172 body << formatv(genericAttrParserCode, var->name, attrTypeStr);
1173 else
1174 body << formatv(attrParserCode, var->name, attrTypeStr);
1177 if (useProperties) {
1178 body << formatv(
1179 " if ({0}Attr) result.getOrAddProperties<{1}::Properties>().{0} = "
1180 "{0}Attr;\n",
1181 var->name, opCppClassName);
1182 } else {
1183 body << formatv(
1184 " if ({0}Attr) result.attributes.append(\"{0}\", {0}Attr);\n",
1185 var->name);
1189 // Generates the 'setPropertiesFromParsedAttr' used to set properties from a
1190 // 'prop-dict' dictionary attr.
1191 static void genParsedAttrPropertiesSetter(OperationFormat &fmt, Operator &op,
1192 OpClass &opClass) {
1193 // Not required unless 'prop-dict' is present.
1194 if (!fmt.hasPropDict)
1195 return;
1197 SmallVector<MethodParameter> paramList;
1198 paramList.emplace_back("Properties &", "prop");
1199 paramList.emplace_back("::mlir::Attribute", "attr");
1200 paramList.emplace_back("::llvm::function_ref<::mlir::InFlightDiagnostic()>",
1201 "emitError");
1203 Method *method = opClass.addStaticMethod("::llvm::LogicalResult",
1204 "setPropertiesFromParsedAttr",
1205 std::move(paramList));
1206 MethodBody &body = method->body().indent();
1208 body << R"decl(
1209 ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr);
1210 if (!dict) {
1211 emitError() << "expected DictionaryAttr to set properties";
1212 return ::mlir::failure();
1214 )decl";
1216 // TODO: properties might be optional as well.
1217 const char *propFromAttrFmt = R"decl(
1218 auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr,
1219 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {{
1220 {0};
1222 auto attr = dict.get("{1}");
1223 if (!attr) {{
1224 emitError() << "expected key entry for {1} in DictionaryAttr to set "
1225 "Properties.";
1226 return ::mlir::failure();
1228 if (::mlir::failed(setFromAttr(prop.{1}, attr, emitError)))
1229 return ::mlir::failure();
1230 )decl";
1232 // Generate the setter for any property not parsed elsewhere.
1233 for (const NamedProperty &namedProperty : op.getProperties()) {
1234 if (fmt.usedProperties.contains(&namedProperty))
1235 continue;
1237 auto scope = body.scope("{\n", "}\n", /*indent=*/true);
1239 StringRef name = namedProperty.name;
1240 const Property &prop = namedProperty.prop;
1241 FmtContext fctx;
1242 body << formatv(propFromAttrFmt,
1243 tgfmt(prop.getConvertFromAttributeCall(),
1244 &fctx.addSubst("_attr", "propAttr")
1245 .addSubst("_storage", "propStorage")
1246 .addSubst("_diag", "emitError")),
1247 name);
1250 // Generate the setter for any attribute not parsed elsewhere.
1251 for (const NamedAttribute &namedAttr : op.getAttributes()) {
1252 if (fmt.usedAttributes.contains(&namedAttr))
1253 continue;
1255 const Attribute &attr = namedAttr.attr;
1256 // Derived attributes do not need to be parsed.
1257 if (attr.isDerivedAttr())
1258 continue;
1260 auto scope = body.scope("{\n", "}\n", /*indent=*/true);
1262 // If the attribute has a default value or is optional, it does not need to
1263 // be present in the parsed dictionary attribute.
1264 bool isRequired = !attr.isOptional() && !attr.hasDefaultValue();
1265 body << formatv(R"decl(
1266 auto &propStorage = prop.{0};
1267 auto attr = dict.get("{0}");
1268 if (attr || /*isRequired=*/{1}) {{
1269 if (!attr) {{
1270 emitError() << "expected key entry for {0} in DictionaryAttr to set "
1271 "Properties.";
1272 return ::mlir::failure();
1274 auto convertedAttr = ::llvm::dyn_cast<std::remove_reference_t<decltype(propStorage)>>(attr);
1275 if (convertedAttr) {{
1276 propStorage = convertedAttr;
1277 } else {{
1278 emitError() << "Invalid attribute `{0}` in property conversion: " << attr;
1279 return ::mlir::failure();
1282 )decl",
1283 namedAttr.name, isRequired);
1285 body << "return ::mlir::success();\n";
1288 void OperationFormat::genParser(Operator &op, OpClass &opClass) {
1289 SmallVector<MethodParameter> paramList;
1290 paramList.emplace_back("::mlir::OpAsmParser &", "parser");
1291 paramList.emplace_back("::mlir::OperationState &", "result");
1293 auto *method = opClass.addStaticMethod("::mlir::ParseResult", "parse",
1294 std::move(paramList));
1295 auto &body = method->body();
1297 // Generate variables to store the operands and type within the format. This
1298 // allows for referencing these variables in the presence of optional
1299 // groupings.
1300 for (FormatElement *element : elements)
1301 genElementParserStorage(element, op, body);
1303 // A format context used when parsing attributes with buildable types.
1304 FmtContext attrTypeCtx;
1305 attrTypeCtx.withBuilder("parser.getBuilder()");
1307 // Generate parsers for each of the elements.
1308 for (FormatElement *element : elements)
1309 genElementParser(element, body, attrTypeCtx);
1311 // Generate the code to resolve the operand/result types and successors now
1312 // that they have been parsed.
1313 genParserRegionResolution(op, body);
1314 genParserSuccessorResolution(op, body);
1315 genParserVariadicSegmentResolution(op, body);
1316 genParserTypeResolution(op, body);
1318 body << " return ::mlir::success();\n";
1320 genParsedAttrPropertiesSetter(*this, op, opClass);
1323 void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
1324 FmtContext &attrTypeCtx,
1325 GenContext genCtx) {
1326 /// Optional Group.
1327 if (auto *optional = dyn_cast<OptionalElement>(element)) {
1328 auto genElementParsers = [&](FormatElement *firstElement,
1329 ArrayRef<FormatElement *> elements,
1330 bool thenGroup) {
1331 // If the anchor is a unit attribute, we don't need to print it. When
1332 // parsing, we will add this attribute if this group is present.
1333 FormatElement *elidedAnchorElement = nullptr;
1334 auto *anchorAttr = dyn_cast<AttributeVariable>(optional->getAnchor());
1335 if (anchorAttr && anchorAttr != firstElement &&
1336 anchorAttr->isUnitAttr()) {
1337 elidedAnchorElement = anchorAttr;
1339 if (!thenGroup == optional->isInverted()) {
1340 // Add the anchor unit attribute to the operation state.
1341 if (useProperties) {
1342 body << formatv(
1343 " result.getOrAddProperties<{1}::Properties>().{0} = "
1344 "parser.getBuilder().getUnitAttr();",
1345 anchorAttr->getVar()->name, opCppClassName);
1346 } else {
1347 body << " result.addAttribute(\"" << anchorAttr->getVar()->name
1348 << "\", parser.getBuilder().getUnitAttr());\n";
1353 // Generate the rest of the elements inside an optional group. Elements in
1354 // an optional group after the guard are parsed as required.
1355 for (FormatElement *childElement : elements)
1356 if (childElement != elidedAnchorElement)
1357 genElementParser(childElement, body, attrTypeCtx,
1358 GenContext::Optional);
1361 ArrayRef<FormatElement *> thenElements =
1362 optional->getThenElements(/*parseable=*/true);
1364 // Generate a special optional parser for the first element to gate the
1365 // parsing of the rest of the elements.
1366 FormatElement *firstElement = thenElements.front();
1367 if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) {
1368 genAttrParser(attrVar, body, attrTypeCtx, /*parseAsOptional=*/true,
1369 useProperties, opCppClassName);
1370 body << " if (" << attrVar->getVar()->name << "Attr) {\n";
1371 } else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
1372 body << " if (::mlir::succeeded(parser.parseOptional";
1373 genLiteralParser(literal->getSpelling(), body);
1374 body << ")) {\n";
1375 } else if (auto *opVar = dyn_cast<OperandVariable>(firstElement)) {
1376 genElementParser(opVar, body, attrTypeCtx);
1377 body << " if (!" << opVar->getVar()->name << "Operands.empty()) {\n";
1378 } else if (auto *regionVar = dyn_cast<RegionVariable>(firstElement)) {
1379 const NamedRegion *region = regionVar->getVar();
1380 if (region->isVariadic()) {
1381 genElementParser(regionVar, body, attrTypeCtx);
1382 body << " if (!" << region->name << "Regions.empty()) {\n";
1383 } else {
1384 body << llvm::formatv(optionalRegionParserCode, region->name);
1385 body << " if (!" << region->name << "Region->empty()) {\n ";
1386 if (hasImplicitTermTrait)
1387 body << llvm::formatv(regionEnsureTerminatorParserCode, region->name);
1388 else if (hasSingleBlockTrait)
1389 body << llvm::formatv(regionEnsureSingleBlockParserCode,
1390 region->name);
1392 } else if (auto *custom = dyn_cast<CustomDirective>(firstElement)) {
1393 body << " if (auto optResult = [&]() -> ::mlir::OptionalParseResult {\n";
1394 genCustomDirectiveParser(custom, body, useProperties, opCppClassName,
1395 /*isOptional=*/true);
1396 body << " return ::mlir::success();\n"
1397 << " }(); optResult.has_value() && ::mlir::failed(*optResult)) {\n"
1398 << " return ::mlir::failure();\n"
1399 << " } else if (optResult.has_value()) {\n";
1402 genElementParsers(firstElement, thenElements.drop_front(),
1403 /*thenGroup=*/true);
1404 body << " }";
1406 // Generate the else elements.
1407 auto elseElements = optional->getElseElements();
1408 if (!elseElements.empty()) {
1409 body << " else {\n";
1410 ArrayRef<FormatElement *> elseElements =
1411 optional->getElseElements(/*parseable=*/true);
1412 genElementParsers(elseElements.front(), elseElements,
1413 /*thenGroup=*/false);
1414 body << " }";
1416 body << "\n";
1418 /// OIList Directive
1419 } else if (OIListElement *oilist = dyn_cast<OIListElement>(element)) {
1420 for (LiteralElement *le : oilist->getLiteralElements())
1421 body << " bool " << le->getSpelling() << "Clause = false;\n";
1423 // Generate the parsing loop
1424 body << " while(true) {\n";
1425 for (auto clause : oilist->getClauses()) {
1426 LiteralElement *lelement = std::get<0>(clause);
1427 ArrayRef<FormatElement *> pelement = std::get<1>(clause);
1428 body << "if (succeeded(parser.parseOptional";
1429 genLiteralParser(lelement->getSpelling(), body);
1430 body << ")) {\n";
1431 StringRef lelementName = lelement->getSpelling();
1432 body << formatv(oilistParserCode, lelementName);
1433 if (AttributeVariable *unitAttrElem =
1434 oilist->getUnitAttrParsingElement(pelement)) {
1435 if (useProperties) {
1436 body << formatv(
1437 " result.getOrAddProperties<{1}::Properties>().{0} = "
1438 "parser.getBuilder().getUnitAttr();",
1439 unitAttrElem->getVar()->name, opCppClassName);
1440 } else {
1441 body << " result.addAttribute(\"" << unitAttrElem->getVar()->name
1442 << "\", UnitAttr::get(parser.getContext()));\n";
1444 } else {
1445 for (FormatElement *el : pelement)
1446 genElementParser(el, body, attrTypeCtx);
1448 body << " } else ";
1450 body << " {\n";
1451 body << " break;\n";
1452 body << " }\n";
1453 body << "}\n";
1455 /// Literals.
1456 } else if (LiteralElement *literal = dyn_cast<LiteralElement>(element)) {
1457 body << " if (parser.parse";
1458 genLiteralParser(literal->getSpelling(), body);
1459 body << ")\n return ::mlir::failure();\n";
1461 /// Whitespaces.
1462 } else if (isa<WhitespaceElement>(element)) {
1463 // Nothing to parse.
1465 /// Arguments.
1466 } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
1467 bool parseAsOptional =
1468 (genCtx == GenContext::Normal && attr->getVar()->attr.isOptional());
1469 genAttrParser(attr, body, attrTypeCtx, parseAsOptional, useProperties,
1470 opCppClassName);
1472 } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
1473 ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
1474 StringRef name = operand->getVar()->name;
1475 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
1476 body << llvm::formatv(
1477 variadicOfVariadicOperandParserCode, name,
1478 operand->getVar()->constraint.getVariadicOfVariadicSegmentSizeAttr());
1479 else if (lengthKind == ArgumentLengthKind::Variadic)
1480 body << llvm::formatv(variadicOperandParserCode, name);
1481 else if (lengthKind == ArgumentLengthKind::Optional)
1482 body << llvm::formatv(optionalOperandParserCode, name);
1483 else
1484 body << formatv(operandParserCode, name);
1486 } else if (auto *region = dyn_cast<RegionVariable>(element)) {
1487 bool isVariadic = region->getVar()->isVariadic();
1488 body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode,
1489 region->getVar()->name);
1490 if (hasImplicitTermTrait)
1491 body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode
1492 : regionEnsureTerminatorParserCode,
1493 region->getVar()->name);
1494 else if (hasSingleBlockTrait)
1495 body << llvm::formatv(isVariadic ? regionListEnsureSingleBlockParserCode
1496 : regionEnsureSingleBlockParserCode,
1497 region->getVar()->name);
1499 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
1500 bool isVariadic = successor->getVar()->isVariadic();
1501 body << formatv(isVariadic ? successorListParserCode : successorParserCode,
1502 successor->getVar()->name);
1504 /// Directives.
1505 } else if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
1506 body.indent() << "{\n";
1507 body.indent() << "auto loc = parser.getCurrentLocation();(void)loc;\n"
1508 << "if (parser.parseOptionalAttrDict"
1509 << (attrDict->isWithKeyword() ? "WithKeyword" : "")
1510 << "(result.attributes))\n"
1511 << " return ::mlir::failure();\n";
1512 if (useProperties) {
1513 body << "if (failed(verifyInherentAttrs(result.name, result.attributes, "
1514 "[&]() {\n"
1515 << " return parser.emitError(loc) << \"'\" << "
1516 "result.name.getStringRef() << \"' op \";\n"
1517 << " })))\n"
1518 << " return ::mlir::failure();\n";
1520 body.unindent() << "}\n";
1521 body.unindent();
1522 } else if (isa<PropDictDirective>(element)) {
1523 body << " if (parseProperties(parser, result))\n"
1524 << " return ::mlir::failure();\n";
1525 } else if (auto *customDir = dyn_cast<CustomDirective>(element)) {
1526 genCustomDirectiveParser(customDir, body, useProperties, opCppClassName);
1527 } else if (isa<OperandsDirective>(element)) {
1528 body << " [[maybe_unused]] ::llvm::SMLoc allOperandLoc ="
1529 << " parser.getCurrentLocation();\n"
1530 << " if (parser.parseOperandList(allOperands))\n"
1531 << " return ::mlir::failure();\n";
1533 } else if (isa<RegionsDirective>(element)) {
1534 body << llvm::formatv(regionListParserCode, "full");
1535 if (hasImplicitTermTrait)
1536 body << llvm::formatv(regionListEnsureTerminatorParserCode, "full");
1537 else if (hasSingleBlockTrait)
1538 body << llvm::formatv(regionListEnsureSingleBlockParserCode, "full");
1540 } else if (isa<SuccessorsDirective>(element)) {
1541 body << llvm::formatv(successorListParserCode, "full");
1543 } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
1544 ArgumentLengthKind lengthKind;
1545 StringRef listName = getTypeListName(dir->getArg(), lengthKind);
1546 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
1547 body << llvm::formatv(variadicOfVariadicTypeParserCode, listName);
1548 } else if (lengthKind == ArgumentLengthKind::Variadic) {
1549 body << llvm::formatv(variadicTypeParserCode, listName);
1550 } else if (lengthKind == ArgumentLengthKind::Optional) {
1551 body << llvm::formatv(optionalTypeParserCode, listName);
1552 } else {
1553 const char *parserCode =
1554 dir->shouldBeQualified() ? qualifiedTypeParserCode : typeParserCode;
1555 TypeSwitch<FormatElement *>(dir->getArg())
1556 .Case<OperandVariable, ResultVariable>([&](auto operand) {
1557 body << formatv(parserCode,
1558 operand->getVar()->constraint.getCPPClassName(),
1559 listName);
1561 .Default([&](auto operand) {
1562 body << formatv(parserCode, "::mlir::Type", listName);
1565 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
1566 ArgumentLengthKind ignored;
1567 body << formatv(functionalTypeParserCode,
1568 getTypeListName(dir->getInputs(), ignored),
1569 getTypeListName(dir->getResults(), ignored));
1570 } else {
1571 llvm_unreachable("unknown format element");
1575 void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
1576 // If any of type resolutions use transformed variables, make sure that the
1577 // types of those variables are resolved.
1578 SmallPtrSet<const NamedTypeConstraint *, 8> verifiedVariables;
1579 FmtContext verifierFCtx;
1580 for (TypeResolution &resolver :
1581 llvm::concat<TypeResolution>(resultTypes, operandTypes)) {
1582 std::optional<StringRef> transformer = resolver.getVarTransformer();
1583 if (!transformer)
1584 continue;
1585 // Ensure that we don't verify the same variables twice.
1586 const NamedTypeConstraint *variable = resolver.getVariable();
1587 if (!variable || !verifiedVariables.insert(variable).second)
1588 continue;
1590 auto constraint = variable->constraint;
1591 body << " for (::mlir::Type type : " << variable->name << "Types) {\n"
1592 << " (void)type;\n"
1593 << " if (!("
1594 << tgfmt(constraint.getConditionTemplate(),
1595 &verifierFCtx.withSelf("type"))
1596 << ")) {\n"
1597 << formatv(" return parser.emitError(parser.getNameLoc()) << "
1598 "\"'{0}' must be {1}, but got \" << type;\n",
1599 variable->name, constraint.getSummary())
1600 << " }\n"
1601 << " }\n";
1604 // Initialize the set of buildable types.
1605 if (!buildableTypes.empty()) {
1606 FmtContext typeBuilderCtx;
1607 typeBuilderCtx.withBuilder("parser.getBuilder()");
1608 for (auto &it : buildableTypes)
1609 body << " ::mlir::Type odsBuildableType" << it.second << " = "
1610 << tgfmt(it.first, &typeBuilderCtx) << ";\n";
1613 // Emit the code necessary for a type resolver.
1614 auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) {
1615 if (std::optional<int> val = resolver.getBuilderIdx()) {
1616 body << "odsBuildableType" << *val;
1617 } else if (const NamedTypeConstraint *var = resolver.getVariable()) {
1618 if (std::optional<StringRef> tform = resolver.getVarTransformer()) {
1619 FmtContext fmtContext;
1620 fmtContext.addSubst("_ctxt", "parser.getContext()");
1621 if (var->isVariadic())
1622 fmtContext.withSelf(var->name + "Types");
1623 else
1624 fmtContext.withSelf(var->name + "Types[0]");
1625 body << tgfmt(*tform, &fmtContext);
1626 } else {
1627 body << var->name << "Types";
1628 if (!var->isVariadic())
1629 body << "[0]";
1631 } else if (const NamedAttribute *attr = resolver.getAttribute()) {
1632 if (std::optional<StringRef> tform = resolver.getVarTransformer())
1633 body << tgfmt(*tform,
1634 &FmtContext().withSelf(attr->name + "Attr.getType()"));
1635 else
1636 body << attr->name << "Attr.getType()";
1637 } else {
1638 body << curVar << "Types";
1642 // Resolve each of the result types.
1643 if (!infersResultTypes) {
1644 if (allResultTypes) {
1645 body << " result.addTypes(allResultTypes);\n";
1646 } else {
1647 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
1648 body << " result.addTypes(";
1649 emitTypeResolver(resultTypes[i], op.getResultName(i));
1650 body << ");\n";
1655 // Emit the operand type resolutions.
1656 genParserOperandTypeResolution(op, body, emitTypeResolver);
1658 // Handle return type inference once all operands have been resolved
1659 if (infersResultTypes)
1660 body << formatv(inferReturnTypesParserCode, op.getCppClassName());
1663 void OperationFormat::genParserOperandTypeResolution(
1664 Operator &op, MethodBody &body,
1665 function_ref<void(TypeResolution &, StringRef)> emitTypeResolver) {
1666 // Early exit if there are no operands.
1667 if (op.getNumOperands() == 0)
1668 return;
1670 // Handle the case where all operand types are grouped together with
1671 // "types(operands)".
1672 if (allOperandTypes) {
1673 // If `operands` was specified, use the full operand list directly.
1674 if (allOperands) {
1675 body << " if (parser.resolveOperands(allOperands, allOperandTypes, "
1676 "allOperandLoc, result.operands))\n"
1677 " return ::mlir::failure();\n";
1678 return;
1681 // Otherwise, use llvm::concat to merge the disjoint operand lists together.
1682 // llvm::concat does not allow the case of a single range, so guard it here.
1683 body << " if (parser.resolveOperands(";
1684 if (op.getNumOperands() > 1) {
1685 body << "::llvm::concat<const ::mlir::OpAsmParser::UnresolvedOperand>(";
1686 llvm::interleaveComma(op.getOperands(), body, [&](auto &operand) {
1687 body << operand.name << "Operands";
1689 body << ")";
1690 } else {
1691 body << op.operand_begin()->name << "Operands";
1693 body << ", allOperandTypes, parser.getNameLoc(), result.operands))\n"
1694 << " return ::mlir::failure();\n";
1695 return;
1698 // Handle the case where all operands are grouped together with "operands".
1699 if (allOperands) {
1700 body << " if (parser.resolveOperands(allOperands, ";
1702 // Group all of the operand types together to perform the resolution all at
1703 // once. Use llvm::concat to perform the merge. llvm::concat does not allow
1704 // the case of a single range, so guard it here.
1705 if (op.getNumOperands() > 1) {
1706 body << "::llvm::concat<const ::mlir::Type>(";
1707 llvm::interleaveComma(
1708 llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
1709 body << "::llvm::ArrayRef<::mlir::Type>(";
1710 emitTypeResolver(operandTypes[i], op.getOperand(i).name);
1711 body << ")";
1713 body << ")";
1714 } else {
1715 emitTypeResolver(operandTypes.front(), op.getOperand(0).name);
1718 body << ", allOperandLoc, result.operands))\n return "
1719 "::mlir::failure();\n";
1720 return;
1723 // The final case is the one where each of the operands types are resolved
1724 // separately.
1725 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
1726 NamedTypeConstraint &operand = op.getOperand(i);
1727 body << " if (parser.resolveOperands(" << operand.name << "Operands, ";
1729 // Resolve the type of this operand.
1730 TypeResolution &operandType = operandTypes[i];
1731 emitTypeResolver(operandType, operand.name);
1733 body << ", " << operand.name
1734 << "OperandsLoc, result.operands))\n return ::mlir::failure();\n";
1738 void OperationFormat::genParserRegionResolution(Operator &op,
1739 MethodBody &body) {
1740 // Check for the case where all regions were parsed.
1741 bool hasAllRegions = llvm::any_of(
1742 elements, [](FormatElement *elt) { return isa<RegionsDirective>(elt); });
1743 if (hasAllRegions) {
1744 body << " result.addRegions(fullRegions);\n";
1745 return;
1748 // Otherwise, handle each region individually.
1749 for (const NamedRegion &region : op.getRegions()) {
1750 if (region.isVariadic())
1751 body << " result.addRegions(" << region.name << "Regions);\n";
1752 else
1753 body << " result.addRegion(std::move(" << region.name << "Region));\n";
1757 void OperationFormat::genParserSuccessorResolution(Operator &op,
1758 MethodBody &body) {
1759 // Check for the case where all successors were parsed.
1760 bool hasAllSuccessors = llvm::any_of(elements, [](FormatElement *elt) {
1761 return isa<SuccessorsDirective>(elt);
1763 if (hasAllSuccessors) {
1764 body << " result.addSuccessors(fullSuccessors);\n";
1765 return;
1768 // Otherwise, handle each successor individually.
1769 for (const NamedSuccessor &successor : op.getSuccessors()) {
1770 if (successor.isVariadic())
1771 body << " result.addSuccessors(" << successor.name << "Successors);\n";
1772 else
1773 body << " result.addSuccessors(" << successor.name << "Successor);\n";
1777 void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
1778 MethodBody &body) {
1779 if (!allOperands) {
1780 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1781 auto interleaveFn = [&](const NamedTypeConstraint &operand) {
1782 // If the operand is variadic emit the parsed size.
1783 if (operand.isVariableLength())
1784 body << "static_cast<int32_t>(" << operand.name << "Operands.size())";
1785 else
1786 body << "1";
1788 if (op.getDialect().usePropertiesForAttributes()) {
1789 body << "::llvm::copy(::llvm::ArrayRef<int32_t>({";
1790 llvm::interleaveComma(op.getOperands(), body, interleaveFn);
1791 body << formatv("}), "
1792 "result.getOrAddProperties<{0}::Properties>()."
1793 "operandSegmentSizes.begin());\n",
1794 op.getCppClassName());
1795 } else {
1796 body << " result.addAttribute(\"operandSegmentSizes\", "
1797 << "parser.getBuilder().getDenseI32ArrayAttr({";
1798 llvm::interleaveComma(op.getOperands(), body, interleaveFn);
1799 body << "}));\n";
1802 for (const NamedTypeConstraint &operand : op.getOperands()) {
1803 if (!operand.isVariadicOfVariadic())
1804 continue;
1805 if (op.getDialect().usePropertiesForAttributes()) {
1806 body << llvm::formatv(
1807 " result.getOrAddProperties<{0}::Properties>().{1} = "
1808 "parser.getBuilder().getDenseI32ArrayAttr({2}OperandGroupSizes);\n",
1809 op.getCppClassName(),
1810 operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
1811 operand.name);
1812 } else {
1813 body << llvm::formatv(
1814 " result.addAttribute(\"{0}\", "
1815 "parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));"
1816 "\n",
1817 operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
1818 operand.name);
1823 if (!allResultTypes &&
1824 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
1825 auto interleaveFn = [&](const NamedTypeConstraint &result) {
1826 // If the result is variadic emit the parsed size.
1827 if (result.isVariableLength())
1828 body << "static_cast<int32_t>(" << result.name << "Types.size())";
1829 else
1830 body << "1";
1832 if (op.getDialect().usePropertiesForAttributes()) {
1833 body << "::llvm::copy(::llvm::ArrayRef<int32_t>({";
1834 llvm::interleaveComma(op.getResults(), body, interleaveFn);
1835 body << formatv("}), "
1836 "result.getOrAddProperties<{0}::Properties>()."
1837 "resultSegmentSizes.begin());\n",
1838 op.getCppClassName());
1839 } else {
1840 body << " result.addAttribute(\"resultSegmentSizes\", "
1841 << "parser.getBuilder().getDenseI32ArrayAttr({";
1842 llvm::interleaveComma(op.getResults(), body, interleaveFn);
1843 body << "}));\n";
1848 //===----------------------------------------------------------------------===//
1849 // PrinterGen
1851 /// The code snippet used to generate a printer call for a region of an
1852 // operation that has the SingleBlockImplicitTerminator trait.
1854 /// {0}: The name of the region.
1855 const char *regionSingleBlockImplicitTerminatorPrinterCode = R"(
1857 bool printTerminator = true;
1858 if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{
1859 printTerminator = !term->getAttrDictionary().empty() ||
1860 term->getNumOperands() != 0 ||
1861 term->getNumResults() != 0;
1863 _odsPrinter.printRegion({0}, /*printEntryBlockArgs=*/true,
1864 /*printBlockTerminators=*/printTerminator);
1868 /// The code snippet used to generate a printer call for an enum that has cases
1869 /// that can't be represented with a keyword.
1871 /// {0}: The name of the enum attribute.
1872 /// {1}: The name of the enum attributes symbolToString function.
1873 const char *enumAttrBeginPrinterCode = R"(
1875 auto caseValue = {0}();
1876 auto caseValueStr = {1}(caseValue);
1879 /// Generate the printer for the 'prop-dict' directive.
1880 static void genPropDictPrinter(OperationFormat &fmt, Operator &op,
1881 MethodBody &body) {
1882 body << " ::llvm::SmallVector<::llvm::StringRef, 2> elidedProps;\n";
1883 for (const NamedProperty *namedProperty : fmt.usedProperties)
1884 body << " elidedProps.push_back(\"" << namedProperty->name << "\");\n";
1885 for (const NamedAttribute *namedAttr : fmt.usedAttributes)
1886 body << " elidedProps.push_back(\"" << namedAttr->name << "\");\n";
1888 // Add code to check attributes for equality with the default value
1889 // for attributes with the elidePrintingDefaultValue bit set.
1890 for (const NamedAttribute &namedAttr : op.getAttributes()) {
1891 const Attribute &attr = namedAttr.attr;
1892 if (!attr.isDerivedAttr() && attr.hasDefaultValue()) {
1893 const StringRef &name = namedAttr.name;
1894 FmtContext fctx;
1895 fctx.withBuilder("odsBuilder");
1896 std::string defaultValue = std::string(
1897 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
1898 body << " {\n";
1899 body << " ::mlir::Builder odsBuilder(getContext());\n";
1900 body << " ::mlir::Attribute attr = " << op.getGetterName(name)
1901 << "Attr();\n";
1902 body << " if(attr && (attr == " << defaultValue << "))\n";
1903 body << " elidedProps.push_back(\"" << name << "\");\n";
1904 body << " }\n";
1908 body << " _odsPrinter << \" \";\n"
1909 << " printProperties(this->getContext(), _odsPrinter, "
1910 "getProperties(), elidedProps);\n";
1913 /// Generate the printer for the 'attr-dict' directive.
1914 static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
1915 MethodBody &body, bool withKeyword) {
1916 body << " ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;\n";
1917 // Elide the variadic segment size attributes if necessary.
1918 if (!fmt.allOperands &&
1919 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"))
1920 body << " elidedAttrs.push_back(\"operandSegmentSizes\");\n";
1921 if (!fmt.allResultTypes &&
1922 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
1923 body << " elidedAttrs.push_back(\"resultSegmentSizes\");\n";
1924 for (const StringRef key : fmt.inferredAttributes.keys())
1925 body << " elidedAttrs.push_back(\"" << key << "\");\n";
1926 for (const NamedAttribute *attr : fmt.usedAttributes)
1927 body << " elidedAttrs.push_back(\"" << attr->name << "\");\n";
1928 // Add code to check attributes for equality with the default value
1929 // for attributes with the elidePrintingDefaultValue bit set.
1930 for (const NamedAttribute &namedAttr : op.getAttributes()) {
1931 const Attribute &attr = namedAttr.attr;
1932 if (!attr.isDerivedAttr() && attr.hasDefaultValue()) {
1933 const StringRef &name = namedAttr.name;
1934 FmtContext fctx;
1935 fctx.withBuilder("odsBuilder");
1936 std::string defaultValue = std::string(
1937 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
1938 body << " {\n";
1939 body << " ::mlir::Builder odsBuilder(getContext());\n";
1940 body << " ::mlir::Attribute attr = " << op.getGetterName(name)
1941 << "Attr();\n";
1942 body << " if(attr && (attr == " << defaultValue << "))\n";
1943 body << " elidedAttrs.push_back(\"" << name << "\");\n";
1944 body << " }\n";
1947 if (fmt.hasPropDict)
1948 body << " _odsPrinter.printOptionalAttrDict"
1949 << (withKeyword ? "WithKeyword" : "")
1950 << "(llvm::to_vector((*this)->getDiscardableAttrs()), elidedAttrs);\n";
1951 else
1952 body << " _odsPrinter.printOptionalAttrDict"
1953 << (withKeyword ? "WithKeyword" : "")
1954 << "((*this)->getAttrs(), elidedAttrs);\n";
1957 /// Generate the printer for a literal value. `shouldEmitSpace` is true if a
1958 /// space should be emitted before this element. `lastWasPunctuation` is true if
1959 /// the previous element was a punctuation literal.
1960 static void genLiteralPrinter(StringRef value, MethodBody &body,
1961 bool &shouldEmitSpace, bool &lastWasPunctuation) {
1962 body << " _odsPrinter";
1964 // Don't insert a space for certain punctuation.
1965 if (shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation))
1966 body << " << ' '";
1967 body << " << \"" << value << "\";\n";
1969 // Insert a space after certain literals.
1970 shouldEmitSpace =
1971 value.size() != 1 || !StringRef("<({[").contains(value.front());
1972 lastWasPunctuation = value.front() != '_' && !isalpha(value.front());
1975 /// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation`
1976 /// are set to false.
1977 static void genSpacePrinter(bool value, MethodBody &body, bool &shouldEmitSpace,
1978 bool &lastWasPunctuation) {
1979 if (value) {
1980 body << " _odsPrinter << ' ';\n";
1981 lastWasPunctuation = false;
1982 } else {
1983 lastWasPunctuation = true;
1985 shouldEmitSpace = false;
1988 /// Generate the printer for a custom directive parameter.
1989 static void genCustomDirectiveParameterPrinter(FormatElement *element,
1990 const Operator &op,
1991 MethodBody &body) {
1992 if (auto *attr = dyn_cast<AttributeVariable>(element)) {
1993 body << op.getGetterName(attr->getVar()->name) << "Attr()";
1995 } else if (isa<AttrDictDirective>(element)) {
1996 body << "getOperation()->getAttrDictionary()";
1998 } else if (isa<PropDictDirective>(element)) {
1999 body << "getProperties()";
2001 } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
2002 body << op.getGetterName(operand->getVar()->name) << "()";
2004 } else if (auto *region = dyn_cast<RegionVariable>(element)) {
2005 body << op.getGetterName(region->getVar()->name) << "()";
2007 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
2008 body << op.getGetterName(successor->getVar()->name) << "()";
2010 } else if (auto *dir = dyn_cast<RefDirective>(element)) {
2011 genCustomDirectiveParameterPrinter(dir->getArg(), op, body);
2013 } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
2014 auto *typeOperand = dir->getArg();
2015 auto *operand = dyn_cast<OperandVariable>(typeOperand);
2016 auto *var = operand ? operand->getVar()
2017 : cast<ResultVariable>(typeOperand)->getVar();
2018 std::string name = op.getGetterName(var->name);
2019 if (var->isVariadic())
2020 body << name << "().getTypes()";
2021 else if (var->isOptional())
2022 body << llvm::formatv("({0}() ? {0}().getType() : ::mlir::Type())", name);
2023 else
2024 body << name << "().getType()";
2026 } else if (auto *string = dyn_cast<StringElement>(element)) {
2027 FmtContext ctx;
2028 ctx.withBuilder("::mlir::Builder(getContext())");
2029 ctx.addSubst("_ctxt", "getContext()");
2030 body << tgfmt(string->getValue(), &ctx);
2032 } else if (auto *property = dyn_cast<PropertyVariable>(element)) {
2033 FmtContext ctx;
2034 ctx.addSubst("_ctxt", "getContext()");
2035 const NamedProperty *namedProperty = property->getVar();
2036 ctx.addSubst("_storage", "getProperties()." + namedProperty->name);
2037 body << tgfmt(namedProperty->prop.getConvertFromStorageCall(), &ctx);
2038 } else {
2039 llvm_unreachable("unknown custom directive parameter");
2043 /// Generate the printer for a custom directive.
2044 static void genCustomDirectivePrinter(CustomDirective *customDir,
2045 const Operator &op, MethodBody &body) {
2046 body << " print" << customDir->getName() << "(_odsPrinter, *this";
2047 for (FormatElement *param : customDir->getArguments()) {
2048 body << ", ";
2049 genCustomDirectiveParameterPrinter(param, op, body);
2051 body << ");\n";
2054 /// Generate the printer for a region with the given variable name.
2055 static void genRegionPrinter(const Twine &regionName, MethodBody &body,
2056 bool hasImplicitTermTrait) {
2057 if (hasImplicitTermTrait)
2058 body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode,
2059 regionName);
2060 else
2061 body << " _odsPrinter.printRegion(" << regionName << ");\n";
2063 static void genVariadicRegionPrinter(const Twine &regionListName,
2064 MethodBody &body,
2065 bool hasImplicitTermTrait) {
2066 body << " llvm::interleaveComma(" << regionListName
2067 << ", _odsPrinter, [&](::mlir::Region &region) {\n ";
2068 genRegionPrinter("region", body, hasImplicitTermTrait);
2069 body << " });\n";
2072 /// Generate the C++ for an operand to a (*-)type directive.
2073 static MethodBody &genTypeOperandPrinter(FormatElement *arg, const Operator &op,
2074 MethodBody &body,
2075 bool useArrayRef = true) {
2076 if (isa<OperandsDirective>(arg))
2077 return body << "getOperation()->getOperandTypes()";
2078 if (isa<ResultsDirective>(arg))
2079 return body << "getOperation()->getResultTypes()";
2080 auto *operand = dyn_cast<OperandVariable>(arg);
2081 auto *var = operand ? operand->getVar() : cast<ResultVariable>(arg)->getVar();
2082 if (var->isVariadicOfVariadic())
2083 return body << llvm::formatv("{0}().join().getTypes()",
2084 op.getGetterName(var->name));
2085 if (var->isVariadic())
2086 return body << op.getGetterName(var->name) << "().getTypes()";
2087 if (var->isOptional())
2088 return body << llvm::formatv(
2089 "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
2090 "::llvm::ArrayRef<::mlir::Type>())",
2091 op.getGetterName(var->name));
2092 if (useArrayRef)
2093 return body << "::llvm::ArrayRef<::mlir::Type>("
2094 << op.getGetterName(var->name) << "().getType())";
2095 return body << op.getGetterName(var->name) << "().getType()";
2098 /// Generate the printer for an enum attribute.
2099 static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
2100 MethodBody &body) {
2101 Attribute baseAttr = var->attr.getBaseAttr();
2102 const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
2103 std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
2105 body << llvm::formatv(enumAttrBeginPrinterCode,
2106 (var->attr.isOptional() ? "*" : "") +
2107 op.getGetterName(var->name),
2108 enumAttr.getSymbolToStringFnName());
2110 // Get a string containing all of the cases that can't be represented with a
2111 // keyword.
2112 BitVector nonKeywordCases(cases.size());
2113 for (auto it : llvm::enumerate(cases)) {
2114 if (!canFormatStringAsKeyword(it.value().getStr()))
2115 nonKeywordCases.set(it.index());
2118 // Otherwise if this is a bit enum attribute, don't allow cases that may
2119 // overlap with other cases. For simplicity sake, only allow cases with a
2120 // single bit value.
2121 if (enumAttr.isBitEnum()) {
2122 for (auto it : llvm::enumerate(cases)) {
2123 int64_t value = it.value().getValue();
2124 if (value < 0 || !llvm::isPowerOf2_64(value))
2125 nonKeywordCases.set(it.index());
2129 // If there are any cases that can't be used with a keyword, switch on the
2130 // case value to determine when to print in the string form.
2131 if (nonKeywordCases.any()) {
2132 body << " switch (caseValue) {\n";
2133 StringRef cppNamespace = enumAttr.getCppNamespace();
2134 StringRef enumName = enumAttr.getEnumClassName();
2135 for (auto it : llvm::enumerate(cases)) {
2136 if (nonKeywordCases.test(it.index()))
2137 continue;
2138 StringRef symbol = it.value().getSymbol();
2139 body << llvm::formatv(" case {0}::{1}::{2}:\n", cppNamespace, enumName,
2140 llvm::isDigit(symbol.front()) ? ("_" + symbol)
2141 : symbol);
2143 body << " _odsPrinter << caseValueStr;\n"
2144 " break;\n"
2145 " default:\n"
2146 " _odsPrinter << '\"' << caseValueStr << '\"';\n"
2147 " break;\n"
2148 " }\n"
2149 " }\n";
2150 return;
2153 body << " _odsPrinter << caseValueStr;\n"
2154 " }\n";
2157 /// Generate a check that a DefaultValuedAttr has a value that is non-default.
2158 static void genNonDefaultValueCheck(MethodBody &body, const Operator &op,
2159 AttributeVariable &attrElement) {
2160 FmtContext fctx;
2161 Attribute attr = attrElement.getVar()->attr;
2162 fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())");
2163 body << " && " << op.getGetterName(attrElement.getVar()->name) << "Attr() != "
2164 << tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue());
2167 /// Generate the check for the anchor of an optional group.
2168 static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
2169 const Operator &op,
2170 MethodBody &body) {
2171 TypeSwitch<FormatElement *>(anchor)
2172 .Case<OperandVariable, ResultVariable>([&](auto *element) {
2173 const NamedTypeConstraint *var = element->getVar();
2174 std::string name = op.getGetterName(var->name);
2175 if (var->isOptional())
2176 body << name << "()";
2177 else if (var->isVariadic())
2178 body << "!" << name << "().empty()";
2180 .Case([&](RegionVariable *element) {
2181 const NamedRegion *var = element->getVar();
2182 std::string name = op.getGetterName(var->name);
2183 // TODO: Add a check for optional regions here when ODS supports it.
2184 body << "!" << name << "().empty()";
2186 .Case([&](TypeDirective *element) {
2187 genOptionalGroupPrinterAnchor(element->getArg(), op, body);
2189 .Case([&](FunctionalTypeDirective *element) {
2190 genOptionalGroupPrinterAnchor(element->getInputs(), op, body);
2192 .Case([&](AttributeVariable *element) {
2193 Attribute attr = element->getVar()->attr;
2194 body << op.getGetterName(element->getVar()->name) << "Attr()";
2195 if (attr.isOptional())
2196 return; // done
2197 if (attr.hasDefaultValue()) {
2198 // Consider a default-valued attribute as present if it's not the
2199 // default value.
2200 genNonDefaultValueCheck(body, op, *element);
2201 return;
2203 llvm_unreachable("attribute must be optional or default-valued");
2205 .Case([&](CustomDirective *ele) {
2206 body << '(';
2207 llvm::interleave(
2208 ele->getArguments(), body,
2209 [&](FormatElement *child) {
2210 body << '(';
2211 genOptionalGroupPrinterAnchor(child, op, body);
2212 body << ')';
2214 " || ");
2215 body << ')';
2219 void collect(FormatElement *element,
2220 SmallVectorImpl<VariableElement *> &variables) {
2221 TypeSwitch<FormatElement *>(element)
2222 .Case([&](VariableElement *var) { variables.emplace_back(var); })
2223 .Case([&](CustomDirective *ele) {
2224 for (FormatElement *arg : ele->getArguments())
2225 collect(arg, variables);
2227 .Case([&](OptionalElement *ele) {
2228 for (FormatElement *arg : ele->getThenElements())
2229 collect(arg, variables);
2230 for (FormatElement *arg : ele->getElseElements())
2231 collect(arg, variables);
2233 .Case([&](FunctionalTypeDirective *funcType) {
2234 collect(funcType->getInputs(), variables);
2235 collect(funcType->getResults(), variables);
2237 .Case([&](OIListElement *oilist) {
2238 for (ArrayRef<FormatElement *> arg : oilist->getParsingElements())
2239 for (FormatElement *arg : arg)
2240 collect(arg, variables);
2244 void OperationFormat::genElementPrinter(FormatElement *element,
2245 MethodBody &body, Operator &op,
2246 bool &shouldEmitSpace,
2247 bool &lastWasPunctuation) {
2248 if (LiteralElement *literal = dyn_cast<LiteralElement>(element))
2249 return genLiteralPrinter(literal->getSpelling(), body, shouldEmitSpace,
2250 lastWasPunctuation);
2252 // Emit a whitespace element.
2253 if (auto *space = dyn_cast<WhitespaceElement>(element)) {
2254 if (space->getValue() == "\\n") {
2255 body << " _odsPrinter.printNewline();\n";
2256 } else {
2257 genSpacePrinter(!space->getValue().empty(), body, shouldEmitSpace,
2258 lastWasPunctuation);
2260 return;
2263 // Emit an optional group.
2264 if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) {
2265 // Emit the check for the presence of the anchor element.
2266 FormatElement *anchor = optional->getAnchor();
2267 body << " if (";
2268 if (optional->isInverted())
2269 body << "!";
2270 genOptionalGroupPrinterAnchor(anchor, op, body);
2271 body << ") {\n";
2272 body.indent();
2274 // If the anchor is a unit attribute, we don't need to print it. When
2275 // parsing, we will add this attribute if this group is present.
2276 ArrayRef<FormatElement *> thenElements = optional->getThenElements();
2277 ArrayRef<FormatElement *> elseElements = optional->getElseElements();
2278 FormatElement *elidedAnchorElement = nullptr;
2279 auto *anchorAttr = dyn_cast<AttributeVariable>(anchor);
2280 if (anchorAttr && anchorAttr != thenElements.front() &&
2281 (elseElements.empty() || anchorAttr != elseElements.front()) &&
2282 anchorAttr->isUnitAttr()) {
2283 elidedAnchorElement = anchorAttr;
2285 auto genElementPrinters = [&](ArrayRef<FormatElement *> elements) {
2286 for (FormatElement *childElement : elements) {
2287 if (childElement != elidedAnchorElement) {
2288 genElementPrinter(childElement, body, op, shouldEmitSpace,
2289 lastWasPunctuation);
2294 // Emit each of the elements.
2295 genElementPrinters(thenElements);
2296 body << "}";
2298 // Emit each of the else elements.
2299 if (!elseElements.empty()) {
2300 body << " else {\n";
2301 genElementPrinters(elseElements);
2302 body << "}";
2305 body.unindent() << "\n";
2306 return;
2309 // Emit the OIList
2310 if (auto *oilist = dyn_cast<OIListElement>(element)) {
2311 for (auto clause : oilist->getClauses()) {
2312 LiteralElement *lelement = std::get<0>(clause);
2313 ArrayRef<FormatElement *> pelement = std::get<1>(clause);
2315 SmallVector<VariableElement *> vars;
2316 for (FormatElement *el : pelement)
2317 collect(el, vars);
2318 body << " if (false";
2319 for (VariableElement *var : vars) {
2320 TypeSwitch<FormatElement *>(var)
2321 .Case([&](AttributeVariable *attrEle) {
2322 body << " || (" << op.getGetterName(attrEle->getVar()->name)
2323 << "Attr()";
2324 Attribute attr = attrEle->getVar()->attr;
2325 if (attr.hasDefaultValue()) {
2326 // Don't print default-valued attributes.
2327 genNonDefaultValueCheck(body, op, *attrEle);
2329 body << ")";
2331 .Case([&](OperandVariable *ele) {
2332 if (ele->getVar()->isVariadic()) {
2333 body << " || " << op.getGetterName(ele->getVar()->name)
2334 << "().size()";
2335 } else {
2336 body << " || " << op.getGetterName(ele->getVar()->name) << "()";
2339 .Case([&](ResultVariable *ele) {
2340 if (ele->getVar()->isVariadic()) {
2341 body << " || " << op.getGetterName(ele->getVar()->name)
2342 << "().size()";
2343 } else {
2344 body << " || " << op.getGetterName(ele->getVar()->name) << "()";
2347 .Case([&](RegionVariable *reg) {
2348 body << " || " << op.getGetterName(reg->getVar()->name) << "()";
2352 body << ") {\n";
2353 genLiteralPrinter(lelement->getSpelling(), body, shouldEmitSpace,
2354 lastWasPunctuation);
2355 if (oilist->getUnitAttrParsingElement(pelement) == nullptr) {
2356 for (FormatElement *element : pelement)
2357 genElementPrinter(element, body, op, shouldEmitSpace,
2358 lastWasPunctuation);
2360 body << " }\n";
2362 return;
2365 // Emit the attribute dictionary.
2366 if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
2367 genAttrDictPrinter(*this, op, body, attrDict->isWithKeyword());
2368 lastWasPunctuation = false;
2369 return;
2372 // Emit the attribute dictionary.
2373 if (isa<PropDictDirective>(element)) {
2374 genPropDictPrinter(*this, op, body);
2375 lastWasPunctuation = false;
2376 return;
2379 // Optionally insert a space before the next element. The AttrDict printer
2380 // already adds a space as necessary.
2381 if (shouldEmitSpace || !lastWasPunctuation)
2382 body << " _odsPrinter << ' ';\n";
2383 lastWasPunctuation = false;
2384 shouldEmitSpace = true;
2386 if (auto *attr = dyn_cast<AttributeVariable>(element)) {
2387 const NamedAttribute *var = attr->getVar();
2389 // If we are formatting as an enum, symbolize the attribute as a string.
2390 if (canFormatEnumAttr(var))
2391 return genEnumAttrPrinter(var, op, body);
2393 // If we are formatting as a symbol name, handle it as a symbol name.
2394 if (shouldFormatSymbolNameAttr(var)) {
2395 body << " _odsPrinter.printSymbolName(" << op.getGetterName(var->name)
2396 << "Attr().getValue());\n";
2397 return;
2400 // Elide the attribute type if it is buildable.
2401 if (attr->getTypeBuilder())
2402 body << " _odsPrinter.printAttributeWithoutType("
2403 << op.getGetterName(var->name) << "Attr());\n";
2404 else if (attr->shouldBeQualified() ||
2405 var->attr.getStorageType() == "::mlir::Attribute")
2406 body << " _odsPrinter.printAttribute(" << op.getGetterName(var->name)
2407 << "Attr());\n";
2408 else
2409 body << "_odsPrinter.printStrippedAttrOrType("
2410 << op.getGetterName(var->name) << "Attr());\n";
2411 } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
2412 if (operand->getVar()->isVariadicOfVariadic()) {
2413 body << " ::llvm::interleaveComma("
2414 << op.getGetterName(operand->getVar()->name)
2415 << "(), _odsPrinter, [&](const auto &operands) { _odsPrinter << "
2416 "\"(\" << operands << "
2417 "\")\"; });\n";
2419 } else if (operand->getVar()->isOptional()) {
2420 body << " if (::mlir::Value value = "
2421 << op.getGetterName(operand->getVar()->name) << "())\n"
2422 << " _odsPrinter << value;\n";
2423 } else {
2424 body << " _odsPrinter << " << op.getGetterName(operand->getVar()->name)
2425 << "();\n";
2427 } else if (auto *region = dyn_cast<RegionVariable>(element)) {
2428 const NamedRegion *var = region->getVar();
2429 std::string name = op.getGetterName(var->name);
2430 if (var->isVariadic()) {
2431 genVariadicRegionPrinter(name + "()", body, hasImplicitTermTrait);
2432 } else {
2433 genRegionPrinter(name + "()", body, hasImplicitTermTrait);
2435 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
2436 const NamedSuccessor *var = successor->getVar();
2437 std::string name = op.getGetterName(var->name);
2438 if (var->isVariadic())
2439 body << " ::llvm::interleaveComma(" << name << "(), _odsPrinter);\n";
2440 else
2441 body << " _odsPrinter << " << name << "();\n";
2442 } else if (auto *dir = dyn_cast<CustomDirective>(element)) {
2443 genCustomDirectivePrinter(dir, op, body);
2444 } else if (isa<OperandsDirective>(element)) {
2445 body << " _odsPrinter << getOperation()->getOperands();\n";
2446 } else if (isa<RegionsDirective>(element)) {
2447 genVariadicRegionPrinter("getOperation()->getRegions()", body,
2448 hasImplicitTermTrait);
2449 } else if (isa<SuccessorsDirective>(element)) {
2450 body << " ::llvm::interleaveComma(getOperation()->getSuccessors(), "
2451 "_odsPrinter);\n";
2452 } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
2453 if (auto *operand = dyn_cast<OperandVariable>(dir->getArg())) {
2454 if (operand->getVar()->isVariadicOfVariadic()) {
2455 body << llvm::formatv(
2456 " ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, "
2457 "[&](::mlir::TypeRange types) {{ _odsPrinter << \"(\" << "
2458 "types << \")\"; });\n",
2459 op.getGetterName(operand->getVar()->name));
2460 return;
2463 const NamedTypeConstraint *var = nullptr;
2465 if (auto *operand = dyn_cast<OperandVariable>(dir->getArg()))
2466 var = operand->getVar();
2467 else if (auto *operand = dyn_cast<ResultVariable>(dir->getArg()))
2468 var = operand->getVar();
2470 if (var && !var->isVariadicOfVariadic() && !var->isVariadic() &&
2471 !var->isOptional()) {
2472 std::string cppClass = var->constraint.getCPPClassName();
2473 if (dir->shouldBeQualified()) {
2474 body << " _odsPrinter << " << op.getGetterName(var->name)
2475 << "().getType();\n";
2476 return;
2478 body << " {\n"
2479 << " auto type = " << op.getGetterName(var->name)
2480 << "().getType();\n"
2481 << " if (auto validType = ::llvm::dyn_cast<" << cppClass
2482 << ">(type))\n"
2483 << " _odsPrinter.printStrippedAttrOrType(validType);\n"
2484 << " else\n"
2485 << " _odsPrinter << type;\n"
2486 << " }\n";
2487 return;
2489 body << " _odsPrinter << ";
2490 genTypeOperandPrinter(dir->getArg(), op, body, /*useArrayRef=*/false)
2491 << ";\n";
2492 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
2493 body << " _odsPrinter.printFunctionalType(";
2494 genTypeOperandPrinter(dir->getInputs(), op, body) << ", ";
2495 genTypeOperandPrinter(dir->getResults(), op, body) << ");\n";
2496 } else {
2497 llvm_unreachable("unknown format element");
2501 void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
2502 auto *method = opClass.addMethod(
2503 "void", "print",
2504 MethodParameter("::mlir::OpAsmPrinter &", "_odsPrinter"));
2505 auto &body = method->body();
2507 // Flags for if we should emit a space, and if the last element was
2508 // punctuation.
2509 bool shouldEmitSpace = true, lastWasPunctuation = false;
2510 for (FormatElement *element : elements)
2511 genElementPrinter(element, body, op, shouldEmitSpace, lastWasPunctuation);
2514 //===----------------------------------------------------------------------===//
2515 // OpFormatParser
2516 //===----------------------------------------------------------------------===//
2518 /// Function to find an element within the given range that has the same name as
2519 /// 'name'.
2520 template <typename RangeT>
2521 static auto findArg(RangeT &&range, StringRef name) {
2522 auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
2523 return it != range.end() ? &*it : nullptr;
2526 namespace {
2527 /// This class implements a parser for an instance of an operation assembly
2528 /// format.
2529 class OpFormatParser : public FormatParser {
2530 public:
2531 OpFormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op)
2532 : FormatParser(mgr, op.getLoc()[0]), fmt(format), op(op),
2533 seenOperandTypes(op.getNumOperands()),
2534 seenResultTypes(op.getNumResults()) {}
2536 protected:
2537 /// Verify the format elements.
2538 LogicalResult verify(SMLoc loc, ArrayRef<FormatElement *> elements) override;
2539 /// Verify the arguments to a custom directive.
2540 LogicalResult
2541 verifyCustomDirectiveArguments(SMLoc loc,
2542 ArrayRef<FormatElement *> arguments) override;
2543 /// Verify the elements of an optional group.
2544 LogicalResult verifyOptionalGroupElements(SMLoc loc,
2545 ArrayRef<FormatElement *> elements,
2546 FormatElement *anchor) override;
2547 LogicalResult verifyOptionalGroupElement(SMLoc loc, FormatElement *element,
2548 bool isAnchor);
2550 LogicalResult markQualified(SMLoc loc, FormatElement *element) override;
2552 /// Parse an operation variable.
2553 FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
2554 Context ctx) override;
2555 /// Parse an operation format directive.
2556 FailureOr<FormatElement *>
2557 parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override;
2559 private:
2560 /// This struct represents a type resolution instance. It includes a specific
2561 /// type as well as an optional transformer to apply to that type in order to
2562 /// properly resolve the type of a variable.
2563 struct TypeResolutionInstance {
2564 ConstArgument resolver;
2565 std::optional<StringRef> transformer;
2568 /// Verify the state of operation attributes within the format.
2569 LogicalResult verifyAttributes(SMLoc loc, ArrayRef<FormatElement *> elements);
2571 /// Verify that attributes elements aren't followed by colon literals.
2572 LogicalResult verifyAttributeColonType(SMLoc loc,
2573 ArrayRef<FormatElement *> elements);
2574 /// Verify that the attribute dictionary directive isn't followed by a region.
2575 LogicalResult verifyAttrDictRegion(SMLoc loc,
2576 ArrayRef<FormatElement *> elements);
2578 /// Verify the state of operation operands within the format.
2579 LogicalResult
2580 verifyOperands(SMLoc loc,
2581 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2583 /// Verify the state of operation regions within the format.
2584 LogicalResult verifyRegions(SMLoc loc);
2586 /// Verify the state of operation results within the format.
2587 LogicalResult
2588 verifyResults(SMLoc loc,
2589 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2591 /// Verify the state of operation successors within the format.
2592 LogicalResult verifySuccessors(SMLoc loc);
2594 LogicalResult verifyOIListElements(SMLoc loc,
2595 ArrayRef<FormatElement *> elements);
2597 /// Given the values of an `AllTypesMatch` trait, check for inferable type
2598 /// resolution.
2599 void handleAllTypesMatchConstraint(
2600 ArrayRef<StringRef> values,
2601 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2602 /// Check for inferable type resolution given all operands, and or results,
2603 /// have the same type. If 'includeResults' is true, the results also have the
2604 /// same type as all of the operands.
2605 void handleSameTypesConstraint(
2606 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2607 bool includeResults);
2608 /// Check for inferable type resolution based on another operand, result, or
2609 /// attribute.
2610 void handleTypesMatchConstraint(
2611 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2612 const llvm::Record &def);
2614 /// Returns an argument or attribute with the given name that has been seen
2615 /// within the format.
2616 ConstArgument findSeenArg(StringRef name);
2618 /// Parse the various different directives.
2619 FailureOr<FormatElement *> parsePropDictDirective(SMLoc loc, Context context);
2620 FailureOr<FormatElement *> parseAttrDictDirective(SMLoc loc, Context context,
2621 bool withKeyword);
2622 FailureOr<FormatElement *> parseFunctionalTypeDirective(SMLoc loc,
2623 Context context);
2624 FailureOr<FormatElement *> parseOIListDirective(SMLoc loc, Context context);
2625 LogicalResult verifyOIListParsingElement(FormatElement *element, SMLoc loc);
2626 FailureOr<FormatElement *> parseOperandsDirective(SMLoc loc, Context context);
2627 FailureOr<FormatElement *> parseRegionsDirective(SMLoc loc, Context context);
2628 FailureOr<FormatElement *> parseResultsDirective(SMLoc loc, Context context);
2629 FailureOr<FormatElement *> parseSuccessorsDirective(SMLoc loc,
2630 Context context);
2631 FailureOr<FormatElement *> parseTypeDirective(SMLoc loc, Context context);
2632 FailureOr<FormatElement *> parseTypeDirectiveOperand(SMLoc loc,
2633 bool isRefChild = false);
2635 //===--------------------------------------------------------------------===//
2636 // Fields
2637 //===--------------------------------------------------------------------===//
2639 OperationFormat &fmt;
2640 Operator &op;
2642 // The following are various bits of format state used for verification
2643 // during parsing.
2644 bool hasAttrDict = false;
2645 bool hasPropDict = false;
2646 bool hasAllRegions = false, hasAllSuccessors = false;
2647 bool canInferResultTypes = false;
2648 llvm::SmallBitVector seenOperandTypes, seenResultTypes;
2649 llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs;
2650 llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
2651 llvm::DenseSet<const NamedRegion *> seenRegions;
2652 llvm::DenseSet<const NamedSuccessor *> seenSuccessors;
2653 llvm::SmallSetVector<const NamedProperty *, 8> seenProperties;
2655 } // namespace
2657 LogicalResult OpFormatParser::verify(SMLoc loc,
2658 ArrayRef<FormatElement *> elements) {
2659 // Check that the attribute dictionary is in the format.
2660 if (!hasAttrDict)
2661 return emitError(loc, "'attr-dict' directive not found in "
2662 "custom assembly format");
2664 // Check for any type traits that we can use for inferring types.
2665 llvm::StringMap<TypeResolutionInstance> variableTyResolver;
2666 for (const Trait &trait : op.getTraits()) {
2667 const llvm::Record &def = trait.getDef();
2668 if (def.isSubClassOf("AllTypesMatch")) {
2669 handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),
2670 variableTyResolver);
2671 } else if (def.getName() == "SameTypeOperands") {
2672 handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false);
2673 } else if (def.getName() == "SameOperandsAndResultType") {
2674 handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
2675 } else if (def.isSubClassOf("TypesMatchWith")) {
2676 handleTypesMatchConstraint(variableTyResolver, def);
2677 } else if (!op.allResultTypesKnown()) {
2678 // This doesn't check the name directly to handle
2679 // DeclareOpInterfaceMethods<InferTypeOpInterface>
2680 // and the like.
2681 // TODO: Add hasCppInterface check.
2682 if (auto name = def.getValueAsOptionalString("cppInterfaceName")) {
2683 if (*name == "InferTypeOpInterface" &&
2684 def.getValueAsString("cppNamespace") == "::mlir")
2685 canInferResultTypes = true;
2690 // Verify the state of the various operation components.
2691 if (failed(verifyAttributes(loc, elements)) ||
2692 failed(verifyResults(loc, variableTyResolver)) ||
2693 failed(verifyOperands(loc, variableTyResolver)) ||
2694 failed(verifyRegions(loc)) || failed(verifySuccessors(loc)) ||
2695 failed(verifyOIListElements(loc, elements)))
2696 return failure();
2698 // Collect the set of used attributes in the format.
2699 fmt.usedAttributes = std::move(seenAttrs);
2700 fmt.usedProperties = std::move(seenProperties);
2702 // Set whether prop-dict is used in the format
2703 fmt.hasPropDict = hasPropDict;
2704 return success();
2707 LogicalResult
2708 OpFormatParser::verifyAttributes(SMLoc loc,
2709 ArrayRef<FormatElement *> elements) {
2710 // Check that there are no `:` literals after an attribute without a constant
2711 // type. The attribute grammar contains an optional trailing colon type, which
2712 // can lead to unexpected and generally unintended behavior. Given that, it is
2713 // better to just error out here instead.
2714 if (failed(verifyAttributeColonType(loc, elements)))
2715 return failure();
2716 // Check that there are no region variables following an attribute dicitonary.
2717 // Both start with `{` and so the optional attribute dictionary can cause
2718 // format ambiguities.
2719 if (failed(verifyAttrDictRegion(loc, elements)))
2720 return failure();
2722 // Check for VariadicOfVariadic variables. The segment attribute of those
2723 // variables will be infered.
2724 for (const NamedTypeConstraint *var : seenOperands) {
2725 if (var->constraint.isVariadicOfVariadic()) {
2726 fmt.inferredAttributes.insert(
2727 var->constraint.getVariadicOfVariadicSegmentSizeAttr());
2731 return success();
2734 /// Returns whether the single format element is optionally parsed.
2735 static bool isOptionallyParsed(FormatElement *el) {
2736 if (auto *attrVar = dyn_cast<AttributeVariable>(el)) {
2737 Attribute attr = attrVar->getVar()->attr;
2738 return attr.isOptional() || attr.hasDefaultValue();
2740 if (auto *operandVar = dyn_cast<OperandVariable>(el)) {
2741 const NamedTypeConstraint *operand = operandVar->getVar();
2742 return operand->isOptional() || operand->isVariadic() ||
2743 operand->isVariadicOfVariadic();
2745 if (auto *successorVar = dyn_cast<SuccessorVariable>(el))
2746 return successorVar->getVar()->isVariadic();
2747 if (auto *regionVar = dyn_cast<RegionVariable>(el))
2748 return regionVar->getVar()->isVariadic();
2749 return isa<WhitespaceElement, AttrDictDirective>(el);
2752 /// Scan the given range of elements from the start for an invalid format
2753 /// element that satisfies `isInvalid`, skipping any optionally-parsed elements.
2754 /// If an optional group is encountered, this function recurses into the 'then'
2755 /// and 'else' elements to check if they are invalid. Returns `success` if the
2756 /// range is known to be valid or `std::nullopt` if scanning reached the end.
2758 /// Since the guard element of an optional group is required, this function
2759 /// accepts an optional element pointer to mark it as required.
2760 static std::optional<LogicalResult> checkRangeForElement(
2761 FormatElement *base,
2762 function_ref<bool(FormatElement *, FormatElement *)> isInvalid,
2763 iterator_range<ArrayRef<FormatElement *>::iterator> elementRange,
2764 FormatElement *optionalGuard = nullptr) {
2765 for (FormatElement *element : elementRange) {
2766 // If we encounter an invalid element, return an error.
2767 if (isInvalid(base, element))
2768 return failure();
2770 // Recurse on optional groups.
2771 if (auto *optional = dyn_cast<OptionalElement>(element)) {
2772 if (std::optional<LogicalResult> result = checkRangeForElement(
2773 base, isInvalid, optional->getThenElements(),
2774 // The optional group guard is required for the group.
2775 optional->getThenElements().front()))
2776 if (failed(*result))
2777 return failure();
2778 if (std::optional<LogicalResult> result = checkRangeForElement(
2779 base, isInvalid, optional->getElseElements()))
2780 if (failed(*result))
2781 return failure();
2782 // Skip the optional group.
2783 continue;
2786 // Skip optionally parsed elements.
2787 if (element != optionalGuard && isOptionallyParsed(element))
2788 continue;
2790 // We found a closing element that is valid.
2791 return success();
2793 // Return std::nullopt to indicate that we reached the end.
2794 return std::nullopt;
2797 /// For the given elements, check whether any attributes are followed by a colon
2798 /// literal, resulting in an ambiguous assembly format. Returns a non-null
2799 /// attribute if verification of said attribute reached the end of the range.
2800 /// Returns null if all attribute elements are verified.
2801 static FailureOr<FormatElement *> verifyAdjacentElements(
2802 function_ref<bool(FormatElement *)> isBase,
2803 function_ref<bool(FormatElement *, FormatElement *)> isInvalid,
2804 ArrayRef<FormatElement *> elements) {
2805 for (auto *it = elements.begin(), *e = elements.end(); it != e; ++it) {
2806 // The current attribute being verified.
2807 FormatElement *base;
2809 if (isBase(*it)) {
2810 base = *it;
2811 } else if (auto *optional = dyn_cast<OptionalElement>(*it)) {
2812 // Recurse on optional groups.
2813 FailureOr<FormatElement *> thenResult = verifyAdjacentElements(
2814 isBase, isInvalid, optional->getThenElements());
2815 if (failed(thenResult))
2816 return failure();
2817 FailureOr<FormatElement *> elseResult = verifyAdjacentElements(
2818 isBase, isInvalid, optional->getElseElements());
2819 if (failed(elseResult))
2820 return failure();
2821 // If either optional group has an unverified attribute, save it.
2822 // Otherwise, move on to the next element.
2823 if (!(base = *thenResult) && !(base = *elseResult))
2824 continue;
2825 } else {
2826 continue;
2829 // Verify subsequent elements for potential ambiguities.
2830 if (std::optional<LogicalResult> result =
2831 checkRangeForElement(base, isInvalid, {std::next(it), e})) {
2832 if (failed(*result))
2833 return failure();
2834 } else {
2835 // Since we reached the end, return the attribute as unverified.
2836 return base;
2839 // All attribute elements are known to be verified.
2840 return nullptr;
2843 LogicalResult
2844 OpFormatParser::verifyAttributeColonType(SMLoc loc,
2845 ArrayRef<FormatElement *> elements) {
2846 auto isBase = [](FormatElement *el) {
2847 auto *attr = dyn_cast<AttributeVariable>(el);
2848 if (!attr)
2849 return false;
2850 // Check only attributes without type builders or that are known to call
2851 // the generic attribute parser.
2852 return !attr->getTypeBuilder() &&
2853 (attr->shouldBeQualified() ||
2854 attr->getVar()->attr.getStorageType() == "::mlir::Attribute");
2856 auto isInvalid = [&](FormatElement *base, FormatElement *el) {
2857 auto *literal = dyn_cast<LiteralElement>(el);
2858 if (!literal || literal->getSpelling() != ":")
2859 return false;
2860 // If we encounter `:`, the range is known to be invalid.
2861 (void)emitError(
2862 loc,
2863 llvm::formatv("format ambiguity caused by `:` literal found after "
2864 "attribute `{0}` which does not have a buildable type",
2865 cast<AttributeVariable>(base)->getVar()->name));
2866 return true;
2868 return verifyAdjacentElements(isBase, isInvalid, elements);
2871 LogicalResult
2872 OpFormatParser::verifyAttrDictRegion(SMLoc loc,
2873 ArrayRef<FormatElement *> elements) {
2874 auto isBase = [](FormatElement *el) {
2875 if (auto *attrDict = dyn_cast<AttrDictDirective>(el))
2876 return !attrDict->isWithKeyword();
2877 return false;
2879 auto isInvalid = [&](FormatElement *base, FormatElement *el) {
2880 auto *region = dyn_cast<RegionVariable>(el);
2881 if (!region)
2882 return false;
2883 (void)emitErrorAndNote(
2884 loc,
2885 llvm::formatv("format ambiguity caused by `attr-dict` directive "
2886 "followed by region `{0}`",
2887 region->getVar()->name),
2888 "try using `attr-dict-with-keyword` instead");
2889 return true;
2891 return verifyAdjacentElements(isBase, isInvalid, elements);
2894 LogicalResult OpFormatParser::verifyOperands(
2895 SMLoc loc, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2896 // Check that all of the operands are within the format, and their types can
2897 // be inferred.
2898 auto &buildableTypes = fmt.buildableTypes;
2899 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
2900 NamedTypeConstraint &operand = op.getOperand(i);
2902 // Check that the operand itself is in the format.
2903 if (!fmt.allOperands && !seenOperands.count(&operand)) {
2904 return emitErrorAndNote(loc,
2905 "operand #" + Twine(i) + ", named '" +
2906 operand.name + "', not found",
2907 "suggest adding a '$" + operand.name +
2908 "' directive to the custom assembly format");
2911 // Check that the operand type is in the format, or that it can be inferred.
2912 if (fmt.allOperandTypes || seenOperandTypes.test(i))
2913 continue;
2915 // Check to see if we can infer this type from another variable.
2916 auto varResolverIt = variableTyResolver.find(op.getOperand(i).name);
2917 if (varResolverIt != variableTyResolver.end()) {
2918 TypeResolutionInstance &resolver = varResolverIt->second;
2919 fmt.operandTypes[i].setResolver(resolver.resolver, resolver.transformer);
2920 continue;
2923 // Similarly to results, allow a custom builder for resolving the type if
2924 // we aren't using the 'operands' directive.
2925 std::optional<StringRef> builder = operand.constraint.getBuilderCall();
2926 if (!builder || (fmt.allOperands && operand.isVariableLength())) {
2927 return emitErrorAndNote(
2928 loc,
2929 "type of operand #" + Twine(i) + ", named '" + operand.name +
2930 "', is not buildable and a buildable type cannot be inferred",
2931 "suggest adding a type constraint to the operation or adding a "
2932 "'type($" +
2933 operand.name + ")' directive to the " + "custom assembly format");
2935 auto it = buildableTypes.insert({*builder, buildableTypes.size()});
2936 fmt.operandTypes[i].setBuilderIdx(it.first->second);
2938 return success();
2941 LogicalResult OpFormatParser::verifyRegions(SMLoc loc) {
2942 // Check that all of the regions are within the format.
2943 if (hasAllRegions)
2944 return success();
2946 for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) {
2947 const NamedRegion &region = op.getRegion(i);
2948 if (!seenRegions.count(&region)) {
2949 return emitErrorAndNote(loc,
2950 "region #" + Twine(i) + ", named '" +
2951 region.name + "', not found",
2952 "suggest adding a '$" + region.name +
2953 "' directive to the custom assembly format");
2956 return success();
2959 LogicalResult OpFormatParser::verifyResults(
2960 SMLoc loc, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2961 // If we format all of the types together, there is nothing to check.
2962 if (fmt.allResultTypes)
2963 return success();
2965 // If no result types are specified and we can infer them, infer all result
2966 // types
2967 if (op.getNumResults() > 0 && seenResultTypes.count() == 0 &&
2968 canInferResultTypes) {
2969 fmt.infersResultTypes = true;
2970 return success();
2973 // Check that all of the result types can be inferred.
2974 auto &buildableTypes = fmt.buildableTypes;
2975 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
2976 if (seenResultTypes.test(i))
2977 continue;
2979 // Check to see if we can infer this type from another variable.
2980 auto varResolverIt = variableTyResolver.find(op.getResultName(i));
2981 if (varResolverIt != variableTyResolver.end()) {
2982 TypeResolutionInstance resolver = varResolverIt->second;
2983 fmt.resultTypes[i].setResolver(resolver.resolver, resolver.transformer);
2984 continue;
2987 // If the result is not variable length, allow for the case where the type
2988 // has a builder that we can use.
2989 NamedTypeConstraint &result = op.getResult(i);
2990 std::optional<StringRef> builder = result.constraint.getBuilderCall();
2991 if (!builder || result.isVariableLength()) {
2992 return emitErrorAndNote(
2993 loc,
2994 "type of result #" + Twine(i) + ", named '" + result.name +
2995 "', is not buildable and a buildable type cannot be inferred",
2996 "suggest adding a type constraint to the operation or adding a "
2997 "'type($" +
2998 result.name + ")' directive to the " + "custom assembly format");
3000 // Note in the format that this result uses the custom builder.
3001 auto it = buildableTypes.insert({*builder, buildableTypes.size()});
3002 fmt.resultTypes[i].setBuilderIdx(it.first->second);
3004 return success();
3007 LogicalResult OpFormatParser::verifySuccessors(SMLoc loc) {
3008 // Check that all of the successors are within the format.
3009 if (hasAllSuccessors)
3010 return success();
3012 for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) {
3013 const NamedSuccessor &successor = op.getSuccessor(i);
3014 if (!seenSuccessors.count(&successor)) {
3015 return emitErrorAndNote(loc,
3016 "successor #" + Twine(i) + ", named '" +
3017 successor.name + "', not found",
3018 "suggest adding a '$" + successor.name +
3019 "' directive to the custom assembly format");
3022 return success();
3025 LogicalResult
3026 OpFormatParser::verifyOIListElements(SMLoc loc,
3027 ArrayRef<FormatElement *> elements) {
3028 // Check that all of the successors are within the format.
3029 SmallVector<StringRef> prohibitedLiterals;
3030 for (FormatElement *it : elements) {
3031 if (auto *oilist = dyn_cast<OIListElement>(it)) {
3032 if (!prohibitedLiterals.empty()) {
3033 // We just saw an oilist element in last iteration. Literals should not
3034 // match.
3035 for (LiteralElement *literal : oilist->getLiteralElements()) {
3036 if (find(prohibitedLiterals, literal->getSpelling()) !=
3037 prohibitedLiterals.end()) {
3038 return emitError(
3039 loc, "format ambiguity because " + literal->getSpelling() +
3040 " is used in two adjacent oilist elements.");
3044 for (LiteralElement *literal : oilist->getLiteralElements())
3045 prohibitedLiterals.push_back(literal->getSpelling());
3046 } else if (auto *literal = dyn_cast<LiteralElement>(it)) {
3047 if (find(prohibitedLiterals, literal->getSpelling()) !=
3048 prohibitedLiterals.end()) {
3049 return emitError(
3050 loc,
3051 "format ambiguity because " + literal->getSpelling() +
3052 " is used both in oilist element and the adjacent literal.");
3054 prohibitedLiterals.clear();
3055 } else {
3056 prohibitedLiterals.clear();
3059 return success();
3062 void OpFormatParser::handleAllTypesMatchConstraint(
3063 ArrayRef<StringRef> values,
3064 llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
3065 for (unsigned i = 0, e = values.size(); i != e; ++i) {
3066 // Check to see if this value matches a resolved operand or result type.
3067 ConstArgument arg = findSeenArg(values[i]);
3068 if (!arg)
3069 continue;
3071 // Mark this value as the type resolver for the other variables.
3072 for (unsigned j = 0; j != i; ++j)
3073 variableTyResolver[values[j]] = {arg, std::nullopt};
3074 for (unsigned j = i + 1; j != e; ++j)
3075 variableTyResolver[values[j]] = {arg, std::nullopt};
3079 void OpFormatParser::handleSameTypesConstraint(
3080 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
3081 bool includeResults) {
3082 const NamedTypeConstraint *resolver = nullptr;
3083 int resolvedIt = -1;
3085 // Check to see if there is an operand or result to use for the resolution.
3086 if ((resolvedIt = seenOperandTypes.find_first()) != -1)
3087 resolver = &op.getOperand(resolvedIt);
3088 else if (includeResults && (resolvedIt = seenResultTypes.find_first()) != -1)
3089 resolver = &op.getResult(resolvedIt);
3090 else
3091 return;
3093 // Set the resolvers for each operand and result.
3094 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i)
3095 if (!seenOperandTypes.test(i))
3096 variableTyResolver[op.getOperand(i).name] = {resolver, std::nullopt};
3097 if (includeResults) {
3098 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
3099 if (!seenResultTypes.test(i))
3100 variableTyResolver[op.getResultName(i)] = {resolver, std::nullopt};
3104 void OpFormatParser::handleTypesMatchConstraint(
3105 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
3106 const llvm::Record &def) {
3107 StringRef lhsName = def.getValueAsString("lhs");
3108 StringRef rhsName = def.getValueAsString("rhs");
3109 StringRef transformer = def.getValueAsString("transformer");
3110 if (ConstArgument arg = findSeenArg(lhsName))
3111 variableTyResolver[rhsName] = {arg, transformer};
3114 ConstArgument OpFormatParser::findSeenArg(StringRef name) {
3115 if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name))
3116 return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;
3117 if (const NamedTypeConstraint *arg = findArg(op.getResults(), name))
3118 return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr;
3119 if (const NamedAttribute *attr = findArg(op.getAttributes(), name))
3120 return seenAttrs.count(attr) ? attr : nullptr;
3121 return nullptr;
3124 FailureOr<FormatElement *>
3125 OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
3126 // Check that the parsed argument is something actually registered on the op.
3127 // Attributes
3128 if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) {
3129 if (ctx == TypeDirectiveContext)
3130 return emitError(
3131 loc, "attributes cannot be used as children to a `type` directive");
3132 if (ctx == RefDirectiveContext) {
3133 if (!seenAttrs.count(attr))
3134 return emitError(loc, "attribute '" + name +
3135 "' must be bound before it is referenced");
3136 } else if (!seenAttrs.insert(attr)) {
3137 return emitError(loc, "attribute '" + name + "' is already bound");
3140 return create<AttributeVariable>(attr);
3143 if (const NamedProperty *property = findArg(op.getProperties(), name)) {
3144 if (ctx != CustomDirectiveContext && ctx != RefDirectiveContext)
3145 return emitError(
3146 loc, "properties currently only supported in `custom` directive");
3148 if (ctx == RefDirectiveContext) {
3149 if (!seenProperties.count(property))
3150 return emitError(loc, "property '" + name +
3151 "' must be bound before it is referenced");
3152 } else {
3153 if (!seenProperties.insert(property))
3154 return emitError(loc, "property '" + name + "' is already bound");
3157 return create<PropertyVariable>(property);
3160 // Operands
3161 if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) {
3162 if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
3163 if (fmt.allOperands || !seenOperands.insert(operand).second)
3164 return emitError(loc, "operand '" + name + "' is already bound");
3165 } else if (ctx == RefDirectiveContext && !seenOperands.count(operand)) {
3166 return emitError(loc, "operand '" + name +
3167 "' must be bound before it is referenced");
3169 return create<OperandVariable>(operand);
3171 // Regions
3172 if (const NamedRegion *region = findArg(op.getRegions(), name)) {
3173 if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
3174 if (hasAllRegions || !seenRegions.insert(region).second)
3175 return emitError(loc, "region '" + name + "' is already bound");
3176 } else if (ctx == RefDirectiveContext && !seenRegions.count(region)) {
3177 return emitError(loc, "region '" + name +
3178 "' must be bound before it is referenced");
3179 } else {
3180 return emitError(loc, "regions can only be used at the top level");
3182 return create<RegionVariable>(region);
3184 // Results.
3185 if (const auto *result = findArg(op.getResults(), name)) {
3186 if (ctx != TypeDirectiveContext)
3187 return emitError(loc, "result variables can can only be used as a child "
3188 "to a 'type' directive");
3189 return create<ResultVariable>(result);
3191 // Successors.
3192 if (const auto *successor = findArg(op.getSuccessors(), name)) {
3193 if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
3194 if (hasAllSuccessors || !seenSuccessors.insert(successor).second)
3195 return emitError(loc, "successor '" + name + "' is already bound");
3196 } else if (ctx == RefDirectiveContext && !seenSuccessors.count(successor)) {
3197 return emitError(loc, "successor '" + name +
3198 "' must be bound before it is referenced");
3199 } else {
3200 return emitError(loc, "successors can only be used at the top level");
3203 return create<SuccessorVariable>(successor);
3205 return emitError(loc, "expected variable to refer to an argument, region, "
3206 "result, or successor");
3209 FailureOr<FormatElement *>
3210 OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
3211 Context ctx) {
3212 switch (kind) {
3213 case FormatToken::kw_prop_dict:
3214 return parsePropDictDirective(loc, ctx);
3215 case FormatToken::kw_attr_dict:
3216 return parseAttrDictDirective(loc, ctx,
3217 /*withKeyword=*/false);
3218 case FormatToken::kw_attr_dict_w_keyword:
3219 return parseAttrDictDirective(loc, ctx,
3220 /*withKeyword=*/true);
3221 case FormatToken::kw_functional_type:
3222 return parseFunctionalTypeDirective(loc, ctx);
3223 case FormatToken::kw_operands:
3224 return parseOperandsDirective(loc, ctx);
3225 case FormatToken::kw_regions:
3226 return parseRegionsDirective(loc, ctx);
3227 case FormatToken::kw_results:
3228 return parseResultsDirective(loc, ctx);
3229 case FormatToken::kw_successors:
3230 return parseSuccessorsDirective(loc, ctx);
3231 case FormatToken::kw_type:
3232 return parseTypeDirective(loc, ctx);
3233 case FormatToken::kw_oilist:
3234 return parseOIListDirective(loc, ctx);
3236 default:
3237 return emitError(loc, "unsupported directive kind");
3241 FailureOr<FormatElement *>
3242 OpFormatParser::parseAttrDictDirective(SMLoc loc, Context context,
3243 bool withKeyword) {
3244 if (context == TypeDirectiveContext)
3245 return emitError(loc, "'attr-dict' directive can only be used as a "
3246 "top-level directive");
3248 if (context == RefDirectiveContext) {
3249 if (!hasAttrDict)
3250 return emitError(loc, "'ref' of 'attr-dict' is not bound by a prior "
3251 "'attr-dict' directive");
3253 // Otherwise, this is a top-level context.
3254 } else {
3255 if (hasAttrDict)
3256 return emitError(loc, "'attr-dict' directive has already been seen");
3257 hasAttrDict = true;
3260 return create<AttrDictDirective>(withKeyword);
3263 FailureOr<FormatElement *>
3264 OpFormatParser::parsePropDictDirective(SMLoc loc, Context context) {
3265 if (context == TypeDirectiveContext)
3266 return emitError(loc, "'prop-dict' directive can only be used as a "
3267 "top-level directive");
3269 if (context == RefDirectiveContext)
3270 llvm::report_fatal_error("'ref' of 'prop-dict' unsupported");
3271 // Otherwise, this is a top-level context.
3273 if (hasPropDict)
3274 return emitError(loc, "'prop-dict' directive has already been seen");
3275 hasPropDict = true;
3277 return create<PropDictDirective>();
3280 LogicalResult OpFormatParser::verifyCustomDirectiveArguments(
3281 SMLoc loc, ArrayRef<FormatElement *> arguments) {
3282 for (FormatElement *argument : arguments) {
3283 if (!isa<AttrDictDirective, PropDictDirective, AttributeVariable,
3284 OperandVariable, PropertyVariable, RefDirective, RegionVariable,
3285 SuccessorVariable, StringElement, TypeDirective>(argument)) {
3286 // TODO: FormatElement should have location info attached.
3287 return emitError(loc, "only variables and types may be used as "
3288 "parameters to a custom directive");
3290 if (auto *type = dyn_cast<TypeDirective>(argument)) {
3291 if (!isa<OperandVariable, ResultVariable>(type->getArg())) {
3292 return emitError(loc, "type directives within a custom directive may "
3293 "only refer to variables");
3297 return success();
3300 FailureOr<FormatElement *>
3301 OpFormatParser::parseFunctionalTypeDirective(SMLoc loc, Context context) {
3302 if (context != TopLevelContext)
3303 return emitError(
3304 loc, "'functional-type' is only valid as a top-level directive");
3306 // Parse the main operand.
3307 FailureOr<FormatElement *> inputs, results;
3308 if (failed(parseToken(FormatToken::l_paren,
3309 "expected '(' before argument list")) ||
3310 failed(inputs = parseTypeDirectiveOperand(loc)) ||
3311 failed(parseToken(FormatToken::comma,
3312 "expected ',' after inputs argument")) ||
3313 failed(results = parseTypeDirectiveOperand(loc)) ||
3314 failed(
3315 parseToken(FormatToken::r_paren, "expected ')' after argument list")))
3316 return failure();
3317 return create<FunctionalTypeDirective>(*inputs, *results);
3320 FailureOr<FormatElement *>
3321 OpFormatParser::parseOperandsDirective(SMLoc loc, Context context) {
3322 if (context == RefDirectiveContext) {
3323 if (!fmt.allOperands)
3324 return emitError(loc, "'ref' of 'operands' is not bound by a prior "
3325 "'operands' directive");
3327 } else if (context == TopLevelContext || context == CustomDirectiveContext) {
3328 if (fmt.allOperands || !seenOperands.empty())
3329 return emitError(loc, "'operands' directive creates overlap in format");
3330 fmt.allOperands = true;
3332 return create<OperandsDirective>();
3335 FailureOr<FormatElement *>
3336 OpFormatParser::parseRegionsDirective(SMLoc loc, Context context) {
3337 if (context == TypeDirectiveContext)
3338 return emitError(loc, "'regions' is only valid as a top-level directive");
3339 if (context == RefDirectiveContext) {
3340 if (!hasAllRegions)
3341 return emitError(loc, "'ref' of 'regions' is not bound by a prior "
3342 "'regions' directive");
3344 // Otherwise, this is a TopLevel directive.
3345 } else {
3346 if (hasAllRegions || !seenRegions.empty())
3347 return emitError(loc, "'regions' directive creates overlap in format");
3348 hasAllRegions = true;
3350 return create<RegionsDirective>();
3353 FailureOr<FormatElement *>
3354 OpFormatParser::parseResultsDirective(SMLoc loc, Context context) {
3355 if (context != TypeDirectiveContext)
3356 return emitError(loc, "'results' directive can can only be used as a child "
3357 "to a 'type' directive");
3358 return create<ResultsDirective>();
3361 FailureOr<FormatElement *>
3362 OpFormatParser::parseSuccessorsDirective(SMLoc loc, Context context) {
3363 if (context == TypeDirectiveContext)
3364 return emitError(loc,
3365 "'successors' is only valid as a top-level directive");
3366 if (context == RefDirectiveContext) {
3367 if (!hasAllSuccessors)
3368 return emitError(loc, "'ref' of 'successors' is not bound by a prior "
3369 "'successors' directive");
3371 // Otherwise, this is a TopLevel directive.
3372 } else {
3373 if (hasAllSuccessors || !seenSuccessors.empty())
3374 return emitError(loc, "'successors' directive creates overlap in format");
3375 hasAllSuccessors = true;
3377 return create<SuccessorsDirective>();
3380 FailureOr<FormatElement *>
3381 OpFormatParser::parseOIListDirective(SMLoc loc, Context context) {
3382 if (failed(parseToken(FormatToken::l_paren,
3383 "expected '(' before oilist argument list")))
3384 return failure();
3385 std::vector<FormatElement *> literalElements;
3386 std::vector<std::vector<FormatElement *>> parsingElements;
3387 do {
3388 FailureOr<FormatElement *> lelement = parseLiteral(context);
3389 if (failed(lelement))
3390 return failure();
3391 literalElements.push_back(*lelement);
3392 parsingElements.emplace_back();
3393 std::vector<FormatElement *> &currParsingElements = parsingElements.back();
3394 while (peekToken().getKind() != FormatToken::pipe &&
3395 peekToken().getKind() != FormatToken::r_paren) {
3396 FailureOr<FormatElement *> pelement = parseElement(context);
3397 if (failed(pelement) ||
3398 failed(verifyOIListParsingElement(*pelement, loc)))
3399 return failure();
3400 currParsingElements.push_back(*pelement);
3402 if (peekToken().getKind() == FormatToken::pipe) {
3403 consumeToken();
3404 continue;
3406 if (peekToken().getKind() == FormatToken::r_paren) {
3407 consumeToken();
3408 break;
3410 } while (true);
3412 return create<OIListElement>(std::move(literalElements),
3413 std::move(parsingElements));
3416 LogicalResult OpFormatParser::verifyOIListParsingElement(FormatElement *element,
3417 SMLoc loc) {
3418 SmallVector<VariableElement *> vars;
3419 collect(element, vars);
3420 for (VariableElement *elem : vars) {
3421 LogicalResult res =
3422 TypeSwitch<FormatElement *, LogicalResult>(elem)
3423 // Only optional attributes can be within an oilist parsing group.
3424 .Case([&](AttributeVariable *attrEle) {
3425 if (!attrEle->getVar()->attr.isOptional() &&
3426 !attrEle->getVar()->attr.hasDefaultValue())
3427 return emitError(loc, "only optional attributes can be used in "
3428 "an oilist parsing group");
3429 return success();
3431 // Only optional-like(i.e. variadic) operands can be within an
3432 // oilist parsing group.
3433 .Case([&](OperandVariable *ele) {
3434 if (!ele->getVar()->isVariableLength())
3435 return emitError(loc, "only variable length operands can be "
3436 "used within an oilist parsing group");
3437 return success();
3439 // Only optional-like(i.e. variadic) results can be within an oilist
3440 // parsing group.
3441 .Case([&](ResultVariable *ele) {
3442 if (!ele->getVar()->isVariableLength())
3443 return emitError(loc, "only variable length results can be "
3444 "used within an oilist parsing group");
3445 return success();
3447 .Case([&](RegionVariable *) { return success(); })
3448 .Default([&](FormatElement *) {
3449 return emitError(loc,
3450 "only literals, types, and variables can be "
3451 "used within an oilist group");
3453 if (failed(res))
3454 return failure();
3456 return success();
3459 FailureOr<FormatElement *> OpFormatParser::parseTypeDirective(SMLoc loc,
3460 Context context) {
3461 if (context == TypeDirectiveContext)
3462 return emitError(loc, "'type' cannot be used as a child of another `type`");
3464 bool isRefChild = context == RefDirectiveContext;
3465 FailureOr<FormatElement *> operand;
3466 if (failed(parseToken(FormatToken::l_paren,
3467 "expected '(' before argument list")) ||
3468 failed(operand = parseTypeDirectiveOperand(loc, isRefChild)) ||
3469 failed(
3470 parseToken(FormatToken::r_paren, "expected ')' after argument list")))
3471 return failure();
3473 return create<TypeDirective>(*operand);
3476 LogicalResult OpFormatParser::markQualified(SMLoc loc, FormatElement *element) {
3477 return TypeSwitch<FormatElement *, LogicalResult>(element)
3478 .Case<AttributeVariable, TypeDirective>([](auto *element) {
3479 element->setShouldBeQualified();
3480 return success();
3482 .Default([&](auto *element) {
3483 return this->emitError(
3484 loc,
3485 "'qualified' directive expects an attribute or a `type` directive");
3489 FailureOr<FormatElement *>
3490 OpFormatParser::parseTypeDirectiveOperand(SMLoc loc, bool isRefChild) {
3491 FailureOr<FormatElement *> result = parseElement(TypeDirectiveContext);
3492 if (failed(result))
3493 return failure();
3495 FormatElement *element = *result;
3496 if (isa<LiteralElement>(element))
3497 return emitError(
3498 loc, "'type' directive operand expects variable or directive operand");
3500 if (auto *var = dyn_cast<OperandVariable>(element)) {
3501 unsigned opIdx = var->getVar() - op.operand_begin();
3502 if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
3503 return emitError(loc, "'type' of '" + var->getVar()->name +
3504 "' is already bound");
3505 if (isRefChild && !(fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
3506 return emitError(loc, "'ref' of 'type($" + var->getVar()->name +
3507 ")' is not bound by a prior 'type' directive");
3508 seenOperandTypes.set(opIdx);
3509 } else if (auto *var = dyn_cast<ResultVariable>(element)) {
3510 unsigned resIdx = var->getVar() - op.result_begin();
3511 if (!isRefChild && (fmt.allResultTypes || seenResultTypes.test(resIdx)))
3512 return emitError(loc, "'type' of '" + var->getVar()->name +
3513 "' is already bound");
3514 if (isRefChild && !(fmt.allResultTypes || seenResultTypes.test(resIdx)))
3515 return emitError(loc, "'ref' of 'type($" + var->getVar()->name +
3516 ")' is not bound by a prior 'type' directive");
3517 seenResultTypes.set(resIdx);
3518 } else if (isa<OperandsDirective>(&*element)) {
3519 if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.any()))
3520 return emitError(loc, "'operands' 'type' is already bound");
3521 if (isRefChild && !fmt.allOperandTypes)
3522 return emitError(loc, "'ref' of 'type(operands)' is not bound by a prior "
3523 "'type' directive");
3524 fmt.allOperandTypes = true;
3525 } else if (isa<ResultsDirective>(&*element)) {
3526 if (!isRefChild && (fmt.allResultTypes || seenResultTypes.any()))
3527 return emitError(loc, "'results' 'type' is already bound");
3528 if (isRefChild && !fmt.allResultTypes)
3529 return emitError(loc, "'ref' of 'type(results)' is not bound by a prior "
3530 "'type' directive");
3531 fmt.allResultTypes = true;
3532 } else {
3533 return emitError(loc, "invalid argument to 'type' directive");
3535 return element;
3538 LogicalResult OpFormatParser::verifyOptionalGroupElements(
3539 SMLoc loc, ArrayRef<FormatElement *> elements, FormatElement *anchor) {
3540 for (FormatElement *element : elements) {
3541 if (failed(verifyOptionalGroupElement(loc, element, element == anchor)))
3542 return failure();
3544 return success();
3547 LogicalResult OpFormatParser::verifyOptionalGroupElement(SMLoc loc,
3548 FormatElement *element,
3549 bool isAnchor) {
3550 return TypeSwitch<FormatElement *, LogicalResult>(element)
3551 // All attributes can be within the optional group, but only optional
3552 // attributes can be the anchor.
3553 .Case([&](AttributeVariable *attrEle) {
3554 Attribute attr = attrEle->getVar()->attr;
3555 if (isAnchor && !(attr.isOptional() || attr.hasDefaultValue()))
3556 return emitError(loc, "only optional or default-valued attributes "
3557 "can be used to anchor an optional group");
3558 return success();
3560 // Only optional-like(i.e. variadic) operands can be within an optional
3561 // group.
3562 .Case([&](OperandVariable *ele) {
3563 if (!ele->getVar()->isVariableLength())
3564 return emitError(loc, "only variable length operands can be used "
3565 "within an optional group");
3566 return success();
3568 // Only optional-like(i.e. variadic) results can be within an optional
3569 // group.
3570 .Case([&](ResultVariable *ele) {
3571 if (!ele->getVar()->isVariableLength())
3572 return emitError(loc, "only variable length results can be used "
3573 "within an optional group");
3574 return success();
3576 .Case([&](RegionVariable *) {
3577 // TODO: When ODS has proper support for marking "optional" regions, add
3578 // a check here.
3579 return success();
3581 .Case([&](TypeDirective *ele) {
3582 return verifyOptionalGroupElement(loc, ele->getArg(),
3583 /*isAnchor=*/false);
3585 .Case([&](FunctionalTypeDirective *ele) {
3586 if (failed(verifyOptionalGroupElement(loc, ele->getInputs(),
3587 /*isAnchor=*/false)))
3588 return failure();
3589 return verifyOptionalGroupElement(loc, ele->getResults(),
3590 /*isAnchor=*/false);
3592 .Case([&](CustomDirective *ele) {
3593 if (!isAnchor)
3594 return success();
3595 // Verify each child as being valid in an optional group. They are all
3596 // potential anchors if the custom directive was marked as one.
3597 for (FormatElement *child : ele->getArguments()) {
3598 if (isa<RefDirective>(child))
3599 continue;
3600 if (failed(verifyOptionalGroupElement(loc, child, /*isAnchor=*/true)))
3601 return failure();
3603 return success();
3605 // Literals, whitespace, and custom directives may be used, but they can't
3606 // anchor the group.
3607 .Case<LiteralElement, WhitespaceElement, OptionalElement>(
3608 [&](FormatElement *) {
3609 if (isAnchor)
3610 return emitError(loc, "only variables and types can be used "
3611 "to anchor an optional group");
3612 return success();
3614 .Default([&](FormatElement *) {
3615 return emitError(loc, "only literals, types, and variables can be "
3616 "used within an optional group");
3620 //===----------------------------------------------------------------------===//
3621 // Interface
3622 //===----------------------------------------------------------------------===//
3624 void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass) {
3625 // TODO: Operator doesn't expose all necessary functionality via
3626 // the const interface.
3627 Operator &op = const_cast<Operator &>(constOp);
3628 if (!op.hasAssemblyFormat())
3629 return;
3631 // Parse the format description.
3632 llvm::SourceMgr mgr;
3633 mgr.AddNewSourceBuffer(
3634 llvm::MemoryBuffer::getMemBuffer(op.getAssemblyFormat()), SMLoc());
3635 OperationFormat format(op);
3636 OpFormatParser parser(mgr, format, op);
3637 FailureOr<std::vector<FormatElement *>> elements = parser.parse();
3638 if (failed(elements)) {
3639 // Exit the process if format errors are treated as fatal.
3640 if (formatErrorIsFatal) {
3641 // Invoke the interrupt handlers to run the file cleanup handlers.
3642 llvm::sys::RunInterruptHandlers();
3643 std::exit(1);
3645 return;
3647 format.elements = std::move(*elements);
3649 // Generate the printer and parser based on the parsed format.
3650 format.genParser(op, opClass);
3651 format.genPrinter(op, opClass);