[Infra] Fix version-check workflow (#100090)
[llvm-project.git] / mlir / tools / mlir-tblgen / OpPythonBindingGen.cpp
blob052020acdcb764d2b6f6e7a12b3e5f74114941fd
1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
10 // binding classes wrapping a generic operation API.
12 //===----------------------------------------------------------------------===//
14 #include "OpGenHelpers.h"
16 #include "mlir/TableGen/GenInfo.h"
17 #include "mlir/TableGen/Operator.h"
18 #include "llvm/ADT/StringSet.h"
19 #include "llvm/Support/CommandLine.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
24 using namespace mlir;
25 using namespace mlir::tblgen;
27 /// File header and includes.
28 /// {0} is the dialect namespace.
29 constexpr const char *fileHeader = R"Py(
30 # Autogenerated by mlir-tblgen; don't manually edit.
32 from ._ods_common import _cext as _ods_cext
33 from ._ods_common import (
34 equally_sized_accessor as _ods_equally_sized_accessor,
35 get_default_loc_context as _ods_get_default_loc_context,
36 get_op_result_or_op_results as _get_op_result_or_op_results,
37 get_op_result_or_value as _get_op_result_or_value,
38 get_op_results_or_values as _get_op_results_or_values,
39 segmented_accessor as _ods_segmented_accessor,
41 _ods_ir = _ods_cext.ir
43 import builtins
44 from typing import Sequence as _Sequence, Union as _Union
46 )Py";
48 /// Template for dialect class:
49 /// {0} is the dialect namespace.
50 constexpr const char *dialectClassTemplate = R"Py(
51 @_ods_cext.register_dialect
52 class _Dialect(_ods_ir.Dialect):
53 DIALECT_NAMESPACE = "{0}"
54 )Py";
56 constexpr const char *dialectExtensionTemplate = R"Py(
57 from ._{0}_ops_gen import _Dialect
58 )Py";
60 /// Template for operation class:
61 /// {0} is the Python class name;
62 /// {1} is the operation name.
63 constexpr const char *opClassTemplate = R"Py(
64 @_ods_cext.register_operation(_Dialect)
65 class {0}(_ods_ir.OpView):
66 OPERATION_NAME = "{1}"
67 )Py";
69 /// Template for class level declarations of operand and result
70 /// segment specs.
71 /// {0} is either "OPERAND" or "RESULT"
72 /// {1} is the segment spec
73 /// Each segment spec is either None (default) or an array of integers
74 /// where:
75 /// 1 = single element (expect non sequence operand/result)
76 /// 0 = optional element (expect a value or std::nullopt)
77 /// -1 = operand/result is a sequence corresponding to a variadic
78 constexpr const char *opClassSizedSegmentsTemplate = R"Py(
79 _ODS_{0}_SEGMENTS = {1}
80 )Py";
82 /// Template for class level declarations of the _ODS_REGIONS spec:
83 /// {0} is the minimum number of regions
84 /// {1} is the Python bool literal for hasNoVariadicRegions
85 constexpr const char *opClassRegionSpecTemplate = R"Py(
86 _ODS_REGIONS = ({0}, {1})
87 )Py";
89 /// Template for single-element accessor:
90 /// {0} is the name of the accessor;
91 /// {1} is either 'operand' or 'result';
92 /// {2} is the position in the element list.
93 constexpr const char *opSingleTemplate = R"Py(
94 @builtins.property
95 def {0}(self):
96 return self.operation.{1}s[{2}]
97 )Py";
99 /// Template for single-element accessor after a variable-length group:
100 /// {0} is the name of the accessor;
101 /// {1} is either 'operand' or 'result';
102 /// {2} is the total number of element groups;
103 /// {3} is the position of the current group in the group list.
104 /// This works for both a single variadic group (non-negative length) and an
105 /// single optional element (zero length if the element is absent).
106 constexpr const char *opSingleAfterVariableTemplate = R"Py(
107 @builtins.property
108 def {0}(self):
109 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
110 return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
111 )Py";
113 /// Template for an optional element accessor:
114 /// {0} is the name of the accessor;
115 /// {1} is either 'operand' or 'result';
116 /// {2} is the total number of element groups;
117 /// {3} is the position of the current group in the group list.
118 /// This works if we have only one variable-length group (and it's the optional
119 /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
120 /// smaller than the total number of groups.
121 constexpr const char *opOneOptionalTemplate = R"Py(
122 @builtins.property
123 def {0}(self):
124 return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
125 )Py";
127 /// Template for the variadic group accessor in the single variadic group case:
128 /// {0} is the name of the accessor;
129 /// {1} is either 'operand' or 'result';
130 /// {2} is the total number of element groups;
131 /// {3} is the position of the current group in the group list.
132 constexpr const char *opOneVariadicTemplate = R"Py(
133 @builtins.property
134 def {0}(self):
135 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
136 return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
137 )Py";
139 /// First part of the template for equally-sized variadic group accessor:
140 /// {0} is the name of the accessor;
141 /// {1} is either 'operand' or 'result';
142 /// {2} is the total number of variadic groups;
143 /// {3} is the number of non-variadic groups preceding the current group;
144 /// {3} is the number of variadic groups preceding the current group.
145 constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
146 @builtins.property
147 def {0}(self):
148 start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
150 /// Second part of the template for equally-sized case, accessing a single
151 /// element:
152 /// {0} is either 'operand' or 'result'.
153 constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
154 return self.operation.{0}s[start]
155 )Py";
157 /// Second part of the template for equally-sized case, accessing a variadic
158 /// group:
159 /// {0} is either 'operand' or 'result'.
160 constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
161 return self.operation.{0}s[start:start + pg]
162 )Py";
164 /// Template for an attribute-sized group accessor:
165 /// {0} is the name of the accessor;
166 /// {1} is either 'operand' or 'result';
167 /// {2} is the position of the group in the group list;
168 /// {3} is a return suffix (expected [0] for single-element, empty for
169 /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
170 constexpr const char *opVariadicSegmentTemplate = R"Py(
171 @builtins.property
172 def {0}(self):
173 {1}_range = _ods_segmented_accessor(
174 self.operation.{1}s,
175 self.operation.attributes["{1}SegmentSizes"], {2})
176 return {1}_range{3}
177 )Py";
179 /// Template for a suffix when accessing an optional element in the
180 /// attribute-sized case:
181 /// {0} is either 'operand' or 'result';
182 constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
183 R"Py([0] if len({0}_range) > 0 else None)Py";
185 /// Template for an operation attribute getter:
186 /// {0} is the name of the attribute sanitized for Python;
187 /// {1} is the original name of the attribute.
188 constexpr const char *attributeGetterTemplate = R"Py(
189 @builtins.property
190 def {0}(self):
191 return self.operation.attributes["{1}"]
192 )Py";
194 /// Template for an optional operation attribute getter:
195 /// {0} is the name of the attribute sanitized for Python;
196 /// {1} is the original name of the attribute.
197 constexpr const char *optionalAttributeGetterTemplate = R"Py(
198 @builtins.property
199 def {0}(self):
200 if "{1}" not in self.operation.attributes:
201 return None
202 return self.operation.attributes["{1}"]
203 )Py";
205 /// Template for a getter of a unit operation attribute, returns True of the
206 /// unit attribute is present, False otherwise (unit attributes have meaning
207 /// by mere presence):
208 /// {0} is the name of the attribute sanitized for Python,
209 /// {1} is the original name of the attribute.
210 constexpr const char *unitAttributeGetterTemplate = R"Py(
211 @builtins.property
212 def {0}(self):
213 return "{1}" in self.operation.attributes
214 )Py";
216 /// Template for an operation attribute setter:
217 /// {0} is the name of the attribute sanitized for Python;
218 /// {1} is the original name of the attribute.
219 constexpr const char *attributeSetterTemplate = R"Py(
220 @{0}.setter
221 def {0}(self, value):
222 if value is None:
223 raise ValueError("'None' not allowed as value for mandatory attributes")
224 self.operation.attributes["{1}"] = value
225 )Py";
227 /// Template for a setter of an optional operation attribute, setting to None
228 /// removes the attribute:
229 /// {0} is the name of the attribute sanitized for Python;
230 /// {1} is the original name of the attribute.
231 constexpr const char *optionalAttributeSetterTemplate = R"Py(
232 @{0}.setter
233 def {0}(self, value):
234 if value is not None:
235 self.operation.attributes["{1}"] = value
236 elif "{1}" in self.operation.attributes:
237 del self.operation.attributes["{1}"]
238 )Py";
240 /// Template for a setter of a unit operation attribute, setting to None or
241 /// False removes the attribute:
242 /// {0} is the name of the attribute sanitized for Python;
243 /// {1} is the original name of the attribute.
244 constexpr const char *unitAttributeSetterTemplate = R"Py(
245 @{0}.setter
246 def {0}(self, value):
247 if bool(value):
248 self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get()
249 elif "{1}" in self.operation.attributes:
250 del self.operation.attributes["{1}"]
251 )Py";
253 /// Template for a deleter of an optional or a unit operation attribute, removes
254 /// the attribute from the operation:
255 /// {0} is the name of the attribute sanitized for Python;
256 /// {1} is the original name of the attribute.
257 constexpr const char *attributeDeleterTemplate = R"Py(
258 @{0}.deleter
259 def {0}(self):
260 del self.operation.attributes["{1}"]
261 )Py";
263 constexpr const char *regionAccessorTemplate = R"Py(
264 @builtins.property
265 def {0}(self):
266 return self.regions[{1}]
267 )Py";
269 constexpr const char *valueBuilderTemplate = R"Py(
270 def {0}({2}) -> {4}:
271 return _get_op_result_or_op_results({1}({3}))
272 )Py";
274 static llvm::cl::OptionCategory
275 clOpPythonBindingCat("Options for -gen-python-op-bindings");
277 static llvm::cl::opt<std::string>
278 clDialectName("bind-dialect",
279 llvm::cl::desc("The dialect to run the generator for"),
280 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
282 static llvm::cl::opt<std::string> clDialectExtensionName(
283 "dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
284 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
286 using AttributeClasses = DenseMap<StringRef, StringRef>;
288 /// Checks whether `str` would shadow a generated variable or attribute
289 /// part of the OpView API.
290 static bool isODSReserved(StringRef str) {
291 static llvm::StringSet<> reserved(
292 {"attributes", "create", "context", "ip", "operands", "print", "get_asm",
293 "loc", "verify", "regions", "results", "self", "operation",
294 "DIALECT_NAMESPACE", "OPERATION_NAME"});
295 return str.starts_with("_ods_") || str.ends_with("_ods") ||
296 reserved.contains(str);
299 /// Modifies the `name` in a way that it becomes suitable for Python bindings
300 /// (does not change the `name` if it already is suitable) and returns the
301 /// modified version.
302 static std::string sanitizeName(StringRef name) {
303 std::string processedStr = name.str();
304 std::replace_if(
305 processedStr.begin(), processedStr.end(),
306 [](char c) { return !llvm::isAlnum(c); }, '_');
308 if (llvm::isDigit(*processedStr.begin()))
309 return "_" + processedStr;
311 if (isPythonReserved(processedStr) || isODSReserved(processedStr))
312 return processedStr + "_";
313 return processedStr;
316 static std::string attrSizedTraitForKind(const char *kind) {
317 return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
318 llvm::StringRef(kind).take_front().upper(),
319 llvm::StringRef(kind).drop_front());
322 /// Emits accessors to "elements" of an Op definition. Currently, the supported
323 /// elements are operands and results, indicated by `kind`, which must be either
324 /// `operand` or `result` and is used verbatim in the emitted code.
325 static void emitElementAccessors(
326 const Operator &op, raw_ostream &os, const char *kind,
327 llvm::function_ref<unsigned(const Operator &)> getNumVariableLength,
328 llvm::function_ref<int(const Operator &)> getNumElements,
329 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
330 getElement) {
331 assert(llvm::is_contained(
332 llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) &&
333 "unsupported kind");
335 // Traits indicating how to process variadic elements.
336 std::string sameSizeTrait =
337 llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
338 llvm::StringRef(kind).take_front().upper(),
339 llvm::StringRef(kind).drop_front());
340 std::string attrSizedTrait = attrSizedTraitForKind(kind);
342 unsigned numVariableLength = getNumVariableLength(op);
344 // If there is only one variable-length element group, its size can be
345 // inferred from the total number of elements. If there are none, the
346 // generation is straightforward.
347 if (numVariableLength <= 1) {
348 bool seenVariableLength = false;
349 for (int i = 0, e = getNumElements(op); i < e; ++i) {
350 const NamedTypeConstraint &element = getElement(op, i);
351 if (element.isVariableLength())
352 seenVariableLength = true;
353 if (element.name.empty())
354 continue;
355 if (element.isVariableLength()) {
356 os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
357 : opOneVariadicTemplate,
358 sanitizeName(element.name), kind,
359 getNumElements(op), i);
360 } else if (seenVariableLength) {
361 os << llvm::formatv(opSingleAfterVariableTemplate,
362 sanitizeName(element.name), kind,
363 getNumElements(op), i);
364 } else {
365 os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
369 return;
372 // Handle the operations where variadic groups have the same size.
373 if (op.getTrait(sameSizeTrait)) {
374 int numPrecedingSimple = 0;
375 int numPrecedingVariadic = 0;
376 for (int i = 0, e = getNumElements(op); i < e; ++i) {
377 const NamedTypeConstraint &element = getElement(op, i);
378 if (!element.name.empty()) {
379 os << llvm::formatv(opVariadicEqualPrefixTemplate,
380 sanitizeName(element.name), kind, numVariableLength,
381 numPrecedingSimple, numPrecedingVariadic);
382 os << llvm::formatv(element.isVariableLength()
383 ? opVariadicEqualVariadicTemplate
384 : opVariadicEqualSimpleTemplate,
385 kind);
387 if (element.isVariableLength())
388 ++numPrecedingVariadic;
389 else
390 ++numPrecedingSimple;
392 return;
395 // Handle the operations where the size of groups (variadic or not) is
396 // provided as an attribute. For non-variadic elements, make sure to return
397 // an element rather than a singleton container.
398 if (op.getTrait(attrSizedTrait)) {
399 for (int i = 0, e = getNumElements(op); i < e; ++i) {
400 const NamedTypeConstraint &element = getElement(op, i);
401 if (element.name.empty())
402 continue;
403 std::string trailing;
404 if (!element.isVariableLength())
405 trailing = "[0]";
406 else if (element.isOptional())
407 trailing = std::string(
408 llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
409 os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name),
410 kind, i, trailing);
412 return;
415 llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
418 /// Free function helpers accessing Operator components.
419 static int getNumOperands(const Operator &op) { return op.getNumOperands(); }
420 static const NamedTypeConstraint &getOperand(const Operator &op, int i) {
421 return op.getOperand(i);
423 static int getNumResults(const Operator &op) { return op.getNumResults(); }
424 static const NamedTypeConstraint &getResult(const Operator &op, int i) {
425 return op.getResult(i);
428 /// Emits accessors to Op operands.
429 static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
430 auto getNumVariableLengthOperands = [](const Operator &oper) {
431 return oper.getNumVariableLengthOperands();
433 emitElementAccessors(op, os, "operand", getNumVariableLengthOperands,
434 getNumOperands, getOperand);
437 /// Emits accessors Op results.
438 static void emitResultAccessors(const Operator &op, raw_ostream &os) {
439 auto getNumVariableLengthResults = [](const Operator &oper) {
440 return oper.getNumVariableLengthResults();
442 emitElementAccessors(op, os, "result", getNumVariableLengthResults,
443 getNumResults, getResult);
446 /// Emits accessors to Op attributes.
447 static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
448 for (const auto &namedAttr : op.getAttributes()) {
449 // Skip "derived" attributes because they are just C++ functions that we
450 // don't currently expose.
451 if (namedAttr.attr.isDerivedAttr())
452 continue;
454 if (namedAttr.name.empty())
455 continue;
457 std::string sanitizedName = sanitizeName(namedAttr.name);
459 // Unit attributes are handled specially.
460 if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") {
461 os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,
462 namedAttr.name);
463 os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,
464 namedAttr.name);
465 os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
466 namedAttr.name);
467 continue;
470 if (namedAttr.attr.isOptional()) {
471 os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName,
472 namedAttr.name);
473 os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName,
474 namedAttr.name);
475 os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
476 namedAttr.name);
477 } else {
478 os << llvm::formatv(attributeGetterTemplate, sanitizedName,
479 namedAttr.name);
480 os << llvm::formatv(attributeSetterTemplate, sanitizedName,
481 namedAttr.name);
482 // Non-optional attributes cannot be deleted.
487 /// Template for the default auto-generated builder.
488 /// {0} is a comma-separated list of builder arguments, including the trailing
489 /// `loc` and `ip`;
490 /// {1} is the code populating `operands`, `results` and `attributes`,
491 /// `successors` fields.
492 constexpr const char *initTemplate = R"Py(
493 def __init__(self, {0}):
494 operands = []
495 results = []
496 attributes = {{}
497 regions = None
499 super().__init__(self.build_generic({2}))
500 )Py";
502 /// Template for appending a single element to the operand/result list.
503 /// {0} is the field name.
504 constexpr const char *singleOperandAppendTemplate =
505 "operands.append(_get_op_result_or_value({0}))";
506 constexpr const char *singleResultAppendTemplate = "results.append({0})";
508 /// Template for appending an optional element to the operand/result list.
509 /// {0} is the field name.
510 constexpr const char *optionalAppendOperandTemplate =
511 "if {0} is not None: operands.append(_get_op_result_or_value({0}))";
512 constexpr const char *optionalAppendAttrSizedOperandsTemplate =
513 "operands.append(_get_op_result_or_value({0}) if {0} is not None else "
514 "None)";
515 constexpr const char *optionalAppendResultTemplate =
516 "if {0} is not None: results.append({0})";
518 /// Template for appending a list of elements to the operand/result list.
519 /// {0} is the field name.
520 constexpr const char *multiOperandAppendTemplate =
521 "operands.extend(_get_op_results_or_values({0}))";
522 constexpr const char *multiOperandAppendPackTemplate =
523 "operands.append(_get_op_results_or_values({0}))";
524 constexpr const char *multiResultAppendTemplate = "results.extend({0})";
526 /// Template for attribute builder from raw input in the operation builder.
527 /// {0} is the builder argument name;
528 /// {1} is the attribute builder from raw;
529 /// {2} is the attribute builder from raw.
530 /// Use the value the user passed in if either it is already an Attribute or
531 /// there is no method registered to make it an Attribute.
532 constexpr const char *initAttributeWithBuilderTemplate =
533 R"Py(attributes["{1}"] = ({0} if (
534 isinstance({0}, _ods_ir.Attribute) or
535 not _ods_ir.AttrBuilder.contains('{2}')) else
536 _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
538 /// Template for attribute builder from raw input for optional attribute in the
539 /// operation builder.
540 /// {0} is the builder argument name;
541 /// {1} is the attribute builder from raw;
542 /// {2} is the attribute builder from raw.
543 /// Use the value the user passed in if either it is already an Attribute or
544 /// there is no method registered to make it an Attribute.
545 constexpr const char *initOptionalAttributeWithBuilderTemplate =
546 R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
547 isinstance({0}, _ods_ir.Attribute) or
548 not _ods_ir.AttrBuilder.contains('{2}')) else
549 _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
551 constexpr const char *initUnitAttributeTemplate =
552 R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
553 _ods_get_default_loc_context(loc)))Py";
555 /// Template to initialize the successors list in the builder if there are any
556 /// successors.
557 /// {0} is the value to initialize the successors list to.
558 constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py";
560 /// Template to append or extend the list of successors in the builder.
561 /// {0} is the list method ('append' or 'extend');
562 /// {1} is the value to add.
563 constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py";
565 /// Returns true if the SameArgumentAndResultTypes trait can be used to infer
566 /// result types of the given operation.
567 static bool hasSameArgumentAndResultTypes(const Operator &op) {
568 return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") &&
569 op.getNumVariableLengthResults() == 0;
572 /// Returns true if the FirstAttrDerivedResultType trait can be used to infer
573 /// result types of the given operation.
574 static bool hasFirstAttrDerivedResultTypes(const Operator &op) {
575 return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") &&
576 op.getNumVariableLengthResults() == 0;
579 /// Returns true if the InferTypeOpInterface can be used to infer result types
580 /// of the given operation.
581 static bool hasInferTypeInterface(const Operator &op) {
582 return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
583 op.getNumRegions() == 0;
586 /// Returns true if there is a trait or interface that can be used to infer
587 /// result types of the given operation.
588 static bool canInferType(const Operator &op) {
589 return hasSameArgumentAndResultTypes(op) ||
590 hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op);
593 /// Populates `builderArgs` with result names if the builder is expected to
594 /// accept them as arguments.
595 static void
596 populateBuilderArgsResults(const Operator &op,
597 llvm::SmallVectorImpl<std::string> &builderArgs) {
598 if (canInferType(op))
599 return;
601 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
602 std::string name = op.getResultName(i).str();
603 if (name.empty()) {
604 if (op.getNumResults() == 1) {
605 // Special case for one result, make the default name be 'result'
606 // to properly match the built-in result accessor.
607 name = "result";
608 } else {
609 name = llvm::formatv("_gen_res_{0}", i);
612 name = sanitizeName(name);
613 builderArgs.push_back(name);
617 /// Populates `builderArgs` with the Python-compatible names of builder function
618 /// arguments using intermixed attributes and operands in the same order as they
619 /// appear in the `arguments` field of the op definition. Additionally,
620 /// `operandNames` is populated with names of operands in their order of
621 /// appearance.
622 static void
623 populateBuilderArgs(const Operator &op,
624 llvm::SmallVectorImpl<std::string> &builderArgs,
625 llvm::SmallVectorImpl<std::string> &operandNames) {
626 for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
627 std::string name = op.getArgName(i).str();
628 if (name.empty())
629 name = llvm::formatv("_gen_arg_{0}", i);
630 name = sanitizeName(name);
631 builderArgs.push_back(name);
632 if (!op.getArg(i).is<NamedAttribute *>())
633 operandNames.push_back(name);
637 /// Populates `builderArgs` with the Python-compatible names of builder function
638 /// successor arguments. Additionally, `successorArgNames` is also populated.
639 static void populateBuilderArgsSuccessors(
640 const Operator &op, llvm::SmallVectorImpl<std::string> &builderArgs,
641 llvm::SmallVectorImpl<std::string> &successorArgNames) {
643 for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
644 NamedSuccessor successor = op.getSuccessor(i);
645 std::string name = std::string(successor.name);
646 if (name.empty())
647 name = llvm::formatv("_gen_successor_{0}", i);
648 name = sanitizeName(name);
649 builderArgs.push_back(name);
650 successorArgNames.push_back(name);
654 /// Populates `builderLines` with additional lines that are required in the
655 /// builder to set up operation attributes. `argNames` is expected to contain
656 /// the names of builder arguments that correspond to op arguments, i.e. to the
657 /// operands and attributes in the same order as they appear in the `arguments`
658 /// field.
659 static void
660 populateBuilderLinesAttr(const Operator &op,
661 llvm::ArrayRef<std::string> argNames,
662 llvm::SmallVectorImpl<std::string> &builderLines) {
663 builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)");
664 for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
665 Argument arg = op.getArg(i);
666 auto *attribute = llvm::dyn_cast_if_present<NamedAttribute *>(arg);
667 if (!attribute)
668 continue;
670 // Unit attributes are handled specially.
671 if (attribute->attr.getStorageType().trim() == "::mlir::UnitAttr") {
672 builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
673 attribute->name, argNames[i]));
674 continue;
677 builderLines.push_back(llvm::formatv(
678 attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
679 ? initOptionalAttributeWithBuilderTemplate
680 : initAttributeWithBuilderTemplate,
681 argNames[i], attribute->name, attribute->attr.getAttrDefName()));
685 /// Populates `builderLines` with additional lines that are required in the
686 /// builder to set up successors. successorArgNames is expected to correspond
687 /// to the Python argument name for each successor on the op.
688 static void populateBuilderLinesSuccessors(
689 const Operator &op, llvm::ArrayRef<std::string> successorArgNames,
690 llvm::SmallVectorImpl<std::string> &builderLines) {
691 if (successorArgNames.empty()) {
692 builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None"));
693 return;
696 builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]"));
697 for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
698 auto &argName = successorArgNames[i];
699 const NamedSuccessor &successor = op.getSuccessor(i);
700 builderLines.push_back(
701 llvm::formatv(addSuccessorTemplate,
702 successor.isVariadic() ? "extend" : "append", argName));
706 /// Populates `builderLines` with additional lines that are required in the
707 /// builder to set up op operands.
708 static void
709 populateBuilderLinesOperand(const Operator &op,
710 llvm::ArrayRef<std::string> names,
711 llvm::SmallVectorImpl<std::string> &builderLines) {
712 bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;
714 // For each element, find or generate a name.
715 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
716 const NamedTypeConstraint &element = op.getOperand(i);
717 std::string name = names[i];
719 // Choose the formatting string based on the element kind.
720 llvm::StringRef formatString;
721 if (!element.isVariableLength()) {
722 formatString = singleOperandAppendTemplate;
723 } else if (element.isOptional()) {
724 if (sizedSegments) {
725 formatString = optionalAppendAttrSizedOperandsTemplate;
726 } else {
727 formatString = optionalAppendOperandTemplate;
729 } else {
730 assert(element.isVariadic() && "unhandled element group type");
731 // If emitting with sizedSegments, then we add the actual list-typed
732 // element. Otherwise, we extend the actual operands.
733 if (sizedSegments) {
734 formatString = multiOperandAppendPackTemplate;
735 } else {
736 formatString = multiOperandAppendTemplate;
740 builderLines.push_back(llvm::formatv(formatString.data(), name));
744 /// Python code template for deriving the operation result types from its
745 /// attribute:
746 /// - {0} is the name of the attribute from which to derive the types.
747 constexpr const char *deriveTypeFromAttrTemplate =
748 R"Py(_ods_result_type_source_attr = attributes["{0}"]
749 _ods_derived_result_type = (
750 _ods_ir.TypeAttr(_ods_result_type_source_attr).value
751 if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
752 _ods_result_type_source_attr.type))Py";
754 /// Python code template appending {0} type {1} times to the results list.
755 constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
757 /// Appends the given multiline string as individual strings into
758 /// `builderLines`.
759 static void appendLineByLine(StringRef string,
760 llvm::SmallVectorImpl<std::string> &builderLines) {
762 std::pair<StringRef, StringRef> split = std::make_pair(string, string);
763 do {
764 split = split.second.split('\n');
765 builderLines.push_back(split.first.str());
766 } while (!split.second.empty());
769 /// Populates `builderLines` with additional lines that are required in the
770 /// builder to set up op results.
771 static void
772 populateBuilderLinesResult(const Operator &op,
773 llvm::ArrayRef<std::string> names,
774 llvm::SmallVectorImpl<std::string> &builderLines) {
775 bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
777 if (hasSameArgumentAndResultTypes(op)) {
778 builderLines.push_back(llvm::formatv(
779 appendSameResultsTemplate, "operands[0].type", op.getNumResults()));
780 return;
783 if (hasFirstAttrDerivedResultTypes(op)) {
784 const NamedAttribute &firstAttr = op.getAttribute(0);
785 assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
786 "from which the type is derived");
787 appendLineByLine(
788 llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
789 builderLines);
790 builderLines.push_back(llvm::formatv(appendSameResultsTemplate,
791 "_ods_derived_result_type",
792 op.getNumResults()));
793 return;
796 if (hasInferTypeInterface(op))
797 return;
799 // For each element, find or generate a name.
800 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
801 const NamedTypeConstraint &element = op.getResult(i);
802 std::string name = names[i];
804 // Choose the formatting string based on the element kind.
805 llvm::StringRef formatString;
806 if (!element.isVariableLength()) {
807 formatString = singleResultAppendTemplate;
808 } else if (element.isOptional()) {
809 formatString = optionalAppendResultTemplate;
810 } else {
811 assert(element.isVariadic() && "unhandled element group type");
812 // If emitting with sizedSegments, then we add the actual list-typed
813 // element. Otherwise, we extend the actual operands.
814 if (sizedSegments) {
815 formatString = singleResultAppendTemplate;
816 } else {
817 formatString = multiResultAppendTemplate;
821 builderLines.push_back(llvm::formatv(formatString.data(), name));
825 /// If the operation has variadic regions, adds a builder argument to specify
826 /// the number of those regions and builder lines to forward it to the generic
827 /// constructor.
828 static void
829 populateBuilderRegions(const Operator &op,
830 llvm::SmallVectorImpl<std::string> &builderArgs,
831 llvm::SmallVectorImpl<std::string> &builderLines) {
832 if (op.hasNoVariadicRegions())
833 return;
835 // This is currently enforced when Operator is constructed.
836 assert(op.getNumVariadicRegions() == 1 &&
837 op.getRegion(op.getNumRegions() - 1).isVariadic() &&
838 "expected the last region to be varidic");
840 const NamedRegion &region = op.getRegion(op.getNumRegions() - 1);
841 std::string name =
842 ("num_" + region.name.take_front().lower() + region.name.drop_front())
843 .str();
844 builderArgs.push_back(name);
845 builderLines.push_back(
846 llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name));
849 /// Emits a default builder constructing an operation from the list of its
850 /// result types, followed by a list of its operands. Returns vector
851 /// of fully built functionArgs for downstream users (to save having to
852 /// rebuild anew).
853 static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
854 raw_ostream &os) {
855 llvm::SmallVector<std::string> builderArgs;
856 llvm::SmallVector<std::string> builderLines;
857 llvm::SmallVector<std::string> operandArgNames;
858 llvm::SmallVector<std::string> successorArgNames;
859 builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
860 op.getNumNativeAttributes() + op.getNumSuccessors());
861 populateBuilderArgsResults(op, builderArgs);
862 size_t numResultArgs = builderArgs.size();
863 populateBuilderArgs(op, builderArgs, operandArgNames);
864 size_t numOperandAttrArgs = builderArgs.size() - numResultArgs;
865 populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
867 populateBuilderLinesOperand(op, operandArgNames, builderLines);
868 populateBuilderLinesAttr(
869 op, llvm::ArrayRef(builderArgs).drop_front(numResultArgs), builderLines);
870 populateBuilderLinesResult(
871 op, llvm::ArrayRef(builderArgs).take_front(numResultArgs), builderLines);
872 populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
873 populateBuilderRegions(op, builderArgs, builderLines);
875 // Layout of builderArgs vector elements:
876 // [ result_args operand_attr_args successor_args regions ]
878 // Determine whether the argument corresponding to a given index into the
879 // builderArgs vector is a python keyword argument or not.
880 auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool {
881 // All result, successor, and region arguments are positional arguments.
882 if ((builderArgIndex < numResultArgs) ||
883 (builderArgIndex >= (numResultArgs + numOperandAttrArgs)))
884 return false;
885 // Keyword arguments:
886 // - optional named attributes (including unit attributes)
887 // - default-valued named attributes
888 // - optional operands
889 Argument a = op.getArg(builderArgIndex - numResultArgs);
890 if (auto *nattr = llvm::dyn_cast_if_present<NamedAttribute *>(a))
891 return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue());
892 if (auto *ntype = llvm::dyn_cast_if_present<NamedTypeConstraint *>(a))
893 return ntype->isOptional();
894 return false;
897 // StringRefs in functionArgs refer to strings allocated by builderArgs.
898 llvm::SmallVector<llvm::StringRef> functionArgs;
900 // Add positional arguments.
901 for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
902 if (!isKeywordArgFn(i))
903 functionArgs.push_back(builderArgs[i]);
906 // Add a bare '*' to indicate that all following arguments must be keyword
907 // arguments.
908 functionArgs.push_back("*");
910 // Add a default 'None' value to each keyword arg string, and then add to the
911 // function args list.
912 for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
913 if (isKeywordArgFn(i)) {
914 builderArgs[i].append("=None");
915 functionArgs.push_back(builderArgs[i]);
918 functionArgs.push_back("loc=None");
919 functionArgs.push_back("ip=None");
921 SmallVector<std::string> initArgs;
922 initArgs.push_back("attributes=attributes");
923 if (!hasInferTypeInterface(op))
924 initArgs.push_back("results=results");
925 initArgs.push_back("operands=operands");
926 initArgs.push_back("successors=_ods_successors");
927 initArgs.push_back("regions=regions");
928 initArgs.push_back("loc=loc");
929 initArgs.push_back("ip=ip");
931 os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "),
932 llvm::join(builderLines, "\n "),
933 llvm::join(initArgs, ", "));
934 return llvm::to_vector<8>(
935 llvm::map_range(functionArgs, [](llvm::StringRef s) { return s.str(); }));
938 static void emitSegmentSpec(
939 const Operator &op, const char *kind,
940 llvm::function_ref<int(const Operator &)> getNumElements,
941 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
942 getElement,
943 raw_ostream &os) {
944 std::string segmentSpec("[");
945 for (int i = 0, e = getNumElements(op); i < e; ++i) {
946 const NamedTypeConstraint &element = getElement(op, i);
947 if (element.isOptional()) {
948 segmentSpec.append("0,");
949 } else if (element.isVariadic()) {
950 segmentSpec.append("-1,");
951 } else {
952 segmentSpec.append("1,");
955 segmentSpec.append("]");
957 os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
960 static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
961 // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
962 // Note that the base OpView class defines this as (0, True).
963 unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
964 os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount,
965 op.hasNoVariadicRegions() ? "True" : "False");
968 /// Emits named accessors to regions.
969 static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
970 for (const auto &en : llvm::enumerate(op.getRegions())) {
971 const NamedRegion &region = en.value();
972 if (region.name.empty())
973 continue;
975 assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
976 "expected only the last region to be variadic");
977 os << llvm::formatv(regionAccessorTemplate, sanitizeName(region.name),
978 std::to_string(en.index()) +
979 (region.isVariadic() ? ":" : ""));
983 /// Emits builder that extracts results from op
984 static void emitValueBuilder(const Operator &op,
985 llvm::SmallVector<std::string> functionArgs,
986 raw_ostream &os) {
987 // Params with (possibly) default args.
988 auto valueBuilderParams =
989 llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) {
990 llvm::SmallVector<llvm::StringRef> argMaybeDefault =
991 llvm::to_vector<2>(llvm::split(argAndMaybeDefault, "="));
992 auto arg = llvm::convertToSnakeFromCamelCase(argMaybeDefault[0]);
993 if (argMaybeDefault.size() == 2)
994 return arg + "=" + argMaybeDefault[1].str();
995 return arg;
997 // Actual args passed to op builder (e.g., opParam=op_param).
998 auto opBuilderArgs = llvm::map_range(
999 llvm::make_filter_range(functionArgs,
1000 [](const std::string &s) { return s != "*"; }),
1001 [](const std::string &arg) {
1002 auto lhs = *llvm::split(arg, "=").begin();
1003 return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
1005 std::string nameWithoutDialect =
1006 op.getOperationName().substr(op.getOperationName().find('.') + 1);
1007 os << llvm::formatv(
1008 valueBuilderTemplate, sanitizeName(nameWithoutDialect),
1009 op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
1010 llvm::join(opBuilderArgs, ", "),
1011 (op.getNumResults() > 1
1012 ? "_Sequence[_ods_ir.Value]"
1013 : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation")));
1016 /// Emits bindings for a specific Op to the given output stream.
1017 static void emitOpBindings(const Operator &op, raw_ostream &os) {
1018 os << llvm::formatv(opClassTemplate, op.getCppClassName(),
1019 op.getOperationName());
1021 // Sized segments.
1022 if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
1023 emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os);
1025 if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) {
1026 emitSegmentSpec(op, "RESULT", getNumResults, getResult, os);
1029 emitRegionAttributes(op, os);
1030 llvm::SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os);
1031 emitOperandAccessors(op, os);
1032 emitAttributeAccessors(op, os);
1033 emitResultAccessors(op, os);
1034 emitRegionAccessors(op, os);
1035 emitValueBuilder(op, functionArgs, os);
1038 /// Emits bindings for the dialect specified in the command line, including file
1039 /// headers and utilities. Returns `false` on success to comply with Tablegen
1040 /// registration requirements.
1041 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
1042 if (clDialectName.empty())
1043 llvm::PrintFatalError("dialect name not provided");
1045 os << fileHeader;
1046 if (!clDialectExtensionName.empty())
1047 os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue());
1048 else
1049 os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
1051 for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
1052 Operator op(rec);
1053 if (op.getDialectName() == clDialectName.getValue())
1054 emitOpBindings(op, os);
1056 return false;
1059 static GenRegistration
1060 genPythonBindings("gen-python-op-bindings",
1061 "Generate Python bindings for MLIR Ops", &emitAllOps);