1 //===- OpFormatGen.cpp - MLIR operation asm format generator --------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "OpFormatGen.h"
10 #include "FormatGen.h"
12 #include "mlir/Support/LLVM.h"
13 #include "mlir/TableGen/Class.h"
14 #include "mlir/TableGen/Format.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "mlir/TableGen/Trait.h"
17 #include "llvm/ADT/MapVector.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallBitVector.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Signals.h"
24 #include "llvm/Support/SourceMgr.h"
25 #include "llvm/TableGen/Record.h"
27 #define DEBUG_TYPE "mlir-tblgen-opformatgen"
30 using namespace mlir::tblgen
;
33 using llvm::StringMap
;
35 //===----------------------------------------------------------------------===//
39 /// This class represents an instance of an op variable element. A variable
40 /// refers to something registered on the operation itself, e.g. an operand,
41 /// result, attribute, region, or successor.
42 template <typename VarT
, VariableElement::Kind VariableKind
>
43 class OpVariableElement
: public VariableElementBase
<VariableKind
> {
45 using Base
= OpVariableElement
<VarT
, VariableKind
>;
47 /// Create an op variable element with the variable value.
48 OpVariableElement(const VarT
*var
) : var(var
) {}
51 const VarT
*getVar() const { return var
; }
54 /// The op variable, e.g. a type or attribute constraint.
58 /// This class represents a variable that refers to an attribute argument.
59 struct AttributeVariable
60 : public OpVariableElement
<NamedAttribute
, VariableElement::Attribute
> {
63 /// Return the constant builder call for the type of this attribute, or
64 /// std::nullopt if it doesn't have one.
65 std::optional
<StringRef
> getTypeBuilder() const {
66 std::optional
<Type
> attrType
= var
->attr
.getValueType();
67 return attrType
? attrType
->getBuilderCall() : std::nullopt
;
70 /// Indicate if this attribute is printed "qualified" (that is it is
71 /// prefixed with the `#dialect.mnemonic`).
72 bool shouldBeQualified() { return shouldBeQualifiedFlag
; }
73 void setShouldBeQualified(bool qualified
= true) {
74 shouldBeQualifiedFlag
= qualified
;
78 bool shouldBeQualifiedFlag
= false;
81 /// This class represents a variable that refers to an operand argument.
82 using OperandVariable
=
83 OpVariableElement
<NamedTypeConstraint
, VariableElement::Operand
>;
85 /// This class represents a variable that refers to a result.
86 using ResultVariable
=
87 OpVariableElement
<NamedTypeConstraint
, VariableElement::Result
>;
89 /// This class represents a variable that refers to a region.
90 using RegionVariable
= OpVariableElement
<NamedRegion
, VariableElement::Region
>;
92 /// This class represents a variable that refers to a successor.
93 using SuccessorVariable
=
94 OpVariableElement
<NamedSuccessor
, VariableElement::Successor
>;
96 /// This class represents a variable that refers to a property argument.
97 using PropertyVariable
=
98 OpVariableElement
<NamedProperty
, VariableElement::Property
>;
100 /// LLVM RTTI helper for attribute-like variables, that is, attributes or
101 /// properties. This allows for common handling of attributes and properties in
102 /// parts of the code that are oblivious to whether something is stored as an
103 /// attribute or a property.
104 struct AttributeLikeVariable
: public VariableElement
{
105 enum { AttributeLike
= 1 << 0 };
107 static bool classof(const VariableElement
*ve
) {
108 return ve
->getKind() == VariableElement::Attribute
||
109 ve
->getKind() == VariableElement::Property
;
112 static bool classof(const FormatElement
*fe
) {
113 return isa
<VariableElement
>(fe
) && classof(cast
<VariableElement
>(fe
));
116 /// Returns true if the variable is a UnitAttr or a UnitProperty.
117 bool isUnit() const {
118 if (const auto *attr
= dyn_cast
<AttributeVariable
>(this))
119 return attr
->getVar()->attr
.getBaseAttr().getAttrDefName() == "UnitAttr";
120 if (const auto *prop
= dyn_cast
<PropertyVariable
>(this)) {
121 return prop
->getVar()->prop
.getBaseProperty().getPropertyDefName() ==
124 llvm_unreachable("Type that wasn't listed in classof()");
127 StringRef
getName() const {
128 if (const auto *attr
= dyn_cast
<AttributeVariable
>(this))
129 return attr
->getVar()->name
;
130 if (const auto *prop
= dyn_cast
<PropertyVariable
>(this))
131 return prop
->getVar()->name
;
132 llvm_unreachable("Type that wasn't listed in classof()");
137 //===----------------------------------------------------------------------===//
141 /// This class represents the `operands` directive. This directive represents
142 /// all of the operands of an operation.
143 using OperandsDirective
= DirectiveElementBase
<DirectiveElement::Operands
>;
145 /// This class represents the `results` directive. This directive represents
146 /// all of the results of an operation.
147 using ResultsDirective
= DirectiveElementBase
<DirectiveElement::Results
>;
149 /// This class represents the `regions` directive. This directive represents
150 /// all of the regions of an operation.
151 using RegionsDirective
= DirectiveElementBase
<DirectiveElement::Regions
>;
153 /// This class represents the `successors` directive. This directive represents
154 /// all of the successors of an operation.
155 using SuccessorsDirective
= DirectiveElementBase
<DirectiveElement::Successors
>;
157 /// This class represents the `attr-dict` directive. This directive represents
158 /// the attribute dictionary of the operation.
159 class AttrDictDirective
160 : public DirectiveElementBase
<DirectiveElement::AttrDict
> {
162 explicit AttrDictDirective(bool withKeyword
) : withKeyword(withKeyword
) {}
164 /// Return whether the dictionary should be printed with the 'attributes'
166 bool isWithKeyword() const { return withKeyword
; }
169 /// If the dictionary should be printed with the 'attributes' keyword.
173 /// This class represents the `prop-dict` directive. This directive represents
174 /// the properties of the operation, expressed as a directionary.
175 class PropDictDirective
176 : public DirectiveElementBase
<DirectiveElement::PropDict
> {
178 explicit PropDictDirective() = default;
181 /// This class represents the `functional-type` directive. This directive takes
182 /// two arguments and formats them, respectively, as the inputs and results of a
184 class FunctionalTypeDirective
185 : public DirectiveElementBase
<DirectiveElement::FunctionalType
> {
187 FunctionalTypeDirective(FormatElement
*inputs
, FormatElement
*results
)
188 : inputs(inputs
), results(results
) {}
190 FormatElement
*getInputs() const { return inputs
; }
191 FormatElement
*getResults() const { return results
; }
194 /// The input and result arguments.
195 FormatElement
*inputs
, *results
;
198 /// This class represents the `type` directive.
199 class TypeDirective
: public DirectiveElementBase
<DirectiveElement::Type
> {
201 TypeDirective(FormatElement
*arg
) : arg(arg
) {}
203 FormatElement
*getArg() const { return arg
; }
205 /// Indicate if this type is printed "qualified" (that is it is
206 /// prefixed with the `!dialect.mnemonic`).
207 bool shouldBeQualified() { return shouldBeQualifiedFlag
; }
208 void setShouldBeQualified(bool qualified
= true) {
209 shouldBeQualifiedFlag
= qualified
;
213 /// The argument that is used to format the directive.
216 bool shouldBeQualifiedFlag
= false;
219 /// This class represents a group of order-independent optional clauses. Each
220 /// clause starts with a literal element and has a coressponding parsing
221 /// element. A parsing element is a continous sequence of format elements.
222 /// Each clause can appear 0 or 1 time.
223 class OIListElement
: public DirectiveElementBase
<DirectiveElement::OIList
> {
225 OIListElement(std::vector
<FormatElement
*> &&literalElements
,
226 std::vector
<std::vector
<FormatElement
*>> &&parsingElements
)
227 : literalElements(std::move(literalElements
)),
228 parsingElements(std::move(parsingElements
)) {}
230 /// Returns a range to iterate over the LiteralElements.
231 auto getLiteralElements() const {
232 return llvm::map_range(literalElements
, [](FormatElement
*el
) {
233 return cast
<LiteralElement
>(el
);
237 /// Returns a range to iterate over the parsing elements corresponding to the
239 ArrayRef
<std::vector
<FormatElement
*>> getParsingElements() const {
240 return parsingElements
;
243 /// Returns a range to iterate over tuples of parsing and literal elements.
244 auto getClauses() const {
245 return llvm::zip(getLiteralElements(), getParsingElements());
248 /// If the parsing element is a single UnitAttr element, then it returns the
249 /// attribute variable. Otherwise, returns nullptr.
250 AttributeLikeVariable
*
251 getUnitVariableParsingElement(ArrayRef
<FormatElement
*> pelement
) {
252 if (pelement
.size() == 1) {
253 auto *attrElem
= dyn_cast
<AttributeLikeVariable
>(pelement
[0]);
254 if (attrElem
&& attrElem
->isUnit())
261 /// A vector of `LiteralElement` objects. Each element stores the keyword
262 /// for one case of oilist element. For example, an oilist element along with
263 /// the `literalElements` vector:
265 /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`]
266 /// literalElements = { `keyword`, `otherKeyword` }
268 std::vector
<FormatElement
*> literalElements
;
270 /// A vector of valid declarative assembly format vectors. Each object in
271 /// parsing elements is a vector of elements in assembly format syntax.
272 /// For example, an oilist element along with the parsingElements vector:
274 /// oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`]
275 /// parsingElements = {
276 /// { `=`, `(`, $arg0, `)` },
277 /// { `<`, $arg1, `>` }
280 std::vector
<std::vector
<FormatElement
*>> parsingElements
;
284 //===----------------------------------------------------------------------===//
286 //===----------------------------------------------------------------------===//
290 using ConstArgument
=
291 llvm::PointerUnion
<const NamedAttribute
*, const NamedTypeConstraint
*>;
293 struct OperationFormat
{
294 /// This class represents a specific resolver for an operand or result type.
295 class TypeResolution
{
297 TypeResolution() = default;
299 /// Get the index into the buildable types for this type, or std::nullopt.
300 std::optional
<int> getBuilderIdx() const { return builderIdx
; }
301 void setBuilderIdx(int idx
) { builderIdx
= idx
; }
303 /// Get the variable this type is resolved to, or nullptr.
304 const NamedTypeConstraint
*getVariable() const {
305 return llvm::dyn_cast_if_present
<const NamedTypeConstraint
*>(resolver
);
307 /// Get the attribute this type is resolved to, or nullptr.
308 const NamedAttribute
*getAttribute() const {
309 return llvm::dyn_cast_if_present
<const NamedAttribute
*>(resolver
);
311 /// Get the transformer for the type of the variable, or std::nullopt.
312 std::optional
<StringRef
> getVarTransformer() const {
313 return variableTransformer
;
315 void setResolver(ConstArgument arg
, std::optional
<StringRef
> transformer
) {
317 variableTransformer
= transformer
;
318 assert(getVariable() || getAttribute());
322 /// If the type is resolved with a buildable type, this is the index into
323 /// 'buildableTypes' in the parent format.
324 std::optional
<int> builderIdx
;
325 /// If the type is resolved based upon another operand or result, this is
326 /// the variable or the attribute that this type is resolved to.
327 ConstArgument resolver
;
328 /// If the type is resolved based upon another operand or result, this is
329 /// a transformer to apply to the variable when resolving.
330 std::optional
<StringRef
> variableTransformer
;
333 /// The context in which an element is generated.
334 enum class GenContext
{
335 /// The element is generated at the top-level or with the same behaviour.
337 /// The element is generated inside an optional group.
341 OperationFormat(const Operator
&op
, bool hasProperties
)
342 : useProperties(hasProperties
), opCppClassName(op
.getCppClassName()) {
343 operandTypes
.resize(op
.getNumOperands(), TypeResolution());
344 resultTypes
.resize(op
.getNumResults(), TypeResolution());
346 hasImplicitTermTrait
= llvm::any_of(op
.getTraits(), [](const Trait
&trait
) {
347 return trait
.getDef().isSubClassOf("SingleBlockImplicitTerminatorImpl");
350 hasSingleBlockTrait
= op
.getTrait("::mlir::OpTrait::SingleBlock");
353 /// Generate the operation parser from this format.
354 void genParser(Operator
&op
, OpClass
&opClass
);
355 /// Generate the parser code for a specific format element.
356 void genElementParser(FormatElement
*element
, MethodBody
&body
,
357 FmtContext
&attrTypeCtx
,
358 GenContext genCtx
= GenContext::Normal
);
359 /// Generate the C++ to resolve the types of operands and results during
361 void genParserTypeResolution(Operator
&op
, MethodBody
&body
);
362 /// Generate the C++ to resolve the types of the operands during parsing.
363 void genParserOperandTypeResolution(
364 Operator
&op
, MethodBody
&body
,
365 function_ref
<void(TypeResolution
&, StringRef
)> emitTypeResolver
);
366 /// Generate the C++ to resolve regions during parsing.
367 void genParserRegionResolution(Operator
&op
, MethodBody
&body
);
368 /// Generate the C++ to resolve successors during parsing.
369 void genParserSuccessorResolution(Operator
&op
, MethodBody
&body
);
370 /// Generate the C++ to handling variadic segment size traits.
371 void genParserVariadicSegmentResolution(Operator
&op
, MethodBody
&body
);
373 /// Generate the operation printer from this format.
374 void genPrinter(Operator
&op
, OpClass
&opClass
);
376 /// Generate the printer code for a specific format element.
377 void genElementPrinter(FormatElement
*element
, MethodBody
&body
, Operator
&op
,
378 bool &shouldEmitSpace
, bool &lastWasPunctuation
);
380 /// The various elements in this format.
381 std::vector
<FormatElement
*> elements
;
383 /// A flag indicating if all operand/result types were seen. If the format
384 /// contains these, it can not contain individual type resolvers.
385 bool allOperands
= false, allOperandTypes
= false, allResultTypes
= false;
387 /// A flag indicating if this operation infers its result types
388 bool infersResultTypes
= false;
390 /// A flag indicating if this operation has the SingleBlockImplicitTerminator
392 bool hasImplicitTermTrait
;
394 /// A flag indicating if this operation has the SingleBlock trait.
395 bool hasSingleBlockTrait
;
397 /// Indicate whether we need to use properties for the current operator.
400 /// Indicate whether prop-dict is used in the format
403 /// The Operation class name
404 StringRef opCppClassName
;
406 /// A map of buildable types to indices.
407 llvm::MapVector
<StringRef
, int, StringMap
<int>> buildableTypes
;
409 /// The index of the buildable type, if valid, for every operand and result.
410 std::vector
<TypeResolution
> operandTypes
, resultTypes
;
412 /// The set of attributes explicitly used within the format.
413 llvm::SmallSetVector
<const NamedAttribute
*, 8> usedAttributes
;
414 llvm::StringSet
<> inferredAttributes
;
416 /// The set of properties explicitly used within the format.
417 llvm::SmallSetVector
<const NamedProperty
*, 8> usedProperties
;
421 //===----------------------------------------------------------------------===//
424 /// Returns true if we can format the given attribute as an EnumAttr in the
426 static bool canFormatEnumAttr(const NamedAttribute
*attr
) {
427 Attribute baseAttr
= attr
->attr
.getBaseAttr();
428 const EnumAttr
*enumAttr
= dyn_cast
<EnumAttr
>(&baseAttr
);
432 // The attribute must have a valid underlying type and a constant builder.
433 return !enumAttr
->getUnderlyingType().empty() &&
434 !enumAttr
->getConstBuilderTemplate().empty();
437 /// Returns if we should format the given attribute as an SymbolNameAttr.
438 static bool shouldFormatSymbolNameAttr(const NamedAttribute
*attr
) {
439 return attr
->attr
.getBaseAttr().getAttrDefName() == "SymbolNameAttr";
442 /// The code snippet used to generate a parser call for an attribute.
444 /// {0}: The name of the attribute.
445 /// {1}: The type for the attribute.
446 const char *const attrParserCode
= R
"(
447 if (parser.parseCustomAttributeWithFallback({0}Attr, {1})) {{
448 return ::mlir::failure();
452 /// The code snippet used to generate a parser call for an attribute.
454 /// {0}: The name of the attribute.
455 /// {1}: The type for the attribute.
456 const char *const genericAttrParserCode
= R
"(
457 if (parser.parseAttribute({0}Attr, {1}))
458 return ::mlir::failure();
461 const char *const optionalAttrParserCode
= R
"(
462 ::mlir::OptionalParseResult parseResult{0}Attr =
463 parser.parseOptionalAttribute({0}Attr, {1});
464 if (parseResult{0}Attr.has_value() && failed(*parseResult{0}Attr))
465 return ::mlir::failure();
466 if (parseResult{0}Attr.has_value() && succeeded(*parseResult{0}Attr))
469 /// The code snippet used to generate a parser call for a symbol name attribute.
471 /// {0}: The name of the attribute.
472 const char *const symbolNameAttrParserCode
= R
"(
473 if (parser.parseSymbolName({0}Attr))
474 return ::mlir::failure();
476 const char *const optionalSymbolNameAttrParserCode
= R
"(
477 // Parsing an optional symbol name doesn't fail, so no need to check the
479 (void)parser.parseOptionalSymbolName({0}Attr);
482 /// The code snippet used to generate a parser call for an enum attribute.
484 /// {0}: The name of the attribute.
485 /// {1}: The c++ namespace for the enum symbolize functions.
486 /// {2}: The function to symbolize a string of the enum.
487 /// {3}: The constant builder call to create an attribute of the enum type.
488 /// {4}: The set of allowed enum keywords.
489 /// {5}: The error message on failure when the enum isn't present.
490 /// {6}: The attribute assignment expression
491 const char *const enumAttrParserCode
= R
"(
493 ::llvm::StringRef attrStr;
494 ::mlir::NamedAttrList attrStorage;
495 auto loc = parser.getCurrentLocation();
496 if (parser.parseOptionalKeyword(&attrStr, {4})) {
497 ::mlir::StringAttr attrVal;
498 ::mlir::OptionalParseResult parseResult =
499 parser.parseOptionalAttribute(attrVal,
500 parser.getBuilder().getNoneType(),
502 if (parseResult.has_value()) {{
503 if (failed(*parseResult))
504 return ::mlir::failure();
505 attrStr = attrVal.getValue();
510 if (!attrStr.empty()) {
511 auto attrOptional = {1}::{2}(attrStr);
513 return parser.emitError(loc, "invalid
")
514 << "{0} attribute specification
: \"" << attrStr << '"';;
522 /// The code snippet used to generate a parser call for a property.
523 /// {0}: The name of the property
524 /// {1}: The C++ class name of the operation
525 /// {2}: The property's parser code with appropriate substitutions performed
526 /// {3}: The description of the expected property for the error message.
527 const char *const propertyParserCode
= R
"(
528 auto {0}PropLoc = parser.getCurrentLocation();
529 auto {0}PropParseResult = [&](auto& propStorage) -> ::mlir::ParseResult {{
531 return ::mlir::success();
532 }(result.getOrAddProperties<{1}::Properties>().{0});
533 if (failed({0}PropParseResult)) {{
534 return parser.emitError({0}PropLoc, "invalid value
for property
{0}, expected
{3}");
538 /// The code snippet used to generate a parser call for a property.
539 /// {0}: The name of the property
540 /// {1}: The C++ class name of the operation
541 /// {2}: The property's parser code with appropriate substitutions performed
542 const char *const optionalPropertyParserCode
= R
"(
543 auto {0}PropParseResult = [&](auto& propStorage) -> ::mlir::OptionalParseResult {{
545 return ::mlir::success();
546 }(result.getOrAddProperties<{1}::Properties>().{0});
547 if ({0}PropParseResult.has_value() && failed(*{0}PropParseResult)) {{
548 return ::mlir::failure();
552 /// The code snippet used to generate a parser call for an operand.
554 /// {0}: The name of the operand.
555 const char *const variadicOperandParserCode
= R
"(
556 {0}OperandsLoc = parser.getCurrentLocation();
557 if (parser.parseOperandList({0}Operands))
558 return ::mlir::failure();
560 const char *const optionalOperandParserCode
= R
"(
562 {0}OperandsLoc = parser.getCurrentLocation();
563 ::mlir::OpAsmParser::UnresolvedOperand operand;
564 ::mlir::OptionalParseResult parseResult =
565 parser.parseOptionalOperand(operand);
566 if (parseResult.has_value()) {
567 if (failed(*parseResult))
568 return ::mlir::failure();
569 {0}Operands.push_back(operand);
573 const char *const operandParserCode
= R
"(
574 {0}OperandsLoc = parser.getCurrentLocation();
575 if (parser.parseOperand({0}RawOperand))
576 return ::mlir::failure();
578 /// The code snippet used to generate a parser call for a VariadicOfVariadic
581 /// {0}: The name of the operand.
582 /// {1}: The name of segment size attribute.
583 const char *const variadicOfVariadicOperandParserCode
= R
"(
585 {0}OperandsLoc = parser.getCurrentLocation();
588 if (parser.parseOptionalLParen())
590 if (parser.parseOperandList({0}Operands) || parser.parseRParen())
591 return ::mlir::failure();
592 {0}OperandGroupSizes.push_back({0}Operands.size() - curSize);
593 curSize = {0}Operands.size();
594 } while (succeeded(parser.parseOptionalComma()));
598 /// The code snippet used to generate a parser call for a type list.
600 /// {0}: The name for the type list.
601 const char *const variadicOfVariadicTypeParserCode
= R
"(
603 if (parser.parseOptionalLParen())
605 if (parser.parseOptionalRParen() &&
606 (parser.parseTypeList({0}Types) || parser.parseRParen()))
607 return ::mlir::failure();
608 } while (succeeded(parser.parseOptionalComma()));
610 const char *const variadicTypeParserCode
= R
"(
611 if (parser.parseTypeList({0}Types))
612 return ::mlir::failure();
614 const char *const optionalTypeParserCode
= R
"(
616 ::mlir::Type optionalType;
617 ::mlir::OptionalParseResult parseResult =
618 parser.parseOptionalType(optionalType);
619 if (parseResult.has_value()) {
620 if (failed(*parseResult))
621 return ::mlir::failure();
622 {0}Types.push_back(optionalType);
626 const char *const typeParserCode
= R
"(
629 if (parser.parseCustomTypeWithFallback(type))
630 return ::mlir::failure();
634 const char *const qualifiedTypeParserCode
= R
"(
635 if (parser.parseType({1}RawType))
636 return ::mlir::failure();
639 /// The code snippet used to generate a parser call for a functional type.
641 /// {0}: The name for the input type list.
642 /// {1}: The name for the result type list.
643 const char *const functionalTypeParserCode
= R
"(
644 ::mlir::FunctionType {0}__{1}_functionType;
645 if (parser.parseType({0}__{1}_functionType))
646 return ::mlir::failure();
647 {0}Types = {0}__{1}_functionType.getInputs();
648 {1}Types = {0}__{1}_functionType.getResults();
651 /// The code snippet used to generate a parser call to infer return types.
653 /// {0}: The operation class name
654 const char *const inferReturnTypesParserCode
= R
"(
655 ::llvm::SmallVector<::mlir::Type> inferredReturnTypes;
656 if (::mlir::failed({0}::inferReturnTypes(parser.getContext(),
657 result.location, result.operands,
658 result.attributes.getDictionary(parser.getContext()),
659 result.getRawProperties(),
660 result.regions, inferredReturnTypes)))
661 return ::mlir::failure();
662 result.addTypes(inferredReturnTypes);
665 /// The code snippet used to generate a parser call for a region list.
667 /// {0}: The name for the region list.
668 const char *regionListParserCode
= R
"(
670 std::unique_ptr<::mlir::Region> region;
671 auto firstRegionResult = parser.parseOptionalRegion(region);
672 if (firstRegionResult.has_value()) {
673 if (failed(*firstRegionResult))
674 return ::mlir::failure();
675 {0}Regions.emplace_back(std::move(region));
677 // Parse any trailing regions.
678 while (succeeded(parser.parseOptionalComma())) {
679 region = std::make_unique<::mlir::Region>();
680 if (parser.parseRegion(*region))
681 return ::mlir::failure();
682 {0}Regions.emplace_back(std::move(region));
688 /// The code snippet used to ensure a list of regions have terminators.
690 /// {0}: The name of the region list.
691 const char *regionListEnsureTerminatorParserCode
= R
"(
692 for (auto ®ion : {0}Regions)
693 ensureTerminator(*region, parser.getBuilder(), result.location);
696 /// The code snippet used to ensure a list of regions have a block.
698 /// {0}: The name of the region list.
699 const char *regionListEnsureSingleBlockParserCode
= R
"(
700 for (auto ®ion : {0}Regions)
701 if (region->empty()) region->emplaceBlock();
704 /// The code snippet used to generate a parser call for an optional region.
706 /// {0}: The name of the region.
707 const char *optionalRegionParserCode
= R
"(
709 auto parseResult = parser.parseOptionalRegion(*{0}Region);
710 if (parseResult.has_value() && failed(*parseResult))
711 return ::mlir::failure();
715 /// The code snippet used to generate a parser call for a region.
717 /// {0}: The name of the region.
718 const char *regionParserCode
= R
"(
719 if (parser.parseRegion(*{0}Region))
720 return ::mlir::failure();
723 /// The code snippet used to ensure a region has a terminator.
725 /// {0}: The name of the region.
726 const char *regionEnsureTerminatorParserCode
= R
"(
727 ensureTerminator(*{0}Region, parser.getBuilder(), result.location);
730 /// The code snippet used to ensure a region has a block.
732 /// {0}: The name of the region.
733 const char *regionEnsureSingleBlockParserCode
= R
"(
734 if ({0}Region->empty()) {0}Region->emplaceBlock();
737 /// The code snippet used to generate a parser call for a successor list.
739 /// {0}: The name for the successor list.
740 const char *successorListParserCode
= R
"(
743 auto firstSucc = parser.parseOptionalSuccessor(succ);
744 if (firstSucc.has_value()) {
745 if (failed(*firstSucc))
746 return ::mlir::failure();
747 {0}Successors.emplace_back(succ);
749 // Parse any trailing successors.
750 while (succeeded(parser.parseOptionalComma())) {
751 if (parser.parseSuccessor(succ))
752 return ::mlir::failure();
753 {0}Successors.emplace_back(succ);
759 /// The code snippet used to generate a parser call for a successor.
761 /// {0}: The name of the successor.
762 const char *successorParserCode
= R
"(
763 if (parser.parseSuccessor({0}Successor))
764 return ::mlir::failure();
767 /// The code snippet used to generate a parser for OIList
769 /// {0}: literal keyword corresponding to a case for oilist
770 const char *oilistParserCode
= R
"(
772 return parser.emitError(parser.getNameLoc())
773 << "`
{0}` clause can appear at most once in the expansion of the
"
780 /// The type of length for a given parse argument.
781 enum class ArgumentLengthKind
{
782 /// The argument is a variadic of a variadic, and may contain 0->N range
785 /// The argument is variadic, and may contain 0->N elements.
787 /// The argument is optional, and may contain 0 or 1 elements.
789 /// The argument is a single element, i.e. always represents 1 element.
794 /// Get the length kind for the given constraint.
795 static ArgumentLengthKind
796 getArgumentLengthKind(const NamedTypeConstraint
*var
) {
797 if (var
->isOptional())
798 return ArgumentLengthKind::Optional
;
799 if (var
->isVariadicOfVariadic())
800 return ArgumentLengthKind::VariadicOfVariadic
;
801 if (var
->isVariadic())
802 return ArgumentLengthKind::Variadic
;
803 return ArgumentLengthKind::Single
;
806 /// Get the name used for the type list for the given type directive operand.
807 /// 'lengthKind' to the corresponding kind for the given argument.
808 static StringRef
getTypeListName(FormatElement
*arg
,
809 ArgumentLengthKind
&lengthKind
) {
810 if (auto *operand
= dyn_cast
<OperandVariable
>(arg
)) {
811 lengthKind
= getArgumentLengthKind(operand
->getVar());
812 return operand
->getVar()->name
;
814 if (auto *result
= dyn_cast
<ResultVariable
>(arg
)) {
815 lengthKind
= getArgumentLengthKind(result
->getVar());
816 return result
->getVar()->name
;
818 lengthKind
= ArgumentLengthKind::Variadic
;
819 if (isa
<OperandsDirective
>(arg
))
821 if (isa
<ResultsDirective
>(arg
))
823 llvm_unreachable("unknown 'type' directive argument");
826 /// Generate the parser for a literal value.
827 static void genLiteralParser(StringRef value
, MethodBody
&body
) {
828 // Handle the case of a keyword/identifier.
829 if (value
.front() == '_' || isalpha(value
.front())) {
830 body
<< "Keyword(\"" << value
<< "\")";
833 body
<< (StringRef
)StringSwitch
<StringRef
>(value
)
834 .Case("->", "Arrow()")
835 .Case(":", "Colon()")
836 .Case(",", "Comma()")
837 .Case("=", "Equal()")
839 .Case(">", "Greater()")
840 .Case("{", "LBrace()")
841 .Case("}", "RBrace()")
842 .Case("(", "LParen()")
843 .Case(")", "RParen()")
844 .Case("[", "LSquare()")
845 .Case("]", "RSquare()")
846 .Case("?", "Question()")
849 .Case("...", "Ellipsis()");
852 /// Generate the storage code required for parsing the given element.
853 static void genElementParserStorage(FormatElement
*element
, const Operator
&op
,
855 if (auto *optional
= dyn_cast
<OptionalElement
>(element
)) {
856 ArrayRef
<FormatElement
*> elements
= optional
->getThenElements();
858 // If the anchor is a unit attribute, it won't be parsed directly so elide
860 auto *anchor
= dyn_cast
<AttributeLikeVariable
>(optional
->getAnchor());
861 FormatElement
*elidedAnchorElement
= nullptr;
862 if (anchor
&& anchor
!= elements
.front() && anchor
->isUnit())
863 elidedAnchorElement
= anchor
;
864 for (FormatElement
*childElement
: elements
)
865 if (childElement
!= elidedAnchorElement
)
866 genElementParserStorage(childElement
, op
, body
);
867 for (FormatElement
*childElement
: optional
->getElseElements())
868 genElementParserStorage(childElement
, op
, body
);
870 } else if (auto *oilist
= dyn_cast
<OIListElement
>(element
)) {
871 for (ArrayRef
<FormatElement
*> pelement
: oilist
->getParsingElements()) {
872 if (!oilist
->getUnitVariableParsingElement(pelement
))
873 for (FormatElement
*element
: pelement
)
874 genElementParserStorage(element
, op
, body
);
877 } else if (auto *custom
= dyn_cast
<CustomDirective
>(element
)) {
878 for (FormatElement
*paramElement
: custom
->getArguments())
879 genElementParserStorage(paramElement
, op
, body
);
881 } else if (isa
<OperandsDirective
>(element
)) {
882 body
<< " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> "
885 } else if (isa
<RegionsDirective
>(element
)) {
886 body
<< " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
889 } else if (isa
<SuccessorsDirective
>(element
)) {
890 body
<< " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
892 } else if (auto *attr
= dyn_cast
<AttributeVariable
>(element
)) {
893 const NamedAttribute
*var
= attr
->getVar();
894 body
<< formatv(" {0} {1}Attr;\n", var
->attr
.getStorageType(), var
->name
);
896 } else if (auto *operand
= dyn_cast
<OperandVariable
>(element
)) {
897 StringRef name
= operand
->getVar()->name
;
898 if (operand
->getVar()->isVariableLength()) {
900 << " ::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> "
901 << name
<< "Operands;\n";
902 if (operand
->getVar()->isVariadicOfVariadic()) {
903 body
<< " llvm::SmallVector<int32_t> " << name
904 << "OperandGroupSizes;\n";
907 body
<< " ::mlir::OpAsmParser::UnresolvedOperand " << name
909 << " ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> "
910 << name
<< "Operands(&" << name
<< "RawOperand, 1);";
912 body
<< formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
913 " (void){0}OperandsLoc;\n",
916 } else if (auto *region
= dyn_cast
<RegionVariable
>(element
)) {
917 StringRef name
= region
->getVar()->name
;
918 if (region
->getVar()->isVariadic()) {
920 " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
924 body
<< formatv(" std::unique_ptr<::mlir::Region> {0}Region = "
925 "std::make_unique<::mlir::Region>();\n",
929 } else if (auto *successor
= dyn_cast
<SuccessorVariable
>(element
)) {
930 StringRef name
= successor
->getVar()->name
;
931 if (successor
->getVar()->isVariadic()) {
932 body
<< formatv(" ::llvm::SmallVector<::mlir::Block *, 2> "
936 body
<< formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name
);
939 } else if (auto *dir
= dyn_cast
<TypeDirective
>(element
)) {
940 ArgumentLengthKind lengthKind
;
941 StringRef name
= getTypeListName(dir
->getArg(), lengthKind
);
942 if (lengthKind
!= ArgumentLengthKind::Single
)
943 body
<< " ::llvm::SmallVector<::mlir::Type, 1> " << name
<< "Types;\n";
946 << formatv(" ::mlir::Type {0}RawType{{};\n", name
)
948 " ::llvm::ArrayRef<::mlir::Type> {0}Types(&{0}RawType, 1);\n",
950 } else if (auto *dir
= dyn_cast
<FunctionalTypeDirective
>(element
)) {
951 ArgumentLengthKind ignored
;
952 body
<< " ::llvm::ArrayRef<::mlir::Type> "
953 << getTypeListName(dir
->getInputs(), ignored
) << "Types;\n";
954 body
<< " ::llvm::ArrayRef<::mlir::Type> "
955 << getTypeListName(dir
->getResults(), ignored
) << "Types;\n";
959 /// Generate the parser for a parameter to a custom directive.
960 static void genCustomParameterParser(FormatElement
*param
, MethodBody
&body
) {
961 if (auto *attr
= dyn_cast
<AttributeVariable
>(param
)) {
962 body
<< attr
->getVar()->name
<< "Attr";
963 } else if (isa
<AttrDictDirective
>(param
)) {
964 body
<< "result.attributes";
965 } else if (isa
<PropDictDirective
>(param
)) {
967 } else if (auto *operand
= dyn_cast
<OperandVariable
>(param
)) {
968 StringRef name
= operand
->getVar()->name
;
969 ArgumentLengthKind lengthKind
= getArgumentLengthKind(operand
->getVar());
970 if (lengthKind
== ArgumentLengthKind::VariadicOfVariadic
)
971 body
<< formatv("{0}OperandGroups", name
);
972 else if (lengthKind
== ArgumentLengthKind::Variadic
)
973 body
<< formatv("{0}Operands", name
);
974 else if (lengthKind
== ArgumentLengthKind::Optional
)
975 body
<< formatv("{0}Operand", name
);
977 body
<< formatv("{0}RawOperand", name
);
979 } else if (auto *region
= dyn_cast
<RegionVariable
>(param
)) {
980 StringRef name
= region
->getVar()->name
;
981 if (region
->getVar()->isVariadic())
982 body
<< formatv("{0}Regions", name
);
984 body
<< formatv("*{0}Region", name
);
986 } else if (auto *successor
= dyn_cast
<SuccessorVariable
>(param
)) {
987 StringRef name
= successor
->getVar()->name
;
988 if (successor
->getVar()->isVariadic())
989 body
<< formatv("{0}Successors", name
);
991 body
<< formatv("{0}Successor", name
);
993 } else if (auto *dir
= dyn_cast
<RefDirective
>(param
)) {
994 genCustomParameterParser(dir
->getArg(), body
);
996 } else if (auto *dir
= dyn_cast
<TypeDirective
>(param
)) {
997 ArgumentLengthKind lengthKind
;
998 StringRef listName
= getTypeListName(dir
->getArg(), lengthKind
);
999 if (lengthKind
== ArgumentLengthKind::VariadicOfVariadic
)
1000 body
<< formatv("{0}TypeGroups", listName
);
1001 else if (lengthKind
== ArgumentLengthKind::Variadic
)
1002 body
<< formatv("{0}Types", listName
);
1003 else if (lengthKind
== ArgumentLengthKind::Optional
)
1004 body
<< formatv("{0}Type", listName
);
1006 body
<< formatv("{0}RawType", listName
);
1008 } else if (auto *string
= dyn_cast
<StringElement
>(param
)) {
1010 ctx
.withBuilder("parser.getBuilder()");
1011 ctx
.addSubst("_ctxt", "parser.getContext()");
1012 body
<< tgfmt(string
->getValue(), &ctx
);
1014 } else if (auto *property
= dyn_cast
<PropertyVariable
>(param
)) {
1015 body
<< formatv("result.getOrAddProperties<Properties>().{0}",
1016 property
->getVar()->name
);
1018 llvm_unreachable("unknown custom directive parameter");
1022 /// Generate the parser for a custom directive.
1023 static void genCustomDirectiveParser(CustomDirective
*dir
, MethodBody
&body
,
1025 StringRef opCppClassName
,
1026 bool isOptional
= false) {
1029 // Preprocess the directive variables.
1030 // * Add a local variable for optional operands and types. This provides a
1031 // better API to the user defined parser methods.
1032 // * Set the location of operand variables.
1033 for (FormatElement
*param
: dir
->getArguments()) {
1034 if (auto *operand
= dyn_cast
<OperandVariable
>(param
)) {
1035 auto *var
= operand
->getVar();
1036 body
<< " " << var
->name
1037 << "OperandsLoc = parser.getCurrentLocation();\n";
1038 if (var
->isOptional()) {
1040 " ::std::optional<::mlir::OpAsmParser::UnresolvedOperand> "
1043 } else if (var
->isVariadicOfVariadic()) {
1045 "::llvm::SmallVector<::llvm::SmallVector<::mlir::"
1046 "OpAsmParser::UnresolvedOperand>> "
1047 "{0}OperandGroups;\n",
1050 } else if (auto *dir
= dyn_cast
<TypeDirective
>(param
)) {
1051 ArgumentLengthKind lengthKind
;
1052 StringRef listName
= getTypeListName(dir
->getArg(), lengthKind
);
1053 if (lengthKind
== ArgumentLengthKind::Optional
) {
1054 body
<< formatv(" ::mlir::Type {0}Type;\n", listName
);
1055 } else if (lengthKind
== ArgumentLengthKind::VariadicOfVariadic
) {
1057 " ::llvm::SmallVector<llvm::SmallVector<::mlir::Type>> "
1061 } else if (auto *dir
= dyn_cast
<RefDirective
>(param
)) {
1062 FormatElement
*input
= dir
->getArg();
1063 if (auto *operand
= dyn_cast
<OperandVariable
>(input
)) {
1064 if (!operand
->getVar()->isOptional())
1067 " {0} {1}Operand = {1}Operands.empty() ? {0}() : "
1068 "{1}Operands[0];\n",
1069 "::std::optional<::mlir::OpAsmParser::UnresolvedOperand>",
1070 operand
->getVar()->name
);
1072 } else if (auto *type
= dyn_cast
<TypeDirective
>(input
)) {
1073 ArgumentLengthKind lengthKind
;
1074 StringRef listName
= getTypeListName(type
->getArg(), lengthKind
);
1075 if (lengthKind
== ArgumentLengthKind::Optional
) {
1076 body
<< formatv(" ::mlir::Type {0}Type = {0}Types.empty() ? "
1077 "::mlir::Type() : {0}Types[0];\n",
1084 body
<< " auto odsResult = parse" << dir
->getName() << "(parser";
1085 for (FormatElement
*param
: dir
->getArguments()) {
1087 genCustomParameterParser(param
, body
);
1092 body
<< " if (!odsResult.has_value()) return {};\n"
1093 << " if (::mlir::failed(*odsResult)) return ::mlir::failure();\n";
1095 body
<< " if (odsResult) return ::mlir::failure();\n";
1098 // After parsing, add handling for any of the optional constructs.
1099 for (FormatElement
*param
: dir
->getArguments()) {
1100 if (auto *attr
= dyn_cast
<AttributeVariable
>(param
)) {
1101 const NamedAttribute
*var
= attr
->getVar();
1102 if (var
->attr
.isOptional() || var
->attr
.hasDefaultValue())
1103 body
<< formatv(" if ({0}Attr)\n ", var
->name
);
1104 if (useProperties
) {
1106 " result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;\n",
1107 var
->name
, opCppClassName
);
1109 body
<< formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
1112 } else if (auto *operand
= dyn_cast
<OperandVariable
>(param
)) {
1113 const NamedTypeConstraint
*var
= operand
->getVar();
1114 if (var
->isOptional()) {
1115 body
<< formatv(" if ({0}Operand.has_value())\n"
1116 " {0}Operands.push_back(*{0}Operand);\n",
1118 } else if (var
->isVariadicOfVariadic()) {
1120 " for (const auto &subRange : {0}OperandGroups) {{\n"
1121 " {0}Operands.append(subRange.begin(), subRange.end());\n"
1122 " {0}OperandGroupSizes.push_back(subRange.size());\n"
1126 } else if (auto *dir
= dyn_cast
<TypeDirective
>(param
)) {
1127 ArgumentLengthKind lengthKind
;
1128 StringRef listName
= getTypeListName(dir
->getArg(), lengthKind
);
1129 if (lengthKind
== ArgumentLengthKind::Optional
) {
1130 body
<< formatv(" if ({0}Type)\n"
1131 " {0}Types.push_back({0}Type);\n",
1133 } else if (lengthKind
== ArgumentLengthKind::VariadicOfVariadic
) {
1135 " for (const auto &subRange : {0}TypeGroups)\n"
1136 " {0}Types.append(subRange.begin(), subRange.end());\n",
1145 /// Generate the parser for a enum attribute.
1146 static void genEnumAttrParser(const NamedAttribute
*var
, MethodBody
&body
,
1147 FmtContext
&attrTypeCtx
, bool parseAsOptional
,
1148 bool useProperties
, StringRef opCppClassName
) {
1149 Attribute baseAttr
= var
->attr
.getBaseAttr();
1150 const EnumAttr
&enumAttr
= cast
<EnumAttr
>(baseAttr
);
1151 std::vector
<EnumAttrCase
> cases
= enumAttr
.getAllCases();
1153 // Generate the code for building an attribute for this enum.
1154 std::string attrBuilderStr
;
1156 llvm::raw_string_ostream
os(attrBuilderStr
);
1157 os
<< tgfmt(enumAttr
.getConstBuilderTemplate(), &attrTypeCtx
,
1161 // Build a string containing the cases that can be formatted as a keyword.
1162 std::string validCaseKeywordsStr
= "{";
1163 llvm::raw_string_ostream
validCaseKeywordsOS(validCaseKeywordsStr
);
1164 for (const EnumAttrCase
&attrCase
: cases
)
1165 if (canFormatStringAsKeyword(attrCase
.getStr()))
1166 validCaseKeywordsOS
<< '"' << attrCase
.getStr() << "\",";
1167 validCaseKeywordsOS
.str().back() = '}';
1169 // If the attribute is not optional, build an error message for the missing
1171 std::string errorMessage
;
1172 if (!parseAsOptional
) {
1173 llvm::raw_string_ostream
errorMessageOS(errorMessage
);
1175 << "return parser.emitError(loc, \"expected string or "
1176 "keyword containing one of the following enum values for attribute '"
1177 << var
->name
<< "' [";
1178 llvm::interleaveComma(cases
, errorMessageOS
, [&](const auto &attrCase
) {
1179 errorMessageOS
<< attrCase
.getStr();
1181 errorMessageOS
<< "]\");";
1183 std::string attrAssignment
;
1184 if (useProperties
) {
1187 "result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;",
1188 var
->name
, opCppClassName
);
1191 formatv("result.addAttribute(\"{0}\", {0}Attr);", var
->name
);
1194 body
<< formatv(enumAttrParserCode
, var
->name
, enumAttr
.getCppNamespace(),
1195 enumAttr
.getStringToSymbolFnName(), attrBuilderStr
,
1196 validCaseKeywordsStr
, errorMessage
, attrAssignment
);
1199 // Generate the parser for a property.
1200 static void genPropertyParser(PropertyVariable
*propVar
, MethodBody
&body
,
1201 StringRef opCppClassName
,
1202 bool requireParse
= true) {
1203 StringRef name
= propVar
->getVar()->name
;
1204 const Property
&prop
= propVar
->getVar()->prop
;
1205 bool parseOptionally
=
1206 prop
.hasDefaultValue() && !requireParse
&& prop
.hasOptionalParser();
1207 FmtContext fmtContext
;
1208 fmtContext
.addSubst("_parser", "parser");
1209 fmtContext
.addSubst("_ctxt", "parser.getContext()");
1210 fmtContext
.addSubst("_storage", "propStorage");
1212 if (parseOptionally
) {
1213 body
<< formatv(optionalPropertyParserCode
, name
, opCppClassName
,
1214 tgfmt(prop
.getOptionalParserCall(), &fmtContext
));
1216 body
<< formatv(propertyParserCode
, name
, opCppClassName
,
1217 tgfmt(prop
.getParserCall(), &fmtContext
),
1222 // Generate the parser for an attribute.
1223 static void genAttrParser(AttributeVariable
*attr
, MethodBody
&body
,
1224 FmtContext
&attrTypeCtx
, bool parseAsOptional
,
1225 bool useProperties
, StringRef opCppClassName
) {
1226 const NamedAttribute
*var
= attr
->getVar();
1228 // Check to see if we can parse this as an enum attribute.
1229 if (canFormatEnumAttr(var
))
1230 return genEnumAttrParser(var
, body
, attrTypeCtx
, parseAsOptional
,
1231 useProperties
, opCppClassName
);
1233 // Check to see if we should parse this as a symbol name attribute.
1234 if (shouldFormatSymbolNameAttr(var
)) {
1235 body
<< formatv(parseAsOptional
? optionalSymbolNameAttrParserCode
1236 : symbolNameAttrParserCode
,
1240 // If this attribute has a buildable type, use that when parsing the
1242 std::string attrTypeStr
;
1243 if (std::optional
<StringRef
> typeBuilder
= attr
->getTypeBuilder()) {
1244 llvm::raw_string_ostream
os(attrTypeStr
);
1245 os
<< tgfmt(*typeBuilder
, &attrTypeCtx
);
1247 attrTypeStr
= "::mlir::Type{}";
1249 if (parseAsOptional
) {
1250 body
<< formatv(optionalAttrParserCode
, var
->name
, attrTypeStr
);
1252 if (attr
->shouldBeQualified() ||
1253 var
->attr
.getStorageType() == "::mlir::Attribute")
1254 body
<< formatv(genericAttrParserCode
, var
->name
, attrTypeStr
);
1256 body
<< formatv(attrParserCode
, var
->name
, attrTypeStr
);
1259 if (useProperties
) {
1261 " if ({0}Attr) result.getOrAddProperties<{1}::Properties>().{0} = "
1263 var
->name
, opCppClassName
);
1266 " if ({0}Attr) result.attributes.append(\"{0}\", {0}Attr);\n",
1271 // Generates the 'setPropertiesFromParsedAttr' used to set properties from a
1272 // 'prop-dict' dictionary attr.
1273 static void genParsedAttrPropertiesSetter(OperationFormat
&fmt
, Operator
&op
,
1275 // Not required unless 'prop-dict' is present or we are not using properties.
1276 if (!fmt
.hasPropDict
|| !fmt
.useProperties
)
1279 SmallVector
<MethodParameter
> paramList
;
1280 paramList
.emplace_back("Properties &", "prop");
1281 paramList
.emplace_back("::mlir::Attribute", "attr");
1282 paramList
.emplace_back("::llvm::function_ref<::mlir::InFlightDiagnostic()>",
1285 Method
*method
= opClass
.addStaticMethod("::llvm::LogicalResult",
1286 "setPropertiesFromParsedAttr",
1287 std::move(paramList
));
1288 MethodBody
&body
= method
->body().indent();
1291 ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr);
1293 emitError() << "expected DictionaryAttr to set properties
";
1294 return ::mlir::failure();
1298 // {0}: fromAttribute call
1299 // {1}: property name
1301 const char *propFromAttrFmt
= R
"decl(
1302 auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr,
1303 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) -> ::mlir::LogicalResult {{
1306 auto attr = dict.get("{1}");
1307 if (!attr && {2}) {{
1308 emitError() << "expected key entry
for {1} in DictionaryAttr to set
"
1310 return ::mlir::failure();
1312 if (::mlir::failed(setFromAttr(prop.{1}, attr, emitError)))
1313 return ::mlir::failure();
1316 // Generate the setter for any property not parsed elsewhere.
1317 for (const NamedProperty
&namedProperty
: op
.getProperties()) {
1318 if (fmt
.usedProperties
.contains(&namedProperty
))
1321 auto scope
= body
.scope("{\n", "}\n", /*indent=*/true);
1323 StringRef name
= namedProperty
.name
;
1324 const Property
&prop
= namedProperty
.prop
;
1325 bool isRequired
= !prop
.hasDefaultValue();
1327 body
<< formatv(propFromAttrFmt
,
1328 tgfmt(prop
.getConvertFromAttributeCall(),
1329 &fctx
.addSubst("_attr", "propAttr")
1330 .addSubst("_storage", "propStorage")
1331 .addSubst("_diag", "emitError")),
1335 // Generate the setter for any attribute not parsed elsewhere.
1336 for (const NamedAttribute
&namedAttr
: op
.getAttributes()) {
1337 if (fmt
.usedAttributes
.contains(&namedAttr
))
1340 const Attribute
&attr
= namedAttr
.attr
;
1341 // Derived attributes do not need to be parsed.
1342 if (attr
.isDerivedAttr())
1345 auto scope
= body
.scope("{\n", "}\n", /*indent=*/true);
1347 // If the attribute has a default value or is optional, it does not need to
1348 // be present in the parsed dictionary attribute.
1349 bool isRequired
= !attr
.isOptional() && !attr
.hasDefaultValue();
1350 body
<< formatv(R
"decl(
1351 auto &propStorage = prop.{0};
1352 auto attr = dict.get("{0}");
1353 if (attr || /*isRequired=*/{1}) {{
1355 emitError() << "expected key entry
for {0} in DictionaryAttr to set
"
1357 return ::mlir::failure();
1359 auto convertedAttr = ::llvm::dyn_cast<std::remove_reference_t<decltype(propStorage)>>(attr);
1360 if (convertedAttr) {{
1361 propStorage = convertedAttr;
1363 emitError() << "Invalid attribute `
{0}` in property conversion
: " << attr;
1364 return ::mlir::failure();
1368 namedAttr
.name
, isRequired
);
1370 body
<< "return ::mlir::success();\n";
1373 void OperationFormat::genParser(Operator
&op
, OpClass
&opClass
) {
1374 SmallVector
<MethodParameter
> paramList
;
1375 paramList
.emplace_back("::mlir::OpAsmParser &", "parser");
1376 paramList
.emplace_back("::mlir::OperationState &", "result");
1378 auto *method
= opClass
.addStaticMethod("::mlir::ParseResult", "parse",
1379 std::move(paramList
));
1380 auto &body
= method
->body();
1382 // Generate variables to store the operands and type within the format. This
1383 // allows for referencing these variables in the presence of optional
1385 for (FormatElement
*element
: elements
)
1386 genElementParserStorage(element
, op
, body
);
1388 // A format context used when parsing attributes with buildable types.
1389 FmtContext attrTypeCtx
;
1390 attrTypeCtx
.withBuilder("parser.getBuilder()");
1392 // Generate parsers for each of the elements.
1393 for (FormatElement
*element
: elements
)
1394 genElementParser(element
, body
, attrTypeCtx
);
1396 // Generate the code to resolve the operand/result types and successors now
1397 // that they have been parsed.
1398 genParserRegionResolution(op
, body
);
1399 genParserSuccessorResolution(op
, body
);
1400 genParserVariadicSegmentResolution(op
, body
);
1401 genParserTypeResolution(op
, body
);
1403 body
<< " return ::mlir::success();\n";
1405 genParsedAttrPropertiesSetter(*this, op
, opClass
);
1408 void OperationFormat::genElementParser(FormatElement
*element
, MethodBody
&body
,
1409 FmtContext
&attrTypeCtx
,
1410 GenContext genCtx
) {
1412 if (auto *optional
= dyn_cast
<OptionalElement
>(element
)) {
1413 auto genElementParsers
= [&](FormatElement
*firstElement
,
1414 ArrayRef
<FormatElement
*> elements
,
1416 // If the anchor is a unit attribute, we don't need to print it. When
1417 // parsing, we will add this attribute if this group is present.
1418 FormatElement
*elidedAnchorElement
= nullptr;
1419 auto *anchorVar
= dyn_cast
<AttributeLikeVariable
>(optional
->getAnchor());
1420 if (anchorVar
&& anchorVar
!= firstElement
&& anchorVar
->isUnit()) {
1421 elidedAnchorElement
= anchorVar
;
1423 if (!thenGroup
== optional
->isInverted()) {
1424 // Add the anchor unit attribute or property to the operation state
1425 // or set the property to true.
1426 if (isa
<PropertyVariable
>(anchorVar
)) {
1428 " result.getOrAddProperties<{1}::Properties>().{0} = true;",
1429 anchorVar
->getName(), opCppClassName
);
1430 } else if (useProperties
) {
1432 " result.getOrAddProperties<{1}::Properties>().{0} = "
1433 "parser.getBuilder().getUnitAttr();",
1434 anchorVar
->getName(), opCppClassName
);
1436 body
<< " result.addAttribute(\"" << anchorVar
->getName()
1437 << "\", parser.getBuilder().getUnitAttr());\n";
1442 // Generate the rest of the elements inside an optional group. Elements in
1443 // an optional group after the guard are parsed as required.
1444 for (FormatElement
*childElement
: elements
)
1445 if (childElement
!= elidedAnchorElement
)
1446 genElementParser(childElement
, body
, attrTypeCtx
,
1447 GenContext::Optional
);
1450 ArrayRef
<FormatElement
*> thenElements
=
1451 optional
->getThenElements(/*parseable=*/true);
1453 // Generate a special optional parser for the first element to gate the
1454 // parsing of the rest of the elements.
1455 FormatElement
*firstElement
= thenElements
.front();
1456 if (auto *attrVar
= dyn_cast
<AttributeVariable
>(firstElement
)) {
1457 genAttrParser(attrVar
, body
, attrTypeCtx
, /*parseAsOptional=*/true,
1458 useProperties
, opCppClassName
);
1459 body
<< " if (" << attrVar
->getVar()->name
<< "Attr) {\n";
1460 } else if (auto *propVar
= dyn_cast
<PropertyVariable
>(firstElement
)) {
1461 genPropertyParser(propVar
, body
, opCppClassName
, /*requireParse=*/false);
1462 body
<< formatv("if ({0}PropParseResult.has_value() && "
1463 "succeeded(*{0}PropParseResult)) ",
1464 propVar
->getVar()->name
)
1466 } else if (auto *literal
= dyn_cast
<LiteralElement
>(firstElement
)) {
1467 body
<< " if (::mlir::succeeded(parser.parseOptional";
1468 genLiteralParser(literal
->getSpelling(), body
);
1470 } else if (auto *opVar
= dyn_cast
<OperandVariable
>(firstElement
)) {
1471 genElementParser(opVar
, body
, attrTypeCtx
);
1472 body
<< " if (!" << opVar
->getVar()->name
<< "Operands.empty()) {\n";
1473 } else if (auto *regionVar
= dyn_cast
<RegionVariable
>(firstElement
)) {
1474 const NamedRegion
*region
= regionVar
->getVar();
1475 if (region
->isVariadic()) {
1476 genElementParser(regionVar
, body
, attrTypeCtx
);
1477 body
<< " if (!" << region
->name
<< "Regions.empty()) {\n";
1479 body
<< formatv(optionalRegionParserCode
, region
->name
);
1480 body
<< " if (!" << region
->name
<< "Region->empty()) {\n ";
1481 if (hasImplicitTermTrait
)
1482 body
<< formatv(regionEnsureTerminatorParserCode
, region
->name
);
1483 else if (hasSingleBlockTrait
)
1484 body
<< formatv(regionEnsureSingleBlockParserCode
, region
->name
);
1486 } else if (auto *custom
= dyn_cast
<CustomDirective
>(firstElement
)) {
1487 body
<< " if (auto optResult = [&]() -> ::mlir::OptionalParseResult {\n";
1488 genCustomDirectiveParser(custom
, body
, useProperties
, opCppClassName
,
1489 /*isOptional=*/true);
1490 body
<< " return ::mlir::success();\n"
1491 << " }(); optResult.has_value() && ::mlir::failed(*optResult)) {\n"
1492 << " return ::mlir::failure();\n"
1493 << " } else if (optResult.has_value()) {\n";
1496 genElementParsers(firstElement
, thenElements
.drop_front(),
1497 /*thenGroup=*/true);
1500 // Generate the else elements.
1501 auto elseElements
= optional
->getElseElements();
1502 if (!elseElements
.empty()) {
1503 body
<< " else {\n";
1504 ArrayRef
<FormatElement
*> elseElements
=
1505 optional
->getElseElements(/*parseable=*/true);
1506 genElementParsers(elseElements
.front(), elseElements
,
1507 /*thenGroup=*/false);
1512 /// OIList Directive
1513 } else if (OIListElement
*oilist
= dyn_cast
<OIListElement
>(element
)) {
1514 for (LiteralElement
*le
: oilist
->getLiteralElements())
1515 body
<< " bool " << le
->getSpelling() << "Clause = false;\n";
1517 // Generate the parsing loop
1518 body
<< " while(true) {\n";
1519 for (auto clause
: oilist
->getClauses()) {
1520 LiteralElement
*lelement
= std::get
<0>(clause
);
1521 ArrayRef
<FormatElement
*> pelement
= std::get
<1>(clause
);
1522 body
<< "if (succeeded(parser.parseOptional";
1523 genLiteralParser(lelement
->getSpelling(), body
);
1525 StringRef lelementName
= lelement
->getSpelling();
1526 body
<< formatv(oilistParserCode
, lelementName
);
1527 if (AttributeLikeVariable
*unitVarElem
=
1528 oilist
->getUnitVariableParsingElement(pelement
)) {
1529 if (isa
<PropertyVariable
>(unitVarElem
)) {
1531 " result.getOrAddProperties<{1}::Properties>().{0} = true;",
1532 unitVarElem
->getName(), opCppClassName
);
1533 } else if (useProperties
) {
1535 " result.getOrAddProperties<{1}::Properties>().{0} = "
1536 "parser.getBuilder().getUnitAttr();",
1537 unitVarElem
->getName(), opCppClassName
);
1539 body
<< " result.addAttribute(\"" << unitVarElem
->getName()
1540 << "\", UnitAttr::get(parser.getContext()));\n";
1543 for (FormatElement
*el
: pelement
)
1544 genElementParser(el
, body
, attrTypeCtx
);
1549 body
<< " break;\n";
1554 } else if (LiteralElement
*literal
= dyn_cast
<LiteralElement
>(element
)) {
1555 body
<< " if (parser.parse";
1556 genLiteralParser(literal
->getSpelling(), body
);
1557 body
<< ")\n return ::mlir::failure();\n";
1560 } else if (isa
<WhitespaceElement
>(element
)) {
1561 // Nothing to parse.
1564 } else if (auto *attr
= dyn_cast
<AttributeVariable
>(element
)) {
1565 bool parseAsOptional
=
1566 (genCtx
== GenContext::Normal
&& attr
->getVar()->attr
.isOptional());
1567 genAttrParser(attr
, body
, attrTypeCtx
, parseAsOptional
, useProperties
,
1569 } else if (auto *prop
= dyn_cast
<PropertyVariable
>(element
)) {
1570 genPropertyParser(prop
, body
, opCppClassName
);
1572 } else if (auto *operand
= dyn_cast
<OperandVariable
>(element
)) {
1573 ArgumentLengthKind lengthKind
= getArgumentLengthKind(operand
->getVar());
1574 StringRef name
= operand
->getVar()->name
;
1575 if (lengthKind
== ArgumentLengthKind::VariadicOfVariadic
)
1576 body
<< formatv(variadicOfVariadicOperandParserCode
, name
);
1577 else if (lengthKind
== ArgumentLengthKind::Variadic
)
1578 body
<< formatv(variadicOperandParserCode
, name
);
1579 else if (lengthKind
== ArgumentLengthKind::Optional
)
1580 body
<< formatv(optionalOperandParserCode
, name
);
1582 body
<< formatv(operandParserCode
, name
);
1584 } else if (auto *region
= dyn_cast
<RegionVariable
>(element
)) {
1585 bool isVariadic
= region
->getVar()->isVariadic();
1586 body
<< formatv(isVariadic
? regionListParserCode
: regionParserCode
,
1587 region
->getVar()->name
);
1588 if (hasImplicitTermTrait
)
1589 body
<< formatv(isVariadic
? regionListEnsureTerminatorParserCode
1590 : regionEnsureTerminatorParserCode
,
1591 region
->getVar()->name
);
1592 else if (hasSingleBlockTrait
)
1593 body
<< formatv(isVariadic
? regionListEnsureSingleBlockParserCode
1594 : regionEnsureSingleBlockParserCode
,
1595 region
->getVar()->name
);
1597 } else if (auto *successor
= dyn_cast
<SuccessorVariable
>(element
)) {
1598 bool isVariadic
= successor
->getVar()->isVariadic();
1599 body
<< formatv(isVariadic
? successorListParserCode
: successorParserCode
,
1600 successor
->getVar()->name
);
1603 } else if (auto *attrDict
= dyn_cast
<AttrDictDirective
>(element
)) {
1604 body
.indent() << "{\n";
1605 body
.indent() << "auto loc = parser.getCurrentLocation();(void)loc;\n"
1606 << "if (parser.parseOptionalAttrDict"
1607 << (attrDict
->isWithKeyword() ? "WithKeyword" : "")
1608 << "(result.attributes))\n"
1609 << " return ::mlir::failure();\n";
1610 if (useProperties
) {
1611 body
<< "if (failed(verifyInherentAttrs(result.name, result.attributes, "
1613 << " return parser.emitError(loc) << \"'\" << "
1614 "result.name.getStringRef() << \"' op \";\n"
1616 << " return ::mlir::failure();\n";
1618 body
.unindent() << "}\n";
1620 } else if (isa
<PropDictDirective
>(element
)) {
1621 if (useProperties
) {
1622 body
<< " if (parseProperties(parser, result))\n"
1623 << " return ::mlir::failure();\n";
1625 } else if (auto *customDir
= dyn_cast
<CustomDirective
>(element
)) {
1626 genCustomDirectiveParser(customDir
, body
, useProperties
, opCppClassName
);
1627 } else if (isa
<OperandsDirective
>(element
)) {
1628 body
<< " [[maybe_unused]] ::llvm::SMLoc allOperandLoc ="
1629 << " parser.getCurrentLocation();\n"
1630 << " if (parser.parseOperandList(allOperands))\n"
1631 << " return ::mlir::failure();\n";
1633 } else if (isa
<RegionsDirective
>(element
)) {
1634 body
<< formatv(regionListParserCode
, "full");
1635 if (hasImplicitTermTrait
)
1636 body
<< formatv(regionListEnsureTerminatorParserCode
, "full");
1637 else if (hasSingleBlockTrait
)
1638 body
<< formatv(regionListEnsureSingleBlockParserCode
, "full");
1640 } else if (isa
<SuccessorsDirective
>(element
)) {
1641 body
<< formatv(successorListParserCode
, "full");
1643 } else if (auto *dir
= dyn_cast
<TypeDirective
>(element
)) {
1644 ArgumentLengthKind lengthKind
;
1645 StringRef listName
= getTypeListName(dir
->getArg(), lengthKind
);
1646 if (lengthKind
== ArgumentLengthKind::VariadicOfVariadic
) {
1647 body
<< formatv(variadicOfVariadicTypeParserCode
, listName
);
1648 } else if (lengthKind
== ArgumentLengthKind::Variadic
) {
1649 body
<< formatv(variadicTypeParserCode
, listName
);
1650 } else if (lengthKind
== ArgumentLengthKind::Optional
) {
1651 body
<< formatv(optionalTypeParserCode
, listName
);
1653 const char *parserCode
=
1654 dir
->shouldBeQualified() ? qualifiedTypeParserCode
: typeParserCode
;
1655 TypeSwitch
<FormatElement
*>(dir
->getArg())
1656 .Case
<OperandVariable
, ResultVariable
>([&](auto operand
) {
1657 body
<< formatv(false, parserCode
,
1658 operand
->getVar()->constraint
.getCppType(),
1661 .Default([&](auto operand
) {
1662 body
<< formatv(false, parserCode
, "::mlir::Type", listName
);
1665 } else if (auto *dir
= dyn_cast
<FunctionalTypeDirective
>(element
)) {
1666 ArgumentLengthKind ignored
;
1667 body
<< formatv(functionalTypeParserCode
,
1668 getTypeListName(dir
->getInputs(), ignored
),
1669 getTypeListName(dir
->getResults(), ignored
));
1671 llvm_unreachable("unknown format element");
1675 void OperationFormat::genParserTypeResolution(Operator
&op
, MethodBody
&body
) {
1676 // If any of type resolutions use transformed variables, make sure that the
1677 // types of those variables are resolved.
1678 SmallPtrSet
<const NamedTypeConstraint
*, 8> verifiedVariables
;
1679 FmtContext verifierFCtx
;
1680 for (TypeResolution
&resolver
:
1681 llvm::concat
<TypeResolution
>(resultTypes
, operandTypes
)) {
1682 std::optional
<StringRef
> transformer
= resolver
.getVarTransformer();
1685 // Ensure that we don't verify the same variables twice.
1686 const NamedTypeConstraint
*variable
= resolver
.getVariable();
1687 if (!variable
|| !verifiedVariables
.insert(variable
).second
)
1690 auto constraint
= variable
->constraint
;
1691 body
<< " for (::mlir::Type type : " << variable
->name
<< "Types) {\n"
1694 << tgfmt(constraint
.getConditionTemplate(),
1695 &verifierFCtx
.withSelf("type"))
1697 << formatv(" return parser.emitError(parser.getNameLoc()) << "
1698 "\"'{0}' must be {1}, but got \" << type;\n",
1699 variable
->name
, constraint
.getSummary())
1704 // Initialize the set of buildable types.
1705 if (!buildableTypes
.empty()) {
1706 FmtContext typeBuilderCtx
;
1707 typeBuilderCtx
.withBuilder("parser.getBuilder()");
1708 for (auto &it
: buildableTypes
)
1709 body
<< " ::mlir::Type odsBuildableType" << it
.second
<< " = "
1710 << tgfmt(it
.first
, &typeBuilderCtx
) << ";\n";
1713 // Emit the code necessary for a type resolver.
1714 auto emitTypeResolver
= [&](TypeResolution
&resolver
, StringRef curVar
) {
1715 if (std::optional
<int> val
= resolver
.getBuilderIdx()) {
1716 body
<< "odsBuildableType" << *val
;
1717 } else if (const NamedTypeConstraint
*var
= resolver
.getVariable()) {
1718 if (std::optional
<StringRef
> tform
= resolver
.getVarTransformer()) {
1719 FmtContext fmtContext
;
1720 fmtContext
.addSubst("_ctxt", "parser.getContext()");
1721 if (var
->isVariadic())
1722 fmtContext
.withSelf(var
->name
+ "Types");
1724 fmtContext
.withSelf(var
->name
+ "Types[0]");
1725 body
<< tgfmt(*tform
, &fmtContext
);
1727 body
<< var
->name
<< "Types";
1728 if (!var
->isVariadic())
1731 } else if (const NamedAttribute
*attr
= resolver
.getAttribute()) {
1732 if (std::optional
<StringRef
> tform
= resolver
.getVarTransformer())
1733 body
<< tgfmt(*tform
,
1734 &FmtContext().withSelf(attr
->name
+ "Attr.getType()"));
1736 body
<< attr
->name
<< "Attr.getType()";
1738 body
<< curVar
<< "Types";
1742 // Resolve each of the result types.
1743 if (!infersResultTypes
) {
1744 if (allResultTypes
) {
1745 body
<< " result.addTypes(allResultTypes);\n";
1747 for (unsigned i
= 0, e
= op
.getNumResults(); i
!= e
; ++i
) {
1748 body
<< " result.addTypes(";
1749 emitTypeResolver(resultTypes
[i
], op
.getResultName(i
));
1755 // Emit the operand type resolutions.
1756 genParserOperandTypeResolution(op
, body
, emitTypeResolver
);
1758 // Handle return type inference once all operands have been resolved
1759 if (infersResultTypes
)
1760 body
<< formatv(inferReturnTypesParserCode
, op
.getCppClassName());
1763 void OperationFormat::genParserOperandTypeResolution(
1764 Operator
&op
, MethodBody
&body
,
1765 function_ref
<void(TypeResolution
&, StringRef
)> emitTypeResolver
) {
1766 // Early exit if there are no operands.
1767 if (op
.getNumOperands() == 0)
1770 // Handle the case where all operand types are grouped together with
1771 // "types(operands)".
1772 if (allOperandTypes
) {
1773 // If `operands` was specified, use the full operand list directly.
1775 body
<< " if (parser.resolveOperands(allOperands, allOperandTypes, "
1776 "allOperandLoc, result.operands))\n"
1777 " return ::mlir::failure();\n";
1781 // Otherwise, use llvm::concat to merge the disjoint operand lists together.
1782 // llvm::concat does not allow the case of a single range, so guard it here.
1783 body
<< " if (parser.resolveOperands(";
1784 if (op
.getNumOperands() > 1) {
1785 body
<< "::llvm::concat<const ::mlir::OpAsmParser::UnresolvedOperand>(";
1786 llvm::interleaveComma(op
.getOperands(), body
, [&](auto &operand
) {
1787 body
<< operand
.name
<< "Operands";
1791 body
<< op
.operand_begin()->name
<< "Operands";
1793 body
<< ", allOperandTypes, parser.getNameLoc(), result.operands))\n"
1794 << " return ::mlir::failure();\n";
1798 // Handle the case where all operands are grouped together with "operands".
1800 body
<< " if (parser.resolveOperands(allOperands, ";
1802 // Group all of the operand types together to perform the resolution all at
1803 // once. Use llvm::concat to perform the merge. llvm::concat does not allow
1804 // the case of a single range, so guard it here.
1805 if (op
.getNumOperands() > 1) {
1806 body
<< "::llvm::concat<const ::mlir::Type>(";
1807 llvm::interleaveComma(
1808 llvm::seq
<int>(0, op
.getNumOperands()), body
, [&](int i
) {
1809 body
<< "::llvm::ArrayRef<::mlir::Type>(";
1810 emitTypeResolver(operandTypes
[i
], op
.getOperand(i
).name
);
1815 emitTypeResolver(operandTypes
.front(), op
.getOperand(0).name
);
1818 body
<< ", allOperandLoc, result.operands))\n return "
1819 "::mlir::failure();\n";
1823 // The final case is the one where each of the operands types are resolved
1825 for (unsigned i
= 0, e
= op
.getNumOperands(); i
!= e
; ++i
) {
1826 NamedTypeConstraint
&operand
= op
.getOperand(i
);
1827 body
<< " if (parser.resolveOperands(" << operand
.name
<< "Operands, ";
1829 // Resolve the type of this operand.
1830 TypeResolution
&operandType
= operandTypes
[i
];
1831 emitTypeResolver(operandType
, operand
.name
);
1833 body
<< ", " << operand
.name
1834 << "OperandsLoc, result.operands))\n return ::mlir::failure();\n";
1838 void OperationFormat::genParserRegionResolution(Operator
&op
,
1840 // Check for the case where all regions were parsed.
1841 bool hasAllRegions
= llvm::any_of(
1842 elements
, [](FormatElement
*elt
) { return isa
<RegionsDirective
>(elt
); });
1843 if (hasAllRegions
) {
1844 body
<< " result.addRegions(fullRegions);\n";
1848 // Otherwise, handle each region individually.
1849 for (const NamedRegion
®ion
: op
.getRegions()) {
1850 if (region
.isVariadic())
1851 body
<< " result.addRegions(" << region
.name
<< "Regions);\n";
1853 body
<< " result.addRegion(std::move(" << region
.name
<< "Region));\n";
1857 void OperationFormat::genParserSuccessorResolution(Operator
&op
,
1859 // Check for the case where all successors were parsed.
1860 bool hasAllSuccessors
= llvm::any_of(elements
, [](FormatElement
*elt
) {
1861 return isa
<SuccessorsDirective
>(elt
);
1863 if (hasAllSuccessors
) {
1864 body
<< " result.addSuccessors(fullSuccessors);\n";
1868 // Otherwise, handle each successor individually.
1869 for (const NamedSuccessor
&successor
: op
.getSuccessors()) {
1870 if (successor
.isVariadic())
1871 body
<< " result.addSuccessors(" << successor
.name
<< "Successors);\n";
1873 body
<< " result.addSuccessors(" << successor
.name
<< "Successor);\n";
1877 void OperationFormat::genParserVariadicSegmentResolution(Operator
&op
,
1880 if (op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1881 auto interleaveFn
= [&](const NamedTypeConstraint
&operand
) {
1882 // If the operand is variadic emit the parsed size.
1883 if (operand
.isVariableLength())
1884 body
<< "static_cast<int32_t>(" << operand
.name
<< "Operands.size())";
1888 if (op
.getDialect().usePropertiesForAttributes()) {
1889 body
<< "::llvm::copy(::llvm::ArrayRef<int32_t>({";
1890 llvm::interleaveComma(op
.getOperands(), body
, interleaveFn
);
1891 body
<< formatv("}), "
1892 "result.getOrAddProperties<{0}::Properties>()."
1893 "operandSegmentSizes.begin());\n",
1894 op
.getCppClassName());
1896 body
<< " result.addAttribute(\"operandSegmentSizes\", "
1897 << "parser.getBuilder().getDenseI32ArrayAttr({";
1898 llvm::interleaveComma(op
.getOperands(), body
, interleaveFn
);
1902 for (const NamedTypeConstraint
&operand
: op
.getOperands()) {
1903 if (!operand
.isVariadicOfVariadic())
1905 if (op
.getDialect().usePropertiesForAttributes()) {
1907 " result.getOrAddProperties<{0}::Properties>().{1} = "
1908 "parser.getBuilder().getDenseI32ArrayAttr({2}OperandGroupSizes);\n",
1909 op
.getCppClassName(),
1910 operand
.constraint
.getVariadicOfVariadicSegmentSizeAttr(),
1914 " result.addAttribute(\"{0}\", "
1915 "parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));"
1917 operand
.constraint
.getVariadicOfVariadicSegmentSizeAttr(),
1923 if (!allResultTypes
&&
1924 op
.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
1925 auto interleaveFn
= [&](const NamedTypeConstraint
&result
) {
1926 // If the result is variadic emit the parsed size.
1927 if (result
.isVariableLength())
1928 body
<< "static_cast<int32_t>(" << result
.name
<< "Types.size())";
1932 if (op
.getDialect().usePropertiesForAttributes()) {
1933 body
<< "::llvm::copy(::llvm::ArrayRef<int32_t>({";
1934 llvm::interleaveComma(op
.getResults(), body
, interleaveFn
);
1935 body
<< formatv("}), "
1936 "result.getOrAddProperties<{0}::Properties>()."
1937 "resultSegmentSizes.begin());\n",
1938 op
.getCppClassName());
1940 body
<< " result.addAttribute(\"resultSegmentSizes\", "
1941 << "parser.getBuilder().getDenseI32ArrayAttr({";
1942 llvm::interleaveComma(op
.getResults(), body
, interleaveFn
);
1948 //===----------------------------------------------------------------------===//
1951 /// The code snippet used to generate a printer call for a region of an
1952 // operation that has the SingleBlockImplicitTerminator trait.
1954 /// {0}: The name of the region.
1955 const char *regionSingleBlockImplicitTerminatorPrinterCode
= R
"(
1957 bool printTerminator = true;
1958 if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{
1959 printTerminator = !term->getAttrDictionary().empty() ||
1960 term->getNumOperands() != 0 ||
1961 term->getNumResults() != 0;
1963 _odsPrinter.printRegion({0}, /*printEntryBlockArgs=*/true,
1964 /*printBlockTerminators=*/printTerminator);
1968 /// The code snippet used to generate a printer call for an enum that has cases
1969 /// that can't be represented with a keyword.
1971 /// {0}: The name of the enum attribute.
1972 /// {1}: The name of the enum attributes symbolToString function.
1973 const char *enumAttrBeginPrinterCode
= R
"(
1975 auto caseValue = {0}();
1976 auto caseValueStr = {1}(caseValue);
1979 /// Generate a check that an optional or default-valued attribute or property
1980 /// has a non-default value. For these purposes, the default value of an
1981 /// optional attribute is its presence, even if the attribute itself has a
1983 static void genNonDefaultValueCheck(MethodBody
&body
, const Operator
&op
,
1984 AttributeVariable
&attrElement
) {
1985 Attribute attr
= attrElement
.getVar()->attr
;
1986 std::string getter
= op
.getGetterName(attrElement
.getVar()->name
);
1987 bool optionalAndDefault
= attr
.isOptional() && attr
.hasDefaultValue();
1988 if (optionalAndDefault
)
1990 if (attr
.isOptional())
1991 body
<< getter
<< "Attr()";
1992 if (optionalAndDefault
)
1994 if (attr
.hasDefaultValue()) {
1996 fctx
.withBuilder("::mlir::OpBuilder((*this)->getContext())");
1997 body
<< getter
<< "Attr() != "
1998 << tgfmt(attr
.getConstBuilderTemplate(), &fctx
,
1999 attr
.getDefaultValue());
2001 if (optionalAndDefault
)
2005 static void genNonDefaultValueCheck(MethodBody
&body
, const Operator
&op
,
2006 PropertyVariable
&propElement
) {
2007 body
<< op
.getGetterName(propElement
.getVar()->name
)
2008 << "() != " << propElement
.getVar()->prop
.getDefaultValue();
2011 /// Generate the printer for the 'prop-dict' directive.
2012 static void genPropDictPrinter(OperationFormat
&fmt
, Operator
&op
,
2014 body
<< " ::llvm::SmallVector<::llvm::StringRef, 2> elidedProps;\n";
2015 for (const NamedProperty
*namedProperty
: fmt
.usedProperties
)
2016 body
<< " elidedProps.push_back(\"" << namedProperty
->name
<< "\");\n";
2017 for (const NamedAttribute
*namedAttr
: fmt
.usedAttributes
)
2018 body
<< " elidedProps.push_back(\"" << namedAttr
->name
<< "\");\n";
2020 // Add code to check attributes for equality with the default value
2021 // for attributes with the elidePrintingDefaultValue bit set.
2022 for (const NamedAttribute
&namedAttr
: op
.getAttributes()) {
2023 const Attribute
&attr
= namedAttr
.attr
;
2024 if (!attr
.isDerivedAttr() && attr
.hasDefaultValue()) {
2025 const StringRef
&name
= namedAttr
.name
;
2027 fctx
.withBuilder("odsBuilder");
2028 std::string defaultValue
= std::string(
2029 tgfmt(attr
.getConstBuilderTemplate(), &fctx
, attr
.getDefaultValue()));
2031 body
<< " ::mlir::Builder odsBuilder(getContext());\n";
2032 body
<< " ::mlir::Attribute attr = " << op
.getGetterName(name
)
2034 body
<< " if(attr && (attr == " << defaultValue
<< "))\n";
2035 body
<< " elidedProps.push_back(\"" << name
<< "\");\n";
2039 // Similarly, elide default-valued properties.
2040 for (const NamedProperty
&prop
: op
.getProperties()) {
2041 if (prop
.prop
.hasDefaultValue()) {
2042 body
<< " if (" << op
.getGetterName(prop
.name
)
2043 << "() == " << prop
.prop
.getDefaultValue() << ") {";
2044 body
<< " elidedProps.push_back(\"" << prop
.name
<< "\");\n";
2049 if (fmt
.useProperties
) {
2050 body
<< " _odsPrinter << \" \";\n"
2051 << " printProperties(this->getContext(), _odsPrinter, "
2052 "getProperties(), elidedProps);\n";
2056 /// Generate the printer for the 'attr-dict' directive.
2057 static void genAttrDictPrinter(OperationFormat
&fmt
, Operator
&op
,
2058 MethodBody
&body
, bool withKeyword
) {
2059 body
<< " ::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;\n";
2060 // Elide the variadic segment size attributes if necessary.
2061 if (!fmt
.allOperands
&&
2062 op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"))
2063 body
<< " elidedAttrs.push_back(\"operandSegmentSizes\");\n";
2064 if (!fmt
.allResultTypes
&&
2065 op
.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
2066 body
<< " elidedAttrs.push_back(\"resultSegmentSizes\");\n";
2067 for (const StringRef key
: fmt
.inferredAttributes
.keys())
2068 body
<< " elidedAttrs.push_back(\"" << key
<< "\");\n";
2069 for (const NamedAttribute
*attr
: fmt
.usedAttributes
)
2070 body
<< " elidedAttrs.push_back(\"" << attr
->name
<< "\");\n";
2071 // Add code to check attributes for equality with the default value
2072 // for attributes with the elidePrintingDefaultValue bit set.
2073 for (const NamedAttribute
&namedAttr
: op
.getAttributes()) {
2074 const Attribute
&attr
= namedAttr
.attr
;
2075 if (!attr
.isDerivedAttr() && attr
.hasDefaultValue()) {
2076 const StringRef
&name
= namedAttr
.name
;
2078 fctx
.withBuilder("odsBuilder");
2079 std::string defaultValue
= std::string(
2080 tgfmt(attr
.getConstBuilderTemplate(), &fctx
, attr
.getDefaultValue()));
2082 body
<< " ::mlir::Builder odsBuilder(getContext());\n";
2083 body
<< " ::mlir::Attribute attr = " << op
.getGetterName(name
)
2085 body
<< " if(attr && (attr == " << defaultValue
<< "))\n";
2086 body
<< " elidedAttrs.push_back(\"" << name
<< "\");\n";
2090 if (fmt
.hasPropDict
)
2091 body
<< " _odsPrinter.printOptionalAttrDict"
2092 << (withKeyword
? "WithKeyword" : "")
2093 << "(llvm::to_vector((*this)->getDiscardableAttrs()), elidedAttrs);\n";
2095 body
<< " _odsPrinter.printOptionalAttrDict"
2096 << (withKeyword
? "WithKeyword" : "")
2097 << "((*this)->getAttrs(), elidedAttrs);\n";
2100 /// Generate the printer for a literal value. `shouldEmitSpace` is true if a
2101 /// space should be emitted before this element. `lastWasPunctuation` is true if
2102 /// the previous element was a punctuation literal.
2103 static void genLiteralPrinter(StringRef value
, MethodBody
&body
,
2104 bool &shouldEmitSpace
, bool &lastWasPunctuation
) {
2105 body
<< " _odsPrinter";
2107 // Don't insert a space for certain punctuation.
2108 if (shouldEmitSpace
&& shouldEmitSpaceBefore(value
, lastWasPunctuation
))
2110 body
<< " << \"" << value
<< "\";\n";
2112 // Insert a space after certain literals.
2114 value
.size() != 1 || !StringRef("<({[").contains(value
.front());
2115 lastWasPunctuation
= value
.front() != '_' && !isalpha(value
.front());
2118 /// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation`
2119 /// are set to false.
2120 static void genSpacePrinter(bool value
, MethodBody
&body
, bool &shouldEmitSpace
,
2121 bool &lastWasPunctuation
) {
2123 body
<< " _odsPrinter << ' ';\n";
2124 lastWasPunctuation
= false;
2126 lastWasPunctuation
= true;
2128 shouldEmitSpace
= false;
2131 /// Generate the printer for a custom directive parameter.
2132 static void genCustomDirectiveParameterPrinter(FormatElement
*element
,
2135 if (auto *attr
= dyn_cast
<AttributeVariable
>(element
)) {
2136 body
<< op
.getGetterName(attr
->getVar()->name
) << "Attr()";
2138 } else if (isa
<AttrDictDirective
>(element
)) {
2139 body
<< "getOperation()->getAttrDictionary()";
2141 } else if (isa
<PropDictDirective
>(element
)) {
2142 body
<< "getProperties()";
2144 } else if (auto *operand
= dyn_cast
<OperandVariable
>(element
)) {
2145 body
<< op
.getGetterName(operand
->getVar()->name
) << "()";
2147 } else if (auto *region
= dyn_cast
<RegionVariable
>(element
)) {
2148 body
<< op
.getGetterName(region
->getVar()->name
) << "()";
2150 } else if (auto *successor
= dyn_cast
<SuccessorVariable
>(element
)) {
2151 body
<< op
.getGetterName(successor
->getVar()->name
) << "()";
2153 } else if (auto *dir
= dyn_cast
<RefDirective
>(element
)) {
2154 genCustomDirectiveParameterPrinter(dir
->getArg(), op
, body
);
2156 } else if (auto *dir
= dyn_cast
<TypeDirective
>(element
)) {
2157 auto *typeOperand
= dir
->getArg();
2158 auto *operand
= dyn_cast
<OperandVariable
>(typeOperand
);
2159 auto *var
= operand
? operand
->getVar()
2160 : cast
<ResultVariable
>(typeOperand
)->getVar();
2161 std::string name
= op
.getGetterName(var
->name
);
2162 if (var
->isVariadic())
2163 body
<< name
<< "().getTypes()";
2164 else if (var
->isOptional())
2165 body
<< formatv("({0}() ? {0}().getType() : ::mlir::Type())", name
);
2167 body
<< name
<< "().getType()";
2169 } else if (auto *string
= dyn_cast
<StringElement
>(element
)) {
2171 ctx
.withBuilder("::mlir::Builder(getContext())");
2172 ctx
.addSubst("_ctxt", "getContext()");
2173 body
<< tgfmt(string
->getValue(), &ctx
);
2175 } else if (auto *property
= dyn_cast
<PropertyVariable
>(element
)) {
2177 const NamedProperty
*namedProperty
= property
->getVar();
2178 ctx
.addSubst("_storage", "getProperties()." + namedProperty
->name
);
2179 body
<< tgfmt(namedProperty
->prop
.getConvertFromStorageCall(), &ctx
);
2181 llvm_unreachable("unknown custom directive parameter");
2185 /// Generate the printer for a custom directive.
2186 static void genCustomDirectivePrinter(CustomDirective
*customDir
,
2187 const Operator
&op
, MethodBody
&body
) {
2188 body
<< " print" << customDir
->getName() << "(_odsPrinter, *this";
2189 for (FormatElement
*param
: customDir
->getArguments()) {
2191 genCustomDirectiveParameterPrinter(param
, op
, body
);
2196 /// Generate the printer for a region with the given variable name.
2197 static void genRegionPrinter(const Twine
®ionName
, MethodBody
&body
,
2198 bool hasImplicitTermTrait
) {
2199 if (hasImplicitTermTrait
)
2200 body
<< formatv(regionSingleBlockImplicitTerminatorPrinterCode
, regionName
);
2202 body
<< " _odsPrinter.printRegion(" << regionName
<< ");\n";
2204 static void genVariadicRegionPrinter(const Twine
®ionListName
,
2206 bool hasImplicitTermTrait
) {
2207 body
<< " llvm::interleaveComma(" << regionListName
2208 << ", _odsPrinter, [&](::mlir::Region ®ion) {\n ";
2209 genRegionPrinter("region", body
, hasImplicitTermTrait
);
2213 /// Generate the C++ for an operand to a (*-)type directive.
2214 static MethodBody
&genTypeOperandPrinter(FormatElement
*arg
, const Operator
&op
,
2216 bool useArrayRef
= true) {
2217 if (isa
<OperandsDirective
>(arg
))
2218 return body
<< "getOperation()->getOperandTypes()";
2219 if (isa
<ResultsDirective
>(arg
))
2220 return body
<< "getOperation()->getResultTypes()";
2221 auto *operand
= dyn_cast
<OperandVariable
>(arg
);
2222 auto *var
= operand
? operand
->getVar() : cast
<ResultVariable
>(arg
)->getVar();
2223 if (var
->isVariadicOfVariadic())
2224 return body
<< formatv("{0}().join().getTypes()",
2225 op
.getGetterName(var
->name
));
2226 if (var
->isVariadic())
2227 return body
<< op
.getGetterName(var
->name
) << "().getTypes()";
2228 if (var
->isOptional())
2229 return body
<< formatv(
2230 "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
2231 "::llvm::ArrayRef<::mlir::Type>())",
2232 op
.getGetterName(var
->name
));
2234 return body
<< "::llvm::ArrayRef<::mlir::Type>("
2235 << op
.getGetterName(var
->name
) << "().getType())";
2236 return body
<< op
.getGetterName(var
->name
) << "().getType()";
2239 /// Generate the printer for an enum attribute.
2240 static void genEnumAttrPrinter(const NamedAttribute
*var
, const Operator
&op
,
2242 Attribute baseAttr
= var
->attr
.getBaseAttr();
2243 const EnumAttr
&enumAttr
= cast
<EnumAttr
>(baseAttr
);
2244 std::vector
<EnumAttrCase
> cases
= enumAttr
.getAllCases();
2246 body
<< formatv(enumAttrBeginPrinterCode
,
2247 (var
->attr
.isOptional() ? "*" : "") +
2248 op
.getGetterName(var
->name
),
2249 enumAttr
.getSymbolToStringFnName());
2251 // Get a string containing all of the cases that can't be represented with a
2253 BitVector
nonKeywordCases(cases
.size());
2254 for (auto it
: llvm::enumerate(cases
)) {
2255 if (!canFormatStringAsKeyword(it
.value().getStr()))
2256 nonKeywordCases
.set(it
.index());
2259 // Otherwise if this is a bit enum attribute, don't allow cases that may
2260 // overlap with other cases. For simplicity sake, only allow cases with a
2261 // single bit value.
2262 if (enumAttr
.isBitEnum()) {
2263 for (auto it
: llvm::enumerate(cases
)) {
2264 int64_t value
= it
.value().getValue();
2265 if (value
< 0 || !llvm::isPowerOf2_64(value
))
2266 nonKeywordCases
.set(it
.index());
2270 // If there are any cases that can't be used with a keyword, switch on the
2271 // case value to determine when to print in the string form.
2272 if (nonKeywordCases
.any()) {
2273 body
<< " switch (caseValue) {\n";
2274 StringRef cppNamespace
= enumAttr
.getCppNamespace();
2275 StringRef enumName
= enumAttr
.getEnumClassName();
2276 for (auto it
: llvm::enumerate(cases
)) {
2277 if (nonKeywordCases
.test(it
.index()))
2279 StringRef symbol
= it
.value().getSymbol();
2280 body
<< formatv(" case {0}::{1}::{2}:\n", cppNamespace
, enumName
,
2281 llvm::isDigit(symbol
.front()) ? ("_" + symbol
) : symbol
);
2283 body
<< " _odsPrinter << caseValueStr;\n"
2286 " _odsPrinter << '\"' << caseValueStr << '\"';\n"
2293 body
<< " _odsPrinter << caseValueStr;\n"
2297 /// Generate the check for the anchor of an optional group.
2298 static void genOptionalGroupPrinterAnchor(FormatElement
*anchor
,
2301 TypeSwitch
<FormatElement
*>(anchor
)
2302 .Case
<OperandVariable
, ResultVariable
>([&](auto *element
) {
2303 const NamedTypeConstraint
*var
= element
->getVar();
2304 std::string name
= op
.getGetterName(var
->name
);
2305 if (var
->isOptional())
2306 body
<< name
<< "()";
2307 else if (var
->isVariadic())
2308 body
<< "!" << name
<< "().empty()";
2310 .Case([&](RegionVariable
*element
) {
2311 const NamedRegion
*var
= element
->getVar();
2312 std::string name
= op
.getGetterName(var
->name
);
2313 // TODO: Add a check for optional regions here when ODS supports it.
2314 body
<< "!" << name
<< "().empty()";
2316 .Case([&](TypeDirective
*element
) {
2317 genOptionalGroupPrinterAnchor(element
->getArg(), op
, body
);
2319 .Case([&](FunctionalTypeDirective
*element
) {
2320 genOptionalGroupPrinterAnchor(element
->getInputs(), op
, body
);
2322 .Case([&](AttributeVariable
*element
) {
2323 // Consider a default-valued attribute as present if it's not the
2324 // default value and an optional one present if it is set.
2325 genNonDefaultValueCheck(body
, op
, *element
);
2327 .Case([&](PropertyVariable
*element
) {
2328 genNonDefaultValueCheck(body
, op
, *element
);
2330 .Case([&](CustomDirective
*ele
) {
2333 ele
->getArguments(), body
,
2334 [&](FormatElement
*child
) {
2336 genOptionalGroupPrinterAnchor(child
, op
, body
);
2344 void collect(FormatElement
*element
,
2345 SmallVectorImpl
<VariableElement
*> &variables
) {
2346 TypeSwitch
<FormatElement
*>(element
)
2347 .Case([&](VariableElement
*var
) { variables
.emplace_back(var
); })
2348 .Case([&](CustomDirective
*ele
) {
2349 for (FormatElement
*arg
: ele
->getArguments())
2350 collect(arg
, variables
);
2352 .Case([&](OptionalElement
*ele
) {
2353 for (FormatElement
*arg
: ele
->getThenElements())
2354 collect(arg
, variables
);
2355 for (FormatElement
*arg
: ele
->getElseElements())
2356 collect(arg
, variables
);
2358 .Case([&](FunctionalTypeDirective
*funcType
) {
2359 collect(funcType
->getInputs(), variables
);
2360 collect(funcType
->getResults(), variables
);
2362 .Case([&](OIListElement
*oilist
) {
2363 for (ArrayRef
<FormatElement
*> arg
: oilist
->getParsingElements())
2364 for (FormatElement
*arg
: arg
)
2365 collect(arg
, variables
);
2369 void OperationFormat::genElementPrinter(FormatElement
*element
,
2370 MethodBody
&body
, Operator
&op
,
2371 bool &shouldEmitSpace
,
2372 bool &lastWasPunctuation
) {
2373 if (LiteralElement
*literal
= dyn_cast
<LiteralElement
>(element
))
2374 return genLiteralPrinter(literal
->getSpelling(), body
, shouldEmitSpace
,
2375 lastWasPunctuation
);
2377 // Emit a whitespace element.
2378 if (auto *space
= dyn_cast
<WhitespaceElement
>(element
)) {
2379 if (space
->getValue() == "\\n") {
2380 body
<< " _odsPrinter.printNewline();\n";
2382 genSpacePrinter(!space
->getValue().empty(), body
, shouldEmitSpace
,
2383 lastWasPunctuation
);
2388 // Emit an optional group.
2389 if (OptionalElement
*optional
= dyn_cast
<OptionalElement
>(element
)) {
2390 // Emit the check for the presence of the anchor element.
2391 FormatElement
*anchor
= optional
->getAnchor();
2393 if (optional
->isInverted())
2395 genOptionalGroupPrinterAnchor(anchor
, op
, body
);
2399 // If the anchor is a unit attribute, we don't need to print it. When
2400 // parsing, we will add this attribute if this group is present.
2401 ArrayRef
<FormatElement
*> thenElements
= optional
->getThenElements();
2402 ArrayRef
<FormatElement
*> elseElements
= optional
->getElseElements();
2403 FormatElement
*elidedAnchorElement
= nullptr;
2404 auto *anchorAttr
= dyn_cast
<AttributeLikeVariable
>(anchor
);
2405 if (anchorAttr
&& anchorAttr
!= thenElements
.front() &&
2406 (elseElements
.empty() || anchorAttr
!= elseElements
.front()) &&
2407 anchorAttr
->isUnit()) {
2408 elidedAnchorElement
= anchorAttr
;
2410 auto genElementPrinters
= [&](ArrayRef
<FormatElement
*> elements
) {
2411 for (FormatElement
*childElement
: elements
) {
2412 if (childElement
!= elidedAnchorElement
) {
2413 genElementPrinter(childElement
, body
, op
, shouldEmitSpace
,
2414 lastWasPunctuation
);
2419 // Emit each of the elements.
2420 genElementPrinters(thenElements
);
2423 // Emit each of the else elements.
2424 if (!elseElements
.empty()) {
2425 body
<< " else {\n";
2426 genElementPrinters(elseElements
);
2430 body
.unindent() << "\n";
2435 if (auto *oilist
= dyn_cast
<OIListElement
>(element
)) {
2436 for (auto clause
: oilist
->getClauses()) {
2437 LiteralElement
*lelement
= std::get
<0>(clause
);
2438 ArrayRef
<FormatElement
*> pelement
= std::get
<1>(clause
);
2440 SmallVector
<VariableElement
*> vars
;
2441 for (FormatElement
*el
: pelement
)
2443 body
<< " if (false";
2444 for (VariableElement
*var
: vars
) {
2445 TypeSwitch
<FormatElement
*>(var
)
2446 .Case([&](AttributeVariable
*attrEle
) {
2448 genNonDefaultValueCheck(body
, op
, *attrEle
);
2451 .Case([&](PropertyVariable
*propEle
) {
2453 genNonDefaultValueCheck(body
, op
, *propEle
);
2456 .Case([&](OperandVariable
*ele
) {
2457 if (ele
->getVar()->isVariadic()) {
2458 body
<< " || " << op
.getGetterName(ele
->getVar()->name
)
2461 body
<< " || " << op
.getGetterName(ele
->getVar()->name
) << "()";
2464 .Case([&](ResultVariable
*ele
) {
2465 if (ele
->getVar()->isVariadic()) {
2466 body
<< " || " << op
.getGetterName(ele
->getVar()->name
)
2469 body
<< " || " << op
.getGetterName(ele
->getVar()->name
) << "()";
2472 .Case([&](RegionVariable
*reg
) {
2473 body
<< " || " << op
.getGetterName(reg
->getVar()->name
) << "()";
2478 genLiteralPrinter(lelement
->getSpelling(), body
, shouldEmitSpace
,
2479 lastWasPunctuation
);
2480 if (oilist
->getUnitVariableParsingElement(pelement
) == nullptr) {
2481 for (FormatElement
*element
: pelement
)
2482 genElementPrinter(element
, body
, op
, shouldEmitSpace
,
2483 lastWasPunctuation
);
2490 // Emit the attribute dictionary.
2491 if (auto *attrDict
= dyn_cast
<AttrDictDirective
>(element
)) {
2492 genAttrDictPrinter(*this, op
, body
, attrDict
->isWithKeyword());
2493 lastWasPunctuation
= false;
2497 // Emit the property dictionary.
2498 if (isa
<PropDictDirective
>(element
)) {
2499 genPropDictPrinter(*this, op
, body
);
2500 lastWasPunctuation
= false;
2504 // Optionally insert a space before the next element. The AttrDict printer
2505 // already adds a space as necessary.
2506 if (shouldEmitSpace
|| !lastWasPunctuation
)
2507 body
<< " _odsPrinter << ' ';\n";
2508 lastWasPunctuation
= false;
2509 shouldEmitSpace
= true;
2511 if (auto *attr
= dyn_cast
<AttributeVariable
>(element
)) {
2512 const NamedAttribute
*var
= attr
->getVar();
2514 // If we are formatting as an enum, symbolize the attribute as a string.
2515 if (canFormatEnumAttr(var
))
2516 return genEnumAttrPrinter(var
, op
, body
);
2518 // If we are formatting as a symbol name, handle it as a symbol name.
2519 if (shouldFormatSymbolNameAttr(var
)) {
2520 body
<< " _odsPrinter.printSymbolName(" << op
.getGetterName(var
->name
)
2521 << "Attr().getValue());\n";
2525 // Elide the attribute type if it is buildable.
2526 if (attr
->getTypeBuilder())
2527 body
<< " _odsPrinter.printAttributeWithoutType("
2528 << op
.getGetterName(var
->name
) << "Attr());\n";
2529 else if (attr
->shouldBeQualified() ||
2530 var
->attr
.getStorageType() == "::mlir::Attribute")
2531 body
<< " _odsPrinter.printAttribute(" << op
.getGetterName(var
->name
)
2534 body
<< "_odsPrinter.printStrippedAttrOrType("
2535 << op
.getGetterName(var
->name
) << "Attr());\n";
2536 } else if (auto *property
= dyn_cast
<PropertyVariable
>(element
)) {
2537 const NamedProperty
*var
= property
->getVar();
2538 FmtContext fmtContext
;
2539 fmtContext
.addSubst("_printer", "_odsPrinter");
2540 fmtContext
.addSubst("_ctxt", "getContext()");
2541 fmtContext
.addSubst("_storage", "getProperties()." + var
->name
);
2542 body
<< tgfmt(var
->prop
.getPrinterCall(), &fmtContext
) << ";\n";
2543 } else if (auto *operand
= dyn_cast
<OperandVariable
>(element
)) {
2544 if (operand
->getVar()->isVariadicOfVariadic()) {
2545 body
<< " ::llvm::interleaveComma("
2546 << op
.getGetterName(operand
->getVar()->name
)
2547 << "(), _odsPrinter, [&](const auto &operands) { _odsPrinter << "
2548 "\"(\" << operands << "
2551 } else if (operand
->getVar()->isOptional()) {
2552 body
<< " if (::mlir::Value value = "
2553 << op
.getGetterName(operand
->getVar()->name
) << "())\n"
2554 << " _odsPrinter << value;\n";
2556 body
<< " _odsPrinter << " << op
.getGetterName(operand
->getVar()->name
)
2559 } else if (auto *region
= dyn_cast
<RegionVariable
>(element
)) {
2560 const NamedRegion
*var
= region
->getVar();
2561 std::string name
= op
.getGetterName(var
->name
);
2562 if (var
->isVariadic()) {
2563 genVariadicRegionPrinter(name
+ "()", body
, hasImplicitTermTrait
);
2565 genRegionPrinter(name
+ "()", body
, hasImplicitTermTrait
);
2567 } else if (auto *successor
= dyn_cast
<SuccessorVariable
>(element
)) {
2568 const NamedSuccessor
*var
= successor
->getVar();
2569 std::string name
= op
.getGetterName(var
->name
);
2570 if (var
->isVariadic())
2571 body
<< " ::llvm::interleaveComma(" << name
<< "(), _odsPrinter);\n";
2573 body
<< " _odsPrinter << " << name
<< "();\n";
2574 } else if (auto *dir
= dyn_cast
<CustomDirective
>(element
)) {
2575 genCustomDirectivePrinter(dir
, op
, body
);
2576 } else if (isa
<OperandsDirective
>(element
)) {
2577 body
<< " _odsPrinter << getOperation()->getOperands();\n";
2578 } else if (isa
<RegionsDirective
>(element
)) {
2579 genVariadicRegionPrinter("getOperation()->getRegions()", body
,
2580 hasImplicitTermTrait
);
2581 } else if (isa
<SuccessorsDirective
>(element
)) {
2582 body
<< " ::llvm::interleaveComma(getOperation()->getSuccessors(), "
2584 } else if (auto *dir
= dyn_cast
<TypeDirective
>(element
)) {
2585 if (auto *operand
= dyn_cast
<OperandVariable
>(dir
->getArg())) {
2586 if (operand
->getVar()->isVariadicOfVariadic()) {
2588 " ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, "
2589 "[&](::mlir::TypeRange types) {{ _odsPrinter << \"(\" << "
2590 "types << \")\"; });\n",
2591 op
.getGetterName(operand
->getVar()->name
));
2595 const NamedTypeConstraint
*var
= nullptr;
2597 if (auto *operand
= dyn_cast
<OperandVariable
>(dir
->getArg()))
2598 var
= operand
->getVar();
2599 else if (auto *operand
= dyn_cast
<ResultVariable
>(dir
->getArg()))
2600 var
= operand
->getVar();
2602 if (var
&& !var
->isVariadicOfVariadic() && !var
->isVariadic() &&
2603 !var
->isOptional()) {
2604 StringRef cppType
= var
->constraint
.getCppType();
2605 if (dir
->shouldBeQualified()) {
2606 body
<< " _odsPrinter << " << op
.getGetterName(var
->name
)
2607 << "().getType();\n";
2611 << " auto type = " << op
.getGetterName(var
->name
)
2612 << "().getType();\n"
2613 << " if (auto validType = ::llvm::dyn_cast<" << cppType
2615 << " _odsPrinter.printStrippedAttrOrType(validType);\n"
2617 << " _odsPrinter << type;\n"
2621 body
<< " _odsPrinter << ";
2622 genTypeOperandPrinter(dir
->getArg(), op
, body
, /*useArrayRef=*/false)
2624 } else if (auto *dir
= dyn_cast
<FunctionalTypeDirective
>(element
)) {
2625 body
<< " _odsPrinter.printFunctionalType(";
2626 genTypeOperandPrinter(dir
->getInputs(), op
, body
) << ", ";
2627 genTypeOperandPrinter(dir
->getResults(), op
, body
) << ");\n";
2629 llvm_unreachable("unknown format element");
2633 void OperationFormat::genPrinter(Operator
&op
, OpClass
&opClass
) {
2634 auto *method
= opClass
.addMethod(
2636 MethodParameter("::mlir::OpAsmPrinter &", "_odsPrinter"));
2637 auto &body
= method
->body();
2639 // Flags for if we should emit a space, and if the last element was
2641 bool shouldEmitSpace
= true, lastWasPunctuation
= false;
2642 for (FormatElement
*element
: elements
)
2643 genElementPrinter(element
, body
, op
, shouldEmitSpace
, lastWasPunctuation
);
2646 //===----------------------------------------------------------------------===//
2648 //===----------------------------------------------------------------------===//
2650 /// Function to find an element within the given range that has the same name as
2652 template <typename RangeT
>
2653 static auto findArg(RangeT
&&range
, StringRef name
) {
2654 auto it
= llvm::find_if(range
, [=](auto &arg
) { return arg
.name
== name
; });
2655 return it
!= range
.end() ? &*it
: nullptr;
2659 /// This class implements a parser for an instance of an operation assembly
2661 class OpFormatParser
: public FormatParser
{
2663 OpFormatParser(llvm::SourceMgr
&mgr
, OperationFormat
&format
, Operator
&op
)
2664 : FormatParser(mgr
, op
.getLoc()[0]), fmt(format
), op(op
),
2665 seenOperandTypes(op
.getNumOperands()),
2666 seenResultTypes(op
.getNumResults()) {}
2669 /// Verify the format elements.
2670 LogicalResult
verify(SMLoc loc
, ArrayRef
<FormatElement
*> elements
) override
;
2671 /// Verify the arguments to a custom directive.
2673 verifyCustomDirectiveArguments(SMLoc loc
,
2674 ArrayRef
<FormatElement
*> arguments
) override
;
2675 /// Verify the elements of an optional group.
2676 LogicalResult
verifyOptionalGroupElements(SMLoc loc
,
2677 ArrayRef
<FormatElement
*> elements
,
2678 FormatElement
*anchor
) override
;
2679 LogicalResult
verifyOptionalGroupElement(SMLoc loc
, FormatElement
*element
,
2682 LogicalResult
markQualified(SMLoc loc
, FormatElement
*element
) override
;
2684 /// Parse an operation variable.
2685 FailureOr
<FormatElement
*> parseVariableImpl(SMLoc loc
, StringRef name
,
2686 Context ctx
) override
;
2687 /// Parse an operation format directive.
2688 FailureOr
<FormatElement
*>
2689 parseDirectiveImpl(SMLoc loc
, FormatToken::Kind kind
, Context ctx
) override
;
2692 /// This struct represents a type resolution instance. It includes a specific
2693 /// type as well as an optional transformer to apply to that type in order to
2694 /// properly resolve the type of a variable.
2695 struct TypeResolutionInstance
{
2696 ConstArgument resolver
;
2697 std::optional
<StringRef
> transformer
;
2700 /// Verify the state of operation attributes within the format.
2701 LogicalResult
verifyAttributes(SMLoc loc
, ArrayRef
<FormatElement
*> elements
);
2703 /// Verify that attributes elements aren't followed by colon literals.
2704 LogicalResult
verifyAttributeColonType(SMLoc loc
,
2705 ArrayRef
<FormatElement
*> elements
);
2706 /// Verify that the attribute dictionary directive isn't followed by a region.
2707 LogicalResult
verifyAttrDictRegion(SMLoc loc
,
2708 ArrayRef
<FormatElement
*> elements
);
2710 /// Verify the state of operation operands within the format.
2712 verifyOperands(SMLoc loc
,
2713 StringMap
<TypeResolutionInstance
> &variableTyResolver
);
2715 /// Verify the state of operation regions within the format.
2716 LogicalResult
verifyRegions(SMLoc loc
);
2718 /// Verify the state of operation results within the format.
2720 verifyResults(SMLoc loc
,
2721 StringMap
<TypeResolutionInstance
> &variableTyResolver
);
2723 /// Verify the state of operation successors within the format.
2724 LogicalResult
verifySuccessors(SMLoc loc
);
2726 LogicalResult
verifyOIListElements(SMLoc loc
,
2727 ArrayRef
<FormatElement
*> elements
);
2729 /// Given the values of an `AllTypesMatch` trait, check for inferable type
2731 void handleAllTypesMatchConstraint(
2732 ArrayRef
<StringRef
> values
,
2733 StringMap
<TypeResolutionInstance
> &variableTyResolver
);
2734 /// Check for inferable type resolution given all operands, and or results,
2735 /// have the same type. If 'includeResults' is true, the results also have the
2736 /// same type as all of the operands.
2737 void handleSameTypesConstraint(
2738 StringMap
<TypeResolutionInstance
> &variableTyResolver
,
2739 bool includeResults
);
2740 /// Check for inferable type resolution based on another operand, result, or
2742 void handleTypesMatchConstraint(
2743 StringMap
<TypeResolutionInstance
> &variableTyResolver
, const Record
&def
);
2745 /// Returns an argument or attribute with the given name that has been seen
2746 /// within the format.
2747 ConstArgument
findSeenArg(StringRef name
);
2749 /// Parse the various different directives.
2750 FailureOr
<FormatElement
*> parsePropDictDirective(SMLoc loc
, Context context
);
2751 FailureOr
<FormatElement
*> parseAttrDictDirective(SMLoc loc
, Context context
,
2753 FailureOr
<FormatElement
*> parseFunctionalTypeDirective(SMLoc loc
,
2755 FailureOr
<FormatElement
*> parseOIListDirective(SMLoc loc
, Context context
);
2756 LogicalResult
verifyOIListParsingElement(FormatElement
*element
, SMLoc loc
);
2757 FailureOr
<FormatElement
*> parseOperandsDirective(SMLoc loc
, Context context
);
2758 FailureOr
<FormatElement
*> parseRegionsDirective(SMLoc loc
, Context context
);
2759 FailureOr
<FormatElement
*> parseResultsDirective(SMLoc loc
, Context context
);
2760 FailureOr
<FormatElement
*> parseSuccessorsDirective(SMLoc loc
,
2762 FailureOr
<FormatElement
*> parseTypeDirective(SMLoc loc
, Context context
);
2763 FailureOr
<FormatElement
*> parseTypeDirectiveOperand(SMLoc loc
,
2764 bool isRefChild
= false);
2766 //===--------------------------------------------------------------------===//
2768 //===--------------------------------------------------------------------===//
2770 OperationFormat
&fmt
;
2773 // The following are various bits of format state used for verification
2775 bool hasAttrDict
= false;
2776 bool hasPropDict
= false;
2777 bool hasAllRegions
= false, hasAllSuccessors
= false;
2778 bool canInferResultTypes
= false;
2779 llvm::SmallBitVector seenOperandTypes
, seenResultTypes
;
2780 llvm::SmallSetVector
<const NamedAttribute
*, 8> seenAttrs
;
2781 llvm::DenseSet
<const NamedTypeConstraint
*> seenOperands
;
2782 llvm::DenseSet
<const NamedRegion
*> seenRegions
;
2783 llvm::DenseSet
<const NamedSuccessor
*> seenSuccessors
;
2784 llvm::SmallSetVector
<const NamedProperty
*, 8> seenProperties
;
2788 LogicalResult
OpFormatParser::verify(SMLoc loc
,
2789 ArrayRef
<FormatElement
*> elements
) {
2790 // Check that the attribute dictionary is in the format.
2792 return emitError(loc
, "'attr-dict' directive not found in "
2793 "custom assembly format");
2795 // Check for any type traits that we can use for inferring types.
2796 StringMap
<TypeResolutionInstance
> variableTyResolver
;
2797 for (const Trait
&trait
: op
.getTraits()) {
2798 const Record
&def
= trait
.getDef();
2799 if (def
.isSubClassOf("AllTypesMatch")) {
2800 handleAllTypesMatchConstraint(def
.getValueAsListOfStrings("values"),
2801 variableTyResolver
);
2802 } else if (def
.getName() == "SameTypeOperands") {
2803 handleSameTypesConstraint(variableTyResolver
, /*includeResults=*/false);
2804 } else if (def
.getName() == "SameOperandsAndResultType") {
2805 handleSameTypesConstraint(variableTyResolver
, /*includeResults=*/true);
2806 } else if (def
.isSubClassOf("TypesMatchWith")) {
2807 handleTypesMatchConstraint(variableTyResolver
, def
);
2808 } else if (!op
.allResultTypesKnown()) {
2809 // This doesn't check the name directly to handle
2810 // DeclareOpInterfaceMethods<InferTypeOpInterface>
2812 // TODO: Add hasCppInterface check.
2813 if (auto name
= def
.getValueAsOptionalString("cppInterfaceName")) {
2814 if (*name
== "InferTypeOpInterface" &&
2815 def
.getValueAsString("cppNamespace") == "::mlir")
2816 canInferResultTypes
= true;
2821 // Verify the state of the various operation components.
2822 if (failed(verifyAttributes(loc
, elements
)) ||
2823 failed(verifyResults(loc
, variableTyResolver
)) ||
2824 failed(verifyOperands(loc
, variableTyResolver
)) ||
2825 failed(verifyRegions(loc
)) || failed(verifySuccessors(loc
)) ||
2826 failed(verifyOIListElements(loc
, elements
)))
2829 // Collect the set of used attributes in the format.
2830 fmt
.usedAttributes
= std::move(seenAttrs
);
2831 fmt
.usedProperties
= std::move(seenProperties
);
2833 // Set whether prop-dict is used in the format
2834 fmt
.hasPropDict
= hasPropDict
;
2839 OpFormatParser::verifyAttributes(SMLoc loc
,
2840 ArrayRef
<FormatElement
*> elements
) {
2841 // Check that there are no `:` literals after an attribute without a constant
2842 // type. The attribute grammar contains an optional trailing colon type, which
2843 // can lead to unexpected and generally unintended behavior. Given that, it is
2844 // better to just error out here instead.
2845 if (failed(verifyAttributeColonType(loc
, elements
)))
2847 // Check that there are no region variables following an attribute dicitonary.
2848 // Both start with `{` and so the optional attribute dictionary can cause
2849 // format ambiguities.
2850 if (failed(verifyAttrDictRegion(loc
, elements
)))
2853 // Check for VariadicOfVariadic variables. The segment attribute of those
2854 // variables will be infered.
2855 for (const NamedTypeConstraint
*var
: seenOperands
) {
2856 if (var
->constraint
.isVariadicOfVariadic()) {
2857 fmt
.inferredAttributes
.insert(
2858 var
->constraint
.getVariadicOfVariadicSegmentSizeAttr());
2865 /// Returns whether the single format element is optionally parsed.
2866 static bool isOptionallyParsed(FormatElement
*el
) {
2867 if (auto *attrVar
= dyn_cast
<AttributeVariable
>(el
)) {
2868 Attribute attr
= attrVar
->getVar()->attr
;
2869 return attr
.isOptional() || attr
.hasDefaultValue();
2871 if (auto *propVar
= dyn_cast
<PropertyVariable
>(el
)) {
2872 const Property
&prop
= propVar
->getVar()->prop
;
2873 return prop
.hasDefaultValue() && prop
.hasOptionalParser();
2875 if (auto *operandVar
= dyn_cast
<OperandVariable
>(el
)) {
2876 const NamedTypeConstraint
*operand
= operandVar
->getVar();
2877 return operand
->isOptional() || operand
->isVariadic() ||
2878 operand
->isVariadicOfVariadic();
2880 if (auto *successorVar
= dyn_cast
<SuccessorVariable
>(el
))
2881 return successorVar
->getVar()->isVariadic();
2882 if (auto *regionVar
= dyn_cast
<RegionVariable
>(el
))
2883 return regionVar
->getVar()->isVariadic();
2884 return isa
<WhitespaceElement
, AttrDictDirective
>(el
);
2887 /// Scan the given range of elements from the start for an invalid format
2888 /// element that satisfies `isInvalid`, skipping any optionally-parsed elements.
2889 /// If an optional group is encountered, this function recurses into the 'then'
2890 /// and 'else' elements to check if they are invalid. Returns `success` if the
2891 /// range is known to be valid or `std::nullopt` if scanning reached the end.
2893 /// Since the guard element of an optional group is required, this function
2894 /// accepts an optional element pointer to mark it as required.
2895 static std::optional
<LogicalResult
> checkRangeForElement(
2896 FormatElement
*base
,
2897 function_ref
<bool(FormatElement
*, FormatElement
*)> isInvalid
,
2898 iterator_range
<ArrayRef
<FormatElement
*>::iterator
> elementRange
,
2899 FormatElement
*optionalGuard
= nullptr) {
2900 for (FormatElement
*element
: elementRange
) {
2901 // If we encounter an invalid element, return an error.
2902 if (isInvalid(base
, element
))
2905 // Recurse on optional groups.
2906 if (auto *optional
= dyn_cast
<OptionalElement
>(element
)) {
2907 if (std::optional
<LogicalResult
> result
= checkRangeForElement(
2908 base
, isInvalid
, optional
->getThenElements(),
2909 // The optional group guard is required for the group.
2910 optional
->getThenElements().front()))
2911 if (failed(*result
))
2913 if (std::optional
<LogicalResult
> result
= checkRangeForElement(
2914 base
, isInvalid
, optional
->getElseElements()))
2915 if (failed(*result
))
2917 // Skip the optional group.
2921 // Skip optionally parsed elements.
2922 if (element
!= optionalGuard
&& isOptionallyParsed(element
))
2925 // We found a closing element that is valid.
2928 // Return std::nullopt to indicate that we reached the end.
2929 return std::nullopt
;
2932 /// For the given elements, check whether any attributes are followed by a colon
2933 /// literal, resulting in an ambiguous assembly format. Returns a non-null
2934 /// attribute if verification of said attribute reached the end of the range.
2935 /// Returns null if all attribute elements are verified.
2936 static FailureOr
<FormatElement
*> verifyAdjacentElements(
2937 function_ref
<bool(FormatElement
*)> isBase
,
2938 function_ref
<bool(FormatElement
*, FormatElement
*)> isInvalid
,
2939 ArrayRef
<FormatElement
*> elements
) {
2940 for (auto *it
= elements
.begin(), *e
= elements
.end(); it
!= e
; ++it
) {
2941 // The current attribute being verified.
2942 FormatElement
*base
;
2946 } else if (auto *optional
= dyn_cast
<OptionalElement
>(*it
)) {
2947 // Recurse on optional groups.
2948 FailureOr
<FormatElement
*> thenResult
= verifyAdjacentElements(
2949 isBase
, isInvalid
, optional
->getThenElements());
2950 if (failed(thenResult
))
2952 FailureOr
<FormatElement
*> elseResult
= verifyAdjacentElements(
2953 isBase
, isInvalid
, optional
->getElseElements());
2954 if (failed(elseResult
))
2956 // If either optional group has an unverified attribute, save it.
2957 // Otherwise, move on to the next element.
2958 if (!(base
= *thenResult
) && !(base
= *elseResult
))
2964 // Verify subsequent elements for potential ambiguities.
2965 if (std::optional
<LogicalResult
> result
=
2966 checkRangeForElement(base
, isInvalid
, {std::next(it
), e
})) {
2967 if (failed(*result
))
2970 // Since we reached the end, return the attribute as unverified.
2974 // All attribute elements are known to be verified.
2979 OpFormatParser::verifyAttributeColonType(SMLoc loc
,
2980 ArrayRef
<FormatElement
*> elements
) {
2981 auto isBase
= [](FormatElement
*el
) {
2982 auto *attr
= dyn_cast
<AttributeVariable
>(el
);
2985 // Check only attributes without type builders or that are known to call
2986 // the generic attribute parser.
2987 return !attr
->getTypeBuilder() &&
2988 (attr
->shouldBeQualified() ||
2989 attr
->getVar()->attr
.getStorageType() == "::mlir::Attribute");
2991 auto isInvalid
= [&](FormatElement
*base
, FormatElement
*el
) {
2992 auto *literal
= dyn_cast
<LiteralElement
>(el
);
2993 if (!literal
|| literal
->getSpelling() != ":")
2995 // If we encounter `:`, the range is known to be invalid.
2997 loc
, formatv("format ambiguity caused by `:` literal found after "
2998 "attribute `{0}` which does not have a buildable type",
2999 cast
<AttributeVariable
>(base
)->getVar()->name
));
3002 return verifyAdjacentElements(isBase
, isInvalid
, elements
);
3006 OpFormatParser::verifyAttrDictRegion(SMLoc loc
,
3007 ArrayRef
<FormatElement
*> elements
) {
3008 auto isBase
= [](FormatElement
*el
) {
3009 if (auto *attrDict
= dyn_cast
<AttrDictDirective
>(el
))
3010 return !attrDict
->isWithKeyword();
3013 auto isInvalid
= [&](FormatElement
*base
, FormatElement
*el
) {
3014 auto *region
= dyn_cast
<RegionVariable
>(el
);
3017 (void)emitErrorAndNote(
3019 formatv("format ambiguity caused by `attr-dict` directive "
3020 "followed by region `{0}`",
3021 region
->getVar()->name
),
3022 "try using `attr-dict-with-keyword` instead");
3025 return verifyAdjacentElements(isBase
, isInvalid
, elements
);
3028 LogicalResult
OpFormatParser::verifyOperands(
3029 SMLoc loc
, StringMap
<TypeResolutionInstance
> &variableTyResolver
) {
3030 // Check that all of the operands are within the format, and their types can
3032 auto &buildableTypes
= fmt
.buildableTypes
;
3033 for (unsigned i
= 0, e
= op
.getNumOperands(); i
!= e
; ++i
) {
3034 NamedTypeConstraint
&operand
= op
.getOperand(i
);
3036 // Check that the operand itself is in the format.
3037 if (!fmt
.allOperands
&& !seenOperands
.count(&operand
)) {
3038 return emitErrorAndNote(loc
,
3039 "operand #" + Twine(i
) + ", named '" +
3040 operand
.name
+ "', not found",
3041 "suggest adding a '$" + operand
.name
+
3042 "' directive to the custom assembly format");
3045 // Check that the operand type is in the format, or that it can be inferred.
3046 if (fmt
.allOperandTypes
|| seenOperandTypes
.test(i
))
3049 // Check to see if we can infer this type from another variable.
3050 auto varResolverIt
= variableTyResolver
.find(op
.getOperand(i
).name
);
3051 if (varResolverIt
!= variableTyResolver
.end()) {
3052 TypeResolutionInstance
&resolver
= varResolverIt
->second
;
3053 fmt
.operandTypes
[i
].setResolver(resolver
.resolver
, resolver
.transformer
);
3057 // Similarly to results, allow a custom builder for resolving the type if
3058 // we aren't using the 'operands' directive.
3059 std::optional
<StringRef
> builder
= operand
.constraint
.getBuilderCall();
3060 if (!builder
|| (fmt
.allOperands
&& operand
.isVariableLength())) {
3061 return emitErrorAndNote(
3063 "type of operand #" + Twine(i
) + ", named '" + operand
.name
+
3064 "', is not buildable and a buildable type cannot be inferred",
3065 "suggest adding a type constraint to the operation or adding a "
3067 operand
.name
+ ")' directive to the " + "custom assembly format");
3069 auto it
= buildableTypes
.insert({*builder
, buildableTypes
.size()});
3070 fmt
.operandTypes
[i
].setBuilderIdx(it
.first
->second
);
3075 LogicalResult
OpFormatParser::verifyRegions(SMLoc loc
) {
3076 // Check that all of the regions are within the format.
3080 for (unsigned i
= 0, e
= op
.getNumRegions(); i
!= e
; ++i
) {
3081 const NamedRegion
®ion
= op
.getRegion(i
);
3082 if (!seenRegions
.count(®ion
)) {
3083 return emitErrorAndNote(loc
,
3084 "region #" + Twine(i
) + ", named '" +
3085 region
.name
+ "', not found",
3086 "suggest adding a '$" + region
.name
+
3087 "' directive to the custom assembly format");
3093 LogicalResult
OpFormatParser::verifyResults(
3094 SMLoc loc
, StringMap
<TypeResolutionInstance
> &variableTyResolver
) {
3095 // If we format all of the types together, there is nothing to check.
3096 if (fmt
.allResultTypes
)
3099 // If no result types are specified and we can infer them, infer all result
3101 if (op
.getNumResults() > 0 && seenResultTypes
.count() == 0 &&
3102 canInferResultTypes
) {
3103 fmt
.infersResultTypes
= true;
3107 // Check that all of the result types can be inferred.
3108 auto &buildableTypes
= fmt
.buildableTypes
;
3109 for (unsigned i
= 0, e
= op
.getNumResults(); i
!= e
; ++i
) {
3110 if (seenResultTypes
.test(i
))
3113 // Check to see if we can infer this type from another variable.
3114 auto varResolverIt
= variableTyResolver
.find(op
.getResultName(i
));
3115 if (varResolverIt
!= variableTyResolver
.end()) {
3116 TypeResolutionInstance resolver
= varResolverIt
->second
;
3117 fmt
.resultTypes
[i
].setResolver(resolver
.resolver
, resolver
.transformer
);
3121 // If the result is not variable length, allow for the case where the type
3122 // has a builder that we can use.
3123 NamedTypeConstraint
&result
= op
.getResult(i
);
3124 std::optional
<StringRef
> builder
= result
.constraint
.getBuilderCall();
3125 if (!builder
|| result
.isVariableLength()) {
3126 return emitErrorAndNote(
3128 "type of result #" + Twine(i
) + ", named '" + result
.name
+
3129 "', is not buildable and a buildable type cannot be inferred",
3130 "suggest adding a type constraint to the operation or adding a "
3132 result
.name
+ ")' directive to the " + "custom assembly format");
3134 // Note in the format that this result uses the custom builder.
3135 auto it
= buildableTypes
.insert({*builder
, buildableTypes
.size()});
3136 fmt
.resultTypes
[i
].setBuilderIdx(it
.first
->second
);
3141 LogicalResult
OpFormatParser::verifySuccessors(SMLoc loc
) {
3142 // Check that all of the successors are within the format.
3143 if (hasAllSuccessors
)
3146 for (unsigned i
= 0, e
= op
.getNumSuccessors(); i
!= e
; ++i
) {
3147 const NamedSuccessor
&successor
= op
.getSuccessor(i
);
3148 if (!seenSuccessors
.count(&successor
)) {
3149 return emitErrorAndNote(loc
,
3150 "successor #" + Twine(i
) + ", named '" +
3151 successor
.name
+ "', not found",
3152 "suggest adding a '$" + successor
.name
+
3153 "' directive to the custom assembly format");
3160 OpFormatParser::verifyOIListElements(SMLoc loc
,
3161 ArrayRef
<FormatElement
*> elements
) {
3162 // Check that all of the successors are within the format.
3163 SmallVector
<StringRef
> prohibitedLiterals
;
3164 for (FormatElement
*it
: elements
) {
3165 if (auto *oilist
= dyn_cast
<OIListElement
>(it
)) {
3166 if (!prohibitedLiterals
.empty()) {
3167 // We just saw an oilist element in last iteration. Literals should not
3169 for (LiteralElement
*literal
: oilist
->getLiteralElements()) {
3170 if (find(prohibitedLiterals
, literal
->getSpelling()) !=
3171 prohibitedLiterals
.end()) {
3173 loc
, "format ambiguity because " + literal
->getSpelling() +
3174 " is used in two adjacent oilist elements.");
3178 for (LiteralElement
*literal
: oilist
->getLiteralElements())
3179 prohibitedLiterals
.push_back(literal
->getSpelling());
3180 } else if (auto *literal
= dyn_cast
<LiteralElement
>(it
)) {
3181 if (find(prohibitedLiterals
, literal
->getSpelling()) !=
3182 prohibitedLiterals
.end()) {
3185 "format ambiguity because " + literal
->getSpelling() +
3186 " is used both in oilist element and the adjacent literal.");
3188 prohibitedLiterals
.clear();
3190 prohibitedLiterals
.clear();
3196 void OpFormatParser::handleAllTypesMatchConstraint(
3197 ArrayRef
<StringRef
> values
,
3198 StringMap
<TypeResolutionInstance
> &variableTyResolver
) {
3199 for (unsigned i
= 0, e
= values
.size(); i
!= e
; ++i
) {
3200 // Check to see if this value matches a resolved operand or result type.
3201 ConstArgument arg
= findSeenArg(values
[i
]);
3205 // Mark this value as the type resolver for the other variables.
3206 for (unsigned j
= 0; j
!= i
; ++j
)
3207 variableTyResolver
[values
[j
]] = {arg
, std::nullopt
};
3208 for (unsigned j
= i
+ 1; j
!= e
; ++j
)
3209 variableTyResolver
[values
[j
]] = {arg
, std::nullopt
};
3213 void OpFormatParser::handleSameTypesConstraint(
3214 StringMap
<TypeResolutionInstance
> &variableTyResolver
,
3215 bool includeResults
) {
3216 const NamedTypeConstraint
*resolver
= nullptr;
3217 int resolvedIt
= -1;
3219 // Check to see if there is an operand or result to use for the resolution.
3220 if ((resolvedIt
= seenOperandTypes
.find_first()) != -1)
3221 resolver
= &op
.getOperand(resolvedIt
);
3222 else if (includeResults
&& (resolvedIt
= seenResultTypes
.find_first()) != -1)
3223 resolver
= &op
.getResult(resolvedIt
);
3227 // Set the resolvers for each operand and result.
3228 for (unsigned i
= 0, e
= op
.getNumOperands(); i
!= e
; ++i
)
3229 if (!seenOperandTypes
.test(i
))
3230 variableTyResolver
[op
.getOperand(i
).name
] = {resolver
, std::nullopt
};
3231 if (includeResults
) {
3232 for (unsigned i
= 0, e
= op
.getNumResults(); i
!= e
; ++i
)
3233 if (!seenResultTypes
.test(i
))
3234 variableTyResolver
[op
.getResultName(i
)] = {resolver
, std::nullopt
};
3238 void OpFormatParser::handleTypesMatchConstraint(
3239 StringMap
<TypeResolutionInstance
> &variableTyResolver
, const Record
&def
) {
3240 StringRef lhsName
= def
.getValueAsString("lhs");
3241 StringRef rhsName
= def
.getValueAsString("rhs");
3242 StringRef transformer
= def
.getValueAsString("transformer");
3243 if (ConstArgument arg
= findSeenArg(lhsName
))
3244 variableTyResolver
[rhsName
] = {arg
, transformer
};
3247 ConstArgument
OpFormatParser::findSeenArg(StringRef name
) {
3248 if (const NamedTypeConstraint
*arg
= findArg(op
.getOperands(), name
))
3249 return seenOperandTypes
.test(arg
- op
.operand_begin()) ? arg
: nullptr;
3250 if (const NamedTypeConstraint
*arg
= findArg(op
.getResults(), name
))
3251 return seenResultTypes
.test(arg
- op
.result_begin()) ? arg
: nullptr;
3252 if (const NamedAttribute
*attr
= findArg(op
.getAttributes(), name
))
3253 return seenAttrs
.count(attr
) ? attr
: nullptr;
3257 FailureOr
<FormatElement
*>
3258 OpFormatParser::parseVariableImpl(SMLoc loc
, StringRef name
, Context ctx
) {
3259 // Check that the parsed argument is something actually registered on the op.
3261 if (const NamedAttribute
*attr
= findArg(op
.getAttributes(), name
)) {
3262 if (ctx
== TypeDirectiveContext
)
3264 loc
, "attributes cannot be used as children to a `type` directive");
3265 if (ctx
== RefDirectiveContext
) {
3266 if (!seenAttrs
.count(attr
))
3267 return emitError(loc
, "attribute '" + name
+
3268 "' must be bound before it is referenced");
3269 } else if (!seenAttrs
.insert(attr
)) {
3270 return emitError(loc
, "attribute '" + name
+ "' is already bound");
3273 return create
<AttributeVariable
>(attr
);
3276 if (const NamedProperty
*property
= findArg(op
.getProperties(), name
)) {
3277 if (ctx
== TypeDirectiveContext
)
3279 loc
, "properties cannot be used as children to a `type` directive");
3280 if (ctx
== RefDirectiveContext
) {
3281 if (!seenProperties
.count(property
))
3282 return emitError(loc
, "property '" + name
+
3283 "' must be bound before it is referenced");
3285 if (!seenProperties
.insert(property
))
3286 return emitError(loc
, "property '" + name
+ "' is already bound");
3289 return create
<PropertyVariable
>(property
);
3293 if (const NamedTypeConstraint
*operand
= findArg(op
.getOperands(), name
)) {
3294 if (ctx
== TopLevelContext
|| ctx
== CustomDirectiveContext
) {
3295 if (fmt
.allOperands
|| !seenOperands
.insert(operand
).second
)
3296 return emitError(loc
, "operand '" + name
+ "' is already bound");
3297 } else if (ctx
== RefDirectiveContext
&& !seenOperands
.count(operand
)) {
3298 return emitError(loc
, "operand '" + name
+
3299 "' must be bound before it is referenced");
3301 return create
<OperandVariable
>(operand
);
3304 if (const NamedRegion
*region
= findArg(op
.getRegions(), name
)) {
3305 if (ctx
== TopLevelContext
|| ctx
== CustomDirectiveContext
) {
3306 if (hasAllRegions
|| !seenRegions
.insert(region
).second
)
3307 return emitError(loc
, "region '" + name
+ "' is already bound");
3308 } else if (ctx
== RefDirectiveContext
&& !seenRegions
.count(region
)) {
3309 return emitError(loc
, "region '" + name
+
3310 "' must be bound before it is referenced");
3312 return emitError(loc
, "regions can only be used at the top level");
3314 return create
<RegionVariable
>(region
);
3317 if (const auto *result
= findArg(op
.getResults(), name
)) {
3318 if (ctx
!= TypeDirectiveContext
)
3319 return emitError(loc
, "result variables can can only be used as a child "
3320 "to a 'type' directive");
3321 return create
<ResultVariable
>(result
);
3324 if (const auto *successor
= findArg(op
.getSuccessors(), name
)) {
3325 if (ctx
== TopLevelContext
|| ctx
== CustomDirectiveContext
) {
3326 if (hasAllSuccessors
|| !seenSuccessors
.insert(successor
).second
)
3327 return emitError(loc
, "successor '" + name
+ "' is already bound");
3328 } else if (ctx
== RefDirectiveContext
&& !seenSuccessors
.count(successor
)) {
3329 return emitError(loc
, "successor '" + name
+
3330 "' must be bound before it is referenced");
3332 return emitError(loc
, "successors can only be used at the top level");
3335 return create
<SuccessorVariable
>(successor
);
3337 return emitError(loc
, "expected variable to refer to an argument, region, "
3338 "result, or successor");
3341 FailureOr
<FormatElement
*>
3342 OpFormatParser::parseDirectiveImpl(SMLoc loc
, FormatToken::Kind kind
,
3345 case FormatToken::kw_prop_dict
:
3346 return parsePropDictDirective(loc
, ctx
);
3347 case FormatToken::kw_attr_dict
:
3348 return parseAttrDictDirective(loc
, ctx
,
3349 /*withKeyword=*/false);
3350 case FormatToken::kw_attr_dict_w_keyword
:
3351 return parseAttrDictDirective(loc
, ctx
,
3352 /*withKeyword=*/true);
3353 case FormatToken::kw_functional_type
:
3354 return parseFunctionalTypeDirective(loc
, ctx
);
3355 case FormatToken::kw_operands
:
3356 return parseOperandsDirective(loc
, ctx
);
3357 case FormatToken::kw_regions
:
3358 return parseRegionsDirective(loc
, ctx
);
3359 case FormatToken::kw_results
:
3360 return parseResultsDirective(loc
, ctx
);
3361 case FormatToken::kw_successors
:
3362 return parseSuccessorsDirective(loc
, ctx
);
3363 case FormatToken::kw_type
:
3364 return parseTypeDirective(loc
, ctx
);
3365 case FormatToken::kw_oilist
:
3366 return parseOIListDirective(loc
, ctx
);
3369 return emitError(loc
, "unsupported directive kind");
3373 FailureOr
<FormatElement
*>
3374 OpFormatParser::parseAttrDictDirective(SMLoc loc
, Context context
,
3376 if (context
== TypeDirectiveContext
)
3377 return emitError(loc
, "'attr-dict' directive can only be used as a "
3378 "top-level directive");
3380 if (context
== RefDirectiveContext
) {
3382 return emitError(loc
, "'ref' of 'attr-dict' is not bound by a prior "
3383 "'attr-dict' directive");
3385 // Otherwise, this is a top-level context.
3388 return emitError(loc
, "'attr-dict' directive has already been seen");
3392 return create
<AttrDictDirective
>(withKeyword
);
3395 FailureOr
<FormatElement
*>
3396 OpFormatParser::parsePropDictDirective(SMLoc loc
, Context context
) {
3397 if (context
== TypeDirectiveContext
)
3398 return emitError(loc
, "'prop-dict' directive can only be used as a "
3399 "top-level directive");
3401 if (context
== RefDirectiveContext
)
3402 llvm::report_fatal_error("'ref' of 'prop-dict' unsupported");
3403 // Otherwise, this is a top-level context.
3406 return emitError(loc
, "'prop-dict' directive has already been seen");
3409 return create
<PropDictDirective
>();
3412 LogicalResult
OpFormatParser::verifyCustomDirectiveArguments(
3413 SMLoc loc
, ArrayRef
<FormatElement
*> arguments
) {
3414 for (FormatElement
*argument
: arguments
) {
3415 if (!isa
<AttrDictDirective
, PropDictDirective
, AttributeVariable
,
3416 OperandVariable
, PropertyVariable
, RefDirective
, RegionVariable
,
3417 SuccessorVariable
, StringElement
, TypeDirective
>(argument
)) {
3418 // TODO: FormatElement should have location info attached.
3419 return emitError(loc
, "only variables and types may be used as "
3420 "parameters to a custom directive");
3422 if (auto *type
= dyn_cast
<TypeDirective
>(argument
)) {
3423 if (!isa
<OperandVariable
, ResultVariable
>(type
->getArg())) {
3424 return emitError(loc
, "type directives within a custom directive may "
3425 "only refer to variables");
3432 FailureOr
<FormatElement
*>
3433 OpFormatParser::parseFunctionalTypeDirective(SMLoc loc
, Context context
) {
3434 if (context
!= TopLevelContext
)
3436 loc
, "'functional-type' is only valid as a top-level directive");
3438 // Parse the main operand.
3439 FailureOr
<FormatElement
*> inputs
, results
;
3440 if (failed(parseToken(FormatToken::l_paren
,
3441 "expected '(' before argument list")) ||
3442 failed(inputs
= parseTypeDirectiveOperand(loc
)) ||
3443 failed(parseToken(FormatToken::comma
,
3444 "expected ',' after inputs argument")) ||
3445 failed(results
= parseTypeDirectiveOperand(loc
)) ||
3447 parseToken(FormatToken::r_paren
, "expected ')' after argument list")))
3449 return create
<FunctionalTypeDirective
>(*inputs
, *results
);
3452 FailureOr
<FormatElement
*>
3453 OpFormatParser::parseOperandsDirective(SMLoc loc
, Context context
) {
3454 if (context
== RefDirectiveContext
) {
3455 if (!fmt
.allOperands
)
3456 return emitError(loc
, "'ref' of 'operands' is not bound by a prior "
3457 "'operands' directive");
3459 } else if (context
== TopLevelContext
|| context
== CustomDirectiveContext
) {
3460 if (fmt
.allOperands
|| !seenOperands
.empty())
3461 return emitError(loc
, "'operands' directive creates overlap in format");
3462 fmt
.allOperands
= true;
3464 return create
<OperandsDirective
>();
3467 FailureOr
<FormatElement
*>
3468 OpFormatParser::parseRegionsDirective(SMLoc loc
, Context context
) {
3469 if (context
== TypeDirectiveContext
)
3470 return emitError(loc
, "'regions' is only valid as a top-level directive");
3471 if (context
== RefDirectiveContext
) {
3473 return emitError(loc
, "'ref' of 'regions' is not bound by a prior "
3474 "'regions' directive");
3476 // Otherwise, this is a TopLevel directive.
3478 if (hasAllRegions
|| !seenRegions
.empty())
3479 return emitError(loc
, "'regions' directive creates overlap in format");
3480 hasAllRegions
= true;
3482 return create
<RegionsDirective
>();
3485 FailureOr
<FormatElement
*>
3486 OpFormatParser::parseResultsDirective(SMLoc loc
, Context context
) {
3487 if (context
!= TypeDirectiveContext
)
3488 return emitError(loc
, "'results' directive can can only be used as a child "
3489 "to a 'type' directive");
3490 return create
<ResultsDirective
>();
3493 FailureOr
<FormatElement
*>
3494 OpFormatParser::parseSuccessorsDirective(SMLoc loc
, Context context
) {
3495 if (context
== TypeDirectiveContext
)
3496 return emitError(loc
,
3497 "'successors' is only valid as a top-level directive");
3498 if (context
== RefDirectiveContext
) {
3499 if (!hasAllSuccessors
)
3500 return emitError(loc
, "'ref' of 'successors' is not bound by a prior "
3501 "'successors' directive");
3503 // Otherwise, this is a TopLevel directive.
3505 if (hasAllSuccessors
|| !seenSuccessors
.empty())
3506 return emitError(loc
, "'successors' directive creates overlap in format");
3507 hasAllSuccessors
= true;
3509 return create
<SuccessorsDirective
>();
3512 FailureOr
<FormatElement
*>
3513 OpFormatParser::parseOIListDirective(SMLoc loc
, Context context
) {
3514 if (failed(parseToken(FormatToken::l_paren
,
3515 "expected '(' before oilist argument list")))
3517 std::vector
<FormatElement
*> literalElements
;
3518 std::vector
<std::vector
<FormatElement
*>> parsingElements
;
3520 FailureOr
<FormatElement
*> lelement
= parseLiteral(context
);
3521 if (failed(lelement
))
3523 literalElements
.push_back(*lelement
);
3524 parsingElements
.emplace_back();
3525 std::vector
<FormatElement
*> &currParsingElements
= parsingElements
.back();
3526 while (peekToken().getKind() != FormatToken::pipe
&&
3527 peekToken().getKind() != FormatToken::r_paren
) {
3528 FailureOr
<FormatElement
*> pelement
= parseElement(context
);
3529 if (failed(pelement
) ||
3530 failed(verifyOIListParsingElement(*pelement
, loc
)))
3532 currParsingElements
.push_back(*pelement
);
3534 if (peekToken().getKind() == FormatToken::pipe
) {
3538 if (peekToken().getKind() == FormatToken::r_paren
) {
3544 return create
<OIListElement
>(std::move(literalElements
),
3545 std::move(parsingElements
));
3548 LogicalResult
OpFormatParser::verifyOIListParsingElement(FormatElement
*element
,
3550 SmallVector
<VariableElement
*> vars
;
3551 collect(element
, vars
);
3552 for (VariableElement
*elem
: vars
) {
3554 TypeSwitch
<FormatElement
*, LogicalResult
>(elem
)
3555 // Only optional attributes can be within an oilist parsing group.
3556 .Case([&](AttributeVariable
*attrEle
) {
3557 if (!attrEle
->getVar()->attr
.isOptional() &&
3558 !attrEle
->getVar()->attr
.hasDefaultValue())
3559 return emitError(loc
, "only optional attributes can be used in "
3560 "an oilist parsing group");
3563 // Only optional properties can be within an oilist parsing group.
3564 .Case([&](PropertyVariable
*propEle
) {
3565 if (!propEle
->getVar()->prop
.hasDefaultValue())
3568 "only default-valued or optional properties can be used in "
3569 "an olist parsing group");
3572 // Only optional-like(i.e. variadic) operands can be within an
3573 // oilist parsing group.
3574 .Case([&](OperandVariable
*ele
) {
3575 if (!ele
->getVar()->isVariableLength())
3576 return emitError(loc
, "only variable length operands can be "
3577 "used within an oilist parsing group");
3580 // Only optional-like(i.e. variadic) results can be within an oilist
3582 .Case([&](ResultVariable
*ele
) {
3583 if (!ele
->getVar()->isVariableLength())
3584 return emitError(loc
, "only variable length results can be "
3585 "used within an oilist parsing group");
3588 .Case([&](RegionVariable
*) { return success(); })
3589 .Default([&](FormatElement
*) {
3590 return emitError(loc
,
3591 "only literals, types, and variables can be "
3592 "used within an oilist group");
3600 FailureOr
<FormatElement
*> OpFormatParser::parseTypeDirective(SMLoc loc
,
3602 if (context
== TypeDirectiveContext
)
3603 return emitError(loc
, "'type' cannot be used as a child of another `type`");
3605 bool isRefChild
= context
== RefDirectiveContext
;
3606 FailureOr
<FormatElement
*> operand
;
3607 if (failed(parseToken(FormatToken::l_paren
,
3608 "expected '(' before argument list")) ||
3609 failed(operand
= parseTypeDirectiveOperand(loc
, isRefChild
)) ||
3611 parseToken(FormatToken::r_paren
, "expected ')' after argument list")))
3614 return create
<TypeDirective
>(*operand
);
3617 LogicalResult
OpFormatParser::markQualified(SMLoc loc
, FormatElement
*element
) {
3618 return TypeSwitch
<FormatElement
*, LogicalResult
>(element
)
3619 .Case
<AttributeVariable
, TypeDirective
>([](auto *element
) {
3620 element
->setShouldBeQualified();
3623 .Default([&](auto *element
) {
3624 return this->emitError(
3626 "'qualified' directive expects an attribute or a `type` directive");
3630 FailureOr
<FormatElement
*>
3631 OpFormatParser::parseTypeDirectiveOperand(SMLoc loc
, bool isRefChild
) {
3632 FailureOr
<FormatElement
*> result
= parseElement(TypeDirectiveContext
);
3636 FormatElement
*element
= *result
;
3637 if (isa
<LiteralElement
>(element
))
3639 loc
, "'type' directive operand expects variable or directive operand");
3641 if (auto *var
= dyn_cast
<OperandVariable
>(element
)) {
3642 unsigned opIdx
= var
->getVar() - op
.operand_begin();
3643 if (!isRefChild
&& (fmt
.allOperandTypes
|| seenOperandTypes
.test(opIdx
)))
3644 return emitError(loc
, "'type' of '" + var
->getVar()->name
+
3645 "' is already bound");
3646 if (isRefChild
&& !(fmt
.allOperandTypes
|| seenOperandTypes
.test(opIdx
)))
3647 return emitError(loc
, "'ref' of 'type($" + var
->getVar()->name
+
3648 ")' is not bound by a prior 'type' directive");
3649 seenOperandTypes
.set(opIdx
);
3650 } else if (auto *var
= dyn_cast
<ResultVariable
>(element
)) {
3651 unsigned resIdx
= var
->getVar() - op
.result_begin();
3652 if (!isRefChild
&& (fmt
.allResultTypes
|| seenResultTypes
.test(resIdx
)))
3653 return emitError(loc
, "'type' of '" + var
->getVar()->name
+
3654 "' is already bound");
3655 if (isRefChild
&& !(fmt
.allResultTypes
|| seenResultTypes
.test(resIdx
)))
3656 return emitError(loc
, "'ref' of 'type($" + var
->getVar()->name
+
3657 ")' is not bound by a prior 'type' directive");
3658 seenResultTypes
.set(resIdx
);
3659 } else if (isa
<OperandsDirective
>(&*element
)) {
3660 if (!isRefChild
&& (fmt
.allOperandTypes
|| seenOperandTypes
.any()))
3661 return emitError(loc
, "'operands' 'type' is already bound");
3662 if (isRefChild
&& !fmt
.allOperandTypes
)
3663 return emitError(loc
, "'ref' of 'type(operands)' is not bound by a prior "
3664 "'type' directive");
3665 fmt
.allOperandTypes
= true;
3666 } else if (isa
<ResultsDirective
>(&*element
)) {
3667 if (!isRefChild
&& (fmt
.allResultTypes
|| seenResultTypes
.any()))
3668 return emitError(loc
, "'results' 'type' is already bound");
3669 if (isRefChild
&& !fmt
.allResultTypes
)
3670 return emitError(loc
, "'ref' of 'type(results)' is not bound by a prior "
3671 "'type' directive");
3672 fmt
.allResultTypes
= true;
3674 return emitError(loc
, "invalid argument to 'type' directive");
3679 LogicalResult
OpFormatParser::verifyOptionalGroupElements(
3680 SMLoc loc
, ArrayRef
<FormatElement
*> elements
, FormatElement
*anchor
) {
3681 for (FormatElement
*element
: elements
) {
3682 if (failed(verifyOptionalGroupElement(loc
, element
, element
== anchor
)))
3688 LogicalResult
OpFormatParser::verifyOptionalGroupElement(SMLoc loc
,
3689 FormatElement
*element
,
3691 return TypeSwitch
<FormatElement
*, LogicalResult
>(element
)
3692 // All attributes can be within the optional group, but only optional
3693 // attributes can be the anchor.
3694 .Case([&](AttributeVariable
*attrEle
) {
3695 Attribute attr
= attrEle
->getVar()->attr
;
3696 if (isAnchor
&& !(attr
.isOptional() || attr
.hasDefaultValue()))
3697 return emitError(loc
, "only optional or default-valued attributes "
3698 "can be used to anchor an optional group");
3701 // All properties can be within the optional group, but only optional
3702 // properties can be the anchor.
3703 .Case([&](PropertyVariable
*propEle
) {
3704 Property prop
= propEle
->getVar()->prop
;
3705 if (isAnchor
&& !(prop
.hasDefaultValue() && prop
.hasOptionalParser()))
3706 return emitError(loc
, "only properties with default values "
3707 "that can be optionally parsed "
3708 "can be used to anchor an optional group");
3711 // Only optional-like(i.e. variadic) operands can be within an optional
3713 .Case([&](OperandVariable
*ele
) {
3714 if (!ele
->getVar()->isVariableLength())
3715 return emitError(loc
, "only variable length operands can be used "
3716 "within an optional group");
3719 // Only optional-like(i.e. variadic) results can be within an optional
3721 .Case([&](ResultVariable
*ele
) {
3722 if (!ele
->getVar()->isVariableLength())
3723 return emitError(loc
, "only variable length results can be used "
3724 "within an optional group");
3727 .Case([&](RegionVariable
*) {
3728 // TODO: When ODS has proper support for marking "optional" regions, add
3732 .Case([&](TypeDirective
*ele
) {
3733 return verifyOptionalGroupElement(loc
, ele
->getArg(),
3734 /*isAnchor=*/false);
3736 .Case([&](FunctionalTypeDirective
*ele
) {
3737 if (failed(verifyOptionalGroupElement(loc
, ele
->getInputs(),
3738 /*isAnchor=*/false)))
3740 return verifyOptionalGroupElement(loc
, ele
->getResults(),
3741 /*isAnchor=*/false);
3743 .Case([&](CustomDirective
*ele
) {
3746 // Verify each child as being valid in an optional group. They are all
3747 // potential anchors if the custom directive was marked as one.
3748 for (FormatElement
*child
: ele
->getArguments()) {
3749 if (isa
<RefDirective
>(child
))
3751 if (failed(verifyOptionalGroupElement(loc
, child
, /*isAnchor=*/true)))
3756 // Literals, whitespace, and custom directives may be used, but they can't
3757 // anchor the group.
3758 .Case
<LiteralElement
, WhitespaceElement
, OptionalElement
>(
3759 [&](FormatElement
*) {
3761 return emitError(loc
, "only variables and types can be used "
3762 "to anchor an optional group");
3765 .Default([&](FormatElement
*) {
3766 return emitError(loc
, "only literals, types, and variables can be "
3767 "used within an optional group");
3771 //===----------------------------------------------------------------------===//
3773 //===----------------------------------------------------------------------===//
3775 void mlir::tblgen::generateOpFormat(const Operator
&constOp
, OpClass
&opClass
,
3776 bool hasProperties
) {
3777 // TODO: Operator doesn't expose all necessary functionality via
3778 // the const interface.
3779 Operator
&op
= const_cast<Operator
&>(constOp
);
3780 if (!op
.hasAssemblyFormat())
3783 // Parse the format description.
3784 llvm::SourceMgr mgr
;
3785 mgr
.AddNewSourceBuffer(
3786 llvm::MemoryBuffer::getMemBuffer(op
.getAssemblyFormat()), SMLoc());
3787 OperationFormat
format(op
, hasProperties
);
3788 OpFormatParser
parser(mgr
, format
, op
);
3789 FailureOr
<std::vector
<FormatElement
*>> elements
= parser
.parse();
3790 if (failed(elements
)) {
3791 // Exit the process if format errors are treated as fatal.
3792 if (formatErrorIsFatal
) {
3793 // Invoke the interrupt handlers to run the file cleanup handlers.
3794 llvm::sys::RunInterruptHandlers();
3799 format
.elements
= std::move(*elements
);
3801 // Generate the printer and parser based on the parsed format.
3802 format
.genParser(op
, opClass
);
3803 format
.genPrinter(op
, opClass
);