[NFC] Add libcxx python reformat SHA to .git-blame-ignore-revs
[llvm-project.git] / mlir / tools / mlir-tblgen / OpFormatGen.cpp
blobe0472926078d4d80a27a14318af6e965877746cc
1 //===- OpFormatGen.cpp - MLIR operation asm format generator --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "OpFormatGen.h"
10 #include "FormatGen.h"
11 #include "OpClass.h"
12 #include "mlir/Support/LLVM.h"
13 #include "mlir/TableGen/Class.h"
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "mlir/TableGen/Trait.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallBitVector.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Signals.h"
24 #include "llvm/Support/SourceMgr.h"
25 #include "llvm/TableGen/Record.h"
27 #define DEBUG_TYPE "mlir-tblgen-opformatgen"
29 using namespace mlir;
30 using namespace mlir::tblgen;
32 //===----------------------------------------------------------------------===//
33 // VariableElement
35 namespace {
36 /// This class represents an instance of an op variable element. A variable
37 /// refers to something registered on the operation itself, e.g. an operand,
38 /// result, attribute, region, or successor.
39 template <typename VarT, VariableElement::Kind VariableKind>
40 class OpVariableElement : public VariableElementBase<VariableKind> {
41 public:
42 using Base = OpVariableElement<VarT, VariableKind>;
44 /// Create an op variable element with the variable value.
45 OpVariableElement(const VarT *var) : var(var) {}
47 /// Get the variable.
48 const VarT *getVar() { return var; }
50 protected:
51 /// The op variable, e.g. a type or attribute constraint.
52 const VarT *var;
55 /// This class represents a variable that refers to an attribute argument.
56 struct AttributeVariable
57 : public OpVariableElement<NamedAttribute, VariableElement::Attribute> {
58 using Base::Base;
60 /// Return the constant builder call for the type of this attribute, or
61 /// std::nullopt if it doesn't have one.
62 std::optional<StringRef> getTypeBuilder() const {
63 std::optional<Type> attrType = var->attr.getValueType();
64 return attrType ? attrType->getBuilderCall() : std::nullopt;
67 /// Return if this attribute refers to a UnitAttr.
68 bool isUnitAttr() const {
69 return var->attr.getBaseAttr().getAttrDefName() == "UnitAttr";
72 /// Indicate if this attribute is printed "qualified" (that is it is
73 /// prefixed with the `#dialect.mnemonic`).
74 bool shouldBeQualified() { return shouldBeQualifiedFlag; }
75 void setShouldBeQualified(bool qualified = true) {
76 shouldBeQualifiedFlag = qualified;
79 private:
80 bool shouldBeQualifiedFlag = false;
83 /// This class represents a variable that refers to an operand argument.
84 using OperandVariable =
85 OpVariableElement<NamedTypeConstraint, VariableElement::Operand>;
87 /// This class represents a variable that refers to a result.
88 using ResultVariable =
89 OpVariableElement<NamedTypeConstraint, VariableElement::Result>;
91 /// This class represents a variable that refers to a region.
92 using RegionVariable = OpVariableElement<NamedRegion, VariableElement::Region>;
94 /// This class represents a variable that refers to a successor.
95 using SuccessorVariable =
96 OpVariableElement<NamedSuccessor, VariableElement::Successor>;
97 } // namespace
99 //===----------------------------------------------------------------------===//
100 // DirectiveElement
102 namespace {
103 /// This class represents the `operands` directive. This directive represents
104 /// all of the operands of an operation.
105 using OperandsDirective = DirectiveElementBase<DirectiveElement::Operands>;
107 /// This class represents the `results` directive. This directive represents
108 /// all of the results of an operation.
109 using ResultsDirective = DirectiveElementBase<DirectiveElement::Results>;
111 /// This class represents the `regions` directive. This directive represents
112 /// all of the regions of an operation.
113 using RegionsDirective = DirectiveElementBase<DirectiveElement::Regions>;
115 /// This class represents the `successors` directive. This directive represents
116 /// all of the successors of an operation.
117 using SuccessorsDirective = DirectiveElementBase<DirectiveElement::Successors>;
119 /// This class represents the `attr-dict` directive. This directive represents
120 /// the attribute dictionary of the operation.
121 class AttrDictDirective
122 : public DirectiveElementBase<DirectiveElement::AttrDict> {
123 public:
124 explicit AttrDictDirective(bool withKeyword) : withKeyword(withKeyword) {}
126 /// Return whether the dictionary should be printed with the 'attributes'
127 /// keyword.
128 bool isWithKeyword() const { return withKeyword; }
130 private:
131 /// If the dictionary should be printed with the 'attributes' keyword.
132 bool withKeyword;
135 /// This class represents the `prop-dict` directive. This directive represents
136 /// the properties of the operation, expressed as a directionary.
137 class PropDictDirective
138 : public DirectiveElementBase<DirectiveElement::PropDict> {
139 public:
140 explicit PropDictDirective() = default;
143 /// This class represents the `functional-type` directive. This directive takes
144 /// two arguments and formats them, respectively, as the inputs and results of a
145 /// FunctionType.
146 class FunctionalTypeDirective
147 : public DirectiveElementBase<DirectiveElement::FunctionalType> {
148 public:
149 FunctionalTypeDirective(FormatElement *inputs, FormatElement *results)
150 : inputs(inputs), results(results) {}
152 FormatElement *getInputs() const { return inputs; }
153 FormatElement *getResults() const { return results; }
155 private:
156 /// The input and result arguments.
157 FormatElement *inputs, *results;
160 /// This class represents the `type` directive.
161 class TypeDirective : public DirectiveElementBase<DirectiveElement::Type> {
162 public:
163 TypeDirective(FormatElement *arg) : arg(arg) {}
165 FormatElement *getArg() const { return arg; }
167 /// Indicate if this type is printed "qualified" (that is it is
168 /// prefixed with the `!dialect.mnemonic`).
169 bool shouldBeQualified() { return shouldBeQualifiedFlag; }
170 void setShouldBeQualified(bool qualified = true) {
171 shouldBeQualifiedFlag = qualified;
174 private:
175 /// The argument that is used to format the directive.
176 FormatElement *arg;
178 bool shouldBeQualifiedFlag = false;
181 /// This class represents a group of order-independent optional clauses. Each
182 /// clause starts with a literal element and has a coressponding parsing
183 /// element. A parsing element is a continous sequence of format elements.
184 /// Each clause can appear 0 or 1 time.
185 class OIListElement : public DirectiveElementBase<DirectiveElement::OIList> {
186 public:
187 OIListElement(std::vector<FormatElement *> &&literalElements,
188 std::vector<std::vector<FormatElement *>> &&parsingElements)
189 : literalElements(std::move(literalElements)),
190 parsingElements(std::move(parsingElements)) {}
192 /// Returns a range to iterate over the LiteralElements.
193 auto getLiteralElements() const {
194 function_ref<LiteralElement *(FormatElement * el)>
195 literalElementCastConverter =
196 [](FormatElement *el) { return cast<LiteralElement>(el); };
197 return llvm::map_range(literalElements, literalElementCastConverter);
200 /// Returns a range to iterate over the parsing elements corresponding to the
201 /// clauses.
202 ArrayRef<std::vector<FormatElement *>> getParsingElements() const {
203 return parsingElements;
206 /// Returns a range to iterate over tuples of parsing and literal elements.
207 auto getClauses() const {
208 return llvm::zip(getLiteralElements(), getParsingElements());
211 /// If the parsing element is a single UnitAttr element, then it returns the
212 /// attribute variable. Otherwise, returns nullptr.
213 AttributeVariable *
214 getUnitAttrParsingElement(ArrayRef<FormatElement *> pelement) {
215 if (pelement.size() == 1) {
216 auto *attrElem = dyn_cast<AttributeVariable>(pelement[0]);
217 if (attrElem && attrElem->isUnitAttr())
218 return attrElem;
220 return nullptr;
223 private:
224 /// A vector of `LiteralElement` objects. Each element stores the keyword
225 /// for one case of oilist element. For example, an oilist element along with
226 /// the `literalElements` vector:
227 /// ```
228 /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`]
229 /// literalElements = { `keyword`, `otherKeyword` }
230 /// ```
231 std::vector<FormatElement *> literalElements;
233 /// A vector of valid declarative assembly format vectors. Each object in
234 /// parsing elements is a vector of elements in assembly format syntax.
235 /// For example, an oilist element along with the parsingElements vector:
236 /// ```
237 /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`]
238 /// parsingElements = {
239 /// { `=`, `(`, $arg0, `)` },
240 /// { `<`, $arg1, `>` }
241 /// }
242 /// ```
243 std::vector<std::vector<FormatElement *>> parsingElements;
245 } // namespace
247 //===----------------------------------------------------------------------===//
248 // OperationFormat
249 //===----------------------------------------------------------------------===//
251 namespace {
253 using ConstArgument =
254 llvm::PointerUnion<const NamedAttribute *, const NamedTypeConstraint *>;
256 struct OperationFormat {
257 /// This class represents a specific resolver for an operand or result type.
258 class TypeResolution {
259 public:
260 TypeResolution() = default;
262 /// Get the index into the buildable types for this type, or std::nullopt.
263 std::optional<int> getBuilderIdx() const { return builderIdx; }
264 void setBuilderIdx(int idx) { builderIdx = idx; }
266 /// Get the variable this type is resolved to, or nullptr.
267 const NamedTypeConstraint *getVariable() const {
268 return resolver.dyn_cast<const NamedTypeConstraint *>();
270 /// Get the attribute this type is resolved to, or nullptr.
271 const NamedAttribute *getAttribute() const {
272 return resolver.dyn_cast<const NamedAttribute *>();
274 /// Get the transformer for the type of the variable, or std::nullopt.
275 std::optional<StringRef> getVarTransformer() const {
276 return variableTransformer;
278 void setResolver(ConstArgument arg, std::optional<StringRef> transformer) {
279 resolver = arg;
280 variableTransformer = transformer;
281 assert(getVariable() || getAttribute());
284 private:
285 /// If the type is resolved with a buildable type, this is the index into
286 /// 'buildableTypes' in the parent format.
287 std::optional<int> builderIdx;
288 /// If the type is resolved based upon another operand or result, this is
289 /// the variable or the attribute that this type is resolved to.
290 ConstArgument resolver;
291 /// If the type is resolved based upon another operand or result, this is
292 /// a transformer to apply to the variable when resolving.
293 std::optional<StringRef> variableTransformer;
296 /// The context in which an element is generated.
297 enum class GenContext {
298 /// The element is generated at the top-level or with the same behaviour.
299 Normal,
300 /// The element is generated inside an optional group.
301 Optional
304 OperationFormat(const Operator &op)
305 : useProperties(op.getDialect().usePropertiesForAttributes() &&
306 !op.getAttributes().empty()),
307 opCppClassName(op.getCppClassName()) {
308 operandTypes.resize(op.getNumOperands(), TypeResolution());
309 resultTypes.resize(op.getNumResults(), TypeResolution());
311 hasImplicitTermTrait = llvm::any_of(op.getTraits(), [](const Trait &trait) {
312 return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator");
315 hasSingleBlockTrait =
316 hasImplicitTermTrait || op.getTrait("::mlir::OpTrait::SingleBlock");
319 /// Generate the operation parser from this format.
320 void genParser(Operator &op, OpClass &opClass);
321 /// Generate the parser code for a specific format element.
322 void genElementParser(FormatElement *element, MethodBody &body,
323 FmtContext &attrTypeCtx,
324 GenContext genCtx = GenContext::Normal);
325 /// Generate the C++ to resolve the types of operands and results during
326 /// parsing.
327 void genParserTypeResolution(Operator &op, MethodBody &body);
328 /// Generate the C++ to resolve the types of the operands during parsing.
329 void genParserOperandTypeResolution(
330 Operator &op, MethodBody &body,
331 function_ref<void(TypeResolution &, StringRef)> emitTypeResolver);
332 /// Generate the C++ to resolve regions during parsing.
333 void genParserRegionResolution(Operator &op, MethodBody &body);
334 /// Generate the C++ to resolve successors during parsing.
335 void genParserSuccessorResolution(Operator &op, MethodBody &body);
336 /// Generate the C++ to handling variadic segment size traits.
337 void genParserVariadicSegmentResolution(Operator &op, MethodBody &body);
339 /// Generate the operation printer from this format.
340 void genPrinter(Operator &op, OpClass &opClass);
342 /// Generate the printer code for a specific format element.
343 void genElementPrinter(FormatElement *element, MethodBody &body, Operator &op,
344 bool &shouldEmitSpace, bool &lastWasPunctuation);
346 /// The various elements in this format.
347 std::vector<FormatElement *> elements;
349 /// A flag indicating if all operand/result types were seen. If the format
350 /// contains these, it can not contain individual type resolvers.
351 bool allOperands = false, allOperandTypes = false, allResultTypes = false;
353 /// A flag indicating if this operation infers its result types
354 bool infersResultTypes = false;
356 /// A flag indicating if this operation has the SingleBlockImplicitTerminator
357 /// trait.
358 bool hasImplicitTermTrait;
360 /// A flag indicating if this operation has the SingleBlock trait.
361 bool hasSingleBlockTrait;
363 /// Indicate whether attribute are stored in properties.
364 bool useProperties;
366 /// The Operation class name
367 StringRef opCppClassName;
369 /// A map of buildable types to indices.
370 llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
372 /// The index of the buildable type, if valid, for every operand and result.
373 std::vector<TypeResolution> operandTypes, resultTypes;
375 /// The set of attributes explicitly used within the format.
376 SmallVector<const NamedAttribute *, 8> usedAttributes;
377 llvm::StringSet<> inferredAttributes;
379 } // namespace
381 //===----------------------------------------------------------------------===//
382 // Parser Gen
384 /// Returns true if we can format the given attribute as an EnumAttr in the
385 /// parser format.
386 static bool canFormatEnumAttr(const NamedAttribute *attr) {
387 Attribute baseAttr = attr->attr.getBaseAttr();
388 const EnumAttr *enumAttr = dyn_cast<EnumAttr>(&baseAttr);
389 if (!enumAttr)
390 return false;
392 // The attribute must have a valid underlying type and a constant builder.
393 return !enumAttr->getUnderlyingType().empty() &&
394 !enumAttr->getConstBuilderTemplate().empty();
397 /// Returns if we should format the given attribute as an SymbolNameAttr.
398 static bool shouldFormatSymbolNameAttr(const NamedAttribute *attr) {
399 return attr->attr.getBaseAttr().getAttrDefName() == "SymbolNameAttr";
402 /// The code snippet used to generate a parser call for an attribute.
404 /// {0}: The name of the attribute.
405 /// {1}: The type for the attribute.
406 const char *const attrParserCode = R"(
407 if (parser.parseCustomAttributeWithFallback({0}Attr, {1})) {{
408 return ::mlir::failure();
412 /// The code snippet used to generate a parser call for an attribute.
414 /// {0}: The name of the attribute.
415 /// {1}: The type for the attribute.
416 const char *const genericAttrParserCode = R"(
417 if (parser.parseAttribute({0}Attr, {1}))
418 return ::mlir::failure();
421 const char *const optionalAttrParserCode = R"(
422 ::mlir::OptionalParseResult parseResult{0}Attr =
423 parser.parseOptionalAttribute({0}Attr, {1});
424 if (parseResult{0}Attr.has_value() && failed(*parseResult{0}Attr))
425 return ::mlir::failure();
426 if (parseResult{0}Attr.has_value() && succeeded(*parseResult{0}Attr))
429 /// The code snippet used to generate a parser call for a symbol name attribute.
431 /// {0}: The name of the attribute.
432 const char *const symbolNameAttrParserCode = R"(
433 if (parser.parseSymbolName({0}Attr))
434 return ::mlir::failure();
436 const char *const optionalSymbolNameAttrParserCode = R"(
437 // Parsing an optional symbol name doesn't fail, so no need to check the
438 // result.
439 (void)parser.parseOptionalSymbolName({0}Attr);
442 /// The code snippet used to generate a parser call for an enum attribute.
444 /// {0}: The name of the attribute.
445 /// {1}: The c++ namespace for the enum symbolize functions.
446 /// {2}: The function to symbolize a string of the enum.
447 /// {3}: The constant builder call to create an attribute of the enum type.
448 /// {4}: The set of allowed enum keywords.
449 /// {5}: The error message on failure when the enum isn't present.
450 /// {6}: The attribute assignment expression
451 const char *const enumAttrParserCode = R"(
453 ::llvm::StringRef attrStr;
454 ::mlir::NamedAttrList attrStorage;
455 auto loc = parser.getCurrentLocation();
456 if (parser.parseOptionalKeyword(&attrStr, {4})) {
457 ::mlir::StringAttr attrVal;
458 ::mlir::OptionalParseResult parseResult =
459 parser.parseOptionalAttribute(attrVal,
460 parser.getBuilder().getNoneType(),
461 "{0}", attrStorage);
462 if (parseResult.has_value()) {{
463 if (failed(*parseResult))
464 return ::mlir::failure();
465 attrStr = attrVal.getValue();
466 } else {
470 if (!attrStr.empty()) {
471 auto attrOptional = {1}::{2}(attrStr);
472 if (!attrOptional)
473 return parser.emitError(loc, "invalid ")
474 << "{0} attribute specification: \"" << attrStr << '"';;
476 {0}Attr = {3};
482 /// The code snippet used to generate a parser call for an operand.
484 /// {0}: The name of the operand.
485 const char *const variadicOperandParserCode = R"(
486 {0}OperandsLoc = parser.getCurrentLocation();
487 if (parser.parseOperandList({0}Operands))
488 return ::mlir::failure();
490 const char *const optionalOperandParserCode = R"(
492 {0}OperandsLoc = parser.getCurrentLocation();
493 ::mlir::OpAsmParser::UnresolvedOperand operand;
494 ::mlir::OptionalParseResult parseResult =
495 parser.parseOptionalOperand(operand);
496 if (parseResult.has_value()) {
497 if (failed(*parseResult))
498 return ::mlir::failure();
499 {0}Operands.push_back(operand);
503 const char *const operandParserCode = R"(
504 {0}OperandsLoc = parser.getCurrentLocation();
505 if (parser.parseOperand({0}RawOperands[0]))
506 return ::mlir::failure();
508 /// The code snippet used to generate a parser call for a VariadicOfVariadic
509 /// operand.
511 /// {0}: The name of the operand.
512 /// {1}: The name of segment size attribute.
513 const char *const variadicOfVariadicOperandParserCode = R"(
515 {0}OperandsLoc = parser.getCurrentLocation();
516 int32_t curSize = 0;
517 do {
518 if (parser.parseOptionalLParen())
519 break;
520 if (parser.parseOperandList({0}Operands) || parser.parseRParen())
521 return ::mlir::failure();
522 {0}OperandGroupSizes.push_back({0}Operands.size() - curSize);
523 curSize = {0}Operands.size();
524 } while (succeeded(parser.parseOptionalComma()));
528 /// The code snippet used to generate a parser call for a type list.
530 /// {0}: The name for the type list.
531 const char *const variadicOfVariadicTypeParserCode = R"(
532 do {
533 if (parser.parseOptionalLParen())
534 break;
535 if (parser.parseOptionalRParen() &&
536 (parser.parseTypeList({0}Types) || parser.parseRParen()))
537 return ::mlir::failure();
538 } while (succeeded(parser.parseOptionalComma()));
540 const char *const variadicTypeParserCode = R"(
541 if (parser.parseTypeList({0}Types))
542 return ::mlir::failure();
544 const char *const optionalTypeParserCode = R"(
546 ::mlir::Type optionalType;
547 ::mlir::OptionalParseResult parseResult =
548 parser.parseOptionalType(optionalType);
549 if (parseResult.has_value()) {
550 if (failed(*parseResult))
551 return ::mlir::failure();
552 {0}Types.push_back(optionalType);
556 const char *const typeParserCode = R"(
558 {0} type;
559 if (parser.parseCustomTypeWithFallback(type))
560 return ::mlir::failure();
561 {1}RawTypes[0] = type;
564 const char *const qualifiedTypeParserCode = R"(
565 if (parser.parseType({1}RawTypes[0]))
566 return ::mlir::failure();
569 /// The code snippet used to generate a parser call for a functional type.
571 /// {0}: The name for the input type list.
572 /// {1}: The name for the result type list.
573 const char *const functionalTypeParserCode = R"(
574 ::mlir::FunctionType {0}__{1}_functionType;
575 if (parser.parseType({0}__{1}_functionType))
576 return ::mlir::failure();
577 {0}Types = {0}__{1}_functionType.getInputs();
578 {1}Types = {0}__{1}_functionType.getResults();
581 /// The code snippet used to generate a parser call to infer return types.
583 /// {0}: The operation class name
584 const char *const inferReturnTypesParserCode = R"(
585 ::llvm::SmallVector<::mlir::Type> inferredReturnTypes;
586 if (::mlir::failed({0}::inferReturnTypes(parser.getContext(),
587 result.location, result.operands,
588 result.attributes.getDictionary(parser.getContext()),
589 result.getRawProperties(),
590 result.regions, inferredReturnTypes)))
591 return ::mlir::failure();
592 result.addTypes(inferredReturnTypes);
595 /// The code snippet used to generate a parser call for a region list.
597 /// {0}: The name for the region list.
598 const char *regionListParserCode = R"(
600 std::unique_ptr<::mlir::Region> region;
601 auto firstRegionResult = parser.parseOptionalRegion(region);
602 if (firstRegionResult.has_value()) {
603 if (failed(*firstRegionResult))
604 return ::mlir::failure();
605 {0}Regions.emplace_back(std::move(region));
607 // Parse any trailing regions.
608 while (succeeded(parser.parseOptionalComma())) {
609 region = std::make_unique<::mlir::Region>();
610 if (parser.parseRegion(*region))
611 return ::mlir::failure();
612 {0}Regions.emplace_back(std::move(region));
618 /// The code snippet used to ensure a list of regions have terminators.
620 /// {0}: The name of the region list.
621 const char *regionListEnsureTerminatorParserCode = R"(
622 for (auto &region : {0}Regions)
623 ensureTerminator(*region, parser.getBuilder(), result.location);
626 /// The code snippet used to ensure a list of regions have a block.
628 /// {0}: The name of the region list.
629 const char *regionListEnsureSingleBlockParserCode = R"(
630 for (auto &region : {0}Regions)
631 if (region->empty()) region->emplaceBlock();
634 /// The code snippet used to generate a parser call for an optional region.
636 /// {0}: The name of the region.
637 const char *optionalRegionParserCode = R"(
639 auto parseResult = parser.parseOptionalRegion(*{0}Region);
640 if (parseResult.has_value() && failed(*parseResult))
641 return ::mlir::failure();
645 /// The code snippet used to generate a parser call for a region.
647 /// {0}: The name of the region.
648 const char *regionParserCode = R"(
649 if (parser.parseRegion(*{0}Region))
650 return ::mlir::failure();
653 /// The code snippet used to ensure a region has a terminator.
655 /// {0}: The name of the region.
656 const char *regionEnsureTerminatorParserCode = R"(
657 ensureTerminator(*{0}Region, parser.getBuilder(), result.location);
660 /// The code snippet used to ensure a region has a block.
662 /// {0}: The name of the region.
663 const char *regionEnsureSingleBlockParserCode = R"(
664 if ({0}Region->empty()) {0}Region->emplaceBlock();
667 /// The code snippet used to generate a parser call for a successor list.
669 /// {0}: The name for the successor list.
670 const char *successorListParserCode = R"(
672 ::mlir::Block *succ;
673 auto firstSucc = parser.parseOptionalSuccessor(succ);
674 if (firstSucc.has_value()) {
675 if (failed(*firstSucc))
676 return ::mlir::failure();
677 {0}Successors.emplace_back(succ);
679 // Parse any trailing successors.
680 while (succeeded(parser.parseOptionalComma())) {
681 if (parser.parseSuccessor(succ))
682 return ::mlir::failure();
683 {0}Successors.emplace_back(succ);
689 /// The code snippet used to generate a parser call for a successor.
691 /// {0}: The name of the successor.
692 const char *successorParserCode = R"(
693 if (parser.parseSuccessor({0}Successor))
694 return ::mlir::failure();
697 /// The code snippet used to generate a parser for OIList
699 /// {0}: literal keyword corresponding to a case for oilist
700 const char *oilistParserCode = R"(
701 if ({0}Clause) {
702 return parser.emitError(parser.getNameLoc())
703 << "`{0}` clause can appear at most once in the expansion of the "
704 "oilist directive";
706 {0}Clause = true;
709 namespace {
710 /// The type of length for a given parse argument.
711 enum class ArgumentLengthKind {
712 /// The argument is a variadic of a variadic, and may contain 0->N range
713 /// elements.
714 VariadicOfVariadic,
715 /// The argument is variadic, and may contain 0->N elements.
716 Variadic,
717 /// The argument is optional, and may contain 0 or 1 elements.
718 Optional,
719 /// The argument is a single element, i.e. always represents 1 element.
720 Single
722 } // namespace
724 /// Get the length kind for the given constraint.
725 static ArgumentLengthKind
726 getArgumentLengthKind(const NamedTypeConstraint *var) {
727 if (var->isOptional())
728 return ArgumentLengthKind::Optional;
729 if (var->isVariadicOfVariadic())
730 return ArgumentLengthKind::VariadicOfVariadic;
731 if (var->isVariadic())
732 return ArgumentLengthKind::Variadic;
733 return ArgumentLengthKind::Single;
736 /// Get the name used for the type list for the given type directive operand.
737 /// 'lengthKind' to the corresponding kind for the given argument.
738 static StringRef getTypeListName(FormatElement *arg,
739 ArgumentLengthKind &lengthKind) {
740 if (auto *operand = dyn_cast<OperandVariable>(arg)) {
741 lengthKind = getArgumentLengthKind(operand->getVar());
742 return operand->getVar()->name;
744 if (auto *result = dyn_cast<ResultVariable>(arg)) {
745 lengthKind = getArgumentLengthKind(result->getVar());
746 return result->getVar()->name;
748 lengthKind = ArgumentLengthKind::Variadic;
749 if (isa<OperandsDirective>(arg))
750 return "allOperand";
751 if (isa<ResultsDirective>(arg))
752 return "allResult";
753 llvm_unreachable("unknown 'type' directive argument");
756 /// Generate the parser for a literal value.
757 static void genLiteralParser(StringRef value, MethodBody &body) {
758 // Handle the case of a keyword/identifier.
759 if (value.front() == '_' || isalpha(value.front())) {
760 body << "Keyword(\"" << value << "\")";
761 return;
763 body << (StringRef)StringSwitch<StringRef>(value)
764 .Case("->", "Arrow()")
765 .Case(":", "Colon()")
766 .Case(",", "Comma()")
767 .Case("=", "Equal()")
768 .Case("<", "Less()")
769 .Case(">", "Greater()")
770 .Case("{", "LBrace()")
771 .Case("}", "RBrace()")
772 .Case("(", "LParen()")
773 .Case(")", "RParen()")
774 .Case("[", "LSquare()")
775 .Case("]", "RSquare()")
776 .Case("?", "Question()")
777 .Case("+", "Plus()")
778 .Case("*", "Star()")
779 .Case("...", "Ellipsis()");
782 /// Generate the storage code required for parsing the given element.
783 static void genElementParserStorage(FormatElement *element, const Operator &op,
784 MethodBody &body) {
785 if (auto *optional = dyn_cast<OptionalElement>(element)) {
786 ArrayRef<FormatElement *> elements = optional->getThenElements();
788 // If the anchor is a unit attribute, it won't be parsed directly so elide
789 // it.
790 auto *anchor = dyn_cast<AttributeVariable>(optional->getAnchor());
791 FormatElement *elidedAnchorElement = nullptr;
792 if (anchor && anchor != elements.front() && anchor->isUnitAttr())
793 elidedAnchorElement = anchor;
794 for (FormatElement *childElement : elements)
795 if (childElement != elidedAnchorElement)
796 genElementParserStorage(childElement, op, body);
797 for (FormatElement *childElement : optional->getElseElements())
798 genElementParserStorage(childElement, op, body);
800 } else if (auto *oilist = dyn_cast<OIListElement>(element)) {
801 for (ArrayRef<FormatElement *> pelement : oilist->getParsingElements()) {
802 if (!oilist->getUnitAttrParsingElement(pelement))
803 for (FormatElement *element : pelement)
804 genElementParserStorage(element, op, body);
807 } else if (auto *custom = dyn_cast<CustomDirective>(element)) {
808 for (FormatElement *paramElement : custom->getArguments())
809 genElementParserStorage(paramElement, op, body);
811 } else if (isa<OperandsDirective>(element)) {
812 body << " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> "
813 "allOperands;\n";
815 } else if (isa<RegionsDirective>(element)) {
816 body << " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
817 "fullRegions;\n";
819 } else if (isa<SuccessorsDirective>(element)) {
820 body << " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
822 } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
823 const NamedAttribute *var = attr->getVar();
824 body << llvm::formatv(" {0} {1}Attr;\n", var->attr.getStorageType(),
825 var->name);
827 } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
828 StringRef name = operand->getVar()->name;
829 if (operand->getVar()->isVariableLength()) {
830 body
831 << " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> "
832 << name << "Operands;\n";
833 if (operand->getVar()->isVariadicOfVariadic()) {
834 body << " llvm::SmallVector<int32_t> " << name
835 << "OperandGroupSizes;\n";
837 } else {
838 body << " ::mlir::OpAsmParser::UnresolvedOperand " << name
839 << "RawOperands[1];\n"
840 << " ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> "
841 << name << "Operands(" << name << "RawOperands);";
843 body << llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
844 " (void){0}OperandsLoc;\n",
845 name);
847 } else if (auto *region = dyn_cast<RegionVariable>(element)) {
848 StringRef name = region->getVar()->name;
849 if (region->getVar()->isVariadic()) {
850 body << llvm::formatv(
851 " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
852 "{0}Regions;\n",
853 name);
854 } else {
855 body << llvm::formatv(" std::unique_ptr<::mlir::Region> {0}Region = "
856 "std::make_unique<::mlir::Region>();\n",
857 name);
860 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
861 StringRef name = successor->getVar()->name;
862 if (successor->getVar()->isVariadic()) {
863 body << llvm::formatv(" ::llvm::SmallVector<::mlir::Block *, 2> "
864 "{0}Successors;\n",
865 name);
866 } else {
867 body << llvm::formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name);
870 } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
871 ArgumentLengthKind lengthKind;
872 StringRef name = getTypeListName(dir->getArg(), lengthKind);
873 if (lengthKind != ArgumentLengthKind::Single)
874 body << " ::llvm::SmallVector<::mlir::Type, 1> " << name << "Types;\n";
875 else
876 body << llvm::formatv(" ::mlir::Type {0}RawTypes[1];\n", name)
877 << llvm::formatv(
878 " ::llvm::ArrayRef<::mlir::Type> {0}Types({0}RawTypes);\n",
879 name);
880 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
881 ArgumentLengthKind ignored;
882 body << " ::llvm::ArrayRef<::mlir::Type> "
883 << getTypeListName(dir->getInputs(), ignored) << "Types;\n";
884 body << " ::llvm::ArrayRef<::mlir::Type> "
885 << getTypeListName(dir->getResults(), ignored) << "Types;\n";
889 /// Generate the parser for a parameter to a custom directive.
890 static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
891 if (auto *attr = dyn_cast<AttributeVariable>(param)) {
892 body << attr->getVar()->name << "Attr";
893 } else if (isa<AttrDictDirective>(param)) {
894 body << "result.attributes";
895 } else if (auto *operand = dyn_cast<OperandVariable>(param)) {
896 StringRef name = operand->getVar()->name;
897 ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
898 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
899 body << llvm::formatv("{0}OperandGroups", name);
900 else if (lengthKind == ArgumentLengthKind::Variadic)
901 body << llvm::formatv("{0}Operands", name);
902 else if (lengthKind == ArgumentLengthKind::Optional)
903 body << llvm::formatv("{0}Operand", name);
904 else
905 body << formatv("{0}RawOperands[0]", name);
907 } else if (auto *region = dyn_cast<RegionVariable>(param)) {
908 StringRef name = region->getVar()->name;
909 if (region->getVar()->isVariadic())
910 body << llvm::formatv("{0}Regions", name);
911 else
912 body << llvm::formatv("*{0}Region", name);
914 } else if (auto *successor = dyn_cast<SuccessorVariable>(param)) {
915 StringRef name = successor->getVar()->name;
916 if (successor->getVar()->isVariadic())
917 body << llvm::formatv("{0}Successors", name);
918 else
919 body << llvm::formatv("{0}Successor", name);
921 } else if (auto *dir = dyn_cast<RefDirective>(param)) {
922 genCustomParameterParser(dir->getArg(), body);
924 } else if (auto *dir = dyn_cast<TypeDirective>(param)) {
925 ArgumentLengthKind lengthKind;
926 StringRef listName = getTypeListName(dir->getArg(), lengthKind);
927 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
928 body << llvm::formatv("{0}TypeGroups", listName);
929 else if (lengthKind == ArgumentLengthKind::Variadic)
930 body << llvm::formatv("{0}Types", listName);
931 else if (lengthKind == ArgumentLengthKind::Optional)
932 body << llvm::formatv("{0}Type", listName);
933 else
934 body << formatv("{0}RawTypes[0]", listName);
936 } else if (auto *string = dyn_cast<StringElement>(param)) {
937 FmtContext ctx;
938 ctx.withBuilder("parser.getBuilder()");
939 ctx.addSubst("_ctxt", "parser.getContext()");
940 body << tgfmt(string->getValue(), &ctx);
942 } else {
943 llvm_unreachable("unknown custom directive parameter");
947 /// Generate the parser for a custom directive.
948 static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
949 bool useProperties,
950 StringRef opCppClassName) {
951 body << " {\n";
953 // Preprocess the directive variables.
954 // * Add a local variable for optional operands and types. This provides a
955 // better API to the user defined parser methods.
956 // * Set the location of operand variables.
957 for (FormatElement *param : dir->getArguments()) {
958 if (auto *operand = dyn_cast<OperandVariable>(param)) {
959 auto *var = operand->getVar();
960 body << " " << var->name
961 << "OperandsLoc = parser.getCurrentLocation();\n";
962 if (var->isOptional()) {
963 body << llvm::formatv(
964 " ::std::optional<::mlir::OpAsmParser::UnresolvedOperand> "
965 "{0}Operand;\n",
966 var->name);
967 } else if (var->isVariadicOfVariadic()) {
968 body << llvm::formatv(" "
969 "::llvm::SmallVector<::llvm::SmallVector<::mlir::"
970 "OpAsmParser::UnresolvedOperand>> "
971 "{0}OperandGroups;\n",
972 var->name);
974 } else if (auto *dir = dyn_cast<TypeDirective>(param)) {
975 ArgumentLengthKind lengthKind;
976 StringRef listName = getTypeListName(dir->getArg(), lengthKind);
977 if (lengthKind == ArgumentLengthKind::Optional) {
978 body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName);
979 } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
980 body << llvm::formatv(
981 " ::llvm::SmallVector<llvm::SmallVector<::mlir::Type>> "
982 "{0}TypeGroups;\n",
983 listName);
985 } else if (auto *dir = dyn_cast<RefDirective>(param)) {
986 FormatElement *input = dir->getArg();
987 if (auto *operand = dyn_cast<OperandVariable>(input)) {
988 if (!operand->getVar()->isOptional())
989 continue;
990 body << llvm::formatv(
991 " {0} {1}Operand = {1}Operands.empty() ? {0}() : "
992 "{1}Operands[0];\n",
993 "::std::optional<::mlir::OpAsmParser::UnresolvedOperand>",
994 operand->getVar()->name);
996 } else if (auto *type = dyn_cast<TypeDirective>(input)) {
997 ArgumentLengthKind lengthKind;
998 StringRef listName = getTypeListName(type->getArg(), lengthKind);
999 if (lengthKind == ArgumentLengthKind::Optional) {
1000 body << llvm::formatv(" ::mlir::Type {0}Type = {0}Types.empty() ? "
1001 "::mlir::Type() : {0}Types[0];\n",
1002 listName);
1008 body << " if (parse" << dir->getName() << "(parser";
1009 for (FormatElement *param : dir->getArguments()) {
1010 body << ", ";
1011 genCustomParameterParser(param, body);
1014 body << "))\n"
1015 << " return ::mlir::failure();\n";
1017 // After parsing, add handling for any of the optional constructs.
1018 for (FormatElement *param : dir->getArguments()) {
1019 if (auto *attr = dyn_cast<AttributeVariable>(param)) {
1020 const NamedAttribute *var = attr->getVar();
1021 if (var->attr.isOptional() || var->attr.hasDefaultValue())
1022 body << llvm::formatv(" if ({0}Attr)\n ", var->name);
1023 if (useProperties) {
1024 body << formatv(
1025 " result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;\n",
1026 var->name, opCppClassName);
1027 } else {
1028 body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
1029 var->name);
1032 } else if (auto *operand = dyn_cast<OperandVariable>(param)) {
1033 const NamedTypeConstraint *var = operand->getVar();
1034 if (var->isOptional()) {
1035 body << llvm::formatv(" if ({0}Operand.has_value())\n"
1036 " {0}Operands.push_back(*{0}Operand);\n",
1037 var->name);
1038 } else if (var->isVariadicOfVariadic()) {
1039 body << llvm::formatv(
1040 " for (const auto &subRange : {0}OperandGroups) {{\n"
1041 " {0}Operands.append(subRange.begin(), subRange.end());\n"
1042 " {0}OperandGroupSizes.push_back(subRange.size());\n"
1043 " }\n",
1044 var->name, var->constraint.getVariadicOfVariadicSegmentSizeAttr());
1046 } else if (auto *dir = dyn_cast<TypeDirective>(param)) {
1047 ArgumentLengthKind lengthKind;
1048 StringRef listName = getTypeListName(dir->getArg(), lengthKind);
1049 if (lengthKind == ArgumentLengthKind::Optional) {
1050 body << llvm::formatv(" if ({0}Type)\n"
1051 " {0}Types.push_back({0}Type);\n",
1052 listName);
1053 } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
1054 body << llvm::formatv(
1055 " for (const auto &subRange : {0}TypeGroups)\n"
1056 " {0}Types.append(subRange.begin(), subRange.end());\n",
1057 listName);
1062 body << " }\n";
1065 /// Generate the parser for a enum attribute.
1066 static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
1067 FmtContext &attrTypeCtx, bool parseAsOptional,
1068 bool useProperties, StringRef opCppClassName) {
1069 Attribute baseAttr = var->attr.getBaseAttr();
1070 const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
1071 std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
1073 // Generate the code for building an attribute for this enum.
1074 std::string attrBuilderStr;
1076 llvm::raw_string_ostream os(attrBuilderStr);
1077 os << tgfmt(enumAttr.getConstBuilderTemplate(), &attrTypeCtx,
1078 "*attrOptional");
1081 // Build a string containing the cases that can be formatted as a keyword.
1082 std::string validCaseKeywordsStr = "{";
1083 llvm::raw_string_ostream validCaseKeywordsOS(validCaseKeywordsStr);
1084 for (const EnumAttrCase &attrCase : cases)
1085 if (canFormatStringAsKeyword(attrCase.getStr()))
1086 validCaseKeywordsOS << '"' << attrCase.getStr() << "\",";
1087 validCaseKeywordsOS.str().back() = '}';
1089 // If the attribute is not optional, build an error message for the missing
1090 // attribute.
1091 std::string errorMessage;
1092 if (!parseAsOptional) {
1093 llvm::raw_string_ostream errorMessageOS(errorMessage);
1094 errorMessageOS
1095 << "return parser.emitError(loc, \"expected string or "
1096 "keyword containing one of the following enum values for attribute '"
1097 << var->name << "' [";
1098 llvm::interleaveComma(cases, errorMessageOS, [&](const auto &attrCase) {
1099 errorMessageOS << attrCase.getStr();
1101 errorMessageOS << "]\");";
1103 std::string attrAssignment;
1104 if (useProperties) {
1105 attrAssignment =
1106 formatv(" "
1107 "result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;",
1108 var->name, opCppClassName);
1109 } else {
1110 attrAssignment =
1111 formatv("result.addAttribute(\"{0}\", {0}Attr);", var->name);
1114 body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(),
1115 enumAttr.getStringToSymbolFnName(), attrBuilderStr,
1116 validCaseKeywordsStr, errorMessage, attrAssignment);
1119 // Generate the parser for an attribute.
1120 static void genAttrParser(AttributeVariable *attr, MethodBody &body,
1121 FmtContext &attrTypeCtx, bool parseAsOptional,
1122 bool useProperties, StringRef opCppClassName) {
1123 const NamedAttribute *var = attr->getVar();
1125 // Check to see if we can parse this as an enum attribute.
1126 if (canFormatEnumAttr(var))
1127 return genEnumAttrParser(var, body, attrTypeCtx, parseAsOptional,
1128 useProperties, opCppClassName);
1130 // Check to see if we should parse this as a symbol name attribute.
1131 if (shouldFormatSymbolNameAttr(var)) {
1132 body << formatv(parseAsOptional ? optionalSymbolNameAttrParserCode
1133 : symbolNameAttrParserCode,
1134 var->name);
1135 } else {
1137 // If this attribute has a buildable type, use that when parsing the
1138 // attribute.
1139 std::string attrTypeStr;
1140 if (std::optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
1141 llvm::raw_string_ostream os(attrTypeStr);
1142 os << tgfmt(*typeBuilder, &attrTypeCtx);
1143 } else {
1144 attrTypeStr = "::mlir::Type{}";
1146 if (parseAsOptional) {
1147 body << formatv(optionalAttrParserCode, var->name, attrTypeStr);
1148 } else {
1149 if (attr->shouldBeQualified() ||
1150 var->attr.getStorageType() == "::mlir::Attribute")
1151 body << formatv(genericAttrParserCode, var->name, attrTypeStr);
1152 else
1153 body << formatv(attrParserCode, var->name, attrTypeStr);
1156 if (useProperties) {
1157 body << formatv(
1158 " if ({0}Attr) result.getOrAddProperties<{1}::Properties>().{0} = "
1159 "{0}Attr;\n",
1160 var->name, opCppClassName);
1161 } else {
1162 body << formatv(
1163 " if ({0}Attr) result.attributes.append(\"{0}\", {0}Attr);\n",
1164 var->name);
1168 void OperationFormat::genParser(Operator &op, OpClass &opClass) {
1169 SmallVector<MethodParameter> paramList;
1170 paramList.emplace_back("::mlir::OpAsmParser &", "parser");
1171 paramList.emplace_back("::mlir::OperationState &", "result");
1173 auto *method = opClass.addStaticMethod("::mlir::ParseResult", "parse",
1174 std::move(paramList));
1175 auto &body = method->body();
1177 // Generate variables to store the operands and type within the format. This
1178 // allows for referencing these variables in the presence of optional
1179 // groupings.
1180 for (FormatElement *element : elements)
1181 genElementParserStorage(element, op, body);
1183 // A format context used when parsing attributes with buildable types.
1184 FmtContext attrTypeCtx;
1185 attrTypeCtx.withBuilder("parser.getBuilder()");
1187 // Generate parsers for each of the elements.
1188 for (FormatElement *element : elements)
1189 genElementParser(element, body, attrTypeCtx);
1191 // Generate the code to resolve the operand/result types and successors now
1192 // that they have been parsed.
1193 genParserRegionResolution(op, body);
1194 genParserSuccessorResolution(op, body);
1195 genParserVariadicSegmentResolution(op, body);
1196 genParserTypeResolution(op, body);
1198 body << " return ::mlir::success();\n";
1201 void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
1202 FmtContext &attrTypeCtx,
1203 GenContext genCtx) {
1204 /// Optional Group.
1205 if (auto *optional = dyn_cast<OptionalElement>(element)) {
1206 auto genElementParsers = [&](FormatElement *firstElement,
1207 ArrayRef<FormatElement *> elements,
1208 bool thenGroup) {
1209 // If the anchor is a unit attribute, we don't need to print it. When
1210 // parsing, we will add this attribute if this group is present.
1211 FormatElement *elidedAnchorElement = nullptr;
1212 auto *anchorAttr = dyn_cast<AttributeVariable>(optional->getAnchor());
1213 if (anchorAttr && anchorAttr != firstElement &&
1214 anchorAttr->isUnitAttr()) {
1215 elidedAnchorElement = anchorAttr;
1217 if (!thenGroup == optional->isInverted()) {
1218 // Add the anchor unit attribute to the operation state.
1219 if (useProperties) {
1220 body << formatv(
1221 " result.getOrAddProperties<{1}::Properties>().{0} = "
1222 "parser.getBuilder().getUnitAttr();",
1223 anchorAttr->getVar()->name, opCppClassName);
1224 } else {
1225 body << " result.addAttribute(\"" << anchorAttr->getVar()->name
1226 << "\", parser.getBuilder().getUnitAttr());\n";
1231 // Generate the rest of the elements inside an optional group. Elements in
1232 // an optional group after the guard are parsed as required.
1233 for (FormatElement *childElement : elements)
1234 if (childElement != elidedAnchorElement)
1235 genElementParser(childElement, body, attrTypeCtx,
1236 GenContext::Optional);
1239 ArrayRef<FormatElement *> thenElements =
1240 optional->getThenElements(/*parseable=*/true);
1242 // Generate a special optional parser for the first element to gate the
1243 // parsing of the rest of the elements.
1244 FormatElement *firstElement = thenElements.front();
1245 if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) {
1246 genAttrParser(attrVar, body, attrTypeCtx, /*parseAsOptional=*/true,
1247 useProperties, opCppClassName);
1248 body << " if (" << attrVar->getVar()->name << "Attr) {\n";
1249 } else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
1250 body << " if (::mlir::succeeded(parser.parseOptional";
1251 genLiteralParser(literal->getSpelling(), body);
1252 body << ")) {\n";
1253 } else if (auto *opVar = dyn_cast<OperandVariable>(firstElement)) {
1254 genElementParser(opVar, body, attrTypeCtx);
1255 body << " if (!" << opVar->getVar()->name << "Operands.empty()) {\n";
1256 } else if (auto *regionVar = dyn_cast<RegionVariable>(firstElement)) {
1257 const NamedRegion *region = regionVar->getVar();
1258 if (region->isVariadic()) {
1259 genElementParser(regionVar, body, attrTypeCtx);
1260 body << " if (!" << region->name << "Regions.empty()) {\n";
1261 } else {
1262 body << llvm::formatv(optionalRegionParserCode, region->name);
1263 body << " if (!" << region->name << "Region->empty()) {\n ";
1264 if (hasImplicitTermTrait)
1265 body << llvm::formatv(regionEnsureTerminatorParserCode, region->name);
1266 else if (hasSingleBlockTrait)
1267 body << llvm::formatv(regionEnsureSingleBlockParserCode,
1268 region->name);
1272 genElementParsers(firstElement, thenElements.drop_front(),
1273 /*thenGroup=*/true);
1274 body << " }";
1276 // Generate the else elements.
1277 auto elseElements = optional->getElseElements();
1278 if (!elseElements.empty()) {
1279 body << " else {\n";
1280 ArrayRef<FormatElement *> elseElements =
1281 optional->getElseElements(/*parseable=*/true);
1282 genElementParsers(elseElements.front(), elseElements,
1283 /*thenGroup=*/false);
1284 body << " }";
1286 body << "\n";
1288 /// OIList Directive
1289 } else if (OIListElement *oilist = dyn_cast<OIListElement>(element)) {
1290 for (LiteralElement *le : oilist->getLiteralElements())
1291 body << " bool " << le->getSpelling() << "Clause = false;\n";
1293 // Generate the parsing loop
1294 body << " while(true) {\n";
1295 for (auto clause : oilist->getClauses()) {
1296 LiteralElement *lelement = std::get<0>(clause);
1297 ArrayRef<FormatElement *> pelement = std::get<1>(clause);
1298 body << "if (succeeded(parser.parseOptional";
1299 genLiteralParser(lelement->getSpelling(), body);
1300 body << ")) {\n";
1301 StringRef lelementName = lelement->getSpelling();
1302 body << formatv(oilistParserCode, lelementName);
1303 if (AttributeVariable *unitAttrElem =
1304 oilist->getUnitAttrParsingElement(pelement)) {
1305 if (useProperties) {
1306 body << formatv(
1307 " result.getOrAddProperties<{1}::Properties>().{0} = "
1308 "parser.getBuilder().getUnitAttr();",
1309 unitAttrElem->getVar()->name, opCppClassName);
1310 } else {
1311 body << " result.addAttribute(\"" << unitAttrElem->getVar()->name
1312 << "\", UnitAttr::get(parser.getContext()));\n";
1314 } else {
1315 for (FormatElement *el : pelement)
1316 genElementParser(el, body, attrTypeCtx);
1318 body << " } else ";
1320 body << " {\n";
1321 body << " break;\n";
1322 body << " }\n";
1323 body << "}\n";
1325 /// Literals.
1326 } else if (LiteralElement *literal = dyn_cast<LiteralElement>(element)) {
1327 body << " if (parser.parse";
1328 genLiteralParser(literal->getSpelling(), body);
1329 body << ")\n return ::mlir::failure();\n";
1331 /// Whitespaces.
1332 } else if (isa<WhitespaceElement>(element)) {
1333 // Nothing to parse.
1335 /// Arguments.
1336 } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
1337 bool parseAsOptional =
1338 (genCtx == GenContext::Normal && attr->getVar()->attr.isOptional());
1339 genAttrParser(attr, body, attrTypeCtx, parseAsOptional, useProperties,
1340 opCppClassName);
1342 } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
1343 ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
1344 StringRef name = operand->getVar()->name;
1345 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
1346 body << llvm::formatv(
1347 variadicOfVariadicOperandParserCode, name,
1348 operand->getVar()->constraint.getVariadicOfVariadicSegmentSizeAttr());
1349 else if (lengthKind == ArgumentLengthKind::Variadic)
1350 body << llvm::formatv(variadicOperandParserCode, name);
1351 else if (lengthKind == ArgumentLengthKind::Optional)
1352 body << llvm::formatv(optionalOperandParserCode, name);
1353 else
1354 body << formatv(operandParserCode, name);
1356 } else if (auto *region = dyn_cast<RegionVariable>(element)) {
1357 bool isVariadic = region->getVar()->isVariadic();
1358 body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode,
1359 region->getVar()->name);
1360 if (hasImplicitTermTrait)
1361 body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode
1362 : regionEnsureTerminatorParserCode,
1363 region->getVar()->name);
1364 else if (hasSingleBlockTrait)
1365 body << llvm::formatv(isVariadic ? regionListEnsureSingleBlockParserCode
1366 : regionEnsureSingleBlockParserCode,
1367 region->getVar()->name);
1369 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
1370 bool isVariadic = successor->getVar()->isVariadic();
1371 body << formatv(isVariadic ? successorListParserCode : successorParserCode,
1372 successor->getVar()->name);
1374 /// Directives.
1375 } else if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
1376 body.indent() << "{\n";
1377 body.indent() << "auto loc = parser.getCurrentLocation();(void)loc;\n"
1378 << "if (parser.parseOptionalAttrDict"
1379 << (attrDict->isWithKeyword() ? "WithKeyword" : "")
1380 << "(result.attributes))\n"
1381 << " return ::mlir::failure();\n";
1382 if (useProperties) {
1383 body << "if (failed(verifyInherentAttrs(result.name, result.attributes, "
1384 "[&]() {\n"
1385 << " return parser.emitError(loc) << \"'\" << "
1386 "result.name.getStringRef() << \"' op \";\n"
1387 << " })))\n"
1388 << " return ::mlir::failure();\n";
1390 body.unindent() << "}\n";
1391 body.unindent();
1392 } else if (auto *attrDict = dyn_cast<PropDictDirective>(element)) {
1393 body << " if (parseProperties(parser, result))\n"
1394 << " return ::mlir::failure();\n";
1395 } else if (auto *customDir = dyn_cast<CustomDirective>(element)) {
1396 genCustomDirectiveParser(customDir, body, useProperties, opCppClassName);
1397 } else if (isa<OperandsDirective>(element)) {
1398 body << " ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
1399 << " if (parser.parseOperandList(allOperands))\n"
1400 << " return ::mlir::failure();\n";
1402 } else if (isa<RegionsDirective>(element)) {
1403 body << llvm::formatv(regionListParserCode, "full");
1404 if (hasImplicitTermTrait)
1405 body << llvm::formatv(regionListEnsureTerminatorParserCode, "full");
1406 else if (hasSingleBlockTrait)
1407 body << llvm::formatv(regionListEnsureSingleBlockParserCode, "full");
1409 } else if (isa<SuccessorsDirective>(element)) {
1410 body << llvm::formatv(successorListParserCode, "full");
1412 } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
1413 ArgumentLengthKind lengthKind;
1414 StringRef listName = getTypeListName(dir->getArg(), lengthKind);
1415 if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
1416 body << llvm::formatv(variadicOfVariadicTypeParserCode, listName);
1417 } else if (lengthKind == ArgumentLengthKind::Variadic) {
1418 body << llvm::formatv(variadicTypeParserCode, listName);
1419 } else if (lengthKind == ArgumentLengthKind::Optional) {
1420 body << llvm::formatv(optionalTypeParserCode, listName);
1421 } else {
1422 const char *parserCode =
1423 dir->shouldBeQualified() ? qualifiedTypeParserCode : typeParserCode;
1424 TypeSwitch<FormatElement *>(dir->getArg())
1425 .Case<OperandVariable, ResultVariable>([&](auto operand) {
1426 body << formatv(parserCode,
1427 operand->getVar()->constraint.getCPPClassName(),
1428 listName);
1430 .Default([&](auto operand) {
1431 body << formatv(parserCode, "::mlir::Type", listName);
1434 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
1435 ArgumentLengthKind ignored;
1436 body << formatv(functionalTypeParserCode,
1437 getTypeListName(dir->getInputs(), ignored),
1438 getTypeListName(dir->getResults(), ignored));
1439 } else {
1440 llvm_unreachable("unknown format element");
1444 void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
1445 // If any of type resolutions use transformed variables, make sure that the
1446 // types of those variables are resolved.
1447 SmallPtrSet<const NamedTypeConstraint *, 8> verifiedVariables;
1448 FmtContext verifierFCtx;
1449 for (TypeResolution &resolver :
1450 llvm::concat<TypeResolution>(resultTypes, operandTypes)) {
1451 std::optional<StringRef> transformer = resolver.getVarTransformer();
1452 if (!transformer)
1453 continue;
1454 // Ensure that we don't verify the same variables twice.
1455 const NamedTypeConstraint *variable = resolver.getVariable();
1456 if (!variable || !verifiedVariables.insert(variable).second)
1457 continue;
1459 auto constraint = variable->constraint;
1460 body << " for (::mlir::Type type : " << variable->name << "Types) {\n"
1461 << " (void)type;\n"
1462 << " if (!("
1463 << tgfmt(constraint.getConditionTemplate(),
1464 &verifierFCtx.withSelf("type"))
1465 << ")) {\n"
1466 << formatv(" return parser.emitError(parser.getNameLoc()) << "
1467 "\"'{0}' must be {1}, but got \" << type;\n",
1468 variable->name, constraint.getSummary())
1469 << " }\n"
1470 << " }\n";
1473 // Initialize the set of buildable types.
1474 if (!buildableTypes.empty()) {
1475 FmtContext typeBuilderCtx;
1476 typeBuilderCtx.withBuilder("parser.getBuilder()");
1477 for (auto &it : buildableTypes)
1478 body << " ::mlir::Type odsBuildableType" << it.second << " = "
1479 << tgfmt(it.first, &typeBuilderCtx) << ";\n";
1482 // Emit the code necessary for a type resolver.
1483 auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) {
1484 if (std::optional<int> val = resolver.getBuilderIdx()) {
1485 body << "odsBuildableType" << *val;
1486 } else if (const NamedTypeConstraint *var = resolver.getVariable()) {
1487 if (std::optional<StringRef> tform = resolver.getVarTransformer()) {
1488 FmtContext fmtContext;
1489 fmtContext.addSubst("_ctxt", "parser.getContext()");
1490 if (var->isVariadic())
1491 fmtContext.withSelf(var->name + "Types");
1492 else
1493 fmtContext.withSelf(var->name + "Types[0]");
1494 body << tgfmt(*tform, &fmtContext);
1495 } else {
1496 body << var->name << "Types";
1497 if (!var->isVariadic())
1498 body << "[0]";
1500 } else if (const NamedAttribute *attr = resolver.getAttribute()) {
1501 if (std::optional<StringRef> tform = resolver.getVarTransformer())
1502 body << tgfmt(*tform,
1503 &FmtContext().withSelf(attr->name + "Attr.getType()"));
1504 else
1505 body << attr->name << "Attr.getType()";
1506 } else {
1507 body << curVar << "Types";
1511 // Resolve each of the result types.
1512 if (!infersResultTypes) {
1513 if (allResultTypes) {
1514 body << " result.addTypes(allResultTypes);\n";
1515 } else {
1516 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
1517 body << " result.addTypes(";
1518 emitTypeResolver(resultTypes[i], op.getResultName(i));
1519 body << ");\n";
1524 // Emit the operand type resolutions.
1525 genParserOperandTypeResolution(op, body, emitTypeResolver);
1527 // Handle return type inference once all operands have been resolved
1528 if (infersResultTypes)
1529 body << formatv(inferReturnTypesParserCode, op.getCppClassName());
1532 void OperationFormat::genParserOperandTypeResolution(
1533 Operator &op, MethodBody &body,
1534 function_ref<void(TypeResolution &, StringRef)> emitTypeResolver) {
1535 // Early exit if there are no operands.
1536 if (op.getNumOperands() == 0)
1537 return;
1539 // Handle the case where all operand types are grouped together with
1540 // "types(operands)".
1541 if (allOperandTypes) {
1542 // If `operands` was specified, use the full operand list directly.
1543 if (allOperands) {
1544 body << " if (parser.resolveOperands(allOperands, allOperandTypes, "
1545 "allOperandLoc, result.operands))\n"
1546 " return ::mlir::failure();\n";
1547 return;
1550 // Otherwise, use llvm::concat to merge the disjoint operand lists together.
1551 // llvm::concat does not allow the case of a single range, so guard it here.
1552 body << " if (parser.resolveOperands(";
1553 if (op.getNumOperands() > 1) {
1554 body << "::llvm::concat<const ::mlir::OpAsmParser::UnresolvedOperand>(";
1555 llvm::interleaveComma(op.getOperands(), body, [&](auto &operand) {
1556 body << operand.name << "Operands";
1558 body << ")";
1559 } else {
1560 body << op.operand_begin()->name << "Operands";
1562 body << ", allOperandTypes, parser.getNameLoc(), result.operands))\n"
1563 << " return ::mlir::failure();\n";
1564 return;
1567 // Handle the case where all operands are grouped together with "operands".
1568 if (allOperands) {
1569 body << " if (parser.resolveOperands(allOperands, ";
1571 // Group all of the operand types together to perform the resolution all at
1572 // once. Use llvm::concat to perform the merge. llvm::concat does not allow
1573 // the case of a single range, so guard it here.
1574 if (op.getNumOperands() > 1) {
1575 body << "::llvm::concat<const ::mlir::Type>(";
1576 llvm::interleaveComma(
1577 llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
1578 body << "::llvm::ArrayRef<::mlir::Type>(";
1579 emitTypeResolver(operandTypes[i], op.getOperand(i).name);
1580 body << ")";
1582 body << ")";
1583 } else {
1584 emitTypeResolver(operandTypes.front(), op.getOperand(0).name);
1587 body << ", allOperandLoc, result.operands))\n return "
1588 "::mlir::failure();\n";
1589 return;
1592 // The final case is the one where each of the operands types are resolved
1593 // separately.
1594 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
1595 NamedTypeConstraint &operand = op.getOperand(i);
1596 body << " if (parser.resolveOperands(" << operand.name << "Operands, ";
1598 // Resolve the type of this operand.
1599 TypeResolution &operandType = operandTypes[i];
1600 emitTypeResolver(operandType, operand.name);
1602 body << ", " << operand.name
1603 << "OperandsLoc, result.operands))\n return ::mlir::failure();\n";
1607 void OperationFormat::genParserRegionResolution(Operator &op,
1608 MethodBody &body) {
1609 // Check for the case where all regions were parsed.
1610 bool hasAllRegions = llvm::any_of(
1611 elements, [](FormatElement *elt) { return isa<RegionsDirective>(elt); });
1612 if (hasAllRegions) {
1613 body << " result.addRegions(fullRegions);\n";
1614 return;
1617 // Otherwise, handle each region individually.
1618 for (const NamedRegion &region : op.getRegions()) {
1619 if (region.isVariadic())
1620 body << " result.addRegions(" << region.name << "Regions);\n";
1621 else
1622 body << " result.addRegion(std::move(" << region.name << "Region));\n";
1626 void OperationFormat::genParserSuccessorResolution(Operator &op,
1627 MethodBody &body) {
1628 // Check for the case where all successors were parsed.
1629 bool hasAllSuccessors = llvm::any_of(elements, [](FormatElement *elt) {
1630 return isa<SuccessorsDirective>(elt);
1632 if (hasAllSuccessors) {
1633 body << " result.addSuccessors(fullSuccessors);\n";
1634 return;
1637 // Otherwise, handle each successor individually.
1638 for (const NamedSuccessor &successor : op.getSuccessors()) {
1639 if (successor.isVariadic())
1640 body << " result.addSuccessors(" << successor.name << "Successors);\n";
1641 else
1642 body << " result.addSuccessors(" << successor.name << "Successor);\n";
1646 void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
1647 MethodBody &body) {
1648 if (!allOperands) {
1649 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1650 if (op.getDialect().usePropertiesForAttributes()) {
1651 body << formatv(" "
1652 "result.getOrAddProperties<{0}::Properties>().operand_"
1653 "segment_sizes = "
1654 "(parser.getBuilder().getDenseI32ArrayAttr({{",
1655 op.getCppClassName());
1656 } else {
1657 body << " result.addAttribute(\"operand_segment_sizes\", "
1658 << "parser.getBuilder().getDenseI32ArrayAttr({";
1660 auto interleaveFn = [&](const NamedTypeConstraint &operand) {
1661 // If the operand is variadic emit the parsed size.
1662 if (operand.isVariableLength())
1663 body << "static_cast<int32_t>(" << operand.name << "Operands.size())";
1664 else
1665 body << "1";
1667 llvm::interleaveComma(op.getOperands(), body, interleaveFn);
1668 body << "}));\n";
1670 for (const NamedTypeConstraint &operand : op.getOperands()) {
1671 if (!operand.isVariadicOfVariadic())
1672 continue;
1673 if (op.getDialect().usePropertiesForAttributes()) {
1674 body << llvm::formatv(
1675 " result.getOrAddProperties<{0}::Properties>().{1} = "
1676 "parser.getBuilder().getDenseI32ArrayAttr({2}OperandGroupSizes);\n",
1677 op.getCppClassName(),
1678 operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
1679 operand.name);
1680 } else {
1681 body << llvm::formatv(
1682 " result.addAttribute(\"{0}\", "
1683 "parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));"
1684 "\n",
1685 operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
1686 operand.name);
1691 if (!allResultTypes &&
1692 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
1693 if (op.getDialect().usePropertiesForAttributes()) {
1694 body << formatv(
1696 "result.getOrAddProperties<{0}::Properties>().result_segment_sizes = "
1697 "(parser.getBuilder().getDenseI32ArrayAttr({{",
1698 op.getCppClassName());
1699 } else {
1700 body << " result.addAttribute(\"result_segment_sizes\", "
1701 << "parser.getBuilder().getDenseI32ArrayAttr({";
1703 auto interleaveFn = [&](const NamedTypeConstraint &result) {
1704 // If the result is variadic emit the parsed size.
1705 if (result.isVariableLength())
1706 body << "static_cast<int32_t>(" << result.name << "Types.size())";
1707 else
1708 body << "1";
1710 llvm::interleaveComma(op.getResults(), body, interleaveFn);
1711 body << "}));\n";
1715 //===----------------------------------------------------------------------===//
1716 // PrinterGen
1718 /// The code snippet used to generate a printer call for a region of an
1719 // operation that has the SingleBlockImplicitTerminator trait.
1721 /// {0}: The name of the region.
1722 const char *regionSingleBlockImplicitTerminatorPrinterCode = R"(
1724 bool printTerminator = true;
1725 if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{
1726 printTerminator = !term->getAttrDictionary().empty() ||
1727 term->getNumOperands() != 0 ||
1728 term->getNumResults() != 0;
1730 _odsPrinter.printRegion({0}, /*printEntryBlockArgs=*/true,
1731 /*printBlockTerminators=*/printTerminator);
1735 /// The code snippet used to generate a printer call for an enum that has cases
1736 /// that can't be represented with a keyword.
1738 /// {0}: The name of the enum attribute.
1739 /// {1}: The name of the enum attributes symbolToString function.
1740 const char *enumAttrBeginPrinterCode = R"(
1742 auto caseValue = {0}();
1743 auto caseValueStr = {1}(caseValue);
1746 /// Generate the printer for the 'prop-dict' directive.
1747 static void genPropDictPrinter(OperationFormat &fmt, Operator &op,
1748 MethodBody &body) {
1749 body << " _odsPrinter << \" \";\n"
1750 << " printProperties(this->getContext(), _odsPrinter, "
1751 "getProperties());\n";
1754 /// Generate the printer for the 'attr-dict' directive.
1755 static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
1756 MethodBody &body, bool withKeyword) {
1757 body << " ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;\n";
1758 // Elide the variadic segment size attributes if necessary.
1759 if (!fmt.allOperands &&
1760 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"))
1761 body << " elidedAttrs.push_back(\"operand_segment_sizes\");\n";
1762 if (!fmt.allResultTypes &&
1763 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
1764 body << " elidedAttrs.push_back(\"result_segment_sizes\");\n";
1765 for (const StringRef key : fmt.inferredAttributes.keys())
1766 body << " elidedAttrs.push_back(\"" << key << "\");\n";
1767 for (const NamedAttribute *attr : fmt.usedAttributes)
1768 body << " elidedAttrs.push_back(\"" << attr->name << "\");\n";
1769 // Add code to check attributes for equality with the default value
1770 // for attributes with the elidePrintingDefaultValue bit set.
1771 for (const NamedAttribute &namedAttr : op.getAttributes()) {
1772 const Attribute &attr = namedAttr.attr;
1773 if (!attr.isDerivedAttr() && attr.hasDefaultValue()) {
1774 const StringRef &name = namedAttr.name;
1775 FmtContext fctx;
1776 fctx.withBuilder("odsBuilder");
1777 std::string defaultValue = std::string(
1778 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
1779 body << " {\n";
1780 body << " ::mlir::Builder odsBuilder(getContext());\n";
1781 body << " ::mlir::Attribute attr = " << op.getGetterName(name)
1782 << "Attr();\n";
1783 body << " if(attr && (attr == " << defaultValue << "))\n";
1784 body << " elidedAttrs.push_back(\"" << name << "\");\n";
1785 body << " }\n";
1788 body << " _odsPrinter.printOptionalAttrDict"
1789 << (withKeyword ? "WithKeyword" : "")
1790 << "((*this)->getAttrs(), elidedAttrs);\n";
1793 /// Generate the printer for a literal value. `shouldEmitSpace` is true if a
1794 /// space should be emitted before this element. `lastWasPunctuation` is true if
1795 /// the previous element was a punctuation literal.
1796 static void genLiteralPrinter(StringRef value, MethodBody &body,
1797 bool &shouldEmitSpace, bool &lastWasPunctuation) {
1798 body << " _odsPrinter";
1800 // Don't insert a space for certain punctuation.
1801 if (shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation))
1802 body << " << ' '";
1803 body << " << \"" << value << "\";\n";
1805 // Insert a space after certain literals.
1806 shouldEmitSpace =
1807 value.size() != 1 || !StringRef("<({[").contains(value.front());
1808 lastWasPunctuation = value.front() != '_' && !isalpha(value.front());
1811 /// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation`
1812 /// are set to false.
1813 static void genSpacePrinter(bool value, MethodBody &body, bool &shouldEmitSpace,
1814 bool &lastWasPunctuation) {
1815 if (value) {
1816 body << " _odsPrinter << ' ';\n";
1817 lastWasPunctuation = false;
1818 } else {
1819 lastWasPunctuation = true;
1821 shouldEmitSpace = false;
1824 /// Generate the printer for a custom directive parameter.
1825 static void genCustomDirectiveParameterPrinter(FormatElement *element,
1826 const Operator &op,
1827 MethodBody &body) {
1828 if (auto *attr = dyn_cast<AttributeVariable>(element)) {
1829 body << op.getGetterName(attr->getVar()->name) << "Attr()";
1831 } else if (isa<AttrDictDirective>(element)) {
1832 body << "getOperation()->getAttrDictionary()";
1834 } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
1835 body << op.getGetterName(operand->getVar()->name) << "()";
1837 } else if (auto *region = dyn_cast<RegionVariable>(element)) {
1838 body << op.getGetterName(region->getVar()->name) << "()";
1840 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
1841 body << op.getGetterName(successor->getVar()->name) << "()";
1843 } else if (auto *dir = dyn_cast<RefDirective>(element)) {
1844 genCustomDirectiveParameterPrinter(dir->getArg(), op, body);
1846 } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
1847 auto *typeOperand = dir->getArg();
1848 auto *operand = dyn_cast<OperandVariable>(typeOperand);
1849 auto *var = operand ? operand->getVar()
1850 : cast<ResultVariable>(typeOperand)->getVar();
1851 std::string name = op.getGetterName(var->name);
1852 if (var->isVariadic())
1853 body << name << "().getTypes()";
1854 else if (var->isOptional())
1855 body << llvm::formatv("({0}() ? {0}().getType() : ::mlir::Type())", name);
1856 else
1857 body << name << "().getType()";
1859 } else if (auto *string = dyn_cast<StringElement>(element)) {
1860 FmtContext ctx;
1861 ctx.withBuilder("::mlir::Builder(getContext())");
1862 ctx.addSubst("_ctxt", "getContext()");
1863 body << tgfmt(string->getValue(), &ctx);
1865 } else {
1866 llvm_unreachable("unknown custom directive parameter");
1870 /// Generate the printer for a custom directive.
1871 static void genCustomDirectivePrinter(CustomDirective *customDir,
1872 const Operator &op, MethodBody &body) {
1873 body << " print" << customDir->getName() << "(_odsPrinter, *this";
1874 for (FormatElement *param : customDir->getArguments()) {
1875 body << ", ";
1876 genCustomDirectiveParameterPrinter(param, op, body);
1878 body << ");\n";
1881 /// Generate the printer for a region with the given variable name.
1882 static void genRegionPrinter(const Twine &regionName, MethodBody &body,
1883 bool hasImplicitTermTrait) {
1884 if (hasImplicitTermTrait)
1885 body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode,
1886 regionName);
1887 else
1888 body << " _odsPrinter.printRegion(" << regionName << ");\n";
1890 static void genVariadicRegionPrinter(const Twine &regionListName,
1891 MethodBody &body,
1892 bool hasImplicitTermTrait) {
1893 body << " llvm::interleaveComma(" << regionListName
1894 << ", _odsPrinter, [&](::mlir::Region &region) {\n ";
1895 genRegionPrinter("region", body, hasImplicitTermTrait);
1896 body << " });\n";
1899 /// Generate the C++ for an operand to a (*-)type directive.
1900 static MethodBody &genTypeOperandPrinter(FormatElement *arg, const Operator &op,
1901 MethodBody &body,
1902 bool useArrayRef = true) {
1903 if (isa<OperandsDirective>(arg))
1904 return body << "getOperation()->getOperandTypes()";
1905 if (isa<ResultsDirective>(arg))
1906 return body << "getOperation()->getResultTypes()";
1907 auto *operand = dyn_cast<OperandVariable>(arg);
1908 auto *var = operand ? operand->getVar() : cast<ResultVariable>(arg)->getVar();
1909 if (var->isVariadicOfVariadic())
1910 return body << llvm::formatv("{0}().join().getTypes()",
1911 op.getGetterName(var->name));
1912 if (var->isVariadic())
1913 return body << op.getGetterName(var->name) << "().getTypes()";
1914 if (var->isOptional())
1915 return body << llvm::formatv(
1916 "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
1917 "::llvm::ArrayRef<::mlir::Type>())",
1918 op.getGetterName(var->name));
1919 if (useArrayRef)
1920 return body << "::llvm::ArrayRef<::mlir::Type>("
1921 << op.getGetterName(var->name) << "().getType())";
1922 return body << op.getGetterName(var->name) << "().getType()";
1925 /// Generate the printer for an enum attribute.
1926 static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
1927 MethodBody &body) {
1928 Attribute baseAttr = var->attr.getBaseAttr();
1929 const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
1930 std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
1932 body << llvm::formatv(enumAttrBeginPrinterCode,
1933 (var->attr.isOptional() ? "*" : "") +
1934 op.getGetterName(var->name),
1935 enumAttr.getSymbolToStringFnName());
1937 // Get a string containing all of the cases that can't be represented with a
1938 // keyword.
1939 BitVector nonKeywordCases(cases.size());
1940 for (auto it : llvm::enumerate(cases)) {
1941 if (!canFormatStringAsKeyword(it.value().getStr()))
1942 nonKeywordCases.set(it.index());
1945 // Otherwise if this is a bit enum attribute, don't allow cases that may
1946 // overlap with other cases. For simplicity sake, only allow cases with a
1947 // single bit value.
1948 if (enumAttr.isBitEnum()) {
1949 for (auto it : llvm::enumerate(cases)) {
1950 int64_t value = it.value().getValue();
1951 if (value < 0 || !llvm::isPowerOf2_64(value))
1952 nonKeywordCases.set(it.index());
1956 // If there are any cases that can't be used with a keyword, switch on the
1957 // case value to determine when to print in the string form.
1958 if (nonKeywordCases.any()) {
1959 body << " switch (caseValue) {\n";
1960 StringRef cppNamespace = enumAttr.getCppNamespace();
1961 StringRef enumName = enumAttr.getEnumClassName();
1962 for (auto it : llvm::enumerate(cases)) {
1963 if (nonKeywordCases.test(it.index()))
1964 continue;
1965 StringRef symbol = it.value().getSymbol();
1966 body << llvm::formatv(" case {0}::{1}::{2}:\n", cppNamespace, enumName,
1967 llvm::isDigit(symbol.front()) ? ("_" + symbol)
1968 : symbol);
1970 body << " _odsPrinter << caseValueStr;\n"
1971 " break;\n"
1972 " default:\n"
1973 " _odsPrinter << '\"' << caseValueStr << '\"';\n"
1974 " break;\n"
1975 " }\n"
1976 " }\n";
1977 return;
1980 body << " _odsPrinter << caseValueStr;\n"
1981 " }\n";
1984 /// Generate the check for the anchor of an optional group.
1985 static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
1986 const Operator &op,
1987 MethodBody &body) {
1988 TypeSwitch<FormatElement *>(anchor)
1989 .Case<OperandVariable, ResultVariable>([&](auto *element) {
1990 const NamedTypeConstraint *var = element->getVar();
1991 std::string name = op.getGetterName(var->name);
1992 if (var->isOptional())
1993 body << name << "()";
1994 else if (var->isVariadic())
1995 body << "!" << name << "().empty()";
1997 .Case<RegionVariable>([&](RegionVariable *element) {
1998 const NamedRegion *var = element->getVar();
1999 std::string name = op.getGetterName(var->name);
2000 // TODO: Add a check for optional regions here when ODS supports it.
2001 body << "!" << name << "().empty()";
2003 .Case<TypeDirective>([&](TypeDirective *element) {
2004 genOptionalGroupPrinterAnchor(element->getArg(), op, body);
2006 .Case<FunctionalTypeDirective>([&](FunctionalTypeDirective *element) {
2007 genOptionalGroupPrinterAnchor(element->getInputs(), op, body);
2009 .Case<AttributeVariable>([&](AttributeVariable *element) {
2010 Attribute attr = element->getVar()->attr;
2011 body << op.getGetterName(element->getVar()->name) << "Attr()";
2012 if (attr.isOptional())
2013 return; // done
2014 if (attr.hasDefaultValue()) {
2015 // Consider a default-valued attribute as present if it's not the
2016 // default value.
2017 FmtContext fctx;
2018 fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())");
2019 body << " && " << op.getGetterName(element->getVar()->name)
2020 << "Attr() != "
2021 << tgfmt(attr.getConstBuilderTemplate(), &fctx,
2022 attr.getDefaultValue());
2023 return;
2025 llvm_unreachable("attribute must be optional or default-valued");
2029 void collect(FormatElement *element,
2030 SmallVectorImpl<VariableElement *> &variables) {
2031 TypeSwitch<FormatElement *>(element)
2032 .Case([&](VariableElement *var) { variables.emplace_back(var); })
2033 .Case([&](CustomDirective *ele) {
2034 for (FormatElement *arg : ele->getArguments())
2035 collect(arg, variables);
2037 .Case([&](OptionalElement *ele) {
2038 for (FormatElement *arg : ele->getThenElements())
2039 collect(arg, variables);
2040 for (FormatElement *arg : ele->getElseElements())
2041 collect(arg, variables);
2043 .Case([&](FunctionalTypeDirective *funcType) {
2044 collect(funcType->getInputs(), variables);
2045 collect(funcType->getResults(), variables);
2047 .Case([&](OIListElement *oilist) {
2048 for (ArrayRef<FormatElement *> arg : oilist->getParsingElements())
2049 for (FormatElement *arg : arg)
2050 collect(arg, variables);
2054 void OperationFormat::genElementPrinter(FormatElement *element,
2055 MethodBody &body, Operator &op,
2056 bool &shouldEmitSpace,
2057 bool &lastWasPunctuation) {
2058 if (LiteralElement *literal = dyn_cast<LiteralElement>(element))
2059 return genLiteralPrinter(literal->getSpelling(), body, shouldEmitSpace,
2060 lastWasPunctuation);
2062 // Emit a whitespace element.
2063 if (auto *space = dyn_cast<WhitespaceElement>(element)) {
2064 if (space->getValue() == "\\n") {
2065 body << " _odsPrinter.printNewline();\n";
2066 } else {
2067 genSpacePrinter(!space->getValue().empty(), body, shouldEmitSpace,
2068 lastWasPunctuation);
2070 return;
2073 // Emit an optional group.
2074 if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) {
2075 // Emit the check for the presence of the anchor element.
2076 FormatElement *anchor = optional->getAnchor();
2077 body << " if (";
2078 if (optional->isInverted())
2079 body << "!";
2080 genOptionalGroupPrinterAnchor(anchor, op, body);
2081 body << ") {\n";
2082 body.indent();
2084 // If the anchor is a unit attribute, we don't need to print it. When
2085 // parsing, we will add this attribute if this group is present.
2086 ArrayRef<FormatElement *> thenElements = optional->getThenElements();
2087 ArrayRef<FormatElement *> elseElements = optional->getElseElements();
2088 FormatElement *elidedAnchorElement = nullptr;
2089 auto *anchorAttr = dyn_cast<AttributeVariable>(anchor);
2090 if (anchorAttr && anchorAttr != thenElements.front() &&
2091 (elseElements.empty() || anchorAttr != elseElements.front()) &&
2092 anchorAttr->isUnitAttr()) {
2093 elidedAnchorElement = anchorAttr;
2095 auto genElementPrinters = [&](ArrayRef<FormatElement *> elements) {
2096 for (FormatElement *childElement : elements) {
2097 if (childElement != elidedAnchorElement) {
2098 genElementPrinter(childElement, body, op, shouldEmitSpace,
2099 lastWasPunctuation);
2104 // Emit each of the elements.
2105 genElementPrinters(thenElements);
2106 body << "}";
2108 // Emit each of the else elements.
2109 if (!elseElements.empty()) {
2110 body << " else {\n";
2111 genElementPrinters(elseElements);
2112 body << "}";
2115 body.unindent() << "\n";
2116 return;
2119 // Emit the OIList
2120 if (auto *oilist = dyn_cast<OIListElement>(element)) {
2121 genLiteralPrinter(" ", body, shouldEmitSpace, lastWasPunctuation);
2122 for (auto clause : oilist->getClauses()) {
2123 LiteralElement *lelement = std::get<0>(clause);
2124 ArrayRef<FormatElement *> pelement = std::get<1>(clause);
2126 SmallVector<VariableElement *> vars;
2127 for (FormatElement *el : pelement)
2128 collect(el, vars);
2129 body << " if (false";
2130 for (VariableElement *var : vars) {
2131 TypeSwitch<FormatElement *>(var)
2132 .Case([&](AttributeVariable *attrEle) {
2133 body << " || " << op.getGetterName(attrEle->getVar()->name)
2134 << "Attr()";
2136 .Case([&](OperandVariable *ele) {
2137 if (ele->getVar()->isVariadic()) {
2138 body << " || " << op.getGetterName(ele->getVar()->name)
2139 << "().size()";
2140 } else {
2141 body << " || " << op.getGetterName(ele->getVar()->name) << "()";
2144 .Case([&](ResultVariable *ele) {
2145 if (ele->getVar()->isVariadic()) {
2146 body << " || " << op.getGetterName(ele->getVar()->name)
2147 << "().size()";
2148 } else {
2149 body << " || " << op.getGetterName(ele->getVar()->name) << "()";
2152 .Case([&](RegionVariable *reg) {
2153 body << " || " << op.getGetterName(reg->getVar()->name) << "()";
2157 body << ") {\n";
2158 genLiteralPrinter(lelement->getSpelling(), body, shouldEmitSpace,
2159 lastWasPunctuation);
2160 if (oilist->getUnitAttrParsingElement(pelement) == nullptr) {
2161 for (FormatElement *element : pelement)
2162 genElementPrinter(element, body, op, shouldEmitSpace,
2163 lastWasPunctuation);
2165 body << " }\n";
2167 return;
2170 // Emit the attribute dictionary.
2171 if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
2172 genAttrDictPrinter(*this, op, body, attrDict->isWithKeyword());
2173 lastWasPunctuation = false;
2174 return;
2177 // Emit the attribute dictionary.
2178 if (auto *propDict = dyn_cast<PropDictDirective>(element)) {
2179 genPropDictPrinter(*this, op, body);
2180 lastWasPunctuation = false;
2181 return;
2184 // Optionally insert a space before the next element. The AttrDict printer
2185 // already adds a space as necessary.
2186 if (shouldEmitSpace || !lastWasPunctuation)
2187 body << " _odsPrinter << ' ';\n";
2188 lastWasPunctuation = false;
2189 shouldEmitSpace = true;
2191 if (auto *attr = dyn_cast<AttributeVariable>(element)) {
2192 const NamedAttribute *var = attr->getVar();
2194 // If we are formatting as an enum, symbolize the attribute as a string.
2195 if (canFormatEnumAttr(var))
2196 return genEnumAttrPrinter(var, op, body);
2198 // If we are formatting as a symbol name, handle it as a symbol name.
2199 if (shouldFormatSymbolNameAttr(var)) {
2200 body << " _odsPrinter.printSymbolName(" << op.getGetterName(var->name)
2201 << "Attr().getValue());\n";
2202 return;
2205 // Elide the attribute type if it is buildable.
2206 if (attr->getTypeBuilder())
2207 body << " _odsPrinter.printAttributeWithoutType("
2208 << op.getGetterName(var->name) << "Attr());\n";
2209 else if (attr->shouldBeQualified() ||
2210 var->attr.getStorageType() == "::mlir::Attribute")
2211 body << " _odsPrinter.printAttribute(" << op.getGetterName(var->name)
2212 << "Attr());\n";
2213 else
2214 body << "_odsPrinter.printStrippedAttrOrType("
2215 << op.getGetterName(var->name) << "Attr());\n";
2216 } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
2217 if (operand->getVar()->isVariadicOfVariadic()) {
2218 body << " ::llvm::interleaveComma("
2219 << op.getGetterName(operand->getVar()->name)
2220 << "(), _odsPrinter, [&](const auto &operands) { _odsPrinter << "
2221 "\"(\" << operands << "
2222 "\")\"; });\n";
2224 } else if (operand->getVar()->isOptional()) {
2225 body << " if (::mlir::Value value = "
2226 << op.getGetterName(operand->getVar()->name) << "())\n"
2227 << " _odsPrinter << value;\n";
2228 } else {
2229 body << " _odsPrinter << " << op.getGetterName(operand->getVar()->name)
2230 << "();\n";
2232 } else if (auto *region = dyn_cast<RegionVariable>(element)) {
2233 const NamedRegion *var = region->getVar();
2234 std::string name = op.getGetterName(var->name);
2235 if (var->isVariadic()) {
2236 genVariadicRegionPrinter(name + "()", body, hasImplicitTermTrait);
2237 } else {
2238 genRegionPrinter(name + "()", body, hasImplicitTermTrait);
2240 } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
2241 const NamedSuccessor *var = successor->getVar();
2242 std::string name = op.getGetterName(var->name);
2243 if (var->isVariadic())
2244 body << " ::llvm::interleaveComma(" << name << "(), _odsPrinter);\n";
2245 else
2246 body << " _odsPrinter << " << name << "();\n";
2247 } else if (auto *dir = dyn_cast<CustomDirective>(element)) {
2248 genCustomDirectivePrinter(dir, op, body);
2249 } else if (isa<OperandsDirective>(element)) {
2250 body << " _odsPrinter << getOperation()->getOperands();\n";
2251 } else if (isa<RegionsDirective>(element)) {
2252 genVariadicRegionPrinter("getOperation()->getRegions()", body,
2253 hasImplicitTermTrait);
2254 } else if (isa<SuccessorsDirective>(element)) {
2255 body << " ::llvm::interleaveComma(getOperation()->getSuccessors(), "
2256 "_odsPrinter);\n";
2257 } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
2258 if (auto *operand = dyn_cast<OperandVariable>(dir->getArg())) {
2259 if (operand->getVar()->isVariadicOfVariadic()) {
2260 body << llvm::formatv(
2261 " ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, "
2262 "[&](::mlir::TypeRange types) {{ _odsPrinter << \"(\" << "
2263 "types << \")\"; });\n",
2264 op.getGetterName(operand->getVar()->name));
2265 return;
2268 const NamedTypeConstraint *var = nullptr;
2270 if (auto *operand = dyn_cast<OperandVariable>(dir->getArg()))
2271 var = operand->getVar();
2272 else if (auto *operand = dyn_cast<ResultVariable>(dir->getArg()))
2273 var = operand->getVar();
2275 if (var && !var->isVariadicOfVariadic() && !var->isVariadic() &&
2276 !var->isOptional()) {
2277 std::string cppClass = var->constraint.getCPPClassName();
2278 if (dir->shouldBeQualified()) {
2279 body << " _odsPrinter << " << op.getGetterName(var->name)
2280 << "().getType();\n";
2281 return;
2283 body << " {\n"
2284 << " auto type = " << op.getGetterName(var->name)
2285 << "().getType();\n"
2286 << " if (auto validType = type.dyn_cast<" << cppClass << ">())\n"
2287 << " _odsPrinter.printStrippedAttrOrType(validType);\n"
2288 << " else\n"
2289 << " _odsPrinter << type;\n"
2290 << " }\n";
2291 return;
2293 body << " _odsPrinter << ";
2294 genTypeOperandPrinter(dir->getArg(), op, body, /*useArrayRef=*/false)
2295 << ";\n";
2296 } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
2297 body << " _odsPrinter.printFunctionalType(";
2298 genTypeOperandPrinter(dir->getInputs(), op, body) << ", ";
2299 genTypeOperandPrinter(dir->getResults(), op, body) << ");\n";
2300 } else {
2301 llvm_unreachable("unknown format element");
2305 void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
2306 auto *method = opClass.addMethod(
2307 "void", "print",
2308 MethodParameter("::mlir::OpAsmPrinter &", "_odsPrinter"));
2309 auto &body = method->body();
2311 // Flags for if we should emit a space, and if the last element was
2312 // punctuation.
2313 bool shouldEmitSpace = true, lastWasPunctuation = false;
2314 for (FormatElement *element : elements)
2315 genElementPrinter(element, body, op, shouldEmitSpace, lastWasPunctuation);
2318 //===----------------------------------------------------------------------===//
2319 // OpFormatParser
2320 //===----------------------------------------------------------------------===//
2322 /// Function to find an element within the given range that has the same name as
2323 /// 'name'.
2324 template <typename RangeT>
2325 static auto findArg(RangeT &&range, StringRef name) {
2326 auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
2327 return it != range.end() ? &*it : nullptr;
2330 namespace {
2331 /// This class implements a parser for an instance of an operation assembly
2332 /// format.
2333 class OpFormatParser : public FormatParser {
2334 public:
2335 OpFormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op)
2336 : FormatParser(mgr, op.getLoc()[0]), fmt(format), op(op),
2337 seenOperandTypes(op.getNumOperands()),
2338 seenResultTypes(op.getNumResults()) {}
2340 protected:
2341 /// Verify the format elements.
2342 LogicalResult verify(SMLoc loc, ArrayRef<FormatElement *> elements) override;
2343 /// Verify the arguments to a custom directive.
2344 LogicalResult
2345 verifyCustomDirectiveArguments(SMLoc loc,
2346 ArrayRef<FormatElement *> arguments) override;
2347 /// Verify the elements of an optional group.
2348 LogicalResult verifyOptionalGroupElements(SMLoc loc,
2349 ArrayRef<FormatElement *> elements,
2350 FormatElement *anchor) override;
2351 LogicalResult verifyOptionalGroupElement(SMLoc loc, FormatElement *element,
2352 bool isAnchor);
2354 /// Parse an operation variable.
2355 FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
2356 Context ctx) override;
2357 /// Parse an operation format directive.
2358 FailureOr<FormatElement *>
2359 parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override;
2361 private:
2362 /// This struct represents a type resolution instance. It includes a specific
2363 /// type as well as an optional transformer to apply to that type in order to
2364 /// properly resolve the type of a variable.
2365 struct TypeResolutionInstance {
2366 ConstArgument resolver;
2367 std::optional<StringRef> transformer;
2370 /// Verify the state of operation attributes within the format.
2371 LogicalResult verifyAttributes(SMLoc loc, ArrayRef<FormatElement *> elements);
2373 /// Verify that attributes elements aren't followed by colon literals.
2374 LogicalResult verifyAttributeColonType(SMLoc loc,
2375 ArrayRef<FormatElement *> elements);
2376 /// Verify that the attribute dictionary directive isn't followed by a region.
2377 LogicalResult verifyAttrDictRegion(SMLoc loc,
2378 ArrayRef<FormatElement *> elements);
2380 /// Verify the state of operation operands within the format.
2381 LogicalResult
2382 verifyOperands(SMLoc loc,
2383 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2385 /// Verify the state of operation regions within the format.
2386 LogicalResult verifyRegions(SMLoc loc);
2388 /// Verify the state of operation results within the format.
2389 LogicalResult
2390 verifyResults(SMLoc loc,
2391 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2393 /// Verify the state of operation successors within the format.
2394 LogicalResult verifySuccessors(SMLoc loc);
2396 LogicalResult verifyOIListElements(SMLoc loc,
2397 ArrayRef<FormatElement *> elements);
2399 /// Given the values of an `AllTypesMatch` trait, check for inferable type
2400 /// resolution.
2401 void handleAllTypesMatchConstraint(
2402 ArrayRef<StringRef> values,
2403 llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
2404 /// Check for inferable type resolution given all operands, and or results,
2405 /// have the same type. If 'includeResults' is true, the results also have the
2406 /// same type as all of the operands.
2407 void handleSameTypesConstraint(
2408 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2409 bool includeResults);
2410 /// Check for inferable type resolution based on another operand, result, or
2411 /// attribute.
2412 void handleTypesMatchConstraint(
2413 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2414 const llvm::Record &def);
2416 /// Returns an argument or attribute with the given name that has been seen
2417 /// within the format.
2418 ConstArgument findSeenArg(StringRef name);
2420 /// Parse the various different directives.
2421 FailureOr<FormatElement *> parsePropDictDirective(SMLoc loc, Context context);
2422 FailureOr<FormatElement *> parseAttrDictDirective(SMLoc loc, Context context,
2423 bool withKeyword);
2424 FailureOr<FormatElement *> parseFunctionalTypeDirective(SMLoc loc,
2425 Context context);
2426 FailureOr<FormatElement *> parseOIListDirective(SMLoc loc, Context context);
2427 LogicalResult verifyOIListParsingElement(FormatElement *element, SMLoc loc);
2428 FailureOr<FormatElement *> parseOperandsDirective(SMLoc loc, Context context);
2429 FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc,
2430 Context context);
2431 FailureOr<FormatElement *> parseReferenceDirective(SMLoc loc,
2432 Context context);
2433 FailureOr<FormatElement *> parseRegionsDirective(SMLoc loc, Context context);
2434 FailureOr<FormatElement *> parseResultsDirective(SMLoc loc, Context context);
2435 FailureOr<FormatElement *> parseSuccessorsDirective(SMLoc loc,
2436 Context context);
2437 FailureOr<FormatElement *> parseTypeDirective(SMLoc loc, Context context);
2438 FailureOr<FormatElement *> parseTypeDirectiveOperand(SMLoc loc,
2439 bool isRefChild = false);
2441 //===--------------------------------------------------------------------===//
2442 // Fields
2443 //===--------------------------------------------------------------------===//
2445 OperationFormat &fmt;
2446 Operator &op;
2448 // The following are various bits of format state used for verification
2449 // during parsing.
2450 bool hasAttrDict = false;
2451 bool hasPropDict = false;
2452 bool hasAllRegions = false, hasAllSuccessors = false;
2453 bool canInferResultTypes = false;
2454 llvm::SmallBitVector seenOperandTypes, seenResultTypes;
2455 llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs;
2456 llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
2457 llvm::DenseSet<const NamedRegion *> seenRegions;
2458 llvm::DenseSet<const NamedSuccessor *> seenSuccessors;
2460 } // namespace
2462 LogicalResult OpFormatParser::verify(SMLoc loc,
2463 ArrayRef<FormatElement *> elements) {
2464 // Check that the attribute dictionary is in the format.
2465 if (!hasAttrDict)
2466 return emitError(loc, "'attr-dict' directive not found in "
2467 "custom assembly format");
2469 // Check for any type traits that we can use for inferring types.
2470 llvm::StringMap<TypeResolutionInstance> variableTyResolver;
2471 for (const Trait &trait : op.getTraits()) {
2472 const llvm::Record &def = trait.getDef();
2473 if (def.isSubClassOf("AllTypesMatch")) {
2474 handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),
2475 variableTyResolver);
2476 } else if (def.getName() == "SameTypeOperands") {
2477 handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false);
2478 } else if (def.getName() == "SameOperandsAndResultType") {
2479 handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
2480 } else if (def.isSubClassOf("TypesMatchWith")) {
2481 handleTypesMatchConstraint(variableTyResolver, def);
2482 } else if (!op.allResultTypesKnown()) {
2483 // This doesn't check the name directly to handle
2484 // DeclareOpInterfaceMethods<InferTypeOpInterface>
2485 // and the like.
2486 // TODO: Add hasCppInterface check.
2487 if (auto name = def.getValueAsOptionalString("cppInterfaceName")) {
2488 if (*name == "InferTypeOpInterface" &&
2489 def.getValueAsString("cppNamespace") == "::mlir")
2490 canInferResultTypes = true;
2495 // Verify the state of the various operation components.
2496 if (failed(verifyAttributes(loc, elements)) ||
2497 failed(verifyResults(loc, variableTyResolver)) ||
2498 failed(verifyOperands(loc, variableTyResolver)) ||
2499 failed(verifyRegions(loc)) || failed(verifySuccessors(loc)) ||
2500 failed(verifyOIListElements(loc, elements)))
2501 return failure();
2503 // Collect the set of used attributes in the format.
2504 fmt.usedAttributes = seenAttrs.takeVector();
2505 return success();
2508 LogicalResult
2509 OpFormatParser::verifyAttributes(SMLoc loc,
2510 ArrayRef<FormatElement *> elements) {
2511 // Check that there are no `:` literals after an attribute without a constant
2512 // type. The attribute grammar contains an optional trailing colon type, which
2513 // can lead to unexpected and generally unintended behavior. Given that, it is
2514 // better to just error out here instead.
2515 if (failed(verifyAttributeColonType(loc, elements)))
2516 return failure();
2517 // Check that there are no region variables following an attribute dicitonary.
2518 // Both start with `{` and so the optional attribute dictionary can cause
2519 // format ambiguities.
2520 if (failed(verifyAttrDictRegion(loc, elements)))
2521 return failure();
2523 // Check for VariadicOfVariadic variables. The segment attribute of those
2524 // variables will be infered.
2525 for (const NamedTypeConstraint *var : seenOperands) {
2526 if (var->constraint.isVariadicOfVariadic()) {
2527 fmt.inferredAttributes.insert(
2528 var->constraint.getVariadicOfVariadicSegmentSizeAttr());
2532 return success();
2535 /// Returns whether the single format element is optionally parsed.
2536 static bool isOptionallyParsed(FormatElement *el) {
2537 if (auto *attrVar = dyn_cast<AttributeVariable>(el)) {
2538 Attribute attr = attrVar->getVar()->attr;
2539 return attr.isOptional() || attr.hasDefaultValue();
2541 if (auto *operandVar = dyn_cast<OperandVariable>(el)) {
2542 const NamedTypeConstraint *operand = operandVar->getVar();
2543 return operand->isOptional() || operand->isVariadic() ||
2544 operand->isVariadicOfVariadic();
2546 if (auto *successorVar = dyn_cast<SuccessorVariable>(el))
2547 return successorVar->getVar()->isVariadic();
2548 if (auto *regionVar = dyn_cast<RegionVariable>(el))
2549 return regionVar->getVar()->isVariadic();
2550 return isa<WhitespaceElement, AttrDictDirective>(el);
2553 /// Scan the given range of elements from the start for an invalid format
2554 /// element that satisfies `isInvalid`, skipping any optionally-parsed elements.
2555 /// If an optional group is encountered, this function recurses into the 'then'
2556 /// and 'else' elements to check if they are invalid. Returns `success` if the
2557 /// range is known to be valid or `std::nullopt` if scanning reached the end.
2559 /// Since the guard element of an optional group is required, this function
2560 /// accepts an optional element pointer to mark it as required.
2561 static std::optional<LogicalResult> checkRangeForElement(
2562 FormatElement *base,
2563 function_ref<bool(FormatElement *, FormatElement *)> isInvalid,
2564 iterator_range<ArrayRef<FormatElement *>::iterator> elementRange,
2565 FormatElement *optionalGuard = nullptr) {
2566 for (FormatElement *element : elementRange) {
2567 // If we encounter an invalid element, return an error.
2568 if (isInvalid(base, element))
2569 return failure();
2571 // Recurse on optional groups.
2572 if (auto *optional = dyn_cast<OptionalElement>(element)) {
2573 if (std::optional<LogicalResult> result = checkRangeForElement(
2574 base, isInvalid, optional->getThenElements(),
2575 // The optional group guard is required for the group.
2576 optional->getThenElements().front()))
2577 if (failed(*result))
2578 return failure();
2579 if (std::optional<LogicalResult> result = checkRangeForElement(
2580 base, isInvalid, optional->getElseElements()))
2581 if (failed(*result))
2582 return failure();
2583 // Skip the optional group.
2584 continue;
2587 // Skip optionally parsed elements.
2588 if (element != optionalGuard && isOptionallyParsed(element))
2589 continue;
2591 // We found a closing element that is valid.
2592 return success();
2594 // Return std::nullopt to indicate that we reached the end.
2595 return std::nullopt;
2598 /// For the given elements, check whether any attributes are followed by a colon
2599 /// literal, resulting in an ambiguous assembly format. Returns a non-null
2600 /// attribute if verification of said attribute reached the end of the range.
2601 /// Returns null if all attribute elements are verified.
2602 static FailureOr<FormatElement *> verifyAdjacentElements(
2603 function_ref<bool(FormatElement *)> isBase,
2604 function_ref<bool(FormatElement *, FormatElement *)> isInvalid,
2605 ArrayRef<FormatElement *> elements) {
2606 for (auto *it = elements.begin(), *e = elements.end(); it != e; ++it) {
2607 // The current attribute being verified.
2608 FormatElement *base;
2610 if (isBase(*it)) {
2611 base = *it;
2612 } else if (auto *optional = dyn_cast<OptionalElement>(*it)) {
2613 // Recurse on optional groups.
2614 FailureOr<FormatElement *> thenResult = verifyAdjacentElements(
2615 isBase, isInvalid, optional->getThenElements());
2616 if (failed(thenResult))
2617 return failure();
2618 FailureOr<FormatElement *> elseResult = verifyAdjacentElements(
2619 isBase, isInvalid, optional->getElseElements());
2620 if (failed(elseResult))
2621 return failure();
2622 // If either optional group has an unverified attribute, save it.
2623 // Otherwise, move on to the next element.
2624 if (!(base = *thenResult) && !(base = *elseResult))
2625 continue;
2626 } else {
2627 continue;
2630 // Verify subsequent elements for potential ambiguities.
2631 if (std::optional<LogicalResult> result =
2632 checkRangeForElement(base, isInvalid, {std::next(it), e})) {
2633 if (failed(*result))
2634 return failure();
2635 } else {
2636 // Since we reached the end, return the attribute as unverified.
2637 return base;
2640 // All attribute elements are known to be verified.
2641 return nullptr;
2644 LogicalResult
2645 OpFormatParser::verifyAttributeColonType(SMLoc loc,
2646 ArrayRef<FormatElement *> elements) {
2647 auto isBase = [](FormatElement *el) {
2648 auto *attr = dyn_cast<AttributeVariable>(el);
2649 if (!attr)
2650 return false;
2651 // Check only attributes without type builders or that are known to call
2652 // the generic attribute parser.
2653 return !attr->getTypeBuilder() &&
2654 (attr->shouldBeQualified() ||
2655 attr->getVar()->attr.getStorageType() == "::mlir::Attribute");
2657 auto isInvalid = [&](FormatElement *base, FormatElement *el) {
2658 auto *literal = dyn_cast<LiteralElement>(el);
2659 if (!literal || literal->getSpelling() != ":")
2660 return false;
2661 // If we encounter `:`, the range is known to be invalid.
2662 (void)emitError(
2663 loc,
2664 llvm::formatv("format ambiguity caused by `:` literal found after "
2665 "attribute `{0}` which does not have a buildable type",
2666 cast<AttributeVariable>(base)->getVar()->name));
2667 return true;
2669 return verifyAdjacentElements(isBase, isInvalid, elements);
2672 LogicalResult
2673 OpFormatParser::verifyAttrDictRegion(SMLoc loc,
2674 ArrayRef<FormatElement *> elements) {
2675 auto isBase = [](FormatElement *el) {
2676 if (auto *attrDict = dyn_cast<AttrDictDirective>(el))
2677 return !attrDict->isWithKeyword();
2678 return false;
2680 auto isInvalid = [&](FormatElement *base, FormatElement *el) {
2681 auto *region = dyn_cast<RegionVariable>(el);
2682 if (!region)
2683 return false;
2684 (void)emitErrorAndNote(
2685 loc,
2686 llvm::formatv("format ambiguity caused by `attr-dict` directive "
2687 "followed by region `{0}`",
2688 region->getVar()->name),
2689 "try using `attr-dict-with-keyword` instead");
2690 return true;
2692 return verifyAdjacentElements(isBase, isInvalid, elements);
2695 LogicalResult OpFormatParser::verifyOperands(
2696 SMLoc loc, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2697 // Check that all of the operands are within the format, and their types can
2698 // be inferred.
2699 auto &buildableTypes = fmt.buildableTypes;
2700 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) {
2701 NamedTypeConstraint &operand = op.getOperand(i);
2703 // Check that the operand itself is in the format.
2704 if (!fmt.allOperands && !seenOperands.count(&operand)) {
2705 return emitErrorAndNote(loc,
2706 "operand #" + Twine(i) + ", named '" +
2707 operand.name + "', not found",
2708 "suggest adding a '$" + operand.name +
2709 "' directive to the custom assembly format");
2712 // Check that the operand type is in the format, or that it can be inferred.
2713 if (fmt.allOperandTypes || seenOperandTypes.test(i))
2714 continue;
2716 // Check to see if we can infer this type from another variable.
2717 auto varResolverIt = variableTyResolver.find(op.getOperand(i).name);
2718 if (varResolverIt != variableTyResolver.end()) {
2719 TypeResolutionInstance &resolver = varResolverIt->second;
2720 fmt.operandTypes[i].setResolver(resolver.resolver, resolver.transformer);
2721 continue;
2724 // Similarly to results, allow a custom builder for resolving the type if
2725 // we aren't using the 'operands' directive.
2726 std::optional<StringRef> builder = operand.constraint.getBuilderCall();
2727 if (!builder || (fmt.allOperands && operand.isVariableLength())) {
2728 return emitErrorAndNote(
2729 loc,
2730 "type of operand #" + Twine(i) + ", named '" + operand.name +
2731 "', is not buildable and a buildable type cannot be inferred",
2732 "suggest adding a type constraint to the operation or adding a "
2733 "'type($" +
2734 operand.name + ")' directive to the " + "custom assembly format");
2736 auto it = buildableTypes.insert({*builder, buildableTypes.size()});
2737 fmt.operandTypes[i].setBuilderIdx(it.first->second);
2739 return success();
2742 LogicalResult OpFormatParser::verifyRegions(SMLoc loc) {
2743 // Check that all of the regions are within the format.
2744 if (hasAllRegions)
2745 return success();
2747 for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) {
2748 const NamedRegion &region = op.getRegion(i);
2749 if (!seenRegions.count(&region)) {
2750 return emitErrorAndNote(loc,
2751 "region #" + Twine(i) + ", named '" +
2752 region.name + "', not found",
2753 "suggest adding a '$" + region.name +
2754 "' directive to the custom assembly format");
2757 return success();
2760 LogicalResult OpFormatParser::verifyResults(
2761 SMLoc loc, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2762 // If we format all of the types together, there is nothing to check.
2763 if (fmt.allResultTypes)
2764 return success();
2766 // If no result types are specified and we can infer them, infer all result
2767 // types
2768 if (op.getNumResults() > 0 && seenResultTypes.count() == 0 &&
2769 canInferResultTypes) {
2770 fmt.infersResultTypes = true;
2771 return success();
2774 // Check that all of the result types can be inferred.
2775 auto &buildableTypes = fmt.buildableTypes;
2776 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) {
2777 if (seenResultTypes.test(i))
2778 continue;
2780 // Check to see if we can infer this type from another variable.
2781 auto varResolverIt = variableTyResolver.find(op.getResultName(i));
2782 if (varResolverIt != variableTyResolver.end()) {
2783 TypeResolutionInstance resolver = varResolverIt->second;
2784 fmt.resultTypes[i].setResolver(resolver.resolver, resolver.transformer);
2785 continue;
2788 // If the result is not variable length, allow for the case where the type
2789 // has a builder that we can use.
2790 NamedTypeConstraint &result = op.getResult(i);
2791 std::optional<StringRef> builder = result.constraint.getBuilderCall();
2792 if (!builder || result.isVariableLength()) {
2793 return emitErrorAndNote(
2794 loc,
2795 "type of result #" + Twine(i) + ", named '" + result.name +
2796 "', is not buildable and a buildable type cannot be inferred",
2797 "suggest adding a type constraint to the operation or adding a "
2798 "'type($" +
2799 result.name + ")' directive to the " + "custom assembly format");
2801 // Note in the format that this result uses the custom builder.
2802 auto it = buildableTypes.insert({*builder, buildableTypes.size()});
2803 fmt.resultTypes[i].setBuilderIdx(it.first->second);
2805 return success();
2808 LogicalResult OpFormatParser::verifySuccessors(SMLoc loc) {
2809 // Check that all of the successors are within the format.
2810 if (hasAllSuccessors)
2811 return success();
2813 for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) {
2814 const NamedSuccessor &successor = op.getSuccessor(i);
2815 if (!seenSuccessors.count(&successor)) {
2816 return emitErrorAndNote(loc,
2817 "successor #" + Twine(i) + ", named '" +
2818 successor.name + "', not found",
2819 "suggest adding a '$" + successor.name +
2820 "' directive to the custom assembly format");
2823 return success();
2826 LogicalResult
2827 OpFormatParser::verifyOIListElements(SMLoc loc,
2828 ArrayRef<FormatElement *> elements) {
2829 // Check that all of the successors are within the format.
2830 SmallVector<StringRef> prohibitedLiterals;
2831 for (FormatElement *it : elements) {
2832 if (auto *oilist = dyn_cast<OIListElement>(it)) {
2833 if (!prohibitedLiterals.empty()) {
2834 // We just saw an oilist element in last iteration. Literals should not
2835 // match.
2836 for (LiteralElement *literal : oilist->getLiteralElements()) {
2837 if (find(prohibitedLiterals, literal->getSpelling()) !=
2838 prohibitedLiterals.end()) {
2839 return emitError(
2840 loc, "format ambiguity because " + literal->getSpelling() +
2841 " is used in two adjacent oilist elements.");
2845 for (LiteralElement *literal : oilist->getLiteralElements())
2846 prohibitedLiterals.push_back(literal->getSpelling());
2847 } else if (auto *literal = dyn_cast<LiteralElement>(it)) {
2848 if (find(prohibitedLiterals, literal->getSpelling()) !=
2849 prohibitedLiterals.end()) {
2850 return emitError(
2851 loc,
2852 "format ambiguity because " + literal->getSpelling() +
2853 " is used both in oilist element and the adjacent literal.");
2855 prohibitedLiterals.clear();
2856 } else {
2857 prohibitedLiterals.clear();
2860 return success();
2863 void OpFormatParser::handleAllTypesMatchConstraint(
2864 ArrayRef<StringRef> values,
2865 llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
2866 for (unsigned i = 0, e = values.size(); i != e; ++i) {
2867 // Check to see if this value matches a resolved operand or result type.
2868 ConstArgument arg = findSeenArg(values[i]);
2869 if (!arg)
2870 continue;
2872 // Mark this value as the type resolver for the other variables.
2873 for (unsigned j = 0; j != i; ++j)
2874 variableTyResolver[values[j]] = {arg, std::nullopt};
2875 for (unsigned j = i + 1; j != e; ++j)
2876 variableTyResolver[values[j]] = {arg, std::nullopt};
2880 void OpFormatParser::handleSameTypesConstraint(
2881 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2882 bool includeResults) {
2883 const NamedTypeConstraint *resolver = nullptr;
2884 int resolvedIt = -1;
2886 // Check to see if there is an operand or result to use for the resolution.
2887 if ((resolvedIt = seenOperandTypes.find_first()) != -1)
2888 resolver = &op.getOperand(resolvedIt);
2889 else if (includeResults && (resolvedIt = seenResultTypes.find_first()) != -1)
2890 resolver = &op.getResult(resolvedIt);
2891 else
2892 return;
2894 // Set the resolvers for each operand and result.
2895 for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i)
2896 if (!seenOperandTypes.test(i))
2897 variableTyResolver[op.getOperand(i).name] = {resolver, std::nullopt};
2898 if (includeResults) {
2899 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
2900 if (!seenResultTypes.test(i))
2901 variableTyResolver[op.getResultName(i)] = {resolver, std::nullopt};
2905 void OpFormatParser::handleTypesMatchConstraint(
2906 llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
2907 const llvm::Record &def) {
2908 StringRef lhsName = def.getValueAsString("lhs");
2909 StringRef rhsName = def.getValueAsString("rhs");
2910 StringRef transformer = def.getValueAsString("transformer");
2911 if (ConstArgument arg = findSeenArg(lhsName))
2912 variableTyResolver[rhsName] = {arg, transformer};
2915 ConstArgument OpFormatParser::findSeenArg(StringRef name) {
2916 if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name))
2917 return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;
2918 if (const NamedTypeConstraint *arg = findArg(op.getResults(), name))
2919 return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr;
2920 if (const NamedAttribute *attr = findArg(op.getAttributes(), name))
2921 return seenAttrs.count(attr) ? attr : nullptr;
2922 return nullptr;
2925 FailureOr<FormatElement *>
2926 OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
2927 // Check that the parsed argument is something actually registered on the op.
2928 // Attributes
2929 if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) {
2930 if (ctx == TypeDirectiveContext)
2931 return emitError(
2932 loc, "attributes cannot be used as children to a `type` directive");
2933 if (ctx == RefDirectiveContext) {
2934 if (!seenAttrs.count(attr))
2935 return emitError(loc, "attribute '" + name +
2936 "' must be bound before it is referenced");
2937 } else if (!seenAttrs.insert(attr)) {
2938 return emitError(loc, "attribute '" + name + "' is already bound");
2941 return create<AttributeVariable>(attr);
2943 // Operands
2944 if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) {
2945 if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
2946 if (fmt.allOperands || !seenOperands.insert(operand).second)
2947 return emitError(loc, "operand '" + name + "' is already bound");
2948 } else if (ctx == RefDirectiveContext && !seenOperands.count(operand)) {
2949 return emitError(loc, "operand '" + name +
2950 "' must be bound before it is referenced");
2952 return create<OperandVariable>(operand);
2954 // Regions
2955 if (const NamedRegion *region = findArg(op.getRegions(), name)) {
2956 if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
2957 if (hasAllRegions || !seenRegions.insert(region).second)
2958 return emitError(loc, "region '" + name + "' is already bound");
2959 } else if (ctx == RefDirectiveContext && !seenRegions.count(region)) {
2960 return emitError(loc, "region '" + name +
2961 "' must be bound before it is referenced");
2962 } else {
2963 return emitError(loc, "regions can only be used at the top level");
2965 return create<RegionVariable>(region);
2967 // Results.
2968 if (const auto *result = findArg(op.getResults(), name)) {
2969 if (ctx != TypeDirectiveContext)
2970 return emitError(loc, "result variables can can only be used as a child "
2971 "to a 'type' directive");
2972 return create<ResultVariable>(result);
2974 // Successors.
2975 if (const auto *successor = findArg(op.getSuccessors(), name)) {
2976 if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
2977 if (hasAllSuccessors || !seenSuccessors.insert(successor).second)
2978 return emitError(loc, "successor '" + name + "' is already bound");
2979 } else if (ctx == RefDirectiveContext && !seenSuccessors.count(successor)) {
2980 return emitError(loc, "successor '" + name +
2981 "' must be bound before it is referenced");
2982 } else {
2983 return emitError(loc, "successors can only be used at the top level");
2986 return create<SuccessorVariable>(successor);
2988 return emitError(loc, "expected variable to refer to an argument, region, "
2989 "result, or successor");
2992 FailureOr<FormatElement *>
2993 OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
2994 Context ctx) {
2995 switch (kind) {
2996 case FormatToken::kw_prop_dict:
2997 return parsePropDictDirective(loc, ctx);
2998 case FormatToken::kw_attr_dict:
2999 return parseAttrDictDirective(loc, ctx,
3000 /*withKeyword=*/false);
3001 case FormatToken::kw_attr_dict_w_keyword:
3002 return parseAttrDictDirective(loc, ctx,
3003 /*withKeyword=*/true);
3004 case FormatToken::kw_functional_type:
3005 return parseFunctionalTypeDirective(loc, ctx);
3006 case FormatToken::kw_operands:
3007 return parseOperandsDirective(loc, ctx);
3008 case FormatToken::kw_qualified:
3009 return parseQualifiedDirective(loc, ctx);
3010 case FormatToken::kw_regions:
3011 return parseRegionsDirective(loc, ctx);
3012 case FormatToken::kw_results:
3013 return parseResultsDirective(loc, ctx);
3014 case FormatToken::kw_successors:
3015 return parseSuccessorsDirective(loc, ctx);
3016 case FormatToken::kw_ref:
3017 return parseReferenceDirective(loc, ctx);
3018 case FormatToken::kw_type:
3019 return parseTypeDirective(loc, ctx);
3020 case FormatToken::kw_oilist:
3021 return parseOIListDirective(loc, ctx);
3023 default:
3024 return emitError(loc, "unsupported directive kind");
3028 FailureOr<FormatElement *>
3029 OpFormatParser::parseAttrDictDirective(SMLoc loc, Context context,
3030 bool withKeyword) {
3031 if (context == TypeDirectiveContext)
3032 return emitError(loc, "'attr-dict' directive can only be used as a "
3033 "top-level directive");
3035 if (context == RefDirectiveContext) {
3036 if (!hasAttrDict)
3037 return emitError(loc, "'ref' of 'attr-dict' is not bound by a prior "
3038 "'attr-dict' directive");
3040 // Otherwise, this is a top-level context.
3041 } else {
3042 if (hasAttrDict)
3043 return emitError(loc, "'attr-dict' directive has already been seen");
3044 hasAttrDict = true;
3047 return create<AttrDictDirective>(withKeyword);
3050 FailureOr<FormatElement *>
3051 OpFormatParser::parsePropDictDirective(SMLoc loc, Context context) {
3052 if (context == TypeDirectiveContext)
3053 return emitError(loc, "'prop-dict' directive can only be used as a "
3054 "top-level directive");
3056 if (context == RefDirectiveContext)
3057 llvm::report_fatal_error("'ref' of 'prop-dict' unsupported");
3058 // Otherwise, this is a top-level context.
3060 if (hasPropDict)
3061 return emitError(loc, "'prop-dict' directive has already been seen");
3062 hasPropDict = true;
3064 return create<PropDictDirective>();
3067 LogicalResult OpFormatParser::verifyCustomDirectiveArguments(
3068 SMLoc loc, ArrayRef<FormatElement *> arguments) {
3069 for (FormatElement *argument : arguments) {
3070 if (!isa<StringElement, RefDirective, TypeDirective, AttrDictDirective,
3071 AttributeVariable, OperandVariable, RegionVariable,
3072 SuccessorVariable>(argument)) {
3073 // TODO: FormatElement should have location info attached.
3074 return emitError(loc, "only variables and types may be used as "
3075 "parameters to a custom directive");
3077 if (auto *type = dyn_cast<TypeDirective>(argument)) {
3078 if (!isa<OperandVariable, ResultVariable>(type->getArg())) {
3079 return emitError(loc, "type directives within a custom directive may "
3080 "only refer to variables");
3084 return success();
3087 FailureOr<FormatElement *>
3088 OpFormatParser::parseFunctionalTypeDirective(SMLoc loc, Context context) {
3089 if (context != TopLevelContext)
3090 return emitError(
3091 loc, "'functional-type' is only valid as a top-level directive");
3093 // Parse the main operand.
3094 FailureOr<FormatElement *> inputs, results;
3095 if (failed(parseToken(FormatToken::l_paren,
3096 "expected '(' before argument list")) ||
3097 failed(inputs = parseTypeDirectiveOperand(loc)) ||
3098 failed(parseToken(FormatToken::comma,
3099 "expected ',' after inputs argument")) ||
3100 failed(results = parseTypeDirectiveOperand(loc)) ||
3101 failed(
3102 parseToken(FormatToken::r_paren, "expected ')' after argument list")))
3103 return failure();
3104 return create<FunctionalTypeDirective>(*inputs, *results);
3107 FailureOr<FormatElement *>
3108 OpFormatParser::parseOperandsDirective(SMLoc loc, Context context) {
3109 if (context == RefDirectiveContext) {
3110 if (!fmt.allOperands)
3111 return emitError(loc, "'ref' of 'operands' is not bound by a prior "
3112 "'operands' directive");
3114 } else if (context == TopLevelContext || context == CustomDirectiveContext) {
3115 if (fmt.allOperands || !seenOperands.empty())
3116 return emitError(loc, "'operands' directive creates overlap in format");
3117 fmt.allOperands = true;
3119 return create<OperandsDirective>();
3122 FailureOr<FormatElement *>
3123 OpFormatParser::parseReferenceDirective(SMLoc loc, Context context) {
3124 if (context != CustomDirectiveContext)
3125 return emitError(loc, "'ref' is only valid within a `custom` directive");
3127 FailureOr<FormatElement *> arg;
3128 if (failed(parseToken(FormatToken::l_paren,
3129 "expected '(' before argument list")) ||
3130 failed(arg = parseElement(RefDirectiveContext)) ||
3131 failed(
3132 parseToken(FormatToken::r_paren, "expected ')' after argument list")))
3133 return failure();
3135 return create<RefDirective>(*arg);
3138 FailureOr<FormatElement *>
3139 OpFormatParser::parseRegionsDirective(SMLoc loc, Context context) {
3140 if (context == TypeDirectiveContext)
3141 return emitError(loc, "'regions' is only valid as a top-level directive");
3142 if (context == RefDirectiveContext) {
3143 if (!hasAllRegions)
3144 return emitError(loc, "'ref' of 'regions' is not bound by a prior "
3145 "'regions' directive");
3147 // Otherwise, this is a TopLevel directive.
3148 } else {
3149 if (hasAllRegions || !seenRegions.empty())
3150 return emitError(loc, "'regions' directive creates overlap in format");
3151 hasAllRegions = true;
3153 return create<RegionsDirective>();
3156 FailureOr<FormatElement *>
3157 OpFormatParser::parseResultsDirective(SMLoc loc, Context context) {
3158 if (context != TypeDirectiveContext)
3159 return emitError(loc, "'results' directive can can only be used as a child "
3160 "to a 'type' directive");
3161 return create<ResultsDirective>();
3164 FailureOr<FormatElement *>
3165 OpFormatParser::parseSuccessorsDirective(SMLoc loc, Context context) {
3166 if (context == TypeDirectiveContext)
3167 return emitError(loc,
3168 "'successors' is only valid as a top-level directive");
3169 if (context == RefDirectiveContext) {
3170 if (!hasAllSuccessors)
3171 return emitError(loc, "'ref' of 'successors' is not bound by a prior "
3172 "'successors' directive");
3174 // Otherwise, this is a TopLevel directive.
3175 } else {
3176 if (hasAllSuccessors || !seenSuccessors.empty())
3177 return emitError(loc, "'successors' directive creates overlap in format");
3178 hasAllSuccessors = true;
3180 return create<SuccessorsDirective>();
3183 FailureOr<FormatElement *>
3184 OpFormatParser::parseOIListDirective(SMLoc loc, Context context) {
3185 if (failed(parseToken(FormatToken::l_paren,
3186 "expected '(' before oilist argument list")))
3187 return failure();
3188 std::vector<FormatElement *> literalElements;
3189 std::vector<std::vector<FormatElement *>> parsingElements;
3190 do {
3191 FailureOr<FormatElement *> lelement = parseLiteral(context);
3192 if (failed(lelement))
3193 return failure();
3194 literalElements.push_back(*lelement);
3195 parsingElements.emplace_back();
3196 std::vector<FormatElement *> &currParsingElements = parsingElements.back();
3197 while (peekToken().getKind() != FormatToken::pipe &&
3198 peekToken().getKind() != FormatToken::r_paren) {
3199 FailureOr<FormatElement *> pelement = parseElement(context);
3200 if (failed(pelement) ||
3201 failed(verifyOIListParsingElement(*pelement, loc)))
3202 return failure();
3203 currParsingElements.push_back(*pelement);
3205 if (peekToken().getKind() == FormatToken::pipe) {
3206 consumeToken();
3207 continue;
3209 if (peekToken().getKind() == FormatToken::r_paren) {
3210 consumeToken();
3211 break;
3213 } while (true);
3215 return create<OIListElement>(std::move(literalElements),
3216 std::move(parsingElements));
3219 LogicalResult OpFormatParser::verifyOIListParsingElement(FormatElement *element,
3220 SMLoc loc) {
3221 SmallVector<VariableElement *> vars;
3222 collect(element, vars);
3223 for (VariableElement *elem : vars) {
3224 LogicalResult res =
3225 TypeSwitch<FormatElement *, LogicalResult>(elem)
3226 // Only optional attributes can be within an oilist parsing group.
3227 .Case([&](AttributeVariable *attrEle) {
3228 if (!attrEle->getVar()->attr.isOptional() &&
3229 !attrEle->getVar()->attr.hasDefaultValue())
3230 return emitError(loc, "only optional attributes can be used in "
3231 "an oilist parsing group");
3232 return success();
3234 // Only optional-like(i.e. variadic) operands can be within an
3235 // oilist parsing group.
3236 .Case([&](OperandVariable *ele) {
3237 if (!ele->getVar()->isVariableLength())
3238 return emitError(loc, "only variable length operands can be "
3239 "used within an oilist parsing group");
3240 return success();
3242 // Only optional-like(i.e. variadic) results can be within an oilist
3243 // parsing group.
3244 .Case([&](ResultVariable *ele) {
3245 if (!ele->getVar()->isVariableLength())
3246 return emitError(loc, "only variable length results can be "
3247 "used within an oilist parsing group");
3248 return success();
3250 .Case([&](RegionVariable *) { return success(); })
3251 .Default([&](FormatElement *) {
3252 return emitError(loc,
3253 "only literals, types, and variables can be "
3254 "used within an oilist group");
3256 if (failed(res))
3257 return failure();
3259 return success();
3262 FailureOr<FormatElement *> OpFormatParser::parseTypeDirective(SMLoc loc,
3263 Context context) {
3264 if (context == TypeDirectiveContext)
3265 return emitError(loc, "'type' cannot be used as a child of another `type`");
3267 bool isRefChild = context == RefDirectiveContext;
3268 FailureOr<FormatElement *> operand;
3269 if (failed(parseToken(FormatToken::l_paren,
3270 "expected '(' before argument list")) ||
3271 failed(operand = parseTypeDirectiveOperand(loc, isRefChild)) ||
3272 failed(
3273 parseToken(FormatToken::r_paren, "expected ')' after argument list")))
3274 return failure();
3276 return create<TypeDirective>(*operand);
3279 FailureOr<FormatElement *>
3280 OpFormatParser::parseQualifiedDirective(SMLoc loc, Context context) {
3281 FailureOr<FormatElement *> element;
3282 if (failed(parseToken(FormatToken::l_paren,
3283 "expected '(' before argument list")) ||
3284 failed(element = parseElement(context)) ||
3285 failed(
3286 parseToken(FormatToken::r_paren, "expected ')' after argument list")))
3287 return failure();
3288 return TypeSwitch<FormatElement *, FailureOr<FormatElement *>>(*element)
3289 .Case<AttributeVariable, TypeDirective>([](auto *element) {
3290 element->setShouldBeQualified();
3291 return element;
3293 .Default([&](auto *element) {
3294 return this->emitError(
3295 loc,
3296 "'qualified' directive expects an attribute or a `type` directive");
3300 FailureOr<FormatElement *>
3301 OpFormatParser::parseTypeDirectiveOperand(SMLoc loc, bool isRefChild) {
3302 FailureOr<FormatElement *> result = parseElement(TypeDirectiveContext);
3303 if (failed(result))
3304 return failure();
3306 FormatElement *element = *result;
3307 if (isa<LiteralElement>(element))
3308 return emitError(
3309 loc, "'type' directive operand expects variable or directive operand");
3311 if (auto *var = dyn_cast<OperandVariable>(element)) {
3312 unsigned opIdx = var->getVar() - op.operand_begin();
3313 if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
3314 return emitError(loc, "'type' of '" + var->getVar()->name +
3315 "' is already bound");
3316 if (isRefChild && !(fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
3317 return emitError(loc, "'ref' of 'type($" + var->getVar()->name +
3318 ")' is not bound by a prior 'type' directive");
3319 seenOperandTypes.set(opIdx);
3320 } else if (auto *var = dyn_cast<ResultVariable>(element)) {
3321 unsigned resIdx = var->getVar() - op.result_begin();
3322 if (!isRefChild && (fmt.allResultTypes || seenResultTypes.test(resIdx)))
3323 return emitError(loc, "'type' of '" + var->getVar()->name +
3324 "' is already bound");
3325 if (isRefChild && !(fmt.allResultTypes || seenResultTypes.test(resIdx)))
3326 return emitError(loc, "'ref' of 'type($" + var->getVar()->name +
3327 ")' is not bound by a prior 'type' directive");
3328 seenResultTypes.set(resIdx);
3329 } else if (isa<OperandsDirective>(&*element)) {
3330 if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.any()))
3331 return emitError(loc, "'operands' 'type' is already bound");
3332 if (isRefChild && !fmt.allOperandTypes)
3333 return emitError(loc, "'ref' of 'type(operands)' is not bound by a prior "
3334 "'type' directive");
3335 fmt.allOperandTypes = true;
3336 } else if (isa<ResultsDirective>(&*element)) {
3337 if (!isRefChild && (fmt.allResultTypes || seenResultTypes.any()))
3338 return emitError(loc, "'results' 'type' is already bound");
3339 if (isRefChild && !fmt.allResultTypes)
3340 return emitError(loc, "'ref' of 'type(results)' is not bound by a prior "
3341 "'type' directive");
3342 fmt.allResultTypes = true;
3343 } else {
3344 return emitError(loc, "invalid argument to 'type' directive");
3346 return element;
3349 LogicalResult OpFormatParser::verifyOptionalGroupElements(
3350 SMLoc loc, ArrayRef<FormatElement *> elements, FormatElement *anchor) {
3351 for (FormatElement *element : elements) {
3352 if (failed(verifyOptionalGroupElement(loc, element, element == anchor)))
3353 return failure();
3355 return success();
3358 LogicalResult OpFormatParser::verifyOptionalGroupElement(SMLoc loc,
3359 FormatElement *element,
3360 bool isAnchor) {
3361 return TypeSwitch<FormatElement *, LogicalResult>(element)
3362 // All attributes can be within the optional group, but only optional
3363 // attributes can be the anchor.
3364 .Case([&](AttributeVariable *attrEle) {
3365 Attribute attr = attrEle->getVar()->attr;
3366 if (isAnchor && !(attr.isOptional() || attr.hasDefaultValue()))
3367 return emitError(loc, "only optional or default-valued attributes "
3368 "can be used to anchor an optional group");
3369 return success();
3371 // Only optional-like(i.e. variadic) operands can be within an optional
3372 // group.
3373 .Case([&](OperandVariable *ele) {
3374 if (!ele->getVar()->isVariableLength())
3375 return emitError(loc, "only variable length operands can be used "
3376 "within an optional group");
3377 return success();
3379 // Only optional-like(i.e. variadic) results can be within an optional
3380 // group.
3381 .Case([&](ResultVariable *ele) {
3382 if (!ele->getVar()->isVariableLength())
3383 return emitError(loc, "only variable length results can be used "
3384 "within an optional group");
3385 return success();
3387 .Case([&](RegionVariable *) {
3388 // TODO: When ODS has proper support for marking "optional" regions, add
3389 // a check here.
3390 return success();
3392 .Case([&](TypeDirective *ele) {
3393 return verifyOptionalGroupElement(loc, ele->getArg(),
3394 /*isAnchor=*/false);
3396 .Case([&](FunctionalTypeDirective *ele) {
3397 if (failed(verifyOptionalGroupElement(loc, ele->getInputs(),
3398 /*isAnchor=*/false)))
3399 return failure();
3400 return verifyOptionalGroupElement(loc, ele->getResults(),
3401 /*isAnchor=*/false);
3403 // Literals, whitespace, and custom directives may be used, but they can't
3404 // anchor the group.
3405 .Case<LiteralElement, WhitespaceElement, CustomDirective,
3406 FunctionalTypeDirective, OptionalElement>([&](FormatElement *) {
3407 if (isAnchor)
3408 return emitError(loc, "only variables and types can be used "
3409 "to anchor an optional group");
3410 return success();
3412 .Default([&](FormatElement *) {
3413 return emitError(loc, "only literals, types, and variables can be "
3414 "used within an optional group");
3418 //===----------------------------------------------------------------------===//
3419 // Interface
3420 //===----------------------------------------------------------------------===//
3422 void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass) {
3423 // TODO: Operator doesn't expose all necessary functionality via
3424 // the const interface.
3425 Operator &op = const_cast<Operator &>(constOp);
3426 if (!op.hasAssemblyFormat())
3427 return;
3429 // Parse the format description.
3430 llvm::SourceMgr mgr;
3431 mgr.AddNewSourceBuffer(
3432 llvm::MemoryBuffer::getMemBuffer(op.getAssemblyFormat()), SMLoc());
3433 OperationFormat format(op);
3434 OpFormatParser parser(mgr, format, op);
3435 FailureOr<std::vector<FormatElement *>> elements = parser.parse();
3436 if (failed(elements)) {
3437 // Exit the process if format errors are treated as fatal.
3438 if (formatErrorIsFatal) {
3439 // Invoke the interrupt handlers to run the file cleanup handlers.
3440 llvm::sys::RunInterruptHandlers();
3441 std::exit(1);
3443 return;
3445 format.elements = std::move(*elements);
3447 // Generate the printer and parser based on the parsed format.
3448 format.genParser(op, opClass);
3449 format.genPrinter(op, opClass);