1 //===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // OpDefinitionsGen uses the description of operations to generate C++
10 // definitions for ops.
12 //===----------------------------------------------------------------------===//
15 #include "OpFormatGen.h"
16 #include "OpGenHelpers.h"
17 #include "mlir/TableGen/Argument.h"
18 #include "mlir/TableGen/Attribute.h"
19 #include "mlir/TableGen/Class.h"
20 #include "mlir/TableGen/CodeGenHelpers.h"
21 #include "mlir/TableGen/Format.h"
22 #include "mlir/TableGen/GenInfo.h"
23 #include "mlir/TableGen/Interfaces.h"
24 #include "mlir/TableGen/Operator.h"
25 #include "mlir/TableGen/Property.h"
26 #include "mlir/TableGen/SideEffects.h"
27 #include "mlir/TableGen/Trait.h"
28 #include "llvm/ADT/BitVector.h"
29 #include "llvm/ADT/MapVector.h"
30 #include "llvm/ADT/Sequence.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/StringExtras.h"
33 #include "llvm/ADT/StringSet.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/ErrorHandling.h"
36 #include "llvm/Support/Signals.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include "llvm/TableGen/Error.h"
39 #include "llvm/TableGen/Record.h"
40 #include "llvm/TableGen/TableGenBackend.h"
42 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
46 using namespace mlir::tblgen
;
48 static const char *const tblgenNamePrefix
= "tblgen_";
49 static const char *const generatedArgName
= "odsArg";
50 static const char *const odsBuilder
= "odsBuilder";
51 static const char *const builderOpState
= "odsState";
52 static const char *const propertyStorage
= "propStorage";
53 static const char *const propertyValue
= "propValue";
54 static const char *const propertyAttr
= "propAttr";
55 static const char *const propertyDiag
= "emitError";
57 /// The names of the implicit attributes that contain variadic operand and
58 /// result segment sizes.
59 static const char *const operandSegmentAttrName
= "operandSegmentSizes";
60 static const char *const resultSegmentAttrName
= "resultSegmentSizes";
62 /// Code for an Op to lookup an attribute. Uses cached identifiers and subrange
65 /// {0}: Code snippet to get the attribute's name or identifier.
66 /// {1}: The lower bound on the sorted subrange.
67 /// {2}: The upper bound on the sorted subrange.
68 /// {3}: Code snippet to get the array of named attributes.
69 /// {4}: "Named" to get the named attribute.
70 static const char *const subrangeGetAttr
=
71 "::mlir::impl::get{4}AttrFromSortedRange({3}.begin() + {1}, {3}.end() - "
74 /// The logic to calculate the actual value range for a declared operand/result
75 /// of an op with variadic operands/results. Note that this logic is not for
76 /// general use; it assumes all variadic operands/results must have the same
79 /// {0}: The list of whether each declared operand/result is variadic.
80 /// {1}: The total number of non-variadic operands/results.
81 /// {2}: The total number of variadic operands/results.
82 /// {3}: The total number of actual values.
83 /// {4}: "operand" or "result".
84 static const char *const sameVariadicSizeValueRangeCalcCode
= R
"(
85 bool isVariadic[] = {{{0}};
86 int prevVariadicCount = 0;
87 for (unsigned i = 0; i < index; ++i)
88 if (isVariadic[i]) ++prevVariadicCount;
90 // Calculate how many dynamic values a static variadic {4} corresponds to.
91 // This assumes all static variadic {4}s have the same dynamic value count.
92 int variadicSize = ({3} - {1}) / {2};
93 // `index` passed in as the parameter is the static index which counts each
94 // {4} (variadic or not) as size 1. So here for each previous static variadic
95 // {4}, we need to offset by (variadicSize - 1) to get where the dynamic
96 // value pack for this static {4} starts.
97 int start = index + (variadicSize - 1) * prevVariadicCount;
98 int size = isVariadic[index] ? variadicSize : 1;
99 return {{start, size};
102 /// The logic to calculate the actual value range for a declared operand/result
103 /// of an op with variadic operands/results. Note that this logic is assumes
104 /// the op has an attribute specifying the size of each operand/result segment
105 /// (variadic or not).
106 static const char *const attrSizedSegmentValueRangeCalcCode
= R
"(
108 for (unsigned i = 0; i < index; ++i)
109 start += sizeAttr[i];
110 return {start, sizeAttr[index]};
112 /// The code snippet to initialize the sizes for the value range calculation.
114 /// {0}: The code to get the attribute.
115 static const char *const adapterSegmentSizeAttrInitCode
= R
"(
116 assert({0} && "missing segment size attribute
for op
");
117 auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>({0});
119 static const char *const adapterSegmentSizeAttrInitCodeProperties
= R
"(
120 ::llvm::ArrayRef<int32_t> sizeAttr = {0};
123 /// The code snippet to initialize the sizes for the value range calculation.
125 /// {0}: The code to get the attribute.
126 static const char *const opSegmentSizeAttrInitCode
= R
"(
127 auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>({0});
130 /// The logic to calculate the actual value range for a declared operand
131 /// of an op with variadic of variadic operands within the OpAdaptor.
133 /// {0}: The name of the segment attribute.
134 /// {1}: The index of the main operand.
135 /// {2}: The range type of adaptor.
136 static const char *const variadicOfVariadicAdaptorCalcCode
= R
"(
137 auto tblgenTmpOperands = getODSOperands({1});
140 ::llvm::SmallVector<{2}> tblgenTmpOperandGroups;
141 for (int i = 0, e = sizes.size(); i < e; ++i) {{
142 tblgenTmpOperandGroups.push_back(tblgenTmpOperands.take_front(sizes[i]));
143 tblgenTmpOperands = tblgenTmpOperands.drop_front(sizes[i]);
145 return tblgenTmpOperandGroups;
148 /// The logic to build a range of either operand or result values.
150 /// {0}: The begin iterator of the actual values.
151 /// {1}: The call to generate the start and length of the value range.
152 static const char *const valueRangeReturnCode
= R
"(
153 auto valueRange = {1};
154 return {{std::next({0}, valueRange.first),
155 std::next({0}, valueRange.first + valueRange.second)};
158 /// Parse operand/result segment_size property.
159 /// {0}: Number of elements in the segment array
160 static const char *const parseTextualSegmentSizeFormat
= R
"(
162 auto parseElem = [&]() -> ::mlir::ParseResult {
164 return $_parser.emitError($_parser.getCurrentLocation(),
165 "expected `
]` after
{0} segment sizes
");
166 if (failed($_parser.parseInteger($_storage[i])))
167 return ::mlir::failure();
169 return ::mlir::success();
171 if (failed($_parser.parseCommaSeparatedList(
172 ::mlir::AsmParser::Delimeter::Square, parseElem)))
175 return $_parser.emitError($_parser.getCurrentLocation(),
176 "expected
{0} segment sizes
, found only
") << i;
180 static const char *const printTextualSegmentSize
= R
"(
183 ::llvm::interleaveComma($_storage, $_printer);
188 /// Read operand/result segment_size from bytecode.
189 static const char *const readBytecodeSegmentSizeNative
= R
"(
190 if ($_reader.getBytecodeVersion() >= /*kNativePropertiesODSSegmentSize=*/6)
191 return $_reader.readSparseArray(::llvm::MutableArrayRef($_storage));
194 static const char *const readBytecodeSegmentSizeLegacy
= R
"(
195 if ($_reader.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) {
196 auto &$_storage = prop.$_propName;
197 ::mlir::DenseI32ArrayAttr attr;
198 if (::mlir::failed($_reader.readAttribute(attr))) return ::mlir::failure();
199 if (attr.size() > static_cast<int64_t>(sizeof($_storage) / sizeof(int32_t))) {
200 $_reader.emitError("size mismatch
for operand
/result_segment_size
");
201 return ::mlir::failure();
203 ::llvm::copy(::llvm::ArrayRef<int32_t>(attr), $_storage.begin());
207 /// Write operand/result segment_size to bytecode.
208 static const char *const writeBytecodeSegmentSizeNative
= R
"(
209 if ($_writer.getBytecodeVersion() >= /*kNativePropertiesODSSegmentSize=*/6)
210 $_writer.writeSparseArray(::llvm::ArrayRef($_storage));
213 /// Write operand/result segment_size to bytecode.
214 static const char *const writeBytecodeSegmentSizeLegacy
= R
"(
215 if ($_writer.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) {
216 auto &$_storage = prop.$_propName;
217 $_writer.writeAttribute(::mlir::DenseI32ArrayAttr::get($_ctxt, $_storage));
221 /// A header for indicating code sections.
223 /// {0}: Some text, or a class name.
225 static const char *const opCommentHeader
= R
"(
226 //===----------------------------------------------------------------------===//
228 //===----------------------------------------------------------------------===//
232 //===----------------------------------------------------------------------===//
233 // Utility structs and functions
234 //===----------------------------------------------------------------------===//
236 // Replaces all occurrences of `match` in `str` with `substitute`.
237 static std::string
replaceAllSubstrs(std::string str
, const std::string
&match
,
238 const std::string
&substitute
) {
239 std::string::size_type scanLoc
= 0, matchLoc
= std::string::npos
;
240 while ((matchLoc
= str
.find(match
, scanLoc
)) != std::string::npos
) {
241 str
= str
.replace(matchLoc
, match
.size(), substitute
);
242 scanLoc
= matchLoc
+ substitute
.size();
247 // Returns whether the record has a value of the given name that can be returned
248 // via getValueAsString.
249 static inline bool hasStringAttribute(const Record
&record
,
250 StringRef fieldName
) {
251 auto *valueInit
= record
.getValueInit(fieldName
);
252 return isa
<StringInit
>(valueInit
);
255 static std::string
getArgumentName(const Operator
&op
, int index
) {
256 const auto &operand
= op
.getOperand(index
);
257 if (!operand
.name
.empty())
258 return std::string(operand
.name
);
259 return std::string(formatv("{0}_{1}", generatedArgName
, index
));
262 // Returns true if we can use unwrapped value for the given `attr` in builders.
263 static bool canUseUnwrappedRawValue(const tblgen::Attribute
&attr
) {
264 return attr
.getReturnType() != attr
.getStorageType() &&
265 // We need to wrap the raw value into an attribute in the builder impl
266 // so we need to make sure that the attribute specifies how to do that.
267 !attr
.getConstBuilderTemplate().empty();
270 /// Build an attribute from a parameter value using the constant builder.
271 static std::string
constBuildAttrFromParam(const tblgen::Attribute
&attr
,
273 StringRef paramName
) {
274 std::string builderTemplate
= attr
.getConstBuilderTemplate().str();
276 // For StringAttr, its constant builder call will wrap the input in
277 // quotes, which is correct for normal string literals, but incorrect
278 // here given we use function arguments. So we need to strip the
280 if (StringRef(builderTemplate
).contains("\"$0\""))
281 builderTemplate
= replaceAllSubstrs(builderTemplate
, "\"$0\"", "$0");
283 return tgfmt(builderTemplate
, &fctx
, paramName
).str();
287 /// Metadata on a registered attribute. Given that attributes are stored in
288 /// sorted order on operations, we can use information from ODS to deduce the
289 /// number of required attributes less and and greater than each attribute,
290 /// allowing us to search only a subrange of the attributes in ODS-generated
292 struct AttributeMetadata
{
293 /// The attribute name.
295 /// Whether the attribute is required.
297 /// The ODS attribute constraint. Not present for implicit attributes.
298 std::optional
<Attribute
> constraint
;
299 /// The number of required attributes less than this attribute.
300 unsigned lowerBound
= 0;
301 /// The number of required attributes greater than this attribute.
302 unsigned upperBound
= 0;
305 /// Helper class to select between OpAdaptor and Op code templates.
306 class OpOrAdaptorHelper
{
308 OpOrAdaptorHelper(const Operator
&op
, bool emitForOp
)
309 : op(op
), emitForOp(emitForOp
) {
310 computeAttrMetadata();
313 /// Object that wraps a functor in a stream operator for interop with
317 template <typename Functor
>
318 Formatter(Functor
&&func
) : func(std::forward
<Functor
>(func
)) {}
320 std::string
str() const {
322 llvm::raw_string_ostream
os(result
);
328 std::function
<raw_ostream
&(raw_ostream
&)> func
;
330 friend raw_ostream
&operator<<(raw_ostream
&os
, const Formatter
&fmt
) {
335 // Generate code for getting an attribute.
336 Formatter
getAttr(StringRef attrName
, bool isNamed
= false) const {
337 assert(attrMetadata
.count(attrName
) && "expected attribute metadata");
338 return [this, attrName
, isNamed
](raw_ostream
&os
) -> raw_ostream
& {
339 const AttributeMetadata
&attr
= attrMetadata
.find(attrName
)->second
;
340 if (hasProperties()) {
342 return os
<< "getProperties()." << attrName
;
344 return os
<< formatv(subrangeGetAttr
, getAttrName(attrName
),
345 attr
.lowerBound
, attr
.upperBound
, getAttrRange(),
346 isNamed
? "Named" : "");
350 // Generate code for getting the name of an attribute.
351 Formatter
getAttrName(StringRef attrName
) const {
352 return [this, attrName
](raw_ostream
&os
) -> raw_ostream
& {
354 return os
<< op
.getGetterName(attrName
) << "AttrName()";
355 return os
<< formatv("{0}::{1}AttrName(*odsOpName)", op
.getCppClassName(),
356 op
.getGetterName(attrName
));
360 // Get the code snippet for getting the named attribute range.
361 StringRef
getAttrRange() const {
362 return emitForOp
? "(*this)->getAttrs()" : "odsAttrs";
365 // Get the prefix code for emitting an error.
366 Formatter
emitErrorPrefix() const {
367 return [this](raw_ostream
&os
) -> raw_ostream
& {
369 return os
<< "emitOpError(";
370 return os
<< formatv("emitError(loc, \"'{0}' op \"",
371 op
.getOperationName());
375 // Get the call to get an operand or segment of operands.
376 Formatter
getOperand(unsigned index
) const {
377 return [this, index
](raw_ostream
&os
) -> raw_ostream
& {
378 return os
<< formatv(op
.getOperand(index
).isVariadic()
379 ? "this->getODSOperands({0})"
380 : "(*this->getODSOperands({0}).begin())",
385 // Get the call to get a result of segment of results.
386 Formatter
getResult(unsigned index
) const {
387 return [this, index
](raw_ostream
&os
) -> raw_ostream
& {
389 return os
<< "<no results should be generated>";
390 return os
<< formatv(op
.getResult(index
).isVariadic()
391 ? "this->getODSResults({0})"
392 : "(*this->getODSResults({0}).begin())",
397 // Return whether an op instance is available.
398 bool isEmittingForOp() const { return emitForOp
; }
400 // Return the ODS operation wrapper.
401 const Operator
&getOp() const { return op
; }
403 // Get the attribute metadata sorted by name.
404 const llvm::MapVector
<StringRef
, AttributeMetadata
> &getAttrMetadata() const {
408 /// Returns whether to emit a `Properties` struct for this operation or not.
409 bool hasProperties() const {
410 if (!op
.getProperties().empty())
412 if (!op
.getDialect().usePropertiesForAttributes())
414 if (op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments") ||
415 op
.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
417 return llvm::any_of(getAttrMetadata(),
418 [](const std::pair
<StringRef
, AttributeMetadata
> &it
) {
419 return !it
.second
.constraint
||
420 !it
.second
.constraint
->isDerivedAttr();
424 std::optional
<NamedProperty
> &getOperandSegmentsSize() {
425 return operandSegmentsSize
;
428 std::optional
<NamedProperty
> &getResultSegmentsSize() {
429 return resultSegmentsSize
;
432 uint32_t getOperandSegmentSizesLegacyIndex() {
433 return operandSegmentSizesLegacyIndex
;
436 uint32_t getResultSegmentSizesLegacyIndex() {
437 return resultSegmentSizesLegacyIndex
;
441 // Compute the attribute metadata.
442 void computeAttrMetadata();
444 // The operation ODS wrapper.
446 // True if code is being generate for an op. False for an adaptor.
447 const bool emitForOp
;
449 // The attribute metadata, mapped by name.
450 llvm::MapVector
<StringRef
, AttributeMetadata
> attrMetadata
;
453 std::optional
<NamedProperty
> operandSegmentsSize
;
454 std::string operandSegmentsSizeStorage
;
455 std::string operandSegmentsSizeParser
;
456 std::optional
<NamedProperty
> resultSegmentsSize
;
457 std::string resultSegmentsSizeStorage
;
458 std::string resultSegmentsSizeParser
;
460 // Indices to store the position in the emission order of the operand/result
461 // segment sizes attribute if emitted as part of the properties for legacy
462 // bytecode encodings, i.e. versions less than 6.
463 uint32_t operandSegmentSizesLegacyIndex
= 0;
464 uint32_t resultSegmentSizesLegacyIndex
= 0;
466 // The number of required attributes.
467 unsigned numRequired
;
472 void OpOrAdaptorHelper::computeAttrMetadata() {
473 // Enumerate the attribute names of this op, ensuring the attribute names are
474 // unique in case implicit attributes are explicitly registered.
475 for (const NamedAttribute
&namedAttr
: op
.getAttributes()) {
476 Attribute attr
= namedAttr
.attr
;
478 attr
.hasDefaultValue() || attr
.isOptional() || attr
.isDerivedAttr();
480 {namedAttr
.name
, AttributeMetadata
{namedAttr
.name
, !isOptional
, attr
}});
483 auto makeProperty
= [&](StringRef storageType
, StringRef parserCall
) {
487 /*storageType=*/storageType
,
488 /*interfaceType=*/"::llvm::ArrayRef<int32_t>",
489 /*convertFromStorageCall=*/"$_storage",
490 /*assignToStorageCall=*/
491 "::llvm::copy($_value, $_storage.begin())",
492 /*convertToAttributeCall=*/
493 "return ::mlir::DenseI32ArrayAttr::get($_ctxt, $_storage);",
494 /*convertFromAttributeCall=*/
495 "return convertFromAttribute($_storage, $_attr, $_diag);",
496 /*parserCall=*/parserCall
,
497 /*optionalParserCall=*/"",
498 /*printerCall=*/printTextualSegmentSize
,
499 /*readFromMlirBytecodeCall=*/readBytecodeSegmentSizeNative
,
500 /*writeToMlirBytecodeCall=*/writeBytecodeSegmentSizeNative
,
501 /*hashPropertyCall=*/
502 "::llvm::hash_combine_range(std::begin($_storage), "
503 "std::end($_storage));",
504 /*StringRef defaultValue=*/"",
505 /*storageTypeValueOverride=*/"");
507 // Include key attributes from several traits as implicitly registered.
508 if (op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
509 if (op
.getDialect().usePropertiesForAttributes()) {
510 operandSegmentsSizeStorage
=
511 llvm::formatv("std::array<int32_t, {0}>", op
.getNumOperands());
512 operandSegmentsSizeParser
=
513 llvm::formatv(parseTextualSegmentSizeFormat
, op
.getNumOperands());
514 operandSegmentsSize
= {
515 "operandSegmentSizes",
516 makeProperty(operandSegmentsSizeStorage
, operandSegmentsSizeParser
)};
519 {operandSegmentAttrName
, AttributeMetadata
{operandSegmentAttrName
,
521 /*attr=*/std::nullopt
}});
524 if (op
.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
525 if (op
.getDialect().usePropertiesForAttributes()) {
526 resultSegmentsSizeStorage
=
527 llvm::formatv("std::array<int32_t, {0}>", op
.getNumResults());
528 resultSegmentsSizeParser
=
529 llvm::formatv(parseTextualSegmentSizeFormat
, op
.getNumResults());
530 resultSegmentsSize
= {
531 "resultSegmentSizes",
532 makeProperty(resultSegmentsSizeStorage
, resultSegmentsSizeParser
)};
535 {resultSegmentAttrName
,
536 AttributeMetadata
{resultSegmentAttrName
, /*isRequired=*/true,
537 /*attr=*/std::nullopt
}});
541 // Store the metadata in sorted order.
542 SmallVector
<AttributeMetadata
> sortedAttrMetadata
=
543 llvm::to_vector(llvm::make_second_range(attrMetadata
.takeVector()));
544 llvm::sort(sortedAttrMetadata
,
545 [](const AttributeMetadata
&lhs
, const AttributeMetadata
&rhs
) {
546 return lhs
.attrName
< rhs
.attrName
;
549 // Store the position of the legacy operand_segment_sizes /
550 // result_segment_sizes so we can emit a backward compatible property readers
552 StringRef legacyOperandSegmentSizeName
=
553 StringLiteral("operand_segment_sizes");
554 StringRef legacyResultSegmentSizeName
= StringLiteral("result_segment_sizes");
555 operandSegmentSizesLegacyIndex
= 0;
556 resultSegmentSizesLegacyIndex
= 0;
557 for (auto item
: sortedAttrMetadata
) {
558 if (item
.attrName
< legacyOperandSegmentSizeName
)
559 ++operandSegmentSizesLegacyIndex
;
560 if (item
.attrName
< legacyResultSegmentSizeName
)
561 ++resultSegmentSizesLegacyIndex
;
564 // Compute the subrange bounds for each attribute.
566 for (AttributeMetadata
&attr
: sortedAttrMetadata
) {
567 attr
.lowerBound
= numRequired
;
568 numRequired
+= attr
.isRequired
;
570 for (AttributeMetadata
&attr
: sortedAttrMetadata
)
571 attr
.upperBound
= numRequired
- attr
.lowerBound
- attr
.isRequired
;
573 // Store the results back into the map.
574 for (const AttributeMetadata
&attr
: sortedAttrMetadata
)
575 attrMetadata
.insert({attr
.attrName
, attr
});
578 //===----------------------------------------------------------------------===//
580 //===----------------------------------------------------------------------===//
583 // Helper class to emit a record into the given output stream.
585 using ConstArgument
=
586 llvm::PointerUnion
<const AttributeMetadata
*, const NamedProperty
*>;
590 emitDecl(const Operator
&op
, raw_ostream
&os
,
591 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
);
593 emitDef(const Operator
&op
, raw_ostream
&os
,
594 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
);
597 OpEmitter(const Operator
&op
,
598 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
);
600 void emitDecl(raw_ostream
&os
);
601 void emitDef(raw_ostream
&os
);
603 // Generate methods for accessing the attribute names of this operation.
604 void genAttrNameGetters();
606 // Generates the OpAsmOpInterface for this operation if possible.
607 void genOpAsmInterface();
609 // Generates the `getOperationName` method for this op.
610 void genOpNameGetter();
612 // Generates code to manage the properties, if any!
613 void genPropertiesSupport();
615 // Generates code to manage the encoding of properties to bytecode.
617 genPropertiesSupportForBytecode(ArrayRef
<ConstArgument
> attrOrProperties
);
619 // Generates getters for the properties.
620 void genPropGetters();
622 // Generates seters for the properties.
623 void genPropSetters();
625 // Generates getters for the attributes.
626 void genAttrGetters();
628 // Generates setter for the attributes.
629 void genAttrSetters();
631 // Generates removers for optional attributes.
632 void genOptionalAttrRemovers();
634 // Generates getters for named operands.
635 void genNamedOperandGetters();
637 // Generates setters for named operands.
638 void genNamedOperandSetters();
640 // Generates getters for named results.
641 void genNamedResultGetters();
643 // Generates getters for named regions.
644 void genNamedRegionGetters();
646 // Generates getters for named successors.
647 void genNamedSuccessorGetters();
649 // Generates the method to populate default attributes.
650 void genPopulateDefaultAttributes();
652 // Generates builder methods for the operation.
655 // Generates the build() method that takes each operand/attribute
656 // as a stand-alone parameter.
657 void genSeparateArgParamBuilder();
659 // Generates the build() method that takes each operand/attribute as a
660 // stand-alone parameter. The generated build() method uses first operand's
661 // type as all results' types.
662 void genUseOperandAsResultTypeSeparateParamBuilder();
664 // Generates the build() method that takes all operands/attributes
665 // collectively as one parameter. The generated build() method uses first
666 // operand's type as all results' types.
667 void genUseOperandAsResultTypeCollectiveParamBuilder();
669 // Generates the build() method that takes aggregate operands/attributes
670 // parameters. This build() method uses inferred types as result types.
671 // Requires: The type needs to be inferable via InferTypeOpInterface.
672 void genInferredTypeCollectiveParamBuilder();
674 // Generates the build() method that takes each operand/attribute as a
675 // stand-alone parameter. The generated build() method uses first attribute's
676 // type as all result's types.
677 void genUseAttrAsResultTypeBuilder();
679 // Generates the build() method that takes all result types collectively as
680 // one parameter. Similarly for operands and attributes.
681 void genCollectiveParamBuilder();
683 // The kind of parameter to generate for result types in builders.
684 enum class TypeParamKind
{
685 None
, // No result type in parameter list.
686 Separate
, // A separate parameter for each result type.
687 Collective
, // An ArrayRef<Type> for all result types.
690 // The kind of parameter to generate for attributes in builders.
691 enum class AttrParamKind
{
692 WrappedAttr
, // A wrapped MLIR Attribute instance.
693 UnwrappedValue
, // A raw value without MLIR Attribute wrapper.
696 // Builds the parameter list for build() method of this op. This method writes
697 // to `paramList` the comma-separated parameter list and updates
698 // `resultTypeNames` with the names for parameters for specifying result
699 // types. `inferredAttributes` is populated with any attributes that are
700 // elided from the build list. The given `typeParamKind` and `attrParamKind`
701 // controls how result types and attributes are placed in the parameter list.
702 void buildParamList(SmallVectorImpl
<MethodParameter
> ¶mList
,
703 llvm::StringSet
<> &inferredAttributes
,
704 SmallVectorImpl
<std::string
> &resultTypeNames
,
705 TypeParamKind typeParamKind
,
706 AttrParamKind attrParamKind
= AttrParamKind::WrappedAttr
);
708 // Adds op arguments and regions into operation state for build() methods.
710 genCodeForAddingArgAndRegionForBuilder(MethodBody
&body
,
711 llvm::StringSet
<> &inferredAttributes
,
712 bool isRawValueAttr
= false);
714 // Generates canonicalizer declaration for the operation.
715 void genCanonicalizerDecls();
717 // Generates the folder declaration for the operation.
718 void genFolderDecls();
720 // Generates the parser for the operation.
723 // Generates the printer for the operation.
726 // Generates verify method for the operation.
729 // Generates custom verify methods for the operation.
730 void genCustomVerifier();
732 // Generates verify statements for operands and results in the operation.
733 // The generated code will be attached to `body`.
734 void genOperandResultVerifier(MethodBody
&body
,
735 Operator::const_value_range values
,
736 StringRef valueKind
);
738 // Generates verify statements for regions in the operation.
739 // The generated code will be attached to `body`.
740 void genRegionVerifier(MethodBody
&body
);
742 // Generates verify statements for successors in the operation.
743 // The generated code will be attached to `body`.
744 void genSuccessorVerifier(MethodBody
&body
);
746 // Generates the traits used by the object.
749 // Generate the OpInterface methods for all interfaces.
750 void genOpInterfaceMethods();
752 // Generate op interface methods for the given interface.
753 void genOpInterfaceMethods(const tblgen::InterfaceTrait
*trait
);
755 // Generate op interface method for the given interface method. If
756 // 'declaration' is true, generates a declaration, else a definition.
757 Method
*genOpInterfaceMethod(const tblgen::InterfaceMethod
&method
,
758 bool declaration
= true);
760 // Generate the side effect interface methods.
761 void genSideEffectInterfaceMethods();
763 // Generate the type inference interface methods.
764 void genTypeInterfaceMethods();
767 // The TableGen record for this op.
768 // TODO: OpEmitter should not have a Record directly,
769 // it should rather go through the Operator for better abstraction.
772 // The wrapper operator class for querying information from this op.
775 // The C++ code builder for this op
778 // The format context for verification code generation.
779 FmtContext verifyCtx
;
781 // The emitter containing all of the locally emitted verification functions.
782 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
;
784 // Helper for emitting op code.
785 OpOrAdaptorHelper emitHelper
;
790 // Populate the format context `ctx` with substitutions of attributes, operands
792 static void populateSubstitutions(const OpOrAdaptorHelper
&emitHelper
,
794 // Populate substitutions for attributes.
795 auto &op
= emitHelper
.getOp();
796 for (const auto &namedAttr
: op
.getAttributes())
797 ctx
.addSubst(namedAttr
.name
,
798 emitHelper
.getOp().getGetterName(namedAttr
.name
) + "()");
800 // Populate substitutions for named operands.
801 for (int i
= 0, e
= op
.getNumOperands(); i
< e
; ++i
) {
802 auto &value
= op
.getOperand(i
);
803 if (!value
.name
.empty())
804 ctx
.addSubst(value
.name
, emitHelper
.getOperand(i
).str());
807 // Populate substitutions for results.
808 for (int i
= 0, e
= op
.getNumResults(); i
< e
; ++i
) {
809 auto &value
= op
.getResult(i
);
810 if (!value
.name
.empty())
811 ctx
.addSubst(value
.name
, emitHelper
.getResult(i
).str());
815 /// Generate verification on native traits requiring attributes.
816 static void genNativeTraitAttrVerifier(MethodBody
&body
,
817 const OpOrAdaptorHelper
&emitHelper
) {
818 // Check that the variadic segment sizes attribute exists and contains the
819 // expected number of elements.
821 // {0}: Attribute name.
822 // {1}: Expected number of elements.
823 // {2}: "operand" or "result".
824 // {3}: Emit error prefix.
825 const char *const checkAttrSizedValueSegmentsCode
= R
"(
827 auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>(tblgen_{0});
828 auto numElements = sizeAttr.asArrayRef().size();
829 if (numElements != {1})
830 return {3}"'{0}' attribute
for specifying
{2} segments must have
{1} "
831 "elements
, but got
") << numElements;
835 // Verify a few traits first so that we can use getODSOperands() and
836 // getODSResults() in the rest of the verifier.
837 auto &op
= emitHelper
.getOp();
838 if (!op
.getDialect().usePropertiesForAttributes()) {
839 if (op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
840 body
<< formatv(checkAttrSizedValueSegmentsCode
, operandSegmentAttrName
,
841 op
.getNumOperands(), "operand",
842 emitHelper
.emitErrorPrefix());
844 if (op
.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
845 body
<< formatv(checkAttrSizedValueSegmentsCode
, resultSegmentAttrName
,
846 op
.getNumResults(), "result",
847 emitHelper
.emitErrorPrefix());
852 // Return true if a verifier can be emitted for the attribute: it is not a
853 // derived attribute, it has a predicate, its condition is not empty, and, for
854 // adaptors, the condition does not reference the op.
855 static bool canEmitAttrVerifier(Attribute attr
, bool isEmittingForOp
) {
856 if (attr
.isDerivedAttr())
858 Pred pred
= attr
.getPredicate();
861 std::string condition
= pred
.getCondition();
862 return !condition
.empty() &&
863 (!StringRef(condition
).contains("$_op") || isEmittingForOp
);
866 // Generate attribute verification. If an op instance is not available, then
867 // attribute checks that require one will not be emitted.
869 // Attribute verification is performed as follows:
871 // 1. Verify that all required attributes are present in sorted order. This
872 // ensures that we can use subrange lookup even with potentially missing
874 // 2. Verify native trait attributes so that other attributes may call methods
875 // that depend on the validity of these attributes, e.g. segment size attributes
876 // and operand or result getters.
877 // 3. Verify the constraints on all present attributes.
879 genAttributeVerifier(const OpOrAdaptorHelper
&emitHelper
, FmtContext
&ctx
,
881 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
,
882 bool useProperties
) {
883 if (emitHelper
.getAttrMetadata().empty())
886 // Verify the attribute if it is present. This assumes that default values
887 // are valid. This code snippet pastes the condition inline.
889 // TODO: verify the default value is valid (perhaps in debug mode only).
891 // {0}: Attribute variable name.
892 // {1}: Attribute condition code.
893 // {2}: Emit error prefix.
894 // {3}: Attribute name.
895 // {4}: Attribute/constraint description.
896 const char *const verifyAttrInline
= R
"(
898 return {2}"attribute
'{3}' failed to satisfy constraint
: {4}");
900 // Verify the attribute using a uniqued constraint. Can only be used within
901 // the context of an op.
903 // {0}: Unique constraint name.
904 // {1}: Attribute variable name.
905 // {2}: Attribute name.
906 const char *const verifyAttrUnique
= R
"(
907 if (::mlir::failed({0}(*this, {1}, "{2}")))
908 return ::mlir::failure();
911 // Traverse the array until the required attribute is found. Return an error
912 // if the traversal reached the end.
914 // {0}: Code to get the name of the attribute.
915 // {1}: The emit error prefix.
916 // {2}: The name of the attribute.
917 const char *const findRequiredAttr
= R
"(
919 if (namedAttrIt == namedAttrRange.end())
920 return {1}"requires attribute
'{2}'");
921 if (namedAttrIt->getName() == {0}) {{
922 tblgen_{2} = namedAttrIt->getValue();
926 // Emit a check to see if the iteration has encountered an optional attribute.
928 // {0}: Code to get the name of the attribute.
929 // {1}: The name of the attribute.
930 const char *const checkOptionalAttr
= R
"(
931 else if (namedAttrIt->getName() == {0}) {{
932 tblgen_{1} = namedAttrIt->getValue();
935 // Emit the start of the loop for checking trailing attributes.
936 const char *const checkTrailingAttrs
= R
"(while (true) {
937 if (namedAttrIt == namedAttrRange.end()) {
941 // Emit the verifier for the attribute.
942 const auto emitVerifier
= [&](Attribute attr
, StringRef attrName
,
944 std::string condition
= attr
.getPredicate().getCondition();
946 std::optional
<StringRef
> constraintFn
;
947 if (emitHelper
.isEmittingForOp() &&
948 (constraintFn
= staticVerifierEmitter
.getAttrConstraintFn(attr
))) {
949 body
<< formatv(verifyAttrUnique
, *constraintFn
, varName
, attrName
);
951 body
<< formatv(verifyAttrInline
, varName
,
952 tgfmt(condition
, &ctx
.withSelf(varName
)),
953 emitHelper
.emitErrorPrefix(), attrName
,
954 escapeString(attr
.getSummary()));
958 // Prefix variables with `tblgen_` to avoid hiding the attribute accessor.
959 const auto getVarName
= [&](StringRef attrName
) {
960 return (tblgenNamePrefix
+ attrName
).str();
965 for (const std::pair
<StringRef
, AttributeMetadata
> &it
:
966 emitHelper
.getAttrMetadata()) {
967 const AttributeMetadata
&metadata
= it
.second
;
968 if (metadata
.constraint
&& metadata
.constraint
->isDerivedAttr())
971 "auto tblgen_{0} = getProperties().{0}; (void)tblgen_{0};\n",
973 if (metadata
.isRequired
)
975 "if (!tblgen_{0}) return {1}\"requires attribute '{0}'\");\n",
976 it
.first
, emitHelper
.emitErrorPrefix());
979 body
<< formatv("auto namedAttrRange = {0};\n", emitHelper
.getAttrRange());
980 body
<< "auto namedAttrIt = namedAttrRange.begin();\n";
982 // Iterate over the attributes in sorted order. Keep track of the optional
983 // attributes that may be encountered along the way.
984 SmallVector
<const AttributeMetadata
*> optionalAttrs
;
986 for (const std::pair
<StringRef
, AttributeMetadata
> &it
:
987 emitHelper
.getAttrMetadata()) {
988 const AttributeMetadata
&metadata
= it
.second
;
989 if (!metadata
.isRequired
) {
990 optionalAttrs
.push_back(&metadata
);
994 body
<< formatv("::mlir::Attribute {0};\n", getVarName(it
.first
));
995 for (const AttributeMetadata
*optional
: optionalAttrs
) {
996 body
<< formatv("::mlir::Attribute {0};\n",
997 getVarName(optional
->attrName
));
999 body
<< formatv(findRequiredAttr
, emitHelper
.getAttrName(it
.first
),
1000 emitHelper
.emitErrorPrefix(), it
.first
);
1001 for (const AttributeMetadata
*optional
: optionalAttrs
) {
1002 body
<< formatv(checkOptionalAttr
,
1003 emitHelper
.getAttrName(optional
->attrName
),
1004 optional
->attrName
);
1006 body
<< "\n ++namedAttrIt;\n}\n";
1007 optionalAttrs
.clear();
1009 // Get trailing optional attributes.
1010 if (!optionalAttrs
.empty()) {
1011 for (const AttributeMetadata
*optional
: optionalAttrs
) {
1012 body
<< formatv("::mlir::Attribute {0};\n",
1013 getVarName(optional
->attrName
));
1015 body
<< checkTrailingAttrs
;
1016 for (const AttributeMetadata
*optional
: optionalAttrs
) {
1017 body
<< formatv(checkOptionalAttr
,
1018 emitHelper
.getAttrName(optional
->attrName
),
1019 optional
->attrName
);
1021 body
<< "\n ++namedAttrIt;\n}\n";
1026 // Emit the checks for segment attributes first so that the other
1027 // constraints can call operand and result getters.
1028 genNativeTraitAttrVerifier(body
, emitHelper
);
1030 bool isEmittingForOp
= emitHelper
.isEmittingForOp();
1031 for (const auto &namedAttr
: emitHelper
.getOp().getAttributes())
1032 if (canEmitAttrVerifier(namedAttr
.attr
, isEmittingForOp
))
1033 emitVerifier(namedAttr
.attr
, namedAttr
.name
, getVarName(namedAttr
.name
));
1036 /// Include declarations specified on NativeTrait
1037 static std::string
formatExtraDeclarations(const Operator
&op
) {
1038 SmallVector
<StringRef
> extraDeclarations
;
1039 // Include extra class declarations from NativeTrait
1040 for (const auto &trait
: op
.getTraits()) {
1041 if (auto *opTrait
= dyn_cast
<tblgen::NativeTrait
>(&trait
)) {
1042 StringRef value
= opTrait
->getExtraConcreteClassDeclaration();
1045 extraDeclarations
.push_back(value
);
1048 extraDeclarations
.push_back(op
.getExtraClassDeclaration());
1049 return llvm::join(extraDeclarations
, "\n");
1052 /// Op extra class definitions have a `$cppClass` substitution that is to be
1053 /// replaced by the C++ class name.
1054 /// Include declarations specified on NativeTrait
1055 static std::string
formatExtraDefinitions(const Operator
&op
) {
1056 SmallVector
<StringRef
> extraDefinitions
;
1057 // Include extra class definitions from NativeTrait
1058 for (const auto &trait
: op
.getTraits()) {
1059 if (auto *opTrait
= dyn_cast
<tblgen::NativeTrait
>(&trait
)) {
1060 StringRef value
= opTrait
->getExtraConcreteClassDefinition();
1063 extraDefinitions
.push_back(value
);
1066 extraDefinitions
.push_back(op
.getExtraClassDefinition());
1067 FmtContext ctx
= FmtContext().addSubst("cppClass", op
.getCppClassName());
1068 return tgfmt(llvm::join(extraDefinitions
, "\n"), &ctx
).str();
1071 OpEmitter::OpEmitter(const Operator
&op
,
1072 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
)
1073 : def(op
.getDef()), op(op
),
1074 opClass(op
.getCppClassName(), formatExtraDeclarations(op
),
1075 formatExtraDefinitions(op
)),
1076 staticVerifierEmitter(staticVerifierEmitter
),
1077 emitHelper(op
, /*emitForOp=*/true) {
1078 verifyCtx
.addSubst("_op", "(*this->getOperation())");
1079 verifyCtx
.addSubst("_ctxt", "this->getOperation()->getContext()");
1083 // Generate C++ code for various op methods. The order here determines the
1084 // methods in the generated file.
1085 genAttrNameGetters();
1086 genOpAsmInterface();
1088 genNamedOperandGetters();
1089 genNamedOperandSetters();
1090 genNamedResultGetters();
1091 genNamedRegionGetters();
1092 genNamedSuccessorGetters();
1093 genPropertiesSupport();
1098 genOptionalAttrRemovers();
1100 genPopulateDefaultAttributes();
1104 genCustomVerifier();
1105 genCanonicalizerDecls();
1107 genTypeInterfaceMethods();
1108 genOpInterfaceMethods();
1109 generateOpFormat(op
, opClass
, emitHelper
.hasProperties());
1110 genSideEffectInterfaceMethods();
1112 void OpEmitter::emitDecl(
1113 const Operator
&op
, raw_ostream
&os
,
1114 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
) {
1115 OpEmitter(op
, staticVerifierEmitter
).emitDecl(os
);
1118 void OpEmitter::emitDef(
1119 const Operator
&op
, raw_ostream
&os
,
1120 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
) {
1121 OpEmitter(op
, staticVerifierEmitter
).emitDef(os
);
1124 void OpEmitter::emitDecl(raw_ostream
&os
) {
1126 opClass
.writeDeclTo(os
);
1129 void OpEmitter::emitDef(raw_ostream
&os
) {
1131 opClass
.writeDefTo(os
);
1134 static void errorIfPruned(size_t line
, Method
*m
, const Twine
&methodName
,
1135 const Operator
&op
) {
1138 PrintFatalError(op
.getLoc(), "Unexpected overlap when generating `" +
1139 methodName
+ "` for " +
1140 op
.getOperationName() + " (from line " +
1144 #define ERROR_IF_PRUNED(M, N, O) errorIfPruned(__LINE__, M, N, O)
1146 void OpEmitter::genAttrNameGetters() {
1147 const llvm::MapVector
<StringRef
, AttributeMetadata
> &attributes
=
1148 emitHelper
.getAttrMetadata();
1149 bool hasOperandSegmentsSize
=
1150 op
.getDialect().usePropertiesForAttributes() &&
1151 op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
1152 // Emit the getAttributeNames method.
1154 auto *method
= opClass
.addStaticInlineMethod(
1155 "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames");
1156 ERROR_IF_PRUNED(method
, "getAttributeNames", op
);
1157 auto &body
= method
->body();
1158 if (!hasOperandSegmentsSize
&& attributes
.empty()) {
1159 body
<< " return {};";
1160 // Nothing else to do if there are no registered attributes. Exit early.
1163 body
<< " static ::llvm::StringRef attrNames[] = {";
1164 llvm::interleaveComma(llvm::make_first_range(attributes
), body
,
1165 [&](StringRef attrName
) {
1166 body
<< "::llvm::StringRef(\"" << attrName
<< "\")";
1168 if (hasOperandSegmentsSize
) {
1169 if (!attributes
.empty())
1171 body
<< "::llvm::StringRef(\"" << operandSegmentAttrName
<< "\")";
1173 body
<< "};\n return ::llvm::ArrayRef(attrNames);";
1176 // Emit the getAttributeNameForIndex methods.
1178 auto *method
= opClass
.addInlineMethod
<Method::Private
>(
1179 "::mlir::StringAttr", "getAttributeNameForIndex",
1180 MethodParameter("unsigned", "index"));
1181 ERROR_IF_PRUNED(method
, "getAttributeNameForIndex", op
);
1183 << " return getAttributeNameForIndex((*this)->getName(), index);";
1186 auto *method
= opClass
.addStaticInlineMethod
<Method::Private
>(
1187 "::mlir::StringAttr", "getAttributeNameForIndex",
1188 MethodParameter("::mlir::OperationName", "name"),
1189 MethodParameter("unsigned", "index"));
1190 ERROR_IF_PRUNED(method
, "getAttributeNameForIndex", op
);
1192 if (attributes
.empty()) {
1193 method
->body() << " return {};";
1195 const char *const getAttrName
= R
"(
1196 assert(index < {0} && "invalid attribute index
");
1197 assert(name.getStringRef() == getOperationName() && "invalid operation name
");
1198 assert(name.isRegistered() && "Operation isn
't registered, missing a "
1199 "dependent dialect loading?");
1200 return name.getAttributeNames()[index];
1202 method->body() << formatv(getAttrName, attributes.size());
1206 // Generate the <attr>AttrName methods, that expose the attribute names to
1208 const char *attrNameMethodBody = " return getAttributeNameForIndex({0});";
1209 for (auto [index, attr] :
1210 llvm::enumerate(llvm::make_first_range(attributes))) {
1211 std::string name = op.getGetterName(attr);
1212 std::string methodName = name + "AttrName";
1214 // Generate the non-static variant.
1216 auto *method = opClass.addInlineMethod("::mlir::StringAttr", methodName);
1217 ERROR_IF_PRUNED(method, methodName, op);
1218 method->body() << llvm::formatv(attrNameMethodBody, index);
1221 // Generate the static variant.
1223 auto *method = opClass.addStaticInlineMethod(
1224 "::mlir::StringAttr", methodName,
1225 MethodParameter("::mlir::OperationName", "name"));
1226 ERROR_IF_PRUNED(method, methodName, op);
1227 method->body() << llvm::formatv(attrNameMethodBody,
1228 "name, " + Twine(index));
1231 if (hasOperandSegmentsSize) {
1232 std::string name = op.getGetterName(operandSegmentAttrName);
1233 std::string methodName = name + "AttrName";
1234 // Generate the non-static variant.
1236 auto *method = opClass.addInlineMethod("::mlir::StringAttr", methodName);
1237 ERROR_IF_PRUNED(method, methodName, op);
1239 << " return (*this)->getName().getAttributeNames().back();";
1242 // Generate the static variant.
1244 auto *method = opClass.addStaticInlineMethod(
1245 "::mlir::StringAttr", methodName,
1246 MethodParameter("::mlir::OperationName", "name"));
1247 ERROR_IF_PRUNED(method, methodName, op);
1248 method->body() << " return name.getAttributeNames().back();";
1253 // Emit the getter for a named property.
1254 // It is templated to be shared between the Op and the adaptor class.
1255 template <typename OpClassOrAdaptor>
1256 static void emitPropGetter(OpClassOrAdaptor &opClass, const Operator &op,
1257 StringRef name, const Property &prop) {
1258 auto *method = opClass.addInlineMethod(prop.getInterfaceType(), name);
1259 ERROR_IF_PRUNED(method, name, op);
1260 method->body() << formatv(" return getProperties().{0}();", name);
1263 // Emit the getter for an attribute with the return type specified.
1264 // It is templated to be shared between the Op and the adaptor class.
1265 template <typename OpClassOrAdaptor>
1266 static void emitAttrGetterWithReturnType(FmtContext &fctx,
1267 OpClassOrAdaptor &opClass,
1268 const Operator &op, StringRef name,
1270 auto *method = opClass.addMethod(attr.getReturnType(), name);
1271 ERROR_IF_PRUNED(method, name, op);
1272 auto &body = method->body();
1273 body << " auto attr = " << name << "Attr();\n";
1274 if (attr.hasDefaultValue() && attr.isOptional()) {
1275 // Returns the default value if not set.
1276 // TODO: this is inefficient, we are recreating the attribute for every
1277 // call. This should be set instead.
1278 if (!attr.isConstBuildable()) {
1279 PrintFatalError("DefaultValuedAttr of type " + attr.getAttrDefName() +
1280 " must have a constBuilder");
1282 std::string defaultValue = std::string(
1283 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
1284 body << " if (!attr)\n return "
1285 << tgfmt(attr.getConvertFromStorageCall(),
1286 &fctx.withSelf(defaultValue))
1290 << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr"))
1294 void OpEmitter::genPropertiesSupport() {
1295 if (!emitHelper.hasProperties())
1298 SmallVector<ConstArgument> attrOrProperties;
1299 for (const std::pair<StringRef, AttributeMetadata> &it :
1300 emitHelper.getAttrMetadata()) {
1301 if (!it.second.constraint || !it.second.constraint->isDerivedAttr())
1302 attrOrProperties.push_back(&it.second);
1304 for (const NamedProperty &prop : op.getProperties())
1305 attrOrProperties.push_back(&prop);
1306 if (emitHelper.getOperandSegmentsSize())
1307 attrOrProperties.push_back(&emitHelper.getOperandSegmentsSize().value());
1308 if (emitHelper.getResultSegmentsSize())
1309 attrOrProperties.push_back(&emitHelper.getResultSegmentsSize().value());
1310 if (attrOrProperties.empty())
1312 auto &setPropMethod =
1315 "::llvm::LogicalResult", "setPropertiesFromAttr",
1316 MethodParameter("Properties &", "prop"),
1317 MethodParameter("::mlir::Attribute", "attr"),
1319 "::llvm::function_ref<::mlir::InFlightDiagnostic()>",
1322 auto &getPropMethod =
1324 .addStaticMethod("::mlir::Attribute", "getPropertiesAsAttr",
1325 MethodParameter("::mlir::MLIRContext *", "ctx"),
1326 MethodParameter("const Properties &", "prop"))
1330 .addStaticMethod("llvm::hash_code", "computePropertiesHash",
1331 MethodParameter("const Properties &", "prop"))
1333 auto &getInherentAttrMethod =
1335 .addStaticMethod("std::optional<mlir::Attribute>", "getInherentAttr",
1336 MethodParameter("::mlir::MLIRContext *", "ctx"),
1337 MethodParameter("const Properties &", "prop"),
1338 MethodParameter("llvm::StringRef", "name"))
1340 auto &setInherentAttrMethod =
1342 .addStaticMethod("void", "setInherentAttr",
1343 MethodParameter("Properties &", "prop"),
1344 MethodParameter("llvm::StringRef", "name"),
1345 MethodParameter("mlir::Attribute", "value"))
1347 auto &populateInherentAttrsMethod =
1349 .addStaticMethod("void", "populateInherentAttrs",
1350 MethodParameter("::mlir::MLIRContext *", "ctx"),
1351 MethodParameter("const Properties &", "prop"),
1352 MethodParameter("::mlir::NamedAttrList &", "attrs"))
1354 auto &verifyInherentAttrsMethod =
1357 "::llvm::LogicalResult", "verifyInherentAttrs",
1358 MethodParameter("::mlir::OperationName", "opName"),
1359 MethodParameter("::mlir::NamedAttrList &", "attrs"),
1361 "llvm::function_ref<::mlir::InFlightDiagnostic()>",
1365 opClass.declare<UsingDeclaration>("Properties", "FoldAdaptor::Properties");
1367 // Convert the property to the attribute form.
1369 setPropMethod << R"decl(
1370 ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr);
1372 emitError() << "expected DictionaryAttr to set properties";
1373 return ::mlir::failure();
1376 const char *propFromAttrFmt = R"decl(
1377 auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr,
1378 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) -> ::mlir::LogicalResult {{
1383 const char *attrGetNoDefaultFmt = R"decl(;
1384 if (attr && ::mlir::failed(setFromAttr(prop.{0}, attr, emitError)))
1385 return ::mlir::failure();
1387 const char *attrGetDefaultFmt = R"decl(;
1389 if (::mlir::failed(setFromAttr(prop.{0}, attr, emitError)))
1390 return ::mlir::failure();
1396 for (const auto &attrOrProp : attrOrProperties) {
1397 if (const auto *namedProperty =
1398 llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
1399 StringRef name = namedProperty->name;
1400 auto &prop = namedProperty->prop;
1403 std::string getAttr;
1404 llvm::raw_string_ostream os(getAttr);
1405 os << " auto attr = dict.get(\"" << name << "\");";
1406 if (name == operandSegmentAttrName) {
1407 // Backward compat for now, TODO: Remove at some point.
1408 os << " if (!attr) attr = dict.get(\"operand_segment_sizes\");";
1410 if (name == resultSegmentAttrName) {
1411 // Backward compat for now, TODO: Remove at some point.
1412 os << " if (!attr) attr = dict.get(\"result_segment_sizes\");";
1415 setPropMethod << "{\n"
1416 << formatv(propFromAttrFmt,
1417 tgfmt(prop.getConvertFromAttributeCall(),
1418 &fctx.addSubst("_attr", propertyAttr)
1419 .addSubst("_storage", propertyStorage)
1420 .addSubst("_diag", propertyDiag)),
1422 if (prop.hasStorageTypeValueOverride()) {
1423 setPropMethod << formatv(attrGetDefaultFmt, name,
1424 prop.getStorageTypeValueOverride());
1425 } else if (prop.hasDefaultValue()) {
1426 setPropMethod << formatv(attrGetDefaultFmt, name,
1427 prop.getDefaultValue());
1429 setPropMethod << formatv(attrGetNoDefaultFmt, name);
1431 setPropMethod << " }\n";
1433 const auto *namedAttr =
1434 llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
1435 StringRef name = namedAttr->attrName;
1436 std::string getAttr;
1437 llvm::raw_string_ostream os(getAttr);
1438 os << " auto attr = dict.get(\"" << name << "\");";
1439 if (name == operandSegmentAttrName) {
1440 // Backward compat for now
1441 os << " if (!attr) attr = dict.get(\"operand_segment_sizes\");";
1443 if (name == resultSegmentAttrName) {
1444 // Backward compat for now
1445 os << " if (!attr) attr = dict.get(\"result_segment_sizes\");";
1448 setPropMethod << formatv(R"decl(
1450 auto &propStorage = prop.{0};
1453 auto convertedAttr = ::llvm::dyn_cast<std::remove_reference_t<decltype(propStorage)>>(attr);
1454 if (convertedAttr) {{
1455 propStorage = convertedAttr;
1457 emitError() << "Invalid attribute `{0}` in property conversion: " << attr;
1458 return ::mlir::failure();
1466 setPropMethod << " return ::mlir::success();\n";
1468 // Convert the attribute form to the property.
1470 getPropMethod << " ::mlir::SmallVector<::mlir::NamedAttribute> attrs;\n"
1471 << " ::mlir::Builder odsBuilder{ctx};\n";
1472 const char *propToAttrFmt = R"decl(
1474 const auto &propStorage = prop.{0};
1475 auto attr = [&]() -> ::mlir::Attribute {{
1478 attrs.push_back(odsBuilder.getNamedAttr("{0}", attr));
1481 for (const auto &attrOrProp : attrOrProperties) {
1482 if (const auto *namedProperty =
1483 llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
1484 StringRef name = namedProperty->name;
1485 auto &prop = namedProperty->prop;
1487 getPropMethod << formatv(
1488 propToAttrFmt, name,
1489 tgfmt(prop.getConvertToAttributeCall(),
1490 &fctx.addSubst("_ctxt", "ctx")
1491 .addSubst("_storage", propertyStorage)));
1494 const auto *namedAttr =
1495 llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
1496 StringRef name = namedAttr->attrName;
1497 getPropMethod << formatv(R"decl(
1499 const auto &propStorage = prop.{0};
1501 attrs.push_back(odsBuilder.getNamedAttr("{0}",
1507 getPropMethod << R"decl(
1509 return odsBuilder.getDictionaryAttr(attrs);
1513 // Hashing for the property
1515 const char *propHashFmt = R"decl(
1516 auto hash_{0} = [] (const auto &propStorage) -> llvm::hash_code {
1520 for (const auto &attrOrProp : attrOrProperties) {
1521 if (const auto *namedProperty =
1522 llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
1523 StringRef name = namedProperty->name;
1524 auto &prop = namedProperty->prop;
1526 if (!prop.getHashPropertyCall().empty()) {
1527 hashMethod << formatv(
1529 tgfmt(prop.getHashPropertyCall(),
1530 &fctx.addSubst("_storage", propertyStorage)));
1534 hashMethod << " return llvm::hash_combine(";
1535 llvm::interleaveComma(
1536 attrOrProperties, hashMethod, [&](const ConstArgument &attrOrProp) {
1537 if (const auto *namedProperty =
1538 llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
1539 if (!namedProperty->prop.getHashPropertyCall().empty()) {
1540 hashMethod << "\n hash_" << namedProperty->name << "(prop."
1541 << namedProperty->name << ")";
1543 hashMethod << "\n ::llvm::hash_value(prop."
1544 << namedProperty->name << ")";
1548 const auto *namedAttr =
1549 llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
1550 StringRef name = namedAttr->attrName;
1551 hashMethod << "\n llvm::hash_value(prop." << name
1552 << ".getAsOpaquePointer())";
1554 hashMethod << ");\n";
1556 const char *getInherentAttrMethodFmt = R"decl(
1560 const char *setInherentAttrMethodFmt = R"decl(
1561 if (name == "{0}") {{
1562 prop.{0} = ::llvm::dyn_cast_or_null<std::remove_reference_t<decltype(prop.{0})>>(value);
1566 const char *populateInherentAttrsMethodFmt = R"decl(
1567 if (prop.{0}) attrs.append("{0}", prop.{0});
1569 for (const auto &attrOrProp : attrOrProperties) {
1570 if (const auto *namedAttr =
1571 llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp)) {
1572 StringRef name = namedAttr->attrName;
1573 getInherentAttrMethod << formatv(getInherentAttrMethodFmt, name);
1574 setInherentAttrMethod << formatv(setInherentAttrMethodFmt, name);
1575 populateInherentAttrsMethod
1576 << formatv(populateInherentAttrsMethodFmt, name);
1579 // The ODS segment size property is "special": we expose it as an attribute
1580 // even though it is a native property.
1581 const auto *namedProperty = cast<const NamedProperty *>(attrOrProp);
1582 StringRef name = namedProperty->name;
1583 if (name != operandSegmentAttrName && name != resultSegmentAttrName)
1585 auto &prop = namedProperty->prop;
1587 fctx.addSubst("_ctxt", "ctx");
1588 fctx.addSubst("_storage", Twine("prop.") + name);
1589 if (name == operandSegmentAttrName) {
1590 getInherentAttrMethod
1591 << formatv(" if (name == \"operand_segment_sizes\" || name == "
1593 operandSegmentAttrName);
1595 getInherentAttrMethod
1596 << formatv(" if (name == \"result_segment_sizes\" || name == "
1598 resultSegmentAttrName);
1600 getInherentAttrMethod << "[&]() -> ::mlir::Attribute { "
1601 << tgfmt(prop.getConvertToAttributeCall(), &fctx)
1604 if (name == operandSegmentAttrName) {
1605 setInherentAttrMethod
1606 << formatv(" if (name == \"operand_segment_sizes\" || name == "
1608 operandSegmentAttrName);
1610 setInherentAttrMethod
1611 << formatv(" if (name == \"result_segment_sizes\" || name == "
1613 resultSegmentAttrName);
1615 setInherentAttrMethod << formatv(R"decl(
1616 auto arrAttr = ::llvm::dyn_cast_or_null<::mlir::DenseI32ArrayAttr>(value);
1617 if (!arrAttr) return;
1618 if (arrAttr.size() != sizeof(prop.{0}) / sizeof(int32_t))
1620 llvm::copy(arrAttr.asArrayRef(), prop.{0}.begin());
1625 if (name == operandSegmentAttrName) {
1626 populateInherentAttrsMethod << formatv(
1627 " attrs.append(\"{0}\", [&]() -> ::mlir::Attribute { {1} }());\n",
1628 operandSegmentAttrName,
1629 tgfmt(prop.getConvertToAttributeCall(), &fctx));
1631 populateInherentAttrsMethod << formatv(
1632 " attrs.append(\"{0}\", [&]() -> ::mlir::Attribute { {1} }());\n",
1633 resultSegmentAttrName,
1634 tgfmt(prop.getConvertToAttributeCall(), &fctx));
1637 getInherentAttrMethod << " return std::nullopt;\n";
1639 // Emit the verifiers method for backward compatibility with the generic
1640 // syntax. This method verifies the constraint on the properties attributes
1641 // before they are set, since dyn_cast<> will silently omit failures.
1642 for (const auto &attrOrProp : attrOrProperties) {
1643 const auto *namedAttr =
1644 llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
1645 if (!namedAttr || !namedAttr->constraint)
1647 Attribute attr = *namedAttr->constraint;
1648 std::optional<StringRef> constraintFn =
1649 staticVerifierEmitter.getAttrConstraintFn(attr);
1652 if (canEmitAttrVerifier(attr,
1653 /*isEmittingForOp=*/false)) {
1654 std::string name = op.getGetterName(namedAttr->attrName);
1655 verifyInherentAttrsMethod
1658 ::mlir::Attribute attr = attrs.get({0}AttrName(opName));
1659 if (attr && ::mlir::failed({1}(attr, "{2}", emitError)))
1660 return ::mlir::failure();
1663 name, constraintFn, namedAttr->attrName);
1666 verifyInherentAttrsMethod << " return ::mlir::success();";
1668 // Generate methods to interact with bytecode.
1669 genPropertiesSupportForBytecode(attrOrProperties);
1672 void OpEmitter::genPropertiesSupportForBytecode(
1673 ArrayRef<ConstArgument> attrOrProperties) {
1674 if (op.useCustomPropertiesEncoding()) {
1675 opClass.declareStaticMethod(
1676 "::llvm::LogicalResult", "readProperties",
1677 MethodParameter("::mlir::DialectBytecodeReader &", "reader"),
1678 MethodParameter("::mlir::OperationState &", "state"));
1679 opClass.declareMethod(
1680 "void", "writeProperties",
1681 MethodParameter("::mlir::DialectBytecodeWriter &", "writer"));
1685 auto &readPropertiesMethod =
1688 "::llvm::LogicalResult", "readProperties",
1689 MethodParameter("::mlir::DialectBytecodeReader &", "reader"),
1690 MethodParameter("::mlir::OperationState &", "state"))
1693 auto &writePropertiesMethod =
1696 "void", "writeProperties",
1697 MethodParameter("::mlir::DialectBytecodeWriter &", "writer"))
1700 // Populate bytecode serialization logic.
1701 readPropertiesMethod
1702 << " auto &prop = state.getOrAddProperties<Properties>(); (void)prop;";
1703 writePropertiesMethod << " auto &prop = getProperties(); (void)prop;\n";
1704 for (const auto &item : llvm::enumerate(attrOrProperties)) {
1705 auto &attrOrProp = item.value();
1707 fctx.addSubst("_reader", "reader")
1708 .addSubst("_writer", "writer")
1709 .addSubst("_storage", propertyStorage)
1710 .addSubst("_ctxt", "this->getContext()");
1711 // If the op emits operand/result segment sizes as a property, emit the
1712 // legacy reader/writer in the appropriate order to allow backward
1713 // compatibility and back deployment.
1714 if (emitHelper.getOperandSegmentsSize().has_value() &&
1715 item.index() == emitHelper.getOperandSegmentSizesLegacyIndex()) {
1716 FmtContext fmtCtxt(fctx);
1717 fmtCtxt.addSubst("_propName", operandSegmentAttrName);
1718 readPropertiesMethod << tgfmt(readBytecodeSegmentSizeLegacy, &fmtCtxt);
1719 writePropertiesMethod << tgfmt(writeBytecodeSegmentSizeLegacy, &fmtCtxt);
1721 if (emitHelper.getResultSegmentsSize().has_value() &&
1722 item.index() == emitHelper.getResultSegmentSizesLegacyIndex()) {
1723 FmtContext fmtCtxt(fctx);
1724 fmtCtxt.addSubst("_propName", resultSegmentAttrName);
1725 readPropertiesMethod << tgfmt(readBytecodeSegmentSizeLegacy, &fmtCtxt);
1726 writePropertiesMethod << tgfmt(writeBytecodeSegmentSizeLegacy, &fmtCtxt);
1728 if (const auto *namedProperty =
1729 attrOrProp.dyn_cast<const NamedProperty *>()) {
1730 StringRef name = namedProperty->name;
1731 readPropertiesMethod << formatv(
1734 auto &propStorage = prop.{0};
1735 auto readProp = [&]() {
1737 return ::mlir::success();
1739 if (::mlir::failed(readProp()))
1740 return ::mlir::failure();
1744 tgfmt(namedProperty->prop.getReadFromMlirBytecodeCall(), &fctx));
1745 writePropertiesMethod << formatv(
1748 auto &propStorage = prop.{0};
1752 name, tgfmt(namedProperty->prop.getWriteToMlirBytecodeCall(), &fctx));
1755 const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
1756 StringRef name = namedAttr->attrName;
1757 if (namedAttr->isRequired) {
1758 readPropertiesMethod << formatv(R"(
1759 if (::mlir::failed(reader.readAttribute(prop.{0})))
1760 return ::mlir::failure();
1763 writePropertiesMethod
1764 << formatv(" writer.writeAttribute(prop.{0});\n", name);
1766 readPropertiesMethod << formatv(R"(
1767 if (::mlir::failed(reader.readOptionalAttribute(prop.{0})))
1768 return ::mlir::failure();
1771 writePropertiesMethod << formatv(R"(
1772 writer.writeOptionalAttribute(prop.{0});
1777 readPropertiesMethod << " return ::mlir::success();";
1780 void OpEmitter::genPropGetters() {
1781 for (const NamedProperty &prop : op.getProperties()) {
1782 std::string name = op.getGetterName(prop.name);
1783 emitPropGetter(opClass, op, name, prop.prop);
1787 void OpEmitter::genPropSetters() {
1788 for (const NamedProperty &prop : op.getProperties()) {
1789 std::string name = op.getSetterName(prop.name);
1790 std::string argName = "new" + convertToCamelFromSnakeCase(
1791 prop.name, /*capitalizeFirst=*/true);
1792 auto *method = opClass.addInlineMethod(
1793 "void", name, MethodParameter(prop.prop.getInterfaceType(), argName));
1796 method->body() << formatv(" getProperties().{0}({1});", name, argName);
1800 void OpEmitter::genAttrGetters() {
1802 fctx.withBuilder("::mlir::Builder((*this)->getContext())");
1804 // Emit the derived attribute body.
1805 auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
1806 if (auto *method = opClass.addMethod(attr.getReturnType(), name))
1807 method->body() << " " << attr.getDerivedCodeBody() << "\n";
1810 // Generate named accessor with Attribute return type. This is a wrapper
1811 // class that allows referring to the attributes via accessors instead of
1812 // having to use the string interface for better compile time verification.
1813 auto emitAttrWithStorageType = [&](StringRef name, StringRef attrName,
1815 // The method body for this getter is trivial. Emit it inline.
1817 opClass.addInlineMethod(attr.getStorageType(), name + "Attr");
1820 method->body() << formatv(
1821 " return ::llvm::{1}<{2}>({0});", emitHelper.getAttr(attrName),
1822 attr.isOptional() || attr.hasDefaultValue() ? "dyn_cast_or_null"
1824 attr.getStorageType());
1827 for (const NamedAttribute &namedAttr : op.getAttributes()) {
1828 std::string name = op.getGetterName(namedAttr.name);
1829 if (namedAttr.attr.isDerivedAttr()) {
1830 emitDerivedAttr(name, namedAttr.attr);
1832 emitAttrWithStorageType(name, namedAttr.name, namedAttr.attr);
1833 emitAttrGetterWithReturnType(fctx, opClass, op, name, namedAttr.attr);
1837 auto derivedAttrs = make_filter_range(op.getAttributes(),
1838 [](const NamedAttribute &namedAttr) {
1839 return namedAttr.attr.isDerivedAttr();
1841 if (derivedAttrs.empty())
1844 opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait");
1845 // Generate helper method to query whether a named attribute is a derived
1846 // attribute. This enables, for example, avoiding adding an attribute that
1847 // overlaps with a derived attribute.
1850 opClass.addStaticMethod("bool", "isDerivedAttribute",
1851 MethodParameter("::llvm::StringRef", "name"));
1852 ERROR_IF_PRUNED(method, "isDerivedAttribute", op);
1853 auto &body = method->body();
1854 for (auto namedAttr : derivedAttrs)
1855 body << " if (name == \"" << namedAttr.name << "\") return true;\n";
1856 body << " return false;";
1858 // Generate method to materialize derived attributes as a DictionaryAttr.
1860 auto *method = opClass.addMethod("::mlir::DictionaryAttr",
1861 "materializeDerivedAttributes");
1862 ERROR_IF_PRUNED(method, "materializeDerivedAttributes", op);
1863 auto &body = method->body();
1865 auto nonMaterializable =
1866 make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) {
1867 return namedAttr.attr.getConvertFromStorageCall().empty();
1869 if (!nonMaterializable.empty()) {
1871 llvm::raw_string_ostream os(attrs);
1872 interleaveComma(nonMaterializable, os, [&](const NamedAttribute &attr) {
1873 os << op.getGetterName(attr.name);
1878 "op has non-materializable derived attributes '{0}', skipping",
1880 body << formatv(" emitOpError(\"op has non-materializable derived "
1881 "attributes '{0}'\");\n",
1883 body << " return nullptr;";
1887 body << " ::mlir::MLIRContext* ctx = getContext();\n";
1888 body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
1889 body << " return ::mlir::DictionaryAttr::get(";
1890 body << " ctx, {\n";
1893 [&](const NamedAttribute &namedAttr) {
1894 auto tmpl = namedAttr.attr.getConvertFromStorageCall();
1895 std::string name = op.getGetterName(namedAttr.name);
1896 body << " {" << name << "AttrName(),\n"
1897 << tgfmt(tmpl, &fctx.withSelf(name + "()")
1898 .withBuilder("odsBuilder")
1899 .addSubst("_ctxt", "ctx")
1900 .addSubst("_storage", "ctx"))
1908 void OpEmitter::genAttrSetters() {
1909 bool useProperties = op.getDialect().usePropertiesForAttributes();
1911 // Generate the code to set an attribute.
1912 auto emitSetAttr = [&](Method *method, StringRef getterName,
1913 StringRef attrName, StringRef attrVar) {
1914 if (useProperties) {
1915 method->body() << formatv(" getProperties().{0} = {1};", attrName,
1918 method->body() << formatv(" (*this)->setAttr({0}AttrName(), {1});",
1919 getterName, attrVar);
1923 // Generate raw named setter type. This is a wrapper class that allows setting
1924 // to the attributes via setters instead of having to use the string interface
1925 // for better compile time verification.
1926 auto emitAttrWithStorageType = [&](StringRef setterName, StringRef getterName,
1927 StringRef attrName, Attribute attr) {
1928 // This method body is trivial, so emit it inline.
1930 opClass.addInlineMethod("void", setterName + "Attr",
1931 MethodParameter(attr.getStorageType(), "attr"));
1933 emitSetAttr(method, getterName, attrName, "attr");
1936 // Generate a setter that accepts the underlying C++ type as opposed to the
1938 auto emitAttrWithReturnType = [&](StringRef setterName, StringRef getterName,
1939 StringRef attrName, Attribute attr) {
1940 Attribute baseAttr = attr.getBaseAttr();
1941 if (!canUseUnwrappedRawValue(baseAttr))
1944 fctx.withBuilder("::mlir::Builder((*this)->getContext())");
1945 bool isUnitAttr = attr.getAttrDefName() == "UnitAttr";
1946 bool isOptional = attr.isOptional();
1948 auto createMethod = [&](const Twine ¶mType) {
1949 return opClass.addMethod("void", setterName,
1950 MethodParameter(paramType.str(), "attrValue"));
1953 // Build the method using the correct parameter type depending on
1955 Method *method = nullptr;
1957 method = createMethod("bool");
1958 else if (isOptional)
1960 createMethod("::std::optional<" + baseAttr.getReturnType() + ">");
1962 method = createMethod(attr.getReturnType());
1966 // If the value isn't optional
, just set it directly
.
1968 emitSetAttr(method
, getterName
, attrName
,
1969 constBuildAttrFromParam(attr
, fctx
, "attrValue"));
1973 // Otherwise, we only set if the provided value is valid. If it isn't, we
1974 // remove the attribute.
1976 // TODO: Handle unit attr parameters specially, given that it is treated as
1977 // optional but not in the same way as the others (i.e. it uses bool over
1978 // std::optional<>).
1979 StringRef paramStr
= isUnitAttr
? "attrValue" : "*attrValue";
1980 if (!useProperties
) {
1981 const char *optionalCodeBody
= R
"(
1983 return (*this)->setAttr({0}AttrName(), {1});
1984 (*this)->removeAttr({0}AttrName());)";
1985 method
->body() << formatv(
1986 optionalCodeBody
, getterName
,
1987 constBuildAttrFromParam(baseAttr
, fctx
, paramStr
));
1989 const char *optionalCodeBody
= R
"(
1990 auto &odsProp = getProperties().{0};
1994 odsProp = nullptr;)";
1995 method
->body() << formatv(
1996 optionalCodeBody
, attrName
,
1997 constBuildAttrFromParam(baseAttr
, fctx
, paramStr
));
2001 for (const NamedAttribute
&namedAttr
: op
.getAttributes()) {
2002 if (namedAttr
.attr
.isDerivedAttr())
2004 std::string setterName
= op
.getSetterName(namedAttr
.name
);
2005 std::string getterName
= op
.getGetterName(namedAttr
.name
);
2006 emitAttrWithStorageType(setterName
, getterName
, namedAttr
.name
,
2008 emitAttrWithReturnType(setterName
, getterName
, namedAttr
.name
,
2013 void OpEmitter::genOptionalAttrRemovers() {
2014 // Generate methods for removing optional attributes, instead of having to
2015 // use the string interface. Enables better compile time verification.
2016 auto emitRemoveAttr
= [&](StringRef name
, bool useProperties
) {
2017 auto upperInitial
= name
.take_front().upper();
2018 auto *method
= opClass
.addInlineMethod("::mlir::Attribute",
2019 op
.getRemoverName(name
) + "Attr");
2022 if (useProperties
) {
2023 method
->body() << formatv(R
"(
2024 auto &attr = getProperties().{0};
2031 method
->body() << formatv("return (*this)->removeAttr({0}AttrName());",
2032 op
.getGetterName(name
));
2035 for (const NamedAttribute
&namedAttr
: op
.getAttributes())
2036 if (namedAttr
.attr
.isOptional())
2037 emitRemoveAttr(namedAttr
.name
,
2038 op
.getDialect().usePropertiesForAttributes());
2041 // Generates the code to compute the start and end index of an operand or result
2043 template <typename RangeT
>
2044 static void generateValueRangeStartAndEnd(
2045 Class
&opClass
, bool isGenericAdaptorBase
, StringRef methodName
,
2046 int numVariadic
, int numNonVariadic
, StringRef rangeSizeCall
,
2047 bool hasAttrSegmentSize
, StringRef sizeAttrInit
, RangeT
&&odsValues
) {
2049 SmallVector
<MethodParameter
> parameters
{MethodParameter("unsigned", "index")};
2050 if (isGenericAdaptorBase
) {
2051 parameters
.emplace_back("unsigned", "odsOperandsSize");
2052 // The range size is passed per parameter for generic adaptor bases as
2053 // using the rangeSizeCall would require the operands, which are not
2054 // accessible in the base class.
2055 rangeSizeCall
= "odsOperandsSize";
2058 // The method is trivial if the operation does not have any variadic operands.
2059 // In that case, make sure to generate it in-line.
2060 auto *method
= opClass
.addMethod("std::pair<unsigned, unsigned>", methodName
,
2061 numVariadic
== 0 ? Method::Properties::Inline
2062 : Method::Properties::None
,
2066 auto &body
= method
->body();
2067 if (numVariadic
== 0) {
2068 body
<< " return {index, 1};\n";
2069 } else if (hasAttrSegmentSize
) {
2070 body
<< sizeAttrInit
<< attrSizedSegmentValueRangeCalcCode
;
2072 // Because the op can have arbitrarily interleaved variadic and non-variadic
2073 // operands, we need to embed a list in the "sink" getter method for
2074 // calculation at run-time.
2075 SmallVector
<StringRef
, 4> isVariadic
;
2076 isVariadic
.reserve(llvm::size(odsValues
));
2077 for (auto &it
: odsValues
)
2078 isVariadic
.push_back(it
.isVariableLength() ? "true" : "false");
2079 std::string isVariadicList
= llvm::join(isVariadic
, ", ");
2080 body
<< formatv(sameVariadicSizeValueRangeCalcCode
, isVariadicList
,
2081 numNonVariadic
, numVariadic
, rangeSizeCall
, "operand");
2085 static std::string
generateTypeForGetter(const NamedTypeConstraint
&value
) {
2086 return llvm::formatv("::mlir::TypedValue<{0}>", value
.constraint
.getCppType())
2090 // Generates the named operand getter methods for the given Operator `op` and
2091 // puts them in `opClass`. Uses `rangeType` as the return type of getters that
2092 // return a range of operands (individual operands are `Value ` and each
2093 // element in the range must also be `Value `); use `rangeBeginCall` to get
2094 // an iterator to the beginning of the operand range; use `rangeSizeCall` to
2095 // obtain the number of operands. `getOperandCallPattern` contains the code
2096 // necessary to obtain a single operand whose position will be substituted
2098 // "{0}" marker in the pattern. Note that the pattern should work for any kind
2099 // of ops, in particular for one-operand ops that may not have the
2100 // `getOperand(unsigned)` method.
2102 generateNamedOperandGetters(const Operator
&op
, Class
&opClass
,
2103 Class
*genericAdaptorBase
, StringRef sizeAttrInit
,
2104 StringRef rangeType
, StringRef rangeElementType
,
2105 StringRef rangeBeginCall
, StringRef rangeSizeCall
,
2106 StringRef getOperandCallPattern
) {
2107 const int numOperands
= op
.getNumOperands();
2108 const int numVariadicOperands
= op
.getNumVariableLengthOperands();
2109 const int numNormalOperands
= numOperands
- numVariadicOperands
;
2111 const auto *sameVariadicSize
=
2112 op
.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
2113 const auto *attrSizedOperands
=
2114 op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
2116 if (numVariadicOperands
> 1 && !sameVariadicSize
&& !attrSizedOperands
) {
2117 PrintFatalError(op
.getLoc(), "op has multiple variadic operands but no "
2118 "specification over their sizes");
2121 if (numVariadicOperands
< 2 && attrSizedOperands
) {
2122 PrintFatalError(op
.getLoc(), "op must have at least two variadic operands "
2123 "to use 'AttrSizedOperandSegments' trait");
2126 if (attrSizedOperands
&& sameVariadicSize
) {
2127 PrintFatalError(op
.getLoc(),
2128 "op cannot have both 'AttrSizedOperandSegments' and "
2129 "'SameVariadicOperandSize' traits");
2132 // First emit a few "sink" getter methods upon which we layer all nicer named
2134 // If generating for an adaptor, the method is put into the non-templated
2135 // generic base class, to not require being defined in the header.
2136 // Since the operand size can't be determined from the base class however,
2137 // it has to be passed as an additional argument. The trampoline below
2138 // generates the function with the same signature as the Op in the generic
2140 bool isGenericAdaptorBase
= genericAdaptorBase
!= nullptr;
2141 generateValueRangeStartAndEnd(
2142 /*opClass=*/isGenericAdaptorBase
? *genericAdaptorBase
: opClass
,
2143 isGenericAdaptorBase
,
2144 /*methodName=*/"getODSOperandIndexAndLength", numVariadicOperands
,
2145 numNormalOperands
, rangeSizeCall
, attrSizedOperands
, sizeAttrInit
,
2146 const_cast<Operator
&>(op
).getOperands());
2147 if (isGenericAdaptorBase
) {
2148 // Generate trampoline for calling 'getODSOperandIndexAndLength' with just
2149 // the index. This just calls the implementation in the base class but
2150 // passes the operand size as parameter.
2151 Method
*method
= opClass
.addInlineMethod(
2152 "std::pair<unsigned, unsigned>", "getODSOperandIndexAndLength",
2153 MethodParameter("unsigned", "index"));
2154 ERROR_IF_PRUNED(method
, "getODSOperandIndexAndLength", op
);
2155 MethodBody
&body
= method
->body();
2156 body
.indent() << formatv(
2157 "return Base::getODSOperandIndexAndLength(index, {0});", rangeSizeCall
);
2160 // The implementation of this method is trivial and it is very load-bearing.
2161 // Generate it inline.
2162 auto *m
= opClass
.addInlineMethod(rangeType
, "getODSOperands",
2163 MethodParameter("unsigned", "index"));
2164 ERROR_IF_PRUNED(m
, "getODSOperands", op
);
2165 auto &body
= m
->body();
2166 body
<< formatv(valueRangeReturnCode
, rangeBeginCall
,
2167 "getODSOperandIndexAndLength(index)");
2169 // Then we emit nicer named getter methods by redirecting to the "sink" getter
2171 for (int i
= 0; i
!= numOperands
; ++i
) {
2172 const auto &operand
= op
.getOperand(i
);
2173 if (operand
.name
.empty())
2175 std::string name
= op
.getGetterName(operand
.name
);
2176 if (operand
.isOptional()) {
2177 m
= opClass
.addInlineMethod(isGenericAdaptorBase
2179 : generateTypeForGetter(operand
),
2181 ERROR_IF_PRUNED(m
, name
, op
);
2182 m
->body().indent() << formatv("auto operands = getODSOperands({0});\n"
2183 "return operands.empty() ? {1}{{} : ",
2184 i
, m
->getReturnType());
2185 if (!isGenericAdaptorBase
)
2186 m
->body() << llvm::formatv("::llvm::cast<{0}>", m
->getReturnType());
2187 m
->body() << "(*operands.begin());";
2188 } else if (operand
.isVariadicOfVariadic()) {
2189 std::string segmentAttr
= op
.getGetterName(
2190 operand
.constraint
.getVariadicOfVariadicSegmentSizeAttr());
2191 if (genericAdaptorBase
) {
2192 m
= opClass
.addMethod("::llvm::SmallVector<" + rangeType
+ ">", name
);
2193 ERROR_IF_PRUNED(m
, name
, op
);
2194 m
->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode
,
2195 segmentAttr
, i
, rangeType
);
2199 m
= opClass
.addInlineMethod("::mlir::OperandRangeRange", name
);
2200 ERROR_IF_PRUNED(m
, name
, op
);
2201 m
->body() << " return getODSOperands(" << i
<< ").split(" << segmentAttr
2203 } else if (operand
.isVariadic()) {
2204 m
= opClass
.addInlineMethod(rangeType
, name
);
2205 ERROR_IF_PRUNED(m
, name
, op
);
2206 m
->body() << " return getODSOperands(" << i
<< ");";
2208 m
= opClass
.addInlineMethod(isGenericAdaptorBase
2210 : generateTypeForGetter(operand
),
2212 ERROR_IF_PRUNED(m
, name
, op
);
2213 m
->body().indent() << "return ";
2214 if (!isGenericAdaptorBase
)
2215 m
->body() << llvm::formatv("::llvm::cast<{0}>", m
->getReturnType());
2216 m
->body() << llvm::formatv("(*getODSOperands({0}).begin());", i
);
2221 void OpEmitter::genNamedOperandGetters() {
2222 // Build the code snippet used for initializing the operand_segment_size)s
2224 std::string attrSizeInitCode
;
2225 if (op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
2226 if (op
.getDialect().usePropertiesForAttributes())
2227 attrSizeInitCode
= formatv(adapterSegmentSizeAttrInitCodeProperties
,
2228 "getProperties().operandSegmentSizes");
2231 attrSizeInitCode
= formatv(opSegmentSizeAttrInitCode
,
2232 emitHelper
.getAttr(operandSegmentAttrName
));
2235 generateNamedOperandGetters(
2237 /*genericAdaptorBase=*/nullptr,
2238 /*sizeAttrInit=*/attrSizeInitCode
,
2239 /*rangeType=*/"::mlir::Operation::operand_range",
2240 /*rangeElementType=*/"::mlir::Value",
2241 /*rangeBeginCall=*/"getOperation()->operand_begin()",
2242 /*rangeSizeCall=*/"getOperation()->getNumOperands()",
2243 /*getOperandCallPattern=*/"getOperation()->getOperand({0})");
2246 void OpEmitter::genNamedOperandSetters() {
2247 auto *attrSizedOperands
=
2248 op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
2249 for (int i
= 0, e
= op
.getNumOperands(); i
!= e
; ++i
) {
2250 const auto &operand
= op
.getOperand(i
);
2251 if (operand
.name
.empty())
2253 std::string name
= op
.getGetterName(operand
.name
);
2255 StringRef returnType
;
2256 if (operand
.isVariadicOfVariadic()) {
2257 returnType
= "::mlir::MutableOperandRangeRange";
2258 } else if (operand
.isVariableLength()) {
2259 returnType
= "::mlir::MutableOperandRange";
2261 returnType
= "::mlir::OpOperand &";
2263 bool isVariadicOperand
=
2264 operand
.isVariadicOfVariadic() || operand
.isVariableLength();
2265 auto *m
= opClass
.addMethod(returnType
, name
+ "Mutable",
2266 isVariadicOperand
? Method::Properties::None
2267 : Method::Properties::Inline
);
2268 ERROR_IF_PRUNED(m
, name
, op
);
2269 auto &body
= m
->body();
2270 body
<< " auto range = getODSOperandIndexAndLength(" << i
<< ");\n";
2272 if (!isVariadicOperand
) {
2273 // In case of a single operand, return a single OpOperand.
2274 body
<< " return getOperation()->getOpOperand(range.first);\n";
2278 body
<< " auto mutableRange = "
2279 "::mlir::MutableOperandRange(getOperation(), "
2280 "range.first, range.second";
2281 if (attrSizedOperands
) {
2282 if (emitHelper
.hasProperties())
2283 body
<< formatv(", ::mlir::MutableOperandRange::OperandSegment({0}u, "
2284 "{{getOperandSegmentSizesAttrName(), "
2285 "::mlir::DenseI32ArrayAttr::get(getContext(), "
2286 "getProperties().operandSegmentSizes)})",
2290 ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i
,
2291 emitHelper
.getAttr(operandSegmentAttrName
, /*isNamed=*/true));
2295 // If this operand is a nested variadic, we split the range into a
2296 // MutableOperandRangeRange that provides a range over all of the
2298 if (operand
.isVariadicOfVariadic()) {
2300 "mutableRange.split(*(*this)->getAttrDictionary().getNamed("
2301 << op
.getGetterName(
2302 operand
.constraint
.getVariadicOfVariadicSegmentSizeAttr())
2303 << "AttrName()));\n";
2305 // Otherwise, we use the full range directly.
2306 body
<< " return mutableRange;\n";
2311 void OpEmitter::genNamedResultGetters() {
2312 const int numResults
= op
.getNumResults();
2313 const int numVariadicResults
= op
.getNumVariableLengthResults();
2314 const int numNormalResults
= numResults
- numVariadicResults
;
2316 // If we have more than one variadic results, we need more complicated logic
2317 // to calculate the value range for each result.
2319 const auto *sameVariadicSize
=
2320 op
.getTrait("::mlir::OpTrait::SameVariadicResultSize");
2321 const auto *attrSizedResults
=
2322 op
.getTrait("::mlir::OpTrait::AttrSizedResultSegments");
2324 if (numVariadicResults
> 1 && !sameVariadicSize
&& !attrSizedResults
) {
2325 PrintFatalError(op
.getLoc(), "op has multiple variadic results but no "
2326 "specification over their sizes");
2329 if (numVariadicResults
< 2 && attrSizedResults
) {
2330 PrintFatalError(op
.getLoc(), "op must have at least two variadic results "
2331 "to use 'AttrSizedResultSegments' trait");
2334 if (attrSizedResults
&& sameVariadicSize
) {
2335 PrintFatalError(op
.getLoc(),
2336 "op cannot have both 'AttrSizedResultSegments' and "
2337 "'SameVariadicResultSize' traits");
2340 // Build the initializer string for the result segment size attribute.
2341 std::string attrSizeInitCode
;
2342 if (attrSizedResults
) {
2343 if (op
.getDialect().usePropertiesForAttributes())
2344 attrSizeInitCode
= formatv(adapterSegmentSizeAttrInitCodeProperties
,
2345 "getProperties().resultSegmentSizes");
2348 attrSizeInitCode
= formatv(opSegmentSizeAttrInitCode
,
2349 emitHelper
.getAttr(resultSegmentAttrName
));
2352 generateValueRangeStartAndEnd(
2353 opClass
, /*isGenericAdaptorBase=*/false, "getODSResultIndexAndLength",
2354 numVariadicResults
, numNormalResults
, "getOperation()->getNumResults()",
2355 attrSizedResults
, attrSizeInitCode
, op
.getResults());
2357 // The implementation of this method is trivial and it is very load-bearing.
2358 // Generate it inline.
2359 auto *m
= opClass
.addInlineMethod("::mlir::Operation::result_range",
2361 MethodParameter("unsigned", "index"));
2362 ERROR_IF_PRUNED(m
, "getODSResults", op
);
2363 m
->body() << formatv(valueRangeReturnCode
, "getOperation()->result_begin()",
2364 "getODSResultIndexAndLength(index)");
2366 for (int i
= 0; i
!= numResults
; ++i
) {
2367 const auto &result
= op
.getResult(i
);
2368 if (result
.name
.empty())
2370 std::string name
= op
.getGetterName(result
.name
);
2371 if (result
.isOptional()) {
2372 m
= opClass
.addInlineMethod(generateTypeForGetter(result
), name
);
2373 ERROR_IF_PRUNED(m
, name
, op
);
2374 m
->body() << " auto results = getODSResults(" << i
<< ");\n"
2375 << llvm::formatv(" return results.empty()"
2377 " : ::llvm::cast<{0}>(*results.begin());",
2378 m
->getReturnType());
2379 } else if (result
.isVariadic()) {
2380 m
= opClass
.addInlineMethod("::mlir::Operation::result_range", name
);
2381 ERROR_IF_PRUNED(m
, name
, op
);
2382 m
->body() << " return getODSResults(" << i
<< ");";
2384 m
= opClass
.addInlineMethod(generateTypeForGetter(result
), name
);
2385 ERROR_IF_PRUNED(m
, name
, op
);
2386 m
->body() << llvm::formatv(
2387 " return ::llvm::cast<{0}>(*getODSResults({1}).begin());",
2388 m
->getReturnType(), i
);
2393 void OpEmitter::genNamedRegionGetters() {
2394 unsigned numRegions
= op
.getNumRegions();
2395 for (unsigned i
= 0; i
< numRegions
; ++i
) {
2396 const auto ®ion
= op
.getRegion(i
);
2397 if (region
.name
.empty())
2399 std::string name
= op
.getGetterName(region
.name
);
2401 // Generate the accessors for a variadic region.
2402 if (region
.isVariadic()) {
2403 auto *m
= opClass
.addInlineMethod(
2404 "::mlir::MutableArrayRef<::mlir::Region>", name
);
2405 ERROR_IF_PRUNED(m
, name
, op
);
2406 m
->body() << formatv(" return (*this)->getRegions().drop_front({0});",
2411 auto *m
= opClass
.addInlineMethod("::mlir::Region &", name
);
2412 ERROR_IF_PRUNED(m
, name
, op
);
2413 m
->body() << formatv(" return (*this)->getRegion({0});", i
);
2417 void OpEmitter::genNamedSuccessorGetters() {
2418 unsigned numSuccessors
= op
.getNumSuccessors();
2419 for (unsigned i
= 0; i
< numSuccessors
; ++i
) {
2420 const NamedSuccessor
&successor
= op
.getSuccessor(i
);
2421 if (successor
.name
.empty())
2423 std::string name
= op
.getGetterName(successor
.name
);
2424 // Generate the accessors for a variadic successor list.
2425 if (successor
.isVariadic()) {
2426 auto *m
= opClass
.addInlineMethod("::mlir::SuccessorRange", name
);
2427 ERROR_IF_PRUNED(m
, name
, op
);
2428 m
->body() << formatv(
2429 " return {std::next((*this)->successor_begin(), {0}), "
2430 "(*this)->successor_end()};",
2435 auto *m
= opClass
.addInlineMethod("::mlir::Block *", name
);
2436 ERROR_IF_PRUNED(m
, name
, op
);
2437 m
->body() << formatv(" return (*this)->getSuccessor({0});", i
);
2441 static bool canGenerateUnwrappedBuilder(const Operator
&op
) {
2442 // If this op does not have native attributes at all, return directly to avoid
2443 // redefining builders.
2444 if (op
.getNumNativeAttributes() == 0)
2447 bool canGenerate
= false;
2448 // We are generating builders that take raw values for attributes. We need to
2449 // make sure the native attributes have a meaningful "unwrapped" value type
2450 // different from the wrapped mlir::Attribute type to avoid redefining
2451 // builders. This checks for the op has at least one such native attribute.
2452 for (int i
= 0, e
= op
.getNumNativeAttributes(); i
< e
; ++i
) {
2453 const NamedAttribute
&namedAttr
= op
.getAttribute(i
);
2454 if (canUseUnwrappedRawValue(namedAttr
.attr
)) {
2462 static bool canInferType(const Operator
&op
) {
2463 return op
.getTrait("::mlir::InferTypeOpInterface::Trait");
2466 void OpEmitter::genSeparateArgParamBuilder() {
2467 SmallVector
<AttrParamKind
, 2> attrBuilderType
;
2468 attrBuilderType
.push_back(AttrParamKind::WrappedAttr
);
2469 if (canGenerateUnwrappedBuilder(op
))
2470 attrBuilderType
.push_back(AttrParamKind::UnwrappedValue
);
2472 // Emit with separate builders with or without unwrapped attributes and/or
2473 // inferring result type.
2474 auto emit
= [&](AttrParamKind attrType
, TypeParamKind paramKind
,
2476 SmallVector
<MethodParameter
> paramList
;
2477 SmallVector
<std::string
, 4> resultNames
;
2478 llvm::StringSet
<> inferredAttributes
;
2479 buildParamList(paramList
, inferredAttributes
, resultNames
, paramKind
,
2482 auto *m
= opClass
.addStaticMethod("void", "build", std::move(paramList
));
2483 // If the builder is redundant, skip generating the method.
2486 auto &body
= m
->body();
2487 genCodeForAddingArgAndRegionForBuilder(body
, inferredAttributes
,
2488 /*isRawValueAttr=*/attrType
==
2489 AttrParamKind::UnwrappedValue
);
2491 // Push all result types to the operation state
2494 // Generate builder that infers type too.
2495 // TODO: Subsume this with general checking if type can be
2496 // inferred automatically.
2498 ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
2499 if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
2500 {1}.location, {1}.operands,
2501 {1}.attributes.getDictionary({1}.getContext()),
2502 {1}.getRawProperties(),
2503 {1}.regions, inferredReturnTypes)))
2504 {1}.addTypes(inferredReturnTypes);
2506 ::mlir::detail::reportFatalInferReturnTypesError({1});
2508 opClass
.getClassName(), builderOpState
);
2512 switch (paramKind
) {
2513 case TypeParamKind::None
:
2515 case TypeParamKind::Separate
:
2516 for (int i
= 0, e
= op
.getNumResults(); i
< e
; ++i
) {
2517 if (op
.getResult(i
).isOptional())
2518 body
<< " if (" << resultNames
[i
] << ")\n ";
2519 body
<< " " << builderOpState
<< ".addTypes(" << resultNames
[i
]
2523 // Automatically create the 'resultSegmentSizes' attribute using
2524 // the length of the type ranges.
2525 if (op
.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
2526 if (op
.getDialect().usePropertiesForAttributes()) {
2527 body
<< " ::llvm::copy(::llvm::ArrayRef<int32_t>({";
2529 std::string getterName
= op
.getGetterName(resultSegmentAttrName
);
2530 body
<< " " << builderOpState
<< ".addAttribute(" << getterName
2531 << "AttrName(" << builderOpState
<< ".name), "
2532 << "odsBuilder.getDenseI32ArrayAttr({";
2535 llvm::seq
<int>(0, op
.getNumResults()), body
, [&](int i
) {
2536 const NamedTypeConstraint
&result
= op
.getResult(i
);
2537 if (!result
.isVariableLength()) {
2539 } else if (result
.isOptional()) {
2540 body
<< "(" << resultNames
[i
] << " ? 1 : 0)";
2542 // VariadicOfVariadic of results are currently unsupported in
2543 // MLIR, hence it can only be a simple variadic.
2544 // TODO: Add implementation for VariadicOfVariadic results here
2546 assert(result
.isVariadic());
2547 body
<< "static_cast<int32_t>(" << resultNames
[i
] << ".size())";
2550 if (op
.getDialect().usePropertiesForAttributes()) {
2551 body
<< "}), " << builderOpState
2552 << ".getOrAddProperties<Properties>()."
2553 "resultSegmentSizes.begin());\n";
2560 case TypeParamKind::Collective
: {
2561 int numResults
= op
.getNumResults();
2562 int numVariadicResults
= op
.getNumVariableLengthResults();
2563 int numNonVariadicResults
= numResults
- numVariadicResults
;
2564 bool hasVariadicResult
= numVariadicResults
!= 0;
2566 // Avoid emitting "resultTypes.size() >= 0u" which is always true.
2567 if (!hasVariadicResult
|| numNonVariadicResults
!= 0)
2569 << "assert(resultTypes.size() "
2570 << (hasVariadicResult
? ">=" : "==") << " "
2571 << numNonVariadicResults
2572 << "u && \"mismatched number of results\");\n";
2573 body
<< " " << builderOpState
<< ".addTypes(resultTypes);\n";
2577 llvm_unreachable("unhandled TypeParamKind");
2580 // Some of the build methods generated here may be ambiguous, but TableGen's
2581 // ambiguous function detection will elide those ones.
2582 for (auto attrType
: attrBuilderType
) {
2583 emit(attrType
, TypeParamKind::Separate
, /*inferType=*/false);
2584 if (canInferType(op
))
2585 emit(attrType
, TypeParamKind::None
, /*inferType=*/true);
2586 emit(attrType
, TypeParamKind::Collective
, /*inferType=*/false);
2590 void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
2591 int numResults
= op
.getNumResults();
2594 SmallVector
<MethodParameter
> paramList
;
2595 paramList
.emplace_back("::mlir::OpBuilder &", "odsBuilder");
2596 paramList
.emplace_back("::mlir::OperationState &", builderOpState
);
2597 paramList
.emplace_back("::mlir::ValueRange", "operands");
2598 // Provide default value for `attributes` when its the last parameter
2599 StringRef attributesDefaultValue
= op
.getNumVariadicRegions() ? "" : "{}";
2600 paramList
.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
2601 "attributes", attributesDefaultValue
);
2602 if (op
.getNumVariadicRegions())
2603 paramList
.emplace_back("unsigned", "numRegions");
2605 auto *m
= opClass
.addStaticMethod("void", "build", std::move(paramList
));
2606 // If the builder is redundant, skip generating the method
2609 auto &body
= m
->body();
2612 body
<< " " << builderOpState
<< ".addOperands(operands);\n";
2615 body
<< " " << builderOpState
<< ".addAttributes(attributes);\n";
2617 // Create the correct number of regions
2618 if (int numRegions
= op
.getNumRegions()) {
2619 body
<< llvm::formatv(
2620 " for (unsigned i = 0; i != {0}; ++i)\n",
2621 (op
.getNumVariadicRegions() ? "numRegions" : Twine(numRegions
)));
2622 body
<< " (void)" << builderOpState
<< ".addRegion();\n";
2626 SmallVector
<std::string
, 2> resultTypes(numResults
, "operands[0].getType()");
2627 body
<< " " << builderOpState
<< ".addTypes({"
2628 << llvm::join(resultTypes
, ", ") << "});\n\n";
2631 void OpEmitter::genPopulateDefaultAttributes() {
2632 // All done if no attributes, except optional ones, have default values.
2633 if (llvm::all_of(op
.getAttributes(), [](const NamedAttribute
&named
) {
2634 return !named
.attr
.hasDefaultValue() || named
.attr
.isOptional();
2638 if (emitHelper
.hasProperties()) {
2639 SmallVector
<MethodParameter
> paramList
;
2640 paramList
.emplace_back("::mlir::OperationName", "opName");
2641 paramList
.emplace_back("Properties &", "properties");
2643 opClass
.addStaticMethod("void", "populateDefaultProperties", paramList
);
2644 ERROR_IF_PRUNED(m
, "populateDefaultProperties", op
);
2645 auto &body
= m
->body();
2647 body
<< "::mlir::Builder " << odsBuilder
<< "(opName.getContext());\n";
2648 for (const NamedAttribute
&namedAttr
: op
.getAttributes()) {
2649 auto &attr
= namedAttr
.attr
;
2650 if (!attr
.hasDefaultValue() || attr
.isOptional())
2652 StringRef name
= namedAttr
.name
;
2654 fctx
.withBuilder(odsBuilder
);
2655 body
<< "if (!properties." << name
<< ")\n"
2656 << " properties." << name
<< " = "
2657 << std::string(tgfmt(attr
.getConstBuilderTemplate(), &fctx
,
2658 tgfmt(attr
.getDefaultValue(), &fctx
)))
2664 SmallVector
<MethodParameter
> paramList
;
2665 paramList
.emplace_back("const ::mlir::OperationName &", "opName");
2666 paramList
.emplace_back("::mlir::NamedAttrList &", "attributes");
2667 auto *m
= opClass
.addStaticMethod("void", "populateDefaultAttrs", paramList
);
2668 ERROR_IF_PRUNED(m
, "populateDefaultAttrs", op
);
2669 auto &body
= m
->body();
2672 // Set default attributes that are unset.
2673 body
<< "auto attrNames = opName.getAttributeNames();\n";
2674 body
<< "::mlir::Builder " << odsBuilder
2675 << "(attrNames.front().getContext());\n";
2676 StringMap
<int> attrIndex
;
2677 for (const auto &it
: llvm::enumerate(emitHelper
.getAttrMetadata())) {
2678 attrIndex
[it
.value().first
] = it
.index();
2680 for (const NamedAttribute
&namedAttr
: op
.getAttributes()) {
2681 auto &attr
= namedAttr
.attr
;
2682 if (!attr
.hasDefaultValue() || attr
.isOptional())
2684 auto index
= attrIndex
[namedAttr
.name
];
2685 body
<< "if (!attributes.get(attrNames[" << index
<< "])) {\n";
2687 fctx
.withBuilder(odsBuilder
);
2689 std::string defaultValue
=
2690 std::string(tgfmt(attr
.getConstBuilderTemplate(), &fctx
,
2691 tgfmt(attr
.getDefaultValue(), &fctx
)));
2692 body
.indent() << formatv("attributes.append(attrNames[{0}], {1});\n", index
,
2694 body
.unindent() << "}\n";
2698 void OpEmitter::genInferredTypeCollectiveParamBuilder() {
2699 SmallVector
<MethodParameter
> paramList
;
2700 paramList
.emplace_back("::mlir::OpBuilder &", "odsBuilder");
2701 paramList
.emplace_back("::mlir::OperationState &", builderOpState
);
2702 paramList
.emplace_back("::mlir::ValueRange", "operands");
2703 StringRef attributesDefaultValue
= op
.getNumVariadicRegions() ? "" : "{}";
2704 paramList
.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
2705 "attributes", attributesDefaultValue
);
2706 if (op
.getNumVariadicRegions())
2707 paramList
.emplace_back("unsigned", "numRegions");
2709 auto *m
= opClass
.addStaticMethod("void", "build", std::move(paramList
));
2710 // If the builder is redundant, skip generating the method
2713 auto &body
= m
->body();
2715 int numResults
= op
.getNumResults();
2716 int numVariadicResults
= op
.getNumVariableLengthResults();
2717 int numNonVariadicResults
= numResults
- numVariadicResults
;
2719 int numOperands
= op
.getNumOperands();
2720 int numVariadicOperands
= op
.getNumVariableLengthOperands();
2721 int numNonVariadicOperands
= numOperands
- numVariadicOperands
;
2724 if (numVariadicOperands
== 0 || numNonVariadicOperands
!= 0)
2725 body
<< " assert(operands.size()"
2726 << (numVariadicOperands
!= 0 ? " >= " : " == ")
2727 << numNonVariadicOperands
2728 << "u && \"mismatched number of parameters\");\n";
2729 body
<< " " << builderOpState
<< ".addOperands(operands);\n";
2730 body
<< " " << builderOpState
<< ".addAttributes(attributes);\n";
2732 // Create the correct number of regions
2733 if (int numRegions
= op
.getNumRegions()) {
2734 body
<< llvm::formatv(
2735 " for (unsigned i = 0; i != {0}; ++i)\n",
2736 (op
.getNumVariadicRegions() ? "numRegions" : Twine(numRegions
)));
2737 body
<< " (void)" << builderOpState
<< ".addRegion();\n";
2741 if (emitHelper
.hasProperties()) {
2742 // Initialize the properties from Attributes before invoking the infer
2745 if (!attributes.empty()) {
2746 ::mlir::OpaqueProperties properties =
2747 &{1}.getOrAddProperties<{0}::Properties>();
2748 std::optional<::mlir::RegisteredOperationName> info =
2749 {1}.name.getRegisteredInfo();
2750 if (failed(info->setOpPropertiesFromAttribute({1}.name, properties,
2751 {1}.attributes.getDictionary({1}.getContext()), nullptr)))
2752 ::llvm::report_fatal_error("Property conversion failed
.");
2754 opClass
.getClassName(), builderOpState
);
2757 ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
2758 if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
2759 {1}.location, operands,
2760 {1}.attributes.getDictionary({1}.getContext()),
2761 {1}.getRawProperties(),
2762 {1}.regions, inferredReturnTypes))) {{)",
2763 opClass
.getClassName(), builderOpState
);
2764 if (numVariadicResults
== 0 || numNonVariadicResults
!= 0)
2765 body
<< "\n assert(inferredReturnTypes.size()"
2766 << (numVariadicResults
!= 0 ? " >= " : " == ") << numNonVariadicResults
2767 << "u && \"mismatched number of return types\");";
2768 body
<< "\n " << builderOpState
<< ".addTypes(inferredReturnTypes);";
2772 ::llvm::report_fatal_error("Failed to infer result
type(s
).");
2776 void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
2777 auto emit
= [&](AttrParamKind attrType
) {
2778 SmallVector
<MethodParameter
> paramList
;
2779 SmallVector
<std::string
, 4> resultNames
;
2780 llvm::StringSet
<> inferredAttributes
;
2781 buildParamList(paramList
, inferredAttributes
, resultNames
,
2782 TypeParamKind::None
, attrType
);
2784 auto *m
= opClass
.addStaticMethod("void", "build", std::move(paramList
));
2785 // If the builder is redundant, skip generating the method
2788 auto &body
= m
->body();
2789 genCodeForAddingArgAndRegionForBuilder(body
, inferredAttributes
,
2790 /*isRawValueAttr=*/attrType
==
2791 AttrParamKind::UnwrappedValue
);
2793 auto numResults
= op
.getNumResults();
2794 if (numResults
== 0)
2797 // Push all result types to the operation state
2798 const char *index
= op
.getOperand(0).isVariadic() ? ".front()" : "";
2799 std::string resultType
=
2800 formatv("{0}{1}.getType()", getArgumentName(op
, 0), index
).str();
2801 body
<< " " << builderOpState
<< ".addTypes({" << resultType
;
2802 for (int i
= 1; i
!= numResults
; ++i
)
2803 body
<< ", " << resultType
;
2807 emit(AttrParamKind::WrappedAttr
);
2808 // Generate additional builder(s) if any attributes can be "unwrapped"
2809 if (canGenerateUnwrappedBuilder(op
))
2810 emit(AttrParamKind::UnwrappedValue
);
2813 void OpEmitter::genUseAttrAsResultTypeBuilder() {
2814 SmallVector
<MethodParameter
> paramList
;
2815 paramList
.emplace_back("::mlir::OpBuilder &", "odsBuilder");
2816 paramList
.emplace_back("::mlir::OperationState &", builderOpState
);
2817 paramList
.emplace_back("::mlir::ValueRange", "operands");
2818 paramList
.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
2819 "attributes", "{}");
2820 auto *m
= opClass
.addStaticMethod("void", "build", std::move(paramList
));
2821 // If the builder is redundant, skip generating the method
2825 auto &body
= m
->body();
2827 // Push all result types to the operation state
2828 std::string resultType
;
2829 const auto &namedAttr
= op
.getAttribute(0);
2831 body
<< " auto attrName = " << op
.getGetterName(namedAttr
.name
)
2832 << "AttrName(" << builderOpState
2834 " for (auto attr : attributes) {\n"
2835 " if (attr.getName() != attrName) continue;\n";
2836 if (namedAttr
.attr
.isTypeAttr()) {
2837 resultType
= "::llvm::cast<::mlir::TypeAttr>(attr.getValue()).getValue()";
2839 resultType
= "::llvm::cast<::mlir::TypedAttr>(attr.getValue()).getType()";
2843 body
<< " " << builderOpState
<< ".addOperands(operands);\n";
2846 body
<< " " << builderOpState
<< ".addAttributes(attributes);\n";
2849 SmallVector
<std::string
, 2> resultTypes(op
.getNumResults(), resultType
);
2850 body
<< " " << builderOpState
<< ".addTypes({"
2851 << llvm::join(resultTypes
, ", ") << "});\n";
2855 /// Returns a signature of the builder. Updates the context `fctx` to enable
2856 /// replacement of $_builder and $_state in the body.
2857 static SmallVector
<MethodParameter
>
2858 getBuilderSignature(const Builder
&builder
) {
2859 ArrayRef
<Builder::Parameter
> params(builder
.getParameters());
2861 // Inject builder and state arguments.
2862 SmallVector
<MethodParameter
> arguments
;
2863 arguments
.reserve(params
.size() + 2);
2864 arguments
.emplace_back("::mlir::OpBuilder &", odsBuilder
);
2865 arguments
.emplace_back("::mlir::OperationState &", builderOpState
);
2867 for (unsigned i
= 0, e
= params
.size(); i
< e
; ++i
) {
2868 // If no name is provided, generate one.
2869 std::optional
<StringRef
> paramName
= params
[i
].getName();
2871 paramName
? paramName
->str() : "odsArg" + std::to_string(i
);
2873 StringRef defaultValue
;
2874 if (std::optional
<StringRef
> defaultParamValue
=
2875 params
[i
].getDefaultValue())
2876 defaultValue
= *defaultParamValue
;
2878 arguments
.emplace_back(params
[i
].getCppType(), std::move(name
),
2885 void OpEmitter::genBuilder() {
2886 // Handle custom builders if provided.
2887 for (const Builder
&builder
: op
.getBuilders()) {
2888 SmallVector
<MethodParameter
> arguments
= getBuilderSignature(builder
);
2890 std::optional
<StringRef
> body
= builder
.getBody();
2891 auto properties
= body
? Method::Static
: Method::StaticDeclaration
;
2893 opClass
.addMethod("void", "build", properties
, std::move(arguments
));
2895 ERROR_IF_PRUNED(method
, "build", op
);
2898 method
->setDeprecated(builder
.getDeprecatedMessage());
2901 fctx
.withBuilder(odsBuilder
);
2902 fctx
.addSubst("_state", builderOpState
);
2904 method
->body() << tgfmt(*body
, &fctx
);
2907 // Generate default builders that requires all result type, operands, and
2908 // attributes as parameters.
2909 if (op
.skipDefaultBuilders())
2912 // We generate three classes of builders here:
2913 // 1. one having a stand-alone parameter for each operand / attribute, and
2914 genSeparateArgParamBuilder();
2915 // 2. one having an aggregated parameter for all result types / operands /
2917 genCollectiveParamBuilder();
2918 // 3. one having a stand-alone parameter for each operand and attribute,
2919 // use the first operand or attribute's type as all result types
2920 // to facilitate different call patterns.
2921 if (op
.getNumVariableLengthResults() == 0) {
2922 if (op
.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
2923 genUseOperandAsResultTypeSeparateParamBuilder();
2924 genUseOperandAsResultTypeCollectiveParamBuilder();
2926 if (op
.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"))
2927 genUseAttrAsResultTypeBuilder();
2931 void OpEmitter::genCollectiveParamBuilder() {
2932 int numResults
= op
.getNumResults();
2933 int numVariadicResults
= op
.getNumVariableLengthResults();
2934 int numNonVariadicResults
= numResults
- numVariadicResults
;
2936 int numOperands
= op
.getNumOperands();
2937 int numVariadicOperands
= op
.getNumVariableLengthOperands();
2938 int numNonVariadicOperands
= numOperands
- numVariadicOperands
;
2940 SmallVector
<MethodParameter
> paramList
;
2941 paramList
.emplace_back("::mlir::OpBuilder &", "");
2942 paramList
.emplace_back("::mlir::OperationState &", builderOpState
);
2943 paramList
.emplace_back("::mlir::TypeRange", "resultTypes");
2944 paramList
.emplace_back("::mlir::ValueRange", "operands");
2945 // Provide default value for `attributes` when its the last parameter
2946 StringRef attributesDefaultValue
= op
.getNumVariadicRegions() ? "" : "{}";
2947 paramList
.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
2948 "attributes", attributesDefaultValue
);
2949 if (op
.getNumVariadicRegions())
2950 paramList
.emplace_back("unsigned", "numRegions");
2952 auto *m
= opClass
.addStaticMethod("void", "build", std::move(paramList
));
2953 // If the builder is redundant, skip generating the method
2956 auto &body
= m
->body();
2959 if (numVariadicOperands
== 0 || numNonVariadicOperands
!= 0)
2960 body
<< " assert(operands.size()"
2961 << (numVariadicOperands
!= 0 ? " >= " : " == ")
2962 << numNonVariadicOperands
2963 << "u && \"mismatched number of parameters\");\n";
2964 body
<< " " << builderOpState
<< ".addOperands(operands);\n";
2967 body
<< " " << builderOpState
<< ".addAttributes(attributes);\n";
2969 // Create the correct number of regions
2970 if (int numRegions
= op
.getNumRegions()) {
2971 body
<< llvm::formatv(
2972 " for (unsigned i = 0; i != {0}; ++i)\n",
2973 (op
.getNumVariadicRegions() ? "numRegions" : Twine(numRegions
)));
2974 body
<< " (void)" << builderOpState
<< ".addRegion();\n";
2978 if (numVariadicResults
== 0 || numNonVariadicResults
!= 0)
2979 body
<< " assert(resultTypes.size()"
2980 << (numVariadicResults
!= 0 ? " >= " : " == ") << numNonVariadicResults
2981 << "u && \"mismatched number of return types\");\n";
2982 body
<< " " << builderOpState
<< ".addTypes(resultTypes);\n";
2984 if (emitHelper
.hasProperties()) {
2985 // Initialize the properties from Attributes before invoking the infer
2988 if (!attributes.empty()) {
2989 ::mlir::OpaqueProperties properties =
2990 &{1}.getOrAddProperties<{0}::Properties>();
2991 std::optional<::mlir::RegisteredOperationName> info =
2992 {1}.name.getRegisteredInfo();
2993 if (failed(info->setOpPropertiesFromAttribute({1}.name, properties,
2994 {1}.attributes.getDictionary({1}.getContext()), nullptr)))
2995 ::llvm::report_fatal_error("Property conversion failed
.");
2997 opClass
.getClassName(), builderOpState
);
3000 // Generate builder that infers type too.
3001 // TODO: Expand to handle successors.
3002 if (canInferType(op
) && op
.getNumSuccessors() == 0)
3003 genInferredTypeCollectiveParamBuilder();
3006 void OpEmitter::buildParamList(SmallVectorImpl
<MethodParameter
> ¶mList
,
3007 llvm::StringSet
<> &inferredAttributes
,
3008 SmallVectorImpl
<std::string
> &resultTypeNames
,
3009 TypeParamKind typeParamKind
,
3010 AttrParamKind attrParamKind
) {
3011 resultTypeNames
.clear();
3012 auto numResults
= op
.getNumResults();
3013 resultTypeNames
.reserve(numResults
);
3015 paramList
.emplace_back("::mlir::OpBuilder &", odsBuilder
);
3016 paramList
.emplace_back("::mlir::OperationState &", builderOpState
);
3018 switch (typeParamKind
) {
3019 case TypeParamKind::None
:
3021 case TypeParamKind::Separate
: {
3022 // Add parameters for all return types
3023 for (int i
= 0; i
< numResults
; ++i
) {
3024 const auto &result
= op
.getResult(i
);
3025 std::string resultName
= std::string(result
.name
);
3026 if (resultName
.empty())
3027 resultName
= std::string(formatv("resultType{0}", i
));
3030 result
.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
3032 paramList
.emplace_back(type
, resultName
, result
.isOptional());
3033 resultTypeNames
.emplace_back(std::move(resultName
));
3036 case TypeParamKind::Collective
: {
3037 paramList
.emplace_back("::mlir::TypeRange", "resultTypes");
3038 resultTypeNames
.push_back("resultTypes");
3042 // Add parameters for all arguments (operands and attributes).
3043 // Track "attr-like" (property and attribute) optional values separate from
3044 // attributes themselves so that the disambiguation code can look at the first
3045 // attribute specifically when determining where to trim the optional-value
3046 // list to avoid ambiguity while preserving the ability of all-property ops to
3047 // use default parameters.
3048 int defaultValuedAttrLikeStartIndex
= op
.getNumArgs();
3049 int defaultValuedAttrStartIndex
= op
.getNumArgs();
3050 // Successors and variadic regions go at the end of the parameter list, so no
3051 // default arguments are possible.
3052 bool hasTrailingParams
= op
.getNumSuccessors() || op
.getNumVariadicRegions();
3053 if (!hasTrailingParams
) {
3054 // Calculate the start index from which we can attach default values in the
3055 // builder declaration.
3056 for (int i
= op
.getNumArgs() - 1; i
>= 0; --i
) {
3058 llvm::dyn_cast_if_present
<tblgen::NamedAttribute
*>(op
.getArg(i
));
3059 auto *namedProperty
=
3060 llvm::dyn_cast_if_present
<tblgen::NamedProperty
*>(op
.getArg(i
));
3061 if (namedProperty
) {
3062 Property prop
= namedProperty
->prop
;
3063 if (!prop
.hasDefaultValue())
3065 defaultValuedAttrLikeStartIndex
= i
;
3071 Attribute attr
= namedAttr
->attr
;
3072 // TODO: Currently we can't differentiate between optional meaning do not
3073 // verify/not always error if missing or optional meaning need not be
3074 // specified in builder. Expand isOptional once we can differentiate.
3075 if (!attr
.hasDefaultValue() && !attr
.isDerivedAttr())
3078 // Creating an APInt requires us to provide bitwidth, value, and
3079 // signedness, which is complicated compared to others. Similarly
3081 // TODO: Adjust the 'returnType' field of such attributes
3083 StringRef retType
= namedAttr
->attr
.getReturnType();
3084 if (retType
== "::llvm::APInt" || retType
== "::llvm::APFloat")
3087 defaultValuedAttrLikeStartIndex
= i
;
3088 defaultValuedAttrStartIndex
= i
;
3091 // Avoid generating build methods that are ambiguous due to default values by
3092 // requiring at least one attribute.
3093 if (defaultValuedAttrStartIndex
< op
.getNumArgs()) {
3094 // TODO: This should have been possible as a cast<NamedAttribute> but
3095 // required template instantiations is not yet defined for the tblgen helper
3098 cast
<NamedAttribute
*>(op
.getArg(defaultValuedAttrStartIndex
));
3099 Attribute attr
= namedAttr
->attr
;
3100 if ((attrParamKind
== AttrParamKind::WrappedAttr
&&
3101 canUseUnwrappedRawValue(attr
)) ||
3102 (attrParamKind
== AttrParamKind::UnwrappedValue
&&
3103 !canUseUnwrappedRawValue(attr
))) {
3104 ++defaultValuedAttrStartIndex
;
3105 defaultValuedAttrLikeStartIndex
= defaultValuedAttrStartIndex
;
3109 /// Collect any inferred attributes.
3110 for (const NamedTypeConstraint
&operand
: op
.getOperands()) {
3111 if (operand
.isVariadicOfVariadic()) {
3112 inferredAttributes
.insert(
3113 operand
.constraint
.getVariadicOfVariadicSegmentSizeAttr());
3117 for (int i
= 0, e
= op
.getNumArgs(), numOperands
= 0; i
< e
; ++i
) {
3118 Argument arg
= op
.getArg(i
);
3119 if (const auto *operand
=
3120 llvm::dyn_cast_if_present
<NamedTypeConstraint
*>(arg
)) {
3122 if (operand
->isVariadicOfVariadic())
3123 type
= "::llvm::ArrayRef<::mlir::ValueRange>";
3124 else if (operand
->isVariadic())
3125 type
= "::mlir::ValueRange";
3127 type
= "::mlir::Value";
3129 paramList
.emplace_back(type
, getArgumentName(op
, numOperands
++),
3130 operand
->isOptional());
3133 if (auto *propArg
= llvm::dyn_cast_if_present
<NamedProperty
*>(arg
)) {
3134 const Property
&prop
= propArg
->prop
;
3135 StringRef type
= prop
.getInterfaceType();
3136 std::string defaultValue
;
3137 if (prop
.hasDefaultValue() && i
>= defaultValuedAttrLikeStartIndex
) {
3138 defaultValue
= prop
.getDefaultValue();
3140 bool isOptional
= prop
.hasDefaultValue();
3141 paramList
.emplace_back(type
, propArg
->name
, StringRef(defaultValue
),
3145 const NamedAttribute
&namedAttr
= *arg
.get
<NamedAttribute
*>();
3146 const Attribute
&attr
= namedAttr
.attr
;
3148 // Inferred attributes don't need to be added to the param list.
3149 if (inferredAttributes
.contains(namedAttr
.name
))
3153 switch (attrParamKind
) {
3154 case AttrParamKind::WrappedAttr
:
3155 type
= attr
.getStorageType();
3157 case AttrParamKind::UnwrappedValue
:
3158 if (canUseUnwrappedRawValue(attr
))
3159 type
= attr
.getReturnType();
3161 type
= attr
.getStorageType();
3165 // Attach default value if requested and possible.
3166 std::string defaultValue
;
3167 if (i
>= defaultValuedAttrStartIndex
) {
3168 if (attrParamKind
== AttrParamKind::UnwrappedValue
&&
3169 canUseUnwrappedRawValue(attr
))
3170 defaultValue
+= attr
.getDefaultValue();
3172 defaultValue
+= "nullptr";
3174 paramList
.emplace_back(type
, namedAttr
.name
, StringRef(defaultValue
),
3178 /// Insert parameters for each successor.
3179 for (const NamedSuccessor
&succ
: op
.getSuccessors()) {
3181 succ
.isVariadic() ? "::mlir::BlockRange" : "::mlir::Block *";
3182 paramList
.emplace_back(type
, succ
.name
);
3185 /// Insert parameters for variadic regions.
3186 for (const NamedRegion
®ion
: op
.getRegions())
3187 if (region
.isVariadic())
3188 paramList
.emplace_back("unsigned",
3189 llvm::formatv("{0}Count", region
.name
).str());
3192 void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
3193 MethodBody
&body
, llvm::StringSet
<> &inferredAttributes
,
3194 bool isRawValueAttr
) {
3195 // Push all operands to the result.
3196 for (int i
= 0, e
= op
.getNumOperands(); i
< e
; ++i
) {
3197 std::string argName
= getArgumentName(op
, i
);
3198 const NamedTypeConstraint
&operand
= op
.getOperand(i
);
3199 if (operand
.constraint
.isVariadicOfVariadic()) {
3200 body
<< " for (::mlir::ValueRange range : " << argName
<< ")\n "
3201 << builderOpState
<< ".addOperands(range);\n";
3203 // Add the segment attribute.
3205 << " ::llvm::SmallVector<int32_t> rangeSegments;\n"
3206 << " for (::mlir::ValueRange range : " << argName
<< ")\n"
3207 << " rangeSegments.push_back(range.size());\n"
3208 << " auto rangeAttr = " << odsBuilder
3209 << ".getDenseI32ArrayAttr(rangeSegments);\n";
3210 if (op
.getDialect().usePropertiesForAttributes()) {
3211 body
<< " " << builderOpState
<< ".getOrAddProperties<Properties>()."
3212 << operand
.constraint
.getVariadicOfVariadicSegmentSizeAttr()
3215 body
<< " " << builderOpState
<< ".addAttribute("
3216 << op
.getGetterName(
3217 operand
.constraint
.getVariadicOfVariadicSegmentSizeAttr())
3218 << "AttrName(" << builderOpState
<< ".name), rangeAttr);";
3224 if (operand
.isOptional())
3225 body
<< " if (" << argName
<< ")\n ";
3226 body
<< " " << builderOpState
<< ".addOperands(" << argName
<< ");\n";
3229 // If the operation has the operand segment size attribute, add it here.
3230 auto emitSegment
= [&]() {
3231 interleaveComma(llvm::seq
<int>(0, op
.getNumOperands()), body
, [&](int i
) {
3232 const NamedTypeConstraint
&operand
= op
.getOperand(i
);
3233 if (!operand
.isVariableLength()) {
3238 std::string operandName
= getArgumentName(op
, i
);
3239 if (operand
.isOptional()) {
3240 body
<< "(" << operandName
<< " ? 1 : 0)";
3241 } else if (operand
.isVariadicOfVariadic()) {
3242 body
<< llvm::formatv(
3243 "static_cast<int32_t>(std::accumulate({0}.begin(), {0}.end(), 0, "
3244 "[](int32_t curSum, ::mlir::ValueRange range) {{ return curSum + "
3245 "static_cast<int32_t>(range.size()); }))",
3248 body
<< "static_cast<int32_t>(" << getArgumentName(op
, i
) << ".size())";
3252 if (op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
3253 std::string sizes
= op
.getGetterName(operandSegmentAttrName
);
3254 if (op
.getDialect().usePropertiesForAttributes()) {
3255 body
<< " ::llvm::copy(::llvm::ArrayRef<int32_t>({";
3257 body
<< "}), " << builderOpState
3258 << ".getOrAddProperties<Properties>()."
3259 "operandSegmentSizes.begin());\n";
3261 body
<< " " << builderOpState
<< ".addAttribute(" << sizes
<< "AttrName("
3262 << builderOpState
<< ".name), "
3263 << "odsBuilder.getDenseI32ArrayAttr({";
3269 // Push all properties to the result.
3270 for (const auto &namedProp
: op
.getProperties()) {
3271 // Use the setter from the Properties struct since the conversion from the
3272 // interface type (used in the builder argument) to the storage type (used
3273 // in the state) is not necessarily trivial.
3274 std::string setterName
= op
.getSetterName(namedProp
.name
);
3275 body
<< formatv(" {0}.getOrAddProperties<Properties>().{1}({2});\n",
3276 builderOpState
, setterName
, namedProp
.name
);
3278 // Push all attributes to the result.
3279 for (const auto &namedAttr
: op
.getAttributes()) {
3280 auto &attr
= namedAttr
.attr
;
3281 if (attr
.isDerivedAttr() || inferredAttributes
.contains(namedAttr
.name
))
3284 // TODO: The wrapping of optional is different for default or not, so don't
3285 // unwrap for default ones that would fail below.
3286 bool emitNotNullCheck
=
3287 (attr
.isOptional() && !attr
.hasDefaultValue()) ||
3288 (attr
.hasDefaultValue() && !isRawValueAttr
) ||
3289 // TODO: UnitAttr is optional, not wrapped, but needs to be guarded as
3290 // the constant materialization is only for true case.
3291 (isRawValueAttr
&& attr
.getAttrDefName() == "UnitAttr");
3292 if (emitNotNullCheck
)
3293 body
.indent() << formatv("if ({0}) ", namedAttr
.name
) << "{\n";
3295 if (isRawValueAttr
&& canUseUnwrappedRawValue(attr
)) {
3296 // If this is a raw value, then we need to wrap it in an Attribute
3299 fctx
.withBuilder("odsBuilder");
3300 if (op
.getDialect().usePropertiesForAttributes()) {
3301 body
<< formatv(" {0}.getOrAddProperties<Properties>().{1} = {2};\n",
3302 builderOpState
, namedAttr
.name
,
3303 constBuildAttrFromParam(attr
, fctx
, namedAttr
.name
));
3305 body
<< formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
3306 builderOpState
, op
.getGetterName(namedAttr
.name
),
3307 constBuildAttrFromParam(attr
, fctx
, namedAttr
.name
));
3310 if (op
.getDialect().usePropertiesForAttributes()) {
3311 body
<< formatv(" {0}.getOrAddProperties<Properties>().{1} = {1};\n",
3312 builderOpState
, namedAttr
.name
);
3314 body
<< formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
3315 builderOpState
, op
.getGetterName(namedAttr
.name
),
3319 if (emitNotNullCheck
)
3320 body
.unindent() << " }\n";
3323 // Create the correct number of regions.
3324 for (const NamedRegion
®ion
: op
.getRegions()) {
3325 if (region
.isVariadic())
3326 body
<< formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ",
3329 body
<< " (void)" << builderOpState
<< ".addRegion();\n";
3332 // Push all successors to the result.
3333 for (const NamedSuccessor
&namedSuccessor
: op
.getSuccessors()) {
3334 body
<< formatv(" {0}.addSuccessors({1});\n", builderOpState
,
3335 namedSuccessor
.name
);
3339 void OpEmitter::genCanonicalizerDecls() {
3340 bool hasCanonicalizeMethod
= def
.getValueAsBit("hasCanonicalizeMethod");
3341 if (hasCanonicalizeMethod
) {
3342 // static LogicResult FooOp::
3343 // canonicalize(FooOp op, PatternRewriter &rewriter);
3344 SmallVector
<MethodParameter
> paramList
;
3345 paramList
.emplace_back(op
.getCppClassName(), "op");
3346 paramList
.emplace_back("::mlir::PatternRewriter &", "rewriter");
3347 auto *m
= opClass
.declareStaticMethod("::llvm::LogicalResult",
3348 "canonicalize", std::move(paramList
));
3349 ERROR_IF_PRUNED(m
, "canonicalize", op
);
3352 // We get a prototype for 'getCanonicalizationPatterns' if requested directly
3353 // or if using a 'canonicalize' method.
3354 bool hasCanonicalizer
= def
.getValueAsBit("hasCanonicalizer");
3355 if (!hasCanonicalizeMethod
&& !hasCanonicalizer
)
3358 // We get a body for 'getCanonicalizationPatterns' when using a 'canonicalize'
3359 // method, but not implementing 'getCanonicalizationPatterns' manually.
3360 bool hasBody
= hasCanonicalizeMethod
&& !hasCanonicalizer
;
3362 // Add a signature for getCanonicalizationPatterns if implemented by the
3363 // dialect or if synthesized to call 'canonicalize'.
3364 SmallVector
<MethodParameter
> paramList
;
3365 paramList
.emplace_back("::mlir::RewritePatternSet &", "results");
3366 paramList
.emplace_back("::mlir::MLIRContext *", "context");
3367 auto kind
= hasBody
? Method::Static
: Method::StaticDeclaration
;
3368 auto *method
= opClass
.addMethod("void", "getCanonicalizationPatterns", kind
,
3369 std::move(paramList
));
3371 // If synthesizing the method, fill it.
3373 ERROR_IF_PRUNED(method
, "getCanonicalizationPatterns", op
);
3374 method
->body() << " results.add(canonicalize);\n";
3378 void OpEmitter::genFolderDecls() {
3379 if (!op
.hasFolder())
3382 SmallVector
<MethodParameter
> paramList
;
3383 paramList
.emplace_back("FoldAdaptor", "adaptor");
3386 bool hasSingleResult
=
3387 op
.getNumResults() == 1 && op
.getNumVariableLengthResults() == 0;
3388 if (hasSingleResult
) {
3389 retType
= "::mlir::OpFoldResult";
3391 paramList
.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
3393 retType
= "::llvm::LogicalResult";
3396 auto *m
= opClass
.declareMethod(retType
, "fold", std::move(paramList
));
3397 ERROR_IF_PRUNED(m
, "fold", op
);
3400 void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait
*opTrait
) {
3401 Interface interface
= opTrait
->getInterface();
3403 // Get the set of methods that should always be declared.
3404 auto alwaysDeclaredMethodsVec
= opTrait
->getAlwaysDeclaredMethods();
3405 llvm::StringSet
<> alwaysDeclaredMethods
;
3406 alwaysDeclaredMethods
.insert(alwaysDeclaredMethodsVec
.begin(),
3407 alwaysDeclaredMethodsVec
.end());
3409 for (const InterfaceMethod
&method
: interface
.getMethods()) {
3410 // Don't declare if the method has a body.
3411 if (method
.getBody())
3413 // Don't declare if the method has a default implementation and the op
3414 // didn't request that it always be declared.
3415 if (method
.getDefaultImplementation() &&
3416 !alwaysDeclaredMethods
.count(method
.getName()))
3418 // Interface methods are allowed to overlap with existing methods, so don't
3420 (void)genOpInterfaceMethod(method
);
3424 Method
*OpEmitter::genOpInterfaceMethod(const InterfaceMethod
&method
,
3426 SmallVector
<MethodParameter
> paramList
;
3427 for (const InterfaceMethod::Argument
&arg
: method
.getArguments())
3428 paramList
.emplace_back(arg
.type
, arg
.name
);
3430 auto props
= (method
.isStatic() ? Method::Static
: Method::None
) |
3431 (declaration
? Method::Declaration
: Method::None
);
3432 return opClass
.addMethod(method
.getReturnType(), method
.getName(), props
,
3433 std::move(paramList
));
3436 void OpEmitter::genOpInterfaceMethods() {
3437 for (const auto &trait
: op
.getTraits()) {
3438 if (const auto *opTrait
= dyn_cast
<tblgen::InterfaceTrait
>(&trait
))
3439 if (opTrait
->shouldDeclareMethods())
3440 genOpInterfaceMethods(opTrait
);
3444 void OpEmitter::genSideEffectInterfaceMethods() {
3445 enum EffectKind
{ Operand
, Result
, Symbol
, Static
};
3446 struct EffectLocation
{
3447 /// The effect applied.
3450 /// The index if the kind is not static.
3453 /// The kind of the location.
3457 StringMap
<SmallVector
<EffectLocation
, 1>> interfaceEffects
;
3458 auto resolveDecorators
= [&](Operator::var_decorator_range decorators
,
3459 unsigned index
, unsigned kind
) {
3460 for (auto decorator
: decorators
)
3461 if (SideEffect
*effect
= dyn_cast
<SideEffect
>(&decorator
)) {
3462 opClass
.addTrait(effect
->getInterfaceTrait());
3463 interfaceEffects
[effect
->getBaseEffectName()].push_back(
3464 EffectLocation
{*effect
, index
, kind
});
3468 // Collect effects that were specified via:
3470 for (const auto &trait
: op
.getTraits()) {
3471 const auto *opTrait
= dyn_cast
<tblgen::SideEffectTrait
>(&trait
);
3474 auto &effects
= interfaceEffects
[opTrait
->getBaseEffectName()];
3475 for (auto decorator
: opTrait
->getEffects())
3476 effects
.push_back(EffectLocation
{cast
<SideEffect
>(decorator
),
3477 /*index=*/0, EffectKind::Static
});
3479 /// Attributes and Operands.
3480 for (unsigned i
= 0, operandIt
= 0, e
= op
.getNumArgs(); i
!= e
; ++i
) {
3481 Argument arg
= op
.getArg(i
);
3482 if (arg
.is
<NamedTypeConstraint
*>()) {
3483 resolveDecorators(op
.getArgDecorators(i
), operandIt
, EffectKind::Operand
);
3487 if (arg
.is
<NamedProperty
*>())
3489 const NamedAttribute
*attr
= arg
.get
<NamedAttribute
*>();
3490 if (attr
->attr
.getBaseAttr().isSymbolRefAttr())
3491 resolveDecorators(op
.getArgDecorators(i
), i
, EffectKind::Symbol
);
3494 for (unsigned i
= 0, e
= op
.getNumResults(); i
!= e
; ++i
)
3495 resolveDecorators(op
.getResultDecorators(i
), i
, EffectKind::Result
);
3497 // The code used to add an effect instance.
3498 // {0}: The effect class.
3499 // {1}: Optional value or symbol reference.
3500 // {2}: The side effect stage.
3501 // {3}: Does this side effect act on every single value of resource.
3502 // {4}: The resource class.
3503 const char *addEffectCode
=
3504 " effects.emplace_back({0}::get(), {1}{2}, {3}, {4}::get());\n";
3506 for (auto &it
: interfaceEffects
) {
3507 // Generate the 'getEffects' method.
3508 std::string type
= llvm::formatv("::llvm::SmallVectorImpl<::mlir::"
3509 "SideEffects::EffectInstance<{0}>> &",
3512 auto *getEffects
= opClass
.addMethod("void", "getEffects",
3513 MethodParameter(type
, "effects"));
3514 ERROR_IF_PRUNED(getEffects
, "getEffects", op
);
3515 auto &body
= getEffects
->body();
3517 // Add effect instances for each of the locations marked on the operation.
3518 for (auto &location
: it
.second
) {
3519 StringRef effect
= location
.effect
.getName();
3520 StringRef resource
= location
.effect
.getResource();
3521 int stage
= (int)location
.effect
.getStage();
3522 bool effectOnFullRegion
= (int)location
.effect
.getEffectOnfullRegion();
3523 if (location
.kind
== EffectKind::Static
) {
3524 // A static instance has no attached value.
3525 body
<< llvm::formatv(addEffectCode
, effect
, "", stage
,
3526 effectOnFullRegion
, resource
)
3528 } else if (location
.kind
== EffectKind::Symbol
) {
3529 // A symbol reference requires adding the proper attribute.
3530 const auto *attr
= op
.getArg(location
.index
).get
<NamedAttribute
*>();
3531 std::string argName
= op
.getGetterName(attr
->name
);
3532 if (attr
->attr
.isOptional()) {
3533 body
<< " if (auto symbolRef = " << argName
<< "Attr())\n "
3534 << llvm::formatv(addEffectCode
, effect
, "symbolRef, ", stage
,
3535 effectOnFullRegion
, resource
)
3538 body
<< llvm::formatv(addEffectCode
, effect
, argName
+ "Attr(), ",
3539 stage
, effectOnFullRegion
, resource
)
3543 // Otherwise this is an operand/result, so we need to attach the Value.
3544 body
<< " {\n auto valueRange = getODS"
3545 << (location
.kind
== EffectKind::Operand
? "Operand" : "Result")
3546 << "IndexAndLength(" << location
.index
<< ");\n"
3547 << " for (unsigned idx = valueRange.first; idx < "
3549 << " + valueRange.second; idx++) {\n "
3550 << llvm::formatv(addEffectCode
, effect
,
3551 (location
.kind
== EffectKind::Operand
3552 ? "&getOperation()->getOpOperand(idx), "
3553 : "getOperation()->getOpResult(idx), "),
3554 stage
, effectOnFullRegion
, resource
)
3561 void OpEmitter::genTypeInterfaceMethods() {
3562 if (!op
.allResultTypesKnown())
3564 // Generate 'inferReturnTypes' method declaration using the interface method
3565 // declared in 'InferTypeOpInterface' op interface.
3567 cast
<InterfaceTrait
>(op
.getTrait("::mlir::InferTypeOpInterface::Trait"));
3568 Interface interface
= trait
->getInterface();
3569 Method
*method
= [&]() -> Method
* {
3570 for (const InterfaceMethod
&interfaceMethod
: interface
.getMethods()) {
3571 if (interfaceMethod
.getName() == "inferReturnTypes") {
3572 return genOpInterfaceMethod(interfaceMethod
, /*declaration=*/false);
3575 assert(0 && "unable to find inferReturnTypes interface method");
3578 ERROR_IF_PRUNED(method
, "inferReturnTypes", op
);
3579 auto &body
= method
->body();
3580 body
<< " inferredReturnTypes.resize(" << op
.getNumResults() << ");\n";
3583 fctx
.withBuilder("odsBuilder");
3584 fctx
.addSubst("_ctxt", "context");
3585 body
<< " ::mlir::Builder odsBuilder(context);\n";
3587 // Preprocessing stage to verify all accesses to operands are valid.
3588 int maxAccessedIndex
= -1;
3589 for (int i
= 0, e
= op
.getNumResults(); i
!= e
; ++i
) {
3590 const InferredResultType
&infer
= op
.getInferredResultType(i
);
3593 Operator::OperandOrAttribute arg
=
3594 op
.getArgToOperandOrAttribute(infer
.getIndex());
3595 if (arg
.kind() == Operator::OperandOrAttribute::Kind::Operand
) {
3597 std::max(maxAccessedIndex
, arg
.operandOrAttributeIndex());
3600 if (maxAccessedIndex
!= -1) {
3601 body
<< " if (operands.size() <= " << Twine(maxAccessedIndex
) << ")\n";
3602 body
<< " return ::mlir::failure();\n";
3605 // Process the type inference graph in topological order, starting from types
3606 // that are always fully-inferred: operands and results with constructible
3607 // types. The type inference graph here will always be a DAG, so this gives
3608 // us the correct order for generating the types. -1 is a placeholder to
3609 // indicate the type for a result has not been generated.
3610 SmallVector
<int> constructedIndices(op
.getNumResults(), -1);
3611 int inferredTypeIdx
= 0;
3612 for (int numResults
= op
.getNumResults(); inferredTypeIdx
!= numResults
;) {
3613 for (int i
= 0, e
= op
.getNumResults(); i
!= e
; ++i
) {
3614 if (constructedIndices
[i
] >= 0)
3616 const InferredResultType
&infer
= op
.getInferredResultType(i
);
3617 std::string typeStr
;
3618 if (infer
.isArg()) {
3619 // If this is an operand, just index into operand list to access the
3621 Operator::OperandOrAttribute arg
=
3622 op
.getArgToOperandOrAttribute(infer
.getIndex());
3623 if (arg
.kind() == Operator::OperandOrAttribute::Kind::Operand
) {
3624 typeStr
= ("operands[" + Twine(arg
.operandOrAttributeIndex()) +
3628 // If this is an attribute, index into the attribute dictionary.
3631 op
.getArg(arg
.operandOrAttributeIndex()).get
<NamedAttribute
*>();
3632 body
<< " ::mlir::TypedAttr odsInferredTypeAttr" << inferredTypeIdx
3634 if (op
.getDialect().usePropertiesForAttributes()) {
3635 body
<< "(properties ? properties.as<Properties *>()->"
3638 "::llvm::dyn_cast_or_null<::mlir::TypedAttr>(attributes."
3640 attr
->name
+ "\")));\n";
3642 body
<< "::llvm::dyn_cast_or_null<::mlir::TypedAttr>(attributes."
3644 attr
->name
+ "\"));\n";
3646 body
<< " if (!odsInferredTypeAttr" << inferredTypeIdx
3647 << ") return ::mlir::failure();\n";
3649 ("odsInferredTypeAttr" + Twine(inferredTypeIdx
) + ".getType()")
3652 } else if (std::optional
<StringRef
> builder
=
3653 op
.getResult(infer
.getResultIndex())
3654 .constraint
.getBuilderCall()) {
3655 typeStr
= tgfmt(*builder
, &fctx
).str();
3656 } else if (int index
= constructedIndices
[infer
.getResultIndex()];
3658 typeStr
= ("odsInferredType" + Twine(index
)).str();
3662 body
<< " ::mlir::Type odsInferredType" << inferredTypeIdx
++ << " = "
3663 << tgfmt(infer
.getTransformer(), &fctx
.withSelf(typeStr
)) << ";\n";
3664 constructedIndices
[i
] = inferredTypeIdx
- 1;
3667 for (auto [i
, index
] : llvm::enumerate(constructedIndices
))
3668 body
<< " inferredReturnTypes[" << i
<< "] = odsInferredType" << index
3670 body
<< " return ::mlir::success();";
3673 void OpEmitter::genParser() {
3674 if (hasStringAttribute(def
, "assemblyFormat"))
3677 if (!def
.getValueAsBit("hasCustomAssemblyFormat"))
3680 SmallVector
<MethodParameter
> paramList
;
3681 paramList
.emplace_back("::mlir::OpAsmParser &", "parser");
3682 paramList
.emplace_back("::mlir::OperationState &", "result");
3684 auto *method
= opClass
.declareStaticMethod("::mlir::ParseResult", "parse",
3685 std::move(paramList
));
3686 ERROR_IF_PRUNED(method
, "parse", op
);
3689 void OpEmitter::genPrinter() {
3690 if (hasStringAttribute(def
, "assemblyFormat"))
3693 // Check to see if this op uses a c++ format.
3694 if (!def
.getValueAsBit("hasCustomAssemblyFormat"))
3696 auto *method
= opClass
.declareMethod(
3697 "void", "print", MethodParameter("::mlir::OpAsmPrinter &", "p"));
3698 ERROR_IF_PRUNED(method
, "print", op
);
3701 void OpEmitter::genVerifier() {
3703 opClass
.addMethod("::llvm::LogicalResult", "verifyInvariantsImpl");
3704 ERROR_IF_PRUNED(implMethod
, "verifyInvariantsImpl", op
);
3705 auto &implBody
= implMethod
->body();
3706 bool useProperties
= emitHelper
.hasProperties();
3708 populateSubstitutions(emitHelper
, verifyCtx
);
3709 genAttributeVerifier(emitHelper
, verifyCtx
, implBody
, staticVerifierEmitter
,
3711 genOperandResultVerifier(implBody
, op
.getOperands(), "operand");
3712 genOperandResultVerifier(implBody
, op
.getResults(), "result");
3714 for (auto &trait
: op
.getTraits()) {
3715 if (auto *t
= dyn_cast
<tblgen::PredTrait
>(&trait
)) {
3716 implBody
<< tgfmt(" if (!($0))\n "
3717 "return emitOpError(\"failed to verify that $1\");\n",
3718 &verifyCtx
, tgfmt(t
->getPredTemplate(), &verifyCtx
),
3723 genRegionVerifier(implBody
);
3724 genSuccessorVerifier(implBody
);
3726 implBody
<< " return ::mlir::success();\n";
3728 // TODO: Some places use the `verifyInvariants` to do operation verification.
3729 // This may not act as their expectation because this doesn't call any
3730 // verifiers of native/interface traits. Needs to review those use cases and
3731 // see if we should use the mlir::verify() instead.
3732 auto *method
= opClass
.addMethod("::llvm::LogicalResult", "verifyInvariants");
3733 ERROR_IF_PRUNED(method
, "verifyInvariants", op
);
3734 auto &body
= method
->body();
3735 if (def
.getValueAsBit("hasVerifier")) {
3736 body
<< " if(::mlir::succeeded(verifyInvariantsImpl()) && "
3737 "::mlir::succeeded(verify()))\n";
3738 body
<< " return ::mlir::success();\n";
3739 body
<< " return ::mlir::failure();";
3741 body
<< " return verifyInvariantsImpl();";
3745 void OpEmitter::genCustomVerifier() {
3746 if (def
.getValueAsBit("hasVerifier")) {
3747 auto *method
= opClass
.declareMethod("::llvm::LogicalResult", "verify");
3748 ERROR_IF_PRUNED(method
, "verify", op
);
3751 if (def
.getValueAsBit("hasRegionVerifier")) {
3753 opClass
.declareMethod("::llvm::LogicalResult", "verifyRegions");
3754 ERROR_IF_PRUNED(method
, "verifyRegions", op
);
3758 void OpEmitter::genOperandResultVerifier(MethodBody
&body
,
3759 Operator::const_value_range values
,
3760 StringRef valueKind
) {
3761 // Check that an optional value is at most 1 element.
3763 // {0}: Value index.
3764 // {1}: "operand" or "result"
3765 const char *const verifyOptional
= R
"(
3766 if (valueGroup{0}.size() > 1) {
3767 return emitOpError("{1} group starting at
#") << index
3768 << " requires 0 or 1 element, but found " << valueGroup{0}.size();
3771 // Check the types of a range of values.
3773 // {0}: Value index.
3774 // {1}: Type constraint function.
3775 // {2}: "operand" or "result"
3776 const char *const verifyValues
= R
"(
3777 for (auto v : valueGroup{0}) {
3778 if (::mlir::failed({1}(*this, v.getType(), "{2}", index++)))
3779 return ::mlir::failure();
3783 const auto canSkip
= [](const NamedTypeConstraint
&value
) {
3784 return !value
.hasPredicate() && !value
.isOptional() &&
3785 !value
.isVariadicOfVariadic();
3787 if (values
.empty() || llvm::all_of(values
, canSkip
))
3792 body
<< " {\n unsigned index = 0; (void)index;\n";
3794 for (const auto &staticValue
: llvm::enumerate(values
)) {
3795 const NamedTypeConstraint
&value
= staticValue
.value();
3797 bool hasPredicate
= value
.hasPredicate();
3798 bool isOptional
= value
.isOptional();
3799 bool isVariadicOfVariadic
= value
.isVariadicOfVariadic();
3800 if (!hasPredicate
&& !isOptional
&& !isVariadicOfVariadic
)
3802 body
<< formatv(" auto valueGroup{2} = getODS{0}{1}s({2});\n",
3803 // Capitalize the first letter to match the function name
3804 valueKind
.substr(0, 1).upper(), valueKind
.substr(1),
3805 staticValue
.index());
3807 // If the constraint is optional check that the value group has at most 1
3810 body
<< formatv(verifyOptional
, staticValue
.index(), valueKind
);
3811 } else if (isVariadicOfVariadic
) {
3813 " if (::mlir::failed(::mlir::OpTrait::impl::verifyValueSizeAttr("
3814 "*this, \"{0}\", \"{1}\", valueGroup{2}.size())))\n"
3815 " return ::mlir::failure();\n",
3816 value
.constraint
.getVariadicOfVariadicSegmentSizeAttr(), value
.name
,
3817 staticValue
.index());
3820 // Otherwise, if there is no predicate there is nothing left to do.
3823 // Emit a loop to check all the dynamic values in the pack.
3824 StringRef constraintFn
=
3825 staticVerifierEmitter
.getTypeConstraintFn(value
.constraint
);
3826 body
<< formatv(verifyValues
, staticValue
.index(), constraintFn
, valueKind
);
3832 void OpEmitter::genRegionVerifier(MethodBody
&body
) {
3833 /// Code to verify a region.
3835 /// {0}: Getter for the regions.
3836 /// {1}: The region constraint.
3837 /// {2}: The region's name.
3838 /// {3}: The region description.
3839 const char *const verifyRegion
= R
"(
3840 for (auto ®ion : {0})
3841 if (::mlir::failed({1}(*this, region, "{2}", index++)))
3842 return ::mlir::failure();
3844 /// Get a single region.
3846 /// {0}: The region's index.
3847 const char *const getSingleRegion
=
3848 "::llvm::MutableArrayRef((*this)->getRegion({0}))";
3850 // If we have no regions, there is nothing more to do.
3851 const auto canSkip
= [](const NamedRegion
®ion
) {
3852 return region
.constraint
.getPredicate().isNull();
3854 auto regions
= op
.getRegions();
3855 if (regions
.empty() && llvm::all_of(regions
, canSkip
))
3858 body
<< " {\n unsigned index = 0; (void)index;\n";
3859 for (const auto &it
: llvm::enumerate(regions
)) {
3860 const auto ®ion
= it
.value();
3861 if (canSkip(region
))
3864 auto getRegion
= region
.isVariadic()
3865 ? formatv("{0}()", op
.getGetterName(region
.name
)).str()
3866 : formatv(getSingleRegion
, it
.index()).str();
3868 staticVerifierEmitter
.getRegionConstraintFn(region
.constraint
);
3869 body
<< formatv(verifyRegion
, getRegion
, constraintFn
, region
.name
);
3874 void OpEmitter::genSuccessorVerifier(MethodBody
&body
) {
3875 const char *const verifySuccessor
= R
"(
3876 for (auto *successor : {0})
3877 if (::mlir::failed({1}(*this, successor, "{2}", index++)))
3878 return ::mlir::failure();
3880 /// Get a single successor.
3882 /// {0}: The successor's name.
3883 const char *const getSingleSuccessor
= "::llvm::MutableArrayRef({0}())";
3885 // If we have no successors, there is nothing more to do.
3886 const auto canSkip
= [](const NamedSuccessor
&successor
) {
3887 return successor
.constraint
.getPredicate().isNull();
3889 auto successors
= op
.getSuccessors();
3890 if (successors
.empty() && llvm::all_of(successors
, canSkip
))
3893 body
<< " {\n unsigned index = 0; (void)index;\n";
3895 for (auto it
: llvm::enumerate(successors
)) {
3896 const auto &successor
= it
.value();
3897 if (canSkip(successor
))
3901 formatv(successor
.isVariadic() ? "{0}()" : getSingleSuccessor
,
3905 staticVerifierEmitter
.getSuccessorConstraintFn(successor
.constraint
);
3906 body
<< formatv(verifySuccessor
, getSuccessor
, constraintFn
,
3912 /// Add a size count trait to the given operation class.
3913 static void addSizeCountTrait(OpClass
&opClass
, StringRef traitKind
,
3914 int numTotal
, int numVariadic
) {
3915 if (numVariadic
!= 0) {
3916 if (numTotal
== numVariadic
)
3917 opClass
.addTrait("::mlir::OpTrait::Variadic" + traitKind
+ "s");
3919 opClass
.addTrait("::mlir::OpTrait::AtLeastN" + traitKind
+ "s<" +
3920 Twine(numTotal
- numVariadic
) + ">::Impl");
3925 opClass
.addTrait("::mlir::OpTrait::Zero" + traitKind
+ "s");
3928 opClass
.addTrait("::mlir::OpTrait::One" + traitKind
);
3931 opClass
.addTrait("::mlir::OpTrait::N" + traitKind
+ "s<" + Twine(numTotal
) +
3937 void OpEmitter::genTraits() {
3938 // Add region size trait.
3939 unsigned numRegions
= op
.getNumRegions();
3940 unsigned numVariadicRegions
= op
.getNumVariadicRegions();
3941 addSizeCountTrait(opClass
, "Region", numRegions
, numVariadicRegions
);
3943 // Add result size traits.
3944 int numResults
= op
.getNumResults();
3945 int numVariadicResults
= op
.getNumVariableLengthResults();
3946 addSizeCountTrait(opClass
, "Result", numResults
, numVariadicResults
);
3948 // For single result ops with a known specific type, generate a OneTypedResult
3950 if (numResults
== 1 && numVariadicResults
== 0) {
3951 auto cppName
= op
.getResults().begin()->constraint
.getCppType();
3952 opClass
.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName
+ ">::Impl");
3955 // Add successor size trait.
3956 unsigned numSuccessors
= op
.getNumSuccessors();
3957 unsigned numVariadicSuccessors
= op
.getNumVariadicSuccessors();
3958 addSizeCountTrait(opClass
, "Successor", numSuccessors
, numVariadicSuccessors
);
3960 // Add variadic size trait and normal op traits.
3961 int numOperands
= op
.getNumOperands();
3962 int numVariadicOperands
= op
.getNumVariableLengthOperands();
3964 // Add operand size trait.
3965 addSizeCountTrait(opClass
, "Operand", numOperands
, numVariadicOperands
);
3967 // The op traits defined internal are ensured that they can be verified
3969 for (const auto &trait
: op
.getTraits()) {
3970 if (auto *opTrait
= dyn_cast
<tblgen::NativeTrait
>(&trait
)) {
3971 if (opTrait
->isStructuralOpTrait())
3972 opClass
.addTrait(opTrait
->getFullyQualifiedTraitName());
3976 // OpInvariants wrapps the verifyInvariants which needs to be run before
3977 // native/interface traits and after all the traits with `StructuralOpTrait`.
3978 opClass
.addTrait("::mlir::OpTrait::OpInvariants");
3980 if (emitHelper
.hasProperties())
3981 opClass
.addTrait("::mlir::BytecodeOpInterface::Trait");
3983 // Add the native and interface traits.
3984 for (const auto &trait
: op
.getTraits()) {
3985 if (auto *opTrait
= dyn_cast
<tblgen::NativeTrait
>(&trait
)) {
3986 if (!opTrait
->isStructuralOpTrait())
3987 opClass
.addTrait(opTrait
->getFullyQualifiedTraitName());
3988 } else if (auto *opTrait
= dyn_cast
<tblgen::InterfaceTrait
>(&trait
)) {
3989 opClass
.addTrait(opTrait
->getFullyQualifiedTraitName());
3994 void OpEmitter::genOpNameGetter() {
3995 auto *method
= opClass
.addStaticMethod
<Method::Constexpr
>(
3996 "::llvm::StringLiteral", "getOperationName");
3997 ERROR_IF_PRUNED(method
, "getOperationName", op
);
3998 method
->body() << " return ::llvm::StringLiteral(\"" << op
.getOperationName()
4002 void OpEmitter::genOpAsmInterface() {
4003 // If the user only has one results or specifically added the Asm trait,
4004 // then don't generate it for them. We specifically only handle multi result
4005 // operations, because the name of a single result in the common case is not
4006 // interesting(generally 'result'/'output'/etc.).
4007 // TODO: We could also add a flag to allow operations to opt in to this
4008 // generation, even if they only have a single operation.
4009 int numResults
= op
.getNumResults();
4010 if (numResults
<= 1 || op
.getTrait("::mlir::OpAsmOpInterface::Trait"))
4013 SmallVector
<StringRef
, 4> resultNames(numResults
);
4014 for (int i
= 0; i
!= numResults
; ++i
)
4015 resultNames
[i
] = op
.getResultName(i
);
4017 // Don't add the trait if none of the results have a valid name.
4018 if (llvm::all_of(resultNames
, [](StringRef name
) { return name
.empty(); }))
4020 opClass
.addTrait("::mlir::OpAsmOpInterface::Trait");
4022 // Generate the right accessor for the number of results.
4023 auto *method
= opClass
.addMethod(
4024 "void", "getAsmResultNames",
4025 MethodParameter("::mlir::OpAsmSetValueNameFn", "setNameFn"));
4026 ERROR_IF_PRUNED(method
, "getAsmResultNames", op
);
4027 auto &body
= method
->body();
4028 for (int i
= 0; i
!= numResults
; ++i
) {
4029 body
<< " auto resultGroup" << i
<< " = getODSResults(" << i
<< ");\n"
4030 << " if (!resultGroup" << i
<< ".empty())\n"
4031 << " setNameFn(*resultGroup" << i
<< ".begin(), \""
4032 << resultNames
[i
] << "\");\n";
4036 //===----------------------------------------------------------------------===//
4037 // OpOperandAdaptor emitter
4038 //===----------------------------------------------------------------------===//
4041 // Helper class to emit Op operand adaptors to an output stream. Operand
4042 // adaptors are wrappers around random access ranges that provide named operand
4043 // getters identical to those defined in the Op.
4044 // This currently generates 3 classes per Op:
4045 // * A Base class within the 'detail' namespace, which contains all logic and
4046 // members independent of the random access range that is indexed into.
4047 // In other words, it contains all the attribute and region getters.
4048 // * A templated class named '{OpName}GenericAdaptor' with a template parameter
4049 // 'RangeT' that is indexed into by the getters to access the operands.
4050 // It contains all getters to access operands and inherits from the previous
4052 // * A class named '{OpName}Adaptor', which inherits from the 'GenericAdaptor'
4053 // with 'mlir::ValueRange' as template parameter. It adds a constructor from
4054 // an instance of the op type and a verify function.
4055 class OpOperandAdaptorEmitter
{
4058 emitDecl(const Operator
&op
,
4059 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
,
4062 emitDef(const Operator
&op
,
4063 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
,
4067 explicit OpOperandAdaptorEmitter(
4069 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
);
4071 // Add verification function. This generates a verify method for the adaptor
4072 // which verifies all the op-independent attribute constraints.
4073 void addVerification();
4075 // The operation for which to emit an adaptor.
4078 // The generated adaptor classes.
4079 Class genericAdaptorBase
;
4080 Class genericAdaptor
;
4083 // The emitter containing all of the locally emitted verification functions.
4084 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
;
4086 // Helper for emitting adaptor code.
4087 OpOrAdaptorHelper emitHelper
;
4091 OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
4093 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
)
4094 : op(op
), genericAdaptorBase(op
.getGenericAdaptorName() + "Base"),
4095 genericAdaptor(op
.getGenericAdaptorName()), adaptor(op
.getAdaptorName()),
4096 staticVerifierEmitter(staticVerifierEmitter
),
4097 emitHelper(op
, /*emitForOp=*/false) {
4099 genericAdaptorBase
.declare
<VisibilityDeclaration
>(Visibility::Public
);
4100 bool useProperties
= emitHelper
.hasProperties();
4101 if (useProperties
) {
4102 // Define the properties struct with multiple members.
4103 using ConstArgument
=
4104 llvm::PointerUnion
<const AttributeMetadata
*, const NamedProperty
*>;
4105 SmallVector
<ConstArgument
> attrOrProperties
;
4106 for (const std::pair
<StringRef
, AttributeMetadata
> &it
:
4107 emitHelper
.getAttrMetadata()) {
4108 if (!it
.second
.constraint
|| !it
.second
.constraint
->isDerivedAttr())
4109 attrOrProperties
.push_back(&it
.second
);
4111 for (const NamedProperty
&prop
: op
.getProperties())
4112 attrOrProperties
.push_back(&prop
);
4113 if (emitHelper
.getOperandSegmentsSize())
4114 attrOrProperties
.push_back(&emitHelper
.getOperandSegmentsSize().value());
4115 if (emitHelper
.getResultSegmentsSize())
4116 attrOrProperties
.push_back(&emitHelper
.getResultSegmentsSize().value());
4117 assert(!attrOrProperties
.empty());
4118 std::string declarations
= " struct Properties {\n";
4119 llvm::raw_string_ostream
os(declarations
);
4120 std::string comparator
=
4121 " bool operator==(const Properties &rhs) const {\n"
4123 llvm::raw_string_ostream
comparatorOs(comparator
);
4124 for (const auto &attrOrProp
: attrOrProperties
) {
4125 if (const auto *namedProperty
=
4126 llvm::dyn_cast_if_present
<const NamedProperty
*>(attrOrProp
)) {
4127 StringRef name
= namedProperty
->name
;
4129 report_fatal_error("missing name for property");
4130 std::string camelName
=
4131 convertToCamelFromSnakeCase(name
, /*capitalizeFirst=*/true);
4132 auto &prop
= namedProperty
->prop
;
4133 // Generate the data member using the storage type.
4134 os
<< " using " << name
<< "Ty = " << prop
.getStorageType() << ";\n"
4135 << " " << name
<< "Ty " << name
;
4136 if (prop
.hasStorageTypeValueOverride())
4137 os
<< " = " << prop
.getStorageTypeValueOverride();
4138 else if (prop
.hasDefaultValue())
4139 os
<< " = " << prop
.getDefaultValue();
4140 comparatorOs
<< " rhs." << name
<< " == this->" << name
4142 // Emit accessors using the interface type.
4143 const char *accessorFmt
= R
"decl(;
4144 {0} get{1}() const {
4145 auto &propStorage = this->{2};
4148 void set{1}({0} propValue) {
4149 auto &propStorage = this->{2};
4154 os
<< formatv(accessorFmt
, prop
.getInterfaceType(), camelName
, name
,
4155 tgfmt(prop
.getConvertFromStorageCall(),
4156 &fctx
.addSubst("_storage", propertyStorage
)),
4157 tgfmt(prop
.getAssignToStorageCall(),
4158 &fctx
.addSubst("_value", propertyValue
)
4159 .addSubst("_storage", propertyStorage
)));
4162 const auto *namedAttr
=
4163 llvm::dyn_cast_if_present
<const AttributeMetadata
*>(attrOrProp
);
4164 const Attribute
*attr
= nullptr;
4165 if (namedAttr
->constraint
)
4166 attr
= &*namedAttr
->constraint
;
4167 StringRef name
= namedAttr
->attrName
;
4169 report_fatal_error("missing name for property attr");
4170 std::string camelName
=
4171 convertToCamelFromSnakeCase(name
, /*capitalizeFirst=*/true);
4172 // Generate the data member using the storage type.
4173 StringRef storageType
;
4175 storageType
= attr
->getStorageType();
4177 if (name
!= operandSegmentAttrName
&& name
!= resultSegmentAttrName
) {
4178 report_fatal_error("unexpected AttributeMetadata");
4180 // TODO: update to use native integers.
4181 storageType
= "::mlir::DenseI32ArrayAttr";
4183 os
<< " using " << name
<< "Ty = " << storageType
<< ";\n"
4184 << " " << name
<< "Ty " << name
<< ";\n";
4185 comparatorOs
<< " rhs." << name
<< " == this->" << name
<< " &&\n";
4187 // Emit accessors using the interface type.
4189 const char *accessorFmt
= R
"decl(
4191 auto &propStorage = this->{1};
4192 return ::llvm::{2}<{3}>(propStorage);
4194 void set{0}(const {3} &propValue) {
4195 this->{1} = propValue;
4198 os
<< formatv(accessorFmt
, camelName
, name
,
4199 attr
->isOptional() || attr
->hasDefaultValue()
4200 ? "dyn_cast_or_null"
4205 comparatorOs
<< " true;\n }\n"
4206 " bool operator!=(const Properties &rhs) const {\n"
4207 " return !(*this == rhs);\n"
4212 genericAdaptorBase
.declare
<ExtraClassDeclaration
>(std::move(declarations
));
4214 genericAdaptorBase
.declare
<VisibilityDeclaration
>(Visibility::Protected
);
4215 genericAdaptorBase
.declare
<Field
>("::mlir::DictionaryAttr", "odsAttrs");
4216 genericAdaptorBase
.declare
<Field
>("::std::optional<::mlir::OperationName>",
4219 genericAdaptorBase
.declare
<Field
>("Properties", "properties");
4220 genericAdaptorBase
.declare
<Field
>("::mlir::RegionRange", "odsRegions");
4222 genericAdaptor
.addTemplateParam("RangeT");
4223 genericAdaptor
.addField("RangeT", "odsOperands");
4224 genericAdaptor
.addParent(
4225 ParentClass("detail::" + genericAdaptorBase
.getClassName()));
4226 genericAdaptor
.declare
<UsingDeclaration
>(
4227 "ValueT", "::llvm::detail::ValueOfRange<RangeT>");
4228 genericAdaptor
.declare
<UsingDeclaration
>(
4229 "Base", "detail::" + genericAdaptorBase
.getClassName());
4231 const auto *attrSizedOperands
=
4232 op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
4234 SmallVector
<MethodParameter
> paramList
;
4235 if (useProperties
) {
4236 // Properties can't be given a default constructor here due to Properties
4237 // struct being defined in the enclosing class which isn't complete by
4239 paramList
.emplace_back("::mlir::DictionaryAttr", "attrs");
4240 paramList
.emplace_back("const Properties &", "properties");
4242 paramList
.emplace_back("::mlir::DictionaryAttr", "attrs", "{}");
4243 paramList
.emplace_back("const ::mlir::EmptyProperties &", "properties",
4246 paramList
.emplace_back("::mlir::RegionRange", "regions", "{}");
4247 auto *baseConstructor
=
4248 genericAdaptorBase
.addConstructor
<Method::Inline
>(paramList
);
4249 baseConstructor
->addMemberInitializer("odsAttrs", "attrs");
4251 baseConstructor
->addMemberInitializer("properties", "properties");
4252 baseConstructor
->addMemberInitializer("odsRegions", "regions");
4254 MethodBody
&body
= baseConstructor
->body();
4255 body
.indent() << "if (odsAttrs)\n";
4256 body
.indent() << formatv(
4257 "odsOpName.emplace(\"{0}\", odsAttrs.getContext());\n",
4258 op
.getOperationName());
4260 paramList
.insert(paramList
.begin(), MethodParameter("RangeT", "values"));
4261 auto *constructor
= genericAdaptor
.addConstructor(paramList
);
4262 constructor
->addMemberInitializer("Base", "attrs, properties, regions");
4263 constructor
->addMemberInitializer("odsOperands", "values");
4265 // Add a forwarding constructor to the previous one that accepts
4266 // OpaqueProperties instead and check for null and perform the cast to the
4267 // actual properties type.
4268 paramList
[1] = MethodParameter("::mlir::DictionaryAttr", "attrs");
4269 paramList
[2] = MethodParameter("::mlir::OpaqueProperties", "properties");
4270 auto *opaquePropertiesConstructor
=
4271 genericAdaptor
.addConstructor(std::move(paramList
));
4272 if (useProperties
) {
4273 opaquePropertiesConstructor
->addMemberInitializer(
4274 genericAdaptor
.getClassName(),
4277 "(properties ? *properties.as<Properties *>() : Properties{}), "
4280 opaquePropertiesConstructor
->addMemberInitializer(
4281 genericAdaptor
.getClassName(),
4284 "(properties ? *properties.as<::mlir::EmptyProperties *>() : "
4285 "::mlir::EmptyProperties{}), "
4289 // Add forwarding constructor that constructs Properties.
4290 if (useProperties
) {
4291 SmallVector
<MethodParameter
> paramList
;
4292 paramList
.emplace_back("RangeT", "values");
4293 paramList
.emplace_back("::mlir::DictionaryAttr", "attrs",
4294 attrSizedOperands
? "" : "nullptr");
4295 auto *noPropertiesConstructor
=
4296 genericAdaptor
.addConstructor(std::move(paramList
));
4297 noPropertiesConstructor
->addMemberInitializer(
4298 genericAdaptor
.getClassName(), "values, "
4305 // Create a constructor that creates a new generic adaptor by copying
4306 // everything from another adaptor, except for the values.
4308 SmallVector
<MethodParameter
> paramList
;
4309 paramList
.emplace_back("RangeT", "values");
4310 paramList
.emplace_back("const " + op
.getGenericAdaptorName() + "Base &",
4313 genericAdaptor
.addConstructor
<Method::Inline
>(paramList
);
4314 constructor
->addMemberInitializer("Base", "base");
4315 constructor
->addMemberInitializer("odsOperands", "values");
4318 // Create constructors constructing the adaptor from an instance of the op.
4319 // This takes the attributes, properties and regions from the op instance
4320 // and the value range from the parameter.
4322 // Base class is in the cpp file and can simply access the members of the op
4323 // class to initialize the template independent fields. If the op doesn't
4324 // have properties, we can emit a generic constructor inline. Otherwise,
4325 // emit it out-of-line because we need the op to be defined.
4326 Constructor
*constructor
;
4327 if (useProperties
) {
4328 constructor
= genericAdaptorBase
.addConstructor(
4329 MethodParameter(op
.getCppClassName(), "op"));
4331 constructor
= genericAdaptorBase
.addConstructor
<Method::Inline
>(
4332 MethodParameter("::mlir::Operation *", "op"));
4334 constructor
->addMemberInitializer("odsAttrs",
4335 "op->getRawDictionaryAttrs()");
4336 // Retrieve the operation name from the op directly.
4337 constructor
->addMemberInitializer("odsOpName", "op->getName()");
4339 constructor
->addMemberInitializer("properties", "op.getProperties()");
4340 constructor
->addMemberInitializer("odsRegions", "op->getRegions()");
4342 // Generic adaptor is templated and therefore defined inline in the header.
4343 // We cannot use the Op class here as it is an incomplete type (we have a
4344 // circular reference between the two).
4345 // Use a template trick to make the constructor be instantiated at call site
4346 // when the op class is complete.
4347 constructor
= genericAdaptor
.addConstructor(
4348 MethodParameter("RangeT", "values"), MethodParameter("LateInst", "op"));
4349 constructor
->addTemplateParam("LateInst = " + op
.getCppClassName());
4350 constructor
->addTemplateParam(
4351 "= std::enable_if_t<std::is_same_v<LateInst, " + op
.getCppClassName() +
4353 constructor
->addMemberInitializer("Base", "op");
4354 constructor
->addMemberInitializer("odsOperands", "values");
4357 std::string sizeAttrInit
;
4358 if (op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
4359 if (op
.getDialect().usePropertiesForAttributes())
4361 formatv(adapterSegmentSizeAttrInitCodeProperties
,
4362 llvm::formatv("getProperties().operandSegmentSizes"));
4364 sizeAttrInit
= formatv(adapterSegmentSizeAttrInitCode
,
4365 emitHelper
.getAttr(operandSegmentAttrName
));
4367 generateNamedOperandGetters(op
, genericAdaptor
,
4368 /*genericAdaptorBase=*/&genericAdaptorBase
,
4369 /*sizeAttrInit=*/sizeAttrInit
,
4370 /*rangeType=*/"RangeT",
4371 /*rangeElementType=*/"ValueT",
4372 /*rangeBeginCall=*/"odsOperands.begin()",
4373 /*rangeSizeCall=*/"odsOperands.size()",
4374 /*getOperandCallPattern=*/"odsOperands[{0}]");
4376 // Any invalid overlap for `getOperands` will have been diagnosed before
4378 if (auto *m
= genericAdaptor
.addMethod("RangeT", "getOperands"))
4379 m
->body() << " return odsOperands;";
4382 fctx
.withBuilder("::mlir::Builder(odsAttrs.getContext())");
4384 // Generate named accessor with Attribute return type.
4385 auto emitAttrWithStorageType
= [&](StringRef name
, StringRef emitName
,
4387 // The method body is trivial if the attribute does not have a default
4388 // value, in which case the default value may be arbitrary code.
4389 auto *method
= genericAdaptorBase
.addMethod(
4390 attr
.getStorageType(), emitName
+ "Attr",
4391 attr
.hasDefaultValue() || !useProperties
? Method::Properties::None
4392 : Method::Properties::Inline
);
4393 ERROR_IF_PRUNED(method
, "Adaptor::" + emitName
+ "Attr", op
);
4394 auto &body
= method
->body().indent();
4396 body
<< "assert(odsAttrs && \"no attributes when constructing "
4399 "auto attr = ::llvm::{1}<{2}>({0});\n", emitHelper
.getAttr(name
),
4400 attr
.hasDefaultValue() || attr
.isOptional() ? "dyn_cast_or_null"
4402 attr
.getStorageType());
4404 if (attr
.hasDefaultValue() && attr
.isOptional()) {
4405 // Use the default value if attribute is not set.
4406 // TODO: this is inefficient, we are recreating the attribute for every
4407 // call. This should be set instead.
4408 std::string defaultValue
= std::string(
4409 tgfmt(attr
.getConstBuilderTemplate(), &fctx
, attr
.getDefaultValue()));
4410 body
<< "if (!attr)\n attr = " << defaultValue
<< ";\n";
4412 body
<< "return attr;\n";
4415 if (useProperties
) {
4416 auto *m
= genericAdaptorBase
.addInlineMethod("const Properties &",
4418 ERROR_IF_PRUNED(m
, "Adaptor::getProperties", op
);
4419 m
->body() << " return properties;";
4422 auto *m
= genericAdaptorBase
.addInlineMethod("::mlir::DictionaryAttr",
4424 ERROR_IF_PRUNED(m
, "Adaptor::getAttributes", op
);
4425 m
->body() << " return odsAttrs;";
4427 for (auto &namedProp
: op
.getProperties()) {
4428 std::string name
= op
.getGetterName(namedProp
.name
);
4429 emitPropGetter(genericAdaptorBase
, op
, name
, namedProp
.prop
);
4432 for (auto &namedAttr
: op
.getAttributes()) {
4433 const auto &name
= namedAttr
.name
;
4434 const auto &attr
= namedAttr
.attr
;
4435 if (attr
.isDerivedAttr())
4437 std::string emitName
= op
.getGetterName(name
);
4438 emitAttrWithStorageType(name
, emitName
, attr
);
4439 emitAttrGetterWithReturnType(fctx
, genericAdaptorBase
, op
, emitName
, attr
);
4442 unsigned numRegions
= op
.getNumRegions();
4443 for (unsigned i
= 0; i
< numRegions
; ++i
) {
4444 const auto ®ion
= op
.getRegion(i
);
4445 if (region
.name
.empty())
4448 // Generate the accessors for a variadic region.
4449 std::string name
= op
.getGetterName(region
.name
);
4450 if (region
.isVariadic()) {
4451 auto *m
= genericAdaptorBase
.addInlineMethod("::mlir::RegionRange", name
);
4452 ERROR_IF_PRUNED(m
, "Adaptor::" + name
, op
);
4453 m
->body() << formatv(" return odsRegions.drop_front({0});", i
);
4457 auto *m
= genericAdaptorBase
.addInlineMethod("::mlir::Region &", name
);
4458 ERROR_IF_PRUNED(m
, "Adaptor::" + name
, op
);
4459 m
->body() << formatv(" return *odsRegions[{0}];", i
);
4461 if (numRegions
> 0) {
4462 // Any invalid overlap for `getRegions` will have been diagnosed before
4464 if (auto *m
= genericAdaptorBase
.addInlineMethod("::mlir::RegionRange",
4466 m
->body() << " return odsRegions;";
4469 StringRef genericAdaptorClassName
= genericAdaptor
.getClassName();
4470 adaptor
.addParent(ParentClass(genericAdaptorClassName
))
4471 .addTemplateParam("::mlir::ValueRange");
4472 adaptor
.declare
<VisibilityDeclaration
>(Visibility::Public
);
4473 adaptor
.declare
<UsingDeclaration
>(genericAdaptorClassName
+
4474 "::" + genericAdaptorClassName
);
4476 // Constructor taking the Op as single parameter.
4478 adaptor
.addConstructor(MethodParameter(op
.getCppClassName(), "op"));
4479 constructor
->addMemberInitializer(genericAdaptorClassName
,
4480 "op->getOperands(), op");
4483 // Add verification function.
4486 genericAdaptorBase
.finalize();
4487 genericAdaptor
.finalize();
4491 void OpOperandAdaptorEmitter::addVerification() {
4492 auto *method
= adaptor
.addMethod("::llvm::LogicalResult", "verify",
4493 MethodParameter("::mlir::Location", "loc"));
4494 ERROR_IF_PRUNED(method
, "verify", op
);
4495 auto &body
= method
->body();
4496 bool useProperties
= emitHelper
.hasProperties();
4498 FmtContext verifyCtx
;
4499 populateSubstitutions(emitHelper
, verifyCtx
);
4500 genAttributeVerifier(emitHelper
, verifyCtx
, body
, staticVerifierEmitter
,
4503 body
<< " return ::mlir::success();";
4506 void OpOperandAdaptorEmitter::emitDecl(
4508 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
,
4510 OpOperandAdaptorEmitter
emitter(op
, staticVerifierEmitter
);
4512 NamespaceEmitter
ns(os
, "detail");
4513 emitter
.genericAdaptorBase
.writeDeclTo(os
);
4515 emitter
.genericAdaptor
.writeDeclTo(os
);
4516 emitter
.adaptor
.writeDeclTo(os
);
4519 void OpOperandAdaptorEmitter::emitDef(
4521 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
,
4523 OpOperandAdaptorEmitter
emitter(op
, staticVerifierEmitter
);
4525 NamespaceEmitter
ns(os
, "detail");
4526 emitter
.genericAdaptorBase
.writeDefTo(os
);
4528 emitter
.genericAdaptor
.writeDefTo(os
);
4529 emitter
.adaptor
.writeDefTo(os
);
4532 /// Emit the class declarations or definitions for the given op defs.
4534 emitOpClasses(const RecordKeeper
&records
,
4535 const std::vector
<const Record
*> &defs
, raw_ostream
&os
,
4536 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
,
4541 for (auto *def
: defs
) {
4545 NamespaceEmitter
emitter(os
, op
.getCppNamespace());
4546 os
<< formatv(opCommentHeader
, op
.getQualCppClassName(),
4548 OpOperandAdaptorEmitter::emitDecl(op
, staticVerifierEmitter
, os
);
4549 OpEmitter::emitDecl(op
, os
, staticVerifierEmitter
);
4551 // Emit the TypeID explicit specialization to have a single definition.
4552 if (!op
.getCppNamespace().empty())
4553 os
<< "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << op
.getCppNamespace()
4554 << "::" << op
.getCppClassName() << ")\n\n";
4557 NamespaceEmitter
emitter(os
, op
.getCppNamespace());
4558 os
<< formatv(opCommentHeader
, op
.getQualCppClassName(), "definitions");
4559 OpOperandAdaptorEmitter::emitDef(op
, staticVerifierEmitter
, os
);
4560 OpEmitter::emitDef(op
, os
, staticVerifierEmitter
);
4562 // Emit the TypeID explicit specialization to have a single definition.
4563 if (!op
.getCppNamespace().empty())
4564 os
<< "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << op
.getCppNamespace()
4565 << "::" << op
.getCppClassName() << ")\n\n";
4570 /// Emit the declarations for the provided op classes.
4571 static void emitOpClassDecls(const RecordKeeper
&records
,
4572 const std::vector
<const Record
*> &defs
,
4574 // First emit forward declaration for each class, this allows them to refer
4575 // to each others in traits for example.
4576 for (auto *def
: defs
) {
4578 NamespaceEmitter
emitter(os
, op
.getCppNamespace());
4579 os
<< "class " << op
.getCppClassName() << ";\n";
4582 // Emit the op class declarations.
4583 IfDefScope
scope("GET_OP_CLASSES", os
);
4586 StaticVerifierFunctionEmitter
staticVerifierEmitter(os
, records
);
4587 staticVerifierEmitter
.collectOpConstraints(defs
);
4588 emitOpClasses(records
, defs
, os
, staticVerifierEmitter
,
4592 /// Emit the definitions for the provided op classes.
4593 static void emitOpClassDefs(const RecordKeeper
&records
,
4594 ArrayRef
<const Record
*> defs
, raw_ostream
&os
,
4595 StringRef constraintPrefix
= "") {
4599 // Generate all of the locally instantiated methods first.
4600 StaticVerifierFunctionEmitter
staticVerifierEmitter(os
, records
,
4602 os
<< formatv(opCommentHeader
, "Local Utility Method", "Definitions");
4603 staticVerifierEmitter
.collectOpConstraints(defs
);
4604 staticVerifierEmitter
.emitOpConstraints(defs
);
4606 // Emit the classes.
4607 emitOpClasses(records
, defs
, os
, staticVerifierEmitter
,
4608 /*emitDecl=*/false);
4611 /// Emit op declarations for all op records.
4612 static bool emitOpDecls(const RecordKeeper
&records
, raw_ostream
&os
) {
4613 emitSourceFileHeader("Op Declarations", os
, records
);
4615 std::vector
<const Record
*> defs
= getRequestedOpDefinitions(records
);
4616 emitOpClassDecls(records
, defs
, os
);
4618 // If we are generating sharded op definitions, emit the sharded op
4619 // registration hooks.
4620 SmallVector
<ArrayRef
<const Record
*>, 4> shardedDefs
;
4621 shardOpDefinitions(defs
, shardedDefs
);
4622 if (defs
.empty() || shardedDefs
.size() <= 1)
4625 Dialect dialect
= Operator(defs
.front()).getDialect();
4626 NamespaceEmitter
ns(os
, dialect
);
4628 const char *const opRegistrationHook
=
4629 "void register{0}Operations{1}({2}::{0} *dialect);\n";
4630 os
<< formatv(opRegistrationHook
, dialect
.getCppClassName(), "",
4631 dialect
.getCppNamespace());
4632 for (unsigned i
= 0; i
< shardedDefs
.size(); ++i
) {
4633 os
<< formatv(opRegistrationHook
, dialect
.getCppClassName(), i
,
4634 dialect
.getCppNamespace());
4640 /// Generate the dialect op registration hook and the op class definitions for a
4642 static void emitOpDefShard(const RecordKeeper
&records
,
4643 ArrayRef
<const Record
*> defs
,
4644 const Dialect
&dialect
, unsigned shardIndex
,
4645 unsigned shardCount
, raw_ostream
&os
) {
4646 std::string shardGuard
= "GET_OP_DEFS_";
4647 std::string indexStr
= std::to_string(shardIndex
);
4648 shardGuard
+= indexStr
;
4649 IfDefScope
scope(shardGuard
, os
);
4651 // Emit the op registration hook in the first shard.
4652 const char *const opRegistrationHook
=
4653 "void {0}::register{1}Operations{2}({0}::{1} *dialect) {{\n";
4654 if (shardIndex
== 0) {
4655 os
<< formatv(opRegistrationHook
, dialect
.getCppNamespace(),
4656 dialect
.getCppClassName(), "");
4657 for (unsigned i
= 0; i
< shardCount
; ++i
) {
4658 os
<< formatv(" {0}::register{1}Operations{2}(dialect);\n",
4659 dialect
.getCppNamespace(), dialect
.getCppClassName(), i
);
4664 // Generate the per-shard op registration hook.
4665 os
<< formatv(opCommentHeader
, dialect
.getCppClassName(),
4666 "Op Registration Hook")
4667 << formatv(opRegistrationHook
, dialect
.getCppNamespace(),
4668 dialect
.getCppClassName(), shardIndex
);
4669 for (const Record
*def
: defs
) {
4670 os
<< formatv(" ::mlir::RegisteredOperationName::insert<{0}>(*dialect);\n",
4671 Operator(def
).getQualCppClassName());
4675 // Generate the per-shard op definitions.
4676 emitOpClassDefs(records
, defs
, os
, indexStr
);
4679 /// Emit op definitions for all op records.
4680 static bool emitOpDefs(const RecordKeeper
&records
, raw_ostream
&os
) {
4681 emitSourceFileHeader("Op Definitions", os
, records
);
4683 std::vector
<const Record
*> defs
= getRequestedOpDefinitions(records
);
4684 SmallVector
<ArrayRef
<const Record
*>, 4> shardedDefs
;
4685 shardOpDefinitions(defs
, shardedDefs
);
4687 // If no shard was requested, emit the regular op list and class definitions.
4688 if (shardedDefs
.size() == 1) {
4690 IfDefScope
scope("GET_OP_LIST", os
);
4693 [&](const Record
*def
) { os
<< Operator(def
).getQualCppClassName(); },
4697 IfDefScope
scope("GET_OP_CLASSES", os
);
4698 emitOpClassDefs(records
, defs
, os
);
4705 Dialect dialect
= Operator(defs
.front()).getDialect();
4706 for (auto [idx
, value
] : llvm::enumerate(shardedDefs
)) {
4707 emitOpDefShard(records
, value
, dialect
, idx
, shardedDefs
.size(), os
);
4712 static mlir::GenRegistration
4713 genOpDecls("gen-op-decls", "Generate op declarations",
4714 [](const RecordKeeper
&records
, raw_ostream
&os
) {
4715 return emitOpDecls(records
, os
);
4718 static mlir::GenRegistration
genOpDefs("gen-op-defs", "Generate op definitions",
4719 [](const RecordKeeper
&records
,
4721 return emitOpDefs(records
, os
);