1 //===- OpFormatGen.cpp - MLIR operation asm format generator --------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "OpFormatGen.h"
10 #include "FormatGen.h"
12 #include "mlir/Support/LLVM.h"
13 #include "mlir/TableGen/Class.h"
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "mlir/TableGen/Trait.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallBitVector.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Signals.h"
24 #include "llvm/Support/SourceMgr.h"
25 #include "llvm/TableGen/Record.h"
27 #define DEBUG_TYPE "mlir-tblgen-opformatgen"
30 using namespace mlir::tblgen
;
32 //===----------------------------------------------------------------------===//
36 /// This class represents an instance of an op variable element. A variable
37 /// refers to something registered on the operation itself, e.g. an operand,
38 /// result, attribute, region, or successor.
39 template <typename VarT
, VariableElement::Kind VariableKind
>
40 class OpVariableElement
: public VariableElementBase
<VariableKind
> {
42 using Base
= OpVariableElement
<VarT
, VariableKind
>;
44 /// Create an op variable element with the variable value.
45 OpVariableElement(const VarT
*var
) : var(var
) {}
48 const VarT
*getVar() { return var
; }
51 /// The op variable, e.g. a type or attribute constraint.
55 /// This class represents a variable that refers to an attribute argument.
56 struct AttributeVariable
57 : public OpVariableElement
<NamedAttribute
, VariableElement::Attribute
> {
60 /// Return the constant builder call for the type of this attribute, or
61 /// std::nullopt if it doesn't have one.
62 std::optional
<StringRef
> getTypeBuilder() const {
63 std::optional
<Type
> attrType
= var
->attr
.getValueType();
64 return attrType
? attrType
->getBuilderCall() : std::nullopt
;
67 /// Return if this attribute refers to a UnitAttr.
68 bool isUnitAttr() const {
69 return var
->attr
.getBaseAttr().getAttrDefName() == "UnitAttr";
72 /// Indicate if this attribute is printed "qualified" (that is it is
73 /// prefixed with the `#dialect.mnemonic`).
74 bool shouldBeQualified() { return shouldBeQualifiedFlag
; }
75 void setShouldBeQualified(bool qualified
= true) {
76 shouldBeQualifiedFlag
= qualified
;
80 bool shouldBeQualifiedFlag
= false;
83 /// This class represents a variable that refers to an operand argument.
84 using OperandVariable
=
85 OpVariableElement
<NamedTypeConstraint
, VariableElement::Operand
>;
87 /// This class represents a variable that refers to a result.
88 using ResultVariable
=
89 OpVariableElement
<NamedTypeConstraint
, VariableElement::Result
>;
91 /// This class represents a variable that refers to a region.
92 using RegionVariable
= OpVariableElement
<NamedRegion
, VariableElement::Region
>;
94 /// This class represents a variable that refers to a successor.
95 using SuccessorVariable
=
96 OpVariableElement
<NamedSuccessor
, VariableElement::Successor
>;
98 /// This class represents a variable that refers to a property argument.
99 using PropertyVariable
=
100 OpVariableElement
<NamedProperty
, VariableElement::Property
>;
103 //===----------------------------------------------------------------------===//
107 /// This class represents the `operands` directive. This directive represents
108 /// all of the operands of an operation.
109 using OperandsDirective
= DirectiveElementBase
<DirectiveElement::Operands
>;
111 /// This class represents the `results` directive. This directive represents
112 /// all of the results of an operation.
113 using ResultsDirective
= DirectiveElementBase
<DirectiveElement::Results
>;
115 /// This class represents the `regions` directive. This directive represents
116 /// all of the regions of an operation.
117 using RegionsDirective
= DirectiveElementBase
<DirectiveElement::Regions
>;
119 /// This class represents the `successors` directive. This directive represents
120 /// all of the successors of an operation.
121 using SuccessorsDirective
= DirectiveElementBase
<DirectiveElement::Successors
>;
123 /// This class represents the `attr-dict` directive. This directive represents
124 /// the attribute dictionary of the operation.
125 class AttrDictDirective
126 : public DirectiveElementBase
<DirectiveElement::AttrDict
> {
128 explicit AttrDictDirective(bool withKeyword
) : withKeyword(withKeyword
) {}
130 /// Return whether the dictionary should be printed with the 'attributes'
132 bool isWithKeyword() const { return withKeyword
; }
135 /// If the dictionary should be printed with the 'attributes' keyword.
139 /// This class represents the `prop-dict` directive. This directive represents
140 /// the properties of the operation, expressed as a directionary.
141 class PropDictDirective
142 : public DirectiveElementBase
<DirectiveElement::PropDict
> {
144 explicit PropDictDirective() = default;
147 /// This class represents the `functional-type` directive. This directive takes
148 /// two arguments and formats them, respectively, as the inputs and results of a
150 class FunctionalTypeDirective
151 : public DirectiveElementBase
<DirectiveElement::FunctionalType
> {
153 FunctionalTypeDirective(FormatElement
*inputs
, FormatElement
*results
)
154 : inputs(inputs
), results(results
) {}
156 FormatElement
*getInputs() const { return inputs
; }
157 FormatElement
*getResults() const { return results
; }
160 /// The input and result arguments.
161 FormatElement
*inputs
, *results
;
164 /// This class represents the `type` directive.
165 class TypeDirective
: public DirectiveElementBase
<DirectiveElement::Type
> {
167 TypeDirective(FormatElement
*arg
) : arg(arg
) {}
169 FormatElement
*getArg() const { return arg
; }
171 /// Indicate if this type is printed "qualified" (that is it is
172 /// prefixed with the `!dialect.mnemonic`).
173 bool shouldBeQualified() { return shouldBeQualifiedFlag
; }
174 void setShouldBeQualified(bool qualified
= true) {
175 shouldBeQualifiedFlag
= qualified
;
179 /// The argument that is used to format the directive.
182 bool shouldBeQualifiedFlag
= false;
185 /// This class represents a group of order-independent optional clauses. Each
186 /// clause starts with a literal element and has a coressponding parsing
187 /// element. A parsing element is a continous sequence of format elements.
188 /// Each clause can appear 0 or 1 time.
189 class OIListElement
: public DirectiveElementBase
<DirectiveElement::OIList
> {
191 OIListElement(std::vector
<FormatElement
*> &&literalElements
,
192 std::vector
<std::vector
<FormatElement
*>> &&parsingElements
)
193 : literalElements(std::move(literalElements
)),
194 parsingElements(std::move(parsingElements
)) {}
196 /// Returns a range to iterate over the LiteralElements.
197 auto getLiteralElements() const {
198 function_ref
<LiteralElement
*(FormatElement
* el
)>
199 literalElementCastConverter
=
200 [](FormatElement
*el
) { return cast
<LiteralElement
>(el
); };
201 return llvm::map_range(literalElements
, literalElementCastConverter
);
204 /// Returns a range to iterate over the parsing elements corresponding to the
206 ArrayRef
<std::vector
<FormatElement
*>> getParsingElements() const {
207 return parsingElements
;
210 /// Returns a range to iterate over tuples of parsing and literal elements.
211 auto getClauses() const {
212 return llvm::zip(getLiteralElements(), getParsingElements());
215 /// If the parsing element is a single UnitAttr element, then it returns the
216 /// attribute variable. Otherwise, returns nullptr.
218 getUnitAttrParsingElement(ArrayRef
<FormatElement
*> pelement
) {
219 if (pelement
.size() == 1) {
220 auto *attrElem
= dyn_cast
<AttributeVariable
>(pelement
[0]);
221 if (attrElem
&& attrElem
->isUnitAttr())
228 /// A vector of `LiteralElement` objects. Each element stores the keyword
229 /// for one case of oilist element. For example, an oilist element along with
230 /// the `literalElements` vector:
232 /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`]
233 /// literalElements = { `keyword`, `otherKeyword` }
235 std::vector
<FormatElement
*> literalElements
;
237 /// A vector of valid declarative assembly format vectors. Each object in
238 /// parsing elements is a vector of elements in assembly format syntax.
239 /// For example, an oilist element along with the parsingElements vector:
241 /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`]
242 /// parsingElements = {
243 /// { `=`, `(`, $arg0, `)` },
244 /// { `<`, $arg1, `>` }
247 std::vector
<std::vector
<FormatElement
*>> parsingElements
;
251 //===----------------------------------------------------------------------===//
253 //===----------------------------------------------------------------------===//
257 using ConstArgument
=
258 llvm::PointerUnion
<const NamedAttribute
*, const NamedTypeConstraint
*>;
260 struct OperationFormat
{
261 /// This class represents a specific resolver for an operand or result type.
262 class TypeResolution
{
264 TypeResolution() = default;
266 /// Get the index into the buildable types for this type, or std::nullopt.
267 std::optional
<int> getBuilderIdx() const { return builderIdx
; }
268 void setBuilderIdx(int idx
) { builderIdx
= idx
; }
270 /// Get the variable this type is resolved to, or nullptr.
271 const NamedTypeConstraint
*getVariable() const {
272 return llvm::dyn_cast_if_present
<const NamedTypeConstraint
*>(resolver
);
274 /// Get the attribute this type is resolved to, or nullptr.
275 const NamedAttribute
*getAttribute() const {
276 return llvm::dyn_cast_if_present
<const NamedAttribute
*>(resolver
);
278 /// Get the transformer for the type of the variable, or std::nullopt.
279 std::optional
<StringRef
> getVarTransformer() const {
280 return variableTransformer
;
282 void setResolver(ConstArgument arg
, std::optional
<StringRef
> transformer
) {
284 variableTransformer
= transformer
;
285 assert(getVariable() || getAttribute());
289 /// If the type is resolved with a buildable type, this is the index into
290 /// 'buildableTypes' in the parent format.
291 std::optional
<int> builderIdx
;
292 /// If the type is resolved based upon another operand or result, this is
293 /// the variable or the attribute that this type is resolved to.
294 ConstArgument resolver
;
295 /// If the type is resolved based upon another operand or result, this is
296 /// a transformer to apply to the variable when resolving.
297 std::optional
<StringRef
> variableTransformer
;
300 /// The context in which an element is generated.
301 enum class GenContext
{
302 /// The element is generated at the top-level or with the same behaviour.
304 /// The element is generated inside an optional group.
308 OperationFormat(const Operator
&op
)
309 : useProperties(op
.getDialect().usePropertiesForAttributes() &&
310 !op
.getAttributes().empty()),
311 opCppClassName(op
.getCppClassName()) {
312 operandTypes
.resize(op
.getNumOperands(), TypeResolution());
313 resultTypes
.resize(op
.getNumResults(), TypeResolution());
315 hasImplicitTermTrait
= llvm::any_of(op
.getTraits(), [](const Trait
&trait
) {
316 return trait
.getDef().isSubClassOf("SingleBlockImplicitTerminatorImpl");
319 hasSingleBlockTrait
= op
.getTrait("::mlir::OpTrait::SingleBlock");
322 /// Generate the operation parser from this format.
323 void genParser(Operator
&op
, OpClass
&opClass
);
324 /// Generate the parser code for a specific format element.
325 void genElementParser(FormatElement
*element
, MethodBody
&body
,
326 FmtContext
&attrTypeCtx
,
327 GenContext genCtx
= GenContext::Normal
);
328 /// Generate the C++ to resolve the types of operands and results during
330 void genParserTypeResolution(Operator
&op
, MethodBody
&body
);
331 /// Generate the C++ to resolve the types of the operands during parsing.
332 void genParserOperandTypeResolution(
333 Operator
&op
, MethodBody
&body
,
334 function_ref
<void(TypeResolution
&, StringRef
)> emitTypeResolver
);
335 /// Generate the C++ to resolve regions during parsing.
336 void genParserRegionResolution(Operator
&op
, MethodBody
&body
);
337 /// Generate the C++ to resolve successors during parsing.
338 void genParserSuccessorResolution(Operator
&op
, MethodBody
&body
);
339 /// Generate the C++ to handling variadic segment size traits.
340 void genParserVariadicSegmentResolution(Operator
&op
, MethodBody
&body
);
342 /// Generate the operation printer from this format.
343 void genPrinter(Operator
&op
, OpClass
&opClass
);
345 /// Generate the printer code for a specific format element.
346 void genElementPrinter(FormatElement
*element
, MethodBody
&body
, Operator
&op
,
347 bool &shouldEmitSpace
, bool &lastWasPunctuation
);
349 /// The various elements in this format.
350 std::vector
<FormatElement
*> elements
;
352 /// A flag indicating if all operand/result types were seen. If the format
353 /// contains these, it can not contain individual type resolvers.
354 bool allOperands
= false, allOperandTypes
= false, allResultTypes
= false;
356 /// A flag indicating if this operation infers its result types
357 bool infersResultTypes
= false;
359 /// A flag indicating if this operation has the SingleBlockImplicitTerminator
361 bool hasImplicitTermTrait
;
363 /// A flag indicating if this operation has the SingleBlock trait.
364 bool hasSingleBlockTrait
;
366 /// Indicate whether attribute are stored in properties.
369 /// The Operation class name
370 StringRef opCppClassName
;
372 /// A map of buildable types to indices.
373 llvm::MapVector
<StringRef
, int, llvm::StringMap
<int>> buildableTypes
;
375 /// The index of the buildable type, if valid, for every operand and result.
376 std::vector
<TypeResolution
> operandTypes
, resultTypes
;
378 /// The set of attributes explicitly used within the format.
379 SmallVector
<const NamedAttribute
*, 8> usedAttributes
;
380 llvm::StringSet
<> inferredAttributes
;
384 //===----------------------------------------------------------------------===//
387 /// Returns true if we can format the given attribute as an EnumAttr in the
389 static bool canFormatEnumAttr(const NamedAttribute
*attr
) {
390 Attribute baseAttr
= attr
->attr
.getBaseAttr();
391 const EnumAttr
*enumAttr
= dyn_cast
<EnumAttr
>(&baseAttr
);
395 // The attribute must have a valid underlying type and a constant builder.
396 return !enumAttr
->getUnderlyingType().empty() &&
397 !enumAttr
->getConstBuilderTemplate().empty();
400 /// Returns if we should format the given attribute as an SymbolNameAttr.
401 static bool shouldFormatSymbolNameAttr(const NamedAttribute
*attr
) {
402 return attr
->attr
.getBaseAttr().getAttrDefName() == "SymbolNameAttr";
405 /// The code snippet used to generate a parser call for an attribute.
407 /// {0}: The name of the attribute.
408 /// {1}: The type for the attribute.
409 const char *const attrParserCode
= R
"(
410 if (parser.parseCustomAttributeWithFallback({0}Attr, {1})) {{
411 return ::mlir::failure();
415 /// The code snippet used to generate a parser call for an attribute.
417 /// {0}: The name of the attribute.
418 /// {1}: The type for the attribute.
419 const char *const genericAttrParserCode
= R
"(
420 if (parser.parseAttribute({0}Attr, {1}))
421 return ::mlir::failure();
424 const char *const optionalAttrParserCode
= R
"(
425 ::mlir::OptionalParseResult parseResult{0}Attr =
426 parser.parseOptionalAttribute({0}Attr, {1});
427 if (parseResult{0}Attr.has_value() && failed(*parseResult{0}Attr))
428 return ::mlir::failure();
429 if (parseResult{0}Attr.has_value() && succeeded(*parseResult{0}Attr))
432 /// The code snippet used to generate a parser call for a symbol name attribute.
434 /// {0}: The name of the attribute.
435 const char *const symbolNameAttrParserCode
= R
"(
436 if (parser.parseSymbolName({0}Attr))
437 return ::mlir::failure();
439 const char *const optionalSymbolNameAttrParserCode
= R
"(
440 // Parsing an optional symbol name doesn't fail, so no need to check the
442 (void)parser.parseOptionalSymbolName({0}Attr);
445 /// The code snippet used to generate a parser call for an enum attribute.
447 /// {0}: The name of the attribute.
448 /// {1}: The c++ namespace for the enum symbolize functions.
449 /// {2}: The function to symbolize a string of the enum.
450 /// {3}: The constant builder call to create an attribute of the enum type.
451 /// {4}: The set of allowed enum keywords.
452 /// {5}: The error message on failure when the enum isn't present.
453 /// {6}: The attribute assignment expression
454 const char *const enumAttrParserCode
= R
"(
456 ::llvm::StringRef attrStr;
457 ::mlir::NamedAttrList attrStorage;
458 auto loc = parser.getCurrentLocation();
459 if (parser.parseOptionalKeyword(&attrStr, {4})) {
460 ::mlir::StringAttr attrVal;
461 ::mlir::OptionalParseResult parseResult =
462 parser.parseOptionalAttribute(attrVal,
463 parser.getBuilder().getNoneType(),
465 if (parseResult.has_value()) {{
466 if (failed(*parseResult))
467 return ::mlir::failure();
468 attrStr = attrVal.getValue();
473 if (!attrStr.empty()) {
474 auto attrOptional = {1}::{2}(attrStr);
476 return parser.emitError(loc, "invalid
")
477 << "{0} attribute specification
: \"" << attrStr << '"';;
485 /// The code snippet used to generate a parser call for an operand.
487 /// {0}: The name of the operand.
488 const char *const variadicOperandParserCode = R"(
489 {0}OperandsLoc = parser.getCurrentLocation();
490 if (parser.parseOperandList({0}Operands))
491 return ::mlir::failure();
493 const char *const optionalOperandParserCode = R"(
495 {0}OperandsLoc = parser.getCurrentLocation();
496 ::mlir::OpAsmParser::UnresolvedOperand operand;
497 ::mlir::OptionalParseResult parseResult =
498 parser.parseOptionalOperand(operand);
499 if (parseResult.has_value()) {
500 if (failed(*parseResult))
501 return ::mlir::failure();
502 {0}Operands.push_back(operand);
506 const char *const operandParserCode = R"(
507 {0}OperandsLoc = parser.getCurrentLocation();
508 if (parser.parseOperand({0}RawOperands[0]))
509 return ::mlir::failure();
511 /// The code snippet used to generate a parser call for a VariadicOfVariadic
514 /// {0}: The name of the operand.
515 /// {1}: The name of segment size attribute.
516 const char *const variadicOfVariadicOperandParserCode = R"(
518 {0}OperandsLoc = parser.getCurrentLocation();
521 if (parser.parseOptionalLParen())
523 if (parser.parseOperandList({0}Operands) || parser.parseRParen())
524 return ::mlir::failure();
525 {0}OperandGroupSizes.push_back({0}Operands.size() - curSize);
526 curSize = {0}Operands.size();
527 } while (succeeded(parser.parseOptionalComma()));
531 /// The code snippet used to generate a parser call for a type list.
533 /// {0}: The name for the type list.
534 const char *const variadicOfVariadicTypeParserCode = R"(
536 if (parser.parseOptionalLParen())
538 if (parser.parseOptionalRParen() &&
539 (parser.parseTypeList({0}Types) || parser.parseRParen()))
540 return ::mlir::failure();
541 } while (succeeded(parser.parseOptionalComma()));
543 const char *const variadicTypeParserCode = R"(
544 if (parser.parseTypeList({0}Types))
545 return ::mlir::failure();
547 const char *const optionalTypeParserCode = R"(
549 ::mlir::Type optionalType;
550 ::mlir::OptionalParseResult parseResult =
551 parser.parseOptionalType(optionalType);
552 if (parseResult.has_value()) {
553 if (failed(*parseResult))
554 return ::mlir::failure();
555 {0}Types.push_back(optionalType);
559 const char *const typeParserCode = R"(
562 if (parser.parseCustomTypeWithFallback(type))
563 return ::mlir::failure();
564 {1}RawTypes[0] = type;
567 const char *const qualifiedTypeParserCode = R"(
568 if (parser.parseType({1}RawTypes[0]))
569 return ::mlir::failure();
572 /// The code snippet used to generate a parser call for a functional type.
574 /// {0}: The name for the input type list.
575 /// {1}: The name for the result type list.
576 const char *const functionalTypeParserCode = R"(
577 ::mlir::FunctionType {0}__{1}_functionType;
578 if (parser.parseType({0}__{1}_functionType))
579 return ::mlir::failure();
580 {0}Types = {0}__{1}_functionType.getInputs();
581 {1}Types = {0}__{1}_functionType.getResults();
584 /// The code snippet used to generate a parser call to infer return types.
586 /// {0}: The operation class name
587 const char *const inferReturnTypesParserCode = R"(
588 ::llvm::SmallVector<::mlir::Type> inferredReturnTypes;
589 if (::mlir::failed({0}::inferReturnTypes(parser.getContext(),
590 result.location, result.operands,
591 result.attributes.getDictionary(parser.getContext()),
592 result.getRawProperties(),
593 result.regions, inferredReturnTypes)))
594 return ::mlir::failure();
595 result.addTypes(inferredReturnTypes);
598 /// The code snippet used to generate a parser call for a region list.
600 /// {0}: The name for the region list.
601 const char *regionListParserCode = R"(
603 std::unique_ptr<::mlir::Region> region;
604 auto firstRegionResult = parser.parseOptionalRegion(region);
605 if (firstRegionResult.has_value()) {
606 if (failed(*firstRegionResult))
607 return ::mlir::failure();
608 {0}Regions.emplace_back(std::move(region));
610 // Parse any trailing regions.
611 while (succeeded(parser.parseOptionalComma())) {
612 region = std::make_unique<::mlir::Region>();
613 if (parser.parseRegion(*region))
614 return ::mlir::failure();
615 {0}Regions.emplace_back(std::move(region));
621 /// The code snippet used to ensure a list of regions have terminators.
623 /// {0}: The name of the region list.
624 const char *regionListEnsureTerminatorParserCode = R"(
625 for (auto ®ion : {0}Regions)
626 ensureTerminator(*region, parser.getBuilder(), result.location);
629 /// The code snippet used to ensure a list of regions have a block.
631 /// {0}: The name of the region list.
632 const char *regionListEnsureSingleBlockParserCode = R"(
633 for (auto ®ion : {0}Regions)
634 if (region->empty()) region->emplaceBlock();
637 /// The code snippet used to generate a parser call for an optional region.
639 /// {0}: The name of the region.
640 const char *optionalRegionParserCode = R"(
642 auto parseResult = parser.parseOptionalRegion(*{0}Region);
643 if (parseResult.has_value() && failed(*parseResult))
644 return ::mlir::failure();
648 /// The code snippet used to generate a parser call for a region.
650 /// {0}: The name of the region.
651 const char *regionParserCode = R"(
652 if (parser.parseRegion(*{0}Region))
653 return ::mlir::failure();
656 /// The code snippet used to ensure a region has a terminator.
658 /// {0}: The name of the region.
659 const char *regionEnsureTerminatorParserCode = R"(
660 ensureTerminator(*{0}Region, parser.getBuilder(), result.location);
663 /// The code snippet used to ensure a region has a block.
665 /// {0}: The name of the region.
666 const char *regionEnsureSingleBlockParserCode = R"(
667 if ({0}Region->empty()) {0}Region->emplaceBlock();
670 /// The code snippet used to generate a parser call for a successor list.
672 /// {0}: The name for the successor list.
673 const char *successorListParserCode = R"(
676 auto firstSucc = parser.parseOptionalSuccessor(succ);
677 if (firstSucc.has_value()) {
678 if (failed(*firstSucc))
679 return ::mlir::failure();
680 {0}Successors.emplace_back(succ);
682 // Parse any trailing successors.
683 while (succeeded(parser.parseOptionalComma())) {
684 if (parser.parseSuccessor(succ))
685 return ::mlir::failure();
686 {0}Successors.emplace_back(succ);
692 /// The code snippet used to generate a parser call for a successor.
694 /// {0}: The name of the successor.
695 const char *successorParserCode = R"(
696 if (parser.parseSuccessor({0}Successor))
697 return ::mlir::failure();
700 /// The code snippet used to generate a parser for OIList
702 /// {0}: literal keyword corresponding to a case for oilist
703 const char *oilistParserCode = R"(
705 return parser.emitError(parser.getNameLoc())
706 << "`{0}` clause can appear at most once in the expansion of the "
713 /// The type of length for a given parse argument.
714 enum class ArgumentLengthKind {
715 /// The argument is a variadic of a variadic, and may contain 0->N range
718 /// The argument is variadic, and may contain 0->N elements.
720 /// The argument is optional, and may contain 0 or 1 elements.
722 /// The argument is a single element, i.e. always represents 1 element.
727 /// Get the length kind for the given constraint.
728 static ArgumentLengthKind
729 getArgumentLengthKind(const NamedTypeConstraint *var) {
730 if (var->isOptional())
731 return ArgumentLengthKind::Optional;
732 if (var->isVariadicOfVariadic())
733 return ArgumentLengthKind::VariadicOfVariadic;
734 if (var->isVariadic())
735 return ArgumentLengthKind::Variadic;
736 return ArgumentLengthKind::Single;
739 /// Get the name used for the type list for the given type directive operand.
740 /// 'lengthKind
' to the corresponding kind for the given argument.
741 static StringRef getTypeListName(FormatElement *arg,
742 ArgumentLengthKind &lengthKind) {
743 if (auto *operand = dyn_cast<OperandVariable>(arg)) {
744 lengthKind = getArgumentLengthKind(operand->getVar());
745 return operand->getVar()->name;
747 if (auto *result = dyn_cast<ResultVariable>(arg)) {
748 lengthKind = getArgumentLengthKind(result->getVar());
749 return result->getVar()->name;
751 lengthKind = ArgumentLengthKind::Variadic;
752 if (isa<OperandsDirective>(arg))
754 if (isa<ResultsDirective>(arg))
756 llvm_unreachable("unknown 'type
' directive argument");
759 /// Generate the parser for a literal value.
760 static void genLiteralParser(StringRef value, MethodBody &body) {
761 // Handle the case of a keyword/identifier.
762 if (value.front() == '_
' || isalpha(value.front())) {
763 body << "Keyword(\"" << value << "\")";
766 body << (StringRef)StringSwitch<StringRef>(value)
767 .Case("->", "Arrow()")
768 .Case(":", "Colon()")
769 .Case(",", "Comma()")
770 .Case("=", "Equal()")
772 .Case(">", "Greater()")
773 .Case("{", "LBrace()")
774 .Case("}", "RBrace()")
775 .Case("(", "LParen()")
776 .Case(")", "RParen()")
777 .Case("[", "LSquare()")
778 .Case("]", "RSquare()")
779 .Case("?", "Question()")
782 .Case("...", "Ellipsis()");
785 /// Generate the storage code required for parsing the given element.
786 static void genElementParserStorage(FormatElement *element, const Operator &op,
788 if (auto *optional = dyn_cast<OptionalElement>(element)) {
789 ArrayRef<FormatElement *> elements = optional->getThenElements();
791 // If the anchor is a unit attribute, it won't be parsed directly so elide
793 auto *anchor
= dyn_cast
<AttributeVariable
>(optional
->getAnchor());
794 FormatElement
*elidedAnchorElement
= nullptr;
795 if (anchor
&& anchor
!= elements
.front() && anchor
->isUnitAttr())
796 elidedAnchorElement
= anchor
;
797 for (FormatElement
*childElement
: elements
)
798 if (childElement
!= elidedAnchorElement
)
799 genElementParserStorage(childElement
, op
, body
);
800 for (FormatElement
*childElement
: optional
->getElseElements())
801 genElementParserStorage(childElement
, op
, body
);
803 } else if (auto *oilist
= dyn_cast
<OIListElement
>(element
)) {
804 for (ArrayRef
<FormatElement
*> pelement
: oilist
->getParsingElements()) {
805 if (!oilist
->getUnitAttrParsingElement(pelement
))
806 for (FormatElement
*element
: pelement
)
807 genElementParserStorage(element
, op
, body
);
810 } else if (auto *custom
= dyn_cast
<CustomDirective
>(element
)) {
811 for (FormatElement
*paramElement
: custom
->getArguments())
812 genElementParserStorage(paramElement
, op
, body
);
814 } else if (isa
<OperandsDirective
>(element
)) {
815 body
<< " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> "
818 } else if (isa
<RegionsDirective
>(element
)) {
819 body
<< " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
822 } else if (isa
<SuccessorsDirective
>(element
)) {
823 body
<< " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
825 } else if (auto *attr
= dyn_cast
<AttributeVariable
>(element
)) {
826 const NamedAttribute
*var
= attr
->getVar();
827 body
<< llvm::formatv(" {0} {1}Attr;\n", var
->attr
.getStorageType(),
830 } else if (auto *operand
= dyn_cast
<OperandVariable
>(element
)) {
831 StringRef name
= operand
->getVar()->name
;
832 if (operand
->getVar()->isVariableLength()) {
834 << " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> "
835 << name
<< "Operands;\n";
836 if (operand
->getVar()->isVariadicOfVariadic()) {
837 body
<< " llvm::SmallVector<int32_t> " << name
838 << "OperandGroupSizes;\n";
841 body
<< " ::mlir::OpAsmParser::UnresolvedOperand " << name
842 << "RawOperands[1];\n"
843 << " ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> "
844 << name
<< "Operands(" << name
<< "RawOperands);";
846 body
<< llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
847 " (void){0}OperandsLoc;\n",
850 } else if (auto *region
= dyn_cast
<RegionVariable
>(element
)) {
851 StringRef name
= region
->getVar()->name
;
852 if (region
->getVar()->isVariadic()) {
853 body
<< llvm::formatv(
854 " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
858 body
<< llvm::formatv(" std::unique_ptr<::mlir::Region> {0}Region = "
859 "std::make_unique<::mlir::Region>();\n",
863 } else if (auto *successor
= dyn_cast
<SuccessorVariable
>(element
)) {
864 StringRef name
= successor
->getVar()->name
;
865 if (successor
->getVar()->isVariadic()) {
866 body
<< llvm::formatv(" ::llvm::SmallVector<::mlir::Block *, 2> "
870 body
<< llvm::formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name
);
873 } else if (auto *dir
= dyn_cast
<TypeDirective
>(element
)) {
874 ArgumentLengthKind lengthKind
;
875 StringRef name
= getTypeListName(dir
->getArg(), lengthKind
);
876 if (lengthKind
!= ArgumentLengthKind::Single
)
877 body
<< " ::llvm::SmallVector<::mlir::Type, 1> " << name
<< "Types;\n";
879 body
<< llvm::formatv(" ::mlir::Type {0}RawTypes[1];\n", name
)
881 " ::llvm::ArrayRef<::mlir::Type> {0}Types({0}RawTypes);\n",
883 } else if (auto *dir
= dyn_cast
<FunctionalTypeDirective
>(element
)) {
884 ArgumentLengthKind ignored
;
885 body
<< " ::llvm::ArrayRef<::mlir::Type> "
886 << getTypeListName(dir
->getInputs(), ignored
) << "Types;\n";
887 body
<< " ::llvm::ArrayRef<::mlir::Type> "
888 << getTypeListName(dir
->getResults(), ignored
) << "Types;\n";
892 /// Generate the parser for a parameter to a custom directive.
893 static void genCustomParameterParser(FormatElement
*param
, MethodBody
&body
) {
894 if (auto *attr
= dyn_cast
<AttributeVariable
>(param
)) {
895 body
<< attr
->getVar()->name
<< "Attr";
896 } else if (isa
<AttrDictDirective
>(param
)) {
897 body
<< "result.attributes";
898 } else if (auto *operand
= dyn_cast
<OperandVariable
>(param
)) {
899 StringRef name
= operand
->getVar()->name
;
900 ArgumentLengthKind lengthKind
= getArgumentLengthKind(operand
->getVar());
901 if (lengthKind
== ArgumentLengthKind::VariadicOfVariadic
)
902 body
<< llvm::formatv("{0}OperandGroups", name
);
903 else if (lengthKind
== ArgumentLengthKind::Variadic
)
904 body
<< llvm::formatv("{0}Operands", name
);
905 else if (lengthKind
== ArgumentLengthKind::Optional
)
906 body
<< llvm::formatv("{0}Operand", name
);
908 body
<< formatv("{0}RawOperands[0]", name
);
910 } else if (auto *region
= dyn_cast
<RegionVariable
>(param
)) {
911 StringRef name
= region
->getVar()->name
;
912 if (region
->getVar()->isVariadic())
913 body
<< llvm::formatv("{0}Regions", name
);
915 body
<< llvm::formatv("*{0}Region", name
);
917 } else if (auto *successor
= dyn_cast
<SuccessorVariable
>(param
)) {
918 StringRef name
= successor
->getVar()->name
;
919 if (successor
->getVar()->isVariadic())
920 body
<< llvm::formatv("{0}Successors", name
);
922 body
<< llvm::formatv("{0}Successor", name
);
924 } else if (auto *dir
= dyn_cast
<RefDirective
>(param
)) {
925 genCustomParameterParser(dir
->getArg(), body
);
927 } else if (auto *dir
= dyn_cast
<TypeDirective
>(param
)) {
928 ArgumentLengthKind lengthKind
;
929 StringRef listName
= getTypeListName(dir
->getArg(), lengthKind
);
930 if (lengthKind
== ArgumentLengthKind::VariadicOfVariadic
)
931 body
<< llvm::formatv("{0}TypeGroups", listName
);
932 else if (lengthKind
== ArgumentLengthKind::Variadic
)
933 body
<< llvm::formatv("{0}Types", listName
);
934 else if (lengthKind
== ArgumentLengthKind::Optional
)
935 body
<< llvm::formatv("{0}Type", listName
);
937 body
<< formatv("{0}RawTypes[0]", listName
);
939 } else if (auto *string
= dyn_cast
<StringElement
>(param
)) {
941 ctx
.withBuilder("parser.getBuilder()");
942 ctx
.addSubst("_ctxt", "parser.getContext()");
943 body
<< tgfmt(string
->getValue(), &ctx
);
945 } else if (auto *property
= dyn_cast
<PropertyVariable
>(param
)) {
946 body
<< llvm::formatv("result.getOrAddProperties<Properties>().{0}",
947 property
->getVar()->name
);
949 llvm_unreachable("unknown custom directive parameter");
953 /// Generate the parser for a custom directive.
954 static void genCustomDirectiveParser(CustomDirective
*dir
, MethodBody
&body
,
956 StringRef opCppClassName
,
957 bool isOptional
= false) {
960 // Preprocess the directive variables.
961 // * Add a local variable for optional operands and types. This provides a
962 // better API to the user defined parser methods.
963 // * Set the location of operand variables.
964 for (FormatElement
*param
: dir
->getArguments()) {
965 if (auto *operand
= dyn_cast
<OperandVariable
>(param
)) {
966 auto *var
= operand
->getVar();
967 body
<< " " << var
->name
968 << "OperandsLoc = parser.getCurrentLocation();\n";
969 if (var
->isOptional()) {
970 body
<< llvm::formatv(
971 " ::std::optional<::mlir::OpAsmParser::UnresolvedOperand> "
974 } else if (var
->isVariadicOfVariadic()) {
975 body
<< llvm::formatv(" "
976 "::llvm::SmallVector<::llvm::SmallVector<::mlir::"
977 "OpAsmParser::UnresolvedOperand>> "
978 "{0}OperandGroups;\n",
981 } else if (auto *dir
= dyn_cast
<TypeDirective
>(param
)) {
982 ArgumentLengthKind lengthKind
;
983 StringRef listName
= getTypeListName(dir
->getArg(), lengthKind
);
984 if (lengthKind
== ArgumentLengthKind::Optional
) {
985 body
<< llvm::formatv(" ::mlir::Type {0}Type;\n", listName
);
986 } else if (lengthKind
== ArgumentLengthKind::VariadicOfVariadic
) {
987 body
<< llvm::formatv(
988 " ::llvm::SmallVector<llvm::SmallVector<::mlir::Type>> "
992 } else if (auto *dir
= dyn_cast
<RefDirective
>(param
)) {
993 FormatElement
*input
= dir
->getArg();
994 if (auto *operand
= dyn_cast
<OperandVariable
>(input
)) {
995 if (!operand
->getVar()->isOptional())
997 body
<< llvm::formatv(
998 " {0} {1}Operand = {1}Operands.empty() ? {0}() : "
1000 "::std::optional<::mlir::OpAsmParser::UnresolvedOperand>",
1001 operand
->getVar()->name
);
1003 } else if (auto *type
= dyn_cast
<TypeDirective
>(input
)) {
1004 ArgumentLengthKind lengthKind
;
1005 StringRef listName
= getTypeListName(type
->getArg(), lengthKind
);
1006 if (lengthKind
== ArgumentLengthKind::Optional
) {
1007 body
<< llvm::formatv(" ::mlir::Type {0}Type = {0}Types.empty() ? "
1008 "::mlir::Type() : {0}Types[0];\n",
1015 body
<< " auto odsResult = parse" << dir
->getName() << "(parser";
1016 for (FormatElement
*param
: dir
->getArguments()) {
1018 genCustomParameterParser(param
, body
);
1023 body
<< " if (!odsResult) return {};\n"
1024 << " if (::mlir::failed(*odsResult)) return ::mlir::failure();\n";
1026 body
<< " if (odsResult) return ::mlir::failure();\n";
1029 // After parsing, add handling for any of the optional constructs.
1030 for (FormatElement
*param
: dir
->getArguments()) {
1031 if (auto *attr
= dyn_cast
<AttributeVariable
>(param
)) {
1032 const NamedAttribute
*var
= attr
->getVar();
1033 if (var
->attr
.isOptional() || var
->attr
.hasDefaultValue())
1034 body
<< llvm::formatv(" if ({0}Attr)\n ", var
->name
);
1035 if (useProperties
) {
1037 " result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;\n",
1038 var
->name
, opCppClassName
);
1040 body
<< llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
1044 } else if (auto *operand
= dyn_cast
<OperandVariable
>(param
)) {
1045 const NamedTypeConstraint
*var
= operand
->getVar();
1046 if (var
->isOptional()) {
1047 body
<< llvm::formatv(" if ({0}Operand.has_value())\n"
1048 " {0}Operands.push_back(*{0}Operand);\n",
1050 } else if (var
->isVariadicOfVariadic()) {
1051 body
<< llvm::formatv(
1052 " for (const auto &subRange : {0}OperandGroups) {{\n"
1053 " {0}Operands.append(subRange.begin(), subRange.end());\n"
1054 " {0}OperandGroupSizes.push_back(subRange.size());\n"
1056 var
->name
, var
->constraint
.getVariadicOfVariadicSegmentSizeAttr());
1058 } else if (auto *dir
= dyn_cast
<TypeDirective
>(param
)) {
1059 ArgumentLengthKind lengthKind
;
1060 StringRef listName
= getTypeListName(dir
->getArg(), lengthKind
);
1061 if (lengthKind
== ArgumentLengthKind::Optional
) {
1062 body
<< llvm::formatv(" if ({0}Type)\n"
1063 " {0}Types.push_back({0}Type);\n",
1065 } else if (lengthKind
== ArgumentLengthKind::VariadicOfVariadic
) {
1066 body
<< llvm::formatv(
1067 " for (const auto &subRange : {0}TypeGroups)\n"
1068 " {0}Types.append(subRange.begin(), subRange.end());\n",
1077 /// Generate the parser for a enum attribute.
1078 static void genEnumAttrParser(const NamedAttribute
*var
, MethodBody
&body
,
1079 FmtContext
&attrTypeCtx
, bool parseAsOptional
,
1080 bool useProperties
, StringRef opCppClassName
) {
1081 Attribute baseAttr
= var
->attr
.getBaseAttr();
1082 const EnumAttr
&enumAttr
= cast
<EnumAttr
>(baseAttr
);
1083 std::vector
<EnumAttrCase
> cases
= enumAttr
.getAllCases();
1085 // Generate the code for building an attribute for this enum.
1086 std::string attrBuilderStr
;
1088 llvm::raw_string_ostream
os(attrBuilderStr
);
1089 os
<< tgfmt(enumAttr
.getConstBuilderTemplate(), &attrTypeCtx
,
1093 // Build a string containing the cases that can be formatted as a keyword.
1094 std::string validCaseKeywordsStr
= "{";
1095 llvm::raw_string_ostream
validCaseKeywordsOS(validCaseKeywordsStr
);
1096 for (const EnumAttrCase
&attrCase
: cases
)
1097 if (canFormatStringAsKeyword(attrCase
.getStr()))
1098 validCaseKeywordsOS
<< '"' << attrCase
.getStr() << "\",";
1099 validCaseKeywordsOS
.str().back() = '}';
1101 // If the attribute is not optional, build an error message for the missing
1103 std::string errorMessage
;
1104 if (!parseAsOptional
) {
1105 llvm::raw_string_ostream
errorMessageOS(errorMessage
);
1107 << "return parser.emitError(loc, \"expected string or "
1108 "keyword containing one of the following enum values for attribute '"
1109 << var
->name
<< "' [";
1110 llvm::interleaveComma(cases
, errorMessageOS
, [&](const auto &attrCase
) {
1111 errorMessageOS
<< attrCase
.getStr();
1113 errorMessageOS
<< "]\");";
1115 std::string attrAssignment
;
1116 if (useProperties
) {
1119 "result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;",
1120 var
->name
, opCppClassName
);
1123 formatv("result.addAttribute(\"{0}\", {0}Attr);", var
->name
);
1126 body
<< formatv(enumAttrParserCode
, var
->name
, enumAttr
.getCppNamespace(),
1127 enumAttr
.getStringToSymbolFnName(), attrBuilderStr
,
1128 validCaseKeywordsStr
, errorMessage
, attrAssignment
);
1131 // Generate the parser for an attribute.
1132 static void genAttrParser(AttributeVariable
*attr
, MethodBody
&body
,
1133 FmtContext
&attrTypeCtx
, bool parseAsOptional
,
1134 bool useProperties
, StringRef opCppClassName
) {
1135 const NamedAttribute
*var
= attr
->getVar();
1137 // Check to see if we can parse this as an enum attribute.
1138 if (canFormatEnumAttr(var
))
1139 return genEnumAttrParser(var
, body
, attrTypeCtx
, parseAsOptional
,
1140 useProperties
, opCppClassName
);
1142 // Check to see if we should parse this as a symbol name attribute.
1143 if (shouldFormatSymbolNameAttr(var
)) {
1144 body
<< formatv(parseAsOptional
? optionalSymbolNameAttrParserCode
1145 : symbolNameAttrParserCode
,
1149 // If this attribute has a buildable type, use that when parsing the
1151 std::string attrTypeStr
;
1152 if (std::optional
<StringRef
> typeBuilder
= attr
->getTypeBuilder()) {
1153 llvm::raw_string_ostream
os(attrTypeStr
);
1154 os
<< tgfmt(*typeBuilder
, &attrTypeCtx
);
1156 attrTypeStr
= "::mlir::Type{}";
1158 if (parseAsOptional
) {
1159 body
<< formatv(optionalAttrParserCode
, var
->name
, attrTypeStr
);
1161 if (attr
->shouldBeQualified() ||
1162 var
->attr
.getStorageType() == "::mlir::Attribute")
1163 body
<< formatv(genericAttrParserCode
, var
->name
, attrTypeStr
);
1165 body
<< formatv(attrParserCode
, var
->name
, attrTypeStr
);
1168 if (useProperties
) {
1170 " if ({0}Attr) result.getOrAddProperties<{1}::Properties>().{0} = "
1172 var
->name
, opCppClassName
);
1175 " if ({0}Attr) result.attributes.append(\"{0}\", {0}Attr);\n",
1180 void OperationFormat::genParser(Operator
&op
, OpClass
&opClass
) {
1181 SmallVector
<MethodParameter
> paramList
;
1182 paramList
.emplace_back("::mlir::OpAsmParser &", "parser");
1183 paramList
.emplace_back("::mlir::OperationState &", "result");
1185 auto *method
= opClass
.addStaticMethod("::mlir::ParseResult", "parse",
1186 std::move(paramList
));
1187 auto &body
= method
->body();
1189 // Generate variables to store the operands and type within the format. This
1190 // allows for referencing these variables in the presence of optional
1192 for (FormatElement
*element
: elements
)
1193 genElementParserStorage(element
, op
, body
);
1195 // A format context used when parsing attributes with buildable types.
1196 FmtContext attrTypeCtx
;
1197 attrTypeCtx
.withBuilder("parser.getBuilder()");
1199 // Generate parsers for each of the elements.
1200 for (FormatElement
*element
: elements
)
1201 genElementParser(element
, body
, attrTypeCtx
);
1203 // Generate the code to resolve the operand/result types and successors now
1204 // that they have been parsed.
1205 genParserRegionResolution(op
, body
);
1206 genParserSuccessorResolution(op
, body
);
1207 genParserVariadicSegmentResolution(op
, body
);
1208 genParserTypeResolution(op
, body
);
1210 body
<< " return ::mlir::success();\n";
1213 void OperationFormat::genElementParser(FormatElement
*element
, MethodBody
&body
,
1214 FmtContext
&attrTypeCtx
,
1215 GenContext genCtx
) {
1217 if (auto *optional
= dyn_cast
<OptionalElement
>(element
)) {
1218 auto genElementParsers
= [&](FormatElement
*firstElement
,
1219 ArrayRef
<FormatElement
*> elements
,
1221 // If the anchor is a unit attribute, we don't need to print it. When
1222 // parsing, we will add this attribute if this group is present.
1223 FormatElement
*elidedAnchorElement
= nullptr;
1224 auto *anchorAttr
= dyn_cast
<AttributeVariable
>(optional
->getAnchor());
1225 if (anchorAttr
&& anchorAttr
!= firstElement
&&
1226 anchorAttr
->isUnitAttr()) {
1227 elidedAnchorElement
= anchorAttr
;
1229 if (!thenGroup
== optional
->isInverted()) {
1230 // Add the anchor unit attribute to the operation state.
1231 if (useProperties
) {
1233 " result.getOrAddProperties<{1}::Properties>().{0} = "
1234 "parser.getBuilder().getUnitAttr();",
1235 anchorAttr
->getVar()->name
, opCppClassName
);
1237 body
<< " result.addAttribute(\"" << anchorAttr
->getVar()->name
1238 << "\", parser.getBuilder().getUnitAttr());\n";
1243 // Generate the rest of the elements inside an optional group. Elements in
1244 // an optional group after the guard are parsed as required.
1245 for (FormatElement
*childElement
: elements
)
1246 if (childElement
!= elidedAnchorElement
)
1247 genElementParser(childElement
, body
, attrTypeCtx
,
1248 GenContext::Optional
);
1251 ArrayRef
<FormatElement
*> thenElements
=
1252 optional
->getThenElements(/*parseable=*/true);
1254 // Generate a special optional parser for the first element to gate the
1255 // parsing of the rest of the elements.
1256 FormatElement
*firstElement
= thenElements
.front();
1257 if (auto *attrVar
= dyn_cast
<AttributeVariable
>(firstElement
)) {
1258 genAttrParser(attrVar
, body
, attrTypeCtx
, /*parseAsOptional=*/true,
1259 useProperties
, opCppClassName
);
1260 body
<< " if (" << attrVar
->getVar()->name
<< "Attr) {\n";
1261 } else if (auto *literal
= dyn_cast
<LiteralElement
>(firstElement
)) {
1262 body
<< " if (::mlir::succeeded(parser.parseOptional";
1263 genLiteralParser(literal
->getSpelling(), body
);
1265 } else if (auto *opVar
= dyn_cast
<OperandVariable
>(firstElement
)) {
1266 genElementParser(opVar
, body
, attrTypeCtx
);
1267 body
<< " if (!" << opVar
->getVar()->name
<< "Operands.empty()) {\n";
1268 } else if (auto *regionVar
= dyn_cast
<RegionVariable
>(firstElement
)) {
1269 const NamedRegion
*region
= regionVar
->getVar();
1270 if (region
->isVariadic()) {
1271 genElementParser(regionVar
, body
, attrTypeCtx
);
1272 body
<< " if (!" << region
->name
<< "Regions.empty()) {\n";
1274 body
<< llvm::formatv(optionalRegionParserCode
, region
->name
);
1275 body
<< " if (!" << region
->name
<< "Region->empty()) {\n ";
1276 if (hasImplicitTermTrait
)
1277 body
<< llvm::formatv(regionEnsureTerminatorParserCode
, region
->name
);
1278 else if (hasSingleBlockTrait
)
1279 body
<< llvm::formatv(regionEnsureSingleBlockParserCode
,
1282 } else if (auto *custom
= dyn_cast
<CustomDirective
>(firstElement
)) {
1283 body
<< " if (auto result = [&]() -> ::mlir::OptionalParseResult {\n";
1284 genCustomDirectiveParser(custom
, body
, useProperties
, opCppClassName
,
1285 /*isOptional=*/true);
1286 body
<< " return ::mlir::success();\n"
1287 << " }(); result.has_value() && ::mlir::failed(*result)) {\n"
1288 << " return ::mlir::failure();\n"
1289 << " } else if (result.has_value()) {\n";
1292 genElementParsers(firstElement
, thenElements
.drop_front(),
1293 /*thenGroup=*/true);
1296 // Generate the else elements.
1297 auto elseElements
= optional
->getElseElements();
1298 if (!elseElements
.empty()) {
1299 body
<< " else {\n";
1300 ArrayRef
<FormatElement
*> elseElements
=
1301 optional
->getElseElements(/*parseable=*/true);
1302 genElementParsers(elseElements
.front(), elseElements
,
1303 /*thenGroup=*/false);
1308 /// OIList Directive
1309 } else if (OIListElement
*oilist
= dyn_cast
<OIListElement
>(element
)) {
1310 for (LiteralElement
*le
: oilist
->getLiteralElements())
1311 body
<< " bool " << le
->getSpelling() << "Clause = false;\n";
1313 // Generate the parsing loop
1314 body
<< " while(true) {\n";
1315 for (auto clause
: oilist
->getClauses()) {
1316 LiteralElement
*lelement
= std::get
<0>(clause
);
1317 ArrayRef
<FormatElement
*> pelement
= std::get
<1>(clause
);
1318 body
<< "if (succeeded(parser.parseOptional";
1319 genLiteralParser(lelement
->getSpelling(), body
);
1321 StringRef lelementName
= lelement
->getSpelling();
1322 body
<< formatv(oilistParserCode
, lelementName
);
1323 if (AttributeVariable
*unitAttrElem
=
1324 oilist
->getUnitAttrParsingElement(pelement
)) {
1325 if (useProperties
) {
1327 " result.getOrAddProperties<{1}::Properties>().{0} = "
1328 "parser.getBuilder().getUnitAttr();",
1329 unitAttrElem
->getVar()->name
, opCppClassName
);
1331 body
<< " result.addAttribute(\"" << unitAttrElem
->getVar()->name
1332 << "\", UnitAttr::get(parser.getContext()));\n";
1335 for (FormatElement
*el
: pelement
)
1336 genElementParser(el
, body
, attrTypeCtx
);
1341 body
<< " break;\n";
1346 } else if (LiteralElement
*literal
= dyn_cast
<LiteralElement
>(element
)) {
1347 body
<< " if (parser.parse";
1348 genLiteralParser(literal
->getSpelling(), body
);
1349 body
<< ")\n return ::mlir::failure();\n";
1352 } else if (isa
<WhitespaceElement
>(element
)) {
1353 // Nothing to parse.
1356 } else if (auto *attr
= dyn_cast
<AttributeVariable
>(element
)) {
1357 bool parseAsOptional
=
1358 (genCtx
== GenContext::Normal
&& attr
->getVar()->attr
.isOptional());
1359 genAttrParser(attr
, body
, attrTypeCtx
, parseAsOptional
, useProperties
,
1362 } else if (auto *operand
= dyn_cast
<OperandVariable
>(element
)) {
1363 ArgumentLengthKind lengthKind
= getArgumentLengthKind(operand
->getVar());
1364 StringRef name
= operand
->getVar()->name
;
1365 if (lengthKind
== ArgumentLengthKind::VariadicOfVariadic
)
1366 body
<< llvm::formatv(
1367 variadicOfVariadicOperandParserCode
, name
,
1368 operand
->getVar()->constraint
.getVariadicOfVariadicSegmentSizeAttr());
1369 else if (lengthKind
== ArgumentLengthKind::Variadic
)
1370 body
<< llvm::formatv(variadicOperandParserCode
, name
);
1371 else if (lengthKind
== ArgumentLengthKind::Optional
)
1372 body
<< llvm::formatv(optionalOperandParserCode
, name
);
1374 body
<< formatv(operandParserCode
, name
);
1376 } else if (auto *region
= dyn_cast
<RegionVariable
>(element
)) {
1377 bool isVariadic
= region
->getVar()->isVariadic();
1378 body
<< llvm::formatv(isVariadic
? regionListParserCode
: regionParserCode
,
1379 region
->getVar()->name
);
1380 if (hasImplicitTermTrait
)
1381 body
<< llvm::formatv(isVariadic
? regionListEnsureTerminatorParserCode
1382 : regionEnsureTerminatorParserCode
,
1383 region
->getVar()->name
);
1384 else if (hasSingleBlockTrait
)
1385 body
<< llvm::formatv(isVariadic
? regionListEnsureSingleBlockParserCode
1386 : regionEnsureSingleBlockParserCode
,
1387 region
->getVar()->name
);
1389 } else if (auto *successor
= dyn_cast
<SuccessorVariable
>(element
)) {
1390 bool isVariadic
= successor
->getVar()->isVariadic();
1391 body
<< formatv(isVariadic
? successorListParserCode
: successorParserCode
,
1392 successor
->getVar()->name
);
1395 } else if (auto *attrDict
= dyn_cast
<AttrDictDirective
>(element
)) {
1396 body
.indent() << "{\n";
1397 body
.indent() << "auto loc = parser.getCurrentLocation();(void)loc;\n"
1398 << "if (parser.parseOptionalAttrDict"
1399 << (attrDict
->isWithKeyword() ? "WithKeyword" : "")
1400 << "(result.attributes))\n"
1401 << " return ::mlir::failure();\n";
1402 if (useProperties
) {
1403 body
<< "if (failed(verifyInherentAttrs(result.name, result.attributes, "
1405 << " return parser.emitError(loc) << \"'\" << "
1406 "result.name.getStringRef() << \"' op \";\n"
1408 << " return ::mlir::failure();\n";
1410 body
.unindent() << "}\n";
1412 } else if (dyn_cast
<PropDictDirective
>(element
)) {
1413 body
<< " if (parseProperties(parser, result))\n"
1414 << " return ::mlir::failure();\n";
1415 } else if (auto *customDir
= dyn_cast
<CustomDirective
>(element
)) {
1416 genCustomDirectiveParser(customDir
, body
, useProperties
, opCppClassName
);
1417 } else if (isa
<OperandsDirective
>(element
)) {
1418 body
<< " ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
1419 << " if (parser.parseOperandList(allOperands))\n"
1420 << " return ::mlir::failure();\n";
1422 } else if (isa
<RegionsDirective
>(element
)) {
1423 body
<< llvm::formatv(regionListParserCode
, "full");
1424 if (hasImplicitTermTrait
)
1425 body
<< llvm::formatv(regionListEnsureTerminatorParserCode
, "full");
1426 else if (hasSingleBlockTrait
)
1427 body
<< llvm::formatv(regionListEnsureSingleBlockParserCode
, "full");
1429 } else if (isa
<SuccessorsDirective
>(element
)) {
1430 body
<< llvm::formatv(successorListParserCode
, "full");
1432 } else if (auto *dir
= dyn_cast
<TypeDirective
>(element
)) {
1433 ArgumentLengthKind lengthKind
;
1434 StringRef listName
= getTypeListName(dir
->getArg(), lengthKind
);
1435 if (lengthKind
== ArgumentLengthKind::VariadicOfVariadic
) {
1436 body
<< llvm::formatv(variadicOfVariadicTypeParserCode
, listName
);
1437 } else if (lengthKind
== ArgumentLengthKind::Variadic
) {
1438 body
<< llvm::formatv(variadicTypeParserCode
, listName
);
1439 } else if (lengthKind
== ArgumentLengthKind::Optional
) {
1440 body
<< llvm::formatv(optionalTypeParserCode
, listName
);
1442 const char *parserCode
=
1443 dir
->shouldBeQualified() ? qualifiedTypeParserCode
: typeParserCode
;
1444 TypeSwitch
<FormatElement
*>(dir
->getArg())
1445 .Case
<OperandVariable
, ResultVariable
>([&](auto operand
) {
1446 body
<< formatv(parserCode
,
1447 operand
->getVar()->constraint
.getCPPClassName(),
1450 .Default([&](auto operand
) {
1451 body
<< formatv(parserCode
, "::mlir::Type", listName
);
1454 } else if (auto *dir
= dyn_cast
<FunctionalTypeDirective
>(element
)) {
1455 ArgumentLengthKind ignored
;
1456 body
<< formatv(functionalTypeParserCode
,
1457 getTypeListName(dir
->getInputs(), ignored
),
1458 getTypeListName(dir
->getResults(), ignored
));
1460 llvm_unreachable("unknown format element");
1464 void OperationFormat::genParserTypeResolution(Operator
&op
, MethodBody
&body
) {
1465 // If any of type resolutions use transformed variables, make sure that the
1466 // types of those variables are resolved.
1467 SmallPtrSet
<const NamedTypeConstraint
*, 8> verifiedVariables
;
1468 FmtContext verifierFCtx
;
1469 for (TypeResolution
&resolver
:
1470 llvm::concat
<TypeResolution
>(resultTypes
, operandTypes
)) {
1471 std::optional
<StringRef
> transformer
= resolver
.getVarTransformer();
1474 // Ensure that we don't verify the same variables twice.
1475 const NamedTypeConstraint
*variable
= resolver
.getVariable();
1476 if (!variable
|| !verifiedVariables
.insert(variable
).second
)
1479 auto constraint
= variable
->constraint
;
1480 body
<< " for (::mlir::Type type : " << variable
->name
<< "Types) {\n"
1483 << tgfmt(constraint
.getConditionTemplate(),
1484 &verifierFCtx
.withSelf("type"))
1486 << formatv(" return parser.emitError(parser.getNameLoc()) << "
1487 "\"'{0}' must be {1}, but got \" << type;\n",
1488 variable
->name
, constraint
.getSummary())
1493 // Initialize the set of buildable types.
1494 if (!buildableTypes
.empty()) {
1495 FmtContext typeBuilderCtx
;
1496 typeBuilderCtx
.withBuilder("parser.getBuilder()");
1497 for (auto &it
: buildableTypes
)
1498 body
<< " ::mlir::Type odsBuildableType" << it
.second
<< " = "
1499 << tgfmt(it
.first
, &typeBuilderCtx
) << ";\n";
1502 // Emit the code necessary for a type resolver.
1503 auto emitTypeResolver
= [&](TypeResolution
&resolver
, StringRef curVar
) {
1504 if (std::optional
<int> val
= resolver
.getBuilderIdx()) {
1505 body
<< "odsBuildableType" << *val
;
1506 } else if (const NamedTypeConstraint
*var
= resolver
.getVariable()) {
1507 if (std::optional
<StringRef
> tform
= resolver
.getVarTransformer()) {
1508 FmtContext fmtContext
;
1509 fmtContext
.addSubst("_ctxt", "parser.getContext()");
1510 if (var
->isVariadic())
1511 fmtContext
.withSelf(var
->name
+ "Types");
1513 fmtContext
.withSelf(var
->name
+ "Types[0]");
1514 body
<< tgfmt(*tform
, &fmtContext
);
1516 body
<< var
->name
<< "Types";
1517 if (!var
->isVariadic())
1520 } else if (const NamedAttribute
*attr
= resolver
.getAttribute()) {
1521 if (std::optional
<StringRef
> tform
= resolver
.getVarTransformer())
1522 body
<< tgfmt(*tform
,
1523 &FmtContext().withSelf(attr
->name
+ "Attr.getType()"));
1525 body
<< attr
->name
<< "Attr.getType()";
1527 body
<< curVar
<< "Types";
1531 // Resolve each of the result types.
1532 if (!infersResultTypes
) {
1533 if (allResultTypes
) {
1534 body
<< " result.addTypes(allResultTypes);\n";
1536 for (unsigned i
= 0, e
= op
.getNumResults(); i
!= e
; ++i
) {
1537 body
<< " result.addTypes(";
1538 emitTypeResolver(resultTypes
[i
], op
.getResultName(i
));
1544 // Emit the operand type resolutions.
1545 genParserOperandTypeResolution(op
, body
, emitTypeResolver
);
1547 // Handle return type inference once all operands have been resolved
1548 if (infersResultTypes
)
1549 body
<< formatv(inferReturnTypesParserCode
, op
.getCppClassName());
1552 void OperationFormat::genParserOperandTypeResolution(
1553 Operator
&op
, MethodBody
&body
,
1554 function_ref
<void(TypeResolution
&, StringRef
)> emitTypeResolver
) {
1555 // Early exit if there are no operands.
1556 if (op
.getNumOperands() == 0)
1559 // Handle the case where all operand types are grouped together with
1560 // "types(operands)".
1561 if (allOperandTypes
) {
1562 // If `operands` was specified, use the full operand list directly.
1564 body
<< " if (parser.resolveOperands(allOperands, allOperandTypes, "
1565 "allOperandLoc, result.operands))\n"
1566 " return ::mlir::failure();\n";
1570 // Otherwise, use llvm::concat to merge the disjoint operand lists together.
1571 // llvm::concat does not allow the case of a single range, so guard it here.
1572 body
<< " if (parser.resolveOperands(";
1573 if (op
.getNumOperands() > 1) {
1574 body
<< "::llvm::concat<const ::mlir::OpAsmParser::UnresolvedOperand>(";
1575 llvm::interleaveComma(op
.getOperands(), body
, [&](auto &operand
) {
1576 body
<< operand
.name
<< "Operands";
1580 body
<< op
.operand_begin()->name
<< "Operands";
1582 body
<< ", allOperandTypes, parser.getNameLoc(), result.operands))\n"
1583 << " return ::mlir::failure();\n";
1587 // Handle the case where all operands are grouped together with "operands".
1589 body
<< " if (parser.resolveOperands(allOperands, ";
1591 // Group all of the operand types together to perform the resolution all at
1592 // once. Use llvm::concat to perform the merge. llvm::concat does not allow
1593 // the case of a single range, so guard it here.
1594 if (op
.getNumOperands() > 1) {
1595 body
<< "::llvm::concat<const ::mlir::Type>(";
1596 llvm::interleaveComma(
1597 llvm::seq
<int>(0, op
.getNumOperands()), body
, [&](int i
) {
1598 body
<< "::llvm::ArrayRef<::mlir::Type>(";
1599 emitTypeResolver(operandTypes
[i
], op
.getOperand(i
).name
);
1604 emitTypeResolver(operandTypes
.front(), op
.getOperand(0).name
);
1607 body
<< ", allOperandLoc, result.operands))\n return "
1608 "::mlir::failure();\n";
1612 // The final case is the one where each of the operands types are resolved
1614 for (unsigned i
= 0, e
= op
.getNumOperands(); i
!= e
; ++i
) {
1615 NamedTypeConstraint
&operand
= op
.getOperand(i
);
1616 body
<< " if (parser.resolveOperands(" << operand
.name
<< "Operands, ";
1618 // Resolve the type of this operand.
1619 TypeResolution
&operandType
= operandTypes
[i
];
1620 emitTypeResolver(operandType
, operand
.name
);
1622 body
<< ", " << operand
.name
1623 << "OperandsLoc, result.operands))\n return ::mlir::failure();\n";
1627 void OperationFormat::genParserRegionResolution(Operator
&op
,
1629 // Check for the case where all regions were parsed.
1630 bool hasAllRegions
= llvm::any_of(
1631 elements
, [](FormatElement
*elt
) { return isa
<RegionsDirective
>(elt
); });
1632 if (hasAllRegions
) {
1633 body
<< " result.addRegions(fullRegions);\n";
1637 // Otherwise, handle each region individually.
1638 for (const NamedRegion
®ion
: op
.getRegions()) {
1639 if (region
.isVariadic())
1640 body
<< " result.addRegions(" << region
.name
<< "Regions);\n";
1642 body
<< " result.addRegion(std::move(" << region
.name
<< "Region));\n";
1646 void OperationFormat::genParserSuccessorResolution(Operator
&op
,
1648 // Check for the case where all successors were parsed.
1649 bool hasAllSuccessors
= llvm::any_of(elements
, [](FormatElement
*elt
) {
1650 return isa
<SuccessorsDirective
>(elt
);
1652 if (hasAllSuccessors
) {
1653 body
<< " result.addSuccessors(fullSuccessors);\n";
1657 // Otherwise, handle each successor individually.
1658 for (const NamedSuccessor
&successor
: op
.getSuccessors()) {
1659 if (successor
.isVariadic())
1660 body
<< " result.addSuccessors(" << successor
.name
<< "Successors);\n";
1662 body
<< " result.addSuccessors(" << successor
.name
<< "Successor);\n";
1666 void OperationFormat::genParserVariadicSegmentResolution(Operator
&op
,
1669 if (op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1670 auto interleaveFn
= [&](const NamedTypeConstraint
&operand
) {
1671 // If the operand is variadic emit the parsed size.
1672 if (operand
.isVariableLength())
1673 body
<< "static_cast<int32_t>(" << operand
.name
<< "Operands.size())";
1677 if (op
.getDialect().usePropertiesForAttributes()) {
1678 body
<< "::llvm::copy(::llvm::ArrayRef<int32_t>({";
1679 llvm::interleaveComma(op
.getOperands(), body
, interleaveFn
);
1680 body
<< formatv("}), "
1681 "result.getOrAddProperties<{0}::Properties>()."
1682 "operandSegmentSizes.begin());\n",
1683 op
.getCppClassName());
1685 body
<< " result.addAttribute(\"operandSegmentSizes\", "
1686 << "parser.getBuilder().getDenseI32ArrayAttr({";
1687 llvm::interleaveComma(op
.getOperands(), body
, interleaveFn
);
1691 for (const NamedTypeConstraint
&operand
: op
.getOperands()) {
1692 if (!operand
.isVariadicOfVariadic())
1694 if (op
.getDialect().usePropertiesForAttributes()) {
1695 body
<< llvm::formatv(
1696 " result.getOrAddProperties<{0}::Properties>().{1} = "
1697 "parser.getBuilder().getDenseI32ArrayAttr({2}OperandGroupSizes);\n",
1698 op
.getCppClassName(),
1699 operand
.constraint
.getVariadicOfVariadicSegmentSizeAttr(),
1702 body
<< llvm::formatv(
1703 " result.addAttribute(\"{0}\", "
1704 "parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));"
1706 operand
.constraint
.getVariadicOfVariadicSegmentSizeAttr(),
1712 if (!allResultTypes
&&
1713 op
.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
1714 auto interleaveFn
= [&](const NamedTypeConstraint
&result
) {
1715 // If the result is variadic emit the parsed size.
1716 if (result
.isVariableLength())
1717 body
<< "static_cast<int32_t>(" << result
.name
<< "Types.size())";
1721 if (op
.getDialect().usePropertiesForAttributes()) {
1722 body
<< "::llvm::copy(::llvm::ArrayRef<int32_t>({";
1723 llvm::interleaveComma(op
.getResults(), body
, interleaveFn
);
1724 body
<< formatv("}), "
1725 "result.getOrAddProperties<{0}::Properties>()."
1726 "resultSegmentSizes.begin());\n",
1727 op
.getCppClassName());
1729 body
<< " result.addAttribute(\"resultSegmentSizes\", "
1730 << "parser.getBuilder().getDenseI32ArrayAttr({";
1731 llvm::interleaveComma(op
.getResults(), body
, interleaveFn
);
1737 //===----------------------------------------------------------------------===//
1740 /// The code snippet used to generate a printer call for a region of an
1741 // operation that has the SingleBlockImplicitTerminator trait.
1743 /// {0}: The name of the region.
1744 const char *regionSingleBlockImplicitTerminatorPrinterCode
= R
"(
1746 bool printTerminator = true;
1747 if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{
1748 printTerminator = !term->getAttrDictionary().empty() ||
1749 term->getNumOperands() != 0 ||
1750 term->getNumResults() != 0;
1752 _odsPrinter.printRegion({0}, /*printEntryBlockArgs=*/true,
1753 /*printBlockTerminators=*/printTerminator);
1757 /// The code snippet used to generate a printer call for an enum that has cases
1758 /// that can't be represented with a keyword.
1760 /// {0}: The name of the enum attribute.
1761 /// {1}: The name of the enum attributes symbolToString function.
1762 const char *enumAttrBeginPrinterCode
= R
"(
1764 auto caseValue = {0}();
1765 auto caseValueStr = {1}(caseValue);
1768 /// Generate the printer for the 'prop-dict' directive.
1769 static void genPropDictPrinter(OperationFormat
&fmt
, Operator
&op
,
1771 body
<< " _odsPrinter << \" \";\n"
1772 << " printProperties(this->getContext(), _odsPrinter, "
1773 "getProperties());\n";
1776 /// Generate the printer for the 'attr-dict' directive.
1777 static void genAttrDictPrinter(OperationFormat
&fmt
, Operator
&op
,
1778 MethodBody
&body
, bool withKeyword
) {
1779 body
<< " ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;\n";
1780 // Elide the variadic segment size attributes if necessary.
1781 if (!fmt
.allOperands
&&
1782 op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"))
1783 body
<< " elidedAttrs.push_back(\"operandSegmentSizes\");\n";
1784 if (!fmt
.allResultTypes
&&
1785 op
.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
1786 body
<< " elidedAttrs.push_back(\"resultSegmentSizes\");\n";
1787 for (const StringRef key
: fmt
.inferredAttributes
.keys())
1788 body
<< " elidedAttrs.push_back(\"" << key
<< "\");\n";
1789 for (const NamedAttribute
*attr
: fmt
.usedAttributes
)
1790 body
<< " elidedAttrs.push_back(\"" << attr
->name
<< "\");\n";
1791 // Add code to check attributes for equality with the default value
1792 // for attributes with the elidePrintingDefaultValue bit set.
1793 for (const NamedAttribute
&namedAttr
: op
.getAttributes()) {
1794 const Attribute
&attr
= namedAttr
.attr
;
1795 if (!attr
.isDerivedAttr() && attr
.hasDefaultValue()) {
1796 const StringRef
&name
= namedAttr
.name
;
1798 fctx
.withBuilder("odsBuilder");
1799 std::string defaultValue
= std::string(
1800 tgfmt(attr
.getConstBuilderTemplate(), &fctx
, attr
.getDefaultValue()));
1802 body
<< " ::mlir::Builder odsBuilder(getContext());\n";
1803 body
<< " ::mlir::Attribute attr = " << op
.getGetterName(name
)
1805 body
<< " if(attr && (attr == " << defaultValue
<< "))\n";
1806 body
<< " elidedAttrs.push_back(\"" << name
<< "\");\n";
1810 body
<< " _odsPrinter.printOptionalAttrDict"
1811 << (withKeyword
? "WithKeyword" : "")
1812 << "((*this)->getAttrs(), elidedAttrs);\n";
1815 /// Generate the printer for a literal value. `shouldEmitSpace` is true if a
1816 /// space should be emitted before this element. `lastWasPunctuation` is true if
1817 /// the previous element was a punctuation literal.
1818 static void genLiteralPrinter(StringRef value
, MethodBody
&body
,
1819 bool &shouldEmitSpace
, bool &lastWasPunctuation
) {
1820 body
<< " _odsPrinter";
1822 // Don't insert a space for certain punctuation.
1823 if (shouldEmitSpace
&& shouldEmitSpaceBefore(value
, lastWasPunctuation
))
1825 body
<< " << \"" << value
<< "\";\n";
1827 // Insert a space after certain literals.
1829 value
.size() != 1 || !StringRef("<({[").contains(value
.front());
1830 lastWasPunctuation
= value
.front() != '_' && !isalpha(value
.front());
1833 /// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation`
1834 /// are set to false.
1835 static void genSpacePrinter(bool value
, MethodBody
&body
, bool &shouldEmitSpace
,
1836 bool &lastWasPunctuation
) {
1838 body
<< " _odsPrinter << ' ';\n";
1839 lastWasPunctuation
= false;
1841 lastWasPunctuation
= true;
1843 shouldEmitSpace
= false;
1846 /// Generate the printer for a custom directive parameter.
1847 static void genCustomDirectiveParameterPrinter(FormatElement
*element
,
1850 if (auto *attr
= dyn_cast
<AttributeVariable
>(element
)) {
1851 body
<< op
.getGetterName(attr
->getVar()->name
) << "Attr()";
1853 } else if (isa
<AttrDictDirective
>(element
)) {
1854 body
<< "getOperation()->getAttrDictionary()";
1856 } else if (auto *operand
= dyn_cast
<OperandVariable
>(element
)) {
1857 body
<< op
.getGetterName(operand
->getVar()->name
) << "()";
1859 } else if (auto *region
= dyn_cast
<RegionVariable
>(element
)) {
1860 body
<< op
.getGetterName(region
->getVar()->name
) << "()";
1862 } else if (auto *successor
= dyn_cast
<SuccessorVariable
>(element
)) {
1863 body
<< op
.getGetterName(successor
->getVar()->name
) << "()";
1865 } else if (auto *dir
= dyn_cast
<RefDirective
>(element
)) {
1866 genCustomDirectiveParameterPrinter(dir
->getArg(), op
, body
);
1868 } else if (auto *dir
= dyn_cast
<TypeDirective
>(element
)) {
1869 auto *typeOperand
= dir
->getArg();
1870 auto *operand
= dyn_cast
<OperandVariable
>(typeOperand
);
1871 auto *var
= operand
? operand
->getVar()
1872 : cast
<ResultVariable
>(typeOperand
)->getVar();
1873 std::string name
= op
.getGetterName(var
->name
);
1874 if (var
->isVariadic())
1875 body
<< name
<< "().getTypes()";
1876 else if (var
->isOptional())
1877 body
<< llvm::formatv("({0}() ? {0}().getType() : ::mlir::Type())", name
);
1879 body
<< name
<< "().getType()";
1881 } else if (auto *string
= dyn_cast
<StringElement
>(element
)) {
1883 ctx
.withBuilder("::mlir::Builder(getContext())");
1884 ctx
.addSubst("_ctxt", "getContext()");
1885 body
<< tgfmt(string
->getValue(), &ctx
);
1887 } else if (auto *property
= dyn_cast
<PropertyVariable
>(element
)) {
1889 ctx
.addSubst("_ctxt", "getContext()");
1890 const NamedProperty
*namedProperty
= property
->getVar();
1891 ctx
.addSubst("_storage", "getProperties()." + namedProperty
->name
);
1892 body
<< tgfmt(namedProperty
->prop
.getConvertFromStorageCall(), &ctx
);
1894 llvm_unreachable("unknown custom directive parameter");
1898 /// Generate the printer for a custom directive.
1899 static void genCustomDirectivePrinter(CustomDirective
*customDir
,
1900 const Operator
&op
, MethodBody
&body
) {
1901 body
<< " print" << customDir
->getName() << "(_odsPrinter, *this";
1902 for (FormatElement
*param
: customDir
->getArguments()) {
1904 genCustomDirectiveParameterPrinter(param
, op
, body
);
1909 /// Generate the printer for a region with the given variable name.
1910 static void genRegionPrinter(const Twine
®ionName
, MethodBody
&body
,
1911 bool hasImplicitTermTrait
) {
1912 if (hasImplicitTermTrait
)
1913 body
<< llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode
,
1916 body
<< " _odsPrinter.printRegion(" << regionName
<< ");\n";
1918 static void genVariadicRegionPrinter(const Twine
®ionListName
,
1920 bool hasImplicitTermTrait
) {
1921 body
<< " llvm::interleaveComma(" << regionListName
1922 << ", _odsPrinter, [&](::mlir::Region ®ion) {\n ";
1923 genRegionPrinter("region", body
, hasImplicitTermTrait
);
1927 /// Generate the C++ for an operand to a (*-)type directive.
1928 static MethodBody
&genTypeOperandPrinter(FormatElement
*arg
, const Operator
&op
,
1930 bool useArrayRef
= true) {
1931 if (isa
<OperandsDirective
>(arg
))
1932 return body
<< "getOperation()->getOperandTypes()";
1933 if (isa
<ResultsDirective
>(arg
))
1934 return body
<< "getOperation()->getResultTypes()";
1935 auto *operand
= dyn_cast
<OperandVariable
>(arg
);
1936 auto *var
= operand
? operand
->getVar() : cast
<ResultVariable
>(arg
)->getVar();
1937 if (var
->isVariadicOfVariadic())
1938 return body
<< llvm::formatv("{0}().join().getTypes()",
1939 op
.getGetterName(var
->name
));
1940 if (var
->isVariadic())
1941 return body
<< op
.getGetterName(var
->name
) << "().getTypes()";
1942 if (var
->isOptional())
1943 return body
<< llvm::formatv(
1944 "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
1945 "::llvm::ArrayRef<::mlir::Type>())",
1946 op
.getGetterName(var
->name
));
1948 return body
<< "::llvm::ArrayRef<::mlir::Type>("
1949 << op
.getGetterName(var
->name
) << "().getType())";
1950 return body
<< op
.getGetterName(var
->name
) << "().getType()";
1953 /// Generate the printer for an enum attribute.
1954 static void genEnumAttrPrinter(const NamedAttribute
*var
, const Operator
&op
,
1956 Attribute baseAttr
= var
->attr
.getBaseAttr();
1957 const EnumAttr
&enumAttr
= cast
<EnumAttr
>(baseAttr
);
1958 std::vector
<EnumAttrCase
> cases
= enumAttr
.getAllCases();
1960 body
<< llvm::formatv(enumAttrBeginPrinterCode
,
1961 (var
->attr
.isOptional() ? "*" : "") +
1962 op
.getGetterName(var
->name
),
1963 enumAttr
.getSymbolToStringFnName());
1965 // Get a string containing all of the cases that can't be represented with a
1967 BitVector
nonKeywordCases(cases
.size());
1968 for (auto it
: llvm::enumerate(cases
)) {
1969 if (!canFormatStringAsKeyword(it
.value().getStr()))
1970 nonKeywordCases
.set(it
.index());
1973 // Otherwise if this is a bit enum attribute, don't allow cases that may
1974 // overlap with other cases. For simplicity sake, only allow cases with a
1975 // single bit value.
1976 if (enumAttr
.isBitEnum()) {
1977 for (auto it
: llvm::enumerate(cases
)) {
1978 int64_t value
= it
.value().getValue();
1979 if (value
< 0 || !llvm::isPowerOf2_64(value
))
1980 nonKeywordCases
.set(it
.index());
1984 // If there are any cases that can't be used with a keyword, switch on the
1985 // case value to determine when to print in the string form.
1986 if (nonKeywordCases
.any()) {
1987 body
<< " switch (caseValue) {\n";
1988 StringRef cppNamespace
= enumAttr
.getCppNamespace();
1989 StringRef enumName
= enumAttr
.getEnumClassName();
1990 for (auto it
: llvm::enumerate(cases
)) {
1991 if (nonKeywordCases
.test(it
.index()))
1993 StringRef symbol
= it
.value().getSymbol();
1994 body
<< llvm::formatv(" case {0}::{1}::{2}:\n", cppNamespace
, enumName
,
1995 llvm::isDigit(symbol
.front()) ? ("_" + symbol
)
1998 body
<< " _odsPrinter << caseValueStr;\n"
2001 " _odsPrinter << '\"' << caseValueStr << '\"';\n"
2008 body
<< " _odsPrinter << caseValueStr;\n"
2012 /// Generate a check that a DefaultValuedAttr has a value that is non-default.
2013 static void genNonDefaultValueCheck(MethodBody
&body
, const Operator
&op
,
2014 AttributeVariable
&attrElement
) {
2016 Attribute attr
= attrElement
.getVar()->attr
;
2017 fctx
.withBuilder("::mlir::OpBuilder((*this)->getContext())");
2018 body
<< " && " << op
.getGetterName(attrElement
.getVar()->name
) << "Attr() != "
2019 << tgfmt(attr
.getConstBuilderTemplate(), &fctx
, attr
.getDefaultValue());
2022 /// Generate the check for the anchor of an optional group.
2023 static void genOptionalGroupPrinterAnchor(FormatElement
*anchor
,
2026 TypeSwitch
<FormatElement
*>(anchor
)
2027 .Case
<OperandVariable
, ResultVariable
>([&](auto *element
) {
2028 const NamedTypeConstraint
*var
= element
->getVar();
2029 std::string name
= op
.getGetterName(var
->name
);
2030 if (var
->isOptional())
2031 body
<< name
<< "()";
2032 else if (var
->isVariadic())
2033 body
<< "!" << name
<< "().empty()";
2035 .Case([&](RegionVariable
*element
) {
2036 const NamedRegion
*var
= element
->getVar();
2037 std::string name
= op
.getGetterName(var
->name
);
2038 // TODO: Add a check for optional regions here when ODS supports it.
2039 body
<< "!" << name
<< "().empty()";
2041 .Case([&](TypeDirective
*element
) {
2042 genOptionalGroupPrinterAnchor(element
->getArg(), op
, body
);
2044 .Case([&](FunctionalTypeDirective
*element
) {
2045 genOptionalGroupPrinterAnchor(element
->getInputs(), op
, body
);
2047 .Case([&](AttributeVariable
*element
) {
2048 Attribute attr
= element
->getVar()->attr
;
2049 body
<< op
.getGetterName(element
->getVar()->name
) << "Attr()";
2050 if (attr
.isOptional())
2052 if (attr
.hasDefaultValue()) {
2053 // Consider a default-valued attribute as present if it's not the
2055 genNonDefaultValueCheck(body
, op
, *element
);
2058 llvm_unreachable("attribute must be optional or default-valued");
2060 .Case([&](CustomDirective
*ele
) {
2063 ele
->getArguments(), body
,
2064 [&](FormatElement
*child
) {
2066 genOptionalGroupPrinterAnchor(child
, op
, body
);
2074 void collect(FormatElement
*element
,
2075 SmallVectorImpl
<VariableElement
*> &variables
) {
2076 TypeSwitch
<FormatElement
*>(element
)
2077 .Case([&](VariableElement
*var
) { variables
.emplace_back(var
); })
2078 .Case([&](CustomDirective
*ele
) {
2079 for (FormatElement
*arg
: ele
->getArguments())
2080 collect(arg
, variables
);
2082 .Case([&](OptionalElement
*ele
) {
2083 for (FormatElement
*arg
: ele
->getThenElements())
2084 collect(arg
, variables
);
2085 for (FormatElement
*arg
: ele
->getElseElements())
2086 collect(arg
, variables
);
2088 .Case([&](FunctionalTypeDirective
*funcType
) {
2089 collect(funcType
->getInputs(), variables
);
2090 collect(funcType
->getResults(), variables
);
2092 .Case([&](OIListElement
*oilist
) {
2093 for (ArrayRef
<FormatElement
*> arg
: oilist
->getParsingElements())
2094 for (FormatElement
*arg
: arg
)
2095 collect(arg
, variables
);
2099 void OperationFormat::genElementPrinter(FormatElement
*element
,
2100 MethodBody
&body
, Operator
&op
,
2101 bool &shouldEmitSpace
,
2102 bool &lastWasPunctuation
) {
2103 if (LiteralElement
*literal
= dyn_cast
<LiteralElement
>(element
))
2104 return genLiteralPrinter(literal
->getSpelling(), body
, shouldEmitSpace
,
2105 lastWasPunctuation
);
2107 // Emit a whitespace element.
2108 if (auto *space
= dyn_cast
<WhitespaceElement
>(element
)) {
2109 if (space
->getValue() == "\\n") {
2110 body
<< " _odsPrinter.printNewline();\n";
2112 genSpacePrinter(!space
->getValue().empty(), body
, shouldEmitSpace
,
2113 lastWasPunctuation
);
2118 // Emit an optional group.
2119 if (OptionalElement
*optional
= dyn_cast
<OptionalElement
>(element
)) {
2120 // Emit the check for the presence of the anchor element.
2121 FormatElement
*anchor
= optional
->getAnchor();
2123 if (optional
->isInverted())
2125 genOptionalGroupPrinterAnchor(anchor
, op
, body
);
2129 // If the anchor is a unit attribute, we don't need to print it. When
2130 // parsing, we will add this attribute if this group is present.
2131 ArrayRef
<FormatElement
*> thenElements
= optional
->getThenElements();
2132 ArrayRef
<FormatElement
*> elseElements
= optional
->getElseElements();
2133 FormatElement
*elidedAnchorElement
= nullptr;
2134 auto *anchorAttr
= dyn_cast
<AttributeVariable
>(anchor
);
2135 if (anchorAttr
&& anchorAttr
!= thenElements
.front() &&
2136 (elseElements
.empty() || anchorAttr
!= elseElements
.front()) &&
2137 anchorAttr
->isUnitAttr()) {
2138 elidedAnchorElement
= anchorAttr
;
2140 auto genElementPrinters
= [&](ArrayRef
<FormatElement
*> elements
) {
2141 for (FormatElement
*childElement
: elements
) {
2142 if (childElement
!= elidedAnchorElement
) {
2143 genElementPrinter(childElement
, body
, op
, shouldEmitSpace
,
2144 lastWasPunctuation
);
2149 // Emit each of the elements.
2150 genElementPrinters(thenElements
);
2153 // Emit each of the else elements.
2154 if (!elseElements
.empty()) {
2155 body
<< " else {\n";
2156 genElementPrinters(elseElements
);
2160 body
.unindent() << "\n";
2165 if (auto *oilist
= dyn_cast
<OIListElement
>(element
)) {
2166 for (auto clause
: oilist
->getClauses()) {
2167 LiteralElement
*lelement
= std::get
<0>(clause
);
2168 ArrayRef
<FormatElement
*> pelement
= std::get
<1>(clause
);
2170 SmallVector
<VariableElement
*> vars
;
2171 for (FormatElement
*el
: pelement
)
2173 body
<< " if (false";
2174 for (VariableElement
*var
: vars
) {
2175 TypeSwitch
<FormatElement
*>(var
)
2176 .Case([&](AttributeVariable
*attrEle
) {
2177 body
<< " || (" << op
.getGetterName(attrEle
->getVar()->name
)
2179 Attribute attr
= attrEle
->getVar()->attr
;
2180 if (attr
.hasDefaultValue()) {
2181 // Don't print default-valued attributes.
2182 genNonDefaultValueCheck(body
, op
, *attrEle
);
2186 .Case([&](OperandVariable
*ele
) {
2187 if (ele
->getVar()->isVariadic()) {
2188 body
<< " || " << op
.getGetterName(ele
->getVar()->name
)
2191 body
<< " || " << op
.getGetterName(ele
->getVar()->name
) << "()";
2194 .Case([&](ResultVariable
*ele
) {
2195 if (ele
->getVar()->isVariadic()) {
2196 body
<< " || " << op
.getGetterName(ele
->getVar()->name
)
2199 body
<< " || " << op
.getGetterName(ele
->getVar()->name
) << "()";
2202 .Case([&](RegionVariable
*reg
) {
2203 body
<< " || " << op
.getGetterName(reg
->getVar()->name
) << "()";
2208 genLiteralPrinter(lelement
->getSpelling(), body
, shouldEmitSpace
,
2209 lastWasPunctuation
);
2210 if (oilist
->getUnitAttrParsingElement(pelement
) == nullptr) {
2211 for (FormatElement
*element
: pelement
)
2212 genElementPrinter(element
, body
, op
, shouldEmitSpace
,
2213 lastWasPunctuation
);
2220 // Emit the attribute dictionary.
2221 if (auto *attrDict
= dyn_cast
<AttrDictDirective
>(element
)) {
2222 genAttrDictPrinter(*this, op
, body
, attrDict
->isWithKeyword());
2223 lastWasPunctuation
= false;
2227 // Emit the attribute dictionary.
2228 if (dyn_cast
<PropDictDirective
>(element
)) {
2229 genPropDictPrinter(*this, op
, body
);
2230 lastWasPunctuation
= false;
2234 // Optionally insert a space before the next element. The AttrDict printer
2235 // already adds a space as necessary.
2236 if (shouldEmitSpace
|| !lastWasPunctuation
)
2237 body
<< " _odsPrinter << ' ';\n";
2238 lastWasPunctuation
= false;
2239 shouldEmitSpace
= true;
2241 if (auto *attr
= dyn_cast
<AttributeVariable
>(element
)) {
2242 const NamedAttribute
*var
= attr
->getVar();
2244 // If we are formatting as an enum, symbolize the attribute as a string.
2245 if (canFormatEnumAttr(var
))
2246 return genEnumAttrPrinter(var
, op
, body
);
2248 // If we are formatting as a symbol name, handle it as a symbol name.
2249 if (shouldFormatSymbolNameAttr(var
)) {
2250 body
<< " _odsPrinter.printSymbolName(" << op
.getGetterName(var
->name
)
2251 << "Attr().getValue());\n";
2255 // Elide the attribute type if it is buildable.
2256 if (attr
->getTypeBuilder())
2257 body
<< " _odsPrinter.printAttributeWithoutType("
2258 << op
.getGetterName(var
->name
) << "Attr());\n";
2259 else if (attr
->shouldBeQualified() ||
2260 var
->attr
.getStorageType() == "::mlir::Attribute")
2261 body
<< " _odsPrinter.printAttribute(" << op
.getGetterName(var
->name
)
2264 body
<< "_odsPrinter.printStrippedAttrOrType("
2265 << op
.getGetterName(var
->name
) << "Attr());\n";
2266 } else if (auto *operand
= dyn_cast
<OperandVariable
>(element
)) {
2267 if (operand
->getVar()->isVariadicOfVariadic()) {
2268 body
<< " ::llvm::interleaveComma("
2269 << op
.getGetterName(operand
->getVar()->name
)
2270 << "(), _odsPrinter, [&](const auto &operands) { _odsPrinter << "
2271 "\"(\" << operands << "
2274 } else if (operand
->getVar()->isOptional()) {
2275 body
<< " if (::mlir::Value value = "
2276 << op
.getGetterName(operand
->getVar()->name
) << "())\n"
2277 << " _odsPrinter << value;\n";
2279 body
<< " _odsPrinter << " << op
.getGetterName(operand
->getVar()->name
)
2282 } else if (auto *region
= dyn_cast
<RegionVariable
>(element
)) {
2283 const NamedRegion
*var
= region
->getVar();
2284 std::string name
= op
.getGetterName(var
->name
);
2285 if (var
->isVariadic()) {
2286 genVariadicRegionPrinter(name
+ "()", body
, hasImplicitTermTrait
);
2288 genRegionPrinter(name
+ "()", body
, hasImplicitTermTrait
);
2290 } else if (auto *successor
= dyn_cast
<SuccessorVariable
>(element
)) {
2291 const NamedSuccessor
*var
= successor
->getVar();
2292 std::string name
= op
.getGetterName(var
->name
);
2293 if (var
->isVariadic())
2294 body
<< " ::llvm::interleaveComma(" << name
<< "(), _odsPrinter);\n";
2296 body
<< " _odsPrinter << " << name
<< "();\n";
2297 } else if (auto *dir
= dyn_cast
<CustomDirective
>(element
)) {
2298 genCustomDirectivePrinter(dir
, op
, body
);
2299 } else if (isa
<OperandsDirective
>(element
)) {
2300 body
<< " _odsPrinter << getOperation()->getOperands();\n";
2301 } else if (isa
<RegionsDirective
>(element
)) {
2302 genVariadicRegionPrinter("getOperation()->getRegions()", body
,
2303 hasImplicitTermTrait
);
2304 } else if (isa
<SuccessorsDirective
>(element
)) {
2305 body
<< " ::llvm::interleaveComma(getOperation()->getSuccessors(), "
2307 } else if (auto *dir
= dyn_cast
<TypeDirective
>(element
)) {
2308 if (auto *operand
= dyn_cast
<OperandVariable
>(dir
->getArg())) {
2309 if (operand
->getVar()->isVariadicOfVariadic()) {
2310 body
<< llvm::formatv(
2311 " ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, "
2312 "[&](::mlir::TypeRange types) {{ _odsPrinter << \"(\" << "
2313 "types << \")\"; });\n",
2314 op
.getGetterName(operand
->getVar()->name
));
2318 const NamedTypeConstraint
*var
= nullptr;
2320 if (auto *operand
= dyn_cast
<OperandVariable
>(dir
->getArg()))
2321 var
= operand
->getVar();
2322 else if (auto *operand
= dyn_cast
<ResultVariable
>(dir
->getArg()))
2323 var
= operand
->getVar();
2325 if (var
&& !var
->isVariadicOfVariadic() && !var
->isVariadic() &&
2326 !var
->isOptional()) {
2327 std::string cppClass
= var
->constraint
.getCPPClassName();
2328 if (dir
->shouldBeQualified()) {
2329 body
<< " _odsPrinter << " << op
.getGetterName(var
->name
)
2330 << "().getType();\n";
2334 << " auto type = " << op
.getGetterName(var
->name
)
2335 << "().getType();\n"
2336 << " if (auto validType = ::llvm::dyn_cast<" << cppClass
2338 << " _odsPrinter.printStrippedAttrOrType(validType);\n"
2340 << " _odsPrinter << type;\n"
2344 body
<< " _odsPrinter << ";
2345 genTypeOperandPrinter(dir
->getArg(), op
, body
, /*useArrayRef=*/false)
2347 } else if (auto *dir
= dyn_cast
<FunctionalTypeDirective
>(element
)) {
2348 body
<< " _odsPrinter.printFunctionalType(";
2349 genTypeOperandPrinter(dir
->getInputs(), op
, body
) << ", ";
2350 genTypeOperandPrinter(dir
->getResults(), op
, body
) << ");\n";
2352 llvm_unreachable("unknown format element");
2356 void OperationFormat::genPrinter(Operator
&op
, OpClass
&opClass
) {
2357 auto *method
= opClass
.addMethod(
2359 MethodParameter("::mlir::OpAsmPrinter &", "_odsPrinter"));
2360 auto &body
= method
->body();
2362 // Flags for if we should emit a space, and if the last element was
2364 bool shouldEmitSpace
= true, lastWasPunctuation
= false;
2365 for (FormatElement
*element
: elements
)
2366 genElementPrinter(element
, body
, op
, shouldEmitSpace
, lastWasPunctuation
);
2369 //===----------------------------------------------------------------------===//
2371 //===----------------------------------------------------------------------===//
2373 /// Function to find an element within the given range that has the same name as
2375 template <typename RangeT
>
2376 static auto findArg(RangeT
&&range
, StringRef name
) {
2377 auto it
= llvm::find_if(range
, [=](auto &arg
) { return arg
.name
== name
; });
2378 return it
!= range
.end() ? &*it
: nullptr;
2382 /// This class implements a parser for an instance of an operation assembly
2384 class OpFormatParser
: public FormatParser
{
2386 OpFormatParser(llvm::SourceMgr
&mgr
, OperationFormat
&format
, Operator
&op
)
2387 : FormatParser(mgr
, op
.getLoc()[0]), fmt(format
), op(op
),
2388 seenOperandTypes(op
.getNumOperands()),
2389 seenResultTypes(op
.getNumResults()) {}
2392 /// Verify the format elements.
2393 LogicalResult
verify(SMLoc loc
, ArrayRef
<FormatElement
*> elements
) override
;
2394 /// Verify the arguments to a custom directive.
2396 verifyCustomDirectiveArguments(SMLoc loc
,
2397 ArrayRef
<FormatElement
*> arguments
) override
;
2398 /// Verify the elements of an optional group.
2399 LogicalResult
verifyOptionalGroupElements(SMLoc loc
,
2400 ArrayRef
<FormatElement
*> elements
,
2401 FormatElement
*anchor
) override
;
2402 LogicalResult
verifyOptionalGroupElement(SMLoc loc
, FormatElement
*element
,
2405 /// Parse an operation variable.
2406 FailureOr
<FormatElement
*> parseVariableImpl(SMLoc loc
, StringRef name
,
2407 Context ctx
) override
;
2408 /// Parse an operation format directive.
2409 FailureOr
<FormatElement
*>
2410 parseDirectiveImpl(SMLoc loc
, FormatToken::Kind kind
, Context ctx
) override
;
2413 /// This struct represents a type resolution instance. It includes a specific
2414 /// type as well as an optional transformer to apply to that type in order to
2415 /// properly resolve the type of a variable.
2416 struct TypeResolutionInstance
{
2417 ConstArgument resolver
;
2418 std::optional
<StringRef
> transformer
;
2421 /// Verify the state of operation attributes within the format.
2422 LogicalResult
verifyAttributes(SMLoc loc
, ArrayRef
<FormatElement
*> elements
);
2424 /// Verify that attributes elements aren't followed by colon literals.
2425 LogicalResult
verifyAttributeColonType(SMLoc loc
,
2426 ArrayRef
<FormatElement
*> elements
);
2427 /// Verify that the attribute dictionary directive isn't followed by a region.
2428 LogicalResult
verifyAttrDictRegion(SMLoc loc
,
2429 ArrayRef
<FormatElement
*> elements
);
2431 /// Verify the state of operation operands within the format.
2433 verifyOperands(SMLoc loc
,
2434 llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
);
2436 /// Verify the state of operation regions within the format.
2437 LogicalResult
verifyRegions(SMLoc loc
);
2439 /// Verify the state of operation results within the format.
2441 verifyResults(SMLoc loc
,
2442 llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
);
2444 /// Verify the state of operation successors within the format.
2445 LogicalResult
verifySuccessors(SMLoc loc
);
2447 LogicalResult
verifyOIListElements(SMLoc loc
,
2448 ArrayRef
<FormatElement
*> elements
);
2450 /// Given the values of an `AllTypesMatch` trait, check for inferable type
2452 void handleAllTypesMatchConstraint(
2453 ArrayRef
<StringRef
> values
,
2454 llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
);
2455 /// Check for inferable type resolution given all operands, and or results,
2456 /// have the same type. If 'includeResults' is true, the results also have the
2457 /// same type as all of the operands.
2458 void handleSameTypesConstraint(
2459 llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
,
2460 bool includeResults
);
2461 /// Check for inferable type resolution based on another operand, result, or
2463 void handleTypesMatchConstraint(
2464 llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
,
2465 const llvm::Record
&def
);
2467 /// Returns an argument or attribute with the given name that has been seen
2468 /// within the format.
2469 ConstArgument
findSeenArg(StringRef name
);
2471 /// Parse the various different directives.
2472 FailureOr
<FormatElement
*> parsePropDictDirective(SMLoc loc
, Context context
);
2473 FailureOr
<FormatElement
*> parseAttrDictDirective(SMLoc loc
, Context context
,
2475 FailureOr
<FormatElement
*> parseFunctionalTypeDirective(SMLoc loc
,
2477 FailureOr
<FormatElement
*> parseOIListDirective(SMLoc loc
, Context context
);
2478 LogicalResult
verifyOIListParsingElement(FormatElement
*element
, SMLoc loc
);
2479 FailureOr
<FormatElement
*> parseOperandsDirective(SMLoc loc
, Context context
);
2480 FailureOr
<FormatElement
*> parseQualifiedDirective(SMLoc loc
,
2482 FailureOr
<FormatElement
*> parseReferenceDirective(SMLoc loc
,
2484 FailureOr
<FormatElement
*> parseRegionsDirective(SMLoc loc
, Context context
);
2485 FailureOr
<FormatElement
*> parseResultsDirective(SMLoc loc
, Context context
);
2486 FailureOr
<FormatElement
*> parseSuccessorsDirective(SMLoc loc
,
2488 FailureOr
<FormatElement
*> parseTypeDirective(SMLoc loc
, Context context
);
2489 FailureOr
<FormatElement
*> parseTypeDirectiveOperand(SMLoc loc
,
2490 bool isRefChild
= false);
2492 //===--------------------------------------------------------------------===//
2494 //===--------------------------------------------------------------------===//
2496 OperationFormat
&fmt
;
2499 // The following are various bits of format state used for verification
2501 bool hasAttrDict
= false;
2502 bool hasPropDict
= false;
2503 bool hasAllRegions
= false, hasAllSuccessors
= false;
2504 bool canInferResultTypes
= false;
2505 llvm::SmallBitVector seenOperandTypes
, seenResultTypes
;
2506 llvm::SmallSetVector
<const NamedAttribute
*, 8> seenAttrs
;
2507 llvm::DenseSet
<const NamedTypeConstraint
*> seenOperands
;
2508 llvm::DenseSet
<const NamedRegion
*> seenRegions
;
2509 llvm::DenseSet
<const NamedSuccessor
*> seenSuccessors
;
2510 llvm::DenseSet
<const NamedProperty
*> seenProperties
;
2514 LogicalResult
OpFormatParser::verify(SMLoc loc
,
2515 ArrayRef
<FormatElement
*> elements
) {
2516 // Check that the attribute dictionary is in the format.
2518 return emitError(loc
, "'attr-dict' directive not found in "
2519 "custom assembly format");
2521 // Check for any type traits that we can use for inferring types.
2522 llvm::StringMap
<TypeResolutionInstance
> variableTyResolver
;
2523 for (const Trait
&trait
: op
.getTraits()) {
2524 const llvm::Record
&def
= trait
.getDef();
2525 if (def
.isSubClassOf("AllTypesMatch")) {
2526 handleAllTypesMatchConstraint(def
.getValueAsListOfStrings("values"),
2527 variableTyResolver
);
2528 } else if (def
.getName() == "SameTypeOperands") {
2529 handleSameTypesConstraint(variableTyResolver
, /*includeResults=*/false);
2530 } else if (def
.getName() == "SameOperandsAndResultType") {
2531 handleSameTypesConstraint(variableTyResolver
, /*includeResults=*/true);
2532 } else if (def
.isSubClassOf("TypesMatchWith")) {
2533 handleTypesMatchConstraint(variableTyResolver
, def
);
2534 } else if (!op
.allResultTypesKnown()) {
2535 // This doesn't check the name directly to handle
2536 // DeclareOpInterfaceMethods<InferTypeOpInterface>
2538 // TODO: Add hasCppInterface check.
2539 if (auto name
= def
.getValueAsOptionalString("cppInterfaceName")) {
2540 if (*name
== "InferTypeOpInterface" &&
2541 def
.getValueAsString("cppNamespace") == "::mlir")
2542 canInferResultTypes
= true;
2547 // Verify the state of the various operation components.
2548 if (failed(verifyAttributes(loc
, elements
)) ||
2549 failed(verifyResults(loc
, variableTyResolver
)) ||
2550 failed(verifyOperands(loc
, variableTyResolver
)) ||
2551 failed(verifyRegions(loc
)) || failed(verifySuccessors(loc
)) ||
2552 failed(verifyOIListElements(loc
, elements
)))
2555 // Collect the set of used attributes in the format.
2556 fmt
.usedAttributes
= seenAttrs
.takeVector();
2561 OpFormatParser::verifyAttributes(SMLoc loc
,
2562 ArrayRef
<FormatElement
*> elements
) {
2563 // Check that there are no `:` literals after an attribute without a constant
2564 // type. The attribute grammar contains an optional trailing colon type, which
2565 // can lead to unexpected and generally unintended behavior. Given that, it is
2566 // better to just error out here instead.
2567 if (failed(verifyAttributeColonType(loc
, elements
)))
2569 // Check that there are no region variables following an attribute dicitonary.
2570 // Both start with `{` and so the optional attribute dictionary can cause
2571 // format ambiguities.
2572 if (failed(verifyAttrDictRegion(loc
, elements
)))
2575 // Check for VariadicOfVariadic variables. The segment attribute of those
2576 // variables will be infered.
2577 for (const NamedTypeConstraint
*var
: seenOperands
) {
2578 if (var
->constraint
.isVariadicOfVariadic()) {
2579 fmt
.inferredAttributes
.insert(
2580 var
->constraint
.getVariadicOfVariadicSegmentSizeAttr());
2587 /// Returns whether the single format element is optionally parsed.
2588 static bool isOptionallyParsed(FormatElement
*el
) {
2589 if (auto *attrVar
= dyn_cast
<AttributeVariable
>(el
)) {
2590 Attribute attr
= attrVar
->getVar()->attr
;
2591 return attr
.isOptional() || attr
.hasDefaultValue();
2593 if (auto *operandVar
= dyn_cast
<OperandVariable
>(el
)) {
2594 const NamedTypeConstraint
*operand
= operandVar
->getVar();
2595 return operand
->isOptional() || operand
->isVariadic() ||
2596 operand
->isVariadicOfVariadic();
2598 if (auto *successorVar
= dyn_cast
<SuccessorVariable
>(el
))
2599 return successorVar
->getVar()->isVariadic();
2600 if (auto *regionVar
= dyn_cast
<RegionVariable
>(el
))
2601 return regionVar
->getVar()->isVariadic();
2602 return isa
<WhitespaceElement
, AttrDictDirective
>(el
);
2605 /// Scan the given range of elements from the start for an invalid format
2606 /// element that satisfies `isInvalid`, skipping any optionally-parsed elements.
2607 /// If an optional group is encountered, this function recurses into the 'then'
2608 /// and 'else' elements to check if they are invalid. Returns `success` if the
2609 /// range is known to be valid or `std::nullopt` if scanning reached the end.
2611 /// Since the guard element of an optional group is required, this function
2612 /// accepts an optional element pointer to mark it as required.
2613 static std::optional
<LogicalResult
> checkRangeForElement(
2614 FormatElement
*base
,
2615 function_ref
<bool(FormatElement
*, FormatElement
*)> isInvalid
,
2616 iterator_range
<ArrayRef
<FormatElement
*>::iterator
> elementRange
,
2617 FormatElement
*optionalGuard
= nullptr) {
2618 for (FormatElement
*element
: elementRange
) {
2619 // If we encounter an invalid element, return an error.
2620 if (isInvalid(base
, element
))
2623 // Recurse on optional groups.
2624 if (auto *optional
= dyn_cast
<OptionalElement
>(element
)) {
2625 if (std::optional
<LogicalResult
> result
= checkRangeForElement(
2626 base
, isInvalid
, optional
->getThenElements(),
2627 // The optional group guard is required for the group.
2628 optional
->getThenElements().front()))
2629 if (failed(*result
))
2631 if (std::optional
<LogicalResult
> result
= checkRangeForElement(
2632 base
, isInvalid
, optional
->getElseElements()))
2633 if (failed(*result
))
2635 // Skip the optional group.
2639 // Skip optionally parsed elements.
2640 if (element
!= optionalGuard
&& isOptionallyParsed(element
))
2643 // We found a closing element that is valid.
2646 // Return std::nullopt to indicate that we reached the end.
2647 return std::nullopt
;
2650 /// For the given elements, check whether any attributes are followed by a colon
2651 /// literal, resulting in an ambiguous assembly format. Returns a non-null
2652 /// attribute if verification of said attribute reached the end of the range.
2653 /// Returns null if all attribute elements are verified.
2654 static FailureOr
<FormatElement
*> verifyAdjacentElements(
2655 function_ref
<bool(FormatElement
*)> isBase
,
2656 function_ref
<bool(FormatElement
*, FormatElement
*)> isInvalid
,
2657 ArrayRef
<FormatElement
*> elements
) {
2658 for (auto *it
= elements
.begin(), *e
= elements
.end(); it
!= e
; ++it
) {
2659 // The current attribute being verified.
2660 FormatElement
*base
;
2664 } else if (auto *optional
= dyn_cast
<OptionalElement
>(*it
)) {
2665 // Recurse on optional groups.
2666 FailureOr
<FormatElement
*> thenResult
= verifyAdjacentElements(
2667 isBase
, isInvalid
, optional
->getThenElements());
2668 if (failed(thenResult
))
2670 FailureOr
<FormatElement
*> elseResult
= verifyAdjacentElements(
2671 isBase
, isInvalid
, optional
->getElseElements());
2672 if (failed(elseResult
))
2674 // If either optional group has an unverified attribute, save it.
2675 // Otherwise, move on to the next element.
2676 if (!(base
= *thenResult
) && !(base
= *elseResult
))
2682 // Verify subsequent elements for potential ambiguities.
2683 if (std::optional
<LogicalResult
> result
=
2684 checkRangeForElement(base
, isInvalid
, {std::next(it
), e
})) {
2685 if (failed(*result
))
2688 // Since we reached the end, return the attribute as unverified.
2692 // All attribute elements are known to be verified.
2697 OpFormatParser::verifyAttributeColonType(SMLoc loc
,
2698 ArrayRef
<FormatElement
*> elements
) {
2699 auto isBase
= [](FormatElement
*el
) {
2700 auto *attr
= dyn_cast
<AttributeVariable
>(el
);
2703 // Check only attributes without type builders or that are known to call
2704 // the generic attribute parser.
2705 return !attr
->getTypeBuilder() &&
2706 (attr
->shouldBeQualified() ||
2707 attr
->getVar()->attr
.getStorageType() == "::mlir::Attribute");
2709 auto isInvalid
= [&](FormatElement
*base
, FormatElement
*el
) {
2710 auto *literal
= dyn_cast
<LiteralElement
>(el
);
2711 if (!literal
|| literal
->getSpelling() != ":")
2713 // If we encounter `:`, the range is known to be invalid.
2716 llvm::formatv("format ambiguity caused by `:` literal found after "
2717 "attribute `{0}` which does not have a buildable type",
2718 cast
<AttributeVariable
>(base
)->getVar()->name
));
2721 return verifyAdjacentElements(isBase
, isInvalid
, elements
);
2725 OpFormatParser::verifyAttrDictRegion(SMLoc loc
,
2726 ArrayRef
<FormatElement
*> elements
) {
2727 auto isBase
= [](FormatElement
*el
) {
2728 if (auto *attrDict
= dyn_cast
<AttrDictDirective
>(el
))
2729 return !attrDict
->isWithKeyword();
2732 auto isInvalid
= [&](FormatElement
*base
, FormatElement
*el
) {
2733 auto *region
= dyn_cast
<RegionVariable
>(el
);
2736 (void)emitErrorAndNote(
2738 llvm::formatv("format ambiguity caused by `attr-dict` directive "
2739 "followed by region `{0}`",
2740 region
->getVar()->name
),
2741 "try using `attr-dict-with-keyword` instead");
2744 return verifyAdjacentElements(isBase
, isInvalid
, elements
);
2747 LogicalResult
OpFormatParser::verifyOperands(
2748 SMLoc loc
, llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
) {
2749 // Check that all of the operands are within the format, and their types can
2751 auto &buildableTypes
= fmt
.buildableTypes
;
2752 for (unsigned i
= 0, e
= op
.getNumOperands(); i
!= e
; ++i
) {
2753 NamedTypeConstraint
&operand
= op
.getOperand(i
);
2755 // Check that the operand itself is in the format.
2756 if (!fmt
.allOperands
&& !seenOperands
.count(&operand
)) {
2757 return emitErrorAndNote(loc
,
2758 "operand #" + Twine(i
) + ", named '" +
2759 operand
.name
+ "', not found",
2760 "suggest adding a '$" + operand
.name
+
2761 "' directive to the custom assembly format");
2764 // Check that the operand type is in the format, or that it can be inferred.
2765 if (fmt
.allOperandTypes
|| seenOperandTypes
.test(i
))
2768 // Check to see if we can infer this type from another variable.
2769 auto varResolverIt
= variableTyResolver
.find(op
.getOperand(i
).name
);
2770 if (varResolverIt
!= variableTyResolver
.end()) {
2771 TypeResolutionInstance
&resolver
= varResolverIt
->second
;
2772 fmt
.operandTypes
[i
].setResolver(resolver
.resolver
, resolver
.transformer
);
2776 // Similarly to results, allow a custom builder for resolving the type if
2777 // we aren't using the 'operands' directive.
2778 std::optional
<StringRef
> builder
= operand
.constraint
.getBuilderCall();
2779 if (!builder
|| (fmt
.allOperands
&& operand
.isVariableLength())) {
2780 return emitErrorAndNote(
2782 "type of operand #" + Twine(i
) + ", named '" + operand
.name
+
2783 "', is not buildable and a buildable type cannot be inferred",
2784 "suggest adding a type constraint to the operation or adding a "
2786 operand
.name
+ ")' directive to the " + "custom assembly format");
2788 auto it
= buildableTypes
.insert({*builder
, buildableTypes
.size()});
2789 fmt
.operandTypes
[i
].setBuilderIdx(it
.first
->second
);
2794 LogicalResult
OpFormatParser::verifyRegions(SMLoc loc
) {
2795 // Check that all of the regions are within the format.
2799 for (unsigned i
= 0, e
= op
.getNumRegions(); i
!= e
; ++i
) {
2800 const NamedRegion
®ion
= op
.getRegion(i
);
2801 if (!seenRegions
.count(®ion
)) {
2802 return emitErrorAndNote(loc
,
2803 "region #" + Twine(i
) + ", named '" +
2804 region
.name
+ "', not found",
2805 "suggest adding a '$" + region
.name
+
2806 "' directive to the custom assembly format");
2812 LogicalResult
OpFormatParser::verifyResults(
2813 SMLoc loc
, llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
) {
2814 // If we format all of the types together, there is nothing to check.
2815 if (fmt
.allResultTypes
)
2818 // If no result types are specified and we can infer them, infer all result
2820 if (op
.getNumResults() > 0 && seenResultTypes
.count() == 0 &&
2821 canInferResultTypes
) {
2822 fmt
.infersResultTypes
= true;
2826 // Check that all of the result types can be inferred.
2827 auto &buildableTypes
= fmt
.buildableTypes
;
2828 for (unsigned i
= 0, e
= op
.getNumResults(); i
!= e
; ++i
) {
2829 if (seenResultTypes
.test(i
))
2832 // Check to see if we can infer this type from another variable.
2833 auto varResolverIt
= variableTyResolver
.find(op
.getResultName(i
));
2834 if (varResolverIt
!= variableTyResolver
.end()) {
2835 TypeResolutionInstance resolver
= varResolverIt
->second
;
2836 fmt
.resultTypes
[i
].setResolver(resolver
.resolver
, resolver
.transformer
);
2840 // If the result is not variable length, allow for the case where the type
2841 // has a builder that we can use.
2842 NamedTypeConstraint
&result
= op
.getResult(i
);
2843 std::optional
<StringRef
> builder
= result
.constraint
.getBuilderCall();
2844 if (!builder
|| result
.isVariableLength()) {
2845 return emitErrorAndNote(
2847 "type of result #" + Twine(i
) + ", named '" + result
.name
+
2848 "', is not buildable and a buildable type cannot be inferred",
2849 "suggest adding a type constraint to the operation or adding a "
2851 result
.name
+ ")' directive to the " + "custom assembly format");
2853 // Note in the format that this result uses the custom builder.
2854 auto it
= buildableTypes
.insert({*builder
, buildableTypes
.size()});
2855 fmt
.resultTypes
[i
].setBuilderIdx(it
.first
->second
);
2860 LogicalResult
OpFormatParser::verifySuccessors(SMLoc loc
) {
2861 // Check that all of the successors are within the format.
2862 if (hasAllSuccessors
)
2865 for (unsigned i
= 0, e
= op
.getNumSuccessors(); i
!= e
; ++i
) {
2866 const NamedSuccessor
&successor
= op
.getSuccessor(i
);
2867 if (!seenSuccessors
.count(&successor
)) {
2868 return emitErrorAndNote(loc
,
2869 "successor #" + Twine(i
) + ", named '" +
2870 successor
.name
+ "', not found",
2871 "suggest adding a '$" + successor
.name
+
2872 "' directive to the custom assembly format");
2879 OpFormatParser::verifyOIListElements(SMLoc loc
,
2880 ArrayRef
<FormatElement
*> elements
) {
2881 // Check that all of the successors are within the format.
2882 SmallVector
<StringRef
> prohibitedLiterals
;
2883 for (FormatElement
*it
: elements
) {
2884 if (auto *oilist
= dyn_cast
<OIListElement
>(it
)) {
2885 if (!prohibitedLiterals
.empty()) {
2886 // We just saw an oilist element in last iteration. Literals should not
2888 for (LiteralElement
*literal
: oilist
->getLiteralElements()) {
2889 if (find(prohibitedLiterals
, literal
->getSpelling()) !=
2890 prohibitedLiterals
.end()) {
2892 loc
, "format ambiguity because " + literal
->getSpelling() +
2893 " is used in two adjacent oilist elements.");
2897 for (LiteralElement
*literal
: oilist
->getLiteralElements())
2898 prohibitedLiterals
.push_back(literal
->getSpelling());
2899 } else if (auto *literal
= dyn_cast
<LiteralElement
>(it
)) {
2900 if (find(prohibitedLiterals
, literal
->getSpelling()) !=
2901 prohibitedLiterals
.end()) {
2904 "format ambiguity because " + literal
->getSpelling() +
2905 " is used both in oilist element and the adjacent literal.");
2907 prohibitedLiterals
.clear();
2909 prohibitedLiterals
.clear();
2915 void OpFormatParser::handleAllTypesMatchConstraint(
2916 ArrayRef
<StringRef
> values
,
2917 llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
) {
2918 for (unsigned i
= 0, e
= values
.size(); i
!= e
; ++i
) {
2919 // Check to see if this value matches a resolved operand or result type.
2920 ConstArgument arg
= findSeenArg(values
[i
]);
2924 // Mark this value as the type resolver for the other variables.
2925 for (unsigned j
= 0; j
!= i
; ++j
)
2926 variableTyResolver
[values
[j
]] = {arg
, std::nullopt
};
2927 for (unsigned j
= i
+ 1; j
!= e
; ++j
)
2928 variableTyResolver
[values
[j
]] = {arg
, std::nullopt
};
2932 void OpFormatParser::handleSameTypesConstraint(
2933 llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
,
2934 bool includeResults
) {
2935 const NamedTypeConstraint
*resolver
= nullptr;
2936 int resolvedIt
= -1;
2938 // Check to see if there is an operand or result to use for the resolution.
2939 if ((resolvedIt
= seenOperandTypes
.find_first()) != -1)
2940 resolver
= &op
.getOperand(resolvedIt
);
2941 else if (includeResults
&& (resolvedIt
= seenResultTypes
.find_first()) != -1)
2942 resolver
= &op
.getResult(resolvedIt
);
2946 // Set the resolvers for each operand and result.
2947 for (unsigned i
= 0, e
= op
.getNumOperands(); i
!= e
; ++i
)
2948 if (!seenOperandTypes
.test(i
))
2949 variableTyResolver
[op
.getOperand(i
).name
] = {resolver
, std::nullopt
};
2950 if (includeResults
) {
2951 for (unsigned i
= 0, e
= op
.getNumResults(); i
!= e
; ++i
)
2952 if (!seenResultTypes
.test(i
))
2953 variableTyResolver
[op
.getResultName(i
)] = {resolver
, std::nullopt
};
2957 void OpFormatParser::handleTypesMatchConstraint(
2958 llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
,
2959 const llvm::Record
&def
) {
2960 StringRef lhsName
= def
.getValueAsString("lhs");
2961 StringRef rhsName
= def
.getValueAsString("rhs");
2962 StringRef transformer
= def
.getValueAsString("transformer");
2963 if (ConstArgument arg
= findSeenArg(lhsName
))
2964 variableTyResolver
[rhsName
] = {arg
, transformer
};
2967 ConstArgument
OpFormatParser::findSeenArg(StringRef name
) {
2968 if (const NamedTypeConstraint
*arg
= findArg(op
.getOperands(), name
))
2969 return seenOperandTypes
.test(arg
- op
.operand_begin()) ? arg
: nullptr;
2970 if (const NamedTypeConstraint
*arg
= findArg(op
.getResults(), name
))
2971 return seenResultTypes
.test(arg
- op
.result_begin()) ? arg
: nullptr;
2972 if (const NamedAttribute
*attr
= findArg(op
.getAttributes(), name
))
2973 return seenAttrs
.count(attr
) ? attr
: nullptr;
2977 FailureOr
<FormatElement
*>
2978 OpFormatParser::parseVariableImpl(SMLoc loc
, StringRef name
, Context ctx
) {
2979 // Check that the parsed argument is something actually registered on the op.
2981 if (const NamedAttribute
*attr
= findArg(op
.getAttributes(), name
)) {
2982 if (ctx
== TypeDirectiveContext
)
2984 loc
, "attributes cannot be used as children to a `type` directive");
2985 if (ctx
== RefDirectiveContext
) {
2986 if (!seenAttrs
.count(attr
))
2987 return emitError(loc
, "attribute '" + name
+
2988 "' must be bound before it is referenced");
2989 } else if (!seenAttrs
.insert(attr
)) {
2990 return emitError(loc
, "attribute '" + name
+ "' is already bound");
2993 return create
<AttributeVariable
>(attr
);
2996 if (const NamedProperty
*property
= findArg(op
.getProperties(), name
)) {
2997 if (ctx
!= CustomDirectiveContext
&& ctx
!= RefDirectiveContext
)
2999 loc
, "properties currently only supported in `custom` directive");
3001 if (ctx
== RefDirectiveContext
) {
3002 if (!seenProperties
.count(property
))
3003 return emitError(loc
, "property '" + name
+
3004 "' must be bound before it is referenced");
3006 if (!seenProperties
.insert(property
).second
)
3007 return emitError(loc
, "property '" + name
+ "' is already bound");
3010 return create
<PropertyVariable
>(property
);
3014 if (const NamedTypeConstraint
*operand
= findArg(op
.getOperands(), name
)) {
3015 if (ctx
== TopLevelContext
|| ctx
== CustomDirectiveContext
) {
3016 if (fmt
.allOperands
|| !seenOperands
.insert(operand
).second
)
3017 return emitError(loc
, "operand '" + name
+ "' is already bound");
3018 } else if (ctx
== RefDirectiveContext
&& !seenOperands
.count(operand
)) {
3019 return emitError(loc
, "operand '" + name
+
3020 "' must be bound before it is referenced");
3022 return create
<OperandVariable
>(operand
);
3025 if (const NamedRegion
*region
= findArg(op
.getRegions(), name
)) {
3026 if (ctx
== TopLevelContext
|| ctx
== CustomDirectiveContext
) {
3027 if (hasAllRegions
|| !seenRegions
.insert(region
).second
)
3028 return emitError(loc
, "region '" + name
+ "' is already bound");
3029 } else if (ctx
== RefDirectiveContext
&& !seenRegions
.count(region
)) {
3030 return emitError(loc
, "region '" + name
+
3031 "' must be bound before it is referenced");
3033 return emitError(loc
, "regions can only be used at the top level");
3035 return create
<RegionVariable
>(region
);
3038 if (const auto *result
= findArg(op
.getResults(), name
)) {
3039 if (ctx
!= TypeDirectiveContext
)
3040 return emitError(loc
, "result variables can can only be used as a child "
3041 "to a 'type' directive");
3042 return create
<ResultVariable
>(result
);
3045 if (const auto *successor
= findArg(op
.getSuccessors(), name
)) {
3046 if (ctx
== TopLevelContext
|| ctx
== CustomDirectiveContext
) {
3047 if (hasAllSuccessors
|| !seenSuccessors
.insert(successor
).second
)
3048 return emitError(loc
, "successor '" + name
+ "' is already bound");
3049 } else if (ctx
== RefDirectiveContext
&& !seenSuccessors
.count(successor
)) {
3050 return emitError(loc
, "successor '" + name
+
3051 "' must be bound before it is referenced");
3053 return emitError(loc
, "successors can only be used at the top level");
3056 return create
<SuccessorVariable
>(successor
);
3058 return emitError(loc
, "expected variable to refer to an argument, region, "
3059 "result, or successor");
3062 FailureOr
<FormatElement
*>
3063 OpFormatParser::parseDirectiveImpl(SMLoc loc
, FormatToken::Kind kind
,
3066 case FormatToken::kw_prop_dict
:
3067 return parsePropDictDirective(loc
, ctx
);
3068 case FormatToken::kw_attr_dict
:
3069 return parseAttrDictDirective(loc
, ctx
,
3070 /*withKeyword=*/false);
3071 case FormatToken::kw_attr_dict_w_keyword
:
3072 return parseAttrDictDirective(loc
, ctx
,
3073 /*withKeyword=*/true);
3074 case FormatToken::kw_functional_type
:
3075 return parseFunctionalTypeDirective(loc
, ctx
);
3076 case FormatToken::kw_operands
:
3077 return parseOperandsDirective(loc
, ctx
);
3078 case FormatToken::kw_qualified
:
3079 return parseQualifiedDirective(loc
, ctx
);
3080 case FormatToken::kw_regions
:
3081 return parseRegionsDirective(loc
, ctx
);
3082 case FormatToken::kw_results
:
3083 return parseResultsDirective(loc
, ctx
);
3084 case FormatToken::kw_successors
:
3085 return parseSuccessorsDirective(loc
, ctx
);
3086 case FormatToken::kw_ref
:
3087 return parseReferenceDirective(loc
, ctx
);
3088 case FormatToken::kw_type
:
3089 return parseTypeDirective(loc
, ctx
);
3090 case FormatToken::kw_oilist
:
3091 return parseOIListDirective(loc
, ctx
);
3094 return emitError(loc
, "unsupported directive kind");
3098 FailureOr
<FormatElement
*>
3099 OpFormatParser::parseAttrDictDirective(SMLoc loc
, Context context
,
3101 if (context
== TypeDirectiveContext
)
3102 return emitError(loc
, "'attr-dict' directive can only be used as a "
3103 "top-level directive");
3105 if (context
== RefDirectiveContext
) {
3107 return emitError(loc
, "'ref' of 'attr-dict' is not bound by a prior "
3108 "'attr-dict' directive");
3110 // Otherwise, this is a top-level context.
3113 return emitError(loc
, "'attr-dict' directive has already been seen");
3117 return create
<AttrDictDirective
>(withKeyword
);
3120 FailureOr
<FormatElement
*>
3121 OpFormatParser::parsePropDictDirective(SMLoc loc
, Context context
) {
3122 if (context
== TypeDirectiveContext
)
3123 return emitError(loc
, "'prop-dict' directive can only be used as a "
3124 "top-level directive");
3126 if (context
== RefDirectiveContext
)
3127 llvm::report_fatal_error("'ref' of 'prop-dict' unsupported");
3128 // Otherwise, this is a top-level context.
3131 return emitError(loc
, "'prop-dict' directive has already been seen");
3134 return create
<PropDictDirective
>();
3137 LogicalResult
OpFormatParser::verifyCustomDirectiveArguments(
3138 SMLoc loc
, ArrayRef
<FormatElement
*> arguments
) {
3139 for (FormatElement
*argument
: arguments
) {
3140 if (!isa
<AttrDictDirective
, AttributeVariable
, OperandVariable
,
3141 PropertyVariable
, RefDirective
, RegionVariable
, SuccessorVariable
,
3142 StringElement
, TypeDirective
>(argument
)) {
3143 // TODO: FormatElement should have location info attached.
3144 return emitError(loc
, "only variables and types may be used as "
3145 "parameters to a custom directive");
3147 if (auto *type
= dyn_cast
<TypeDirective
>(argument
)) {
3148 if (!isa
<OperandVariable
, ResultVariable
>(type
->getArg())) {
3149 return emitError(loc
, "type directives within a custom directive may "
3150 "only refer to variables");
3157 FailureOr
<FormatElement
*>
3158 OpFormatParser::parseFunctionalTypeDirective(SMLoc loc
, Context context
) {
3159 if (context
!= TopLevelContext
)
3161 loc
, "'functional-type' is only valid as a top-level directive");
3163 // Parse the main operand.
3164 FailureOr
<FormatElement
*> inputs
, results
;
3165 if (failed(parseToken(FormatToken::l_paren
,
3166 "expected '(' before argument list")) ||
3167 failed(inputs
= parseTypeDirectiveOperand(loc
)) ||
3168 failed(parseToken(FormatToken::comma
,
3169 "expected ',' after inputs argument")) ||
3170 failed(results
= parseTypeDirectiveOperand(loc
)) ||
3172 parseToken(FormatToken::r_paren
, "expected ')' after argument list")))
3174 return create
<FunctionalTypeDirective
>(*inputs
, *results
);
3177 FailureOr
<FormatElement
*>
3178 OpFormatParser::parseOperandsDirective(SMLoc loc
, Context context
) {
3179 if (context
== RefDirectiveContext
) {
3180 if (!fmt
.allOperands
)
3181 return emitError(loc
, "'ref' of 'operands' is not bound by a prior "
3182 "'operands' directive");
3184 } else if (context
== TopLevelContext
|| context
== CustomDirectiveContext
) {
3185 if (fmt
.allOperands
|| !seenOperands
.empty())
3186 return emitError(loc
, "'operands' directive creates overlap in format");
3187 fmt
.allOperands
= true;
3189 return create
<OperandsDirective
>();
3192 FailureOr
<FormatElement
*>
3193 OpFormatParser::parseReferenceDirective(SMLoc loc
, Context context
) {
3194 if (context
!= CustomDirectiveContext
)
3195 return emitError(loc
, "'ref' is only valid within a `custom` directive");
3197 FailureOr
<FormatElement
*> arg
;
3198 if (failed(parseToken(FormatToken::l_paren
,
3199 "expected '(' before argument list")) ||
3200 failed(arg
= parseElement(RefDirectiveContext
)) ||
3202 parseToken(FormatToken::r_paren
, "expected ')' after argument list")))
3205 return create
<RefDirective
>(*arg
);
3208 FailureOr
<FormatElement
*>
3209 OpFormatParser::parseRegionsDirective(SMLoc loc
, Context context
) {
3210 if (context
== TypeDirectiveContext
)
3211 return emitError(loc
, "'regions' is only valid as a top-level directive");
3212 if (context
== RefDirectiveContext
) {
3214 return emitError(loc
, "'ref' of 'regions' is not bound by a prior "
3215 "'regions' directive");
3217 // Otherwise, this is a TopLevel directive.
3219 if (hasAllRegions
|| !seenRegions
.empty())
3220 return emitError(loc
, "'regions' directive creates overlap in format");
3221 hasAllRegions
= true;
3223 return create
<RegionsDirective
>();
3226 FailureOr
<FormatElement
*>
3227 OpFormatParser::parseResultsDirective(SMLoc loc
, Context context
) {
3228 if (context
!= TypeDirectiveContext
)
3229 return emitError(loc
, "'results' directive can can only be used as a child "
3230 "to a 'type' directive");
3231 return create
<ResultsDirective
>();
3234 FailureOr
<FormatElement
*>
3235 OpFormatParser::parseSuccessorsDirective(SMLoc loc
, Context context
) {
3236 if (context
== TypeDirectiveContext
)
3237 return emitError(loc
,
3238 "'successors' is only valid as a top-level directive");
3239 if (context
== RefDirectiveContext
) {
3240 if (!hasAllSuccessors
)
3241 return emitError(loc
, "'ref' of 'successors' is not bound by a prior "
3242 "'successors' directive");
3244 // Otherwise, this is a TopLevel directive.
3246 if (hasAllSuccessors
|| !seenSuccessors
.empty())
3247 return emitError(loc
, "'successors' directive creates overlap in format");
3248 hasAllSuccessors
= true;
3250 return create
<SuccessorsDirective
>();
3253 FailureOr
<FormatElement
*>
3254 OpFormatParser::parseOIListDirective(SMLoc loc
, Context context
) {
3255 if (failed(parseToken(FormatToken::l_paren
,
3256 "expected '(' before oilist argument list")))
3258 std::vector
<FormatElement
*> literalElements
;
3259 std::vector
<std::vector
<FormatElement
*>> parsingElements
;
3261 FailureOr
<FormatElement
*> lelement
= parseLiteral(context
);
3262 if (failed(lelement
))
3264 literalElements
.push_back(*lelement
);
3265 parsingElements
.emplace_back();
3266 std::vector
<FormatElement
*> &currParsingElements
= parsingElements
.back();
3267 while (peekToken().getKind() != FormatToken::pipe
&&
3268 peekToken().getKind() != FormatToken::r_paren
) {
3269 FailureOr
<FormatElement
*> pelement
= parseElement(context
);
3270 if (failed(pelement
) ||
3271 failed(verifyOIListParsingElement(*pelement
, loc
)))
3273 currParsingElements
.push_back(*pelement
);
3275 if (peekToken().getKind() == FormatToken::pipe
) {
3279 if (peekToken().getKind() == FormatToken::r_paren
) {
3285 return create
<OIListElement
>(std::move(literalElements
),
3286 std::move(parsingElements
));
3289 LogicalResult
OpFormatParser::verifyOIListParsingElement(FormatElement
*element
,
3291 SmallVector
<VariableElement
*> vars
;
3292 collect(element
, vars
);
3293 for (VariableElement
*elem
: vars
) {
3295 TypeSwitch
<FormatElement
*, LogicalResult
>(elem
)
3296 // Only optional attributes can be within an oilist parsing group.
3297 .Case([&](AttributeVariable
*attrEle
) {
3298 if (!attrEle
->getVar()->attr
.isOptional() &&
3299 !attrEle
->getVar()->attr
.hasDefaultValue())
3300 return emitError(loc
, "only optional attributes can be used in "
3301 "an oilist parsing group");
3304 // Only optional-like(i.e. variadic) operands can be within an
3305 // oilist parsing group.
3306 .Case([&](OperandVariable
*ele
) {
3307 if (!ele
->getVar()->isVariableLength())
3308 return emitError(loc
, "only variable length operands can be "
3309 "used within an oilist parsing group");
3312 // Only optional-like(i.e. variadic) results can be within an oilist
3314 .Case([&](ResultVariable
*ele
) {
3315 if (!ele
->getVar()->isVariableLength())
3316 return emitError(loc
, "only variable length results can be "
3317 "used within an oilist parsing group");
3320 .Case([&](RegionVariable
*) { return success(); })
3321 .Default([&](FormatElement
*) {
3322 return emitError(loc
,
3323 "only literals, types, and variables can be "
3324 "used within an oilist group");
3332 FailureOr
<FormatElement
*> OpFormatParser::parseTypeDirective(SMLoc loc
,
3334 if (context
== TypeDirectiveContext
)
3335 return emitError(loc
, "'type' cannot be used as a child of another `type`");
3337 bool isRefChild
= context
== RefDirectiveContext
;
3338 FailureOr
<FormatElement
*> operand
;
3339 if (failed(parseToken(FormatToken::l_paren
,
3340 "expected '(' before argument list")) ||
3341 failed(operand
= parseTypeDirectiveOperand(loc
, isRefChild
)) ||
3343 parseToken(FormatToken::r_paren
, "expected ')' after argument list")))
3346 return create
<TypeDirective
>(*operand
);
3349 FailureOr
<FormatElement
*>
3350 OpFormatParser::parseQualifiedDirective(SMLoc loc
, Context context
) {
3351 FailureOr
<FormatElement
*> element
;
3352 if (failed(parseToken(FormatToken::l_paren
,
3353 "expected '(' before argument list")) ||
3354 failed(element
= parseElement(context
)) ||
3356 parseToken(FormatToken::r_paren
, "expected ')' after argument list")))
3358 return TypeSwitch
<FormatElement
*, FailureOr
<FormatElement
*>>(*element
)
3359 .Case
<AttributeVariable
, TypeDirective
>([](auto *element
) {
3360 element
->setShouldBeQualified();
3363 .Default([&](auto *element
) {
3364 return this->emitError(
3366 "'qualified' directive expects an attribute or a `type` directive");
3370 FailureOr
<FormatElement
*>
3371 OpFormatParser::parseTypeDirectiveOperand(SMLoc loc
, bool isRefChild
) {
3372 FailureOr
<FormatElement
*> result
= parseElement(TypeDirectiveContext
);
3376 FormatElement
*element
= *result
;
3377 if (isa
<LiteralElement
>(element
))
3379 loc
, "'type' directive operand expects variable or directive operand");
3381 if (auto *var
= dyn_cast
<OperandVariable
>(element
)) {
3382 unsigned opIdx
= var
->getVar() - op
.operand_begin();
3383 if (!isRefChild
&& (fmt
.allOperandTypes
|| seenOperandTypes
.test(opIdx
)))
3384 return emitError(loc
, "'type' of '" + var
->getVar()->name
+
3385 "' is already bound");
3386 if (isRefChild
&& !(fmt
.allOperandTypes
|| seenOperandTypes
.test(opIdx
)))
3387 return emitError(loc
, "'ref' of 'type($" + var
->getVar()->name
+
3388 ")' is not bound by a prior 'type' directive");
3389 seenOperandTypes
.set(opIdx
);
3390 } else if (auto *var
= dyn_cast
<ResultVariable
>(element
)) {
3391 unsigned resIdx
= var
->getVar() - op
.result_begin();
3392 if (!isRefChild
&& (fmt
.allResultTypes
|| seenResultTypes
.test(resIdx
)))
3393 return emitError(loc
, "'type' of '" + var
->getVar()->name
+
3394 "' is already bound");
3395 if (isRefChild
&& !(fmt
.allResultTypes
|| seenResultTypes
.test(resIdx
)))
3396 return emitError(loc
, "'ref' of 'type($" + var
->getVar()->name
+
3397 ")' is not bound by a prior 'type' directive");
3398 seenResultTypes
.set(resIdx
);
3399 } else if (isa
<OperandsDirective
>(&*element
)) {
3400 if (!isRefChild
&& (fmt
.allOperandTypes
|| seenOperandTypes
.any()))
3401 return emitError(loc
, "'operands' 'type' is already bound");
3402 if (isRefChild
&& !fmt
.allOperandTypes
)
3403 return emitError(loc
, "'ref' of 'type(operands)' is not bound by a prior "
3404 "'type' directive");
3405 fmt
.allOperandTypes
= true;
3406 } else if (isa
<ResultsDirective
>(&*element
)) {
3407 if (!isRefChild
&& (fmt
.allResultTypes
|| seenResultTypes
.any()))
3408 return emitError(loc
, "'results' 'type' is already bound");
3409 if (isRefChild
&& !fmt
.allResultTypes
)
3410 return emitError(loc
, "'ref' of 'type(results)' is not bound by a prior "
3411 "'type' directive");
3412 fmt
.allResultTypes
= true;
3414 return emitError(loc
, "invalid argument to 'type' directive");
3419 LogicalResult
OpFormatParser::verifyOptionalGroupElements(
3420 SMLoc loc
, ArrayRef
<FormatElement
*> elements
, FormatElement
*anchor
) {
3421 for (FormatElement
*element
: elements
) {
3422 if (failed(verifyOptionalGroupElement(loc
, element
, element
== anchor
)))
3428 LogicalResult
OpFormatParser::verifyOptionalGroupElement(SMLoc loc
,
3429 FormatElement
*element
,
3431 return TypeSwitch
<FormatElement
*, LogicalResult
>(element
)
3432 // All attributes can be within the optional group, but only optional
3433 // attributes can be the anchor.
3434 .Case([&](AttributeVariable
*attrEle
) {
3435 Attribute attr
= attrEle
->getVar()->attr
;
3436 if (isAnchor
&& !(attr
.isOptional() || attr
.hasDefaultValue()))
3437 return emitError(loc
, "only optional or default-valued attributes "
3438 "can be used to anchor an optional group");
3441 // Only optional-like(i.e. variadic) operands can be within an optional
3443 .Case([&](OperandVariable
*ele
) {
3444 if (!ele
->getVar()->isVariableLength())
3445 return emitError(loc
, "only variable length operands can be used "
3446 "within an optional group");
3449 // Only optional-like(i.e. variadic) results can be within an optional
3451 .Case([&](ResultVariable
*ele
) {
3452 if (!ele
->getVar()->isVariableLength())
3453 return emitError(loc
, "only variable length results can be used "
3454 "within an optional group");
3457 .Case([&](RegionVariable
*) {
3458 // TODO: When ODS has proper support for marking "optional" regions, add
3462 .Case([&](TypeDirective
*ele
) {
3463 return verifyOptionalGroupElement(loc
, ele
->getArg(),
3464 /*isAnchor=*/false);
3466 .Case([&](FunctionalTypeDirective
*ele
) {
3467 if (failed(verifyOptionalGroupElement(loc
, ele
->getInputs(),
3468 /*isAnchor=*/false)))
3470 return verifyOptionalGroupElement(loc
, ele
->getResults(),
3471 /*isAnchor=*/false);
3473 .Case([&](CustomDirective
*ele
) {
3476 // Verify each child as being valid in an optional group. They are all
3477 // potential anchors if the custom directive was marked as one.
3478 for (FormatElement
*child
: ele
->getArguments()) {
3479 if (isa
<RefDirective
>(child
))
3481 if (failed(verifyOptionalGroupElement(loc
, child
, /*isAnchor=*/true)))
3486 // Literals, whitespace, and custom directives may be used, but they can't
3487 // anchor the group.
3488 .Case
<LiteralElement
, WhitespaceElement
, OptionalElement
>(
3489 [&](FormatElement
*) {
3491 return emitError(loc
, "only variables and types can be used "
3492 "to anchor an optional group");
3495 .Default([&](FormatElement
*) {
3496 return emitError(loc
, "only literals, types, and variables can be "
3497 "used within an optional group");
3501 //===----------------------------------------------------------------------===//
3503 //===----------------------------------------------------------------------===//
3505 void mlir::tblgen::generateOpFormat(const Operator
&constOp
, OpClass
&opClass
) {
3506 // TODO: Operator doesn't expose all necessary functionality via
3507 // the const interface.
3508 Operator
&op
= const_cast<Operator
&>(constOp
);
3509 if (!op
.hasAssemblyFormat())
3512 // Parse the format description.
3513 llvm::SourceMgr mgr
;
3514 mgr
.AddNewSourceBuffer(
3515 llvm::MemoryBuffer::getMemBuffer(op
.getAssemblyFormat()), SMLoc());
3516 OperationFormat
format(op
);
3517 OpFormatParser
parser(mgr
, format
, op
);
3518 FailureOr
<std::vector
<FormatElement
*>> elements
= parser
.parse();
3519 if (failed(elements
)) {
3520 // Exit the process if format errors are treated as fatal.
3521 if (formatErrorIsFatal
) {
3522 // Invoke the interrupt handlers to run the file cleanup handlers.
3523 llvm::sys::RunInterruptHandlers();
3528 format
.elements
= std::move(*elements
);
3530 // Generate the printer and parser based on the parsed format.
3531 format
.genParser(op
, opClass
);
3532 format
.genPrinter(op
, opClass
);