[clang][modules] Don't prevent translation of FW_Private includes when explicitly...
[llvm-project.git] / mlir / tools / mlir-tblgen / OpPythonBindingGen.cpp
blobc8ef84721090ab91567666733bb0a6180c3696c3
1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
10 // binding classes wrapping a generic operation API.
12 //===----------------------------------------------------------------------===//
14 #include "OpGenHelpers.h"
16 #include "mlir/TableGen/GenInfo.h"
17 #include "mlir/TableGen/Operator.h"
18 #include "llvm/ADT/StringSet.h"
19 #include "llvm/Support/CommandLine.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
24 using namespace mlir;
25 using namespace mlir::tblgen;
27 /// File header and includes.
28 /// {0} is the dialect namespace.
29 constexpr const char *fileHeader = R"Py(
30 # Autogenerated by mlir-tblgen; don't manually edit.
32 from ._ods_common import _cext as _ods_cext
33 from ._ods_common import segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
34 _ods_ir = _ods_cext.ir
36 import builtins
37 from typing import Sequence as _Sequence, Union as _Union
39 )Py";
41 /// Template for dialect class:
42 /// {0} is the dialect namespace.
43 constexpr const char *dialectClassTemplate = R"Py(
44 @_ods_cext.register_dialect
45 class _Dialect(_ods_ir.Dialect):
46 DIALECT_NAMESPACE = "{0}"
47 pass
49 )Py";
51 constexpr const char *dialectExtensionTemplate = R"Py(
52 from ._{0}_ops_gen import _Dialect
53 )Py";
55 /// Template for operation class:
56 /// {0} is the Python class name;
57 /// {1} is the operation name.
58 constexpr const char *opClassTemplate = R"Py(
59 @_ods_cext.register_operation(_Dialect)
60 class {0}(_ods_ir.OpView):
61 OPERATION_NAME = "{1}"
62 )Py";
64 /// Template for class level declarations of operand and result
65 /// segment specs.
66 /// {0} is either "OPERAND" or "RESULT"
67 /// {1} is the segment spec
68 /// Each segment spec is either None (default) or an array of integers
69 /// where:
70 /// 1 = single element (expect non sequence operand/result)
71 /// 0 = optional element (expect a value or std::nullopt)
72 /// -1 = operand/result is a sequence corresponding to a variadic
73 constexpr const char *opClassSizedSegmentsTemplate = R"Py(
74 _ODS_{0}_SEGMENTS = {1}
75 )Py";
77 /// Template for class level declarations of the _ODS_REGIONS spec:
78 /// {0} is the minimum number of regions
79 /// {1} is the Python bool literal for hasNoVariadicRegions
80 constexpr const char *opClassRegionSpecTemplate = R"Py(
81 _ODS_REGIONS = ({0}, {1})
82 )Py";
84 /// Template for single-element accessor:
85 /// {0} is the name of the accessor;
86 /// {1} is either 'operand' or 'result';
87 /// {2} is the position in the element list.
88 constexpr const char *opSingleTemplate = R"Py(
89 @builtins.property
90 def {0}(self):
91 return self.operation.{1}s[{2}]
92 )Py";
94 /// Template for single-element accessor after a variable-length group:
95 /// {0} is the name of the accessor;
96 /// {1} is either 'operand' or 'result';
97 /// {2} is the total number of element groups;
98 /// {3} is the position of the current group in the group list.
99 /// This works for both a single variadic group (non-negative length) and an
100 /// single optional element (zero length if the element is absent).
101 constexpr const char *opSingleAfterVariableTemplate = R"Py(
102 @builtins.property
103 def {0}(self):
104 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
105 return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
106 )Py";
108 /// Template for an optional element accessor:
109 /// {0} is the name of the accessor;
110 /// {1} is either 'operand' or 'result';
111 /// {2} is the total number of element groups;
112 /// {3} is the position of the current group in the group list.
113 /// This works if we have only one variable-length group (and it's the optional
114 /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
115 /// smaller than the total number of groups.
116 constexpr const char *opOneOptionalTemplate = R"Py(
117 @builtins.property
118 def {0}(self):
119 return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
120 )Py";
122 /// Template for the variadic group accessor in the single variadic group case:
123 /// {0} is the name of the accessor;
124 /// {1} is either 'operand' or 'result';
125 /// {2} is the total number of element groups;
126 /// {3} is the position of the current group in the group list.
127 constexpr const char *opOneVariadicTemplate = R"Py(
128 @builtins.property
129 def {0}(self):
130 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
131 return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
132 )Py";
134 /// First part of the template for equally-sized variadic group accessor:
135 /// {0} is the name of the accessor;
136 /// {1} is either 'operand' or 'result';
137 /// {2} is the total number of variadic groups;
138 /// {3} is the number of non-variadic groups preceding the current group;
139 /// {3} is the number of variadic groups preceding the current group.
140 constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
141 @builtins.property
142 def {0}(self):
143 start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
145 /// Second part of the template for equally-sized case, accessing a single
146 /// element:
147 /// {0} is either 'operand' or 'result'.
148 constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
149 return self.operation.{0}s[start]
150 )Py";
152 /// Second part of the template for equally-sized case, accessing a variadic
153 /// group:
154 /// {0} is either 'operand' or 'result'.
155 constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
156 return self.operation.{0}s[start:start + pg]
157 )Py";
159 /// Template for an attribute-sized group accessor:
160 /// {0} is the name of the accessor;
161 /// {1} is either 'operand' or 'result';
162 /// {2} is the position of the group in the group list;
163 /// {3} is a return suffix (expected [0] for single-element, empty for
164 /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
165 constexpr const char *opVariadicSegmentTemplate = R"Py(
166 @builtins.property
167 def {0}(self):
168 {1}_range = _ods_segmented_accessor(
169 self.operation.{1}s,
170 self.operation.attributes["{1}SegmentSizes"], {2})
171 return {1}_range{3}
172 )Py";
174 /// Template for a suffix when accessing an optional element in the
175 /// attribute-sized case:
176 /// {0} is either 'operand' or 'result';
177 constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
178 R"Py([0] if len({0}_range) > 0 else None)Py";
180 /// Template for an operation attribute getter:
181 /// {0} is the name of the attribute sanitized for Python;
182 /// {1} is the original name of the attribute.
183 constexpr const char *attributeGetterTemplate = R"Py(
184 @builtins.property
185 def {0}(self):
186 return self.operation.attributes["{1}"]
187 )Py";
189 /// Template for an optional operation attribute getter:
190 /// {0} is the name of the attribute sanitized for Python;
191 /// {1} is the original name of the attribute.
192 constexpr const char *optionalAttributeGetterTemplate = R"Py(
193 @builtins.property
194 def {0}(self):
195 if "{1}" not in self.operation.attributes:
196 return None
197 return self.operation.attributes["{1}"]
198 )Py";
200 /// Template for a getter of a unit operation attribute, returns True of the
201 /// unit attribute is present, False otherwise (unit attributes have meaning
202 /// by mere presence):
203 /// {0} is the name of the attribute sanitized for Python,
204 /// {1} is the original name of the attribute.
205 constexpr const char *unitAttributeGetterTemplate = R"Py(
206 @builtins.property
207 def {0}(self):
208 return "{1}" in self.operation.attributes
209 )Py";
211 /// Template for an operation attribute setter:
212 /// {0} is the name of the attribute sanitized for Python;
213 /// {1} is the original name of the attribute.
214 constexpr const char *attributeSetterTemplate = R"Py(
215 @{0}.setter
216 def {0}(self, value):
217 if value is None:
218 raise ValueError("'None' not allowed as value for mandatory attributes")
219 self.operation.attributes["{1}"] = value
220 )Py";
222 /// Template for a setter of an optional operation attribute, setting to None
223 /// removes the attribute:
224 /// {0} is the name of the attribute sanitized for Python;
225 /// {1} is the original name of the attribute.
226 constexpr const char *optionalAttributeSetterTemplate = R"Py(
227 @{0}.setter
228 def {0}(self, value):
229 if value is not None:
230 self.operation.attributes["{1}"] = value
231 elif "{1}" in self.operation.attributes:
232 del self.operation.attributes["{1}"]
233 )Py";
235 /// Template for a setter of a unit operation attribute, setting to None or
236 /// False removes the attribute:
237 /// {0} is the name of the attribute sanitized for Python;
238 /// {1} is the original name of the attribute.
239 constexpr const char *unitAttributeSetterTemplate = R"Py(
240 @{0}.setter
241 def {0}(self, value):
242 if bool(value):
243 self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get()
244 elif "{1}" in self.operation.attributes:
245 del self.operation.attributes["{1}"]
246 )Py";
248 /// Template for a deleter of an optional or a unit operation attribute, removes
249 /// the attribute from the operation:
250 /// {0} is the name of the attribute sanitized for Python;
251 /// {1} is the original name of the attribute.
252 constexpr const char *attributeDeleterTemplate = R"Py(
253 @{0}.deleter
254 def {0}(self):
255 del self.operation.attributes["{1}"]
256 )Py";
258 constexpr const char *regionAccessorTemplate = R"Py(
259 @builtins.property
260 def {0}(self):
261 return self.regions[{1}]
262 )Py";
264 constexpr const char *valueBuilderTemplate = R"Py(
265 def {0}({2}) -> {4}:
266 return _get_op_result_or_op_results({1}({3}))
267 )Py";
269 static llvm::cl::OptionCategory
270 clOpPythonBindingCat("Options for -gen-python-op-bindings");
272 static llvm::cl::opt<std::string>
273 clDialectName("bind-dialect",
274 llvm::cl::desc("The dialect to run the generator for"),
275 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
277 static llvm::cl::opt<std::string> clDialectExtensionName(
278 "dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
279 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
281 using AttributeClasses = DenseMap<StringRef, StringRef>;
283 /// Checks whether `str` would shadow a generated variable or attribute
284 /// part of the OpView API.
285 static bool isODSReserved(StringRef str) {
286 static llvm::StringSet<> reserved(
287 {"attributes", "create", "context", "ip", "operands", "print", "get_asm",
288 "loc", "verify", "regions", "results", "self", "operation",
289 "DIALECT_NAMESPACE", "OPERATION_NAME"});
290 return str.startswith("_ods_") || str.endswith("_ods") ||
291 reserved.contains(str);
294 /// Modifies the `name` in a way that it becomes suitable for Python bindings
295 /// (does not change the `name` if it already is suitable) and returns the
296 /// modified version.
297 static std::string sanitizeName(StringRef name) {
298 std::string processedStr = name.str();
299 std::replace_if(
300 processedStr.begin(), processedStr.end(),
301 [](char c) { return !llvm::isAlnum(c); }, '_');
303 if (llvm::isDigit(*processedStr.begin()))
304 return "_" + processedStr;
306 if (isPythonReserved(processedStr) || isODSReserved(processedStr))
307 return processedStr + "_";
308 return processedStr;
311 static std::string attrSizedTraitForKind(const char *kind) {
312 return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
313 llvm::StringRef(kind).take_front().upper(),
314 llvm::StringRef(kind).drop_front());
317 /// Emits accessors to "elements" of an Op definition. Currently, the supported
318 /// elements are operands and results, indicated by `kind`, which must be either
319 /// `operand` or `result` and is used verbatim in the emitted code.
320 static void emitElementAccessors(
321 const Operator &op, raw_ostream &os, const char *kind,
322 llvm::function_ref<unsigned(const Operator &)> getNumVariableLength,
323 llvm::function_ref<int(const Operator &)> getNumElements,
324 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
325 getElement) {
326 assert(llvm::is_contained(
327 llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) &&
328 "unsupported kind");
330 // Traits indicating how to process variadic elements.
331 std::string sameSizeTrait =
332 llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
333 llvm::StringRef(kind).take_front().upper(),
334 llvm::StringRef(kind).drop_front());
335 std::string attrSizedTrait = attrSizedTraitForKind(kind);
337 unsigned numVariableLength = getNumVariableLength(op);
339 // If there is only one variable-length element group, its size can be
340 // inferred from the total number of elements. If there are none, the
341 // generation is straightforward.
342 if (numVariableLength <= 1) {
343 bool seenVariableLength = false;
344 for (int i = 0, e = getNumElements(op); i < e; ++i) {
345 const NamedTypeConstraint &element = getElement(op, i);
346 if (element.isVariableLength())
347 seenVariableLength = true;
348 if (element.name.empty())
349 continue;
350 if (element.isVariableLength()) {
351 os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
352 : opOneVariadicTemplate,
353 sanitizeName(element.name), kind,
354 getNumElements(op), i);
355 } else if (seenVariableLength) {
356 os << llvm::formatv(opSingleAfterVariableTemplate,
357 sanitizeName(element.name), kind,
358 getNumElements(op), i);
359 } else {
360 os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
364 return;
367 // Handle the operations where variadic groups have the same size.
368 if (op.getTrait(sameSizeTrait)) {
369 int numPrecedingSimple = 0;
370 int numPrecedingVariadic = 0;
371 for (int i = 0, e = getNumElements(op); i < e; ++i) {
372 const NamedTypeConstraint &element = getElement(op, i);
373 if (!element.name.empty()) {
374 os << llvm::formatv(opVariadicEqualPrefixTemplate,
375 sanitizeName(element.name), kind, numVariableLength,
376 numPrecedingSimple, numPrecedingVariadic);
377 os << llvm::formatv(element.isVariableLength()
378 ? opVariadicEqualVariadicTemplate
379 : opVariadicEqualSimpleTemplate,
380 kind);
382 if (element.isVariableLength())
383 ++numPrecedingVariadic;
384 else
385 ++numPrecedingSimple;
387 return;
390 // Handle the operations where the size of groups (variadic or not) is
391 // provided as an attribute. For non-variadic elements, make sure to return
392 // an element rather than a singleton container.
393 if (op.getTrait(attrSizedTrait)) {
394 for (int i = 0, e = getNumElements(op); i < e; ++i) {
395 const NamedTypeConstraint &element = getElement(op, i);
396 if (element.name.empty())
397 continue;
398 std::string trailing;
399 if (!element.isVariableLength())
400 trailing = "[0]";
401 else if (element.isOptional())
402 trailing = std::string(
403 llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
404 os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name),
405 kind, i, trailing);
407 return;
410 llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
413 /// Free function helpers accessing Operator components.
414 static int getNumOperands(const Operator &op) { return op.getNumOperands(); }
415 static const NamedTypeConstraint &getOperand(const Operator &op, int i) {
416 return op.getOperand(i);
418 static int getNumResults(const Operator &op) { return op.getNumResults(); }
419 static const NamedTypeConstraint &getResult(const Operator &op, int i) {
420 return op.getResult(i);
423 /// Emits accessors to Op operands.
424 static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
425 auto getNumVariableLengthOperands = [](const Operator &oper) {
426 return oper.getNumVariableLengthOperands();
428 emitElementAccessors(op, os, "operand", getNumVariableLengthOperands,
429 getNumOperands, getOperand);
432 /// Emits accessors Op results.
433 static void emitResultAccessors(const Operator &op, raw_ostream &os) {
434 auto getNumVariableLengthResults = [](const Operator &oper) {
435 return oper.getNumVariableLengthResults();
437 emitElementAccessors(op, os, "result", getNumVariableLengthResults,
438 getNumResults, getResult);
441 /// Emits accessors to Op attributes.
442 static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
443 for (const auto &namedAttr : op.getAttributes()) {
444 // Skip "derived" attributes because they are just C++ functions that we
445 // don't currently expose.
446 if (namedAttr.attr.isDerivedAttr())
447 continue;
449 if (namedAttr.name.empty())
450 continue;
452 std::string sanitizedName = sanitizeName(namedAttr.name);
454 // Unit attributes are handled specially.
455 if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
456 os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,
457 namedAttr.name);
458 os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,
459 namedAttr.name);
460 os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
461 namedAttr.name);
462 continue;
465 if (namedAttr.attr.isOptional()) {
466 os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName,
467 namedAttr.name);
468 os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName,
469 namedAttr.name);
470 os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
471 namedAttr.name);
472 } else {
473 os << llvm::formatv(attributeGetterTemplate, sanitizedName,
474 namedAttr.name);
475 os << llvm::formatv(attributeSetterTemplate, sanitizedName,
476 namedAttr.name);
477 // Non-optional attributes cannot be deleted.
482 /// Template for the default auto-generated builder.
483 /// {0} is a comma-separated list of builder arguments, including the trailing
484 /// `loc` and `ip`;
485 /// {1} is the code populating `operands`, `results` and `attributes`,
486 /// `successors` fields.
487 constexpr const char *initTemplate = R"Py(
488 def __init__(self, {0}):
489 operands = []
490 results = []
491 attributes = {{}
492 regions = None
494 super().__init__(self.build_generic({2}))
495 )Py";
497 /// Template for appending a single element to the operand/result list.
498 /// {0} is the field name.
499 constexpr const char *singleOperandAppendTemplate =
500 "operands.append(_get_op_result_or_value({0}))";
501 constexpr const char *singleResultAppendTemplate = "results.append({0})";
503 /// Template for appending an optional element to the operand/result list.
504 /// {0} is the field name.
505 constexpr const char *optionalAppendOperandTemplate =
506 "if {0} is not None: operands.append(_get_op_result_or_value({0}))";
507 constexpr const char *optionalAppendAttrSizedOperandsTemplate =
508 "operands.append(_get_op_result_or_value({0}) if {0} is not None else "
509 "None)";
510 constexpr const char *optionalAppendResultTemplate =
511 "if {0} is not None: results.append({0})";
513 /// Template for appending a list of elements to the operand/result list.
514 /// {0} is the field name.
515 constexpr const char *multiOperandAppendTemplate =
516 "operands.extend(_get_op_results_or_values({0}))";
517 constexpr const char *multiOperandAppendPackTemplate =
518 "operands.append(_get_op_results_or_values({0}))";
519 constexpr const char *multiResultAppendTemplate = "results.extend({0})";
521 /// Template for attribute builder from raw input in the operation builder.
522 /// {0} is the builder argument name;
523 /// {1} is the attribute builder from raw;
524 /// {2} is the attribute builder from raw.
525 /// Use the value the user passed in if either it is already an Attribute or
526 /// there is no method registered to make it an Attribute.
527 constexpr const char *initAttributeWithBuilderTemplate =
528 R"Py(attributes["{1}"] = ({0} if (
529 issubclass(type({0}), _ods_ir.Attribute) or
530 not _ods_ir.AttrBuilder.contains('{2}')) else
531 _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
533 /// Template for attribute builder from raw input for optional attribute in the
534 /// operation builder.
535 /// {0} is the builder argument name;
536 /// {1} is the attribute builder from raw;
537 /// {2} is the attribute builder from raw.
538 /// Use the value the user passed in if either it is already an Attribute or
539 /// there is no method registered to make it an Attribute.
540 constexpr const char *initOptionalAttributeWithBuilderTemplate =
541 R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
542 issubclass(type({0}), _ods_ir.Attribute) or
543 not _ods_ir.AttrBuilder.contains('{2}')) else
544 _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
546 constexpr const char *initUnitAttributeTemplate =
547 R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
548 _ods_get_default_loc_context(loc)))Py";
550 /// Template to initialize the successors list in the builder if there are any
551 /// successors.
552 /// {0} is the value to initialize the successors list to.
553 constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py";
555 /// Template to append or extend the list of successors in the builder.
556 /// {0} is the list method ('append' or 'extend');
557 /// {1} is the value to add.
558 constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py";
560 /// Returns true if the SameArgumentAndResultTypes trait can be used to infer
561 /// result types of the given operation.
562 static bool hasSameArgumentAndResultTypes(const Operator &op) {
563 return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") &&
564 op.getNumVariableLengthResults() == 0;
567 /// Returns true if the FirstAttrDerivedResultType trait can be used to infer
568 /// result types of the given operation.
569 static bool hasFirstAttrDerivedResultTypes(const Operator &op) {
570 return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") &&
571 op.getNumVariableLengthResults() == 0;
574 /// Returns true if the InferTypeOpInterface can be used to infer result types
575 /// of the given operation.
576 static bool hasInferTypeInterface(const Operator &op) {
577 return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
578 op.getNumRegions() == 0;
581 /// Returns true if there is a trait or interface that can be used to infer
582 /// result types of the given operation.
583 static bool canInferType(const Operator &op) {
584 return hasSameArgumentAndResultTypes(op) ||
585 hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op);
588 /// Populates `builderArgs` with result names if the builder is expected to
589 /// accept them as arguments.
590 static void
591 populateBuilderArgsResults(const Operator &op,
592 llvm::SmallVectorImpl<std::string> &builderArgs) {
593 if (canInferType(op))
594 return;
596 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
597 std::string name = op.getResultName(i).str();
598 if (name.empty()) {
599 if (op.getNumResults() == 1) {
600 // Special case for one result, make the default name be 'result'
601 // to properly match the built-in result accessor.
602 name = "result";
603 } else {
604 name = llvm::formatv("_gen_res_{0}", i);
607 name = sanitizeName(name);
608 builderArgs.push_back(name);
612 /// Populates `builderArgs` with the Python-compatible names of builder function
613 /// arguments using intermixed attributes and operands in the same order as they
614 /// appear in the `arguments` field of the op definition. Additionally,
615 /// `operandNames` is populated with names of operands in their order of
616 /// appearance.
617 static void
618 populateBuilderArgs(const Operator &op,
619 llvm::SmallVectorImpl<std::string> &builderArgs,
620 llvm::SmallVectorImpl<std::string> &operandNames) {
621 for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
622 std::string name = op.getArgName(i).str();
623 if (name.empty())
624 name = llvm::formatv("_gen_arg_{0}", i);
625 name = sanitizeName(name);
626 builderArgs.push_back(name);
627 if (!op.getArg(i).is<NamedAttribute *>())
628 operandNames.push_back(name);
632 /// Populates `builderArgs` with the Python-compatible names of builder function
633 /// successor arguments. Additionally, `successorArgNames` is also populated.
634 static void populateBuilderArgsSuccessors(
635 const Operator &op, llvm::SmallVectorImpl<std::string> &builderArgs,
636 llvm::SmallVectorImpl<std::string> &successorArgNames) {
638 for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
639 NamedSuccessor successor = op.getSuccessor(i);
640 std::string name = std::string(successor.name);
641 if (name.empty())
642 name = llvm::formatv("_gen_successor_{0}", i);
643 name = sanitizeName(name);
644 builderArgs.push_back(name);
645 successorArgNames.push_back(name);
649 /// Populates `builderLines` with additional lines that are required in the
650 /// builder to set up operation attributes. `argNames` is expected to contain
651 /// the names of builder arguments that correspond to op arguments, i.e. to the
652 /// operands and attributes in the same order as they appear in the `arguments`
653 /// field.
654 static void
655 populateBuilderLinesAttr(const Operator &op,
656 llvm::ArrayRef<std::string> argNames,
657 llvm::SmallVectorImpl<std::string> &builderLines) {
658 builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)");
659 for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
660 Argument arg = op.getArg(i);
661 auto *attribute = llvm::dyn_cast_if_present<NamedAttribute *>(arg);
662 if (!attribute)
663 continue;
665 // Unit attributes are handled specially.
666 if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
667 builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
668 attribute->name, argNames[i]));
669 continue;
672 builderLines.push_back(llvm::formatv(
673 attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
674 ? initOptionalAttributeWithBuilderTemplate
675 : initAttributeWithBuilderTemplate,
676 argNames[i], attribute->name, attribute->attr.getAttrDefName()));
680 /// Populates `builderLines` with additional lines that are required in the
681 /// builder to set up successors. successorArgNames is expected to correspond
682 /// to the Python argument name for each successor on the op.
683 static void populateBuilderLinesSuccessors(
684 const Operator &op, llvm::ArrayRef<std::string> successorArgNames,
685 llvm::SmallVectorImpl<std::string> &builderLines) {
686 if (successorArgNames.empty()) {
687 builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None"));
688 return;
691 builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]"));
692 for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
693 auto &argName = successorArgNames[i];
694 const NamedSuccessor &successor = op.getSuccessor(i);
695 builderLines.push_back(
696 llvm::formatv(addSuccessorTemplate,
697 successor.isVariadic() ? "extend" : "append", argName));
701 /// Populates `builderLines` with additional lines that are required in the
702 /// builder to set up op operands.
703 static void
704 populateBuilderLinesOperand(const Operator &op,
705 llvm::ArrayRef<std::string> names,
706 llvm::SmallVectorImpl<std::string> &builderLines) {
707 bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;
709 // For each element, find or generate a name.
710 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
711 const NamedTypeConstraint &element = op.getOperand(i);
712 std::string name = names[i];
714 // Choose the formatting string based on the element kind.
715 llvm::StringRef formatString;
716 if (!element.isVariableLength()) {
717 formatString = singleOperandAppendTemplate;
718 } else if (element.isOptional()) {
719 if (sizedSegments) {
720 formatString = optionalAppendAttrSizedOperandsTemplate;
721 } else {
722 formatString = optionalAppendOperandTemplate;
724 } else {
725 assert(element.isVariadic() && "unhandled element group type");
726 // If emitting with sizedSegments, then we add the actual list-typed
727 // element. Otherwise, we extend the actual operands.
728 if (sizedSegments) {
729 formatString = multiOperandAppendPackTemplate;
730 } else {
731 formatString = multiOperandAppendTemplate;
735 builderLines.push_back(llvm::formatv(formatString.data(), name));
739 /// Python code template for deriving the operation result types from its
740 /// attribute:
741 /// - {0} is the name of the attribute from which to derive the types.
742 constexpr const char *deriveTypeFromAttrTemplate =
743 R"Py(_ods_result_type_source_attr = attributes["{0}"]
744 _ods_derived_result_type = (
745 _ods_ir.TypeAttr(_ods_result_type_source_attr).value
746 if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
747 _ods_result_type_source_attr.type))Py";
749 /// Python code template appending {0} type {1} times to the results list.
750 constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
752 /// Appends the given multiline string as individual strings into
753 /// `builderLines`.
754 static void appendLineByLine(StringRef string,
755 llvm::SmallVectorImpl<std::string> &builderLines) {
757 std::pair<StringRef, StringRef> split = std::make_pair(string, string);
758 do {
759 split = split.second.split('\n');
760 builderLines.push_back(split.first.str());
761 } while (!split.second.empty());
764 /// Populates `builderLines` with additional lines that are required in the
765 /// builder to set up op results.
766 static void
767 populateBuilderLinesResult(const Operator &op,
768 llvm::ArrayRef<std::string> names,
769 llvm::SmallVectorImpl<std::string> &builderLines) {
770 bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
772 if (hasSameArgumentAndResultTypes(op)) {
773 builderLines.push_back(llvm::formatv(
774 appendSameResultsTemplate, "operands[0].type", op.getNumResults()));
775 return;
778 if (hasFirstAttrDerivedResultTypes(op)) {
779 const NamedAttribute &firstAttr = op.getAttribute(0);
780 assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
781 "from which the type is derived");
782 appendLineByLine(
783 llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
784 builderLines);
785 builderLines.push_back(llvm::formatv(appendSameResultsTemplate,
786 "_ods_derived_result_type",
787 op.getNumResults()));
788 return;
791 if (hasInferTypeInterface(op))
792 return;
794 // For each element, find or generate a name.
795 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
796 const NamedTypeConstraint &element = op.getResult(i);
797 std::string name = names[i];
799 // Choose the formatting string based on the element kind.
800 llvm::StringRef formatString;
801 if (!element.isVariableLength()) {
802 formatString = singleResultAppendTemplate;
803 } else if (element.isOptional()) {
804 formatString = optionalAppendResultTemplate;
805 } else {
806 assert(element.isVariadic() && "unhandled element group type");
807 // If emitting with sizedSegments, then we add the actual list-typed
808 // element. Otherwise, we extend the actual operands.
809 if (sizedSegments) {
810 formatString = singleResultAppendTemplate;
811 } else {
812 formatString = multiResultAppendTemplate;
816 builderLines.push_back(llvm::formatv(formatString.data(), name));
820 /// If the operation has variadic regions, adds a builder argument to specify
821 /// the number of those regions and builder lines to forward it to the generic
822 /// constructor.
823 static void
824 populateBuilderRegions(const Operator &op,
825 llvm::SmallVectorImpl<std::string> &builderArgs,
826 llvm::SmallVectorImpl<std::string> &builderLines) {
827 if (op.hasNoVariadicRegions())
828 return;
830 // This is currently enforced when Operator is constructed.
831 assert(op.getNumVariadicRegions() == 1 &&
832 op.getRegion(op.getNumRegions() - 1).isVariadic() &&
833 "expected the last region to be varidic");
835 const NamedRegion &region = op.getRegion(op.getNumRegions() - 1);
836 std::string name =
837 ("num_" + region.name.take_front().lower() + region.name.drop_front())
838 .str();
839 builderArgs.push_back(name);
840 builderLines.push_back(
841 llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name));
844 /// Emits a default builder constructing an operation from the list of its
845 /// result types, followed by a list of its operands. Returns vector
846 /// of fully built functionArgs for downstream users (to save having to
847 /// rebuild anew).
848 static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
849 raw_ostream &os) {
850 llvm::SmallVector<std::string> builderArgs;
851 llvm::SmallVector<std::string> builderLines;
852 llvm::SmallVector<std::string> operandArgNames;
853 llvm::SmallVector<std::string> successorArgNames;
854 builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
855 op.getNumNativeAttributes() + op.getNumSuccessors());
856 populateBuilderArgsResults(op, builderArgs);
857 size_t numResultArgs = builderArgs.size();
858 populateBuilderArgs(op, builderArgs, operandArgNames);
859 size_t numOperandAttrArgs = builderArgs.size() - numResultArgs;
860 populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
862 populateBuilderLinesOperand(op, operandArgNames, builderLines);
863 populateBuilderLinesAttr(
864 op, llvm::ArrayRef(builderArgs).drop_front(numResultArgs), builderLines);
865 populateBuilderLinesResult(
866 op, llvm::ArrayRef(builderArgs).take_front(numResultArgs), builderLines);
867 populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
868 populateBuilderRegions(op, builderArgs, builderLines);
870 // Layout of builderArgs vector elements:
871 // [ result_args operand_attr_args successor_args regions ]
873 // Determine whether the argument corresponding to a given index into the
874 // builderArgs vector is a python keyword argument or not.
875 auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool {
876 // All result, successor, and region arguments are positional arguments.
877 if ((builderArgIndex < numResultArgs) ||
878 (builderArgIndex >= (numResultArgs + numOperandAttrArgs)))
879 return false;
880 // Keyword arguments:
881 // - optional named attributes (including unit attributes)
882 // - default-valued named attributes
883 // - optional operands
884 Argument a = op.getArg(builderArgIndex - numResultArgs);
885 if (auto *nattr = llvm::dyn_cast_if_present<NamedAttribute *>(a))
886 return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue());
887 if (auto *ntype = llvm::dyn_cast_if_present<NamedTypeConstraint *>(a))
888 return ntype->isOptional();
889 return false;
892 // StringRefs in functionArgs refer to strings allocated by builderArgs.
893 llvm::SmallVector<llvm::StringRef> functionArgs;
895 // Add positional arguments.
896 for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
897 if (!isKeywordArgFn(i))
898 functionArgs.push_back(builderArgs[i]);
901 // Add a bare '*' to indicate that all following arguments must be keyword
902 // arguments.
903 functionArgs.push_back("*");
905 // Add a default 'None' value to each keyword arg string, and then add to the
906 // function args list.
907 for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
908 if (isKeywordArgFn(i)) {
909 builderArgs[i].append("=None");
910 functionArgs.push_back(builderArgs[i]);
913 functionArgs.push_back("loc=None");
914 functionArgs.push_back("ip=None");
916 SmallVector<std::string> initArgs;
917 initArgs.push_back("attributes=attributes");
918 if (!hasInferTypeInterface(op))
919 initArgs.push_back("results=results");
920 initArgs.push_back("operands=operands");
921 initArgs.push_back("successors=_ods_successors");
922 initArgs.push_back("regions=regions");
923 initArgs.push_back("loc=loc");
924 initArgs.push_back("ip=ip");
926 os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "),
927 llvm::join(builderLines, "\n "),
928 llvm::join(initArgs, ", "));
929 return llvm::to_vector<8>(
930 llvm::map_range(functionArgs, [](llvm::StringRef s) { return s.str(); }));
933 static void emitSegmentSpec(
934 const Operator &op, const char *kind,
935 llvm::function_ref<int(const Operator &)> getNumElements,
936 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
937 getElement,
938 raw_ostream &os) {
939 std::string segmentSpec("[");
940 for (int i = 0, e = getNumElements(op); i < e; ++i) {
941 const NamedTypeConstraint &element = getElement(op, i);
942 if (element.isOptional()) {
943 segmentSpec.append("0,");
944 } else if (element.isVariadic()) {
945 segmentSpec.append("-1,");
946 } else {
947 segmentSpec.append("1,");
950 segmentSpec.append("]");
952 os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
955 static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
956 // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
957 // Note that the base OpView class defines this as (0, True).
958 unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
959 os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount,
960 op.hasNoVariadicRegions() ? "True" : "False");
963 /// Emits named accessors to regions.
964 static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
965 for (const auto &en : llvm::enumerate(op.getRegions())) {
966 const NamedRegion &region = en.value();
967 if (region.name.empty())
968 continue;
970 assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
971 "expected only the last region to be variadic");
972 os << llvm::formatv(regionAccessorTemplate, sanitizeName(region.name),
973 std::to_string(en.index()) +
974 (region.isVariadic() ? ":" : ""));
978 /// Emits builder that extracts results from op
979 static void emitValueBuilder(const Operator &op,
980 llvm::SmallVector<std::string> functionArgs,
981 raw_ostream &os) {
982 // Params with (possibly) default args.
983 auto valueBuilderParams =
984 llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) {
985 llvm::SmallVector<llvm::StringRef> argMaybeDefault =
986 llvm::to_vector<2>(llvm::split(argAndMaybeDefault, "="));
987 auto arg = llvm::convertToSnakeFromCamelCase(argMaybeDefault[0]);
988 if (argMaybeDefault.size() == 2)
989 return arg + "=" + argMaybeDefault[1].str();
990 return arg;
992 // Actual args passed to op builder (e.g., opParam=op_param).
993 auto opBuilderArgs = llvm::map_range(
994 llvm::make_filter_range(functionArgs,
995 [](const std::string &s) { return s != "*"; }),
996 [](const std::string &arg) {
997 auto lhs = *llvm::split(arg, "=").begin();
998 return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
1000 std::string nameWithoutDialect =
1001 op.getOperationName().substr(op.getOperationName().find('.') + 1);
1002 os << llvm::formatv(valueBuilderTemplate, sanitizeName(nameWithoutDialect),
1003 op.getCppClassName(),
1004 llvm::join(valueBuilderParams, ", "),
1005 llvm::join(opBuilderArgs, ", "),
1006 (op.getNumResults() > 1
1007 ? "_Sequence[_ods_ir.OpResult]"
1008 : (op.getNumResults() > 0 ? "_ods_ir.OpResult"
1009 : "_ods_ir.Operation")));
1012 /// Emits bindings for a specific Op to the given output stream.
1013 static void emitOpBindings(const Operator &op, raw_ostream &os) {
1014 os << llvm::formatv(opClassTemplate, op.getCppClassName(),
1015 op.getOperationName());
1017 // Sized segments.
1018 if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
1019 emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os);
1021 if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) {
1022 emitSegmentSpec(op, "RESULT", getNumResults, getResult, os);
1025 emitRegionAttributes(op, os);
1026 llvm::SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os);
1027 emitOperandAccessors(op, os);
1028 emitAttributeAccessors(op, os);
1029 emitResultAccessors(op, os);
1030 emitRegionAccessors(op, os);
1031 emitValueBuilder(op, functionArgs, os);
1034 /// Emits bindings for the dialect specified in the command line, including file
1035 /// headers and utilities. Returns `false` on success to comply with Tablegen
1036 /// registration requirements.
1037 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
1038 if (clDialectName.empty())
1039 llvm::PrintFatalError("dialect name not provided");
1041 os << fileHeader;
1042 if (!clDialectExtensionName.empty())
1043 os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue());
1044 else
1045 os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
1047 for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
1048 Operator op(rec);
1049 if (op.getDialectName() == clDialectName.getValue())
1050 emitOpBindings(op, os);
1052 return false;
1055 static GenRegistration
1056 genPythonBindings("gen-python-op-bindings",
1057 "Generate Python bindings for MLIR Ops", &emitAllOps);