1 //===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // OpDefinitionsGen uses the description of operations to generate C++
10 // definitions for ops.
12 //===----------------------------------------------------------------------===//
15 #include "OpFormatGen.h"
16 #include "OpGenHelpers.h"
17 #include "mlir/TableGen/Argument.h"
18 #include "mlir/TableGen/Attribute.h"
19 #include "mlir/TableGen/Class.h"
20 #include "mlir/TableGen/CodeGenHelpers.h"
21 #include "mlir/TableGen/Format.h"
22 #include "mlir/TableGen/GenInfo.h"
23 #include "mlir/TableGen/Interfaces.h"
24 #include "mlir/TableGen/Operator.h"
25 #include "mlir/TableGen/Property.h"
26 #include "mlir/TableGen/SideEffects.h"
27 #include "mlir/TableGen/Trait.h"
28 #include "llvm/ADT/BitVector.h"
29 #include "llvm/ADT/MapVector.h"
30 #include "llvm/ADT/Sequence.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/StringExtras.h"
33 #include "llvm/ADT/StringSet.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/ErrorHandling.h"
36 #include "llvm/Support/Signals.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include "llvm/TableGen/Error.h"
39 #include "llvm/TableGen/Record.h"
40 #include "llvm/TableGen/TableGenBackend.h"
42 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
46 using namespace mlir::tblgen
;
48 static const char *const tblgenNamePrefix
= "tblgen_";
49 static const char *const generatedArgName
= "odsArg";
50 static const char *const odsBuilder
= "odsBuilder";
51 static const char *const builderOpState
= "odsState";
52 static const char *const propertyStorage
= "propStorage";
53 static const char *const propertyValue
= "propValue";
54 static const char *const propertyAttr
= "propAttr";
55 static const char *const propertyDiag
= "emitError";
57 /// The names of the implicit attributes that contain variadic operand and
58 /// result segment sizes.
59 static const char *const operandSegmentAttrName
= "operandSegmentSizes";
60 static const char *const resultSegmentAttrName
= "resultSegmentSizes";
62 /// Code for an Op to lookup an attribute. Uses cached identifiers and subrange
65 /// {0}: Code snippet to get the attribute's name or identifier.
66 /// {1}: The lower bound on the sorted subrange.
67 /// {2}: The upper bound on the sorted subrange.
68 /// {3}: Code snippet to get the array of named attributes.
69 /// {4}: "Named" to get the named attribute.
70 static const char *const subrangeGetAttr
=
71 "::mlir::impl::get{4}AttrFromSortedRange({3}.begin() + {1}, {3}.end() - "
74 /// The logic to calculate the actual value range for a declared operand/result
75 /// of an op with variadic operands/results. Note that this logic is not for
76 /// general use; it assumes all variadic operands/results must have the same
79 /// {0}: The list of whether each declared operand/result is variadic.
80 /// {1}: The total number of non-variadic operands/results.
81 /// {2}: The total number of variadic operands/results.
82 /// {3}: The total number of actual values.
83 /// {4}: "operand" or "result".
84 static const char *const sameVariadicSizeValueRangeCalcCode
= R
"(
85 bool isVariadic[] = {{{0}};
86 int prevVariadicCount = 0;
87 for (unsigned i = 0; i < index; ++i)
88 if (isVariadic[i]) ++prevVariadicCount;
90 // Calculate how many dynamic values a static variadic {4} corresponds to.
91 // This assumes all static variadic {4}s have the same dynamic value count.
92 int variadicSize = ({3} - {1}) / {2};
93 // `index` passed in as the parameter is the static index which counts each
94 // {4} (variadic or not) as size 1. So here for each previous static variadic
95 // {4}, we need to offset by (variadicSize - 1) to get where the dynamic
96 // value pack for this static {4} starts.
97 int start = index + (variadicSize - 1) * prevVariadicCount;
98 int size = isVariadic[index] ? variadicSize : 1;
99 return {{start, size};
102 /// The logic to calculate the actual value range for a declared operand/result
103 /// of an op with variadic operands/results. Note that this logic is assumes
104 /// the op has an attribute specifying the size of each operand/result segment
105 /// (variadic or not).
106 static const char *const attrSizedSegmentValueRangeCalcCode
= R
"(
108 for (unsigned i = 0; i < index; ++i)
109 start += sizeAttr[i];
110 return {start, sizeAttr[index]};
112 /// The code snippet to initialize the sizes for the value range calculation.
114 /// {0}: The code to get the attribute.
115 static const char *const adapterSegmentSizeAttrInitCode
= R
"(
116 assert({0} && "missing segment size attribute
for op
");
117 auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>({0});
119 static const char *const adapterSegmentSizeAttrInitCodeProperties
= R
"(
120 ::llvm::ArrayRef<int32_t> sizeAttr = {0};
123 /// The code snippet to initialize the sizes for the value range calculation.
125 /// {0}: The code to get the attribute.
126 static const char *const opSegmentSizeAttrInitCode
= R
"(
127 auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>({0});
130 /// The logic to calculate the actual value range for a declared operand
131 /// of an op with variadic of variadic operands within the OpAdaptor.
133 /// {0}: The name of the segment attribute.
134 /// {1}: The index of the main operand.
135 /// {2}: The range type of adaptor.
136 static const char *const variadicOfVariadicAdaptorCalcCode
= R
"(
137 auto tblgenTmpOperands = getODSOperands({1});
140 ::llvm::SmallVector<{2}> tblgenTmpOperandGroups;
141 for (int i = 0, e = sizes.size(); i < e; ++i) {{
142 tblgenTmpOperandGroups.push_back(tblgenTmpOperands.take_front(sizes[i]));
143 tblgenTmpOperands = tblgenTmpOperands.drop_front(sizes[i]);
145 return tblgenTmpOperandGroups;
148 /// The logic to build a range of either operand or result values.
150 /// {0}: The begin iterator of the actual values.
151 /// {1}: The call to generate the start and length of the value range.
152 static const char *const valueRangeReturnCode
= R
"(
153 auto valueRange = {1};
154 return {{std::next({0}, valueRange.first),
155 std::next({0}, valueRange.first + valueRange.second)};
158 /// Read operand/result segment_size from bytecode.
159 static const char *const readBytecodeSegmentSizeNative
= R
"(
160 if ($_reader.getBytecodeVersion() >= /*kNativePropertiesODSSegmentSize=*/6)
161 return $_reader.readSparseArray(::llvm::MutableArrayRef($_storage));
164 static const char *const readBytecodeSegmentSizeLegacy
= R
"(
165 if ($_reader.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) {
166 auto &$_storage = prop.$_propName;
167 ::mlir::DenseI32ArrayAttr attr;
168 if (::mlir::failed($_reader.readAttribute(attr))) return ::mlir::failure();
169 if (attr.size() > static_cast<int64_t>(sizeof($_storage) / sizeof(int32_t))) {
170 $_reader.emitError("size mismatch
for operand
/result_segment_size
");
171 return ::mlir::failure();
173 ::llvm::copy(::llvm::ArrayRef<int32_t>(attr), $_storage.begin());
177 /// Write operand/result segment_size to bytecode.
178 static const char *const writeBytecodeSegmentSizeNative
= R
"(
179 if ($_writer.getBytecodeVersion() >= /*kNativePropertiesODSSegmentSize=*/6)
180 $_writer.writeSparseArray(::llvm::ArrayRef($_storage));
183 /// Write operand/result segment_size to bytecode.
184 static const char *const writeBytecodeSegmentSizeLegacy
= R
"(
185 if ($_writer.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) {
186 auto &$_storage = prop.$_propName;
187 $_writer.writeAttribute(::mlir::DenseI32ArrayAttr::get($_ctxt, $_storage));
191 /// A header for indicating code sections.
193 /// {0}: Some text, or a class name.
195 static const char *const opCommentHeader
= R
"(
196 //===----------------------------------------------------------------------===//
198 //===----------------------------------------------------------------------===//
202 //===----------------------------------------------------------------------===//
203 // Utility structs and functions
204 //===----------------------------------------------------------------------===//
206 // Replaces all occurrences of `match` in `str` with `substitute`.
207 static std::string
replaceAllSubstrs(std::string str
, const std::string
&match
,
208 const std::string
&substitute
) {
209 std::string::size_type scanLoc
= 0, matchLoc
= std::string::npos
;
210 while ((matchLoc
= str
.find(match
, scanLoc
)) != std::string::npos
) {
211 str
= str
.replace(matchLoc
, match
.size(), substitute
);
212 scanLoc
= matchLoc
+ substitute
.size();
217 // Returns whether the record has a value of the given name that can be returned
218 // via getValueAsString.
219 static inline bool hasStringAttribute(const Record
&record
,
220 StringRef fieldName
) {
221 auto *valueInit
= record
.getValueInit(fieldName
);
222 return isa
<StringInit
>(valueInit
);
225 static std::string
getArgumentName(const Operator
&op
, int index
) {
226 const auto &operand
= op
.getOperand(index
);
227 if (!operand
.name
.empty())
228 return std::string(operand
.name
);
229 return std::string(formatv("{0}_{1}", generatedArgName
, index
));
232 // Returns true if we can use unwrapped value for the given `attr` in builders.
233 static bool canUseUnwrappedRawValue(const tblgen::Attribute
&attr
) {
234 return attr
.getReturnType() != attr
.getStorageType() &&
235 // We need to wrap the raw value into an attribute in the builder impl
236 // so we need to make sure that the attribute specifies how to do that.
237 !attr
.getConstBuilderTemplate().empty();
240 /// Build an attribute from a parameter value using the constant builder.
241 static std::string
constBuildAttrFromParam(const tblgen::Attribute
&attr
,
243 StringRef paramName
) {
244 std::string builderTemplate
= attr
.getConstBuilderTemplate().str();
246 // For StringAttr, its constant builder call will wrap the input in
247 // quotes, which is correct for normal string literals, but incorrect
248 // here given we use function arguments. So we need to strip the
250 if (StringRef(builderTemplate
).contains("\"$0\""))
251 builderTemplate
= replaceAllSubstrs(builderTemplate
, "\"$0\"", "$0");
253 return tgfmt(builderTemplate
, &fctx
, paramName
).str();
257 /// Metadata on a registered attribute. Given that attributes are stored in
258 /// sorted order on operations, we can use information from ODS to deduce the
259 /// number of required attributes less and and greater than each attribute,
260 /// allowing us to search only a subrange of the attributes in ODS-generated
262 struct AttributeMetadata
{
263 /// The attribute name.
265 /// Whether the attribute is required.
267 /// The ODS attribute constraint. Not present for implicit attributes.
268 std::optional
<Attribute
> constraint
;
269 /// The number of required attributes less than this attribute.
270 unsigned lowerBound
= 0;
271 /// The number of required attributes greater than this attribute.
272 unsigned upperBound
= 0;
275 /// Helper class to select between OpAdaptor and Op code templates.
276 class OpOrAdaptorHelper
{
278 OpOrAdaptorHelper(const Operator
&op
, bool emitForOp
)
279 : op(op
), emitForOp(emitForOp
) {
280 computeAttrMetadata();
283 /// Object that wraps a functor in a stream operator for interop with
287 template <typename Functor
>
288 Formatter(Functor
&&func
) : func(std::forward
<Functor
>(func
)) {}
290 std::string
str() const {
292 llvm::raw_string_ostream
os(result
);
298 std::function
<raw_ostream
&(raw_ostream
&)> func
;
300 friend raw_ostream
&operator<<(raw_ostream
&os
, const Formatter
&fmt
) {
305 // Generate code for getting an attribute.
306 Formatter
getAttr(StringRef attrName
, bool isNamed
= false) const {
307 assert(attrMetadata
.count(attrName
) && "expected attribute metadata");
308 return [this, attrName
, isNamed
](raw_ostream
&os
) -> raw_ostream
& {
309 const AttributeMetadata
&attr
= attrMetadata
.find(attrName
)->second
;
310 if (hasProperties()) {
312 return os
<< "getProperties()." << attrName
;
314 return os
<< formatv(subrangeGetAttr
, getAttrName(attrName
),
315 attr
.lowerBound
, attr
.upperBound
, getAttrRange(),
316 isNamed
? "Named" : "");
320 // Generate code for getting the name of an attribute.
321 Formatter
getAttrName(StringRef attrName
) const {
322 return [this, attrName
](raw_ostream
&os
) -> raw_ostream
& {
324 return os
<< op
.getGetterName(attrName
) << "AttrName()";
325 return os
<< formatv("{0}::{1}AttrName(*odsOpName)", op
.getCppClassName(),
326 op
.getGetterName(attrName
));
330 // Get the code snippet for getting the named attribute range.
331 StringRef
getAttrRange() const {
332 return emitForOp
? "(*this)->getAttrs()" : "odsAttrs";
335 // Get the prefix code for emitting an error.
336 Formatter
emitErrorPrefix() const {
337 return [this](raw_ostream
&os
) -> raw_ostream
& {
339 return os
<< "emitOpError(";
340 return os
<< formatv("emitError(loc, \"'{0}' op \"",
341 op
.getOperationName());
345 // Get the call to get an operand or segment of operands.
346 Formatter
getOperand(unsigned index
) const {
347 return [this, index
](raw_ostream
&os
) -> raw_ostream
& {
348 return os
<< formatv(op
.getOperand(index
).isVariadic()
349 ? "this->getODSOperands({0})"
350 : "(*this->getODSOperands({0}).begin())",
355 // Get the call to get a result of segment of results.
356 Formatter
getResult(unsigned index
) const {
357 return [this, index
](raw_ostream
&os
) -> raw_ostream
& {
359 return os
<< "<no results should be generated>";
360 return os
<< formatv(op
.getResult(index
).isVariadic()
361 ? "this->getODSResults({0})"
362 : "(*this->getODSResults({0}).begin())",
367 // Return whether an op instance is available.
368 bool isEmittingForOp() const { return emitForOp
; }
370 // Return the ODS operation wrapper.
371 const Operator
&getOp() const { return op
; }
373 // Get the attribute metadata sorted by name.
374 const llvm::MapVector
<StringRef
, AttributeMetadata
> &getAttrMetadata() const {
378 /// Returns whether to emit a `Properties` struct for this operation or not.
379 bool hasProperties() const {
380 if (!op
.getProperties().empty())
382 if (!op
.getDialect().usePropertiesForAttributes())
384 if (op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments") ||
385 op
.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
387 return llvm::any_of(getAttrMetadata(),
388 [](const std::pair
<StringRef
, AttributeMetadata
> &it
) {
389 return !it
.second
.constraint
||
390 !it
.second
.constraint
->isDerivedAttr();
394 std::optional
<NamedProperty
> &getOperandSegmentsSize() {
395 return operandSegmentsSize
;
398 std::optional
<NamedProperty
> &getResultSegmentsSize() {
399 return resultSegmentsSize
;
402 uint32_t getOperandSegmentSizesLegacyIndex() {
403 return operandSegmentSizesLegacyIndex
;
406 uint32_t getResultSegmentSizesLegacyIndex() {
407 return resultSegmentSizesLegacyIndex
;
411 // Compute the attribute metadata.
412 void computeAttrMetadata();
414 // The operation ODS wrapper.
416 // True if code is being generate for an op. False for an adaptor.
417 const bool emitForOp
;
419 // The attribute metadata, mapped by name.
420 llvm::MapVector
<StringRef
, AttributeMetadata
> attrMetadata
;
423 std::optional
<NamedProperty
> operandSegmentsSize
;
424 std::string operandSegmentsSizeStorage
;
425 std::optional
<NamedProperty
> resultSegmentsSize
;
426 std::string resultSegmentsSizeStorage
;
428 // Indices to store the position in the emission order of the operand/result
429 // segment sizes attribute if emitted as part of the properties for legacy
430 // bytecode encodings, i.e. versions less than 6.
431 uint32_t operandSegmentSizesLegacyIndex
= 0;
432 uint32_t resultSegmentSizesLegacyIndex
= 0;
434 // The number of required attributes.
435 unsigned numRequired
;
440 void OpOrAdaptorHelper::computeAttrMetadata() {
441 // Enumerate the attribute names of this op, ensuring the attribute names are
442 // unique in case implicit attributes are explicitly registered.
443 for (const NamedAttribute
&namedAttr
: op
.getAttributes()) {
444 Attribute attr
= namedAttr
.attr
;
446 attr
.hasDefaultValue() || attr
.isOptional() || attr
.isDerivedAttr();
448 {namedAttr
.name
, AttributeMetadata
{namedAttr
.name
, !isOptional
, attr
}});
451 auto makeProperty
= [&](StringRef storageType
) {
453 /*storageType=*/storageType
,
454 /*interfaceType=*/"::llvm::ArrayRef<int32_t>",
455 /*convertFromStorageCall=*/"$_storage",
456 /*assignToStorageCall=*/
457 "::llvm::copy($_value, $_storage.begin())",
458 /*convertToAttributeCall=*/
459 "::mlir::DenseI32ArrayAttr::get($_ctxt, $_storage)",
460 /*convertFromAttributeCall=*/
461 "return convertFromAttribute($_storage, $_attr, $_diag);",
462 /*readFromMlirBytecodeCall=*/readBytecodeSegmentSizeNative
,
463 /*writeToMlirBytecodeCall=*/writeBytecodeSegmentSizeNative
,
464 /*hashPropertyCall=*/
465 "::llvm::hash_combine_range(std::begin($_storage), "
466 "std::end($_storage));",
467 /*StringRef defaultValue=*/"");
469 // Include key attributes from several traits as implicitly registered.
470 if (op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
471 if (op
.getDialect().usePropertiesForAttributes()) {
472 operandSegmentsSizeStorage
=
473 llvm::formatv("std::array<int32_t, {0}>", op
.getNumOperands());
474 operandSegmentsSize
= {"operandSegmentSizes",
475 makeProperty(operandSegmentsSizeStorage
)};
478 {operandSegmentAttrName
, AttributeMetadata
{operandSegmentAttrName
,
480 /*attr=*/std::nullopt
}});
483 if (op
.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
484 if (op
.getDialect().usePropertiesForAttributes()) {
485 resultSegmentsSizeStorage
=
486 llvm::formatv("std::array<int32_t, {0}>", op
.getNumResults());
487 resultSegmentsSize
= {"resultSegmentSizes",
488 makeProperty(resultSegmentsSizeStorage
)};
491 {resultSegmentAttrName
,
492 AttributeMetadata
{resultSegmentAttrName
, /*isRequired=*/true,
493 /*attr=*/std::nullopt
}});
497 // Store the metadata in sorted order.
498 SmallVector
<AttributeMetadata
> sortedAttrMetadata
=
499 llvm::to_vector(llvm::make_second_range(attrMetadata
.takeVector()));
500 llvm::sort(sortedAttrMetadata
,
501 [](const AttributeMetadata
&lhs
, const AttributeMetadata
&rhs
) {
502 return lhs
.attrName
< rhs
.attrName
;
505 // Store the position of the legacy operand_segment_sizes /
506 // result_segment_sizes so we can emit a backward compatible property readers
508 StringRef legacyOperandSegmentSizeName
=
509 StringLiteral("operand_segment_sizes");
510 StringRef legacyResultSegmentSizeName
= StringLiteral("result_segment_sizes");
511 operandSegmentSizesLegacyIndex
= 0;
512 resultSegmentSizesLegacyIndex
= 0;
513 for (auto item
: sortedAttrMetadata
) {
514 if (item
.attrName
< legacyOperandSegmentSizeName
)
515 ++operandSegmentSizesLegacyIndex
;
516 if (item
.attrName
< legacyResultSegmentSizeName
)
517 ++resultSegmentSizesLegacyIndex
;
520 // Compute the subrange bounds for each attribute.
522 for (AttributeMetadata
&attr
: sortedAttrMetadata
) {
523 attr
.lowerBound
= numRequired
;
524 numRequired
+= attr
.isRequired
;
526 for (AttributeMetadata
&attr
: sortedAttrMetadata
)
527 attr
.upperBound
= numRequired
- attr
.lowerBound
- attr
.isRequired
;
529 // Store the results back into the map.
530 for (const AttributeMetadata
&attr
: sortedAttrMetadata
)
531 attrMetadata
.insert({attr
.attrName
, attr
});
534 //===----------------------------------------------------------------------===//
536 //===----------------------------------------------------------------------===//
539 // Helper class to emit a record into the given output stream.
541 using ConstArgument
=
542 llvm::PointerUnion
<const AttributeMetadata
*, const NamedProperty
*>;
546 emitDecl(const Operator
&op
, raw_ostream
&os
,
547 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
);
549 emitDef(const Operator
&op
, raw_ostream
&os
,
550 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
);
553 OpEmitter(const Operator
&op
,
554 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
);
556 void emitDecl(raw_ostream
&os
);
557 void emitDef(raw_ostream
&os
);
559 // Generate methods for accessing the attribute names of this operation.
560 void genAttrNameGetters();
562 // Generates the OpAsmOpInterface for this operation if possible.
563 void genOpAsmInterface();
565 // Generates the `getOperationName` method for this op.
566 void genOpNameGetter();
568 // Generates code to manage the properties, if any!
569 void genPropertiesSupport();
571 // Generates code to manage the encoding of properties to bytecode.
573 genPropertiesSupportForBytecode(ArrayRef
<ConstArgument
> attrOrProperties
);
575 // Generates getters for the attributes.
576 void genAttrGetters();
578 // Generates setter for the attributes.
579 void genAttrSetters();
581 // Generates removers for optional attributes.
582 void genOptionalAttrRemovers();
584 // Generates getters for named operands.
585 void genNamedOperandGetters();
587 // Generates setters for named operands.
588 void genNamedOperandSetters();
590 // Generates getters for named results.
591 void genNamedResultGetters();
593 // Generates getters for named regions.
594 void genNamedRegionGetters();
596 // Generates getters for named successors.
597 void genNamedSuccessorGetters();
599 // Generates the method to populate default attributes.
600 void genPopulateDefaultAttributes();
602 // Generates builder methods for the operation.
605 // Generates the build() method that takes each operand/attribute
606 // as a stand-alone parameter.
607 void genSeparateArgParamBuilder();
609 // Generates the build() method that takes each operand/attribute as a
610 // stand-alone parameter. The generated build() method uses first operand's
611 // type as all results' types.
612 void genUseOperandAsResultTypeSeparateParamBuilder();
614 // Generates the build() method that takes all operands/attributes
615 // collectively as one parameter. The generated build() method uses first
616 // operand's type as all results' types.
617 void genUseOperandAsResultTypeCollectiveParamBuilder();
619 // Generates the build() method that takes aggregate operands/attributes
620 // parameters. This build() method uses inferred types as result types.
621 // Requires: The type needs to be inferable via InferTypeOpInterface.
622 void genInferredTypeCollectiveParamBuilder();
624 // Generates the build() method that takes each operand/attribute as a
625 // stand-alone parameter. The generated build() method uses first attribute's
626 // type as all result's types.
627 void genUseAttrAsResultTypeBuilder();
629 // Generates the build() method that takes all result types collectively as
630 // one parameter. Similarly for operands and attributes.
631 void genCollectiveParamBuilder();
633 // The kind of parameter to generate for result types in builders.
634 enum class TypeParamKind
{
635 None
, // No result type in parameter list.
636 Separate
, // A separate parameter for each result type.
637 Collective
, // An ArrayRef<Type> for all result types.
640 // The kind of parameter to generate for attributes in builders.
641 enum class AttrParamKind
{
642 WrappedAttr
, // A wrapped MLIR Attribute instance.
643 UnwrappedValue
, // A raw value without MLIR Attribute wrapper.
646 // Builds the parameter list for build() method of this op. This method writes
647 // to `paramList` the comma-separated parameter list and updates
648 // `resultTypeNames` with the names for parameters for specifying result
649 // types. `inferredAttributes` is populated with any attributes that are
650 // elided from the build list. The given `typeParamKind` and `attrParamKind`
651 // controls how result types and attributes are placed in the parameter list.
652 void buildParamList(SmallVectorImpl
<MethodParameter
> ¶mList
,
653 llvm::StringSet
<> &inferredAttributes
,
654 SmallVectorImpl
<std::string
> &resultTypeNames
,
655 TypeParamKind typeParamKind
,
656 AttrParamKind attrParamKind
= AttrParamKind::WrappedAttr
);
658 // Adds op arguments and regions into operation state for build() methods.
660 genCodeForAddingArgAndRegionForBuilder(MethodBody
&body
,
661 llvm::StringSet
<> &inferredAttributes
,
662 bool isRawValueAttr
= false);
664 // Generates canonicalizer declaration for the operation.
665 void genCanonicalizerDecls();
667 // Generates the folder declaration for the operation.
668 void genFolderDecls();
670 // Generates the parser for the operation.
673 // Generates the printer for the operation.
676 // Generates verify method for the operation.
679 // Generates custom verify methods for the operation.
680 void genCustomVerifier();
682 // Generates verify statements for operands and results in the operation.
683 // The generated code will be attached to `body`.
684 void genOperandResultVerifier(MethodBody
&body
,
685 Operator::const_value_range values
,
686 StringRef valueKind
);
688 // Generates verify statements for regions in the operation.
689 // The generated code will be attached to `body`.
690 void genRegionVerifier(MethodBody
&body
);
692 // Generates verify statements for successors in the operation.
693 // The generated code will be attached to `body`.
694 void genSuccessorVerifier(MethodBody
&body
);
696 // Generates the traits used by the object.
699 // Generate the OpInterface methods for all interfaces.
700 void genOpInterfaceMethods();
702 // Generate op interface methods for the given interface.
703 void genOpInterfaceMethods(const tblgen::InterfaceTrait
*trait
);
705 // Generate op interface method for the given interface method. If
706 // 'declaration' is true, generates a declaration, else a definition.
707 Method
*genOpInterfaceMethod(const tblgen::InterfaceMethod
&method
,
708 bool declaration
= true);
710 // Generate the side effect interface methods.
711 void genSideEffectInterfaceMethods();
713 // Generate the type inference interface methods.
714 void genTypeInterfaceMethods();
717 // The TableGen record for this op.
718 // TODO: OpEmitter should not have a Record directly,
719 // it should rather go through the Operator for better abstraction.
722 // The wrapper operator class for querying information from this op.
725 // The C++ code builder for this op
728 // The format context for verification code generation.
729 FmtContext verifyCtx
;
731 // The emitter containing all of the locally emitted verification functions.
732 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
;
734 // Helper for emitting op code.
735 OpOrAdaptorHelper emitHelper
;
740 // Populate the format context `ctx` with substitutions of attributes, operands
742 static void populateSubstitutions(const OpOrAdaptorHelper
&emitHelper
,
744 // Populate substitutions for attributes.
745 auto &op
= emitHelper
.getOp();
746 for (const auto &namedAttr
: op
.getAttributes())
747 ctx
.addSubst(namedAttr
.name
,
748 emitHelper
.getOp().getGetterName(namedAttr
.name
) + "()");
750 // Populate substitutions for named operands.
751 for (int i
= 0, e
= op
.getNumOperands(); i
< e
; ++i
) {
752 auto &value
= op
.getOperand(i
);
753 if (!value
.name
.empty())
754 ctx
.addSubst(value
.name
, emitHelper
.getOperand(i
).str());
757 // Populate substitutions for results.
758 for (int i
= 0, e
= op
.getNumResults(); i
< e
; ++i
) {
759 auto &value
= op
.getResult(i
);
760 if (!value
.name
.empty())
761 ctx
.addSubst(value
.name
, emitHelper
.getResult(i
).str());
765 /// Generate verification on native traits requiring attributes.
766 static void genNativeTraitAttrVerifier(MethodBody
&body
,
767 const OpOrAdaptorHelper
&emitHelper
) {
768 // Check that the variadic segment sizes attribute exists and contains the
769 // expected number of elements.
771 // {0}: Attribute name.
772 // {1}: Expected number of elements.
773 // {2}: "operand" or "result".
774 // {3}: Emit error prefix.
775 const char *const checkAttrSizedValueSegmentsCode
= R
"(
777 auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>(tblgen_{0});
778 auto numElements = sizeAttr.asArrayRef().size();
779 if (numElements != {1})
780 return {3}"'{0}' attribute
for specifying
{2} segments must have
{1} "
781 "elements
, but got
") << numElements;
785 // Verify a few traits first so that we can use getODSOperands() and
786 // getODSResults() in the rest of the verifier.
787 auto &op
= emitHelper
.getOp();
788 if (!op
.getDialect().usePropertiesForAttributes()) {
789 if (op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
790 body
<< formatv(checkAttrSizedValueSegmentsCode
, operandSegmentAttrName
,
791 op
.getNumOperands(), "operand",
792 emitHelper
.emitErrorPrefix());
794 if (op
.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
795 body
<< formatv(checkAttrSizedValueSegmentsCode
, resultSegmentAttrName
,
796 op
.getNumResults(), "result",
797 emitHelper
.emitErrorPrefix());
802 // Return true if a verifier can be emitted for the attribute: it is not a
803 // derived attribute, it has a predicate, its condition is not empty, and, for
804 // adaptors, the condition does not reference the op.
805 static bool canEmitAttrVerifier(Attribute attr
, bool isEmittingForOp
) {
806 if (attr
.isDerivedAttr())
808 Pred pred
= attr
.getPredicate();
811 std::string condition
= pred
.getCondition();
812 return !condition
.empty() &&
813 (!StringRef(condition
).contains("$_op") || isEmittingForOp
);
816 // Generate attribute verification. If an op instance is not available, then
817 // attribute checks that require one will not be emitted.
819 // Attribute verification is performed as follows:
821 // 1. Verify that all required attributes are present in sorted order. This
822 // ensures that we can use subrange lookup even with potentially missing
824 // 2. Verify native trait attributes so that other attributes may call methods
825 // that depend on the validity of these attributes, e.g. segment size attributes
826 // and operand or result getters.
827 // 3. Verify the constraints on all present attributes.
829 genAttributeVerifier(const OpOrAdaptorHelper
&emitHelper
, FmtContext
&ctx
,
831 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
,
832 bool useProperties
) {
833 if (emitHelper
.getAttrMetadata().empty())
836 // Verify the attribute if it is present. This assumes that default values
837 // are valid. This code snippet pastes the condition inline.
839 // TODO: verify the default value is valid (perhaps in debug mode only).
841 // {0}: Attribute variable name.
842 // {1}: Attribute condition code.
843 // {2}: Emit error prefix.
844 // {3}: Attribute name.
845 // {4}: Attribute/constraint description.
846 const char *const verifyAttrInline
= R
"(
848 return {2}"attribute
'{3}' failed to satisfy constraint
: {4}");
850 // Verify the attribute using a uniqued constraint. Can only be used within
851 // the context of an op.
853 // {0}: Unique constraint name.
854 // {1}: Attribute variable name.
855 // {2}: Attribute name.
856 const char *const verifyAttrUnique
= R
"(
857 if (::mlir::failed({0}(*this, {1}, "{2}")))
858 return ::mlir::failure();
861 // Traverse the array until the required attribute is found. Return an error
862 // if the traversal reached the end.
864 // {0}: Code to get the name of the attribute.
865 // {1}: The emit error prefix.
866 // {2}: The name of the attribute.
867 const char *const findRequiredAttr
= R
"(
869 if (namedAttrIt == namedAttrRange.end())
870 return {1}"requires attribute
'{2}'");
871 if (namedAttrIt->getName() == {0}) {{
872 tblgen_{2} = namedAttrIt->getValue();
876 // Emit a check to see if the iteration has encountered an optional attribute.
878 // {0}: Code to get the name of the attribute.
879 // {1}: The name of the attribute.
880 const char *const checkOptionalAttr
= R
"(
881 else if (namedAttrIt->getName() == {0}) {{
882 tblgen_{1} = namedAttrIt->getValue();
885 // Emit the start of the loop for checking trailing attributes.
886 const char *const checkTrailingAttrs
= R
"(while (true) {
887 if (namedAttrIt == namedAttrRange.end()) {
891 // Emit the verifier for the attribute.
892 const auto emitVerifier
= [&](Attribute attr
, StringRef attrName
,
894 std::string condition
= attr
.getPredicate().getCondition();
896 std::optional
<StringRef
> constraintFn
;
897 if (emitHelper
.isEmittingForOp() &&
898 (constraintFn
= staticVerifierEmitter
.getAttrConstraintFn(attr
))) {
899 body
<< formatv(verifyAttrUnique
, *constraintFn
, varName
, attrName
);
901 body
<< formatv(verifyAttrInline
, varName
,
902 tgfmt(condition
, &ctx
.withSelf(varName
)),
903 emitHelper
.emitErrorPrefix(), attrName
,
904 escapeString(attr
.getSummary()));
908 // Prefix variables with `tblgen_` to avoid hiding the attribute accessor.
909 const auto getVarName
= [&](StringRef attrName
) {
910 return (tblgenNamePrefix
+ attrName
).str();
915 for (const std::pair
<StringRef
, AttributeMetadata
> &it
:
916 emitHelper
.getAttrMetadata()) {
917 const AttributeMetadata
&metadata
= it
.second
;
918 if (metadata
.constraint
&& metadata
.constraint
->isDerivedAttr())
921 "auto tblgen_{0} = getProperties().{0}; (void)tblgen_{0};\n",
923 if (metadata
.isRequired
)
925 "if (!tblgen_{0}) return {1}\"requires attribute '{0}'\");\n",
926 it
.first
, emitHelper
.emitErrorPrefix());
929 body
<< formatv("auto namedAttrRange = {0};\n", emitHelper
.getAttrRange());
930 body
<< "auto namedAttrIt = namedAttrRange.begin();\n";
932 // Iterate over the attributes in sorted order. Keep track of the optional
933 // attributes that may be encountered along the way.
934 SmallVector
<const AttributeMetadata
*> optionalAttrs
;
936 for (const std::pair
<StringRef
, AttributeMetadata
> &it
:
937 emitHelper
.getAttrMetadata()) {
938 const AttributeMetadata
&metadata
= it
.second
;
939 if (!metadata
.isRequired
) {
940 optionalAttrs
.push_back(&metadata
);
944 body
<< formatv("::mlir::Attribute {0};\n", getVarName(it
.first
));
945 for (const AttributeMetadata
*optional
: optionalAttrs
) {
946 body
<< formatv("::mlir::Attribute {0};\n",
947 getVarName(optional
->attrName
));
949 body
<< formatv(findRequiredAttr
, emitHelper
.getAttrName(it
.first
),
950 emitHelper
.emitErrorPrefix(), it
.first
);
951 for (const AttributeMetadata
*optional
: optionalAttrs
) {
952 body
<< formatv(checkOptionalAttr
,
953 emitHelper
.getAttrName(optional
->attrName
),
956 body
<< "\n ++namedAttrIt;\n}\n";
957 optionalAttrs
.clear();
959 // Get trailing optional attributes.
960 if (!optionalAttrs
.empty()) {
961 for (const AttributeMetadata
*optional
: optionalAttrs
) {
962 body
<< formatv("::mlir::Attribute {0};\n",
963 getVarName(optional
->attrName
));
965 body
<< checkTrailingAttrs
;
966 for (const AttributeMetadata
*optional
: optionalAttrs
) {
967 body
<< formatv(checkOptionalAttr
,
968 emitHelper
.getAttrName(optional
->attrName
),
971 body
<< "\n ++namedAttrIt;\n}\n";
976 // Emit the checks for segment attributes first so that the other
977 // constraints can call operand and result getters.
978 genNativeTraitAttrVerifier(body
, emitHelper
);
980 bool isEmittingForOp
= emitHelper
.isEmittingForOp();
981 for (const auto &namedAttr
: emitHelper
.getOp().getAttributes())
982 if (canEmitAttrVerifier(namedAttr
.attr
, isEmittingForOp
))
983 emitVerifier(namedAttr
.attr
, namedAttr
.name
, getVarName(namedAttr
.name
));
986 /// Include declarations specified on NativeTrait
987 static std::string
formatExtraDeclarations(const Operator
&op
) {
988 SmallVector
<StringRef
> extraDeclarations
;
989 // Include extra class declarations from NativeTrait
990 for (const auto &trait
: op
.getTraits()) {
991 if (auto *opTrait
= dyn_cast
<tblgen::NativeTrait
>(&trait
)) {
992 StringRef value
= opTrait
->getExtraConcreteClassDeclaration();
995 extraDeclarations
.push_back(value
);
998 extraDeclarations
.push_back(op
.getExtraClassDeclaration());
999 return llvm::join(extraDeclarations
, "\n");
1002 /// Op extra class definitions have a `$cppClass` substitution that is to be
1003 /// replaced by the C++ class name.
1004 /// Include declarations specified on NativeTrait
1005 static std::string
formatExtraDefinitions(const Operator
&op
) {
1006 SmallVector
<StringRef
> extraDefinitions
;
1007 // Include extra class definitions from NativeTrait
1008 for (const auto &trait
: op
.getTraits()) {
1009 if (auto *opTrait
= dyn_cast
<tblgen::NativeTrait
>(&trait
)) {
1010 StringRef value
= opTrait
->getExtraConcreteClassDefinition();
1013 extraDefinitions
.push_back(value
);
1016 extraDefinitions
.push_back(op
.getExtraClassDefinition());
1017 FmtContext ctx
= FmtContext().addSubst("cppClass", op
.getCppClassName());
1018 return tgfmt(llvm::join(extraDefinitions
, "\n"), &ctx
).str();
1021 OpEmitter::OpEmitter(const Operator
&op
,
1022 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
)
1023 : def(op
.getDef()), op(op
),
1024 opClass(op
.getCppClassName(), formatExtraDeclarations(op
),
1025 formatExtraDefinitions(op
)),
1026 staticVerifierEmitter(staticVerifierEmitter
),
1027 emitHelper(op
, /*emitForOp=*/true) {
1028 verifyCtx
.addSubst("_op", "(*this->getOperation())");
1029 verifyCtx
.addSubst("_ctxt", "this->getOperation()->getContext()");
1033 // Generate C++ code for various op methods. The order here determines the
1034 // methods in the generated file.
1035 genAttrNameGetters();
1036 genOpAsmInterface();
1038 genNamedOperandGetters();
1039 genNamedOperandSetters();
1040 genNamedResultGetters();
1041 genNamedRegionGetters();
1042 genNamedSuccessorGetters();
1043 genPropertiesSupport();
1046 genOptionalAttrRemovers();
1048 genPopulateDefaultAttributes();
1052 genCustomVerifier();
1053 genCanonicalizerDecls();
1055 genTypeInterfaceMethods();
1056 genOpInterfaceMethods();
1057 generateOpFormat(op
, opClass
);
1058 genSideEffectInterfaceMethods();
1060 void OpEmitter::emitDecl(
1061 const Operator
&op
, raw_ostream
&os
,
1062 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
) {
1063 OpEmitter(op
, staticVerifierEmitter
).emitDecl(os
);
1066 void OpEmitter::emitDef(
1067 const Operator
&op
, raw_ostream
&os
,
1068 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
) {
1069 OpEmitter(op
, staticVerifierEmitter
).emitDef(os
);
1072 void OpEmitter::emitDecl(raw_ostream
&os
) {
1074 opClass
.writeDeclTo(os
);
1077 void OpEmitter::emitDef(raw_ostream
&os
) {
1079 opClass
.writeDefTo(os
);
1082 static void errorIfPruned(size_t line
, Method
*m
, const Twine
&methodName
,
1083 const Operator
&op
) {
1086 PrintFatalError(op
.getLoc(), "Unexpected overlap when generating `" +
1087 methodName
+ "` for " +
1088 op
.getOperationName() + " (from line " +
1092 #define ERROR_IF_PRUNED(M, N, O) errorIfPruned(__LINE__, M, N, O)
1094 void OpEmitter::genAttrNameGetters() {
1095 const llvm::MapVector
<StringRef
, AttributeMetadata
> &attributes
=
1096 emitHelper
.getAttrMetadata();
1097 bool hasOperandSegmentsSize
=
1098 op
.getDialect().usePropertiesForAttributes() &&
1099 op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
1100 // Emit the getAttributeNames method.
1102 auto *method
= opClass
.addStaticInlineMethod(
1103 "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames");
1104 ERROR_IF_PRUNED(method
, "getAttributeNames", op
);
1105 auto &body
= method
->body();
1106 if (!hasOperandSegmentsSize
&& attributes
.empty()) {
1107 body
<< " return {};";
1108 // Nothing else to do if there are no registered attributes. Exit early.
1111 body
<< " static ::llvm::StringRef attrNames[] = {";
1112 llvm::interleaveComma(llvm::make_first_range(attributes
), body
,
1113 [&](StringRef attrName
) {
1114 body
<< "::llvm::StringRef(\"" << attrName
<< "\")";
1116 if (hasOperandSegmentsSize
) {
1117 if (!attributes
.empty())
1119 body
<< "::llvm::StringRef(\"" << operandSegmentAttrName
<< "\")";
1121 body
<< "};\n return ::llvm::ArrayRef(attrNames);";
1124 // Emit the getAttributeNameForIndex methods.
1126 auto *method
= opClass
.addInlineMethod
<Method::Private
>(
1127 "::mlir::StringAttr", "getAttributeNameForIndex",
1128 MethodParameter("unsigned", "index"));
1129 ERROR_IF_PRUNED(method
, "getAttributeNameForIndex", op
);
1131 << " return getAttributeNameForIndex((*this)->getName(), index);";
1134 auto *method
= opClass
.addStaticInlineMethod
<Method::Private
>(
1135 "::mlir::StringAttr", "getAttributeNameForIndex",
1136 MethodParameter("::mlir::OperationName", "name"),
1137 MethodParameter("unsigned", "index"));
1138 ERROR_IF_PRUNED(method
, "getAttributeNameForIndex", op
);
1140 if (attributes
.empty()) {
1141 method
->body() << " return {};";
1143 const char *const getAttrName
= R
"(
1144 assert(index < {0} && "invalid attribute index
");
1145 assert(name.getStringRef() == getOperationName() && "invalid operation name
");
1146 return name.getAttributeNames()[index];
1148 method
->body() << formatv(getAttrName
, attributes
.size());
1152 // Generate the <attr>AttrName methods, that expose the attribute names to
1154 const char *attrNameMethodBody
= " return getAttributeNameForIndex({0});";
1155 for (auto [index
, attr
] :
1156 llvm::enumerate(llvm::make_first_range(attributes
))) {
1157 std::string name
= op
.getGetterName(attr
);
1158 std::string methodName
= name
+ "AttrName";
1160 // Generate the non-static variant.
1162 auto *method
= opClass
.addInlineMethod("::mlir::StringAttr", methodName
);
1163 ERROR_IF_PRUNED(method
, methodName
, op
);
1164 method
->body() << llvm::formatv(attrNameMethodBody
, index
);
1167 // Generate the static variant.
1169 auto *method
= opClass
.addStaticInlineMethod(
1170 "::mlir::StringAttr", methodName
,
1171 MethodParameter("::mlir::OperationName", "name"));
1172 ERROR_IF_PRUNED(method
, methodName
, op
);
1173 method
->body() << llvm::formatv(attrNameMethodBody
,
1174 "name, " + Twine(index
));
1177 if (hasOperandSegmentsSize
) {
1178 std::string name
= op
.getGetterName(operandSegmentAttrName
);
1179 std::string methodName
= name
+ "AttrName";
1180 // Generate the non-static variant.
1182 auto *method
= opClass
.addInlineMethod("::mlir::StringAttr", methodName
);
1183 ERROR_IF_PRUNED(method
, methodName
, op
);
1185 << " return (*this)->getName().getAttributeNames().back();";
1188 // Generate the static variant.
1190 auto *method
= opClass
.addStaticInlineMethod(
1191 "::mlir::StringAttr", methodName
,
1192 MethodParameter("::mlir::OperationName", "name"));
1193 ERROR_IF_PRUNED(method
, methodName
, op
);
1194 method
->body() << " return name.getAttributeNames().back();";
1199 // Emit the getter for an attribute with the return type specified.
1200 // It is templated to be shared between the Op and the adaptor class.
1201 template <typename OpClassOrAdaptor
>
1202 static void emitAttrGetterWithReturnType(FmtContext
&fctx
,
1203 OpClassOrAdaptor
&opClass
,
1204 const Operator
&op
, StringRef name
,
1206 auto *method
= opClass
.addMethod(attr
.getReturnType(), name
);
1207 ERROR_IF_PRUNED(method
, name
, op
);
1208 auto &body
= method
->body();
1209 body
<< " auto attr = " << name
<< "Attr();\n";
1210 if (attr
.hasDefaultValue() && attr
.isOptional()) {
1211 // Returns the default value if not set.
1212 // TODO: this is inefficient, we are recreating the attribute for every
1213 // call. This should be set instead.
1214 if (!attr
.isConstBuildable()) {
1215 PrintFatalError("DefaultValuedAttr of type " + attr
.getAttrDefName() +
1216 " must have a constBuilder");
1218 std::string defaultValue
= std::string(
1219 tgfmt(attr
.getConstBuilderTemplate(), &fctx
, attr
.getDefaultValue()));
1220 body
<< " if (!attr)\n return "
1221 << tgfmt(attr
.getConvertFromStorageCall(),
1222 &fctx
.withSelf(defaultValue
))
1226 << tgfmt(attr
.getConvertFromStorageCall(), &fctx
.withSelf("attr"))
1230 void OpEmitter::genPropertiesSupport() {
1231 if (!emitHelper
.hasProperties())
1234 SmallVector
<ConstArgument
> attrOrProperties
;
1235 for (const std::pair
<StringRef
, AttributeMetadata
> &it
:
1236 emitHelper
.getAttrMetadata()) {
1237 if (!it
.second
.constraint
|| !it
.second
.constraint
->isDerivedAttr())
1238 attrOrProperties
.push_back(&it
.second
);
1240 for (const NamedProperty
&prop
: op
.getProperties())
1241 attrOrProperties
.push_back(&prop
);
1242 if (emitHelper
.getOperandSegmentsSize())
1243 attrOrProperties
.push_back(&emitHelper
.getOperandSegmentsSize().value());
1244 if (emitHelper
.getResultSegmentsSize())
1245 attrOrProperties
.push_back(&emitHelper
.getResultSegmentsSize().value());
1246 if (attrOrProperties
.empty())
1248 auto &setPropMethod
=
1251 "::mlir::LogicalResult", "setPropertiesFromAttr",
1252 MethodParameter("Properties &", "prop"),
1253 MethodParameter("::mlir::Attribute", "attr"),
1255 "::llvm::function_ref<::mlir::InFlightDiagnostic()>",
1258 auto &getPropMethod
=
1260 .addStaticMethod("::mlir::Attribute", "getPropertiesAsAttr",
1261 MethodParameter("::mlir::MLIRContext *", "ctx"),
1262 MethodParameter("const Properties &", "prop"))
1266 .addStaticMethod("llvm::hash_code", "computePropertiesHash",
1267 MethodParameter("const Properties &", "prop"))
1269 auto &getInherentAttrMethod
=
1271 .addStaticMethod("std::optional<mlir::Attribute>", "getInherentAttr",
1272 MethodParameter("::mlir::MLIRContext *", "ctx"),
1273 MethodParameter("const Properties &", "prop"),
1274 MethodParameter("llvm::StringRef", "name"))
1276 auto &setInherentAttrMethod
=
1278 .addStaticMethod("void", "setInherentAttr",
1279 MethodParameter("Properties &", "prop"),
1280 MethodParameter("llvm::StringRef", "name"),
1281 MethodParameter("mlir::Attribute", "value"))
1283 auto &populateInherentAttrsMethod
=
1285 .addStaticMethod("void", "populateInherentAttrs",
1286 MethodParameter("::mlir::MLIRContext *", "ctx"),
1287 MethodParameter("const Properties &", "prop"),
1288 MethodParameter("::mlir::NamedAttrList &", "attrs"))
1290 auto &verifyInherentAttrsMethod
=
1293 "::mlir::LogicalResult", "verifyInherentAttrs",
1294 MethodParameter("::mlir::OperationName", "opName"),
1295 MethodParameter("::mlir::NamedAttrList &", "attrs"),
1297 "llvm::function_ref<::mlir::InFlightDiagnostic()>",
1301 opClass
.declare
<UsingDeclaration
>("Properties", "FoldAdaptor::Properties");
1303 // Convert the property to the attribute form.
1305 setPropMethod
<< R
"decl(
1306 ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr);
1308 emitError() << "expected DictionaryAttr to set properties
";
1309 return ::mlir::failure();
1312 // TODO: properties might be optional as well.
1313 const char *propFromAttrFmt
= R
"decl(;
1315 auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr,
1316 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {{
1321 emitError() << "expected key entry
for {1} in DictionaryAttr to set
"
1323 return ::mlir::failure();
1325 if (::mlir::failed(setFromAttr(prop.{1}, attr, emitError)))
1326 return ::mlir::failure();
1330 for (const auto &attrOrProp
: attrOrProperties
) {
1331 if (const auto *namedProperty
=
1332 llvm::dyn_cast_if_present
<const NamedProperty
*>(attrOrProp
)) {
1333 StringRef name
= namedProperty
->name
;
1334 auto &prop
= namedProperty
->prop
;
1337 std::string getAttr
;
1338 llvm::raw_string_ostream
os(getAttr
);
1339 os
<< " auto attr = dict.get(\"" << name
<< "\");";
1340 if (name
== operandSegmentAttrName
) {
1341 // Backward compat for now, TODO: Remove at some point.
1342 os
<< " if (!attr) attr = dict.get(\"operand_segment_sizes\");";
1344 if (name
== resultSegmentAttrName
) {
1345 // Backward compat for now, TODO: Remove at some point.
1346 os
<< " if (!attr) attr = dict.get(\"result_segment_sizes\");";
1350 setPropMethod
<< formatv(propFromAttrFmt
,
1351 tgfmt(prop
.getConvertFromAttributeCall(),
1352 &fctx
.addSubst("_attr", propertyAttr
)
1353 .addSubst("_storage", propertyStorage
)
1354 .addSubst("_diag", propertyDiag
)),
1358 const auto *namedAttr
=
1359 llvm::dyn_cast_if_present
<const AttributeMetadata
*>(attrOrProp
);
1360 StringRef name
= namedAttr
->attrName
;
1361 std::string getAttr
;
1362 llvm::raw_string_ostream
os(getAttr
);
1363 os
<< " auto attr = dict.get(\"" << name
<< "\");";
1364 if (name
== operandSegmentAttrName
) {
1365 // Backward compat for now
1366 os
<< " if (!attr) attr = dict.get(\"operand_segment_sizes\");";
1368 if (name
== resultSegmentAttrName
) {
1369 // Backward compat for now
1370 os
<< " if (!attr) attr = dict.get(\"result_segment_sizes\");";
1374 setPropMethod
<< formatv(R
"decl(
1376 auto &propStorage = prop.{0};
1378 if (attr || /*isRequired=*/{1}) {{
1380 emitError() << "expected key entry
for {0} in DictionaryAttr to set
"
1382 return ::mlir::failure();
1384 auto convertedAttr = ::llvm::dyn_cast<std::remove_reference_t<decltype(propStorage)>>(attr);
1385 if (convertedAttr) {{
1386 propStorage = convertedAttr;
1388 emitError() << "Invalid attribute `
{0}` in property conversion
: " << attr;
1389 return ::mlir::failure();
1394 name
, namedAttr
->isRequired
, getAttr
);
1397 setPropMethod
<< " return ::mlir::success();\n";
1399 // Convert the attribute form to the property.
1401 getPropMethod
<< " ::mlir::SmallVector<::mlir::NamedAttribute> attrs;\n"
1402 << " ::mlir::Builder odsBuilder{ctx};\n";
1403 const char *propToAttrFmt
= R
"decl(
1405 const auto &propStorage = prop.{0};
1406 attrs.push_back(odsBuilder.getNamedAttr("{0}",
1410 for (const auto &attrOrProp
: attrOrProperties
) {
1411 if (const auto *namedProperty
=
1412 llvm::dyn_cast_if_present
<const NamedProperty
*>(attrOrProp
)) {
1413 StringRef name
= namedProperty
->name
;
1414 auto &prop
= namedProperty
->prop
;
1416 getPropMethod
<< formatv(
1417 propToAttrFmt
, name
,
1418 tgfmt(prop
.getConvertToAttributeCall(),
1419 &fctx
.addSubst("_ctxt", "ctx")
1420 .addSubst("_storage", propertyStorage
)));
1423 const auto *namedAttr
=
1424 llvm::dyn_cast_if_present
<const AttributeMetadata
*>(attrOrProp
);
1425 StringRef name
= namedAttr
->attrName
;
1426 getPropMethod
<< formatv(R
"decl(
1428 const auto &propStorage = prop.{0};
1430 attrs.push_back(odsBuilder.getNamedAttr("{0}",
1436 getPropMethod
<< R
"decl(
1438 return odsBuilder.getDictionaryAttr(attrs);
1442 // Hashing for the property
1444 const char *propHashFmt
= R
"decl(
1445 auto hash_{0} = [] (const auto &propStorage) -> llvm::hash_code {
1449 for (const auto &attrOrProp
: attrOrProperties
) {
1450 if (const auto *namedProperty
=
1451 llvm::dyn_cast_if_present
<const NamedProperty
*>(attrOrProp
)) {
1452 StringRef name
= namedProperty
->name
;
1453 auto &prop
= namedProperty
->prop
;
1455 hashMethod
<< formatv(propHashFmt
, name
,
1456 tgfmt(prop
.getHashPropertyCall(),
1457 &fctx
.addSubst("_storage", propertyStorage
)));
1460 hashMethod
<< " return llvm::hash_combine(";
1461 llvm::interleaveComma(
1462 attrOrProperties
, hashMethod
, [&](const ConstArgument
&attrOrProp
) {
1463 if (const auto *namedProperty
=
1464 llvm::dyn_cast_if_present
<const NamedProperty
*>(attrOrProp
)) {
1465 hashMethod
<< "\n hash_" << namedProperty
->name
<< "(prop."
1466 << namedProperty
->name
<< ")";
1469 const auto *namedAttr
=
1470 llvm::dyn_cast_if_present
<const AttributeMetadata
*>(attrOrProp
);
1471 StringRef name
= namedAttr
->attrName
;
1472 hashMethod
<< "\n llvm::hash_value(prop." << name
1473 << ".getAsOpaquePointer())";
1475 hashMethod
<< ");\n";
1477 const char *getInherentAttrMethodFmt
= R
"decl(
1481 const char *setInherentAttrMethodFmt
= R
"decl(
1482 if (name == "{0}") {{
1483 prop.{0} = ::llvm::dyn_cast_or_null<std::remove_reference_t<decltype(prop.{0})>>(value);
1487 const char *populateInherentAttrsMethodFmt
= R
"decl(
1488 if (prop.{0}) attrs.append("{0}", prop.{0});
1490 for (const auto &attrOrProp
: attrOrProperties
) {
1491 if (const auto *namedAttr
=
1492 llvm::dyn_cast_if_present
<const AttributeMetadata
*>(attrOrProp
)) {
1493 StringRef name
= namedAttr
->attrName
;
1494 getInherentAttrMethod
<< formatv(getInherentAttrMethodFmt
, name
);
1495 setInherentAttrMethod
<< formatv(setInherentAttrMethodFmt
, name
);
1496 populateInherentAttrsMethod
1497 << formatv(populateInherentAttrsMethodFmt
, name
);
1500 // The ODS segment size property is "special": we expose it as an attribute
1501 // even though it is a native property.
1502 const auto *namedProperty
= cast
<const NamedProperty
*>(attrOrProp
);
1503 StringRef name
= namedProperty
->name
;
1504 if (name
!= operandSegmentAttrName
&& name
!= resultSegmentAttrName
)
1506 auto &prop
= namedProperty
->prop
;
1508 fctx
.addSubst("_ctxt", "ctx");
1509 fctx
.addSubst("_storage", Twine("prop.") + name
);
1510 if (name
== operandSegmentAttrName
) {
1511 getInherentAttrMethod
1512 << formatv(" if (name == \"operand_segment_sizes\" || name == "
1514 operandSegmentAttrName
);
1516 getInherentAttrMethod
1517 << formatv(" if (name == \"result_segment_sizes\" || name == "
1519 resultSegmentAttrName
);
1521 getInherentAttrMethod
<< tgfmt(prop
.getConvertToAttributeCall(), &fctx
)
1524 if (name
== operandSegmentAttrName
) {
1525 setInherentAttrMethod
1526 << formatv(" if (name == \"operand_segment_sizes\" || name == "
1528 operandSegmentAttrName
);
1530 setInherentAttrMethod
1531 << formatv(" if (name == \"result_segment_sizes\" || name == "
1533 resultSegmentAttrName
);
1535 setInherentAttrMethod
<< formatv(R
"decl(
1536 auto arrAttr = ::llvm::dyn_cast_or_null<::mlir::DenseI32ArrayAttr>(value);
1537 if (!arrAttr) return;
1538 if (arrAttr.size() != sizeof(prop.{0}) / sizeof(int32_t))
1540 llvm::copy(arrAttr.asArrayRef(), prop.{0}.begin());
1545 if (name
== operandSegmentAttrName
) {
1546 populateInherentAttrsMethod
1547 << formatv(" attrs.append(\"{0}\", {1});\n", operandSegmentAttrName
,
1548 tgfmt(prop
.getConvertToAttributeCall(), &fctx
));
1550 populateInherentAttrsMethod
1551 << formatv(" attrs.append(\"{0}\", {1});\n", resultSegmentAttrName
,
1552 tgfmt(prop
.getConvertToAttributeCall(), &fctx
));
1555 getInherentAttrMethod
<< " return std::nullopt;\n";
1557 // Emit the verifiers method for backward compatibility with the generic
1558 // syntax. This method verifies the constraint on the properties attributes
1559 // before they are set, since dyn_cast<> will silently omit failures.
1560 for (const auto &attrOrProp
: attrOrProperties
) {
1561 const auto *namedAttr
=
1562 llvm::dyn_cast_if_present
<const AttributeMetadata
*>(attrOrProp
);
1563 if (!namedAttr
|| !namedAttr
->constraint
)
1565 Attribute attr
= *namedAttr
->constraint
;
1566 std::optional
<StringRef
> constraintFn
=
1567 staticVerifierEmitter
.getAttrConstraintFn(attr
);
1570 if (canEmitAttrVerifier(attr
,
1571 /*isEmittingForOp=*/false)) {
1572 std::string name
= op
.getGetterName(namedAttr
->attrName
);
1573 verifyInherentAttrsMethod
1576 ::mlir::Attribute attr = attrs.get({0}AttrName(opName));
1577 if (attr && ::mlir::failed({1}(attr, "{2}", emitError)))
1578 return ::mlir::failure();
1581 name
, constraintFn
, namedAttr
->attrName
);
1584 verifyInherentAttrsMethod
<< " return ::mlir::success();";
1586 // Generate methods to interact with bytecode.
1587 genPropertiesSupportForBytecode(attrOrProperties
);
1590 void OpEmitter::genPropertiesSupportForBytecode(
1591 ArrayRef
<ConstArgument
> attrOrProperties
) {
1592 if (op
.useCustomPropertiesEncoding()) {
1593 opClass
.declareStaticMethod(
1594 "::mlir::LogicalResult", "readProperties",
1595 MethodParameter("::mlir::DialectBytecodeReader &", "reader"),
1596 MethodParameter("::mlir::OperationState &", "state"));
1597 opClass
.declareMethod(
1598 "void", "writeProperties",
1599 MethodParameter("::mlir::DialectBytecodeWriter &", "writer"));
1603 auto &readPropertiesMethod
=
1606 "::mlir::LogicalResult", "readProperties",
1607 MethodParameter("::mlir::DialectBytecodeReader &", "reader"),
1608 MethodParameter("::mlir::OperationState &", "state"))
1611 auto &writePropertiesMethod
=
1614 "void", "writeProperties",
1615 MethodParameter("::mlir::DialectBytecodeWriter &", "writer"))
1618 // Populate bytecode serialization logic.
1619 readPropertiesMethod
1620 << " auto &prop = state.getOrAddProperties<Properties>(); (void)prop;";
1621 writePropertiesMethod
<< " auto &prop = getProperties(); (void)prop;\n";
1622 for (const auto &item
: llvm::enumerate(attrOrProperties
)) {
1623 auto &attrOrProp
= item
.value();
1625 fctx
.addSubst("_reader", "reader")
1626 .addSubst("_writer", "writer")
1627 .addSubst("_storage", propertyStorage
)
1628 .addSubst("_ctxt", "this->getContext()");
1629 // If the op emits operand/result segment sizes as a property, emit the
1630 // legacy reader/writer in the appropriate order to allow backward
1631 // compatibility and back deployment.
1632 if (emitHelper
.getOperandSegmentsSize().has_value() &&
1633 item
.index() == emitHelper
.getOperandSegmentSizesLegacyIndex()) {
1634 FmtContext
fmtCtxt(fctx
);
1635 fmtCtxt
.addSubst("_propName", operandSegmentAttrName
);
1636 readPropertiesMethod
<< tgfmt(readBytecodeSegmentSizeLegacy
, &fmtCtxt
);
1637 writePropertiesMethod
<< tgfmt(writeBytecodeSegmentSizeLegacy
, &fmtCtxt
);
1639 if (emitHelper
.getResultSegmentsSize().has_value() &&
1640 item
.index() == emitHelper
.getResultSegmentSizesLegacyIndex()) {
1641 FmtContext
fmtCtxt(fctx
);
1642 fmtCtxt
.addSubst("_propName", resultSegmentAttrName
);
1643 readPropertiesMethod
<< tgfmt(readBytecodeSegmentSizeLegacy
, &fmtCtxt
);
1644 writePropertiesMethod
<< tgfmt(writeBytecodeSegmentSizeLegacy
, &fmtCtxt
);
1646 if (const auto *namedProperty
=
1647 attrOrProp
.dyn_cast
<const NamedProperty
*>()) {
1648 StringRef name
= namedProperty
->name
;
1649 readPropertiesMethod
<< formatv(
1652 auto &propStorage = prop.{0};
1653 auto readProp = [&]() {
1655 return ::mlir::success();
1657 if (::mlir::failed(readProp()))
1658 return ::mlir::failure();
1662 tgfmt(namedProperty
->prop
.getReadFromMlirBytecodeCall(), &fctx
));
1663 writePropertiesMethod
<< formatv(
1666 auto &propStorage = prop.{0};
1670 name
, tgfmt(namedProperty
->prop
.getWriteToMlirBytecodeCall(), &fctx
));
1673 const auto *namedAttr
= attrOrProp
.dyn_cast
<const AttributeMetadata
*>();
1674 StringRef name
= namedAttr
->attrName
;
1675 if (namedAttr
->isRequired
) {
1676 readPropertiesMethod
<< formatv(R
"(
1677 if (::mlir::failed(reader.readAttribute(prop.{0})))
1678 return ::mlir::failure();
1681 writePropertiesMethod
1682 << formatv(" writer.writeAttribute(prop.{0});\n", name
);
1684 readPropertiesMethod
<< formatv(R
"(
1685 if (::mlir::failed(reader.readOptionalAttribute(prop.{0})))
1686 return ::mlir::failure();
1689 writePropertiesMethod
<< formatv(R
"(
1690 writer.writeOptionalAttribute(prop.{0});
1695 readPropertiesMethod
<< " return ::mlir::success();";
1698 void OpEmitter::genAttrGetters() {
1700 fctx
.withBuilder("::mlir::Builder((*this)->getContext())");
1702 // Emit the derived attribute body.
1703 auto emitDerivedAttr
= [&](StringRef name
, Attribute attr
) {
1704 if (auto *method
= opClass
.addMethod(attr
.getReturnType(), name
))
1705 method
->body() << " " << attr
.getDerivedCodeBody() << "\n";
1708 // Generate named accessor with Attribute return type. This is a wrapper
1709 // class that allows referring to the attributes via accessors instead of
1710 // having to use the string interface for better compile time verification.
1711 auto emitAttrWithStorageType
= [&](StringRef name
, StringRef attrName
,
1713 auto *method
= opClass
.addMethod(attr
.getStorageType(), name
+ "Attr");
1716 method
->body() << formatv(
1717 " return ::llvm::{1}<{2}>({0});", emitHelper
.getAttr(attrName
),
1718 attr
.isOptional() || attr
.hasDefaultValue() ? "dyn_cast_or_null"
1720 attr
.getStorageType());
1723 for (const NamedAttribute
&namedAttr
: op
.getAttributes()) {
1724 std::string name
= op
.getGetterName(namedAttr
.name
);
1725 if (namedAttr
.attr
.isDerivedAttr()) {
1726 emitDerivedAttr(name
, namedAttr
.attr
);
1728 emitAttrWithStorageType(name
, namedAttr
.name
, namedAttr
.attr
);
1729 emitAttrGetterWithReturnType(fctx
, opClass
, op
, name
, namedAttr
.attr
);
1733 auto derivedAttrs
= make_filter_range(op
.getAttributes(),
1734 [](const NamedAttribute
&namedAttr
) {
1735 return namedAttr
.attr
.isDerivedAttr();
1737 if (derivedAttrs
.empty())
1740 opClass
.addTrait("::mlir::DerivedAttributeOpInterface::Trait");
1741 // Generate helper method to query whether a named attribute is a derived
1742 // attribute. This enables, for example, avoiding adding an attribute that
1743 // overlaps with a derived attribute.
1746 opClass
.addStaticMethod("bool", "isDerivedAttribute",
1747 MethodParameter("::llvm::StringRef", "name"));
1748 ERROR_IF_PRUNED(method
, "isDerivedAttribute", op
);
1749 auto &body
= method
->body();
1750 for (auto namedAttr
: derivedAttrs
)
1751 body
<< " if (name == \"" << namedAttr
.name
<< "\") return true;\n";
1752 body
<< " return false;";
1754 // Generate method to materialize derived attributes as a DictionaryAttr.
1756 auto *method
= opClass
.addMethod("::mlir::DictionaryAttr",
1757 "materializeDerivedAttributes");
1758 ERROR_IF_PRUNED(method
, "materializeDerivedAttributes", op
);
1759 auto &body
= method
->body();
1761 auto nonMaterializable
=
1762 make_filter_range(derivedAttrs
, [](const NamedAttribute
&namedAttr
) {
1763 return namedAttr
.attr
.getConvertFromStorageCall().empty();
1765 if (!nonMaterializable
.empty()) {
1767 llvm::raw_string_ostream
os(attrs
);
1768 interleaveComma(nonMaterializable
, os
, [&](const NamedAttribute
&attr
) {
1769 os
<< op
.getGetterName(attr
.name
);
1774 "op has non-materializable derived attributes '{0}', skipping",
1776 body
<< formatv(" emitOpError(\"op has non-materializable derived "
1777 "attributes '{0}'\");\n",
1779 body
<< " return nullptr;";
1783 body
<< " ::mlir::MLIRContext* ctx = getContext();\n";
1784 body
<< " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
1785 body
<< " return ::mlir::DictionaryAttr::get(";
1786 body
<< " ctx, {\n";
1789 [&](const NamedAttribute
&namedAttr
) {
1790 auto tmpl
= namedAttr
.attr
.getConvertFromStorageCall();
1791 std::string name
= op
.getGetterName(namedAttr
.name
);
1792 body
<< " {" << name
<< "AttrName(),\n"
1793 << tgfmt(tmpl
, &fctx
.withSelf(name
+ "()")
1794 .withBuilder("odsBuilder")
1795 .addSubst("_ctxt", "ctx")
1796 .addSubst("_storage", "ctx"))
1804 void OpEmitter::genAttrSetters() {
1805 // Generate raw named setter type. This is a wrapper class that allows setting
1806 // to the attributes via setters instead of having to use the string interface
1807 // for better compile time verification.
1808 auto emitAttrWithStorageType
= [&](StringRef setterName
, StringRef getterName
,
1811 opClass
.addMethod("void", setterName
+ "Attr",
1812 MethodParameter(attr
.getStorageType(), "attr"));
1814 method
->body() << formatv(" (*this)->setAttr({0}AttrName(), attr);",
1818 // Generate a setter that accepts the underlying C++ type as opposed to the
1820 auto emitAttrWithReturnType
= [&](StringRef setterName
, StringRef getterName
,
1822 Attribute baseAttr
= attr
.getBaseAttr();
1823 if (!canUseUnwrappedRawValue(baseAttr
))
1826 fctx
.withBuilder("::mlir::Builder((*this)->getContext())");
1827 bool isUnitAttr
= attr
.getAttrDefName() == "UnitAttr";
1828 bool isOptional
= attr
.isOptional();
1830 auto createMethod
= [&](const Twine
¶mType
) {
1831 return opClass
.addMethod("void", setterName
,
1832 MethodParameter(paramType
.str(), "attrValue"));
1835 // Build the method using the correct parameter type depending on
1837 Method
*method
= nullptr;
1839 method
= createMethod("bool");
1840 else if (isOptional
)
1842 createMethod("::std::optional<" + baseAttr
.getReturnType() + ">");
1844 method
= createMethod(attr
.getReturnType());
1848 // If the value isn't optional, just set it directly.
1850 method
->body() << formatv(
1851 " (*this)->setAttr({0}AttrName(), {1});", getterName
,
1852 constBuildAttrFromParam(attr
, fctx
, "attrValue"));
1856 // Otherwise, we only set if the provided value is valid. If it isn't, we
1857 // remove the attribute.
1859 // TODO: Handle unit attr parameters specially, given that it is treated as
1860 // optional but not in the same way as the others (i.e. it uses bool over
1861 // std::optional<>).
1862 StringRef paramStr
= isUnitAttr
? "attrValue" : "*attrValue";
1863 const char *optionalCodeBody
= R
"(
1865 return (*this)->setAttr({0}AttrName(), {1});
1866 (*this)->removeAttr({0}AttrName());)";
1867 method
->body() << formatv(
1868 optionalCodeBody
, getterName
,
1869 constBuildAttrFromParam(baseAttr
, fctx
, paramStr
));
1872 for (const NamedAttribute
&namedAttr
: op
.getAttributes()) {
1873 if (namedAttr
.attr
.isDerivedAttr())
1875 std::string setterName
= op
.getSetterName(namedAttr
.name
);
1876 std::string getterName
= op
.getGetterName(namedAttr
.name
);
1877 emitAttrWithStorageType(setterName
, getterName
, namedAttr
.attr
);
1878 emitAttrWithReturnType(setterName
, getterName
, namedAttr
.attr
);
1882 void OpEmitter::genOptionalAttrRemovers() {
1883 // Generate methods for removing optional attributes, instead of having to
1884 // use the string interface. Enables better compile time verification.
1885 auto emitRemoveAttr
= [&](StringRef name
, bool useProperties
) {
1886 auto upperInitial
= name
.take_front().upper();
1887 auto *method
= opClass
.addMethod("::mlir::Attribute",
1888 op
.getRemoverName(name
) + "Attr");
1891 if (useProperties
) {
1892 method
->body() << formatv(R
"(
1893 auto &attr = getProperties().{0};
1900 method
->body() << formatv("return (*this)->removeAttr({0}AttrName());",
1901 op
.getGetterName(name
));
1904 for (const NamedAttribute
&namedAttr
: op
.getAttributes())
1905 if (namedAttr
.attr
.isOptional())
1906 emitRemoveAttr(namedAttr
.name
,
1907 op
.getDialect().usePropertiesForAttributes());
1910 // Generates the code to compute the start and end index of an operand or result
1912 template <typename RangeT
>
1913 static void generateValueRangeStartAndEnd(
1914 Class
&opClass
, bool isGenericAdaptorBase
, StringRef methodName
,
1915 int numVariadic
, int numNonVariadic
, StringRef rangeSizeCall
,
1916 bool hasAttrSegmentSize
, StringRef sizeAttrInit
, RangeT
&&odsValues
) {
1918 SmallVector
<MethodParameter
> parameters
{MethodParameter("unsigned", "index")};
1919 if (isGenericAdaptorBase
) {
1920 parameters
.emplace_back("unsigned", "odsOperandsSize");
1921 // The range size is passed per parameter for generic adaptor bases as
1922 // using the rangeSizeCall would require the operands, which are not
1923 // accessible in the base class.
1924 rangeSizeCall
= "odsOperandsSize";
1927 auto *method
= opClass
.addMethod("std::pair<unsigned, unsigned>", methodName
,
1931 auto &body
= method
->body();
1932 if (numVariadic
== 0) {
1933 body
<< " return {index, 1};\n";
1934 } else if (hasAttrSegmentSize
) {
1935 body
<< sizeAttrInit
<< attrSizedSegmentValueRangeCalcCode
;
1937 // Because the op can have arbitrarily interleaved variadic and non-variadic
1938 // operands, we need to embed a list in the "sink" getter method for
1939 // calculation at run-time.
1940 SmallVector
<StringRef
, 4> isVariadic
;
1941 isVariadic
.reserve(llvm::size(odsValues
));
1942 for (auto &it
: odsValues
)
1943 isVariadic
.push_back(it
.isVariableLength() ? "true" : "false");
1944 std::string isVariadicList
= llvm::join(isVariadic
, ", ");
1945 body
<< formatv(sameVariadicSizeValueRangeCalcCode
, isVariadicList
,
1946 numNonVariadic
, numVariadic
, rangeSizeCall
, "operand");
1950 static std::string
generateTypeForGetter(const NamedTypeConstraint
&value
) {
1951 std::string str
= "::mlir::Value";
1952 /// If the CPPClassName is not a fully qualified type. Uses of types
1953 /// across Dialect fail because they are not in the correct namespace. So we
1954 /// dont generate TypedValue unless the type is fully qualified.
1955 /// getCPPClassName doesn't return the fully qualified path for
1956 /// `mlir::pdl::OperationType` see
1957 /// https://github.com/llvm/llvm-project/issues/57279.
1958 /// Adaptor will have values that are not from the type of their operation and
1959 /// this is expected, so we dont generate TypedValue for Adaptor
1960 if (value
.constraint
.getCPPClassName() != "::mlir::Type" &&
1961 StringRef(value
.constraint
.getCPPClassName()).startswith("::"))
1962 str
= llvm::formatv("::mlir::TypedValue<{0}>",
1963 value
.constraint
.getCPPClassName())
1968 // Generates the named operand getter methods for the given Operator `op` and
1969 // puts them in `opClass`. Uses `rangeType` as the return type of getters that
1970 // return a range of operands (individual operands are `Value ` and each
1971 // element in the range must also be `Value `); use `rangeBeginCall` to get
1972 // an iterator to the beginning of the operand range; use `rangeSizeCall` to
1973 // obtain the number of operands. `getOperandCallPattern` contains the code
1974 // necessary to obtain a single operand whose position will be substituted
1976 // "{0}" marker in the pattern. Note that the pattern should work for any kind
1977 // of ops, in particular for one-operand ops that may not have the
1978 // `getOperand(unsigned)` method.
1980 generateNamedOperandGetters(const Operator
&op
, Class
&opClass
,
1981 Class
*genericAdaptorBase
, StringRef sizeAttrInit
,
1982 StringRef rangeType
, StringRef rangeElementType
,
1983 StringRef rangeBeginCall
, StringRef rangeSizeCall
,
1984 StringRef getOperandCallPattern
) {
1985 const int numOperands
= op
.getNumOperands();
1986 const int numVariadicOperands
= op
.getNumVariableLengthOperands();
1987 const int numNormalOperands
= numOperands
- numVariadicOperands
;
1989 const auto *sameVariadicSize
=
1990 op
.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
1991 const auto *attrSizedOperands
=
1992 op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
1994 if (numVariadicOperands
> 1 && !sameVariadicSize
&& !attrSizedOperands
) {
1995 PrintFatalError(op
.getLoc(), "op has multiple variadic operands but no "
1996 "specification over their sizes");
1999 if (numVariadicOperands
< 2 && attrSizedOperands
) {
2000 PrintFatalError(op
.getLoc(), "op must have at least two variadic operands "
2001 "to use 'AttrSizedOperandSegments' trait");
2004 if (attrSizedOperands
&& sameVariadicSize
) {
2005 PrintFatalError(op
.getLoc(),
2006 "op cannot have both 'AttrSizedOperandSegments' and "
2007 "'SameVariadicOperandSize' traits");
2010 // First emit a few "sink" getter methods upon which we layer all nicer named
2012 // If generating for an adaptor, the method is put into the non-templated
2013 // generic base class, to not require being defined in the header.
2014 // Since the operand size can't be determined from the base class however,
2015 // it has to be passed as an additional argument. The trampoline below
2016 // generates the function with the same signature as the Op in the generic
2018 bool isGenericAdaptorBase
= genericAdaptorBase
!= nullptr;
2019 generateValueRangeStartAndEnd(
2020 /*opClass=*/isGenericAdaptorBase
? *genericAdaptorBase
: opClass
,
2021 isGenericAdaptorBase
,
2022 /*methodName=*/"getODSOperandIndexAndLength", numVariadicOperands
,
2023 numNormalOperands
, rangeSizeCall
, attrSizedOperands
, sizeAttrInit
,
2024 const_cast<Operator
&>(op
).getOperands());
2025 if (isGenericAdaptorBase
) {
2026 // Generate trampoline for calling 'getODSOperandIndexAndLength' with just
2027 // the index. This just calls the implementation in the base class but
2028 // passes the operand size as parameter.
2029 Method
*method
= opClass
.addMethod("std::pair<unsigned, unsigned>",
2030 "getODSOperandIndexAndLength",
2031 MethodParameter("unsigned", "index"));
2032 ERROR_IF_PRUNED(method
, "getODSOperandIndexAndLength", op
);
2033 MethodBody
&body
= method
->body();
2034 body
.indent() << formatv(
2035 "return Base::getODSOperandIndexAndLength(index, {0});", rangeSizeCall
);
2038 auto *m
= opClass
.addMethod(rangeType
, "getODSOperands",
2039 MethodParameter("unsigned", "index"));
2040 ERROR_IF_PRUNED(m
, "getODSOperands", op
);
2041 auto &body
= m
->body();
2042 body
<< formatv(valueRangeReturnCode
, rangeBeginCall
,
2043 "getODSOperandIndexAndLength(index)");
2045 // Then we emit nicer named getter methods by redirecting to the "sink" getter
2047 for (int i
= 0; i
!= numOperands
; ++i
) {
2048 const auto &operand
= op
.getOperand(i
);
2049 if (operand
.name
.empty())
2051 std::string name
= op
.getGetterName(operand
.name
);
2052 if (operand
.isOptional()) {
2053 m
= opClass
.addMethod(isGenericAdaptorBase
2055 : generateTypeForGetter(operand
),
2057 ERROR_IF_PRUNED(m
, name
, op
);
2058 m
->body().indent() << formatv("auto operands = getODSOperands({0});\n"
2059 "return operands.empty() ? {1}{{} : ",
2060 i
, m
->getReturnType());
2061 if (!isGenericAdaptorBase
)
2062 m
->body() << llvm::formatv("::llvm::cast<{0}>", m
->getReturnType());
2063 m
->body() << "(*operands.begin());";
2064 } else if (operand
.isVariadicOfVariadic()) {
2065 std::string segmentAttr
= op
.getGetterName(
2066 operand
.constraint
.getVariadicOfVariadicSegmentSizeAttr());
2067 if (genericAdaptorBase
) {
2068 m
= opClass
.addMethod("::llvm::SmallVector<" + rangeType
+ ">", name
);
2069 ERROR_IF_PRUNED(m
, name
, op
);
2070 m
->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode
,
2071 segmentAttr
, i
, rangeType
);
2075 m
= opClass
.addMethod("::mlir::OperandRangeRange", name
);
2076 ERROR_IF_PRUNED(m
, name
, op
);
2077 m
->body() << " return getODSOperands(" << i
<< ").split(" << segmentAttr
2079 } else if (operand
.isVariadic()) {
2080 m
= opClass
.addMethod(rangeType
, name
);
2081 ERROR_IF_PRUNED(m
, name
, op
);
2082 m
->body() << " return getODSOperands(" << i
<< ");";
2084 m
= opClass
.addMethod(isGenericAdaptorBase
2086 : generateTypeForGetter(operand
),
2088 ERROR_IF_PRUNED(m
, name
, op
);
2089 m
->body().indent() << "return ";
2090 if (!isGenericAdaptorBase
)
2091 m
->body() << llvm::formatv("::llvm::cast<{0}>", m
->getReturnType());
2092 m
->body() << llvm::formatv("(*getODSOperands({0}).begin());", i
);
2097 void OpEmitter::genNamedOperandGetters() {
2098 // Build the code snippet used for initializing the operand_segment_size)s
2100 std::string attrSizeInitCode
;
2101 if (op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
2102 if (op
.getDialect().usePropertiesForAttributes())
2103 attrSizeInitCode
= formatv(adapterSegmentSizeAttrInitCodeProperties
,
2104 "getProperties().operandSegmentSizes");
2107 attrSizeInitCode
= formatv(opSegmentSizeAttrInitCode
,
2108 emitHelper
.getAttr(operandSegmentAttrName
));
2111 generateNamedOperandGetters(
2113 /*genericAdaptorBase=*/nullptr,
2114 /*sizeAttrInit=*/attrSizeInitCode
,
2115 /*rangeType=*/"::mlir::Operation::operand_range",
2116 /*rangeElementType=*/"::mlir::Value",
2117 /*rangeBeginCall=*/"getOperation()->operand_begin()",
2118 /*rangeSizeCall=*/"getOperation()->getNumOperands()",
2119 /*getOperandCallPattern=*/"getOperation()->getOperand({0})");
2122 void OpEmitter::genNamedOperandSetters() {
2123 auto *attrSizedOperands
=
2124 op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
2125 for (int i
= 0, e
= op
.getNumOperands(); i
!= e
; ++i
) {
2126 const auto &operand
= op
.getOperand(i
);
2127 if (operand
.name
.empty())
2129 std::string name
= op
.getGetterName(operand
.name
);
2131 StringRef returnType
;
2132 if (operand
.isVariadicOfVariadic()) {
2133 returnType
= "::mlir::MutableOperandRangeRange";
2134 } else if (operand
.isVariableLength()) {
2135 returnType
= "::mlir::MutableOperandRange";
2137 returnType
= "::mlir::OpOperand &";
2139 auto *m
= opClass
.addMethod(returnType
, name
+ "Mutable");
2140 ERROR_IF_PRUNED(m
, name
, op
);
2141 auto &body
= m
->body();
2142 body
<< " auto range = getODSOperandIndexAndLength(" << i
<< ");\n";
2144 if (!operand
.isVariadicOfVariadic() && !operand
.isVariableLength()) {
2145 // In case of a single operand, return a single OpOperand.
2146 body
<< " return getOperation()->getOpOperand(range.first);\n";
2150 body
<< " auto mutableRange = "
2151 "::mlir::MutableOperandRange(getOperation(), "
2152 "range.first, range.second";
2153 if (attrSizedOperands
) {
2154 if (emitHelper
.hasProperties())
2155 body
<< formatv(", ::mlir::MutableOperandRange::OperandSegment({0}u, "
2156 "{{getOperandSegmentSizesAttrName(), "
2157 "::mlir::DenseI32ArrayAttr::get(getContext(), "
2158 "getProperties().operandSegmentSizes)})",
2162 ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i
,
2163 emitHelper
.getAttr(operandSegmentAttrName
, /*isNamed=*/true));
2167 // If this operand is a nested variadic, we split the range into a
2168 // MutableOperandRangeRange that provides a range over all of the
2170 if (operand
.isVariadicOfVariadic()) {
2172 "mutableRange.split(*(*this)->getAttrDictionary().getNamed("
2173 << op
.getGetterName(
2174 operand
.constraint
.getVariadicOfVariadicSegmentSizeAttr())
2175 << "AttrName()));\n";
2177 // Otherwise, we use the full range directly.
2178 body
<< " return mutableRange;\n";
2183 void OpEmitter::genNamedResultGetters() {
2184 const int numResults
= op
.getNumResults();
2185 const int numVariadicResults
= op
.getNumVariableLengthResults();
2186 const int numNormalResults
= numResults
- numVariadicResults
;
2188 // If we have more than one variadic results, we need more complicated logic
2189 // to calculate the value range for each result.
2191 const auto *sameVariadicSize
=
2192 op
.getTrait("::mlir::OpTrait::SameVariadicResultSize");
2193 const auto *attrSizedResults
=
2194 op
.getTrait("::mlir::OpTrait::AttrSizedResultSegments");
2196 if (numVariadicResults
> 1 && !sameVariadicSize
&& !attrSizedResults
) {
2197 PrintFatalError(op
.getLoc(), "op has multiple variadic results but no "
2198 "specification over their sizes");
2201 if (numVariadicResults
< 2 && attrSizedResults
) {
2202 PrintFatalError(op
.getLoc(), "op must have at least two variadic results "
2203 "to use 'AttrSizedResultSegments' trait");
2206 if (attrSizedResults
&& sameVariadicSize
) {
2207 PrintFatalError(op
.getLoc(),
2208 "op cannot have both 'AttrSizedResultSegments' and "
2209 "'SameVariadicResultSize' traits");
2212 // Build the initializer string for the result segment size attribute.
2213 std::string attrSizeInitCode
;
2214 if (attrSizedResults
) {
2215 if (op
.getDialect().usePropertiesForAttributes())
2216 attrSizeInitCode
= formatv(adapterSegmentSizeAttrInitCodeProperties
,
2217 "getProperties().resultSegmentSizes");
2220 attrSizeInitCode
= formatv(opSegmentSizeAttrInitCode
,
2221 emitHelper
.getAttr(resultSegmentAttrName
));
2224 generateValueRangeStartAndEnd(
2225 opClass
, /*isGenericAdaptorBase=*/false, "getODSResultIndexAndLength",
2226 numVariadicResults
, numNormalResults
, "getOperation()->getNumResults()",
2227 attrSizedResults
, attrSizeInitCode
, op
.getResults());
2230 opClass
.addMethod("::mlir::Operation::result_range", "getODSResults",
2231 MethodParameter("unsigned", "index"));
2232 ERROR_IF_PRUNED(m
, "getODSResults", op
);
2233 m
->body() << formatv(valueRangeReturnCode
, "getOperation()->result_begin()",
2234 "getODSResultIndexAndLength(index)");
2236 for (int i
= 0; i
!= numResults
; ++i
) {
2237 const auto &result
= op
.getResult(i
);
2238 if (result
.name
.empty())
2240 std::string name
= op
.getGetterName(result
.name
);
2241 if (result
.isOptional()) {
2242 m
= opClass
.addMethod(generateTypeForGetter(result
), name
);
2243 ERROR_IF_PRUNED(m
, name
, op
);
2244 m
->body() << " auto results = getODSResults(" << i
<< ");\n"
2245 << llvm::formatv(" return results.empty()"
2247 " : ::llvm::cast<{0}>(*results.begin());",
2248 m
->getReturnType());
2249 } else if (result
.isVariadic()) {
2250 m
= opClass
.addMethod("::mlir::Operation::result_range", name
);
2251 ERROR_IF_PRUNED(m
, name
, op
);
2252 m
->body() << " return getODSResults(" << i
<< ");";
2254 m
= opClass
.addMethod(generateTypeForGetter(result
), name
);
2255 ERROR_IF_PRUNED(m
, name
, op
);
2256 m
->body() << llvm::formatv(
2257 " return ::llvm::cast<{0}>(*getODSResults({1}).begin());",
2258 m
->getReturnType(), i
);
2263 void OpEmitter::genNamedRegionGetters() {
2264 unsigned numRegions
= op
.getNumRegions();
2265 for (unsigned i
= 0; i
< numRegions
; ++i
) {
2266 const auto ®ion
= op
.getRegion(i
);
2267 if (region
.name
.empty())
2269 std::string name
= op
.getGetterName(region
.name
);
2271 // Generate the accessors for a variadic region.
2272 if (region
.isVariadic()) {
2274 opClass
.addMethod("::mlir::MutableArrayRef<::mlir::Region>", name
);
2275 ERROR_IF_PRUNED(m
, name
, op
);
2276 m
->body() << formatv(" return (*this)->getRegions().drop_front({0});",
2281 auto *m
= opClass
.addMethod("::mlir::Region &", name
);
2282 ERROR_IF_PRUNED(m
, name
, op
);
2283 m
->body() << formatv(" return (*this)->getRegion({0});", i
);
2287 void OpEmitter::genNamedSuccessorGetters() {
2288 unsigned numSuccessors
= op
.getNumSuccessors();
2289 for (unsigned i
= 0; i
< numSuccessors
; ++i
) {
2290 const NamedSuccessor
&successor
= op
.getSuccessor(i
);
2291 if (successor
.name
.empty())
2293 std::string name
= op
.getGetterName(successor
.name
);
2294 // Generate the accessors for a variadic successor list.
2295 if (successor
.isVariadic()) {
2296 auto *m
= opClass
.addMethod("::mlir::SuccessorRange", name
);
2297 ERROR_IF_PRUNED(m
, name
, op
);
2298 m
->body() << formatv(
2299 " return {std::next((*this)->successor_begin(), {0}), "
2300 "(*this)->successor_end()};",
2305 auto *m
= opClass
.addMethod("::mlir::Block *", name
);
2306 ERROR_IF_PRUNED(m
, name
, op
);
2307 m
->body() << formatv(" return (*this)->getSuccessor({0});", i
);
2311 static bool canGenerateUnwrappedBuilder(const Operator
&op
) {
2312 // If this op does not have native attributes at all, return directly to avoid
2313 // redefining builders.
2314 if (op
.getNumNativeAttributes() == 0)
2317 bool canGenerate
= false;
2318 // We are generating builders that take raw values for attributes. We need to
2319 // make sure the native attributes have a meaningful "unwrapped" value type
2320 // different from the wrapped mlir::Attribute type to avoid redefining
2321 // builders. This checks for the op has at least one such native attribute.
2322 for (int i
= 0, e
= op
.getNumNativeAttributes(); i
< e
; ++i
) {
2323 const NamedAttribute
&namedAttr
= op
.getAttribute(i
);
2324 if (canUseUnwrappedRawValue(namedAttr
.attr
)) {
2332 static bool canInferType(const Operator
&op
) {
2333 return op
.getTrait("::mlir::InferTypeOpInterface::Trait");
2336 void OpEmitter::genSeparateArgParamBuilder() {
2337 SmallVector
<AttrParamKind
, 2> attrBuilderType
;
2338 attrBuilderType
.push_back(AttrParamKind::WrappedAttr
);
2339 if (canGenerateUnwrappedBuilder(op
))
2340 attrBuilderType
.push_back(AttrParamKind::UnwrappedValue
);
2342 // Emit with separate builders with or without unwrapped attributes and/or
2343 // inferring result type.
2344 auto emit
= [&](AttrParamKind attrType
, TypeParamKind paramKind
,
2346 SmallVector
<MethodParameter
> paramList
;
2347 SmallVector
<std::string
, 4> resultNames
;
2348 llvm::StringSet
<> inferredAttributes
;
2349 buildParamList(paramList
, inferredAttributes
, resultNames
, paramKind
,
2352 auto *m
= opClass
.addStaticMethod("void", "build", std::move(paramList
));
2353 // If the builder is redundant, skip generating the method.
2356 auto &body
= m
->body();
2357 genCodeForAddingArgAndRegionForBuilder(body
, inferredAttributes
,
2358 /*isRawValueAttr=*/attrType
==
2359 AttrParamKind::UnwrappedValue
);
2361 // Push all result types to the operation state
2364 // Generate builder that infers type too.
2365 // TODO: Subsume this with general checking if type can be
2366 // inferred automatically.
2368 ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
2369 if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
2370 {1}.location, {1}.operands,
2371 {1}.attributes.getDictionary({1}.getContext()),
2372 {1}.getRawProperties(),
2373 {1}.regions, inferredReturnTypes)))
2374 {1}.addTypes(inferredReturnTypes);
2376 ::llvm::report_fatal_error("Failed to infer result
type(s
).");)",
2377 opClass
.getClassName(), builderOpState
);
2381 switch (paramKind
) {
2382 case TypeParamKind::None
:
2384 case TypeParamKind::Separate
:
2385 for (int i
= 0, e
= op
.getNumResults(); i
< e
; ++i
) {
2386 if (op
.getResult(i
).isOptional())
2387 body
<< " if (" << resultNames
[i
] << ")\n ";
2388 body
<< " " << builderOpState
<< ".addTypes(" << resultNames
[i
]
2392 // Automatically create the 'resultSegmentSizes' attribute using
2393 // the length of the type ranges.
2394 if (op
.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
2395 if (op
.getDialect().usePropertiesForAttributes()) {
2396 body
<< " ::llvm::copy(::llvm::ArrayRef<int32_t>({";
2398 std::string getterName
= op
.getGetterName(resultSegmentAttrName
);
2399 body
<< " " << builderOpState
<< ".addAttribute(" << getterName
2400 << "AttrName(" << builderOpState
<< ".name), "
2401 << "odsBuilder.getDenseI32ArrayAttr({";
2404 llvm::seq
<int>(0, op
.getNumResults()), body
, [&](int i
) {
2405 const NamedTypeConstraint
&result
= op
.getResult(i
);
2406 if (!result
.isVariableLength()) {
2408 } else if (result
.isOptional()) {
2409 body
<< "(" << resultNames
[i
] << " ? 1 : 0)";
2411 // VariadicOfVariadic of results are currently unsupported in
2412 // MLIR, hence it can only be a simple variadic.
2413 // TODO: Add implementation for VariadicOfVariadic results here
2415 assert(result
.isVariadic());
2416 body
<< "static_cast<int32_t>(" << resultNames
[i
] << ".size())";
2419 if (op
.getDialect().usePropertiesForAttributes()) {
2420 body
<< "}), " << builderOpState
2421 << ".getOrAddProperties<Properties>()."
2422 "resultSegmentSizes.begin());\n";
2429 case TypeParamKind::Collective
: {
2430 int numResults
= op
.getNumResults();
2431 int numVariadicResults
= op
.getNumVariableLengthResults();
2432 int numNonVariadicResults
= numResults
- numVariadicResults
;
2433 bool hasVariadicResult
= numVariadicResults
!= 0;
2435 // Avoid emitting "resultTypes.size() >= 0u" which is always true.
2436 if (!hasVariadicResult
|| numNonVariadicResults
!= 0)
2438 << "assert(resultTypes.size() "
2439 << (hasVariadicResult
? ">=" : "==") << " "
2440 << numNonVariadicResults
2441 << "u && \"mismatched number of results\");\n";
2442 body
<< " " << builderOpState
<< ".addTypes(resultTypes);\n";
2446 llvm_unreachable("unhandled TypeParamKind");
2449 // Some of the build methods generated here may be ambiguous, but TableGen's
2450 // ambiguous function detection will elide those ones.
2451 for (auto attrType
: attrBuilderType
) {
2452 emit(attrType
, TypeParamKind::Separate
, /*inferType=*/false);
2453 if (canInferType(op
))
2454 emit(attrType
, TypeParamKind::None
, /*inferType=*/true);
2455 emit(attrType
, TypeParamKind::Collective
, /*inferType=*/false);
2459 void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
2460 int numResults
= op
.getNumResults();
2463 SmallVector
<MethodParameter
> paramList
;
2464 paramList
.emplace_back("::mlir::OpBuilder &", "odsBuilder");
2465 paramList
.emplace_back("::mlir::OperationState &", builderOpState
);
2466 paramList
.emplace_back("::mlir::ValueRange", "operands");
2467 // Provide default value for `attributes` when its the last parameter
2468 StringRef attributesDefaultValue
= op
.getNumVariadicRegions() ? "" : "{}";
2469 paramList
.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
2470 "attributes", attributesDefaultValue
);
2471 if (op
.getNumVariadicRegions())
2472 paramList
.emplace_back("unsigned", "numRegions");
2474 auto *m
= opClass
.addStaticMethod("void", "build", std::move(paramList
));
2475 // If the builder is redundant, skip generating the method
2478 auto &body
= m
->body();
2481 body
<< " " << builderOpState
<< ".addOperands(operands);\n";
2484 body
<< " " << builderOpState
<< ".addAttributes(attributes);\n";
2486 // Create the correct number of regions
2487 if (int numRegions
= op
.getNumRegions()) {
2488 body
<< llvm::formatv(
2489 " for (unsigned i = 0; i != {0}; ++i)\n",
2490 (op
.getNumVariadicRegions() ? "numRegions" : Twine(numRegions
)));
2491 body
<< " (void)" << builderOpState
<< ".addRegion();\n";
2495 SmallVector
<std::string
, 2> resultTypes(numResults
, "operands[0].getType()");
2496 body
<< " " << builderOpState
<< ".addTypes({"
2497 << llvm::join(resultTypes
, ", ") << "});\n\n";
2500 void OpEmitter::genPopulateDefaultAttributes() {
2501 // All done if no attributes, except optional ones, have default values.
2502 if (llvm::all_of(op
.getAttributes(), [](const NamedAttribute
&named
) {
2503 return !named
.attr
.hasDefaultValue() || named
.attr
.isOptional();
2507 if (op
.getDialect().usePropertiesForAttributes()) {
2508 SmallVector
<MethodParameter
> paramList
;
2509 paramList
.emplace_back("::mlir::OperationName", "opName");
2510 paramList
.emplace_back("Properties &", "properties");
2512 opClass
.addStaticMethod("void", "populateDefaultProperties", paramList
);
2513 ERROR_IF_PRUNED(m
, "populateDefaultProperties", op
);
2514 auto &body
= m
->body();
2516 body
<< "::mlir::Builder " << odsBuilder
<< "(opName.getContext());\n";
2517 for (const NamedAttribute
&namedAttr
: op
.getAttributes()) {
2518 auto &attr
= namedAttr
.attr
;
2519 if (!attr
.hasDefaultValue() || attr
.isOptional())
2521 StringRef name
= namedAttr
.name
;
2523 fctx
.withBuilder(odsBuilder
);
2524 body
<< "if (!properties." << name
<< ")\n"
2525 << " properties." << name
<< " = "
2526 << std::string(tgfmt(attr
.getConstBuilderTemplate(), &fctx
,
2527 tgfmt(attr
.getDefaultValue(), &fctx
)))
2533 SmallVector
<MethodParameter
> paramList
;
2534 paramList
.emplace_back("const ::mlir::OperationName &", "opName");
2535 paramList
.emplace_back("::mlir::NamedAttrList &", "attributes");
2536 auto *m
= opClass
.addStaticMethod("void", "populateDefaultAttrs", paramList
);
2537 ERROR_IF_PRUNED(m
, "populateDefaultAttrs", op
);
2538 auto &body
= m
->body();
2541 // Set default attributes that are unset.
2542 body
<< "auto attrNames = opName.getAttributeNames();\n";
2543 body
<< "::mlir::Builder " << odsBuilder
2544 << "(attrNames.front().getContext());\n";
2545 StringMap
<int> attrIndex
;
2546 for (const auto &it
: llvm::enumerate(emitHelper
.getAttrMetadata())) {
2547 attrIndex
[it
.value().first
] = it
.index();
2549 for (const NamedAttribute
&namedAttr
: op
.getAttributes()) {
2550 auto &attr
= namedAttr
.attr
;
2551 if (!attr
.hasDefaultValue() || attr
.isOptional())
2553 auto index
= attrIndex
[namedAttr
.name
];
2554 body
<< "if (!attributes.get(attrNames[" << index
<< "])) {\n";
2556 fctx
.withBuilder(odsBuilder
);
2558 std::string defaultValue
=
2559 std::string(tgfmt(attr
.getConstBuilderTemplate(), &fctx
,
2560 tgfmt(attr
.getDefaultValue(), &fctx
)));
2561 body
.indent() << formatv("attributes.append(attrNames[{0}], {1});\n", index
,
2563 body
.unindent() << "}\n";
2567 void OpEmitter::genInferredTypeCollectiveParamBuilder() {
2568 SmallVector
<MethodParameter
> paramList
;
2569 paramList
.emplace_back("::mlir::OpBuilder &", "odsBuilder");
2570 paramList
.emplace_back("::mlir::OperationState &", builderOpState
);
2571 paramList
.emplace_back("::mlir::ValueRange", "operands");
2572 StringRef attributesDefaultValue
= op
.getNumVariadicRegions() ? "" : "{}";
2573 paramList
.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
2574 "attributes", attributesDefaultValue
);
2575 if (op
.getNumVariadicRegions())
2576 paramList
.emplace_back("unsigned", "numRegions");
2578 auto *m
= opClass
.addStaticMethod("void", "build", std::move(paramList
));
2579 // If the builder is redundant, skip generating the method
2582 auto &body
= m
->body();
2584 int numResults
= op
.getNumResults();
2585 int numVariadicResults
= op
.getNumVariableLengthResults();
2586 int numNonVariadicResults
= numResults
- numVariadicResults
;
2588 int numOperands
= op
.getNumOperands();
2589 int numVariadicOperands
= op
.getNumVariableLengthOperands();
2590 int numNonVariadicOperands
= numOperands
- numVariadicOperands
;
2593 if (numVariadicOperands
== 0 || numNonVariadicOperands
!= 0)
2594 body
<< " assert(operands.size()"
2595 << (numVariadicOperands
!= 0 ? " >= " : " == ")
2596 << numNonVariadicOperands
2597 << "u && \"mismatched number of parameters\");\n";
2598 body
<< " " << builderOpState
<< ".addOperands(operands);\n";
2599 body
<< " " << builderOpState
<< ".addAttributes(attributes);\n";
2601 // Create the correct number of regions
2602 if (int numRegions
= op
.getNumRegions()) {
2603 body
<< llvm::formatv(
2604 " for (unsigned i = 0; i != {0}; ++i)\n",
2605 (op
.getNumVariadicRegions() ? "numRegions" : Twine(numRegions
)));
2606 body
<< " (void)" << builderOpState
<< ".addRegion();\n";
2611 ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
2612 if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
2613 {1}.location, operands,
2614 {1}.attributes.getDictionary({1}.getContext()),
2615 {1}.getRawProperties(),
2616 {1}.regions, inferredReturnTypes))) {{)",
2617 opClass
.getClassName(), builderOpState
);
2618 if (numVariadicResults
== 0 || numNonVariadicResults
!= 0)
2619 body
<< "\n assert(inferredReturnTypes.size()"
2620 << (numVariadicResults
!= 0 ? " >= " : " == ") << numNonVariadicResults
2621 << "u && \"mismatched number of return types\");";
2622 body
<< "\n " << builderOpState
<< ".addTypes(inferredReturnTypes);";
2626 ::llvm::report_fatal_error("Failed to infer result
type(s
).");
2628 opClass
.getClassName(), builderOpState
);
2631 void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
2632 auto emit
= [&](AttrParamKind attrType
) {
2633 SmallVector
<MethodParameter
> paramList
;
2634 SmallVector
<std::string
, 4> resultNames
;
2635 llvm::StringSet
<> inferredAttributes
;
2636 buildParamList(paramList
, inferredAttributes
, resultNames
,
2637 TypeParamKind::None
, attrType
);
2639 auto *m
= opClass
.addStaticMethod("void", "build", std::move(paramList
));
2640 // If the builder is redundant, skip generating the method
2643 auto &body
= m
->body();
2644 genCodeForAddingArgAndRegionForBuilder(body
, inferredAttributes
,
2645 /*isRawValueAttr=*/attrType
==
2646 AttrParamKind::UnwrappedValue
);
2648 auto numResults
= op
.getNumResults();
2649 if (numResults
== 0)
2652 // Push all result types to the operation state
2653 const char *index
= op
.getOperand(0).isVariadic() ? ".front()" : "";
2654 std::string resultType
=
2655 formatv("{0}{1}.getType()", getArgumentName(op
, 0), index
).str();
2656 body
<< " " << builderOpState
<< ".addTypes({" << resultType
;
2657 for (int i
= 1; i
!= numResults
; ++i
)
2658 body
<< ", " << resultType
;
2662 emit(AttrParamKind::WrappedAttr
);
2663 // Generate additional builder(s) if any attributes can be "unwrapped"
2664 if (canGenerateUnwrappedBuilder(op
))
2665 emit(AttrParamKind::UnwrappedValue
);
2668 void OpEmitter::genUseAttrAsResultTypeBuilder() {
2669 SmallVector
<MethodParameter
> paramList
;
2670 paramList
.emplace_back("::mlir::OpBuilder &", "odsBuilder");
2671 paramList
.emplace_back("::mlir::OperationState &", builderOpState
);
2672 paramList
.emplace_back("::mlir::ValueRange", "operands");
2673 paramList
.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
2674 "attributes", "{}");
2675 auto *m
= opClass
.addStaticMethod("void", "build", std::move(paramList
));
2676 // If the builder is redundant, skip generating the method
2680 auto &body
= m
->body();
2682 // Push all result types to the operation state
2683 std::string resultType
;
2684 const auto &namedAttr
= op
.getAttribute(0);
2686 body
<< " auto attrName = " << op
.getGetterName(namedAttr
.name
)
2687 << "AttrName(" << builderOpState
2689 " for (auto attr : attributes) {\n"
2690 " if (attr.getName() != attrName) continue;\n";
2691 if (namedAttr
.attr
.isTypeAttr()) {
2692 resultType
= "::llvm::cast<::mlir::TypeAttr>(attr.getValue()).getValue()";
2694 resultType
= "::llvm::cast<::mlir::TypedAttr>(attr.getValue()).getType()";
2698 body
<< " " << builderOpState
<< ".addOperands(operands);\n";
2701 body
<< " " << builderOpState
<< ".addAttributes(attributes);\n";
2704 SmallVector
<std::string
, 2> resultTypes(op
.getNumResults(), resultType
);
2705 body
<< " " << builderOpState
<< ".addTypes({"
2706 << llvm::join(resultTypes
, ", ") << "});\n";
2710 /// Returns a signature of the builder. Updates the context `fctx` to enable
2711 /// replacement of $_builder and $_state in the body.
2712 static SmallVector
<MethodParameter
>
2713 getBuilderSignature(const Builder
&builder
) {
2714 ArrayRef
<Builder::Parameter
> params(builder
.getParameters());
2716 // Inject builder and state arguments.
2717 SmallVector
<MethodParameter
> arguments
;
2718 arguments
.reserve(params
.size() + 2);
2719 arguments
.emplace_back("::mlir::OpBuilder &", odsBuilder
);
2720 arguments
.emplace_back("::mlir::OperationState &", builderOpState
);
2722 for (unsigned i
= 0, e
= params
.size(); i
< e
; ++i
) {
2723 // If no name is provided, generate one.
2724 std::optional
<StringRef
> paramName
= params
[i
].getName();
2726 paramName
? paramName
->str() : "odsArg" + std::to_string(i
);
2728 StringRef defaultValue
;
2729 if (std::optional
<StringRef
> defaultParamValue
=
2730 params
[i
].getDefaultValue())
2731 defaultValue
= *defaultParamValue
;
2733 arguments
.emplace_back(params
[i
].getCppType(), std::move(name
),
2740 void OpEmitter::genBuilder() {
2741 // Handle custom builders if provided.
2742 for (const Builder
&builder
: op
.getBuilders()) {
2743 SmallVector
<MethodParameter
> arguments
= getBuilderSignature(builder
);
2745 std::optional
<StringRef
> body
= builder
.getBody();
2746 auto properties
= body
? Method::Static
: Method::StaticDeclaration
;
2748 opClass
.addMethod("void", "build", properties
, std::move(arguments
));
2750 ERROR_IF_PRUNED(method
, "build", op
);
2753 method
->setDeprecated(builder
.getDeprecatedMessage());
2756 fctx
.withBuilder(odsBuilder
);
2757 fctx
.addSubst("_state", builderOpState
);
2759 method
->body() << tgfmt(*body
, &fctx
);
2762 // Generate default builders that requires all result type, operands, and
2763 // attributes as parameters.
2764 if (op
.skipDefaultBuilders())
2767 // We generate three classes of builders here:
2768 // 1. one having a stand-alone parameter for each operand / attribute, and
2769 genSeparateArgParamBuilder();
2770 // 2. one having an aggregated parameter for all result types / operands /
2772 genCollectiveParamBuilder();
2773 // 3. one having a stand-alone parameter for each operand and attribute,
2774 // use the first operand or attribute's type as all result types
2775 // to facilitate different call patterns.
2776 if (op
.getNumVariableLengthResults() == 0) {
2777 if (op
.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
2778 genUseOperandAsResultTypeSeparateParamBuilder();
2779 genUseOperandAsResultTypeCollectiveParamBuilder();
2781 if (op
.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"))
2782 genUseAttrAsResultTypeBuilder();
2786 void OpEmitter::genCollectiveParamBuilder() {
2787 int numResults
= op
.getNumResults();
2788 int numVariadicResults
= op
.getNumVariableLengthResults();
2789 int numNonVariadicResults
= numResults
- numVariadicResults
;
2791 int numOperands
= op
.getNumOperands();
2792 int numVariadicOperands
= op
.getNumVariableLengthOperands();
2793 int numNonVariadicOperands
= numOperands
- numVariadicOperands
;
2795 SmallVector
<MethodParameter
> paramList
;
2796 paramList
.emplace_back("::mlir::OpBuilder &", "");
2797 paramList
.emplace_back("::mlir::OperationState &", builderOpState
);
2798 paramList
.emplace_back("::mlir::TypeRange", "resultTypes");
2799 paramList
.emplace_back("::mlir::ValueRange", "operands");
2800 // Provide default value for `attributes` when its the last parameter
2801 StringRef attributesDefaultValue
= op
.getNumVariadicRegions() ? "" : "{}";
2802 paramList
.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
2803 "attributes", attributesDefaultValue
);
2804 if (op
.getNumVariadicRegions())
2805 paramList
.emplace_back("unsigned", "numRegions");
2807 auto *m
= opClass
.addStaticMethod("void", "build", std::move(paramList
));
2808 // If the builder is redundant, skip generating the method
2811 auto &body
= m
->body();
2814 if (numVariadicOperands
== 0 || numNonVariadicOperands
!= 0)
2815 body
<< " assert(operands.size()"
2816 << (numVariadicOperands
!= 0 ? " >= " : " == ")
2817 << numNonVariadicOperands
2818 << "u && \"mismatched number of parameters\");\n";
2819 body
<< " " << builderOpState
<< ".addOperands(operands);\n";
2822 body
<< " " << builderOpState
<< ".addAttributes(attributes);\n";
2824 // Create the correct number of regions
2825 if (int numRegions
= op
.getNumRegions()) {
2826 body
<< llvm::formatv(
2827 " for (unsigned i = 0; i != {0}; ++i)\n",
2828 (op
.getNumVariadicRegions() ? "numRegions" : Twine(numRegions
)));
2829 body
<< " (void)" << builderOpState
<< ".addRegion();\n";
2833 if (numVariadicResults
== 0 || numNonVariadicResults
!= 0)
2834 body
<< " assert(resultTypes.size()"
2835 << (numVariadicResults
!= 0 ? " >= " : " == ") << numNonVariadicResults
2836 << "u && \"mismatched number of return types\");\n";
2837 body
<< " " << builderOpState
<< ".addTypes(resultTypes);\n";
2839 // Generate builder that infers type too.
2840 // TODO: Expand to handle successors.
2841 if (canInferType(op
) && op
.getNumSuccessors() == 0)
2842 genInferredTypeCollectiveParamBuilder();
2845 void OpEmitter::buildParamList(SmallVectorImpl
<MethodParameter
> ¶mList
,
2846 llvm::StringSet
<> &inferredAttributes
,
2847 SmallVectorImpl
<std::string
> &resultTypeNames
,
2848 TypeParamKind typeParamKind
,
2849 AttrParamKind attrParamKind
) {
2850 resultTypeNames
.clear();
2851 auto numResults
= op
.getNumResults();
2852 resultTypeNames
.reserve(numResults
);
2854 paramList
.emplace_back("::mlir::OpBuilder &", odsBuilder
);
2855 paramList
.emplace_back("::mlir::OperationState &", builderOpState
);
2857 switch (typeParamKind
) {
2858 case TypeParamKind::None
:
2860 case TypeParamKind::Separate
: {
2861 // Add parameters for all return types
2862 for (int i
= 0; i
< numResults
; ++i
) {
2863 const auto &result
= op
.getResult(i
);
2864 std::string resultName
= std::string(result
.name
);
2865 if (resultName
.empty())
2866 resultName
= std::string(formatv("resultType{0}", i
));
2869 result
.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
2871 paramList
.emplace_back(type
, resultName
, result
.isOptional());
2872 resultTypeNames
.emplace_back(std::move(resultName
));
2875 case TypeParamKind::Collective
: {
2876 paramList
.emplace_back("::mlir::TypeRange", "resultTypes");
2877 resultTypeNames
.push_back("resultTypes");
2881 // Add parameters for all arguments (operands and attributes).
2882 int defaultValuedAttrStartIndex
= op
.getNumArgs();
2883 // Successors and variadic regions go at the end of the parameter list, so no
2884 // default arguments are possible.
2885 bool hasTrailingParams
= op
.getNumSuccessors() || op
.getNumVariadicRegions();
2886 if (attrParamKind
== AttrParamKind::UnwrappedValue
&& !hasTrailingParams
) {
2887 // Calculate the start index from which we can attach default values in the
2888 // builder declaration.
2889 for (int i
= op
.getNumArgs() - 1; i
>= 0; --i
) {
2891 llvm::dyn_cast_if_present
<tblgen::NamedAttribute
*>(op
.getArg(i
));
2892 if (!namedAttr
|| !namedAttr
->attr
.hasDefaultValue())
2895 if (!canUseUnwrappedRawValue(namedAttr
->attr
))
2898 // Creating an APInt requires us to provide bitwidth, value, and
2899 // signedness, which is complicated compared to others. Similarly
2901 // TODO: Adjust the 'returnType' field of such attributes
2903 StringRef retType
= namedAttr
->attr
.getReturnType();
2904 if (retType
== "::llvm::APInt" || retType
== "::llvm::APFloat")
2907 defaultValuedAttrStartIndex
= i
;
2911 /// Collect any inferred attributes.
2912 for (const NamedTypeConstraint
&operand
: op
.getOperands()) {
2913 if (operand
.isVariadicOfVariadic()) {
2914 inferredAttributes
.insert(
2915 operand
.constraint
.getVariadicOfVariadicSegmentSizeAttr());
2919 for (int i
= 0, e
= op
.getNumArgs(), numOperands
= 0; i
< e
; ++i
) {
2920 Argument arg
= op
.getArg(i
);
2921 if (const auto *operand
=
2922 llvm::dyn_cast_if_present
<NamedTypeConstraint
*>(arg
)) {
2924 if (operand
->isVariadicOfVariadic())
2925 type
= "::llvm::ArrayRef<::mlir::ValueRange>";
2926 else if (operand
->isVariadic())
2927 type
= "::mlir::ValueRange";
2929 type
= "::mlir::Value";
2931 paramList
.emplace_back(type
, getArgumentName(op
, numOperands
++),
2932 operand
->isOptional());
2935 if ([[maybe_unused
]] const auto *operand
=
2936 llvm::dyn_cast_if_present
<NamedProperty
*>(arg
)) {
2940 const NamedAttribute
&namedAttr
= *arg
.get
<NamedAttribute
*>();
2941 const Attribute
&attr
= namedAttr
.attr
;
2943 // Inferred attributes don't need to be added to the param list.
2944 if (inferredAttributes
.contains(namedAttr
.name
))
2948 switch (attrParamKind
) {
2949 case AttrParamKind::WrappedAttr
:
2950 type
= attr
.getStorageType();
2952 case AttrParamKind::UnwrappedValue
:
2953 if (canUseUnwrappedRawValue(attr
))
2954 type
= attr
.getReturnType();
2956 type
= attr
.getStorageType();
2960 // Attach default value if requested and possible.
2961 std::string defaultValue
;
2962 if (attrParamKind
== AttrParamKind::UnwrappedValue
&&
2963 i
>= defaultValuedAttrStartIndex
) {
2964 defaultValue
+= attr
.getDefaultValue();
2966 paramList
.emplace_back(type
, namedAttr
.name
, StringRef(defaultValue
),
2970 /// Insert parameters for each successor.
2971 for (const NamedSuccessor
&succ
: op
.getSuccessors()) {
2973 succ
.isVariadic() ? "::mlir::BlockRange" : "::mlir::Block *";
2974 paramList
.emplace_back(type
, succ
.name
);
2977 /// Insert parameters for variadic regions.
2978 for (const NamedRegion
®ion
: op
.getRegions())
2979 if (region
.isVariadic())
2980 paramList
.emplace_back("unsigned",
2981 llvm::formatv("{0}Count", region
.name
).str());
2984 void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
2985 MethodBody
&body
, llvm::StringSet
<> &inferredAttributes
,
2986 bool isRawValueAttr
) {
2987 // Push all operands to the result.
2988 for (int i
= 0, e
= op
.getNumOperands(); i
< e
; ++i
) {
2989 std::string argName
= getArgumentName(op
, i
);
2990 const NamedTypeConstraint
&operand
= op
.getOperand(i
);
2991 if (operand
.constraint
.isVariadicOfVariadic()) {
2992 body
<< " for (::mlir::ValueRange range : " << argName
<< ")\n "
2993 << builderOpState
<< ".addOperands(range);\n";
2995 // Add the segment attribute.
2997 << " ::llvm::SmallVector<int32_t> rangeSegments;\n"
2998 << " for (::mlir::ValueRange range : " << argName
<< ")\n"
2999 << " rangeSegments.push_back(range.size());\n"
3000 << " auto rangeAttr = " << odsBuilder
3001 << ".getDenseI32ArrayAttr(rangeSegments);\n";
3002 if (op
.getDialect().usePropertiesForAttributes()) {
3003 body
<< " " << builderOpState
<< ".getOrAddProperties<Properties>()."
3004 << operand
.constraint
.getVariadicOfVariadicSegmentSizeAttr()
3007 body
<< " " << builderOpState
<< ".addAttribute("
3008 << op
.getGetterName(
3009 operand
.constraint
.getVariadicOfVariadicSegmentSizeAttr())
3010 << "AttrName(" << builderOpState
<< ".name), rangeAttr);";
3016 if (operand
.isOptional())
3017 body
<< " if (" << argName
<< ")\n ";
3018 body
<< " " << builderOpState
<< ".addOperands(" << argName
<< ");\n";
3021 // If the operation has the operand segment size attribute, add it here.
3022 auto emitSegment
= [&]() {
3023 interleaveComma(llvm::seq
<int>(0, op
.getNumOperands()), body
, [&](int i
) {
3024 const NamedTypeConstraint
&operand
= op
.getOperand(i
);
3025 if (!operand
.isVariableLength()) {
3030 std::string operandName
= getArgumentName(op
, i
);
3031 if (operand
.isOptional()) {
3032 body
<< "(" << operandName
<< " ? 1 : 0)";
3033 } else if (operand
.isVariadicOfVariadic()) {
3034 body
<< llvm::formatv(
3035 "static_cast<int32_t>(std::accumulate({0}.begin(), {0}.end(), 0, "
3036 "[](int32_t curSum, ::mlir::ValueRange range) {{ return curSum + "
3037 "range.size(); }))",
3040 body
<< "static_cast<int32_t>(" << getArgumentName(op
, i
) << ".size())";
3044 if (op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
3045 std::string sizes
= op
.getGetterName(operandSegmentAttrName
);
3046 if (op
.getDialect().usePropertiesForAttributes()) {
3047 body
<< " ::llvm::copy(::llvm::ArrayRef<int32_t>({";
3049 body
<< "}), " << builderOpState
3050 << ".getOrAddProperties<Properties>()."
3051 "operandSegmentSizes.begin());\n";
3053 body
<< " " << builderOpState
<< ".addAttribute(" << sizes
<< "AttrName("
3054 << builderOpState
<< ".name), "
3055 << "odsBuilder.getDenseI32ArrayAttr({";
3061 // Push all attributes to the result.
3062 for (const auto &namedAttr
: op
.getAttributes()) {
3063 auto &attr
= namedAttr
.attr
;
3064 if (attr
.isDerivedAttr() || inferredAttributes
.contains(namedAttr
.name
))
3067 // TODO: The wrapping of optional is different for default or not, so don't
3068 // unwrap for default ones that would fail below.
3069 bool emitNotNullCheck
=
3070 (attr
.isOptional() && !attr
.hasDefaultValue()) ||
3071 (attr
.hasDefaultValue() && !isRawValueAttr
) ||
3072 // TODO: UnitAttr is optional, not wrapped, but needs to be guarded as
3073 // the constant materialization is only for true case.
3074 (isRawValueAttr
&& attr
.getAttrDefName() == "UnitAttr");
3075 if (emitNotNullCheck
)
3076 body
.indent() << formatv("if ({0}) ", namedAttr
.name
) << "{\n";
3078 if (isRawValueAttr
&& canUseUnwrappedRawValue(attr
)) {
3079 // If this is a raw value, then we need to wrap it in an Attribute
3082 fctx
.withBuilder("odsBuilder");
3083 if (op
.getDialect().usePropertiesForAttributes()) {
3084 body
<< formatv(" {0}.getOrAddProperties<Properties>().{1} = {2};\n",
3085 builderOpState
, namedAttr
.name
,
3086 constBuildAttrFromParam(attr
, fctx
, namedAttr
.name
));
3088 body
<< formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
3089 builderOpState
, op
.getGetterName(namedAttr
.name
),
3090 constBuildAttrFromParam(attr
, fctx
, namedAttr
.name
));
3093 if (op
.getDialect().usePropertiesForAttributes()) {
3094 body
<< formatv(" {0}.getOrAddProperties<Properties>().{1} = {1};\n",
3095 builderOpState
, namedAttr
.name
);
3097 body
<< formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
3098 builderOpState
, op
.getGetterName(namedAttr
.name
),
3102 if (emitNotNullCheck
)
3103 body
.unindent() << " }\n";
3106 // Create the correct number of regions.
3107 for (const NamedRegion
®ion
: op
.getRegions()) {
3108 if (region
.isVariadic())
3109 body
<< formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ",
3112 body
<< " (void)" << builderOpState
<< ".addRegion();\n";
3115 // Push all successors to the result.
3116 for (const NamedSuccessor
&namedSuccessor
: op
.getSuccessors()) {
3117 body
<< formatv(" {0}.addSuccessors({1});\n", builderOpState
,
3118 namedSuccessor
.name
);
3122 void OpEmitter::genCanonicalizerDecls() {
3123 bool hasCanonicalizeMethod
= def
.getValueAsBit("hasCanonicalizeMethod");
3124 if (hasCanonicalizeMethod
) {
3125 // static LogicResult FooOp::
3126 // canonicalize(FooOp op, PatternRewriter &rewriter);
3127 SmallVector
<MethodParameter
> paramList
;
3128 paramList
.emplace_back(op
.getCppClassName(), "op");
3129 paramList
.emplace_back("::mlir::PatternRewriter &", "rewriter");
3130 auto *m
= opClass
.declareStaticMethod("::mlir::LogicalResult",
3131 "canonicalize", std::move(paramList
));
3132 ERROR_IF_PRUNED(m
, "canonicalize", op
);
3135 // We get a prototype for 'getCanonicalizationPatterns' if requested directly
3136 // or if using a 'canonicalize' method.
3137 bool hasCanonicalizer
= def
.getValueAsBit("hasCanonicalizer");
3138 if (!hasCanonicalizeMethod
&& !hasCanonicalizer
)
3141 // We get a body for 'getCanonicalizationPatterns' when using a 'canonicalize'
3142 // method, but not implementing 'getCanonicalizationPatterns' manually.
3143 bool hasBody
= hasCanonicalizeMethod
&& !hasCanonicalizer
;
3145 // Add a signature for getCanonicalizationPatterns if implemented by the
3146 // dialect or if synthesized to call 'canonicalize'.
3147 SmallVector
<MethodParameter
> paramList
;
3148 paramList
.emplace_back("::mlir::RewritePatternSet &", "results");
3149 paramList
.emplace_back("::mlir::MLIRContext *", "context");
3150 auto kind
= hasBody
? Method::Static
: Method::StaticDeclaration
;
3151 auto *method
= opClass
.addMethod("void", "getCanonicalizationPatterns", kind
,
3152 std::move(paramList
));
3154 // If synthesizing the method, fill it.
3156 ERROR_IF_PRUNED(method
, "getCanonicalizationPatterns", op
);
3157 method
->body() << " results.add(canonicalize);\n";
3161 void OpEmitter::genFolderDecls() {
3162 if (!op
.hasFolder())
3165 SmallVector
<MethodParameter
> paramList
;
3166 paramList
.emplace_back("FoldAdaptor", "adaptor");
3169 bool hasSingleResult
=
3170 op
.getNumResults() == 1 && op
.getNumVariableLengthResults() == 0;
3171 if (hasSingleResult
) {
3172 retType
= "::mlir::OpFoldResult";
3174 paramList
.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
3176 retType
= "::mlir::LogicalResult";
3179 auto *m
= opClass
.declareMethod(retType
, "fold", std::move(paramList
));
3180 ERROR_IF_PRUNED(m
, "fold", op
);
3183 void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait
*opTrait
) {
3184 Interface interface
= opTrait
->getInterface();
3186 // Get the set of methods that should always be declared.
3187 auto alwaysDeclaredMethodsVec
= opTrait
->getAlwaysDeclaredMethods();
3188 llvm::StringSet
<> alwaysDeclaredMethods
;
3189 alwaysDeclaredMethods
.insert(alwaysDeclaredMethodsVec
.begin(),
3190 alwaysDeclaredMethodsVec
.end());
3192 for (const InterfaceMethod
&method
: interface
.getMethods()) {
3193 // Don't declare if the method has a body.
3194 if (method
.getBody())
3196 // Don't declare if the method has a default implementation and the op
3197 // didn't request that it always be declared.
3198 if (method
.getDefaultImplementation() &&
3199 !alwaysDeclaredMethods
.count(method
.getName()))
3201 // Interface methods are allowed to overlap with existing methods, so don't
3203 (void)genOpInterfaceMethod(method
);
3207 Method
*OpEmitter::genOpInterfaceMethod(const InterfaceMethod
&method
,
3209 SmallVector
<MethodParameter
> paramList
;
3210 for (const InterfaceMethod::Argument
&arg
: method
.getArguments())
3211 paramList
.emplace_back(arg
.type
, arg
.name
);
3213 auto props
= (method
.isStatic() ? Method::Static
: Method::None
) |
3214 (declaration
? Method::Declaration
: Method::None
);
3215 return opClass
.addMethod(method
.getReturnType(), method
.getName(), props
,
3216 std::move(paramList
));
3219 void OpEmitter::genOpInterfaceMethods() {
3220 for (const auto &trait
: op
.getTraits()) {
3221 if (const auto *opTrait
= dyn_cast
<tblgen::InterfaceTrait
>(&trait
))
3222 if (opTrait
->shouldDeclareMethods())
3223 genOpInterfaceMethods(opTrait
);
3227 void OpEmitter::genSideEffectInterfaceMethods() {
3228 enum EffectKind
{ Operand
, Result
, Symbol
, Static
};
3229 struct EffectLocation
{
3230 /// The effect applied.
3233 /// The index if the kind is not static.
3236 /// The kind of the location.
3240 StringMap
<SmallVector
<EffectLocation
, 1>> interfaceEffects
;
3241 auto resolveDecorators
= [&](Operator::var_decorator_range decorators
,
3242 unsigned index
, unsigned kind
) {
3243 for (auto decorator
: decorators
)
3244 if (SideEffect
*effect
= dyn_cast
<SideEffect
>(&decorator
)) {
3245 opClass
.addTrait(effect
->getInterfaceTrait());
3246 interfaceEffects
[effect
->getBaseEffectName()].push_back(
3247 EffectLocation
{*effect
, index
, kind
});
3251 // Collect effects that were specified via:
3253 for (const auto &trait
: op
.getTraits()) {
3254 const auto *opTrait
= dyn_cast
<tblgen::SideEffectTrait
>(&trait
);
3257 auto &effects
= interfaceEffects
[opTrait
->getBaseEffectName()];
3258 for (auto decorator
: opTrait
->getEffects())
3259 effects
.push_back(EffectLocation
{cast
<SideEffect
>(decorator
),
3260 /*index=*/0, EffectKind::Static
});
3262 /// Attributes and Operands.
3263 for (unsigned i
= 0, operandIt
= 0, e
= op
.getNumArgs(); i
!= e
; ++i
) {
3264 Argument arg
= op
.getArg(i
);
3265 if (arg
.is
<NamedTypeConstraint
*>()) {
3266 resolveDecorators(op
.getArgDecorators(i
), operandIt
, EffectKind::Operand
);
3270 if (arg
.is
<NamedProperty
*>())
3272 const NamedAttribute
*attr
= arg
.get
<NamedAttribute
*>();
3273 if (attr
->attr
.getBaseAttr().isSymbolRefAttr())
3274 resolveDecorators(op
.getArgDecorators(i
), i
, EffectKind::Symbol
);
3277 for (unsigned i
= 0, e
= op
.getNumResults(); i
!= e
; ++i
)
3278 resolveDecorators(op
.getResultDecorators(i
), i
, EffectKind::Result
);
3280 // The code used to add an effect instance.
3281 // {0}: The effect class.
3282 // {1}: Optional value or symbol reference.
3283 // {2}: The side effect stage.
3284 // {3}: Does this side effect act on every single value of resource.
3285 // {4}: The resource class.
3286 const char *addEffectCode
=
3287 " effects.emplace_back({0}::get(), {1}{2}, {3}, {4}::get());\n";
3289 for (auto &it
: interfaceEffects
) {
3290 // Generate the 'getEffects' method.
3291 std::string type
= llvm::formatv("::llvm::SmallVectorImpl<::mlir::"
3292 "SideEffects::EffectInstance<{0}>> &",
3295 auto *getEffects
= opClass
.addMethod("void", "getEffects",
3296 MethodParameter(type
, "effects"));
3297 ERROR_IF_PRUNED(getEffects
, "getEffects", op
);
3298 auto &body
= getEffects
->body();
3300 // Add effect instances for each of the locations marked on the operation.
3301 for (auto &location
: it
.second
) {
3302 StringRef effect
= location
.effect
.getName();
3303 StringRef resource
= location
.effect
.getResource();
3304 int stage
= (int)location
.effect
.getStage();
3305 bool effectOnFullRegion
= (int)location
.effect
.getEffectOnfullRegion();
3306 if (location
.kind
== EffectKind::Static
) {
3307 // A static instance has no attached value.
3308 body
<< llvm::formatv(addEffectCode
, effect
, "", stage
,
3309 effectOnFullRegion
, resource
)
3311 } else if (location
.kind
== EffectKind::Symbol
) {
3312 // A symbol reference requires adding the proper attribute.
3313 const auto *attr
= op
.getArg(location
.index
).get
<NamedAttribute
*>();
3314 std::string argName
= op
.getGetterName(attr
->name
);
3315 if (attr
->attr
.isOptional()) {
3316 body
<< " if (auto symbolRef = " << argName
<< "Attr())\n "
3317 << llvm::formatv(addEffectCode
, effect
, "symbolRef, ", stage
,
3318 effectOnFullRegion
, resource
)
3321 body
<< llvm::formatv(addEffectCode
, effect
, argName
+ "Attr(), ",
3322 stage
, effectOnFullRegion
, resource
)
3326 // Otherwise this is an operand/result, so we need to attach the Value.
3327 body
<< " for (::mlir::Value value : getODS"
3328 << (location
.kind
== EffectKind::Operand
? "Operands" : "Results")
3329 << "(" << location
.index
<< "))\n "
3330 << llvm::formatv(addEffectCode
, effect
, "value, ", stage
,
3331 effectOnFullRegion
, resource
)
3338 void OpEmitter::genTypeInterfaceMethods() {
3339 if (!op
.allResultTypesKnown())
3341 // Generate 'inferReturnTypes' method declaration using the interface method
3342 // declared in 'InferTypeOpInterface' op interface.
3344 cast
<InterfaceTrait
>(op
.getTrait("::mlir::InferTypeOpInterface::Trait"));
3345 Interface interface
= trait
->getInterface();
3346 Method
*method
= [&]() -> Method
* {
3347 for (const InterfaceMethod
&interfaceMethod
: interface
.getMethods()) {
3348 if (interfaceMethod
.getName() == "inferReturnTypes") {
3349 return genOpInterfaceMethod(interfaceMethod
, /*declaration=*/false);
3352 assert(0 && "unable to find inferReturnTypes interface method");
3355 ERROR_IF_PRUNED(method
, "inferReturnTypes", op
);
3356 auto &body
= method
->body();
3357 body
<< " inferredReturnTypes.resize(" << op
.getNumResults() << ");\n";
3360 fctx
.withBuilder("odsBuilder");
3361 fctx
.addSubst("_ctxt", "context");
3362 body
<< " ::mlir::Builder odsBuilder(context);\n";
3364 // Process the type inference graph in topological order, starting from types
3365 // that are always fully-inferred: operands and results with constructible
3366 // types. The type inference graph here will always be a DAG, so this gives
3367 // us the correct order for generating the types. -1 is a placeholder to
3368 // indicate the type for a result has not been generated.
3369 SmallVector
<int> constructedIndices(op
.getNumResults(), -1);
3370 int inferredTypeIdx
= 0;
3371 for (int numResults
= op
.getNumResults(); inferredTypeIdx
!= numResults
;) {
3372 for (int i
= 0, e
= op
.getNumResults(); i
!= e
; ++i
) {
3373 if (constructedIndices
[i
] >= 0)
3375 const InferredResultType
&infer
= op
.getInferredResultType(i
);
3376 std::string typeStr
;
3377 if (infer
.isArg()) {
3378 // If this is an operand, just index into operand list to access the
3380 auto arg
= op
.getArgToOperandOrAttribute(infer
.getIndex());
3381 if (arg
.kind() == Operator::OperandOrAttribute::Kind::Operand
) {
3382 typeStr
= ("operands[" + Twine(arg
.operandOrAttributeIndex()) +
3386 // If this is an attribute, index into the attribute dictionary.
3389 op
.getArg(arg
.operandOrAttributeIndex()).get
<NamedAttribute
*>();
3390 body
<< " ::mlir::TypedAttr odsInferredTypeAttr" << inferredTypeIdx
3392 if (op
.getDialect().usePropertiesForAttributes()) {
3393 body
<< "(properties ? properties.as<Properties *>()->"
3396 "::llvm::dyn_cast_or_null<::mlir::TypedAttr>(attributes."
3398 attr
->name
+ "\")));\n";
3400 body
<< "::llvm::dyn_cast_or_null<::mlir::TypedAttr>(attributes."
3402 attr
->name
+ "\"));\n";
3404 body
<< " if (!odsInferredTypeAttr" << inferredTypeIdx
3405 << ") return ::mlir::failure();\n";
3407 ("odsInferredTypeAttr" + Twine(inferredTypeIdx
) + ".getType()")
3410 } else if (std::optional
<StringRef
> builder
=
3411 op
.getResult(infer
.getResultIndex())
3412 .constraint
.getBuilderCall()) {
3413 typeStr
= tgfmt(*builder
, &fctx
).str();
3414 } else if (int index
= constructedIndices
[infer
.getResultIndex()];
3416 typeStr
= ("odsInferredType" + Twine(index
)).str();
3420 body
<< " ::mlir::Type odsInferredType" << inferredTypeIdx
++ << " = "
3421 << tgfmt(infer
.getTransformer(), &fctx
.withSelf(typeStr
)) << ";\n";
3422 constructedIndices
[i
] = inferredTypeIdx
- 1;
3425 for (auto [i
, index
] : llvm::enumerate(constructedIndices
))
3426 body
<< " inferredReturnTypes[" << i
<< "] = odsInferredType" << index
3428 body
<< " return ::mlir::success();";
3431 void OpEmitter::genParser() {
3432 if (hasStringAttribute(def
, "assemblyFormat"))
3435 if (!def
.getValueAsBit("hasCustomAssemblyFormat"))
3438 SmallVector
<MethodParameter
> paramList
;
3439 paramList
.emplace_back("::mlir::OpAsmParser &", "parser");
3440 paramList
.emplace_back("::mlir::OperationState &", "result");
3442 auto *method
= opClass
.declareStaticMethod("::mlir::ParseResult", "parse",
3443 std::move(paramList
));
3444 ERROR_IF_PRUNED(method
, "parse", op
);
3447 void OpEmitter::genPrinter() {
3448 if (hasStringAttribute(def
, "assemblyFormat"))
3451 // Check to see if this op uses a c++ format.
3452 if (!def
.getValueAsBit("hasCustomAssemblyFormat"))
3454 auto *method
= opClass
.declareMethod(
3455 "void", "print", MethodParameter("::mlir::OpAsmPrinter &", "p"));
3456 ERROR_IF_PRUNED(method
, "print", op
);
3459 void OpEmitter::genVerifier() {
3461 opClass
.addMethod("::mlir::LogicalResult", "verifyInvariantsImpl");
3462 ERROR_IF_PRUNED(implMethod
, "verifyInvariantsImpl", op
);
3463 auto &implBody
= implMethod
->body();
3464 bool useProperties
= emitHelper
.hasProperties();
3466 populateSubstitutions(emitHelper
, verifyCtx
);
3467 genAttributeVerifier(emitHelper
, verifyCtx
, implBody
, staticVerifierEmitter
,
3469 genOperandResultVerifier(implBody
, op
.getOperands(), "operand");
3470 genOperandResultVerifier(implBody
, op
.getResults(), "result");
3472 for (auto &trait
: op
.getTraits()) {
3473 if (auto *t
= dyn_cast
<tblgen::PredTrait
>(&trait
)) {
3474 implBody
<< tgfmt(" if (!($0))\n "
3475 "return emitOpError(\"failed to verify that $1\");\n",
3476 &verifyCtx
, tgfmt(t
->getPredTemplate(), &verifyCtx
),
3481 genRegionVerifier(implBody
);
3482 genSuccessorVerifier(implBody
);
3484 implBody
<< " return ::mlir::success();\n";
3486 // TODO: Some places use the `verifyInvariants` to do operation verification.
3487 // This may not act as their expectation because this doesn't call any
3488 // verifiers of native/interface traits. Needs to review those use cases and
3489 // see if we should use the mlir::verify() instead.
3490 auto *method
= opClass
.addMethod("::mlir::LogicalResult", "verifyInvariants");
3491 ERROR_IF_PRUNED(method
, "verifyInvariants", op
);
3492 auto &body
= method
->body();
3493 if (def
.getValueAsBit("hasVerifier")) {
3494 body
<< " if(::mlir::succeeded(verifyInvariantsImpl()) && "
3495 "::mlir::succeeded(verify()))\n";
3496 body
<< " return ::mlir::success();\n";
3497 body
<< " return ::mlir::failure();";
3499 body
<< " return verifyInvariantsImpl();";
3503 void OpEmitter::genCustomVerifier() {
3504 if (def
.getValueAsBit("hasVerifier")) {
3505 auto *method
= opClass
.declareMethod("::mlir::LogicalResult", "verify");
3506 ERROR_IF_PRUNED(method
, "verify", op
);
3509 if (def
.getValueAsBit("hasRegionVerifier")) {
3511 opClass
.declareMethod("::mlir::LogicalResult", "verifyRegions");
3512 ERROR_IF_PRUNED(method
, "verifyRegions", op
);
3516 void OpEmitter::genOperandResultVerifier(MethodBody
&body
,
3517 Operator::const_value_range values
,
3518 StringRef valueKind
) {
3519 // Check that an optional value is at most 1 element.
3521 // {0}: Value index.
3522 // {1}: "operand" or "result"
3523 const char *const verifyOptional
= R
"(
3524 if (valueGroup{0}.size() > 1) {
3525 return emitOpError("{1} group starting at
#") << index
3526 << " requires 0 or 1 element, but found " << valueGroup{0}.size();
3529 // Check the types of a range of values.
3531 // {0}: Value index.
3532 // {1}: Type constraint function.
3533 // {2}: "operand" or "result"
3534 const char *const verifyValues
= R
"(
3535 for (auto v : valueGroup{0}) {
3536 if (::mlir::failed({1}(*this, v.getType(), "{2}", index++)))
3537 return ::mlir::failure();
3541 const auto canSkip
= [](const NamedTypeConstraint
&value
) {
3542 return !value
.hasPredicate() && !value
.isOptional() &&
3543 !value
.isVariadicOfVariadic();
3545 if (values
.empty() || llvm::all_of(values
, canSkip
))
3550 body
<< " {\n unsigned index = 0; (void)index;\n";
3552 for (const auto &staticValue
: llvm::enumerate(values
)) {
3553 const NamedTypeConstraint
&value
= staticValue
.value();
3555 bool hasPredicate
= value
.hasPredicate();
3556 bool isOptional
= value
.isOptional();
3557 bool isVariadicOfVariadic
= value
.isVariadicOfVariadic();
3558 if (!hasPredicate
&& !isOptional
&& !isVariadicOfVariadic
)
3560 body
<< formatv(" auto valueGroup{2} = getODS{0}{1}s({2});\n",
3561 // Capitalize the first letter to match the function name
3562 valueKind
.substr(0, 1).upper(), valueKind
.substr(1),
3563 staticValue
.index());
3565 // If the constraint is optional check that the value group has at most 1
3568 body
<< formatv(verifyOptional
, staticValue
.index(), valueKind
);
3569 } else if (isVariadicOfVariadic
) {
3571 " if (::mlir::failed(::mlir::OpTrait::impl::verifyValueSizeAttr("
3572 "*this, \"{0}\", \"{1}\", valueGroup{2}.size())))\n"
3573 " return ::mlir::failure();\n",
3574 value
.constraint
.getVariadicOfVariadicSegmentSizeAttr(), value
.name
,
3575 staticValue
.index());
3578 // Otherwise, if there is no predicate there is nothing left to do.
3581 // Emit a loop to check all the dynamic values in the pack.
3582 StringRef constraintFn
=
3583 staticVerifierEmitter
.getTypeConstraintFn(value
.constraint
);
3584 body
<< formatv(verifyValues
, staticValue
.index(), constraintFn
, valueKind
);
3590 void OpEmitter::genRegionVerifier(MethodBody
&body
) {
3591 /// Code to verify a region.
3593 /// {0}: Getter for the regions.
3594 /// {1}: The region constraint.
3595 /// {2}: The region's name.
3596 /// {3}: The region description.
3597 const char *const verifyRegion
= R
"(
3598 for (auto ®ion : {0})
3599 if (::mlir::failed({1}(*this, region, "{2}", index++)))
3600 return ::mlir::failure();
3602 /// Get a single region.
3604 /// {0}: The region's index.
3605 const char *const getSingleRegion
=
3606 "::llvm::MutableArrayRef((*this)->getRegion({0}))";
3608 // If we have no regions, there is nothing more to do.
3609 const auto canSkip
= [](const NamedRegion
®ion
) {
3610 return region
.constraint
.getPredicate().isNull();
3612 auto regions
= op
.getRegions();
3613 if (regions
.empty() && llvm::all_of(regions
, canSkip
))
3616 body
<< " {\n unsigned index = 0; (void)index;\n";
3617 for (const auto &it
: llvm::enumerate(regions
)) {
3618 const auto ®ion
= it
.value();
3619 if (canSkip(region
))
3622 auto getRegion
= region
.isVariadic()
3623 ? formatv("{0}()", op
.getGetterName(region
.name
)).str()
3624 : formatv(getSingleRegion
, it
.index()).str();
3626 staticVerifierEmitter
.getRegionConstraintFn(region
.constraint
);
3627 body
<< formatv(verifyRegion
, getRegion
, constraintFn
, region
.name
);
3632 void OpEmitter::genSuccessorVerifier(MethodBody
&body
) {
3633 const char *const verifySuccessor
= R
"(
3634 for (auto *successor : {0})
3635 if (::mlir::failed({1}(*this, successor, "{2}", index++)))
3636 return ::mlir::failure();
3638 /// Get a single successor.
3640 /// {0}: The successor's name.
3641 const char *const getSingleSuccessor
= "::llvm::MutableArrayRef({0}())";
3643 // If we have no successors, there is nothing more to do.
3644 const auto canSkip
= [](const NamedSuccessor
&successor
) {
3645 return successor
.constraint
.getPredicate().isNull();
3647 auto successors
= op
.getSuccessors();
3648 if (successors
.empty() && llvm::all_of(successors
, canSkip
))
3651 body
<< " {\n unsigned index = 0; (void)index;\n";
3653 for (auto it
: llvm::enumerate(successors
)) {
3654 const auto &successor
= it
.value();
3655 if (canSkip(successor
))
3659 formatv(successor
.isVariadic() ? "{0}()" : getSingleSuccessor
,
3660 successor
.name
, it
.index())
3663 staticVerifierEmitter
.getSuccessorConstraintFn(successor
.constraint
);
3664 body
<< formatv(verifySuccessor
, getSuccessor
, constraintFn
,
3670 /// Add a size count trait to the given operation class.
3671 static void addSizeCountTrait(OpClass
&opClass
, StringRef traitKind
,
3672 int numTotal
, int numVariadic
) {
3673 if (numVariadic
!= 0) {
3674 if (numTotal
== numVariadic
)
3675 opClass
.addTrait("::mlir::OpTrait::Variadic" + traitKind
+ "s");
3677 opClass
.addTrait("::mlir::OpTrait::AtLeastN" + traitKind
+ "s<" +
3678 Twine(numTotal
- numVariadic
) + ">::Impl");
3683 opClass
.addTrait("::mlir::OpTrait::Zero" + traitKind
+ "s");
3686 opClass
.addTrait("::mlir::OpTrait::One" + traitKind
);
3689 opClass
.addTrait("::mlir::OpTrait::N" + traitKind
+ "s<" + Twine(numTotal
) +
3695 void OpEmitter::genTraits() {
3696 // Add region size trait.
3697 unsigned numRegions
= op
.getNumRegions();
3698 unsigned numVariadicRegions
= op
.getNumVariadicRegions();
3699 addSizeCountTrait(opClass
, "Region", numRegions
, numVariadicRegions
);
3701 // Add result size traits.
3702 int numResults
= op
.getNumResults();
3703 int numVariadicResults
= op
.getNumVariableLengthResults();
3704 addSizeCountTrait(opClass
, "Result", numResults
, numVariadicResults
);
3706 // For single result ops with a known specific type, generate a OneTypedResult
3708 if (numResults
== 1 && numVariadicResults
== 0) {
3709 auto cppName
= op
.getResults().begin()->constraint
.getCPPClassName();
3710 opClass
.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName
+ ">::Impl");
3713 // Add successor size trait.
3714 unsigned numSuccessors
= op
.getNumSuccessors();
3715 unsigned numVariadicSuccessors
= op
.getNumVariadicSuccessors();
3716 addSizeCountTrait(opClass
, "Successor", numSuccessors
, numVariadicSuccessors
);
3718 // Add variadic size trait and normal op traits.
3719 int numOperands
= op
.getNumOperands();
3720 int numVariadicOperands
= op
.getNumVariableLengthOperands();
3722 // Add operand size trait.
3723 addSizeCountTrait(opClass
, "Operand", numOperands
, numVariadicOperands
);
3725 // The op traits defined internal are ensured that they can be verified
3727 for (const auto &trait
: op
.getTraits()) {
3728 if (auto *opTrait
= dyn_cast
<tblgen::NativeTrait
>(&trait
)) {
3729 if (opTrait
->isStructuralOpTrait())
3730 opClass
.addTrait(opTrait
->getFullyQualifiedTraitName());
3734 // OpInvariants wrapps the verifyInvariants which needs to be run before
3735 // native/interface traits and after all the traits with `StructuralOpTrait`.
3736 opClass
.addTrait("::mlir::OpTrait::OpInvariants");
3738 if (emitHelper
.hasProperties())
3739 opClass
.addTrait("::mlir::BytecodeOpInterface::Trait");
3741 // Add the native and interface traits.
3742 for (const auto &trait
: op
.getTraits()) {
3743 if (auto *opTrait
= dyn_cast
<tblgen::NativeTrait
>(&trait
)) {
3744 if (!opTrait
->isStructuralOpTrait())
3745 opClass
.addTrait(opTrait
->getFullyQualifiedTraitName());
3746 } else if (auto *opTrait
= dyn_cast
<tblgen::InterfaceTrait
>(&trait
)) {
3747 opClass
.addTrait(opTrait
->getFullyQualifiedTraitName());
3752 void OpEmitter::genOpNameGetter() {
3753 auto *method
= opClass
.addStaticMethod
<Method::Constexpr
>(
3754 "::llvm::StringLiteral", "getOperationName");
3755 ERROR_IF_PRUNED(method
, "getOperationName", op
);
3756 method
->body() << " return ::llvm::StringLiteral(\"" << op
.getOperationName()
3760 void OpEmitter::genOpAsmInterface() {
3761 // If the user only has one results or specifically added the Asm trait,
3762 // then don't generate it for them. We specifically only handle multi result
3763 // operations, because the name of a single result in the common case is not
3764 // interesting(generally 'result'/'output'/etc.).
3765 // TODO: We could also add a flag to allow operations to opt in to this
3766 // generation, even if they only have a single operation.
3767 int numResults
= op
.getNumResults();
3768 if (numResults
<= 1 || op
.getTrait("::mlir::OpAsmOpInterface::Trait"))
3771 SmallVector
<StringRef
, 4> resultNames(numResults
);
3772 for (int i
= 0; i
!= numResults
; ++i
)
3773 resultNames
[i
] = op
.getResultName(i
);
3775 // Don't add the trait if none of the results have a valid name.
3776 if (llvm::all_of(resultNames
, [](StringRef name
) { return name
.empty(); }))
3778 opClass
.addTrait("::mlir::OpAsmOpInterface::Trait");
3780 // Generate the right accessor for the number of results.
3781 auto *method
= opClass
.addMethod(
3782 "void", "getAsmResultNames",
3783 MethodParameter("::mlir::OpAsmSetValueNameFn", "setNameFn"));
3784 ERROR_IF_PRUNED(method
, "getAsmResultNames", op
);
3785 auto &body
= method
->body();
3786 for (int i
= 0; i
!= numResults
; ++i
) {
3787 body
<< " auto resultGroup" << i
<< " = getODSResults(" << i
<< ");\n"
3788 << " if (!resultGroup" << i
<< ".empty())\n"
3789 << " setNameFn(*resultGroup" << i
<< ".begin(), \""
3790 << resultNames
[i
] << "\");\n";
3794 //===----------------------------------------------------------------------===//
3795 // OpOperandAdaptor emitter
3796 //===----------------------------------------------------------------------===//
3799 // Helper class to emit Op operand adaptors to an output stream. Operand
3800 // adaptors are wrappers around random access ranges that provide named operand
3801 // getters identical to those defined in the Op.
3802 // This currently generates 3 classes per Op:
3803 // * A Base class within the 'detail' namespace, which contains all logic and
3804 // members independent of the random access range that is indexed into.
3805 // In other words, it contains all the attribute and region getters.
3806 // * A templated class named '{OpName}GenericAdaptor' with a template parameter
3807 // 'RangeT' that is indexed into by the getters to access the operands.
3808 // It contains all getters to access operands and inherits from the previous
3810 // * A class named '{OpName}Adaptor', which inherits from the 'GenericAdaptor'
3811 // with 'mlir::ValueRange' as template parameter. It adds a constructor from
3812 // an instance of the op type and a verify function.
3813 class OpOperandAdaptorEmitter
{
3816 emitDecl(const Operator
&op
,
3817 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
,
3820 emitDef(const Operator
&op
,
3821 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
,
3825 explicit OpOperandAdaptorEmitter(
3827 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
);
3829 // Add verification function. This generates a verify method for the adaptor
3830 // which verifies all the op-independent attribute constraints.
3831 void addVerification();
3833 // The operation for which to emit an adaptor.
3836 // The generated adaptor classes.
3837 Class genericAdaptorBase
;
3838 Class genericAdaptor
;
3841 // The emitter containing all of the locally emitted verification functions.
3842 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
;
3844 // Helper for emitting adaptor code.
3845 OpOrAdaptorHelper emitHelper
;
3849 OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
3851 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
)
3852 : op(op
), genericAdaptorBase(op
.getGenericAdaptorName() + "Base"),
3853 genericAdaptor(op
.getGenericAdaptorName()), adaptor(op
.getAdaptorName()),
3854 staticVerifierEmitter(staticVerifierEmitter
),
3855 emitHelper(op
, /*emitForOp=*/false) {
3857 genericAdaptorBase
.declare
<VisibilityDeclaration
>(Visibility::Public
);
3858 bool useProperties
= emitHelper
.hasProperties();
3859 if (useProperties
) {
3860 // Define the properties struct with multiple members.
3861 using ConstArgument
=
3862 llvm::PointerUnion
<const AttributeMetadata
*, const NamedProperty
*>;
3863 SmallVector
<ConstArgument
> attrOrProperties
;
3864 for (const std::pair
<StringRef
, AttributeMetadata
> &it
:
3865 emitHelper
.getAttrMetadata()) {
3866 if (!it
.second
.constraint
|| !it
.second
.constraint
->isDerivedAttr())
3867 attrOrProperties
.push_back(&it
.second
);
3869 for (const NamedProperty
&prop
: op
.getProperties())
3870 attrOrProperties
.push_back(&prop
);
3871 if (emitHelper
.getOperandSegmentsSize())
3872 attrOrProperties
.push_back(&emitHelper
.getOperandSegmentsSize().value());
3873 if (emitHelper
.getResultSegmentsSize())
3874 attrOrProperties
.push_back(&emitHelper
.getResultSegmentsSize().value());
3875 assert(!attrOrProperties
.empty());
3876 std::string declarations
= " struct Properties {\n";
3877 llvm::raw_string_ostream
os(declarations
);
3878 std::string comparator
=
3879 " bool operator==(const Properties &rhs) const {\n"
3881 llvm::raw_string_ostream
comparatorOs(comparator
);
3882 for (const auto &attrOrProp
: attrOrProperties
) {
3883 if (const auto *namedProperty
=
3884 llvm::dyn_cast_if_present
<const NamedProperty
*>(attrOrProp
)) {
3885 StringRef name
= namedProperty
->name
;
3887 report_fatal_error("missing name for property");
3888 std::string camelName
=
3889 convertToCamelFromSnakeCase(name
, /*capitalizeFirst=*/true);
3890 auto &prop
= namedProperty
->prop
;
3891 // Generate the data member using the storage type.
3892 os
<< " using " << name
<< "Ty = " << prop
.getStorageType() << ";\n"
3893 << " " << name
<< "Ty " << name
;
3894 if (prop
.hasDefaultValue())
3895 os
<< " = " << prop
.getDefaultValue();
3896 comparatorOs
<< " rhs." << name
<< " == this->" << name
3898 // Emit accessors using the interface type.
3899 const char *accessorFmt
= R
"decl(;
3901 auto &propStorage = this->{2};
3904 void set{1}(const {0} &propValue) {
3905 auto &propStorage = this->{2};
3910 os
<< formatv(accessorFmt
, prop
.getInterfaceType(), camelName
, name
,
3911 tgfmt(prop
.getConvertFromStorageCall(),
3912 &fctx
.addSubst("_storage", propertyStorage
)),
3913 tgfmt(prop
.getAssignToStorageCall(),
3914 &fctx
.addSubst("_value", propertyValue
)
3915 .addSubst("_storage", propertyStorage
)));
3918 const auto *namedAttr
=
3919 llvm::dyn_cast_if_present
<const AttributeMetadata
*>(attrOrProp
);
3920 const Attribute
*attr
= nullptr;
3921 if (namedAttr
->constraint
)
3922 attr
= &*namedAttr
->constraint
;
3923 StringRef name
= namedAttr
->attrName
;
3925 report_fatal_error("missing name for property attr");
3926 std::string camelName
=
3927 convertToCamelFromSnakeCase(name
, /*capitalizeFirst=*/true);
3928 // Generate the data member using the storage type.
3929 StringRef storageType
;
3931 storageType
= attr
->getStorageType();
3933 if (name
!= operandSegmentAttrName
&& name
!= resultSegmentAttrName
) {
3934 report_fatal_error("unexpected AttributeMetadata");
3936 // TODO: update to use native integers.
3937 storageType
= "::mlir::DenseI32ArrayAttr";
3939 os
<< " using " << name
<< "Ty = " << storageType
<< ";\n"
3940 << " " << name
<< "Ty " << name
<< ";\n";
3941 comparatorOs
<< " rhs." << name
<< " == this->" << name
<< " &&\n";
3943 // Emit accessors using the interface type.
3945 const char *accessorFmt
= R
"decl(
3947 auto &propStorage = this->{1};
3948 return ::llvm::{2}<{3}>(propStorage);
3950 void set{0}(const {3} &propValue) {
3951 this->{1} = propValue;
3954 os
<< formatv(accessorFmt
, camelName
, name
,
3955 attr
->isOptional() || attr
->hasDefaultValue()
3956 ? "dyn_cast_or_null"
3961 comparatorOs
<< " true;\n }\n"
3962 " bool operator!=(const Properties &rhs) const {\n"
3963 " return !(*this == rhs);\n"
3965 comparatorOs
.flush();
3970 genericAdaptorBase
.declare
<ExtraClassDeclaration
>(std::move(declarations
));
3972 genericAdaptorBase
.declare
<VisibilityDeclaration
>(Visibility::Protected
);
3973 genericAdaptorBase
.declare
<Field
>("::mlir::DictionaryAttr", "odsAttrs");
3974 genericAdaptorBase
.declare
<Field
>("::std::optional<::mlir::OperationName>",
3977 genericAdaptorBase
.declare
<Field
>("Properties", "properties");
3978 genericAdaptorBase
.declare
<Field
>("::mlir::RegionRange", "odsRegions");
3980 genericAdaptor
.addTemplateParam("RangeT");
3981 genericAdaptor
.addField("RangeT", "odsOperands");
3982 genericAdaptor
.addParent(
3983 ParentClass("detail::" + genericAdaptorBase
.getClassName()));
3984 genericAdaptor
.declare
<UsingDeclaration
>(
3985 "ValueT", "::llvm::detail::ValueOfRange<RangeT>");
3986 genericAdaptor
.declare
<UsingDeclaration
>(
3987 "Base", "detail::" + genericAdaptorBase
.getClassName());
3989 const auto *attrSizedOperands
=
3990 op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
3992 SmallVector
<MethodParameter
> paramList
;
3993 paramList
.emplace_back("::mlir::DictionaryAttr", "attrs",
3994 attrSizedOperands
? "" : "nullptr");
3996 paramList
.emplace_back("const Properties &", "properties", "{}");
3998 paramList
.emplace_back("const ::mlir::EmptyProperties &", "properties",
4000 paramList
.emplace_back("::mlir::RegionRange", "regions", "{}");
4001 auto *baseConstructor
= genericAdaptorBase
.addConstructor(paramList
);
4002 baseConstructor
->addMemberInitializer("odsAttrs", "attrs");
4004 baseConstructor
->addMemberInitializer("properties", "properties");
4005 baseConstructor
->addMemberInitializer("odsRegions", "regions");
4007 MethodBody
&body
= baseConstructor
->body();
4008 body
.indent() << "if (odsAttrs)\n";
4009 body
.indent() << formatv(
4010 "odsOpName.emplace(\"{0}\", odsAttrs.getContext());\n",
4011 op
.getOperationName());
4013 paramList
.insert(paramList
.begin(), MethodParameter("RangeT", "values"));
4014 auto *constructor
= genericAdaptor
.addConstructor(paramList
);
4015 constructor
->addMemberInitializer("Base", "attrs, properties, regions");
4016 constructor
->addMemberInitializer("odsOperands", "values");
4018 // Add a forwarding constructor to the previous one that accepts
4019 // OpaqueProperties instead and check for null and perform the cast to the
4020 // actual properties type.
4021 paramList
[1] = MethodParameter("::mlir::DictionaryAttr", "attrs");
4022 paramList
[2] = MethodParameter("::mlir::OpaqueProperties", "properties");
4023 auto *opaquePropertiesConstructor
=
4024 genericAdaptor
.addConstructor(std::move(paramList
));
4025 if (useProperties
) {
4026 opaquePropertiesConstructor
->addMemberInitializer(
4027 genericAdaptor
.getClassName(),
4030 "(properties ? *properties.as<Properties *>() : Properties{}), "
4033 opaquePropertiesConstructor
->addMemberInitializer(
4034 genericAdaptor
.getClassName(),
4037 "(properties ? *properties.as<::mlir::EmptyProperties *>() : "
4038 "::mlir::EmptyProperties{}), "
4043 // Create constructors constructing the adaptor from an instance of the op.
4044 // This takes the attributes, properties and regions from the op instance
4045 // and the value range from the parameter.
4047 // Base class is in the cpp file and can simply access the members of the op
4048 // class to initialize the template independent fields.
4049 auto *constructor
= genericAdaptorBase
.addConstructor(
4050 MethodParameter(op
.getCppClassName(), "op"));
4051 constructor
->addMemberInitializer(
4052 genericAdaptorBase
.getClassName(),
4053 llvm::Twine(!useProperties
? "op->getAttrDictionary()"
4054 : "op->getDiscardableAttrDictionary()") +
4055 ", op.getProperties(), op->getRegions()");
4057 // Generic adaptor is templated and therefore defined inline in the header.
4058 // We cannot use the Op class here as it is an incomplete type (we have a
4059 // circular reference between the two).
4060 // Use a template trick to make the constructor be instantiated at call site
4061 // when the op class is complete.
4062 constructor
= genericAdaptor
.addConstructor(
4063 MethodParameter("RangeT", "values"), MethodParameter("LateInst", "op"));
4064 constructor
->addTemplateParam("LateInst = " + op
.getCppClassName());
4065 constructor
->addTemplateParam(
4066 "= std::enable_if_t<std::is_same_v<LateInst, " + op
.getCppClassName() +
4068 constructor
->addMemberInitializer("Base", "op");
4069 constructor
->addMemberInitializer("odsOperands", "values");
4072 std::string sizeAttrInit
;
4073 if (op
.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
4074 if (op
.getDialect().usePropertiesForAttributes())
4076 formatv(adapterSegmentSizeAttrInitCodeProperties
,
4077 llvm::formatv("getProperties().operandSegmentSizes"));
4079 sizeAttrInit
= formatv(adapterSegmentSizeAttrInitCode
,
4080 emitHelper
.getAttr(operandSegmentAttrName
));
4082 generateNamedOperandGetters(op
, genericAdaptor
,
4083 /*genericAdaptorBase=*/&genericAdaptorBase
,
4084 /*sizeAttrInit=*/sizeAttrInit
,
4085 /*rangeType=*/"RangeT",
4086 /*rangeElementType=*/"ValueT",
4087 /*rangeBeginCall=*/"odsOperands.begin()",
4088 /*rangeSizeCall=*/"odsOperands.size()",
4089 /*getOperandCallPattern=*/"odsOperands[{0}]");
4091 // Any invalid overlap for `getOperands` will have been diagnosed before
4093 if (auto *m
= genericAdaptor
.addMethod("RangeT", "getOperands"))
4094 m
->body() << " return odsOperands;";
4097 fctx
.withBuilder("::mlir::Builder(odsAttrs.getContext())");
4099 // Generate named accessor with Attribute return type.
4100 auto emitAttrWithStorageType
= [&](StringRef name
, StringRef emitName
,
4103 genericAdaptorBase
.addMethod(attr
.getStorageType(), emitName
+ "Attr");
4104 ERROR_IF_PRUNED(method
, "Adaptor::" + emitName
+ "Attr", op
);
4105 auto &body
= method
->body().indent();
4107 body
<< "assert(odsAttrs && \"no attributes when constructing "
4110 "auto attr = ::llvm::{1}<{2}>({0});\n", emitHelper
.getAttr(name
),
4111 attr
.hasDefaultValue() || attr
.isOptional() ? "dyn_cast_or_null"
4113 attr
.getStorageType());
4115 if (attr
.hasDefaultValue() && attr
.isOptional()) {
4116 // Use the default value if attribute is not set.
4117 // TODO: this is inefficient, we are recreating the attribute for every
4118 // call. This should be set instead.
4119 std::string defaultValue
= std::string(
4120 tgfmt(attr
.getConstBuilderTemplate(), &fctx
, attr
.getDefaultValue()));
4121 body
<< "if (!attr)\n attr = " << defaultValue
<< ";\n";
4123 body
<< "return attr;\n";
4126 if (useProperties
) {
4127 auto *m
= genericAdaptorBase
.addInlineMethod("const Properties &",
4129 ERROR_IF_PRUNED(m
, "Adaptor::getProperties", op
);
4130 m
->body() << " return properties;";
4134 genericAdaptorBase
.addMethod("::mlir::DictionaryAttr", "getAttributes");
4135 ERROR_IF_PRUNED(m
, "Adaptor::getAttributes", op
);
4136 m
->body() << " return odsAttrs;";
4138 for (auto &namedAttr
: op
.getAttributes()) {
4139 const auto &name
= namedAttr
.name
;
4140 const auto &attr
= namedAttr
.attr
;
4141 if (attr
.isDerivedAttr())
4143 std::string emitName
= op
.getGetterName(name
);
4144 emitAttrWithStorageType(name
, emitName
, attr
);
4145 emitAttrGetterWithReturnType(fctx
, genericAdaptorBase
, op
, emitName
, attr
);
4148 unsigned numRegions
= op
.getNumRegions();
4149 for (unsigned i
= 0; i
< numRegions
; ++i
) {
4150 const auto ®ion
= op
.getRegion(i
);
4151 if (region
.name
.empty())
4154 // Generate the accessors for a variadic region.
4155 std::string name
= op
.getGetterName(region
.name
);
4156 if (region
.isVariadic()) {
4157 auto *m
= genericAdaptorBase
.addMethod("::mlir::RegionRange", name
);
4158 ERROR_IF_PRUNED(m
, "Adaptor::" + name
, op
);
4159 m
->body() << formatv(" return odsRegions.drop_front({0});", i
);
4163 auto *m
= genericAdaptorBase
.addMethod("::mlir::Region &", name
);
4164 ERROR_IF_PRUNED(m
, "Adaptor::" + name
, op
);
4165 m
->body() << formatv(" return *odsRegions[{0}];", i
);
4167 if (numRegions
> 0) {
4168 // Any invalid overlap for `getRegions` will have been diagnosed before
4171 genericAdaptorBase
.addMethod("::mlir::RegionRange", "getRegions"))
4172 m
->body() << " return odsRegions;";
4175 StringRef genericAdaptorClassName
= genericAdaptor
.getClassName();
4176 adaptor
.addParent(ParentClass(genericAdaptorClassName
))
4177 .addTemplateParam("::mlir::ValueRange");
4178 adaptor
.declare
<VisibilityDeclaration
>(Visibility::Public
);
4179 adaptor
.declare
<UsingDeclaration
>(genericAdaptorClassName
+
4180 "::" + genericAdaptorClassName
);
4182 // Constructor taking the Op as single parameter.
4184 adaptor
.addConstructor(MethodParameter(op
.getCppClassName(), "op"));
4185 constructor
->addMemberInitializer(genericAdaptorClassName
,
4186 "op->getOperands(), op");
4189 // Add verification function.
4192 genericAdaptorBase
.finalize();
4193 genericAdaptor
.finalize();
4197 void OpOperandAdaptorEmitter::addVerification() {
4198 auto *method
= adaptor
.addMethod("::mlir::LogicalResult", "verify",
4199 MethodParameter("::mlir::Location", "loc"));
4200 ERROR_IF_PRUNED(method
, "verify", op
);
4201 auto &body
= method
->body();
4202 bool useProperties
= emitHelper
.hasProperties();
4204 FmtContext verifyCtx
;
4205 populateSubstitutions(emitHelper
, verifyCtx
);
4206 genAttributeVerifier(emitHelper
, verifyCtx
, body
, staticVerifierEmitter
,
4209 body
<< " return ::mlir::success();";
4212 void OpOperandAdaptorEmitter::emitDecl(
4214 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
,
4216 OpOperandAdaptorEmitter
emitter(op
, staticVerifierEmitter
);
4218 NamespaceEmitter
ns(os
, "detail");
4219 emitter
.genericAdaptorBase
.writeDeclTo(os
);
4221 emitter
.genericAdaptor
.writeDeclTo(os
);
4222 emitter
.adaptor
.writeDeclTo(os
);
4225 void OpOperandAdaptorEmitter::emitDef(
4227 const StaticVerifierFunctionEmitter
&staticVerifierEmitter
,
4229 OpOperandAdaptorEmitter
emitter(op
, staticVerifierEmitter
);
4231 NamespaceEmitter
ns(os
, "detail");
4232 emitter
.genericAdaptorBase
.writeDefTo(os
);
4234 emitter
.genericAdaptor
.writeDefTo(os
);
4235 emitter
.adaptor
.writeDefTo(os
);
4238 // Emits the opcode enum and op classes.
4239 static void emitOpClasses(const RecordKeeper
&recordKeeper
,
4240 const std::vector
<Record
*> &defs
, raw_ostream
&os
,
4242 // First emit forward declaration for each class, this allows them to refer
4243 // to each others in traits for example.
4245 os
<< "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n";
4246 os
<< "#undef GET_OP_FWD_DEFINES\n";
4247 for (auto *def
: defs
) {
4249 NamespaceEmitter
emitter(os
, op
.getCppNamespace());
4250 os
<< "class " << op
.getCppClassName() << ";\n";
4255 IfDefScope
scope("GET_OP_CLASSES", os
);
4259 // Generate all of the locally instantiated methods first.
4260 StaticVerifierFunctionEmitter
staticVerifierEmitter(os
, recordKeeper
);
4261 os
<< formatv(opCommentHeader
, "Local Utility Method", "Definitions");
4262 staticVerifierEmitter
.emitOpConstraints(defs
, emitDecl
);
4264 for (auto *def
: defs
) {
4268 NamespaceEmitter
emitter(os
, op
.getCppNamespace());
4269 os
<< formatv(opCommentHeader
, op
.getQualCppClassName(),
4271 OpOperandAdaptorEmitter::emitDecl(op
, staticVerifierEmitter
, os
);
4272 OpEmitter::emitDecl(op
, os
, staticVerifierEmitter
);
4274 // Emit the TypeID explicit specialization to have a single definition.
4275 if (!op
.getCppNamespace().empty())
4276 os
<< "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << op
.getCppNamespace()
4277 << "::" << op
.getCppClassName() << ")\n\n";
4280 NamespaceEmitter
emitter(os
, op
.getCppNamespace());
4281 os
<< formatv(opCommentHeader
, op
.getQualCppClassName(), "definitions");
4282 OpOperandAdaptorEmitter::emitDef(op
, staticVerifierEmitter
, os
);
4283 OpEmitter::emitDef(op
, os
, staticVerifierEmitter
);
4285 // Emit the TypeID explicit specialization to have a single definition.
4286 if (!op
.getCppNamespace().empty())
4287 os
<< "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << op
.getCppNamespace()
4288 << "::" << op
.getCppClassName() << ")\n\n";
4293 // Emits a comma-separated list of the ops.
4294 static void emitOpList(const std::vector
<Record
*> &defs
, raw_ostream
&os
) {
4295 IfDefScope
scope("GET_OP_LIST", os
);
4298 // TODO: We are constructing the Operator wrapper instance just for
4299 // getting it's qualified class name here. Reduce the overhead by having a
4300 // lightweight version of Operator class just for that purpose.
4301 defs
, [&os
](Record
*def
) { os
<< Operator(def
).getQualCppClassName(); },
4302 [&os
]() { os
<< ",\n"; });
4305 static bool emitOpDecls(const RecordKeeper
&recordKeeper
, raw_ostream
&os
) {
4306 emitSourceFileHeader("Op Declarations", os
, recordKeeper
);
4308 std::vector
<Record
*> defs
= getRequestedOpDefinitions(recordKeeper
);
4309 emitOpClasses(recordKeeper
, defs
, os
, /*emitDecl=*/true);
4314 static bool emitOpDefs(const RecordKeeper
&recordKeeper
, raw_ostream
&os
) {
4315 emitSourceFileHeader("Op Definitions", os
, recordKeeper
);
4317 std::vector
<Record
*> defs
= getRequestedOpDefinitions(recordKeeper
);
4318 emitOpList(defs
, os
);
4319 emitOpClasses(recordKeeper
, defs
, os
, /*emitDecl=*/false);
4324 static mlir::GenRegistration
4325 genOpDecls("gen-op-decls", "Generate op declarations",
4326 [](const RecordKeeper
&records
, raw_ostream
&os
) {
4327 return emitOpDecls(records
, os
);
4330 static mlir::GenRegistration
genOpDefs("gen-op-defs", "Generate op definitions",
4331 [](const RecordKeeper
&records
,
4333 return emitOpDefs(records
, os
);