1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
10 // binding classes wrapping a generic operation API.
12 //===----------------------------------------------------------------------===//
14 #include "OpGenHelpers.h"
16 #include "mlir/TableGen/GenInfo.h"
17 #include "mlir/TableGen/Operator.h"
18 #include "llvm/ADT/StringSet.h"
19 #include "llvm/Support/CommandLine.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/TableGen/Error.h"
22 #include "llvm/TableGen/Record.h"
25 using namespace mlir::tblgen
;
27 /// File header and includes.
28 /// {0} is the dialect namespace.
29 constexpr const char *fileHeader
= R
"Py(
30 # Autogenerated by mlir-tblgen; don't manually edit.
32 from ._ods_common import _cext as _ods_cext
33 from ._ods_common import (
34 equally_sized_accessor as _ods_equally_sized_accessor,
35 get_default_loc_context as _ods_get_default_loc_context,
36 get_op_result_or_op_results as _get_op_result_or_op_results,
37 get_op_result_or_value as _get_op_result_or_value,
38 get_op_results_or_values as _get_op_results_or_values,
39 segmented_accessor as _ods_segmented_accessor,
41 _ods_ir = _ods_cext.ir
44 from typing import Sequence as _Sequence, Union as _Union
48 /// Template for dialect class:
49 /// {0} is the dialect namespace.
50 constexpr const char *dialectClassTemplate
= R
"Py(
51 @_ods_cext.register_dialect
52 class _Dialect(_ods_ir.Dialect):
53 DIALECT_NAMESPACE = "{0}"
56 constexpr const char *dialectExtensionTemplate
= R
"Py(
57 from ._{0}_ops_gen import _Dialect
60 /// Template for operation class:
61 /// {0} is the Python class name;
62 /// {1} is the operation name.
63 constexpr const char *opClassTemplate
= R
"Py(
64 @_ods_cext.register_operation(_Dialect)
65 class {0}(_ods_ir.OpView):
66 OPERATION_NAME = "{1}"
69 /// Template for class level declarations of operand and result
71 /// {0} is either "OPERAND" or "RESULT"
72 /// {1} is the segment spec
73 /// Each segment spec is either None (default) or an array of integers
75 /// 1 = single element (expect non sequence operand/result)
76 /// 0 = optional element (expect a value or std::nullopt)
77 /// -1 = operand/result is a sequence corresponding to a variadic
78 constexpr const char *opClassSizedSegmentsTemplate
= R
"Py(
79 _ODS_{0}_SEGMENTS = {1}
82 /// Template for class level declarations of the _ODS_REGIONS spec:
83 /// {0} is the minimum number of regions
84 /// {1} is the Python bool literal for hasNoVariadicRegions
85 constexpr const char *opClassRegionSpecTemplate
= R
"Py(
86 _ODS_REGIONS = ({0}, {1})
89 /// Template for single-element accessor:
90 /// {0} is the name of the accessor;
91 /// {1} is either 'operand' or 'result';
92 /// {2} is the position in the element list.
93 constexpr const char *opSingleTemplate
= R
"Py(
96 return self.operation.{1}s[{2}]
99 /// Template for single-element accessor after a variable-length group:
100 /// {0} is the name of the accessor;
101 /// {1} is either 'operand' or 'result';
102 /// {2} is the total number of element groups;
103 /// {3} is the position of the current group in the group list.
104 /// This works for both a single variadic group (non-negative length) and an
105 /// single optional element (zero length if the element is absent).
106 constexpr const char *opSingleAfterVariableTemplate
= R
"Py(
109 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
110 return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
113 /// Template for an optional element accessor:
114 /// {0} is the name of the accessor;
115 /// {1} is either 'operand' or 'result';
116 /// {2} is the total number of element groups;
117 /// {3} is the position of the current group in the group list.
118 /// This works if we have only one variable-length group (and it's the optional
119 /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
120 /// smaller than the total number of groups.
121 constexpr const char *opOneOptionalTemplate
= R
"Py(
124 return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
127 /// Template for the variadic group accessor in the single variadic group case:
128 /// {0} is the name of the accessor;
129 /// {1} is either 'operand' or 'result';
130 /// {2} is the total number of element groups;
131 /// {3} is the position of the current group in the group list.
132 constexpr const char *opOneVariadicTemplate
= R
"Py(
135 _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
136 return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
139 /// First part of the template for equally-sized variadic group accessor:
140 /// {0} is the name of the accessor;
141 /// {1} is either 'operand' or 'result';
142 /// {2} is the total number of variadic groups;
143 /// {3} is the number of non-variadic groups preceding the current group;
144 /// {3} is the number of variadic groups preceding the current group.
145 constexpr const char *opVariadicEqualPrefixTemplate
= R
"Py(
148 start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
150 /// Second part of the template for equally-sized case, accessing a single
152 /// {0} is either 'operand' or 'result'.
153 constexpr const char *opVariadicEqualSimpleTemplate
= R
"Py(
154 return self.operation.{0}s[start]
157 /// Second part of the template for equally-sized case, accessing a variadic
159 /// {0} is either 'operand' or 'result'.
160 constexpr const char *opVariadicEqualVariadicTemplate
= R
"Py(
161 return self.operation.{0}s[start:start + pg]
164 /// Template for an attribute-sized group accessor:
165 /// {0} is the name of the accessor;
166 /// {1} is either 'operand' or 'result';
167 /// {2} is the position of the group in the group list;
168 /// {3} is a return suffix (expected [0] for single-element, empty for
169 /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
170 constexpr const char *opVariadicSegmentTemplate
= R
"Py(
173 {1}_range = _ods_segmented_accessor(
175 self.operation.attributes["{1}SegmentSizes
"], {2})
179 /// Template for a suffix when accessing an optional element in the
180 /// attribute-sized case:
181 /// {0} is either 'operand' or 'result';
182 constexpr const char *opVariadicSegmentOptionalTrailingTemplate
=
183 R
"Py([0] if len({0}_range) > 0 else None)Py";
185 /// Template for an operation attribute getter:
186 /// {0} is the name of the attribute sanitized for Python;
187 /// {1} is the original name of the attribute.
188 constexpr const char *attributeGetterTemplate
= R
"Py(
191 return self.operation.attributes["{1}"]
194 /// Template for an optional operation attribute getter:
195 /// {0} is the name of the attribute sanitized for Python;
196 /// {1} is the original name of the attribute.
197 constexpr const char *optionalAttributeGetterTemplate
= R
"Py(
200 if "{1}" not in self.operation.attributes:
202 return self.operation.attributes["{1}"]
205 /// Template for a getter of a unit operation attribute, returns True of the
206 /// unit attribute is present, False otherwise (unit attributes have meaning
207 /// by mere presence):
208 /// {0} is the name of the attribute sanitized for Python,
209 /// {1} is the original name of the attribute.
210 constexpr const char *unitAttributeGetterTemplate
= R
"Py(
213 return "{1}" in self.operation.attributes
216 /// Template for an operation attribute setter:
217 /// {0} is the name of the attribute sanitized for Python;
218 /// {1} is the original name of the attribute.
219 constexpr const char *attributeSetterTemplate
= R
"Py(
221 def {0}(self, value):
223 raise ValueError("'None' not allowed as value
for mandatory attributes
")
224 self.operation.attributes["{1}"] = value
227 /// Template for a setter of an optional operation attribute, setting to None
228 /// removes the attribute:
229 /// {0} is the name of the attribute sanitized for Python;
230 /// {1} is the original name of the attribute.
231 constexpr const char *optionalAttributeSetterTemplate
= R
"Py(
233 def {0}(self, value):
234 if value is not None:
235 self.operation.attributes["{1}"] = value
236 elif "{1}" in self.operation.attributes:
237 del self.operation.attributes["{1}"]
240 /// Template for a setter of a unit operation attribute, setting to None or
241 /// False removes the attribute:
242 /// {0} is the name of the attribute sanitized for Python;
243 /// {1} is the original name of the attribute.
244 constexpr const char *unitAttributeSetterTemplate
= R
"Py(
246 def {0}(self, value):
248 self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get()
249 elif "{1}" in self.operation.attributes:
250 del self.operation.attributes["{1}"]
253 /// Template for a deleter of an optional or a unit operation attribute, removes
254 /// the attribute from the operation:
255 /// {0} is the name of the attribute sanitized for Python;
256 /// {1} is the original name of the attribute.
257 constexpr const char *attributeDeleterTemplate
= R
"Py(
260 del self.operation.attributes["{1}"]
263 constexpr const char *regionAccessorTemplate
= R
"Py(
266 return self.regions[{1}]
269 constexpr const char *valueBuilderTemplate
= R
"Py(
271 return _get_op_result_or_op_results({1}({3}))
274 static llvm::cl::OptionCategory
275 clOpPythonBindingCat("Options for -gen-python-op-bindings");
277 static llvm::cl::opt
<std::string
>
278 clDialectName("bind-dialect",
279 llvm::cl::desc("The dialect to run the generator for"),
280 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat
));
282 static llvm::cl::opt
<std::string
> clDialectExtensionName(
283 "dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
284 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat
));
286 using AttributeClasses
= DenseMap
<StringRef
, StringRef
>;
288 /// Checks whether `str` would shadow a generated variable or attribute
289 /// part of the OpView API.
290 static bool isODSReserved(StringRef str
) {
291 static llvm::StringSet
<> reserved(
292 {"attributes", "create", "context", "ip", "operands", "print", "get_asm",
293 "loc", "verify", "regions", "results", "self", "operation",
294 "DIALECT_NAMESPACE", "OPERATION_NAME"});
295 return str
.starts_with("_ods_") || str
.ends_with("_ods") ||
296 reserved
.contains(str
);
299 /// Modifies the `name` in a way that it becomes suitable for Python bindings
300 /// (does not change the `name` if it already is suitable) and returns the
301 /// modified version.
302 static std::string
sanitizeName(StringRef name
) {
303 std::string processedStr
= name
.str();
305 processedStr
.begin(), processedStr
.end(),
306 [](char c
) { return !llvm::isAlnum(c
); }, '_');
308 if (llvm::isDigit(*processedStr
.begin()))
309 return "_" + processedStr
;
311 if (isPythonReserved(processedStr
) || isODSReserved(processedStr
))
312 return processedStr
+ "_";
316 static std::string
attrSizedTraitForKind(const char *kind
) {
317 return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
318 llvm::StringRef(kind
).take_front().upper(),
319 llvm::StringRef(kind
).drop_front());
322 /// Emits accessors to "elements" of an Op definition. Currently, the supported
323 /// elements are operands and results, indicated by `kind`, which must be either
324 /// `operand` or `result` and is used verbatim in the emitted code.
325 static void emitElementAccessors(
326 const Operator
&op
, raw_ostream
&os
, const char *kind
,
327 llvm::function_ref
<unsigned(const Operator
&)> getNumVariableLength
,
328 llvm::function_ref
<int(const Operator
&)> getNumElements
,
329 llvm::function_ref
<const NamedTypeConstraint
&(const Operator
&, int)>
331 assert(llvm::is_contained(
332 llvm::SmallVector
<StringRef
, 2>{"operand", "result"}, kind
) &&
335 // Traits indicating how to process variadic elements.
336 std::string sameSizeTrait
=
337 llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
338 llvm::StringRef(kind
).take_front().upper(),
339 llvm::StringRef(kind
).drop_front());
340 std::string attrSizedTrait
= attrSizedTraitForKind(kind
);
342 unsigned numVariableLength
= getNumVariableLength(op
);
344 // If there is only one variable-length element group, its size can be
345 // inferred from the total number of elements. If there are none, the
346 // generation is straightforward.
347 if (numVariableLength
<= 1) {
348 bool seenVariableLength
= false;
349 for (int i
= 0, e
= getNumElements(op
); i
< e
; ++i
) {
350 const NamedTypeConstraint
&element
= getElement(op
, i
);
351 if (element
.isVariableLength())
352 seenVariableLength
= true;
353 if (element
.name
.empty())
355 if (element
.isVariableLength()) {
356 os
<< llvm::formatv(element
.isOptional() ? opOneOptionalTemplate
357 : opOneVariadicTemplate
,
358 sanitizeName(element
.name
), kind
,
359 getNumElements(op
), i
);
360 } else if (seenVariableLength
) {
361 os
<< llvm::formatv(opSingleAfterVariableTemplate
,
362 sanitizeName(element
.name
), kind
,
363 getNumElements(op
), i
);
365 os
<< llvm::formatv(opSingleTemplate
, sanitizeName(element
.name
), kind
,
372 // Handle the operations where variadic groups have the same size.
373 if (op
.getTrait(sameSizeTrait
)) {
374 int numPrecedingSimple
= 0;
375 int numPrecedingVariadic
= 0;
376 for (int i
= 0, e
= getNumElements(op
); i
< e
; ++i
) {
377 const NamedTypeConstraint
&element
= getElement(op
, i
);
378 if (!element
.name
.empty()) {
379 os
<< llvm::formatv(opVariadicEqualPrefixTemplate
,
380 sanitizeName(element
.name
), kind
, numVariableLength
,
381 numPrecedingSimple
, numPrecedingVariadic
);
382 os
<< llvm::formatv(element
.isVariableLength()
383 ? opVariadicEqualVariadicTemplate
384 : opVariadicEqualSimpleTemplate
,
387 if (element
.isVariableLength())
388 ++numPrecedingVariadic
;
390 ++numPrecedingSimple
;
395 // Handle the operations where the size of groups (variadic or not) is
396 // provided as an attribute. For non-variadic elements, make sure to return
397 // an element rather than a singleton container.
398 if (op
.getTrait(attrSizedTrait
)) {
399 for (int i
= 0, e
= getNumElements(op
); i
< e
; ++i
) {
400 const NamedTypeConstraint
&element
= getElement(op
, i
);
401 if (element
.name
.empty())
403 std::string trailing
;
404 if (!element
.isVariableLength())
406 else if (element
.isOptional())
407 trailing
= std::string(
408 llvm::formatv(opVariadicSegmentOptionalTrailingTemplate
, kind
));
409 os
<< llvm::formatv(opVariadicSegmentTemplate
, sanitizeName(element
.name
),
415 llvm::PrintFatalError("unsupported " + llvm::Twine(kind
) + " structure");
418 /// Free function helpers accessing Operator components.
419 static int getNumOperands(const Operator
&op
) { return op
.getNumOperands(); }
420 static const NamedTypeConstraint
&getOperand(const Operator
&op
, int i
) {
421 return op
.getOperand(i
);
423 static int getNumResults(const Operator
&op
) { return op
.getNumResults(); }
424 static const NamedTypeConstraint
&getResult(const Operator
&op
, int i
) {
425 return op
.getResult(i
);
428 /// Emits accessors to Op operands.
429 static void emitOperandAccessors(const Operator
&op
, raw_ostream
&os
) {
430 auto getNumVariableLengthOperands
= [](const Operator
&oper
) {
431 return oper
.getNumVariableLengthOperands();
433 emitElementAccessors(op
, os
, "operand", getNumVariableLengthOperands
,
434 getNumOperands
, getOperand
);
437 /// Emits accessors Op results.
438 static void emitResultAccessors(const Operator
&op
, raw_ostream
&os
) {
439 auto getNumVariableLengthResults
= [](const Operator
&oper
) {
440 return oper
.getNumVariableLengthResults();
442 emitElementAccessors(op
, os
, "result", getNumVariableLengthResults
,
443 getNumResults
, getResult
);
446 /// Emits accessors to Op attributes.
447 static void emitAttributeAccessors(const Operator
&op
, raw_ostream
&os
) {
448 for (const auto &namedAttr
: op
.getAttributes()) {
449 // Skip "derived" attributes because they are just C++ functions that we
450 // don't currently expose.
451 if (namedAttr
.attr
.isDerivedAttr())
454 if (namedAttr
.name
.empty())
457 std::string sanitizedName
= sanitizeName(namedAttr
.name
);
459 // Unit attributes are handled specially.
460 if (namedAttr
.attr
.getStorageType().trim() == "::mlir::UnitAttr") {
461 os
<< llvm::formatv(unitAttributeGetterTemplate
, sanitizedName
,
463 os
<< llvm::formatv(unitAttributeSetterTemplate
, sanitizedName
,
465 os
<< llvm::formatv(attributeDeleterTemplate
, sanitizedName
,
470 if (namedAttr
.attr
.isOptional()) {
471 os
<< llvm::formatv(optionalAttributeGetterTemplate
, sanitizedName
,
473 os
<< llvm::formatv(optionalAttributeSetterTemplate
, sanitizedName
,
475 os
<< llvm::formatv(attributeDeleterTemplate
, sanitizedName
,
478 os
<< llvm::formatv(attributeGetterTemplate
, sanitizedName
,
480 os
<< llvm::formatv(attributeSetterTemplate
, sanitizedName
,
482 // Non-optional attributes cannot be deleted.
487 /// Template for the default auto-generated builder.
488 /// {0} is a comma-separated list of builder arguments, including the trailing
490 /// {1} is the code populating `operands`, `results` and `attributes`,
491 /// `successors` fields.
492 constexpr const char *initTemplate
= R
"Py(
493 def __init__(self, {0}):
499 super().__init__(self.build_generic({2}))
502 /// Template for appending a single element to the operand/result list.
503 /// {0} is the field name.
504 constexpr const char *singleOperandAppendTemplate
=
505 "operands.append(_get_op_result_or_value({0}))";
506 constexpr const char *singleResultAppendTemplate
= "results.append({0})";
508 /// Template for appending an optional element to the operand/result list.
509 /// {0} is the field name.
510 constexpr const char *optionalAppendOperandTemplate
=
511 "if {0} is not None: operands.append(_get_op_result_or_value({0}))";
512 constexpr const char *optionalAppendAttrSizedOperandsTemplate
=
513 "operands.append(_get_op_result_or_value({0}) if {0} is not None else "
515 constexpr const char *optionalAppendResultTemplate
=
516 "if {0} is not None: results.append({0})";
518 /// Template for appending a list of elements to the operand/result list.
519 /// {0} is the field name.
520 constexpr const char *multiOperandAppendTemplate
=
521 "operands.extend(_get_op_results_or_values({0}))";
522 constexpr const char *multiOperandAppendPackTemplate
=
523 "operands.append(_get_op_results_or_values({0}))";
524 constexpr const char *multiResultAppendTemplate
= "results.extend({0})";
526 /// Template for attribute builder from raw input in the operation builder.
527 /// {0} is the builder argument name;
528 /// {1} is the attribute builder from raw;
529 /// {2} is the attribute builder from raw.
530 /// Use the value the user passed in if either it is already an Attribute or
531 /// there is no method registered to make it an Attribute.
532 constexpr const char *initAttributeWithBuilderTemplate
=
533 R
"Py(attributes["{1}"] = ({0} if (
534 isinstance({0}, _ods_ir.Attribute) or
535 not _ods_ir.AttrBuilder.contains('{2}')) else
536 _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
538 /// Template for attribute builder from raw input for optional attribute in the
539 /// operation builder.
540 /// {0} is the builder argument name;
541 /// {1} is the attribute builder from raw;
542 /// {2} is the attribute builder from raw.
543 /// Use the value the user passed in if either it is already an Attribute or
544 /// there is no method registered to make it an Attribute.
545 constexpr const char *initOptionalAttributeWithBuilderTemplate
=
546 R
"Py(if {0} is not None: attributes["{1}"] = ({0} if (
547 isinstance({0}, _ods_ir.Attribute) or
548 not _ods_ir.AttrBuilder.contains('{2}')) else
549 _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
551 constexpr const char *initUnitAttributeTemplate
=
552 R
"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
553 _ods_get_default_loc_context(loc)))Py";
555 /// Template to initialize the successors list in the builder if there are any
557 /// {0} is the value to initialize the successors list to.
558 constexpr const char *initSuccessorsTemplate
= R
"Py(_ods_successors = {0})Py";
560 /// Template to append or extend the list of successors in the builder.
561 /// {0} is the list method ('append' or 'extend');
562 /// {1} is the value to add.
563 constexpr const char *addSuccessorTemplate
= R
"Py(_ods_successors.{0}({1}))Py";
565 /// Returns true if the SameArgumentAndResultTypes trait can be used to infer
566 /// result types of the given operation.
567 static bool hasSameArgumentAndResultTypes(const Operator
&op
) {
568 return op
.getTrait("::mlir::OpTrait::SameOperandsAndResultType") &&
569 op
.getNumVariableLengthResults() == 0;
572 /// Returns true if the FirstAttrDerivedResultType trait can be used to infer
573 /// result types of the given operation.
574 static bool hasFirstAttrDerivedResultTypes(const Operator
&op
) {
575 return op
.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") &&
576 op
.getNumVariableLengthResults() == 0;
579 /// Returns true if the InferTypeOpInterface can be used to infer result types
580 /// of the given operation.
581 static bool hasInferTypeInterface(const Operator
&op
) {
582 return op
.getTrait("::mlir::InferTypeOpInterface::Trait") &&
583 op
.getNumRegions() == 0;
586 /// Returns true if there is a trait or interface that can be used to infer
587 /// result types of the given operation.
588 static bool canInferType(const Operator
&op
) {
589 return hasSameArgumentAndResultTypes(op
) ||
590 hasFirstAttrDerivedResultTypes(op
) || hasInferTypeInterface(op
);
593 /// Populates `builderArgs` with result names if the builder is expected to
594 /// accept them as arguments.
596 populateBuilderArgsResults(const Operator
&op
,
597 llvm::SmallVectorImpl
<std::string
> &builderArgs
) {
598 if (canInferType(op
))
601 for (int i
= 0, e
= op
.getNumResults(); i
< e
; ++i
) {
602 std::string name
= op
.getResultName(i
).str();
604 if (op
.getNumResults() == 1) {
605 // Special case for one result, make the default name be 'result'
606 // to properly match the built-in result accessor.
609 name
= llvm::formatv("_gen_res_{0}", i
);
612 name
= sanitizeName(name
);
613 builderArgs
.push_back(name
);
617 /// Populates `builderArgs` with the Python-compatible names of builder function
618 /// arguments using intermixed attributes and operands in the same order as they
619 /// appear in the `arguments` field of the op definition. Additionally,
620 /// `operandNames` is populated with names of operands in their order of
623 populateBuilderArgs(const Operator
&op
,
624 llvm::SmallVectorImpl
<std::string
> &builderArgs
,
625 llvm::SmallVectorImpl
<std::string
> &operandNames
) {
626 for (int i
= 0, e
= op
.getNumArgs(); i
< e
; ++i
) {
627 std::string name
= op
.getArgName(i
).str();
629 name
= llvm::formatv("_gen_arg_{0}", i
);
630 name
= sanitizeName(name
);
631 builderArgs
.push_back(name
);
632 if (!op
.getArg(i
).is
<NamedAttribute
*>())
633 operandNames
.push_back(name
);
637 /// Populates `builderArgs` with the Python-compatible names of builder function
638 /// successor arguments. Additionally, `successorArgNames` is also populated.
639 static void populateBuilderArgsSuccessors(
640 const Operator
&op
, llvm::SmallVectorImpl
<std::string
> &builderArgs
,
641 llvm::SmallVectorImpl
<std::string
> &successorArgNames
) {
643 for (int i
= 0, e
= op
.getNumSuccessors(); i
< e
; ++i
) {
644 NamedSuccessor successor
= op
.getSuccessor(i
);
645 std::string name
= std::string(successor
.name
);
647 name
= llvm::formatv("_gen_successor_{0}", i
);
648 name
= sanitizeName(name
);
649 builderArgs
.push_back(name
);
650 successorArgNames
.push_back(name
);
654 /// Populates `builderLines` with additional lines that are required in the
655 /// builder to set up operation attributes. `argNames` is expected to contain
656 /// the names of builder arguments that correspond to op arguments, i.e. to the
657 /// operands and attributes in the same order as they appear in the `arguments`
660 populateBuilderLinesAttr(const Operator
&op
,
661 llvm::ArrayRef
<std::string
> argNames
,
662 llvm::SmallVectorImpl
<std::string
> &builderLines
) {
663 builderLines
.push_back("_ods_context = _ods_get_default_loc_context(loc)");
664 for (int i
= 0, e
= op
.getNumArgs(); i
< e
; ++i
) {
665 Argument arg
= op
.getArg(i
);
666 auto *attribute
= llvm::dyn_cast_if_present
<NamedAttribute
*>(arg
);
670 // Unit attributes are handled specially.
671 if (attribute
->attr
.getStorageType().trim() == "::mlir::UnitAttr") {
672 builderLines
.push_back(llvm::formatv(initUnitAttributeTemplate
,
673 attribute
->name
, argNames
[i
]));
677 builderLines
.push_back(llvm::formatv(
678 attribute
->attr
.isOptional() || attribute
->attr
.hasDefaultValue()
679 ? initOptionalAttributeWithBuilderTemplate
680 : initAttributeWithBuilderTemplate
,
681 argNames
[i
], attribute
->name
, attribute
->attr
.getAttrDefName()));
685 /// Populates `builderLines` with additional lines that are required in the
686 /// builder to set up successors. successorArgNames is expected to correspond
687 /// to the Python argument name for each successor on the op.
688 static void populateBuilderLinesSuccessors(
689 const Operator
&op
, llvm::ArrayRef
<std::string
> successorArgNames
,
690 llvm::SmallVectorImpl
<std::string
> &builderLines
) {
691 if (successorArgNames
.empty()) {
692 builderLines
.push_back(llvm::formatv(initSuccessorsTemplate
, "None"));
696 builderLines
.push_back(llvm::formatv(initSuccessorsTemplate
, "[]"));
697 for (int i
= 0, e
= successorArgNames
.size(); i
< e
; ++i
) {
698 auto &argName
= successorArgNames
[i
];
699 const NamedSuccessor
&successor
= op
.getSuccessor(i
);
700 builderLines
.push_back(
701 llvm::formatv(addSuccessorTemplate
,
702 successor
.isVariadic() ? "extend" : "append", argName
));
706 /// Populates `builderLines` with additional lines that are required in the
707 /// builder to set up op operands.
709 populateBuilderLinesOperand(const Operator
&op
,
710 llvm::ArrayRef
<std::string
> names
,
711 llvm::SmallVectorImpl
<std::string
> &builderLines
) {
712 bool sizedSegments
= op
.getTrait(attrSizedTraitForKind("operand")) != nullptr;
714 // For each element, find or generate a name.
715 for (int i
= 0, e
= op
.getNumOperands(); i
< e
; ++i
) {
716 const NamedTypeConstraint
&element
= op
.getOperand(i
);
717 std::string name
= names
[i
];
719 // Choose the formatting string based on the element kind.
720 llvm::StringRef formatString
;
721 if (!element
.isVariableLength()) {
722 formatString
= singleOperandAppendTemplate
;
723 } else if (element
.isOptional()) {
725 formatString
= optionalAppendAttrSizedOperandsTemplate
;
727 formatString
= optionalAppendOperandTemplate
;
730 assert(element
.isVariadic() && "unhandled element group type");
731 // If emitting with sizedSegments, then we add the actual list-typed
732 // element. Otherwise, we extend the actual operands.
734 formatString
= multiOperandAppendPackTemplate
;
736 formatString
= multiOperandAppendTemplate
;
740 builderLines
.push_back(llvm::formatv(formatString
.data(), name
));
744 /// Python code template for deriving the operation result types from its
746 /// - {0} is the name of the attribute from which to derive the types.
747 constexpr const char *deriveTypeFromAttrTemplate
=
748 R
"Py(_ods_result_type_source_attr = attributes["{0}"]
749 _ods_derived_result_type = (
750 _ods_ir.TypeAttr(_ods_result_type_source_attr).value
751 if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
752 _ods_result_type_source_attr.type))Py";
754 /// Python code template appending {0} type {1} times to the results list.
755 constexpr const char *appendSameResultsTemplate
= "results.extend([{0}] * {1})";
757 /// Appends the given multiline string as individual strings into
759 static void appendLineByLine(StringRef string
,
760 llvm::SmallVectorImpl
<std::string
> &builderLines
) {
762 std::pair
<StringRef
, StringRef
> split
= std::make_pair(string
, string
);
764 split
= split
.second
.split('\n');
765 builderLines
.push_back(split
.first
.str());
766 } while (!split
.second
.empty());
769 /// Populates `builderLines` with additional lines that are required in the
770 /// builder to set up op results.
772 populateBuilderLinesResult(const Operator
&op
,
773 llvm::ArrayRef
<std::string
> names
,
774 llvm::SmallVectorImpl
<std::string
> &builderLines
) {
775 bool sizedSegments
= op
.getTrait(attrSizedTraitForKind("result")) != nullptr;
777 if (hasSameArgumentAndResultTypes(op
)) {
778 builderLines
.push_back(llvm::formatv(
779 appendSameResultsTemplate
, "operands[0].type", op
.getNumResults()));
783 if (hasFirstAttrDerivedResultTypes(op
)) {
784 const NamedAttribute
&firstAttr
= op
.getAttribute(0);
785 assert(!firstAttr
.name
.empty() && "unexpected empty name for the attribute "
786 "from which the type is derived");
788 llvm::formatv(deriveTypeFromAttrTemplate
, firstAttr
.name
).str(),
790 builderLines
.push_back(llvm::formatv(appendSameResultsTemplate
,
791 "_ods_derived_result_type",
792 op
.getNumResults()));
796 if (hasInferTypeInterface(op
))
799 // For each element, find or generate a name.
800 for (int i
= 0, e
= op
.getNumResults(); i
< e
; ++i
) {
801 const NamedTypeConstraint
&element
= op
.getResult(i
);
802 std::string name
= names
[i
];
804 // Choose the formatting string based on the element kind.
805 llvm::StringRef formatString
;
806 if (!element
.isVariableLength()) {
807 formatString
= singleResultAppendTemplate
;
808 } else if (element
.isOptional()) {
809 formatString
= optionalAppendResultTemplate
;
811 assert(element
.isVariadic() && "unhandled element group type");
812 // If emitting with sizedSegments, then we add the actual list-typed
813 // element. Otherwise, we extend the actual operands.
815 formatString
= singleResultAppendTemplate
;
817 formatString
= multiResultAppendTemplate
;
821 builderLines
.push_back(llvm::formatv(formatString
.data(), name
));
825 /// If the operation has variadic regions, adds a builder argument to specify
826 /// the number of those regions and builder lines to forward it to the generic
829 populateBuilderRegions(const Operator
&op
,
830 llvm::SmallVectorImpl
<std::string
> &builderArgs
,
831 llvm::SmallVectorImpl
<std::string
> &builderLines
) {
832 if (op
.hasNoVariadicRegions())
835 // This is currently enforced when Operator is constructed.
836 assert(op
.getNumVariadicRegions() == 1 &&
837 op
.getRegion(op
.getNumRegions() - 1).isVariadic() &&
838 "expected the last region to be varidic");
840 const NamedRegion
®ion
= op
.getRegion(op
.getNumRegions() - 1);
842 ("num_" + region
.name
.take_front().lower() + region
.name
.drop_front())
844 builderArgs
.push_back(name
);
845 builderLines
.push_back(
846 llvm::formatv("regions = {0} + {1}", op
.getNumRegions() - 1, name
));
849 /// Emits a default builder constructing an operation from the list of its
850 /// result types, followed by a list of its operands. Returns vector
851 /// of fully built functionArgs for downstream users (to save having to
853 static llvm::SmallVector
<std::string
> emitDefaultOpBuilder(const Operator
&op
,
855 llvm::SmallVector
<std::string
> builderArgs
;
856 llvm::SmallVector
<std::string
> builderLines
;
857 llvm::SmallVector
<std::string
> operandArgNames
;
858 llvm::SmallVector
<std::string
> successorArgNames
;
859 builderArgs
.reserve(op
.getNumOperands() + op
.getNumResults() +
860 op
.getNumNativeAttributes() + op
.getNumSuccessors());
861 populateBuilderArgsResults(op
, builderArgs
);
862 size_t numResultArgs
= builderArgs
.size();
863 populateBuilderArgs(op
, builderArgs
, operandArgNames
);
864 size_t numOperandAttrArgs
= builderArgs
.size() - numResultArgs
;
865 populateBuilderArgsSuccessors(op
, builderArgs
, successorArgNames
);
867 populateBuilderLinesOperand(op
, operandArgNames
, builderLines
);
868 populateBuilderLinesAttr(
869 op
, llvm::ArrayRef(builderArgs
).drop_front(numResultArgs
), builderLines
);
870 populateBuilderLinesResult(
871 op
, llvm::ArrayRef(builderArgs
).take_front(numResultArgs
), builderLines
);
872 populateBuilderLinesSuccessors(op
, successorArgNames
, builderLines
);
873 populateBuilderRegions(op
, builderArgs
, builderLines
);
875 // Layout of builderArgs vector elements:
876 // [ result_args operand_attr_args successor_args regions ]
878 // Determine whether the argument corresponding to a given index into the
879 // builderArgs vector is a python keyword argument or not.
880 auto isKeywordArgFn
= [&](size_t builderArgIndex
) -> bool {
881 // All result, successor, and region arguments are positional arguments.
882 if ((builderArgIndex
< numResultArgs
) ||
883 (builderArgIndex
>= (numResultArgs
+ numOperandAttrArgs
)))
885 // Keyword arguments:
886 // - optional named attributes (including unit attributes)
887 // - default-valued named attributes
888 // - optional operands
889 Argument a
= op
.getArg(builderArgIndex
- numResultArgs
);
890 if (auto *nattr
= llvm::dyn_cast_if_present
<NamedAttribute
*>(a
))
891 return (nattr
->attr
.isOptional() || nattr
->attr
.hasDefaultValue());
892 if (auto *ntype
= llvm::dyn_cast_if_present
<NamedTypeConstraint
*>(a
))
893 return ntype
->isOptional();
897 // StringRefs in functionArgs refer to strings allocated by builderArgs.
898 llvm::SmallVector
<llvm::StringRef
> functionArgs
;
900 // Add positional arguments.
901 for (size_t i
= 0, cnt
= builderArgs
.size(); i
< cnt
; ++i
) {
902 if (!isKeywordArgFn(i
))
903 functionArgs
.push_back(builderArgs
[i
]);
906 // Add a bare '*' to indicate that all following arguments must be keyword
908 functionArgs
.push_back("*");
910 // Add a default 'None' value to each keyword arg string, and then add to the
911 // function args list.
912 for (size_t i
= 0, cnt
= builderArgs
.size(); i
< cnt
; ++i
) {
913 if (isKeywordArgFn(i
)) {
914 builderArgs
[i
].append("=None");
915 functionArgs
.push_back(builderArgs
[i
]);
918 functionArgs
.push_back("loc=None");
919 functionArgs
.push_back("ip=None");
921 SmallVector
<std::string
> initArgs
;
922 initArgs
.push_back("attributes=attributes");
923 if (!hasInferTypeInterface(op
))
924 initArgs
.push_back("results=results");
925 initArgs
.push_back("operands=operands");
926 initArgs
.push_back("successors=_ods_successors");
927 initArgs
.push_back("regions=regions");
928 initArgs
.push_back("loc=loc");
929 initArgs
.push_back("ip=ip");
931 os
<< llvm::formatv(initTemplate
, llvm::join(functionArgs
, ", "),
932 llvm::join(builderLines
, "\n "),
933 llvm::join(initArgs
, ", "));
934 return llvm::to_vector
<8>(
935 llvm::map_range(functionArgs
, [](llvm::StringRef s
) { return s
.str(); }));
938 static void emitSegmentSpec(
939 const Operator
&op
, const char *kind
,
940 llvm::function_ref
<int(const Operator
&)> getNumElements
,
941 llvm::function_ref
<const NamedTypeConstraint
&(const Operator
&, int)>
944 std::string
segmentSpec("[");
945 for (int i
= 0, e
= getNumElements(op
); i
< e
; ++i
) {
946 const NamedTypeConstraint
&element
= getElement(op
, i
);
947 if (element
.isOptional()) {
948 segmentSpec
.append("0,");
949 } else if (element
.isVariadic()) {
950 segmentSpec
.append("-1,");
952 segmentSpec
.append("1,");
955 segmentSpec
.append("]");
957 os
<< llvm::formatv(opClassSizedSegmentsTemplate
, kind
, segmentSpec
);
960 static void emitRegionAttributes(const Operator
&op
, raw_ostream
&os
) {
961 // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
962 // Note that the base OpView class defines this as (0, True).
963 unsigned minRegionCount
= op
.getNumRegions() - op
.getNumVariadicRegions();
964 os
<< llvm::formatv(opClassRegionSpecTemplate
, minRegionCount
,
965 op
.hasNoVariadicRegions() ? "True" : "False");
968 /// Emits named accessors to regions.
969 static void emitRegionAccessors(const Operator
&op
, raw_ostream
&os
) {
970 for (const auto &en
: llvm::enumerate(op
.getRegions())) {
971 const NamedRegion
®ion
= en
.value();
972 if (region
.name
.empty())
975 assert((!region
.isVariadic() || en
.index() == op
.getNumRegions() - 1) &&
976 "expected only the last region to be variadic");
977 os
<< llvm::formatv(regionAccessorTemplate
, sanitizeName(region
.name
),
978 std::to_string(en
.index()) +
979 (region
.isVariadic() ? ":" : ""));
983 /// Emits builder that extracts results from op
984 static void emitValueBuilder(const Operator
&op
,
985 llvm::SmallVector
<std::string
> functionArgs
,
987 // Params with (possibly) default args.
988 auto valueBuilderParams
=
989 llvm::map_range(functionArgs
, [](const std::string
&argAndMaybeDefault
) {
990 llvm::SmallVector
<llvm::StringRef
> argMaybeDefault
=
991 llvm::to_vector
<2>(llvm::split(argAndMaybeDefault
, "="));
992 auto arg
= llvm::convertToSnakeFromCamelCase(argMaybeDefault
[0]);
993 if (argMaybeDefault
.size() == 2)
994 return arg
+ "=" + argMaybeDefault
[1].str();
997 // Actual args passed to op builder (e.g., opParam=op_param).
998 auto opBuilderArgs
= llvm::map_range(
999 llvm::make_filter_range(functionArgs
,
1000 [](const std::string
&s
) { return s
!= "*"; }),
1001 [](const std::string
&arg
) {
1002 auto lhs
= *llvm::split(arg
, "=").begin();
1003 return (lhs
+ "=" + llvm::convertToSnakeFromCamelCase(lhs
)).str();
1005 std::string nameWithoutDialect
=
1006 op
.getOperationName().substr(op
.getOperationName().find('.') + 1);
1007 os
<< llvm::formatv(
1008 valueBuilderTemplate
, sanitizeName(nameWithoutDialect
),
1009 op
.getCppClassName(), llvm::join(valueBuilderParams
, ", "),
1010 llvm::join(opBuilderArgs
, ", "),
1011 (op
.getNumResults() > 1
1012 ? "_Sequence[_ods_ir.Value]"
1013 : (op
.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation")));
1016 /// Emits bindings for a specific Op to the given output stream.
1017 static void emitOpBindings(const Operator
&op
, raw_ostream
&os
) {
1018 os
<< llvm::formatv(opClassTemplate
, op
.getCppClassName(),
1019 op
.getOperationName());
1022 if (op
.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
1023 emitSegmentSpec(op
, "OPERAND", getNumOperands
, getOperand
, os
);
1025 if (op
.getTrait(attrSizedTraitForKind("result")) != nullptr) {
1026 emitSegmentSpec(op
, "RESULT", getNumResults
, getResult
, os
);
1029 emitRegionAttributes(op
, os
);
1030 llvm::SmallVector
<std::string
> functionArgs
= emitDefaultOpBuilder(op
, os
);
1031 emitOperandAccessors(op
, os
);
1032 emitAttributeAccessors(op
, os
);
1033 emitResultAccessors(op
, os
);
1034 emitRegionAccessors(op
, os
);
1035 emitValueBuilder(op
, functionArgs
, os
);
1038 /// Emits bindings for the dialect specified in the command line, including file
1039 /// headers and utilities. Returns `false` on success to comply with Tablegen
1040 /// registration requirements.
1041 static bool emitAllOps(const llvm::RecordKeeper
&records
, raw_ostream
&os
) {
1042 if (clDialectName
.empty())
1043 llvm::PrintFatalError("dialect name not provided");
1046 if (!clDialectExtensionName
.empty())
1047 os
<< llvm::formatv(dialectExtensionTemplate
, clDialectName
.getValue());
1049 os
<< llvm::formatv(dialectClassTemplate
, clDialectName
.getValue());
1051 for (const llvm::Record
*rec
: records
.getAllDerivedDefinitions("Op")) {
1053 if (op
.getDialectName() == clDialectName
.getValue())
1054 emitOpBindings(op
, os
);
1059 static GenRegistration
1060 genPythonBindings("gen-python-op-bindings",
1061 "Generate Python bindings for MLIR Ops", &emitAllOps
);