1 //===- OpFormatGen.cpp - MLIR operation asm format generator --------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "OpFormatGen.h"
10 #include "FormatGen.h"
12 #include "mlir/Support/LLVM.h"
13 #include "mlir/TableGen/Class.h"
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "mlir/TableGen/Trait.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallBitVector.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Signals.h"
24 #include "llvm/Support/SourceMgr.h"
25 #include "llvm/TableGen/Record.h"
27 #define DEBUG_TYPE "mlir-tblgen-opformatgen"
30 using namespace mlir::tblgen
;
32 //===----------------------------------------------------------------------===//
36 /// This class represents an instance of an op variable element. A variable
37 /// refers to something registered on the operation itself, e.g. an operand,
38 /// result, attribute, region, or successor.
39 template <typename VarT
, VariableElement::Kind VariableKind
>
40 class OpVariableElement
: public VariableElementBase
<VariableKind
> {
42 using Base
= OpVariableElement
<VarT
, VariableKind
>;
44 /// Create an op variable element with the variable value.
45 OpVariableElement(const VarT
*var
) : var(var
) {}
48 const VarT
*getVar() { return var
; }
51 /// The op variable, e.g. a type or attribute constraint.
55 /// This class represents a variable that refers to an attribute argument.
56 struct AttributeVariable
57 : public OpVariableElement
<NamedAttribute
, VariableElement::Attribute
> {
60 /// Return the constant builder call for the type of this attribute, or
61 /// std::nullopt if it doesn't have one.
62 std::optional
<StringRef
> getTypeBuilder() const {
63 std::optional
<Type
> attrType
= var
->attr
.getValueType();
64 return attrType
? attrType
->getBuilderCall() : std::nullopt
;
67 /// Return if this attribute refers to a UnitAttr.
68 bool isUnitAttr() const {
69 return var
->attr
.getBaseAttr().getAttrDefName() == "UnitAttr";
72 /// Indicate if this attribute is printed "qualified" (that is it is
73 /// prefixed with the `#dialect.mnemonic`).
74 bool shouldBeQualified() { return shouldBeQualifiedFlag
; }
75 void setShouldBeQualified(bool qualified
= true) {
76 shouldBeQualifiedFlag
= qualified
;
80 bool shouldBeQualifiedFlag
= false;
83 /// This class represents a variable that refers to an operand argument.
84 using OperandVariable
=
85 OpVariableElement
<NamedTypeConstraint
, VariableElement::Operand
>;
87 /// This class represents a variable that refers to a result.
88 using ResultVariable
=
89 OpVariableElement
<NamedTypeConstraint
, VariableElement::Result
>;
91 /// This class represents a variable that refers to a region.
92 using RegionVariable
= OpVariableElement
<NamedRegion
, VariableElement::Region
>;
94 /// This class represents a variable that refers to a successor.
95 using SuccessorVariable
=
96 OpVariableElement
<NamedSuccessor
, VariableElement::Successor
>;
98 /// This class represents a variable that refers to a property argument.
99 using PropertyVariable
=
100 OpVariableElement
<NamedProperty
, VariableElement::Property
>;
103 //===----------------------------------------------------------------------===//
107 /// This class represents the `operands` directive. This directive represents
108 /// all of the operands of an operation.
109 using OperandsDirective
= DirectiveElementBase
<DirectiveElement::Operands
>;
111 /// This class represents the `results` directive. This directive represents
112 /// all of the results of an operation.
113 using ResultsDirective
= DirectiveElementBase
<DirectiveElement::Results
>;
115 /// This class represents the `regions` directive. This directive represents
116 /// all of the regions of an operation.
117 using RegionsDirective
= DirectiveElementBase
<DirectiveElement::Regions
>;
119 /// This class represents the `successors` directive. This directive represents
120 /// all of the successors of an operation.
121 using SuccessorsDirective
= DirectiveElementBase
<DirectiveElement::Successors
>;
123 /// This class represents the `attr-dict` directive. This directive represents
124 /// the attribute dictionary of the operation.
125 class AttrDictDirective
126 : public DirectiveElementBase
<DirectiveElement::AttrDict
> {
128 explicit AttrDictDirective(bool withKeyword
) : withKeyword(withKeyword
) {}
130 /// Return whether the dictionary should be printed with the 'attributes'
132 bool isWithKeyword() const { return withKeyword
; }
135 /// If the dictionary should be printed with the 'attributes' keyword.
139 /// This class represents the `prop-dict` directive. This directive represents
140 /// the properties of the operation, expressed as a directionary.
141 class PropDictDirective
142 : public DirectiveElementBase
<DirectiveElement::PropDict
> {
144 explicit PropDictDirective() = default;
147 /// This class represents the `functional-type` directive. This directive takes
148 /// two arguments and formats them, respectively, as the inputs and results of a
150 class FunctionalTypeDirective
151 : public DirectiveElementBase
<DirectiveElement::FunctionalType
> {
153 FunctionalTypeDirective(FormatElement
*inputs
, FormatElement
*results
)
154 : inputs(inputs
), results(results
) {}
156 FormatElement
*getInputs() const { return inputs
; }
157 FormatElement
*getResults() const { return results
; }
160 /// The input and result arguments.
161 FormatElement
*inputs
, *results
;
164 /// This class represents the `type` directive.
165 class TypeDirective
: public DirectiveElementBase
<DirectiveElement::Type
> {
167 TypeDirective(FormatElement
*arg
) : arg(arg
) {}
169 FormatElement
*getArg() const { return arg
; }
171 /// Indicate if this type is printed "qualified" (that is it is
172 /// prefixed with the `!dialect.mnemonic`).
173 bool shouldBeQualified() { return shouldBeQualifiedFlag
; }
174 void setShouldBeQualified(bool qualified
= true) {
175 shouldBeQualifiedFlag
= qualified
;
179 /// The argument that is used to format the directive.
182 bool shouldBeQualifiedFlag
= false;
185 /// This class represents a group of order-independent optional clauses. Each
186 /// clause starts with a literal element and has a coressponding parsing
187 /// element. A parsing element is a continous sequence of format elements.
188 /// Each clause can appear 0 or 1 time.
189 class OIListElement
: public DirectiveElementBase
<DirectiveElement::OIList
> {
191 OIListElement(std::vector
<FormatElement
*> &&literalElements
,
192 std::vector
<std::vector
<FormatElement
*>> &&parsingElements
)
193 : literalElements(std::move(literalElements
)),
194 parsingElements(std::move(parsingElements
)) {}
196 /// Returns a range to iterate over the LiteralElements.
197 auto getLiteralElements() const {
198 function_ref
<LiteralElement
*(FormatElement
* el
)>
199 literalElementCastConverter
=
200 [](FormatElement
*el
) { return cast
<LiteralElement
>(el
); };
201 return llvm::map_range(literalElements
, literalElementCastConverter
);
204 /// Returns a range to iterate over the parsing elements corresponding to the
206 ArrayRef
<std::vector
<FormatElement
*>> getParsingElements() const {
207 return parsingElements
;
210 /// Returns a range to iterate over tuples of parsing and literal elements.
211 auto getClauses() const {
212 return llvm::zip(getLiteralElements(), getParsingElements());
215 /// If the parsing element is a single UnitAttr element, then it returns the
216 /// attribute variable. Otherwise, returns nullptr.
218 getUnitAttrParsingElement(ArrayRef
<FormatElement
*> pelement
) {
219 if (pelement
.size() == 1) {
220 auto *attrElem
= dyn_cast
<AttributeVariable
>(pelement
[0]);
221 if (attrElem
&& attrElem
->isUnitAttr())
228 /// A vector of `LiteralElement` objects. Each element stores the keyword
229 /// for one case of oilist element. For example, an oilist element along with
230 /// the `literalElements` vector:
232 /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`]
233 /// literalElements = { `keyword`, `otherKeyword` }
235 std::vector
<FormatElement
*> literalElements
;
237 /// A vector of valid declarative assembly format vectors. Each object in
238 /// parsing elements is a vector of elements in assembly format syntax.
239 /// For example, an oilist element along with the parsingElements vector:
241 /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`]
242 /// parsingElements = {
243 /// { `=`, `(`, $arg0, `)` },
244 /// { `<`, $arg1, `>` }
247 std::vector
<std::vector
<FormatElement
*>> parsingElements
;
251 //===----------------------------------------------------------------------===//
253 //===----------------------------------------------------------------------===//
257 using ConstArgument
=
258 llvm::PointerUnion
<const NamedAttribute
*, const NamedTypeConstraint
*>;
260 struct OperationFormat
{
261 /// This class represents a specific resolver for an operand or result type.
262 class TypeResolution
{
264 TypeResolution() = default;
266 /// Get the index into the buildable types for this type, or std::nullopt.
267 std::optional
<int> getBuilderIdx() const { return builderIdx
; }
268 void setBuilderIdx(int idx
) { builderIdx
= idx
; }
270 /// Get the variable this type is resolved to, or nullptr.
271 const NamedTypeConstraint
*getVariable() const {
272 return llvm::dyn_cast_if_present
<const NamedTypeConstraint
*>(resolver
);
274 /// Get the attribute this type is resolved to, or nullptr.
275 const NamedAttribute
*getAttribute() const {
276 return llvm::dyn_cast_if_present
<const NamedAttribute
*>(resolver
);
278 /// Get the transformer for the type of the variable, or std::nullopt.
279 std::optional
<StringRef
> getVarTransformer() const {
280 return variableTransformer
;
282 void setResolver(ConstArgument arg
, std::optional
<StringRef
> transformer
) {
284 variableTransformer
= transformer
;
285 assert(getVariable() || getAttribute());
289 /// If the type is resolved with a buildable type, this is the index into
290 /// 'buildableTypes' in the parent format.
291 std::optional
<int> builderIdx
;
292 /// If the type is resolved based upon another operand or result, this is
293 /// the variable or the attribute that this type is resolved to.
294 ConstArgument resolver
;
295 /// If the type is resolved based upon another operand or result, this is
296 /// a transformer to apply to the variable when resolving.
297 std::optional
<StringRef
> variableTransformer
;
300 /// The context in which an element is generated.
301 enum class GenContext
{
302 /// The element is generated at the top-level or with the same behaviour.
304 /// The element is generated inside an optional group.
308 OperationFormat(const Operator
&op
)
309 : useProperties(op
.getDialect().usePropertiesForAttributes() &&
310 !op
.getAttributes().empty()),
311 opCppClassName(op
.getCppClassName()) {
312 operandTypes
.resize(op
.getNumOperands(), TypeResolution());
313 resultTypes
.resize(op
.getNumResults(), TypeResolution());
315 hasImplicitTermTrait
= llvm::any_of(op
.getTraits(), [](const Trait
&trait
) {
316 return trait
.getDef().isSubClassOf("SingleBlockImplicitTerminatorImpl");
319 hasSingleBlockTrait
= op
.getTrait("::mlir::OpTrait::SingleBlock");
322 /// Generate the operation parser from this format.
323 void genParser(Operator
&op
, OpClass
&opClass
);
324 /// Generate the parser code for a specific format element.
325 void genElementParser(FormatElement
*element
, MethodBody
&body
,
326 FmtContext
&attrTypeCtx
,
327 GenContext genCtx
= GenContext::Normal
);
328 /// Generate the C++ to resolve the types of operands and results during
330 void genParserTypeResolution(Operator
&op
, MethodBody
&body
);
331 /// Generate the C++ to resolve the types of the operands during parsing.
332 void genParserOperandTypeResolution(
333 Operator
&op
, MethodBody
&body
,
334 function_ref
<void(TypeResolution
&, StringRef
)> emitTypeResolver
);
335 /// Generate the C++ to resolve regions during parsing.
336 void genParserRegionResolution(Operator
&op
, MethodBody
&body
);
337 /// Generate the C++ to resolve successors during parsing.
338 void genParserSuccessorResolution(Operator
&op
, MethodBody
&body
);
339 /// Generate the C++ to handling variadic segment size traits.
340 void genParserVariadicSegmentResolution(Operator
&op
, MethodBody
&body
);
342 /// Generate the operation printer from this format.
343 void genPrinter(Operator
&op
, OpClass
&opClass
);
345 /// Generate the printer code for a specific format element.
346 void genElementPrinter(FormatElement
*element
, MethodBody
&body
, Operator
&op
,
347 bool &shouldEmitSpace
, bool &lastWasPunctuation
);
349 /// The various elements in this format.
350 std::vector
<FormatElement
*> elements
;
352 /// A flag indicating if all operand/result types were seen. If the format
353 /// contains these, it can not contain individual type resolvers.
354 bool allOperands
= false, allOperandTypes
= false, allResultTypes
= false;
356 /// A flag indicating if this operation infers its result types
357 bool infersResultTypes
= false;
359 /// A flag indicating if this operation has the SingleBlockImplicitTerminator
361 bool hasImplicitTermTrait
;
363 /// A flag indicating if this operation has the SingleBlock trait.
364 bool hasSingleBlockTrait
;
366 /// Indicate whether attribute are stored in properties.
369 /// Indicate whether prop-dict is used in the format
372 /// The Operation class name
373 StringRef opCppClassName
;
375 /// A map of buildable types to indices.
376 llvm::MapVector
<StringRef
, int, llvm::StringMap
<int>> buildableTypes
;
378 /// The index of the buildable type, if valid, for every operand and result.
379 std::vector
<TypeResolution
> operandTypes
, resultTypes
;
381 /// The set of attributes explicitly used within the format.
382 llvm::SmallSetVector
<const NamedAttribute
*, 8> usedAttributes
;
383 llvm::StringSet
<> inferredAttributes
;
385 /// The set of properties explicitly used within the format.
386 llvm::SmallSetVector
<const NamedProperty
*, 8> usedProperties
;
390 //===----------------------------------------------------------------------===//
393 /// Returns true if we can format the given attribute as an EnumAttr in the
395 static bool canFormatEnumAttr(const NamedAttribute
*attr
) {
396 Attribute baseAttr
= attr
->attr
.getBaseAttr();
397 const EnumAttr
*enumAttr
= dyn_cast
<EnumAttr
>(&baseAttr
);
401 // The attribute must have a valid underlying type and a constant builder.
402 return !enumAttr
->getUnderlyingType().empty() &&
403 !enumAttr
->getConstBuilderTemplate().empty();
406 /// Returns if we should format the given attribute as an SymbolNameAttr.
407 static bool shouldFormatSymbolNameAttr(const NamedAttribute
*attr
) {
408 return attr
->attr
.getBaseAttr().getAttrDefName() == "SymbolNameAttr";
411 /// The code snippet used to generate a parser call for an attribute.
413 /// {0}: The name of the attribute.
414 /// {1}: The type for the attribute.
415 const char *const attrParserCode
= R
"(
416 if (parser.parseCustomAttributeWithFallback({0}Attr, {1})) {{
417 return ::mlir::failure();
421 /// The code snippet used to generate a parser call for an attribute.
423 /// {0}: The name of the attribute.
424 /// {1}: The type for the attribute.
425 const char *const genericAttrParserCode
= R
"(
426 if (parser.parseAttribute({0}Attr, {1}))
427 return ::mlir::failure();
430 const char *const optionalAttrParserCode
= R
"(
431 ::mlir::OptionalParseResult parseResult{0}Attr =
432 parser.parseOptionalAttribute({0}Attr, {1});
433 if (parseResult{0}Attr.has_value() && failed(*parseResult{0}Attr))
434 return ::mlir::failure();
435 if (parseResult{0}Attr.has_value() && succeeded(*parseResult{0}Attr))
438 /// The code snippet used to generate a parser call for a symbol name attribute.
440 /// {0}: The name of the attribute.
441 const char *const symbolNameAttrParserCode
= R
"(
442 if (parser.parseSymbolName({0}Attr))
443 return ::mlir::failure();
445 const char *const optionalSymbolNameAttrParserCode
= R
"(
446 // Parsing an optional symbol name doesn't fail, so no need to check the
448 (void)parser.parseOptionalSymbolName({0}Attr);
451 /// The code snippet used to generate a parser call for an enum attribute.
453 /// {0}: The name of the attribute.
454 /// {1}: The c++ namespace for the enum symbolize functions.
455 /// {2}: The function to symbolize a string of the enum.
456 /// {3}: The constant builder call to create an attribute of the enum type.
457 /// {4}: The set of allowed enum keywords.
458 /// {5}: The error message on failure when the enum isn't present.
459 /// {6}: The attribute assignment expression
460 const char *const enumAttrParserCode
= R
"(
462 ::llvm::StringRef attrStr;
463 ::mlir::NamedAttrList attrStorage;
464 auto loc = parser.getCurrentLocation();
465 if (parser.parseOptionalKeyword(&attrStr, {4})) {
466 ::mlir::StringAttr attrVal;
467 ::mlir::OptionalParseResult parseResult =
468 parser.parseOptionalAttribute(attrVal,
469 parser.getBuilder().getNoneType(),
471 if (parseResult.has_value()) {{
472 if (failed(*parseResult))
473 return ::mlir::failure();
474 attrStr = attrVal.getValue();
479 if (!attrStr.empty()) {
480 auto attrOptional = {1}::{2}(attrStr);
482 return parser.emitError(loc, "invalid
")
483 << "{0} attribute specification
: \"" << attrStr << '"';;
491 /// The code snippet used to generate a parser call for an operand.
493 /// {0}: The name of the operand.
494 const char *const variadicOperandParserCode = R"(
495 {0}OperandsLoc = parser.getCurrentLocation();
496 if (parser.parseOperandList({0}Operands))
497 return ::mlir::failure();
499 const char *const optionalOperandParserCode = R"(
501 {0}OperandsLoc = parser.getCurrentLocation();
502 ::mlir::OpAsmParser::UnresolvedOperand operand;
503 ::mlir::OptionalParseResult parseResult =
504 parser.parseOptionalOperand(operand);
505 if (parseResult.has_value()) {
506 if (failed(*parseResult))
507 return ::mlir::failure();
508 {0}Operands.push_back(operand);
512 const char *const operandParserCode = R"(
513 {0}OperandsLoc = parser.getCurrentLocation();
514 if (parser.parseOperand({0}RawOperand))
515 return ::mlir::failure();
517 /// The code snippet used to generate a parser call for a VariadicOfVariadic
520 /// {0}: The name of the operand.
521 /// {1}: The name of segment size attribute.
522 const char *const variadicOfVariadicOperandParserCode = R"(
524 {0}OperandsLoc = parser.getCurrentLocation();
527 if (parser.parseOptionalLParen())
529 if (parser.parseOperandList({0}Operands) || parser.parseRParen())
530 return ::mlir::failure();
531 {0}OperandGroupSizes.push_back({0}Operands.size() - curSize);
532 curSize = {0}Operands.size();
533 } while (succeeded(parser.parseOptionalComma()));
537 /// The code snippet used to generate a parser call for a type list.
539 /// {0}: The name for the type list.
540 const char *const variadicOfVariadicTypeParserCode = R"(
542 if (parser.parseOptionalLParen())
544 if (parser.parseOptionalRParen() &&
545 (parser.parseTypeList({0}Types) || parser.parseRParen()))
546 return ::mlir::failure();
547 } while (succeeded(parser.parseOptionalComma()));
549 const char *const variadicTypeParserCode = R"(
550 if (parser.parseTypeList({0}Types))
551 return ::mlir::failure();
553 const char *const optionalTypeParserCode = R"(
555 ::mlir::Type optionalType;
556 ::mlir::OptionalParseResult parseResult =
557 parser.parseOptionalType(optionalType);
558 if (parseResult.has_value()) {
559 if (failed(*parseResult))
560 return ::mlir::failure();
561 {0}Types.push_back(optionalType);
565 const char *const typeParserCode = R"(
568 if (parser.parseCustomTypeWithFallback(type))
569 return ::mlir::failure();
573 const char *const qualifiedTypeParserCode = R"(
574 if (parser.parseType({1}RawType))
575 return ::mlir::failure();
578 /// The code snippet used to generate a parser call for a functional type.
580 /// {0}: The name for the input type list.
581 /// {1}: The name for the result type list.
582 const char *const functionalTypeParserCode = R"(
583 ::mlir::FunctionType {0}__{1}_functionType;
584 if (parser.parseType({0}__{1}_functionType))
585 return ::mlir::failure();
586 {0}Types = {0}__{1}_functionType.getInputs();
587 {1}Types = {0}__{1}_functionType.getResults();
590 /// The code snippet used to generate a parser call to infer return types.
592 /// {0}: The operation class name
593 const char *const inferReturnTypesParserCode = R"(
594 ::llvm::SmallVector<::mlir::Type> inferredReturnTypes;
595 if (::mlir::failed({0}::inferReturnTypes(parser.getContext(),
596 result.location, result.operands,
597 result.attributes.getDictionary(parser.getContext()),
598 result.getRawProperties(),
599 result.regions, inferredReturnTypes)))
600 return ::mlir::failure();
601 result.addTypes(inferredReturnTypes);
604 /// The code snippet used to generate a parser call for a region list.
606 /// {0}: The name for the region list.
607 const char *regionListParserCode = R"(
609 std::unique_ptr<::mlir::Region> region;
610 auto firstRegionResult = parser.parseOptionalRegion(region);
611 if (firstRegionResult.has_value()) {
612 if (failed(*firstRegionResult))
613 return ::mlir::failure();
614 {0}Regions.emplace_back(std::move(region));
616 // Parse any trailing regions.
617 while (succeeded(parser.parseOptionalComma())) {
618 region = std::make_unique<::mlir::Region>();
619 if (parser.parseRegion(*region))
620 return ::mlir::failure();
621 {0}Regions.emplace_back(std::move(region));
627 /// The code snippet used to ensure a list of regions have terminators.
629 /// {0}: The name of the region list.
630 const char *regionListEnsureTerminatorParserCode = R"(
631 for (auto ®ion : {0}Regions)
632 ensureTerminator(*region, parser.getBuilder(), result.location);
635 /// The code snippet used to ensure a list of regions have a block.
637 /// {0}: The name of the region list.
638 const char *regionListEnsureSingleBlockParserCode = R"(
639 for (auto ®ion : {0}Regions)
640 if (region->empty()) region->emplaceBlock();
643 /// The code snippet used to generate a parser call for an optional region.
645 /// {0}: The name of the region.
646 const char *optionalRegionParserCode = R"(
648 auto parseResult = parser.parseOptionalRegion(*{0}Region);
649 if (parseResult.has_value() && failed(*parseResult))
650 return ::mlir::failure();
654 /// The code snippet used to generate a parser call for a region.
656 /// {0}: The name of the region.
657 const char *regionParserCode = R"(
658 if (parser.parseRegion(*{0}Region))
659 return ::mlir::failure();
662 /// The code snippet used to ensure a region has a terminator.
664 /// {0}: The name of the region.
665 const char *regionEnsureTerminatorParserCode = R"(
666 ensureTerminator(*{0}Region, parser.getBuilder(), result.location);
669 /// The code snippet used to ensure a region has a block.
671 /// {0}: The name of the region.
672 const char *regionEnsureSingleBlockParserCode = R"(
673 if ({0}Region->empty()) {0}Region->emplaceBlock();
676 /// The code snippet used to generate a parser call for a successor list.
678 /// {0}: The name for the successor list.
679 const char *successorListParserCode = R"(
682 auto firstSucc = parser.parseOptionalSuccessor(succ);
683 if (firstSucc.has_value()) {
684 if (failed(*firstSucc))
685 return ::mlir::failure();
686 {0}Successors.emplace_back(succ);
688 // Parse any trailing successors.
689 while (succeeded(parser.parseOptionalComma())) {
690 if (parser.parseSuccessor(succ))
691 return ::mlir::failure();
692 {0}Successors.emplace_back(succ);
698 /// The code snippet used to generate a parser call for a successor.
700 /// {0}: The name of the successor.
701 const char *successorParserCode = R"(
702 if (parser.parseSuccessor({0}Successor))
703 return ::mlir::failure();
706 /// The code snippet used to generate a parser for OIList
708 /// {0}: literal keyword corresponding to a case for oilist
709 const char *oilistParserCode = R"(
711 return parser.emitError(parser.getNameLoc())
712 << "`{0}` clause can appear at most once in the expansion of the "
719 /// The type of length for a given parse argument.
720 enum class ArgumentLengthKind {
721 /// The argument is a variadic of a variadic, and may contain 0->N range
724 /// The argument is variadic, and may contain 0->N elements.
726 /// The argument is optional, and may contain 0 or 1 elements.
728 /// The argument is a single element, i.e. always represents 1 element.
733 /// Get the length kind for the given constraint.
734 static ArgumentLengthKind
735 getArgumentLengthKind(const NamedTypeConstraint *var) {
736 if (var->isOptional())
737 return ArgumentLengthKind::Optional;
738 if (var->isVariadicOfVariadic())
739 return ArgumentLengthKind::VariadicOfVariadic;
740 if (var->isVariadic())
741 return ArgumentLengthKind::Variadic;
742 return ArgumentLengthKind::Single;
745 /// Get the name used for the type list for the given type directive operand.
746 /// 'lengthKind
' to the corresponding kind for the given argument.
747 static StringRef getTypeListName(FormatElement *arg,
748 ArgumentLengthKind &lengthKind) {
749 if (auto *operand = dyn_cast<OperandVariable>(arg)) {
750 lengthKind = getArgumentLengthKind(operand->getVar());
751 return operand->getVar()->name;
753 if (auto *result = dyn_cast<ResultVariable>(arg)) {
754 lengthKind = getArgumentLengthKind(result->getVar());
755 return result->getVar()->name;
757 lengthKind = ArgumentLengthKind::Variadic;
758 if (isa<OperandsDirective>(arg))
760 if (isa<ResultsDirective>(arg))
762 llvm_unreachable("unknown 'type
' directive argument");
765 /// Generate the parser for a literal value.
766 static void genLiteralParser(StringRef value, MethodBody &body) {
767 // Handle the case of a keyword/identifier.
768 if (value.front() == '_
' || isalpha(value.front())) {
769 body << "Keyword(\"" << value << "\")";
772 body << (StringRef)StringSwitch<StringRef>(value)
773 .Case("->", "Arrow()")
774 .Case(":", "Colon()")
775 .Case(",", "Comma()")
776 .Case("=", "Equal()")
778 .Case(">", "Greater()")
779 .Case("{", "LBrace()")
780 .Case("}", "RBrace()")
781 .Case("(", "LParen()")
782 .Case(")", "RParen()")
783 .Case("[", "LSquare()")
784 .Case("]", "RSquare()")
785 .Case("?", "Question()")
788 .Case("...", "Ellipsis()");
791 /// Generate the storage code required for parsing the given element.
792 static void genElementParserStorage(FormatElement *element, const Operator &op,
794 if (auto *optional = dyn_cast<OptionalElement>(element)) {
795 ArrayRef<FormatElement *> elements = optional->getThenElements();
797 // If the anchor is a unit attribute, it won't be parsed directly so elide
799 auto *anchor
= dyn_cast
<AttributeVariable
>(optional
->getAnchor());
800 FormatElement
*elidedAnchorElement
= nullptr;
801 if (anchor
&& anchor
!= elements
.front() && anchor
->isUnitAttr())
802 elidedAnchorElement
= anchor
;
803 for (FormatElement
*childElement
: elements
)
804 if (childElement
!= elidedAnchorElement
)
805 genElementParserStorage(childElement
, op
, body
);
806 for (FormatElement
*childElement
: optional
->getElseElements())
807 genElementParserStorage(childElement
, op
, body
);
809 } else if (auto *oilist
= dyn_cast
<OIListElement
>(element
)) {
810 for (ArrayRef
<FormatElement
*> pelement
: oilist
->getParsingElements()) {
811 if (!oilist
->getUnitAttrParsingElement(pelement
))
812 for (FormatElement
*element
: pelement
)
813 genElementParserStorage(element
, op
, body
);
816 } else if (auto *custom
= dyn_cast
<CustomDirective
>(element
)) {
817 for (FormatElement
*paramElement
: custom
->getArguments())
818 genElementParserStorage(paramElement
, op
, body
);
820 } else if (isa
<OperandsDirective
>(element
)) {
821 body
<< " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> "
824 } else if (isa
<RegionsDirective
>(element
)) {
825 body
<< " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
828 } else if (isa
<SuccessorsDirective
>(element
)) {
829 body
<< " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
831 } else if (auto *attr
= dyn_cast
<AttributeVariable
>(element
)) {
832 const NamedAttribute
*var
= attr
->getVar();
833 body
<< llvm::formatv(" {0} {1}Attr;\n", var
->attr
.getStorageType(),
836 } else if (auto *operand
= dyn_cast
<OperandVariable
>(element
)) {
837 StringRef name
= operand
->getVar()->name
;
838 if (operand
->getVar()->isVariableLength()) {
840 << " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> "
841 << name
<< "Operands;\n";
842 if (operand
->getVar()->isVariadicOfVariadic()) {
843 body
<< " llvm::SmallVector<int32_t> " << name
844 << "OperandGroupSizes;\n";
847 body
<< " ::mlir::OpAsmParser::UnresolvedOperand " << name
849 << " ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> "
850 << name
<< "Operands(&" << name
<< "RawOperand, 1);";
852 body
<< llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
853 " (void){0}OperandsLoc;\n",
856 } else if (auto *region
= dyn_cast
<RegionVariable
>(element
)) {
857 StringRef name
= region
->getVar()->name
;
858 if (region
->getVar()->isVariadic()) {
859 body
<< llvm::formatv(
860 " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
864 body
<< llvm::formatv(" std::unique_ptr<::mlir::Region> {0}Region = "
865 "std::make_unique<::mlir::Region>();\n",
869 } else if (auto *successor
= dyn_cast
<SuccessorVariable
>(element
)) {
870 StringRef name
= successor
->getVar()->name
;
871 if (successor
->getVar()->isVariadic()) {
872 body
<< llvm::formatv(" ::llvm::SmallVector<::mlir::Block *, 2> "
876 body
<< llvm::formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name
);
879 } else if (auto *dir
= dyn_cast
<TypeDirective
>(element
)) {
880 ArgumentLengthKind lengthKind
;
881 StringRef name
= getTypeListName(dir
->getArg(), lengthKind
);
882 if (lengthKind
!= ArgumentLengthKind::Single
)
883 body
<< " ::llvm::SmallVector<::mlir::Type, 1> " << name
<< "Types;\n";
886 << llvm::formatv(" ::mlir::Type {0}RawType{{};\n", name
)
888 " ::llvm::ArrayRef<::mlir::Type> {0}Types(&{0}RawType, 1);\n",
890 } else if (auto *dir
= dyn_cast
<FunctionalTypeDirective
>(element
)) {
891 ArgumentLengthKind ignored
;
892 body
<< " ::llvm::ArrayRef<::mlir::Type> "
893 << getTypeListName(dir
->getInputs(), ignored
) << "Types;\n";
894 body
<< " ::llvm::ArrayRef<::mlir::Type> "
895 << getTypeListName(dir
->getResults(), ignored
) << "Types;\n";
899 /// Generate the parser for a parameter to a custom directive.
900 static void genCustomParameterParser(FormatElement
*param
, MethodBody
&body
) {
901 if (auto *attr
= dyn_cast
<AttributeVariable
>(param
)) {
902 body
<< attr
->getVar()->name
<< "Attr";
903 } else if (isa
<AttrDictDirective
>(param
)) {
904 body
<< "result.attributes";
905 } else if (isa
<PropDictDirective
>(param
)) {
907 } else if (auto *operand
= dyn_cast
<OperandVariable
>(param
)) {
908 StringRef name
= operand
->getVar()->name
;
909 ArgumentLengthKind lengthKind
= getArgumentLengthKind(operand
->getVar());
910 if (lengthKind
== ArgumentLengthKind::VariadicOfVariadic
)
911 body
<< llvm::formatv("{0}OperandGroups", name
);
912 else if (lengthKind
== ArgumentLengthKind::Variadic
)
913 body
<< llvm::formatv("{0}Operands", name
);
914 else if (lengthKind
== ArgumentLengthKind::Optional
)
915 body
<< llvm::formatv("{0}Operand", name
);
917 body
<< formatv("{0}RawOperand", name
);
919 } else if (auto *region
= dyn_cast
<RegionVariable
>(param
)) {
920 StringRef name
= region
->getVar()->name
;
921 if (region
->getVar()->isVariadic())
922 body
<< llvm::formatv("{0}Regions", name
);
924 body
<< llvm::formatv("*{0}Region", name
);
926 } else if (auto *successor
= dyn_cast
<SuccessorVariable
>(param
)) {
927 StringRef name
= successor
->getVar()->name
;
928 if (successor
->getVar()->isVariadic())
929 body
<< llvm::formatv("{0}Successors", name
);
931 body
<< llvm::formatv("{0}Successor", name
);
933 } else if (auto *dir
= dyn_cast
<RefDirective
>(param
)) {
934 genCustomParameterParser(dir
->getArg(), body
);
936 } else if (auto *dir
= dyn_cast
<TypeDirective
>(param
)) {
937 ArgumentLengthKind lengthKind
;
938 StringRef listName
= getTypeListName(dir
->getArg(), lengthKind
);
939 if (lengthKind
== ArgumentLengthKind::VariadicOfVariadic
)
940 body
<< llvm::formatv("{0}TypeGroups", listName
);
941 else if (lengthKind
== ArgumentLengthKind::Variadic
)
942 body
<< llvm::formatv("{0}Types", listName
);
943 else if (lengthKind
== ArgumentLengthKind::Optional
)
944 body
<< llvm::formatv("{0}Type", listName
);
946 body
<< formatv("{0}RawType", listName
);
948 } else if (auto *string
= dyn_cast
<StringElement
>(param
)) {
950 ctx
.withBuilder("parser.getBuilder()");
951 ctx
.addSubst("_ctxt", "parser.getContext()");
952 body
<< tgfmt(string
->getValue(), &ctx
);
954 } else if (auto *property
= dyn_cast
<PropertyVariable
>(param
)) {
955 body
<< llvm::formatv("result.getOrAddProperties<Properties>().{0}",
956 property
->getVar()->name
);
958 llvm_unreachable("unknown custom directive parameter");
962 /// Generate the parser for a custom directive.
963 static void genCustomDirectiveParser(CustomDirective
*dir
, MethodBody
&body
,
965 StringRef opCppClassName
,
966 bool isOptional
= false) {
969 // Preprocess the directive variables.
970 // * Add a local variable for optional operands and types. This provides a
971 // better API to the user defined parser methods.
972 // * Set the location of operand variables.
973 for (FormatElement
*param
: dir
->getArguments()) {
974 if (auto *operand
= dyn_cast
<OperandVariable
>(param
)) {
975 auto *var
= operand
->getVar();
976 body
<< " " << var
->name
977 << "OperandsLoc = parser.getCurrentLocation();\n";
978 if (var
->isOptional()) {
979 body
<< llvm::formatv(
980 " ::std::optional<::mlir::OpAsmParser::UnresolvedOperand> "
983 } else if (var
->isVariadicOfVariadic()) {
984 body
<< llvm::formatv(" "
985 "::llvm::SmallVector<::llvm::SmallVector<::mlir::"
986 "OpAsmParser::UnresolvedOperand>> "
987 "{0}OperandGroups;\n",
990 } else if (auto *dir
= dyn_cast
<TypeDirective
>(param
)) {
991 ArgumentLengthKind lengthKind
;
992 StringRef listName
= getTypeListName(dir
->getArg(), lengthKind
);
993 if (lengthKind
== ArgumentLengthKind::Optional
) {
994 body
<< llvm::formatv(" ::mlir::Type {0}Type;\n", listName
);
995 } else if (lengthKind
== ArgumentLengthKind::VariadicOfVariadic
) {
996 body
<< llvm::formatv(
997 " ::llvm::SmallVector<llvm::SmallVector<::mlir::Type>> "
1001 } else if (auto *dir
= dyn_cast
<RefDirective
>(param
)) {
1002 FormatElement
*input
= dir
->getArg();
1003 if (auto *operand
= dyn_cast
<OperandVariable
>(input
)) {
1004 if (!operand
->getVar()->isOptional())
1006 body
<< llvm::formatv(
1007 " {0} {1}Operand = {1}Operands.empty() ? {0}() : "
1008 "{1}Operands[0];\n",
1009 "::std::optional<::mlir::OpAsmParser::UnresolvedOperand>",
1010 operand
->getVar()->name
);
1012 } else if (auto *type
= dyn_cast
<TypeDirective
>(input
)) {
1013 ArgumentLengthKind lengthKind
;
1014 StringRef listName
= getTypeListName(type
->getArg(), lengthKind
);
1015 if (lengthKind
== ArgumentLengthKind::Optional
) {
1016 body
<< llvm::formatv(" ::mlir::Type {0}Type = {0}Types.empty() ? "
1017 "::mlir::Type() : {0}Types[0];\n",
1024 body
<< " auto odsResult = parse" << dir
->getName() << "(parser";
1025 for (FormatElement
*param
: dir
->getArguments()) {
1027 genCustomParameterParser(param
, body
);
1032 body
<< " if (!odsResult.has_value()) return {};\n"
1033 << " if (::mlir::failed(*odsResult)) return ::mlir::failure();\n";
1035 body
<< " if (odsResult) return ::mlir::failure();\n";
1038 // After parsing, add handling for any of the optional constructs.
1039 for (FormatElement
*param
: dir
->getArguments()) {
1040 if (auto *attr
= dyn_cast
<AttributeVariable
>(param
)) {
1041 const NamedAttribute
*var
= attr
->getVar();
1042 if (var
->attr
.isOptional() || var
->attr
.hasDefaultValue())
1043 body
<< llvm::formatv(" if ({0}Attr)\n ", var
->name
);
1044 if (useProperties
) {
1046 " result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;\n",
1047 var
->name
, opCppClassName
);
1049 body
<< llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
1053 } else if (auto *operand
= dyn_cast
<OperandVariable
>(param
)) {
1054 const NamedTypeConstraint
*var
= operand
->getVar();
1055 if (var
->isOptional()) {
1056 body
<< llvm::formatv(" if ({0}Operand.has_value())\n"
1057 " {0}Operands.push_back(*{0}Operand);\n",
1059 } else if (var
->isVariadicOfVariadic()) {
1060 body
<< llvm::formatv(
1061 " for (const auto &subRange : {0}OperandGroups) {{\n"
1062 " {0}Operands.append(subRange.begin(), subRange.end());\n"
1063 " {0}OperandGroupSizes.push_back(subRange.size());\n"
1065 var
->name
, var
->constraint
.getVariadicOfVariadicSegmentSizeAttr());
1067 } else if (auto *dir
= dyn_cast
<TypeDirective
>(param
)) {
1068 ArgumentLengthKind lengthKind
;
1069 StringRef listName
= getTypeListName(dir
->getArg(), lengthKind
);
1070 if (lengthKind
== ArgumentLengthKind::Optional
) {
1071 body
<< llvm::formatv(" if ({0}Type)\n"
1072 " {0}Types.push_back({0}Type);\n",
1074 } else if (lengthKind
== ArgumentLengthKind::VariadicOfVariadic
) {
1075 body
<< llvm::formatv(
1076 " for (const auto &subRange : {0}TypeGroups)\n"
1077 " {0}Types.append(subRange.begin(), subRange.end());\n",
1086 /// Generate the parser for a enum attribute.
1087 static void genEnumAttrParser(const NamedAttribute
*var
, MethodBody
&body
,
1088 FmtContext
&attrTypeCtx
, bool parseAsOptional
,
1089 bool useProperties
, StringRef opCppClassName
) {
1090 Attribute baseAttr
= var
->attr
.getBaseAttr();
1091 const EnumAttr
&enumAttr
= cast
<EnumAttr
>(baseAttr
);
1092 std::vector
<EnumAttrCase
> cases
= enumAttr
.getAllCases();
1094 // Generate the code for building an attribute for this enum.
1095 std::string attrBuilderStr
;
1097 llvm::raw_string_ostream
os(attrBuilderStr
);
1098 os
<< tgfmt(enumAttr
.getConstBuilderTemplate(), &attrTypeCtx
,
1102 // Build a string containing the cases that can be formatted as a keyword.
1103 std::string validCaseKeywordsStr
= "{";
1104 llvm::raw_string_ostream
validCaseKeywordsOS(validCaseKeywordsStr
);
1105 for (const EnumAttrCase
&attrCase
: cases
)
1106 if (canFormatStringAsKeyword(attrCase
.getStr()))
1107 validCaseKeywordsOS
<< '"' << attrCase
.getStr() << "\",";
1108 validCaseKeywordsOS
.str().back() = '}';
1110 // If the attribute is not optional, build an error message for the missing
1112 std::string errorMessage
;
1113 if (!parseAsOptional
) {
1114 llvm::raw_string_ostream
errorMessageOS(errorMessage
);
1116 << "return parser.emitError(loc, \"expected string or "
1117 "keyword containing one of the following enum values for attribute '"
1118 << var
->name
<< "' [";
1119 llvm::interleaveComma(cases
, errorMessageOS
, [&](const auto &attrCase
) {
1120 errorMessageOS
<< attrCase
.getStr();
1122 errorMessageOS
<< "]\");";
1124 std::string attrAssignment
;
1125 if (useProperties
) {
1128 "result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;",
1129 var
->name
, opCppClassName
);
1132 formatv("result.addAttribute(\"{0}\", {0}Attr);", var
->name
);
1135 body
<< formatv(enumAttrParserCode
, var
->name
, enumAttr
.getCppNamespace(),
1136 enumAttr
.getStringToSymbolFnName(), attrBuilderStr
,
1137 validCaseKeywordsStr
, errorMessage
, attrAssignment
);
1140 // Generate the parser for an attribute.
1141 static void genAttrParser(AttributeVariable
*attr
, MethodBody
&body
,
1142 FmtContext
&attrTypeCtx
, bool parseAsOptional
,
1143 bool useProperties
, StringRef opCppClassName
) {
1144 const NamedAttribute
*var
= attr
->getVar();
1146 // Check to see if we can parse this as an enum attribute.
1147 if (canFormatEnumAttr(var
))
1148 return genEnumAttrParser(var
, body
, attrTypeCtx
, parseAsOptional
,
1149 useProperties
, opCppClassName
);
1151 // Check to see if we should parse this as a symbol name attribute.
1152 if (shouldFormatSymbolNameAttr(var
)) {
1153 body
<< formatv(parseAsOptional
? optionalSymbolNameAttrParserCode
1154 : symbolNameAttrParserCode
,
1158 // If this attribute has a buildable type, use that when parsing the
1160 std::string attrTypeStr
;
1161 if (std::optional
<StringRef
> typeBuilder
= attr
->getTypeBuilder()) {
1162 llvm::raw_string_ostream
os(attrTypeStr
);
1163 os
<< tgfmt(*typeBuilder
, &attrTypeCtx
);
1165 attrTypeStr
= "::mlir::Type{}";
1167 if (parseAsOptional
) {
1168 body
<< formatv(optionalAttrParserCode
, var
->name
, attrTypeStr
);
1170 if (attr
->shouldBeQualified() ||
1171 var
->attr
.getStorageType() == "::mlir::Attribute")
1172 body
<< formatv(genericAttrParserCode
, var
->name
, attrTypeStr
);
1174 body
<< formatv(attrParserCode
, var
->name
, attrTypeStr
);
1177 if (useProperties
) {
1179 " if ({0}Attr) result.getOrAddProperties<{1}::Properties>().{0} = "
1181 var
->name
, opCppClassName
);
1184 " if ({0}Attr) result.attributes.append(\"{0}\", {0}Attr);\n",
1189 // Generates the 'setPropertiesFromParsedAttr' used to set properties from a
1190 // 'prop-dict' dictionary attr.
1191 static void genParsedAttrPropertiesSetter(OperationFormat
&fmt
, Operator
&op
,
1193 // Not required unless 'prop-dict' is present.
1194 if (!fmt
.hasPropDict
)
1197 SmallVector
<MethodParameter
> paramList
;
1198 paramList
.emplace_back("Properties &", "prop");
1199 paramList
.emplace_back("::mlir::Attribute", "attr");
1200 paramList
.emplace_back("::llvm::function_ref<::mlir::InFlightDiagnostic()>",
1203 Method
*method
= opClass
.addStaticMethod("::llvm::LogicalResult",
1204 "setPropertiesFromParsedAttr",
1205 std::move(paramList
));
1206 MethodBody
&body
= method
->body().indent();
1209 ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr);
1211 emitError() << "expected DictionaryAttr to set properties
";
1212 return ::mlir::failure();
1216 // TODO: properties might be optional as well.
1217 const char *propFromAttrFmt
= R
"decl(
1218 auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr,
1219 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {{
1222 auto attr = dict.get("{1}");
1224 emitError() << "expected key entry
for {1} in DictionaryAttr to set
"
1226 return ::mlir::failure();
1228 if (::mlir::failed(setFromAttr(prop.{1}, attr, emitError)))
1229 return ::mlir::failure();
1232 // Generate the setter for any property not parsed elsewhere.
1233 for (const NamedProperty
&namedProperty
: op
.getProperties()) {
1234 if (fmt
.usedProperties
.contains(&namedProperty
))
1237 auto scope
= body
.scope("{\n", "}\n", /*indent=*/true);
1239 StringRef name
= namedProperty
.name
;
1240 const Property
&prop
= namedProperty
.prop
;
1242 body
<< formatv(propFromAttrFmt
,
1243 tgfmt(prop
.getConvertFromAttributeCall(),
1244 &fctx
.addSubst("_attr", "propAttr")
1245 .addSubst("_storage", "propStorage")
1246 .addSubst("_diag", "emitError")),
1250 // Generate the setter for any attribute not parsed elsewhere.
1251 for (const NamedAttribute
&namedAttr
: op
.getAttributes()) {
1252 if (fmt
.usedAttributes
.contains(&namedAttr
))
1255 const Attribute
&attr
= namedAttr
.attr
;
1256 // Derived attributes do not need to be parsed.
1257 if (attr
.isDerivedAttr())
1260 auto scope
= body
.scope("{\n", "}\n", /*indent=*/true);
1262 // If the attribute has a default value or is optional, it does not need to
1263 // be present in the parsed dictionary attribute.
1264 bool isRequired
= !attr
.isOptional() && !attr
.hasDefaultValue();
1265 body
<< formatv(R
"decl(
1266 auto &propStorage = prop.{0};
1267 auto attr = dict.get("{0}");
1268 if (attr || /*isRequired=*/{1}) {{
1270 emitError() << "expected key entry
for {0} in DictionaryAttr to set
"
1272 return ::mlir::failure();
1274 auto convertedAttr = ::llvm::dyn_cast<std::remove_reference_t<decltype(propStorage)>>(attr);
1275 if (convertedAttr) {{
1276 propStorage = convertedAttr;
1278 emitError() << "Invalid attribute `
{0}` in property conversion
: " << attr;
1279 return ::mlir::failure();
1283 namedAttr
.name
, isRequired
);
1285 body
<< "return ::mlir::success();\n";
1288 void OperationFormat::genParser(Operator
&op
, OpClass
&opClass
) {
1289 SmallVector
<MethodParameter
> paramList
;
1290 paramList
.emplace_back("::mlir::OpAsmParser &", "parser");
1291 paramList
.emplace_back("::mlir::OperationState &", "result");
1293 auto *method
= opClass
.addStaticMethod("::mlir::ParseResult", "parse",
1294 std::move(paramList
));
1295 auto &body
= method
->body();
1297 // Generate variables to store the operands and type within the format. This
1298 // allows for referencing these variables in the presence of optional
1300 for (FormatElement
*element
: elements
)
1301 genElementParserStorage(element
, op
, body
);
1303 // A format context used when parsing attributes with buildable types.
1304 FmtContext attrTypeCtx
;
1305 attrTypeCtx
.withBuilder("parser.getBuilder()");
1307 // Generate parsers for each of the elements.
1308 for (FormatElement
*element
: elements
)
1309 genElementParser(element
, body
, attrTypeCtx
);
1311 // Generate the code to resolve the operand/result types and successors now
1312 // that they have been parsed.
1313 genParserRegionResolution(op
, body
);
1314 genParserSuccessorResolution(op
, body
);
1315 genParserVariadicSegmentResolution(op
, body
);
1316 genParserTypeResolution(op
, body
);
1318 body
<< " return ::mlir::success();\n";
1320 genParsedAttrPropertiesSetter(*this, op
, opClass
);
1323 void OperationFormat::genElementParser(FormatElement
*element
, MethodBody
&body
,
1324 FmtContext
&attrTypeCtx
,
1325 GenContext genCtx
) {
1327 if (auto *optional
= dyn_cast
<OptionalElement
>(element
)) {
1328 auto genElementParsers
= [&](FormatElement
*firstElement
,
1329 ArrayRef
<FormatElement
*> elements
,
1331 // If the anchor is a unit attribute, we don't need to print it. When
1332 // parsing, we will add this attribute if this group is present.
1333 FormatElement
*elidedAnchorElement
= nullptr;
1334 auto *anchorAttr
= dyn_cast
<AttributeVariable
>(optional
->getAnchor());
1335 if (anchorAttr
&& anchorAttr
!= firstElement
&&
1336 anchorAttr
->isUnitAttr()) {
1337 elidedAnchorElement
= anchorAttr
;
1339 if (!thenGroup
== optional
->isInverted()) {
1340 // Add the anchor unit attribute to the operation state.
1341 if (useProperties
) {
1343 " result.getOrAddProperties<{1}::Properties>().{0} = "
1344 "parser.getBuilder().getUnitAttr();",
1345 anchorAttr
->getVar()->name
, opCppClassName
);
1347 body
<< " result.addAttribute(\"" << anchorAttr
->getVar()->name
1348 << "\", parser.getBuilder().getUnitAttr());\n";
1353 // Generate the rest of the elements inside an optional group. Elements in
1354 // an optional group after the guard are parsed as required.
1355 for (FormatElement
*childElement
: elements
)
1356 if (childElement
!= elidedAnchorElement
)
1357 genElementParser(childElement
, body
, attrTypeCtx
,
1358 GenContext::Optional
);
1361 ArrayRef
<FormatElement
*> thenElements
=
1362 optional
->getThenElements(/*parseable=*/true);
1364 // Generate a special optional parser for the first element to gate the
1365 // parsing of the rest of the elements.
1366 FormatElement
*firstElement
= thenElements
.front();
1367 if (auto *attrVar
= dyn_cast
<AttributeVariable
>(firstElement
)) {
1368 genAttrParser(attrVar
, body
, attrTypeCtx
, /*parseAsOptional=*/true,
1369 useProperties
, opCppClassName
);
1370 body
<< " if (" << attrVar
->getVar()->name
<< "Attr) {\n";
1371 } else if (auto *literal
= dyn_cast
<LiteralElement
>(firstElement
)) {
1372 body
<< " if (::mlir::succeeded(parser.parseOptional";
1373 genLiteralParser(literal
->getSpelling(), body
);
1375 } else if (auto *opVar
= dyn_cast
<OperandVariable
>(firstElement
)) {
1376 genElementParser(opVar
, body
, attrTypeCtx
);
1377 body
<< " if (!" << opVar
->getVar()->name
<< "Operands.empty()) {\n";
1378 } else if (auto *regionVar
= dyn_cast
<RegionVariable
>(firstElement
)) {
1379 const NamedRegion
*region
= regionVar
->getVar();
1380 if (region
->isVariadic()) {
1381 genElementParser(regionVar
, body
, attrTypeCtx
);
1382 body
<< " if (!" << region
->name
<< "Regions.empty()) {\n";
1384 body
<< llvm::formatv(optionalRegionParserCode
, region
->name
);
1385 body
<< " if (!" << region
->name
<< "Region->empty()) {\n ";
1386 if (hasImplicitTermTrait
)
1387 body
<< llvm::formatv(regionEnsureTerminatorParserCode
, region
->name
);
1388 else if (hasSingleBlockTrait
)
1389 body
<< llvm::formatv(regionEnsureSingleBlockParserCode
,
1392 } else if (auto *custom
= dyn_cast
<CustomDirective
>(firstElement
)) {
1393 body
<< " if (auto optResult = [&]() -> ::mlir::OptionalParseResult {\n";
1394 genCustomDirectiveParser(custom
, body
, useProperties
, opCppClassName
,
1395 /*isOptional=*/true);
1396 body
<< " return ::mlir::success();\n"
1397 << " }(); optResult.has_value() && ::mlir::failed(*optResult)) {\n"
1398 << " return ::mlir::failure();\n"
1399 << " } else if (optResult.has_value()) {\n";
1402 genElementParsers(firstElement
, thenElements
.drop_front(),
1403 /*thenGroup=*/true);
1406 // Generate the else elements.
1407 auto elseElements
= optional
->getElseElements();
1408 if (!elseElements
.empty()) {
1409 body
<< " else {\n";
1410 ArrayRef
<FormatElement
*> elseElements
=
1411 optional
->getElseElements(/*parseable=*/true);
1412 genElementParsers(elseElements
.front(), elseElements
,
1413 /*thenGroup=*/false);
1418 /// OIList Directive
1419 } else if (OIListElement
*oilist
= dyn_cast
<OIListElement
>(element
)) {
1420 for (LiteralElement
*le
: oilist
->getLiteralElements())
1421 body
<< " bool " << le
->getSpelling() << "Clause = false;\n";
1423 // Generate the parsing loop
1424 body
<< " while(true) {\n";
1425 for (auto clause
: oilist
->getClauses()) {
1426 LiteralElement
*lelement
= std::get
<0>(clause
);
1427 ArrayRef
<FormatElement
*> pelement
= std::get
<1>(clause
);
1428 body
<< "if (succeeded(parser.parseOptional";
1429 genLiteralParser(lelement
->getSpelling(), body
);
1431 StringRef lelementName
= lelement
->getSpelling();
1432 body
<< formatv(oilistParserCode
, lelementName
);
1433 if (AttributeVariable
*unitAttrElem
=
1434 oilist
->getUnitAttrParsingElement(pelement
)) {
1435 if (useProperties
) {
1437 " result.getOrAddProperties<{1}::Properties>().{0} = "
1438 "parser.getBuilder().getUnitAttr();",
1439 unitAttrElem
->getVar()->name
, opCppClassName
);
1441 body
<< " result.addAttribute(\"" << unitAttrElem
->getVar()->name
1442 << "\", UnitAttr::get(parser.getContext()));\n";
1445 for (FormatElement
*el
: pelement
)
1446 genElementParser(el
, body
, attrTypeCtx
);
1451 body
<< " break;\n";
1456 } else if (LiteralElement
*literal
= dyn_cast
<LiteralElement
>(element
)) {
1457 body
<< " if (parser.parse";
1458 genLiteralParser(literal
->getSpelling(), body
);
1459 body
<< ")\n return ::mlir::failure();\n";
1462 } else if (isa
<WhitespaceElement
>(element
)) {
1463 // Nothing to parse.
1466 } else if (auto *attr
= dyn_cast
<AttributeVariable
>(element
)) {
1467 bool parseAsOptional
=
1468 (genCtx
== GenContext::Normal
&& attr
->getVar()->attr
.isOptional());
1469 genAttrParser(attr
, body
, attrTypeCtx
, parseAsOptional
, useProperties
,
1472 } else if (auto *operand
= dyn_cast
<OperandVariable
>(element
)) {
1473 ArgumentLengthKind lengthKind
= getArgumentLengthKind(operand
->getVar());
1474 StringRef name
= operand
->getVar()->name
;
1475 if (lengthKind
== ArgumentLengthKind::VariadicOfVariadic
)
1476 body
<< llvm::formatv(
1477 variadicOfVariadicOperandParserCode
, name
,
1478 operand
->getVar()->constraint
.getVariadicOfVariadicSegmentSizeAttr());
1479 else if (lengthKind
== ArgumentLengthKind::Variadic
)
1480 body
<< llvm::formatv(variadicOperandParserCode
, name
);
1481 else if (lengthKind
== ArgumentLengthKind::Optional
)
1482 body
<< llvm::formatv(optionalOperandParserCode
, name
);
1484 body
<< formatv(operandParserCode
, name
);
1486 } else if (auto *region
= dyn_cast
<RegionVariable
>(element
)) {
1487 bool isVariadic
= region
->getVar()->isVariadic();
1488 body
<< llvm::formatv(isVariadic
? regionListParserCode
: regionParserCode
,
1489 region
->getVar()->name
);
1490 if (hasImplicitTermTrait
)
1491 body
<< llvm::formatv(isVariadic
? regionListEnsureTerminatorParserCode
1492 : regionEnsureTerminatorParserCode
,
1493 region
->getVar()->name
);
1494 else if (hasSingleBlockTrait
)
1495 body
<< llvm::formatv(isVariadic
? regionListEnsureSingleBlockParserCode
1496 : regionEnsureSingleBlockParserCode
,
1497 region
->getVar()->name
);
1499 } else if (auto *successor
= dyn_cast
<SuccessorVariable
>(element
)) {
1500 bool isVariadic
= successor
->getVar()->isVariadic();
1501 body
<< formatv(isVariadic
? successorListParserCode
: successorParserCode
,
1502 successor
->getVar()->name
);
1505 } else if (auto *attrDict
= dyn_cast
<AttrDictDirective
>(element
)) {
1506 body
.indent() << "{\n";
1507 body
.indent() << "auto loc = parser.getCurrentLocation();(void)loc;\n"
1508 << "if (parser.parseOptionalAttrDict"
1509 << (attrDict
->isWithKeyword() ? "WithKeyword" : "")
1510 << "(result.attributes))\n"
1511 << " return ::mlir::failure();\n";
1512 if (useProperties
) {
1513 body
<< "if (failed(verifyInherentAttrs(result.name, result.attributes, "
1515 << " return parser.emitError(loc) << \"'\" << "
1516 "result.name.getStringRef() << \"' op \";\n"
1518 << " return ::mlir::failure();\n";
1520 body
.unindent() << "}\n";
1522 } else if (isa
<PropDictDirective
>(element
)) {
1523 body
<< " if (parseProperties(parser, result))\n"
1524 << " return ::mlir::failure();\n";
1525 } else if (auto *customDir
= dyn_cast
<CustomDirective
>(element
)) {
1526 genCustomDirectiveParser(customDir
, body
, useProperties
, opCppClassName
);
1527 } else if (isa
<OperandsDirective
>(element
)) {
1528 body
<< " [[maybe_unused]] ::llvm::SMLoc allOperandLoc ="
1529 << " parser.getCurrentLocation();\n"
1530 << " if (parser.parseOperandList(allOperands))\n"
1531 << " return ::mlir::failure();\n";
1533 } else if (isa
<RegionsDirective
>(element
)) {
1534 body
<< llvm::formatv(regionListParserCode
, "full");
1535 if (hasImplicitTermTrait
)
1536 body
<< llvm::formatv(regionListEnsureTerminatorParserCode
, "full");
1537 else if (hasSingleBlockTrait
)
1538 body
<< llvm::formatv(regionListEnsureSingleBlockParserCode
, "full");
1540 } else if (isa
<SuccessorsDirective
>(element
)) {
1541 body
<< llvm::formatv(successorListParserCode
, "full");
1543 } else if (auto *dir
= dyn_cast
<TypeDirective
>(element
)) {
1544 ArgumentLengthKind lengthKind
;
1545 StringRef listName
= getTypeListName(dir
->getArg(), lengthKind
);
1546 if (lengthKind
== ArgumentLengthKind::VariadicOfVariadic
) {
1547 body
<< llvm::formatv(variadicOfVariadicTypeParserCode
, listName
);
1548 } else if (lengthKind
== ArgumentLengthKind::Variadic
) {
1549 body
<< llvm::formatv(variadicTypeParserCode
, listName
);
1550 } else if (lengthKind
== ArgumentLengthKind::Optional
) {
1551 body
<< llvm::formatv(optionalTypeParserCode
, listName
);
1553 const char *parserCode
=
1554 dir
->shouldBeQualified() ? qualifiedTypeParserCode
: typeParserCode
;
1555 TypeSwitch
<FormatElement
*>(dir
->getArg())
1556 .Case
<OperandVariable
, ResultVariable
>([&](auto operand
) {
1557 body
<< formatv(parserCode
,
1558 operand
->getVar()->constraint
.getCPPClassName(),
1561 .Default([&](auto operand
) {
1562 body
<< formatv(parserCode
, "::mlir::Type", listName
);
1565 } else if (auto *dir
= dyn_cast
<FunctionalTypeDirective
>(element
)) {
1566 ArgumentLengthKind ignored
;
1567 body
<< formatv(functionalTypeParserCode
,
1568 getTypeListName(dir
->getInputs(), ignored
),
1569 getTypeListName(dir
->getResults(), ignored
));
1571 llvm_unreachable("unknown format element");
1575 void OperationFormat::genParserTypeResolution(Operator
&op
, MethodBody
&body
) {
1576 // If any of type resolutions use transformed variables, make sure that the
1577 // types of those variables are resolved.
1578 SmallPtrSet
<const NamedTypeConstraint
*, 8> verifiedVariables
;
1579 FmtContext verifierFCtx
;
1580 for (TypeResolution
&resolver
:
1581 llvm::concat
<TypeResolution
>(resultTypes
, operandTypes
)) {
1582 std::optional
<StringRef
> transformer
= resolver
.getVarTransformer();
1585 // Ensure that we don't verify the same variables twice.
1586 const NamedTypeConstraint
*variable
= resolver
.getVariable();
1587 if (!variable
|| !verifiedVariables
.insert(variable
).second
)
1590 auto constraint
= variable
->constraint
;
1591 body
<< " for (::mlir::Type type : " << variable
->name
<< "Types) {\n"
1594 << tgfmt(constraint
.getConditionTemplate(),
1595 &verifierFCtx
.withSelf("type"))
1597 << formatv(" return parser.emitError(parser.getNameLoc()) << "
1598 "\"'{0}' must be {1}, but got \" << type;\n",
1599 variable
->name
, constraint
.getSummary())
1604 // Initialize the set of buildable types.
1605 if (!buildableTypes
.empty()) {
1606 FmtContext typeBuilderCtx
;
1607 typeBuilderCtx
.withBuilder("parser.getBuilder()");
1608 for (auto &it
: buildableTypes
)
1609 body
<< " ::mlir::Type odsBuildableType" << it
.second
<< " = "
1610 << tgfmt(it
.first
, &typeBuilderCtx
) << ";\n";
1613 // Emit the code necessary for a type resolver.
1614 auto emitTypeResolver
= [&](TypeResolution
&resolver
, StringRef curVar
) {
1615 if (std::optional
<int> val
= resolver
.getBuilderIdx()) {
1616 body
<< "odsBuildableType" << *val
;
1617 } else if (const NamedTypeConstraint
*var
= resolver
.getVariable()) {
1618 if (std::optional
<StringRef
> tform
= resolver
.getVarTransformer()) {
1619 FmtContext fmtContext
;
1620 fmtContext
.addSubst("_ctxt", "parser.getContext()");
1621 if (var
->isVariadic())
1622 fmtContext
.withSelf(var
->name
+ "Types");
1624 fmtContext
.withSelf(var
->name
+ "Types[0]");
1625 body
<< tgfmt(*tform
, &fmtContext
);
1627 body
<< var
->name
<< "Types";
1628 if (!var
->isVariadic())
1631 } else if (const NamedAttribute
*attr
= resolver
.getAttribute()) {
1632 if (std::optional
<StringRef
> tform
= resolver
.getVarTransformer())
1633 body
<< tgfmt(*tform
,
1634 &FmtContext().withSelf(attr
->name
+ "Attr.getType()"));
1636 body
<< attr
->name
<< "Attr.getType()";
1638 body
<< curVar
<< "Types";
1642 // Resolve each of the result types.
1643 if (!infersResultTypes
) {
1644 if (allResultTypes
) {
1645 body
<< " result.addTypes(allResultTypes);\n";
1647 for (unsigned i
= 0, e
= op
.getNumResults(); i
!= e
; ++i
) {
1648 body
<< " result.addTypes(";
1649 emitTypeResolver(resultTypes
[i
], op
.getResultName(i
));
1655 // Emit the operand type resolutions.
1656 genParserOperandTypeResolution(op
, body
, emitTypeResolver
);
1658 // Handle return type inference once all operands have been resolved
1659 if (infersResultTypes
)
1660 body
<< formatv(inferReturnTypesParserCode
, op
.getCppClassName());
1663 void OperationFormat::genParserOperandTypeResolution(
1664 Operator
&op
, MethodBody
&body
,
1665 function_ref
<void(TypeResolution
&, StringRef
)> emitTypeResolver
) {
1666 // Early exit if there are no operands.
1667 if (op
.getNumOperands() == 0)
1670 // Handle the case where all operand types are grouped together with
1671 // "types(operands)".
1672 if (allOperandTypes
) {
1673 // If `operands` was specified, use the full operand list directly.
1675 body
<< " if (parser.resolveOperands(allOperands, allOperandTypes, "
1676 "allOperandLoc, result.operands))\n"
1677 " return ::mlir::failure();\n";
1681 // Otherwise, use llvm::concat to merge the disjoint operand lists together.
1682 // llvm::concat does not allow the case of a single range, so guard it here.
1683 body
<< " if (parser.resolveOperands(";
1684 if (op
.getNumOperands() > 1) {
1685 body
<< "::llvm::concat<const ::mlir::OpAsmParser::UnresolvedOperand>(";
1686 llvm::interleaveComma(op
.getOperands(), body
, [&](auto &operand
) {
1687 body
<< operand
.name
<< "Operands";
1691 body
<< op
.operand_begin()->name
<< "Operands";
1693 body
<< ", allOperandTypes, parser.getNameLoc(), result.operands))\n"
1694 << " return ::mlir::failure();\n";
1698 // Handle the case where all operands are grouped together with "operands".
1700 body
<< " if (parser.resolveOperands(allOperands, ";
1702 // Group all of the operand types together to perform the resolution all at
1703 // once. Use llvm::concat to perform the merge. llvm::concat does not allow
1704 // the case of a single range, so guard it here.
1705 if (op
.getNumOperands() > 1) {
1706 body
<< "::llvm::concat<const ::mlir::Type>(";
1707 llvm::interleaveComma(
1708 llvm::seq
<int>(0, op
.getNumOperands()), body
, [&](int i
) {
1709 body
<< "::llvm::ArrayRef<::mlir::Type>(";
1710 emitTypeResolver(operandTypes
[i
], op
.getOperand(i
).name
);
1715 emitTypeResolver(operandTypes
.front(), op
.getOperand(0).name
);
1718 body
<< ", allOperandLoc, result.operands))\n return "
1719 "::mlir::failure();\n";
1723 // The final case is the one where each of the operands types are resolved
1725 for (unsigned i
= 0, e
= op
.getNumOperands(); i
!= e
; ++i
) {
1726 NamedTypeConstraint
&operand
= op
.getOperand(i
);
1727 body
<< " if (parser.resolveOperands(" << operand
.name
<< "Operands, ";
1729 // Resolve the type of this operand.
1730 TypeResolution
&operandType
= operandTypes
[i
];
1731 emitTypeResolver(operandType
, operand
.name
);
1733 body
<< ", " << operand
.name
1734 << "OperandsLoc, result.operands))\n return ::mlir::failure();\n";
1738 void OperationFormat::genParserRegionResolution(Operator
&op
,
1740 // Check for the case where all regions were parsed.
1741 bool hasAllRegions
= llvm::any_of(
1742 elements
, [](FormatElement
*elt
) { return isa
<RegionsDirective
>(elt
); });
1743 if (hasAllRegions
) {
1744 body
<< " result.addRegions(fullRegions);\n";
1748 // Otherwise, handle each region individually.
1749 for (const NamedRegion
®ion
: op
.getRegions()) {
1750 if (region
.isVariadic())
1751 body
<< " result.addRegions(" << region
.name
<< "Regions);\n";
1753 body
<< " result.addRegion(std::move(" << region
.name
<< "Region));\n";
1757 void OperationFormat::genParserSuccessorResolution(Operator
&op
,
1759 // Check for the case where all successors were parsed.
1760 bool hasAllSuccessors
= llvm::any_of(elements
, [](FormatElement
*elt
) {
1761 return isa
<SuccessorsDirective
>(elt
);
1763 if (hasAllSuccessors
) {
1764 body
<< " result.addSuccessors(fullSuccessors);\n";
1768 // Otherwise, handle each successor individually.
1769 for (const NamedSuccessor
&successor
: op
.getSuccessors()) {
1770 if (successor
.isVariadic())
1771 body
<< " result.addSuccessors(" << successor
.name
<< "Successors);\n";
1773 body
<< " result.addSuccessors(" << successor
.name
<< "Successor);\n";
1777 void OperationFormat::genParserVariadicSegmentResolution(Operator
&op
,
1780 if (op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1781 auto interleaveFn
= [&](const NamedTypeConstraint
&operand
) {
1782 // If the operand is variadic emit the parsed size.
1783 if (operand
.isVariableLength())
1784 body
<< "static_cast<int32_t>(" << operand
.name
<< "Operands.size())";
1788 if (op
.getDialect().usePropertiesForAttributes()) {
1789 body
<< "::llvm::copy(::llvm::ArrayRef<int32_t>({";
1790 llvm::interleaveComma(op
.getOperands(), body
, interleaveFn
);
1791 body
<< formatv("}), "
1792 "result.getOrAddProperties<{0}::Properties>()."
1793 "operandSegmentSizes.begin());\n",
1794 op
.getCppClassName());
1796 body
<< " result.addAttribute(\"operandSegmentSizes\", "
1797 << "parser.getBuilder().getDenseI32ArrayAttr({";
1798 llvm::interleaveComma(op
.getOperands(), body
, interleaveFn
);
1802 for (const NamedTypeConstraint
&operand
: op
.getOperands()) {
1803 if (!operand
.isVariadicOfVariadic())
1805 if (op
.getDialect().usePropertiesForAttributes()) {
1806 body
<< llvm::formatv(
1807 " result.getOrAddProperties<{0}::Properties>().{1} = "
1808 "parser.getBuilder().getDenseI32ArrayAttr({2}OperandGroupSizes);\n",
1809 op
.getCppClassName(),
1810 operand
.constraint
.getVariadicOfVariadicSegmentSizeAttr(),
1813 body
<< llvm::formatv(
1814 " result.addAttribute(\"{0}\", "
1815 "parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));"
1817 operand
.constraint
.getVariadicOfVariadicSegmentSizeAttr(),
1823 if (!allResultTypes
&&
1824 op
.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
1825 auto interleaveFn
= [&](const NamedTypeConstraint
&result
) {
1826 // If the result is variadic emit the parsed size.
1827 if (result
.isVariableLength())
1828 body
<< "static_cast<int32_t>(" << result
.name
<< "Types.size())";
1832 if (op
.getDialect().usePropertiesForAttributes()) {
1833 body
<< "::llvm::copy(::llvm::ArrayRef<int32_t>({";
1834 llvm::interleaveComma(op
.getResults(), body
, interleaveFn
);
1835 body
<< formatv("}), "
1836 "result.getOrAddProperties<{0}::Properties>()."
1837 "resultSegmentSizes.begin());\n",
1838 op
.getCppClassName());
1840 body
<< " result.addAttribute(\"resultSegmentSizes\", "
1841 << "parser.getBuilder().getDenseI32ArrayAttr({";
1842 llvm::interleaveComma(op
.getResults(), body
, interleaveFn
);
1848 //===----------------------------------------------------------------------===//
1851 /// The code snippet used to generate a printer call for a region of an
1852 // operation that has the SingleBlockImplicitTerminator trait.
1854 /// {0}: The name of the region.
1855 const char *regionSingleBlockImplicitTerminatorPrinterCode
= R
"(
1857 bool printTerminator = true;
1858 if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{
1859 printTerminator = !term->getAttrDictionary().empty() ||
1860 term->getNumOperands() != 0 ||
1861 term->getNumResults() != 0;
1863 _odsPrinter.printRegion({0}, /*printEntryBlockArgs=*/true,
1864 /*printBlockTerminators=*/printTerminator);
1868 /// The code snippet used to generate a printer call for an enum that has cases
1869 /// that can't be represented with a keyword.
1871 /// {0}: The name of the enum attribute.
1872 /// {1}: The name of the enum attributes symbolToString function.
1873 const char *enumAttrBeginPrinterCode
= R
"(
1875 auto caseValue = {0}();
1876 auto caseValueStr = {1}(caseValue);
1879 /// Generate the printer for the 'prop-dict' directive.
1880 static void genPropDictPrinter(OperationFormat
&fmt
, Operator
&op
,
1882 body
<< " ::llvm::SmallVector<::llvm::StringRef, 2> elidedProps;\n";
1883 for (const NamedProperty
*namedProperty
: fmt
.usedProperties
)
1884 body
<< " elidedProps.push_back(\"" << namedProperty
->name
<< "\");\n";
1885 for (const NamedAttribute
*namedAttr
: fmt
.usedAttributes
)
1886 body
<< " elidedProps.push_back(\"" << namedAttr
->name
<< "\");\n";
1888 // Add code to check attributes for equality with the default value
1889 // for attributes with the elidePrintingDefaultValue bit set.
1890 for (const NamedAttribute
&namedAttr
: op
.getAttributes()) {
1891 const Attribute
&attr
= namedAttr
.attr
;
1892 if (!attr
.isDerivedAttr() && attr
.hasDefaultValue()) {
1893 const StringRef
&name
= namedAttr
.name
;
1895 fctx
.withBuilder("odsBuilder");
1896 std::string defaultValue
= std::string(
1897 tgfmt(attr
.getConstBuilderTemplate(), &fctx
, attr
.getDefaultValue()));
1899 body
<< " ::mlir::Builder odsBuilder(getContext());\n";
1900 body
<< " ::mlir::Attribute attr = " << op
.getGetterName(name
)
1902 body
<< " if(attr && (attr == " << defaultValue
<< "))\n";
1903 body
<< " elidedProps.push_back(\"" << name
<< "\");\n";
1908 body
<< " _odsPrinter << \" \";\n"
1909 << " printProperties(this->getContext(), _odsPrinter, "
1910 "getProperties(), elidedProps);\n";
1913 /// Generate the printer for the 'attr-dict' directive.
1914 static void genAttrDictPrinter(OperationFormat
&fmt
, Operator
&op
,
1915 MethodBody
&body
, bool withKeyword
) {
1916 body
<< " ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;\n";
1917 // Elide the variadic segment size attributes if necessary.
1918 if (!fmt
.allOperands
&&
1919 op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"))
1920 body
<< " elidedAttrs.push_back(\"operandSegmentSizes\");\n";
1921 if (!fmt
.allResultTypes
&&
1922 op
.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
1923 body
<< " elidedAttrs.push_back(\"resultSegmentSizes\");\n";
1924 for (const StringRef key
: fmt
.inferredAttributes
.keys())
1925 body
<< " elidedAttrs.push_back(\"" << key
<< "\");\n";
1926 for (const NamedAttribute
*attr
: fmt
.usedAttributes
)
1927 body
<< " elidedAttrs.push_back(\"" << attr
->name
<< "\");\n";
1928 // Add code to check attributes for equality with the default value
1929 // for attributes with the elidePrintingDefaultValue bit set.
1930 for (const NamedAttribute
&namedAttr
: op
.getAttributes()) {
1931 const Attribute
&attr
= namedAttr
.attr
;
1932 if (!attr
.isDerivedAttr() && attr
.hasDefaultValue()) {
1933 const StringRef
&name
= namedAttr
.name
;
1935 fctx
.withBuilder("odsBuilder");
1936 std::string defaultValue
= std::string(
1937 tgfmt(attr
.getConstBuilderTemplate(), &fctx
, attr
.getDefaultValue()));
1939 body
<< " ::mlir::Builder odsBuilder(getContext());\n";
1940 body
<< " ::mlir::Attribute attr = " << op
.getGetterName(name
)
1942 body
<< " if(attr && (attr == " << defaultValue
<< "))\n";
1943 body
<< " elidedAttrs.push_back(\"" << name
<< "\");\n";
1947 if (fmt
.hasPropDict
)
1948 body
<< " _odsPrinter.printOptionalAttrDict"
1949 << (withKeyword
? "WithKeyword" : "")
1950 << "(llvm::to_vector((*this)->getDiscardableAttrs()), elidedAttrs);\n";
1952 body
<< " _odsPrinter.printOptionalAttrDict"
1953 << (withKeyword
? "WithKeyword" : "")
1954 << "((*this)->getAttrs(), elidedAttrs);\n";
1957 /// Generate the printer for a literal value. `shouldEmitSpace` is true if a
1958 /// space should be emitted before this element. `lastWasPunctuation` is true if
1959 /// the previous element was a punctuation literal.
1960 static void genLiteralPrinter(StringRef value
, MethodBody
&body
,
1961 bool &shouldEmitSpace
, bool &lastWasPunctuation
) {
1962 body
<< " _odsPrinter";
1964 // Don't insert a space for certain punctuation.
1965 if (shouldEmitSpace
&& shouldEmitSpaceBefore(value
, lastWasPunctuation
))
1967 body
<< " << \"" << value
<< "\";\n";
1969 // Insert a space after certain literals.
1971 value
.size() != 1 || !StringRef("<({[").contains(value
.front());
1972 lastWasPunctuation
= value
.front() != '_' && !isalpha(value
.front());
1975 /// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation`
1976 /// are set to false.
1977 static void genSpacePrinter(bool value
, MethodBody
&body
, bool &shouldEmitSpace
,
1978 bool &lastWasPunctuation
) {
1980 body
<< " _odsPrinter << ' ';\n";
1981 lastWasPunctuation
= false;
1983 lastWasPunctuation
= true;
1985 shouldEmitSpace
= false;
1988 /// Generate the printer for a custom directive parameter.
1989 static void genCustomDirectiveParameterPrinter(FormatElement
*element
,
1992 if (auto *attr
= dyn_cast
<AttributeVariable
>(element
)) {
1993 body
<< op
.getGetterName(attr
->getVar()->name
) << "Attr()";
1995 } else if (isa
<AttrDictDirective
>(element
)) {
1996 body
<< "getOperation()->getAttrDictionary()";
1998 } else if (isa
<PropDictDirective
>(element
)) {
1999 body
<< "getProperties()";
2001 } else if (auto *operand
= dyn_cast
<OperandVariable
>(element
)) {
2002 body
<< op
.getGetterName(operand
->getVar()->name
) << "()";
2004 } else if (auto *region
= dyn_cast
<RegionVariable
>(element
)) {
2005 body
<< op
.getGetterName(region
->getVar()->name
) << "()";
2007 } else if (auto *successor
= dyn_cast
<SuccessorVariable
>(element
)) {
2008 body
<< op
.getGetterName(successor
->getVar()->name
) << "()";
2010 } else if (auto *dir
= dyn_cast
<RefDirective
>(element
)) {
2011 genCustomDirectiveParameterPrinter(dir
->getArg(), op
, body
);
2013 } else if (auto *dir
= dyn_cast
<TypeDirective
>(element
)) {
2014 auto *typeOperand
= dir
->getArg();
2015 auto *operand
= dyn_cast
<OperandVariable
>(typeOperand
);
2016 auto *var
= operand
? operand
->getVar()
2017 : cast
<ResultVariable
>(typeOperand
)->getVar();
2018 std::string name
= op
.getGetterName(var
->name
);
2019 if (var
->isVariadic())
2020 body
<< name
<< "().getTypes()";
2021 else if (var
->isOptional())
2022 body
<< llvm::formatv("({0}() ? {0}().getType() : ::mlir::Type())", name
);
2024 body
<< name
<< "().getType()";
2026 } else if (auto *string
= dyn_cast
<StringElement
>(element
)) {
2028 ctx
.withBuilder("::mlir::Builder(getContext())");
2029 ctx
.addSubst("_ctxt", "getContext()");
2030 body
<< tgfmt(string
->getValue(), &ctx
);
2032 } else if (auto *property
= dyn_cast
<PropertyVariable
>(element
)) {
2034 ctx
.addSubst("_ctxt", "getContext()");
2035 const NamedProperty
*namedProperty
= property
->getVar();
2036 ctx
.addSubst("_storage", "getProperties()." + namedProperty
->name
);
2037 body
<< tgfmt(namedProperty
->prop
.getConvertFromStorageCall(), &ctx
);
2039 llvm_unreachable("unknown custom directive parameter");
2043 /// Generate the printer for a custom directive.
2044 static void genCustomDirectivePrinter(CustomDirective
*customDir
,
2045 const Operator
&op
, MethodBody
&body
) {
2046 body
<< " print" << customDir
->getName() << "(_odsPrinter, *this";
2047 for (FormatElement
*param
: customDir
->getArguments()) {
2049 genCustomDirectiveParameterPrinter(param
, op
, body
);
2054 /// Generate the printer for a region with the given variable name.
2055 static void genRegionPrinter(const Twine
®ionName
, MethodBody
&body
,
2056 bool hasImplicitTermTrait
) {
2057 if (hasImplicitTermTrait
)
2058 body
<< llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode
,
2061 body
<< " _odsPrinter.printRegion(" << regionName
<< ");\n";
2063 static void genVariadicRegionPrinter(const Twine
®ionListName
,
2065 bool hasImplicitTermTrait
) {
2066 body
<< " llvm::interleaveComma(" << regionListName
2067 << ", _odsPrinter, [&](::mlir::Region ®ion) {\n ";
2068 genRegionPrinter("region", body
, hasImplicitTermTrait
);
2072 /// Generate the C++ for an operand to a (*-)type directive.
2073 static MethodBody
&genTypeOperandPrinter(FormatElement
*arg
, const Operator
&op
,
2075 bool useArrayRef
= true) {
2076 if (isa
<OperandsDirective
>(arg
))
2077 return body
<< "getOperation()->getOperandTypes()";
2078 if (isa
<ResultsDirective
>(arg
))
2079 return body
<< "getOperation()->getResultTypes()";
2080 auto *operand
= dyn_cast
<OperandVariable
>(arg
);
2081 auto *var
= operand
? operand
->getVar() : cast
<ResultVariable
>(arg
)->getVar();
2082 if (var
->isVariadicOfVariadic())
2083 return body
<< llvm::formatv("{0}().join().getTypes()",
2084 op
.getGetterName(var
->name
));
2085 if (var
->isVariadic())
2086 return body
<< op
.getGetterName(var
->name
) << "().getTypes()";
2087 if (var
->isOptional())
2088 return body
<< llvm::formatv(
2089 "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
2090 "::llvm::ArrayRef<::mlir::Type>())",
2091 op
.getGetterName(var
->name
));
2093 return body
<< "::llvm::ArrayRef<::mlir::Type>("
2094 << op
.getGetterName(var
->name
) << "().getType())";
2095 return body
<< op
.getGetterName(var
->name
) << "().getType()";
2098 /// Generate the printer for an enum attribute.
2099 static void genEnumAttrPrinter(const NamedAttribute
*var
, const Operator
&op
,
2101 Attribute baseAttr
= var
->attr
.getBaseAttr();
2102 const EnumAttr
&enumAttr
= cast
<EnumAttr
>(baseAttr
);
2103 std::vector
<EnumAttrCase
> cases
= enumAttr
.getAllCases();
2105 body
<< llvm::formatv(enumAttrBeginPrinterCode
,
2106 (var
->attr
.isOptional() ? "*" : "") +
2107 op
.getGetterName(var
->name
),
2108 enumAttr
.getSymbolToStringFnName());
2110 // Get a string containing all of the cases that can't be represented with a
2112 BitVector
nonKeywordCases(cases
.size());
2113 for (auto it
: llvm::enumerate(cases
)) {
2114 if (!canFormatStringAsKeyword(it
.value().getStr()))
2115 nonKeywordCases
.set(it
.index());
2118 // Otherwise if this is a bit enum attribute, don't allow cases that may
2119 // overlap with other cases. For simplicity sake, only allow cases with a
2120 // single bit value.
2121 if (enumAttr
.isBitEnum()) {
2122 for (auto it
: llvm::enumerate(cases
)) {
2123 int64_t value
= it
.value().getValue();
2124 if (value
< 0 || !llvm::isPowerOf2_64(value
))
2125 nonKeywordCases
.set(it
.index());
2129 // If there are any cases that can't be used with a keyword, switch on the
2130 // case value to determine when to print in the string form.
2131 if (nonKeywordCases
.any()) {
2132 body
<< " switch (caseValue) {\n";
2133 StringRef cppNamespace
= enumAttr
.getCppNamespace();
2134 StringRef enumName
= enumAttr
.getEnumClassName();
2135 for (auto it
: llvm::enumerate(cases
)) {
2136 if (nonKeywordCases
.test(it
.index()))
2138 StringRef symbol
= it
.value().getSymbol();
2139 body
<< llvm::formatv(" case {0}::{1}::{2}:\n", cppNamespace
, enumName
,
2140 llvm::isDigit(symbol
.front()) ? ("_" + symbol
)
2143 body
<< " _odsPrinter << caseValueStr;\n"
2146 " _odsPrinter << '\"' << caseValueStr << '\"';\n"
2153 body
<< " _odsPrinter << caseValueStr;\n"
2157 /// Generate a check that a DefaultValuedAttr has a value that is non-default.
2158 static void genNonDefaultValueCheck(MethodBody
&body
, const Operator
&op
,
2159 AttributeVariable
&attrElement
) {
2161 Attribute attr
= attrElement
.getVar()->attr
;
2162 fctx
.withBuilder("::mlir::OpBuilder((*this)->getContext())");
2163 body
<< " && " << op
.getGetterName(attrElement
.getVar()->name
) << "Attr() != "
2164 << tgfmt(attr
.getConstBuilderTemplate(), &fctx
, attr
.getDefaultValue());
2167 /// Generate the check for the anchor of an optional group.
2168 static void genOptionalGroupPrinterAnchor(FormatElement
*anchor
,
2171 TypeSwitch
<FormatElement
*>(anchor
)
2172 .Case
<OperandVariable
, ResultVariable
>([&](auto *element
) {
2173 const NamedTypeConstraint
*var
= element
->getVar();
2174 std::string name
= op
.getGetterName(var
->name
);
2175 if (var
->isOptional())
2176 body
<< name
<< "()";
2177 else if (var
->isVariadic())
2178 body
<< "!" << name
<< "().empty()";
2180 .Case([&](RegionVariable
*element
) {
2181 const NamedRegion
*var
= element
->getVar();
2182 std::string name
= op
.getGetterName(var
->name
);
2183 // TODO: Add a check for optional regions here when ODS supports it.
2184 body
<< "!" << name
<< "().empty()";
2186 .Case([&](TypeDirective
*element
) {
2187 genOptionalGroupPrinterAnchor(element
->getArg(), op
, body
);
2189 .Case([&](FunctionalTypeDirective
*element
) {
2190 genOptionalGroupPrinterAnchor(element
->getInputs(), op
, body
);
2192 .Case([&](AttributeVariable
*element
) {
2193 Attribute attr
= element
->getVar()->attr
;
2194 body
<< op
.getGetterName(element
->getVar()->name
) << "Attr()";
2195 if (attr
.isOptional())
2197 if (attr
.hasDefaultValue()) {
2198 // Consider a default-valued attribute as present if it's not the
2200 genNonDefaultValueCheck(body
, op
, *element
);
2203 llvm_unreachable("attribute must be optional or default-valued");
2205 .Case([&](CustomDirective
*ele
) {
2208 ele
->getArguments(), body
,
2209 [&](FormatElement
*child
) {
2211 genOptionalGroupPrinterAnchor(child
, op
, body
);
2219 void collect(FormatElement
*element
,
2220 SmallVectorImpl
<VariableElement
*> &variables
) {
2221 TypeSwitch
<FormatElement
*>(element
)
2222 .Case([&](VariableElement
*var
) { variables
.emplace_back(var
); })
2223 .Case([&](CustomDirective
*ele
) {
2224 for (FormatElement
*arg
: ele
->getArguments())
2225 collect(arg
, variables
);
2227 .Case([&](OptionalElement
*ele
) {
2228 for (FormatElement
*arg
: ele
->getThenElements())
2229 collect(arg
, variables
);
2230 for (FormatElement
*arg
: ele
->getElseElements())
2231 collect(arg
, variables
);
2233 .Case([&](FunctionalTypeDirective
*funcType
) {
2234 collect(funcType
->getInputs(), variables
);
2235 collect(funcType
->getResults(), variables
);
2237 .Case([&](OIListElement
*oilist
) {
2238 for (ArrayRef
<FormatElement
*> arg
: oilist
->getParsingElements())
2239 for (FormatElement
*arg
: arg
)
2240 collect(arg
, variables
);
2244 void OperationFormat::genElementPrinter(FormatElement
*element
,
2245 MethodBody
&body
, Operator
&op
,
2246 bool &shouldEmitSpace
,
2247 bool &lastWasPunctuation
) {
2248 if (LiteralElement
*literal
= dyn_cast
<LiteralElement
>(element
))
2249 return genLiteralPrinter(literal
->getSpelling(), body
, shouldEmitSpace
,
2250 lastWasPunctuation
);
2252 // Emit a whitespace element.
2253 if (auto *space
= dyn_cast
<WhitespaceElement
>(element
)) {
2254 if (space
->getValue() == "\\n") {
2255 body
<< " _odsPrinter.printNewline();\n";
2257 genSpacePrinter(!space
->getValue().empty(), body
, shouldEmitSpace
,
2258 lastWasPunctuation
);
2263 // Emit an optional group.
2264 if (OptionalElement
*optional
= dyn_cast
<OptionalElement
>(element
)) {
2265 // Emit the check for the presence of the anchor element.
2266 FormatElement
*anchor
= optional
->getAnchor();
2268 if (optional
->isInverted())
2270 genOptionalGroupPrinterAnchor(anchor
, op
, body
);
2274 // If the anchor is a unit attribute, we don't need to print it. When
2275 // parsing, we will add this attribute if this group is present.
2276 ArrayRef
<FormatElement
*> thenElements
= optional
->getThenElements();
2277 ArrayRef
<FormatElement
*> elseElements
= optional
->getElseElements();
2278 FormatElement
*elidedAnchorElement
= nullptr;
2279 auto *anchorAttr
= dyn_cast
<AttributeVariable
>(anchor
);
2280 if (anchorAttr
&& anchorAttr
!= thenElements
.front() &&
2281 (elseElements
.empty() || anchorAttr
!= elseElements
.front()) &&
2282 anchorAttr
->isUnitAttr()) {
2283 elidedAnchorElement
= anchorAttr
;
2285 auto genElementPrinters
= [&](ArrayRef
<FormatElement
*> elements
) {
2286 for (FormatElement
*childElement
: elements
) {
2287 if (childElement
!= elidedAnchorElement
) {
2288 genElementPrinter(childElement
, body
, op
, shouldEmitSpace
,
2289 lastWasPunctuation
);
2294 // Emit each of the elements.
2295 genElementPrinters(thenElements
);
2298 // Emit each of the else elements.
2299 if (!elseElements
.empty()) {
2300 body
<< " else {\n";
2301 genElementPrinters(elseElements
);
2305 body
.unindent() << "\n";
2310 if (auto *oilist
= dyn_cast
<OIListElement
>(element
)) {
2311 for (auto clause
: oilist
->getClauses()) {
2312 LiteralElement
*lelement
= std::get
<0>(clause
);
2313 ArrayRef
<FormatElement
*> pelement
= std::get
<1>(clause
);
2315 SmallVector
<VariableElement
*> vars
;
2316 for (FormatElement
*el
: pelement
)
2318 body
<< " if (false";
2319 for (VariableElement
*var
: vars
) {
2320 TypeSwitch
<FormatElement
*>(var
)
2321 .Case([&](AttributeVariable
*attrEle
) {
2322 body
<< " || (" << op
.getGetterName(attrEle
->getVar()->name
)
2324 Attribute attr
= attrEle
->getVar()->attr
;
2325 if (attr
.hasDefaultValue()) {
2326 // Don't print default-valued attributes.
2327 genNonDefaultValueCheck(body
, op
, *attrEle
);
2331 .Case([&](OperandVariable
*ele
) {
2332 if (ele
->getVar()->isVariadic()) {
2333 body
<< " || " << op
.getGetterName(ele
->getVar()->name
)
2336 body
<< " || " << op
.getGetterName(ele
->getVar()->name
) << "()";
2339 .Case([&](ResultVariable
*ele
) {
2340 if (ele
->getVar()->isVariadic()) {
2341 body
<< " || " << op
.getGetterName(ele
->getVar()->name
)
2344 body
<< " || " << op
.getGetterName(ele
->getVar()->name
) << "()";
2347 .Case([&](RegionVariable
*reg
) {
2348 body
<< " || " << op
.getGetterName(reg
->getVar()->name
) << "()";
2353 genLiteralPrinter(lelement
->getSpelling(), body
, shouldEmitSpace
,
2354 lastWasPunctuation
);
2355 if (oilist
->getUnitAttrParsingElement(pelement
) == nullptr) {
2356 for (FormatElement
*element
: pelement
)
2357 genElementPrinter(element
, body
, op
, shouldEmitSpace
,
2358 lastWasPunctuation
);
2365 // Emit the attribute dictionary.
2366 if (auto *attrDict
= dyn_cast
<AttrDictDirective
>(element
)) {
2367 genAttrDictPrinter(*this, op
, body
, attrDict
->isWithKeyword());
2368 lastWasPunctuation
= false;
2372 // Emit the attribute dictionary.
2373 if (isa
<PropDictDirective
>(element
)) {
2374 genPropDictPrinter(*this, op
, body
);
2375 lastWasPunctuation
= false;
2379 // Optionally insert a space before the next element. The AttrDict printer
2380 // already adds a space as necessary.
2381 if (shouldEmitSpace
|| !lastWasPunctuation
)
2382 body
<< " _odsPrinter << ' ';\n";
2383 lastWasPunctuation
= false;
2384 shouldEmitSpace
= true;
2386 if (auto *attr
= dyn_cast
<AttributeVariable
>(element
)) {
2387 const NamedAttribute
*var
= attr
->getVar();
2389 // If we are formatting as an enum, symbolize the attribute as a string.
2390 if (canFormatEnumAttr(var
))
2391 return genEnumAttrPrinter(var
, op
, body
);
2393 // If we are formatting as a symbol name, handle it as a symbol name.
2394 if (shouldFormatSymbolNameAttr(var
)) {
2395 body
<< " _odsPrinter.printSymbolName(" << op
.getGetterName(var
->name
)
2396 << "Attr().getValue());\n";
2400 // Elide the attribute type if it is buildable.
2401 if (attr
->getTypeBuilder())
2402 body
<< " _odsPrinter.printAttributeWithoutType("
2403 << op
.getGetterName(var
->name
) << "Attr());\n";
2404 else if (attr
->shouldBeQualified() ||
2405 var
->attr
.getStorageType() == "::mlir::Attribute")
2406 body
<< " _odsPrinter.printAttribute(" << op
.getGetterName(var
->name
)
2409 body
<< "_odsPrinter.printStrippedAttrOrType("
2410 << op
.getGetterName(var
->name
) << "Attr());\n";
2411 } else if (auto *operand
= dyn_cast
<OperandVariable
>(element
)) {
2412 if (operand
->getVar()->isVariadicOfVariadic()) {
2413 body
<< " ::llvm::interleaveComma("
2414 << op
.getGetterName(operand
->getVar()->name
)
2415 << "(), _odsPrinter, [&](const auto &operands) { _odsPrinter << "
2416 "\"(\" << operands << "
2419 } else if (operand
->getVar()->isOptional()) {
2420 body
<< " if (::mlir::Value value = "
2421 << op
.getGetterName(operand
->getVar()->name
) << "())\n"
2422 << " _odsPrinter << value;\n";
2424 body
<< " _odsPrinter << " << op
.getGetterName(operand
->getVar()->name
)
2427 } else if (auto *region
= dyn_cast
<RegionVariable
>(element
)) {
2428 const NamedRegion
*var
= region
->getVar();
2429 std::string name
= op
.getGetterName(var
->name
);
2430 if (var
->isVariadic()) {
2431 genVariadicRegionPrinter(name
+ "()", body
, hasImplicitTermTrait
);
2433 genRegionPrinter(name
+ "()", body
, hasImplicitTermTrait
);
2435 } else if (auto *successor
= dyn_cast
<SuccessorVariable
>(element
)) {
2436 const NamedSuccessor
*var
= successor
->getVar();
2437 std::string name
= op
.getGetterName(var
->name
);
2438 if (var
->isVariadic())
2439 body
<< " ::llvm::interleaveComma(" << name
<< "(), _odsPrinter);\n";
2441 body
<< " _odsPrinter << " << name
<< "();\n";
2442 } else if (auto *dir
= dyn_cast
<CustomDirective
>(element
)) {
2443 genCustomDirectivePrinter(dir
, op
, body
);
2444 } else if (isa
<OperandsDirective
>(element
)) {
2445 body
<< " _odsPrinter << getOperation()->getOperands();\n";
2446 } else if (isa
<RegionsDirective
>(element
)) {
2447 genVariadicRegionPrinter("getOperation()->getRegions()", body
,
2448 hasImplicitTermTrait
);
2449 } else if (isa
<SuccessorsDirective
>(element
)) {
2450 body
<< " ::llvm::interleaveComma(getOperation()->getSuccessors(), "
2452 } else if (auto *dir
= dyn_cast
<TypeDirective
>(element
)) {
2453 if (auto *operand
= dyn_cast
<OperandVariable
>(dir
->getArg())) {
2454 if (operand
->getVar()->isVariadicOfVariadic()) {
2455 body
<< llvm::formatv(
2456 " ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, "
2457 "[&](::mlir::TypeRange types) {{ _odsPrinter << \"(\" << "
2458 "types << \")\"; });\n",
2459 op
.getGetterName(operand
->getVar()->name
));
2463 const NamedTypeConstraint
*var
= nullptr;
2465 if (auto *operand
= dyn_cast
<OperandVariable
>(dir
->getArg()))
2466 var
= operand
->getVar();
2467 else if (auto *operand
= dyn_cast
<ResultVariable
>(dir
->getArg()))
2468 var
= operand
->getVar();
2470 if (var
&& !var
->isVariadicOfVariadic() && !var
->isVariadic() &&
2471 !var
->isOptional()) {
2472 std::string cppClass
= var
->constraint
.getCPPClassName();
2473 if (dir
->shouldBeQualified()) {
2474 body
<< " _odsPrinter << " << op
.getGetterName(var
->name
)
2475 << "().getType();\n";
2479 << " auto type = " << op
.getGetterName(var
->name
)
2480 << "().getType();\n"
2481 << " if (auto validType = ::llvm::dyn_cast<" << cppClass
2483 << " _odsPrinter.printStrippedAttrOrType(validType);\n"
2485 << " _odsPrinter << type;\n"
2489 body
<< " _odsPrinter << ";
2490 genTypeOperandPrinter(dir
->getArg(), op
, body
, /*useArrayRef=*/false)
2492 } else if (auto *dir
= dyn_cast
<FunctionalTypeDirective
>(element
)) {
2493 body
<< " _odsPrinter.printFunctionalType(";
2494 genTypeOperandPrinter(dir
->getInputs(), op
, body
) << ", ";
2495 genTypeOperandPrinter(dir
->getResults(), op
, body
) << ");\n";
2497 llvm_unreachable("unknown format element");
2501 void OperationFormat::genPrinter(Operator
&op
, OpClass
&opClass
) {
2502 auto *method
= opClass
.addMethod(
2504 MethodParameter("::mlir::OpAsmPrinter &", "_odsPrinter"));
2505 auto &body
= method
->body();
2507 // Flags for if we should emit a space, and if the last element was
2509 bool shouldEmitSpace
= true, lastWasPunctuation
= false;
2510 for (FormatElement
*element
: elements
)
2511 genElementPrinter(element
, body
, op
, shouldEmitSpace
, lastWasPunctuation
);
2514 //===----------------------------------------------------------------------===//
2516 //===----------------------------------------------------------------------===//
2518 /// Function to find an element within the given range that has the same name as
2520 template <typename RangeT
>
2521 static auto findArg(RangeT
&&range
, StringRef name
) {
2522 auto it
= llvm::find_if(range
, [=](auto &arg
) { return arg
.name
== name
; });
2523 return it
!= range
.end() ? &*it
: nullptr;
2527 /// This class implements a parser for an instance of an operation assembly
2529 class OpFormatParser
: public FormatParser
{
2531 OpFormatParser(llvm::SourceMgr
&mgr
, OperationFormat
&format
, Operator
&op
)
2532 : FormatParser(mgr
, op
.getLoc()[0]), fmt(format
), op(op
),
2533 seenOperandTypes(op
.getNumOperands()),
2534 seenResultTypes(op
.getNumResults()) {}
2537 /// Verify the format elements.
2538 LogicalResult
verify(SMLoc loc
, ArrayRef
<FormatElement
*> elements
) override
;
2539 /// Verify the arguments to a custom directive.
2541 verifyCustomDirectiveArguments(SMLoc loc
,
2542 ArrayRef
<FormatElement
*> arguments
) override
;
2543 /// Verify the elements of an optional group.
2544 LogicalResult
verifyOptionalGroupElements(SMLoc loc
,
2545 ArrayRef
<FormatElement
*> elements
,
2546 FormatElement
*anchor
) override
;
2547 LogicalResult
verifyOptionalGroupElement(SMLoc loc
, FormatElement
*element
,
2550 LogicalResult
markQualified(SMLoc loc
, FormatElement
*element
) override
;
2552 /// Parse an operation variable.
2553 FailureOr
<FormatElement
*> parseVariableImpl(SMLoc loc
, StringRef name
,
2554 Context ctx
) override
;
2555 /// Parse an operation format directive.
2556 FailureOr
<FormatElement
*>
2557 parseDirectiveImpl(SMLoc loc
, FormatToken::Kind kind
, Context ctx
) override
;
2560 /// This struct represents a type resolution instance. It includes a specific
2561 /// type as well as an optional transformer to apply to that type in order to
2562 /// properly resolve the type of a variable.
2563 struct TypeResolutionInstance
{
2564 ConstArgument resolver
;
2565 std::optional
<StringRef
> transformer
;
2568 /// Verify the state of operation attributes within the format.
2569 LogicalResult
verifyAttributes(SMLoc loc
, ArrayRef
<FormatElement
*> elements
);
2571 /// Verify that attributes elements aren't followed by colon literals.
2572 LogicalResult
verifyAttributeColonType(SMLoc loc
,
2573 ArrayRef
<FormatElement
*> elements
);
2574 /// Verify that the attribute dictionary directive isn't followed by a region.
2575 LogicalResult
verifyAttrDictRegion(SMLoc loc
,
2576 ArrayRef
<FormatElement
*> elements
);
2578 /// Verify the state of operation operands within the format.
2580 verifyOperands(SMLoc loc
,
2581 llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
);
2583 /// Verify the state of operation regions within the format.
2584 LogicalResult
verifyRegions(SMLoc loc
);
2586 /// Verify the state of operation results within the format.
2588 verifyResults(SMLoc loc
,
2589 llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
);
2591 /// Verify the state of operation successors within the format.
2592 LogicalResult
verifySuccessors(SMLoc loc
);
2594 LogicalResult
verifyOIListElements(SMLoc loc
,
2595 ArrayRef
<FormatElement
*> elements
);
2597 /// Given the values of an `AllTypesMatch` trait, check for inferable type
2599 void handleAllTypesMatchConstraint(
2600 ArrayRef
<StringRef
> values
,
2601 llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
);
2602 /// Check for inferable type resolution given all operands, and or results,
2603 /// have the same type. If 'includeResults' is true, the results also have the
2604 /// same type as all of the operands.
2605 void handleSameTypesConstraint(
2606 llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
,
2607 bool includeResults
);
2608 /// Check for inferable type resolution based on another operand, result, or
2610 void handleTypesMatchConstraint(
2611 llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
,
2612 const llvm::Record
&def
);
2614 /// Returns an argument or attribute with the given name that has been seen
2615 /// within the format.
2616 ConstArgument
findSeenArg(StringRef name
);
2618 /// Parse the various different directives.
2619 FailureOr
<FormatElement
*> parsePropDictDirective(SMLoc loc
, Context context
);
2620 FailureOr
<FormatElement
*> parseAttrDictDirective(SMLoc loc
, Context context
,
2622 FailureOr
<FormatElement
*> parseFunctionalTypeDirective(SMLoc loc
,
2624 FailureOr
<FormatElement
*> parseOIListDirective(SMLoc loc
, Context context
);
2625 LogicalResult
verifyOIListParsingElement(FormatElement
*element
, SMLoc loc
);
2626 FailureOr
<FormatElement
*> parseOperandsDirective(SMLoc loc
, Context context
);
2627 FailureOr
<FormatElement
*> parseRegionsDirective(SMLoc loc
, Context context
);
2628 FailureOr
<FormatElement
*> parseResultsDirective(SMLoc loc
, Context context
);
2629 FailureOr
<FormatElement
*> parseSuccessorsDirective(SMLoc loc
,
2631 FailureOr
<FormatElement
*> parseTypeDirective(SMLoc loc
, Context context
);
2632 FailureOr
<FormatElement
*> parseTypeDirectiveOperand(SMLoc loc
,
2633 bool isRefChild
= false);
2635 //===--------------------------------------------------------------------===//
2637 //===--------------------------------------------------------------------===//
2639 OperationFormat
&fmt
;
2642 // The following are various bits of format state used for verification
2644 bool hasAttrDict
= false;
2645 bool hasPropDict
= false;
2646 bool hasAllRegions
= false, hasAllSuccessors
= false;
2647 bool canInferResultTypes
= false;
2648 llvm::SmallBitVector seenOperandTypes
, seenResultTypes
;
2649 llvm::SmallSetVector
<const NamedAttribute
*, 8> seenAttrs
;
2650 llvm::DenseSet
<const NamedTypeConstraint
*> seenOperands
;
2651 llvm::DenseSet
<const NamedRegion
*> seenRegions
;
2652 llvm::DenseSet
<const NamedSuccessor
*> seenSuccessors
;
2653 llvm::SmallSetVector
<const NamedProperty
*, 8> seenProperties
;
2657 LogicalResult
OpFormatParser::verify(SMLoc loc
,
2658 ArrayRef
<FormatElement
*> elements
) {
2659 // Check that the attribute dictionary is in the format.
2661 return emitError(loc
, "'attr-dict' directive not found in "
2662 "custom assembly format");
2664 // Check for any type traits that we can use for inferring types.
2665 llvm::StringMap
<TypeResolutionInstance
> variableTyResolver
;
2666 for (const Trait
&trait
: op
.getTraits()) {
2667 const llvm::Record
&def
= trait
.getDef();
2668 if (def
.isSubClassOf("AllTypesMatch")) {
2669 handleAllTypesMatchConstraint(def
.getValueAsListOfStrings("values"),
2670 variableTyResolver
);
2671 } else if (def
.getName() == "SameTypeOperands") {
2672 handleSameTypesConstraint(variableTyResolver
, /*includeResults=*/false);
2673 } else if (def
.getName() == "SameOperandsAndResultType") {
2674 handleSameTypesConstraint(variableTyResolver
, /*includeResults=*/true);
2675 } else if (def
.isSubClassOf("TypesMatchWith")) {
2676 handleTypesMatchConstraint(variableTyResolver
, def
);
2677 } else if (!op
.allResultTypesKnown()) {
2678 // This doesn't check the name directly to handle
2679 // DeclareOpInterfaceMethods<InferTypeOpInterface>
2681 // TODO: Add hasCppInterface check.
2682 if (auto name
= def
.getValueAsOptionalString("cppInterfaceName")) {
2683 if (*name
== "InferTypeOpInterface" &&
2684 def
.getValueAsString("cppNamespace") == "::mlir")
2685 canInferResultTypes
= true;
2690 // Verify the state of the various operation components.
2691 if (failed(verifyAttributes(loc
, elements
)) ||
2692 failed(verifyResults(loc
, variableTyResolver
)) ||
2693 failed(verifyOperands(loc
, variableTyResolver
)) ||
2694 failed(verifyRegions(loc
)) || failed(verifySuccessors(loc
)) ||
2695 failed(verifyOIListElements(loc
, elements
)))
2698 // Collect the set of used attributes in the format.
2699 fmt
.usedAttributes
= std::move(seenAttrs
);
2700 fmt
.usedProperties
= std::move(seenProperties
);
2702 // Set whether prop-dict is used in the format
2703 fmt
.hasPropDict
= hasPropDict
;
2708 OpFormatParser::verifyAttributes(SMLoc loc
,
2709 ArrayRef
<FormatElement
*> elements
) {
2710 // Check that there are no `:` literals after an attribute without a constant
2711 // type. The attribute grammar contains an optional trailing colon type, which
2712 // can lead to unexpected and generally unintended behavior. Given that, it is
2713 // better to just error out here instead.
2714 if (failed(verifyAttributeColonType(loc
, elements
)))
2716 // Check that there are no region variables following an attribute dicitonary.
2717 // Both start with `{` and so the optional attribute dictionary can cause
2718 // format ambiguities.
2719 if (failed(verifyAttrDictRegion(loc
, elements
)))
2722 // Check for VariadicOfVariadic variables. The segment attribute of those
2723 // variables will be infered.
2724 for (const NamedTypeConstraint
*var
: seenOperands
) {
2725 if (var
->constraint
.isVariadicOfVariadic()) {
2726 fmt
.inferredAttributes
.insert(
2727 var
->constraint
.getVariadicOfVariadicSegmentSizeAttr());
2734 /// Returns whether the single format element is optionally parsed.
2735 static bool isOptionallyParsed(FormatElement
*el
) {
2736 if (auto *attrVar
= dyn_cast
<AttributeVariable
>(el
)) {
2737 Attribute attr
= attrVar
->getVar()->attr
;
2738 return attr
.isOptional() || attr
.hasDefaultValue();
2740 if (auto *operandVar
= dyn_cast
<OperandVariable
>(el
)) {
2741 const NamedTypeConstraint
*operand
= operandVar
->getVar();
2742 return operand
->isOptional() || operand
->isVariadic() ||
2743 operand
->isVariadicOfVariadic();
2745 if (auto *successorVar
= dyn_cast
<SuccessorVariable
>(el
))
2746 return successorVar
->getVar()->isVariadic();
2747 if (auto *regionVar
= dyn_cast
<RegionVariable
>(el
))
2748 return regionVar
->getVar()->isVariadic();
2749 return isa
<WhitespaceElement
, AttrDictDirective
>(el
);
2752 /// Scan the given range of elements from the start for an invalid format
2753 /// element that satisfies `isInvalid`, skipping any optionally-parsed elements.
2754 /// If an optional group is encountered, this function recurses into the 'then'
2755 /// and 'else' elements to check if they are invalid. Returns `success` if the
2756 /// range is known to be valid or `std::nullopt` if scanning reached the end.
2758 /// Since the guard element of an optional group is required, this function
2759 /// accepts an optional element pointer to mark it as required.
2760 static std::optional
<LogicalResult
> checkRangeForElement(
2761 FormatElement
*base
,
2762 function_ref
<bool(FormatElement
*, FormatElement
*)> isInvalid
,
2763 iterator_range
<ArrayRef
<FormatElement
*>::iterator
> elementRange
,
2764 FormatElement
*optionalGuard
= nullptr) {
2765 for (FormatElement
*element
: elementRange
) {
2766 // If we encounter an invalid element, return an error.
2767 if (isInvalid(base
, element
))
2770 // Recurse on optional groups.
2771 if (auto *optional
= dyn_cast
<OptionalElement
>(element
)) {
2772 if (std::optional
<LogicalResult
> result
= checkRangeForElement(
2773 base
, isInvalid
, optional
->getThenElements(),
2774 // The optional group guard is required for the group.
2775 optional
->getThenElements().front()))
2776 if (failed(*result
))
2778 if (std::optional
<LogicalResult
> result
= checkRangeForElement(
2779 base
, isInvalid
, optional
->getElseElements()))
2780 if (failed(*result
))
2782 // Skip the optional group.
2786 // Skip optionally parsed elements.
2787 if (element
!= optionalGuard
&& isOptionallyParsed(element
))
2790 // We found a closing element that is valid.
2793 // Return std::nullopt to indicate that we reached the end.
2794 return std::nullopt
;
2797 /// For the given elements, check whether any attributes are followed by a colon
2798 /// literal, resulting in an ambiguous assembly format. Returns a non-null
2799 /// attribute if verification of said attribute reached the end of the range.
2800 /// Returns null if all attribute elements are verified.
2801 static FailureOr
<FormatElement
*> verifyAdjacentElements(
2802 function_ref
<bool(FormatElement
*)> isBase
,
2803 function_ref
<bool(FormatElement
*, FormatElement
*)> isInvalid
,
2804 ArrayRef
<FormatElement
*> elements
) {
2805 for (auto *it
= elements
.begin(), *e
= elements
.end(); it
!= e
; ++it
) {
2806 // The current attribute being verified.
2807 FormatElement
*base
;
2811 } else if (auto *optional
= dyn_cast
<OptionalElement
>(*it
)) {
2812 // Recurse on optional groups.
2813 FailureOr
<FormatElement
*> thenResult
= verifyAdjacentElements(
2814 isBase
, isInvalid
, optional
->getThenElements());
2815 if (failed(thenResult
))
2817 FailureOr
<FormatElement
*> elseResult
= verifyAdjacentElements(
2818 isBase
, isInvalid
, optional
->getElseElements());
2819 if (failed(elseResult
))
2821 // If either optional group has an unverified attribute, save it.
2822 // Otherwise, move on to the next element.
2823 if (!(base
= *thenResult
) && !(base
= *elseResult
))
2829 // Verify subsequent elements for potential ambiguities.
2830 if (std::optional
<LogicalResult
> result
=
2831 checkRangeForElement(base
, isInvalid
, {std::next(it
), e
})) {
2832 if (failed(*result
))
2835 // Since we reached the end, return the attribute as unverified.
2839 // All attribute elements are known to be verified.
2844 OpFormatParser::verifyAttributeColonType(SMLoc loc
,
2845 ArrayRef
<FormatElement
*> elements
) {
2846 auto isBase
= [](FormatElement
*el
) {
2847 auto *attr
= dyn_cast
<AttributeVariable
>(el
);
2850 // Check only attributes without type builders or that are known to call
2851 // the generic attribute parser.
2852 return !attr
->getTypeBuilder() &&
2853 (attr
->shouldBeQualified() ||
2854 attr
->getVar()->attr
.getStorageType() == "::mlir::Attribute");
2856 auto isInvalid
= [&](FormatElement
*base
, FormatElement
*el
) {
2857 auto *literal
= dyn_cast
<LiteralElement
>(el
);
2858 if (!literal
|| literal
->getSpelling() != ":")
2860 // If we encounter `:`, the range is known to be invalid.
2863 llvm::formatv("format ambiguity caused by `:` literal found after "
2864 "attribute `{0}` which does not have a buildable type",
2865 cast
<AttributeVariable
>(base
)->getVar()->name
));
2868 return verifyAdjacentElements(isBase
, isInvalid
, elements
);
2872 OpFormatParser::verifyAttrDictRegion(SMLoc loc
,
2873 ArrayRef
<FormatElement
*> elements
) {
2874 auto isBase
= [](FormatElement
*el
) {
2875 if (auto *attrDict
= dyn_cast
<AttrDictDirective
>(el
))
2876 return !attrDict
->isWithKeyword();
2879 auto isInvalid
= [&](FormatElement
*base
, FormatElement
*el
) {
2880 auto *region
= dyn_cast
<RegionVariable
>(el
);
2883 (void)emitErrorAndNote(
2885 llvm::formatv("format ambiguity caused by `attr-dict` directive "
2886 "followed by region `{0}`",
2887 region
->getVar()->name
),
2888 "try using `attr-dict-with-keyword` instead");
2891 return verifyAdjacentElements(isBase
, isInvalid
, elements
);
2894 LogicalResult
OpFormatParser::verifyOperands(
2895 SMLoc loc
, llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
) {
2896 // Check that all of the operands are within the format, and their types can
2898 auto &buildableTypes
= fmt
.buildableTypes
;
2899 for (unsigned i
= 0, e
= op
.getNumOperands(); i
!= e
; ++i
) {
2900 NamedTypeConstraint
&operand
= op
.getOperand(i
);
2902 // Check that the operand itself is in the format.
2903 if (!fmt
.allOperands
&& !seenOperands
.count(&operand
)) {
2904 return emitErrorAndNote(loc
,
2905 "operand #" + Twine(i
) + ", named '" +
2906 operand
.name
+ "', not found",
2907 "suggest adding a '$" + operand
.name
+
2908 "' directive to the custom assembly format");
2911 // Check that the operand type is in the format, or that it can be inferred.
2912 if (fmt
.allOperandTypes
|| seenOperandTypes
.test(i
))
2915 // Check to see if we can infer this type from another variable.
2916 auto varResolverIt
= variableTyResolver
.find(op
.getOperand(i
).name
);
2917 if (varResolverIt
!= variableTyResolver
.end()) {
2918 TypeResolutionInstance
&resolver
= varResolverIt
->second
;
2919 fmt
.operandTypes
[i
].setResolver(resolver
.resolver
, resolver
.transformer
);
2923 // Similarly to results, allow a custom builder for resolving the type if
2924 // we aren't using the 'operands' directive.
2925 std::optional
<StringRef
> builder
= operand
.constraint
.getBuilderCall();
2926 if (!builder
|| (fmt
.allOperands
&& operand
.isVariableLength())) {
2927 return emitErrorAndNote(
2929 "type of operand #" + Twine(i
) + ", named '" + operand
.name
+
2930 "', is not buildable and a buildable type cannot be inferred",
2931 "suggest adding a type constraint to the operation or adding a "
2933 operand
.name
+ ")' directive to the " + "custom assembly format");
2935 auto it
= buildableTypes
.insert({*builder
, buildableTypes
.size()});
2936 fmt
.operandTypes
[i
].setBuilderIdx(it
.first
->second
);
2941 LogicalResult
OpFormatParser::verifyRegions(SMLoc loc
) {
2942 // Check that all of the regions are within the format.
2946 for (unsigned i
= 0, e
= op
.getNumRegions(); i
!= e
; ++i
) {
2947 const NamedRegion
®ion
= op
.getRegion(i
);
2948 if (!seenRegions
.count(®ion
)) {
2949 return emitErrorAndNote(loc
,
2950 "region #" + Twine(i
) + ", named '" +
2951 region
.name
+ "', not found",
2952 "suggest adding a '$" + region
.name
+
2953 "' directive to the custom assembly format");
2959 LogicalResult
OpFormatParser::verifyResults(
2960 SMLoc loc
, llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
) {
2961 // If we format all of the types together, there is nothing to check.
2962 if (fmt
.allResultTypes
)
2965 // If no result types are specified and we can infer them, infer all result
2967 if (op
.getNumResults() > 0 && seenResultTypes
.count() == 0 &&
2968 canInferResultTypes
) {
2969 fmt
.infersResultTypes
= true;
2973 // Check that all of the result types can be inferred.
2974 auto &buildableTypes
= fmt
.buildableTypes
;
2975 for (unsigned i
= 0, e
= op
.getNumResults(); i
!= e
; ++i
) {
2976 if (seenResultTypes
.test(i
))
2979 // Check to see if we can infer this type from another variable.
2980 auto varResolverIt
= variableTyResolver
.find(op
.getResultName(i
));
2981 if (varResolverIt
!= variableTyResolver
.end()) {
2982 TypeResolutionInstance resolver
= varResolverIt
->second
;
2983 fmt
.resultTypes
[i
].setResolver(resolver
.resolver
, resolver
.transformer
);
2987 // If the result is not variable length, allow for the case where the type
2988 // has a builder that we can use.
2989 NamedTypeConstraint
&result
= op
.getResult(i
);
2990 std::optional
<StringRef
> builder
= result
.constraint
.getBuilderCall();
2991 if (!builder
|| result
.isVariableLength()) {
2992 return emitErrorAndNote(
2994 "type of result #" + Twine(i
) + ", named '" + result
.name
+
2995 "', is not buildable and a buildable type cannot be inferred",
2996 "suggest adding a type constraint to the operation or adding a "
2998 result
.name
+ ")' directive to the " + "custom assembly format");
3000 // Note in the format that this result uses the custom builder.
3001 auto it
= buildableTypes
.insert({*builder
, buildableTypes
.size()});
3002 fmt
.resultTypes
[i
].setBuilderIdx(it
.first
->second
);
3007 LogicalResult
OpFormatParser::verifySuccessors(SMLoc loc
) {
3008 // Check that all of the successors are within the format.
3009 if (hasAllSuccessors
)
3012 for (unsigned i
= 0, e
= op
.getNumSuccessors(); i
!= e
; ++i
) {
3013 const NamedSuccessor
&successor
= op
.getSuccessor(i
);
3014 if (!seenSuccessors
.count(&successor
)) {
3015 return emitErrorAndNote(loc
,
3016 "successor #" + Twine(i
) + ", named '" +
3017 successor
.name
+ "', not found",
3018 "suggest adding a '$" + successor
.name
+
3019 "' directive to the custom assembly format");
3026 OpFormatParser::verifyOIListElements(SMLoc loc
,
3027 ArrayRef
<FormatElement
*> elements
) {
3028 // Check that all of the successors are within the format.
3029 SmallVector
<StringRef
> prohibitedLiterals
;
3030 for (FormatElement
*it
: elements
) {
3031 if (auto *oilist
= dyn_cast
<OIListElement
>(it
)) {
3032 if (!prohibitedLiterals
.empty()) {
3033 // We just saw an oilist element in last iteration. Literals should not
3035 for (LiteralElement
*literal
: oilist
->getLiteralElements()) {
3036 if (find(prohibitedLiterals
, literal
->getSpelling()) !=
3037 prohibitedLiterals
.end()) {
3039 loc
, "format ambiguity because " + literal
->getSpelling() +
3040 " is used in two adjacent oilist elements.");
3044 for (LiteralElement
*literal
: oilist
->getLiteralElements())
3045 prohibitedLiterals
.push_back(literal
->getSpelling());
3046 } else if (auto *literal
= dyn_cast
<LiteralElement
>(it
)) {
3047 if (find(prohibitedLiterals
, literal
->getSpelling()) !=
3048 prohibitedLiterals
.end()) {
3051 "format ambiguity because " + literal
->getSpelling() +
3052 " is used both in oilist element and the adjacent literal.");
3054 prohibitedLiterals
.clear();
3056 prohibitedLiterals
.clear();
3062 void OpFormatParser::handleAllTypesMatchConstraint(
3063 ArrayRef
<StringRef
> values
,
3064 llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
) {
3065 for (unsigned i
= 0, e
= values
.size(); i
!= e
; ++i
) {
3066 // Check to see if this value matches a resolved operand or result type.
3067 ConstArgument arg
= findSeenArg(values
[i
]);
3071 // Mark this value as the type resolver for the other variables.
3072 for (unsigned j
= 0; j
!= i
; ++j
)
3073 variableTyResolver
[values
[j
]] = {arg
, std::nullopt
};
3074 for (unsigned j
= i
+ 1; j
!= e
; ++j
)
3075 variableTyResolver
[values
[j
]] = {arg
, std::nullopt
};
3079 void OpFormatParser::handleSameTypesConstraint(
3080 llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
,
3081 bool includeResults
) {
3082 const NamedTypeConstraint
*resolver
= nullptr;
3083 int resolvedIt
= -1;
3085 // Check to see if there is an operand or result to use for the resolution.
3086 if ((resolvedIt
= seenOperandTypes
.find_first()) != -1)
3087 resolver
= &op
.getOperand(resolvedIt
);
3088 else if (includeResults
&& (resolvedIt
= seenResultTypes
.find_first()) != -1)
3089 resolver
= &op
.getResult(resolvedIt
);
3093 // Set the resolvers for each operand and result.
3094 for (unsigned i
= 0, e
= op
.getNumOperands(); i
!= e
; ++i
)
3095 if (!seenOperandTypes
.test(i
))
3096 variableTyResolver
[op
.getOperand(i
).name
] = {resolver
, std::nullopt
};
3097 if (includeResults
) {
3098 for (unsigned i
= 0, e
= op
.getNumResults(); i
!= e
; ++i
)
3099 if (!seenResultTypes
.test(i
))
3100 variableTyResolver
[op
.getResultName(i
)] = {resolver
, std::nullopt
};
3104 void OpFormatParser::handleTypesMatchConstraint(
3105 llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
,
3106 const llvm::Record
&def
) {
3107 StringRef lhsName
= def
.getValueAsString("lhs");
3108 StringRef rhsName
= def
.getValueAsString("rhs");
3109 StringRef transformer
= def
.getValueAsString("transformer");
3110 if (ConstArgument arg
= findSeenArg(lhsName
))
3111 variableTyResolver
[rhsName
] = {arg
, transformer
};
3114 ConstArgument
OpFormatParser::findSeenArg(StringRef name
) {
3115 if (const NamedTypeConstraint
*arg
= findArg(op
.getOperands(), name
))
3116 return seenOperandTypes
.test(arg
- op
.operand_begin()) ? arg
: nullptr;
3117 if (const NamedTypeConstraint
*arg
= findArg(op
.getResults(), name
))
3118 return seenResultTypes
.test(arg
- op
.result_begin()) ? arg
: nullptr;
3119 if (const NamedAttribute
*attr
= findArg(op
.getAttributes(), name
))
3120 return seenAttrs
.count(attr
) ? attr
: nullptr;
3124 FailureOr
<FormatElement
*>
3125 OpFormatParser::parseVariableImpl(SMLoc loc
, StringRef name
, Context ctx
) {
3126 // Check that the parsed argument is something actually registered on the op.
3128 if (const NamedAttribute
*attr
= findArg(op
.getAttributes(), name
)) {
3129 if (ctx
== TypeDirectiveContext
)
3131 loc
, "attributes cannot be used as children to a `type` directive");
3132 if (ctx
== RefDirectiveContext
) {
3133 if (!seenAttrs
.count(attr
))
3134 return emitError(loc
, "attribute '" + name
+
3135 "' must be bound before it is referenced");
3136 } else if (!seenAttrs
.insert(attr
)) {
3137 return emitError(loc
, "attribute '" + name
+ "' is already bound");
3140 return create
<AttributeVariable
>(attr
);
3143 if (const NamedProperty
*property
= findArg(op
.getProperties(), name
)) {
3144 if (ctx
!= CustomDirectiveContext
&& ctx
!= RefDirectiveContext
)
3146 loc
, "properties currently only supported in `custom` directive");
3148 if (ctx
== RefDirectiveContext
) {
3149 if (!seenProperties
.count(property
))
3150 return emitError(loc
, "property '" + name
+
3151 "' must be bound before it is referenced");
3153 if (!seenProperties
.insert(property
))
3154 return emitError(loc
, "property '" + name
+ "' is already bound");
3157 return create
<PropertyVariable
>(property
);
3161 if (const NamedTypeConstraint
*operand
= findArg(op
.getOperands(), name
)) {
3162 if (ctx
== TopLevelContext
|| ctx
== CustomDirectiveContext
) {
3163 if (fmt
.allOperands
|| !seenOperands
.insert(operand
).second
)
3164 return emitError(loc
, "operand '" + name
+ "' is already bound");
3165 } else if (ctx
== RefDirectiveContext
&& !seenOperands
.count(operand
)) {
3166 return emitError(loc
, "operand '" + name
+
3167 "' must be bound before it is referenced");
3169 return create
<OperandVariable
>(operand
);
3172 if (const NamedRegion
*region
= findArg(op
.getRegions(), name
)) {
3173 if (ctx
== TopLevelContext
|| ctx
== CustomDirectiveContext
) {
3174 if (hasAllRegions
|| !seenRegions
.insert(region
).second
)
3175 return emitError(loc
, "region '" + name
+ "' is already bound");
3176 } else if (ctx
== RefDirectiveContext
&& !seenRegions
.count(region
)) {
3177 return emitError(loc
, "region '" + name
+
3178 "' must be bound before it is referenced");
3180 return emitError(loc
, "regions can only be used at the top level");
3182 return create
<RegionVariable
>(region
);
3185 if (const auto *result
= findArg(op
.getResults(), name
)) {
3186 if (ctx
!= TypeDirectiveContext
)
3187 return emitError(loc
, "result variables can can only be used as a child "
3188 "to a 'type' directive");
3189 return create
<ResultVariable
>(result
);
3192 if (const auto *successor
= findArg(op
.getSuccessors(), name
)) {
3193 if (ctx
== TopLevelContext
|| ctx
== CustomDirectiveContext
) {
3194 if (hasAllSuccessors
|| !seenSuccessors
.insert(successor
).second
)
3195 return emitError(loc
, "successor '" + name
+ "' is already bound");
3196 } else if (ctx
== RefDirectiveContext
&& !seenSuccessors
.count(successor
)) {
3197 return emitError(loc
, "successor '" + name
+
3198 "' must be bound before it is referenced");
3200 return emitError(loc
, "successors can only be used at the top level");
3203 return create
<SuccessorVariable
>(successor
);
3205 return emitError(loc
, "expected variable to refer to an argument, region, "
3206 "result, or successor");
3209 FailureOr
<FormatElement
*>
3210 OpFormatParser::parseDirectiveImpl(SMLoc loc
, FormatToken::Kind kind
,
3213 case FormatToken::kw_prop_dict
:
3214 return parsePropDictDirective(loc
, ctx
);
3215 case FormatToken::kw_attr_dict
:
3216 return parseAttrDictDirective(loc
, ctx
,
3217 /*withKeyword=*/false);
3218 case FormatToken::kw_attr_dict_w_keyword
:
3219 return parseAttrDictDirective(loc
, ctx
,
3220 /*withKeyword=*/true);
3221 case FormatToken::kw_functional_type
:
3222 return parseFunctionalTypeDirective(loc
, ctx
);
3223 case FormatToken::kw_operands
:
3224 return parseOperandsDirective(loc
, ctx
);
3225 case FormatToken::kw_regions
:
3226 return parseRegionsDirective(loc
, ctx
);
3227 case FormatToken::kw_results
:
3228 return parseResultsDirective(loc
, ctx
);
3229 case FormatToken::kw_successors
:
3230 return parseSuccessorsDirective(loc
, ctx
);
3231 case FormatToken::kw_type
:
3232 return parseTypeDirective(loc
, ctx
);
3233 case FormatToken::kw_oilist
:
3234 return parseOIListDirective(loc
, ctx
);
3237 return emitError(loc
, "unsupported directive kind");
3241 FailureOr
<FormatElement
*>
3242 OpFormatParser::parseAttrDictDirective(SMLoc loc
, Context context
,
3244 if (context
== TypeDirectiveContext
)
3245 return emitError(loc
, "'attr-dict' directive can only be used as a "
3246 "top-level directive");
3248 if (context
== RefDirectiveContext
) {
3250 return emitError(loc
, "'ref' of 'attr-dict' is not bound by a prior "
3251 "'attr-dict' directive");
3253 // Otherwise, this is a top-level context.
3256 return emitError(loc
, "'attr-dict' directive has already been seen");
3260 return create
<AttrDictDirective
>(withKeyword
);
3263 FailureOr
<FormatElement
*>
3264 OpFormatParser::parsePropDictDirective(SMLoc loc
, Context context
) {
3265 if (context
== TypeDirectiveContext
)
3266 return emitError(loc
, "'prop-dict' directive can only be used as a "
3267 "top-level directive");
3269 if (context
== RefDirectiveContext
)
3270 llvm::report_fatal_error("'ref' of 'prop-dict' unsupported");
3271 // Otherwise, this is a top-level context.
3274 return emitError(loc
, "'prop-dict' directive has already been seen");
3277 return create
<PropDictDirective
>();
3280 LogicalResult
OpFormatParser::verifyCustomDirectiveArguments(
3281 SMLoc loc
, ArrayRef
<FormatElement
*> arguments
) {
3282 for (FormatElement
*argument
: arguments
) {
3283 if (!isa
<AttrDictDirective
, PropDictDirective
, AttributeVariable
,
3284 OperandVariable
, PropertyVariable
, RefDirective
, RegionVariable
,
3285 SuccessorVariable
, StringElement
, TypeDirective
>(argument
)) {
3286 // TODO: FormatElement should have location info attached.
3287 return emitError(loc
, "only variables and types may be used as "
3288 "parameters to a custom directive");
3290 if (auto *type
= dyn_cast
<TypeDirective
>(argument
)) {
3291 if (!isa
<OperandVariable
, ResultVariable
>(type
->getArg())) {
3292 return emitError(loc
, "type directives within a custom directive may "
3293 "only refer to variables");
3300 FailureOr
<FormatElement
*>
3301 OpFormatParser::parseFunctionalTypeDirective(SMLoc loc
, Context context
) {
3302 if (context
!= TopLevelContext
)
3304 loc
, "'functional-type' is only valid as a top-level directive");
3306 // Parse the main operand.
3307 FailureOr
<FormatElement
*> inputs
, results
;
3308 if (failed(parseToken(FormatToken::l_paren
,
3309 "expected '(' before argument list")) ||
3310 failed(inputs
= parseTypeDirectiveOperand(loc
)) ||
3311 failed(parseToken(FormatToken::comma
,
3312 "expected ',' after inputs argument")) ||
3313 failed(results
= parseTypeDirectiveOperand(loc
)) ||
3315 parseToken(FormatToken::r_paren
, "expected ')' after argument list")))
3317 return create
<FunctionalTypeDirective
>(*inputs
, *results
);
3320 FailureOr
<FormatElement
*>
3321 OpFormatParser::parseOperandsDirective(SMLoc loc
, Context context
) {
3322 if (context
== RefDirectiveContext
) {
3323 if (!fmt
.allOperands
)
3324 return emitError(loc
, "'ref' of 'operands' is not bound by a prior "
3325 "'operands' directive");
3327 } else if (context
== TopLevelContext
|| context
== CustomDirectiveContext
) {
3328 if (fmt
.allOperands
|| !seenOperands
.empty())
3329 return emitError(loc
, "'operands' directive creates overlap in format");
3330 fmt
.allOperands
= true;
3332 return create
<OperandsDirective
>();
3335 FailureOr
<FormatElement
*>
3336 OpFormatParser::parseRegionsDirective(SMLoc loc
, Context context
) {
3337 if (context
== TypeDirectiveContext
)
3338 return emitError(loc
, "'regions' is only valid as a top-level directive");
3339 if (context
== RefDirectiveContext
) {
3341 return emitError(loc
, "'ref' of 'regions' is not bound by a prior "
3342 "'regions' directive");
3344 // Otherwise, this is a TopLevel directive.
3346 if (hasAllRegions
|| !seenRegions
.empty())
3347 return emitError(loc
, "'regions' directive creates overlap in format");
3348 hasAllRegions
= true;
3350 return create
<RegionsDirective
>();
3353 FailureOr
<FormatElement
*>
3354 OpFormatParser::parseResultsDirective(SMLoc loc
, Context context
) {
3355 if (context
!= TypeDirectiveContext
)
3356 return emitError(loc
, "'results' directive can can only be used as a child "
3357 "to a 'type' directive");
3358 return create
<ResultsDirective
>();
3361 FailureOr
<FormatElement
*>
3362 OpFormatParser::parseSuccessorsDirective(SMLoc loc
, Context context
) {
3363 if (context
== TypeDirectiveContext
)
3364 return emitError(loc
,
3365 "'successors' is only valid as a top-level directive");
3366 if (context
== RefDirectiveContext
) {
3367 if (!hasAllSuccessors
)
3368 return emitError(loc
, "'ref' of 'successors' is not bound by a prior "
3369 "'successors' directive");
3371 // Otherwise, this is a TopLevel directive.
3373 if (hasAllSuccessors
|| !seenSuccessors
.empty())
3374 return emitError(loc
, "'successors' directive creates overlap in format");
3375 hasAllSuccessors
= true;
3377 return create
<SuccessorsDirective
>();
3380 FailureOr
<FormatElement
*>
3381 OpFormatParser::parseOIListDirective(SMLoc loc
, Context context
) {
3382 if (failed(parseToken(FormatToken::l_paren
,
3383 "expected '(' before oilist argument list")))
3385 std::vector
<FormatElement
*> literalElements
;
3386 std::vector
<std::vector
<FormatElement
*>> parsingElements
;
3388 FailureOr
<FormatElement
*> lelement
= parseLiteral(context
);
3389 if (failed(lelement
))
3391 literalElements
.push_back(*lelement
);
3392 parsingElements
.emplace_back();
3393 std::vector
<FormatElement
*> &currParsingElements
= parsingElements
.back();
3394 while (peekToken().getKind() != FormatToken::pipe
&&
3395 peekToken().getKind() != FormatToken::r_paren
) {
3396 FailureOr
<FormatElement
*> pelement
= parseElement(context
);
3397 if (failed(pelement
) ||
3398 failed(verifyOIListParsingElement(*pelement
, loc
)))
3400 currParsingElements
.push_back(*pelement
);
3402 if (peekToken().getKind() == FormatToken::pipe
) {
3406 if (peekToken().getKind() == FormatToken::r_paren
) {
3412 return create
<OIListElement
>(std::move(literalElements
),
3413 std::move(parsingElements
));
3416 LogicalResult
OpFormatParser::verifyOIListParsingElement(FormatElement
*element
,
3418 SmallVector
<VariableElement
*> vars
;
3419 collect(element
, vars
);
3420 for (VariableElement
*elem
: vars
) {
3422 TypeSwitch
<FormatElement
*, LogicalResult
>(elem
)
3423 // Only optional attributes can be within an oilist parsing group.
3424 .Case([&](AttributeVariable
*attrEle
) {
3425 if (!attrEle
->getVar()->attr
.isOptional() &&
3426 !attrEle
->getVar()->attr
.hasDefaultValue())
3427 return emitError(loc
, "only optional attributes can be used in "
3428 "an oilist parsing group");
3431 // Only optional-like(i.e. variadic) operands can be within an
3432 // oilist parsing group.
3433 .Case([&](OperandVariable
*ele
) {
3434 if (!ele
->getVar()->isVariableLength())
3435 return emitError(loc
, "only variable length operands can be "
3436 "used within an oilist parsing group");
3439 // Only optional-like(i.e. variadic) results can be within an oilist
3441 .Case([&](ResultVariable
*ele
) {
3442 if (!ele
->getVar()->isVariableLength())
3443 return emitError(loc
, "only variable length results can be "
3444 "used within an oilist parsing group");
3447 .Case([&](RegionVariable
*) { return success(); })
3448 .Default([&](FormatElement
*) {
3449 return emitError(loc
,
3450 "only literals, types, and variables can be "
3451 "used within an oilist group");
3459 FailureOr
<FormatElement
*> OpFormatParser::parseTypeDirective(SMLoc loc
,
3461 if (context
== TypeDirectiveContext
)
3462 return emitError(loc
, "'type' cannot be used as a child of another `type`");
3464 bool isRefChild
= context
== RefDirectiveContext
;
3465 FailureOr
<FormatElement
*> operand
;
3466 if (failed(parseToken(FormatToken::l_paren
,
3467 "expected '(' before argument list")) ||
3468 failed(operand
= parseTypeDirectiveOperand(loc
, isRefChild
)) ||
3470 parseToken(FormatToken::r_paren
, "expected ')' after argument list")))
3473 return create
<TypeDirective
>(*operand
);
3476 LogicalResult
OpFormatParser::markQualified(SMLoc loc
, FormatElement
*element
) {
3477 return TypeSwitch
<FormatElement
*, LogicalResult
>(element
)
3478 .Case
<AttributeVariable
, TypeDirective
>([](auto *element
) {
3479 element
->setShouldBeQualified();
3482 .Default([&](auto *element
) {
3483 return this->emitError(
3485 "'qualified' directive expects an attribute or a `type` directive");
3489 FailureOr
<FormatElement
*>
3490 OpFormatParser::parseTypeDirectiveOperand(SMLoc loc
, bool isRefChild
) {
3491 FailureOr
<FormatElement
*> result
= parseElement(TypeDirectiveContext
);
3495 FormatElement
*element
= *result
;
3496 if (isa
<LiteralElement
>(element
))
3498 loc
, "'type' directive operand expects variable or directive operand");
3500 if (auto *var
= dyn_cast
<OperandVariable
>(element
)) {
3501 unsigned opIdx
= var
->getVar() - op
.operand_begin();
3502 if (!isRefChild
&& (fmt
.allOperandTypes
|| seenOperandTypes
.test(opIdx
)))
3503 return emitError(loc
, "'type' of '" + var
->getVar()->name
+
3504 "' is already bound");
3505 if (isRefChild
&& !(fmt
.allOperandTypes
|| seenOperandTypes
.test(opIdx
)))
3506 return emitError(loc
, "'ref' of 'type($" + var
->getVar()->name
+
3507 ")' is not bound by a prior 'type' directive");
3508 seenOperandTypes
.set(opIdx
);
3509 } else if (auto *var
= dyn_cast
<ResultVariable
>(element
)) {
3510 unsigned resIdx
= var
->getVar() - op
.result_begin();
3511 if (!isRefChild
&& (fmt
.allResultTypes
|| seenResultTypes
.test(resIdx
)))
3512 return emitError(loc
, "'type' of '" + var
->getVar()->name
+
3513 "' is already bound");
3514 if (isRefChild
&& !(fmt
.allResultTypes
|| seenResultTypes
.test(resIdx
)))
3515 return emitError(loc
, "'ref' of 'type($" + var
->getVar()->name
+
3516 ")' is not bound by a prior 'type' directive");
3517 seenResultTypes
.set(resIdx
);
3518 } else if (isa
<OperandsDirective
>(&*element
)) {
3519 if (!isRefChild
&& (fmt
.allOperandTypes
|| seenOperandTypes
.any()))
3520 return emitError(loc
, "'operands' 'type' is already bound");
3521 if (isRefChild
&& !fmt
.allOperandTypes
)
3522 return emitError(loc
, "'ref' of 'type(operands)' is not bound by a prior "
3523 "'type' directive");
3524 fmt
.allOperandTypes
= true;
3525 } else if (isa
<ResultsDirective
>(&*element
)) {
3526 if (!isRefChild
&& (fmt
.allResultTypes
|| seenResultTypes
.any()))
3527 return emitError(loc
, "'results' 'type' is already bound");
3528 if (isRefChild
&& !fmt
.allResultTypes
)
3529 return emitError(loc
, "'ref' of 'type(results)' is not bound by a prior "
3530 "'type' directive");
3531 fmt
.allResultTypes
= true;
3533 return emitError(loc
, "invalid argument to 'type' directive");
3538 LogicalResult
OpFormatParser::verifyOptionalGroupElements(
3539 SMLoc loc
, ArrayRef
<FormatElement
*> elements
, FormatElement
*anchor
) {
3540 for (FormatElement
*element
: elements
) {
3541 if (failed(verifyOptionalGroupElement(loc
, element
, element
== anchor
)))
3547 LogicalResult
OpFormatParser::verifyOptionalGroupElement(SMLoc loc
,
3548 FormatElement
*element
,
3550 return TypeSwitch
<FormatElement
*, LogicalResult
>(element
)
3551 // All attributes can be within the optional group, but only optional
3552 // attributes can be the anchor.
3553 .Case([&](AttributeVariable
*attrEle
) {
3554 Attribute attr
= attrEle
->getVar()->attr
;
3555 if (isAnchor
&& !(attr
.isOptional() || attr
.hasDefaultValue()))
3556 return emitError(loc
, "only optional or default-valued attributes "
3557 "can be used to anchor an optional group");
3560 // Only optional-like(i.e. variadic) operands can be within an optional
3562 .Case([&](OperandVariable
*ele
) {
3563 if (!ele
->getVar()->isVariableLength())
3564 return emitError(loc
, "only variable length operands can be used "
3565 "within an optional group");
3568 // Only optional-like(i.e. variadic) results can be within an optional
3570 .Case([&](ResultVariable
*ele
) {
3571 if (!ele
->getVar()->isVariableLength())
3572 return emitError(loc
, "only variable length results can be used "
3573 "within an optional group");
3576 .Case([&](RegionVariable
*) {
3577 // TODO: When ODS has proper support for marking "optional" regions, add
3581 .Case([&](TypeDirective
*ele
) {
3582 return verifyOptionalGroupElement(loc
, ele
->getArg(),
3583 /*isAnchor=*/false);
3585 .Case([&](FunctionalTypeDirective
*ele
) {
3586 if (failed(verifyOptionalGroupElement(loc
, ele
->getInputs(),
3587 /*isAnchor=*/false)))
3589 return verifyOptionalGroupElement(loc
, ele
->getResults(),
3590 /*isAnchor=*/false);
3592 .Case([&](CustomDirective
*ele
) {
3595 // Verify each child as being valid in an optional group. They are all
3596 // potential anchors if the custom directive was marked as one.
3597 for (FormatElement
*child
: ele
->getArguments()) {
3598 if (isa
<RefDirective
>(child
))
3600 if (failed(verifyOptionalGroupElement(loc
, child
, /*isAnchor=*/true)))
3605 // Literals, whitespace, and custom directives may be used, but they can't
3606 // anchor the group.
3607 .Case
<LiteralElement
, WhitespaceElement
, OptionalElement
>(
3608 [&](FormatElement
*) {
3610 return emitError(loc
, "only variables and types can be used "
3611 "to anchor an optional group");
3614 .Default([&](FormatElement
*) {
3615 return emitError(loc
, "only literals, types, and variables can be "
3616 "used within an optional group");
3620 //===----------------------------------------------------------------------===//
3622 //===----------------------------------------------------------------------===//
3624 void mlir::tblgen::generateOpFormat(const Operator
&constOp
, OpClass
&opClass
) {
3625 // TODO: Operator doesn't expose all necessary functionality via
3626 // the const interface.
3627 Operator
&op
= const_cast<Operator
&>(constOp
);
3628 if (!op
.hasAssemblyFormat())
3631 // Parse the format description.
3632 llvm::SourceMgr mgr
;
3633 mgr
.AddNewSourceBuffer(
3634 llvm::MemoryBuffer::getMemBuffer(op
.getAssemblyFormat()), SMLoc());
3635 OperationFormat
format(op
);
3636 OpFormatParser
parser(mgr
, format
, op
);
3637 FailureOr
<std::vector
<FormatElement
*>> elements
= parser
.parse();
3638 if (failed(elements
)) {
3639 // Exit the process if format errors are treated as fatal.
3640 if (formatErrorIsFatal
) {
3641 // Invoke the interrupt handlers to run the file cleanup handlers.
3642 llvm::sys::RunInterruptHandlers();
3647 format
.elements
= std::move(*elements
);
3649 // Generate the printer and parser based on the parsed format.
3650 format
.genParser(op
, opClass
);
3651 format
.genPrinter(op
, opClass
);