[Infra] Fix version-check workflow (#100090)
[llvm-project.git] / mlir / tools / mlir-tblgen / OpDefinitionsGen.cpp
blob0fc750c7bbc88736247acb9ed71a5ce7daa1e4bd
1 //===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpDefinitionsGen uses the description of operations to generate C++
10 // definitions for ops.
12 //===----------------------------------------------------------------------===//
14 #include "OpClass.h"
15 #include "OpFormatGen.h"
16 #include "OpGenHelpers.h"
17 #include "mlir/TableGen/Argument.h"
18 #include "mlir/TableGen/Attribute.h"
19 #include "mlir/TableGen/Class.h"
20 #include "mlir/TableGen/CodeGenHelpers.h"
21 #include "mlir/TableGen/Format.h"
22 #include "mlir/TableGen/GenInfo.h"
23 #include "mlir/TableGen/Interfaces.h"
24 #include "mlir/TableGen/Operator.h"
25 #include "mlir/TableGen/Property.h"
26 #include "mlir/TableGen/SideEffects.h"
27 #include "mlir/TableGen/Trait.h"
28 #include "llvm/ADT/BitVector.h"
29 #include "llvm/ADT/MapVector.h"
30 #include "llvm/ADT/Sequence.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/StringExtras.h"
33 #include "llvm/ADT/StringSet.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/ErrorHandling.h"
36 #include "llvm/Support/Signals.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include "llvm/TableGen/Error.h"
39 #include "llvm/TableGen/Record.h"
40 #include "llvm/TableGen/TableGenBackend.h"
42 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
44 using namespace llvm;
45 using namespace mlir;
46 using namespace mlir::tblgen;
48 static const char *const tblgenNamePrefix = "tblgen_";
49 static const char *const generatedArgName = "odsArg";
50 static const char *const odsBuilder = "odsBuilder";
51 static const char *const builderOpState = "odsState";
52 static const char *const propertyStorage = "propStorage";
53 static const char *const propertyValue = "propValue";
54 static const char *const propertyAttr = "propAttr";
55 static const char *const propertyDiag = "emitError";
57 /// The names of the implicit attributes that contain variadic operand and
58 /// result segment sizes.
59 static const char *const operandSegmentAttrName = "operandSegmentSizes";
60 static const char *const resultSegmentAttrName = "resultSegmentSizes";
62 /// Code for an Op to lookup an attribute. Uses cached identifiers and subrange
63 /// lookup.
64 ///
65 /// {0}: Code snippet to get the attribute's name or identifier.
66 /// {1}: The lower bound on the sorted subrange.
67 /// {2}: The upper bound on the sorted subrange.
68 /// {3}: Code snippet to get the array of named attributes.
69 /// {4}: "Named" to get the named attribute.
70 static const char *const subrangeGetAttr =
71 "::mlir::impl::get{4}AttrFromSortedRange({3}.begin() + {1}, {3}.end() - "
72 "{2}, {0})";
74 /// The logic to calculate the actual value range for a declared operand/result
75 /// of an op with variadic operands/results. Note that this logic is not for
76 /// general use; it assumes all variadic operands/results must have the same
77 /// number of values.
78 ///
79 /// {0}: The list of whether each declared operand/result is variadic.
80 /// {1}: The total number of non-variadic operands/results.
81 /// {2}: The total number of variadic operands/results.
82 /// {3}: The total number of actual values.
83 /// {4}: "operand" or "result".
84 static const char *const sameVariadicSizeValueRangeCalcCode = R"(
85 bool isVariadic[] = {{{0}};
86 int prevVariadicCount = 0;
87 for (unsigned i = 0; i < index; ++i)
88 if (isVariadic[i]) ++prevVariadicCount;
90 // Calculate how many dynamic values a static variadic {4} corresponds to.
91 // This assumes all static variadic {4}s have the same dynamic value count.
92 int variadicSize = ({3} - {1}) / {2};
93 // `index` passed in as the parameter is the static index which counts each
94 // {4} (variadic or not) as size 1. So here for each previous static variadic
95 // {4}, we need to offset by (variadicSize - 1) to get where the dynamic
96 // value pack for this static {4} starts.
97 int start = index + (variadicSize - 1) * prevVariadicCount;
98 int size = isVariadic[index] ? variadicSize : 1;
99 return {{start, size};
102 /// The logic to calculate the actual value range for a declared operand/result
103 /// of an op with variadic operands/results. Note that this logic is assumes
104 /// the op has an attribute specifying the size of each operand/result segment
105 /// (variadic or not).
106 static const char *const attrSizedSegmentValueRangeCalcCode = R"(
107 unsigned start = 0;
108 for (unsigned i = 0; i < index; ++i)
109 start += sizeAttr[i];
110 return {start, sizeAttr[index]};
112 /// The code snippet to initialize the sizes for the value range calculation.
114 /// {0}: The code to get the attribute.
115 static const char *const adapterSegmentSizeAttrInitCode = R"(
116 assert({0} && "missing segment size attribute for op");
117 auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>({0});
119 static const char *const adapterSegmentSizeAttrInitCodeProperties = R"(
120 ::llvm::ArrayRef<int32_t> sizeAttr = {0};
123 /// The code snippet to initialize the sizes for the value range calculation.
125 /// {0}: The code to get the attribute.
126 static const char *const opSegmentSizeAttrInitCode = R"(
127 auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>({0});
130 /// The logic to calculate the actual value range for a declared operand
131 /// of an op with variadic of variadic operands within the OpAdaptor.
133 /// {0}: The name of the segment attribute.
134 /// {1}: The index of the main operand.
135 /// {2}: The range type of adaptor.
136 static const char *const variadicOfVariadicAdaptorCalcCode = R"(
137 auto tblgenTmpOperands = getODSOperands({1});
138 auto sizes = {0}();
140 ::llvm::SmallVector<{2}> tblgenTmpOperandGroups;
141 for (int i = 0, e = sizes.size(); i < e; ++i) {{
142 tblgenTmpOperandGroups.push_back(tblgenTmpOperands.take_front(sizes[i]));
143 tblgenTmpOperands = tblgenTmpOperands.drop_front(sizes[i]);
145 return tblgenTmpOperandGroups;
148 /// The logic to build a range of either operand or result values.
150 /// {0}: The begin iterator of the actual values.
151 /// {1}: The call to generate the start and length of the value range.
152 static const char *const valueRangeReturnCode = R"(
153 auto valueRange = {1};
154 return {{std::next({0}, valueRange.first),
155 std::next({0}, valueRange.first + valueRange.second)};
158 /// Read operand/result segment_size from bytecode.
159 static const char *const readBytecodeSegmentSizeNative = R"(
160 if ($_reader.getBytecodeVersion() >= /*kNativePropertiesODSSegmentSize=*/6)
161 return $_reader.readSparseArray(::llvm::MutableArrayRef($_storage));
164 static const char *const readBytecodeSegmentSizeLegacy = R"(
165 if ($_reader.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) {
166 auto &$_storage = prop.$_propName;
167 ::mlir::DenseI32ArrayAttr attr;
168 if (::mlir::failed($_reader.readAttribute(attr))) return ::mlir::failure();
169 if (attr.size() > static_cast<int64_t>(sizeof($_storage) / sizeof(int32_t))) {
170 $_reader.emitError("size mismatch for operand/result_segment_size");
171 return ::mlir::failure();
173 ::llvm::copy(::llvm::ArrayRef<int32_t>(attr), $_storage.begin());
177 /// Write operand/result segment_size to bytecode.
178 static const char *const writeBytecodeSegmentSizeNative = R"(
179 if ($_writer.getBytecodeVersion() >= /*kNativePropertiesODSSegmentSize=*/6)
180 $_writer.writeSparseArray(::llvm::ArrayRef($_storage));
183 /// Write operand/result segment_size to bytecode.
184 static const char *const writeBytecodeSegmentSizeLegacy = R"(
185 if ($_writer.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) {
186 auto &$_storage = prop.$_propName;
187 $_writer.writeAttribute(::mlir::DenseI32ArrayAttr::get($_ctxt, $_storage));
191 /// A header for indicating code sections.
193 /// {0}: Some text, or a class name.
194 /// {1}: Some text.
195 static const char *const opCommentHeader = R"(
196 //===----------------------------------------------------------------------===//
197 // {0} {1}
198 //===----------------------------------------------------------------------===//
202 //===----------------------------------------------------------------------===//
203 // Utility structs and functions
204 //===----------------------------------------------------------------------===//
206 // Replaces all occurrences of `match` in `str` with `substitute`.
207 static std::string replaceAllSubstrs(std::string str, const std::string &match,
208 const std::string &substitute) {
209 std::string::size_type scanLoc = 0, matchLoc = std::string::npos;
210 while ((matchLoc = str.find(match, scanLoc)) != std::string::npos) {
211 str = str.replace(matchLoc, match.size(), substitute);
212 scanLoc = matchLoc + substitute.size();
214 return str;
217 // Returns whether the record has a value of the given name that can be returned
218 // via getValueAsString.
219 static inline bool hasStringAttribute(const Record &record,
220 StringRef fieldName) {
221 auto *valueInit = record.getValueInit(fieldName);
222 return isa<StringInit>(valueInit);
225 static std::string getArgumentName(const Operator &op, int index) {
226 const auto &operand = op.getOperand(index);
227 if (!operand.name.empty())
228 return std::string(operand.name);
229 return std::string(formatv("{0}_{1}", generatedArgName, index));
232 // Returns true if we can use unwrapped value for the given `attr` in builders.
233 static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
234 return attr.getReturnType() != attr.getStorageType() &&
235 // We need to wrap the raw value into an attribute in the builder impl
236 // so we need to make sure that the attribute specifies how to do that.
237 !attr.getConstBuilderTemplate().empty();
240 /// Build an attribute from a parameter value using the constant builder.
241 static std::string constBuildAttrFromParam(const tblgen::Attribute &attr,
242 FmtContext &fctx,
243 StringRef paramName) {
244 std::string builderTemplate = attr.getConstBuilderTemplate().str();
246 // For StringAttr, its constant builder call will wrap the input in
247 // quotes, which is correct for normal string literals, but incorrect
248 // here given we use function arguments. So we need to strip the
249 // wrapping quotes.
250 if (StringRef(builderTemplate).contains("\"$0\""))
251 builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
253 return tgfmt(builderTemplate, &fctx, paramName).str();
256 namespace {
257 /// Metadata on a registered attribute. Given that attributes are stored in
258 /// sorted order on operations, we can use information from ODS to deduce the
259 /// number of required attributes less and and greater than each attribute,
260 /// allowing us to search only a subrange of the attributes in ODS-generated
261 /// getters.
262 struct AttributeMetadata {
263 /// The attribute name.
264 StringRef attrName;
265 /// Whether the attribute is required.
266 bool isRequired;
267 /// The ODS attribute constraint. Not present for implicit attributes.
268 std::optional<Attribute> constraint;
269 /// The number of required attributes less than this attribute.
270 unsigned lowerBound = 0;
271 /// The number of required attributes greater than this attribute.
272 unsigned upperBound = 0;
275 /// Helper class to select between OpAdaptor and Op code templates.
276 class OpOrAdaptorHelper {
277 public:
278 OpOrAdaptorHelper(const Operator &op, bool emitForOp)
279 : op(op), emitForOp(emitForOp) {
280 computeAttrMetadata();
283 /// Object that wraps a functor in a stream operator for interop with
284 /// llvm::formatv.
285 class Formatter {
286 public:
287 template <typename Functor>
288 Formatter(Functor &&func) : func(std::forward<Functor>(func)) {}
290 std::string str() const {
291 std::string result;
292 llvm::raw_string_ostream os(result);
293 os << *this;
294 return os.str();
297 private:
298 std::function<raw_ostream &(raw_ostream &)> func;
300 friend raw_ostream &operator<<(raw_ostream &os, const Formatter &fmt) {
301 return fmt.func(os);
305 // Generate code for getting an attribute.
306 Formatter getAttr(StringRef attrName, bool isNamed = false) const {
307 assert(attrMetadata.count(attrName) && "expected attribute metadata");
308 return [this, attrName, isNamed](raw_ostream &os) -> raw_ostream & {
309 const AttributeMetadata &attr = attrMetadata.find(attrName)->second;
310 if (hasProperties()) {
311 assert(!isNamed);
312 return os << "getProperties()." << attrName;
314 return os << formatv(subrangeGetAttr, getAttrName(attrName),
315 attr.lowerBound, attr.upperBound, getAttrRange(),
316 isNamed ? "Named" : "");
320 // Generate code for getting the name of an attribute.
321 Formatter getAttrName(StringRef attrName) const {
322 return [this, attrName](raw_ostream &os) -> raw_ostream & {
323 if (emitForOp)
324 return os << op.getGetterName(attrName) << "AttrName()";
325 return os << formatv("{0}::{1}AttrName(*odsOpName)", op.getCppClassName(),
326 op.getGetterName(attrName));
330 // Get the code snippet for getting the named attribute range.
331 StringRef getAttrRange() const {
332 return emitForOp ? "(*this)->getAttrs()" : "odsAttrs";
335 // Get the prefix code for emitting an error.
336 Formatter emitErrorPrefix() const {
337 return [this](raw_ostream &os) -> raw_ostream & {
338 if (emitForOp)
339 return os << "emitOpError(";
340 return os << formatv("emitError(loc, \"'{0}' op \"",
341 op.getOperationName());
345 // Get the call to get an operand or segment of operands.
346 Formatter getOperand(unsigned index) const {
347 return [this, index](raw_ostream &os) -> raw_ostream & {
348 return os << formatv(op.getOperand(index).isVariadic()
349 ? "this->getODSOperands({0})"
350 : "(*this->getODSOperands({0}).begin())",
351 index);
355 // Get the call to get a result of segment of results.
356 Formatter getResult(unsigned index) const {
357 return [this, index](raw_ostream &os) -> raw_ostream & {
358 if (!emitForOp)
359 return os << "<no results should be generated>";
360 return os << formatv(op.getResult(index).isVariadic()
361 ? "this->getODSResults({0})"
362 : "(*this->getODSResults({0}).begin())",
363 index);
367 // Return whether an op instance is available.
368 bool isEmittingForOp() const { return emitForOp; }
370 // Return the ODS operation wrapper.
371 const Operator &getOp() const { return op; }
373 // Get the attribute metadata sorted by name.
374 const llvm::MapVector<StringRef, AttributeMetadata> &getAttrMetadata() const {
375 return attrMetadata;
378 /// Returns whether to emit a `Properties` struct for this operation or not.
379 bool hasProperties() const {
380 if (!op.getProperties().empty())
381 return true;
382 if (!op.getDialect().usePropertiesForAttributes())
383 return false;
384 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments") ||
385 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
386 return true;
387 return llvm::any_of(getAttrMetadata(),
388 [](const std::pair<StringRef, AttributeMetadata> &it) {
389 return !it.second.constraint ||
390 !it.second.constraint->isDerivedAttr();
394 std::optional<NamedProperty> &getOperandSegmentsSize() {
395 return operandSegmentsSize;
398 std::optional<NamedProperty> &getResultSegmentsSize() {
399 return resultSegmentsSize;
402 uint32_t getOperandSegmentSizesLegacyIndex() {
403 return operandSegmentSizesLegacyIndex;
406 uint32_t getResultSegmentSizesLegacyIndex() {
407 return resultSegmentSizesLegacyIndex;
410 private:
411 // Compute the attribute metadata.
412 void computeAttrMetadata();
414 // The operation ODS wrapper.
415 const Operator &op;
416 // True if code is being generate for an op. False for an adaptor.
417 const bool emitForOp;
419 // The attribute metadata, mapped by name.
420 llvm::MapVector<StringRef, AttributeMetadata> attrMetadata;
422 // Property
423 std::optional<NamedProperty> operandSegmentsSize;
424 std::string operandSegmentsSizeStorage;
425 std::optional<NamedProperty> resultSegmentsSize;
426 std::string resultSegmentsSizeStorage;
428 // Indices to store the position in the emission order of the operand/result
429 // segment sizes attribute if emitted as part of the properties for legacy
430 // bytecode encodings, i.e. versions less than 6.
431 uint32_t operandSegmentSizesLegacyIndex = 0;
432 uint32_t resultSegmentSizesLegacyIndex = 0;
434 // The number of required attributes.
435 unsigned numRequired;
438 } // namespace
440 void OpOrAdaptorHelper::computeAttrMetadata() {
441 // Enumerate the attribute names of this op, ensuring the attribute names are
442 // unique in case implicit attributes are explicitly registered.
443 for (const NamedAttribute &namedAttr : op.getAttributes()) {
444 Attribute attr = namedAttr.attr;
445 bool isOptional =
446 attr.hasDefaultValue() || attr.isOptional() || attr.isDerivedAttr();
447 attrMetadata.insert(
448 {namedAttr.name, AttributeMetadata{namedAttr.name, !isOptional, attr}});
451 auto makeProperty = [&](StringRef storageType) {
452 return Property(
453 /*storageType=*/storageType,
454 /*interfaceType=*/"::llvm::ArrayRef<int32_t>",
455 /*convertFromStorageCall=*/"$_storage",
456 /*assignToStorageCall=*/
457 "::llvm::copy($_value, $_storage.begin())",
458 /*convertToAttributeCall=*/
459 "::mlir::DenseI32ArrayAttr::get($_ctxt, $_storage)",
460 /*convertFromAttributeCall=*/
461 "return convertFromAttribute($_storage, $_attr, $_diag);",
462 /*readFromMlirBytecodeCall=*/readBytecodeSegmentSizeNative,
463 /*writeToMlirBytecodeCall=*/writeBytecodeSegmentSizeNative,
464 /*hashPropertyCall=*/
465 "::llvm::hash_combine_range(std::begin($_storage), "
466 "std::end($_storage));",
467 /*StringRef defaultValue=*/"");
469 // Include key attributes from several traits as implicitly registered.
470 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
471 if (op.getDialect().usePropertiesForAttributes()) {
472 operandSegmentsSizeStorage =
473 llvm::formatv("std::array<int32_t, {0}>", op.getNumOperands());
474 operandSegmentsSize = {"operandSegmentSizes",
475 makeProperty(operandSegmentsSizeStorage)};
476 } else {
477 attrMetadata.insert(
478 {operandSegmentAttrName, AttributeMetadata{operandSegmentAttrName,
479 /*isRequired=*/true,
480 /*attr=*/std::nullopt}});
483 if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
484 if (op.getDialect().usePropertiesForAttributes()) {
485 resultSegmentsSizeStorage =
486 llvm::formatv("std::array<int32_t, {0}>", op.getNumResults());
487 resultSegmentsSize = {"resultSegmentSizes",
488 makeProperty(resultSegmentsSizeStorage)};
489 } else {
490 attrMetadata.insert(
491 {resultSegmentAttrName,
492 AttributeMetadata{resultSegmentAttrName, /*isRequired=*/true,
493 /*attr=*/std::nullopt}});
497 // Store the metadata in sorted order.
498 SmallVector<AttributeMetadata> sortedAttrMetadata =
499 llvm::to_vector(llvm::make_second_range(attrMetadata.takeVector()));
500 llvm::sort(sortedAttrMetadata,
501 [](const AttributeMetadata &lhs, const AttributeMetadata &rhs) {
502 return lhs.attrName < rhs.attrName;
505 // Store the position of the legacy operand_segment_sizes /
506 // result_segment_sizes so we can emit a backward compatible property readers
507 // and writers.
508 StringRef legacyOperandSegmentSizeName =
509 StringLiteral("operand_segment_sizes");
510 StringRef legacyResultSegmentSizeName = StringLiteral("result_segment_sizes");
511 operandSegmentSizesLegacyIndex = 0;
512 resultSegmentSizesLegacyIndex = 0;
513 for (auto item : sortedAttrMetadata) {
514 if (item.attrName < legacyOperandSegmentSizeName)
515 ++operandSegmentSizesLegacyIndex;
516 if (item.attrName < legacyResultSegmentSizeName)
517 ++resultSegmentSizesLegacyIndex;
520 // Compute the subrange bounds for each attribute.
521 numRequired = 0;
522 for (AttributeMetadata &attr : sortedAttrMetadata) {
523 attr.lowerBound = numRequired;
524 numRequired += attr.isRequired;
526 for (AttributeMetadata &attr : sortedAttrMetadata)
527 attr.upperBound = numRequired - attr.lowerBound - attr.isRequired;
529 // Store the results back into the map.
530 for (const AttributeMetadata &attr : sortedAttrMetadata)
531 attrMetadata.insert({attr.attrName, attr});
534 //===----------------------------------------------------------------------===//
535 // Op emitter
536 //===----------------------------------------------------------------------===//
538 namespace {
539 // Helper class to emit a record into the given output stream.
540 class OpEmitter {
541 using ConstArgument =
542 llvm::PointerUnion<const AttributeMetadata *, const NamedProperty *>;
544 public:
545 static void
546 emitDecl(const Operator &op, raw_ostream &os,
547 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
548 static void
549 emitDef(const Operator &op, raw_ostream &os,
550 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
552 private:
553 OpEmitter(const Operator &op,
554 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
556 void emitDecl(raw_ostream &os);
557 void emitDef(raw_ostream &os);
559 // Generate methods for accessing the attribute names of this operation.
560 void genAttrNameGetters();
562 // Generates the OpAsmOpInterface for this operation if possible.
563 void genOpAsmInterface();
565 // Generates the `getOperationName` method for this op.
566 void genOpNameGetter();
568 // Generates code to manage the properties, if any!
569 void genPropertiesSupport();
571 // Generates code to manage the encoding of properties to bytecode.
572 void
573 genPropertiesSupportForBytecode(ArrayRef<ConstArgument> attrOrProperties);
575 // Generates getters for the attributes.
576 void genAttrGetters();
578 // Generates setter for the attributes.
579 void genAttrSetters();
581 // Generates removers for optional attributes.
582 void genOptionalAttrRemovers();
584 // Generates getters for named operands.
585 void genNamedOperandGetters();
587 // Generates setters for named operands.
588 void genNamedOperandSetters();
590 // Generates getters for named results.
591 void genNamedResultGetters();
593 // Generates getters for named regions.
594 void genNamedRegionGetters();
596 // Generates getters for named successors.
597 void genNamedSuccessorGetters();
599 // Generates the method to populate default attributes.
600 void genPopulateDefaultAttributes();
602 // Generates builder methods for the operation.
603 void genBuilder();
605 // Generates the build() method that takes each operand/attribute
606 // as a stand-alone parameter.
607 void genSeparateArgParamBuilder();
609 // Generates the build() method that takes each operand/attribute as a
610 // stand-alone parameter. The generated build() method uses first operand's
611 // type as all results' types.
612 void genUseOperandAsResultTypeSeparateParamBuilder();
614 // Generates the build() method that takes all operands/attributes
615 // collectively as one parameter. The generated build() method uses first
616 // operand's type as all results' types.
617 void genUseOperandAsResultTypeCollectiveParamBuilder();
619 // Generates the build() method that takes aggregate operands/attributes
620 // parameters. This build() method uses inferred types as result types.
621 // Requires: The type needs to be inferable via InferTypeOpInterface.
622 void genInferredTypeCollectiveParamBuilder();
624 // Generates the build() method that takes each operand/attribute as a
625 // stand-alone parameter. The generated build() method uses first attribute's
626 // type as all result's types.
627 void genUseAttrAsResultTypeBuilder();
629 // Generates the build() method that takes all result types collectively as
630 // one parameter. Similarly for operands and attributes.
631 void genCollectiveParamBuilder();
633 // The kind of parameter to generate for result types in builders.
634 enum class TypeParamKind {
635 None, // No result type in parameter list.
636 Separate, // A separate parameter for each result type.
637 Collective, // An ArrayRef<Type> for all result types.
640 // The kind of parameter to generate for attributes in builders.
641 enum class AttrParamKind {
642 WrappedAttr, // A wrapped MLIR Attribute instance.
643 UnwrappedValue, // A raw value without MLIR Attribute wrapper.
646 // Builds the parameter list for build() method of this op. This method writes
647 // to `paramList` the comma-separated parameter list and updates
648 // `resultTypeNames` with the names for parameters for specifying result
649 // types. `inferredAttributes` is populated with any attributes that are
650 // elided from the build list. The given `typeParamKind` and `attrParamKind`
651 // controls how result types and attributes are placed in the parameter list.
652 void buildParamList(SmallVectorImpl<MethodParameter> &paramList,
653 llvm::StringSet<> &inferredAttributes,
654 SmallVectorImpl<std::string> &resultTypeNames,
655 TypeParamKind typeParamKind,
656 AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
658 // Adds op arguments and regions into operation state for build() methods.
659 void
660 genCodeForAddingArgAndRegionForBuilder(MethodBody &body,
661 llvm::StringSet<> &inferredAttributes,
662 bool isRawValueAttr = false);
664 // Generates canonicalizer declaration for the operation.
665 void genCanonicalizerDecls();
667 // Generates the folder declaration for the operation.
668 void genFolderDecls();
670 // Generates the parser for the operation.
671 void genParser();
673 // Generates the printer for the operation.
674 void genPrinter();
676 // Generates verify method for the operation.
677 void genVerifier();
679 // Generates custom verify methods for the operation.
680 void genCustomVerifier();
682 // Generates verify statements for operands and results in the operation.
683 // The generated code will be attached to `body`.
684 void genOperandResultVerifier(MethodBody &body,
685 Operator::const_value_range values,
686 StringRef valueKind);
688 // Generates verify statements for regions in the operation.
689 // The generated code will be attached to `body`.
690 void genRegionVerifier(MethodBody &body);
692 // Generates verify statements for successors in the operation.
693 // The generated code will be attached to `body`.
694 void genSuccessorVerifier(MethodBody &body);
696 // Generates the traits used by the object.
697 void genTraits();
699 // Generate the OpInterface methods for all interfaces.
700 void genOpInterfaceMethods();
702 // Generate op interface methods for the given interface.
703 void genOpInterfaceMethods(const tblgen::InterfaceTrait *trait);
705 // Generate op interface method for the given interface method. If
706 // 'declaration' is true, generates a declaration, else a definition.
707 Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
708 bool declaration = true);
710 // Generate the side effect interface methods.
711 void genSideEffectInterfaceMethods();
713 // Generate the type inference interface methods.
714 void genTypeInterfaceMethods();
716 private:
717 // The TableGen record for this op.
718 // TODO: OpEmitter should not have a Record directly,
719 // it should rather go through the Operator for better abstraction.
720 const Record &def;
722 // The wrapper operator class for querying information from this op.
723 const Operator &op;
725 // The C++ code builder for this op
726 OpClass opClass;
728 // The format context for verification code generation.
729 FmtContext verifyCtx;
731 // The emitter containing all of the locally emitted verification functions.
732 const StaticVerifierFunctionEmitter &staticVerifierEmitter;
734 // Helper for emitting op code.
735 OpOrAdaptorHelper emitHelper;
738 } // namespace
740 // Populate the format context `ctx` with substitutions of attributes, operands
741 // and results.
742 static void populateSubstitutions(const OpOrAdaptorHelper &emitHelper,
743 FmtContext &ctx) {
744 // Populate substitutions for attributes.
745 auto &op = emitHelper.getOp();
746 for (const auto &namedAttr : op.getAttributes())
747 ctx.addSubst(namedAttr.name,
748 emitHelper.getOp().getGetterName(namedAttr.name) + "()");
750 // Populate substitutions for named operands.
751 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
752 auto &value = op.getOperand(i);
753 if (!value.name.empty())
754 ctx.addSubst(value.name, emitHelper.getOperand(i).str());
757 // Populate substitutions for results.
758 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
759 auto &value = op.getResult(i);
760 if (!value.name.empty())
761 ctx.addSubst(value.name, emitHelper.getResult(i).str());
765 /// Generate verification on native traits requiring attributes.
766 static void genNativeTraitAttrVerifier(MethodBody &body,
767 const OpOrAdaptorHelper &emitHelper) {
768 // Check that the variadic segment sizes attribute exists and contains the
769 // expected number of elements.
771 // {0}: Attribute name.
772 // {1}: Expected number of elements.
773 // {2}: "operand" or "result".
774 // {3}: Emit error prefix.
775 const char *const checkAttrSizedValueSegmentsCode = R"(
777 auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>(tblgen_{0});
778 auto numElements = sizeAttr.asArrayRef().size();
779 if (numElements != {1})
780 return {3}"'{0}' attribute for specifying {2} segments must have {1} "
781 "elements, but got ") << numElements;
785 // Verify a few traits first so that we can use getODSOperands() and
786 // getODSResults() in the rest of the verifier.
787 auto &op = emitHelper.getOp();
788 if (!op.getDialect().usePropertiesForAttributes()) {
789 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
790 body << formatv(checkAttrSizedValueSegmentsCode, operandSegmentAttrName,
791 op.getNumOperands(), "operand",
792 emitHelper.emitErrorPrefix());
794 if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
795 body << formatv(checkAttrSizedValueSegmentsCode, resultSegmentAttrName,
796 op.getNumResults(), "result",
797 emitHelper.emitErrorPrefix());
802 // Return true if a verifier can be emitted for the attribute: it is not a
803 // derived attribute, it has a predicate, its condition is not empty, and, for
804 // adaptors, the condition does not reference the op.
805 static bool canEmitAttrVerifier(Attribute attr, bool isEmittingForOp) {
806 if (attr.isDerivedAttr())
807 return false;
808 Pred pred = attr.getPredicate();
809 if (pred.isNull())
810 return false;
811 std::string condition = pred.getCondition();
812 return !condition.empty() &&
813 (!StringRef(condition).contains("$_op") || isEmittingForOp);
816 // Generate attribute verification. If an op instance is not available, then
817 // attribute checks that require one will not be emitted.
819 // Attribute verification is performed as follows:
821 // 1. Verify that all required attributes are present in sorted order. This
822 // ensures that we can use subrange lookup even with potentially missing
823 // attributes.
824 // 2. Verify native trait attributes so that other attributes may call methods
825 // that depend on the validity of these attributes, e.g. segment size attributes
826 // and operand or result getters.
827 // 3. Verify the constraints on all present attributes.
828 static void
829 genAttributeVerifier(const OpOrAdaptorHelper &emitHelper, FmtContext &ctx,
830 MethodBody &body,
831 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
832 bool useProperties) {
833 if (emitHelper.getAttrMetadata().empty())
834 return;
836 // Verify the attribute if it is present. This assumes that default values
837 // are valid. This code snippet pastes the condition inline.
839 // TODO: verify the default value is valid (perhaps in debug mode only).
841 // {0}: Attribute variable name.
842 // {1}: Attribute condition code.
843 // {2}: Emit error prefix.
844 // {3}: Attribute name.
845 // {4}: Attribute/constraint description.
846 const char *const verifyAttrInline = R"(
847 if ({0} && !({1}))
848 return {2}"attribute '{3}' failed to satisfy constraint: {4}");
850 // Verify the attribute using a uniqued constraint. Can only be used within
851 // the context of an op.
853 // {0}: Unique constraint name.
854 // {1}: Attribute variable name.
855 // {2}: Attribute name.
856 const char *const verifyAttrUnique = R"(
857 if (::mlir::failed({0}(*this, {1}, "{2}")))
858 return ::mlir::failure();
861 // Traverse the array until the required attribute is found. Return an error
862 // if the traversal reached the end.
864 // {0}: Code to get the name of the attribute.
865 // {1}: The emit error prefix.
866 // {2}: The name of the attribute.
867 const char *const findRequiredAttr = R"(
868 while (true) {{
869 if (namedAttrIt == namedAttrRange.end())
870 return {1}"requires attribute '{2}'");
871 if (namedAttrIt->getName() == {0}) {{
872 tblgen_{2} = namedAttrIt->getValue();
873 break;
874 })";
876 // Emit a check to see if the iteration has encountered an optional attribute.
878 // {0}: Code to get the name of the attribute.
879 // {1}: The name of the attribute.
880 const char *const checkOptionalAttr = R"(
881 else if (namedAttrIt->getName() == {0}) {{
882 tblgen_{1} = namedAttrIt->getValue();
883 })";
885 // Emit the start of the loop for checking trailing attributes.
886 const char *const checkTrailingAttrs = R"(while (true) {
887 if (namedAttrIt == namedAttrRange.end()) {
888 break;
889 })";
891 // Emit the verifier for the attribute.
892 const auto emitVerifier = [&](Attribute attr, StringRef attrName,
893 StringRef varName) {
894 std::string condition = attr.getPredicate().getCondition();
896 std::optional<StringRef> constraintFn;
897 if (emitHelper.isEmittingForOp() &&
898 (constraintFn = staticVerifierEmitter.getAttrConstraintFn(attr))) {
899 body << formatv(verifyAttrUnique, *constraintFn, varName, attrName);
900 } else {
901 body << formatv(verifyAttrInline, varName,
902 tgfmt(condition, &ctx.withSelf(varName)),
903 emitHelper.emitErrorPrefix(), attrName,
904 escapeString(attr.getSummary()));
908 // Prefix variables with `tblgen_` to avoid hiding the attribute accessor.
909 const auto getVarName = [&](StringRef attrName) {
910 return (tblgenNamePrefix + attrName).str();
913 body.indent();
914 if (useProperties) {
915 for (const std::pair<StringRef, AttributeMetadata> &it :
916 emitHelper.getAttrMetadata()) {
917 const AttributeMetadata &metadata = it.second;
918 if (metadata.constraint && metadata.constraint->isDerivedAttr())
919 continue;
920 body << formatv(
921 "auto tblgen_{0} = getProperties().{0}; (void)tblgen_{0};\n",
922 it.first);
923 if (metadata.isRequired)
924 body << formatv(
925 "if (!tblgen_{0}) return {1}\"requires attribute '{0}'\");\n",
926 it.first, emitHelper.emitErrorPrefix());
928 } else {
929 body << formatv("auto namedAttrRange = {0};\n", emitHelper.getAttrRange());
930 body << "auto namedAttrIt = namedAttrRange.begin();\n";
932 // Iterate over the attributes in sorted order. Keep track of the optional
933 // attributes that may be encountered along the way.
934 SmallVector<const AttributeMetadata *> optionalAttrs;
936 for (const std::pair<StringRef, AttributeMetadata> &it :
937 emitHelper.getAttrMetadata()) {
938 const AttributeMetadata &metadata = it.second;
939 if (!metadata.isRequired) {
940 optionalAttrs.push_back(&metadata);
941 continue;
944 body << formatv("::mlir::Attribute {0};\n", getVarName(it.first));
945 for (const AttributeMetadata *optional : optionalAttrs) {
946 body << formatv("::mlir::Attribute {0};\n",
947 getVarName(optional->attrName));
949 body << formatv(findRequiredAttr, emitHelper.getAttrName(it.first),
950 emitHelper.emitErrorPrefix(), it.first);
951 for (const AttributeMetadata *optional : optionalAttrs) {
952 body << formatv(checkOptionalAttr,
953 emitHelper.getAttrName(optional->attrName),
954 optional->attrName);
956 body << "\n ++namedAttrIt;\n}\n";
957 optionalAttrs.clear();
959 // Get trailing optional attributes.
960 if (!optionalAttrs.empty()) {
961 for (const AttributeMetadata *optional : optionalAttrs) {
962 body << formatv("::mlir::Attribute {0};\n",
963 getVarName(optional->attrName));
965 body << checkTrailingAttrs;
966 for (const AttributeMetadata *optional : optionalAttrs) {
967 body << formatv(checkOptionalAttr,
968 emitHelper.getAttrName(optional->attrName),
969 optional->attrName);
971 body << "\n ++namedAttrIt;\n}\n";
974 body.unindent();
976 // Emit the checks for segment attributes first so that the other
977 // constraints can call operand and result getters.
978 genNativeTraitAttrVerifier(body, emitHelper);
980 bool isEmittingForOp = emitHelper.isEmittingForOp();
981 for (const auto &namedAttr : emitHelper.getOp().getAttributes())
982 if (canEmitAttrVerifier(namedAttr.attr, isEmittingForOp))
983 emitVerifier(namedAttr.attr, namedAttr.name, getVarName(namedAttr.name));
986 /// Include declarations specified on NativeTrait
987 static std::string formatExtraDeclarations(const Operator &op) {
988 SmallVector<StringRef> extraDeclarations;
989 // Include extra class declarations from NativeTrait
990 for (const auto &trait : op.getTraits()) {
991 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
992 StringRef value = opTrait->getExtraConcreteClassDeclaration();
993 if (value.empty())
994 continue;
995 extraDeclarations.push_back(value);
998 extraDeclarations.push_back(op.getExtraClassDeclaration());
999 return llvm::join(extraDeclarations, "\n");
1002 /// Op extra class definitions have a `$cppClass` substitution that is to be
1003 /// replaced by the C++ class name.
1004 /// Include declarations specified on NativeTrait
1005 static std::string formatExtraDefinitions(const Operator &op) {
1006 SmallVector<StringRef> extraDefinitions;
1007 // Include extra class definitions from NativeTrait
1008 for (const auto &trait : op.getTraits()) {
1009 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
1010 StringRef value = opTrait->getExtraConcreteClassDefinition();
1011 if (value.empty())
1012 continue;
1013 extraDefinitions.push_back(value);
1016 extraDefinitions.push_back(op.getExtraClassDefinition());
1017 FmtContext ctx = FmtContext().addSubst("cppClass", op.getCppClassName());
1018 return tgfmt(llvm::join(extraDefinitions, "\n"), &ctx).str();
1021 OpEmitter::OpEmitter(const Operator &op,
1022 const StaticVerifierFunctionEmitter &staticVerifierEmitter)
1023 : def(op.getDef()), op(op),
1024 opClass(op.getCppClassName(), formatExtraDeclarations(op),
1025 formatExtraDefinitions(op)),
1026 staticVerifierEmitter(staticVerifierEmitter),
1027 emitHelper(op, /*emitForOp=*/true) {
1028 verifyCtx.addSubst("_op", "(*this->getOperation())");
1029 verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()");
1031 genTraits();
1033 // Generate C++ code for various op methods. The order here determines the
1034 // methods in the generated file.
1035 genAttrNameGetters();
1036 genOpAsmInterface();
1037 genOpNameGetter();
1038 genNamedOperandGetters();
1039 genNamedOperandSetters();
1040 genNamedResultGetters();
1041 genNamedRegionGetters();
1042 genNamedSuccessorGetters();
1043 genPropertiesSupport();
1044 genAttrGetters();
1045 genAttrSetters();
1046 genOptionalAttrRemovers();
1047 genBuilder();
1048 genPopulateDefaultAttributes();
1049 genParser();
1050 genPrinter();
1051 genVerifier();
1052 genCustomVerifier();
1053 genCanonicalizerDecls();
1054 genFolderDecls();
1055 genTypeInterfaceMethods();
1056 genOpInterfaceMethods();
1057 generateOpFormat(op, opClass);
1058 genSideEffectInterfaceMethods();
1060 void OpEmitter::emitDecl(
1061 const Operator &op, raw_ostream &os,
1062 const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
1063 OpEmitter(op, staticVerifierEmitter).emitDecl(os);
1066 void OpEmitter::emitDef(
1067 const Operator &op, raw_ostream &os,
1068 const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
1069 OpEmitter(op, staticVerifierEmitter).emitDef(os);
1072 void OpEmitter::emitDecl(raw_ostream &os) {
1073 opClass.finalize();
1074 opClass.writeDeclTo(os);
1077 void OpEmitter::emitDef(raw_ostream &os) {
1078 opClass.finalize();
1079 opClass.writeDefTo(os);
1082 static void errorIfPruned(size_t line, Method *m, const Twine &methodName,
1083 const Operator &op) {
1084 if (m)
1085 return;
1086 PrintFatalError(op.getLoc(), "Unexpected overlap when generating `" +
1087 methodName + "` for " +
1088 op.getOperationName() + " (from line " +
1089 Twine(line) + ")");
1092 #define ERROR_IF_PRUNED(M, N, O) errorIfPruned(__LINE__, M, N, O)
1094 void OpEmitter::genAttrNameGetters() {
1095 const llvm::MapVector<StringRef, AttributeMetadata> &attributes =
1096 emitHelper.getAttrMetadata();
1097 bool hasOperandSegmentsSize =
1098 op.getDialect().usePropertiesForAttributes() &&
1099 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
1100 // Emit the getAttributeNames method.
1102 auto *method = opClass.addStaticInlineMethod(
1103 "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames");
1104 ERROR_IF_PRUNED(method, "getAttributeNames", op);
1105 auto &body = method->body();
1106 if (!hasOperandSegmentsSize && attributes.empty()) {
1107 body << " return {};";
1108 // Nothing else to do if there are no registered attributes. Exit early.
1109 return;
1111 body << " static ::llvm::StringRef attrNames[] = {";
1112 llvm::interleaveComma(llvm::make_first_range(attributes), body,
1113 [&](StringRef attrName) {
1114 body << "::llvm::StringRef(\"" << attrName << "\")";
1116 if (hasOperandSegmentsSize) {
1117 if (!attributes.empty())
1118 body << ", ";
1119 body << "::llvm::StringRef(\"" << operandSegmentAttrName << "\")";
1121 body << "};\n return ::llvm::ArrayRef(attrNames);";
1124 // Emit the getAttributeNameForIndex methods.
1126 auto *method = opClass.addInlineMethod<Method::Private>(
1127 "::mlir::StringAttr", "getAttributeNameForIndex",
1128 MethodParameter("unsigned", "index"));
1129 ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
1130 method->body()
1131 << " return getAttributeNameForIndex((*this)->getName(), index);";
1134 auto *method = opClass.addStaticInlineMethod<Method::Private>(
1135 "::mlir::StringAttr", "getAttributeNameForIndex",
1136 MethodParameter("::mlir::OperationName", "name"),
1137 MethodParameter("unsigned", "index"));
1138 ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
1140 if (attributes.empty()) {
1141 method->body() << " return {};";
1142 } else {
1143 const char *const getAttrName = R"(
1144 assert(index < {0} && "invalid attribute index");
1145 assert(name.getStringRef() == getOperationName() && "invalid operation name");
1146 assert(name.isRegistered() && "Operation isn't registered, missing a "
1147 "dependent dialect loading?");
1148 return name.getAttributeNames()[index];
1150 method->body() << formatv(getAttrName, attributes.size());
1154 // Generate the <attr>AttrName methods, that expose the attribute names to
1155 // users.
1156 const char *attrNameMethodBody = " return getAttributeNameForIndex({0});";
1157 for (auto [index, attr] :
1158 llvm::enumerate(llvm::make_first_range(attributes))) {
1159 std::string name = op.getGetterName(attr);
1160 std::string methodName = name + "AttrName";
1162 // Generate the non-static variant.
1164 auto *method = opClass.addInlineMethod("::mlir::StringAttr", methodName);
1165 ERROR_IF_PRUNED(method, methodName, op);
1166 method->body() << llvm::formatv(attrNameMethodBody, index);
1169 // Generate the static variant.
1171 auto *method = opClass.addStaticInlineMethod(
1172 "::mlir::StringAttr", methodName,
1173 MethodParameter("::mlir::OperationName", "name"));
1174 ERROR_IF_PRUNED(method, methodName, op);
1175 method->body() << llvm::formatv(attrNameMethodBody,
1176 "name, " + Twine(index));
1179 if (hasOperandSegmentsSize) {
1180 std::string name = op.getGetterName(operandSegmentAttrName);
1181 std::string methodName = name + "AttrName";
1182 // Generate the non-static variant.
1184 auto *method = opClass.addInlineMethod("::mlir::StringAttr", methodName);
1185 ERROR_IF_PRUNED(method, methodName, op);
1186 method->body()
1187 << " return (*this)->getName().getAttributeNames().back();";
1190 // Generate the static variant.
1192 auto *method = opClass.addStaticInlineMethod(
1193 "::mlir::StringAttr", methodName,
1194 MethodParameter("::mlir::OperationName", "name"));
1195 ERROR_IF_PRUNED(method, methodName, op);
1196 method->body() << " return name.getAttributeNames().back();";
1201 // Emit the getter for an attribute with the return type specified.
1202 // It is templated to be shared between the Op and the adaptor class.
1203 template <typename OpClassOrAdaptor>
1204 static void emitAttrGetterWithReturnType(FmtContext &fctx,
1205 OpClassOrAdaptor &opClass,
1206 const Operator &op, StringRef name,
1207 Attribute attr) {
1208 auto *method = opClass.addMethod(attr.getReturnType(), name);
1209 ERROR_IF_PRUNED(method, name, op);
1210 auto &body = method->body();
1211 body << " auto attr = " << name << "Attr();\n";
1212 if (attr.hasDefaultValue() && attr.isOptional()) {
1213 // Returns the default value if not set.
1214 // TODO: this is inefficient, we are recreating the attribute for every
1215 // call. This should be set instead.
1216 if (!attr.isConstBuildable()) {
1217 PrintFatalError("DefaultValuedAttr of type " + attr.getAttrDefName() +
1218 " must have a constBuilder");
1220 std::string defaultValue = std::string(
1221 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
1222 body << " if (!attr)\n return "
1223 << tgfmt(attr.getConvertFromStorageCall(),
1224 &fctx.withSelf(defaultValue))
1225 << ";\n";
1227 body << " return "
1228 << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr"))
1229 << ";\n";
1232 void OpEmitter::genPropertiesSupport() {
1233 if (!emitHelper.hasProperties())
1234 return;
1236 SmallVector<ConstArgument> attrOrProperties;
1237 for (const std::pair<StringRef, AttributeMetadata> &it :
1238 emitHelper.getAttrMetadata()) {
1239 if (!it.second.constraint || !it.second.constraint->isDerivedAttr())
1240 attrOrProperties.push_back(&it.second);
1242 for (const NamedProperty &prop : op.getProperties())
1243 attrOrProperties.push_back(&prop);
1244 if (emitHelper.getOperandSegmentsSize())
1245 attrOrProperties.push_back(&emitHelper.getOperandSegmentsSize().value());
1246 if (emitHelper.getResultSegmentsSize())
1247 attrOrProperties.push_back(&emitHelper.getResultSegmentsSize().value());
1248 if (attrOrProperties.empty())
1249 return;
1250 auto &setPropMethod =
1251 opClass
1252 .addStaticMethod(
1253 "::llvm::LogicalResult", "setPropertiesFromAttr",
1254 MethodParameter("Properties &", "prop"),
1255 MethodParameter("::mlir::Attribute", "attr"),
1256 MethodParameter(
1257 "::llvm::function_ref<::mlir::InFlightDiagnostic()>",
1258 "emitError"))
1259 ->body();
1260 auto &getPropMethod =
1261 opClass
1262 .addStaticMethod("::mlir::Attribute", "getPropertiesAsAttr",
1263 MethodParameter("::mlir::MLIRContext *", "ctx"),
1264 MethodParameter("const Properties &", "prop"))
1265 ->body();
1266 auto &hashMethod =
1267 opClass
1268 .addStaticMethod("llvm::hash_code", "computePropertiesHash",
1269 MethodParameter("const Properties &", "prop"))
1270 ->body();
1271 auto &getInherentAttrMethod =
1272 opClass
1273 .addStaticMethod("std::optional<mlir::Attribute>", "getInherentAttr",
1274 MethodParameter("::mlir::MLIRContext *", "ctx"),
1275 MethodParameter("const Properties &", "prop"),
1276 MethodParameter("llvm::StringRef", "name"))
1277 ->body();
1278 auto &setInherentAttrMethod =
1279 opClass
1280 .addStaticMethod("void", "setInherentAttr",
1281 MethodParameter("Properties &", "prop"),
1282 MethodParameter("llvm::StringRef", "name"),
1283 MethodParameter("mlir::Attribute", "value"))
1284 ->body();
1285 auto &populateInherentAttrsMethod =
1286 opClass
1287 .addStaticMethod("void", "populateInherentAttrs",
1288 MethodParameter("::mlir::MLIRContext *", "ctx"),
1289 MethodParameter("const Properties &", "prop"),
1290 MethodParameter("::mlir::NamedAttrList &", "attrs"))
1291 ->body();
1292 auto &verifyInherentAttrsMethod =
1293 opClass
1294 .addStaticMethod(
1295 "::llvm::LogicalResult", "verifyInherentAttrs",
1296 MethodParameter("::mlir::OperationName", "opName"),
1297 MethodParameter("::mlir::NamedAttrList &", "attrs"),
1298 MethodParameter(
1299 "llvm::function_ref<::mlir::InFlightDiagnostic()>",
1300 "emitError"))
1301 ->body();
1303 opClass.declare<UsingDeclaration>("Properties", "FoldAdaptor::Properties");
1305 // Convert the property to the attribute form.
1307 setPropMethod << R"decl(
1308 ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr);
1309 if (!dict) {
1310 emitError() << "expected DictionaryAttr to set properties";
1311 return ::mlir::failure();
1313 )decl";
1314 const char *propFromAttrFmt = R"decl(
1315 auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr,
1316 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {{
1319 {2};
1320 )decl";
1321 const char *attrGetNoDefaultFmt = R"decl(;
1322 if (attr && ::mlir::failed(setFromAttr(prop.{0}, attr, emitError)))
1323 return ::mlir::failure();
1324 )decl";
1325 const char *attrGetDefaultFmt = R"decl(;
1326 if (attr) {{
1327 if (::mlir::failed(setFromAttr(prop.{0}, attr, emitError)))
1328 return ::mlir::failure();
1329 } else {{
1330 prop.{0} = {1};
1332 )decl";
1334 for (const auto &attrOrProp : attrOrProperties) {
1335 if (const auto *namedProperty =
1336 llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
1337 StringRef name = namedProperty->name;
1338 auto &prop = namedProperty->prop;
1339 FmtContext fctx;
1341 std::string getAttr;
1342 llvm::raw_string_ostream os(getAttr);
1343 os << " auto attr = dict.get(\"" << name << "\");";
1344 if (name == operandSegmentAttrName) {
1345 // Backward compat for now, TODO: Remove at some point.
1346 os << " if (!attr) attr = dict.get(\"operand_segment_sizes\");";
1348 if (name == resultSegmentAttrName) {
1349 // Backward compat for now, TODO: Remove at some point.
1350 os << " if (!attr) attr = dict.get(\"result_segment_sizes\");";
1352 os.flush();
1354 setPropMethod << "{\n"
1355 << formatv(propFromAttrFmt,
1356 tgfmt(prop.getConvertFromAttributeCall(),
1357 &fctx.addSubst("_attr", propertyAttr)
1358 .addSubst("_storage", propertyStorage)
1359 .addSubst("_diag", propertyDiag)),
1360 name, getAttr);
1361 if (prop.hasDefaultValue()) {
1362 setPropMethod << formatv(attrGetDefaultFmt, name,
1363 prop.getDefaultValue());
1364 } else {
1365 setPropMethod << formatv(attrGetNoDefaultFmt, name);
1367 setPropMethod << " }\n";
1368 } else {
1369 const auto *namedAttr =
1370 llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
1371 StringRef name = namedAttr->attrName;
1372 std::string getAttr;
1373 llvm::raw_string_ostream os(getAttr);
1374 os << " auto attr = dict.get(\"" << name << "\");";
1375 if (name == operandSegmentAttrName) {
1376 // Backward compat for now
1377 os << " if (!attr) attr = dict.get(\"operand_segment_sizes\");";
1379 if (name == resultSegmentAttrName) {
1380 // Backward compat for now
1381 os << " if (!attr) attr = dict.get(\"result_segment_sizes\");";
1383 os.flush();
1385 setPropMethod << formatv(R"decl(
1387 auto &propStorage = prop.{0};
1389 if (attr) {{
1390 auto convertedAttr = ::llvm::dyn_cast<std::remove_reference_t<decltype(propStorage)>>(attr);
1391 if (convertedAttr) {{
1392 propStorage = convertedAttr;
1393 } else {{
1394 emitError() << "Invalid attribute `{0}` in property conversion: " << attr;
1395 return ::mlir::failure();
1399 )decl",
1400 name, getAttr);
1403 setPropMethod << " return ::mlir::success();\n";
1405 // Convert the attribute form to the property.
1407 getPropMethod << " ::mlir::SmallVector<::mlir::NamedAttribute> attrs;\n"
1408 << " ::mlir::Builder odsBuilder{ctx};\n";
1409 const char *propToAttrFmt = R"decl(
1411 const auto &propStorage = prop.{0};
1412 attrs.push_back(odsBuilder.getNamedAttr("{0}",
1413 {1}));
1415 )decl";
1416 for (const auto &attrOrProp : attrOrProperties) {
1417 if (const auto *namedProperty =
1418 llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
1419 StringRef name = namedProperty->name;
1420 auto &prop = namedProperty->prop;
1421 FmtContext fctx;
1422 getPropMethod << formatv(
1423 propToAttrFmt, name,
1424 tgfmt(prop.getConvertToAttributeCall(),
1425 &fctx.addSubst("_ctxt", "ctx")
1426 .addSubst("_storage", propertyStorage)));
1427 continue;
1429 const auto *namedAttr =
1430 llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
1431 StringRef name = namedAttr->attrName;
1432 getPropMethod << formatv(R"decl(
1434 const auto &propStorage = prop.{0};
1435 if (propStorage)
1436 attrs.push_back(odsBuilder.getNamedAttr("{0}",
1437 propStorage));
1439 )decl",
1440 name);
1442 getPropMethod << R"decl(
1443 if (!attrs.empty())
1444 return odsBuilder.getDictionaryAttr(attrs);
1445 return {};
1446 )decl";
1448 // Hashing for the property
1450 const char *propHashFmt = R"decl(
1451 auto hash_{0} = [] (const auto &propStorage) -> llvm::hash_code {
1452 return {1};
1454 )decl";
1455 for (const auto &attrOrProp : attrOrProperties) {
1456 if (const auto *namedProperty =
1457 llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
1458 StringRef name = namedProperty->name;
1459 auto &prop = namedProperty->prop;
1460 FmtContext fctx;
1461 hashMethod << formatv(propHashFmt, name,
1462 tgfmt(prop.getHashPropertyCall(),
1463 &fctx.addSubst("_storage", propertyStorage)));
1466 hashMethod << " return llvm::hash_combine(";
1467 llvm::interleaveComma(
1468 attrOrProperties, hashMethod, [&](const ConstArgument &attrOrProp) {
1469 if (const auto *namedProperty =
1470 llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
1471 hashMethod << "\n hash_" << namedProperty->name << "(prop."
1472 << namedProperty->name << ")";
1473 return;
1475 const auto *namedAttr =
1476 llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
1477 StringRef name = namedAttr->attrName;
1478 hashMethod << "\n llvm::hash_value(prop." << name
1479 << ".getAsOpaquePointer())";
1481 hashMethod << ");\n";
1483 const char *getInherentAttrMethodFmt = R"decl(
1484 if (name == "{0}")
1485 return prop.{0};
1486 )decl";
1487 const char *setInherentAttrMethodFmt = R"decl(
1488 if (name == "{0}") {{
1489 prop.{0} = ::llvm::dyn_cast_or_null<std::remove_reference_t<decltype(prop.{0})>>(value);
1490 return;
1492 )decl";
1493 const char *populateInherentAttrsMethodFmt = R"decl(
1494 if (prop.{0}) attrs.append("{0}", prop.{0});
1495 )decl";
1496 for (const auto &attrOrProp : attrOrProperties) {
1497 if (const auto *namedAttr =
1498 llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp)) {
1499 StringRef name = namedAttr->attrName;
1500 getInherentAttrMethod << formatv(getInherentAttrMethodFmt, name);
1501 setInherentAttrMethod << formatv(setInherentAttrMethodFmt, name);
1502 populateInherentAttrsMethod
1503 << formatv(populateInherentAttrsMethodFmt, name);
1504 continue;
1506 // The ODS segment size property is "special": we expose it as an attribute
1507 // even though it is a native property.
1508 const auto *namedProperty = cast<const NamedProperty *>(attrOrProp);
1509 StringRef name = namedProperty->name;
1510 if (name != operandSegmentAttrName && name != resultSegmentAttrName)
1511 continue;
1512 auto &prop = namedProperty->prop;
1513 FmtContext fctx;
1514 fctx.addSubst("_ctxt", "ctx");
1515 fctx.addSubst("_storage", Twine("prop.") + name);
1516 if (name == operandSegmentAttrName) {
1517 getInherentAttrMethod
1518 << formatv(" if (name == \"operand_segment_sizes\" || name == "
1519 "\"{0}\") return ",
1520 operandSegmentAttrName);
1521 } else {
1522 getInherentAttrMethod
1523 << formatv(" if (name == \"result_segment_sizes\" || name == "
1524 "\"{0}\") return ",
1525 resultSegmentAttrName);
1527 getInherentAttrMethod << tgfmt(prop.getConvertToAttributeCall(), &fctx)
1528 << ";\n";
1530 if (name == operandSegmentAttrName) {
1531 setInherentAttrMethod
1532 << formatv(" if (name == \"operand_segment_sizes\" || name == "
1533 "\"{0}\") {{",
1534 operandSegmentAttrName);
1535 } else {
1536 setInherentAttrMethod
1537 << formatv(" if (name == \"result_segment_sizes\" || name == "
1538 "\"{0}\") {{",
1539 resultSegmentAttrName);
1541 setInherentAttrMethod << formatv(R"decl(
1542 auto arrAttr = ::llvm::dyn_cast_or_null<::mlir::DenseI32ArrayAttr>(value);
1543 if (!arrAttr) return;
1544 if (arrAttr.size() != sizeof(prop.{0}) / sizeof(int32_t))
1545 return;
1546 llvm::copy(arrAttr.asArrayRef(), prop.{0}.begin());
1547 return;
1549 )decl",
1550 name);
1551 if (name == operandSegmentAttrName) {
1552 populateInherentAttrsMethod
1553 << formatv(" attrs.append(\"{0}\", {1});\n", operandSegmentAttrName,
1554 tgfmt(prop.getConvertToAttributeCall(), &fctx));
1555 } else {
1556 populateInherentAttrsMethod
1557 << formatv(" attrs.append(\"{0}\", {1});\n", resultSegmentAttrName,
1558 tgfmt(prop.getConvertToAttributeCall(), &fctx));
1561 getInherentAttrMethod << " return std::nullopt;\n";
1563 // Emit the verifiers method for backward compatibility with the generic
1564 // syntax. This method verifies the constraint on the properties attributes
1565 // before they are set, since dyn_cast<> will silently omit failures.
1566 for (const auto &attrOrProp : attrOrProperties) {
1567 const auto *namedAttr =
1568 llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
1569 if (!namedAttr || !namedAttr->constraint)
1570 continue;
1571 Attribute attr = *namedAttr->constraint;
1572 std::optional<StringRef> constraintFn =
1573 staticVerifierEmitter.getAttrConstraintFn(attr);
1574 if (!constraintFn)
1575 continue;
1576 if (canEmitAttrVerifier(attr,
1577 /*isEmittingForOp=*/false)) {
1578 std::string name = op.getGetterName(namedAttr->attrName);
1579 verifyInherentAttrsMethod
1580 << formatv(R"(
1582 ::mlir::Attribute attr = attrs.get({0}AttrName(opName));
1583 if (attr && ::mlir::failed({1}(attr, "{2}", emitError)))
1584 return ::mlir::failure();
1587 name, constraintFn, namedAttr->attrName);
1590 verifyInherentAttrsMethod << " return ::mlir::success();";
1592 // Generate methods to interact with bytecode.
1593 genPropertiesSupportForBytecode(attrOrProperties);
1596 void OpEmitter::genPropertiesSupportForBytecode(
1597 ArrayRef<ConstArgument> attrOrProperties) {
1598 if (op.useCustomPropertiesEncoding()) {
1599 opClass.declareStaticMethod(
1600 "::llvm::LogicalResult", "readProperties",
1601 MethodParameter("::mlir::DialectBytecodeReader &", "reader"),
1602 MethodParameter("::mlir::OperationState &", "state"));
1603 opClass.declareMethod(
1604 "void", "writeProperties",
1605 MethodParameter("::mlir::DialectBytecodeWriter &", "writer"));
1606 return;
1609 auto &readPropertiesMethod =
1610 opClass
1611 .addStaticMethod(
1612 "::llvm::LogicalResult", "readProperties",
1613 MethodParameter("::mlir::DialectBytecodeReader &", "reader"),
1614 MethodParameter("::mlir::OperationState &", "state"))
1615 ->body();
1617 auto &writePropertiesMethod =
1618 opClass
1619 .addMethod(
1620 "void", "writeProperties",
1621 MethodParameter("::mlir::DialectBytecodeWriter &", "writer"))
1622 ->body();
1624 // Populate bytecode serialization logic.
1625 readPropertiesMethod
1626 << " auto &prop = state.getOrAddProperties<Properties>(); (void)prop;";
1627 writePropertiesMethod << " auto &prop = getProperties(); (void)prop;\n";
1628 for (const auto &item : llvm::enumerate(attrOrProperties)) {
1629 auto &attrOrProp = item.value();
1630 FmtContext fctx;
1631 fctx.addSubst("_reader", "reader")
1632 .addSubst("_writer", "writer")
1633 .addSubst("_storage", propertyStorage)
1634 .addSubst("_ctxt", "this->getContext()");
1635 // If the op emits operand/result segment sizes as a property, emit the
1636 // legacy reader/writer in the appropriate order to allow backward
1637 // compatibility and back deployment.
1638 if (emitHelper.getOperandSegmentsSize().has_value() &&
1639 item.index() == emitHelper.getOperandSegmentSizesLegacyIndex()) {
1640 FmtContext fmtCtxt(fctx);
1641 fmtCtxt.addSubst("_propName", operandSegmentAttrName);
1642 readPropertiesMethod << tgfmt(readBytecodeSegmentSizeLegacy, &fmtCtxt);
1643 writePropertiesMethod << tgfmt(writeBytecodeSegmentSizeLegacy, &fmtCtxt);
1645 if (emitHelper.getResultSegmentsSize().has_value() &&
1646 item.index() == emitHelper.getResultSegmentSizesLegacyIndex()) {
1647 FmtContext fmtCtxt(fctx);
1648 fmtCtxt.addSubst("_propName", resultSegmentAttrName);
1649 readPropertiesMethod << tgfmt(readBytecodeSegmentSizeLegacy, &fmtCtxt);
1650 writePropertiesMethod << tgfmt(writeBytecodeSegmentSizeLegacy, &fmtCtxt);
1652 if (const auto *namedProperty =
1653 attrOrProp.dyn_cast<const NamedProperty *>()) {
1654 StringRef name = namedProperty->name;
1655 readPropertiesMethod << formatv(
1658 auto &propStorage = prop.{0};
1659 auto readProp = [&]() {
1660 {1};
1661 return ::mlir::success();
1663 if (::mlir::failed(readProp()))
1664 return ::mlir::failure();
1667 name,
1668 tgfmt(namedProperty->prop.getReadFromMlirBytecodeCall(), &fctx));
1669 writePropertiesMethod << formatv(
1672 auto &propStorage = prop.{0};
1673 {1};
1676 name, tgfmt(namedProperty->prop.getWriteToMlirBytecodeCall(), &fctx));
1677 continue;
1679 const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
1680 StringRef name = namedAttr->attrName;
1681 if (namedAttr->isRequired) {
1682 readPropertiesMethod << formatv(R"(
1683 if (::mlir::failed(reader.readAttribute(prop.{0})))
1684 return ::mlir::failure();
1686 name);
1687 writePropertiesMethod
1688 << formatv(" writer.writeAttribute(prop.{0});\n", name);
1689 } else {
1690 readPropertiesMethod << formatv(R"(
1691 if (::mlir::failed(reader.readOptionalAttribute(prop.{0})))
1692 return ::mlir::failure();
1694 name);
1695 writePropertiesMethod << formatv(R"(
1696 writer.writeOptionalAttribute(prop.{0});
1698 name);
1701 readPropertiesMethod << " return ::mlir::success();";
1704 void OpEmitter::genAttrGetters() {
1705 FmtContext fctx;
1706 fctx.withBuilder("::mlir::Builder((*this)->getContext())");
1708 // Emit the derived attribute body.
1709 auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
1710 if (auto *method = opClass.addMethod(attr.getReturnType(), name))
1711 method->body() << " " << attr.getDerivedCodeBody() << "\n";
1714 // Generate named accessor with Attribute return type. This is a wrapper
1715 // class that allows referring to the attributes via accessors instead of
1716 // having to use the string interface for better compile time verification.
1717 auto emitAttrWithStorageType = [&](StringRef name, StringRef attrName,
1718 Attribute attr) {
1719 // The method body for this getter is trivial. Emit it inline.
1720 auto *method =
1721 opClass.addInlineMethod(attr.getStorageType(), name + "Attr");
1722 if (!method)
1723 return;
1724 method->body() << formatv(
1725 " return ::llvm::{1}<{2}>({0});", emitHelper.getAttr(attrName),
1726 attr.isOptional() || attr.hasDefaultValue() ? "dyn_cast_or_null"
1727 : "cast",
1728 attr.getStorageType());
1731 for (const NamedAttribute &namedAttr : op.getAttributes()) {
1732 std::string name = op.getGetterName(namedAttr.name);
1733 if (namedAttr.attr.isDerivedAttr()) {
1734 emitDerivedAttr(name, namedAttr.attr);
1735 } else {
1736 emitAttrWithStorageType(name, namedAttr.name, namedAttr.attr);
1737 emitAttrGetterWithReturnType(fctx, opClass, op, name, namedAttr.attr);
1741 auto derivedAttrs = make_filter_range(op.getAttributes(),
1742 [](const NamedAttribute &namedAttr) {
1743 return namedAttr.attr.isDerivedAttr();
1745 if (derivedAttrs.empty())
1746 return;
1748 opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait");
1749 // Generate helper method to query whether a named attribute is a derived
1750 // attribute. This enables, for example, avoiding adding an attribute that
1751 // overlaps with a derived attribute.
1753 auto *method =
1754 opClass.addStaticMethod("bool", "isDerivedAttribute",
1755 MethodParameter("::llvm::StringRef", "name"));
1756 ERROR_IF_PRUNED(method, "isDerivedAttribute", op);
1757 auto &body = method->body();
1758 for (auto namedAttr : derivedAttrs)
1759 body << " if (name == \"" << namedAttr.name << "\") return true;\n";
1760 body << " return false;";
1762 // Generate method to materialize derived attributes as a DictionaryAttr.
1764 auto *method = opClass.addMethod("::mlir::DictionaryAttr",
1765 "materializeDerivedAttributes");
1766 ERROR_IF_PRUNED(method, "materializeDerivedAttributes", op);
1767 auto &body = method->body();
1769 auto nonMaterializable =
1770 make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) {
1771 return namedAttr.attr.getConvertFromStorageCall().empty();
1773 if (!nonMaterializable.empty()) {
1774 std::string attrs;
1775 llvm::raw_string_ostream os(attrs);
1776 interleaveComma(nonMaterializable, os, [&](const NamedAttribute &attr) {
1777 os << op.getGetterName(attr.name);
1779 PrintWarning(
1780 op.getLoc(),
1781 formatv(
1782 "op has non-materializable derived attributes '{0}', skipping",
1783 os.str()));
1784 body << formatv(" emitOpError(\"op has non-materializable derived "
1785 "attributes '{0}'\");\n",
1786 attrs);
1787 body << " return nullptr;";
1788 return;
1791 body << " ::mlir::MLIRContext* ctx = getContext();\n";
1792 body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
1793 body << " return ::mlir::DictionaryAttr::get(";
1794 body << " ctx, {\n";
1795 interleave(
1796 derivedAttrs, body,
1797 [&](const NamedAttribute &namedAttr) {
1798 auto tmpl = namedAttr.attr.getConvertFromStorageCall();
1799 std::string name = op.getGetterName(namedAttr.name);
1800 body << " {" << name << "AttrName(),\n"
1801 << tgfmt(tmpl, &fctx.withSelf(name + "()")
1802 .withBuilder("odsBuilder")
1803 .addSubst("_ctxt", "ctx")
1804 .addSubst("_storage", "ctx"))
1805 << "}";
1807 ",\n");
1808 body << "});";
1812 void OpEmitter::genAttrSetters() {
1813 bool useProperties = op.getDialect().usePropertiesForAttributes();
1815 // Generate the code to set an attribute.
1816 auto emitSetAttr = [&](Method *method, StringRef getterName,
1817 StringRef attrName, StringRef attrVar) {
1818 if (useProperties) {
1819 method->body() << formatv(" getProperties().{0} = {1};", attrName,
1820 attrVar);
1821 } else {
1822 method->body() << formatv(" (*this)->setAttr({0}AttrName(), {1});",
1823 getterName, attrVar);
1827 // Generate raw named setter type. This is a wrapper class that allows setting
1828 // to the attributes via setters instead of having to use the string interface
1829 // for better compile time verification.
1830 auto emitAttrWithStorageType = [&](StringRef setterName, StringRef getterName,
1831 StringRef attrName, Attribute attr) {
1832 // This method body is trivial, so emit it inline.
1833 auto *method =
1834 opClass.addInlineMethod("void", setterName + "Attr",
1835 MethodParameter(attr.getStorageType(), "attr"));
1836 if (method)
1837 emitSetAttr(method, getterName, attrName, "attr");
1840 // Generate a setter that accepts the underlying C++ type as opposed to the
1841 // attribute type.
1842 auto emitAttrWithReturnType = [&](StringRef setterName, StringRef getterName,
1843 StringRef attrName, Attribute attr) {
1844 Attribute baseAttr = attr.getBaseAttr();
1845 if (!canUseUnwrappedRawValue(baseAttr))
1846 return;
1847 FmtContext fctx;
1848 fctx.withBuilder("::mlir::Builder((*this)->getContext())");
1849 bool isUnitAttr = attr.getAttrDefName() == "UnitAttr";
1850 bool isOptional = attr.isOptional();
1852 auto createMethod = [&](const Twine &paramType) {
1853 return opClass.addMethod("void", setterName,
1854 MethodParameter(paramType.str(), "attrValue"));
1857 // Build the method using the correct parameter type depending on
1858 // optionality.
1859 Method *method = nullptr;
1860 if (isUnitAttr)
1861 method = createMethod("bool");
1862 else if (isOptional)
1863 method =
1864 createMethod("::std::optional<" + baseAttr.getReturnType() + ">");
1865 else
1866 method = createMethod(attr.getReturnType());
1867 if (!method)
1868 return;
1870 // If the value isn't optional, just set it directly.
1871 if (!isOptional) {
1872 emitSetAttr(method, getterName, attrName,
1873 constBuildAttrFromParam(attr, fctx, "attrValue"));
1874 return;
1877 // Otherwise, we only set if the provided value is valid. If it isn't, we
1878 // remove the attribute.
1880 // TODO: Handle unit attr parameters specially, given that it is treated as
1881 // optional but not in the same way as the others (i.e. it uses bool over
1882 // std::optional<>).
1883 StringRef paramStr = isUnitAttr ? "attrValue" : "*attrValue";
1884 if (!useProperties) {
1885 const char *optionalCodeBody = R"(
1886 if (attrValue)
1887 return (*this)->setAttr({0}AttrName(), {1});
1888 (*this)->removeAttr({0}AttrName());)";
1889 method->body() << formatv(
1890 optionalCodeBody, getterName,
1891 constBuildAttrFromParam(baseAttr, fctx, paramStr));
1892 } else {
1893 const char *optionalCodeBody = R"(
1894 auto &odsProp = getProperties().{0};
1895 if (attrValue)
1896 odsProp = {1};
1897 else
1898 odsProp = nullptr;)";
1899 method->body() << formatv(
1900 optionalCodeBody, attrName,
1901 constBuildAttrFromParam(baseAttr, fctx, paramStr));
1905 for (const NamedAttribute &namedAttr : op.getAttributes()) {
1906 if (namedAttr.attr.isDerivedAttr())
1907 continue;
1908 std::string setterName = op.getSetterName(namedAttr.name);
1909 std::string getterName = op.getGetterName(namedAttr.name);
1910 emitAttrWithStorageType(setterName, getterName, namedAttr.name,
1911 namedAttr.attr);
1912 emitAttrWithReturnType(setterName, getterName, namedAttr.name,
1913 namedAttr.attr);
1917 void OpEmitter::genOptionalAttrRemovers() {
1918 // Generate methods for removing optional attributes, instead of having to
1919 // use the string interface. Enables better compile time verification.
1920 auto emitRemoveAttr = [&](StringRef name, bool useProperties) {
1921 auto upperInitial = name.take_front().upper();
1922 auto *method = opClass.addInlineMethod("::mlir::Attribute",
1923 op.getRemoverName(name) + "Attr");
1924 if (!method)
1925 return;
1926 if (useProperties) {
1927 method->body() << formatv(R"(
1928 auto &attr = getProperties().{0};
1929 attr = {{};
1930 return attr;
1932 name);
1933 return;
1935 method->body() << formatv("return (*this)->removeAttr({0}AttrName());",
1936 op.getGetterName(name));
1939 for (const NamedAttribute &namedAttr : op.getAttributes())
1940 if (namedAttr.attr.isOptional())
1941 emitRemoveAttr(namedAttr.name,
1942 op.getDialect().usePropertiesForAttributes());
1945 // Generates the code to compute the start and end index of an operand or result
1946 // range.
1947 template <typename RangeT>
1948 static void generateValueRangeStartAndEnd(
1949 Class &opClass, bool isGenericAdaptorBase, StringRef methodName,
1950 int numVariadic, int numNonVariadic, StringRef rangeSizeCall,
1951 bool hasAttrSegmentSize, StringRef sizeAttrInit, RangeT &&odsValues) {
1953 SmallVector<MethodParameter> parameters{MethodParameter("unsigned", "index")};
1954 if (isGenericAdaptorBase) {
1955 parameters.emplace_back("unsigned", "odsOperandsSize");
1956 // The range size is passed per parameter for generic adaptor bases as
1957 // using the rangeSizeCall would require the operands, which are not
1958 // accessible in the base class.
1959 rangeSizeCall = "odsOperandsSize";
1962 // The method is trivial if the operation does not have any variadic operands.
1963 // In that case, make sure to generate it in-line.
1964 auto *method = opClass.addMethod("std::pair<unsigned, unsigned>", methodName,
1965 numVariadic == 0 ? Method::Properties::Inline
1966 : Method::Properties::None,
1967 parameters);
1968 if (!method)
1969 return;
1970 auto &body = method->body();
1971 if (numVariadic == 0) {
1972 body << " return {index, 1};\n";
1973 } else if (hasAttrSegmentSize) {
1974 body << sizeAttrInit << attrSizedSegmentValueRangeCalcCode;
1975 } else {
1976 // Because the op can have arbitrarily interleaved variadic and non-variadic
1977 // operands, we need to embed a list in the "sink" getter method for
1978 // calculation at run-time.
1979 SmallVector<StringRef, 4> isVariadic;
1980 isVariadic.reserve(llvm::size(odsValues));
1981 for (auto &it : odsValues)
1982 isVariadic.push_back(it.isVariableLength() ? "true" : "false");
1983 std::string isVariadicList = llvm::join(isVariadic, ", ");
1984 body << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
1985 numNonVariadic, numVariadic, rangeSizeCall, "operand");
1989 static std::string generateTypeForGetter(const NamedTypeConstraint &value) {
1990 std::string str = "::mlir::Value";
1991 /// If the CPPClassName is not a fully qualified type. Uses of types
1992 /// across Dialect fail because they are not in the correct namespace. So we
1993 /// dont generate TypedValue unless the type is fully qualified.
1994 /// getCPPClassName doesn't return the fully qualified path for
1995 /// `mlir::pdl::OperationType` see
1996 /// https://github.com/llvm/llvm-project/issues/57279.
1997 /// Adaptor will have values that are not from the type of their operation and
1998 /// this is expected, so we dont generate TypedValue for Adaptor
1999 if (value.constraint.getCPPClassName() != "::mlir::Type" &&
2000 StringRef(value.constraint.getCPPClassName()).starts_with("::"))
2001 str = llvm::formatv("::mlir::TypedValue<{0}>",
2002 value.constraint.getCPPClassName())
2003 .str();
2004 return str;
2007 // Generates the named operand getter methods for the given Operator `op` and
2008 // puts them in `opClass`. Uses `rangeType` as the return type of getters that
2009 // return a range of operands (individual operands are `Value ` and each
2010 // element in the range must also be `Value `); use `rangeBeginCall` to get
2011 // an iterator to the beginning of the operand range; use `rangeSizeCall` to
2012 // obtain the number of operands. `getOperandCallPattern` contains the code
2013 // necessary to obtain a single operand whose position will be substituted
2014 // instead of
2015 // "{0}" marker in the pattern. Note that the pattern should work for any kind
2016 // of ops, in particular for one-operand ops that may not have the
2017 // `getOperand(unsigned)` method.
2018 static void
2019 generateNamedOperandGetters(const Operator &op, Class &opClass,
2020 Class *genericAdaptorBase, StringRef sizeAttrInit,
2021 StringRef rangeType, StringRef rangeElementType,
2022 StringRef rangeBeginCall, StringRef rangeSizeCall,
2023 StringRef getOperandCallPattern) {
2024 const int numOperands = op.getNumOperands();
2025 const int numVariadicOperands = op.getNumVariableLengthOperands();
2026 const int numNormalOperands = numOperands - numVariadicOperands;
2028 const auto *sameVariadicSize =
2029 op.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
2030 const auto *attrSizedOperands =
2031 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
2033 if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) {
2034 PrintFatalError(op.getLoc(), "op has multiple variadic operands but no "
2035 "specification over their sizes");
2038 if (numVariadicOperands < 2 && attrSizedOperands) {
2039 PrintFatalError(op.getLoc(), "op must have at least two variadic operands "
2040 "to use 'AttrSizedOperandSegments' trait");
2043 if (attrSizedOperands && sameVariadicSize) {
2044 PrintFatalError(op.getLoc(),
2045 "op cannot have both 'AttrSizedOperandSegments' and "
2046 "'SameVariadicOperandSize' traits");
2049 // First emit a few "sink" getter methods upon which we layer all nicer named
2050 // getter methods.
2051 // If generating for an adaptor, the method is put into the non-templated
2052 // generic base class, to not require being defined in the header.
2053 // Since the operand size can't be determined from the base class however,
2054 // it has to be passed as an additional argument. The trampoline below
2055 // generates the function with the same signature as the Op in the generic
2056 // adaptor.
2057 bool isGenericAdaptorBase = genericAdaptorBase != nullptr;
2058 generateValueRangeStartAndEnd(
2059 /*opClass=*/isGenericAdaptorBase ? *genericAdaptorBase : opClass,
2060 isGenericAdaptorBase,
2061 /*methodName=*/"getODSOperandIndexAndLength", numVariadicOperands,
2062 numNormalOperands, rangeSizeCall, attrSizedOperands, sizeAttrInit,
2063 const_cast<Operator &>(op).getOperands());
2064 if (isGenericAdaptorBase) {
2065 // Generate trampoline for calling 'getODSOperandIndexAndLength' with just
2066 // the index. This just calls the implementation in the base class but
2067 // passes the operand size as parameter.
2068 Method *method = opClass.addInlineMethod(
2069 "std::pair<unsigned, unsigned>", "getODSOperandIndexAndLength",
2070 MethodParameter("unsigned", "index"));
2071 ERROR_IF_PRUNED(method, "getODSOperandIndexAndLength", op);
2072 MethodBody &body = method->body();
2073 body.indent() << formatv(
2074 "return Base::getODSOperandIndexAndLength(index, {0});", rangeSizeCall);
2077 // The implementation of this method is trivial and it is very load-bearing.
2078 // Generate it inline.
2079 auto *m = opClass.addInlineMethod(rangeType, "getODSOperands",
2080 MethodParameter("unsigned", "index"));
2081 ERROR_IF_PRUNED(m, "getODSOperands", op);
2082 auto &body = m->body();
2083 body << formatv(valueRangeReturnCode, rangeBeginCall,
2084 "getODSOperandIndexAndLength(index)");
2086 // Then we emit nicer named getter methods by redirecting to the "sink" getter
2087 // method.
2088 for (int i = 0; i != numOperands; ++i) {
2089 const auto &operand = op.getOperand(i);
2090 if (operand.name.empty())
2091 continue;
2092 std::string name = op.getGetterName(operand.name);
2093 if (operand.isOptional()) {
2094 m = opClass.addInlineMethod(isGenericAdaptorBase
2095 ? rangeElementType
2096 : generateTypeForGetter(operand),
2097 name);
2098 ERROR_IF_PRUNED(m, name, op);
2099 m->body().indent() << formatv("auto operands = getODSOperands({0});\n"
2100 "return operands.empty() ? {1}{{} : ",
2101 i, m->getReturnType());
2102 if (!isGenericAdaptorBase)
2103 m->body() << llvm::formatv("::llvm::cast<{0}>", m->getReturnType());
2104 m->body() << "(*operands.begin());";
2105 } else if (operand.isVariadicOfVariadic()) {
2106 std::string segmentAttr = op.getGetterName(
2107 operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
2108 if (genericAdaptorBase) {
2109 m = opClass.addMethod("::llvm::SmallVector<" + rangeType + ">", name);
2110 ERROR_IF_PRUNED(m, name, op);
2111 m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode,
2112 segmentAttr, i, rangeType);
2113 continue;
2116 m = opClass.addInlineMethod("::mlir::OperandRangeRange", name);
2117 ERROR_IF_PRUNED(m, name, op);
2118 m->body() << " return getODSOperands(" << i << ").split(" << segmentAttr
2119 << "Attr());";
2120 } else if (operand.isVariadic()) {
2121 m = opClass.addInlineMethod(rangeType, name);
2122 ERROR_IF_PRUNED(m, name, op);
2123 m->body() << " return getODSOperands(" << i << ");";
2124 } else {
2125 m = opClass.addInlineMethod(isGenericAdaptorBase
2126 ? rangeElementType
2127 : generateTypeForGetter(operand),
2128 name);
2129 ERROR_IF_PRUNED(m, name, op);
2130 m->body().indent() << "return ";
2131 if (!isGenericAdaptorBase)
2132 m->body() << llvm::formatv("::llvm::cast<{0}>", m->getReturnType());
2133 m->body() << llvm::formatv("(*getODSOperands({0}).begin());", i);
2138 void OpEmitter::genNamedOperandGetters() {
2139 // Build the code snippet used for initializing the operand_segment_size)s
2140 // array.
2141 std::string attrSizeInitCode;
2142 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
2143 if (op.getDialect().usePropertiesForAttributes())
2144 attrSizeInitCode = formatv(adapterSegmentSizeAttrInitCodeProperties,
2145 "getProperties().operandSegmentSizes");
2147 else
2148 attrSizeInitCode = formatv(opSegmentSizeAttrInitCode,
2149 emitHelper.getAttr(operandSegmentAttrName));
2152 generateNamedOperandGetters(
2153 op, opClass,
2154 /*genericAdaptorBase=*/nullptr,
2155 /*sizeAttrInit=*/attrSizeInitCode,
2156 /*rangeType=*/"::mlir::Operation::operand_range",
2157 /*rangeElementType=*/"::mlir::Value",
2158 /*rangeBeginCall=*/"getOperation()->operand_begin()",
2159 /*rangeSizeCall=*/"getOperation()->getNumOperands()",
2160 /*getOperandCallPattern=*/"getOperation()->getOperand({0})");
2163 void OpEmitter::genNamedOperandSetters() {
2164 auto *attrSizedOperands =
2165 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
2166 for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
2167 const auto &operand = op.getOperand(i);
2168 if (operand.name.empty())
2169 continue;
2170 std::string name = op.getGetterName(operand.name);
2172 StringRef returnType;
2173 if (operand.isVariadicOfVariadic()) {
2174 returnType = "::mlir::MutableOperandRangeRange";
2175 } else if (operand.isVariableLength()) {
2176 returnType = "::mlir::MutableOperandRange";
2177 } else {
2178 returnType = "::mlir::OpOperand &";
2180 bool isVariadicOperand =
2181 operand.isVariadicOfVariadic() || operand.isVariableLength();
2182 auto *m = opClass.addMethod(returnType, name + "Mutable",
2183 isVariadicOperand ? Method::Properties::None
2184 : Method::Properties::Inline);
2185 ERROR_IF_PRUNED(m, name, op);
2186 auto &body = m->body();
2187 body << " auto range = getODSOperandIndexAndLength(" << i << ");\n";
2189 if (!isVariadicOperand) {
2190 // In case of a single operand, return a single OpOperand.
2191 body << " return getOperation()->getOpOperand(range.first);\n";
2192 continue;
2195 body << " auto mutableRange = "
2196 "::mlir::MutableOperandRange(getOperation(), "
2197 "range.first, range.second";
2198 if (attrSizedOperands) {
2199 if (emitHelper.hasProperties())
2200 body << formatv(", ::mlir::MutableOperandRange::OperandSegment({0}u, "
2201 "{{getOperandSegmentSizesAttrName(), "
2202 "::mlir::DenseI32ArrayAttr::get(getContext(), "
2203 "getProperties().operandSegmentSizes)})",
2205 else
2206 body << formatv(
2207 ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i,
2208 emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true));
2210 body << ");\n";
2212 // If this operand is a nested variadic, we split the range into a
2213 // MutableOperandRangeRange that provides a range over all of the
2214 // sub-ranges.
2215 if (operand.isVariadicOfVariadic()) {
2216 body << " return "
2217 "mutableRange.split(*(*this)->getAttrDictionary().getNamed("
2218 << op.getGetterName(
2219 operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
2220 << "AttrName()));\n";
2221 } else {
2222 // Otherwise, we use the full range directly.
2223 body << " return mutableRange;\n";
2228 void OpEmitter::genNamedResultGetters() {
2229 const int numResults = op.getNumResults();
2230 const int numVariadicResults = op.getNumVariableLengthResults();
2231 const int numNormalResults = numResults - numVariadicResults;
2233 // If we have more than one variadic results, we need more complicated logic
2234 // to calculate the value range for each result.
2236 const auto *sameVariadicSize =
2237 op.getTrait("::mlir::OpTrait::SameVariadicResultSize");
2238 const auto *attrSizedResults =
2239 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments");
2241 if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) {
2242 PrintFatalError(op.getLoc(), "op has multiple variadic results but no "
2243 "specification over their sizes");
2246 if (numVariadicResults < 2 && attrSizedResults) {
2247 PrintFatalError(op.getLoc(), "op must have at least two variadic results "
2248 "to use 'AttrSizedResultSegments' trait");
2251 if (attrSizedResults && sameVariadicSize) {
2252 PrintFatalError(op.getLoc(),
2253 "op cannot have both 'AttrSizedResultSegments' and "
2254 "'SameVariadicResultSize' traits");
2257 // Build the initializer string for the result segment size attribute.
2258 std::string attrSizeInitCode;
2259 if (attrSizedResults) {
2260 if (op.getDialect().usePropertiesForAttributes())
2261 attrSizeInitCode = formatv(adapterSegmentSizeAttrInitCodeProperties,
2262 "getProperties().resultSegmentSizes");
2264 else
2265 attrSizeInitCode = formatv(opSegmentSizeAttrInitCode,
2266 emitHelper.getAttr(resultSegmentAttrName));
2269 generateValueRangeStartAndEnd(
2270 opClass, /*isGenericAdaptorBase=*/false, "getODSResultIndexAndLength",
2271 numVariadicResults, numNormalResults, "getOperation()->getNumResults()",
2272 attrSizedResults, attrSizeInitCode, op.getResults());
2274 // The implementation of this method is trivial and it is very load-bearing.
2275 // Generate it inline.
2276 auto *m = opClass.addInlineMethod("::mlir::Operation::result_range",
2277 "getODSResults",
2278 MethodParameter("unsigned", "index"));
2279 ERROR_IF_PRUNED(m, "getODSResults", op);
2280 m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
2281 "getODSResultIndexAndLength(index)");
2283 for (int i = 0; i != numResults; ++i) {
2284 const auto &result = op.getResult(i);
2285 if (result.name.empty())
2286 continue;
2287 std::string name = op.getGetterName(result.name);
2288 if (result.isOptional()) {
2289 m = opClass.addInlineMethod(generateTypeForGetter(result), name);
2290 ERROR_IF_PRUNED(m, name, op);
2291 m->body() << " auto results = getODSResults(" << i << ");\n"
2292 << llvm::formatv(" return results.empty()"
2293 " ? {0}()"
2294 " : ::llvm::cast<{0}>(*results.begin());",
2295 m->getReturnType());
2296 } else if (result.isVariadic()) {
2297 m = opClass.addInlineMethod("::mlir::Operation::result_range", name);
2298 ERROR_IF_PRUNED(m, name, op);
2299 m->body() << " return getODSResults(" << i << ");";
2300 } else {
2301 m = opClass.addInlineMethod(generateTypeForGetter(result), name);
2302 ERROR_IF_PRUNED(m, name, op);
2303 m->body() << llvm::formatv(
2304 " return ::llvm::cast<{0}>(*getODSResults({1}).begin());",
2305 m->getReturnType(), i);
2310 void OpEmitter::genNamedRegionGetters() {
2311 unsigned numRegions = op.getNumRegions();
2312 for (unsigned i = 0; i < numRegions; ++i) {
2313 const auto &region = op.getRegion(i);
2314 if (region.name.empty())
2315 continue;
2316 std::string name = op.getGetterName(region.name);
2318 // Generate the accessors for a variadic region.
2319 if (region.isVariadic()) {
2320 auto *m = opClass.addInlineMethod(
2321 "::mlir::MutableArrayRef<::mlir::Region>", name);
2322 ERROR_IF_PRUNED(m, name, op);
2323 m->body() << formatv(" return (*this)->getRegions().drop_front({0});",
2325 continue;
2328 auto *m = opClass.addInlineMethod("::mlir::Region &", name);
2329 ERROR_IF_PRUNED(m, name, op);
2330 m->body() << formatv(" return (*this)->getRegion({0});", i);
2334 void OpEmitter::genNamedSuccessorGetters() {
2335 unsigned numSuccessors = op.getNumSuccessors();
2336 for (unsigned i = 0; i < numSuccessors; ++i) {
2337 const NamedSuccessor &successor = op.getSuccessor(i);
2338 if (successor.name.empty())
2339 continue;
2340 std::string name = op.getGetterName(successor.name);
2341 // Generate the accessors for a variadic successor list.
2342 if (successor.isVariadic()) {
2343 auto *m = opClass.addInlineMethod("::mlir::SuccessorRange", name);
2344 ERROR_IF_PRUNED(m, name, op);
2345 m->body() << formatv(
2346 " return {std::next((*this)->successor_begin(), {0}), "
2347 "(*this)->successor_end()};",
2349 continue;
2352 auto *m = opClass.addInlineMethod("::mlir::Block *", name);
2353 ERROR_IF_PRUNED(m, name, op);
2354 m->body() << formatv(" return (*this)->getSuccessor({0});", i);
2358 static bool canGenerateUnwrappedBuilder(const Operator &op) {
2359 // If this op does not have native attributes at all, return directly to avoid
2360 // redefining builders.
2361 if (op.getNumNativeAttributes() == 0)
2362 return false;
2364 bool canGenerate = false;
2365 // We are generating builders that take raw values for attributes. We need to
2366 // make sure the native attributes have a meaningful "unwrapped" value type
2367 // different from the wrapped mlir::Attribute type to avoid redefining
2368 // builders. This checks for the op has at least one such native attribute.
2369 for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) {
2370 const NamedAttribute &namedAttr = op.getAttribute(i);
2371 if (canUseUnwrappedRawValue(namedAttr.attr)) {
2372 canGenerate = true;
2373 break;
2376 return canGenerate;
2379 static bool canInferType(const Operator &op) {
2380 return op.getTrait("::mlir::InferTypeOpInterface::Trait");
2383 void OpEmitter::genSeparateArgParamBuilder() {
2384 SmallVector<AttrParamKind, 2> attrBuilderType;
2385 attrBuilderType.push_back(AttrParamKind::WrappedAttr);
2386 if (canGenerateUnwrappedBuilder(op))
2387 attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
2389 // Emit with separate builders with or without unwrapped attributes and/or
2390 // inferring result type.
2391 auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
2392 bool inferType) {
2393 SmallVector<MethodParameter> paramList;
2394 SmallVector<std::string, 4> resultNames;
2395 llvm::StringSet<> inferredAttributes;
2396 buildParamList(paramList, inferredAttributes, resultNames, paramKind,
2397 attrType);
2399 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2400 // If the builder is redundant, skip generating the method.
2401 if (!m)
2402 return;
2403 auto &body = m->body();
2404 genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
2405 /*isRawValueAttr=*/attrType ==
2406 AttrParamKind::UnwrappedValue);
2408 // Push all result types to the operation state
2410 if (inferType) {
2411 // Generate builder that infers type too.
2412 // TODO: Subsume this with general checking if type can be
2413 // inferred automatically.
2414 body << formatv(R"(
2415 ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
2416 if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
2417 {1}.location, {1}.operands,
2418 {1}.attributes.getDictionary({1}.getContext()),
2419 {1}.getRawProperties(),
2420 {1}.regions, inferredReturnTypes)))
2421 {1}.addTypes(inferredReturnTypes);
2422 else
2423 ::llvm::report_fatal_error("Failed to infer result type(s).");)",
2424 opClass.getClassName(), builderOpState);
2425 return;
2428 switch (paramKind) {
2429 case TypeParamKind::None:
2430 return;
2431 case TypeParamKind::Separate:
2432 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
2433 if (op.getResult(i).isOptional())
2434 body << " if (" << resultNames[i] << ")\n ";
2435 body << " " << builderOpState << ".addTypes(" << resultNames[i]
2436 << ");\n";
2439 // Automatically create the 'resultSegmentSizes' attribute using
2440 // the length of the type ranges.
2441 if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
2442 if (op.getDialect().usePropertiesForAttributes()) {
2443 body << " ::llvm::copy(::llvm::ArrayRef<int32_t>({";
2444 } else {
2445 std::string getterName = op.getGetterName(resultSegmentAttrName);
2446 body << " " << builderOpState << ".addAttribute(" << getterName
2447 << "AttrName(" << builderOpState << ".name), "
2448 << "odsBuilder.getDenseI32ArrayAttr({";
2450 interleaveComma(
2451 llvm::seq<int>(0, op.getNumResults()), body, [&](int i) {
2452 const NamedTypeConstraint &result = op.getResult(i);
2453 if (!result.isVariableLength()) {
2454 body << "1";
2455 } else if (result.isOptional()) {
2456 body << "(" << resultNames[i] << " ? 1 : 0)";
2457 } else {
2458 // VariadicOfVariadic of results are currently unsupported in
2459 // MLIR, hence it can only be a simple variadic.
2460 // TODO: Add implementation for VariadicOfVariadic results here
2461 // once supported.
2462 assert(result.isVariadic());
2463 body << "static_cast<int32_t>(" << resultNames[i] << ".size())";
2466 if (op.getDialect().usePropertiesForAttributes()) {
2467 body << "}), " << builderOpState
2468 << ".getOrAddProperties<Properties>()."
2469 "resultSegmentSizes.begin());\n";
2470 } else {
2471 body << "}));\n";
2475 return;
2476 case TypeParamKind::Collective: {
2477 int numResults = op.getNumResults();
2478 int numVariadicResults = op.getNumVariableLengthResults();
2479 int numNonVariadicResults = numResults - numVariadicResults;
2480 bool hasVariadicResult = numVariadicResults != 0;
2482 // Avoid emitting "resultTypes.size() >= 0u" which is always true.
2483 if (!hasVariadicResult || numNonVariadicResults != 0)
2484 body << " "
2485 << "assert(resultTypes.size() "
2486 << (hasVariadicResult ? ">=" : "==") << " "
2487 << numNonVariadicResults
2488 << "u && \"mismatched number of results\");\n";
2489 body << " " << builderOpState << ".addTypes(resultTypes);\n";
2491 return;
2493 llvm_unreachable("unhandled TypeParamKind");
2496 // Some of the build methods generated here may be ambiguous, but TableGen's
2497 // ambiguous function detection will elide those ones.
2498 for (auto attrType : attrBuilderType) {
2499 emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
2500 if (canInferType(op))
2501 emit(attrType, TypeParamKind::None, /*inferType=*/true);
2502 emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
2506 void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
2507 int numResults = op.getNumResults();
2509 // Signature
2510 SmallVector<MethodParameter> paramList;
2511 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
2512 paramList.emplace_back("::mlir::OperationState &", builderOpState);
2513 paramList.emplace_back("::mlir::ValueRange", "operands");
2514 // Provide default value for `attributes` when its the last parameter
2515 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
2516 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
2517 "attributes", attributesDefaultValue);
2518 if (op.getNumVariadicRegions())
2519 paramList.emplace_back("unsigned", "numRegions");
2521 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2522 // If the builder is redundant, skip generating the method
2523 if (!m)
2524 return;
2525 auto &body = m->body();
2527 // Operands
2528 body << " " << builderOpState << ".addOperands(operands);\n";
2530 // Attributes
2531 body << " " << builderOpState << ".addAttributes(attributes);\n";
2533 // Create the correct number of regions
2534 if (int numRegions = op.getNumRegions()) {
2535 body << llvm::formatv(
2536 " for (unsigned i = 0; i != {0}; ++i)\n",
2537 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
2538 body << " (void)" << builderOpState << ".addRegion();\n";
2541 // Result types
2542 SmallVector<std::string, 2> resultTypes(numResults, "operands[0].getType()");
2543 body << " " << builderOpState << ".addTypes({"
2544 << llvm::join(resultTypes, ", ") << "});\n\n";
2547 void OpEmitter::genPopulateDefaultAttributes() {
2548 // All done if no attributes, except optional ones, have default values.
2549 if (llvm::all_of(op.getAttributes(), [](const NamedAttribute &named) {
2550 return !named.attr.hasDefaultValue() || named.attr.isOptional();
2552 return;
2554 if (emitHelper.hasProperties()) {
2555 SmallVector<MethodParameter> paramList;
2556 paramList.emplace_back("::mlir::OperationName", "opName");
2557 paramList.emplace_back("Properties &", "properties");
2558 auto *m =
2559 opClass.addStaticMethod("void", "populateDefaultProperties", paramList);
2560 ERROR_IF_PRUNED(m, "populateDefaultProperties", op);
2561 auto &body = m->body();
2562 body.indent();
2563 body << "::mlir::Builder " << odsBuilder << "(opName.getContext());\n";
2564 for (const NamedAttribute &namedAttr : op.getAttributes()) {
2565 auto &attr = namedAttr.attr;
2566 if (!attr.hasDefaultValue() || attr.isOptional())
2567 continue;
2568 StringRef name = namedAttr.name;
2569 FmtContext fctx;
2570 fctx.withBuilder(odsBuilder);
2571 body << "if (!properties." << name << ")\n"
2572 << " properties." << name << " = "
2573 << std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx,
2574 tgfmt(attr.getDefaultValue(), &fctx)))
2575 << ";\n";
2577 return;
2580 SmallVector<MethodParameter> paramList;
2581 paramList.emplace_back("const ::mlir::OperationName &", "opName");
2582 paramList.emplace_back("::mlir::NamedAttrList &", "attributes");
2583 auto *m = opClass.addStaticMethod("void", "populateDefaultAttrs", paramList);
2584 ERROR_IF_PRUNED(m, "populateDefaultAttrs", op);
2585 auto &body = m->body();
2586 body.indent();
2588 // Set default attributes that are unset.
2589 body << "auto attrNames = opName.getAttributeNames();\n";
2590 body << "::mlir::Builder " << odsBuilder
2591 << "(attrNames.front().getContext());\n";
2592 StringMap<int> attrIndex;
2593 for (const auto &it : llvm::enumerate(emitHelper.getAttrMetadata())) {
2594 attrIndex[it.value().first] = it.index();
2596 for (const NamedAttribute &namedAttr : op.getAttributes()) {
2597 auto &attr = namedAttr.attr;
2598 if (!attr.hasDefaultValue() || attr.isOptional())
2599 continue;
2600 auto index = attrIndex[namedAttr.name];
2601 body << "if (!attributes.get(attrNames[" << index << "])) {\n";
2602 FmtContext fctx;
2603 fctx.withBuilder(odsBuilder);
2605 std::string defaultValue =
2606 std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx,
2607 tgfmt(attr.getDefaultValue(), &fctx)));
2608 body.indent() << formatv("attributes.append(attrNames[{0}], {1});\n", index,
2609 defaultValue);
2610 body.unindent() << "}\n";
2614 void OpEmitter::genInferredTypeCollectiveParamBuilder() {
2615 SmallVector<MethodParameter> paramList;
2616 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
2617 paramList.emplace_back("::mlir::OperationState &", builderOpState);
2618 paramList.emplace_back("::mlir::ValueRange", "operands");
2619 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
2620 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
2621 "attributes", attributesDefaultValue);
2622 if (op.getNumVariadicRegions())
2623 paramList.emplace_back("unsigned", "numRegions");
2625 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2626 // If the builder is redundant, skip generating the method
2627 if (!m)
2628 return;
2629 auto &body = m->body();
2631 int numResults = op.getNumResults();
2632 int numVariadicResults = op.getNumVariableLengthResults();
2633 int numNonVariadicResults = numResults - numVariadicResults;
2635 int numOperands = op.getNumOperands();
2636 int numVariadicOperands = op.getNumVariableLengthOperands();
2637 int numNonVariadicOperands = numOperands - numVariadicOperands;
2639 // Operands
2640 if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
2641 body << " assert(operands.size()"
2642 << (numVariadicOperands != 0 ? " >= " : " == ")
2643 << numNonVariadicOperands
2644 << "u && \"mismatched number of parameters\");\n";
2645 body << " " << builderOpState << ".addOperands(operands);\n";
2646 body << " " << builderOpState << ".addAttributes(attributes);\n";
2648 // Create the correct number of regions
2649 if (int numRegions = op.getNumRegions()) {
2650 body << llvm::formatv(
2651 " for (unsigned i = 0; i != {0}; ++i)\n",
2652 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
2653 body << " (void)" << builderOpState << ".addRegion();\n";
2656 // Result types
2657 if (emitHelper.hasProperties()) {
2658 // Initialize the properties from Attributes before invoking the infer
2659 // function.
2660 body << formatv(R"(
2661 if (!attributes.empty()) {
2662 ::mlir::OpaqueProperties properties =
2663 &{1}.getOrAddProperties<{0}::Properties>();
2664 std::optional<::mlir::RegisteredOperationName> info =
2665 {1}.name.getRegisteredInfo();
2666 if (failed(info->setOpPropertiesFromAttribute({1}.name, properties,
2667 {1}.attributes.getDictionary({1}.getContext()), nullptr)))
2668 ::llvm::report_fatal_error("Property conversion failed.");
2669 })",
2670 opClass.getClassName(), builderOpState);
2672 body << formatv(R"(
2673 ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
2674 if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
2675 {1}.location, operands,
2676 {1}.attributes.getDictionary({1}.getContext()),
2677 {1}.getRawProperties(),
2678 {1}.regions, inferredReturnTypes))) {{)",
2679 opClass.getClassName(), builderOpState);
2680 if (numVariadicResults == 0 || numNonVariadicResults != 0)
2681 body << "\n assert(inferredReturnTypes.size()"
2682 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
2683 << "u && \"mismatched number of return types\");";
2684 body << "\n " << builderOpState << ".addTypes(inferredReturnTypes);";
2686 body << formatv(R"(
2687 } else {{
2688 ::llvm::report_fatal_error("Failed to infer result type(s).");
2689 })",
2690 opClass.getClassName(), builderOpState);
2693 void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
2694 auto emit = [&](AttrParamKind attrType) {
2695 SmallVector<MethodParameter> paramList;
2696 SmallVector<std::string, 4> resultNames;
2697 llvm::StringSet<> inferredAttributes;
2698 buildParamList(paramList, inferredAttributes, resultNames,
2699 TypeParamKind::None, attrType);
2701 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2702 // If the builder is redundant, skip generating the method
2703 if (!m)
2704 return;
2705 auto &body = m->body();
2706 genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
2707 /*isRawValueAttr=*/attrType ==
2708 AttrParamKind::UnwrappedValue);
2710 auto numResults = op.getNumResults();
2711 if (numResults == 0)
2712 return;
2714 // Push all result types to the operation state
2715 const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
2716 std::string resultType =
2717 formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str();
2718 body << " " << builderOpState << ".addTypes({" << resultType;
2719 for (int i = 1; i != numResults; ++i)
2720 body << ", " << resultType;
2721 body << "});\n\n";
2724 emit(AttrParamKind::WrappedAttr);
2725 // Generate additional builder(s) if any attributes can be "unwrapped"
2726 if (canGenerateUnwrappedBuilder(op))
2727 emit(AttrParamKind::UnwrappedValue);
2730 void OpEmitter::genUseAttrAsResultTypeBuilder() {
2731 SmallVector<MethodParameter> paramList;
2732 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
2733 paramList.emplace_back("::mlir::OperationState &", builderOpState);
2734 paramList.emplace_back("::mlir::ValueRange", "operands");
2735 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
2736 "attributes", "{}");
2737 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2738 // If the builder is redundant, skip generating the method
2739 if (!m)
2740 return;
2742 auto &body = m->body();
2744 // Push all result types to the operation state
2745 std::string resultType;
2746 const auto &namedAttr = op.getAttribute(0);
2748 body << " auto attrName = " << op.getGetterName(namedAttr.name)
2749 << "AttrName(" << builderOpState
2750 << ".name);\n"
2751 " for (auto attr : attributes) {\n"
2752 " if (attr.getName() != attrName) continue;\n";
2753 if (namedAttr.attr.isTypeAttr()) {
2754 resultType = "::llvm::cast<::mlir::TypeAttr>(attr.getValue()).getValue()";
2755 } else {
2756 resultType = "::llvm::cast<::mlir::TypedAttr>(attr.getValue()).getType()";
2759 // Operands
2760 body << " " << builderOpState << ".addOperands(operands);\n";
2762 // Attributes
2763 body << " " << builderOpState << ".addAttributes(attributes);\n";
2765 // Result types
2766 SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType);
2767 body << " " << builderOpState << ".addTypes({"
2768 << llvm::join(resultTypes, ", ") << "});\n";
2769 body << " }\n";
2772 /// Returns a signature of the builder. Updates the context `fctx` to enable
2773 /// replacement of $_builder and $_state in the body.
2774 static SmallVector<MethodParameter>
2775 getBuilderSignature(const Builder &builder) {
2776 ArrayRef<Builder::Parameter> params(builder.getParameters());
2778 // Inject builder and state arguments.
2779 SmallVector<MethodParameter> arguments;
2780 arguments.reserve(params.size() + 2);
2781 arguments.emplace_back("::mlir::OpBuilder &", odsBuilder);
2782 arguments.emplace_back("::mlir::OperationState &", builderOpState);
2784 for (unsigned i = 0, e = params.size(); i < e; ++i) {
2785 // If no name is provided, generate one.
2786 std::optional<StringRef> paramName = params[i].getName();
2787 std::string name =
2788 paramName ? paramName->str() : "odsArg" + std::to_string(i);
2790 StringRef defaultValue;
2791 if (std::optional<StringRef> defaultParamValue =
2792 params[i].getDefaultValue())
2793 defaultValue = *defaultParamValue;
2795 arguments.emplace_back(params[i].getCppType(), std::move(name),
2796 defaultValue);
2799 return arguments;
2802 void OpEmitter::genBuilder() {
2803 // Handle custom builders if provided.
2804 for (const Builder &builder : op.getBuilders()) {
2805 SmallVector<MethodParameter> arguments = getBuilderSignature(builder);
2807 std::optional<StringRef> body = builder.getBody();
2808 auto properties = body ? Method::Static : Method::StaticDeclaration;
2809 auto *method =
2810 opClass.addMethod("void", "build", properties, std::move(arguments));
2811 if (body)
2812 ERROR_IF_PRUNED(method, "build", op);
2814 if (method)
2815 method->setDeprecated(builder.getDeprecatedMessage());
2817 FmtContext fctx;
2818 fctx.withBuilder(odsBuilder);
2819 fctx.addSubst("_state", builderOpState);
2820 if (body)
2821 method->body() << tgfmt(*body, &fctx);
2824 // Generate default builders that requires all result type, operands, and
2825 // attributes as parameters.
2826 if (op.skipDefaultBuilders())
2827 return;
2829 // We generate three classes of builders here:
2830 // 1. one having a stand-alone parameter for each operand / attribute, and
2831 genSeparateArgParamBuilder();
2832 // 2. one having an aggregated parameter for all result types / operands /
2833 // attributes, and
2834 genCollectiveParamBuilder();
2835 // 3. one having a stand-alone parameter for each operand and attribute,
2836 // use the first operand or attribute's type as all result types
2837 // to facilitate different call patterns.
2838 if (op.getNumVariableLengthResults() == 0) {
2839 if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
2840 genUseOperandAsResultTypeSeparateParamBuilder();
2841 genUseOperandAsResultTypeCollectiveParamBuilder();
2843 if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"))
2844 genUseAttrAsResultTypeBuilder();
2848 void OpEmitter::genCollectiveParamBuilder() {
2849 int numResults = op.getNumResults();
2850 int numVariadicResults = op.getNumVariableLengthResults();
2851 int numNonVariadicResults = numResults - numVariadicResults;
2853 int numOperands = op.getNumOperands();
2854 int numVariadicOperands = op.getNumVariableLengthOperands();
2855 int numNonVariadicOperands = numOperands - numVariadicOperands;
2857 SmallVector<MethodParameter> paramList;
2858 paramList.emplace_back("::mlir::OpBuilder &", "");
2859 paramList.emplace_back("::mlir::OperationState &", builderOpState);
2860 paramList.emplace_back("::mlir::TypeRange", "resultTypes");
2861 paramList.emplace_back("::mlir::ValueRange", "operands");
2862 // Provide default value for `attributes` when its the last parameter
2863 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
2864 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
2865 "attributes", attributesDefaultValue);
2866 if (op.getNumVariadicRegions())
2867 paramList.emplace_back("unsigned", "numRegions");
2869 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2870 // If the builder is redundant, skip generating the method
2871 if (!m)
2872 return;
2873 auto &body = m->body();
2875 // Operands
2876 if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
2877 body << " assert(operands.size()"
2878 << (numVariadicOperands != 0 ? " >= " : " == ")
2879 << numNonVariadicOperands
2880 << "u && \"mismatched number of parameters\");\n";
2881 body << " " << builderOpState << ".addOperands(operands);\n";
2883 // Attributes
2884 body << " " << builderOpState << ".addAttributes(attributes);\n";
2886 // Create the correct number of regions
2887 if (int numRegions = op.getNumRegions()) {
2888 body << llvm::formatv(
2889 " for (unsigned i = 0; i != {0}; ++i)\n",
2890 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
2891 body << " (void)" << builderOpState << ".addRegion();\n";
2894 // Result types
2895 if (numVariadicResults == 0 || numNonVariadicResults != 0)
2896 body << " assert(resultTypes.size()"
2897 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
2898 << "u && \"mismatched number of return types\");\n";
2899 body << " " << builderOpState << ".addTypes(resultTypes);\n";
2901 if (emitHelper.hasProperties()) {
2902 // Initialize the properties from Attributes before invoking the infer
2903 // function.
2904 body << formatv(R"(
2905 if (!attributes.empty()) {
2906 ::mlir::OpaqueProperties properties =
2907 &{1}.getOrAddProperties<{0}::Properties>();
2908 std::optional<::mlir::RegisteredOperationName> info =
2909 {1}.name.getRegisteredInfo();
2910 if (failed(info->setOpPropertiesFromAttribute({1}.name, properties,
2911 {1}.attributes.getDictionary({1}.getContext()), nullptr)))
2912 ::llvm::report_fatal_error("Property conversion failed.");
2913 })",
2914 opClass.getClassName(), builderOpState);
2917 // Generate builder that infers type too.
2918 // TODO: Expand to handle successors.
2919 if (canInferType(op) && op.getNumSuccessors() == 0)
2920 genInferredTypeCollectiveParamBuilder();
2923 void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
2924 llvm::StringSet<> &inferredAttributes,
2925 SmallVectorImpl<std::string> &resultTypeNames,
2926 TypeParamKind typeParamKind,
2927 AttrParamKind attrParamKind) {
2928 resultTypeNames.clear();
2929 auto numResults = op.getNumResults();
2930 resultTypeNames.reserve(numResults);
2932 paramList.emplace_back("::mlir::OpBuilder &", odsBuilder);
2933 paramList.emplace_back("::mlir::OperationState &", builderOpState);
2935 switch (typeParamKind) {
2936 case TypeParamKind::None:
2937 break;
2938 case TypeParamKind::Separate: {
2939 // Add parameters for all return types
2940 for (int i = 0; i < numResults; ++i) {
2941 const auto &result = op.getResult(i);
2942 std::string resultName = std::string(result.name);
2943 if (resultName.empty())
2944 resultName = std::string(formatv("resultType{0}", i));
2946 StringRef type =
2947 result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
2949 paramList.emplace_back(type, resultName, result.isOptional());
2950 resultTypeNames.emplace_back(std::move(resultName));
2952 } break;
2953 case TypeParamKind::Collective: {
2954 paramList.emplace_back("::mlir::TypeRange", "resultTypes");
2955 resultTypeNames.push_back("resultTypes");
2956 } break;
2959 // Add parameters for all arguments (operands and attributes).
2960 int defaultValuedAttrStartIndex = op.getNumArgs();
2961 // Successors and variadic regions go at the end of the parameter list, so no
2962 // default arguments are possible.
2963 bool hasTrailingParams = op.getNumSuccessors() || op.getNumVariadicRegions();
2964 if (!hasTrailingParams) {
2965 // Calculate the start index from which we can attach default values in the
2966 // builder declaration.
2967 for (int i = op.getNumArgs() - 1; i >= 0; --i) {
2968 auto *namedAttr =
2969 llvm::dyn_cast_if_present<tblgen::NamedAttribute *>(op.getArg(i));
2970 if (!namedAttr)
2971 break;
2973 Attribute attr = namedAttr->attr;
2974 // TODO: Currently we can't differentiate between optional meaning do not
2975 // verify/not always error if missing or optional meaning need not be
2976 // specified in builder. Expand isOptional once we can differentiate.
2977 if (!attr.hasDefaultValue() && !attr.isDerivedAttr())
2978 break;
2980 // Creating an APInt requires us to provide bitwidth, value, and
2981 // signedness, which is complicated compared to others. Similarly
2982 // for APFloat.
2983 // TODO: Adjust the 'returnType' field of such attributes
2984 // to support them.
2985 StringRef retType = namedAttr->attr.getReturnType();
2986 if (retType == "::llvm::APInt" || retType == "::llvm::APFloat")
2987 break;
2989 defaultValuedAttrStartIndex = i;
2992 // Avoid generating build methods that are ambiguous due to default values by
2993 // requiring at least one attribute.
2994 if (defaultValuedAttrStartIndex < op.getNumArgs()) {
2995 // TODO: This should have been possible as a cast<NamedAttribute> but
2996 // required template instantiations is not yet defined for the tblgen helper
2997 // classes.
2998 auto *namedAttr =
2999 cast<NamedAttribute *>(op.getArg(defaultValuedAttrStartIndex));
3000 Attribute attr = namedAttr->attr;
3001 if ((attrParamKind == AttrParamKind::WrappedAttr &&
3002 canUseUnwrappedRawValue(attr)) ||
3003 (attrParamKind == AttrParamKind::UnwrappedValue &&
3004 !canUseUnwrappedRawValue(attr)))
3005 ++defaultValuedAttrStartIndex;
3008 /// Collect any inferred attributes.
3009 for (const NamedTypeConstraint &operand : op.getOperands()) {
3010 if (operand.isVariadicOfVariadic()) {
3011 inferredAttributes.insert(
3012 operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
3016 for (int i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) {
3017 Argument arg = op.getArg(i);
3018 if (const auto *operand =
3019 llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg)) {
3020 StringRef type;
3021 if (operand->isVariadicOfVariadic())
3022 type = "::llvm::ArrayRef<::mlir::ValueRange>";
3023 else if (operand->isVariadic())
3024 type = "::mlir::ValueRange";
3025 else
3026 type = "::mlir::Value";
3028 paramList.emplace_back(type, getArgumentName(op, numOperands++),
3029 operand->isOptional());
3030 continue;
3032 if (llvm::isa_and_present<NamedProperty *>(arg)) {
3033 // TODO
3034 continue;
3036 const NamedAttribute &namedAttr = *arg.get<NamedAttribute *>();
3037 const Attribute &attr = namedAttr.attr;
3039 // Inferred attributes don't need to be added to the param list.
3040 if (inferredAttributes.contains(namedAttr.name))
3041 continue;
3043 StringRef type;
3044 switch (attrParamKind) {
3045 case AttrParamKind::WrappedAttr:
3046 type = attr.getStorageType();
3047 break;
3048 case AttrParamKind::UnwrappedValue:
3049 if (canUseUnwrappedRawValue(attr))
3050 type = attr.getReturnType();
3051 else
3052 type = attr.getStorageType();
3053 break;
3056 // Attach default value if requested and possible.
3057 std::string defaultValue;
3058 if (i >= defaultValuedAttrStartIndex) {
3059 if (attrParamKind == AttrParamKind::UnwrappedValue &&
3060 canUseUnwrappedRawValue(attr))
3061 defaultValue += attr.getDefaultValue();
3062 else
3063 defaultValue += "nullptr";
3065 paramList.emplace_back(type, namedAttr.name, StringRef(defaultValue),
3066 attr.isOptional());
3069 /// Insert parameters for each successor.
3070 for (const NamedSuccessor &succ : op.getSuccessors()) {
3071 StringRef type =
3072 succ.isVariadic() ? "::mlir::BlockRange" : "::mlir::Block *";
3073 paramList.emplace_back(type, succ.name);
3076 /// Insert parameters for variadic regions.
3077 for (const NamedRegion &region : op.getRegions())
3078 if (region.isVariadic())
3079 paramList.emplace_back("unsigned",
3080 llvm::formatv("{0}Count", region.name).str());
3083 void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
3084 MethodBody &body, llvm::StringSet<> &inferredAttributes,
3085 bool isRawValueAttr) {
3086 // Push all operands to the result.
3087 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
3088 std::string argName = getArgumentName(op, i);
3089 const NamedTypeConstraint &operand = op.getOperand(i);
3090 if (operand.constraint.isVariadicOfVariadic()) {
3091 body << " for (::mlir::ValueRange range : " << argName << ")\n "
3092 << builderOpState << ".addOperands(range);\n";
3094 // Add the segment attribute.
3095 body << " {\n"
3096 << " ::llvm::SmallVector<int32_t> rangeSegments;\n"
3097 << " for (::mlir::ValueRange range : " << argName << ")\n"
3098 << " rangeSegments.push_back(range.size());\n"
3099 << " auto rangeAttr = " << odsBuilder
3100 << ".getDenseI32ArrayAttr(rangeSegments);\n";
3101 if (op.getDialect().usePropertiesForAttributes()) {
3102 body << " " << builderOpState << ".getOrAddProperties<Properties>()."
3103 << operand.constraint.getVariadicOfVariadicSegmentSizeAttr()
3104 << " = rangeAttr;";
3105 } else {
3106 body << " " << builderOpState << ".addAttribute("
3107 << op.getGetterName(
3108 operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
3109 << "AttrName(" << builderOpState << ".name), rangeAttr);";
3111 body << " }\n";
3112 continue;
3115 if (operand.isOptional())
3116 body << " if (" << argName << ")\n ";
3117 body << " " << builderOpState << ".addOperands(" << argName << ");\n";
3120 // If the operation has the operand segment size attribute, add it here.
3121 auto emitSegment = [&]() {
3122 interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
3123 const NamedTypeConstraint &operand = op.getOperand(i);
3124 if (!operand.isVariableLength()) {
3125 body << "1";
3126 return;
3129 std::string operandName = getArgumentName(op, i);
3130 if (operand.isOptional()) {
3131 body << "(" << operandName << " ? 1 : 0)";
3132 } else if (operand.isVariadicOfVariadic()) {
3133 body << llvm::formatv(
3134 "static_cast<int32_t>(std::accumulate({0}.begin(), {0}.end(), 0, "
3135 "[](int32_t curSum, ::mlir::ValueRange range) {{ return curSum + "
3136 "static_cast<int32_t>(range.size()); }))",
3137 operandName);
3138 } else {
3139 body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
3143 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
3144 std::string sizes = op.getGetterName(operandSegmentAttrName);
3145 if (op.getDialect().usePropertiesForAttributes()) {
3146 body << " ::llvm::copy(::llvm::ArrayRef<int32_t>({";
3147 emitSegment();
3148 body << "}), " << builderOpState
3149 << ".getOrAddProperties<Properties>()."
3150 "operandSegmentSizes.begin());\n";
3151 } else {
3152 body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName("
3153 << builderOpState << ".name), "
3154 << "odsBuilder.getDenseI32ArrayAttr({";
3155 emitSegment();
3156 body << "}));\n";
3160 // Push all attributes to the result.
3161 for (const auto &namedAttr : op.getAttributes()) {
3162 auto &attr = namedAttr.attr;
3163 if (attr.isDerivedAttr() || inferredAttributes.contains(namedAttr.name))
3164 continue;
3166 // TODO: The wrapping of optional is different for default or not, so don't
3167 // unwrap for default ones that would fail below.
3168 bool emitNotNullCheck =
3169 (attr.isOptional() && !attr.hasDefaultValue()) ||
3170 (attr.hasDefaultValue() && !isRawValueAttr) ||
3171 // TODO: UnitAttr is optional, not wrapped, but needs to be guarded as
3172 // the constant materialization is only for true case.
3173 (isRawValueAttr && attr.getAttrDefName() == "UnitAttr");
3174 if (emitNotNullCheck)
3175 body.indent() << formatv("if ({0}) ", namedAttr.name) << "{\n";
3177 if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
3178 // If this is a raw value, then we need to wrap it in an Attribute
3179 // instance.
3180 FmtContext fctx;
3181 fctx.withBuilder("odsBuilder");
3182 if (op.getDialect().usePropertiesForAttributes()) {
3183 body << formatv(" {0}.getOrAddProperties<Properties>().{1} = {2};\n",
3184 builderOpState, namedAttr.name,
3185 constBuildAttrFromParam(attr, fctx, namedAttr.name));
3186 } else {
3187 body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
3188 builderOpState, op.getGetterName(namedAttr.name),
3189 constBuildAttrFromParam(attr, fctx, namedAttr.name));
3191 } else {
3192 if (op.getDialect().usePropertiesForAttributes()) {
3193 body << formatv(" {0}.getOrAddProperties<Properties>().{1} = {1};\n",
3194 builderOpState, namedAttr.name);
3195 } else {
3196 body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
3197 builderOpState, op.getGetterName(namedAttr.name),
3198 namedAttr.name);
3201 if (emitNotNullCheck)
3202 body.unindent() << " }\n";
3205 // Create the correct number of regions.
3206 for (const NamedRegion &region : op.getRegions()) {
3207 if (region.isVariadic())
3208 body << formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ",
3209 region.name);
3211 body << " (void)" << builderOpState << ".addRegion();\n";
3214 // Push all successors to the result.
3215 for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
3216 body << formatv(" {0}.addSuccessors({1});\n", builderOpState,
3217 namedSuccessor.name);
3221 void OpEmitter::genCanonicalizerDecls() {
3222 bool hasCanonicalizeMethod = def.getValueAsBit("hasCanonicalizeMethod");
3223 if (hasCanonicalizeMethod) {
3224 // static LogicResult FooOp::
3225 // canonicalize(FooOp op, PatternRewriter &rewriter);
3226 SmallVector<MethodParameter> paramList;
3227 paramList.emplace_back(op.getCppClassName(), "op");
3228 paramList.emplace_back("::mlir::PatternRewriter &", "rewriter");
3229 auto *m = opClass.declareStaticMethod("::llvm::LogicalResult",
3230 "canonicalize", std::move(paramList));
3231 ERROR_IF_PRUNED(m, "canonicalize", op);
3234 // We get a prototype for 'getCanonicalizationPatterns' if requested directly
3235 // or if using a 'canonicalize' method.
3236 bool hasCanonicalizer = def.getValueAsBit("hasCanonicalizer");
3237 if (!hasCanonicalizeMethod && !hasCanonicalizer)
3238 return;
3240 // We get a body for 'getCanonicalizationPatterns' when using a 'canonicalize'
3241 // method, but not implementing 'getCanonicalizationPatterns' manually.
3242 bool hasBody = hasCanonicalizeMethod && !hasCanonicalizer;
3244 // Add a signature for getCanonicalizationPatterns if implemented by the
3245 // dialect or if synthesized to call 'canonicalize'.
3246 SmallVector<MethodParameter> paramList;
3247 paramList.emplace_back("::mlir::RewritePatternSet &", "results");
3248 paramList.emplace_back("::mlir::MLIRContext *", "context");
3249 auto kind = hasBody ? Method::Static : Method::StaticDeclaration;
3250 auto *method = opClass.addMethod("void", "getCanonicalizationPatterns", kind,
3251 std::move(paramList));
3253 // If synthesizing the method, fill it.
3254 if (hasBody) {
3255 ERROR_IF_PRUNED(method, "getCanonicalizationPatterns", op);
3256 method->body() << " results.add(canonicalize);\n";
3260 void OpEmitter::genFolderDecls() {
3261 if (!op.hasFolder())
3262 return;
3264 SmallVector<MethodParameter> paramList;
3265 paramList.emplace_back("FoldAdaptor", "adaptor");
3267 StringRef retType;
3268 bool hasSingleResult =
3269 op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0;
3270 if (hasSingleResult) {
3271 retType = "::mlir::OpFoldResult";
3272 } else {
3273 paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
3274 "results");
3275 retType = "::llvm::LogicalResult";
3278 auto *m = opClass.declareMethod(retType, "fold", std::move(paramList));
3279 ERROR_IF_PRUNED(m, "fold", op);
3282 void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
3283 Interface interface = opTrait->getInterface();
3285 // Get the set of methods that should always be declared.
3286 auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods();
3287 llvm::StringSet<> alwaysDeclaredMethods;
3288 alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(),
3289 alwaysDeclaredMethodsVec.end());
3291 for (const InterfaceMethod &method : interface.getMethods()) {
3292 // Don't declare if the method has a body.
3293 if (method.getBody())
3294 continue;
3295 // Don't declare if the method has a default implementation and the op
3296 // didn't request that it always be declared.
3297 if (method.getDefaultImplementation() &&
3298 !alwaysDeclaredMethods.count(method.getName()))
3299 continue;
3300 // Interface methods are allowed to overlap with existing methods, so don't
3301 // check if pruned.
3302 (void)genOpInterfaceMethod(method);
3306 Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
3307 bool declaration) {
3308 SmallVector<MethodParameter> paramList;
3309 for (const InterfaceMethod::Argument &arg : method.getArguments())
3310 paramList.emplace_back(arg.type, arg.name);
3312 auto props = (method.isStatic() ? Method::Static : Method::None) |
3313 (declaration ? Method::Declaration : Method::None);
3314 return opClass.addMethod(method.getReturnType(), method.getName(), props,
3315 std::move(paramList));
3318 void OpEmitter::genOpInterfaceMethods() {
3319 for (const auto &trait : op.getTraits()) {
3320 if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
3321 if (opTrait->shouldDeclareMethods())
3322 genOpInterfaceMethods(opTrait);
3326 void OpEmitter::genSideEffectInterfaceMethods() {
3327 enum EffectKind { Operand, Result, Symbol, Static };
3328 struct EffectLocation {
3329 /// The effect applied.
3330 SideEffect effect;
3332 /// The index if the kind is not static.
3333 unsigned index;
3335 /// The kind of the location.
3336 unsigned kind;
3339 StringMap<SmallVector<EffectLocation, 1>> interfaceEffects;
3340 auto resolveDecorators = [&](Operator::var_decorator_range decorators,
3341 unsigned index, unsigned kind) {
3342 for (auto decorator : decorators)
3343 if (SideEffect *effect = dyn_cast<SideEffect>(&decorator)) {
3344 opClass.addTrait(effect->getInterfaceTrait());
3345 interfaceEffects[effect->getBaseEffectName()].push_back(
3346 EffectLocation{*effect, index, kind});
3350 // Collect effects that were specified via:
3351 /// Traits.
3352 for (const auto &trait : op.getTraits()) {
3353 const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(&trait);
3354 if (!opTrait)
3355 continue;
3356 auto &effects = interfaceEffects[opTrait->getBaseEffectName()];
3357 for (auto decorator : opTrait->getEffects())
3358 effects.push_back(EffectLocation{cast<SideEffect>(decorator),
3359 /*index=*/0, EffectKind::Static});
3361 /// Attributes and Operands.
3362 for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
3363 Argument arg = op.getArg(i);
3364 if (arg.is<NamedTypeConstraint *>()) {
3365 resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
3366 ++operandIt;
3367 continue;
3369 if (arg.is<NamedProperty *>())
3370 continue;
3371 const NamedAttribute *attr = arg.get<NamedAttribute *>();
3372 if (attr->attr.getBaseAttr().isSymbolRefAttr())
3373 resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol);
3375 /// Results.
3376 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
3377 resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result);
3379 // The code used to add an effect instance.
3380 // {0}: The effect class.
3381 // {1}: Optional value or symbol reference.
3382 // {2}: The side effect stage.
3383 // {3}: Does this side effect act on every single value of resource.
3384 // {4}: The resource class.
3385 const char *addEffectCode =
3386 " effects.emplace_back({0}::get(), {1}{2}, {3}, {4}::get());\n";
3388 for (auto &it : interfaceEffects) {
3389 // Generate the 'getEffects' method.
3390 std::string type = llvm::formatv("::llvm::SmallVectorImpl<::mlir::"
3391 "SideEffects::EffectInstance<{0}>> &",
3392 it.first())
3393 .str();
3394 auto *getEffects = opClass.addMethod("void", "getEffects",
3395 MethodParameter(type, "effects"));
3396 ERROR_IF_PRUNED(getEffects, "getEffects", op);
3397 auto &body = getEffects->body();
3399 // Add effect instances for each of the locations marked on the operation.
3400 for (auto &location : it.second) {
3401 StringRef effect = location.effect.getName();
3402 StringRef resource = location.effect.getResource();
3403 int stage = (int)location.effect.getStage();
3404 bool effectOnFullRegion = (int)location.effect.getEffectOnfullRegion();
3405 if (location.kind == EffectKind::Static) {
3406 // A static instance has no attached value.
3407 body << llvm::formatv(addEffectCode, effect, "", stage,
3408 effectOnFullRegion, resource)
3409 .str();
3410 } else if (location.kind == EffectKind::Symbol) {
3411 // A symbol reference requires adding the proper attribute.
3412 const auto *attr = op.getArg(location.index).get<NamedAttribute *>();
3413 std::string argName = op.getGetterName(attr->name);
3414 if (attr->attr.isOptional()) {
3415 body << " if (auto symbolRef = " << argName << "Attr())\n "
3416 << llvm::formatv(addEffectCode, effect, "symbolRef, ", stage,
3417 effectOnFullRegion, resource)
3418 .str();
3419 } else {
3420 body << llvm::formatv(addEffectCode, effect, argName + "Attr(), ",
3421 stage, effectOnFullRegion, resource)
3422 .str();
3424 } else {
3425 // Otherwise this is an operand/result, so we need to attach the Value.
3426 body << " {\n auto valueRange = getODS"
3427 << (location.kind == EffectKind::Operand ? "Operand" : "Result")
3428 << "IndexAndLength(" << location.index << ");\n"
3429 << " for (unsigned idx = valueRange.first; idx < "
3430 "valueRange.first"
3431 << " + valueRange.second; idx++) {\n "
3432 << llvm::formatv(addEffectCode, effect,
3433 (location.kind == EffectKind::Operand
3434 ? "&getOperation()->getOpOperand(idx), "
3435 : "getOperation()->getOpResult(idx), "),
3436 stage, effectOnFullRegion, resource)
3437 << " }\n }\n";
3443 void OpEmitter::genTypeInterfaceMethods() {
3444 if (!op.allResultTypesKnown())
3445 return;
3446 // Generate 'inferReturnTypes' method declaration using the interface method
3447 // declared in 'InferTypeOpInterface' op interface.
3448 const auto *trait =
3449 cast<InterfaceTrait>(op.getTrait("::mlir::InferTypeOpInterface::Trait"));
3450 Interface interface = trait->getInterface();
3451 Method *method = [&]() -> Method * {
3452 for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
3453 if (interfaceMethod.getName() == "inferReturnTypes") {
3454 return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false);
3457 assert(0 && "unable to find inferReturnTypes interface method");
3458 return nullptr;
3459 }();
3460 ERROR_IF_PRUNED(method, "inferReturnTypes", op);
3461 auto &body = method->body();
3462 body << " inferredReturnTypes.resize(" << op.getNumResults() << ");\n";
3464 FmtContext fctx;
3465 fctx.withBuilder("odsBuilder");
3466 fctx.addSubst("_ctxt", "context");
3467 body << " ::mlir::Builder odsBuilder(context);\n";
3469 // Process the type inference graph in topological order, starting from types
3470 // that are always fully-inferred: operands and results with constructible
3471 // types. The type inference graph here will always be a DAG, so this gives
3472 // us the correct order for generating the types. -1 is a placeholder to
3473 // indicate the type for a result has not been generated.
3474 SmallVector<int> constructedIndices(op.getNumResults(), -1);
3475 int inferredTypeIdx = 0;
3476 for (int numResults = op.getNumResults(); inferredTypeIdx != numResults;) {
3477 for (int i = 0, e = op.getNumResults(); i != e; ++i) {
3478 if (constructedIndices[i] >= 0)
3479 continue;
3480 const InferredResultType &infer = op.getInferredResultType(i);
3481 std::string typeStr;
3482 if (infer.isArg()) {
3483 // If this is an operand, just index into operand list to access the
3484 // type.
3485 auto arg = op.getArgToOperandOrAttribute(infer.getIndex());
3486 if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
3487 typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) +
3488 "].getType()")
3489 .str();
3491 // If this is an attribute, index into the attribute dictionary.
3492 } else {
3493 auto *attr =
3494 op.getArg(arg.operandOrAttributeIndex()).get<NamedAttribute *>();
3495 body << " ::mlir::TypedAttr odsInferredTypeAttr" << inferredTypeIdx
3496 << " = ";
3497 if (op.getDialect().usePropertiesForAttributes()) {
3498 body << "(properties ? properties.as<Properties *>()->"
3499 << attr->name
3500 << " : "
3501 "::llvm::dyn_cast_or_null<::mlir::TypedAttr>(attributes."
3502 "get(\"" +
3503 attr->name + "\")));\n";
3504 } else {
3505 body << "::llvm::dyn_cast_or_null<::mlir::TypedAttr>(attributes."
3506 "get(\"" +
3507 attr->name + "\"));\n";
3509 body << " if (!odsInferredTypeAttr" << inferredTypeIdx
3510 << ") return ::mlir::failure();\n";
3511 typeStr =
3512 ("odsInferredTypeAttr" + Twine(inferredTypeIdx) + ".getType()")
3513 .str();
3515 } else if (std::optional<StringRef> builder =
3516 op.getResult(infer.getResultIndex())
3517 .constraint.getBuilderCall()) {
3518 typeStr = tgfmt(*builder, &fctx).str();
3519 } else if (int index = constructedIndices[infer.getResultIndex()];
3520 index >= 0) {
3521 typeStr = ("odsInferredType" + Twine(index)).str();
3522 } else {
3523 continue;
3525 body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = "
3526 << tgfmt(infer.getTransformer(), &fctx.withSelf(typeStr)) << ";\n";
3527 constructedIndices[i] = inferredTypeIdx - 1;
3530 for (auto [i, index] : llvm::enumerate(constructedIndices))
3531 body << " inferredReturnTypes[" << i << "] = odsInferredType" << index
3532 << ";\n";
3533 body << " return ::mlir::success();";
3536 void OpEmitter::genParser() {
3537 if (hasStringAttribute(def, "assemblyFormat"))
3538 return;
3540 if (!def.getValueAsBit("hasCustomAssemblyFormat"))
3541 return;
3543 SmallVector<MethodParameter> paramList;
3544 paramList.emplace_back("::mlir::OpAsmParser &", "parser");
3545 paramList.emplace_back("::mlir::OperationState &", "result");
3547 auto *method = opClass.declareStaticMethod("::mlir::ParseResult", "parse",
3548 std::move(paramList));
3549 ERROR_IF_PRUNED(method, "parse", op);
3552 void OpEmitter::genPrinter() {
3553 if (hasStringAttribute(def, "assemblyFormat"))
3554 return;
3556 // Check to see if this op uses a c++ format.
3557 if (!def.getValueAsBit("hasCustomAssemblyFormat"))
3558 return;
3559 auto *method = opClass.declareMethod(
3560 "void", "print", MethodParameter("::mlir::OpAsmPrinter &", "p"));
3561 ERROR_IF_PRUNED(method, "print", op);
3564 void OpEmitter::genVerifier() {
3565 auto *implMethod =
3566 opClass.addMethod("::llvm::LogicalResult", "verifyInvariantsImpl");
3567 ERROR_IF_PRUNED(implMethod, "verifyInvariantsImpl", op);
3568 auto &implBody = implMethod->body();
3569 bool useProperties = emitHelper.hasProperties();
3571 populateSubstitutions(emitHelper, verifyCtx);
3572 genAttributeVerifier(emitHelper, verifyCtx, implBody, staticVerifierEmitter,
3573 useProperties);
3574 genOperandResultVerifier(implBody, op.getOperands(), "operand");
3575 genOperandResultVerifier(implBody, op.getResults(), "result");
3577 for (auto &trait : op.getTraits()) {
3578 if (auto *t = dyn_cast<tblgen::PredTrait>(&trait)) {
3579 implBody << tgfmt(" if (!($0))\n "
3580 "return emitOpError(\"failed to verify that $1\");\n",
3581 &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
3582 t->getSummary());
3586 genRegionVerifier(implBody);
3587 genSuccessorVerifier(implBody);
3589 implBody << " return ::mlir::success();\n";
3591 // TODO: Some places use the `verifyInvariants` to do operation verification.
3592 // This may not act as their expectation because this doesn't call any
3593 // verifiers of native/interface traits. Needs to review those use cases and
3594 // see if we should use the mlir::verify() instead.
3595 auto *method = opClass.addMethod("::llvm::LogicalResult", "verifyInvariants");
3596 ERROR_IF_PRUNED(method, "verifyInvariants", op);
3597 auto &body = method->body();
3598 if (def.getValueAsBit("hasVerifier")) {
3599 body << " if(::mlir::succeeded(verifyInvariantsImpl()) && "
3600 "::mlir::succeeded(verify()))\n";
3601 body << " return ::mlir::success();\n";
3602 body << " return ::mlir::failure();";
3603 } else {
3604 body << " return verifyInvariantsImpl();";
3608 void OpEmitter::genCustomVerifier() {
3609 if (def.getValueAsBit("hasVerifier")) {
3610 auto *method = opClass.declareMethod("::llvm::LogicalResult", "verify");
3611 ERROR_IF_PRUNED(method, "verify", op);
3614 if (def.getValueAsBit("hasRegionVerifier")) {
3615 auto *method =
3616 opClass.declareMethod("::llvm::LogicalResult", "verifyRegions");
3617 ERROR_IF_PRUNED(method, "verifyRegions", op);
3621 void OpEmitter::genOperandResultVerifier(MethodBody &body,
3622 Operator::const_value_range values,
3623 StringRef valueKind) {
3624 // Check that an optional value is at most 1 element.
3626 // {0}: Value index.
3627 // {1}: "operand" or "result"
3628 const char *const verifyOptional = R"(
3629 if (valueGroup{0}.size() > 1) {
3630 return emitOpError("{1} group starting at #") << index
3631 << " requires 0 or 1 element, but found " << valueGroup{0}.size();
3634 // Check the types of a range of values.
3636 // {0}: Value index.
3637 // {1}: Type constraint function.
3638 // {2}: "operand" or "result"
3639 const char *const verifyValues = R"(
3640 for (auto v : valueGroup{0}) {
3641 if (::mlir::failed({1}(*this, v.getType(), "{2}", index++)))
3642 return ::mlir::failure();
3646 const auto canSkip = [](const NamedTypeConstraint &value) {
3647 return !value.hasPredicate() && !value.isOptional() &&
3648 !value.isVariadicOfVariadic();
3650 if (values.empty() || llvm::all_of(values, canSkip))
3651 return;
3653 FmtContext fctx;
3655 body << " {\n unsigned index = 0; (void)index;\n";
3657 for (const auto &staticValue : llvm::enumerate(values)) {
3658 const NamedTypeConstraint &value = staticValue.value();
3660 bool hasPredicate = value.hasPredicate();
3661 bool isOptional = value.isOptional();
3662 bool isVariadicOfVariadic = value.isVariadicOfVariadic();
3663 if (!hasPredicate && !isOptional && !isVariadicOfVariadic)
3664 continue;
3665 body << formatv(" auto valueGroup{2} = getODS{0}{1}s({2});\n",
3666 // Capitalize the first letter to match the function name
3667 valueKind.substr(0, 1).upper(), valueKind.substr(1),
3668 staticValue.index());
3670 // If the constraint is optional check that the value group has at most 1
3671 // value.
3672 if (isOptional) {
3673 body << formatv(verifyOptional, staticValue.index(), valueKind);
3674 } else if (isVariadicOfVariadic) {
3675 body << formatv(
3676 " if (::mlir::failed(::mlir::OpTrait::impl::verifyValueSizeAttr("
3677 "*this, \"{0}\", \"{1}\", valueGroup{2}.size())))\n"
3678 " return ::mlir::failure();\n",
3679 value.constraint.getVariadicOfVariadicSegmentSizeAttr(), value.name,
3680 staticValue.index());
3683 // Otherwise, if there is no predicate there is nothing left to do.
3684 if (!hasPredicate)
3685 continue;
3686 // Emit a loop to check all the dynamic values in the pack.
3687 StringRef constraintFn =
3688 staticVerifierEmitter.getTypeConstraintFn(value.constraint);
3689 body << formatv(verifyValues, staticValue.index(), constraintFn, valueKind);
3692 body << " }\n";
3695 void OpEmitter::genRegionVerifier(MethodBody &body) {
3696 /// Code to verify a region.
3698 /// {0}: Getter for the regions.
3699 /// {1}: The region constraint.
3700 /// {2}: The region's name.
3701 /// {3}: The region description.
3702 const char *const verifyRegion = R"(
3703 for (auto &region : {0})
3704 if (::mlir::failed({1}(*this, region, "{2}", index++)))
3705 return ::mlir::failure();
3707 /// Get a single region.
3709 /// {0}: The region's index.
3710 const char *const getSingleRegion =
3711 "::llvm::MutableArrayRef((*this)->getRegion({0}))";
3713 // If we have no regions, there is nothing more to do.
3714 const auto canSkip = [](const NamedRegion &region) {
3715 return region.constraint.getPredicate().isNull();
3717 auto regions = op.getRegions();
3718 if (regions.empty() && llvm::all_of(regions, canSkip))
3719 return;
3721 body << " {\n unsigned index = 0; (void)index;\n";
3722 for (const auto &it : llvm::enumerate(regions)) {
3723 const auto &region = it.value();
3724 if (canSkip(region))
3725 continue;
3727 auto getRegion = region.isVariadic()
3728 ? formatv("{0}()", op.getGetterName(region.name)).str()
3729 : formatv(getSingleRegion, it.index()).str();
3730 auto constraintFn =
3731 staticVerifierEmitter.getRegionConstraintFn(region.constraint);
3732 body << formatv(verifyRegion, getRegion, constraintFn, region.name);
3734 body << " }\n";
3737 void OpEmitter::genSuccessorVerifier(MethodBody &body) {
3738 const char *const verifySuccessor = R"(
3739 for (auto *successor : {0})
3740 if (::mlir::failed({1}(*this, successor, "{2}", index++)))
3741 return ::mlir::failure();
3743 /// Get a single successor.
3745 /// {0}: The successor's name.
3746 const char *const getSingleSuccessor = "::llvm::MutableArrayRef({0}())";
3748 // If we have no successors, there is nothing more to do.
3749 const auto canSkip = [](const NamedSuccessor &successor) {
3750 return successor.constraint.getPredicate().isNull();
3752 auto successors = op.getSuccessors();
3753 if (successors.empty() && llvm::all_of(successors, canSkip))
3754 return;
3756 body << " {\n unsigned index = 0; (void)index;\n";
3758 for (auto it : llvm::enumerate(successors)) {
3759 const auto &successor = it.value();
3760 if (canSkip(successor))
3761 continue;
3763 auto getSuccessor =
3764 formatv(successor.isVariadic() ? "{0}()" : getSingleSuccessor,
3765 successor.name, it.index())
3766 .str();
3767 auto constraintFn =
3768 staticVerifierEmitter.getSuccessorConstraintFn(successor.constraint);
3769 body << formatv(verifySuccessor, getSuccessor, constraintFn,
3770 successor.name);
3772 body << " }\n";
3775 /// Add a size count trait to the given operation class.
3776 static void addSizeCountTrait(OpClass &opClass, StringRef traitKind,
3777 int numTotal, int numVariadic) {
3778 if (numVariadic != 0) {
3779 if (numTotal == numVariadic)
3780 opClass.addTrait("::mlir::OpTrait::Variadic" + traitKind + "s");
3781 else
3782 opClass.addTrait("::mlir::OpTrait::AtLeastN" + traitKind + "s<" +
3783 Twine(numTotal - numVariadic) + ">::Impl");
3784 return;
3786 switch (numTotal) {
3787 case 0:
3788 opClass.addTrait("::mlir::OpTrait::Zero" + traitKind + "s");
3789 break;
3790 case 1:
3791 opClass.addTrait("::mlir::OpTrait::One" + traitKind);
3792 break;
3793 default:
3794 opClass.addTrait("::mlir::OpTrait::N" + traitKind + "s<" + Twine(numTotal) +
3795 ">::Impl");
3796 break;
3800 void OpEmitter::genTraits() {
3801 // Add region size trait.
3802 unsigned numRegions = op.getNumRegions();
3803 unsigned numVariadicRegions = op.getNumVariadicRegions();
3804 addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions);
3806 // Add result size traits.
3807 int numResults = op.getNumResults();
3808 int numVariadicResults = op.getNumVariableLengthResults();
3809 addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
3811 // For single result ops with a known specific type, generate a OneTypedResult
3812 // trait.
3813 if (numResults == 1 && numVariadicResults == 0) {
3814 auto cppName = op.getResults().begin()->constraint.getCPPClassName();
3815 opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
3818 // Add successor size trait.
3819 unsigned numSuccessors = op.getNumSuccessors();
3820 unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
3821 addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors);
3823 // Add variadic size trait and normal op traits.
3824 int numOperands = op.getNumOperands();
3825 int numVariadicOperands = op.getNumVariableLengthOperands();
3827 // Add operand size trait.
3828 addSizeCountTrait(opClass, "Operand", numOperands, numVariadicOperands);
3830 // The op traits defined internal are ensured that they can be verified
3831 // earlier.
3832 for (const auto &trait : op.getTraits()) {
3833 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
3834 if (opTrait->isStructuralOpTrait())
3835 opClass.addTrait(opTrait->getFullyQualifiedTraitName());
3839 // OpInvariants wrapps the verifyInvariants which needs to be run before
3840 // native/interface traits and after all the traits with `StructuralOpTrait`.
3841 opClass.addTrait("::mlir::OpTrait::OpInvariants");
3843 if (emitHelper.hasProperties())
3844 opClass.addTrait("::mlir::BytecodeOpInterface::Trait");
3846 // Add the native and interface traits.
3847 for (const auto &trait : op.getTraits()) {
3848 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
3849 if (!opTrait->isStructuralOpTrait())
3850 opClass.addTrait(opTrait->getFullyQualifiedTraitName());
3851 } else if (auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait)) {
3852 opClass.addTrait(opTrait->getFullyQualifiedTraitName());
3857 void OpEmitter::genOpNameGetter() {
3858 auto *method = opClass.addStaticMethod<Method::Constexpr>(
3859 "::llvm::StringLiteral", "getOperationName");
3860 ERROR_IF_PRUNED(method, "getOperationName", op);
3861 method->body() << " return ::llvm::StringLiteral(\"" << op.getOperationName()
3862 << "\");";
3865 void OpEmitter::genOpAsmInterface() {
3866 // If the user only has one results or specifically added the Asm trait,
3867 // then don't generate it for them. We specifically only handle multi result
3868 // operations, because the name of a single result in the common case is not
3869 // interesting(generally 'result'/'output'/etc.).
3870 // TODO: We could also add a flag to allow operations to opt in to this
3871 // generation, even if they only have a single operation.
3872 int numResults = op.getNumResults();
3873 if (numResults <= 1 || op.getTrait("::mlir::OpAsmOpInterface::Trait"))
3874 return;
3876 SmallVector<StringRef, 4> resultNames(numResults);
3877 for (int i = 0; i != numResults; ++i)
3878 resultNames[i] = op.getResultName(i);
3880 // Don't add the trait if none of the results have a valid name.
3881 if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); }))
3882 return;
3883 opClass.addTrait("::mlir::OpAsmOpInterface::Trait");
3885 // Generate the right accessor for the number of results.
3886 auto *method = opClass.addMethod(
3887 "void", "getAsmResultNames",
3888 MethodParameter("::mlir::OpAsmSetValueNameFn", "setNameFn"));
3889 ERROR_IF_PRUNED(method, "getAsmResultNames", op);
3890 auto &body = method->body();
3891 for (int i = 0; i != numResults; ++i) {
3892 body << " auto resultGroup" << i << " = getODSResults(" << i << ");\n"
3893 << " if (!resultGroup" << i << ".empty())\n"
3894 << " setNameFn(*resultGroup" << i << ".begin(), \""
3895 << resultNames[i] << "\");\n";
3899 //===----------------------------------------------------------------------===//
3900 // OpOperandAdaptor emitter
3901 //===----------------------------------------------------------------------===//
3903 namespace {
3904 // Helper class to emit Op operand adaptors to an output stream. Operand
3905 // adaptors are wrappers around random access ranges that provide named operand
3906 // getters identical to those defined in the Op.
3907 // This currently generates 3 classes per Op:
3908 // * A Base class within the 'detail' namespace, which contains all logic and
3909 // members independent of the random access range that is indexed into.
3910 // In other words, it contains all the attribute and region getters.
3911 // * A templated class named '{OpName}GenericAdaptor' with a template parameter
3912 // 'RangeT' that is indexed into by the getters to access the operands.
3913 // It contains all getters to access operands and inherits from the previous
3914 // class.
3915 // * A class named '{OpName}Adaptor', which inherits from the 'GenericAdaptor'
3916 // with 'mlir::ValueRange' as template parameter. It adds a constructor from
3917 // an instance of the op type and a verify function.
3918 class OpOperandAdaptorEmitter {
3919 public:
3920 static void
3921 emitDecl(const Operator &op,
3922 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
3923 raw_ostream &os);
3924 static void
3925 emitDef(const Operator &op,
3926 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
3927 raw_ostream &os);
3929 private:
3930 explicit OpOperandAdaptorEmitter(
3931 const Operator &op,
3932 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
3934 // Add verification function. This generates a verify method for the adaptor
3935 // which verifies all the op-independent attribute constraints.
3936 void addVerification();
3938 // The operation for which to emit an adaptor.
3939 const Operator &op;
3941 // The generated adaptor classes.
3942 Class genericAdaptorBase;
3943 Class genericAdaptor;
3944 Class adaptor;
3946 // The emitter containing all of the locally emitted verification functions.
3947 const StaticVerifierFunctionEmitter &staticVerifierEmitter;
3949 // Helper for emitting adaptor code.
3950 OpOrAdaptorHelper emitHelper;
3952 } // namespace
3954 OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
3955 const Operator &op,
3956 const StaticVerifierFunctionEmitter &staticVerifierEmitter)
3957 : op(op), genericAdaptorBase(op.getGenericAdaptorName() + "Base"),
3958 genericAdaptor(op.getGenericAdaptorName()), adaptor(op.getAdaptorName()),
3959 staticVerifierEmitter(staticVerifierEmitter),
3960 emitHelper(op, /*emitForOp=*/false) {
3962 genericAdaptorBase.declare<VisibilityDeclaration>(Visibility::Public);
3963 bool useProperties = emitHelper.hasProperties();
3964 if (useProperties) {
3965 // Define the properties struct with multiple members.
3966 using ConstArgument =
3967 llvm::PointerUnion<const AttributeMetadata *, const NamedProperty *>;
3968 SmallVector<ConstArgument> attrOrProperties;
3969 for (const std::pair<StringRef, AttributeMetadata> &it :
3970 emitHelper.getAttrMetadata()) {
3971 if (!it.second.constraint || !it.second.constraint->isDerivedAttr())
3972 attrOrProperties.push_back(&it.second);
3974 for (const NamedProperty &prop : op.getProperties())
3975 attrOrProperties.push_back(&prop);
3976 if (emitHelper.getOperandSegmentsSize())
3977 attrOrProperties.push_back(&emitHelper.getOperandSegmentsSize().value());
3978 if (emitHelper.getResultSegmentsSize())
3979 attrOrProperties.push_back(&emitHelper.getResultSegmentsSize().value());
3980 assert(!attrOrProperties.empty());
3981 std::string declarations = " struct Properties {\n";
3982 llvm::raw_string_ostream os(declarations);
3983 std::string comparator =
3984 " bool operator==(const Properties &rhs) const {\n"
3985 " return \n";
3986 llvm::raw_string_ostream comparatorOs(comparator);
3987 for (const auto &attrOrProp : attrOrProperties) {
3988 if (const auto *namedProperty =
3989 llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
3990 StringRef name = namedProperty->name;
3991 if (name.empty())
3992 report_fatal_error("missing name for property");
3993 std::string camelName =
3994 convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true);
3995 auto &prop = namedProperty->prop;
3996 // Generate the data member using the storage type.
3997 os << " using " << name << "Ty = " << prop.getStorageType() << ";\n"
3998 << " " << name << "Ty " << name;
3999 if (prop.hasDefaultValue())
4000 os << " = " << prop.getDefaultValue();
4001 comparatorOs << " rhs." << name << " == this->" << name
4002 << " &&\n";
4003 // Emit accessors using the interface type.
4004 const char *accessorFmt = R"decl(;
4005 {0} get{1}() {
4006 auto &propStorage = this->{2};
4007 return {3};
4009 void set{1}(const {0} &propValue) {
4010 auto &propStorage = this->{2};
4011 {4};
4013 )decl";
4014 FmtContext fctx;
4015 os << formatv(accessorFmt, prop.getInterfaceType(), camelName, name,
4016 tgfmt(prop.getConvertFromStorageCall(),
4017 &fctx.addSubst("_storage", propertyStorage)),
4018 tgfmt(prop.getAssignToStorageCall(),
4019 &fctx.addSubst("_value", propertyValue)
4020 .addSubst("_storage", propertyStorage)));
4021 continue;
4023 const auto *namedAttr =
4024 llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
4025 const Attribute *attr = nullptr;
4026 if (namedAttr->constraint)
4027 attr = &*namedAttr->constraint;
4028 StringRef name = namedAttr->attrName;
4029 if (name.empty())
4030 report_fatal_error("missing name for property attr");
4031 std::string camelName =
4032 convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true);
4033 // Generate the data member using the storage type.
4034 StringRef storageType;
4035 if (attr) {
4036 storageType = attr->getStorageType();
4037 } else {
4038 if (name != operandSegmentAttrName && name != resultSegmentAttrName) {
4039 report_fatal_error("unexpected AttributeMetadata");
4041 // TODO: update to use native integers.
4042 storageType = "::mlir::DenseI32ArrayAttr";
4044 os << " using " << name << "Ty = " << storageType << ";\n"
4045 << " " << name << "Ty " << name << ";\n";
4046 comparatorOs << " rhs." << name << " == this->" << name << " &&\n";
4048 // Emit accessors using the interface type.
4049 if (attr) {
4050 const char *accessorFmt = R"decl(
4051 auto get{0}() {
4052 auto &propStorage = this->{1};
4053 return ::llvm::{2}<{3}>(propStorage);
4055 void set{0}(const {3} &propValue) {
4056 this->{1} = propValue;
4058 )decl";
4059 os << formatv(accessorFmt, camelName, name,
4060 attr->isOptional() || attr->hasDefaultValue()
4061 ? "dyn_cast_or_null"
4062 : "cast",
4063 storageType);
4066 comparatorOs << " true;\n }\n"
4067 " bool operator!=(const Properties &rhs) const {\n"
4068 " return !(*this == rhs);\n"
4069 " }\n";
4070 comparatorOs.flush();
4071 os << comparator;
4072 os << " };\n";
4073 os.flush();
4075 genericAdaptorBase.declare<ExtraClassDeclaration>(std::move(declarations));
4077 genericAdaptorBase.declare<VisibilityDeclaration>(Visibility::Protected);
4078 genericAdaptorBase.declare<Field>("::mlir::DictionaryAttr", "odsAttrs");
4079 genericAdaptorBase.declare<Field>("::std::optional<::mlir::OperationName>",
4080 "odsOpName");
4081 if (useProperties)
4082 genericAdaptorBase.declare<Field>("Properties", "properties");
4083 genericAdaptorBase.declare<Field>("::mlir::RegionRange", "odsRegions");
4085 genericAdaptor.addTemplateParam("RangeT");
4086 genericAdaptor.addField("RangeT", "odsOperands");
4087 genericAdaptor.addParent(
4088 ParentClass("detail::" + genericAdaptorBase.getClassName()));
4089 genericAdaptor.declare<UsingDeclaration>(
4090 "ValueT", "::llvm::detail::ValueOfRange<RangeT>");
4091 genericAdaptor.declare<UsingDeclaration>(
4092 "Base", "detail::" + genericAdaptorBase.getClassName());
4094 const auto *attrSizedOperands =
4095 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
4097 SmallVector<MethodParameter> paramList;
4098 if (useProperties) {
4099 // Properties can't be given a default constructor here due to Properties
4100 // struct being defined in the enclosing class which isn't complete by
4101 // here.
4102 paramList.emplace_back("::mlir::DictionaryAttr", "attrs");
4103 paramList.emplace_back("const Properties &", "properties");
4104 } else {
4105 paramList.emplace_back("::mlir::DictionaryAttr", "attrs", "{}");
4106 paramList.emplace_back("const ::mlir::EmptyProperties &", "properties",
4107 "{}");
4109 paramList.emplace_back("::mlir::RegionRange", "regions", "{}");
4110 auto *baseConstructor =
4111 genericAdaptorBase.addConstructor<Method::Inline>(paramList);
4112 baseConstructor->addMemberInitializer("odsAttrs", "attrs");
4113 if (useProperties)
4114 baseConstructor->addMemberInitializer("properties", "properties");
4115 baseConstructor->addMemberInitializer("odsRegions", "regions");
4117 MethodBody &body = baseConstructor->body();
4118 body.indent() << "if (odsAttrs)\n";
4119 body.indent() << formatv(
4120 "odsOpName.emplace(\"{0}\", odsAttrs.getContext());\n",
4121 op.getOperationName());
4123 paramList.insert(paramList.begin(), MethodParameter("RangeT", "values"));
4124 auto *constructor = genericAdaptor.addConstructor(paramList);
4125 constructor->addMemberInitializer("Base", "attrs, properties, regions");
4126 constructor->addMemberInitializer("odsOperands", "values");
4128 // Add a forwarding constructor to the previous one that accepts
4129 // OpaqueProperties instead and check for null and perform the cast to the
4130 // actual properties type.
4131 paramList[1] = MethodParameter("::mlir::DictionaryAttr", "attrs");
4132 paramList[2] = MethodParameter("::mlir::OpaqueProperties", "properties");
4133 auto *opaquePropertiesConstructor =
4134 genericAdaptor.addConstructor(std::move(paramList));
4135 if (useProperties) {
4136 opaquePropertiesConstructor->addMemberInitializer(
4137 genericAdaptor.getClassName(),
4138 "values, "
4139 "attrs, "
4140 "(properties ? *properties.as<Properties *>() : Properties{}), "
4141 "regions");
4142 } else {
4143 opaquePropertiesConstructor->addMemberInitializer(
4144 genericAdaptor.getClassName(),
4145 "values, "
4146 "attrs, "
4147 "(properties ? *properties.as<::mlir::EmptyProperties *>() : "
4148 "::mlir::EmptyProperties{}), "
4149 "regions");
4152 // Add forwarding constructor that constructs Properties.
4153 if (useProperties) {
4154 SmallVector<MethodParameter> paramList;
4155 paramList.emplace_back("RangeT", "values");
4156 paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
4157 attrSizedOperands ? "" : "nullptr");
4158 auto *noPropertiesConstructor =
4159 genericAdaptor.addConstructor(std::move(paramList));
4160 noPropertiesConstructor->addMemberInitializer(
4161 genericAdaptor.getClassName(), "values, "
4162 "attrs, "
4163 "Properties{}, "
4164 "{}");
4168 // Create constructors constructing the adaptor from an instance of the op.
4169 // This takes the attributes, properties and regions from the op instance
4170 // and the value range from the parameter.
4172 // Base class is in the cpp file and can simply access the members of the op
4173 // class to initialize the template independent fields. If the op doesn't
4174 // have properties, we can emit a generic constructor inline. Otherwise,
4175 // emit it out-of-line because we need the op to be defined.
4176 Constructor *constructor;
4177 if (useProperties) {
4178 constructor = genericAdaptorBase.addConstructor(
4179 MethodParameter(op.getCppClassName(), "op"));
4180 } else {
4181 constructor = genericAdaptorBase.addConstructor<Method::Inline>(
4182 MethodParameter("::mlir::Operation *", "op"));
4184 constructor->addMemberInitializer("odsAttrs",
4185 "op->getRawDictionaryAttrs()");
4186 // Retrieve the operation name from the op directly.
4187 constructor->addMemberInitializer("odsOpName", "op->getName()");
4188 if (useProperties)
4189 constructor->addMemberInitializer("properties", "op.getProperties()");
4190 constructor->addMemberInitializer("odsRegions", "op->getRegions()");
4192 // Generic adaptor is templated and therefore defined inline in the header.
4193 // We cannot use the Op class here as it is an incomplete type (we have a
4194 // circular reference between the two).
4195 // Use a template trick to make the constructor be instantiated at call site
4196 // when the op class is complete.
4197 constructor = genericAdaptor.addConstructor(
4198 MethodParameter("RangeT", "values"), MethodParameter("LateInst", "op"));
4199 constructor->addTemplateParam("LateInst = " + op.getCppClassName());
4200 constructor->addTemplateParam(
4201 "= std::enable_if_t<std::is_same_v<LateInst, " + op.getCppClassName() +
4202 ">>");
4203 constructor->addMemberInitializer("Base", "op");
4204 constructor->addMemberInitializer("odsOperands", "values");
4207 std::string sizeAttrInit;
4208 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
4209 if (op.getDialect().usePropertiesForAttributes())
4210 sizeAttrInit =
4211 formatv(adapterSegmentSizeAttrInitCodeProperties,
4212 llvm::formatv("getProperties().operandSegmentSizes"));
4213 else
4214 sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode,
4215 emitHelper.getAttr(operandSegmentAttrName));
4217 generateNamedOperandGetters(op, genericAdaptor,
4218 /*genericAdaptorBase=*/&genericAdaptorBase,
4219 /*sizeAttrInit=*/sizeAttrInit,
4220 /*rangeType=*/"RangeT",
4221 /*rangeElementType=*/"ValueT",
4222 /*rangeBeginCall=*/"odsOperands.begin()",
4223 /*rangeSizeCall=*/"odsOperands.size()",
4224 /*getOperandCallPattern=*/"odsOperands[{0}]");
4226 // Any invalid overlap for `getOperands` will have been diagnosed before
4227 // here already.
4228 if (auto *m = genericAdaptor.addMethod("RangeT", "getOperands"))
4229 m->body() << " return odsOperands;";
4231 FmtContext fctx;
4232 fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
4234 // Generate named accessor with Attribute return type.
4235 auto emitAttrWithStorageType = [&](StringRef name, StringRef emitName,
4236 Attribute attr) {
4237 // The method body is trivial if the attribute does not have a default
4238 // value, in which case the default value may be arbitrary code.
4239 auto *method = genericAdaptorBase.addMethod(
4240 attr.getStorageType(), emitName + "Attr",
4241 attr.hasDefaultValue() || !useProperties ? Method::Properties::None
4242 : Method::Properties::Inline);
4243 ERROR_IF_PRUNED(method, "Adaptor::" + emitName + "Attr", op);
4244 auto &body = method->body().indent();
4245 if (!useProperties)
4246 body << "assert(odsAttrs && \"no attributes when constructing "
4247 "adapter\");\n";
4248 body << formatv(
4249 "auto attr = ::llvm::{1}<{2}>({0});\n", emitHelper.getAttr(name),
4250 attr.hasDefaultValue() || attr.isOptional() ? "dyn_cast_or_null"
4251 : "cast",
4252 attr.getStorageType());
4254 if (attr.hasDefaultValue() && attr.isOptional()) {
4255 // Use the default value if attribute is not set.
4256 // TODO: this is inefficient, we are recreating the attribute for every
4257 // call. This should be set instead.
4258 std::string defaultValue = std::string(
4259 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
4260 body << "if (!attr)\n attr = " << defaultValue << ";\n";
4262 body << "return attr;\n";
4265 if (useProperties) {
4266 auto *m = genericAdaptorBase.addInlineMethod("const Properties &",
4267 "getProperties");
4268 ERROR_IF_PRUNED(m, "Adaptor::getProperties", op);
4269 m->body() << " return properties;";
4272 auto *m = genericAdaptorBase.addInlineMethod("::mlir::DictionaryAttr",
4273 "getAttributes");
4274 ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op);
4275 m->body() << " return odsAttrs;";
4277 for (auto &namedAttr : op.getAttributes()) {
4278 const auto &name = namedAttr.name;
4279 const auto &attr = namedAttr.attr;
4280 if (attr.isDerivedAttr())
4281 continue;
4282 std::string emitName = op.getGetterName(name);
4283 emitAttrWithStorageType(name, emitName, attr);
4284 emitAttrGetterWithReturnType(fctx, genericAdaptorBase, op, emitName, attr);
4287 unsigned numRegions = op.getNumRegions();
4288 for (unsigned i = 0; i < numRegions; ++i) {
4289 const auto &region = op.getRegion(i);
4290 if (region.name.empty())
4291 continue;
4293 // Generate the accessors for a variadic region.
4294 std::string name = op.getGetterName(region.name);
4295 if (region.isVariadic()) {
4296 auto *m = genericAdaptorBase.addInlineMethod("::mlir::RegionRange", name);
4297 ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
4298 m->body() << formatv(" return odsRegions.drop_front({0});", i);
4299 continue;
4302 auto *m = genericAdaptorBase.addInlineMethod("::mlir::Region &", name);
4303 ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
4304 m->body() << formatv(" return *odsRegions[{0}];", i);
4306 if (numRegions > 0) {
4307 // Any invalid overlap for `getRegions` will have been diagnosed before
4308 // here already.
4309 if (auto *m = genericAdaptorBase.addInlineMethod("::mlir::RegionRange",
4310 "getRegions"))
4311 m->body() << " return odsRegions;";
4314 StringRef genericAdaptorClassName = genericAdaptor.getClassName();
4315 adaptor.addParent(ParentClass(genericAdaptorClassName))
4316 .addTemplateParam("::mlir::ValueRange");
4317 adaptor.declare<VisibilityDeclaration>(Visibility::Public);
4318 adaptor.declare<UsingDeclaration>(genericAdaptorClassName +
4319 "::" + genericAdaptorClassName);
4321 // Constructor taking the Op as single parameter.
4322 auto *constructor =
4323 adaptor.addConstructor(MethodParameter(op.getCppClassName(), "op"));
4324 constructor->addMemberInitializer(genericAdaptorClassName,
4325 "op->getOperands(), op");
4328 // Add verification function.
4329 addVerification();
4331 genericAdaptorBase.finalize();
4332 genericAdaptor.finalize();
4333 adaptor.finalize();
4336 void OpOperandAdaptorEmitter::addVerification() {
4337 auto *method = adaptor.addMethod("::llvm::LogicalResult", "verify",
4338 MethodParameter("::mlir::Location", "loc"));
4339 ERROR_IF_PRUNED(method, "verify", op);
4340 auto &body = method->body();
4341 bool useProperties = emitHelper.hasProperties();
4343 FmtContext verifyCtx;
4344 populateSubstitutions(emitHelper, verifyCtx);
4345 genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter,
4346 useProperties);
4348 body << " return ::mlir::success();";
4351 void OpOperandAdaptorEmitter::emitDecl(
4352 const Operator &op,
4353 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
4354 raw_ostream &os) {
4355 OpOperandAdaptorEmitter emitter(op, staticVerifierEmitter);
4357 NamespaceEmitter ns(os, "detail");
4358 emitter.genericAdaptorBase.writeDeclTo(os);
4360 emitter.genericAdaptor.writeDeclTo(os);
4361 emitter.adaptor.writeDeclTo(os);
4364 void OpOperandAdaptorEmitter::emitDef(
4365 const Operator &op,
4366 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
4367 raw_ostream &os) {
4368 OpOperandAdaptorEmitter emitter(op, staticVerifierEmitter);
4370 NamespaceEmitter ns(os, "detail");
4371 emitter.genericAdaptorBase.writeDefTo(os);
4373 emitter.genericAdaptor.writeDefTo(os);
4374 emitter.adaptor.writeDefTo(os);
4377 /// Emit the class declarations or definitions for the given op defs.
4378 static void
4379 emitOpClasses(const RecordKeeper &recordKeeper,
4380 const std::vector<Record *> &defs, raw_ostream &os,
4381 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
4382 bool emitDecl) {
4383 if (defs.empty())
4384 return;
4386 for (auto *def : defs) {
4387 Operator op(*def);
4388 if (emitDecl) {
4390 NamespaceEmitter emitter(os, op.getCppNamespace());
4391 os << formatv(opCommentHeader, op.getQualCppClassName(),
4392 "declarations");
4393 OpOperandAdaptorEmitter::emitDecl(op, staticVerifierEmitter, os);
4394 OpEmitter::emitDecl(op, os, staticVerifierEmitter);
4396 // Emit the TypeID explicit specialization to have a single definition.
4397 if (!op.getCppNamespace().empty())
4398 os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << op.getCppNamespace()
4399 << "::" << op.getCppClassName() << ")\n\n";
4400 } else {
4402 NamespaceEmitter emitter(os, op.getCppNamespace());
4403 os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
4404 OpOperandAdaptorEmitter::emitDef(op, staticVerifierEmitter, os);
4405 OpEmitter::emitDef(op, os, staticVerifierEmitter);
4407 // Emit the TypeID explicit specialization to have a single definition.
4408 if (!op.getCppNamespace().empty())
4409 os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << op.getCppNamespace()
4410 << "::" << op.getCppClassName() << ")\n\n";
4415 /// Emit the declarations for the provided op classes.
4416 static void emitOpClassDecls(const RecordKeeper &recordKeeper,
4417 const std::vector<Record *> &defs,
4418 raw_ostream &os) {
4419 // First emit forward declaration for each class, this allows them to refer
4420 // to each others in traits for example.
4421 for (auto *def : defs) {
4422 Operator op(*def);
4423 NamespaceEmitter emitter(os, op.getCppNamespace());
4424 os << "class " << op.getCppClassName() << ";\n";
4427 // Emit the op class declarations.
4428 IfDefScope scope("GET_OP_CLASSES", os);
4429 if (defs.empty())
4430 return;
4431 StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper);
4432 staticVerifierEmitter.collectOpConstraints(defs);
4433 emitOpClasses(recordKeeper, defs, os, staticVerifierEmitter,
4434 /*emitDecl=*/true);
4437 /// Emit the definitions for the provided op classes.
4438 static void emitOpClassDefs(const RecordKeeper &recordKeeper,
4439 ArrayRef<Record *> defs, raw_ostream &os,
4440 StringRef constraintPrefix = "") {
4441 if (defs.empty())
4442 return;
4444 // Generate all of the locally instantiated methods first.
4445 StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper,
4446 constraintPrefix);
4447 os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
4448 staticVerifierEmitter.collectOpConstraints(defs);
4449 staticVerifierEmitter.emitOpConstraints(defs);
4451 // Emit the classes.
4452 emitOpClasses(recordKeeper, defs, os, staticVerifierEmitter,
4453 /*emitDecl=*/false);
4456 /// Emit op declarations for all op records.
4457 static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
4458 emitSourceFileHeader("Op Declarations", os, recordKeeper);
4460 std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
4461 emitOpClassDecls(recordKeeper, defs, os);
4463 // If we are generating sharded op definitions, emit the sharded op
4464 // registration hooks.
4465 SmallVector<ArrayRef<Record *>, 4> shardedDefs;
4466 shardOpDefinitions(defs, shardedDefs);
4467 if (defs.empty() || shardedDefs.size() <= 1)
4468 return false;
4470 Dialect dialect = Operator(defs.front()).getDialect();
4471 NamespaceEmitter ns(os, dialect);
4473 const char *const opRegistrationHook =
4474 "void register{0}Operations{1}({2}::{0} *dialect);\n";
4475 os << formatv(opRegistrationHook, dialect.getCppClassName(), "",
4476 dialect.getCppNamespace());
4477 for (unsigned i = 0; i < shardedDefs.size(); ++i) {
4478 os << formatv(opRegistrationHook, dialect.getCppClassName(), i,
4479 dialect.getCppNamespace());
4482 return false;
4485 /// Generate the dialect op registration hook and the op class definitions for a
4486 /// shard of ops.
4487 static void emitOpDefShard(const RecordKeeper &recordKeeper,
4488 ArrayRef<Record *> defs, const Dialect &dialect,
4489 unsigned shardIndex, unsigned shardCount,
4490 raw_ostream &os) {
4491 std::string shardGuard = "GET_OP_DEFS_";
4492 std::string indexStr = std::to_string(shardIndex);
4493 shardGuard += indexStr;
4494 IfDefScope scope(shardGuard, os);
4496 // Emit the op registration hook in the first shard.
4497 const char *const opRegistrationHook =
4498 "void {0}::register{1}Operations{2}({0}::{1} *dialect) {{\n";
4499 if (shardIndex == 0) {
4500 os << formatv(opRegistrationHook, dialect.getCppNamespace(),
4501 dialect.getCppClassName(), "");
4502 for (unsigned i = 0; i < shardCount; ++i) {
4503 os << formatv(" {0}::register{1}Operations{2}(dialect);\n",
4504 dialect.getCppNamespace(), dialect.getCppClassName(), i);
4506 os << "}\n";
4509 // Generate the per-shard op registration hook.
4510 os << formatv(opCommentHeader, dialect.getCppClassName(),
4511 "Op Registration Hook")
4512 << formatv(opRegistrationHook, dialect.getCppNamespace(),
4513 dialect.getCppClassName(), shardIndex);
4514 for (Record *def : defs) {
4515 os << formatv(" ::mlir::RegisteredOperationName::insert<{0}>(*dialect);\n",
4516 Operator(def).getQualCppClassName());
4518 os << "}\n";
4520 // Generate the per-shard op definitions.
4521 emitOpClassDefs(recordKeeper, defs, os, indexStr);
4524 /// Emit op definitions for all op records.
4525 static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
4526 emitSourceFileHeader("Op Definitions", os, recordKeeper);
4528 std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
4529 SmallVector<ArrayRef<Record *>, 4> shardedDefs;
4530 shardOpDefinitions(defs, shardedDefs);
4532 // If no shard was requested, emit the regular op list and class definitions.
4533 if (shardedDefs.size() == 1) {
4535 IfDefScope scope("GET_OP_LIST", os);
4536 interleave(
4537 defs, os,
4538 [&](Record *def) { os << Operator(def).getQualCppClassName(); },
4539 ",\n");
4542 IfDefScope scope("GET_OP_CLASSES", os);
4543 emitOpClassDefs(recordKeeper, defs, os);
4545 return false;
4548 if (defs.empty())
4549 return false;
4550 Dialect dialect = Operator(defs.front()).getDialect();
4551 for (auto [idx, value] : llvm::enumerate(shardedDefs)) {
4552 emitOpDefShard(recordKeeper, value, dialect, idx, shardedDefs.size(), os);
4554 return false;
4557 static mlir::GenRegistration
4558 genOpDecls("gen-op-decls", "Generate op declarations",
4559 [](const RecordKeeper &records, raw_ostream &os) {
4560 return emitOpDecls(records, os);
4563 static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions",
4564 [](const RecordKeeper &records,
4565 raw_ostream &os) {
4566 return emitOpDefs(records, os);