1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
10 // binding classes wrapping a generic operation API.
12 //===----------------------------------------------------------------------===//
14 #include "OpGenHelpers.h"
16 #include "mlir/TableGen/GenInfo.h"
17 #include "mlir/TableGen/Operator.h"
18 #include "llvm/ADT/StringSet.h"
19 #include "llvm/Support/CommandLine.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
25 using namespace mlir::tblgen
;
27 /// File header and includes.
28 /// {0} is the dialect namespace.
29 constexpr const char *fileHeader
= R
"Py(
30 # Autogenerated by mlir-tblgen; don't manually edit.
32 from ._ods_common import _cext as _ods_cext
33 from ._ods_common import segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results
34 _ods_ir = _ods_cext.ir
37 from typing import Sequence as _Sequence, Union as _Union
41 /// Template for dialect class:
42 /// {0} is the dialect namespace.
43 constexpr const char *dialectClassTemplate
= R
"Py(
44 @_ods_cext.register_dialect
45 class _Dialect(_ods_ir.Dialect):
46 DIALECT_NAMESPACE = "{0}"
51 constexpr const char *dialectExtensionTemplate
= R
"Py(
52 from ._{0}_ops_gen import _Dialect
55 /// Template for operation class:
56 /// {0} is the Python class name;
57 /// {1} is the operation name.
58 constexpr const char *opClassTemplate
= R
"Py(
59 @_ods_cext.register_operation(_Dialect)
60 class {0}(_ods_ir.OpView):
61 OPERATION_NAME = "{1}"
64 /// Template for class level declarations of operand and result
66 /// {0} is either "OPERAND" or "RESULT"
67 /// {1} is the segment spec
68 /// Each segment spec is either None (default) or an array of integers
70 /// 1 = single element (expect non sequence operand/result)
71 /// 0 = optional element (expect a value or std::nullopt)
72 /// -1 = operand/result is a sequence corresponding to a variadic
73 constexpr const char *opClassSizedSegmentsTemplate
= R
"Py(
74 _ODS_{0}_SEGMENTS = {1}
77 /// Template for class level declarations of the _ODS_REGIONS spec:
78 /// {0} is the minimum number of regions
79 /// {1} is the Python bool literal for hasNoVariadicRegions
80 constexpr const char *opClassRegionSpecTemplate
= R
"Py(
81 _ODS_REGIONS = ({0}, {1})
84 /// Template for single-element accessor:
85 /// {0} is the name of the accessor;
86 /// {1} is either 'operand' or 'result';
87 /// {2} is the position in the element list.
88 constexpr const char *opSingleTemplate
= R
"Py(
91 return self.operation.{1}s[{2}]
94 /// Template for single-element accessor after a variable-length group:
95 /// {0} is the name of the accessor;
96 /// {1} is either 'operand' or 'result';
97 /// {2} is the total number of element groups;
98 /// {3} is the position of the current group in the group list.
99 /// This works for both a single variadic group (non-negative length) and an
100 /// single optional element (zero length if the element is absent).
101 constexpr const char *opSingleAfterVariableTemplate
= R
"Py(
104 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
105 return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
108 /// Template for an optional element accessor:
109 /// {0} is the name of the accessor;
110 /// {1} is either 'operand' or 'result';
111 /// {2} is the total number of element groups;
112 /// {3} is the position of the current group in the group list.
113 /// This works if we have only one variable-length group (and it's the optional
114 /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
115 /// smaller than the total number of groups.
116 constexpr const char *opOneOptionalTemplate
= R
"Py(
119 return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
122 /// Template for the variadic group accessor in the single variadic group case:
123 /// {0} is the name of the accessor;
124 /// {1} is either 'operand' or 'result';
125 /// {2} is the total number of element groups;
126 /// {3} is the position of the current group in the group list.
127 constexpr const char *opOneVariadicTemplate
= R
"Py(
130 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
131 return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
134 /// First part of the template for equally-sized variadic group accessor:
135 /// {0} is the name of the accessor;
136 /// {1} is either 'operand' or 'result';
137 /// {2} is the total number of variadic groups;
138 /// {3} is the number of non-variadic groups preceding the current group;
139 /// {3} is the number of variadic groups preceding the current group.
140 constexpr const char *opVariadicEqualPrefixTemplate
= R
"Py(
143 start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
145 /// Second part of the template for equally-sized case, accessing a single
147 /// {0} is either 'operand' or 'result'.
148 constexpr const char *opVariadicEqualSimpleTemplate
= R
"Py(
149 return self.operation.{0}s[start]
152 /// Second part of the template for equally-sized case, accessing a variadic
154 /// {0} is either 'operand' or 'result'.
155 constexpr const char *opVariadicEqualVariadicTemplate
= R
"Py(
156 return self.operation.{0}s[start:start + pg]
159 /// Template for an attribute-sized group accessor:
160 /// {0} is the name of the accessor;
161 /// {1} is either 'operand' or 'result';
162 /// {2} is the position of the group in the group list;
163 /// {3} is a return suffix (expected [0] for single-element, empty for
164 /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
165 constexpr const char *opVariadicSegmentTemplate
= R
"Py(
168 {1}_range = _ods_segmented_accessor(
170 self.operation.attributes["{1}SegmentSizes
"], {2})
174 /// Template for a suffix when accessing an optional element in the
175 /// attribute-sized case:
176 /// {0} is either 'operand' or 'result';
177 constexpr const char *opVariadicSegmentOptionalTrailingTemplate
=
178 R
"Py([0] if len({0}_range) > 0 else None)Py";
180 /// Template for an operation attribute getter:
181 /// {0} is the name of the attribute sanitized for Python;
182 /// {1} is the original name of the attribute.
183 constexpr const char *attributeGetterTemplate
= R
"Py(
186 return self.operation.attributes["{1}"]
189 /// Template for an optional operation attribute getter:
190 /// {0} is the name of the attribute sanitized for Python;
191 /// {1} is the original name of the attribute.
192 constexpr const char *optionalAttributeGetterTemplate
= R
"Py(
195 if "{1}" not in self.operation.attributes:
197 return self.operation.attributes["{1}"]
200 /// Template for a getter of a unit operation attribute, returns True of the
201 /// unit attribute is present, False otherwise (unit attributes have meaning
202 /// by mere presence):
203 /// {0} is the name of the attribute sanitized for Python,
204 /// {1} is the original name of the attribute.
205 constexpr const char *unitAttributeGetterTemplate
= R
"Py(
208 return "{1}" in self.operation.attributes
211 /// Template for an operation attribute setter:
212 /// {0} is the name of the attribute sanitized for Python;
213 /// {1} is the original name of the attribute.
214 constexpr const char *attributeSetterTemplate
= R
"Py(
216 def {0}(self, value):
218 raise ValueError("'None' not allowed as value
for mandatory attributes
")
219 self.operation.attributes["{1}"] = value
222 /// Template for a setter of an optional operation attribute, setting to None
223 /// removes the attribute:
224 /// {0} is the name of the attribute sanitized for Python;
225 /// {1} is the original name of the attribute.
226 constexpr const char *optionalAttributeSetterTemplate
= R
"Py(
228 def {0}(self, value):
229 if value is not None:
230 self.operation.attributes["{1}"] = value
231 elif "{1}" in self.operation.attributes:
232 del self.operation.attributes["{1}"]
235 /// Template for a setter of a unit operation attribute, setting to None or
236 /// False removes the attribute:
237 /// {0} is the name of the attribute sanitized for Python;
238 /// {1} is the original name of the attribute.
239 constexpr const char *unitAttributeSetterTemplate
= R
"Py(
241 def {0}(self, value):
243 self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get()
244 elif "{1}" in self.operation.attributes:
245 del self.operation.attributes["{1}"]
248 /// Template for a deleter of an optional or a unit operation attribute, removes
249 /// the attribute from the operation:
250 /// {0} is the name of the attribute sanitized for Python;
251 /// {1} is the original name of the attribute.
252 constexpr const char *attributeDeleterTemplate
= R
"Py(
255 del self.operation.attributes["{1}"]
258 constexpr const char *regionAccessorTemplate
= R
"Py(
261 return self.regions[{1}]
264 constexpr const char *valueBuilderTemplate
= R
"Py(
266 return _get_op_result_or_op_results({1}({3}))
269 static llvm::cl::OptionCategory
270 clOpPythonBindingCat("Options for -gen-python-op-bindings");
272 static llvm::cl::opt
<std::string
>
273 clDialectName("bind-dialect",
274 llvm::cl::desc("The dialect to run the generator for"),
275 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat
));
277 static llvm::cl::opt
<std::string
> clDialectExtensionName(
278 "dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
279 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat
));
281 using AttributeClasses
= DenseMap
<StringRef
, StringRef
>;
283 /// Checks whether `str` would shadow a generated variable or attribute
284 /// part of the OpView API.
285 static bool isODSReserved(StringRef str
) {
286 static llvm::StringSet
<> reserved(
287 {"attributes", "create", "context", "ip", "operands", "print", "get_asm",
288 "loc", "verify", "regions", "results", "self", "operation",
289 "DIALECT_NAMESPACE", "OPERATION_NAME"});
290 return str
.startswith("_ods_") || str
.endswith("_ods") ||
291 reserved
.contains(str
);
294 /// Modifies the `name` in a way that it becomes suitable for Python bindings
295 /// (does not change the `name` if it already is suitable) and returns the
296 /// modified version.
297 static std::string
sanitizeName(StringRef name
) {
298 std::string processedStr
= name
.str();
300 processedStr
.begin(), processedStr
.end(),
301 [](char c
) { return !llvm::isAlnum(c
); }, '_');
303 if (llvm::isDigit(*processedStr
.begin()))
304 return "_" + processedStr
;
306 if (isPythonReserved(processedStr
) || isODSReserved(processedStr
))
307 return processedStr
+ "_";
311 static std::string
attrSizedTraitForKind(const char *kind
) {
312 return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
313 llvm::StringRef(kind
).take_front().upper(),
314 llvm::StringRef(kind
).drop_front());
317 /// Emits accessors to "elements" of an Op definition. Currently, the supported
318 /// elements are operands and results, indicated by `kind`, which must be either
319 /// `operand` or `result` and is used verbatim in the emitted code.
320 static void emitElementAccessors(
321 const Operator
&op
, raw_ostream
&os
, const char *kind
,
322 llvm::function_ref
<unsigned(const Operator
&)> getNumVariableLength
,
323 llvm::function_ref
<int(const Operator
&)> getNumElements
,
324 llvm::function_ref
<const NamedTypeConstraint
&(const Operator
&, int)>
326 assert(llvm::is_contained(
327 llvm::SmallVector
<StringRef
, 2>{"operand", "result"}, kind
) &&
330 // Traits indicating how to process variadic elements.
331 std::string sameSizeTrait
=
332 llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
333 llvm::StringRef(kind
).take_front().upper(),
334 llvm::StringRef(kind
).drop_front());
335 std::string attrSizedTrait
= attrSizedTraitForKind(kind
);
337 unsigned numVariableLength
= getNumVariableLength(op
);
339 // If there is only one variable-length element group, its size can be
340 // inferred from the total number of elements. If there are none, the
341 // generation is straightforward.
342 if (numVariableLength
<= 1) {
343 bool seenVariableLength
= false;
344 for (int i
= 0, e
= getNumElements(op
); i
< e
; ++i
) {
345 const NamedTypeConstraint
&element
= getElement(op
, i
);
346 if (element
.isVariableLength())
347 seenVariableLength
= true;
348 if (element
.name
.empty())
350 if (element
.isVariableLength()) {
351 os
<< llvm::formatv(element
.isOptional() ? opOneOptionalTemplate
352 : opOneVariadicTemplate
,
353 sanitizeName(element
.name
), kind
,
354 getNumElements(op
), i
);
355 } else if (seenVariableLength
) {
356 os
<< llvm::formatv(opSingleAfterVariableTemplate
,
357 sanitizeName(element
.name
), kind
,
358 getNumElements(op
), i
);
360 os
<< llvm::formatv(opSingleTemplate
, sanitizeName(element
.name
), kind
,
367 // Handle the operations where variadic groups have the same size.
368 if (op
.getTrait(sameSizeTrait
)) {
369 int numPrecedingSimple
= 0;
370 int numPrecedingVariadic
= 0;
371 for (int i
= 0, e
= getNumElements(op
); i
< e
; ++i
) {
372 const NamedTypeConstraint
&element
= getElement(op
, i
);
373 if (!element
.name
.empty()) {
374 os
<< llvm::formatv(opVariadicEqualPrefixTemplate
,
375 sanitizeName(element
.name
), kind
, numVariableLength
,
376 numPrecedingSimple
, numPrecedingVariadic
);
377 os
<< llvm::formatv(element
.isVariableLength()
378 ? opVariadicEqualVariadicTemplate
379 : opVariadicEqualSimpleTemplate
,
382 if (element
.isVariableLength())
383 ++numPrecedingVariadic
;
385 ++numPrecedingSimple
;
390 // Handle the operations where the size of groups (variadic or not) is
391 // provided as an attribute. For non-variadic elements, make sure to return
392 // an element rather than a singleton container.
393 if (op
.getTrait(attrSizedTrait
)) {
394 for (int i
= 0, e
= getNumElements(op
); i
< e
; ++i
) {
395 const NamedTypeConstraint
&element
= getElement(op
, i
);
396 if (element
.name
.empty())
398 std::string trailing
;
399 if (!element
.isVariableLength())
401 else if (element
.isOptional())
402 trailing
= std::string(
403 llvm::formatv(opVariadicSegmentOptionalTrailingTemplate
, kind
));
404 os
<< llvm::formatv(opVariadicSegmentTemplate
, sanitizeName(element
.name
),
410 llvm::PrintFatalError("unsupported " + llvm::Twine(kind
) + " structure");
413 /// Free function helpers accessing Operator components.
414 static int getNumOperands(const Operator
&op
) { return op
.getNumOperands(); }
415 static const NamedTypeConstraint
&getOperand(const Operator
&op
, int i
) {
416 return op
.getOperand(i
);
418 static int getNumResults(const Operator
&op
) { return op
.getNumResults(); }
419 static const NamedTypeConstraint
&getResult(const Operator
&op
, int i
) {
420 return op
.getResult(i
);
423 /// Emits accessors to Op operands.
424 static void emitOperandAccessors(const Operator
&op
, raw_ostream
&os
) {
425 auto getNumVariableLengthOperands
= [](const Operator
&oper
) {
426 return oper
.getNumVariableLengthOperands();
428 emitElementAccessors(op
, os
, "operand", getNumVariableLengthOperands
,
429 getNumOperands
, getOperand
);
432 /// Emits accessors Op results.
433 static void emitResultAccessors(const Operator
&op
, raw_ostream
&os
) {
434 auto getNumVariableLengthResults
= [](const Operator
&oper
) {
435 return oper
.getNumVariableLengthResults();
437 emitElementAccessors(op
, os
, "result", getNumVariableLengthResults
,
438 getNumResults
, getResult
);
441 /// Emits accessors to Op attributes.
442 static void emitAttributeAccessors(const Operator
&op
, raw_ostream
&os
) {
443 for (const auto &namedAttr
: op
.getAttributes()) {
444 // Skip "derived" attributes because they are just C++ functions that we
445 // don't currently expose.
446 if (namedAttr
.attr
.isDerivedAttr())
449 if (namedAttr
.name
.empty())
452 std::string sanitizedName
= sanitizeName(namedAttr
.name
);
454 // Unit attributes are handled specially.
455 if (namedAttr
.attr
.getStorageType().trim().equals("::mlir::UnitAttr")) {
456 os
<< llvm::formatv(unitAttributeGetterTemplate
, sanitizedName
,
458 os
<< llvm::formatv(unitAttributeSetterTemplate
, sanitizedName
,
460 os
<< llvm::formatv(attributeDeleterTemplate
, sanitizedName
,
465 if (namedAttr
.attr
.isOptional()) {
466 os
<< llvm::formatv(optionalAttributeGetterTemplate
, sanitizedName
,
468 os
<< llvm::formatv(optionalAttributeSetterTemplate
, sanitizedName
,
470 os
<< llvm::formatv(attributeDeleterTemplate
, sanitizedName
,
473 os
<< llvm::formatv(attributeGetterTemplate
, sanitizedName
,
475 os
<< llvm::formatv(attributeSetterTemplate
, sanitizedName
,
477 // Non-optional attributes cannot be deleted.
482 /// Template for the default auto-generated builder.
483 /// {0} is a comma-separated list of builder arguments, including the trailing
485 /// {1} is the code populating `operands`, `results` and `attributes`,
486 /// `successors` fields.
487 constexpr const char *initTemplate
= R
"Py(
488 def __init__(self, {0}):
494 super().__init__(self.build_generic({2}))
497 /// Template for appending a single element to the operand/result list.
498 /// {0} is the field name.
499 constexpr const char *singleOperandAppendTemplate
=
500 "operands.append(_get_op_result_or_value({0}))";
501 constexpr const char *singleResultAppendTemplate
= "results.append({0})";
503 /// Template for appending an optional element to the operand/result list.
504 /// {0} is the field name.
505 constexpr const char *optionalAppendOperandTemplate
=
506 "if {0} is not None: operands.append(_get_op_result_or_value({0}))";
507 constexpr const char *optionalAppendAttrSizedOperandsTemplate
=
508 "operands.append(_get_op_result_or_value({0}) if {0} is not None else "
510 constexpr const char *optionalAppendResultTemplate
=
511 "if {0} is not None: results.append({0})";
513 /// Template for appending a list of elements to the operand/result list.
514 /// {0} is the field name.
515 constexpr const char *multiOperandAppendTemplate
=
516 "operands.extend(_get_op_results_or_values({0}))";
517 constexpr const char *multiOperandAppendPackTemplate
=
518 "operands.append(_get_op_results_or_values({0}))";
519 constexpr const char *multiResultAppendTemplate
= "results.extend({0})";
521 /// Template for attribute builder from raw input in the operation builder.
522 /// {0} is the builder argument name;
523 /// {1} is the attribute builder from raw;
524 /// {2} is the attribute builder from raw.
525 /// Use the value the user passed in if either it is already an Attribute or
526 /// there is no method registered to make it an Attribute.
527 constexpr const char *initAttributeWithBuilderTemplate
=
528 R
"Py(attributes["{1}"] = ({0} if (
529 issubclass(type({0}), _ods_ir.Attribute) or
530 not _ods_ir.AttrBuilder.contains('{2}')) else
531 _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
533 /// Template for attribute builder from raw input for optional attribute in the
534 /// operation builder.
535 /// {0} is the builder argument name;
536 /// {1} is the attribute builder from raw;
537 /// {2} is the attribute builder from raw.
538 /// Use the value the user passed in if either it is already an Attribute or
539 /// there is no method registered to make it an Attribute.
540 constexpr const char *initOptionalAttributeWithBuilderTemplate
=
541 R
"Py(if {0} is not None: attributes["{1}"] = ({0} if (
542 issubclass(type({0}), _ods_ir.Attribute) or
543 not _ods_ir.AttrBuilder.contains('{2}')) else
544 _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
546 constexpr const char *initUnitAttributeTemplate
=
547 R
"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
548 _ods_get_default_loc_context(loc)))Py";
550 /// Template to initialize the successors list in the builder if there are any
552 /// {0} is the value to initialize the successors list to.
553 constexpr const char *initSuccessorsTemplate
= R
"Py(_ods_successors = {0})Py";
555 /// Template to append or extend the list of successors in the builder.
556 /// {0} is the list method ('append' or 'extend');
557 /// {1} is the value to add.
558 constexpr const char *addSuccessorTemplate
= R
"Py(_ods_successors.{0}({1}))Py";
560 /// Returns true if the SameArgumentAndResultTypes trait can be used to infer
561 /// result types of the given operation.
562 static bool hasSameArgumentAndResultTypes(const Operator
&op
) {
563 return op
.getTrait("::mlir::OpTrait::SameOperandsAndResultType") &&
564 op
.getNumVariableLengthResults() == 0;
567 /// Returns true if the FirstAttrDerivedResultType trait can be used to infer
568 /// result types of the given operation.
569 static bool hasFirstAttrDerivedResultTypes(const Operator
&op
) {
570 return op
.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") &&
571 op
.getNumVariableLengthResults() == 0;
574 /// Returns true if the InferTypeOpInterface can be used to infer result types
575 /// of the given operation.
576 static bool hasInferTypeInterface(const Operator
&op
) {
577 return op
.getTrait("::mlir::InferTypeOpInterface::Trait") &&
578 op
.getNumRegions() == 0;
581 /// Returns true if there is a trait or interface that can be used to infer
582 /// result types of the given operation.
583 static bool canInferType(const Operator
&op
) {
584 return hasSameArgumentAndResultTypes(op
) ||
585 hasFirstAttrDerivedResultTypes(op
) || hasInferTypeInterface(op
);
588 /// Populates `builderArgs` with result names if the builder is expected to
589 /// accept them as arguments.
591 populateBuilderArgsResults(const Operator
&op
,
592 llvm::SmallVectorImpl
<std::string
> &builderArgs
) {
593 if (canInferType(op
))
596 for (int i
= 0, e
= op
.getNumResults(); i
< e
; ++i
) {
597 std::string name
= op
.getResultName(i
).str();
599 if (op
.getNumResults() == 1) {
600 // Special case for one result, make the default name be 'result'
601 // to properly match the built-in result accessor.
604 name
= llvm::formatv("_gen_res_{0}", i
);
607 name
= sanitizeName(name
);
608 builderArgs
.push_back(name
);
612 /// Populates `builderArgs` with the Python-compatible names of builder function
613 /// arguments using intermixed attributes and operands in the same order as they
614 /// appear in the `arguments` field of the op definition. Additionally,
615 /// `operandNames` is populated with names of operands in their order of
618 populateBuilderArgs(const Operator
&op
,
619 llvm::SmallVectorImpl
<std::string
> &builderArgs
,
620 llvm::SmallVectorImpl
<std::string
> &operandNames
) {
621 for (int i
= 0, e
= op
.getNumArgs(); i
< e
; ++i
) {
622 std::string name
= op
.getArgName(i
).str();
624 name
= llvm::formatv("_gen_arg_{0}", i
);
625 name
= sanitizeName(name
);
626 builderArgs
.push_back(name
);
627 if (!op
.getArg(i
).is
<NamedAttribute
*>())
628 operandNames
.push_back(name
);
632 /// Populates `builderArgs` with the Python-compatible names of builder function
633 /// successor arguments. Additionally, `successorArgNames` is also populated.
634 static void populateBuilderArgsSuccessors(
635 const Operator
&op
, llvm::SmallVectorImpl
<std::string
> &builderArgs
,
636 llvm::SmallVectorImpl
<std::string
> &successorArgNames
) {
638 for (int i
= 0, e
= op
.getNumSuccessors(); i
< e
; ++i
) {
639 NamedSuccessor successor
= op
.getSuccessor(i
);
640 std::string name
= std::string(successor
.name
);
642 name
= llvm::formatv("_gen_successor_{0}", i
);
643 name
= sanitizeName(name
);
644 builderArgs
.push_back(name
);
645 successorArgNames
.push_back(name
);
649 /// Populates `builderLines` with additional lines that are required in the
650 /// builder to set up operation attributes. `argNames` is expected to contain
651 /// the names of builder arguments that correspond to op arguments, i.e. to the
652 /// operands and attributes in the same order as they appear in the `arguments`
655 populateBuilderLinesAttr(const Operator
&op
,
656 llvm::ArrayRef
<std::string
> argNames
,
657 llvm::SmallVectorImpl
<std::string
> &builderLines
) {
658 builderLines
.push_back("_ods_context = _ods_get_default_loc_context(loc)");
659 for (int i
= 0, e
= op
.getNumArgs(); i
< e
; ++i
) {
660 Argument arg
= op
.getArg(i
);
661 auto *attribute
= llvm::dyn_cast_if_present
<NamedAttribute
*>(arg
);
665 // Unit attributes are handled specially.
666 if (attribute
->attr
.getStorageType().trim().equals("::mlir::UnitAttr")) {
667 builderLines
.push_back(llvm::formatv(initUnitAttributeTemplate
,
668 attribute
->name
, argNames
[i
]));
672 builderLines
.push_back(llvm::formatv(
673 attribute
->attr
.isOptional() || attribute
->attr
.hasDefaultValue()
674 ? initOptionalAttributeWithBuilderTemplate
675 : initAttributeWithBuilderTemplate
,
676 argNames
[i
], attribute
->name
, attribute
->attr
.getAttrDefName()));
680 /// Populates `builderLines` with additional lines that are required in the
681 /// builder to set up successors. successorArgNames is expected to correspond
682 /// to the Python argument name for each successor on the op.
683 static void populateBuilderLinesSuccessors(
684 const Operator
&op
, llvm::ArrayRef
<std::string
> successorArgNames
,
685 llvm::SmallVectorImpl
<std::string
> &builderLines
) {
686 if (successorArgNames
.empty()) {
687 builderLines
.push_back(llvm::formatv(initSuccessorsTemplate
, "None"));
691 builderLines
.push_back(llvm::formatv(initSuccessorsTemplate
, "[]"));
692 for (int i
= 0, e
= successorArgNames
.size(); i
< e
; ++i
) {
693 auto &argName
= successorArgNames
[i
];
694 const NamedSuccessor
&successor
= op
.getSuccessor(i
);
695 builderLines
.push_back(
696 llvm::formatv(addSuccessorTemplate
,
697 successor
.isVariadic() ? "extend" : "append", argName
));
701 /// Populates `builderLines` with additional lines that are required in the
702 /// builder to set up op operands.
704 populateBuilderLinesOperand(const Operator
&op
,
705 llvm::ArrayRef
<std::string
> names
,
706 llvm::SmallVectorImpl
<std::string
> &builderLines
) {
707 bool sizedSegments
= op
.getTrait(attrSizedTraitForKind("operand")) != nullptr;
709 // For each element, find or generate a name.
710 for (int i
= 0, e
= op
.getNumOperands(); i
< e
; ++i
) {
711 const NamedTypeConstraint
&element
= op
.getOperand(i
);
712 std::string name
= names
[i
];
714 // Choose the formatting string based on the element kind.
715 llvm::StringRef formatString
;
716 if (!element
.isVariableLength()) {
717 formatString
= singleOperandAppendTemplate
;
718 } else if (element
.isOptional()) {
720 formatString
= optionalAppendAttrSizedOperandsTemplate
;
722 formatString
= optionalAppendOperandTemplate
;
725 assert(element
.isVariadic() && "unhandled element group type");
726 // If emitting with sizedSegments, then we add the actual list-typed
727 // element. Otherwise, we extend the actual operands.
729 formatString
= multiOperandAppendPackTemplate
;
731 formatString
= multiOperandAppendTemplate
;
735 builderLines
.push_back(llvm::formatv(formatString
.data(), name
));
739 /// Python code template for deriving the operation result types from its
741 /// - {0} is the name of the attribute from which to derive the types.
742 constexpr const char *deriveTypeFromAttrTemplate
=
743 R
"Py(_ods_result_type_source_attr = attributes["{0}"]
744 _ods_derived_result_type = (
745 _ods_ir.TypeAttr(_ods_result_type_source_attr).value
746 if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
747 _ods_result_type_source_attr.type))Py";
749 /// Python code template appending {0} type {1} times to the results list.
750 constexpr const char *appendSameResultsTemplate
= "results.extend([{0}] * {1})";
752 /// Appends the given multiline string as individual strings into
754 static void appendLineByLine(StringRef string
,
755 llvm::SmallVectorImpl
<std::string
> &builderLines
) {
757 std::pair
<StringRef
, StringRef
> split
= std::make_pair(string
, string
);
759 split
= split
.second
.split('\n');
760 builderLines
.push_back(split
.first
.str());
761 } while (!split
.second
.empty());
764 /// Populates `builderLines` with additional lines that are required in the
765 /// builder to set up op results.
767 populateBuilderLinesResult(const Operator
&op
,
768 llvm::ArrayRef
<std::string
> names
,
769 llvm::SmallVectorImpl
<std::string
> &builderLines
) {
770 bool sizedSegments
= op
.getTrait(attrSizedTraitForKind("result")) != nullptr;
772 if (hasSameArgumentAndResultTypes(op
)) {
773 builderLines
.push_back(llvm::formatv(
774 appendSameResultsTemplate
, "operands[0].type", op
.getNumResults()));
778 if (hasFirstAttrDerivedResultTypes(op
)) {
779 const NamedAttribute
&firstAttr
= op
.getAttribute(0);
780 assert(!firstAttr
.name
.empty() && "unexpected empty name for the attribute "
781 "from which the type is derived");
783 llvm::formatv(deriveTypeFromAttrTemplate
, firstAttr
.name
).str(),
785 builderLines
.push_back(llvm::formatv(appendSameResultsTemplate
,
786 "_ods_derived_result_type",
787 op
.getNumResults()));
791 if (hasInferTypeInterface(op
))
794 // For each element, find or generate a name.
795 for (int i
= 0, e
= op
.getNumResults(); i
< e
; ++i
) {
796 const NamedTypeConstraint
&element
= op
.getResult(i
);
797 std::string name
= names
[i
];
799 // Choose the formatting string based on the element kind.
800 llvm::StringRef formatString
;
801 if (!element
.isVariableLength()) {
802 formatString
= singleResultAppendTemplate
;
803 } else if (element
.isOptional()) {
804 formatString
= optionalAppendResultTemplate
;
806 assert(element
.isVariadic() && "unhandled element group type");
807 // If emitting with sizedSegments, then we add the actual list-typed
808 // element. Otherwise, we extend the actual operands.
810 formatString
= singleResultAppendTemplate
;
812 formatString
= multiResultAppendTemplate
;
816 builderLines
.push_back(llvm::formatv(formatString
.data(), name
));
820 /// If the operation has variadic regions, adds a builder argument to specify
821 /// the number of those regions and builder lines to forward it to the generic
824 populateBuilderRegions(const Operator
&op
,
825 llvm::SmallVectorImpl
<std::string
> &builderArgs
,
826 llvm::SmallVectorImpl
<std::string
> &builderLines
) {
827 if (op
.hasNoVariadicRegions())
830 // This is currently enforced when Operator is constructed.
831 assert(op
.getNumVariadicRegions() == 1 &&
832 op
.getRegion(op
.getNumRegions() - 1).isVariadic() &&
833 "expected the last region to be varidic");
835 const NamedRegion
®ion
= op
.getRegion(op
.getNumRegions() - 1);
837 ("num_" + region
.name
.take_front().lower() + region
.name
.drop_front())
839 builderArgs
.push_back(name
);
840 builderLines
.push_back(
841 llvm::formatv("regions = {0} + {1}", op
.getNumRegions() - 1, name
));
844 /// Emits a default builder constructing an operation from the list of its
845 /// result types, followed by a list of its operands. Returns vector
846 /// of fully built functionArgs for downstream users (to save having to
848 static llvm::SmallVector
<std::string
> emitDefaultOpBuilder(const Operator
&op
,
850 llvm::SmallVector
<std::string
> builderArgs
;
851 llvm::SmallVector
<std::string
> builderLines
;
852 llvm::SmallVector
<std::string
> operandArgNames
;
853 llvm::SmallVector
<std::string
> successorArgNames
;
854 builderArgs
.reserve(op
.getNumOperands() + op
.getNumResults() +
855 op
.getNumNativeAttributes() + op
.getNumSuccessors());
856 populateBuilderArgsResults(op
, builderArgs
);
857 size_t numResultArgs
= builderArgs
.size();
858 populateBuilderArgs(op
, builderArgs
, operandArgNames
);
859 size_t numOperandAttrArgs
= builderArgs
.size() - numResultArgs
;
860 populateBuilderArgsSuccessors(op
, builderArgs
, successorArgNames
);
862 populateBuilderLinesOperand(op
, operandArgNames
, builderLines
);
863 populateBuilderLinesAttr(
864 op
, llvm::ArrayRef(builderArgs
).drop_front(numResultArgs
), builderLines
);
865 populateBuilderLinesResult(
866 op
, llvm::ArrayRef(builderArgs
).take_front(numResultArgs
), builderLines
);
867 populateBuilderLinesSuccessors(op
, successorArgNames
, builderLines
);
868 populateBuilderRegions(op
, builderArgs
, builderLines
);
870 // Layout of builderArgs vector elements:
871 // [ result_args operand_attr_args successor_args regions ]
873 // Determine whether the argument corresponding to a given index into the
874 // builderArgs vector is a python keyword argument or not.
875 auto isKeywordArgFn
= [&](size_t builderArgIndex
) -> bool {
876 // All result, successor, and region arguments are positional arguments.
877 if ((builderArgIndex
< numResultArgs
) ||
878 (builderArgIndex
>= (numResultArgs
+ numOperandAttrArgs
)))
880 // Keyword arguments:
881 // - optional named attributes (including unit attributes)
882 // - default-valued named attributes
883 // - optional operands
884 Argument a
= op
.getArg(builderArgIndex
- numResultArgs
);
885 if (auto *nattr
= llvm::dyn_cast_if_present
<NamedAttribute
*>(a
))
886 return (nattr
->attr
.isOptional() || nattr
->attr
.hasDefaultValue());
887 if (auto *ntype
= llvm::dyn_cast_if_present
<NamedTypeConstraint
*>(a
))
888 return ntype
->isOptional();
892 // StringRefs in functionArgs refer to strings allocated by builderArgs.
893 llvm::SmallVector
<llvm::StringRef
> functionArgs
;
895 // Add positional arguments.
896 for (size_t i
= 0, cnt
= builderArgs
.size(); i
< cnt
; ++i
) {
897 if (!isKeywordArgFn(i
))
898 functionArgs
.push_back(builderArgs
[i
]);
901 // Add a bare '*' to indicate that all following arguments must be keyword
903 functionArgs
.push_back("*");
905 // Add a default 'None' value to each keyword arg string, and then add to the
906 // function args list.
907 for (size_t i
= 0, cnt
= builderArgs
.size(); i
< cnt
; ++i
) {
908 if (isKeywordArgFn(i
)) {
909 builderArgs
[i
].append("=None");
910 functionArgs
.push_back(builderArgs
[i
]);
913 functionArgs
.push_back("loc=None");
914 functionArgs
.push_back("ip=None");
916 SmallVector
<std::string
> initArgs
;
917 initArgs
.push_back("attributes=attributes");
918 if (!hasInferTypeInterface(op
))
919 initArgs
.push_back("results=results");
920 initArgs
.push_back("operands=operands");
921 initArgs
.push_back("successors=_ods_successors");
922 initArgs
.push_back("regions=regions");
923 initArgs
.push_back("loc=loc");
924 initArgs
.push_back("ip=ip");
926 os
<< llvm::formatv(initTemplate
, llvm::join(functionArgs
, ", "),
927 llvm::join(builderLines
, "\n "),
928 llvm::join(initArgs
, ", "));
929 return llvm::to_vector
<8>(
930 llvm::map_range(functionArgs
, [](llvm::StringRef s
) { return s
.str(); }));
933 static void emitSegmentSpec(
934 const Operator
&op
, const char *kind
,
935 llvm::function_ref
<int(const Operator
&)> getNumElements
,
936 llvm::function_ref
<const NamedTypeConstraint
&(const Operator
&, int)>
939 std::string
segmentSpec("[");
940 for (int i
= 0, e
= getNumElements(op
); i
< e
; ++i
) {
941 const NamedTypeConstraint
&element
= getElement(op
, i
);
942 if (element
.isOptional()) {
943 segmentSpec
.append("0,");
944 } else if (element
.isVariadic()) {
945 segmentSpec
.append("-1,");
947 segmentSpec
.append("1,");
950 segmentSpec
.append("]");
952 os
<< llvm::formatv(opClassSizedSegmentsTemplate
, kind
, segmentSpec
);
955 static void emitRegionAttributes(const Operator
&op
, raw_ostream
&os
) {
956 // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
957 // Note that the base OpView class defines this as (0, True).
958 unsigned minRegionCount
= op
.getNumRegions() - op
.getNumVariadicRegions();
959 os
<< llvm::formatv(opClassRegionSpecTemplate
, minRegionCount
,
960 op
.hasNoVariadicRegions() ? "True" : "False");
963 /// Emits named accessors to regions.
964 static void emitRegionAccessors(const Operator
&op
, raw_ostream
&os
) {
965 for (const auto &en
: llvm::enumerate(op
.getRegions())) {
966 const NamedRegion
®ion
= en
.value();
967 if (region
.name
.empty())
970 assert((!region
.isVariadic() || en
.index() == op
.getNumRegions() - 1) &&
971 "expected only the last region to be variadic");
972 os
<< llvm::formatv(regionAccessorTemplate
, sanitizeName(region
.name
),
973 std::to_string(en
.index()) +
974 (region
.isVariadic() ? ":" : ""));
978 /// Emits builder that extracts results from op
979 static void emitValueBuilder(const Operator
&op
,
980 llvm::SmallVector
<std::string
> functionArgs
,
982 // Params with (possibly) default args.
983 auto valueBuilderParams
=
984 llvm::map_range(functionArgs
, [](const std::string
&argAndMaybeDefault
) {
985 llvm::SmallVector
<llvm::StringRef
> argMaybeDefault
=
986 llvm::to_vector
<2>(llvm::split(argAndMaybeDefault
, "="));
987 auto arg
= llvm::convertToSnakeFromCamelCase(argMaybeDefault
[0]);
988 if (argMaybeDefault
.size() == 2)
989 return arg
+ "=" + argMaybeDefault
[1].str();
992 // Actual args passed to op builder (e.g., opParam=op_param).
993 auto opBuilderArgs
= llvm::map_range(
994 llvm::make_filter_range(functionArgs
,
995 [](const std::string
&s
) { return s
!= "*"; }),
996 [](const std::string
&arg
) {
997 auto lhs
= *llvm::split(arg
, "=").begin();
998 return (lhs
+ "=" + llvm::convertToSnakeFromCamelCase(lhs
)).str();
1000 std::string nameWithoutDialect
=
1001 op
.getOperationName().substr(op
.getOperationName().find('.') + 1);
1002 os
<< llvm::formatv(valueBuilderTemplate
, sanitizeName(nameWithoutDialect
),
1003 op
.getCppClassName(),
1004 llvm::join(valueBuilderParams
, ", "),
1005 llvm::join(opBuilderArgs
, ", "),
1006 (op
.getNumResults() > 1
1007 ? "_Sequence[_ods_ir.OpResult]"
1008 : (op
.getNumResults() > 0 ? "_ods_ir.OpResult"
1009 : "_ods_ir.Operation")));
1012 /// Emits bindings for a specific Op to the given output stream.
1013 static void emitOpBindings(const Operator
&op
, raw_ostream
&os
) {
1014 os
<< llvm::formatv(opClassTemplate
, op
.getCppClassName(),
1015 op
.getOperationName());
1018 if (op
.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
1019 emitSegmentSpec(op
, "OPERAND", getNumOperands
, getOperand
, os
);
1021 if (op
.getTrait(attrSizedTraitForKind("result")) != nullptr) {
1022 emitSegmentSpec(op
, "RESULT", getNumResults
, getResult
, os
);
1025 emitRegionAttributes(op
, os
);
1026 llvm::SmallVector
<std::string
> functionArgs
= emitDefaultOpBuilder(op
, os
);
1027 emitOperandAccessors(op
, os
);
1028 emitAttributeAccessors(op
, os
);
1029 emitResultAccessors(op
, os
);
1030 emitRegionAccessors(op
, os
);
1031 emitValueBuilder(op
, functionArgs
, os
);
1034 /// Emits bindings for the dialect specified in the command line, including file
1035 /// headers and utilities. Returns `false` on success to comply with Tablegen
1036 /// registration requirements.
1037 static bool emitAllOps(const llvm::RecordKeeper
&records
, raw_ostream
&os
) {
1038 if (clDialectName
.empty())
1039 llvm::PrintFatalError("dialect name not provided");
1042 if (!clDialectExtensionName
.empty())
1043 os
<< llvm::formatv(dialectExtensionTemplate
, clDialectName
.getValue());
1045 os
<< llvm::formatv(dialectClassTemplate
, clDialectName
.getValue());
1047 for (const llvm::Record
*rec
: records
.getAllDerivedDefinitions("Op")) {
1049 if (op
.getDialectName() == clDialectName
.getValue())
1050 emitOpBindings(op
, os
);
1055 static GenRegistration
1056 genPythonBindings("gen-python-op-bindings",
1057 "Generate Python bindings for MLIR Ops", &emitAllOps
);