[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / tools / mlir-tblgen / OpDefinitionsGen.cpp
blob9badb7aa163a60952d6c3f22f4dbf0f1f69fb42c
1 //===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpDefinitionsGen uses the description of operations to generate C++
10 // definitions for ops.
12 //===----------------------------------------------------------------------===//
14 #include "OpClass.h"
15 #include "OpFormatGen.h"
16 #include "OpGenHelpers.h"
17 #include "mlir/TableGen/Argument.h"
18 #include "mlir/TableGen/Attribute.h"
19 #include "mlir/TableGen/Class.h"
20 #include "mlir/TableGen/CodeGenHelpers.h"
21 #include "mlir/TableGen/Format.h"
22 #include "mlir/TableGen/GenInfo.h"
23 #include "mlir/TableGen/Interfaces.h"
24 #include "mlir/TableGen/Operator.h"
25 #include "mlir/TableGen/Property.h"
26 #include "mlir/TableGen/SideEffects.h"
27 #include "mlir/TableGen/Trait.h"
28 #include "llvm/ADT/BitVector.h"
29 #include "llvm/ADT/MapVector.h"
30 #include "llvm/ADT/Sequence.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/StringExtras.h"
33 #include "llvm/ADT/StringSet.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/ErrorHandling.h"
36 #include "llvm/Support/Signals.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include "llvm/TableGen/Error.h"
39 #include "llvm/TableGen/Record.h"
40 #include "llvm/TableGen/TableGenBackend.h"
42 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
44 using namespace llvm;
45 using namespace mlir;
46 using namespace mlir::tblgen;
48 static const char *const tblgenNamePrefix = "tblgen_";
49 static const char *const generatedArgName = "odsArg";
50 static const char *const odsBuilder = "odsBuilder";
51 static const char *const builderOpState = "odsState";
52 static const char *const propertyStorage = "propStorage";
53 static const char *const propertyValue = "propValue";
54 static const char *const propertyAttr = "propAttr";
55 static const char *const propertyDiag = "emitError";
57 /// The names of the implicit attributes that contain variadic operand and
58 /// result segment sizes.
59 static const char *const operandSegmentAttrName = "operandSegmentSizes";
60 static const char *const resultSegmentAttrName = "resultSegmentSizes";
62 /// Code for an Op to lookup an attribute. Uses cached identifiers and subrange
63 /// lookup.
64 ///
65 /// {0}: Code snippet to get the attribute's name or identifier.
66 /// {1}: The lower bound on the sorted subrange.
67 /// {2}: The upper bound on the sorted subrange.
68 /// {3}: Code snippet to get the array of named attributes.
69 /// {4}: "Named" to get the named attribute.
70 static const char *const subrangeGetAttr =
71 "::mlir::impl::get{4}AttrFromSortedRange({3}.begin() + {1}, {3}.end() - "
72 "{2}, {0})";
74 /// The logic to calculate the actual value range for a declared operand/result
75 /// of an op with variadic operands/results. Note that this logic is not for
76 /// general use; it assumes all variadic operands/results must have the same
77 /// number of values.
78 ///
79 /// {0}: The list of whether each declared operand/result is variadic.
80 /// {1}: The total number of non-variadic operands/results.
81 /// {2}: The total number of variadic operands/results.
82 /// {3}: The total number of actual values.
83 /// {4}: "operand" or "result".
84 static const char *const sameVariadicSizeValueRangeCalcCode = R"(
85 bool isVariadic[] = {{{0}};
86 int prevVariadicCount = 0;
87 for (unsigned i = 0; i < index; ++i)
88 if (isVariadic[i]) ++prevVariadicCount;
90 // Calculate how many dynamic values a static variadic {4} corresponds to.
91 // This assumes all static variadic {4}s have the same dynamic value count.
92 int variadicSize = ({3} - {1}) / {2};
93 // `index` passed in as the parameter is the static index which counts each
94 // {4} (variadic or not) as size 1. So here for each previous static variadic
95 // {4}, we need to offset by (variadicSize - 1) to get where the dynamic
96 // value pack for this static {4} starts.
97 int start = index + (variadicSize - 1) * prevVariadicCount;
98 int size = isVariadic[index] ? variadicSize : 1;
99 return {{start, size};
102 /// The logic to calculate the actual value range for a declared operand/result
103 /// of an op with variadic operands/results. Note that this logic is assumes
104 /// the op has an attribute specifying the size of each operand/result segment
105 /// (variadic or not).
106 static const char *const attrSizedSegmentValueRangeCalcCode = R"(
107 unsigned start = 0;
108 for (unsigned i = 0; i < index; ++i)
109 start += sizeAttr[i];
110 return {start, sizeAttr[index]};
112 /// The code snippet to initialize the sizes for the value range calculation.
114 /// {0}: The code to get the attribute.
115 static const char *const adapterSegmentSizeAttrInitCode = R"(
116 assert({0} && "missing segment size attribute for op");
117 auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>({0});
119 static const char *const adapterSegmentSizeAttrInitCodeProperties = R"(
120 ::llvm::ArrayRef<int32_t> sizeAttr = {0};
123 /// The code snippet to initialize the sizes for the value range calculation.
125 /// {0}: The code to get the attribute.
126 static const char *const opSegmentSizeAttrInitCode = R"(
127 auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>({0});
130 /// The logic to calculate the actual value range for a declared operand
131 /// of an op with variadic of variadic operands within the OpAdaptor.
133 /// {0}: The name of the segment attribute.
134 /// {1}: The index of the main operand.
135 /// {2}: The range type of adaptor.
136 static const char *const variadicOfVariadicAdaptorCalcCode = R"(
137 auto tblgenTmpOperands = getODSOperands({1});
138 auto sizes = {0}();
140 ::llvm::SmallVector<{2}> tblgenTmpOperandGroups;
141 for (int i = 0, e = sizes.size(); i < e; ++i) {{
142 tblgenTmpOperandGroups.push_back(tblgenTmpOperands.take_front(sizes[i]));
143 tblgenTmpOperands = tblgenTmpOperands.drop_front(sizes[i]);
145 return tblgenTmpOperandGroups;
148 /// The logic to build a range of either operand or result values.
150 /// {0}: The begin iterator of the actual values.
151 /// {1}: The call to generate the start and length of the value range.
152 static const char *const valueRangeReturnCode = R"(
153 auto valueRange = {1};
154 return {{std::next({0}, valueRange.first),
155 std::next({0}, valueRange.first + valueRange.second)};
158 /// Parse operand/result segment_size property.
159 /// {0}: Number of elements in the segment array
160 static const char *const parseTextualSegmentSizeFormat = R"(
161 size_t i = 0;
162 auto parseElem = [&]() -> ::mlir::ParseResult {
163 if (i >= {0})
164 return $_parser.emitError($_parser.getCurrentLocation(),
165 "expected `]` after {0} segment sizes");
166 if (failed($_parser.parseInteger($_storage[i])))
167 return ::mlir::failure();
168 i += 1;
169 return ::mlir::success();
171 if (failed($_parser.parseCommaSeparatedList(
172 ::mlir::AsmParser::Delimeter::Square, parseElem)))
173 return failure();
174 if (i < {0})
175 return $_parser.emitError($_parser.getCurrentLocation(),
176 "expected {0} segment sizes, found only ") << i;
177 return success();
180 static const char *const printTextualSegmentSize = R"(
181 [&]() {
182 $_printer << '[';
183 ::llvm::interleaveComma($_storage, $_printer);
184 $_printer << ']';
188 /// Read operand/result segment_size from bytecode.
189 static const char *const readBytecodeSegmentSizeNative = R"(
190 if ($_reader.getBytecodeVersion() >= /*kNativePropertiesODSSegmentSize=*/6)
191 return $_reader.readSparseArray(::llvm::MutableArrayRef($_storage));
194 static const char *const readBytecodeSegmentSizeLegacy = R"(
195 if ($_reader.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) {
196 auto &$_storage = prop.$_propName;
197 ::mlir::DenseI32ArrayAttr attr;
198 if (::mlir::failed($_reader.readAttribute(attr))) return ::mlir::failure();
199 if (attr.size() > static_cast<int64_t>(sizeof($_storage) / sizeof(int32_t))) {
200 $_reader.emitError("size mismatch for operand/result_segment_size");
201 return ::mlir::failure();
203 ::llvm::copy(::llvm::ArrayRef<int32_t>(attr), $_storage.begin());
207 /// Write operand/result segment_size to bytecode.
208 static const char *const writeBytecodeSegmentSizeNative = R"(
209 if ($_writer.getBytecodeVersion() >= /*kNativePropertiesODSSegmentSize=*/6)
210 $_writer.writeSparseArray(::llvm::ArrayRef($_storage));
213 /// Write operand/result segment_size to bytecode.
214 static const char *const writeBytecodeSegmentSizeLegacy = R"(
215 if ($_writer.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) {
216 auto &$_storage = prop.$_propName;
217 $_writer.writeAttribute(::mlir::DenseI32ArrayAttr::get($_ctxt, $_storage));
221 /// A header for indicating code sections.
223 /// {0}: Some text, or a class name.
224 /// {1}: Some text.
225 static const char *const opCommentHeader = R"(
226 //===----------------------------------------------------------------------===//
227 // {0} {1}
228 //===----------------------------------------------------------------------===//
232 //===----------------------------------------------------------------------===//
233 // Utility structs and functions
234 //===----------------------------------------------------------------------===//
236 // Replaces all occurrences of `match` in `str` with `substitute`.
237 static std::string replaceAllSubstrs(std::string str, const std::string &match,
238 const std::string &substitute) {
239 std::string::size_type scanLoc = 0, matchLoc = std::string::npos;
240 while ((matchLoc = str.find(match, scanLoc)) != std::string::npos) {
241 str = str.replace(matchLoc, match.size(), substitute);
242 scanLoc = matchLoc + substitute.size();
244 return str;
247 // Returns whether the record has a value of the given name that can be returned
248 // via getValueAsString.
249 static inline bool hasStringAttribute(const Record &record,
250 StringRef fieldName) {
251 auto *valueInit = record.getValueInit(fieldName);
252 return isa<StringInit>(valueInit);
255 static std::string getArgumentName(const Operator &op, int index) {
256 const auto &operand = op.getOperand(index);
257 if (!operand.name.empty())
258 return std::string(operand.name);
259 return std::string(formatv("{0}_{1}", generatedArgName, index));
262 // Returns true if we can use unwrapped value for the given `attr` in builders.
263 static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
264 return attr.getReturnType() != attr.getStorageType() &&
265 // We need to wrap the raw value into an attribute in the builder impl
266 // so we need to make sure that the attribute specifies how to do that.
267 !attr.getConstBuilderTemplate().empty();
270 /// Build an attribute from a parameter value using the constant builder.
271 static std::string constBuildAttrFromParam(const tblgen::Attribute &attr,
272 FmtContext &fctx,
273 StringRef paramName) {
274 std::string builderTemplate = attr.getConstBuilderTemplate().str();
276 // For StringAttr, its constant builder call will wrap the input in
277 // quotes, which is correct for normal string literals, but incorrect
278 // here given we use function arguments. So we need to strip the
279 // wrapping quotes.
280 if (StringRef(builderTemplate).contains("\"$0\""))
281 builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
283 return tgfmt(builderTemplate, &fctx, paramName).str();
286 namespace {
287 /// Metadata on a registered attribute. Given that attributes are stored in
288 /// sorted order on operations, we can use information from ODS to deduce the
289 /// number of required attributes less and and greater than each attribute,
290 /// allowing us to search only a subrange of the attributes in ODS-generated
291 /// getters.
292 struct AttributeMetadata {
293 /// The attribute name.
294 StringRef attrName;
295 /// Whether the attribute is required.
296 bool isRequired;
297 /// The ODS attribute constraint. Not present for implicit attributes.
298 std::optional<Attribute> constraint;
299 /// The number of required attributes less than this attribute.
300 unsigned lowerBound = 0;
301 /// The number of required attributes greater than this attribute.
302 unsigned upperBound = 0;
305 /// Helper class to select between OpAdaptor and Op code templates.
306 class OpOrAdaptorHelper {
307 public:
308 OpOrAdaptorHelper(const Operator &op, bool emitForOp)
309 : op(op), emitForOp(emitForOp) {
310 computeAttrMetadata();
313 /// Object that wraps a functor in a stream operator for interop with
314 /// llvm::formatv.
315 class Formatter {
316 public:
317 template <typename Functor>
318 Formatter(Functor &&func) : func(std::forward<Functor>(func)) {}
320 std::string str() const {
321 std::string result;
322 llvm::raw_string_ostream os(result);
323 os << *this;
324 return os.str();
327 private:
328 std::function<raw_ostream &(raw_ostream &)> func;
330 friend raw_ostream &operator<<(raw_ostream &os, const Formatter &fmt) {
331 return fmt.func(os);
335 // Generate code for getting an attribute.
336 Formatter getAttr(StringRef attrName, bool isNamed = false) const {
337 assert(attrMetadata.count(attrName) && "expected attribute metadata");
338 return [this, attrName, isNamed](raw_ostream &os) -> raw_ostream & {
339 const AttributeMetadata &attr = attrMetadata.find(attrName)->second;
340 if (hasProperties()) {
341 assert(!isNamed);
342 return os << "getProperties()." << attrName;
344 return os << formatv(subrangeGetAttr, getAttrName(attrName),
345 attr.lowerBound, attr.upperBound, getAttrRange(),
346 isNamed ? "Named" : "");
350 // Generate code for getting the name of an attribute.
351 Formatter getAttrName(StringRef attrName) const {
352 return [this, attrName](raw_ostream &os) -> raw_ostream & {
353 if (emitForOp)
354 return os << op.getGetterName(attrName) << "AttrName()";
355 return os << formatv("{0}::{1}AttrName(*odsOpName)", op.getCppClassName(),
356 op.getGetterName(attrName));
360 // Get the code snippet for getting the named attribute range.
361 StringRef getAttrRange() const {
362 return emitForOp ? "(*this)->getAttrs()" : "odsAttrs";
365 // Get the prefix code for emitting an error.
366 Formatter emitErrorPrefix() const {
367 return [this](raw_ostream &os) -> raw_ostream & {
368 if (emitForOp)
369 return os << "emitOpError(";
370 return os << formatv("emitError(loc, \"'{0}' op \"",
371 op.getOperationName());
375 // Get the call to get an operand or segment of operands.
376 Formatter getOperand(unsigned index) const {
377 return [this, index](raw_ostream &os) -> raw_ostream & {
378 return os << formatv(op.getOperand(index).isVariadic()
379 ? "this->getODSOperands({0})"
380 : "(*this->getODSOperands({0}).begin())",
381 index);
385 // Get the call to get a result of segment of results.
386 Formatter getResult(unsigned index) const {
387 return [this, index](raw_ostream &os) -> raw_ostream & {
388 if (!emitForOp)
389 return os << "<no results should be generated>";
390 return os << formatv(op.getResult(index).isVariadic()
391 ? "this->getODSResults({0})"
392 : "(*this->getODSResults({0}).begin())",
393 index);
397 // Return whether an op instance is available.
398 bool isEmittingForOp() const { return emitForOp; }
400 // Return the ODS operation wrapper.
401 const Operator &getOp() const { return op; }
403 // Get the attribute metadata sorted by name.
404 const llvm::MapVector<StringRef, AttributeMetadata> &getAttrMetadata() const {
405 return attrMetadata;
408 /// Returns whether to emit a `Properties` struct for this operation or not.
409 bool hasProperties() const {
410 if (!op.getProperties().empty())
411 return true;
412 if (!op.getDialect().usePropertiesForAttributes())
413 return false;
414 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments") ||
415 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
416 return true;
417 return llvm::any_of(getAttrMetadata(),
418 [](const std::pair<StringRef, AttributeMetadata> &it) {
419 return !it.second.constraint ||
420 !it.second.constraint->isDerivedAttr();
424 std::optional<NamedProperty> &getOperandSegmentsSize() {
425 return operandSegmentsSize;
428 std::optional<NamedProperty> &getResultSegmentsSize() {
429 return resultSegmentsSize;
432 uint32_t getOperandSegmentSizesLegacyIndex() {
433 return operandSegmentSizesLegacyIndex;
436 uint32_t getResultSegmentSizesLegacyIndex() {
437 return resultSegmentSizesLegacyIndex;
440 private:
441 // Compute the attribute metadata.
442 void computeAttrMetadata();
444 // The operation ODS wrapper.
445 const Operator &op;
446 // True if code is being generate for an op. False for an adaptor.
447 const bool emitForOp;
449 // The attribute metadata, mapped by name.
450 llvm::MapVector<StringRef, AttributeMetadata> attrMetadata;
452 // Property
453 std::optional<NamedProperty> operandSegmentsSize;
454 std::string operandSegmentsSizeStorage;
455 std::string operandSegmentsSizeParser;
456 std::optional<NamedProperty> resultSegmentsSize;
457 std::string resultSegmentsSizeStorage;
458 std::string resultSegmentsSizeParser;
460 // Indices to store the position in the emission order of the operand/result
461 // segment sizes attribute if emitted as part of the properties for legacy
462 // bytecode encodings, i.e. versions less than 6.
463 uint32_t operandSegmentSizesLegacyIndex = 0;
464 uint32_t resultSegmentSizesLegacyIndex = 0;
466 // The number of required attributes.
467 unsigned numRequired;
470 } // namespace
472 void OpOrAdaptorHelper::computeAttrMetadata() {
473 // Enumerate the attribute names of this op, ensuring the attribute names are
474 // unique in case implicit attributes are explicitly registered.
475 for (const NamedAttribute &namedAttr : op.getAttributes()) {
476 Attribute attr = namedAttr.attr;
477 bool isOptional =
478 attr.hasDefaultValue() || attr.isOptional() || attr.isDerivedAttr();
479 attrMetadata.insert(
480 {namedAttr.name, AttributeMetadata{namedAttr.name, !isOptional, attr}});
483 auto makeProperty = [&](StringRef storageType, StringRef parserCall) {
484 return Property(
485 /*summary=*/"",
486 /*description=*/"",
487 /*storageType=*/storageType,
488 /*interfaceType=*/"::llvm::ArrayRef<int32_t>",
489 /*convertFromStorageCall=*/"$_storage",
490 /*assignToStorageCall=*/
491 "::llvm::copy($_value, $_storage.begin())",
492 /*convertToAttributeCall=*/
493 "return ::mlir::DenseI32ArrayAttr::get($_ctxt, $_storage);",
494 /*convertFromAttributeCall=*/
495 "return convertFromAttribute($_storage, $_attr, $_diag);",
496 /*parserCall=*/parserCall,
497 /*optionalParserCall=*/"",
498 /*printerCall=*/printTextualSegmentSize,
499 /*readFromMlirBytecodeCall=*/readBytecodeSegmentSizeNative,
500 /*writeToMlirBytecodeCall=*/writeBytecodeSegmentSizeNative,
501 /*hashPropertyCall=*/
502 "::llvm::hash_combine_range(std::begin($_storage), "
503 "std::end($_storage));",
504 /*StringRef defaultValue=*/"",
505 /*storageTypeValueOverride=*/"");
507 // Include key attributes from several traits as implicitly registered.
508 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
509 if (op.getDialect().usePropertiesForAttributes()) {
510 operandSegmentsSizeStorage =
511 llvm::formatv("std::array<int32_t, {0}>", op.getNumOperands());
512 operandSegmentsSizeParser =
513 llvm::formatv(parseTextualSegmentSizeFormat, op.getNumOperands());
514 operandSegmentsSize = {
515 "operandSegmentSizes",
516 makeProperty(operandSegmentsSizeStorage, operandSegmentsSizeParser)};
517 } else {
518 attrMetadata.insert(
519 {operandSegmentAttrName, AttributeMetadata{operandSegmentAttrName,
520 /*isRequired=*/true,
521 /*attr=*/std::nullopt}});
524 if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
525 if (op.getDialect().usePropertiesForAttributes()) {
526 resultSegmentsSizeStorage =
527 llvm::formatv("std::array<int32_t, {0}>", op.getNumResults());
528 resultSegmentsSizeParser =
529 llvm::formatv(parseTextualSegmentSizeFormat, op.getNumResults());
530 resultSegmentsSize = {
531 "resultSegmentSizes",
532 makeProperty(resultSegmentsSizeStorage, resultSegmentsSizeParser)};
533 } else {
534 attrMetadata.insert(
535 {resultSegmentAttrName,
536 AttributeMetadata{resultSegmentAttrName, /*isRequired=*/true,
537 /*attr=*/std::nullopt}});
541 // Store the metadata in sorted order.
542 SmallVector<AttributeMetadata> sortedAttrMetadata =
543 llvm::to_vector(llvm::make_second_range(attrMetadata.takeVector()));
544 llvm::sort(sortedAttrMetadata,
545 [](const AttributeMetadata &lhs, const AttributeMetadata &rhs) {
546 return lhs.attrName < rhs.attrName;
549 // Store the position of the legacy operand_segment_sizes /
550 // result_segment_sizes so we can emit a backward compatible property readers
551 // and writers.
552 StringRef legacyOperandSegmentSizeName =
553 StringLiteral("operand_segment_sizes");
554 StringRef legacyResultSegmentSizeName = StringLiteral("result_segment_sizes");
555 operandSegmentSizesLegacyIndex = 0;
556 resultSegmentSizesLegacyIndex = 0;
557 for (auto item : sortedAttrMetadata) {
558 if (item.attrName < legacyOperandSegmentSizeName)
559 ++operandSegmentSizesLegacyIndex;
560 if (item.attrName < legacyResultSegmentSizeName)
561 ++resultSegmentSizesLegacyIndex;
564 // Compute the subrange bounds for each attribute.
565 numRequired = 0;
566 for (AttributeMetadata &attr : sortedAttrMetadata) {
567 attr.lowerBound = numRequired;
568 numRequired += attr.isRequired;
570 for (AttributeMetadata &attr : sortedAttrMetadata)
571 attr.upperBound = numRequired - attr.lowerBound - attr.isRequired;
573 // Store the results back into the map.
574 for (const AttributeMetadata &attr : sortedAttrMetadata)
575 attrMetadata.insert({attr.attrName, attr});
578 //===----------------------------------------------------------------------===//
579 // Op emitter
580 //===----------------------------------------------------------------------===//
582 namespace {
583 // Helper class to emit a record into the given output stream.
584 class OpEmitter {
585 using ConstArgument =
586 llvm::PointerUnion<const AttributeMetadata *, const NamedProperty *>;
588 public:
589 static void
590 emitDecl(const Operator &op, raw_ostream &os,
591 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
592 static void
593 emitDef(const Operator &op, raw_ostream &os,
594 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
596 private:
597 OpEmitter(const Operator &op,
598 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
600 void emitDecl(raw_ostream &os);
601 void emitDef(raw_ostream &os);
603 // Generate methods for accessing the attribute names of this operation.
604 void genAttrNameGetters();
606 // Generates the OpAsmOpInterface for this operation if possible.
607 void genOpAsmInterface();
609 // Generates the `getOperationName` method for this op.
610 void genOpNameGetter();
612 // Generates code to manage the properties, if any!
613 void genPropertiesSupport();
615 // Generates code to manage the encoding of properties to bytecode.
616 void
617 genPropertiesSupportForBytecode(ArrayRef<ConstArgument> attrOrProperties);
619 // Generates getters for the properties.
620 void genPropGetters();
622 // Generates seters for the properties.
623 void genPropSetters();
625 // Generates getters for the attributes.
626 void genAttrGetters();
628 // Generates setter for the attributes.
629 void genAttrSetters();
631 // Generates removers for optional attributes.
632 void genOptionalAttrRemovers();
634 // Generates getters for named operands.
635 void genNamedOperandGetters();
637 // Generates setters for named operands.
638 void genNamedOperandSetters();
640 // Generates getters for named results.
641 void genNamedResultGetters();
643 // Generates getters for named regions.
644 void genNamedRegionGetters();
646 // Generates getters for named successors.
647 void genNamedSuccessorGetters();
649 // Generates the method to populate default attributes.
650 void genPopulateDefaultAttributes();
652 // Generates builder methods for the operation.
653 void genBuilder();
655 // Generates the build() method that takes each operand/attribute
656 // as a stand-alone parameter.
657 void genSeparateArgParamBuilder();
659 // Generates the build() method that takes each operand/attribute as a
660 // stand-alone parameter. The generated build() method uses first operand's
661 // type as all results' types.
662 void genUseOperandAsResultTypeSeparateParamBuilder();
664 // Generates the build() method that takes all operands/attributes
665 // collectively as one parameter. The generated build() method uses first
666 // operand's type as all results' types.
667 void genUseOperandAsResultTypeCollectiveParamBuilder();
669 // Generates the build() method that takes aggregate operands/attributes
670 // parameters. This build() method uses inferred types as result types.
671 // Requires: The type needs to be inferable via InferTypeOpInterface.
672 void genInferredTypeCollectiveParamBuilder();
674 // Generates the build() method that takes each operand/attribute as a
675 // stand-alone parameter. The generated build() method uses first attribute's
676 // type as all result's types.
677 void genUseAttrAsResultTypeBuilder();
679 // Generates the build() method that takes all result types collectively as
680 // one parameter. Similarly for operands and attributes.
681 void genCollectiveParamBuilder();
683 // The kind of parameter to generate for result types in builders.
684 enum class TypeParamKind {
685 None, // No result type in parameter list.
686 Separate, // A separate parameter for each result type.
687 Collective, // An ArrayRef<Type> for all result types.
690 // The kind of parameter to generate for attributes in builders.
691 enum class AttrParamKind {
692 WrappedAttr, // A wrapped MLIR Attribute instance.
693 UnwrappedValue, // A raw value without MLIR Attribute wrapper.
696 // Builds the parameter list for build() method of this op. This method writes
697 // to `paramList` the comma-separated parameter list and updates
698 // `resultTypeNames` with the names for parameters for specifying result
699 // types. `inferredAttributes` is populated with any attributes that are
700 // elided from the build list. The given `typeParamKind` and `attrParamKind`
701 // controls how result types and attributes are placed in the parameter list.
702 void buildParamList(SmallVectorImpl<MethodParameter> &paramList,
703 llvm::StringSet<> &inferredAttributes,
704 SmallVectorImpl<std::string> &resultTypeNames,
705 TypeParamKind typeParamKind,
706 AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
708 // Adds op arguments and regions into operation state for build() methods.
709 void
710 genCodeForAddingArgAndRegionForBuilder(MethodBody &body,
711 llvm::StringSet<> &inferredAttributes,
712 bool isRawValueAttr = false);
714 // Generates canonicalizer declaration for the operation.
715 void genCanonicalizerDecls();
717 // Generates the folder declaration for the operation.
718 void genFolderDecls();
720 // Generates the parser for the operation.
721 void genParser();
723 // Generates the printer for the operation.
724 void genPrinter();
726 // Generates verify method for the operation.
727 void genVerifier();
729 // Generates custom verify methods for the operation.
730 void genCustomVerifier();
732 // Generates verify statements for operands and results in the operation.
733 // The generated code will be attached to `body`.
734 void genOperandResultVerifier(MethodBody &body,
735 Operator::const_value_range values,
736 StringRef valueKind);
738 // Generates verify statements for regions in the operation.
739 // The generated code will be attached to `body`.
740 void genRegionVerifier(MethodBody &body);
742 // Generates verify statements for successors in the operation.
743 // The generated code will be attached to `body`.
744 void genSuccessorVerifier(MethodBody &body);
746 // Generates the traits used by the object.
747 void genTraits();
749 // Generate the OpInterface methods for all interfaces.
750 void genOpInterfaceMethods();
752 // Generate op interface methods for the given interface.
753 void genOpInterfaceMethods(const tblgen::InterfaceTrait *trait);
755 // Generate op interface method for the given interface method. If
756 // 'declaration' is true, generates a declaration, else a definition.
757 Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
758 bool declaration = true);
760 // Generate the side effect interface methods.
761 void genSideEffectInterfaceMethods();
763 // Generate the type inference interface methods.
764 void genTypeInterfaceMethods();
766 private:
767 // The TableGen record for this op.
768 // TODO: OpEmitter should not have a Record directly,
769 // it should rather go through the Operator for better abstraction.
770 const Record &def;
772 // The wrapper operator class for querying information from this op.
773 const Operator &op;
775 // The C++ code builder for this op
776 OpClass opClass;
778 // The format context for verification code generation.
779 FmtContext verifyCtx;
781 // The emitter containing all of the locally emitted verification functions.
782 const StaticVerifierFunctionEmitter &staticVerifierEmitter;
784 // Helper for emitting op code.
785 OpOrAdaptorHelper emitHelper;
788 } // namespace
790 // Populate the format context `ctx` with substitutions of attributes, operands
791 // and results.
792 static void populateSubstitutions(const OpOrAdaptorHelper &emitHelper,
793 FmtContext &ctx) {
794 // Populate substitutions for attributes.
795 auto &op = emitHelper.getOp();
796 for (const auto &namedAttr : op.getAttributes())
797 ctx.addSubst(namedAttr.name,
798 emitHelper.getOp().getGetterName(namedAttr.name) + "()");
800 // Populate substitutions for named operands.
801 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
802 auto &value = op.getOperand(i);
803 if (!value.name.empty())
804 ctx.addSubst(value.name, emitHelper.getOperand(i).str());
807 // Populate substitutions for results.
808 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
809 auto &value = op.getResult(i);
810 if (!value.name.empty())
811 ctx.addSubst(value.name, emitHelper.getResult(i).str());
815 /// Generate verification on native traits requiring attributes.
816 static void genNativeTraitAttrVerifier(MethodBody &body,
817 const OpOrAdaptorHelper &emitHelper) {
818 // Check that the variadic segment sizes attribute exists and contains the
819 // expected number of elements.
821 // {0}: Attribute name.
822 // {1}: Expected number of elements.
823 // {2}: "operand" or "result".
824 // {3}: Emit error prefix.
825 const char *const checkAttrSizedValueSegmentsCode = R"(
827 auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>(tblgen_{0});
828 auto numElements = sizeAttr.asArrayRef().size();
829 if (numElements != {1})
830 return {3}"'{0}' attribute for specifying {2} segments must have {1} "
831 "elements, but got ") << numElements;
835 // Verify a few traits first so that we can use getODSOperands() and
836 // getODSResults() in the rest of the verifier.
837 auto &op = emitHelper.getOp();
838 if (!op.getDialect().usePropertiesForAttributes()) {
839 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
840 body << formatv(checkAttrSizedValueSegmentsCode, operandSegmentAttrName,
841 op.getNumOperands(), "operand",
842 emitHelper.emitErrorPrefix());
844 if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
845 body << formatv(checkAttrSizedValueSegmentsCode, resultSegmentAttrName,
846 op.getNumResults(), "result",
847 emitHelper.emitErrorPrefix());
852 // Return true if a verifier can be emitted for the attribute: it is not a
853 // derived attribute, it has a predicate, its condition is not empty, and, for
854 // adaptors, the condition does not reference the op.
855 static bool canEmitAttrVerifier(Attribute attr, bool isEmittingForOp) {
856 if (attr.isDerivedAttr())
857 return false;
858 Pred pred = attr.getPredicate();
859 if (pred.isNull())
860 return false;
861 std::string condition = pred.getCondition();
862 return !condition.empty() &&
863 (!StringRef(condition).contains("$_op") || isEmittingForOp);
866 // Generate attribute verification. If an op instance is not available, then
867 // attribute checks that require one will not be emitted.
869 // Attribute verification is performed as follows:
871 // 1. Verify that all required attributes are present in sorted order. This
872 // ensures that we can use subrange lookup even with potentially missing
873 // attributes.
874 // 2. Verify native trait attributes so that other attributes may call methods
875 // that depend on the validity of these attributes, e.g. segment size attributes
876 // and operand or result getters.
877 // 3. Verify the constraints on all present attributes.
878 static void
879 genAttributeVerifier(const OpOrAdaptorHelper &emitHelper, FmtContext &ctx,
880 MethodBody &body,
881 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
882 bool useProperties) {
883 if (emitHelper.getAttrMetadata().empty())
884 return;
886 // Verify the attribute if it is present. This assumes that default values
887 // are valid. This code snippet pastes the condition inline.
889 // TODO: verify the default value is valid (perhaps in debug mode only).
891 // {0}: Attribute variable name.
892 // {1}: Attribute condition code.
893 // {2}: Emit error prefix.
894 // {3}: Attribute name.
895 // {4}: Attribute/constraint description.
896 const char *const verifyAttrInline = R"(
897 if ({0} && !({1}))
898 return {2}"attribute '{3}' failed to satisfy constraint: {4}");
900 // Verify the attribute using a uniqued constraint. Can only be used within
901 // the context of an op.
903 // {0}: Unique constraint name.
904 // {1}: Attribute variable name.
905 // {2}: Attribute name.
906 const char *const verifyAttrUnique = R"(
907 if (::mlir::failed({0}(*this, {1}, "{2}")))
908 return ::mlir::failure();
911 // Traverse the array until the required attribute is found. Return an error
912 // if the traversal reached the end.
914 // {0}: Code to get the name of the attribute.
915 // {1}: The emit error prefix.
916 // {2}: The name of the attribute.
917 const char *const findRequiredAttr = R"(
918 while (true) {{
919 if (namedAttrIt == namedAttrRange.end())
920 return {1}"requires attribute '{2}'");
921 if (namedAttrIt->getName() == {0}) {{
922 tblgen_{2} = namedAttrIt->getValue();
923 break;
924 })";
926 // Emit a check to see if the iteration has encountered an optional attribute.
928 // {0}: Code to get the name of the attribute.
929 // {1}: The name of the attribute.
930 const char *const checkOptionalAttr = R"(
931 else if (namedAttrIt->getName() == {0}) {{
932 tblgen_{1} = namedAttrIt->getValue();
933 })";
935 // Emit the start of the loop for checking trailing attributes.
936 const char *const checkTrailingAttrs = R"(while (true) {
937 if (namedAttrIt == namedAttrRange.end()) {
938 break;
939 })";
941 // Emit the verifier for the attribute.
942 const auto emitVerifier = [&](Attribute attr, StringRef attrName,
943 StringRef varName) {
944 std::string condition = attr.getPredicate().getCondition();
946 std::optional<StringRef> constraintFn;
947 if (emitHelper.isEmittingForOp() &&
948 (constraintFn = staticVerifierEmitter.getAttrConstraintFn(attr))) {
949 body << formatv(verifyAttrUnique, *constraintFn, varName, attrName);
950 } else {
951 body << formatv(verifyAttrInline, varName,
952 tgfmt(condition, &ctx.withSelf(varName)),
953 emitHelper.emitErrorPrefix(), attrName,
954 escapeString(attr.getSummary()));
958 // Prefix variables with `tblgen_` to avoid hiding the attribute accessor.
959 const auto getVarName = [&](StringRef attrName) {
960 return (tblgenNamePrefix + attrName).str();
963 body.indent();
964 if (useProperties) {
965 for (const std::pair<StringRef, AttributeMetadata> &it :
966 emitHelper.getAttrMetadata()) {
967 const AttributeMetadata &metadata = it.second;
968 if (metadata.constraint && metadata.constraint->isDerivedAttr())
969 continue;
970 body << formatv(
971 "auto tblgen_{0} = getProperties().{0}; (void)tblgen_{0};\n",
972 it.first);
973 if (metadata.isRequired)
974 body << formatv(
975 "if (!tblgen_{0}) return {1}\"requires attribute '{0}'\");\n",
976 it.first, emitHelper.emitErrorPrefix());
978 } else {
979 body << formatv("auto namedAttrRange = {0};\n", emitHelper.getAttrRange());
980 body << "auto namedAttrIt = namedAttrRange.begin();\n";
982 // Iterate over the attributes in sorted order. Keep track of the optional
983 // attributes that may be encountered along the way.
984 SmallVector<const AttributeMetadata *> optionalAttrs;
986 for (const std::pair<StringRef, AttributeMetadata> &it :
987 emitHelper.getAttrMetadata()) {
988 const AttributeMetadata &metadata = it.second;
989 if (!metadata.isRequired) {
990 optionalAttrs.push_back(&metadata);
991 continue;
994 body << formatv("::mlir::Attribute {0};\n", getVarName(it.first));
995 for (const AttributeMetadata *optional : optionalAttrs) {
996 body << formatv("::mlir::Attribute {0};\n",
997 getVarName(optional->attrName));
999 body << formatv(findRequiredAttr, emitHelper.getAttrName(it.first),
1000 emitHelper.emitErrorPrefix(), it.first);
1001 for (const AttributeMetadata *optional : optionalAttrs) {
1002 body << formatv(checkOptionalAttr,
1003 emitHelper.getAttrName(optional->attrName),
1004 optional->attrName);
1006 body << "\n ++namedAttrIt;\n}\n";
1007 optionalAttrs.clear();
1009 // Get trailing optional attributes.
1010 if (!optionalAttrs.empty()) {
1011 for (const AttributeMetadata *optional : optionalAttrs) {
1012 body << formatv("::mlir::Attribute {0};\n",
1013 getVarName(optional->attrName));
1015 body << checkTrailingAttrs;
1016 for (const AttributeMetadata *optional : optionalAttrs) {
1017 body << formatv(checkOptionalAttr,
1018 emitHelper.getAttrName(optional->attrName),
1019 optional->attrName);
1021 body << "\n ++namedAttrIt;\n}\n";
1024 body.unindent();
1026 // Emit the checks for segment attributes first so that the other
1027 // constraints can call operand and result getters.
1028 genNativeTraitAttrVerifier(body, emitHelper);
1030 bool isEmittingForOp = emitHelper.isEmittingForOp();
1031 for (const auto &namedAttr : emitHelper.getOp().getAttributes())
1032 if (canEmitAttrVerifier(namedAttr.attr, isEmittingForOp))
1033 emitVerifier(namedAttr.attr, namedAttr.name, getVarName(namedAttr.name));
1036 /// Include declarations specified on NativeTrait
1037 static std::string formatExtraDeclarations(const Operator &op) {
1038 SmallVector<StringRef> extraDeclarations;
1039 // Include extra class declarations from NativeTrait
1040 for (const auto &trait : op.getTraits()) {
1041 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
1042 StringRef value = opTrait->getExtraConcreteClassDeclaration();
1043 if (value.empty())
1044 continue;
1045 extraDeclarations.push_back(value);
1048 extraDeclarations.push_back(op.getExtraClassDeclaration());
1049 return llvm::join(extraDeclarations, "\n");
1052 /// Op extra class definitions have a `$cppClass` substitution that is to be
1053 /// replaced by the C++ class name.
1054 /// Include declarations specified on NativeTrait
1055 static std::string formatExtraDefinitions(const Operator &op) {
1056 SmallVector<StringRef> extraDefinitions;
1057 // Include extra class definitions from NativeTrait
1058 for (const auto &trait : op.getTraits()) {
1059 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
1060 StringRef value = opTrait->getExtraConcreteClassDefinition();
1061 if (value.empty())
1062 continue;
1063 extraDefinitions.push_back(value);
1066 extraDefinitions.push_back(op.getExtraClassDefinition());
1067 FmtContext ctx = FmtContext().addSubst("cppClass", op.getCppClassName());
1068 return tgfmt(llvm::join(extraDefinitions, "\n"), &ctx).str();
1071 OpEmitter::OpEmitter(const Operator &op,
1072 const StaticVerifierFunctionEmitter &staticVerifierEmitter)
1073 : def(op.getDef()), op(op),
1074 opClass(op.getCppClassName(), formatExtraDeclarations(op),
1075 formatExtraDefinitions(op)),
1076 staticVerifierEmitter(staticVerifierEmitter),
1077 emitHelper(op, /*emitForOp=*/true) {
1078 verifyCtx.addSubst("_op", "(*this->getOperation())");
1079 verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()");
1081 genTraits();
1083 // Generate C++ code for various op methods. The order here determines the
1084 // methods in the generated file.
1085 genAttrNameGetters();
1086 genOpAsmInterface();
1087 genOpNameGetter();
1088 genNamedOperandGetters();
1089 genNamedOperandSetters();
1090 genNamedResultGetters();
1091 genNamedRegionGetters();
1092 genNamedSuccessorGetters();
1093 genPropertiesSupport();
1094 genPropGetters();
1095 genPropSetters();
1096 genAttrGetters();
1097 genAttrSetters();
1098 genOptionalAttrRemovers();
1099 genBuilder();
1100 genPopulateDefaultAttributes();
1101 genParser();
1102 genPrinter();
1103 genVerifier();
1104 genCustomVerifier();
1105 genCanonicalizerDecls();
1106 genFolderDecls();
1107 genTypeInterfaceMethods();
1108 genOpInterfaceMethods();
1109 generateOpFormat(op, opClass, emitHelper.hasProperties());
1110 genSideEffectInterfaceMethods();
1112 void OpEmitter::emitDecl(
1113 const Operator &op, raw_ostream &os,
1114 const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
1115 OpEmitter(op, staticVerifierEmitter).emitDecl(os);
1118 void OpEmitter::emitDef(
1119 const Operator &op, raw_ostream &os,
1120 const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
1121 OpEmitter(op, staticVerifierEmitter).emitDef(os);
1124 void OpEmitter::emitDecl(raw_ostream &os) {
1125 opClass.finalize();
1126 opClass.writeDeclTo(os);
1129 void OpEmitter::emitDef(raw_ostream &os) {
1130 opClass.finalize();
1131 opClass.writeDefTo(os);
1134 static void errorIfPruned(size_t line, Method *m, const Twine &methodName,
1135 const Operator &op) {
1136 if (m)
1137 return;
1138 PrintFatalError(op.getLoc(), "Unexpected overlap when generating `" +
1139 methodName + "` for " +
1140 op.getOperationName() + " (from line " +
1141 Twine(line) + ")");
1144 #define ERROR_IF_PRUNED(M, N, O) errorIfPruned(__LINE__, M, N, O)
1146 void OpEmitter::genAttrNameGetters() {
1147 const llvm::MapVector<StringRef, AttributeMetadata> &attributes =
1148 emitHelper.getAttrMetadata();
1149 bool hasOperandSegmentsSize =
1150 op.getDialect().usePropertiesForAttributes() &&
1151 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
1152 // Emit the getAttributeNames method.
1154 auto *method = opClass.addStaticInlineMethod(
1155 "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames");
1156 ERROR_IF_PRUNED(method, "getAttributeNames", op);
1157 auto &body = method->body();
1158 if (!hasOperandSegmentsSize && attributes.empty()) {
1159 body << " return {};";
1160 // Nothing else to do if there are no registered attributes. Exit early.
1161 return;
1163 body << " static ::llvm::StringRef attrNames[] = {";
1164 llvm::interleaveComma(llvm::make_first_range(attributes), body,
1165 [&](StringRef attrName) {
1166 body << "::llvm::StringRef(\"" << attrName << "\")";
1168 if (hasOperandSegmentsSize) {
1169 if (!attributes.empty())
1170 body << ", ";
1171 body << "::llvm::StringRef(\"" << operandSegmentAttrName << "\")";
1173 body << "};\n return ::llvm::ArrayRef(attrNames);";
1176 // Emit the getAttributeNameForIndex methods.
1178 auto *method = opClass.addInlineMethod<Method::Private>(
1179 "::mlir::StringAttr", "getAttributeNameForIndex",
1180 MethodParameter("unsigned", "index"));
1181 ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
1182 method->body()
1183 << " return getAttributeNameForIndex((*this)->getName(), index);";
1186 auto *method = opClass.addStaticInlineMethod<Method::Private>(
1187 "::mlir::StringAttr", "getAttributeNameForIndex",
1188 MethodParameter("::mlir::OperationName", "name"),
1189 MethodParameter("unsigned", "index"));
1190 ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
1192 if (attributes.empty()) {
1193 method->body() << " return {};";
1194 } else {
1195 const char *const getAttrName = R"(
1196 assert(index < {0} && "invalid attribute index");
1197 assert(name.getStringRef() == getOperationName() && "invalid operation name");
1198 assert(name.isRegistered() && "Operation isn't registered, missing a "
1199 "dependent dialect loading?");
1200 return name.getAttributeNames()[index];
1202 method->body() << formatv(getAttrName, attributes.size());
1206 // Generate the <attr>AttrName methods, that expose the attribute names to
1207 // users.
1208 const char *attrNameMethodBody = " return getAttributeNameForIndex({0});";
1209 for (auto [index, attr] :
1210 llvm::enumerate(llvm::make_first_range(attributes))) {
1211 std::string name = op.getGetterName(attr);
1212 std::string methodName = name + "AttrName";
1214 // Generate the non-static variant.
1216 auto *method = opClass.addInlineMethod("::mlir::StringAttr", methodName);
1217 ERROR_IF_PRUNED(method, methodName, op);
1218 method->body() << llvm::formatv(attrNameMethodBody, index);
1221 // Generate the static variant.
1223 auto *method = opClass.addStaticInlineMethod(
1224 "::mlir::StringAttr", methodName,
1225 MethodParameter("::mlir::OperationName", "name"));
1226 ERROR_IF_PRUNED(method, methodName, op);
1227 method->body() << llvm::formatv(attrNameMethodBody,
1228 "name, " + Twine(index));
1231 if (hasOperandSegmentsSize) {
1232 std::string name = op.getGetterName(operandSegmentAttrName);
1233 std::string methodName = name + "AttrName";
1234 // Generate the non-static variant.
1236 auto *method = opClass.addInlineMethod("::mlir::StringAttr", methodName);
1237 ERROR_IF_PRUNED(method, methodName, op);
1238 method->body()
1239 << " return (*this)->getName().getAttributeNames().back();";
1242 // Generate the static variant.
1244 auto *method = opClass.addStaticInlineMethod(
1245 "::mlir::StringAttr", methodName,
1246 MethodParameter("::mlir::OperationName", "name"));
1247 ERROR_IF_PRUNED(method, methodName, op);
1248 method->body() << " return name.getAttributeNames().back();";
1253 // Emit the getter for a named property.
1254 // It is templated to be shared between the Op and the adaptor class.
1255 template <typename OpClassOrAdaptor>
1256 static void emitPropGetter(OpClassOrAdaptor &opClass, const Operator &op,
1257 StringRef name, const Property &prop) {
1258 auto *method = opClass.addInlineMethod(prop.getInterfaceType(), name);
1259 ERROR_IF_PRUNED(method, name, op);
1260 method->body() << formatv(" return getProperties().{0}();", name);
1263 // Emit the getter for an attribute with the return type specified.
1264 // It is templated to be shared between the Op and the adaptor class.
1265 template <typename OpClassOrAdaptor>
1266 static void emitAttrGetterWithReturnType(FmtContext &fctx,
1267 OpClassOrAdaptor &opClass,
1268 const Operator &op, StringRef name,
1269 Attribute attr) {
1270 auto *method = opClass.addMethod(attr.getReturnType(), name);
1271 ERROR_IF_PRUNED(method, name, op);
1272 auto &body = method->body();
1273 body << " auto attr = " << name << "Attr();\n";
1274 if (attr.hasDefaultValue() && attr.isOptional()) {
1275 // Returns the default value if not set.
1276 // TODO: this is inefficient, we are recreating the attribute for every
1277 // call. This should be set instead.
1278 if (!attr.isConstBuildable()) {
1279 PrintFatalError("DefaultValuedAttr of type " + attr.getAttrDefName() +
1280 " must have a constBuilder");
1282 std::string defaultValue = std::string(
1283 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
1284 body << " if (!attr)\n return "
1285 << tgfmt(attr.getConvertFromStorageCall(),
1286 &fctx.withSelf(defaultValue))
1287 << ";\n";
1289 body << " return "
1290 << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr"))
1291 << ";\n";
1294 void OpEmitter::genPropertiesSupport() {
1295 if (!emitHelper.hasProperties())
1296 return;
1298 SmallVector<ConstArgument> attrOrProperties;
1299 for (const std::pair<StringRef, AttributeMetadata> &it :
1300 emitHelper.getAttrMetadata()) {
1301 if (!it.second.constraint || !it.second.constraint->isDerivedAttr())
1302 attrOrProperties.push_back(&it.second);
1304 for (const NamedProperty &prop : op.getProperties())
1305 attrOrProperties.push_back(&prop);
1306 if (emitHelper.getOperandSegmentsSize())
1307 attrOrProperties.push_back(&emitHelper.getOperandSegmentsSize().value());
1308 if (emitHelper.getResultSegmentsSize())
1309 attrOrProperties.push_back(&emitHelper.getResultSegmentsSize().value());
1310 if (attrOrProperties.empty())
1311 return;
1312 auto &setPropMethod =
1313 opClass
1314 .addStaticMethod(
1315 "::llvm::LogicalResult", "setPropertiesFromAttr",
1316 MethodParameter("Properties &", "prop"),
1317 MethodParameter("::mlir::Attribute", "attr"),
1318 MethodParameter(
1319 "::llvm::function_ref<::mlir::InFlightDiagnostic()>",
1320 "emitError"))
1321 ->body();
1322 auto &getPropMethod =
1323 opClass
1324 .addStaticMethod("::mlir::Attribute", "getPropertiesAsAttr",
1325 MethodParameter("::mlir::MLIRContext *", "ctx"),
1326 MethodParameter("const Properties &", "prop"))
1327 ->body();
1328 auto &hashMethod =
1329 opClass
1330 .addStaticMethod("llvm::hash_code", "computePropertiesHash",
1331 MethodParameter("const Properties &", "prop"))
1332 ->body();
1333 auto &getInherentAttrMethod =
1334 opClass
1335 .addStaticMethod("std::optional<mlir::Attribute>", "getInherentAttr",
1336 MethodParameter("::mlir::MLIRContext *", "ctx"),
1337 MethodParameter("const Properties &", "prop"),
1338 MethodParameter("llvm::StringRef", "name"))
1339 ->body();
1340 auto &setInherentAttrMethod =
1341 opClass
1342 .addStaticMethod("void", "setInherentAttr",
1343 MethodParameter("Properties &", "prop"),
1344 MethodParameter("llvm::StringRef", "name"),
1345 MethodParameter("mlir::Attribute", "value"))
1346 ->body();
1347 auto &populateInherentAttrsMethod =
1348 opClass
1349 .addStaticMethod("void", "populateInherentAttrs",
1350 MethodParameter("::mlir::MLIRContext *", "ctx"),
1351 MethodParameter("const Properties &", "prop"),
1352 MethodParameter("::mlir::NamedAttrList &", "attrs"))
1353 ->body();
1354 auto &verifyInherentAttrsMethod =
1355 opClass
1356 .addStaticMethod(
1357 "::llvm::LogicalResult", "verifyInherentAttrs",
1358 MethodParameter("::mlir::OperationName", "opName"),
1359 MethodParameter("::mlir::NamedAttrList &", "attrs"),
1360 MethodParameter(
1361 "llvm::function_ref<::mlir::InFlightDiagnostic()>",
1362 "emitError"))
1363 ->body();
1365 opClass.declare<UsingDeclaration>("Properties", "FoldAdaptor::Properties");
1367 // Convert the property to the attribute form.
1369 setPropMethod << R"decl(
1370 ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr);
1371 if (!dict) {
1372 emitError() << "expected DictionaryAttr to set properties";
1373 return ::mlir::failure();
1375 )decl";
1376 const char *propFromAttrFmt = R"decl(
1377 auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr,
1378 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) -> ::mlir::LogicalResult {{
1381 {1};
1382 )decl";
1383 const char *attrGetNoDefaultFmt = R"decl(;
1384 if (attr && ::mlir::failed(setFromAttr(prop.{0}, attr, emitError)))
1385 return ::mlir::failure();
1386 )decl";
1387 const char *attrGetDefaultFmt = R"decl(;
1388 if (attr) {{
1389 if (::mlir::failed(setFromAttr(prop.{0}, attr, emitError)))
1390 return ::mlir::failure();
1391 } else {{
1392 prop.{0} = {1};
1394 )decl";
1396 for (const auto &attrOrProp : attrOrProperties) {
1397 if (const auto *namedProperty =
1398 llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
1399 StringRef name = namedProperty->name;
1400 auto &prop = namedProperty->prop;
1401 FmtContext fctx;
1403 std::string getAttr;
1404 llvm::raw_string_ostream os(getAttr);
1405 os << " auto attr = dict.get(\"" << name << "\");";
1406 if (name == operandSegmentAttrName) {
1407 // Backward compat for now, TODO: Remove at some point.
1408 os << " if (!attr) attr = dict.get(\"operand_segment_sizes\");";
1410 if (name == resultSegmentAttrName) {
1411 // Backward compat for now, TODO: Remove at some point.
1412 os << " if (!attr) attr = dict.get(\"result_segment_sizes\");";
1415 setPropMethod << "{\n"
1416 << formatv(propFromAttrFmt,
1417 tgfmt(prop.getConvertFromAttributeCall(),
1418 &fctx.addSubst("_attr", propertyAttr)
1419 .addSubst("_storage", propertyStorage)
1420 .addSubst("_diag", propertyDiag)),
1421 getAttr);
1422 if (prop.hasStorageTypeValueOverride()) {
1423 setPropMethod << formatv(attrGetDefaultFmt, name,
1424 prop.getStorageTypeValueOverride());
1425 } else if (prop.hasDefaultValue()) {
1426 setPropMethod << formatv(attrGetDefaultFmt, name,
1427 prop.getDefaultValue());
1428 } else {
1429 setPropMethod << formatv(attrGetNoDefaultFmt, name);
1431 setPropMethod << " }\n";
1432 } else {
1433 const auto *namedAttr =
1434 llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
1435 StringRef name = namedAttr->attrName;
1436 std::string getAttr;
1437 llvm::raw_string_ostream os(getAttr);
1438 os << " auto attr = dict.get(\"" << name << "\");";
1439 if (name == operandSegmentAttrName) {
1440 // Backward compat for now
1441 os << " if (!attr) attr = dict.get(\"operand_segment_sizes\");";
1443 if (name == resultSegmentAttrName) {
1444 // Backward compat for now
1445 os << " if (!attr) attr = dict.get(\"result_segment_sizes\");";
1448 setPropMethod << formatv(R"decl(
1450 auto &propStorage = prop.{0};
1452 if (attr) {{
1453 auto convertedAttr = ::llvm::dyn_cast<std::remove_reference_t<decltype(propStorage)>>(attr);
1454 if (convertedAttr) {{
1455 propStorage = convertedAttr;
1456 } else {{
1457 emitError() << "Invalid attribute `{0}` in property conversion: " << attr;
1458 return ::mlir::failure();
1462 )decl",
1463 name, getAttr);
1466 setPropMethod << " return ::mlir::success();\n";
1468 // Convert the attribute form to the property.
1470 getPropMethod << " ::mlir::SmallVector<::mlir::NamedAttribute> attrs;\n"
1471 << " ::mlir::Builder odsBuilder{ctx};\n";
1472 const char *propToAttrFmt = R"decl(
1474 const auto &propStorage = prop.{0};
1475 auto attr = [&]() -> ::mlir::Attribute {{
1477 }();
1478 attrs.push_back(odsBuilder.getNamedAttr("{0}", attr));
1480 )decl";
1481 for (const auto &attrOrProp : attrOrProperties) {
1482 if (const auto *namedProperty =
1483 llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
1484 StringRef name = namedProperty->name;
1485 auto &prop = namedProperty->prop;
1486 FmtContext fctx;
1487 getPropMethod << formatv(
1488 propToAttrFmt, name,
1489 tgfmt(prop.getConvertToAttributeCall(),
1490 &fctx.addSubst("_ctxt", "ctx")
1491 .addSubst("_storage", propertyStorage)));
1492 continue;
1494 const auto *namedAttr =
1495 llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
1496 StringRef name = namedAttr->attrName;
1497 getPropMethod << formatv(R"decl(
1499 const auto &propStorage = prop.{0};
1500 if (propStorage)
1501 attrs.push_back(odsBuilder.getNamedAttr("{0}",
1502 propStorage));
1504 )decl",
1505 name);
1507 getPropMethod << R"decl(
1508 if (!attrs.empty())
1509 return odsBuilder.getDictionaryAttr(attrs);
1510 return {};
1511 )decl";
1513 // Hashing for the property
1515 const char *propHashFmt = R"decl(
1516 auto hash_{0} = [] (const auto &propStorage) -> llvm::hash_code {
1517 return {1};
1519 )decl";
1520 for (const auto &attrOrProp : attrOrProperties) {
1521 if (const auto *namedProperty =
1522 llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
1523 StringRef name = namedProperty->name;
1524 auto &prop = namedProperty->prop;
1525 FmtContext fctx;
1526 if (!prop.getHashPropertyCall().empty()) {
1527 hashMethod << formatv(
1528 propHashFmt, name,
1529 tgfmt(prop.getHashPropertyCall(),
1530 &fctx.addSubst("_storage", propertyStorage)));
1534 hashMethod << " return llvm::hash_combine(";
1535 llvm::interleaveComma(
1536 attrOrProperties, hashMethod, [&](const ConstArgument &attrOrProp) {
1537 if (const auto *namedProperty =
1538 llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
1539 if (!namedProperty->prop.getHashPropertyCall().empty()) {
1540 hashMethod << "\n hash_" << namedProperty->name << "(prop."
1541 << namedProperty->name << ")";
1542 } else {
1543 hashMethod << "\n ::llvm::hash_value(prop."
1544 << namedProperty->name << ")";
1546 return;
1548 const auto *namedAttr =
1549 llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
1550 StringRef name = namedAttr->attrName;
1551 hashMethod << "\n llvm::hash_value(prop." << name
1552 << ".getAsOpaquePointer())";
1554 hashMethod << ");\n";
1556 const char *getInherentAttrMethodFmt = R"decl(
1557 if (name == "{0}")
1558 return prop.{0};
1559 )decl";
1560 const char *setInherentAttrMethodFmt = R"decl(
1561 if (name == "{0}") {{
1562 prop.{0} = ::llvm::dyn_cast_or_null<std::remove_reference_t<decltype(prop.{0})>>(value);
1563 return;
1565 )decl";
1566 const char *populateInherentAttrsMethodFmt = R"decl(
1567 if (prop.{0}) attrs.append("{0}", prop.{0});
1568 )decl";
1569 for (const auto &attrOrProp : attrOrProperties) {
1570 if (const auto *namedAttr =
1571 llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp)) {
1572 StringRef name = namedAttr->attrName;
1573 getInherentAttrMethod << formatv(getInherentAttrMethodFmt, name);
1574 setInherentAttrMethod << formatv(setInherentAttrMethodFmt, name);
1575 populateInherentAttrsMethod
1576 << formatv(populateInherentAttrsMethodFmt, name);
1577 continue;
1579 // The ODS segment size property is "special": we expose it as an attribute
1580 // even though it is a native property.
1581 const auto *namedProperty = cast<const NamedProperty *>(attrOrProp);
1582 StringRef name = namedProperty->name;
1583 if (name != operandSegmentAttrName && name != resultSegmentAttrName)
1584 continue;
1585 auto &prop = namedProperty->prop;
1586 FmtContext fctx;
1587 fctx.addSubst("_ctxt", "ctx");
1588 fctx.addSubst("_storage", Twine("prop.") + name);
1589 if (name == operandSegmentAttrName) {
1590 getInherentAttrMethod
1591 << formatv(" if (name == \"operand_segment_sizes\" || name == "
1592 "\"{0}\") return ",
1593 operandSegmentAttrName);
1594 } else {
1595 getInherentAttrMethod
1596 << formatv(" if (name == \"result_segment_sizes\" || name == "
1597 "\"{0}\") return ",
1598 resultSegmentAttrName);
1600 getInherentAttrMethod << "[&]() -> ::mlir::Attribute { "
1601 << tgfmt(prop.getConvertToAttributeCall(), &fctx)
1602 << " }();\n";
1604 if (name == operandSegmentAttrName) {
1605 setInherentAttrMethod
1606 << formatv(" if (name == \"operand_segment_sizes\" || name == "
1607 "\"{0}\") {{",
1608 operandSegmentAttrName);
1609 } else {
1610 setInherentAttrMethod
1611 << formatv(" if (name == \"result_segment_sizes\" || name == "
1612 "\"{0}\") {{",
1613 resultSegmentAttrName);
1615 setInherentAttrMethod << formatv(R"decl(
1616 auto arrAttr = ::llvm::dyn_cast_or_null<::mlir::DenseI32ArrayAttr>(value);
1617 if (!arrAttr) return;
1618 if (arrAttr.size() != sizeof(prop.{0}) / sizeof(int32_t))
1619 return;
1620 llvm::copy(arrAttr.asArrayRef(), prop.{0}.begin());
1621 return;
1623 )decl",
1624 name);
1625 if (name == operandSegmentAttrName) {
1626 populateInherentAttrsMethod << formatv(
1627 " attrs.append(\"{0}\", [&]() -> ::mlir::Attribute { {1} }());\n",
1628 operandSegmentAttrName,
1629 tgfmt(prop.getConvertToAttributeCall(), &fctx));
1630 } else {
1631 populateInherentAttrsMethod << formatv(
1632 " attrs.append(\"{0}\", [&]() -> ::mlir::Attribute { {1} }());\n",
1633 resultSegmentAttrName,
1634 tgfmt(prop.getConvertToAttributeCall(), &fctx));
1637 getInherentAttrMethod << " return std::nullopt;\n";
1639 // Emit the verifiers method for backward compatibility with the generic
1640 // syntax. This method verifies the constraint on the properties attributes
1641 // before they are set, since dyn_cast<> will silently omit failures.
1642 for (const auto &attrOrProp : attrOrProperties) {
1643 const auto *namedAttr =
1644 llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
1645 if (!namedAttr || !namedAttr->constraint)
1646 continue;
1647 Attribute attr = *namedAttr->constraint;
1648 std::optional<StringRef> constraintFn =
1649 staticVerifierEmitter.getAttrConstraintFn(attr);
1650 if (!constraintFn)
1651 continue;
1652 if (canEmitAttrVerifier(attr,
1653 /*isEmittingForOp=*/false)) {
1654 std::string name = op.getGetterName(namedAttr->attrName);
1655 verifyInherentAttrsMethod
1656 << formatv(R"(
1658 ::mlir::Attribute attr = attrs.get({0}AttrName(opName));
1659 if (attr && ::mlir::failed({1}(attr, "{2}", emitError)))
1660 return ::mlir::failure();
1663 name, constraintFn, namedAttr->attrName);
1666 verifyInherentAttrsMethod << " return ::mlir::success();";
1668 // Generate methods to interact with bytecode.
1669 genPropertiesSupportForBytecode(attrOrProperties);
1672 void OpEmitter::genPropertiesSupportForBytecode(
1673 ArrayRef<ConstArgument> attrOrProperties) {
1674 if (op.useCustomPropertiesEncoding()) {
1675 opClass.declareStaticMethod(
1676 "::llvm::LogicalResult", "readProperties",
1677 MethodParameter("::mlir::DialectBytecodeReader &", "reader"),
1678 MethodParameter("::mlir::OperationState &", "state"));
1679 opClass.declareMethod(
1680 "void", "writeProperties",
1681 MethodParameter("::mlir::DialectBytecodeWriter &", "writer"));
1682 return;
1685 auto &readPropertiesMethod =
1686 opClass
1687 .addStaticMethod(
1688 "::llvm::LogicalResult", "readProperties",
1689 MethodParameter("::mlir::DialectBytecodeReader &", "reader"),
1690 MethodParameter("::mlir::OperationState &", "state"))
1691 ->body();
1693 auto &writePropertiesMethod =
1694 opClass
1695 .addMethod(
1696 "void", "writeProperties",
1697 MethodParameter("::mlir::DialectBytecodeWriter &", "writer"))
1698 ->body();
1700 // Populate bytecode serialization logic.
1701 readPropertiesMethod
1702 << " auto &prop = state.getOrAddProperties<Properties>(); (void)prop;";
1703 writePropertiesMethod << " auto &prop = getProperties(); (void)prop;\n";
1704 for (const auto &item : llvm::enumerate(attrOrProperties)) {
1705 auto &attrOrProp = item.value();
1706 FmtContext fctx;
1707 fctx.addSubst("_reader", "reader")
1708 .addSubst("_writer", "writer")
1709 .addSubst("_storage", propertyStorage)
1710 .addSubst("_ctxt", "this->getContext()");
1711 // If the op emits operand/result segment sizes as a property, emit the
1712 // legacy reader/writer in the appropriate order to allow backward
1713 // compatibility and back deployment.
1714 if (emitHelper.getOperandSegmentsSize().has_value() &&
1715 item.index() == emitHelper.getOperandSegmentSizesLegacyIndex()) {
1716 FmtContext fmtCtxt(fctx);
1717 fmtCtxt.addSubst("_propName", operandSegmentAttrName);
1718 readPropertiesMethod << tgfmt(readBytecodeSegmentSizeLegacy, &fmtCtxt);
1719 writePropertiesMethod << tgfmt(writeBytecodeSegmentSizeLegacy, &fmtCtxt);
1721 if (emitHelper.getResultSegmentsSize().has_value() &&
1722 item.index() == emitHelper.getResultSegmentSizesLegacyIndex()) {
1723 FmtContext fmtCtxt(fctx);
1724 fmtCtxt.addSubst("_propName", resultSegmentAttrName);
1725 readPropertiesMethod << tgfmt(readBytecodeSegmentSizeLegacy, &fmtCtxt);
1726 writePropertiesMethod << tgfmt(writeBytecodeSegmentSizeLegacy, &fmtCtxt);
1728 if (const auto *namedProperty =
1729 attrOrProp.dyn_cast<const NamedProperty *>()) {
1730 StringRef name = namedProperty->name;
1731 readPropertiesMethod << formatv(
1734 auto &propStorage = prop.{0};
1735 auto readProp = [&]() {
1736 {1};
1737 return ::mlir::success();
1739 if (::mlir::failed(readProp()))
1740 return ::mlir::failure();
1743 name,
1744 tgfmt(namedProperty->prop.getReadFromMlirBytecodeCall(), &fctx));
1745 writePropertiesMethod << formatv(
1748 auto &propStorage = prop.{0};
1749 {1};
1752 name, tgfmt(namedProperty->prop.getWriteToMlirBytecodeCall(), &fctx));
1753 continue;
1755 const auto *namedAttr = attrOrProp.dyn_cast<const AttributeMetadata *>();
1756 StringRef name = namedAttr->attrName;
1757 if (namedAttr->isRequired) {
1758 readPropertiesMethod << formatv(R"(
1759 if (::mlir::failed(reader.readAttribute(prop.{0})))
1760 return ::mlir::failure();
1762 name);
1763 writePropertiesMethod
1764 << formatv(" writer.writeAttribute(prop.{0});\n", name);
1765 } else {
1766 readPropertiesMethod << formatv(R"(
1767 if (::mlir::failed(reader.readOptionalAttribute(prop.{0})))
1768 return ::mlir::failure();
1770 name);
1771 writePropertiesMethod << formatv(R"(
1772 writer.writeOptionalAttribute(prop.{0});
1774 name);
1777 readPropertiesMethod << " return ::mlir::success();";
1780 void OpEmitter::genPropGetters() {
1781 for (const NamedProperty &prop : op.getProperties()) {
1782 std::string name = op.getGetterName(prop.name);
1783 emitPropGetter(opClass, op, name, prop.prop);
1787 void OpEmitter::genPropSetters() {
1788 for (const NamedProperty &prop : op.getProperties()) {
1789 std::string name = op.getSetterName(prop.name);
1790 std::string argName = "new" + convertToCamelFromSnakeCase(
1791 prop.name, /*capitalizeFirst=*/true);
1792 auto *method = opClass.addInlineMethod(
1793 "void", name, MethodParameter(prop.prop.getInterfaceType(), argName));
1794 if (!method)
1795 return;
1796 method->body() << formatv(" getProperties().{0}({1});", name, argName);
1800 void OpEmitter::genAttrGetters() {
1801 FmtContext fctx;
1802 fctx.withBuilder("::mlir::Builder((*this)->getContext())");
1804 // Emit the derived attribute body.
1805 auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
1806 if (auto *method = opClass.addMethod(attr.getReturnType(), name))
1807 method->body() << " " << attr.getDerivedCodeBody() << "\n";
1810 // Generate named accessor with Attribute return type. This is a wrapper
1811 // class that allows referring to the attributes via accessors instead of
1812 // having to use the string interface for better compile time verification.
1813 auto emitAttrWithStorageType = [&](StringRef name, StringRef attrName,
1814 Attribute attr) {
1815 // The method body for this getter is trivial. Emit it inline.
1816 auto *method =
1817 opClass.addInlineMethod(attr.getStorageType(), name + "Attr");
1818 if (!method)
1819 return;
1820 method->body() << formatv(
1821 " return ::llvm::{1}<{2}>({0});", emitHelper.getAttr(attrName),
1822 attr.isOptional() || attr.hasDefaultValue() ? "dyn_cast_or_null"
1823 : "cast",
1824 attr.getStorageType());
1827 for (const NamedAttribute &namedAttr : op.getAttributes()) {
1828 std::string name = op.getGetterName(namedAttr.name);
1829 if (namedAttr.attr.isDerivedAttr()) {
1830 emitDerivedAttr(name, namedAttr.attr);
1831 } else {
1832 emitAttrWithStorageType(name, namedAttr.name, namedAttr.attr);
1833 emitAttrGetterWithReturnType(fctx, opClass, op, name, namedAttr.attr);
1837 auto derivedAttrs = make_filter_range(op.getAttributes(),
1838 [](const NamedAttribute &namedAttr) {
1839 return namedAttr.attr.isDerivedAttr();
1841 if (derivedAttrs.empty())
1842 return;
1844 opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait");
1845 // Generate helper method to query whether a named attribute is a derived
1846 // attribute. This enables, for example, avoiding adding an attribute that
1847 // overlaps with a derived attribute.
1849 auto *method =
1850 opClass.addStaticMethod("bool", "isDerivedAttribute",
1851 MethodParameter("::llvm::StringRef", "name"));
1852 ERROR_IF_PRUNED(method, "isDerivedAttribute", op);
1853 auto &body = method->body();
1854 for (auto namedAttr : derivedAttrs)
1855 body << " if (name == \"" << namedAttr.name << "\") return true;\n";
1856 body << " return false;";
1858 // Generate method to materialize derived attributes as a DictionaryAttr.
1860 auto *method = opClass.addMethod("::mlir::DictionaryAttr",
1861 "materializeDerivedAttributes");
1862 ERROR_IF_PRUNED(method, "materializeDerivedAttributes", op);
1863 auto &body = method->body();
1865 auto nonMaterializable =
1866 make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) {
1867 return namedAttr.attr.getConvertFromStorageCall().empty();
1869 if (!nonMaterializable.empty()) {
1870 std::string attrs;
1871 llvm::raw_string_ostream os(attrs);
1872 interleaveComma(nonMaterializable, os, [&](const NamedAttribute &attr) {
1873 os << op.getGetterName(attr.name);
1875 PrintWarning(
1876 op.getLoc(),
1877 formatv(
1878 "op has non-materializable derived attributes '{0}', skipping",
1879 os.str()));
1880 body << formatv(" emitOpError(\"op has non-materializable derived "
1881 "attributes '{0}'\");\n",
1882 attrs);
1883 body << " return nullptr;";
1884 return;
1887 body << " ::mlir::MLIRContext* ctx = getContext();\n";
1888 body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
1889 body << " return ::mlir::DictionaryAttr::get(";
1890 body << " ctx, {\n";
1891 interleave(
1892 derivedAttrs, body,
1893 [&](const NamedAttribute &namedAttr) {
1894 auto tmpl = namedAttr.attr.getConvertFromStorageCall();
1895 std::string name = op.getGetterName(namedAttr.name);
1896 body << " {" << name << "AttrName(),\n"
1897 << tgfmt(tmpl, &fctx.withSelf(name + "()")
1898 .withBuilder("odsBuilder")
1899 .addSubst("_ctxt", "ctx")
1900 .addSubst("_storage", "ctx"))
1901 << "}";
1903 ",\n");
1904 body << "});";
1908 void OpEmitter::genAttrSetters() {
1909 bool useProperties = op.getDialect().usePropertiesForAttributes();
1911 // Generate the code to set an attribute.
1912 auto emitSetAttr = [&](Method *method, StringRef getterName,
1913 StringRef attrName, StringRef attrVar) {
1914 if (useProperties) {
1915 method->body() << formatv(" getProperties().{0} = {1};", attrName,
1916 attrVar);
1917 } else {
1918 method->body() << formatv(" (*this)->setAttr({0}AttrName(), {1});",
1919 getterName, attrVar);
1923 // Generate raw named setter type. This is a wrapper class that allows setting
1924 // to the attributes via setters instead of having to use the string interface
1925 // for better compile time verification.
1926 auto emitAttrWithStorageType = [&](StringRef setterName, StringRef getterName,
1927 StringRef attrName, Attribute attr) {
1928 // This method body is trivial, so emit it inline.
1929 auto *method =
1930 opClass.addInlineMethod("void", setterName + "Attr",
1931 MethodParameter(attr.getStorageType(), "attr"));
1932 if (method)
1933 emitSetAttr(method, getterName, attrName, "attr");
1936 // Generate a setter that accepts the underlying C++ type as opposed to the
1937 // attribute type.
1938 auto emitAttrWithReturnType = [&](StringRef setterName, StringRef getterName,
1939 StringRef attrName, Attribute attr) {
1940 Attribute baseAttr = attr.getBaseAttr();
1941 if (!canUseUnwrappedRawValue(baseAttr))
1942 return;
1943 FmtContext fctx;
1944 fctx.withBuilder("::mlir::Builder((*this)->getContext())");
1945 bool isUnitAttr = attr.getAttrDefName() == "UnitAttr";
1946 bool isOptional = attr.isOptional();
1948 auto createMethod = [&](const Twine &paramType) {
1949 return opClass.addMethod("void", setterName,
1950 MethodParameter(paramType.str(), "attrValue"));
1953 // Build the method using the correct parameter type depending on
1954 // optionality.
1955 Method *method = nullptr;
1956 if (isUnitAttr)
1957 method = createMethod("bool");
1958 else if (isOptional)
1959 method =
1960 createMethod("::std::optional<" + baseAttr.getReturnType() + ">");
1961 else
1962 method = createMethod(attr.getReturnType());
1963 if (!method)
1964 return;
1966 // If the value isn't optional, just set it directly.
1967 if (!isOptional) {
1968 emitSetAttr(method, getterName, attrName,
1969 constBuildAttrFromParam(attr, fctx, "attrValue"));
1970 return;
1973 // Otherwise, we only set if the provided value is valid. If it isn't, we
1974 // remove the attribute.
1976 // TODO: Handle unit attr parameters specially, given that it is treated as
1977 // optional but not in the same way as the others (i.e. it uses bool over
1978 // std::optional<>).
1979 StringRef paramStr = isUnitAttr ? "attrValue" : "*attrValue";
1980 if (!useProperties) {
1981 const char *optionalCodeBody = R"(
1982 if (attrValue)
1983 return (*this)->setAttr({0}AttrName(), {1});
1984 (*this)->removeAttr({0}AttrName());)";
1985 method->body() << formatv(
1986 optionalCodeBody, getterName,
1987 constBuildAttrFromParam(baseAttr, fctx, paramStr));
1988 } else {
1989 const char *optionalCodeBody = R"(
1990 auto &odsProp = getProperties().{0};
1991 if (attrValue)
1992 odsProp = {1};
1993 else
1994 odsProp = nullptr;)";
1995 method->body() << formatv(
1996 optionalCodeBody, attrName,
1997 constBuildAttrFromParam(baseAttr, fctx, paramStr));
2001 for (const NamedAttribute &namedAttr : op.getAttributes()) {
2002 if (namedAttr.attr.isDerivedAttr())
2003 continue;
2004 std::string setterName = op.getSetterName(namedAttr.name);
2005 std::string getterName = op.getGetterName(namedAttr.name);
2006 emitAttrWithStorageType(setterName, getterName, namedAttr.name,
2007 namedAttr.attr);
2008 emitAttrWithReturnType(setterName, getterName, namedAttr.name,
2009 namedAttr.attr);
2013 void OpEmitter::genOptionalAttrRemovers() {
2014 // Generate methods for removing optional attributes, instead of having to
2015 // use the string interface. Enables better compile time verification.
2016 auto emitRemoveAttr = [&](StringRef name, bool useProperties) {
2017 auto upperInitial = name.take_front().upper();
2018 auto *method = opClass.addInlineMethod("::mlir::Attribute",
2019 op.getRemoverName(name) + "Attr");
2020 if (!method)
2021 return;
2022 if (useProperties) {
2023 method->body() << formatv(R"(
2024 auto &attr = getProperties().{0};
2025 attr = {{};
2026 return attr;
2028 name);
2029 return;
2031 method->body() << formatv("return (*this)->removeAttr({0}AttrName());",
2032 op.getGetterName(name));
2035 for (const NamedAttribute &namedAttr : op.getAttributes())
2036 if (namedAttr.attr.isOptional())
2037 emitRemoveAttr(namedAttr.name,
2038 op.getDialect().usePropertiesForAttributes());
2041 // Generates the code to compute the start and end index of an operand or result
2042 // range.
2043 template <typename RangeT>
2044 static void generateValueRangeStartAndEnd(
2045 Class &opClass, bool isGenericAdaptorBase, StringRef methodName,
2046 int numVariadic, int numNonVariadic, StringRef rangeSizeCall,
2047 bool hasAttrSegmentSize, StringRef sizeAttrInit, RangeT &&odsValues) {
2049 SmallVector<MethodParameter> parameters{MethodParameter("unsigned", "index")};
2050 if (isGenericAdaptorBase) {
2051 parameters.emplace_back("unsigned", "odsOperandsSize");
2052 // The range size is passed per parameter for generic adaptor bases as
2053 // using the rangeSizeCall would require the operands, which are not
2054 // accessible in the base class.
2055 rangeSizeCall = "odsOperandsSize";
2058 // The method is trivial if the operation does not have any variadic operands.
2059 // In that case, make sure to generate it in-line.
2060 auto *method = opClass.addMethod("std::pair<unsigned, unsigned>", methodName,
2061 numVariadic == 0 ? Method::Properties::Inline
2062 : Method::Properties::None,
2063 parameters);
2064 if (!method)
2065 return;
2066 auto &body = method->body();
2067 if (numVariadic == 0) {
2068 body << " return {index, 1};\n";
2069 } else if (hasAttrSegmentSize) {
2070 body << sizeAttrInit << attrSizedSegmentValueRangeCalcCode;
2071 } else {
2072 // Because the op can have arbitrarily interleaved variadic and non-variadic
2073 // operands, we need to embed a list in the "sink" getter method for
2074 // calculation at run-time.
2075 SmallVector<StringRef, 4> isVariadic;
2076 isVariadic.reserve(llvm::size(odsValues));
2077 for (auto &it : odsValues)
2078 isVariadic.push_back(it.isVariableLength() ? "true" : "false");
2079 std::string isVariadicList = llvm::join(isVariadic, ", ");
2080 body << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
2081 numNonVariadic, numVariadic, rangeSizeCall, "operand");
2085 static std::string generateTypeForGetter(const NamedTypeConstraint &value) {
2086 return llvm::formatv("::mlir::TypedValue<{0}>", value.constraint.getCppType())
2087 .str();
2090 // Generates the named operand getter methods for the given Operator `op` and
2091 // puts them in `opClass`. Uses `rangeType` as the return type of getters that
2092 // return a range of operands (individual operands are `Value ` and each
2093 // element in the range must also be `Value `); use `rangeBeginCall` to get
2094 // an iterator to the beginning of the operand range; use `rangeSizeCall` to
2095 // obtain the number of operands. `getOperandCallPattern` contains the code
2096 // necessary to obtain a single operand whose position will be substituted
2097 // instead of
2098 // "{0}" marker in the pattern. Note that the pattern should work for any kind
2099 // of ops, in particular for one-operand ops that may not have the
2100 // `getOperand(unsigned)` method.
2101 static void
2102 generateNamedOperandGetters(const Operator &op, Class &opClass,
2103 Class *genericAdaptorBase, StringRef sizeAttrInit,
2104 StringRef rangeType, StringRef rangeElementType,
2105 StringRef rangeBeginCall, StringRef rangeSizeCall,
2106 StringRef getOperandCallPattern) {
2107 const int numOperands = op.getNumOperands();
2108 const int numVariadicOperands = op.getNumVariableLengthOperands();
2109 const int numNormalOperands = numOperands - numVariadicOperands;
2111 const auto *sameVariadicSize =
2112 op.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
2113 const auto *attrSizedOperands =
2114 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
2116 if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) {
2117 PrintFatalError(op.getLoc(), "op has multiple variadic operands but no "
2118 "specification over their sizes");
2121 if (numVariadicOperands < 2 && attrSizedOperands) {
2122 PrintFatalError(op.getLoc(), "op must have at least two variadic operands "
2123 "to use 'AttrSizedOperandSegments' trait");
2126 if (attrSizedOperands && sameVariadicSize) {
2127 PrintFatalError(op.getLoc(),
2128 "op cannot have both 'AttrSizedOperandSegments' and "
2129 "'SameVariadicOperandSize' traits");
2132 // First emit a few "sink" getter methods upon which we layer all nicer named
2133 // getter methods.
2134 // If generating for an adaptor, the method is put into the non-templated
2135 // generic base class, to not require being defined in the header.
2136 // Since the operand size can't be determined from the base class however,
2137 // it has to be passed as an additional argument. The trampoline below
2138 // generates the function with the same signature as the Op in the generic
2139 // adaptor.
2140 bool isGenericAdaptorBase = genericAdaptorBase != nullptr;
2141 generateValueRangeStartAndEnd(
2142 /*opClass=*/isGenericAdaptorBase ? *genericAdaptorBase : opClass,
2143 isGenericAdaptorBase,
2144 /*methodName=*/"getODSOperandIndexAndLength", numVariadicOperands,
2145 numNormalOperands, rangeSizeCall, attrSizedOperands, sizeAttrInit,
2146 const_cast<Operator &>(op).getOperands());
2147 if (isGenericAdaptorBase) {
2148 // Generate trampoline for calling 'getODSOperandIndexAndLength' with just
2149 // the index. This just calls the implementation in the base class but
2150 // passes the operand size as parameter.
2151 Method *method = opClass.addInlineMethod(
2152 "std::pair<unsigned, unsigned>", "getODSOperandIndexAndLength",
2153 MethodParameter("unsigned", "index"));
2154 ERROR_IF_PRUNED(method, "getODSOperandIndexAndLength", op);
2155 MethodBody &body = method->body();
2156 body.indent() << formatv(
2157 "return Base::getODSOperandIndexAndLength(index, {0});", rangeSizeCall);
2160 // The implementation of this method is trivial and it is very load-bearing.
2161 // Generate it inline.
2162 auto *m = opClass.addInlineMethod(rangeType, "getODSOperands",
2163 MethodParameter("unsigned", "index"));
2164 ERROR_IF_PRUNED(m, "getODSOperands", op);
2165 auto &body = m->body();
2166 body << formatv(valueRangeReturnCode, rangeBeginCall,
2167 "getODSOperandIndexAndLength(index)");
2169 // Then we emit nicer named getter methods by redirecting to the "sink" getter
2170 // method.
2171 for (int i = 0; i != numOperands; ++i) {
2172 const auto &operand = op.getOperand(i);
2173 if (operand.name.empty())
2174 continue;
2175 std::string name = op.getGetterName(operand.name);
2176 if (operand.isOptional()) {
2177 m = opClass.addInlineMethod(isGenericAdaptorBase
2178 ? rangeElementType
2179 : generateTypeForGetter(operand),
2180 name);
2181 ERROR_IF_PRUNED(m, name, op);
2182 m->body().indent() << formatv("auto operands = getODSOperands({0});\n"
2183 "return operands.empty() ? {1}{{} : ",
2184 i, m->getReturnType());
2185 if (!isGenericAdaptorBase)
2186 m->body() << llvm::formatv("::llvm::cast<{0}>", m->getReturnType());
2187 m->body() << "(*operands.begin());";
2188 } else if (operand.isVariadicOfVariadic()) {
2189 std::string segmentAttr = op.getGetterName(
2190 operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
2191 if (genericAdaptorBase) {
2192 m = opClass.addMethod("::llvm::SmallVector<" + rangeType + ">", name);
2193 ERROR_IF_PRUNED(m, name, op);
2194 m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode,
2195 segmentAttr, i, rangeType);
2196 continue;
2199 m = opClass.addInlineMethod("::mlir::OperandRangeRange", name);
2200 ERROR_IF_PRUNED(m, name, op);
2201 m->body() << " return getODSOperands(" << i << ").split(" << segmentAttr
2202 << "Attr());";
2203 } else if (operand.isVariadic()) {
2204 m = opClass.addInlineMethod(rangeType, name);
2205 ERROR_IF_PRUNED(m, name, op);
2206 m->body() << " return getODSOperands(" << i << ");";
2207 } else {
2208 m = opClass.addInlineMethod(isGenericAdaptorBase
2209 ? rangeElementType
2210 : generateTypeForGetter(operand),
2211 name);
2212 ERROR_IF_PRUNED(m, name, op);
2213 m->body().indent() << "return ";
2214 if (!isGenericAdaptorBase)
2215 m->body() << llvm::formatv("::llvm::cast<{0}>", m->getReturnType());
2216 m->body() << llvm::formatv("(*getODSOperands({0}).begin());", i);
2221 void OpEmitter::genNamedOperandGetters() {
2222 // Build the code snippet used for initializing the operand_segment_size)s
2223 // array.
2224 std::string attrSizeInitCode;
2225 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
2226 if (op.getDialect().usePropertiesForAttributes())
2227 attrSizeInitCode = formatv(adapterSegmentSizeAttrInitCodeProperties,
2228 "getProperties().operandSegmentSizes");
2230 else
2231 attrSizeInitCode = formatv(opSegmentSizeAttrInitCode,
2232 emitHelper.getAttr(operandSegmentAttrName));
2235 generateNamedOperandGetters(
2236 op, opClass,
2237 /*genericAdaptorBase=*/nullptr,
2238 /*sizeAttrInit=*/attrSizeInitCode,
2239 /*rangeType=*/"::mlir::Operation::operand_range",
2240 /*rangeElementType=*/"::mlir::Value",
2241 /*rangeBeginCall=*/"getOperation()->operand_begin()",
2242 /*rangeSizeCall=*/"getOperation()->getNumOperands()",
2243 /*getOperandCallPattern=*/"getOperation()->getOperand({0})");
2246 void OpEmitter::genNamedOperandSetters() {
2247 auto *attrSizedOperands =
2248 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
2249 for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
2250 const auto &operand = op.getOperand(i);
2251 if (operand.name.empty())
2252 continue;
2253 std::string name = op.getGetterName(operand.name);
2255 StringRef returnType;
2256 if (operand.isVariadicOfVariadic()) {
2257 returnType = "::mlir::MutableOperandRangeRange";
2258 } else if (operand.isVariableLength()) {
2259 returnType = "::mlir::MutableOperandRange";
2260 } else {
2261 returnType = "::mlir::OpOperand &";
2263 bool isVariadicOperand =
2264 operand.isVariadicOfVariadic() || operand.isVariableLength();
2265 auto *m = opClass.addMethod(returnType, name + "Mutable",
2266 isVariadicOperand ? Method::Properties::None
2267 : Method::Properties::Inline);
2268 ERROR_IF_PRUNED(m, name, op);
2269 auto &body = m->body();
2270 body << " auto range = getODSOperandIndexAndLength(" << i << ");\n";
2272 if (!isVariadicOperand) {
2273 // In case of a single operand, return a single OpOperand.
2274 body << " return getOperation()->getOpOperand(range.first);\n";
2275 continue;
2278 body << " auto mutableRange = "
2279 "::mlir::MutableOperandRange(getOperation(), "
2280 "range.first, range.second";
2281 if (attrSizedOperands) {
2282 if (emitHelper.hasProperties())
2283 body << formatv(", ::mlir::MutableOperandRange::OperandSegment({0}u, "
2284 "{{getOperandSegmentSizesAttrName(), "
2285 "::mlir::DenseI32ArrayAttr::get(getContext(), "
2286 "getProperties().operandSegmentSizes)})",
2288 else
2289 body << formatv(
2290 ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i,
2291 emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true));
2293 body << ");\n";
2295 // If this operand is a nested variadic, we split the range into a
2296 // MutableOperandRangeRange that provides a range over all of the
2297 // sub-ranges.
2298 if (operand.isVariadicOfVariadic()) {
2299 body << " return "
2300 "mutableRange.split(*(*this)->getAttrDictionary().getNamed("
2301 << op.getGetterName(
2302 operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
2303 << "AttrName()));\n";
2304 } else {
2305 // Otherwise, we use the full range directly.
2306 body << " return mutableRange;\n";
2311 void OpEmitter::genNamedResultGetters() {
2312 const int numResults = op.getNumResults();
2313 const int numVariadicResults = op.getNumVariableLengthResults();
2314 const int numNormalResults = numResults - numVariadicResults;
2316 // If we have more than one variadic results, we need more complicated logic
2317 // to calculate the value range for each result.
2319 const auto *sameVariadicSize =
2320 op.getTrait("::mlir::OpTrait::SameVariadicResultSize");
2321 const auto *attrSizedResults =
2322 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments");
2324 if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) {
2325 PrintFatalError(op.getLoc(), "op has multiple variadic results but no "
2326 "specification over their sizes");
2329 if (numVariadicResults < 2 && attrSizedResults) {
2330 PrintFatalError(op.getLoc(), "op must have at least two variadic results "
2331 "to use 'AttrSizedResultSegments' trait");
2334 if (attrSizedResults && sameVariadicSize) {
2335 PrintFatalError(op.getLoc(),
2336 "op cannot have both 'AttrSizedResultSegments' and "
2337 "'SameVariadicResultSize' traits");
2340 // Build the initializer string for the result segment size attribute.
2341 std::string attrSizeInitCode;
2342 if (attrSizedResults) {
2343 if (op.getDialect().usePropertiesForAttributes())
2344 attrSizeInitCode = formatv(adapterSegmentSizeAttrInitCodeProperties,
2345 "getProperties().resultSegmentSizes");
2347 else
2348 attrSizeInitCode = formatv(opSegmentSizeAttrInitCode,
2349 emitHelper.getAttr(resultSegmentAttrName));
2352 generateValueRangeStartAndEnd(
2353 opClass, /*isGenericAdaptorBase=*/false, "getODSResultIndexAndLength",
2354 numVariadicResults, numNormalResults, "getOperation()->getNumResults()",
2355 attrSizedResults, attrSizeInitCode, op.getResults());
2357 // The implementation of this method is trivial and it is very load-bearing.
2358 // Generate it inline.
2359 auto *m = opClass.addInlineMethod("::mlir::Operation::result_range",
2360 "getODSResults",
2361 MethodParameter("unsigned", "index"));
2362 ERROR_IF_PRUNED(m, "getODSResults", op);
2363 m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
2364 "getODSResultIndexAndLength(index)");
2366 for (int i = 0; i != numResults; ++i) {
2367 const auto &result = op.getResult(i);
2368 if (result.name.empty())
2369 continue;
2370 std::string name = op.getGetterName(result.name);
2371 if (result.isOptional()) {
2372 m = opClass.addInlineMethod(generateTypeForGetter(result), name);
2373 ERROR_IF_PRUNED(m, name, op);
2374 m->body() << " auto results = getODSResults(" << i << ");\n"
2375 << llvm::formatv(" return results.empty()"
2376 " ? {0}()"
2377 " : ::llvm::cast<{0}>(*results.begin());",
2378 m->getReturnType());
2379 } else if (result.isVariadic()) {
2380 m = opClass.addInlineMethod("::mlir::Operation::result_range", name);
2381 ERROR_IF_PRUNED(m, name, op);
2382 m->body() << " return getODSResults(" << i << ");";
2383 } else {
2384 m = opClass.addInlineMethod(generateTypeForGetter(result), name);
2385 ERROR_IF_PRUNED(m, name, op);
2386 m->body() << llvm::formatv(
2387 " return ::llvm::cast<{0}>(*getODSResults({1}).begin());",
2388 m->getReturnType(), i);
2393 void OpEmitter::genNamedRegionGetters() {
2394 unsigned numRegions = op.getNumRegions();
2395 for (unsigned i = 0; i < numRegions; ++i) {
2396 const auto &region = op.getRegion(i);
2397 if (region.name.empty())
2398 continue;
2399 std::string name = op.getGetterName(region.name);
2401 // Generate the accessors for a variadic region.
2402 if (region.isVariadic()) {
2403 auto *m = opClass.addInlineMethod(
2404 "::mlir::MutableArrayRef<::mlir::Region>", name);
2405 ERROR_IF_PRUNED(m, name, op);
2406 m->body() << formatv(" return (*this)->getRegions().drop_front({0});",
2408 continue;
2411 auto *m = opClass.addInlineMethod("::mlir::Region &", name);
2412 ERROR_IF_PRUNED(m, name, op);
2413 m->body() << formatv(" return (*this)->getRegion({0});", i);
2417 void OpEmitter::genNamedSuccessorGetters() {
2418 unsigned numSuccessors = op.getNumSuccessors();
2419 for (unsigned i = 0; i < numSuccessors; ++i) {
2420 const NamedSuccessor &successor = op.getSuccessor(i);
2421 if (successor.name.empty())
2422 continue;
2423 std::string name = op.getGetterName(successor.name);
2424 // Generate the accessors for a variadic successor list.
2425 if (successor.isVariadic()) {
2426 auto *m = opClass.addInlineMethod("::mlir::SuccessorRange", name);
2427 ERROR_IF_PRUNED(m, name, op);
2428 m->body() << formatv(
2429 " return {std::next((*this)->successor_begin(), {0}), "
2430 "(*this)->successor_end()};",
2432 continue;
2435 auto *m = opClass.addInlineMethod("::mlir::Block *", name);
2436 ERROR_IF_PRUNED(m, name, op);
2437 m->body() << formatv(" return (*this)->getSuccessor({0});", i);
2441 static bool canGenerateUnwrappedBuilder(const Operator &op) {
2442 // If this op does not have native attributes at all, return directly to avoid
2443 // redefining builders.
2444 if (op.getNumNativeAttributes() == 0)
2445 return false;
2447 bool canGenerate = false;
2448 // We are generating builders that take raw values for attributes. We need to
2449 // make sure the native attributes have a meaningful "unwrapped" value type
2450 // different from the wrapped mlir::Attribute type to avoid redefining
2451 // builders. This checks for the op has at least one such native attribute.
2452 for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) {
2453 const NamedAttribute &namedAttr = op.getAttribute(i);
2454 if (canUseUnwrappedRawValue(namedAttr.attr)) {
2455 canGenerate = true;
2456 break;
2459 return canGenerate;
2462 static bool canInferType(const Operator &op) {
2463 return op.getTrait("::mlir::InferTypeOpInterface::Trait");
2466 void OpEmitter::genSeparateArgParamBuilder() {
2467 SmallVector<AttrParamKind, 2> attrBuilderType;
2468 attrBuilderType.push_back(AttrParamKind::WrappedAttr);
2469 if (canGenerateUnwrappedBuilder(op))
2470 attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
2472 // Emit with separate builders with or without unwrapped attributes and/or
2473 // inferring result type.
2474 auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
2475 bool inferType) {
2476 SmallVector<MethodParameter> paramList;
2477 SmallVector<std::string, 4> resultNames;
2478 llvm::StringSet<> inferredAttributes;
2479 buildParamList(paramList, inferredAttributes, resultNames, paramKind,
2480 attrType);
2482 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2483 // If the builder is redundant, skip generating the method.
2484 if (!m)
2485 return;
2486 auto &body = m->body();
2487 genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
2488 /*isRawValueAttr=*/attrType ==
2489 AttrParamKind::UnwrappedValue);
2491 // Push all result types to the operation state
2493 if (inferType) {
2494 // Generate builder that infers type too.
2495 // TODO: Subsume this with general checking if type can be
2496 // inferred automatically.
2497 body << formatv(R"(
2498 ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
2499 if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
2500 {1}.location, {1}.operands,
2501 {1}.attributes.getDictionary({1}.getContext()),
2502 {1}.getRawProperties(),
2503 {1}.regions, inferredReturnTypes)))
2504 {1}.addTypes(inferredReturnTypes);
2505 else
2506 ::mlir::detail::reportFatalInferReturnTypesError({1});
2508 opClass.getClassName(), builderOpState);
2509 return;
2512 switch (paramKind) {
2513 case TypeParamKind::None:
2514 return;
2515 case TypeParamKind::Separate:
2516 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
2517 if (op.getResult(i).isOptional())
2518 body << " if (" << resultNames[i] << ")\n ";
2519 body << " " << builderOpState << ".addTypes(" << resultNames[i]
2520 << ");\n";
2523 // Automatically create the 'resultSegmentSizes' attribute using
2524 // the length of the type ranges.
2525 if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
2526 if (op.getDialect().usePropertiesForAttributes()) {
2527 body << " ::llvm::copy(::llvm::ArrayRef<int32_t>({";
2528 } else {
2529 std::string getterName = op.getGetterName(resultSegmentAttrName);
2530 body << " " << builderOpState << ".addAttribute(" << getterName
2531 << "AttrName(" << builderOpState << ".name), "
2532 << "odsBuilder.getDenseI32ArrayAttr({";
2534 interleaveComma(
2535 llvm::seq<int>(0, op.getNumResults()), body, [&](int i) {
2536 const NamedTypeConstraint &result = op.getResult(i);
2537 if (!result.isVariableLength()) {
2538 body << "1";
2539 } else if (result.isOptional()) {
2540 body << "(" << resultNames[i] << " ? 1 : 0)";
2541 } else {
2542 // VariadicOfVariadic of results are currently unsupported in
2543 // MLIR, hence it can only be a simple variadic.
2544 // TODO: Add implementation for VariadicOfVariadic results here
2545 // once supported.
2546 assert(result.isVariadic());
2547 body << "static_cast<int32_t>(" << resultNames[i] << ".size())";
2550 if (op.getDialect().usePropertiesForAttributes()) {
2551 body << "}), " << builderOpState
2552 << ".getOrAddProperties<Properties>()."
2553 "resultSegmentSizes.begin());\n";
2554 } else {
2555 body << "}));\n";
2559 return;
2560 case TypeParamKind::Collective: {
2561 int numResults = op.getNumResults();
2562 int numVariadicResults = op.getNumVariableLengthResults();
2563 int numNonVariadicResults = numResults - numVariadicResults;
2564 bool hasVariadicResult = numVariadicResults != 0;
2566 // Avoid emitting "resultTypes.size() >= 0u" which is always true.
2567 if (!hasVariadicResult || numNonVariadicResults != 0)
2568 body << " "
2569 << "assert(resultTypes.size() "
2570 << (hasVariadicResult ? ">=" : "==") << " "
2571 << numNonVariadicResults
2572 << "u && \"mismatched number of results\");\n";
2573 body << " " << builderOpState << ".addTypes(resultTypes);\n";
2575 return;
2577 llvm_unreachable("unhandled TypeParamKind");
2580 // Some of the build methods generated here may be ambiguous, but TableGen's
2581 // ambiguous function detection will elide those ones.
2582 for (auto attrType : attrBuilderType) {
2583 emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
2584 if (canInferType(op))
2585 emit(attrType, TypeParamKind::None, /*inferType=*/true);
2586 emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
2590 void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
2591 int numResults = op.getNumResults();
2593 // Signature
2594 SmallVector<MethodParameter> paramList;
2595 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
2596 paramList.emplace_back("::mlir::OperationState &", builderOpState);
2597 paramList.emplace_back("::mlir::ValueRange", "operands");
2598 // Provide default value for `attributes` when its the last parameter
2599 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
2600 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
2601 "attributes", attributesDefaultValue);
2602 if (op.getNumVariadicRegions())
2603 paramList.emplace_back("unsigned", "numRegions");
2605 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2606 // If the builder is redundant, skip generating the method
2607 if (!m)
2608 return;
2609 auto &body = m->body();
2611 // Operands
2612 body << " " << builderOpState << ".addOperands(operands);\n";
2614 // Attributes
2615 body << " " << builderOpState << ".addAttributes(attributes);\n";
2617 // Create the correct number of regions
2618 if (int numRegions = op.getNumRegions()) {
2619 body << llvm::formatv(
2620 " for (unsigned i = 0; i != {0}; ++i)\n",
2621 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
2622 body << " (void)" << builderOpState << ".addRegion();\n";
2625 // Result types
2626 SmallVector<std::string, 2> resultTypes(numResults, "operands[0].getType()");
2627 body << " " << builderOpState << ".addTypes({"
2628 << llvm::join(resultTypes, ", ") << "});\n\n";
2631 void OpEmitter::genPopulateDefaultAttributes() {
2632 // All done if no attributes, except optional ones, have default values.
2633 if (llvm::all_of(op.getAttributes(), [](const NamedAttribute &named) {
2634 return !named.attr.hasDefaultValue() || named.attr.isOptional();
2636 return;
2638 if (emitHelper.hasProperties()) {
2639 SmallVector<MethodParameter> paramList;
2640 paramList.emplace_back("::mlir::OperationName", "opName");
2641 paramList.emplace_back("Properties &", "properties");
2642 auto *m =
2643 opClass.addStaticMethod("void", "populateDefaultProperties", paramList);
2644 ERROR_IF_PRUNED(m, "populateDefaultProperties", op);
2645 auto &body = m->body();
2646 body.indent();
2647 body << "::mlir::Builder " << odsBuilder << "(opName.getContext());\n";
2648 for (const NamedAttribute &namedAttr : op.getAttributes()) {
2649 auto &attr = namedAttr.attr;
2650 if (!attr.hasDefaultValue() || attr.isOptional())
2651 continue;
2652 StringRef name = namedAttr.name;
2653 FmtContext fctx;
2654 fctx.withBuilder(odsBuilder);
2655 body << "if (!properties." << name << ")\n"
2656 << " properties." << name << " = "
2657 << std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx,
2658 tgfmt(attr.getDefaultValue(), &fctx)))
2659 << ";\n";
2661 return;
2664 SmallVector<MethodParameter> paramList;
2665 paramList.emplace_back("const ::mlir::OperationName &", "opName");
2666 paramList.emplace_back("::mlir::NamedAttrList &", "attributes");
2667 auto *m = opClass.addStaticMethod("void", "populateDefaultAttrs", paramList);
2668 ERROR_IF_PRUNED(m, "populateDefaultAttrs", op);
2669 auto &body = m->body();
2670 body.indent();
2672 // Set default attributes that are unset.
2673 body << "auto attrNames = opName.getAttributeNames();\n";
2674 body << "::mlir::Builder " << odsBuilder
2675 << "(attrNames.front().getContext());\n";
2676 StringMap<int> attrIndex;
2677 for (const auto &it : llvm::enumerate(emitHelper.getAttrMetadata())) {
2678 attrIndex[it.value().first] = it.index();
2680 for (const NamedAttribute &namedAttr : op.getAttributes()) {
2681 auto &attr = namedAttr.attr;
2682 if (!attr.hasDefaultValue() || attr.isOptional())
2683 continue;
2684 auto index = attrIndex[namedAttr.name];
2685 body << "if (!attributes.get(attrNames[" << index << "])) {\n";
2686 FmtContext fctx;
2687 fctx.withBuilder(odsBuilder);
2689 std::string defaultValue =
2690 std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx,
2691 tgfmt(attr.getDefaultValue(), &fctx)));
2692 body.indent() << formatv("attributes.append(attrNames[{0}], {1});\n", index,
2693 defaultValue);
2694 body.unindent() << "}\n";
2698 void OpEmitter::genInferredTypeCollectiveParamBuilder() {
2699 SmallVector<MethodParameter> paramList;
2700 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
2701 paramList.emplace_back("::mlir::OperationState &", builderOpState);
2702 paramList.emplace_back("::mlir::ValueRange", "operands");
2703 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
2704 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
2705 "attributes", attributesDefaultValue);
2706 if (op.getNumVariadicRegions())
2707 paramList.emplace_back("unsigned", "numRegions");
2709 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2710 // If the builder is redundant, skip generating the method
2711 if (!m)
2712 return;
2713 auto &body = m->body();
2715 int numResults = op.getNumResults();
2716 int numVariadicResults = op.getNumVariableLengthResults();
2717 int numNonVariadicResults = numResults - numVariadicResults;
2719 int numOperands = op.getNumOperands();
2720 int numVariadicOperands = op.getNumVariableLengthOperands();
2721 int numNonVariadicOperands = numOperands - numVariadicOperands;
2723 // Operands
2724 if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
2725 body << " assert(operands.size()"
2726 << (numVariadicOperands != 0 ? " >= " : " == ")
2727 << numNonVariadicOperands
2728 << "u && \"mismatched number of parameters\");\n";
2729 body << " " << builderOpState << ".addOperands(operands);\n";
2730 body << " " << builderOpState << ".addAttributes(attributes);\n";
2732 // Create the correct number of regions
2733 if (int numRegions = op.getNumRegions()) {
2734 body << llvm::formatv(
2735 " for (unsigned i = 0; i != {0}; ++i)\n",
2736 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
2737 body << " (void)" << builderOpState << ".addRegion();\n";
2740 // Result types
2741 if (emitHelper.hasProperties()) {
2742 // Initialize the properties from Attributes before invoking the infer
2743 // function.
2744 body << formatv(R"(
2745 if (!attributes.empty()) {
2746 ::mlir::OpaqueProperties properties =
2747 &{1}.getOrAddProperties<{0}::Properties>();
2748 std::optional<::mlir::RegisteredOperationName> info =
2749 {1}.name.getRegisteredInfo();
2750 if (failed(info->setOpPropertiesFromAttribute({1}.name, properties,
2751 {1}.attributes.getDictionary({1}.getContext()), nullptr)))
2752 ::llvm::report_fatal_error("Property conversion failed.");
2753 })",
2754 opClass.getClassName(), builderOpState);
2756 body << formatv(R"(
2757 ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
2758 if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
2759 {1}.location, operands,
2760 {1}.attributes.getDictionary({1}.getContext()),
2761 {1}.getRawProperties(),
2762 {1}.regions, inferredReturnTypes))) {{)",
2763 opClass.getClassName(), builderOpState);
2764 if (numVariadicResults == 0 || numNonVariadicResults != 0)
2765 body << "\n assert(inferredReturnTypes.size()"
2766 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
2767 << "u && \"mismatched number of return types\");";
2768 body << "\n " << builderOpState << ".addTypes(inferredReturnTypes);";
2770 body << R"(
2771 } else {
2772 ::llvm::report_fatal_error("Failed to infer result type(s).");
2773 })";
2776 void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
2777 auto emit = [&](AttrParamKind attrType) {
2778 SmallVector<MethodParameter> paramList;
2779 SmallVector<std::string, 4> resultNames;
2780 llvm::StringSet<> inferredAttributes;
2781 buildParamList(paramList, inferredAttributes, resultNames,
2782 TypeParamKind::None, attrType);
2784 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2785 // If the builder is redundant, skip generating the method
2786 if (!m)
2787 return;
2788 auto &body = m->body();
2789 genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
2790 /*isRawValueAttr=*/attrType ==
2791 AttrParamKind::UnwrappedValue);
2793 auto numResults = op.getNumResults();
2794 if (numResults == 0)
2795 return;
2797 // Push all result types to the operation state
2798 const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
2799 std::string resultType =
2800 formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str();
2801 body << " " << builderOpState << ".addTypes({" << resultType;
2802 for (int i = 1; i != numResults; ++i)
2803 body << ", " << resultType;
2804 body << "});\n\n";
2807 emit(AttrParamKind::WrappedAttr);
2808 // Generate additional builder(s) if any attributes can be "unwrapped"
2809 if (canGenerateUnwrappedBuilder(op))
2810 emit(AttrParamKind::UnwrappedValue);
2813 void OpEmitter::genUseAttrAsResultTypeBuilder() {
2814 SmallVector<MethodParameter> paramList;
2815 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
2816 paramList.emplace_back("::mlir::OperationState &", builderOpState);
2817 paramList.emplace_back("::mlir::ValueRange", "operands");
2818 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
2819 "attributes", "{}");
2820 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2821 // If the builder is redundant, skip generating the method
2822 if (!m)
2823 return;
2825 auto &body = m->body();
2827 // Push all result types to the operation state
2828 std::string resultType;
2829 const auto &namedAttr = op.getAttribute(0);
2831 body << " auto attrName = " << op.getGetterName(namedAttr.name)
2832 << "AttrName(" << builderOpState
2833 << ".name);\n"
2834 " for (auto attr : attributes) {\n"
2835 " if (attr.getName() != attrName) continue;\n";
2836 if (namedAttr.attr.isTypeAttr()) {
2837 resultType = "::llvm::cast<::mlir::TypeAttr>(attr.getValue()).getValue()";
2838 } else {
2839 resultType = "::llvm::cast<::mlir::TypedAttr>(attr.getValue()).getType()";
2842 // Operands
2843 body << " " << builderOpState << ".addOperands(operands);\n";
2845 // Attributes
2846 body << " " << builderOpState << ".addAttributes(attributes);\n";
2848 // Result types
2849 SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType);
2850 body << " " << builderOpState << ".addTypes({"
2851 << llvm::join(resultTypes, ", ") << "});\n";
2852 body << " }\n";
2855 /// Returns a signature of the builder. Updates the context `fctx` to enable
2856 /// replacement of $_builder and $_state in the body.
2857 static SmallVector<MethodParameter>
2858 getBuilderSignature(const Builder &builder) {
2859 ArrayRef<Builder::Parameter> params(builder.getParameters());
2861 // Inject builder and state arguments.
2862 SmallVector<MethodParameter> arguments;
2863 arguments.reserve(params.size() + 2);
2864 arguments.emplace_back("::mlir::OpBuilder &", odsBuilder);
2865 arguments.emplace_back("::mlir::OperationState &", builderOpState);
2867 for (unsigned i = 0, e = params.size(); i < e; ++i) {
2868 // If no name is provided, generate one.
2869 std::optional<StringRef> paramName = params[i].getName();
2870 std::string name =
2871 paramName ? paramName->str() : "odsArg" + std::to_string(i);
2873 StringRef defaultValue;
2874 if (std::optional<StringRef> defaultParamValue =
2875 params[i].getDefaultValue())
2876 defaultValue = *defaultParamValue;
2878 arguments.emplace_back(params[i].getCppType(), std::move(name),
2879 defaultValue);
2882 return arguments;
2885 void OpEmitter::genBuilder() {
2886 // Handle custom builders if provided.
2887 for (const Builder &builder : op.getBuilders()) {
2888 SmallVector<MethodParameter> arguments = getBuilderSignature(builder);
2890 std::optional<StringRef> body = builder.getBody();
2891 auto properties = body ? Method::Static : Method::StaticDeclaration;
2892 auto *method =
2893 opClass.addMethod("void", "build", properties, std::move(arguments));
2894 if (body)
2895 ERROR_IF_PRUNED(method, "build", op);
2897 if (method)
2898 method->setDeprecated(builder.getDeprecatedMessage());
2900 FmtContext fctx;
2901 fctx.withBuilder(odsBuilder);
2902 fctx.addSubst("_state", builderOpState);
2903 if (body)
2904 method->body() << tgfmt(*body, &fctx);
2907 // Generate default builders that requires all result type, operands, and
2908 // attributes as parameters.
2909 if (op.skipDefaultBuilders())
2910 return;
2912 // We generate three classes of builders here:
2913 // 1. one having a stand-alone parameter for each operand / attribute, and
2914 genSeparateArgParamBuilder();
2915 // 2. one having an aggregated parameter for all result types / operands /
2916 // attributes, and
2917 genCollectiveParamBuilder();
2918 // 3. one having a stand-alone parameter for each operand and attribute,
2919 // use the first operand or attribute's type as all result types
2920 // to facilitate different call patterns.
2921 if (op.getNumVariableLengthResults() == 0) {
2922 if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
2923 genUseOperandAsResultTypeSeparateParamBuilder();
2924 genUseOperandAsResultTypeCollectiveParamBuilder();
2926 if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"))
2927 genUseAttrAsResultTypeBuilder();
2931 void OpEmitter::genCollectiveParamBuilder() {
2932 int numResults = op.getNumResults();
2933 int numVariadicResults = op.getNumVariableLengthResults();
2934 int numNonVariadicResults = numResults - numVariadicResults;
2936 int numOperands = op.getNumOperands();
2937 int numVariadicOperands = op.getNumVariableLengthOperands();
2938 int numNonVariadicOperands = numOperands - numVariadicOperands;
2940 SmallVector<MethodParameter> paramList;
2941 paramList.emplace_back("::mlir::OpBuilder &", "");
2942 paramList.emplace_back("::mlir::OperationState &", builderOpState);
2943 paramList.emplace_back("::mlir::TypeRange", "resultTypes");
2944 paramList.emplace_back("::mlir::ValueRange", "operands");
2945 // Provide default value for `attributes` when its the last parameter
2946 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
2947 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
2948 "attributes", attributesDefaultValue);
2949 if (op.getNumVariadicRegions())
2950 paramList.emplace_back("unsigned", "numRegions");
2952 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2953 // If the builder is redundant, skip generating the method
2954 if (!m)
2955 return;
2956 auto &body = m->body();
2958 // Operands
2959 if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
2960 body << " assert(operands.size()"
2961 << (numVariadicOperands != 0 ? " >= " : " == ")
2962 << numNonVariadicOperands
2963 << "u && \"mismatched number of parameters\");\n";
2964 body << " " << builderOpState << ".addOperands(operands);\n";
2966 // Attributes
2967 body << " " << builderOpState << ".addAttributes(attributes);\n";
2969 // Create the correct number of regions
2970 if (int numRegions = op.getNumRegions()) {
2971 body << llvm::formatv(
2972 " for (unsigned i = 0; i != {0}; ++i)\n",
2973 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
2974 body << " (void)" << builderOpState << ".addRegion();\n";
2977 // Result types
2978 if (numVariadicResults == 0 || numNonVariadicResults != 0)
2979 body << " assert(resultTypes.size()"
2980 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
2981 << "u && \"mismatched number of return types\");\n";
2982 body << " " << builderOpState << ".addTypes(resultTypes);\n";
2984 if (emitHelper.hasProperties()) {
2985 // Initialize the properties from Attributes before invoking the infer
2986 // function.
2987 body << formatv(R"(
2988 if (!attributes.empty()) {
2989 ::mlir::OpaqueProperties properties =
2990 &{1}.getOrAddProperties<{0}::Properties>();
2991 std::optional<::mlir::RegisteredOperationName> info =
2992 {1}.name.getRegisteredInfo();
2993 if (failed(info->setOpPropertiesFromAttribute({1}.name, properties,
2994 {1}.attributes.getDictionary({1}.getContext()), nullptr)))
2995 ::llvm::report_fatal_error("Property conversion failed.");
2996 })",
2997 opClass.getClassName(), builderOpState);
3000 // Generate builder that infers type too.
3001 // TODO: Expand to handle successors.
3002 if (canInferType(op) && op.getNumSuccessors() == 0)
3003 genInferredTypeCollectiveParamBuilder();
3006 void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> &paramList,
3007 llvm::StringSet<> &inferredAttributes,
3008 SmallVectorImpl<std::string> &resultTypeNames,
3009 TypeParamKind typeParamKind,
3010 AttrParamKind attrParamKind) {
3011 resultTypeNames.clear();
3012 auto numResults = op.getNumResults();
3013 resultTypeNames.reserve(numResults);
3015 paramList.emplace_back("::mlir::OpBuilder &", odsBuilder);
3016 paramList.emplace_back("::mlir::OperationState &", builderOpState);
3018 switch (typeParamKind) {
3019 case TypeParamKind::None:
3020 break;
3021 case TypeParamKind::Separate: {
3022 // Add parameters for all return types
3023 for (int i = 0; i < numResults; ++i) {
3024 const auto &result = op.getResult(i);
3025 std::string resultName = std::string(result.name);
3026 if (resultName.empty())
3027 resultName = std::string(formatv("resultType{0}", i));
3029 StringRef type =
3030 result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
3032 paramList.emplace_back(type, resultName, result.isOptional());
3033 resultTypeNames.emplace_back(std::move(resultName));
3035 } break;
3036 case TypeParamKind::Collective: {
3037 paramList.emplace_back("::mlir::TypeRange", "resultTypes");
3038 resultTypeNames.push_back("resultTypes");
3039 } break;
3042 // Add parameters for all arguments (operands and attributes).
3043 // Track "attr-like" (property and attribute) optional values separate from
3044 // attributes themselves so that the disambiguation code can look at the first
3045 // attribute specifically when determining where to trim the optional-value
3046 // list to avoid ambiguity while preserving the ability of all-property ops to
3047 // use default parameters.
3048 int defaultValuedAttrLikeStartIndex = op.getNumArgs();
3049 int defaultValuedAttrStartIndex = op.getNumArgs();
3050 // Successors and variadic regions go at the end of the parameter list, so no
3051 // default arguments are possible.
3052 bool hasTrailingParams = op.getNumSuccessors() || op.getNumVariadicRegions();
3053 if (!hasTrailingParams) {
3054 // Calculate the start index from which we can attach default values in the
3055 // builder declaration.
3056 for (int i = op.getNumArgs() - 1; i >= 0; --i) {
3057 auto *namedAttr =
3058 llvm::dyn_cast_if_present<tblgen::NamedAttribute *>(op.getArg(i));
3059 auto *namedProperty =
3060 llvm::dyn_cast_if_present<tblgen::NamedProperty *>(op.getArg(i));
3061 if (namedProperty) {
3062 Property prop = namedProperty->prop;
3063 if (!prop.hasDefaultValue())
3064 break;
3065 defaultValuedAttrLikeStartIndex = i;
3066 continue;
3068 if (!namedAttr)
3069 break;
3071 Attribute attr = namedAttr->attr;
3072 // TODO: Currently we can't differentiate between optional meaning do not
3073 // verify/not always error if missing or optional meaning need not be
3074 // specified in builder. Expand isOptional once we can differentiate.
3075 if (!attr.hasDefaultValue() && !attr.isDerivedAttr())
3076 break;
3078 // Creating an APInt requires us to provide bitwidth, value, and
3079 // signedness, which is complicated compared to others. Similarly
3080 // for APFloat.
3081 // TODO: Adjust the 'returnType' field of such attributes
3082 // to support them.
3083 StringRef retType = namedAttr->attr.getReturnType();
3084 if (retType == "::llvm::APInt" || retType == "::llvm::APFloat")
3085 break;
3087 defaultValuedAttrLikeStartIndex = i;
3088 defaultValuedAttrStartIndex = i;
3091 // Avoid generating build methods that are ambiguous due to default values by
3092 // requiring at least one attribute.
3093 if (defaultValuedAttrStartIndex < op.getNumArgs()) {
3094 // TODO: This should have been possible as a cast<NamedAttribute> but
3095 // required template instantiations is not yet defined for the tblgen helper
3096 // classes.
3097 auto *namedAttr =
3098 cast<NamedAttribute *>(op.getArg(defaultValuedAttrStartIndex));
3099 Attribute attr = namedAttr->attr;
3100 if ((attrParamKind == AttrParamKind::WrappedAttr &&
3101 canUseUnwrappedRawValue(attr)) ||
3102 (attrParamKind == AttrParamKind::UnwrappedValue &&
3103 !canUseUnwrappedRawValue(attr))) {
3104 ++defaultValuedAttrStartIndex;
3105 defaultValuedAttrLikeStartIndex = defaultValuedAttrStartIndex;
3109 /// Collect any inferred attributes.
3110 for (const NamedTypeConstraint &operand : op.getOperands()) {
3111 if (operand.isVariadicOfVariadic()) {
3112 inferredAttributes.insert(
3113 operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
3117 for (int i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) {
3118 Argument arg = op.getArg(i);
3119 if (const auto *operand =
3120 llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg)) {
3121 StringRef type;
3122 if (operand->isVariadicOfVariadic())
3123 type = "::llvm::ArrayRef<::mlir::ValueRange>";
3124 else if (operand->isVariadic())
3125 type = "::mlir::ValueRange";
3126 else
3127 type = "::mlir::Value";
3129 paramList.emplace_back(type, getArgumentName(op, numOperands++),
3130 operand->isOptional());
3131 continue;
3133 if (auto *propArg = llvm::dyn_cast_if_present<NamedProperty *>(arg)) {
3134 const Property &prop = propArg->prop;
3135 StringRef type = prop.getInterfaceType();
3136 std::string defaultValue;
3137 if (prop.hasDefaultValue() && i >= defaultValuedAttrLikeStartIndex) {
3138 defaultValue = prop.getDefaultValue();
3140 bool isOptional = prop.hasDefaultValue();
3141 paramList.emplace_back(type, propArg->name, StringRef(defaultValue),
3142 isOptional);
3143 continue;
3145 const NamedAttribute &namedAttr = *arg.get<NamedAttribute *>();
3146 const Attribute &attr = namedAttr.attr;
3148 // Inferred attributes don't need to be added to the param list.
3149 if (inferredAttributes.contains(namedAttr.name))
3150 continue;
3152 StringRef type;
3153 switch (attrParamKind) {
3154 case AttrParamKind::WrappedAttr:
3155 type = attr.getStorageType();
3156 break;
3157 case AttrParamKind::UnwrappedValue:
3158 if (canUseUnwrappedRawValue(attr))
3159 type = attr.getReturnType();
3160 else
3161 type = attr.getStorageType();
3162 break;
3165 // Attach default value if requested and possible.
3166 std::string defaultValue;
3167 if (i >= defaultValuedAttrStartIndex) {
3168 if (attrParamKind == AttrParamKind::UnwrappedValue &&
3169 canUseUnwrappedRawValue(attr))
3170 defaultValue += attr.getDefaultValue();
3171 else
3172 defaultValue += "nullptr";
3174 paramList.emplace_back(type, namedAttr.name, StringRef(defaultValue),
3175 attr.isOptional());
3178 /// Insert parameters for each successor.
3179 for (const NamedSuccessor &succ : op.getSuccessors()) {
3180 StringRef type =
3181 succ.isVariadic() ? "::mlir::BlockRange" : "::mlir::Block *";
3182 paramList.emplace_back(type, succ.name);
3185 /// Insert parameters for variadic regions.
3186 for (const NamedRegion &region : op.getRegions())
3187 if (region.isVariadic())
3188 paramList.emplace_back("unsigned",
3189 llvm::formatv("{0}Count", region.name).str());
3192 void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
3193 MethodBody &body, llvm::StringSet<> &inferredAttributes,
3194 bool isRawValueAttr) {
3195 // Push all operands to the result.
3196 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
3197 std::string argName = getArgumentName(op, i);
3198 const NamedTypeConstraint &operand = op.getOperand(i);
3199 if (operand.constraint.isVariadicOfVariadic()) {
3200 body << " for (::mlir::ValueRange range : " << argName << ")\n "
3201 << builderOpState << ".addOperands(range);\n";
3203 // Add the segment attribute.
3204 body << " {\n"
3205 << " ::llvm::SmallVector<int32_t> rangeSegments;\n"
3206 << " for (::mlir::ValueRange range : " << argName << ")\n"
3207 << " rangeSegments.push_back(range.size());\n"
3208 << " auto rangeAttr = " << odsBuilder
3209 << ".getDenseI32ArrayAttr(rangeSegments);\n";
3210 if (op.getDialect().usePropertiesForAttributes()) {
3211 body << " " << builderOpState << ".getOrAddProperties<Properties>()."
3212 << operand.constraint.getVariadicOfVariadicSegmentSizeAttr()
3213 << " = rangeAttr;";
3214 } else {
3215 body << " " << builderOpState << ".addAttribute("
3216 << op.getGetterName(
3217 operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
3218 << "AttrName(" << builderOpState << ".name), rangeAttr);";
3220 body << " }\n";
3221 continue;
3224 if (operand.isOptional())
3225 body << " if (" << argName << ")\n ";
3226 body << " " << builderOpState << ".addOperands(" << argName << ");\n";
3229 // If the operation has the operand segment size attribute, add it here.
3230 auto emitSegment = [&]() {
3231 interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
3232 const NamedTypeConstraint &operand = op.getOperand(i);
3233 if (!operand.isVariableLength()) {
3234 body << "1";
3235 return;
3238 std::string operandName = getArgumentName(op, i);
3239 if (operand.isOptional()) {
3240 body << "(" << operandName << " ? 1 : 0)";
3241 } else if (operand.isVariadicOfVariadic()) {
3242 body << llvm::formatv(
3243 "static_cast<int32_t>(std::accumulate({0}.begin(), {0}.end(), 0, "
3244 "[](int32_t curSum, ::mlir::ValueRange range) {{ return curSum + "
3245 "static_cast<int32_t>(range.size()); }))",
3246 operandName);
3247 } else {
3248 body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
3252 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
3253 std::string sizes = op.getGetterName(operandSegmentAttrName);
3254 if (op.getDialect().usePropertiesForAttributes()) {
3255 body << " ::llvm::copy(::llvm::ArrayRef<int32_t>({";
3256 emitSegment();
3257 body << "}), " << builderOpState
3258 << ".getOrAddProperties<Properties>()."
3259 "operandSegmentSizes.begin());\n";
3260 } else {
3261 body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName("
3262 << builderOpState << ".name), "
3263 << "odsBuilder.getDenseI32ArrayAttr({";
3264 emitSegment();
3265 body << "}));\n";
3269 // Push all properties to the result.
3270 for (const auto &namedProp : op.getProperties()) {
3271 // Use the setter from the Properties struct since the conversion from the
3272 // interface type (used in the builder argument) to the storage type (used
3273 // in the state) is not necessarily trivial.
3274 std::string setterName = op.getSetterName(namedProp.name);
3275 body << formatv(" {0}.getOrAddProperties<Properties>().{1}({2});\n",
3276 builderOpState, setterName, namedProp.name);
3278 // Push all attributes to the result.
3279 for (const auto &namedAttr : op.getAttributes()) {
3280 auto &attr = namedAttr.attr;
3281 if (attr.isDerivedAttr() || inferredAttributes.contains(namedAttr.name))
3282 continue;
3284 // TODO: The wrapping of optional is different for default or not, so don't
3285 // unwrap for default ones that would fail below.
3286 bool emitNotNullCheck =
3287 (attr.isOptional() && !attr.hasDefaultValue()) ||
3288 (attr.hasDefaultValue() && !isRawValueAttr) ||
3289 // TODO: UnitAttr is optional, not wrapped, but needs to be guarded as
3290 // the constant materialization is only for true case.
3291 (isRawValueAttr && attr.getAttrDefName() == "UnitAttr");
3292 if (emitNotNullCheck)
3293 body.indent() << formatv("if ({0}) ", namedAttr.name) << "{\n";
3295 if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
3296 // If this is a raw value, then we need to wrap it in an Attribute
3297 // instance.
3298 FmtContext fctx;
3299 fctx.withBuilder("odsBuilder");
3300 if (op.getDialect().usePropertiesForAttributes()) {
3301 body << formatv(" {0}.getOrAddProperties<Properties>().{1} = {2};\n",
3302 builderOpState, namedAttr.name,
3303 constBuildAttrFromParam(attr, fctx, namedAttr.name));
3304 } else {
3305 body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
3306 builderOpState, op.getGetterName(namedAttr.name),
3307 constBuildAttrFromParam(attr, fctx, namedAttr.name));
3309 } else {
3310 if (op.getDialect().usePropertiesForAttributes()) {
3311 body << formatv(" {0}.getOrAddProperties<Properties>().{1} = {1};\n",
3312 builderOpState, namedAttr.name);
3313 } else {
3314 body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
3315 builderOpState, op.getGetterName(namedAttr.name),
3316 namedAttr.name);
3319 if (emitNotNullCheck)
3320 body.unindent() << " }\n";
3323 // Create the correct number of regions.
3324 for (const NamedRegion &region : op.getRegions()) {
3325 if (region.isVariadic())
3326 body << formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ",
3327 region.name);
3329 body << " (void)" << builderOpState << ".addRegion();\n";
3332 // Push all successors to the result.
3333 for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
3334 body << formatv(" {0}.addSuccessors({1});\n", builderOpState,
3335 namedSuccessor.name);
3339 void OpEmitter::genCanonicalizerDecls() {
3340 bool hasCanonicalizeMethod = def.getValueAsBit("hasCanonicalizeMethod");
3341 if (hasCanonicalizeMethod) {
3342 // static LogicResult FooOp::
3343 // canonicalize(FooOp op, PatternRewriter &rewriter);
3344 SmallVector<MethodParameter> paramList;
3345 paramList.emplace_back(op.getCppClassName(), "op");
3346 paramList.emplace_back("::mlir::PatternRewriter &", "rewriter");
3347 auto *m = opClass.declareStaticMethod("::llvm::LogicalResult",
3348 "canonicalize", std::move(paramList));
3349 ERROR_IF_PRUNED(m, "canonicalize", op);
3352 // We get a prototype for 'getCanonicalizationPatterns' if requested directly
3353 // or if using a 'canonicalize' method.
3354 bool hasCanonicalizer = def.getValueAsBit("hasCanonicalizer");
3355 if (!hasCanonicalizeMethod && !hasCanonicalizer)
3356 return;
3358 // We get a body for 'getCanonicalizationPatterns' when using a 'canonicalize'
3359 // method, but not implementing 'getCanonicalizationPatterns' manually.
3360 bool hasBody = hasCanonicalizeMethod && !hasCanonicalizer;
3362 // Add a signature for getCanonicalizationPatterns if implemented by the
3363 // dialect or if synthesized to call 'canonicalize'.
3364 SmallVector<MethodParameter> paramList;
3365 paramList.emplace_back("::mlir::RewritePatternSet &", "results");
3366 paramList.emplace_back("::mlir::MLIRContext *", "context");
3367 auto kind = hasBody ? Method::Static : Method::StaticDeclaration;
3368 auto *method = opClass.addMethod("void", "getCanonicalizationPatterns", kind,
3369 std::move(paramList));
3371 // If synthesizing the method, fill it.
3372 if (hasBody) {
3373 ERROR_IF_PRUNED(method, "getCanonicalizationPatterns", op);
3374 method->body() << " results.add(canonicalize);\n";
3378 void OpEmitter::genFolderDecls() {
3379 if (!op.hasFolder())
3380 return;
3382 SmallVector<MethodParameter> paramList;
3383 paramList.emplace_back("FoldAdaptor", "adaptor");
3385 StringRef retType;
3386 bool hasSingleResult =
3387 op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0;
3388 if (hasSingleResult) {
3389 retType = "::mlir::OpFoldResult";
3390 } else {
3391 paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
3392 "results");
3393 retType = "::llvm::LogicalResult";
3396 auto *m = opClass.declareMethod(retType, "fold", std::move(paramList));
3397 ERROR_IF_PRUNED(m, "fold", op);
3400 void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
3401 Interface interface = opTrait->getInterface();
3403 // Get the set of methods that should always be declared.
3404 auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods();
3405 llvm::StringSet<> alwaysDeclaredMethods;
3406 alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(),
3407 alwaysDeclaredMethodsVec.end());
3409 for (const InterfaceMethod &method : interface.getMethods()) {
3410 // Don't declare if the method has a body.
3411 if (method.getBody())
3412 continue;
3413 // Don't declare if the method has a default implementation and the op
3414 // didn't request that it always be declared.
3415 if (method.getDefaultImplementation() &&
3416 !alwaysDeclaredMethods.count(method.getName()))
3417 continue;
3418 // Interface methods are allowed to overlap with existing methods, so don't
3419 // check if pruned.
3420 (void)genOpInterfaceMethod(method);
3424 Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
3425 bool declaration) {
3426 SmallVector<MethodParameter> paramList;
3427 for (const InterfaceMethod::Argument &arg : method.getArguments())
3428 paramList.emplace_back(arg.type, arg.name);
3430 auto props = (method.isStatic() ? Method::Static : Method::None) |
3431 (declaration ? Method::Declaration : Method::None);
3432 return opClass.addMethod(method.getReturnType(), method.getName(), props,
3433 std::move(paramList));
3436 void OpEmitter::genOpInterfaceMethods() {
3437 for (const auto &trait : op.getTraits()) {
3438 if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
3439 if (opTrait->shouldDeclareMethods())
3440 genOpInterfaceMethods(opTrait);
3444 void OpEmitter::genSideEffectInterfaceMethods() {
3445 enum EffectKind { Operand, Result, Symbol, Static };
3446 struct EffectLocation {
3447 /// The effect applied.
3448 SideEffect effect;
3450 /// The index if the kind is not static.
3451 unsigned index;
3453 /// The kind of the location.
3454 unsigned kind;
3457 StringMap<SmallVector<EffectLocation, 1>> interfaceEffects;
3458 auto resolveDecorators = [&](Operator::var_decorator_range decorators,
3459 unsigned index, unsigned kind) {
3460 for (auto decorator : decorators)
3461 if (SideEffect *effect = dyn_cast<SideEffect>(&decorator)) {
3462 opClass.addTrait(effect->getInterfaceTrait());
3463 interfaceEffects[effect->getBaseEffectName()].push_back(
3464 EffectLocation{*effect, index, kind});
3468 // Collect effects that were specified via:
3469 /// Traits.
3470 for (const auto &trait : op.getTraits()) {
3471 const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(&trait);
3472 if (!opTrait)
3473 continue;
3474 auto &effects = interfaceEffects[opTrait->getBaseEffectName()];
3475 for (auto decorator : opTrait->getEffects())
3476 effects.push_back(EffectLocation{cast<SideEffect>(decorator),
3477 /*index=*/0, EffectKind::Static});
3479 /// Attributes and Operands.
3480 for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
3481 Argument arg = op.getArg(i);
3482 if (arg.is<NamedTypeConstraint *>()) {
3483 resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
3484 ++operandIt;
3485 continue;
3487 if (arg.is<NamedProperty *>())
3488 continue;
3489 const NamedAttribute *attr = arg.get<NamedAttribute *>();
3490 if (attr->attr.getBaseAttr().isSymbolRefAttr())
3491 resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol);
3493 /// Results.
3494 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
3495 resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result);
3497 // The code used to add an effect instance.
3498 // {0}: The effect class.
3499 // {1}: Optional value or symbol reference.
3500 // {2}: The side effect stage.
3501 // {3}: Does this side effect act on every single value of resource.
3502 // {4}: The resource class.
3503 const char *addEffectCode =
3504 " effects.emplace_back({0}::get(), {1}{2}, {3}, {4}::get());\n";
3506 for (auto &it : interfaceEffects) {
3507 // Generate the 'getEffects' method.
3508 std::string type = llvm::formatv("::llvm::SmallVectorImpl<::mlir::"
3509 "SideEffects::EffectInstance<{0}>> &",
3510 it.first())
3511 .str();
3512 auto *getEffects = opClass.addMethod("void", "getEffects",
3513 MethodParameter(type, "effects"));
3514 ERROR_IF_PRUNED(getEffects, "getEffects", op);
3515 auto &body = getEffects->body();
3517 // Add effect instances for each of the locations marked on the operation.
3518 for (auto &location : it.second) {
3519 StringRef effect = location.effect.getName();
3520 StringRef resource = location.effect.getResource();
3521 int stage = (int)location.effect.getStage();
3522 bool effectOnFullRegion = (int)location.effect.getEffectOnfullRegion();
3523 if (location.kind == EffectKind::Static) {
3524 // A static instance has no attached value.
3525 body << llvm::formatv(addEffectCode, effect, "", stage,
3526 effectOnFullRegion, resource)
3527 .str();
3528 } else if (location.kind == EffectKind::Symbol) {
3529 // A symbol reference requires adding the proper attribute.
3530 const auto *attr = op.getArg(location.index).get<NamedAttribute *>();
3531 std::string argName = op.getGetterName(attr->name);
3532 if (attr->attr.isOptional()) {
3533 body << " if (auto symbolRef = " << argName << "Attr())\n "
3534 << llvm::formatv(addEffectCode, effect, "symbolRef, ", stage,
3535 effectOnFullRegion, resource)
3536 .str();
3537 } else {
3538 body << llvm::formatv(addEffectCode, effect, argName + "Attr(), ",
3539 stage, effectOnFullRegion, resource)
3540 .str();
3542 } else {
3543 // Otherwise this is an operand/result, so we need to attach the Value.
3544 body << " {\n auto valueRange = getODS"
3545 << (location.kind == EffectKind::Operand ? "Operand" : "Result")
3546 << "IndexAndLength(" << location.index << ");\n"
3547 << " for (unsigned idx = valueRange.first; idx < "
3548 "valueRange.first"
3549 << " + valueRange.second; idx++) {\n "
3550 << llvm::formatv(addEffectCode, effect,
3551 (location.kind == EffectKind::Operand
3552 ? "&getOperation()->getOpOperand(idx), "
3553 : "getOperation()->getOpResult(idx), "),
3554 stage, effectOnFullRegion, resource)
3555 << " }\n }\n";
3561 void OpEmitter::genTypeInterfaceMethods() {
3562 if (!op.allResultTypesKnown())
3563 return;
3564 // Generate 'inferReturnTypes' method declaration using the interface method
3565 // declared in 'InferTypeOpInterface' op interface.
3566 const auto *trait =
3567 cast<InterfaceTrait>(op.getTrait("::mlir::InferTypeOpInterface::Trait"));
3568 Interface interface = trait->getInterface();
3569 Method *method = [&]() -> Method * {
3570 for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
3571 if (interfaceMethod.getName() == "inferReturnTypes") {
3572 return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false);
3575 assert(0 && "unable to find inferReturnTypes interface method");
3576 return nullptr;
3577 }();
3578 ERROR_IF_PRUNED(method, "inferReturnTypes", op);
3579 auto &body = method->body();
3580 body << " inferredReturnTypes.resize(" << op.getNumResults() << ");\n";
3582 FmtContext fctx;
3583 fctx.withBuilder("odsBuilder");
3584 fctx.addSubst("_ctxt", "context");
3585 body << " ::mlir::Builder odsBuilder(context);\n";
3587 // Preprocessing stage to verify all accesses to operands are valid.
3588 int maxAccessedIndex = -1;
3589 for (int i = 0, e = op.getNumResults(); i != e; ++i) {
3590 const InferredResultType &infer = op.getInferredResultType(i);
3591 if (!infer.isArg())
3592 continue;
3593 Operator::OperandOrAttribute arg =
3594 op.getArgToOperandOrAttribute(infer.getIndex());
3595 if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
3596 maxAccessedIndex =
3597 std::max(maxAccessedIndex, arg.operandOrAttributeIndex());
3600 if (maxAccessedIndex != -1) {
3601 body << " if (operands.size() <= " << Twine(maxAccessedIndex) << ")\n";
3602 body << " return ::mlir::failure();\n";
3605 // Process the type inference graph in topological order, starting from types
3606 // that are always fully-inferred: operands and results with constructible
3607 // types. The type inference graph here will always be a DAG, so this gives
3608 // us the correct order for generating the types. -1 is a placeholder to
3609 // indicate the type for a result has not been generated.
3610 SmallVector<int> constructedIndices(op.getNumResults(), -1);
3611 int inferredTypeIdx = 0;
3612 for (int numResults = op.getNumResults(); inferredTypeIdx != numResults;) {
3613 for (int i = 0, e = op.getNumResults(); i != e; ++i) {
3614 if (constructedIndices[i] >= 0)
3615 continue;
3616 const InferredResultType &infer = op.getInferredResultType(i);
3617 std::string typeStr;
3618 if (infer.isArg()) {
3619 // If this is an operand, just index into operand list to access the
3620 // type.
3621 Operator::OperandOrAttribute arg =
3622 op.getArgToOperandOrAttribute(infer.getIndex());
3623 if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
3624 typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) +
3625 "].getType()")
3626 .str();
3628 // If this is an attribute, index into the attribute dictionary.
3629 } else {
3630 auto *attr =
3631 op.getArg(arg.operandOrAttributeIndex()).get<NamedAttribute *>();
3632 body << " ::mlir::TypedAttr odsInferredTypeAttr" << inferredTypeIdx
3633 << " = ";
3634 if (op.getDialect().usePropertiesForAttributes()) {
3635 body << "(properties ? properties.as<Properties *>()->"
3636 << attr->name
3637 << " : "
3638 "::llvm::dyn_cast_or_null<::mlir::TypedAttr>(attributes."
3639 "get(\"" +
3640 attr->name + "\")));\n";
3641 } else {
3642 body << "::llvm::dyn_cast_or_null<::mlir::TypedAttr>(attributes."
3643 "get(\"" +
3644 attr->name + "\"));\n";
3646 body << " if (!odsInferredTypeAttr" << inferredTypeIdx
3647 << ") return ::mlir::failure();\n";
3648 typeStr =
3649 ("odsInferredTypeAttr" + Twine(inferredTypeIdx) + ".getType()")
3650 .str();
3652 } else if (std::optional<StringRef> builder =
3653 op.getResult(infer.getResultIndex())
3654 .constraint.getBuilderCall()) {
3655 typeStr = tgfmt(*builder, &fctx).str();
3656 } else if (int index = constructedIndices[infer.getResultIndex()];
3657 index >= 0) {
3658 typeStr = ("odsInferredType" + Twine(index)).str();
3659 } else {
3660 continue;
3662 body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = "
3663 << tgfmt(infer.getTransformer(), &fctx.withSelf(typeStr)) << ";\n";
3664 constructedIndices[i] = inferredTypeIdx - 1;
3667 for (auto [i, index] : llvm::enumerate(constructedIndices))
3668 body << " inferredReturnTypes[" << i << "] = odsInferredType" << index
3669 << ";\n";
3670 body << " return ::mlir::success();";
3673 void OpEmitter::genParser() {
3674 if (hasStringAttribute(def, "assemblyFormat"))
3675 return;
3677 if (!def.getValueAsBit("hasCustomAssemblyFormat"))
3678 return;
3680 SmallVector<MethodParameter> paramList;
3681 paramList.emplace_back("::mlir::OpAsmParser &", "parser");
3682 paramList.emplace_back("::mlir::OperationState &", "result");
3684 auto *method = opClass.declareStaticMethod("::mlir::ParseResult", "parse",
3685 std::move(paramList));
3686 ERROR_IF_PRUNED(method, "parse", op);
3689 void OpEmitter::genPrinter() {
3690 if (hasStringAttribute(def, "assemblyFormat"))
3691 return;
3693 // Check to see if this op uses a c++ format.
3694 if (!def.getValueAsBit("hasCustomAssemblyFormat"))
3695 return;
3696 auto *method = opClass.declareMethod(
3697 "void", "print", MethodParameter("::mlir::OpAsmPrinter &", "p"));
3698 ERROR_IF_PRUNED(method, "print", op);
3701 void OpEmitter::genVerifier() {
3702 auto *implMethod =
3703 opClass.addMethod("::llvm::LogicalResult", "verifyInvariantsImpl");
3704 ERROR_IF_PRUNED(implMethod, "verifyInvariantsImpl", op);
3705 auto &implBody = implMethod->body();
3706 bool useProperties = emitHelper.hasProperties();
3708 populateSubstitutions(emitHelper, verifyCtx);
3709 genAttributeVerifier(emitHelper, verifyCtx, implBody, staticVerifierEmitter,
3710 useProperties);
3711 genOperandResultVerifier(implBody, op.getOperands(), "operand");
3712 genOperandResultVerifier(implBody, op.getResults(), "result");
3714 for (auto &trait : op.getTraits()) {
3715 if (auto *t = dyn_cast<tblgen::PredTrait>(&trait)) {
3716 implBody << tgfmt(" if (!($0))\n "
3717 "return emitOpError(\"failed to verify that $1\");\n",
3718 &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
3719 t->getSummary());
3723 genRegionVerifier(implBody);
3724 genSuccessorVerifier(implBody);
3726 implBody << " return ::mlir::success();\n";
3728 // TODO: Some places use the `verifyInvariants` to do operation verification.
3729 // This may not act as their expectation because this doesn't call any
3730 // verifiers of native/interface traits. Needs to review those use cases and
3731 // see if we should use the mlir::verify() instead.
3732 auto *method = opClass.addMethod("::llvm::LogicalResult", "verifyInvariants");
3733 ERROR_IF_PRUNED(method, "verifyInvariants", op);
3734 auto &body = method->body();
3735 if (def.getValueAsBit("hasVerifier")) {
3736 body << " if(::mlir::succeeded(verifyInvariantsImpl()) && "
3737 "::mlir::succeeded(verify()))\n";
3738 body << " return ::mlir::success();\n";
3739 body << " return ::mlir::failure();";
3740 } else {
3741 body << " return verifyInvariantsImpl();";
3745 void OpEmitter::genCustomVerifier() {
3746 if (def.getValueAsBit("hasVerifier")) {
3747 auto *method = opClass.declareMethod("::llvm::LogicalResult", "verify");
3748 ERROR_IF_PRUNED(method, "verify", op);
3751 if (def.getValueAsBit("hasRegionVerifier")) {
3752 auto *method =
3753 opClass.declareMethod("::llvm::LogicalResult", "verifyRegions");
3754 ERROR_IF_PRUNED(method, "verifyRegions", op);
3758 void OpEmitter::genOperandResultVerifier(MethodBody &body,
3759 Operator::const_value_range values,
3760 StringRef valueKind) {
3761 // Check that an optional value is at most 1 element.
3763 // {0}: Value index.
3764 // {1}: "operand" or "result"
3765 const char *const verifyOptional = R"(
3766 if (valueGroup{0}.size() > 1) {
3767 return emitOpError("{1} group starting at #") << index
3768 << " requires 0 or 1 element, but found " << valueGroup{0}.size();
3771 // Check the types of a range of values.
3773 // {0}: Value index.
3774 // {1}: Type constraint function.
3775 // {2}: "operand" or "result"
3776 const char *const verifyValues = R"(
3777 for (auto v : valueGroup{0}) {
3778 if (::mlir::failed({1}(*this, v.getType(), "{2}", index++)))
3779 return ::mlir::failure();
3783 const auto canSkip = [](const NamedTypeConstraint &value) {
3784 return !value.hasPredicate() && !value.isOptional() &&
3785 !value.isVariadicOfVariadic();
3787 if (values.empty() || llvm::all_of(values, canSkip))
3788 return;
3790 FmtContext fctx;
3792 body << " {\n unsigned index = 0; (void)index;\n";
3794 for (const auto &staticValue : llvm::enumerate(values)) {
3795 const NamedTypeConstraint &value = staticValue.value();
3797 bool hasPredicate = value.hasPredicate();
3798 bool isOptional = value.isOptional();
3799 bool isVariadicOfVariadic = value.isVariadicOfVariadic();
3800 if (!hasPredicate && !isOptional && !isVariadicOfVariadic)
3801 continue;
3802 body << formatv(" auto valueGroup{2} = getODS{0}{1}s({2});\n",
3803 // Capitalize the first letter to match the function name
3804 valueKind.substr(0, 1).upper(), valueKind.substr(1),
3805 staticValue.index());
3807 // If the constraint is optional check that the value group has at most 1
3808 // value.
3809 if (isOptional) {
3810 body << formatv(verifyOptional, staticValue.index(), valueKind);
3811 } else if (isVariadicOfVariadic) {
3812 body << formatv(
3813 " if (::mlir::failed(::mlir::OpTrait::impl::verifyValueSizeAttr("
3814 "*this, \"{0}\", \"{1}\", valueGroup{2}.size())))\n"
3815 " return ::mlir::failure();\n",
3816 value.constraint.getVariadicOfVariadicSegmentSizeAttr(), value.name,
3817 staticValue.index());
3820 // Otherwise, if there is no predicate there is nothing left to do.
3821 if (!hasPredicate)
3822 continue;
3823 // Emit a loop to check all the dynamic values in the pack.
3824 StringRef constraintFn =
3825 staticVerifierEmitter.getTypeConstraintFn(value.constraint);
3826 body << formatv(verifyValues, staticValue.index(), constraintFn, valueKind);
3829 body << " }\n";
3832 void OpEmitter::genRegionVerifier(MethodBody &body) {
3833 /// Code to verify a region.
3835 /// {0}: Getter for the regions.
3836 /// {1}: The region constraint.
3837 /// {2}: The region's name.
3838 /// {3}: The region description.
3839 const char *const verifyRegion = R"(
3840 for (auto &region : {0})
3841 if (::mlir::failed({1}(*this, region, "{2}", index++)))
3842 return ::mlir::failure();
3844 /// Get a single region.
3846 /// {0}: The region's index.
3847 const char *const getSingleRegion =
3848 "::llvm::MutableArrayRef((*this)->getRegion({0}))";
3850 // If we have no regions, there is nothing more to do.
3851 const auto canSkip = [](const NamedRegion &region) {
3852 return region.constraint.getPredicate().isNull();
3854 auto regions = op.getRegions();
3855 if (regions.empty() && llvm::all_of(regions, canSkip))
3856 return;
3858 body << " {\n unsigned index = 0; (void)index;\n";
3859 for (const auto &it : llvm::enumerate(regions)) {
3860 const auto &region = it.value();
3861 if (canSkip(region))
3862 continue;
3864 auto getRegion = region.isVariadic()
3865 ? formatv("{0}()", op.getGetterName(region.name)).str()
3866 : formatv(getSingleRegion, it.index()).str();
3867 auto constraintFn =
3868 staticVerifierEmitter.getRegionConstraintFn(region.constraint);
3869 body << formatv(verifyRegion, getRegion, constraintFn, region.name);
3871 body << " }\n";
3874 void OpEmitter::genSuccessorVerifier(MethodBody &body) {
3875 const char *const verifySuccessor = R"(
3876 for (auto *successor : {0})
3877 if (::mlir::failed({1}(*this, successor, "{2}", index++)))
3878 return ::mlir::failure();
3880 /// Get a single successor.
3882 /// {0}: The successor's name.
3883 const char *const getSingleSuccessor = "::llvm::MutableArrayRef({0}())";
3885 // If we have no successors, there is nothing more to do.
3886 const auto canSkip = [](const NamedSuccessor &successor) {
3887 return successor.constraint.getPredicate().isNull();
3889 auto successors = op.getSuccessors();
3890 if (successors.empty() && llvm::all_of(successors, canSkip))
3891 return;
3893 body << " {\n unsigned index = 0; (void)index;\n";
3895 for (auto it : llvm::enumerate(successors)) {
3896 const auto &successor = it.value();
3897 if (canSkip(successor))
3898 continue;
3900 auto getSuccessor =
3901 formatv(successor.isVariadic() ? "{0}()" : getSingleSuccessor,
3902 successor.name)
3903 .str();
3904 auto constraintFn =
3905 staticVerifierEmitter.getSuccessorConstraintFn(successor.constraint);
3906 body << formatv(verifySuccessor, getSuccessor, constraintFn,
3907 successor.name);
3909 body << " }\n";
3912 /// Add a size count trait to the given operation class.
3913 static void addSizeCountTrait(OpClass &opClass, StringRef traitKind,
3914 int numTotal, int numVariadic) {
3915 if (numVariadic != 0) {
3916 if (numTotal == numVariadic)
3917 opClass.addTrait("::mlir::OpTrait::Variadic" + traitKind + "s");
3918 else
3919 opClass.addTrait("::mlir::OpTrait::AtLeastN" + traitKind + "s<" +
3920 Twine(numTotal - numVariadic) + ">::Impl");
3921 return;
3923 switch (numTotal) {
3924 case 0:
3925 opClass.addTrait("::mlir::OpTrait::Zero" + traitKind + "s");
3926 break;
3927 case 1:
3928 opClass.addTrait("::mlir::OpTrait::One" + traitKind);
3929 break;
3930 default:
3931 opClass.addTrait("::mlir::OpTrait::N" + traitKind + "s<" + Twine(numTotal) +
3932 ">::Impl");
3933 break;
3937 void OpEmitter::genTraits() {
3938 // Add region size trait.
3939 unsigned numRegions = op.getNumRegions();
3940 unsigned numVariadicRegions = op.getNumVariadicRegions();
3941 addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions);
3943 // Add result size traits.
3944 int numResults = op.getNumResults();
3945 int numVariadicResults = op.getNumVariableLengthResults();
3946 addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
3948 // For single result ops with a known specific type, generate a OneTypedResult
3949 // trait.
3950 if (numResults == 1 && numVariadicResults == 0) {
3951 auto cppName = op.getResults().begin()->constraint.getCppType();
3952 opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
3955 // Add successor size trait.
3956 unsigned numSuccessors = op.getNumSuccessors();
3957 unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
3958 addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors);
3960 // Add variadic size trait and normal op traits.
3961 int numOperands = op.getNumOperands();
3962 int numVariadicOperands = op.getNumVariableLengthOperands();
3964 // Add operand size trait.
3965 addSizeCountTrait(opClass, "Operand", numOperands, numVariadicOperands);
3967 // The op traits defined internal are ensured that they can be verified
3968 // earlier.
3969 for (const auto &trait : op.getTraits()) {
3970 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
3971 if (opTrait->isStructuralOpTrait())
3972 opClass.addTrait(opTrait->getFullyQualifiedTraitName());
3976 // OpInvariants wrapps the verifyInvariants which needs to be run before
3977 // native/interface traits and after all the traits with `StructuralOpTrait`.
3978 opClass.addTrait("::mlir::OpTrait::OpInvariants");
3980 if (emitHelper.hasProperties())
3981 opClass.addTrait("::mlir::BytecodeOpInterface::Trait");
3983 // Add the native and interface traits.
3984 for (const auto &trait : op.getTraits()) {
3985 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
3986 if (!opTrait->isStructuralOpTrait())
3987 opClass.addTrait(opTrait->getFullyQualifiedTraitName());
3988 } else if (auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait)) {
3989 opClass.addTrait(opTrait->getFullyQualifiedTraitName());
3994 void OpEmitter::genOpNameGetter() {
3995 auto *method = opClass.addStaticMethod<Method::Constexpr>(
3996 "::llvm::StringLiteral", "getOperationName");
3997 ERROR_IF_PRUNED(method, "getOperationName", op);
3998 method->body() << " return ::llvm::StringLiteral(\"" << op.getOperationName()
3999 << "\");";
4002 void OpEmitter::genOpAsmInterface() {
4003 // If the user only has one results or specifically added the Asm trait,
4004 // then don't generate it for them. We specifically only handle multi result
4005 // operations, because the name of a single result in the common case is not
4006 // interesting(generally 'result'/'output'/etc.).
4007 // TODO: We could also add a flag to allow operations to opt in to this
4008 // generation, even if they only have a single operation.
4009 int numResults = op.getNumResults();
4010 if (numResults <= 1 || op.getTrait("::mlir::OpAsmOpInterface::Trait"))
4011 return;
4013 SmallVector<StringRef, 4> resultNames(numResults);
4014 for (int i = 0; i != numResults; ++i)
4015 resultNames[i] = op.getResultName(i);
4017 // Don't add the trait if none of the results have a valid name.
4018 if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); }))
4019 return;
4020 opClass.addTrait("::mlir::OpAsmOpInterface::Trait");
4022 // Generate the right accessor for the number of results.
4023 auto *method = opClass.addMethod(
4024 "void", "getAsmResultNames",
4025 MethodParameter("::mlir::OpAsmSetValueNameFn", "setNameFn"));
4026 ERROR_IF_PRUNED(method, "getAsmResultNames", op);
4027 auto &body = method->body();
4028 for (int i = 0; i != numResults; ++i) {
4029 body << " auto resultGroup" << i << " = getODSResults(" << i << ");\n"
4030 << " if (!resultGroup" << i << ".empty())\n"
4031 << " setNameFn(*resultGroup" << i << ".begin(), \""
4032 << resultNames[i] << "\");\n";
4036 //===----------------------------------------------------------------------===//
4037 // OpOperandAdaptor emitter
4038 //===----------------------------------------------------------------------===//
4040 namespace {
4041 // Helper class to emit Op operand adaptors to an output stream. Operand
4042 // adaptors are wrappers around random access ranges that provide named operand
4043 // getters identical to those defined in the Op.
4044 // This currently generates 3 classes per Op:
4045 // * A Base class within the 'detail' namespace, which contains all logic and
4046 // members independent of the random access range that is indexed into.
4047 // In other words, it contains all the attribute and region getters.
4048 // * A templated class named '{OpName}GenericAdaptor' with a template parameter
4049 // 'RangeT' that is indexed into by the getters to access the operands.
4050 // It contains all getters to access operands and inherits from the previous
4051 // class.
4052 // * A class named '{OpName}Adaptor', which inherits from the 'GenericAdaptor'
4053 // with 'mlir::ValueRange' as template parameter. It adds a constructor from
4054 // an instance of the op type and a verify function.
4055 class OpOperandAdaptorEmitter {
4056 public:
4057 static void
4058 emitDecl(const Operator &op,
4059 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
4060 raw_ostream &os);
4061 static void
4062 emitDef(const Operator &op,
4063 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
4064 raw_ostream &os);
4066 private:
4067 explicit OpOperandAdaptorEmitter(
4068 const Operator &op,
4069 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
4071 // Add verification function. This generates a verify method for the adaptor
4072 // which verifies all the op-independent attribute constraints.
4073 void addVerification();
4075 // The operation for which to emit an adaptor.
4076 const Operator &op;
4078 // The generated adaptor classes.
4079 Class genericAdaptorBase;
4080 Class genericAdaptor;
4081 Class adaptor;
4083 // The emitter containing all of the locally emitted verification functions.
4084 const StaticVerifierFunctionEmitter &staticVerifierEmitter;
4086 // Helper for emitting adaptor code.
4087 OpOrAdaptorHelper emitHelper;
4089 } // namespace
4091 OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
4092 const Operator &op,
4093 const StaticVerifierFunctionEmitter &staticVerifierEmitter)
4094 : op(op), genericAdaptorBase(op.getGenericAdaptorName() + "Base"),
4095 genericAdaptor(op.getGenericAdaptorName()), adaptor(op.getAdaptorName()),
4096 staticVerifierEmitter(staticVerifierEmitter),
4097 emitHelper(op, /*emitForOp=*/false) {
4099 genericAdaptorBase.declare<VisibilityDeclaration>(Visibility::Public);
4100 bool useProperties = emitHelper.hasProperties();
4101 if (useProperties) {
4102 // Define the properties struct with multiple members.
4103 using ConstArgument =
4104 llvm::PointerUnion<const AttributeMetadata *, const NamedProperty *>;
4105 SmallVector<ConstArgument> attrOrProperties;
4106 for (const std::pair<StringRef, AttributeMetadata> &it :
4107 emitHelper.getAttrMetadata()) {
4108 if (!it.second.constraint || !it.second.constraint->isDerivedAttr())
4109 attrOrProperties.push_back(&it.second);
4111 for (const NamedProperty &prop : op.getProperties())
4112 attrOrProperties.push_back(&prop);
4113 if (emitHelper.getOperandSegmentsSize())
4114 attrOrProperties.push_back(&emitHelper.getOperandSegmentsSize().value());
4115 if (emitHelper.getResultSegmentsSize())
4116 attrOrProperties.push_back(&emitHelper.getResultSegmentsSize().value());
4117 assert(!attrOrProperties.empty());
4118 std::string declarations = " struct Properties {\n";
4119 llvm::raw_string_ostream os(declarations);
4120 std::string comparator =
4121 " bool operator==(const Properties &rhs) const {\n"
4122 " return \n";
4123 llvm::raw_string_ostream comparatorOs(comparator);
4124 for (const auto &attrOrProp : attrOrProperties) {
4125 if (const auto *namedProperty =
4126 llvm::dyn_cast_if_present<const NamedProperty *>(attrOrProp)) {
4127 StringRef name = namedProperty->name;
4128 if (name.empty())
4129 report_fatal_error("missing name for property");
4130 std::string camelName =
4131 convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true);
4132 auto &prop = namedProperty->prop;
4133 // Generate the data member using the storage type.
4134 os << " using " << name << "Ty = " << prop.getStorageType() << ";\n"
4135 << " " << name << "Ty " << name;
4136 if (prop.hasStorageTypeValueOverride())
4137 os << " = " << prop.getStorageTypeValueOverride();
4138 else if (prop.hasDefaultValue())
4139 os << " = " << prop.getDefaultValue();
4140 comparatorOs << " rhs." << name << " == this->" << name
4141 << " &&\n";
4142 // Emit accessors using the interface type.
4143 const char *accessorFmt = R"decl(;
4144 {0} get{1}() const {
4145 auto &propStorage = this->{2};
4146 return {3};
4148 void set{1}({0} propValue) {
4149 auto &propStorage = this->{2};
4150 {4};
4152 )decl";
4153 FmtContext fctx;
4154 os << formatv(accessorFmt, prop.getInterfaceType(), camelName, name,
4155 tgfmt(prop.getConvertFromStorageCall(),
4156 &fctx.addSubst("_storage", propertyStorage)),
4157 tgfmt(prop.getAssignToStorageCall(),
4158 &fctx.addSubst("_value", propertyValue)
4159 .addSubst("_storage", propertyStorage)));
4160 continue;
4162 const auto *namedAttr =
4163 llvm::dyn_cast_if_present<const AttributeMetadata *>(attrOrProp);
4164 const Attribute *attr = nullptr;
4165 if (namedAttr->constraint)
4166 attr = &*namedAttr->constraint;
4167 StringRef name = namedAttr->attrName;
4168 if (name.empty())
4169 report_fatal_error("missing name for property attr");
4170 std::string camelName =
4171 convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true);
4172 // Generate the data member using the storage type.
4173 StringRef storageType;
4174 if (attr) {
4175 storageType = attr->getStorageType();
4176 } else {
4177 if (name != operandSegmentAttrName && name != resultSegmentAttrName) {
4178 report_fatal_error("unexpected AttributeMetadata");
4180 // TODO: update to use native integers.
4181 storageType = "::mlir::DenseI32ArrayAttr";
4183 os << " using " << name << "Ty = " << storageType << ";\n"
4184 << " " << name << "Ty " << name << ";\n";
4185 comparatorOs << " rhs." << name << " == this->" << name << " &&\n";
4187 // Emit accessors using the interface type.
4188 if (attr) {
4189 const char *accessorFmt = R"decl(
4190 auto get{0}() {
4191 auto &propStorage = this->{1};
4192 return ::llvm::{2}<{3}>(propStorage);
4194 void set{0}(const {3} &propValue) {
4195 this->{1} = propValue;
4197 )decl";
4198 os << formatv(accessorFmt, camelName, name,
4199 attr->isOptional() || attr->hasDefaultValue()
4200 ? "dyn_cast_or_null"
4201 : "cast",
4202 storageType);
4205 comparatorOs << " true;\n }\n"
4206 " bool operator!=(const Properties &rhs) const {\n"
4207 " return !(*this == rhs);\n"
4208 " }\n";
4209 os << comparator;
4210 os << " };\n";
4212 genericAdaptorBase.declare<ExtraClassDeclaration>(std::move(declarations));
4214 genericAdaptorBase.declare<VisibilityDeclaration>(Visibility::Protected);
4215 genericAdaptorBase.declare<Field>("::mlir::DictionaryAttr", "odsAttrs");
4216 genericAdaptorBase.declare<Field>("::std::optional<::mlir::OperationName>",
4217 "odsOpName");
4218 if (useProperties)
4219 genericAdaptorBase.declare<Field>("Properties", "properties");
4220 genericAdaptorBase.declare<Field>("::mlir::RegionRange", "odsRegions");
4222 genericAdaptor.addTemplateParam("RangeT");
4223 genericAdaptor.addField("RangeT", "odsOperands");
4224 genericAdaptor.addParent(
4225 ParentClass("detail::" + genericAdaptorBase.getClassName()));
4226 genericAdaptor.declare<UsingDeclaration>(
4227 "ValueT", "::llvm::detail::ValueOfRange<RangeT>");
4228 genericAdaptor.declare<UsingDeclaration>(
4229 "Base", "detail::" + genericAdaptorBase.getClassName());
4231 const auto *attrSizedOperands =
4232 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
4234 SmallVector<MethodParameter> paramList;
4235 if (useProperties) {
4236 // Properties can't be given a default constructor here due to Properties
4237 // struct being defined in the enclosing class which isn't complete by
4238 // here.
4239 paramList.emplace_back("::mlir::DictionaryAttr", "attrs");
4240 paramList.emplace_back("const Properties &", "properties");
4241 } else {
4242 paramList.emplace_back("::mlir::DictionaryAttr", "attrs", "{}");
4243 paramList.emplace_back("const ::mlir::EmptyProperties &", "properties",
4244 "{}");
4246 paramList.emplace_back("::mlir::RegionRange", "regions", "{}");
4247 auto *baseConstructor =
4248 genericAdaptorBase.addConstructor<Method::Inline>(paramList);
4249 baseConstructor->addMemberInitializer("odsAttrs", "attrs");
4250 if (useProperties)
4251 baseConstructor->addMemberInitializer("properties", "properties");
4252 baseConstructor->addMemberInitializer("odsRegions", "regions");
4254 MethodBody &body = baseConstructor->body();
4255 body.indent() << "if (odsAttrs)\n";
4256 body.indent() << formatv(
4257 "odsOpName.emplace(\"{0}\", odsAttrs.getContext());\n",
4258 op.getOperationName());
4260 paramList.insert(paramList.begin(), MethodParameter("RangeT", "values"));
4261 auto *constructor = genericAdaptor.addConstructor(paramList);
4262 constructor->addMemberInitializer("Base", "attrs, properties, regions");
4263 constructor->addMemberInitializer("odsOperands", "values");
4265 // Add a forwarding constructor to the previous one that accepts
4266 // OpaqueProperties instead and check for null and perform the cast to the
4267 // actual properties type.
4268 paramList[1] = MethodParameter("::mlir::DictionaryAttr", "attrs");
4269 paramList[2] = MethodParameter("::mlir::OpaqueProperties", "properties");
4270 auto *opaquePropertiesConstructor =
4271 genericAdaptor.addConstructor(std::move(paramList));
4272 if (useProperties) {
4273 opaquePropertiesConstructor->addMemberInitializer(
4274 genericAdaptor.getClassName(),
4275 "values, "
4276 "attrs, "
4277 "(properties ? *properties.as<Properties *>() : Properties{}), "
4278 "regions");
4279 } else {
4280 opaquePropertiesConstructor->addMemberInitializer(
4281 genericAdaptor.getClassName(),
4282 "values, "
4283 "attrs, "
4284 "(properties ? *properties.as<::mlir::EmptyProperties *>() : "
4285 "::mlir::EmptyProperties{}), "
4286 "regions");
4289 // Add forwarding constructor that constructs Properties.
4290 if (useProperties) {
4291 SmallVector<MethodParameter> paramList;
4292 paramList.emplace_back("RangeT", "values");
4293 paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
4294 attrSizedOperands ? "" : "nullptr");
4295 auto *noPropertiesConstructor =
4296 genericAdaptor.addConstructor(std::move(paramList));
4297 noPropertiesConstructor->addMemberInitializer(
4298 genericAdaptor.getClassName(), "values, "
4299 "attrs, "
4300 "Properties{}, "
4301 "{}");
4305 // Create a constructor that creates a new generic adaptor by copying
4306 // everything from another adaptor, except for the values.
4308 SmallVector<MethodParameter> paramList;
4309 paramList.emplace_back("RangeT", "values");
4310 paramList.emplace_back("const " + op.getGenericAdaptorName() + "Base &",
4311 "base");
4312 auto *constructor =
4313 genericAdaptor.addConstructor<Method::Inline>(paramList);
4314 constructor->addMemberInitializer("Base", "base");
4315 constructor->addMemberInitializer("odsOperands", "values");
4318 // Create constructors constructing the adaptor from an instance of the op.
4319 // This takes the attributes, properties and regions from the op instance
4320 // and the value range from the parameter.
4322 // Base class is in the cpp file and can simply access the members of the op
4323 // class to initialize the template independent fields. If the op doesn't
4324 // have properties, we can emit a generic constructor inline. Otherwise,
4325 // emit it out-of-line because we need the op to be defined.
4326 Constructor *constructor;
4327 if (useProperties) {
4328 constructor = genericAdaptorBase.addConstructor(
4329 MethodParameter(op.getCppClassName(), "op"));
4330 } else {
4331 constructor = genericAdaptorBase.addConstructor<Method::Inline>(
4332 MethodParameter("::mlir::Operation *", "op"));
4334 constructor->addMemberInitializer("odsAttrs",
4335 "op->getRawDictionaryAttrs()");
4336 // Retrieve the operation name from the op directly.
4337 constructor->addMemberInitializer("odsOpName", "op->getName()");
4338 if (useProperties)
4339 constructor->addMemberInitializer("properties", "op.getProperties()");
4340 constructor->addMemberInitializer("odsRegions", "op->getRegions()");
4342 // Generic adaptor is templated and therefore defined inline in the header.
4343 // We cannot use the Op class here as it is an incomplete type (we have a
4344 // circular reference between the two).
4345 // Use a template trick to make the constructor be instantiated at call site
4346 // when the op class is complete.
4347 constructor = genericAdaptor.addConstructor(
4348 MethodParameter("RangeT", "values"), MethodParameter("LateInst", "op"));
4349 constructor->addTemplateParam("LateInst = " + op.getCppClassName());
4350 constructor->addTemplateParam(
4351 "= std::enable_if_t<std::is_same_v<LateInst, " + op.getCppClassName() +
4352 ">>");
4353 constructor->addMemberInitializer("Base", "op");
4354 constructor->addMemberInitializer("odsOperands", "values");
4357 std::string sizeAttrInit;
4358 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
4359 if (op.getDialect().usePropertiesForAttributes())
4360 sizeAttrInit =
4361 formatv(adapterSegmentSizeAttrInitCodeProperties,
4362 llvm::formatv("getProperties().operandSegmentSizes"));
4363 else
4364 sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode,
4365 emitHelper.getAttr(operandSegmentAttrName));
4367 generateNamedOperandGetters(op, genericAdaptor,
4368 /*genericAdaptorBase=*/&genericAdaptorBase,
4369 /*sizeAttrInit=*/sizeAttrInit,
4370 /*rangeType=*/"RangeT",
4371 /*rangeElementType=*/"ValueT",
4372 /*rangeBeginCall=*/"odsOperands.begin()",
4373 /*rangeSizeCall=*/"odsOperands.size()",
4374 /*getOperandCallPattern=*/"odsOperands[{0}]");
4376 // Any invalid overlap for `getOperands` will have been diagnosed before
4377 // here already.
4378 if (auto *m = genericAdaptor.addMethod("RangeT", "getOperands"))
4379 m->body() << " return odsOperands;";
4381 FmtContext fctx;
4382 fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
4384 // Generate named accessor with Attribute return type.
4385 auto emitAttrWithStorageType = [&](StringRef name, StringRef emitName,
4386 Attribute attr) {
4387 // The method body is trivial if the attribute does not have a default
4388 // value, in which case the default value may be arbitrary code.
4389 auto *method = genericAdaptorBase.addMethod(
4390 attr.getStorageType(), emitName + "Attr",
4391 attr.hasDefaultValue() || !useProperties ? Method::Properties::None
4392 : Method::Properties::Inline);
4393 ERROR_IF_PRUNED(method, "Adaptor::" + emitName + "Attr", op);
4394 auto &body = method->body().indent();
4395 if (!useProperties)
4396 body << "assert(odsAttrs && \"no attributes when constructing "
4397 "adapter\");\n";
4398 body << formatv(
4399 "auto attr = ::llvm::{1}<{2}>({0});\n", emitHelper.getAttr(name),
4400 attr.hasDefaultValue() || attr.isOptional() ? "dyn_cast_or_null"
4401 : "cast",
4402 attr.getStorageType());
4404 if (attr.hasDefaultValue() && attr.isOptional()) {
4405 // Use the default value if attribute is not set.
4406 // TODO: this is inefficient, we are recreating the attribute for every
4407 // call. This should be set instead.
4408 std::string defaultValue = std::string(
4409 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
4410 body << "if (!attr)\n attr = " << defaultValue << ";\n";
4412 body << "return attr;\n";
4415 if (useProperties) {
4416 auto *m = genericAdaptorBase.addInlineMethod("const Properties &",
4417 "getProperties");
4418 ERROR_IF_PRUNED(m, "Adaptor::getProperties", op);
4419 m->body() << " return properties;";
4422 auto *m = genericAdaptorBase.addInlineMethod("::mlir::DictionaryAttr",
4423 "getAttributes");
4424 ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op);
4425 m->body() << " return odsAttrs;";
4427 for (auto &namedProp : op.getProperties()) {
4428 std::string name = op.getGetterName(namedProp.name);
4429 emitPropGetter(genericAdaptorBase, op, name, namedProp.prop);
4432 for (auto &namedAttr : op.getAttributes()) {
4433 const auto &name = namedAttr.name;
4434 const auto &attr = namedAttr.attr;
4435 if (attr.isDerivedAttr())
4436 continue;
4437 std::string emitName = op.getGetterName(name);
4438 emitAttrWithStorageType(name, emitName, attr);
4439 emitAttrGetterWithReturnType(fctx, genericAdaptorBase, op, emitName, attr);
4442 unsigned numRegions = op.getNumRegions();
4443 for (unsigned i = 0; i < numRegions; ++i) {
4444 const auto &region = op.getRegion(i);
4445 if (region.name.empty())
4446 continue;
4448 // Generate the accessors for a variadic region.
4449 std::string name = op.getGetterName(region.name);
4450 if (region.isVariadic()) {
4451 auto *m = genericAdaptorBase.addInlineMethod("::mlir::RegionRange", name);
4452 ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
4453 m->body() << formatv(" return odsRegions.drop_front({0});", i);
4454 continue;
4457 auto *m = genericAdaptorBase.addInlineMethod("::mlir::Region &", name);
4458 ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
4459 m->body() << formatv(" return *odsRegions[{0}];", i);
4461 if (numRegions > 0) {
4462 // Any invalid overlap for `getRegions` will have been diagnosed before
4463 // here already.
4464 if (auto *m = genericAdaptorBase.addInlineMethod("::mlir::RegionRange",
4465 "getRegions"))
4466 m->body() << " return odsRegions;";
4469 StringRef genericAdaptorClassName = genericAdaptor.getClassName();
4470 adaptor.addParent(ParentClass(genericAdaptorClassName))
4471 .addTemplateParam("::mlir::ValueRange");
4472 adaptor.declare<VisibilityDeclaration>(Visibility::Public);
4473 adaptor.declare<UsingDeclaration>(genericAdaptorClassName +
4474 "::" + genericAdaptorClassName);
4476 // Constructor taking the Op as single parameter.
4477 auto *constructor =
4478 adaptor.addConstructor(MethodParameter(op.getCppClassName(), "op"));
4479 constructor->addMemberInitializer(genericAdaptorClassName,
4480 "op->getOperands(), op");
4483 // Add verification function.
4484 addVerification();
4486 genericAdaptorBase.finalize();
4487 genericAdaptor.finalize();
4488 adaptor.finalize();
4491 void OpOperandAdaptorEmitter::addVerification() {
4492 auto *method = adaptor.addMethod("::llvm::LogicalResult", "verify",
4493 MethodParameter("::mlir::Location", "loc"));
4494 ERROR_IF_PRUNED(method, "verify", op);
4495 auto &body = method->body();
4496 bool useProperties = emitHelper.hasProperties();
4498 FmtContext verifyCtx;
4499 populateSubstitutions(emitHelper, verifyCtx);
4500 genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter,
4501 useProperties);
4503 body << " return ::mlir::success();";
4506 void OpOperandAdaptorEmitter::emitDecl(
4507 const Operator &op,
4508 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
4509 raw_ostream &os) {
4510 OpOperandAdaptorEmitter emitter(op, staticVerifierEmitter);
4512 NamespaceEmitter ns(os, "detail");
4513 emitter.genericAdaptorBase.writeDeclTo(os);
4515 emitter.genericAdaptor.writeDeclTo(os);
4516 emitter.adaptor.writeDeclTo(os);
4519 void OpOperandAdaptorEmitter::emitDef(
4520 const Operator &op,
4521 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
4522 raw_ostream &os) {
4523 OpOperandAdaptorEmitter emitter(op, staticVerifierEmitter);
4525 NamespaceEmitter ns(os, "detail");
4526 emitter.genericAdaptorBase.writeDefTo(os);
4528 emitter.genericAdaptor.writeDefTo(os);
4529 emitter.adaptor.writeDefTo(os);
4532 /// Emit the class declarations or definitions for the given op defs.
4533 static void
4534 emitOpClasses(const RecordKeeper &records,
4535 const std::vector<const Record *> &defs, raw_ostream &os,
4536 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
4537 bool emitDecl) {
4538 if (defs.empty())
4539 return;
4541 for (auto *def : defs) {
4542 Operator op(*def);
4543 if (emitDecl) {
4545 NamespaceEmitter emitter(os, op.getCppNamespace());
4546 os << formatv(opCommentHeader, op.getQualCppClassName(),
4547 "declarations");
4548 OpOperandAdaptorEmitter::emitDecl(op, staticVerifierEmitter, os);
4549 OpEmitter::emitDecl(op, os, staticVerifierEmitter);
4551 // Emit the TypeID explicit specialization to have a single definition.
4552 if (!op.getCppNamespace().empty())
4553 os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << op.getCppNamespace()
4554 << "::" << op.getCppClassName() << ")\n\n";
4555 } else {
4557 NamespaceEmitter emitter(os, op.getCppNamespace());
4558 os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
4559 OpOperandAdaptorEmitter::emitDef(op, staticVerifierEmitter, os);
4560 OpEmitter::emitDef(op, os, staticVerifierEmitter);
4562 // Emit the TypeID explicit specialization to have a single definition.
4563 if (!op.getCppNamespace().empty())
4564 os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << op.getCppNamespace()
4565 << "::" << op.getCppClassName() << ")\n\n";
4570 /// Emit the declarations for the provided op classes.
4571 static void emitOpClassDecls(const RecordKeeper &records,
4572 const std::vector<const Record *> &defs,
4573 raw_ostream &os) {
4574 // First emit forward declaration for each class, this allows them to refer
4575 // to each others in traits for example.
4576 for (auto *def : defs) {
4577 Operator op(*def);
4578 NamespaceEmitter emitter(os, op.getCppNamespace());
4579 os << "class " << op.getCppClassName() << ";\n";
4582 // Emit the op class declarations.
4583 IfDefScope scope("GET_OP_CLASSES", os);
4584 if (defs.empty())
4585 return;
4586 StaticVerifierFunctionEmitter staticVerifierEmitter(os, records);
4587 staticVerifierEmitter.collectOpConstraints(defs);
4588 emitOpClasses(records, defs, os, staticVerifierEmitter,
4589 /*emitDecl=*/true);
4592 /// Emit the definitions for the provided op classes.
4593 static void emitOpClassDefs(const RecordKeeper &records,
4594 ArrayRef<const Record *> defs, raw_ostream &os,
4595 StringRef constraintPrefix = "") {
4596 if (defs.empty())
4597 return;
4599 // Generate all of the locally instantiated methods first.
4600 StaticVerifierFunctionEmitter staticVerifierEmitter(os, records,
4601 constraintPrefix);
4602 os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
4603 staticVerifierEmitter.collectOpConstraints(defs);
4604 staticVerifierEmitter.emitOpConstraints(defs);
4606 // Emit the classes.
4607 emitOpClasses(records, defs, os, staticVerifierEmitter,
4608 /*emitDecl=*/false);
4611 /// Emit op declarations for all op records.
4612 static bool emitOpDecls(const RecordKeeper &records, raw_ostream &os) {
4613 emitSourceFileHeader("Op Declarations", os, records);
4615 std::vector<const Record *> defs = getRequestedOpDefinitions(records);
4616 emitOpClassDecls(records, defs, os);
4618 // If we are generating sharded op definitions, emit the sharded op
4619 // registration hooks.
4620 SmallVector<ArrayRef<const Record *>, 4> shardedDefs;
4621 shardOpDefinitions(defs, shardedDefs);
4622 if (defs.empty() || shardedDefs.size() <= 1)
4623 return false;
4625 Dialect dialect = Operator(defs.front()).getDialect();
4626 NamespaceEmitter ns(os, dialect);
4628 const char *const opRegistrationHook =
4629 "void register{0}Operations{1}({2}::{0} *dialect);\n";
4630 os << formatv(opRegistrationHook, dialect.getCppClassName(), "",
4631 dialect.getCppNamespace());
4632 for (unsigned i = 0; i < shardedDefs.size(); ++i) {
4633 os << formatv(opRegistrationHook, dialect.getCppClassName(), i,
4634 dialect.getCppNamespace());
4637 return false;
4640 /// Generate the dialect op registration hook and the op class definitions for a
4641 /// shard of ops.
4642 static void emitOpDefShard(const RecordKeeper &records,
4643 ArrayRef<const Record *> defs,
4644 const Dialect &dialect, unsigned shardIndex,
4645 unsigned shardCount, raw_ostream &os) {
4646 std::string shardGuard = "GET_OP_DEFS_";
4647 std::string indexStr = std::to_string(shardIndex);
4648 shardGuard += indexStr;
4649 IfDefScope scope(shardGuard, os);
4651 // Emit the op registration hook in the first shard.
4652 const char *const opRegistrationHook =
4653 "void {0}::register{1}Operations{2}({0}::{1} *dialect) {{\n";
4654 if (shardIndex == 0) {
4655 os << formatv(opRegistrationHook, dialect.getCppNamespace(),
4656 dialect.getCppClassName(), "");
4657 for (unsigned i = 0; i < shardCount; ++i) {
4658 os << formatv(" {0}::register{1}Operations{2}(dialect);\n",
4659 dialect.getCppNamespace(), dialect.getCppClassName(), i);
4661 os << "}\n";
4664 // Generate the per-shard op registration hook.
4665 os << formatv(opCommentHeader, dialect.getCppClassName(),
4666 "Op Registration Hook")
4667 << formatv(opRegistrationHook, dialect.getCppNamespace(),
4668 dialect.getCppClassName(), shardIndex);
4669 for (const Record *def : defs) {
4670 os << formatv(" ::mlir::RegisteredOperationName::insert<{0}>(*dialect);\n",
4671 Operator(def).getQualCppClassName());
4673 os << "}\n";
4675 // Generate the per-shard op definitions.
4676 emitOpClassDefs(records, defs, os, indexStr);
4679 /// Emit op definitions for all op records.
4680 static bool emitOpDefs(const RecordKeeper &records, raw_ostream &os) {
4681 emitSourceFileHeader("Op Definitions", os, records);
4683 std::vector<const Record *> defs = getRequestedOpDefinitions(records);
4684 SmallVector<ArrayRef<const Record *>, 4> shardedDefs;
4685 shardOpDefinitions(defs, shardedDefs);
4687 // If no shard was requested, emit the regular op list and class definitions.
4688 if (shardedDefs.size() == 1) {
4690 IfDefScope scope("GET_OP_LIST", os);
4691 interleave(
4692 defs, os,
4693 [&](const Record *def) { os << Operator(def).getQualCppClassName(); },
4694 ",\n");
4697 IfDefScope scope("GET_OP_CLASSES", os);
4698 emitOpClassDefs(records, defs, os);
4700 return false;
4703 if (defs.empty())
4704 return false;
4705 Dialect dialect = Operator(defs.front()).getDialect();
4706 for (auto [idx, value] : llvm::enumerate(shardedDefs)) {
4707 emitOpDefShard(records, value, dialect, idx, shardedDefs.size(), os);
4709 return false;
4712 static mlir::GenRegistration
4713 genOpDecls("gen-op-decls", "Generate op declarations",
4714 [](const RecordKeeper &records, raw_ostream &os) {
4715 return emitOpDecls(records, os);
4718 static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions",
4719 [](const RecordKeeper &records,
4720 raw_ostream &os) {
4721 return emitOpDefs(records, os);