1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
10 // binding classes wrapping a generic operation API.
12 //===----------------------------------------------------------------------===//
14 #include "mlir/TableGen/GenInfo.h"
15 #include "mlir/TableGen/Operator.h"
16 #include "llvm/ADT/StringSet.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/TableGen/Error.h"
20 #include "llvm/TableGen/Record.h"
23 using namespace mlir::tblgen
;
25 /// File header and includes.
26 /// {0} is the dialect namespace.
27 constexpr const char *fileHeader
= R
"Py(
28 # Autogenerated by mlir-tblgen; don't manually edit.
30 from ._ods_common import _cext as _ods_cext
31 from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values
32 _ods_ir = _ods_cext.ir
35 from . import _{0}_ops_ext as _ods_ext_module
37 _ods_ext_module = None
43 /// Template for dialect class:
44 /// {0} is the dialect namespace.
45 constexpr const char *dialectClassTemplate
= R
"Py(
46 @_ods_cext.register_dialect
47 class _Dialect(_ods_ir.Dialect):
48 DIALECT_NAMESPACE = "{0}"
53 constexpr const char *dialectExtensionTemplate
= R
"Py(
54 from ._{0}_ops_gen import _Dialect
57 /// Template for operation class:
58 /// {0} is the Python class name;
59 /// {1} is the operation name.
60 constexpr const char *opClassTemplate
= R
"Py(
61 @_ods_cext.register_operation(_Dialect)
62 @_ods_extend_opview_class(_ods_ext_module)
63 class {0}(_ods_ir.OpView):
64 OPERATION_NAME = "{1}"
67 /// Template for class level declarations of operand and result
69 /// {0} is either "OPERAND" or "RESULT"
70 /// {1} is the segment spec
71 /// Each segment spec is either None (default) or an array of integers
73 /// 1 = single element (expect non sequence operand/result)
74 /// 0 = optional element (expect a value or std::nullopt)
75 /// -1 = operand/result is a sequence corresponding to a variadic
76 constexpr const char *opClassSizedSegmentsTemplate
= R
"Py(
77 _ODS_{0}_SEGMENTS = {1}
80 /// Template for class level declarations of the _ODS_REGIONS spec:
81 /// {0} is the minimum number of regions
82 /// {1} is the Python bool literal for hasNoVariadicRegions
83 constexpr const char *opClassRegionSpecTemplate
= R
"Py(
84 _ODS_REGIONS = ({0}, {1})
87 /// Template for single-element accessor:
88 /// {0} is the name of the accessor;
89 /// {1} is either 'operand' or 'result';
90 /// {2} is the position in the element list.
91 constexpr const char *opSingleTemplate
= R
"Py(
94 return self.operation.{1}s[{2}]
97 /// Template for single-element accessor after a variable-length group:
98 /// {0} is the name of the accessor;
99 /// {1} is either 'operand' or 'result';
100 /// {2} is the total number of element groups;
101 /// {3} is the position of the current group in the group list.
102 /// This works for both a single variadic group (non-negative length) and an
103 /// single optional element (zero length if the element is absent).
104 constexpr const char *opSingleAfterVariableTemplate
= R
"Py(
107 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
108 return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
111 /// Template for an optional element accessor:
112 /// {0} is the name of the accessor;
113 /// {1} is either 'operand' or 'result';
114 /// {2} is the total number of element groups;
115 /// {3} is the position of the current group in the group list.
116 /// This works if we have only one variable-length group (and it's the optional
117 /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
118 /// smaller than the total number of groups.
119 constexpr const char *opOneOptionalTemplate
= R
"Py(
122 return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
125 /// Template for the variadic group accessor in the single variadic group case:
126 /// {0} is the name of the accessor;
127 /// {1} is either 'operand' or 'result';
128 /// {2} is the total number of element groups;
129 /// {3} is the position of the current group in the group list.
130 constexpr const char *opOneVariadicTemplate
= R
"Py(
133 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
134 return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
137 /// First part of the template for equally-sized variadic group accessor:
138 /// {0} is the name of the accessor;
139 /// {1} is either 'operand' or 'result';
140 /// {2} is the total number of variadic groups;
141 /// {3} is the number of non-variadic groups preceding the current group;
142 /// {3} is the number of variadic groups preceding the current group.
143 constexpr const char *opVariadicEqualPrefixTemplate
= R
"Py(
146 start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
148 /// Second part of the template for equally-sized case, accessing a single
150 /// {0} is either 'operand' or 'result'.
151 constexpr const char *opVariadicEqualSimpleTemplate
= R
"Py(
152 return self.operation.{0}s[start]
155 /// Second part of the template for equally-sized case, accessing a variadic
157 /// {0} is either 'operand' or 'result'.
158 constexpr const char *opVariadicEqualVariadicTemplate
= R
"Py(
159 return self.operation.{0}s[start:start + pg]
162 /// Template for an attribute-sized group accessor:
163 /// {0} is the name of the accessor;
164 /// {1} is either 'operand' or 'result';
165 /// {2} is the position of the group in the group list;
166 /// {3} is a return suffix (expected [0] for single-element, empty for
167 /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
168 constexpr const char *opVariadicSegmentTemplate
= R
"Py(
171 {1}_range = _ods_segmented_accessor(
173 self.operation.attributes["{1}_segment_sizes
"], {2})
177 /// Template for a suffix when accessing an optional element in the
178 /// attribute-sized case:
179 /// {0} is either 'operand' or 'result';
180 constexpr const char *opVariadicSegmentOptionalTrailingTemplate
=
181 R
"Py([0] if len({0}_range) > 0 else None)Py";
183 /// Template for an operation attribute getter:
184 /// {0} is the name of the attribute sanitized for Python;
185 /// {1} is the Python type of the attribute;
186 /// {2} os the original name of the attribute.
187 constexpr const char *attributeGetterTemplate
= R
"Py(
190 return {1}(self.operation.attributes["{2}"])
193 /// Template for an optional operation attribute getter:
194 /// {0} is the name of the attribute sanitized for Python;
195 /// {1} is the Python type of the attribute;
196 /// {2} is the original name of the attribute.
197 constexpr const char *optionalAttributeGetterTemplate
= R
"Py(
200 if "{2}" not in self.operation.attributes:
202 return {1}(self.operation.attributes["{2}"])
205 /// Template for a getter of a unit operation attribute, returns True of the
206 /// unit attribute is present, False otherwise (unit attributes have meaning
207 /// by mere presence):
208 /// {0} is the name of the attribute sanitized for Python,
209 /// {1} is the original name of the attribute.
210 constexpr const char *unitAttributeGetterTemplate
= R
"Py(
213 return "{1}" in self.operation.attributes
216 /// Template for an operation attribute setter:
217 /// {0} is the name of the attribute sanitized for Python;
218 /// {1} is the original name of the attribute.
219 constexpr const char *attributeSetterTemplate
= R
"Py(
221 def {0}(self, value):
223 raise ValueError("'None' not allowed as value
for mandatory attributes
")
224 self.operation.attributes["{1}"] = value
227 /// Template for a setter of an optional operation attribute, setting to None
228 /// removes the attribute:
229 /// {0} is the name of the attribute sanitized for Python;
230 /// {1} is the original name of the attribute.
231 constexpr const char *optionalAttributeSetterTemplate
= R
"Py(
233 def {0}(self, value):
234 if value is not None:
235 self.operation.attributes["{1}"] = value
236 elif "{1}" in self.operation.attributes:
237 del self.operation.attributes["{1}"]
240 /// Template for a setter of a unit operation attribute, setting to None or
241 /// False removes the attribute:
242 /// {0} is the name of the attribute sanitized for Python;
243 /// {1} is the original name of the attribute.
244 constexpr const char *unitAttributeSetterTemplate
= R
"Py(
246 def {0}(self, value):
248 self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get()
249 elif "{1}" in self.operation.attributes:
250 del self.operation.attributes["{1}"]
253 /// Template for a deleter of an optional or a unit operation attribute, removes
254 /// the attribute from the operation:
255 /// {0} is the name of the attribute sanitized for Python;
256 /// {1} is the original name of the attribute.
257 constexpr const char *attributeDeleterTemplate
= R
"Py(
260 del self.operation.attributes["{1}"]
263 constexpr const char *regionAccessorTemplate
= R
"PY(
266 return self.regions[{1}]
269 static llvm::cl::OptionCategory
270 clOpPythonBindingCat("Options for -gen-python-op-bindings");
272 static llvm::cl::opt
<std::string
>
273 clDialectName("bind-dialect",
274 llvm::cl::desc("The dialect to run the generator for"),
275 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat
));
277 static llvm::cl::opt
<std::string
> clDialectExtensionName(
278 "dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
279 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat
));
281 using AttributeClasses
= DenseMap
<StringRef
, StringRef
>;
283 /// Checks whether `str` is a Python keyword or would shadow builtin function.
284 static bool isPythonReserved(StringRef str
) {
285 static llvm::StringSet
<> reserved(
286 {"and", "as", "assert", "break", "callable", "class",
287 "continue", "def", "del", "elif", "else", "except",
288 "finally", "for", "from", "global", "if", "import",
289 "in", "is", "lambda", "nonlocal", "not", "or",
290 "pass", "raise", "return", "issubclass", "try", "type",
291 "while", "with", "yield"});
292 return reserved
.contains(str
);
295 /// Checks whether `str` would shadow a generated variable or attribute
296 /// part of the OpView API.
297 static bool isODSReserved(StringRef str
) {
298 static llvm::StringSet
<> reserved(
299 {"attributes", "create", "context", "ip", "operands", "print", "get_asm",
300 "loc", "verify", "regions", "results", "self", "operation",
301 "DIALECT_NAMESPACE", "OPERATION_NAME"});
302 return str
.startswith("_ods_") || str
.endswith("_ods") ||
303 reserved
.contains(str
);
306 /// Modifies the `name` in a way that it becomes suitable for Python bindings
307 /// (does not change the `name` if it already is suitable) and returns the
308 /// modified version.
309 static std::string
sanitizeName(StringRef name
) {
310 if (isPythonReserved(name
) || isODSReserved(name
))
311 return (name
+ "_").str();
315 static std::string
attrSizedTraitForKind(const char *kind
) {
316 return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
317 llvm::StringRef(kind
).take_front().upper(),
318 llvm::StringRef(kind
).drop_front());
321 /// Emits accessors to "elements" of an Op definition. Currently, the supported
322 /// elements are operands and results, indicated by `kind`, which must be either
323 /// `operand` or `result` and is used verbatim in the emitted code.
324 static void emitElementAccessors(
325 const Operator
&op
, raw_ostream
&os
, const char *kind
,
326 llvm::function_ref
<unsigned(const Operator
&)> getNumVariableLength
,
327 llvm::function_ref
<int(const Operator
&)> getNumElements
,
328 llvm::function_ref
<const NamedTypeConstraint
&(const Operator
&, int)>
330 assert(llvm::is_contained(
331 llvm::SmallVector
<StringRef
, 2>{"operand", "result"}, kind
) &&
334 // Traits indicating how to process variadic elements.
335 std::string sameSizeTrait
=
336 llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
337 llvm::StringRef(kind
).take_front().upper(),
338 llvm::StringRef(kind
).drop_front());
339 std::string attrSizedTrait
= attrSizedTraitForKind(kind
);
341 unsigned numVariableLength
= getNumVariableLength(op
);
343 // If there is only one variable-length element group, its size can be
344 // inferred from the total number of elements. If there are none, the
345 // generation is straightforward.
346 if (numVariableLength
<= 1) {
347 bool seenVariableLength
= false;
348 for (int i
= 0, e
= getNumElements(op
); i
< e
; ++i
) {
349 const NamedTypeConstraint
&element
= getElement(op
, i
);
350 if (element
.isVariableLength())
351 seenVariableLength
= true;
352 if (element
.name
.empty())
354 if (element
.isVariableLength()) {
355 os
<< llvm::formatv(element
.isOptional() ? opOneOptionalTemplate
356 : opOneVariadicTemplate
,
357 sanitizeName(element
.name
), kind
,
358 getNumElements(op
), i
);
359 } else if (seenVariableLength
) {
360 os
<< llvm::formatv(opSingleAfterVariableTemplate
,
361 sanitizeName(element
.name
), kind
,
362 getNumElements(op
), i
);
364 os
<< llvm::formatv(opSingleTemplate
, sanitizeName(element
.name
), kind
,
371 // Handle the operations where variadic groups have the same size.
372 if (op
.getTrait(sameSizeTrait
)) {
373 int numPrecedingSimple
= 0;
374 int numPrecedingVariadic
= 0;
375 for (int i
= 0, e
= getNumElements(op
); i
< e
; ++i
) {
376 const NamedTypeConstraint
&element
= getElement(op
, i
);
377 if (!element
.name
.empty()) {
378 os
<< llvm::formatv(opVariadicEqualPrefixTemplate
,
379 sanitizeName(element
.name
), kind
, numVariableLength
,
380 numPrecedingSimple
, numPrecedingVariadic
);
381 os
<< llvm::formatv(element
.isVariableLength()
382 ? opVariadicEqualVariadicTemplate
383 : opVariadicEqualSimpleTemplate
,
386 if (element
.isVariableLength())
387 ++numPrecedingVariadic
;
389 ++numPrecedingSimple
;
394 // Handle the operations where the size of groups (variadic or not) is
395 // provided as an attribute. For non-variadic elements, make sure to return
396 // an element rather than a singleton container.
397 if (op
.getTrait(attrSizedTrait
)) {
398 for (int i
= 0, e
= getNumElements(op
); i
< e
; ++i
) {
399 const NamedTypeConstraint
&element
= getElement(op
, i
);
400 if (element
.name
.empty())
402 std::string trailing
;
403 if (!element
.isVariableLength())
405 else if (element
.isOptional())
406 trailing
= std::string(
407 llvm::formatv(opVariadicSegmentOptionalTrailingTemplate
, kind
));
408 os
<< llvm::formatv(opVariadicSegmentTemplate
, sanitizeName(element
.name
),
414 llvm::PrintFatalError("unsupported " + llvm::Twine(kind
) + " structure");
417 /// Free function helpers accessing Operator components.
418 static int getNumOperands(const Operator
&op
) { return op
.getNumOperands(); }
419 static const NamedTypeConstraint
&getOperand(const Operator
&op
, int i
) {
420 return op
.getOperand(i
);
422 static int getNumResults(const Operator
&op
) { return op
.getNumResults(); }
423 static const NamedTypeConstraint
&getResult(const Operator
&op
, int i
) {
424 return op
.getResult(i
);
427 /// Emits accessors to Op operands.
428 static void emitOperandAccessors(const Operator
&op
, raw_ostream
&os
) {
429 auto getNumVariableLengthOperands
= [](const Operator
&oper
) {
430 return oper
.getNumVariableLengthOperands();
432 emitElementAccessors(op
, os
, "operand", getNumVariableLengthOperands
,
433 getNumOperands
, getOperand
);
436 /// Emits accessors Op results.
437 static void emitResultAccessors(const Operator
&op
, raw_ostream
&os
) {
438 auto getNumVariableLengthResults
= [](const Operator
&oper
) {
439 return oper
.getNumVariableLengthResults();
441 emitElementAccessors(op
, os
, "result", getNumVariableLengthResults
,
442 getNumResults
, getResult
);
445 /// Emits accessors to Op attributes.
446 static void emitAttributeAccessors(const Operator
&op
,
447 const AttributeClasses
&attributeClasses
,
449 for (const auto &namedAttr
: op
.getAttributes()) {
450 // Skip "derived" attributes because they are just C++ functions that we
451 // don't currently expose.
452 if (namedAttr
.attr
.isDerivedAttr())
455 if (namedAttr
.name
.empty())
458 std::string sanitizedName
= sanitizeName(namedAttr
.name
);
460 // Unit attributes are handled specially.
461 if (namedAttr
.attr
.getStorageType().trim().equals("::mlir::UnitAttr")) {
462 os
<< llvm::formatv(unitAttributeGetterTemplate
, sanitizedName
,
464 os
<< llvm::formatv(unitAttributeSetterTemplate
, sanitizedName
,
466 os
<< llvm::formatv(attributeDeleterTemplate
, sanitizedName
,
471 // Other kinds of attributes need a mapping to a Python type.
472 if (!attributeClasses
.count(namedAttr
.attr
.getStorageType().trim()))
475 StringRef pythonType
=
476 attributeClasses
.lookup(namedAttr
.attr
.getStorageType());
477 if (namedAttr
.attr
.isOptional()) {
478 os
<< llvm::formatv(optionalAttributeGetterTemplate
, sanitizedName
,
479 pythonType
, namedAttr
.name
);
480 os
<< llvm::formatv(optionalAttributeSetterTemplate
, sanitizedName
,
482 os
<< llvm::formatv(attributeDeleterTemplate
, sanitizedName
,
485 os
<< llvm::formatv(attributeGetterTemplate
, sanitizedName
, pythonType
,
487 os
<< llvm::formatv(attributeSetterTemplate
, sanitizedName
,
489 // Non-optional attributes cannot be deleted.
494 /// Template for the default auto-generated builder.
495 /// {0} is a comma-separated list of builder arguments, including the trailing
497 /// {1} is the code populating `operands`, `results` and `attributes`,
498 /// `successors` fields.
499 constexpr const char *initTemplate
= R
"Py(
500 def __init__(self, {0}):
506 super().__init__(self.build_generic(
507 attributes=attributes, results=results, operands=operands,
508 successors=_ods_successors, regions=regions, loc=loc, ip=ip))
511 /// Template for appending a single element to the operand/result list.
512 /// {0} is the field name.
513 constexpr const char *singleOperandAppendTemplate
=
514 "operands.append(_get_op_result_or_value({0}))";
515 constexpr const char *singleResultAppendTemplate
= "results.append({0})";
517 /// Template for appending an optional element to the operand/result list.
518 /// {0} is the field name.
519 constexpr const char *optionalAppendOperandTemplate
=
520 "if {0} is not None: operands.append(_get_op_result_or_value({0}))";
521 constexpr const char *optionalAppendAttrSizedOperandsTemplate
=
522 "operands.append(_get_op_result_or_value({0}) if {0} is not None else "
524 constexpr const char *optionalAppendResultTemplate
=
525 "if {0} is not None: results.append({0})";
527 /// Template for appending a list of elements to the operand/result list.
528 /// {0} is the field name.
529 constexpr const char *multiOperandAppendTemplate
=
530 "operands.extend(_get_op_results_or_values({0}))";
531 constexpr const char *multiOperandAppendPackTemplate
=
532 "operands.append(_get_op_results_or_values({0}))";
533 constexpr const char *multiResultAppendTemplate
= "results.extend({0})";
535 /// Template for attribute builder from raw input in the operation builder.
536 /// {0} is the builder argument name;
537 /// {1} is the attribute builder from raw;
538 /// {2} is the attribute builder from raw.
539 /// Use the value the user passed in if either it is already an Attribute or
540 /// there is no method registered to make it an Attribute.
541 constexpr const char *initAttributeWithBuilderTemplate
=
542 R
"Py(attributes["{1}"] = ({0} if (
543 issubclass(type({0}), _ods_ir.Attribute) or
544 not _ods_ir.AttrBuilder.contains('{2}')) else
545 _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
547 /// Template for attribute builder from raw input for optional attribute in the
548 /// operation builder.
549 /// {0} is the builder argument name;
550 /// {1} is the attribute builder from raw;
551 /// {2} is the attribute builder from raw.
552 /// Use the value the user passed in if either it is already an Attribute or
553 /// there is no method registered to make it an Attribute.
554 constexpr const char *initOptionalAttributeWithBuilderTemplate
=
555 R
"Py(if {0} is not None: attributes["{1}"] = ({0} if (
556 issubclass(type({0}), _ods_ir.Attribute) or
557 not _ods_ir.AttrBuilder.contains('{2}')) else
558 _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
560 constexpr const char *initUnitAttributeTemplate
=
561 R
"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
562 _ods_get_default_loc_context(loc)))Py";
564 /// Template to initialize the successors list in the builder if there are any
566 /// {0} is the value to initialize the successors list to.
567 constexpr const char *initSuccessorsTemplate
= R
"Py(_ods_successors = {0})Py";
569 /// Template to append or extend the list of successors in the builder.
570 /// {0} is the list method ('append' or 'extend');
571 /// {1} is the value to add.
572 constexpr const char *addSuccessorTemplate
= R
"Py(_ods_successors.{0}({1}))Py";
574 /// Returns true if the SameArgumentAndResultTypes trait can be used to infer
575 /// result types of the given operation.
576 static bool hasSameArgumentAndResultTypes(const Operator
&op
) {
577 return op
.getTrait("::mlir::OpTrait::SameOperandsAndResultType") &&
578 op
.getNumVariableLengthResults() == 0;
581 /// Returns true if the FirstAttrDerivedResultType trait can be used to infer
582 /// result types of the given operation.
583 static bool hasFirstAttrDerivedResultTypes(const Operator
&op
) {
584 return op
.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") &&
585 op
.getNumVariableLengthResults() == 0;
588 /// Returns true if the InferTypeOpInterface can be used to infer result types
589 /// of the given operation.
590 static bool hasInferTypeInterface(const Operator
&op
) {
591 return op
.getTrait("::mlir::InferTypeOpInterface::Trait") &&
592 op
.getNumRegions() == 0;
595 /// Returns true if there is a trait or interface that can be used to infer
596 /// result types of the given operation.
597 static bool canInferType(const Operator
&op
) {
598 return hasSameArgumentAndResultTypes(op
) ||
599 hasFirstAttrDerivedResultTypes(op
) || hasInferTypeInterface(op
);
602 /// Populates `builderArgs` with result names if the builder is expected to
603 /// accept them as arguments.
605 populateBuilderArgsResults(const Operator
&op
,
606 llvm::SmallVectorImpl
<std::string
> &builderArgs
) {
607 if (canInferType(op
))
610 for (int i
= 0, e
= op
.getNumResults(); i
< e
; ++i
) {
611 std::string name
= op
.getResultName(i
).str();
613 if (op
.getNumResults() == 1) {
614 // Special case for one result, make the default name be 'result'
615 // to properly match the built-in result accessor.
618 name
= llvm::formatv("_gen_res_{0}", i
);
621 name
= sanitizeName(name
);
622 builderArgs
.push_back(name
);
626 /// Populates `builderArgs` with the Python-compatible names of builder function
627 /// arguments using intermixed attributes and operands in the same order as they
628 /// appear in the `arguments` field of the op definition. Additionally,
629 /// `operandNames` is populated with names of operands in their order of
632 populateBuilderArgs(const Operator
&op
,
633 llvm::SmallVectorImpl
<std::string
> &builderArgs
,
634 llvm::SmallVectorImpl
<std::string
> &operandNames
,
635 llvm::SmallVectorImpl
<std::string
> &successorArgNames
) {
637 for (int i
= 0, e
= op
.getNumArgs(); i
< e
; ++i
) {
638 std::string name
= op
.getArgName(i
).str();
640 name
= llvm::formatv("_gen_arg_{0}", i
);
641 name
= sanitizeName(name
);
642 builderArgs
.push_back(name
);
643 if (!op
.getArg(i
).is
<NamedAttribute
*>())
644 operandNames
.push_back(name
);
648 /// Populates `builderArgs` with the Python-compatible names of builder function
649 /// successor arguments. Additionally, `successorArgNames` is also populated.
650 static void populateBuilderArgsSuccessors(
651 const Operator
&op
, llvm::SmallVectorImpl
<std::string
> &builderArgs
,
652 llvm::SmallVectorImpl
<std::string
> &successorArgNames
) {
654 for (int i
= 0, e
= op
.getNumSuccessors(); i
< e
; ++i
) {
655 NamedSuccessor successor
= op
.getSuccessor(i
);
656 std::string name
= std::string(successor
.name
);
658 name
= llvm::formatv("_gen_successor_{0}", i
);
659 name
= sanitizeName(name
);
660 builderArgs
.push_back(name
);
661 successorArgNames
.push_back(name
);
665 /// Populates `builderLines` with additional lines that are required in the
666 /// builder to set up operation attributes. `argNames` is expected to contain
667 /// the names of builder arguments that correspond to op arguments, i.e. to the
668 /// operands and attributes in the same order as they appear in the `arguments`
671 populateBuilderLinesAttr(const Operator
&op
,
672 llvm::ArrayRef
<std::string
> argNames
,
673 llvm::SmallVectorImpl
<std::string
> &builderLines
) {
674 builderLines
.push_back("_ods_context = _ods_get_default_loc_context(loc)");
675 for (int i
= 0, e
= op
.getNumArgs(); i
< e
; ++i
) {
676 Argument arg
= op
.getArg(i
);
677 auto *attribute
= arg
.dyn_cast
<NamedAttribute
*>();
681 // Unit attributes are handled specially.
682 if (attribute
->attr
.getStorageType().trim().equals("::mlir::UnitAttr")) {
683 builderLines
.push_back(llvm::formatv(initUnitAttributeTemplate
,
684 attribute
->name
, argNames
[i
]));
688 builderLines
.push_back(llvm::formatv(
689 attribute
->attr
.isOptional() || attribute
->attr
.hasDefaultValue()
690 ? initOptionalAttributeWithBuilderTemplate
691 : initAttributeWithBuilderTemplate
,
692 argNames
[i
], attribute
->name
, attribute
->attr
.getAttrDefName()));
696 /// Populates `builderLines` with additional lines that are required in the
697 /// builder to set up successors. successorArgNames is expected to correspond
698 /// to the Python argument name for each successor on the op.
699 static void populateBuilderLinesSuccessors(
700 const Operator
&op
, llvm::ArrayRef
<std::string
> successorArgNames
,
701 llvm::SmallVectorImpl
<std::string
> &builderLines
) {
702 if (successorArgNames
.empty()) {
703 builderLines
.push_back(llvm::formatv(initSuccessorsTemplate
, "None"));
707 builderLines
.push_back(llvm::formatv(initSuccessorsTemplate
, "[]"));
708 for (int i
= 0, e
= successorArgNames
.size(); i
< e
; ++i
) {
709 auto &argName
= successorArgNames
[i
];
710 const NamedSuccessor
&successor
= op
.getSuccessor(i
);
711 builderLines
.push_back(
712 llvm::formatv(addSuccessorTemplate
,
713 successor
.isVariadic() ? "extend" : "append", argName
));
717 /// Populates `builderLines` with additional lines that are required in the
718 /// builder to set up op operands.
720 populateBuilderLinesOperand(const Operator
&op
,
721 llvm::ArrayRef
<std::string
> names
,
722 llvm::SmallVectorImpl
<std::string
> &builderLines
) {
723 bool sizedSegments
= op
.getTrait(attrSizedTraitForKind("operand")) != nullptr;
725 // For each element, find or generate a name.
726 for (int i
= 0, e
= op
.getNumOperands(); i
< e
; ++i
) {
727 const NamedTypeConstraint
&element
= op
.getOperand(i
);
728 std::string name
= names
[i
];
730 // Choose the formatting string based on the element kind.
731 llvm::StringRef formatString
;
732 if (!element
.isVariableLength()) {
733 formatString
= singleOperandAppendTemplate
;
734 } else if (element
.isOptional()) {
736 formatString
= optionalAppendAttrSizedOperandsTemplate
;
738 formatString
= optionalAppendOperandTemplate
;
741 assert(element
.isVariadic() && "unhandled element group type");
742 // If emitting with sizedSegments, then we add the actual list-typed
743 // element. Otherwise, we extend the actual operands.
745 formatString
= multiOperandAppendPackTemplate
;
747 formatString
= multiOperandAppendTemplate
;
751 builderLines
.push_back(llvm::formatv(formatString
.data(), name
));
755 /// Python code template for deriving the operation result types from its
757 /// - {0} is the name of the attribute from which to derive the types.
758 constexpr const char *deriveTypeFromAttrTemplate
=
759 R
"PY(_ods_result_type_source_attr = attributes["{0}"]
760 _ods_derived_result_type = (
761 _ods_ir.TypeAttr(_ods_result_type_source_attr).value
762 if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
763 _ods_result_type_source_attr.type))PY";
765 /// Python code template appending {0} type {1} times to the results list.
766 constexpr const char *appendSameResultsTemplate
= "results.extend([{0}] * {1})";
768 /// Python code template for inferring the operation results using the
769 /// corresponding interface:
770 /// - {0} is the name of the class for which the types are inferred.
771 constexpr const char *inferTypeInterfaceTemplate
=
772 R
"PY(results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
774 attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
775 context=_ods_context,
779 /// Appends the given multiline string as individual strings into
781 static void appendLineByLine(StringRef string
,
782 llvm::SmallVectorImpl
<std::string
> &builderLines
) {
784 std::pair
<StringRef
, StringRef
> split
= std::make_pair(string
, string
);
786 split
= split
.second
.split('\n');
787 builderLines
.push_back(split
.first
.str());
788 } while (!split
.second
.empty());
791 /// Populates `builderLines` with additional lines that are required in the
792 /// builder to set up op results.
794 populateBuilderLinesResult(const Operator
&op
,
795 llvm::ArrayRef
<std::string
> names
,
796 llvm::SmallVectorImpl
<std::string
> &builderLines
) {
797 bool sizedSegments
= op
.getTrait(attrSizedTraitForKind("result")) != nullptr;
799 if (hasSameArgumentAndResultTypes(op
)) {
800 builderLines
.push_back(llvm::formatv(
801 appendSameResultsTemplate
, "operands[0].type", op
.getNumResults()));
805 if (hasFirstAttrDerivedResultTypes(op
)) {
806 const NamedAttribute
&firstAttr
= op
.getAttribute(0);
807 assert(!firstAttr
.name
.empty() && "unexpected empty name for the attribute "
808 "from which the type is derived");
810 llvm::formatv(deriveTypeFromAttrTemplate
, firstAttr
.name
).str(),
812 builderLines
.push_back(llvm::formatv(appendSameResultsTemplate
,
813 "_ods_derived_result_type",
814 op
.getNumResults()));
818 if (hasInferTypeInterface(op
)) {
820 llvm::formatv(inferTypeInterfaceTemplate
, op
.getCppClassName()).str(),
825 // For each element, find or generate a name.
826 for (int i
= 0, e
= op
.getNumResults(); i
< e
; ++i
) {
827 const NamedTypeConstraint
&element
= op
.getResult(i
);
828 std::string name
= names
[i
];
830 // Choose the formatting string based on the element kind.
831 llvm::StringRef formatString
;
832 if (!element
.isVariableLength()) {
833 formatString
= singleResultAppendTemplate
;
834 } else if (element
.isOptional()) {
835 formatString
= optionalAppendResultTemplate
;
837 assert(element
.isVariadic() && "unhandled element group type");
838 // If emitting with sizedSegments, then we add the actual list-typed
839 // element. Otherwise, we extend the actual operands.
841 formatString
= singleResultAppendTemplate
;
843 formatString
= multiResultAppendTemplate
;
847 builderLines
.push_back(llvm::formatv(formatString
.data(), name
));
851 /// If the operation has variadic regions, adds a builder argument to specify
852 /// the number of those regions and builder lines to forward it to the generic
855 populateBuilderRegions(const Operator
&op
,
856 llvm::SmallVectorImpl
<std::string
> &builderArgs
,
857 llvm::SmallVectorImpl
<std::string
> &builderLines
) {
858 if (op
.hasNoVariadicRegions())
861 // This is currently enforced when Operator is constructed.
862 assert(op
.getNumVariadicRegions() == 1 &&
863 op
.getRegion(op
.getNumRegions() - 1).isVariadic() &&
864 "expected the last region to be varidic");
866 const NamedRegion
®ion
= op
.getRegion(op
.getNumRegions() - 1);
868 ("num_" + region
.name
.take_front().lower() + region
.name
.drop_front())
870 builderArgs
.push_back(name
);
871 builderLines
.push_back(
872 llvm::formatv("regions = {0} + {1}", op
.getNumRegions() - 1, name
));
875 /// Emits a default builder constructing an operation from the list of its
876 /// result types, followed by a list of its operands.
877 static void emitDefaultOpBuilder(const Operator
&op
, raw_ostream
&os
) {
878 // If we are asked to skip default builders, comply.
879 if (op
.skipDefaultBuilders())
882 llvm::SmallVector
<std::string
> builderArgs
;
883 llvm::SmallVector
<std::string
> builderLines
;
884 llvm::SmallVector
<std::string
> operandArgNames
;
885 llvm::SmallVector
<std::string
> successorArgNames
;
886 builderArgs
.reserve(op
.getNumOperands() + op
.getNumResults() +
887 op
.getNumNativeAttributes() + op
.getNumSuccessors());
888 populateBuilderArgsResults(op
, builderArgs
);
889 size_t numResultArgs
= builderArgs
.size();
890 populateBuilderArgs(op
, builderArgs
, operandArgNames
, successorArgNames
);
891 size_t numOperandAttrArgs
= builderArgs
.size() - numResultArgs
;
892 populateBuilderArgsSuccessors(op
, builderArgs
, successorArgNames
);
894 populateBuilderLinesOperand(op
, operandArgNames
, builderLines
);
895 populateBuilderLinesAttr(
896 op
, llvm::ArrayRef(builderArgs
).drop_front(numResultArgs
), builderLines
);
897 populateBuilderLinesResult(
898 op
, llvm::ArrayRef(builderArgs
).take_front(numResultArgs
), builderLines
);
899 populateBuilderLinesSuccessors(op
, successorArgNames
, builderLines
);
900 populateBuilderRegions(op
, builderArgs
, builderLines
);
902 // Layout of builderArgs vector elements:
903 // [ result_args operand_attr_args successor_args regions ]
905 // Determine whether the argument corresponding to a given index into the
906 // builderArgs vector is a python keyword argument or not.
907 auto isKeywordArgFn
= [&](size_t builderArgIndex
) -> bool {
908 // All result, successor, and region arguments are positional arguments.
909 if ((builderArgIndex
< numResultArgs
) ||
910 (builderArgIndex
>= (numResultArgs
+ numOperandAttrArgs
)))
912 // Keyword arguments:
913 // - optional named attributes (including unit attributes)
914 // - default-valued named attributes
915 // - optional operands
916 Argument a
= op
.getArg(builderArgIndex
- numResultArgs
);
917 if (auto *nattr
= a
.dyn_cast
<NamedAttribute
*>())
918 return (nattr
->attr
.isOptional() || nattr
->attr
.hasDefaultValue());
919 if (auto *ntype
= a
.dyn_cast
<NamedTypeConstraint
*>())
920 return ntype
->isOptional();
924 // StringRefs in functionArgs refer to strings allocated by builderArgs.
925 llvm::SmallVector
<llvm::StringRef
> functionArgs
;
927 // Add positional arguments.
928 for (size_t i
= 0, cnt
= builderArgs
.size(); i
< cnt
; ++i
) {
929 if (!isKeywordArgFn(i
))
930 functionArgs
.push_back(builderArgs
[i
]);
933 // Add a bare '*' to indicate that all following arguments must be keyword
935 functionArgs
.push_back("*");
937 // Add a default 'None' value to each keyword arg string, and then add to the
938 // function args list.
939 for (size_t i
= 0, cnt
= builderArgs
.size(); i
< cnt
; ++i
) {
940 if (isKeywordArgFn(i
)) {
941 builderArgs
[i
].append("=None");
942 functionArgs
.push_back(builderArgs
[i
]);
945 functionArgs
.push_back("loc=None");
946 functionArgs
.push_back("ip=None");
947 os
<< llvm::formatv(initTemplate
, llvm::join(functionArgs
, ", "),
948 llvm::join(builderLines
, "\n "));
951 static void constructAttributeMapping(const llvm::RecordKeeper
&records
,
952 AttributeClasses
&attributeClasses
) {
953 for (const llvm::Record
*rec
:
954 records
.getAllDerivedDefinitions("PythonAttr")) {
955 attributeClasses
.try_emplace(rec
->getValueAsString("cppStorageType").trim(),
956 rec
->getValueAsString("pythonType").trim());
960 static void emitSegmentSpec(
961 const Operator
&op
, const char *kind
,
962 llvm::function_ref
<int(const Operator
&)> getNumElements
,
963 llvm::function_ref
<const NamedTypeConstraint
&(const Operator
&, int)>
966 std::string
segmentSpec("[");
967 for (int i
= 0, e
= getNumElements(op
); i
< e
; ++i
) {
968 const NamedTypeConstraint
&element
= getElement(op
, i
);
969 if (element
.isOptional()) {
970 segmentSpec
.append("0,");
971 } else if (element
.isVariadic()) {
972 segmentSpec
.append("-1,");
974 segmentSpec
.append("1,");
977 segmentSpec
.append("]");
979 os
<< llvm::formatv(opClassSizedSegmentsTemplate
, kind
, segmentSpec
);
982 static void emitRegionAttributes(const Operator
&op
, raw_ostream
&os
) {
983 // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
984 // Note that the base OpView class defines this as (0, True).
985 unsigned minRegionCount
= op
.getNumRegions() - op
.getNumVariadicRegions();
986 os
<< llvm::formatv(opClassRegionSpecTemplate
, minRegionCount
,
987 op
.hasNoVariadicRegions() ? "True" : "False");
990 /// Emits named accessors to regions.
991 static void emitRegionAccessors(const Operator
&op
, raw_ostream
&os
) {
992 for (const auto &en
: llvm::enumerate(op
.getRegions())) {
993 const NamedRegion
®ion
= en
.value();
994 if (region
.name
.empty())
997 assert((!region
.isVariadic() || en
.index() == op
.getNumRegions() - 1) &&
998 "expected only the last region to be variadic");
999 os
<< llvm::formatv(regionAccessorTemplate
, sanitizeName(region
.name
),
1000 std::to_string(en
.index()) +
1001 (region
.isVariadic() ? ":" : ""));
1005 /// Emits bindings for a specific Op to the given output stream.
1006 static void emitOpBindings(const Operator
&op
,
1007 const AttributeClasses
&attributeClasses
,
1009 os
<< llvm::formatv(opClassTemplate
, op
.getCppClassName(),
1010 op
.getOperationName());
1013 if (op
.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
1014 emitSegmentSpec(op
, "OPERAND", getNumOperands
, getOperand
, os
);
1016 if (op
.getTrait(attrSizedTraitForKind("result")) != nullptr) {
1017 emitSegmentSpec(op
, "RESULT", getNumResults
, getResult
, os
);
1020 emitRegionAttributes(op
, os
);
1021 emitDefaultOpBuilder(op
, os
);
1022 emitOperandAccessors(op
, os
);
1023 emitAttributeAccessors(op
, attributeClasses
, os
);
1024 emitResultAccessors(op
, os
);
1025 emitRegionAccessors(op
, os
);
1028 /// Emits bindings for the dialect specified in the command line, including file
1029 /// headers and utilities. Returns `false` on success to comply with Tablegen
1030 /// registration requirements.
1031 static bool emitAllOps(const llvm::RecordKeeper
&records
, raw_ostream
&os
) {
1032 if (clDialectName
.empty())
1033 llvm::PrintFatalError("dialect name not provided");
1035 AttributeClasses attributeClasses
;
1036 constructAttributeMapping(records
, attributeClasses
);
1038 bool isExtension
= !clDialectExtensionName
.empty();
1039 os
<< llvm::formatv(fileHeader
, isExtension
1040 ? clDialectExtensionName
.getValue()
1041 : clDialectName
.getValue());
1043 os
<< llvm::formatv(dialectExtensionTemplate
, clDialectName
.getValue());
1045 os
<< llvm::formatv(dialectClassTemplate
, clDialectName
.getValue());
1047 for (const llvm::Record
*rec
: records
.getAllDerivedDefinitions("Op")) {
1049 if (op
.getDialectName() == clDialectName
.getValue())
1050 emitOpBindings(op
, attributeClasses
, os
);
1055 static GenRegistration
1056 genPythonBindings("gen-python-op-bindings",
1057 "Generate Python bindings for MLIR Ops", &emitAllOps
);