1 //===- OpFormatGen.cpp - MLIR operation asm format generator --------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "OpFormatGen.h"
10 #include "FormatGen.h"
12 #include "mlir/Support/LLVM.h"
13 #include "mlir/TableGen/Class.h"
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "mlir/TableGen/Trait.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallBitVector.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Signals.h"
24 #include "llvm/Support/SourceMgr.h"
25 #include "llvm/TableGen/Record.h"
27 #define DEBUG_TYPE "mlir-tblgen-opformatgen"
30 using namespace mlir::tblgen
;
32 //===----------------------------------------------------------------------===//
36 /// This class represents an instance of an op variable element. A variable
37 /// refers to something registered on the operation itself, e.g. an operand,
38 /// result, attribute, region, or successor.
39 template <typename VarT
, VariableElement::Kind VariableKind
>
40 class OpVariableElement
: public VariableElementBase
<VariableKind
> {
42 using Base
= OpVariableElement
<VarT
, VariableKind
>;
44 /// Create an op variable element with the variable value.
45 OpVariableElement(const VarT
*var
) : var(var
) {}
48 const VarT
*getVar() { return var
; }
51 /// The op variable, e.g. a type or attribute constraint.
55 /// This class represents a variable that refers to an attribute argument.
56 struct AttributeVariable
57 : public OpVariableElement
<NamedAttribute
, VariableElement::Attribute
> {
60 /// Return the constant builder call for the type of this attribute, or
61 /// std::nullopt if it doesn't have one.
62 std::optional
<StringRef
> getTypeBuilder() const {
63 std::optional
<Type
> attrType
= var
->attr
.getValueType();
64 return attrType
? attrType
->getBuilderCall() : std::nullopt
;
67 /// Return if this attribute refers to a UnitAttr.
68 bool isUnitAttr() const {
69 return var
->attr
.getBaseAttr().getAttrDefName() == "UnitAttr";
72 /// Indicate if this attribute is printed "qualified" (that is it is
73 /// prefixed with the `#dialect.mnemonic`).
74 bool shouldBeQualified() { return shouldBeQualifiedFlag
; }
75 void setShouldBeQualified(bool qualified
= true) {
76 shouldBeQualifiedFlag
= qualified
;
80 bool shouldBeQualifiedFlag
= false;
83 /// This class represents a variable that refers to an operand argument.
84 using OperandVariable
=
85 OpVariableElement
<NamedTypeConstraint
, VariableElement::Operand
>;
87 /// This class represents a variable that refers to a result.
88 using ResultVariable
=
89 OpVariableElement
<NamedTypeConstraint
, VariableElement::Result
>;
91 /// This class represents a variable that refers to a region.
92 using RegionVariable
= OpVariableElement
<NamedRegion
, VariableElement::Region
>;
94 /// This class represents a variable that refers to a successor.
95 using SuccessorVariable
=
96 OpVariableElement
<NamedSuccessor
, VariableElement::Successor
>;
99 //===----------------------------------------------------------------------===//
103 /// This class represents the `operands` directive. This directive represents
104 /// all of the operands of an operation.
105 using OperandsDirective
= DirectiveElementBase
<DirectiveElement::Operands
>;
107 /// This class represents the `results` directive. This directive represents
108 /// all of the results of an operation.
109 using ResultsDirective
= DirectiveElementBase
<DirectiveElement::Results
>;
111 /// This class represents the `regions` directive. This directive represents
112 /// all of the regions of an operation.
113 using RegionsDirective
= DirectiveElementBase
<DirectiveElement::Regions
>;
115 /// This class represents the `successors` directive. This directive represents
116 /// all of the successors of an operation.
117 using SuccessorsDirective
= DirectiveElementBase
<DirectiveElement::Successors
>;
119 /// This class represents the `attr-dict` directive. This directive represents
120 /// the attribute dictionary of the operation.
121 class AttrDictDirective
122 : public DirectiveElementBase
<DirectiveElement::AttrDict
> {
124 explicit AttrDictDirective(bool withKeyword
) : withKeyword(withKeyword
) {}
126 /// Return whether the dictionary should be printed with the 'attributes'
128 bool isWithKeyword() const { return withKeyword
; }
131 /// If the dictionary should be printed with the 'attributes' keyword.
135 /// This class represents the `prop-dict` directive. This directive represents
136 /// the properties of the operation, expressed as a directionary.
137 class PropDictDirective
138 : public DirectiveElementBase
<DirectiveElement::PropDict
> {
140 explicit PropDictDirective() = default;
143 /// This class represents the `functional-type` directive. This directive takes
144 /// two arguments and formats them, respectively, as the inputs and results of a
146 class FunctionalTypeDirective
147 : public DirectiveElementBase
<DirectiveElement::FunctionalType
> {
149 FunctionalTypeDirective(FormatElement
*inputs
, FormatElement
*results
)
150 : inputs(inputs
), results(results
) {}
152 FormatElement
*getInputs() const { return inputs
; }
153 FormatElement
*getResults() const { return results
; }
156 /// The input and result arguments.
157 FormatElement
*inputs
, *results
;
160 /// This class represents the `type` directive.
161 class TypeDirective
: public DirectiveElementBase
<DirectiveElement::Type
> {
163 TypeDirective(FormatElement
*arg
) : arg(arg
) {}
165 FormatElement
*getArg() const { return arg
; }
167 /// Indicate if this type is printed "qualified" (that is it is
168 /// prefixed with the `!dialect.mnemonic`).
169 bool shouldBeQualified() { return shouldBeQualifiedFlag
; }
170 void setShouldBeQualified(bool qualified
= true) {
171 shouldBeQualifiedFlag
= qualified
;
175 /// The argument that is used to format the directive.
178 bool shouldBeQualifiedFlag
= false;
181 /// This class represents a group of order-independent optional clauses. Each
182 /// clause starts with a literal element and has a coressponding parsing
183 /// element. A parsing element is a continous sequence of format elements.
184 /// Each clause can appear 0 or 1 time.
185 class OIListElement
: public DirectiveElementBase
<DirectiveElement::OIList
> {
187 OIListElement(std::vector
<FormatElement
*> &&literalElements
,
188 std::vector
<std::vector
<FormatElement
*>> &&parsingElements
)
189 : literalElements(std::move(literalElements
)),
190 parsingElements(std::move(parsingElements
)) {}
192 /// Returns a range to iterate over the LiteralElements.
193 auto getLiteralElements() const {
194 function_ref
<LiteralElement
*(FormatElement
* el
)>
195 literalElementCastConverter
=
196 [](FormatElement
*el
) { return cast
<LiteralElement
>(el
); };
197 return llvm::map_range(literalElements
, literalElementCastConverter
);
200 /// Returns a range to iterate over the parsing elements corresponding to the
202 ArrayRef
<std::vector
<FormatElement
*>> getParsingElements() const {
203 return parsingElements
;
206 /// Returns a range to iterate over tuples of parsing and literal elements.
207 auto getClauses() const {
208 return llvm::zip(getLiteralElements(), getParsingElements());
211 /// If the parsing element is a single UnitAttr element, then it returns the
212 /// attribute variable. Otherwise, returns nullptr.
214 getUnitAttrParsingElement(ArrayRef
<FormatElement
*> pelement
) {
215 if (pelement
.size() == 1) {
216 auto *attrElem
= dyn_cast
<AttributeVariable
>(pelement
[0]);
217 if (attrElem
&& attrElem
->isUnitAttr())
224 /// A vector of `LiteralElement` objects. Each element stores the keyword
225 /// for one case of oilist element. For example, an oilist element along with
226 /// the `literalElements` vector:
228 /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`]
229 /// literalElements = { `keyword`, `otherKeyword` }
231 std::vector
<FormatElement
*> literalElements
;
233 /// A vector of valid declarative assembly format vectors. Each object in
234 /// parsing elements is a vector of elements in assembly format syntax.
235 /// For example, an oilist element along with the parsingElements vector:
237 /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`]
238 /// parsingElements = {
239 /// { `=`, `(`, $arg0, `)` },
240 /// { `<`, $arg1, `>` }
243 std::vector
<std::vector
<FormatElement
*>> parsingElements
;
247 //===----------------------------------------------------------------------===//
249 //===----------------------------------------------------------------------===//
253 using ConstArgument
=
254 llvm::PointerUnion
<const NamedAttribute
*, const NamedTypeConstraint
*>;
256 struct OperationFormat
{
257 /// This class represents a specific resolver for an operand or result type.
258 class TypeResolution
{
260 TypeResolution() = default;
262 /// Get the index into the buildable types for this type, or std::nullopt.
263 std::optional
<int> getBuilderIdx() const { return builderIdx
; }
264 void setBuilderIdx(int idx
) { builderIdx
= idx
; }
266 /// Get the variable this type is resolved to, or nullptr.
267 const NamedTypeConstraint
*getVariable() const {
268 return resolver
.dyn_cast
<const NamedTypeConstraint
*>();
270 /// Get the attribute this type is resolved to, or nullptr.
271 const NamedAttribute
*getAttribute() const {
272 return resolver
.dyn_cast
<const NamedAttribute
*>();
274 /// Get the transformer for the type of the variable, or std::nullopt.
275 std::optional
<StringRef
> getVarTransformer() const {
276 return variableTransformer
;
278 void setResolver(ConstArgument arg
, std::optional
<StringRef
> transformer
) {
280 variableTransformer
= transformer
;
281 assert(getVariable() || getAttribute());
285 /// If the type is resolved with a buildable type, this is the index into
286 /// 'buildableTypes' in the parent format.
287 std::optional
<int> builderIdx
;
288 /// If the type is resolved based upon another operand or result, this is
289 /// the variable or the attribute that this type is resolved to.
290 ConstArgument resolver
;
291 /// If the type is resolved based upon another operand or result, this is
292 /// a transformer to apply to the variable when resolving.
293 std::optional
<StringRef
> variableTransformer
;
296 /// The context in which an element is generated.
297 enum class GenContext
{
298 /// The element is generated at the top-level or with the same behaviour.
300 /// The element is generated inside an optional group.
304 OperationFormat(const Operator
&op
)
305 : useProperties(op
.getDialect().usePropertiesForAttributes() &&
306 !op
.getAttributes().empty()),
307 opCppClassName(op
.getCppClassName()) {
308 operandTypes
.resize(op
.getNumOperands(), TypeResolution());
309 resultTypes
.resize(op
.getNumResults(), TypeResolution());
311 hasImplicitTermTrait
= llvm::any_of(op
.getTraits(), [](const Trait
&trait
) {
312 return trait
.getDef().isSubClassOf("SingleBlockImplicitTerminator");
315 hasSingleBlockTrait
=
316 hasImplicitTermTrait
|| op
.getTrait("::mlir::OpTrait::SingleBlock");
319 /// Generate the operation parser from this format.
320 void genParser(Operator
&op
, OpClass
&opClass
);
321 /// Generate the parser code for a specific format element.
322 void genElementParser(FormatElement
*element
, MethodBody
&body
,
323 FmtContext
&attrTypeCtx
,
324 GenContext genCtx
= GenContext::Normal
);
325 /// Generate the C++ to resolve the types of operands and results during
327 void genParserTypeResolution(Operator
&op
, MethodBody
&body
);
328 /// Generate the C++ to resolve the types of the operands during parsing.
329 void genParserOperandTypeResolution(
330 Operator
&op
, MethodBody
&body
,
331 function_ref
<void(TypeResolution
&, StringRef
)> emitTypeResolver
);
332 /// Generate the C++ to resolve regions during parsing.
333 void genParserRegionResolution(Operator
&op
, MethodBody
&body
);
334 /// Generate the C++ to resolve successors during parsing.
335 void genParserSuccessorResolution(Operator
&op
, MethodBody
&body
);
336 /// Generate the C++ to handling variadic segment size traits.
337 void genParserVariadicSegmentResolution(Operator
&op
, MethodBody
&body
);
339 /// Generate the operation printer from this format.
340 void genPrinter(Operator
&op
, OpClass
&opClass
);
342 /// Generate the printer code for a specific format element.
343 void genElementPrinter(FormatElement
*element
, MethodBody
&body
, Operator
&op
,
344 bool &shouldEmitSpace
, bool &lastWasPunctuation
);
346 /// The various elements in this format.
347 std::vector
<FormatElement
*> elements
;
349 /// A flag indicating if all operand/result types were seen. If the format
350 /// contains these, it can not contain individual type resolvers.
351 bool allOperands
= false, allOperandTypes
= false, allResultTypes
= false;
353 /// A flag indicating if this operation infers its result types
354 bool infersResultTypes
= false;
356 /// A flag indicating if this operation has the SingleBlockImplicitTerminator
358 bool hasImplicitTermTrait
;
360 /// A flag indicating if this operation has the SingleBlock trait.
361 bool hasSingleBlockTrait
;
363 /// Indicate whether attribute are stored in properties.
366 /// The Operation class name
367 StringRef opCppClassName
;
369 /// A map of buildable types to indices.
370 llvm::MapVector
<StringRef
, int, llvm::StringMap
<int>> buildableTypes
;
372 /// The index of the buildable type, if valid, for every operand and result.
373 std::vector
<TypeResolution
> operandTypes
, resultTypes
;
375 /// The set of attributes explicitly used within the format.
376 SmallVector
<const NamedAttribute
*, 8> usedAttributes
;
377 llvm::StringSet
<> inferredAttributes
;
381 //===----------------------------------------------------------------------===//
384 /// Returns true if we can format the given attribute as an EnumAttr in the
386 static bool canFormatEnumAttr(const NamedAttribute
*attr
) {
387 Attribute baseAttr
= attr
->attr
.getBaseAttr();
388 const EnumAttr
*enumAttr
= dyn_cast
<EnumAttr
>(&baseAttr
);
392 // The attribute must have a valid underlying type and a constant builder.
393 return !enumAttr
->getUnderlyingType().empty() &&
394 !enumAttr
->getConstBuilderTemplate().empty();
397 /// Returns if we should format the given attribute as an SymbolNameAttr.
398 static bool shouldFormatSymbolNameAttr(const NamedAttribute
*attr
) {
399 return attr
->attr
.getBaseAttr().getAttrDefName() == "SymbolNameAttr";
402 /// The code snippet used to generate a parser call for an attribute.
404 /// {0}: The name of the attribute.
405 /// {1}: The type for the attribute.
406 const char *const attrParserCode
= R
"(
407 if (parser.parseCustomAttributeWithFallback({0}Attr, {1})) {{
408 return ::mlir::failure();
412 /// The code snippet used to generate a parser call for an attribute.
414 /// {0}: The name of the attribute.
415 /// {1}: The type for the attribute.
416 const char *const genericAttrParserCode
= R
"(
417 if (parser.parseAttribute({0}Attr, {1}))
418 return ::mlir::failure();
421 const char *const optionalAttrParserCode
= R
"(
422 ::mlir::OptionalParseResult parseResult{0}Attr =
423 parser.parseOptionalAttribute({0}Attr, {1});
424 if (parseResult{0}Attr.has_value() && failed(*parseResult{0}Attr))
425 return ::mlir::failure();
426 if (parseResult{0}Attr.has_value() && succeeded(*parseResult{0}Attr))
429 /// The code snippet used to generate a parser call for a symbol name attribute.
431 /// {0}: The name of the attribute.
432 const char *const symbolNameAttrParserCode
= R
"(
433 if (parser.parseSymbolName({0}Attr))
434 return ::mlir::failure();
436 const char *const optionalSymbolNameAttrParserCode
= R
"(
437 // Parsing an optional symbol name doesn't fail, so no need to check the
439 (void)parser.parseOptionalSymbolName({0}Attr);
442 /// The code snippet used to generate a parser call for an enum attribute.
444 /// {0}: The name of the attribute.
445 /// {1}: The c++ namespace for the enum symbolize functions.
446 /// {2}: The function to symbolize a string of the enum.
447 /// {3}: The constant builder call to create an attribute of the enum type.
448 /// {4}: The set of allowed enum keywords.
449 /// {5}: The error message on failure when the enum isn't present.
450 /// {6}: The attribute assignment expression
451 const char *const enumAttrParserCode
= R
"(
453 ::llvm::StringRef attrStr;
454 ::mlir::NamedAttrList attrStorage;
455 auto loc = parser.getCurrentLocation();
456 if (parser.parseOptionalKeyword(&attrStr, {4})) {
457 ::mlir::StringAttr attrVal;
458 ::mlir::OptionalParseResult parseResult =
459 parser.parseOptionalAttribute(attrVal,
460 parser.getBuilder().getNoneType(),
462 if (parseResult.has_value()) {{
463 if (failed(*parseResult))
464 return ::mlir::failure();
465 attrStr = attrVal.getValue();
470 if (!attrStr.empty()) {
471 auto attrOptional = {1}::{2}(attrStr);
473 return parser.emitError(loc, "invalid
")
474 << "{0} attribute specification
: \"" << attrStr << '"';;
482 /// The code snippet used to generate a parser call for an operand.
484 /// {0}: The name of the operand.
485 const char *const variadicOperandParserCode = R"(
486 {0}OperandsLoc = parser.getCurrentLocation();
487 if (parser.parseOperandList({0}Operands))
488 return ::mlir::failure();
490 const char *const optionalOperandParserCode = R"(
492 {0}OperandsLoc = parser.getCurrentLocation();
493 ::mlir::OpAsmParser::UnresolvedOperand operand;
494 ::mlir::OptionalParseResult parseResult =
495 parser.parseOptionalOperand(operand);
496 if (parseResult.has_value()) {
497 if (failed(*parseResult))
498 return ::mlir::failure();
499 {0}Operands.push_back(operand);
503 const char *const operandParserCode = R"(
504 {0}OperandsLoc = parser.getCurrentLocation();
505 if (parser.parseOperand({0}RawOperands[0]))
506 return ::mlir::failure();
508 /// The code snippet used to generate a parser call for a VariadicOfVariadic
511 /// {0}: The name of the operand.
512 /// {1}: The name of segment size attribute.
513 const char *const variadicOfVariadicOperandParserCode = R"(
515 {0}OperandsLoc = parser.getCurrentLocation();
518 if (parser.parseOptionalLParen())
520 if (parser.parseOperandList({0}Operands) || parser.parseRParen())
521 return ::mlir::failure();
522 {0}OperandGroupSizes.push_back({0}Operands.size() - curSize);
523 curSize = {0}Operands.size();
524 } while (succeeded(parser.parseOptionalComma()));
528 /// The code snippet used to generate a parser call for a type list.
530 /// {0}: The name for the type list.
531 const char *const variadicOfVariadicTypeParserCode = R"(
533 if (parser.parseOptionalLParen())
535 if (parser.parseOptionalRParen() &&
536 (parser.parseTypeList({0}Types) || parser.parseRParen()))
537 return ::mlir::failure();
538 } while (succeeded(parser.parseOptionalComma()));
540 const char *const variadicTypeParserCode = R"(
541 if (parser.parseTypeList({0}Types))
542 return ::mlir::failure();
544 const char *const optionalTypeParserCode = R"(
546 ::mlir::Type optionalType;
547 ::mlir::OptionalParseResult parseResult =
548 parser.parseOptionalType(optionalType);
549 if (parseResult.has_value()) {
550 if (failed(*parseResult))
551 return ::mlir::failure();
552 {0}Types.push_back(optionalType);
556 const char *const typeParserCode = R"(
559 if (parser.parseCustomTypeWithFallback(type))
560 return ::mlir::failure();
561 {1}RawTypes[0] = type;
564 const char *const qualifiedTypeParserCode = R"(
565 if (parser.parseType({1}RawTypes[0]))
566 return ::mlir::failure();
569 /// The code snippet used to generate a parser call for a functional type.
571 /// {0}: The name for the input type list.
572 /// {1}: The name for the result type list.
573 const char *const functionalTypeParserCode = R"(
574 ::mlir::FunctionType {0}__{1}_functionType;
575 if (parser.parseType({0}__{1}_functionType))
576 return ::mlir::failure();
577 {0}Types = {0}__{1}_functionType.getInputs();
578 {1}Types = {0}__{1}_functionType.getResults();
581 /// The code snippet used to generate a parser call to infer return types.
583 /// {0}: The operation class name
584 const char *const inferReturnTypesParserCode = R"(
585 ::llvm::SmallVector<::mlir::Type> inferredReturnTypes;
586 if (::mlir::failed({0}::inferReturnTypes(parser.getContext(),
587 result.location, result.operands,
588 result.attributes.getDictionary(parser.getContext()),
589 result.getRawProperties(),
590 result.regions, inferredReturnTypes)))
591 return ::mlir::failure();
592 result.addTypes(inferredReturnTypes);
595 /// The code snippet used to generate a parser call for a region list.
597 /// {0}: The name for the region list.
598 const char *regionListParserCode = R"(
600 std::unique_ptr<::mlir::Region> region;
601 auto firstRegionResult = parser.parseOptionalRegion(region);
602 if (firstRegionResult.has_value()) {
603 if (failed(*firstRegionResult))
604 return ::mlir::failure();
605 {0}Regions.emplace_back(std::move(region));
607 // Parse any trailing regions.
608 while (succeeded(parser.parseOptionalComma())) {
609 region = std::make_unique<::mlir::Region>();
610 if (parser.parseRegion(*region))
611 return ::mlir::failure();
612 {0}Regions.emplace_back(std::move(region));
618 /// The code snippet used to ensure a list of regions have terminators.
620 /// {0}: The name of the region list.
621 const char *regionListEnsureTerminatorParserCode = R"(
622 for (auto ®ion : {0}Regions)
623 ensureTerminator(*region, parser.getBuilder(), result.location);
626 /// The code snippet used to ensure a list of regions have a block.
628 /// {0}: The name of the region list.
629 const char *regionListEnsureSingleBlockParserCode = R"(
630 for (auto ®ion : {0}Regions)
631 if (region->empty()) region->emplaceBlock();
634 /// The code snippet used to generate a parser call for an optional region.
636 /// {0}: The name of the region.
637 const char *optionalRegionParserCode = R"(
639 auto parseResult = parser.parseOptionalRegion(*{0}Region);
640 if (parseResult.has_value() && failed(*parseResult))
641 return ::mlir::failure();
645 /// The code snippet used to generate a parser call for a region.
647 /// {0}: The name of the region.
648 const char *regionParserCode = R"(
649 if (parser.parseRegion(*{0}Region))
650 return ::mlir::failure();
653 /// The code snippet used to ensure a region has a terminator.
655 /// {0}: The name of the region.
656 const char *regionEnsureTerminatorParserCode = R"(
657 ensureTerminator(*{0}Region, parser.getBuilder(), result.location);
660 /// The code snippet used to ensure a region has a block.
662 /// {0}: The name of the region.
663 const char *regionEnsureSingleBlockParserCode = R"(
664 if ({0}Region->empty()) {0}Region->emplaceBlock();
667 /// The code snippet used to generate a parser call for a successor list.
669 /// {0}: The name for the successor list.
670 const char *successorListParserCode = R"(
673 auto firstSucc = parser.parseOptionalSuccessor(succ);
674 if (firstSucc.has_value()) {
675 if (failed(*firstSucc))
676 return ::mlir::failure();
677 {0}Successors.emplace_back(succ);
679 // Parse any trailing successors.
680 while (succeeded(parser.parseOptionalComma())) {
681 if (parser.parseSuccessor(succ))
682 return ::mlir::failure();
683 {0}Successors.emplace_back(succ);
689 /// The code snippet used to generate a parser call for a successor.
691 /// {0}: The name of the successor.
692 const char *successorParserCode = R"(
693 if (parser.parseSuccessor({0}Successor))
694 return ::mlir::failure();
697 /// The code snippet used to generate a parser for OIList
699 /// {0}: literal keyword corresponding to a case for oilist
700 const char *oilistParserCode = R"(
702 return parser.emitError(parser.getNameLoc())
703 << "`{0}` clause can appear at most once in the expansion of the "
710 /// The type of length for a given parse argument.
711 enum class ArgumentLengthKind {
712 /// The argument is a variadic of a variadic, and may contain 0->N range
715 /// The argument is variadic, and may contain 0->N elements.
717 /// The argument is optional, and may contain 0 or 1 elements.
719 /// The argument is a single element, i.e. always represents 1 element.
724 /// Get the length kind for the given constraint.
725 static ArgumentLengthKind
726 getArgumentLengthKind(const NamedTypeConstraint *var) {
727 if (var->isOptional())
728 return ArgumentLengthKind::Optional;
729 if (var->isVariadicOfVariadic())
730 return ArgumentLengthKind::VariadicOfVariadic;
731 if (var->isVariadic())
732 return ArgumentLengthKind::Variadic;
733 return ArgumentLengthKind::Single;
736 /// Get the name used for the type list for the given type directive operand.
737 /// 'lengthKind
' to the corresponding kind for the given argument.
738 static StringRef getTypeListName(FormatElement *arg,
739 ArgumentLengthKind &lengthKind) {
740 if (auto *operand = dyn_cast<OperandVariable>(arg)) {
741 lengthKind = getArgumentLengthKind(operand->getVar());
742 return operand->getVar()->name;
744 if (auto *result = dyn_cast<ResultVariable>(arg)) {
745 lengthKind = getArgumentLengthKind(result->getVar());
746 return result->getVar()->name;
748 lengthKind = ArgumentLengthKind::Variadic;
749 if (isa<OperandsDirective>(arg))
751 if (isa<ResultsDirective>(arg))
753 llvm_unreachable("unknown 'type
' directive argument");
756 /// Generate the parser for a literal value.
757 static void genLiteralParser(StringRef value, MethodBody &body) {
758 // Handle the case of a keyword/identifier.
759 if (value.front() == '_
' || isalpha(value.front())) {
760 body << "Keyword(\"" << value << "\")";
763 body << (StringRef)StringSwitch<StringRef>(value)
764 .Case("->", "Arrow()")
765 .Case(":", "Colon()")
766 .Case(",", "Comma()")
767 .Case("=", "Equal()")
769 .Case(">", "Greater()")
770 .Case("{", "LBrace()")
771 .Case("}", "RBrace()")
772 .Case("(", "LParen()")
773 .Case(")", "RParen()")
774 .Case("[", "LSquare()")
775 .Case("]", "RSquare()")
776 .Case("?", "Question()")
779 .Case("...", "Ellipsis()");
782 /// Generate the storage code required for parsing the given element.
783 static void genElementParserStorage(FormatElement *element, const Operator &op,
785 if (auto *optional = dyn_cast<OptionalElement>(element)) {
786 ArrayRef<FormatElement *> elements = optional->getThenElements();
788 // If the anchor is a unit attribute, it won't be parsed directly so elide
790 auto *anchor
= dyn_cast
<AttributeVariable
>(optional
->getAnchor());
791 FormatElement
*elidedAnchorElement
= nullptr;
792 if (anchor
&& anchor
!= elements
.front() && anchor
->isUnitAttr())
793 elidedAnchorElement
= anchor
;
794 for (FormatElement
*childElement
: elements
)
795 if (childElement
!= elidedAnchorElement
)
796 genElementParserStorage(childElement
, op
, body
);
797 for (FormatElement
*childElement
: optional
->getElseElements())
798 genElementParserStorage(childElement
, op
, body
);
800 } else if (auto *oilist
= dyn_cast
<OIListElement
>(element
)) {
801 for (ArrayRef
<FormatElement
*> pelement
: oilist
->getParsingElements()) {
802 if (!oilist
->getUnitAttrParsingElement(pelement
))
803 for (FormatElement
*element
: pelement
)
804 genElementParserStorage(element
, op
, body
);
807 } else if (auto *custom
= dyn_cast
<CustomDirective
>(element
)) {
808 for (FormatElement
*paramElement
: custom
->getArguments())
809 genElementParserStorage(paramElement
, op
, body
);
811 } else if (isa
<OperandsDirective
>(element
)) {
812 body
<< " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> "
815 } else if (isa
<RegionsDirective
>(element
)) {
816 body
<< " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
819 } else if (isa
<SuccessorsDirective
>(element
)) {
820 body
<< " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
822 } else if (auto *attr
= dyn_cast
<AttributeVariable
>(element
)) {
823 const NamedAttribute
*var
= attr
->getVar();
824 body
<< llvm::formatv(" {0} {1}Attr;\n", var
->attr
.getStorageType(),
827 } else if (auto *operand
= dyn_cast
<OperandVariable
>(element
)) {
828 StringRef name
= operand
->getVar()->name
;
829 if (operand
->getVar()->isVariableLength()) {
831 << " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> "
832 << name
<< "Operands;\n";
833 if (operand
->getVar()->isVariadicOfVariadic()) {
834 body
<< " llvm::SmallVector<int32_t> " << name
835 << "OperandGroupSizes;\n";
838 body
<< " ::mlir::OpAsmParser::UnresolvedOperand " << name
839 << "RawOperands[1];\n"
840 << " ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> "
841 << name
<< "Operands(" << name
<< "RawOperands);";
843 body
<< llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
844 " (void){0}OperandsLoc;\n",
847 } else if (auto *region
= dyn_cast
<RegionVariable
>(element
)) {
848 StringRef name
= region
->getVar()->name
;
849 if (region
->getVar()->isVariadic()) {
850 body
<< llvm::formatv(
851 " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
855 body
<< llvm::formatv(" std::unique_ptr<::mlir::Region> {0}Region = "
856 "std::make_unique<::mlir::Region>();\n",
860 } else if (auto *successor
= dyn_cast
<SuccessorVariable
>(element
)) {
861 StringRef name
= successor
->getVar()->name
;
862 if (successor
->getVar()->isVariadic()) {
863 body
<< llvm::formatv(" ::llvm::SmallVector<::mlir::Block *, 2> "
867 body
<< llvm::formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name
);
870 } else if (auto *dir
= dyn_cast
<TypeDirective
>(element
)) {
871 ArgumentLengthKind lengthKind
;
872 StringRef name
= getTypeListName(dir
->getArg(), lengthKind
);
873 if (lengthKind
!= ArgumentLengthKind::Single
)
874 body
<< " ::llvm::SmallVector<::mlir::Type, 1> " << name
<< "Types;\n";
876 body
<< llvm::formatv(" ::mlir::Type {0}RawTypes[1];\n", name
)
878 " ::llvm::ArrayRef<::mlir::Type> {0}Types({0}RawTypes);\n",
880 } else if (auto *dir
= dyn_cast
<FunctionalTypeDirective
>(element
)) {
881 ArgumentLengthKind ignored
;
882 body
<< " ::llvm::ArrayRef<::mlir::Type> "
883 << getTypeListName(dir
->getInputs(), ignored
) << "Types;\n";
884 body
<< " ::llvm::ArrayRef<::mlir::Type> "
885 << getTypeListName(dir
->getResults(), ignored
) << "Types;\n";
889 /// Generate the parser for a parameter to a custom directive.
890 static void genCustomParameterParser(FormatElement
*param
, MethodBody
&body
) {
891 if (auto *attr
= dyn_cast
<AttributeVariable
>(param
)) {
892 body
<< attr
->getVar()->name
<< "Attr";
893 } else if (isa
<AttrDictDirective
>(param
)) {
894 body
<< "result.attributes";
895 } else if (auto *operand
= dyn_cast
<OperandVariable
>(param
)) {
896 StringRef name
= operand
->getVar()->name
;
897 ArgumentLengthKind lengthKind
= getArgumentLengthKind(operand
->getVar());
898 if (lengthKind
== ArgumentLengthKind::VariadicOfVariadic
)
899 body
<< llvm::formatv("{0}OperandGroups", name
);
900 else if (lengthKind
== ArgumentLengthKind::Variadic
)
901 body
<< llvm::formatv("{0}Operands", name
);
902 else if (lengthKind
== ArgumentLengthKind::Optional
)
903 body
<< llvm::formatv("{0}Operand", name
);
905 body
<< formatv("{0}RawOperands[0]", name
);
907 } else if (auto *region
= dyn_cast
<RegionVariable
>(param
)) {
908 StringRef name
= region
->getVar()->name
;
909 if (region
->getVar()->isVariadic())
910 body
<< llvm::formatv("{0}Regions", name
);
912 body
<< llvm::formatv("*{0}Region", name
);
914 } else if (auto *successor
= dyn_cast
<SuccessorVariable
>(param
)) {
915 StringRef name
= successor
->getVar()->name
;
916 if (successor
->getVar()->isVariadic())
917 body
<< llvm::formatv("{0}Successors", name
);
919 body
<< llvm::formatv("{0}Successor", name
);
921 } else if (auto *dir
= dyn_cast
<RefDirective
>(param
)) {
922 genCustomParameterParser(dir
->getArg(), body
);
924 } else if (auto *dir
= dyn_cast
<TypeDirective
>(param
)) {
925 ArgumentLengthKind lengthKind
;
926 StringRef listName
= getTypeListName(dir
->getArg(), lengthKind
);
927 if (lengthKind
== ArgumentLengthKind::VariadicOfVariadic
)
928 body
<< llvm::formatv("{0}TypeGroups", listName
);
929 else if (lengthKind
== ArgumentLengthKind::Variadic
)
930 body
<< llvm::formatv("{0}Types", listName
);
931 else if (lengthKind
== ArgumentLengthKind::Optional
)
932 body
<< llvm::formatv("{0}Type", listName
);
934 body
<< formatv("{0}RawTypes[0]", listName
);
936 } else if (auto *string
= dyn_cast
<StringElement
>(param
)) {
938 ctx
.withBuilder("parser.getBuilder()");
939 ctx
.addSubst("_ctxt", "parser.getContext()");
940 body
<< tgfmt(string
->getValue(), &ctx
);
943 llvm_unreachable("unknown custom directive parameter");
947 /// Generate the parser for a custom directive.
948 static void genCustomDirectiveParser(CustomDirective
*dir
, MethodBody
&body
,
950 StringRef opCppClassName
) {
953 // Preprocess the directive variables.
954 // * Add a local variable for optional operands and types. This provides a
955 // better API to the user defined parser methods.
956 // * Set the location of operand variables.
957 for (FormatElement
*param
: dir
->getArguments()) {
958 if (auto *operand
= dyn_cast
<OperandVariable
>(param
)) {
959 auto *var
= operand
->getVar();
960 body
<< " " << var
->name
961 << "OperandsLoc = parser.getCurrentLocation();\n";
962 if (var
->isOptional()) {
963 body
<< llvm::formatv(
964 " ::std::optional<::mlir::OpAsmParser::UnresolvedOperand> "
967 } else if (var
->isVariadicOfVariadic()) {
968 body
<< llvm::formatv(" "
969 "::llvm::SmallVector<::llvm::SmallVector<::mlir::"
970 "OpAsmParser::UnresolvedOperand>> "
971 "{0}OperandGroups;\n",
974 } else if (auto *dir
= dyn_cast
<TypeDirective
>(param
)) {
975 ArgumentLengthKind lengthKind
;
976 StringRef listName
= getTypeListName(dir
->getArg(), lengthKind
);
977 if (lengthKind
== ArgumentLengthKind::Optional
) {
978 body
<< llvm::formatv(" ::mlir::Type {0}Type;\n", listName
);
979 } else if (lengthKind
== ArgumentLengthKind::VariadicOfVariadic
) {
980 body
<< llvm::formatv(
981 " ::llvm::SmallVector<llvm::SmallVector<::mlir::Type>> "
985 } else if (auto *dir
= dyn_cast
<RefDirective
>(param
)) {
986 FormatElement
*input
= dir
->getArg();
987 if (auto *operand
= dyn_cast
<OperandVariable
>(input
)) {
988 if (!operand
->getVar()->isOptional())
990 body
<< llvm::formatv(
991 " {0} {1}Operand = {1}Operands.empty() ? {0}() : "
993 "::std::optional<::mlir::OpAsmParser::UnresolvedOperand>",
994 operand
->getVar()->name
);
996 } else if (auto *type
= dyn_cast
<TypeDirective
>(input
)) {
997 ArgumentLengthKind lengthKind
;
998 StringRef listName
= getTypeListName(type
->getArg(), lengthKind
);
999 if (lengthKind
== ArgumentLengthKind::Optional
) {
1000 body
<< llvm::formatv(" ::mlir::Type {0}Type = {0}Types.empty() ? "
1001 "::mlir::Type() : {0}Types[0];\n",
1008 body
<< " if (parse" << dir
->getName() << "(parser";
1009 for (FormatElement
*param
: dir
->getArguments()) {
1011 genCustomParameterParser(param
, body
);
1015 << " return ::mlir::failure();\n";
1017 // After parsing, add handling for any of the optional constructs.
1018 for (FormatElement
*param
: dir
->getArguments()) {
1019 if (auto *attr
= dyn_cast
<AttributeVariable
>(param
)) {
1020 const NamedAttribute
*var
= attr
->getVar();
1021 if (var
->attr
.isOptional() || var
->attr
.hasDefaultValue())
1022 body
<< llvm::formatv(" if ({0}Attr)\n ", var
->name
);
1023 if (useProperties
) {
1025 " result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;\n",
1026 var
->name
, opCppClassName
);
1028 body
<< llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
1032 } else if (auto *operand
= dyn_cast
<OperandVariable
>(param
)) {
1033 const NamedTypeConstraint
*var
= operand
->getVar();
1034 if (var
->isOptional()) {
1035 body
<< llvm::formatv(" if ({0}Operand.has_value())\n"
1036 " {0}Operands.push_back(*{0}Operand);\n",
1038 } else if (var
->isVariadicOfVariadic()) {
1039 body
<< llvm::formatv(
1040 " for (const auto &subRange : {0}OperandGroups) {{\n"
1041 " {0}Operands.append(subRange.begin(), subRange.end());\n"
1042 " {0}OperandGroupSizes.push_back(subRange.size());\n"
1044 var
->name
, var
->constraint
.getVariadicOfVariadicSegmentSizeAttr());
1046 } else if (auto *dir
= dyn_cast
<TypeDirective
>(param
)) {
1047 ArgumentLengthKind lengthKind
;
1048 StringRef listName
= getTypeListName(dir
->getArg(), lengthKind
);
1049 if (lengthKind
== ArgumentLengthKind::Optional
) {
1050 body
<< llvm::formatv(" if ({0}Type)\n"
1051 " {0}Types.push_back({0}Type);\n",
1053 } else if (lengthKind
== ArgumentLengthKind::VariadicOfVariadic
) {
1054 body
<< llvm::formatv(
1055 " for (const auto &subRange : {0}TypeGroups)\n"
1056 " {0}Types.append(subRange.begin(), subRange.end());\n",
1065 /// Generate the parser for a enum attribute.
1066 static void genEnumAttrParser(const NamedAttribute
*var
, MethodBody
&body
,
1067 FmtContext
&attrTypeCtx
, bool parseAsOptional
,
1068 bool useProperties
, StringRef opCppClassName
) {
1069 Attribute baseAttr
= var
->attr
.getBaseAttr();
1070 const EnumAttr
&enumAttr
= cast
<EnumAttr
>(baseAttr
);
1071 std::vector
<EnumAttrCase
> cases
= enumAttr
.getAllCases();
1073 // Generate the code for building an attribute for this enum.
1074 std::string attrBuilderStr
;
1076 llvm::raw_string_ostream
os(attrBuilderStr
);
1077 os
<< tgfmt(enumAttr
.getConstBuilderTemplate(), &attrTypeCtx
,
1081 // Build a string containing the cases that can be formatted as a keyword.
1082 std::string validCaseKeywordsStr
= "{";
1083 llvm::raw_string_ostream
validCaseKeywordsOS(validCaseKeywordsStr
);
1084 for (const EnumAttrCase
&attrCase
: cases
)
1085 if (canFormatStringAsKeyword(attrCase
.getStr()))
1086 validCaseKeywordsOS
<< '"' << attrCase
.getStr() << "\",";
1087 validCaseKeywordsOS
.str().back() = '}';
1089 // If the attribute is not optional, build an error message for the missing
1091 std::string errorMessage
;
1092 if (!parseAsOptional
) {
1093 llvm::raw_string_ostream
errorMessageOS(errorMessage
);
1095 << "return parser.emitError(loc, \"expected string or "
1096 "keyword containing one of the following enum values for attribute '"
1097 << var
->name
<< "' [";
1098 llvm::interleaveComma(cases
, errorMessageOS
, [&](const auto &attrCase
) {
1099 errorMessageOS
<< attrCase
.getStr();
1101 errorMessageOS
<< "]\");";
1103 std::string attrAssignment
;
1104 if (useProperties
) {
1107 "result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;",
1108 var
->name
, opCppClassName
);
1111 formatv("result.addAttribute(\"{0}\", {0}Attr);", var
->name
);
1114 body
<< formatv(enumAttrParserCode
, var
->name
, enumAttr
.getCppNamespace(),
1115 enumAttr
.getStringToSymbolFnName(), attrBuilderStr
,
1116 validCaseKeywordsStr
, errorMessage
, attrAssignment
);
1119 // Generate the parser for an attribute.
1120 static void genAttrParser(AttributeVariable
*attr
, MethodBody
&body
,
1121 FmtContext
&attrTypeCtx
, bool parseAsOptional
,
1122 bool useProperties
, StringRef opCppClassName
) {
1123 const NamedAttribute
*var
= attr
->getVar();
1125 // Check to see if we can parse this as an enum attribute.
1126 if (canFormatEnumAttr(var
))
1127 return genEnumAttrParser(var
, body
, attrTypeCtx
, parseAsOptional
,
1128 useProperties
, opCppClassName
);
1130 // Check to see if we should parse this as a symbol name attribute.
1131 if (shouldFormatSymbolNameAttr(var
)) {
1132 body
<< formatv(parseAsOptional
? optionalSymbolNameAttrParserCode
1133 : symbolNameAttrParserCode
,
1137 // If this attribute has a buildable type, use that when parsing the
1139 std::string attrTypeStr
;
1140 if (std::optional
<StringRef
> typeBuilder
= attr
->getTypeBuilder()) {
1141 llvm::raw_string_ostream
os(attrTypeStr
);
1142 os
<< tgfmt(*typeBuilder
, &attrTypeCtx
);
1144 attrTypeStr
= "::mlir::Type{}";
1146 if (parseAsOptional
) {
1147 body
<< formatv(optionalAttrParserCode
, var
->name
, attrTypeStr
);
1149 if (attr
->shouldBeQualified() ||
1150 var
->attr
.getStorageType() == "::mlir::Attribute")
1151 body
<< formatv(genericAttrParserCode
, var
->name
, attrTypeStr
);
1153 body
<< formatv(attrParserCode
, var
->name
, attrTypeStr
);
1156 if (useProperties
) {
1158 " if ({0}Attr) result.getOrAddProperties<{1}::Properties>().{0} = "
1160 var
->name
, opCppClassName
);
1163 " if ({0}Attr) result.attributes.append(\"{0}\", {0}Attr);\n",
1168 void OperationFormat::genParser(Operator
&op
, OpClass
&opClass
) {
1169 SmallVector
<MethodParameter
> paramList
;
1170 paramList
.emplace_back("::mlir::OpAsmParser &", "parser");
1171 paramList
.emplace_back("::mlir::OperationState &", "result");
1173 auto *method
= opClass
.addStaticMethod("::mlir::ParseResult", "parse",
1174 std::move(paramList
));
1175 auto &body
= method
->body();
1177 // Generate variables to store the operands and type within the format. This
1178 // allows for referencing these variables in the presence of optional
1180 for (FormatElement
*element
: elements
)
1181 genElementParserStorage(element
, op
, body
);
1183 // A format context used when parsing attributes with buildable types.
1184 FmtContext attrTypeCtx
;
1185 attrTypeCtx
.withBuilder("parser.getBuilder()");
1187 // Generate parsers for each of the elements.
1188 for (FormatElement
*element
: elements
)
1189 genElementParser(element
, body
, attrTypeCtx
);
1191 // Generate the code to resolve the operand/result types and successors now
1192 // that they have been parsed.
1193 genParserRegionResolution(op
, body
);
1194 genParserSuccessorResolution(op
, body
);
1195 genParserVariadicSegmentResolution(op
, body
);
1196 genParserTypeResolution(op
, body
);
1198 body
<< " return ::mlir::success();\n";
1201 void OperationFormat::genElementParser(FormatElement
*element
, MethodBody
&body
,
1202 FmtContext
&attrTypeCtx
,
1203 GenContext genCtx
) {
1205 if (auto *optional
= dyn_cast
<OptionalElement
>(element
)) {
1206 auto genElementParsers
= [&](FormatElement
*firstElement
,
1207 ArrayRef
<FormatElement
*> elements
,
1209 // If the anchor is a unit attribute, we don't need to print it. When
1210 // parsing, we will add this attribute if this group is present.
1211 FormatElement
*elidedAnchorElement
= nullptr;
1212 auto *anchorAttr
= dyn_cast
<AttributeVariable
>(optional
->getAnchor());
1213 if (anchorAttr
&& anchorAttr
!= firstElement
&&
1214 anchorAttr
->isUnitAttr()) {
1215 elidedAnchorElement
= anchorAttr
;
1217 if (!thenGroup
== optional
->isInverted()) {
1218 // Add the anchor unit attribute to the operation state.
1219 if (useProperties
) {
1221 " result.getOrAddProperties<{1}::Properties>().{0} = "
1222 "parser.getBuilder().getUnitAttr();",
1223 anchorAttr
->getVar()->name
, opCppClassName
);
1225 body
<< " result.addAttribute(\"" << anchorAttr
->getVar()->name
1226 << "\", parser.getBuilder().getUnitAttr());\n";
1231 // Generate the rest of the elements inside an optional group. Elements in
1232 // an optional group after the guard are parsed as required.
1233 for (FormatElement
*childElement
: elements
)
1234 if (childElement
!= elidedAnchorElement
)
1235 genElementParser(childElement
, body
, attrTypeCtx
,
1236 GenContext::Optional
);
1239 ArrayRef
<FormatElement
*> thenElements
=
1240 optional
->getThenElements(/*parseable=*/true);
1242 // Generate a special optional parser for the first element to gate the
1243 // parsing of the rest of the elements.
1244 FormatElement
*firstElement
= thenElements
.front();
1245 if (auto *attrVar
= dyn_cast
<AttributeVariable
>(firstElement
)) {
1246 genAttrParser(attrVar
, body
, attrTypeCtx
, /*parseAsOptional=*/true,
1247 useProperties
, opCppClassName
);
1248 body
<< " if (" << attrVar
->getVar()->name
<< "Attr) {\n";
1249 } else if (auto *literal
= dyn_cast
<LiteralElement
>(firstElement
)) {
1250 body
<< " if (::mlir::succeeded(parser.parseOptional";
1251 genLiteralParser(literal
->getSpelling(), body
);
1253 } else if (auto *opVar
= dyn_cast
<OperandVariable
>(firstElement
)) {
1254 genElementParser(opVar
, body
, attrTypeCtx
);
1255 body
<< " if (!" << opVar
->getVar()->name
<< "Operands.empty()) {\n";
1256 } else if (auto *regionVar
= dyn_cast
<RegionVariable
>(firstElement
)) {
1257 const NamedRegion
*region
= regionVar
->getVar();
1258 if (region
->isVariadic()) {
1259 genElementParser(regionVar
, body
, attrTypeCtx
);
1260 body
<< " if (!" << region
->name
<< "Regions.empty()) {\n";
1262 body
<< llvm::formatv(optionalRegionParserCode
, region
->name
);
1263 body
<< " if (!" << region
->name
<< "Region->empty()) {\n ";
1264 if (hasImplicitTermTrait
)
1265 body
<< llvm::formatv(regionEnsureTerminatorParserCode
, region
->name
);
1266 else if (hasSingleBlockTrait
)
1267 body
<< llvm::formatv(regionEnsureSingleBlockParserCode
,
1272 genElementParsers(firstElement
, thenElements
.drop_front(),
1273 /*thenGroup=*/true);
1276 // Generate the else elements.
1277 auto elseElements
= optional
->getElseElements();
1278 if (!elseElements
.empty()) {
1279 body
<< " else {\n";
1280 ArrayRef
<FormatElement
*> elseElements
=
1281 optional
->getElseElements(/*parseable=*/true);
1282 genElementParsers(elseElements
.front(), elseElements
,
1283 /*thenGroup=*/false);
1288 /// OIList Directive
1289 } else if (OIListElement
*oilist
= dyn_cast
<OIListElement
>(element
)) {
1290 for (LiteralElement
*le
: oilist
->getLiteralElements())
1291 body
<< " bool " << le
->getSpelling() << "Clause = false;\n";
1293 // Generate the parsing loop
1294 body
<< " while(true) {\n";
1295 for (auto clause
: oilist
->getClauses()) {
1296 LiteralElement
*lelement
= std::get
<0>(clause
);
1297 ArrayRef
<FormatElement
*> pelement
= std::get
<1>(clause
);
1298 body
<< "if (succeeded(parser.parseOptional";
1299 genLiteralParser(lelement
->getSpelling(), body
);
1301 StringRef lelementName
= lelement
->getSpelling();
1302 body
<< formatv(oilistParserCode
, lelementName
);
1303 if (AttributeVariable
*unitAttrElem
=
1304 oilist
->getUnitAttrParsingElement(pelement
)) {
1305 if (useProperties
) {
1307 " result.getOrAddProperties<{1}::Properties>().{0} = "
1308 "parser.getBuilder().getUnitAttr();",
1309 unitAttrElem
->getVar()->name
, opCppClassName
);
1311 body
<< " result.addAttribute(\"" << unitAttrElem
->getVar()->name
1312 << "\", UnitAttr::get(parser.getContext()));\n";
1315 for (FormatElement
*el
: pelement
)
1316 genElementParser(el
, body
, attrTypeCtx
);
1321 body
<< " break;\n";
1326 } else if (LiteralElement
*literal
= dyn_cast
<LiteralElement
>(element
)) {
1327 body
<< " if (parser.parse";
1328 genLiteralParser(literal
->getSpelling(), body
);
1329 body
<< ")\n return ::mlir::failure();\n";
1332 } else if (isa
<WhitespaceElement
>(element
)) {
1333 // Nothing to parse.
1336 } else if (auto *attr
= dyn_cast
<AttributeVariable
>(element
)) {
1337 bool parseAsOptional
=
1338 (genCtx
== GenContext::Normal
&& attr
->getVar()->attr
.isOptional());
1339 genAttrParser(attr
, body
, attrTypeCtx
, parseAsOptional
, useProperties
,
1342 } else if (auto *operand
= dyn_cast
<OperandVariable
>(element
)) {
1343 ArgumentLengthKind lengthKind
= getArgumentLengthKind(operand
->getVar());
1344 StringRef name
= operand
->getVar()->name
;
1345 if (lengthKind
== ArgumentLengthKind::VariadicOfVariadic
)
1346 body
<< llvm::formatv(
1347 variadicOfVariadicOperandParserCode
, name
,
1348 operand
->getVar()->constraint
.getVariadicOfVariadicSegmentSizeAttr());
1349 else if (lengthKind
== ArgumentLengthKind::Variadic
)
1350 body
<< llvm::formatv(variadicOperandParserCode
, name
);
1351 else if (lengthKind
== ArgumentLengthKind::Optional
)
1352 body
<< llvm::formatv(optionalOperandParserCode
, name
);
1354 body
<< formatv(operandParserCode
, name
);
1356 } else if (auto *region
= dyn_cast
<RegionVariable
>(element
)) {
1357 bool isVariadic
= region
->getVar()->isVariadic();
1358 body
<< llvm::formatv(isVariadic
? regionListParserCode
: regionParserCode
,
1359 region
->getVar()->name
);
1360 if (hasImplicitTermTrait
)
1361 body
<< llvm::formatv(isVariadic
? regionListEnsureTerminatorParserCode
1362 : regionEnsureTerminatorParserCode
,
1363 region
->getVar()->name
);
1364 else if (hasSingleBlockTrait
)
1365 body
<< llvm::formatv(isVariadic
? regionListEnsureSingleBlockParserCode
1366 : regionEnsureSingleBlockParserCode
,
1367 region
->getVar()->name
);
1369 } else if (auto *successor
= dyn_cast
<SuccessorVariable
>(element
)) {
1370 bool isVariadic
= successor
->getVar()->isVariadic();
1371 body
<< formatv(isVariadic
? successorListParserCode
: successorParserCode
,
1372 successor
->getVar()->name
);
1375 } else if (auto *attrDict
= dyn_cast
<AttrDictDirective
>(element
)) {
1376 body
.indent() << "{\n";
1377 body
.indent() << "auto loc = parser.getCurrentLocation();(void)loc;\n"
1378 << "if (parser.parseOptionalAttrDict"
1379 << (attrDict
->isWithKeyword() ? "WithKeyword" : "")
1380 << "(result.attributes))\n"
1381 << " return ::mlir::failure();\n";
1382 if (useProperties
) {
1383 body
<< "if (failed(verifyInherentAttrs(result.name, result.attributes, "
1385 << " return parser.emitError(loc) << \"'\" << "
1386 "result.name.getStringRef() << \"' op \";\n"
1388 << " return ::mlir::failure();\n";
1390 body
.unindent() << "}\n";
1392 } else if (auto *attrDict
= dyn_cast
<PropDictDirective
>(element
)) {
1393 body
<< " if (parseProperties(parser, result))\n"
1394 << " return ::mlir::failure();\n";
1395 } else if (auto *customDir
= dyn_cast
<CustomDirective
>(element
)) {
1396 genCustomDirectiveParser(customDir
, body
, useProperties
, opCppClassName
);
1397 } else if (isa
<OperandsDirective
>(element
)) {
1398 body
<< " ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
1399 << " if (parser.parseOperandList(allOperands))\n"
1400 << " return ::mlir::failure();\n";
1402 } else if (isa
<RegionsDirective
>(element
)) {
1403 body
<< llvm::formatv(regionListParserCode
, "full");
1404 if (hasImplicitTermTrait
)
1405 body
<< llvm::formatv(regionListEnsureTerminatorParserCode
, "full");
1406 else if (hasSingleBlockTrait
)
1407 body
<< llvm::formatv(regionListEnsureSingleBlockParserCode
, "full");
1409 } else if (isa
<SuccessorsDirective
>(element
)) {
1410 body
<< llvm::formatv(successorListParserCode
, "full");
1412 } else if (auto *dir
= dyn_cast
<TypeDirective
>(element
)) {
1413 ArgumentLengthKind lengthKind
;
1414 StringRef listName
= getTypeListName(dir
->getArg(), lengthKind
);
1415 if (lengthKind
== ArgumentLengthKind::VariadicOfVariadic
) {
1416 body
<< llvm::formatv(variadicOfVariadicTypeParserCode
, listName
);
1417 } else if (lengthKind
== ArgumentLengthKind::Variadic
) {
1418 body
<< llvm::formatv(variadicTypeParserCode
, listName
);
1419 } else if (lengthKind
== ArgumentLengthKind::Optional
) {
1420 body
<< llvm::formatv(optionalTypeParserCode
, listName
);
1422 const char *parserCode
=
1423 dir
->shouldBeQualified() ? qualifiedTypeParserCode
: typeParserCode
;
1424 TypeSwitch
<FormatElement
*>(dir
->getArg())
1425 .Case
<OperandVariable
, ResultVariable
>([&](auto operand
) {
1426 body
<< formatv(parserCode
,
1427 operand
->getVar()->constraint
.getCPPClassName(),
1430 .Default([&](auto operand
) {
1431 body
<< formatv(parserCode
, "::mlir::Type", listName
);
1434 } else if (auto *dir
= dyn_cast
<FunctionalTypeDirective
>(element
)) {
1435 ArgumentLengthKind ignored
;
1436 body
<< formatv(functionalTypeParserCode
,
1437 getTypeListName(dir
->getInputs(), ignored
),
1438 getTypeListName(dir
->getResults(), ignored
));
1440 llvm_unreachable("unknown format element");
1444 void OperationFormat::genParserTypeResolution(Operator
&op
, MethodBody
&body
) {
1445 // If any of type resolutions use transformed variables, make sure that the
1446 // types of those variables are resolved.
1447 SmallPtrSet
<const NamedTypeConstraint
*, 8> verifiedVariables
;
1448 FmtContext verifierFCtx
;
1449 for (TypeResolution
&resolver
:
1450 llvm::concat
<TypeResolution
>(resultTypes
, operandTypes
)) {
1451 std::optional
<StringRef
> transformer
= resolver
.getVarTransformer();
1454 // Ensure that we don't verify the same variables twice.
1455 const NamedTypeConstraint
*variable
= resolver
.getVariable();
1456 if (!variable
|| !verifiedVariables
.insert(variable
).second
)
1459 auto constraint
= variable
->constraint
;
1460 body
<< " for (::mlir::Type type : " << variable
->name
<< "Types) {\n"
1463 << tgfmt(constraint
.getConditionTemplate(),
1464 &verifierFCtx
.withSelf("type"))
1466 << formatv(" return parser.emitError(parser.getNameLoc()) << "
1467 "\"'{0}' must be {1}, but got \" << type;\n",
1468 variable
->name
, constraint
.getSummary())
1473 // Initialize the set of buildable types.
1474 if (!buildableTypes
.empty()) {
1475 FmtContext typeBuilderCtx
;
1476 typeBuilderCtx
.withBuilder("parser.getBuilder()");
1477 for (auto &it
: buildableTypes
)
1478 body
<< " ::mlir::Type odsBuildableType" << it
.second
<< " = "
1479 << tgfmt(it
.first
, &typeBuilderCtx
) << ";\n";
1482 // Emit the code necessary for a type resolver.
1483 auto emitTypeResolver
= [&](TypeResolution
&resolver
, StringRef curVar
) {
1484 if (std::optional
<int> val
= resolver
.getBuilderIdx()) {
1485 body
<< "odsBuildableType" << *val
;
1486 } else if (const NamedTypeConstraint
*var
= resolver
.getVariable()) {
1487 if (std::optional
<StringRef
> tform
= resolver
.getVarTransformer()) {
1488 FmtContext fmtContext
;
1489 fmtContext
.addSubst("_ctxt", "parser.getContext()");
1490 if (var
->isVariadic())
1491 fmtContext
.withSelf(var
->name
+ "Types");
1493 fmtContext
.withSelf(var
->name
+ "Types[0]");
1494 body
<< tgfmt(*tform
, &fmtContext
);
1496 body
<< var
->name
<< "Types";
1497 if (!var
->isVariadic())
1500 } else if (const NamedAttribute
*attr
= resolver
.getAttribute()) {
1501 if (std::optional
<StringRef
> tform
= resolver
.getVarTransformer())
1502 body
<< tgfmt(*tform
,
1503 &FmtContext().withSelf(attr
->name
+ "Attr.getType()"));
1505 body
<< attr
->name
<< "Attr.getType()";
1507 body
<< curVar
<< "Types";
1511 // Resolve each of the result types.
1512 if (!infersResultTypes
) {
1513 if (allResultTypes
) {
1514 body
<< " result.addTypes(allResultTypes);\n";
1516 for (unsigned i
= 0, e
= op
.getNumResults(); i
!= e
; ++i
) {
1517 body
<< " result.addTypes(";
1518 emitTypeResolver(resultTypes
[i
], op
.getResultName(i
));
1524 // Emit the operand type resolutions.
1525 genParserOperandTypeResolution(op
, body
, emitTypeResolver
);
1527 // Handle return type inference once all operands have been resolved
1528 if (infersResultTypes
)
1529 body
<< formatv(inferReturnTypesParserCode
, op
.getCppClassName());
1532 void OperationFormat::genParserOperandTypeResolution(
1533 Operator
&op
, MethodBody
&body
,
1534 function_ref
<void(TypeResolution
&, StringRef
)> emitTypeResolver
) {
1535 // Early exit if there are no operands.
1536 if (op
.getNumOperands() == 0)
1539 // Handle the case where all operand types are grouped together with
1540 // "types(operands)".
1541 if (allOperandTypes
) {
1542 // If `operands` was specified, use the full operand list directly.
1544 body
<< " if (parser.resolveOperands(allOperands, allOperandTypes, "
1545 "allOperandLoc, result.operands))\n"
1546 " return ::mlir::failure();\n";
1550 // Otherwise, use llvm::concat to merge the disjoint operand lists together.
1551 // llvm::concat does not allow the case of a single range, so guard it here.
1552 body
<< " if (parser.resolveOperands(";
1553 if (op
.getNumOperands() > 1) {
1554 body
<< "::llvm::concat<const ::mlir::OpAsmParser::UnresolvedOperand>(";
1555 llvm::interleaveComma(op
.getOperands(), body
, [&](auto &operand
) {
1556 body
<< operand
.name
<< "Operands";
1560 body
<< op
.operand_begin()->name
<< "Operands";
1562 body
<< ", allOperandTypes, parser.getNameLoc(), result.operands))\n"
1563 << " return ::mlir::failure();\n";
1567 // Handle the case where all operands are grouped together with "operands".
1569 body
<< " if (parser.resolveOperands(allOperands, ";
1571 // Group all of the operand types together to perform the resolution all at
1572 // once. Use llvm::concat to perform the merge. llvm::concat does not allow
1573 // the case of a single range, so guard it here.
1574 if (op
.getNumOperands() > 1) {
1575 body
<< "::llvm::concat<const ::mlir::Type>(";
1576 llvm::interleaveComma(
1577 llvm::seq
<int>(0, op
.getNumOperands()), body
, [&](int i
) {
1578 body
<< "::llvm::ArrayRef<::mlir::Type>(";
1579 emitTypeResolver(operandTypes
[i
], op
.getOperand(i
).name
);
1584 emitTypeResolver(operandTypes
.front(), op
.getOperand(0).name
);
1587 body
<< ", allOperandLoc, result.operands))\n return "
1588 "::mlir::failure();\n";
1592 // The final case is the one where each of the operands types are resolved
1594 for (unsigned i
= 0, e
= op
.getNumOperands(); i
!= e
; ++i
) {
1595 NamedTypeConstraint
&operand
= op
.getOperand(i
);
1596 body
<< " if (parser.resolveOperands(" << operand
.name
<< "Operands, ";
1598 // Resolve the type of this operand.
1599 TypeResolution
&operandType
= operandTypes
[i
];
1600 emitTypeResolver(operandType
, operand
.name
);
1602 body
<< ", " << operand
.name
1603 << "OperandsLoc, result.operands))\n return ::mlir::failure();\n";
1607 void OperationFormat::genParserRegionResolution(Operator
&op
,
1609 // Check for the case where all regions were parsed.
1610 bool hasAllRegions
= llvm::any_of(
1611 elements
, [](FormatElement
*elt
) { return isa
<RegionsDirective
>(elt
); });
1612 if (hasAllRegions
) {
1613 body
<< " result.addRegions(fullRegions);\n";
1617 // Otherwise, handle each region individually.
1618 for (const NamedRegion
®ion
: op
.getRegions()) {
1619 if (region
.isVariadic())
1620 body
<< " result.addRegions(" << region
.name
<< "Regions);\n";
1622 body
<< " result.addRegion(std::move(" << region
.name
<< "Region));\n";
1626 void OperationFormat::genParserSuccessorResolution(Operator
&op
,
1628 // Check for the case where all successors were parsed.
1629 bool hasAllSuccessors
= llvm::any_of(elements
, [](FormatElement
*elt
) {
1630 return isa
<SuccessorsDirective
>(elt
);
1632 if (hasAllSuccessors
) {
1633 body
<< " result.addSuccessors(fullSuccessors);\n";
1637 // Otherwise, handle each successor individually.
1638 for (const NamedSuccessor
&successor
: op
.getSuccessors()) {
1639 if (successor
.isVariadic())
1640 body
<< " result.addSuccessors(" << successor
.name
<< "Successors);\n";
1642 body
<< " result.addSuccessors(" << successor
.name
<< "Successor);\n";
1646 void OperationFormat::genParserVariadicSegmentResolution(Operator
&op
,
1649 if (op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1650 if (op
.getDialect().usePropertiesForAttributes()) {
1652 "result.getOrAddProperties<{0}::Properties>().operand_"
1654 "(parser.getBuilder().getDenseI32ArrayAttr({{",
1655 op
.getCppClassName());
1657 body
<< " result.addAttribute(\"operand_segment_sizes\", "
1658 << "parser.getBuilder().getDenseI32ArrayAttr({";
1660 auto interleaveFn
= [&](const NamedTypeConstraint
&operand
) {
1661 // If the operand is variadic emit the parsed size.
1662 if (operand
.isVariableLength())
1663 body
<< "static_cast<int32_t>(" << operand
.name
<< "Operands.size())";
1667 llvm::interleaveComma(op
.getOperands(), body
, interleaveFn
);
1670 for (const NamedTypeConstraint
&operand
: op
.getOperands()) {
1671 if (!operand
.isVariadicOfVariadic())
1673 if (op
.getDialect().usePropertiesForAttributes()) {
1674 body
<< llvm::formatv(
1675 " result.getOrAddProperties<{0}::Properties>().{1} = "
1676 "parser.getBuilder().getDenseI32ArrayAttr({2}OperandGroupSizes);\n",
1677 op
.getCppClassName(),
1678 operand
.constraint
.getVariadicOfVariadicSegmentSizeAttr(),
1681 body
<< llvm::formatv(
1682 " result.addAttribute(\"{0}\", "
1683 "parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));"
1685 operand
.constraint
.getVariadicOfVariadicSegmentSizeAttr(),
1691 if (!allResultTypes
&&
1692 op
.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
1693 if (op
.getDialect().usePropertiesForAttributes()) {
1696 "result.getOrAddProperties<{0}::Properties>().result_segment_sizes = "
1697 "(parser.getBuilder().getDenseI32ArrayAttr({{",
1698 op
.getCppClassName());
1700 body
<< " result.addAttribute(\"result_segment_sizes\", "
1701 << "parser.getBuilder().getDenseI32ArrayAttr({";
1703 auto interleaveFn
= [&](const NamedTypeConstraint
&result
) {
1704 // If the result is variadic emit the parsed size.
1705 if (result
.isVariableLength())
1706 body
<< "static_cast<int32_t>(" << result
.name
<< "Types.size())";
1710 llvm::interleaveComma(op
.getResults(), body
, interleaveFn
);
1715 //===----------------------------------------------------------------------===//
1718 /// The code snippet used to generate a printer call for a region of an
1719 // operation that has the SingleBlockImplicitTerminator trait.
1721 /// {0}: The name of the region.
1722 const char *regionSingleBlockImplicitTerminatorPrinterCode
= R
"(
1724 bool printTerminator = true;
1725 if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{
1726 printTerminator = !term->getAttrDictionary().empty() ||
1727 term->getNumOperands() != 0 ||
1728 term->getNumResults() != 0;
1730 _odsPrinter.printRegion({0}, /*printEntryBlockArgs=*/true,
1731 /*printBlockTerminators=*/printTerminator);
1735 /// The code snippet used to generate a printer call for an enum that has cases
1736 /// that can't be represented with a keyword.
1738 /// {0}: The name of the enum attribute.
1739 /// {1}: The name of the enum attributes symbolToString function.
1740 const char *enumAttrBeginPrinterCode
= R
"(
1742 auto caseValue = {0}();
1743 auto caseValueStr = {1}(caseValue);
1746 /// Generate the printer for the 'prop-dict' directive.
1747 static void genPropDictPrinter(OperationFormat
&fmt
, Operator
&op
,
1749 body
<< " _odsPrinter << \" \";\n"
1750 << " printProperties(this->getContext(), _odsPrinter, "
1751 "getProperties());\n";
1754 /// Generate the printer for the 'attr-dict' directive.
1755 static void genAttrDictPrinter(OperationFormat
&fmt
, Operator
&op
,
1756 MethodBody
&body
, bool withKeyword
) {
1757 body
<< " ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;\n";
1758 // Elide the variadic segment size attributes if necessary.
1759 if (!fmt
.allOperands
&&
1760 op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"))
1761 body
<< " elidedAttrs.push_back(\"operand_segment_sizes\");\n";
1762 if (!fmt
.allResultTypes
&&
1763 op
.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
1764 body
<< " elidedAttrs.push_back(\"result_segment_sizes\");\n";
1765 for (const StringRef key
: fmt
.inferredAttributes
.keys())
1766 body
<< " elidedAttrs.push_back(\"" << key
<< "\");\n";
1767 for (const NamedAttribute
*attr
: fmt
.usedAttributes
)
1768 body
<< " elidedAttrs.push_back(\"" << attr
->name
<< "\");\n";
1769 // Add code to check attributes for equality with the default value
1770 // for attributes with the elidePrintingDefaultValue bit set.
1771 for (const NamedAttribute
&namedAttr
: op
.getAttributes()) {
1772 const Attribute
&attr
= namedAttr
.attr
;
1773 if (!attr
.isDerivedAttr() && attr
.hasDefaultValue()) {
1774 const StringRef
&name
= namedAttr
.name
;
1776 fctx
.withBuilder("odsBuilder");
1777 std::string defaultValue
= std::string(
1778 tgfmt(attr
.getConstBuilderTemplate(), &fctx
, attr
.getDefaultValue()));
1780 body
<< " ::mlir::Builder odsBuilder(getContext());\n";
1781 body
<< " ::mlir::Attribute attr = " << op
.getGetterName(name
)
1783 body
<< " if(attr && (attr == " << defaultValue
<< "))\n";
1784 body
<< " elidedAttrs.push_back(\"" << name
<< "\");\n";
1788 body
<< " _odsPrinter.printOptionalAttrDict"
1789 << (withKeyword
? "WithKeyword" : "")
1790 << "((*this)->getAttrs(), elidedAttrs);\n";
1793 /// Generate the printer for a literal value. `shouldEmitSpace` is true if a
1794 /// space should be emitted before this element. `lastWasPunctuation` is true if
1795 /// the previous element was a punctuation literal.
1796 static void genLiteralPrinter(StringRef value
, MethodBody
&body
,
1797 bool &shouldEmitSpace
, bool &lastWasPunctuation
) {
1798 body
<< " _odsPrinter";
1800 // Don't insert a space for certain punctuation.
1801 if (shouldEmitSpace
&& shouldEmitSpaceBefore(value
, lastWasPunctuation
))
1803 body
<< " << \"" << value
<< "\";\n";
1805 // Insert a space after certain literals.
1807 value
.size() != 1 || !StringRef("<({[").contains(value
.front());
1808 lastWasPunctuation
= value
.front() != '_' && !isalpha(value
.front());
1811 /// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation`
1812 /// are set to false.
1813 static void genSpacePrinter(bool value
, MethodBody
&body
, bool &shouldEmitSpace
,
1814 bool &lastWasPunctuation
) {
1816 body
<< " _odsPrinter << ' ';\n";
1817 lastWasPunctuation
= false;
1819 lastWasPunctuation
= true;
1821 shouldEmitSpace
= false;
1824 /// Generate the printer for a custom directive parameter.
1825 static void genCustomDirectiveParameterPrinter(FormatElement
*element
,
1828 if (auto *attr
= dyn_cast
<AttributeVariable
>(element
)) {
1829 body
<< op
.getGetterName(attr
->getVar()->name
) << "Attr()";
1831 } else if (isa
<AttrDictDirective
>(element
)) {
1832 body
<< "getOperation()->getAttrDictionary()";
1834 } else if (auto *operand
= dyn_cast
<OperandVariable
>(element
)) {
1835 body
<< op
.getGetterName(operand
->getVar()->name
) << "()";
1837 } else if (auto *region
= dyn_cast
<RegionVariable
>(element
)) {
1838 body
<< op
.getGetterName(region
->getVar()->name
) << "()";
1840 } else if (auto *successor
= dyn_cast
<SuccessorVariable
>(element
)) {
1841 body
<< op
.getGetterName(successor
->getVar()->name
) << "()";
1843 } else if (auto *dir
= dyn_cast
<RefDirective
>(element
)) {
1844 genCustomDirectiveParameterPrinter(dir
->getArg(), op
, body
);
1846 } else if (auto *dir
= dyn_cast
<TypeDirective
>(element
)) {
1847 auto *typeOperand
= dir
->getArg();
1848 auto *operand
= dyn_cast
<OperandVariable
>(typeOperand
);
1849 auto *var
= operand
? operand
->getVar()
1850 : cast
<ResultVariable
>(typeOperand
)->getVar();
1851 std::string name
= op
.getGetterName(var
->name
);
1852 if (var
->isVariadic())
1853 body
<< name
<< "().getTypes()";
1854 else if (var
->isOptional())
1855 body
<< llvm::formatv("({0}() ? {0}().getType() : ::mlir::Type())", name
);
1857 body
<< name
<< "().getType()";
1859 } else if (auto *string
= dyn_cast
<StringElement
>(element
)) {
1861 ctx
.withBuilder("::mlir::Builder(getContext())");
1862 ctx
.addSubst("_ctxt", "getContext()");
1863 body
<< tgfmt(string
->getValue(), &ctx
);
1866 llvm_unreachable("unknown custom directive parameter");
1870 /// Generate the printer for a custom directive.
1871 static void genCustomDirectivePrinter(CustomDirective
*customDir
,
1872 const Operator
&op
, MethodBody
&body
) {
1873 body
<< " print" << customDir
->getName() << "(_odsPrinter, *this";
1874 for (FormatElement
*param
: customDir
->getArguments()) {
1876 genCustomDirectiveParameterPrinter(param
, op
, body
);
1881 /// Generate the printer for a region with the given variable name.
1882 static void genRegionPrinter(const Twine
®ionName
, MethodBody
&body
,
1883 bool hasImplicitTermTrait
) {
1884 if (hasImplicitTermTrait
)
1885 body
<< llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode
,
1888 body
<< " _odsPrinter.printRegion(" << regionName
<< ");\n";
1890 static void genVariadicRegionPrinter(const Twine
®ionListName
,
1892 bool hasImplicitTermTrait
) {
1893 body
<< " llvm::interleaveComma(" << regionListName
1894 << ", _odsPrinter, [&](::mlir::Region ®ion) {\n ";
1895 genRegionPrinter("region", body
, hasImplicitTermTrait
);
1899 /// Generate the C++ for an operand to a (*-)type directive.
1900 static MethodBody
&genTypeOperandPrinter(FormatElement
*arg
, const Operator
&op
,
1902 bool useArrayRef
= true) {
1903 if (isa
<OperandsDirective
>(arg
))
1904 return body
<< "getOperation()->getOperandTypes()";
1905 if (isa
<ResultsDirective
>(arg
))
1906 return body
<< "getOperation()->getResultTypes()";
1907 auto *operand
= dyn_cast
<OperandVariable
>(arg
);
1908 auto *var
= operand
? operand
->getVar() : cast
<ResultVariable
>(arg
)->getVar();
1909 if (var
->isVariadicOfVariadic())
1910 return body
<< llvm::formatv("{0}().join().getTypes()",
1911 op
.getGetterName(var
->name
));
1912 if (var
->isVariadic())
1913 return body
<< op
.getGetterName(var
->name
) << "().getTypes()";
1914 if (var
->isOptional())
1915 return body
<< llvm::formatv(
1916 "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
1917 "::llvm::ArrayRef<::mlir::Type>())",
1918 op
.getGetterName(var
->name
));
1920 return body
<< "::llvm::ArrayRef<::mlir::Type>("
1921 << op
.getGetterName(var
->name
) << "().getType())";
1922 return body
<< op
.getGetterName(var
->name
) << "().getType()";
1925 /// Generate the printer for an enum attribute.
1926 static void genEnumAttrPrinter(const NamedAttribute
*var
, const Operator
&op
,
1928 Attribute baseAttr
= var
->attr
.getBaseAttr();
1929 const EnumAttr
&enumAttr
= cast
<EnumAttr
>(baseAttr
);
1930 std::vector
<EnumAttrCase
> cases
= enumAttr
.getAllCases();
1932 body
<< llvm::formatv(enumAttrBeginPrinterCode
,
1933 (var
->attr
.isOptional() ? "*" : "") +
1934 op
.getGetterName(var
->name
),
1935 enumAttr
.getSymbolToStringFnName());
1937 // Get a string containing all of the cases that can't be represented with a
1939 BitVector
nonKeywordCases(cases
.size());
1940 for (auto it
: llvm::enumerate(cases
)) {
1941 if (!canFormatStringAsKeyword(it
.value().getStr()))
1942 nonKeywordCases
.set(it
.index());
1945 // Otherwise if this is a bit enum attribute, don't allow cases that may
1946 // overlap with other cases. For simplicity sake, only allow cases with a
1947 // single bit value.
1948 if (enumAttr
.isBitEnum()) {
1949 for (auto it
: llvm::enumerate(cases
)) {
1950 int64_t value
= it
.value().getValue();
1951 if (value
< 0 || !llvm::isPowerOf2_64(value
))
1952 nonKeywordCases
.set(it
.index());
1956 // If there are any cases that can't be used with a keyword, switch on the
1957 // case value to determine when to print in the string form.
1958 if (nonKeywordCases
.any()) {
1959 body
<< " switch (caseValue) {\n";
1960 StringRef cppNamespace
= enumAttr
.getCppNamespace();
1961 StringRef enumName
= enumAttr
.getEnumClassName();
1962 for (auto it
: llvm::enumerate(cases
)) {
1963 if (nonKeywordCases
.test(it
.index()))
1965 StringRef symbol
= it
.value().getSymbol();
1966 body
<< llvm::formatv(" case {0}::{1}::{2}:\n", cppNamespace
, enumName
,
1967 llvm::isDigit(symbol
.front()) ? ("_" + symbol
)
1970 body
<< " _odsPrinter << caseValueStr;\n"
1973 " _odsPrinter << '\"' << caseValueStr << '\"';\n"
1980 body
<< " _odsPrinter << caseValueStr;\n"
1984 /// Generate the check for the anchor of an optional group.
1985 static void genOptionalGroupPrinterAnchor(FormatElement
*anchor
,
1988 TypeSwitch
<FormatElement
*>(anchor
)
1989 .Case
<OperandVariable
, ResultVariable
>([&](auto *element
) {
1990 const NamedTypeConstraint
*var
= element
->getVar();
1991 std::string name
= op
.getGetterName(var
->name
);
1992 if (var
->isOptional())
1993 body
<< name
<< "()";
1994 else if (var
->isVariadic())
1995 body
<< "!" << name
<< "().empty()";
1997 .Case
<RegionVariable
>([&](RegionVariable
*element
) {
1998 const NamedRegion
*var
= element
->getVar();
1999 std::string name
= op
.getGetterName(var
->name
);
2000 // TODO: Add a check for optional regions here when ODS supports it.
2001 body
<< "!" << name
<< "().empty()";
2003 .Case
<TypeDirective
>([&](TypeDirective
*element
) {
2004 genOptionalGroupPrinterAnchor(element
->getArg(), op
, body
);
2006 .Case
<FunctionalTypeDirective
>([&](FunctionalTypeDirective
*element
) {
2007 genOptionalGroupPrinterAnchor(element
->getInputs(), op
, body
);
2009 .Case
<AttributeVariable
>([&](AttributeVariable
*element
) {
2010 Attribute attr
= element
->getVar()->attr
;
2011 body
<< op
.getGetterName(element
->getVar()->name
) << "Attr()";
2012 if (attr
.isOptional())
2014 if (attr
.hasDefaultValue()) {
2015 // Consider a default-valued attribute as present if it's not the
2018 fctx
.withBuilder("::mlir::OpBuilder((*this)->getContext())");
2019 body
<< " && " << op
.getGetterName(element
->getVar()->name
)
2021 << tgfmt(attr
.getConstBuilderTemplate(), &fctx
,
2022 attr
.getDefaultValue());
2025 llvm_unreachable("attribute must be optional or default-valued");
2029 void collect(FormatElement
*element
,
2030 SmallVectorImpl
<VariableElement
*> &variables
) {
2031 TypeSwitch
<FormatElement
*>(element
)
2032 .Case([&](VariableElement
*var
) { variables
.emplace_back(var
); })
2033 .Case([&](CustomDirective
*ele
) {
2034 for (FormatElement
*arg
: ele
->getArguments())
2035 collect(arg
, variables
);
2037 .Case([&](OptionalElement
*ele
) {
2038 for (FormatElement
*arg
: ele
->getThenElements())
2039 collect(arg
, variables
);
2040 for (FormatElement
*arg
: ele
->getElseElements())
2041 collect(arg
, variables
);
2043 .Case([&](FunctionalTypeDirective
*funcType
) {
2044 collect(funcType
->getInputs(), variables
);
2045 collect(funcType
->getResults(), variables
);
2047 .Case([&](OIListElement
*oilist
) {
2048 for (ArrayRef
<FormatElement
*> arg
: oilist
->getParsingElements())
2049 for (FormatElement
*arg
: arg
)
2050 collect(arg
, variables
);
2054 void OperationFormat::genElementPrinter(FormatElement
*element
,
2055 MethodBody
&body
, Operator
&op
,
2056 bool &shouldEmitSpace
,
2057 bool &lastWasPunctuation
) {
2058 if (LiteralElement
*literal
= dyn_cast
<LiteralElement
>(element
))
2059 return genLiteralPrinter(literal
->getSpelling(), body
, shouldEmitSpace
,
2060 lastWasPunctuation
);
2062 // Emit a whitespace element.
2063 if (auto *space
= dyn_cast
<WhitespaceElement
>(element
)) {
2064 if (space
->getValue() == "\\n") {
2065 body
<< " _odsPrinter.printNewline();\n";
2067 genSpacePrinter(!space
->getValue().empty(), body
, shouldEmitSpace
,
2068 lastWasPunctuation
);
2073 // Emit an optional group.
2074 if (OptionalElement
*optional
= dyn_cast
<OptionalElement
>(element
)) {
2075 // Emit the check for the presence of the anchor element.
2076 FormatElement
*anchor
= optional
->getAnchor();
2078 if (optional
->isInverted())
2080 genOptionalGroupPrinterAnchor(anchor
, op
, body
);
2084 // If the anchor is a unit attribute, we don't need to print it. When
2085 // parsing, we will add this attribute if this group is present.
2086 ArrayRef
<FormatElement
*> thenElements
= optional
->getThenElements();
2087 ArrayRef
<FormatElement
*> elseElements
= optional
->getElseElements();
2088 FormatElement
*elidedAnchorElement
= nullptr;
2089 auto *anchorAttr
= dyn_cast
<AttributeVariable
>(anchor
);
2090 if (anchorAttr
&& anchorAttr
!= thenElements
.front() &&
2091 (elseElements
.empty() || anchorAttr
!= elseElements
.front()) &&
2092 anchorAttr
->isUnitAttr()) {
2093 elidedAnchorElement
= anchorAttr
;
2095 auto genElementPrinters
= [&](ArrayRef
<FormatElement
*> elements
) {
2096 for (FormatElement
*childElement
: elements
) {
2097 if (childElement
!= elidedAnchorElement
) {
2098 genElementPrinter(childElement
, body
, op
, shouldEmitSpace
,
2099 lastWasPunctuation
);
2104 // Emit each of the elements.
2105 genElementPrinters(thenElements
);
2108 // Emit each of the else elements.
2109 if (!elseElements
.empty()) {
2110 body
<< " else {\n";
2111 genElementPrinters(elseElements
);
2115 body
.unindent() << "\n";
2120 if (auto *oilist
= dyn_cast
<OIListElement
>(element
)) {
2121 genLiteralPrinter(" ", body
, shouldEmitSpace
, lastWasPunctuation
);
2122 for (auto clause
: oilist
->getClauses()) {
2123 LiteralElement
*lelement
= std::get
<0>(clause
);
2124 ArrayRef
<FormatElement
*> pelement
= std::get
<1>(clause
);
2126 SmallVector
<VariableElement
*> vars
;
2127 for (FormatElement
*el
: pelement
)
2129 body
<< " if (false";
2130 for (VariableElement
*var
: vars
) {
2131 TypeSwitch
<FormatElement
*>(var
)
2132 .Case([&](AttributeVariable
*attrEle
) {
2133 body
<< " || " << op
.getGetterName(attrEle
->getVar()->name
)
2136 .Case([&](OperandVariable
*ele
) {
2137 if (ele
->getVar()->isVariadic()) {
2138 body
<< " || " << op
.getGetterName(ele
->getVar()->name
)
2141 body
<< " || " << op
.getGetterName(ele
->getVar()->name
) << "()";
2144 .Case([&](ResultVariable
*ele
) {
2145 if (ele
->getVar()->isVariadic()) {
2146 body
<< " || " << op
.getGetterName(ele
->getVar()->name
)
2149 body
<< " || " << op
.getGetterName(ele
->getVar()->name
) << "()";
2152 .Case([&](RegionVariable
*reg
) {
2153 body
<< " || " << op
.getGetterName(reg
->getVar()->name
) << "()";
2158 genLiteralPrinter(lelement
->getSpelling(), body
, shouldEmitSpace
,
2159 lastWasPunctuation
);
2160 if (oilist
->getUnitAttrParsingElement(pelement
) == nullptr) {
2161 for (FormatElement
*element
: pelement
)
2162 genElementPrinter(element
, body
, op
, shouldEmitSpace
,
2163 lastWasPunctuation
);
2170 // Emit the attribute dictionary.
2171 if (auto *attrDict
= dyn_cast
<AttrDictDirective
>(element
)) {
2172 genAttrDictPrinter(*this, op
, body
, attrDict
->isWithKeyword());
2173 lastWasPunctuation
= false;
2177 // Emit the attribute dictionary.
2178 if (auto *propDict
= dyn_cast
<PropDictDirective
>(element
)) {
2179 genPropDictPrinter(*this, op
, body
);
2180 lastWasPunctuation
= false;
2184 // Optionally insert a space before the next element. The AttrDict printer
2185 // already adds a space as necessary.
2186 if (shouldEmitSpace
|| !lastWasPunctuation
)
2187 body
<< " _odsPrinter << ' ';\n";
2188 lastWasPunctuation
= false;
2189 shouldEmitSpace
= true;
2191 if (auto *attr
= dyn_cast
<AttributeVariable
>(element
)) {
2192 const NamedAttribute
*var
= attr
->getVar();
2194 // If we are formatting as an enum, symbolize the attribute as a string.
2195 if (canFormatEnumAttr(var
))
2196 return genEnumAttrPrinter(var
, op
, body
);
2198 // If we are formatting as a symbol name, handle it as a symbol name.
2199 if (shouldFormatSymbolNameAttr(var
)) {
2200 body
<< " _odsPrinter.printSymbolName(" << op
.getGetterName(var
->name
)
2201 << "Attr().getValue());\n";
2205 // Elide the attribute type if it is buildable.
2206 if (attr
->getTypeBuilder())
2207 body
<< " _odsPrinter.printAttributeWithoutType("
2208 << op
.getGetterName(var
->name
) << "Attr());\n";
2209 else if (attr
->shouldBeQualified() ||
2210 var
->attr
.getStorageType() == "::mlir::Attribute")
2211 body
<< " _odsPrinter.printAttribute(" << op
.getGetterName(var
->name
)
2214 body
<< "_odsPrinter.printStrippedAttrOrType("
2215 << op
.getGetterName(var
->name
) << "Attr());\n";
2216 } else if (auto *operand
= dyn_cast
<OperandVariable
>(element
)) {
2217 if (operand
->getVar()->isVariadicOfVariadic()) {
2218 body
<< " ::llvm::interleaveComma("
2219 << op
.getGetterName(operand
->getVar()->name
)
2220 << "(), _odsPrinter, [&](const auto &operands) { _odsPrinter << "
2221 "\"(\" << operands << "
2224 } else if (operand
->getVar()->isOptional()) {
2225 body
<< " if (::mlir::Value value = "
2226 << op
.getGetterName(operand
->getVar()->name
) << "())\n"
2227 << " _odsPrinter << value;\n";
2229 body
<< " _odsPrinter << " << op
.getGetterName(operand
->getVar()->name
)
2232 } else if (auto *region
= dyn_cast
<RegionVariable
>(element
)) {
2233 const NamedRegion
*var
= region
->getVar();
2234 std::string name
= op
.getGetterName(var
->name
);
2235 if (var
->isVariadic()) {
2236 genVariadicRegionPrinter(name
+ "()", body
, hasImplicitTermTrait
);
2238 genRegionPrinter(name
+ "()", body
, hasImplicitTermTrait
);
2240 } else if (auto *successor
= dyn_cast
<SuccessorVariable
>(element
)) {
2241 const NamedSuccessor
*var
= successor
->getVar();
2242 std::string name
= op
.getGetterName(var
->name
);
2243 if (var
->isVariadic())
2244 body
<< " ::llvm::interleaveComma(" << name
<< "(), _odsPrinter);\n";
2246 body
<< " _odsPrinter << " << name
<< "();\n";
2247 } else if (auto *dir
= dyn_cast
<CustomDirective
>(element
)) {
2248 genCustomDirectivePrinter(dir
, op
, body
);
2249 } else if (isa
<OperandsDirective
>(element
)) {
2250 body
<< " _odsPrinter << getOperation()->getOperands();\n";
2251 } else if (isa
<RegionsDirective
>(element
)) {
2252 genVariadicRegionPrinter("getOperation()->getRegions()", body
,
2253 hasImplicitTermTrait
);
2254 } else if (isa
<SuccessorsDirective
>(element
)) {
2255 body
<< " ::llvm::interleaveComma(getOperation()->getSuccessors(), "
2257 } else if (auto *dir
= dyn_cast
<TypeDirective
>(element
)) {
2258 if (auto *operand
= dyn_cast
<OperandVariable
>(dir
->getArg())) {
2259 if (operand
->getVar()->isVariadicOfVariadic()) {
2260 body
<< llvm::formatv(
2261 " ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, "
2262 "[&](::mlir::TypeRange types) {{ _odsPrinter << \"(\" << "
2263 "types << \")\"; });\n",
2264 op
.getGetterName(operand
->getVar()->name
));
2268 const NamedTypeConstraint
*var
= nullptr;
2270 if (auto *operand
= dyn_cast
<OperandVariable
>(dir
->getArg()))
2271 var
= operand
->getVar();
2272 else if (auto *operand
= dyn_cast
<ResultVariable
>(dir
->getArg()))
2273 var
= operand
->getVar();
2275 if (var
&& !var
->isVariadicOfVariadic() && !var
->isVariadic() &&
2276 !var
->isOptional()) {
2277 std::string cppClass
= var
->constraint
.getCPPClassName();
2278 if (dir
->shouldBeQualified()) {
2279 body
<< " _odsPrinter << " << op
.getGetterName(var
->name
)
2280 << "().getType();\n";
2284 << " auto type = " << op
.getGetterName(var
->name
)
2285 << "().getType();\n"
2286 << " if (auto validType = type.dyn_cast<" << cppClass
<< ">())\n"
2287 << " _odsPrinter.printStrippedAttrOrType(validType);\n"
2289 << " _odsPrinter << type;\n"
2293 body
<< " _odsPrinter << ";
2294 genTypeOperandPrinter(dir
->getArg(), op
, body
, /*useArrayRef=*/false)
2296 } else if (auto *dir
= dyn_cast
<FunctionalTypeDirective
>(element
)) {
2297 body
<< " _odsPrinter.printFunctionalType(";
2298 genTypeOperandPrinter(dir
->getInputs(), op
, body
) << ", ";
2299 genTypeOperandPrinter(dir
->getResults(), op
, body
) << ");\n";
2301 llvm_unreachable("unknown format element");
2305 void OperationFormat::genPrinter(Operator
&op
, OpClass
&opClass
) {
2306 auto *method
= opClass
.addMethod(
2308 MethodParameter("::mlir::OpAsmPrinter &", "_odsPrinter"));
2309 auto &body
= method
->body();
2311 // Flags for if we should emit a space, and if the last element was
2313 bool shouldEmitSpace
= true, lastWasPunctuation
= false;
2314 for (FormatElement
*element
: elements
)
2315 genElementPrinter(element
, body
, op
, shouldEmitSpace
, lastWasPunctuation
);
2318 //===----------------------------------------------------------------------===//
2320 //===----------------------------------------------------------------------===//
2322 /// Function to find an element within the given range that has the same name as
2324 template <typename RangeT
>
2325 static auto findArg(RangeT
&&range
, StringRef name
) {
2326 auto it
= llvm::find_if(range
, [=](auto &arg
) { return arg
.name
== name
; });
2327 return it
!= range
.end() ? &*it
: nullptr;
2331 /// This class implements a parser for an instance of an operation assembly
2333 class OpFormatParser
: public FormatParser
{
2335 OpFormatParser(llvm::SourceMgr
&mgr
, OperationFormat
&format
, Operator
&op
)
2336 : FormatParser(mgr
, op
.getLoc()[0]), fmt(format
), op(op
),
2337 seenOperandTypes(op
.getNumOperands()),
2338 seenResultTypes(op
.getNumResults()) {}
2341 /// Verify the format elements.
2342 LogicalResult
verify(SMLoc loc
, ArrayRef
<FormatElement
*> elements
) override
;
2343 /// Verify the arguments to a custom directive.
2345 verifyCustomDirectiveArguments(SMLoc loc
,
2346 ArrayRef
<FormatElement
*> arguments
) override
;
2347 /// Verify the elements of an optional group.
2348 LogicalResult
verifyOptionalGroupElements(SMLoc loc
,
2349 ArrayRef
<FormatElement
*> elements
,
2350 FormatElement
*anchor
) override
;
2351 LogicalResult
verifyOptionalGroupElement(SMLoc loc
, FormatElement
*element
,
2354 /// Parse an operation variable.
2355 FailureOr
<FormatElement
*> parseVariableImpl(SMLoc loc
, StringRef name
,
2356 Context ctx
) override
;
2357 /// Parse an operation format directive.
2358 FailureOr
<FormatElement
*>
2359 parseDirectiveImpl(SMLoc loc
, FormatToken::Kind kind
, Context ctx
) override
;
2362 /// This struct represents a type resolution instance. It includes a specific
2363 /// type as well as an optional transformer to apply to that type in order to
2364 /// properly resolve the type of a variable.
2365 struct TypeResolutionInstance
{
2366 ConstArgument resolver
;
2367 std::optional
<StringRef
> transformer
;
2370 /// Verify the state of operation attributes within the format.
2371 LogicalResult
verifyAttributes(SMLoc loc
, ArrayRef
<FormatElement
*> elements
);
2373 /// Verify that attributes elements aren't followed by colon literals.
2374 LogicalResult
verifyAttributeColonType(SMLoc loc
,
2375 ArrayRef
<FormatElement
*> elements
);
2376 /// Verify that the attribute dictionary directive isn't followed by a region.
2377 LogicalResult
verifyAttrDictRegion(SMLoc loc
,
2378 ArrayRef
<FormatElement
*> elements
);
2380 /// Verify the state of operation operands within the format.
2382 verifyOperands(SMLoc loc
,
2383 llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
);
2385 /// Verify the state of operation regions within the format.
2386 LogicalResult
verifyRegions(SMLoc loc
);
2388 /// Verify the state of operation results within the format.
2390 verifyResults(SMLoc loc
,
2391 llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
);
2393 /// Verify the state of operation successors within the format.
2394 LogicalResult
verifySuccessors(SMLoc loc
);
2396 LogicalResult
verifyOIListElements(SMLoc loc
,
2397 ArrayRef
<FormatElement
*> elements
);
2399 /// Given the values of an `AllTypesMatch` trait, check for inferable type
2401 void handleAllTypesMatchConstraint(
2402 ArrayRef
<StringRef
> values
,
2403 llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
);
2404 /// Check for inferable type resolution given all operands, and or results,
2405 /// have the same type. If 'includeResults' is true, the results also have the
2406 /// same type as all of the operands.
2407 void handleSameTypesConstraint(
2408 llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
,
2409 bool includeResults
);
2410 /// Check for inferable type resolution based on another operand, result, or
2412 void handleTypesMatchConstraint(
2413 llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
,
2414 const llvm::Record
&def
);
2416 /// Returns an argument or attribute with the given name that has been seen
2417 /// within the format.
2418 ConstArgument
findSeenArg(StringRef name
);
2420 /// Parse the various different directives.
2421 FailureOr
<FormatElement
*> parsePropDictDirective(SMLoc loc
, Context context
);
2422 FailureOr
<FormatElement
*> parseAttrDictDirective(SMLoc loc
, Context context
,
2424 FailureOr
<FormatElement
*> parseFunctionalTypeDirective(SMLoc loc
,
2426 FailureOr
<FormatElement
*> parseOIListDirective(SMLoc loc
, Context context
);
2427 LogicalResult
verifyOIListParsingElement(FormatElement
*element
, SMLoc loc
);
2428 FailureOr
<FormatElement
*> parseOperandsDirective(SMLoc loc
, Context context
);
2429 FailureOr
<FormatElement
*> parseQualifiedDirective(SMLoc loc
,
2431 FailureOr
<FormatElement
*> parseReferenceDirective(SMLoc loc
,
2433 FailureOr
<FormatElement
*> parseRegionsDirective(SMLoc loc
, Context context
);
2434 FailureOr
<FormatElement
*> parseResultsDirective(SMLoc loc
, Context context
);
2435 FailureOr
<FormatElement
*> parseSuccessorsDirective(SMLoc loc
,
2437 FailureOr
<FormatElement
*> parseTypeDirective(SMLoc loc
, Context context
);
2438 FailureOr
<FormatElement
*> parseTypeDirectiveOperand(SMLoc loc
,
2439 bool isRefChild
= false);
2441 //===--------------------------------------------------------------------===//
2443 //===--------------------------------------------------------------------===//
2445 OperationFormat
&fmt
;
2448 // The following are various bits of format state used for verification
2450 bool hasAttrDict
= false;
2451 bool hasPropDict
= false;
2452 bool hasAllRegions
= false, hasAllSuccessors
= false;
2453 bool canInferResultTypes
= false;
2454 llvm::SmallBitVector seenOperandTypes
, seenResultTypes
;
2455 llvm::SmallSetVector
<const NamedAttribute
*, 8> seenAttrs
;
2456 llvm::DenseSet
<const NamedTypeConstraint
*> seenOperands
;
2457 llvm::DenseSet
<const NamedRegion
*> seenRegions
;
2458 llvm::DenseSet
<const NamedSuccessor
*> seenSuccessors
;
2462 LogicalResult
OpFormatParser::verify(SMLoc loc
,
2463 ArrayRef
<FormatElement
*> elements
) {
2464 // Check that the attribute dictionary is in the format.
2466 return emitError(loc
, "'attr-dict' directive not found in "
2467 "custom assembly format");
2469 // Check for any type traits that we can use for inferring types.
2470 llvm::StringMap
<TypeResolutionInstance
> variableTyResolver
;
2471 for (const Trait
&trait
: op
.getTraits()) {
2472 const llvm::Record
&def
= trait
.getDef();
2473 if (def
.isSubClassOf("AllTypesMatch")) {
2474 handleAllTypesMatchConstraint(def
.getValueAsListOfStrings("values"),
2475 variableTyResolver
);
2476 } else if (def
.getName() == "SameTypeOperands") {
2477 handleSameTypesConstraint(variableTyResolver
, /*includeResults=*/false);
2478 } else if (def
.getName() == "SameOperandsAndResultType") {
2479 handleSameTypesConstraint(variableTyResolver
, /*includeResults=*/true);
2480 } else if (def
.isSubClassOf("TypesMatchWith")) {
2481 handleTypesMatchConstraint(variableTyResolver
, def
);
2482 } else if (!op
.allResultTypesKnown()) {
2483 // This doesn't check the name directly to handle
2484 // DeclareOpInterfaceMethods<InferTypeOpInterface>
2486 // TODO: Add hasCppInterface check.
2487 if (auto name
= def
.getValueAsOptionalString("cppInterfaceName")) {
2488 if (*name
== "InferTypeOpInterface" &&
2489 def
.getValueAsString("cppNamespace") == "::mlir")
2490 canInferResultTypes
= true;
2495 // Verify the state of the various operation components.
2496 if (failed(verifyAttributes(loc
, elements
)) ||
2497 failed(verifyResults(loc
, variableTyResolver
)) ||
2498 failed(verifyOperands(loc
, variableTyResolver
)) ||
2499 failed(verifyRegions(loc
)) || failed(verifySuccessors(loc
)) ||
2500 failed(verifyOIListElements(loc
, elements
)))
2503 // Collect the set of used attributes in the format.
2504 fmt
.usedAttributes
= seenAttrs
.takeVector();
2509 OpFormatParser::verifyAttributes(SMLoc loc
,
2510 ArrayRef
<FormatElement
*> elements
) {
2511 // Check that there are no `:` literals after an attribute without a constant
2512 // type. The attribute grammar contains an optional trailing colon type, which
2513 // can lead to unexpected and generally unintended behavior. Given that, it is
2514 // better to just error out here instead.
2515 if (failed(verifyAttributeColonType(loc
, elements
)))
2517 // Check that there are no region variables following an attribute dicitonary.
2518 // Both start with `{` and so the optional attribute dictionary can cause
2519 // format ambiguities.
2520 if (failed(verifyAttrDictRegion(loc
, elements
)))
2523 // Check for VariadicOfVariadic variables. The segment attribute of those
2524 // variables will be infered.
2525 for (const NamedTypeConstraint
*var
: seenOperands
) {
2526 if (var
->constraint
.isVariadicOfVariadic()) {
2527 fmt
.inferredAttributes
.insert(
2528 var
->constraint
.getVariadicOfVariadicSegmentSizeAttr());
2535 /// Returns whether the single format element is optionally parsed.
2536 static bool isOptionallyParsed(FormatElement
*el
) {
2537 if (auto *attrVar
= dyn_cast
<AttributeVariable
>(el
)) {
2538 Attribute attr
= attrVar
->getVar()->attr
;
2539 return attr
.isOptional() || attr
.hasDefaultValue();
2541 if (auto *operandVar
= dyn_cast
<OperandVariable
>(el
)) {
2542 const NamedTypeConstraint
*operand
= operandVar
->getVar();
2543 return operand
->isOptional() || operand
->isVariadic() ||
2544 operand
->isVariadicOfVariadic();
2546 if (auto *successorVar
= dyn_cast
<SuccessorVariable
>(el
))
2547 return successorVar
->getVar()->isVariadic();
2548 if (auto *regionVar
= dyn_cast
<RegionVariable
>(el
))
2549 return regionVar
->getVar()->isVariadic();
2550 return isa
<WhitespaceElement
, AttrDictDirective
>(el
);
2553 /// Scan the given range of elements from the start for an invalid format
2554 /// element that satisfies `isInvalid`, skipping any optionally-parsed elements.
2555 /// If an optional group is encountered, this function recurses into the 'then'
2556 /// and 'else' elements to check if they are invalid. Returns `success` if the
2557 /// range is known to be valid or `std::nullopt` if scanning reached the end.
2559 /// Since the guard element of an optional group is required, this function
2560 /// accepts an optional element pointer to mark it as required.
2561 static std::optional
<LogicalResult
> checkRangeForElement(
2562 FormatElement
*base
,
2563 function_ref
<bool(FormatElement
*, FormatElement
*)> isInvalid
,
2564 iterator_range
<ArrayRef
<FormatElement
*>::iterator
> elementRange
,
2565 FormatElement
*optionalGuard
= nullptr) {
2566 for (FormatElement
*element
: elementRange
) {
2567 // If we encounter an invalid element, return an error.
2568 if (isInvalid(base
, element
))
2571 // Recurse on optional groups.
2572 if (auto *optional
= dyn_cast
<OptionalElement
>(element
)) {
2573 if (std::optional
<LogicalResult
> result
= checkRangeForElement(
2574 base
, isInvalid
, optional
->getThenElements(),
2575 // The optional group guard is required for the group.
2576 optional
->getThenElements().front()))
2577 if (failed(*result
))
2579 if (std::optional
<LogicalResult
> result
= checkRangeForElement(
2580 base
, isInvalid
, optional
->getElseElements()))
2581 if (failed(*result
))
2583 // Skip the optional group.
2587 // Skip optionally parsed elements.
2588 if (element
!= optionalGuard
&& isOptionallyParsed(element
))
2591 // We found a closing element that is valid.
2594 // Return std::nullopt to indicate that we reached the end.
2595 return std::nullopt
;
2598 /// For the given elements, check whether any attributes are followed by a colon
2599 /// literal, resulting in an ambiguous assembly format. Returns a non-null
2600 /// attribute if verification of said attribute reached the end of the range.
2601 /// Returns null if all attribute elements are verified.
2602 static FailureOr
<FormatElement
*> verifyAdjacentElements(
2603 function_ref
<bool(FormatElement
*)> isBase
,
2604 function_ref
<bool(FormatElement
*, FormatElement
*)> isInvalid
,
2605 ArrayRef
<FormatElement
*> elements
) {
2606 for (auto *it
= elements
.begin(), *e
= elements
.end(); it
!= e
; ++it
) {
2607 // The current attribute being verified.
2608 FormatElement
*base
;
2612 } else if (auto *optional
= dyn_cast
<OptionalElement
>(*it
)) {
2613 // Recurse on optional groups.
2614 FailureOr
<FormatElement
*> thenResult
= verifyAdjacentElements(
2615 isBase
, isInvalid
, optional
->getThenElements());
2616 if (failed(thenResult
))
2618 FailureOr
<FormatElement
*> elseResult
= verifyAdjacentElements(
2619 isBase
, isInvalid
, optional
->getElseElements());
2620 if (failed(elseResult
))
2622 // If either optional group has an unverified attribute, save it.
2623 // Otherwise, move on to the next element.
2624 if (!(base
= *thenResult
) && !(base
= *elseResult
))
2630 // Verify subsequent elements for potential ambiguities.
2631 if (std::optional
<LogicalResult
> result
=
2632 checkRangeForElement(base
, isInvalid
, {std::next(it
), e
})) {
2633 if (failed(*result
))
2636 // Since we reached the end, return the attribute as unverified.
2640 // All attribute elements are known to be verified.
2645 OpFormatParser::verifyAttributeColonType(SMLoc loc
,
2646 ArrayRef
<FormatElement
*> elements
) {
2647 auto isBase
= [](FormatElement
*el
) {
2648 auto *attr
= dyn_cast
<AttributeVariable
>(el
);
2651 // Check only attributes without type builders or that are known to call
2652 // the generic attribute parser.
2653 return !attr
->getTypeBuilder() &&
2654 (attr
->shouldBeQualified() ||
2655 attr
->getVar()->attr
.getStorageType() == "::mlir::Attribute");
2657 auto isInvalid
= [&](FormatElement
*base
, FormatElement
*el
) {
2658 auto *literal
= dyn_cast
<LiteralElement
>(el
);
2659 if (!literal
|| literal
->getSpelling() != ":")
2661 // If we encounter `:`, the range is known to be invalid.
2664 llvm::formatv("format ambiguity caused by `:` literal found after "
2665 "attribute `{0}` which does not have a buildable type",
2666 cast
<AttributeVariable
>(base
)->getVar()->name
));
2669 return verifyAdjacentElements(isBase
, isInvalid
, elements
);
2673 OpFormatParser::verifyAttrDictRegion(SMLoc loc
,
2674 ArrayRef
<FormatElement
*> elements
) {
2675 auto isBase
= [](FormatElement
*el
) {
2676 if (auto *attrDict
= dyn_cast
<AttrDictDirective
>(el
))
2677 return !attrDict
->isWithKeyword();
2680 auto isInvalid
= [&](FormatElement
*base
, FormatElement
*el
) {
2681 auto *region
= dyn_cast
<RegionVariable
>(el
);
2684 (void)emitErrorAndNote(
2686 llvm::formatv("format ambiguity caused by `attr-dict` directive "
2687 "followed by region `{0}`",
2688 region
->getVar()->name
),
2689 "try using `attr-dict-with-keyword` instead");
2692 return verifyAdjacentElements(isBase
, isInvalid
, elements
);
2695 LogicalResult
OpFormatParser::verifyOperands(
2696 SMLoc loc
, llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
) {
2697 // Check that all of the operands are within the format, and their types can
2699 auto &buildableTypes
= fmt
.buildableTypes
;
2700 for (unsigned i
= 0, e
= op
.getNumOperands(); i
!= e
; ++i
) {
2701 NamedTypeConstraint
&operand
= op
.getOperand(i
);
2703 // Check that the operand itself is in the format.
2704 if (!fmt
.allOperands
&& !seenOperands
.count(&operand
)) {
2705 return emitErrorAndNote(loc
,
2706 "operand #" + Twine(i
) + ", named '" +
2707 operand
.name
+ "', not found",
2708 "suggest adding a '$" + operand
.name
+
2709 "' directive to the custom assembly format");
2712 // Check that the operand type is in the format, or that it can be inferred.
2713 if (fmt
.allOperandTypes
|| seenOperandTypes
.test(i
))
2716 // Check to see if we can infer this type from another variable.
2717 auto varResolverIt
= variableTyResolver
.find(op
.getOperand(i
).name
);
2718 if (varResolverIt
!= variableTyResolver
.end()) {
2719 TypeResolutionInstance
&resolver
= varResolverIt
->second
;
2720 fmt
.operandTypes
[i
].setResolver(resolver
.resolver
, resolver
.transformer
);
2724 // Similarly to results, allow a custom builder for resolving the type if
2725 // we aren't using the 'operands' directive.
2726 std::optional
<StringRef
> builder
= operand
.constraint
.getBuilderCall();
2727 if (!builder
|| (fmt
.allOperands
&& operand
.isVariableLength())) {
2728 return emitErrorAndNote(
2730 "type of operand #" + Twine(i
) + ", named '" + operand
.name
+
2731 "', is not buildable and a buildable type cannot be inferred",
2732 "suggest adding a type constraint to the operation or adding a "
2734 operand
.name
+ ")' directive to the " + "custom assembly format");
2736 auto it
= buildableTypes
.insert({*builder
, buildableTypes
.size()});
2737 fmt
.operandTypes
[i
].setBuilderIdx(it
.first
->second
);
2742 LogicalResult
OpFormatParser::verifyRegions(SMLoc loc
) {
2743 // Check that all of the regions are within the format.
2747 for (unsigned i
= 0, e
= op
.getNumRegions(); i
!= e
; ++i
) {
2748 const NamedRegion
®ion
= op
.getRegion(i
);
2749 if (!seenRegions
.count(®ion
)) {
2750 return emitErrorAndNote(loc
,
2751 "region #" + Twine(i
) + ", named '" +
2752 region
.name
+ "', not found",
2753 "suggest adding a '$" + region
.name
+
2754 "' directive to the custom assembly format");
2760 LogicalResult
OpFormatParser::verifyResults(
2761 SMLoc loc
, llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
) {
2762 // If we format all of the types together, there is nothing to check.
2763 if (fmt
.allResultTypes
)
2766 // If no result types are specified and we can infer them, infer all result
2768 if (op
.getNumResults() > 0 && seenResultTypes
.count() == 0 &&
2769 canInferResultTypes
) {
2770 fmt
.infersResultTypes
= true;
2774 // Check that all of the result types can be inferred.
2775 auto &buildableTypes
= fmt
.buildableTypes
;
2776 for (unsigned i
= 0, e
= op
.getNumResults(); i
!= e
; ++i
) {
2777 if (seenResultTypes
.test(i
))
2780 // Check to see if we can infer this type from another variable.
2781 auto varResolverIt
= variableTyResolver
.find(op
.getResultName(i
));
2782 if (varResolverIt
!= variableTyResolver
.end()) {
2783 TypeResolutionInstance resolver
= varResolverIt
->second
;
2784 fmt
.resultTypes
[i
].setResolver(resolver
.resolver
, resolver
.transformer
);
2788 // If the result is not variable length, allow for the case where the type
2789 // has a builder that we can use.
2790 NamedTypeConstraint
&result
= op
.getResult(i
);
2791 std::optional
<StringRef
> builder
= result
.constraint
.getBuilderCall();
2792 if (!builder
|| result
.isVariableLength()) {
2793 return emitErrorAndNote(
2795 "type of result #" + Twine(i
) + ", named '" + result
.name
+
2796 "', is not buildable and a buildable type cannot be inferred",
2797 "suggest adding a type constraint to the operation or adding a "
2799 result
.name
+ ")' directive to the " + "custom assembly format");
2801 // Note in the format that this result uses the custom builder.
2802 auto it
= buildableTypes
.insert({*builder
, buildableTypes
.size()});
2803 fmt
.resultTypes
[i
].setBuilderIdx(it
.first
->second
);
2808 LogicalResult
OpFormatParser::verifySuccessors(SMLoc loc
) {
2809 // Check that all of the successors are within the format.
2810 if (hasAllSuccessors
)
2813 for (unsigned i
= 0, e
= op
.getNumSuccessors(); i
!= e
; ++i
) {
2814 const NamedSuccessor
&successor
= op
.getSuccessor(i
);
2815 if (!seenSuccessors
.count(&successor
)) {
2816 return emitErrorAndNote(loc
,
2817 "successor #" + Twine(i
) + ", named '" +
2818 successor
.name
+ "', not found",
2819 "suggest adding a '$" + successor
.name
+
2820 "' directive to the custom assembly format");
2827 OpFormatParser::verifyOIListElements(SMLoc loc
,
2828 ArrayRef
<FormatElement
*> elements
) {
2829 // Check that all of the successors are within the format.
2830 SmallVector
<StringRef
> prohibitedLiterals
;
2831 for (FormatElement
*it
: elements
) {
2832 if (auto *oilist
= dyn_cast
<OIListElement
>(it
)) {
2833 if (!prohibitedLiterals
.empty()) {
2834 // We just saw an oilist element in last iteration. Literals should not
2836 for (LiteralElement
*literal
: oilist
->getLiteralElements()) {
2837 if (find(prohibitedLiterals
, literal
->getSpelling()) !=
2838 prohibitedLiterals
.end()) {
2840 loc
, "format ambiguity because " + literal
->getSpelling() +
2841 " is used in two adjacent oilist elements.");
2845 for (LiteralElement
*literal
: oilist
->getLiteralElements())
2846 prohibitedLiterals
.push_back(literal
->getSpelling());
2847 } else if (auto *literal
= dyn_cast
<LiteralElement
>(it
)) {
2848 if (find(prohibitedLiterals
, literal
->getSpelling()) !=
2849 prohibitedLiterals
.end()) {
2852 "format ambiguity because " + literal
->getSpelling() +
2853 " is used both in oilist element and the adjacent literal.");
2855 prohibitedLiterals
.clear();
2857 prohibitedLiterals
.clear();
2863 void OpFormatParser::handleAllTypesMatchConstraint(
2864 ArrayRef
<StringRef
> values
,
2865 llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
) {
2866 for (unsigned i
= 0, e
= values
.size(); i
!= e
; ++i
) {
2867 // Check to see if this value matches a resolved operand or result type.
2868 ConstArgument arg
= findSeenArg(values
[i
]);
2872 // Mark this value as the type resolver for the other variables.
2873 for (unsigned j
= 0; j
!= i
; ++j
)
2874 variableTyResolver
[values
[j
]] = {arg
, std::nullopt
};
2875 for (unsigned j
= i
+ 1; j
!= e
; ++j
)
2876 variableTyResolver
[values
[j
]] = {arg
, std::nullopt
};
2880 void OpFormatParser::handleSameTypesConstraint(
2881 llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
,
2882 bool includeResults
) {
2883 const NamedTypeConstraint
*resolver
= nullptr;
2884 int resolvedIt
= -1;
2886 // Check to see if there is an operand or result to use for the resolution.
2887 if ((resolvedIt
= seenOperandTypes
.find_first()) != -1)
2888 resolver
= &op
.getOperand(resolvedIt
);
2889 else if (includeResults
&& (resolvedIt
= seenResultTypes
.find_first()) != -1)
2890 resolver
= &op
.getResult(resolvedIt
);
2894 // Set the resolvers for each operand and result.
2895 for (unsigned i
= 0, e
= op
.getNumOperands(); i
!= e
; ++i
)
2896 if (!seenOperandTypes
.test(i
))
2897 variableTyResolver
[op
.getOperand(i
).name
] = {resolver
, std::nullopt
};
2898 if (includeResults
) {
2899 for (unsigned i
= 0, e
= op
.getNumResults(); i
!= e
; ++i
)
2900 if (!seenResultTypes
.test(i
))
2901 variableTyResolver
[op
.getResultName(i
)] = {resolver
, std::nullopt
};
2905 void OpFormatParser::handleTypesMatchConstraint(
2906 llvm::StringMap
<TypeResolutionInstance
> &variableTyResolver
,
2907 const llvm::Record
&def
) {
2908 StringRef lhsName
= def
.getValueAsString("lhs");
2909 StringRef rhsName
= def
.getValueAsString("rhs");
2910 StringRef transformer
= def
.getValueAsString("transformer");
2911 if (ConstArgument arg
= findSeenArg(lhsName
))
2912 variableTyResolver
[rhsName
] = {arg
, transformer
};
2915 ConstArgument
OpFormatParser::findSeenArg(StringRef name
) {
2916 if (const NamedTypeConstraint
*arg
= findArg(op
.getOperands(), name
))
2917 return seenOperandTypes
.test(arg
- op
.operand_begin()) ? arg
: nullptr;
2918 if (const NamedTypeConstraint
*arg
= findArg(op
.getResults(), name
))
2919 return seenResultTypes
.test(arg
- op
.result_begin()) ? arg
: nullptr;
2920 if (const NamedAttribute
*attr
= findArg(op
.getAttributes(), name
))
2921 return seenAttrs
.count(attr
) ? attr
: nullptr;
2925 FailureOr
<FormatElement
*>
2926 OpFormatParser::parseVariableImpl(SMLoc loc
, StringRef name
, Context ctx
) {
2927 // Check that the parsed argument is something actually registered on the op.
2929 if (const NamedAttribute
*attr
= findArg(op
.getAttributes(), name
)) {
2930 if (ctx
== TypeDirectiveContext
)
2932 loc
, "attributes cannot be used as children to a `type` directive");
2933 if (ctx
== RefDirectiveContext
) {
2934 if (!seenAttrs
.count(attr
))
2935 return emitError(loc
, "attribute '" + name
+
2936 "' must be bound before it is referenced");
2937 } else if (!seenAttrs
.insert(attr
)) {
2938 return emitError(loc
, "attribute '" + name
+ "' is already bound");
2941 return create
<AttributeVariable
>(attr
);
2944 if (const NamedTypeConstraint
*operand
= findArg(op
.getOperands(), name
)) {
2945 if (ctx
== TopLevelContext
|| ctx
== CustomDirectiveContext
) {
2946 if (fmt
.allOperands
|| !seenOperands
.insert(operand
).second
)
2947 return emitError(loc
, "operand '" + name
+ "' is already bound");
2948 } else if (ctx
== RefDirectiveContext
&& !seenOperands
.count(operand
)) {
2949 return emitError(loc
, "operand '" + name
+
2950 "' must be bound before it is referenced");
2952 return create
<OperandVariable
>(operand
);
2955 if (const NamedRegion
*region
= findArg(op
.getRegions(), name
)) {
2956 if (ctx
== TopLevelContext
|| ctx
== CustomDirectiveContext
) {
2957 if (hasAllRegions
|| !seenRegions
.insert(region
).second
)
2958 return emitError(loc
, "region '" + name
+ "' is already bound");
2959 } else if (ctx
== RefDirectiveContext
&& !seenRegions
.count(region
)) {
2960 return emitError(loc
, "region '" + name
+
2961 "' must be bound before it is referenced");
2963 return emitError(loc
, "regions can only be used at the top level");
2965 return create
<RegionVariable
>(region
);
2968 if (const auto *result
= findArg(op
.getResults(), name
)) {
2969 if (ctx
!= TypeDirectiveContext
)
2970 return emitError(loc
, "result variables can can only be used as a child "
2971 "to a 'type' directive");
2972 return create
<ResultVariable
>(result
);
2975 if (const auto *successor
= findArg(op
.getSuccessors(), name
)) {
2976 if (ctx
== TopLevelContext
|| ctx
== CustomDirectiveContext
) {
2977 if (hasAllSuccessors
|| !seenSuccessors
.insert(successor
).second
)
2978 return emitError(loc
, "successor '" + name
+ "' is already bound");
2979 } else if (ctx
== RefDirectiveContext
&& !seenSuccessors
.count(successor
)) {
2980 return emitError(loc
, "successor '" + name
+
2981 "' must be bound before it is referenced");
2983 return emitError(loc
, "successors can only be used at the top level");
2986 return create
<SuccessorVariable
>(successor
);
2988 return emitError(loc
, "expected variable to refer to an argument, region, "
2989 "result, or successor");
2992 FailureOr
<FormatElement
*>
2993 OpFormatParser::parseDirectiveImpl(SMLoc loc
, FormatToken::Kind kind
,
2996 case FormatToken::kw_prop_dict
:
2997 return parsePropDictDirective(loc
, ctx
);
2998 case FormatToken::kw_attr_dict
:
2999 return parseAttrDictDirective(loc
, ctx
,
3000 /*withKeyword=*/false);
3001 case FormatToken::kw_attr_dict_w_keyword
:
3002 return parseAttrDictDirective(loc
, ctx
,
3003 /*withKeyword=*/true);
3004 case FormatToken::kw_functional_type
:
3005 return parseFunctionalTypeDirective(loc
, ctx
);
3006 case FormatToken::kw_operands
:
3007 return parseOperandsDirective(loc
, ctx
);
3008 case FormatToken::kw_qualified
:
3009 return parseQualifiedDirective(loc
, ctx
);
3010 case FormatToken::kw_regions
:
3011 return parseRegionsDirective(loc
, ctx
);
3012 case FormatToken::kw_results
:
3013 return parseResultsDirective(loc
, ctx
);
3014 case FormatToken::kw_successors
:
3015 return parseSuccessorsDirective(loc
, ctx
);
3016 case FormatToken::kw_ref
:
3017 return parseReferenceDirective(loc
, ctx
);
3018 case FormatToken::kw_type
:
3019 return parseTypeDirective(loc
, ctx
);
3020 case FormatToken::kw_oilist
:
3021 return parseOIListDirective(loc
, ctx
);
3024 return emitError(loc
, "unsupported directive kind");
3028 FailureOr
<FormatElement
*>
3029 OpFormatParser::parseAttrDictDirective(SMLoc loc
, Context context
,
3031 if (context
== TypeDirectiveContext
)
3032 return emitError(loc
, "'attr-dict' directive can only be used as a "
3033 "top-level directive");
3035 if (context
== RefDirectiveContext
) {
3037 return emitError(loc
, "'ref' of 'attr-dict' is not bound by a prior "
3038 "'attr-dict' directive");
3040 // Otherwise, this is a top-level context.
3043 return emitError(loc
, "'attr-dict' directive has already been seen");
3047 return create
<AttrDictDirective
>(withKeyword
);
3050 FailureOr
<FormatElement
*>
3051 OpFormatParser::parsePropDictDirective(SMLoc loc
, Context context
) {
3052 if (context
== TypeDirectiveContext
)
3053 return emitError(loc
, "'prop-dict' directive can only be used as a "
3054 "top-level directive");
3056 if (context
== RefDirectiveContext
)
3057 llvm::report_fatal_error("'ref' of 'prop-dict' unsupported");
3058 // Otherwise, this is a top-level context.
3061 return emitError(loc
, "'prop-dict' directive has already been seen");
3064 return create
<PropDictDirective
>();
3067 LogicalResult
OpFormatParser::verifyCustomDirectiveArguments(
3068 SMLoc loc
, ArrayRef
<FormatElement
*> arguments
) {
3069 for (FormatElement
*argument
: arguments
) {
3070 if (!isa
<StringElement
, RefDirective
, TypeDirective
, AttrDictDirective
,
3071 AttributeVariable
, OperandVariable
, RegionVariable
,
3072 SuccessorVariable
>(argument
)) {
3073 // TODO: FormatElement should have location info attached.
3074 return emitError(loc
, "only variables and types may be used as "
3075 "parameters to a custom directive");
3077 if (auto *type
= dyn_cast
<TypeDirective
>(argument
)) {
3078 if (!isa
<OperandVariable
, ResultVariable
>(type
->getArg())) {
3079 return emitError(loc
, "type directives within a custom directive may "
3080 "only refer to variables");
3087 FailureOr
<FormatElement
*>
3088 OpFormatParser::parseFunctionalTypeDirective(SMLoc loc
, Context context
) {
3089 if (context
!= TopLevelContext
)
3091 loc
, "'functional-type' is only valid as a top-level directive");
3093 // Parse the main operand.
3094 FailureOr
<FormatElement
*> inputs
, results
;
3095 if (failed(parseToken(FormatToken::l_paren
,
3096 "expected '(' before argument list")) ||
3097 failed(inputs
= parseTypeDirectiveOperand(loc
)) ||
3098 failed(parseToken(FormatToken::comma
,
3099 "expected ',' after inputs argument")) ||
3100 failed(results
= parseTypeDirectiveOperand(loc
)) ||
3102 parseToken(FormatToken::r_paren
, "expected ')' after argument list")))
3104 return create
<FunctionalTypeDirective
>(*inputs
, *results
);
3107 FailureOr
<FormatElement
*>
3108 OpFormatParser::parseOperandsDirective(SMLoc loc
, Context context
) {
3109 if (context
== RefDirectiveContext
) {
3110 if (!fmt
.allOperands
)
3111 return emitError(loc
, "'ref' of 'operands' is not bound by a prior "
3112 "'operands' directive");
3114 } else if (context
== TopLevelContext
|| context
== CustomDirectiveContext
) {
3115 if (fmt
.allOperands
|| !seenOperands
.empty())
3116 return emitError(loc
, "'operands' directive creates overlap in format");
3117 fmt
.allOperands
= true;
3119 return create
<OperandsDirective
>();
3122 FailureOr
<FormatElement
*>
3123 OpFormatParser::parseReferenceDirective(SMLoc loc
, Context context
) {
3124 if (context
!= CustomDirectiveContext
)
3125 return emitError(loc
, "'ref' is only valid within a `custom` directive");
3127 FailureOr
<FormatElement
*> arg
;
3128 if (failed(parseToken(FormatToken::l_paren
,
3129 "expected '(' before argument list")) ||
3130 failed(arg
= parseElement(RefDirectiveContext
)) ||
3132 parseToken(FormatToken::r_paren
, "expected ')' after argument list")))
3135 return create
<RefDirective
>(*arg
);
3138 FailureOr
<FormatElement
*>
3139 OpFormatParser::parseRegionsDirective(SMLoc loc
, Context context
) {
3140 if (context
== TypeDirectiveContext
)
3141 return emitError(loc
, "'regions' is only valid as a top-level directive");
3142 if (context
== RefDirectiveContext
) {
3144 return emitError(loc
, "'ref' of 'regions' is not bound by a prior "
3145 "'regions' directive");
3147 // Otherwise, this is a TopLevel directive.
3149 if (hasAllRegions
|| !seenRegions
.empty())
3150 return emitError(loc
, "'regions' directive creates overlap in format");
3151 hasAllRegions
= true;
3153 return create
<RegionsDirective
>();
3156 FailureOr
<FormatElement
*>
3157 OpFormatParser::parseResultsDirective(SMLoc loc
, Context context
) {
3158 if (context
!= TypeDirectiveContext
)
3159 return emitError(loc
, "'results' directive can can only be used as a child "
3160 "to a 'type' directive");
3161 return create
<ResultsDirective
>();
3164 FailureOr
<FormatElement
*>
3165 OpFormatParser::parseSuccessorsDirective(SMLoc loc
, Context context
) {
3166 if (context
== TypeDirectiveContext
)
3167 return emitError(loc
,
3168 "'successors' is only valid as a top-level directive");
3169 if (context
== RefDirectiveContext
) {
3170 if (!hasAllSuccessors
)
3171 return emitError(loc
, "'ref' of 'successors' is not bound by a prior "
3172 "'successors' directive");
3174 // Otherwise, this is a TopLevel directive.
3176 if (hasAllSuccessors
|| !seenSuccessors
.empty())
3177 return emitError(loc
, "'successors' directive creates overlap in format");
3178 hasAllSuccessors
= true;
3180 return create
<SuccessorsDirective
>();
3183 FailureOr
<FormatElement
*>
3184 OpFormatParser::parseOIListDirective(SMLoc loc
, Context context
) {
3185 if (failed(parseToken(FormatToken::l_paren
,
3186 "expected '(' before oilist argument list")))
3188 std::vector
<FormatElement
*> literalElements
;
3189 std::vector
<std::vector
<FormatElement
*>> parsingElements
;
3191 FailureOr
<FormatElement
*> lelement
= parseLiteral(context
);
3192 if (failed(lelement
))
3194 literalElements
.push_back(*lelement
);
3195 parsingElements
.emplace_back();
3196 std::vector
<FormatElement
*> &currParsingElements
= parsingElements
.back();
3197 while (peekToken().getKind() != FormatToken::pipe
&&
3198 peekToken().getKind() != FormatToken::r_paren
) {
3199 FailureOr
<FormatElement
*> pelement
= parseElement(context
);
3200 if (failed(pelement
) ||
3201 failed(verifyOIListParsingElement(*pelement
, loc
)))
3203 currParsingElements
.push_back(*pelement
);
3205 if (peekToken().getKind() == FormatToken::pipe
) {
3209 if (peekToken().getKind() == FormatToken::r_paren
) {
3215 return create
<OIListElement
>(std::move(literalElements
),
3216 std::move(parsingElements
));
3219 LogicalResult
OpFormatParser::verifyOIListParsingElement(FormatElement
*element
,
3221 SmallVector
<VariableElement
*> vars
;
3222 collect(element
, vars
);
3223 for (VariableElement
*elem
: vars
) {
3225 TypeSwitch
<FormatElement
*, LogicalResult
>(elem
)
3226 // Only optional attributes can be within an oilist parsing group.
3227 .Case([&](AttributeVariable
*attrEle
) {
3228 if (!attrEle
->getVar()->attr
.isOptional() &&
3229 !attrEle
->getVar()->attr
.hasDefaultValue())
3230 return emitError(loc
, "only optional attributes can be used in "
3231 "an oilist parsing group");
3234 // Only optional-like(i.e. variadic) operands can be within an
3235 // oilist parsing group.
3236 .Case([&](OperandVariable
*ele
) {
3237 if (!ele
->getVar()->isVariableLength())
3238 return emitError(loc
, "only variable length operands can be "
3239 "used within an oilist parsing group");
3242 // Only optional-like(i.e. variadic) results can be within an oilist
3244 .Case([&](ResultVariable
*ele
) {
3245 if (!ele
->getVar()->isVariableLength())
3246 return emitError(loc
, "only variable length results can be "
3247 "used within an oilist parsing group");
3250 .Case([&](RegionVariable
*) { return success(); })
3251 .Default([&](FormatElement
*) {
3252 return emitError(loc
,
3253 "only literals, types, and variables can be "
3254 "used within an oilist group");
3262 FailureOr
<FormatElement
*> OpFormatParser::parseTypeDirective(SMLoc loc
,
3264 if (context
== TypeDirectiveContext
)
3265 return emitError(loc
, "'type' cannot be used as a child of another `type`");
3267 bool isRefChild
= context
== RefDirectiveContext
;
3268 FailureOr
<FormatElement
*> operand
;
3269 if (failed(parseToken(FormatToken::l_paren
,
3270 "expected '(' before argument list")) ||
3271 failed(operand
= parseTypeDirectiveOperand(loc
, isRefChild
)) ||
3273 parseToken(FormatToken::r_paren
, "expected ')' after argument list")))
3276 return create
<TypeDirective
>(*operand
);
3279 FailureOr
<FormatElement
*>
3280 OpFormatParser::parseQualifiedDirective(SMLoc loc
, Context context
) {
3281 FailureOr
<FormatElement
*> element
;
3282 if (failed(parseToken(FormatToken::l_paren
,
3283 "expected '(' before argument list")) ||
3284 failed(element
= parseElement(context
)) ||
3286 parseToken(FormatToken::r_paren
, "expected ')' after argument list")))
3288 return TypeSwitch
<FormatElement
*, FailureOr
<FormatElement
*>>(*element
)
3289 .Case
<AttributeVariable
, TypeDirective
>([](auto *element
) {
3290 element
->setShouldBeQualified();
3293 .Default([&](auto *element
) {
3294 return this->emitError(
3296 "'qualified' directive expects an attribute or a `type` directive");
3300 FailureOr
<FormatElement
*>
3301 OpFormatParser::parseTypeDirectiveOperand(SMLoc loc
, bool isRefChild
) {
3302 FailureOr
<FormatElement
*> result
= parseElement(TypeDirectiveContext
);
3306 FormatElement
*element
= *result
;
3307 if (isa
<LiteralElement
>(element
))
3309 loc
, "'type' directive operand expects variable or directive operand");
3311 if (auto *var
= dyn_cast
<OperandVariable
>(element
)) {
3312 unsigned opIdx
= var
->getVar() - op
.operand_begin();
3313 if (!isRefChild
&& (fmt
.allOperandTypes
|| seenOperandTypes
.test(opIdx
)))
3314 return emitError(loc
, "'type' of '" + var
->getVar()->name
+
3315 "' is already bound");
3316 if (isRefChild
&& !(fmt
.allOperandTypes
|| seenOperandTypes
.test(opIdx
)))
3317 return emitError(loc
, "'ref' of 'type($" + var
->getVar()->name
+
3318 ")' is not bound by a prior 'type' directive");
3319 seenOperandTypes
.set(opIdx
);
3320 } else if (auto *var
= dyn_cast
<ResultVariable
>(element
)) {
3321 unsigned resIdx
= var
->getVar() - op
.result_begin();
3322 if (!isRefChild
&& (fmt
.allResultTypes
|| seenResultTypes
.test(resIdx
)))
3323 return emitError(loc
, "'type' of '" + var
->getVar()->name
+
3324 "' is already bound");
3325 if (isRefChild
&& !(fmt
.allResultTypes
|| seenResultTypes
.test(resIdx
)))
3326 return emitError(loc
, "'ref' of 'type($" + var
->getVar()->name
+
3327 ")' is not bound by a prior 'type' directive");
3328 seenResultTypes
.set(resIdx
);
3329 } else if (isa
<OperandsDirective
>(&*element
)) {
3330 if (!isRefChild
&& (fmt
.allOperandTypes
|| seenOperandTypes
.any()))
3331 return emitError(loc
, "'operands' 'type' is already bound");
3332 if (isRefChild
&& !fmt
.allOperandTypes
)
3333 return emitError(loc
, "'ref' of 'type(operands)' is not bound by a prior "
3334 "'type' directive");
3335 fmt
.allOperandTypes
= true;
3336 } else if (isa
<ResultsDirective
>(&*element
)) {
3337 if (!isRefChild
&& (fmt
.allResultTypes
|| seenResultTypes
.any()))
3338 return emitError(loc
, "'results' 'type' is already bound");
3339 if (isRefChild
&& !fmt
.allResultTypes
)
3340 return emitError(loc
, "'ref' of 'type(results)' is not bound by a prior "
3341 "'type' directive");
3342 fmt
.allResultTypes
= true;
3344 return emitError(loc
, "invalid argument to 'type' directive");
3349 LogicalResult
OpFormatParser::verifyOptionalGroupElements(
3350 SMLoc loc
, ArrayRef
<FormatElement
*> elements
, FormatElement
*anchor
) {
3351 for (FormatElement
*element
: elements
) {
3352 if (failed(verifyOptionalGroupElement(loc
, element
, element
== anchor
)))
3358 LogicalResult
OpFormatParser::verifyOptionalGroupElement(SMLoc loc
,
3359 FormatElement
*element
,
3361 return TypeSwitch
<FormatElement
*, LogicalResult
>(element
)
3362 // All attributes can be within the optional group, but only optional
3363 // attributes can be the anchor.
3364 .Case([&](AttributeVariable
*attrEle
) {
3365 Attribute attr
= attrEle
->getVar()->attr
;
3366 if (isAnchor
&& !(attr
.isOptional() || attr
.hasDefaultValue()))
3367 return emitError(loc
, "only optional or default-valued attributes "
3368 "can be used to anchor an optional group");
3371 // Only optional-like(i.e. variadic) operands can be within an optional
3373 .Case([&](OperandVariable
*ele
) {
3374 if (!ele
->getVar()->isVariableLength())
3375 return emitError(loc
, "only variable length operands can be used "
3376 "within an optional group");
3379 // Only optional-like(i.e. variadic) results can be within an optional
3381 .Case([&](ResultVariable
*ele
) {
3382 if (!ele
->getVar()->isVariableLength())
3383 return emitError(loc
, "only variable length results can be used "
3384 "within an optional group");
3387 .Case([&](RegionVariable
*) {
3388 // TODO: When ODS has proper support for marking "optional" regions, add
3392 .Case([&](TypeDirective
*ele
) {
3393 return verifyOptionalGroupElement(loc
, ele
->getArg(),
3394 /*isAnchor=*/false);
3396 .Case([&](FunctionalTypeDirective
*ele
) {
3397 if (failed(verifyOptionalGroupElement(loc
, ele
->getInputs(),
3398 /*isAnchor=*/false)))
3400 return verifyOptionalGroupElement(loc
, ele
->getResults(),
3401 /*isAnchor=*/false);
3403 // Literals, whitespace, and custom directives may be used, but they can't
3404 // anchor the group.
3405 .Case
<LiteralElement
, WhitespaceElement
, CustomDirective
,
3406 FunctionalTypeDirective
, OptionalElement
>([&](FormatElement
*) {
3408 return emitError(loc
, "only variables and types can be used "
3409 "to anchor an optional group");
3412 .Default([&](FormatElement
*) {
3413 return emitError(loc
, "only literals, types, and variables can be "
3414 "used within an optional group");
3418 //===----------------------------------------------------------------------===//
3420 //===----------------------------------------------------------------------===//
3422 void mlir::tblgen::generateOpFormat(const Operator
&constOp
, OpClass
&opClass
) {
3423 // TODO: Operator doesn't expose all necessary functionality via
3424 // the const interface.
3425 Operator
&op
= const_cast<Operator
&>(constOp
);
3426 if (!op
.hasAssemblyFormat())
3429 // Parse the format description.
3430 llvm::SourceMgr mgr
;
3431 mgr
.AddNewSourceBuffer(
3432 llvm::MemoryBuffer::getMemBuffer(op
.getAssemblyFormat()), SMLoc());
3433 OperationFormat
format(op
);
3434 OpFormatParser
parser(mgr
, format
, op
);
3435 FailureOr
<std::vector
<FormatElement
*>> elements
= parser
.parse();
3436 if (failed(elements
)) {
3437 // Exit the process if format errors are treated as fatal.
3438 if (formatErrorIsFatal
) {
3439 // Invoke the interrupt handlers to run the file cleanup handlers.
3440 llvm::sys::RunInterruptHandlers();
3445 format
.elements
= std::move(*elements
);
3447 // Generate the printer and parser based on the parsed format.
3448 format
.genParser(op
, opClass
);
3449 format
.genPrinter(op
, opClass
);