1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
12 #include "PybindUtils.h"
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/Debug.h"
17 #include "mlir-c/Diagnostics.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Support.h"
20 #include "mlir/Bindings/Python/PybindAdaptors.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/SmallVector.h"
27 namespace py
= pybind11
;
28 using namespace py::literals
;
30 using namespace mlir::python
;
32 using llvm::SmallVector
;
33 using llvm::StringRef
;
36 //------------------------------------------------------------------------------
37 // Docstrings (trivial, non-duplicated docstrings are included inline).
38 //------------------------------------------------------------------------------
40 static const char kContextParseTypeDocstring
[] =
41 R
"(Parses the assembly form of a type.
43 Returns a Type object or raises an MLIRError if the type cannot be parsed.
45 See also: https://mlir.llvm.org/docs/LangRef/#type-system
48 static const char kContextGetCallSiteLocationDocstring
[] =
49 R
"(Gets a Location representing a caller and callsite)";
51 static const char kContextGetFileLocationDocstring
[] =
52 R
"(Gets a Location representing a file, line and column)";
54 static const char kContextGetFusedLocationDocstring
[] =
55 R
"(Gets a Location representing a fused location with optional metadata)";
57 static const char kContextGetNameLocationDocString
[] =
58 R
"(Gets a Location representing a named location with optional child location)";
60 static const char kModuleParseDocstring
[] =
61 R
"(Parses a module's assembly format from a string.
63 Returns a new MlirModule or raises an MLIRError if the parsing fails.
65 See also: https://mlir.llvm.org/docs/LangRef/
68 static const char kOperationCreateDocstring
[] =
69 R
"(Creates a new operation.
72 name: Operation name (e.g. "dialect
.operation
").
73 results: Sequence of Type representing op result types.
74 attributes: Dict of str:Attribute.
75 successors: List of Block for the operation's successors.
76 regions: Number of regions to create.
77 location: A Location object (defaults to resolve from context manager).
78 ip: An InsertionPoint (defaults to resolve from context manager or set to
79 False to disable insertion, even with an insertion point set in the
81 infer_type: Whether to infer result types.
83 A new "detached
" Operation object. Detached operations can be added
84 to blocks, which causes them to become "attached
."
87 static const char kOperationPrintDocstring
[] =
88 R
"(Prints the assembly form of the operation to a file like object.
91 file: The file like object to write to. Defaults to sys.stdout.
92 binary: Whether to write bytes (True) or str (False). Defaults to False.
93 large_elements_limit: Whether to elide elements attributes above this
94 number of elements. Defaults to None (no limit).
95 enable_debug_info: Whether to print debug/location information. Defaults
97 pretty_debug_info: Whether to format debug information for easier reading
98 by a human (warning: the result is unparseable).
99 print_generic_op_form: Whether to print the generic assembly forms of all
100 ops. Defaults to False.
101 use_local_Scope: Whether to print in a way that is more optimized for
102 multi-threaded access but may not be consistent with how the overall
104 assume_verified: By default, if not printing generic form, the verifier
105 will be run and if it fails, generic form will be printed with a comment
106 about failed verification. While a reasonable default for interactive use,
107 for systematic use, it is often better for the caller to verify explicitly
108 and report failures in a more robust fashion. Set this to True if doing this
109 in order to avoid running a redundant verification. If the IR is actually
110 invalid, behavior is undefined.
111 skip_regions: Whether to skip printing regions. Defaults to False.
114 static const char kOperationPrintStateDocstring
[] =
115 R
"(Prints the assembly form of the operation to a file like object.
118 file: The file like object to write to. Defaults to sys.stdout.
119 binary: Whether to write bytes (True) or str (False). Defaults to False.
120 state: AsmState capturing the operation numbering and flags.
123 static const char kOperationGetAsmDocstring
[] =
124 R
"(Gets the assembly form of the operation with all options available.
127 binary: Whether to return a bytes (True) or str (False) object. Defaults to
129 ... others ...: See the print() method for common keyword arguments for
130 configuring the printout.
132 Either a bytes or str object, depending on the setting of the 'binary'
136 static const char kOperationPrintBytecodeDocstring
[] =
137 R
"(Write the bytecode form of the operation to a file like object.
140 file: The file like object to write to.
141 desired_version: The version of bytecode to emit.
143 The bytecode writer status.
146 static const char kOperationStrDunderDocstring
[] =
147 R
"(Gets the assembly form of the operation with default options.
149 If more advanced control over the assembly formatting or I/O options is needed,
150 use the dedicated print or get_asm method, which supports keyword arguments to
154 static const char kDumpDocstring
[] =
155 R
"(Dumps a debug representation of the object to stderr.)";
157 static const char kAppendBlockDocstring
[] =
158 R
"(Appends a new block, with argument types as positional args.
164 static const char kValueDunderStrDocstring
[] =
165 R
"(Returns the string form of the value.
167 If the value is a block argument, this is the assembly form of its type and the
168 position in the argument list. If the value is an operation result, this is
169 equivalent to printing the operation that produced it.
172 static const char kGetNameAsOperand
[] =
173 R
"(Returns the string form of value as an operand (i.e., the ValueID).
176 static const char kValueReplaceAllUsesWithDocstring
[] =
177 R
"(Replace all uses of value with the new value, updating anything in
178 the IR that uses 'self' to use the other value instead.
181 static const char kValueReplaceAllUsesExceptDocstring
[] =
182 R
"("Replace all uses of
this value with the
'with' value
, except
for those
183 in
'exceptions'. 'exceptions' can be either a single operation
or a list of
187 //------------------------------------------------------------------------------
189 //------------------------------------------------------------------------------
191 /// Helper for creating an @classmethod.
192 template <class Func, typename... Args>
193 py::object classmethod(Func f, Args... args) {
194 py::object cf = py::cpp_function(f, args...);
195 return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
199 createCustomDialectWrapper(const std::string &dialectNamespace,
200 py::object dialectDescriptor) {
201 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
203 // Use the base class.
204 return py::cast(PyDialect(std::move(dialectDescriptor)));
207 // Create the custom implementation.
208 return (*dialectClass)(std::move(dialectDescriptor));
211 static MlirStringRef toMlirStringRef(const std::string &s) {
212 return mlirStringRefCreate(s.data(), s.size());
215 /// Create a block, using the current location context if no locations are
217 static MlirBlock createBlock(const py::sequence &pyArgTypes,
218 const std::optional<py::sequence> &pyArgLocs) {
219 SmallVector<MlirType> argTypes;
220 argTypes.reserve(pyArgTypes.size());
221 for (const auto &pyType : pyArgTypes)
222 argTypes.push_back(pyType.cast<PyType &>());
224 SmallVector<MlirLocation> argLocs;
226 argLocs.reserve(pyArgLocs->size());
227 for (const auto &pyLoc : *pyArgLocs)
228 argLocs.push_back(pyLoc.cast<PyLocation &>());
229 } else if (!argTypes.empty()) {
230 argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
233 if (argTypes.size() != argLocs.size())
234 throw py::value_error(("Expected
" + Twine(argTypes.size()) +
235 " locations
, got
: " + Twine(argLocs.size()))
237 return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
240 /// Wrapper for the global LLVM debugging flag.
241 struct PyGlobalDebugFlag {
242 static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
244 static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); }
246 static void bind(py::module &m) {
248 py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug
", py::module_local())
249 .def_property_static("flag
", &PyGlobalDebugFlag::get,
250 &PyGlobalDebugFlag::set, "LLVM
-wide debug flag
")
253 [](const std::string &type) {
254 mlirSetGlobalDebugType(type.c_str());
256 "types
"_a, "Sets specific debug types to be produced by LLVM
")
257 .def_static("set_types
", [](const std::vector<std::string> &types) {
258 std::vector<const char *> pointers;
259 pointers.reserve(types.size());
260 for (const std::string &str : types)
261 pointers.push_back(str.c_str());
262 mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
267 struct PyAttrBuilderMap {
268 static bool dunderContains(const std::string &attributeKind) {
269 return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
271 static py::function dundeGetItemNamed(const std::string &attributeKind) {
272 auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
274 throw py::key_error(attributeKind);
277 static void dundeSetItemNamed(const std::string &attributeKind,
278 py::function func, bool replace) {
279 PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
283 static void bind(py::module &m) {
284 py::class_<PyAttrBuilderMap>(m, "AttrBuilder
", py::module_local())
285 .def_static("contains
", &PyAttrBuilderMap::dunderContains)
286 .def_static("get
", &PyAttrBuilderMap::dundeGetItemNamed)
287 .def_static("insert
", &PyAttrBuilderMap::dundeSetItemNamed,
288 "attribute_kind
"_a, "attr_builder
"_a, "replace
"_a = false,
289 "Register an attribute builder
for building MLIR
"
290 "attributes from python values
.");
294 //------------------------------------------------------------------------------
296 //------------------------------------------------------------------------------
298 py::object PyBlock::getCapsule() {
299 return py::reinterpret_steal<py::object>(mlirPythonBlockToCapsule(get()));
302 //------------------------------------------------------------------------------
304 //------------------------------------------------------------------------------
308 class PyRegionIterator {
310 PyRegionIterator(PyOperationRef operation)
311 : operation(std::move(operation)) {}
313 PyRegionIterator &dunderIter() { return *this; }
315 PyRegion dunderNext() {
316 operation->checkValid();
317 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
318 throw py::stop_iteration();
320 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
321 return PyRegion(operation, region);
324 static void bind(py::module &m) {
325 py::class_<PyRegionIterator>(m, "RegionIterator
", py::module_local())
326 .def("__iter__
", &PyRegionIterator::dunderIter)
327 .def("__next__
", &PyRegionIterator::dunderNext);
331 PyOperationRef operation;
335 /// Regions of an op are fixed length and indexed numerically so are represented
336 /// with a sequence-like container.
339 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
341 PyRegionIterator dunderIter() {
342 operation->checkValid();
343 return PyRegionIterator(operation);
346 intptr_t dunderLen() {
347 operation->checkValid();
348 return mlirOperationGetNumRegions(operation->get());
351 PyRegion dunderGetItem(intptr_t index) {
352 // dunderLen checks validity.
353 if (index < 0 || index >= dunderLen()) {
354 throw py::index_error("attempt to access out of bounds region
");
356 MlirRegion region = mlirOperationGetRegion(operation->get(), index);
357 return PyRegion(operation, region);
360 static void bind(py::module &m) {
361 py::class_<PyRegionList>(m, "RegionSequence
", py::module_local())
362 .def("__len__
", &PyRegionList::dunderLen)
363 .def("__iter__
", &PyRegionList::dunderIter)
364 .def("__getitem__
", &PyRegionList::dunderGetItem);
368 PyOperationRef operation;
371 class PyBlockIterator {
373 PyBlockIterator(PyOperationRef operation, MlirBlock next)
374 : operation(std::move(operation)), next(next) {}
376 PyBlockIterator &dunderIter() { return *this; }
378 PyBlock dunderNext() {
379 operation->checkValid();
380 if (mlirBlockIsNull(next)) {
381 throw py::stop_iteration();
384 PyBlock returnBlock(operation, next);
385 next = mlirBlockGetNextInRegion(next);
389 static void bind(py::module &m) {
390 py::class_<PyBlockIterator>(m, "BlockIterator
", py::module_local())
391 .def("__iter__
", &PyBlockIterator::dunderIter)
392 .def("__next__
", &PyBlockIterator::dunderNext);
396 PyOperationRef operation;
400 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
401 /// we present them as a more full-featured list-like container but optimize
402 /// it for forward iteration. Blocks are always owned by a region.
405 PyBlockList(PyOperationRef operation, MlirRegion region)
406 : operation(std::move(operation)), region(region) {}
408 PyBlockIterator dunderIter() {
409 operation->checkValid();
410 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
413 intptr_t dunderLen() {
414 operation->checkValid();
416 MlirBlock block = mlirRegionGetFirstBlock(region);
417 while (!mlirBlockIsNull(block)) {
419 block = mlirBlockGetNextInRegion(block);
424 PyBlock dunderGetItem(intptr_t index) {
425 operation->checkValid();
427 throw py::index_error("attempt to access out of bounds block
");
429 MlirBlock block = mlirRegionGetFirstBlock(region);
430 while (!mlirBlockIsNull(block)) {
432 return PyBlock(operation, block);
434 block = mlirBlockGetNextInRegion(block);
437 throw py::index_error("attempt to access out of bounds block
");
440 PyBlock appendBlock(const py::args &pyArgTypes,
441 const std::optional<py::sequence> &pyArgLocs) {
442 operation->checkValid();
443 MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
444 mlirRegionAppendOwnedBlock(region, block);
445 return PyBlock(operation, block);
448 static void bind(py::module &m) {
449 py::class_<PyBlockList>(m, "BlockList
", py::module_local())
450 .def("__getitem__
", &PyBlockList::dunderGetItem)
451 .def("__iter__
", &PyBlockList::dunderIter)
452 .def("__len__
", &PyBlockList::dunderLen)
453 .def("append
", &PyBlockList::appendBlock, kAppendBlockDocstring,
454 py::arg("arg_locs
") = std::nullopt);
458 PyOperationRef operation;
462 class PyOperationIterator {
464 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
465 : parentOperation(std::move(parentOperation)), next(next) {}
467 PyOperationIterator &dunderIter() { return *this; }
469 py::object dunderNext() {
470 parentOperation->checkValid();
471 if (mlirOperationIsNull(next)) {
472 throw py::stop_iteration();
475 PyOperationRef returnOperation =
476 PyOperation::forOperation(parentOperation->getContext(), next);
477 next = mlirOperationGetNextInBlock(next);
478 return returnOperation->createOpView();
481 static void bind(py::module &m) {
482 py::class_<PyOperationIterator>(m, "OperationIterator
", py::module_local())
483 .def("__iter__
", &PyOperationIterator::dunderIter)
484 .def("__next__
", &PyOperationIterator::dunderNext);
488 PyOperationRef parentOperation;
492 /// Operations are exposed by the C-API as a forward-only linked list. In
493 /// Python, we present them as a more full-featured list-like container but
494 /// optimize it for forward iteration. Iterable operations are always owned
496 class PyOperationList {
498 PyOperationList(PyOperationRef parentOperation, MlirBlock block)
499 : parentOperation(std::move(parentOperation)), block(block) {}
501 PyOperationIterator dunderIter() {
502 parentOperation->checkValid();
503 return PyOperationIterator(parentOperation,
504 mlirBlockGetFirstOperation(block));
507 intptr_t dunderLen() {
508 parentOperation->checkValid();
510 MlirOperation childOp = mlirBlockGetFirstOperation(block);
511 while (!mlirOperationIsNull(childOp)) {
513 childOp = mlirOperationGetNextInBlock(childOp);
518 py::object dunderGetItem(intptr_t index) {
519 parentOperation->checkValid();
521 throw py::index_error("attempt to access out of bounds operation
");
523 MlirOperation childOp = mlirBlockGetFirstOperation(block);
524 while (!mlirOperationIsNull(childOp)) {
526 return PyOperation::forOperation(parentOperation->getContext(), childOp)
529 childOp = mlirOperationGetNextInBlock(childOp);
532 throw py::index_error("attempt to access out of bounds operation
");
535 static void bind(py::module &m) {
536 py::class_<PyOperationList>(m, "OperationList
", py::module_local())
537 .def("__getitem__
", &PyOperationList::dunderGetItem)
538 .def("__iter__
", &PyOperationList::dunderIter)
539 .def("__len__
", &PyOperationList::dunderLen);
543 PyOperationRef parentOperation;
549 PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
551 py::object getOwner() {
552 MlirOperation owner = mlirOpOperandGetOwner(opOperand);
553 PyMlirContextRef context =
554 PyMlirContext::forContext(mlirOperationGetContext(owner));
555 return PyOperation::forOperation(context, owner)->createOpView();
558 size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
560 static void bind(py::module &m) {
561 py::class_<PyOpOperand>(m, "OpOperand
", py::module_local())
562 .def_property_readonly("owner
", &PyOpOperand::getOwner)
563 .def_property_readonly("operand_number
",
564 &PyOpOperand::getOperandNumber);
568 MlirOpOperand opOperand;
571 class PyOpOperandIterator {
573 PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
575 PyOpOperandIterator &dunderIter() { return *this; }
577 PyOpOperand dunderNext() {
578 if (mlirOpOperandIsNull(opOperand))
579 throw py::stop_iteration();
581 PyOpOperand returnOpOperand(opOperand);
582 opOperand = mlirOpOperandGetNextUse(opOperand);
583 return returnOpOperand;
586 static void bind(py::module &m) {
587 py::class_<PyOpOperandIterator>(m, "OpOperandIterator
", py::module_local())
588 .def("__iter__
", &PyOpOperandIterator::dunderIter)
589 .def("__next__
", &PyOpOperandIterator::dunderNext);
593 MlirOpOperand opOperand;
598 //------------------------------------------------------------------------------
600 //------------------------------------------------------------------------------
602 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
603 py::gil_scoped_acquire acquire;
604 auto &liveContexts = getLiveContexts();
605 liveContexts[context.ptr] = this;
608 PyMlirContext::~PyMlirContext() {
609 // Note that the only public way to construct an instance is via the
610 // forContext method, which always puts the associated handle into
612 py::gil_scoped_acquire acquire;
613 getLiveContexts().erase(context.ptr);
614 mlirContextDestroy(context);
617 py::object PyMlirContext::getCapsule() {
618 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
621 py::object PyMlirContext::createFromCapsule(py::object capsule) {
622 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
623 if (mlirContextIsNull(rawContext))
624 throw py::error_already_set();
625 return forContext(rawContext).releaseObject();
628 PyMlirContext *PyMlirContext::createNewContextForInit() {
629 MlirContext context = mlirContextCreateWithThreading(false);
630 return new PyMlirContext(context);
633 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
634 py::gil_scoped_acquire acquire;
635 auto &liveContexts = getLiveContexts();
636 auto it = liveContexts.find(context.ptr);
637 if (it == liveContexts.end()) {
639 PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
640 py::object pyRef = py::cast(unownedContextWrapper);
641 assert(pyRef && "cast to
py::object failed
");
642 liveContexts[context.ptr] = unownedContextWrapper;
643 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
646 py::object pyRef = py::cast(it->second);
647 return PyMlirContextRef(it->second, std::move(pyRef));
650 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
651 static LiveContextMap liveContexts;
655 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
657 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
659 std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
660 std::vector<PyOperation *> liveObjects;
661 for (auto &entry : liveOperations)
662 liveObjects.push_back(entry.second.second);
666 size_t PyMlirContext::clearLiveOperations() {
667 for (auto &op : liveOperations)
668 op.second.second->setInvalid();
669 size_t numInvalidated = liveOperations.size();
670 liveOperations.clear();
671 return numInvalidated;
674 void PyMlirContext::clearOperation(MlirOperation op) {
675 auto it = liveOperations.find(op.ptr);
676 if (it != liveOperations.end()) {
677 it->second.second->setInvalid();
678 liveOperations.erase(it);
682 void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
687 callBackData data{op.getOperation(), false};
688 // Mark all ops below the op that the passmanager will be rooted
689 // at (but not op itself - note the preorder) as invalid.
690 MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
692 callBackData *data = static_cast<callBackData *>(userData);
693 if (LLVM_LIKELY(data->rootSeen))
694 data->rootOp.getOperation().getContext()->clearOperation(op);
696 data->rootSeen = true;
697 return MlirWalkResult::MlirWalkResultAdvance;
699 mlirOperationWalk(op.getOperation(), invalidatingCallback,
700 static_cast<void *>(&data), MlirWalkPreOrder);
702 void PyMlirContext::clearOperationsInside(MlirOperation op) {
703 PyOperationRef opRef = PyOperation::forOperation(getRef(), op);
704 clearOperationsInside(opRef->getOperation());
707 void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
708 MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
710 PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
711 contextRef->clearOperation(op);
712 return MlirWalkResult::MlirWalkResultAdvance;
714 mlirOperationWalk(op.getOperation(), invalidatingCallback,
715 &op.getOperation().getContext(), MlirWalkPreOrder);
718 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
720 pybind11::object PyMlirContext::contextEnter() {
721 return PyThreadContextEntry::pushContext(*this);
724 void PyMlirContext::contextExit(const pybind11::object &excType,
725 const pybind11::object &excVal,
726 const pybind11::object &excTb) {
727 PyThreadContextEntry::popContext(*this);
730 py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
731 // Note that ownership is transferred to the delete callback below by way of
732 // an explicit inc_ref (borrow).
733 PyDiagnosticHandler *pyHandler =
734 new PyDiagnosticHandler(get(), std::move(callback));
735 py::object pyHandlerObject =
736 py::cast(pyHandler, py::return_value_policy::take_ownership);
737 pyHandlerObject.inc_ref();
739 // In these C callbacks, the userData is a PyDiagnosticHandler* that is
740 // guaranteed to be known to pybind.
741 auto handlerCallback =
742 +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
743 PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
744 py::object pyDiagnosticObject =
745 py::cast(pyDiagnostic, py::return_value_policy::take_ownership);
747 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
750 // Since this can be called from arbitrary C++ contexts, always get the
752 py::gil_scoped_acquire gil;
754 result = py::cast<bool>(pyHandler->callback(pyDiagnostic));
755 } catch (std::exception &e) {
756 fprintf(stderr, "MLIR Python Diagnostic handler raised exception
: %s
\n",
758 pyHandler->hadError = true;
762 pyDiagnostic->invalidate();
763 return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure();
765 auto deleteCallback = +[](void *userData) {
766 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
767 assert(pyHandler->registeredID && "handler is
not registered
");
768 pyHandler->registeredID.reset();
770 // Decrement reference, balancing the inc_ref() above.
771 py::object pyHandlerObject =
772 py::cast(pyHandler, py::return_value_policy::reference);
773 pyHandlerObject.dec_ref();
776 pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
777 get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
778 return pyHandlerObject;
781 MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
783 auto *self = static_cast<ErrorCapture *>(userData);
784 // Check if the context requested we emit errors instead of capturing them.
785 if (self->ctx->emitErrorDiagnostics)
786 return mlirLogicalResultFailure();
788 if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError)
789 return mlirLogicalResultFailure();
791 self->errors.emplace_back(PyDiagnostic(diag).getInfo());
792 return mlirLogicalResultSuccess();
795 PyMlirContext &DefaultingPyMlirContext::resolve() {
796 PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
798 throw std::runtime_error(
799 "An MLIR function requires a Context but none was provided in the call
"
800 "or from the surrounding environment
. Either pass to the function with
"
801 "a
'context=' argument
or establish a
default using 'with Context():'");
806 //------------------------------------------------------------------------------
807 // PyThreadContextEntry management
808 //------------------------------------------------------------------------------
810 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
811 static thread_local std::vector<PyThreadContextEntry> stack;
815 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
816 auto &stack = getStack();
819 return &stack.back();
822 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
823 py::object insertionPoint,
824 py::object location) {
825 auto &stack = getStack();
826 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
827 std::move(location));
828 // If the new stack has more than one entry and the context of the new top
829 // entry matches the previous, copy the insertionPoint and location from the
830 // previous entry if missing from the new top entry.
831 if (stack.size() > 1) {
832 auto &prev = *(stack.rbegin() + 1);
833 auto ¤t = stack.back();
834 if (current.context.is(prev.context)) {
835 // Default non-context objects from the previous entry.
836 if (!current.insertionPoint)
837 current.insertionPoint = prev.insertionPoint;
838 if (!current.location)
839 current.location = prev.location;
844 PyMlirContext *PyThreadContextEntry::getContext() {
847 return py::cast<PyMlirContext *>(context);
850 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
853 return py::cast<PyInsertionPoint *>(insertionPoint);
856 PyLocation *PyThreadContextEntry::getLocation() {
859 return py::cast<PyLocation *>(location);
862 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
863 auto *tos = getTopOfStack();
864 return tos ? tos->getContext() : nullptr;
867 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
868 auto *tos = getTopOfStack();
869 return tos ? tos->getInsertionPoint() : nullptr;
872 PyLocation *PyThreadContextEntry::getDefaultLocation() {
873 auto *tos = getTopOfStack();
874 return tos ? tos->getLocation() : nullptr;
877 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
878 py::object contextObj = py::cast(context);
879 push(FrameKind::Context, /*context=*/contextObj,
880 /*insertionPoint=*/py::object(),
881 /*location=*/py::object());
885 void PyThreadContextEntry::popContext(PyMlirContext &context) {
886 auto &stack = getStack();
888 throw std::runtime_error("Unbalanced Context enter
/exit
");
889 auto &tos = stack.back();
890 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
891 throw std::runtime_error("Unbalanced Context enter
/exit
");
896 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
897 py::object contextObj =
898 insertionPoint.getBlock().getParentOperation()->getContext().getObject();
899 py::object insertionPointObj = py::cast(insertionPoint);
900 push(FrameKind::InsertionPoint,
901 /*context=*/contextObj,
902 /*insertionPoint=*/insertionPointObj,
903 /*location=*/py::object());
904 return insertionPointObj;
907 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
908 auto &stack = getStack();
910 throw std::runtime_error("Unbalanced InsertionPoint enter
/exit
");
911 auto &tos = stack.back();
912 if (tos.frameKind != FrameKind::InsertionPoint &&
913 tos.getInsertionPoint() != &insertionPoint)
914 throw std::runtime_error("Unbalanced InsertionPoint enter
/exit
");
918 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
919 py::object contextObj = location.getContext().getObject();
920 py::object locationObj = py::cast(location);
921 push(FrameKind::Location, /*context=*/contextObj,
922 /*insertionPoint=*/py::object(),
923 /*location=*/locationObj);
927 void PyThreadContextEntry::popLocation(PyLocation &location) {
928 auto &stack = getStack();
930 throw std::runtime_error("Unbalanced Location enter
/exit
");
931 auto &tos = stack.back();
932 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
933 throw std::runtime_error("Unbalanced Location enter
/exit
");
937 //------------------------------------------------------------------------------
939 //------------------------------------------------------------------------------
941 void PyDiagnostic::invalidate() {
943 if (materializedNotes) {
944 for (auto ¬eObject : *materializedNotes) {
945 PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject);
951 PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context,
953 : context(context), callback(std::move(callback)) {}
955 PyDiagnosticHandler::~PyDiagnosticHandler() = default;
957 void PyDiagnosticHandler::detach() {
960 MlirDiagnosticHandlerID localID = *registeredID;
961 mlirContextDetachDiagnosticHandler(context, localID);
962 assert(!registeredID && "should have unregistered
");
963 // Not strictly necessary but keeps stale pointers from being around to cause
968 void PyDiagnostic::checkValid() {
970 throw std::invalid_argument(
971 "Diagnostic is
invalid (used outside of callback
)");
975 MlirDiagnosticSeverity PyDiagnostic::getSeverity() {
977 return mlirDiagnosticGetSeverity(diagnostic);
980 PyLocation PyDiagnostic::getLocation() {
982 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
983 MlirContext context = mlirLocationGetContext(loc);
984 return PyLocation(PyMlirContext::forContext(context), loc);
987 py::str PyDiagnostic::getMessage() {
989 py::object fileObject = py::module::import("io
").attr("StringIO
")();
990 PyFileAccumulator accum(fileObject, /*binary=*/false);
991 mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
992 return fileObject.attr("getvalue
")();
995 py::tuple PyDiagnostic::getNotes() {
997 if (materializedNotes)
998 return *materializedNotes;
999 intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
1000 materializedNotes = py::tuple(numNotes);
1001 for (intptr_t i = 0; i < numNotes; ++i) {
1002 MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
1003 (*materializedNotes)[i] = PyDiagnostic(noteDiag);
1005 return *materializedNotes;
1008 PyDiagnostic::DiagnosticInfo PyDiagnostic::getInfo() {
1009 std::vector<DiagnosticInfo> notes;
1010 for (py::handle n : getNotes())
1011 notes.emplace_back(n.cast<PyDiagnostic>().getInfo());
1012 return {getSeverity(), getLocation(), getMessage(), std::move(notes)};
1015 //------------------------------------------------------------------------------
1016 // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
1017 //------------------------------------------------------------------------------
1019 MlirDialect PyDialects::getDialectForKey(const std::string &key,
1021 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
1022 {key.data(), key.size()});
1023 if (mlirDialectIsNull(dialect)) {
1024 std::string msg = (Twine("Dialect
'") + key + "' not found
").str();
1026 throw py::attribute_error(msg);
1027 throw py::index_error(msg);
1032 py::object PyDialectRegistry::getCapsule() {
1033 return py::reinterpret_steal<py::object>(
1034 mlirPythonDialectRegistryToCapsule(*this));
1037 PyDialectRegistry PyDialectRegistry::createFromCapsule(py::object capsule) {
1038 MlirDialectRegistry rawRegistry =
1039 mlirPythonCapsuleToDialectRegistry(capsule.ptr());
1040 if (mlirDialectRegistryIsNull(rawRegistry))
1041 throw py::error_already_set();
1042 return PyDialectRegistry(rawRegistry);
1045 //------------------------------------------------------------------------------
1047 //------------------------------------------------------------------------------
1049 py::object PyLocation::getCapsule() {
1050 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
1053 PyLocation PyLocation::createFromCapsule(py::object capsule) {
1054 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
1055 if (mlirLocationIsNull(rawLoc))
1056 throw py::error_already_set();
1057 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
1061 py::object PyLocation::contextEnter() {
1062 return PyThreadContextEntry::pushLocation(*this);
1065 void PyLocation::contextExit(const pybind11::object &excType,
1066 const pybind11::object &excVal,
1067 const pybind11::object &excTb) {
1068 PyThreadContextEntry::popLocation(*this);
1071 PyLocation &DefaultingPyLocation::resolve() {
1072 auto *location = PyThreadContextEntry::getDefaultLocation();
1074 throw std::runtime_error(
1075 "An MLIR function requires a Location but none was provided in the
"
1076 "call
or from the surrounding environment
. Either pass to the function
"
1077 "with a
'loc=' argument
or establish a
default using 'with loc:'");
1082 //------------------------------------------------------------------------------
1084 //------------------------------------------------------------------------------
1086 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
1087 : BaseContextObject(std::move(contextRef)), module(module) {}
1089 PyModule::~PyModule() {
1090 py::gil_scoped_acquire acquire;
1091 auto &liveModules = getContext()->liveModules;
1092 assert(liveModules.count(module.ptr) == 1 &&
1093 "destroying module
not in live map
");
1094 liveModules.erase(module.ptr);
1095 mlirModuleDestroy(module);
1098 PyModuleRef PyModule::forModule(MlirModule module) {
1099 MlirContext context = mlirModuleGetContext(module);
1100 PyMlirContextRef contextRef = PyMlirContext::forContext(context);
1102 py::gil_scoped_acquire acquire;
1103 auto &liveModules = contextRef->liveModules;
1104 auto it = liveModules.find(module.ptr);
1105 if (it == liveModules.end()) {
1107 PyModule *unownedModule = new PyModule(std::move(contextRef), module);
1108 // Note that the default return value policy on cast is automatic_reference,
1109 // which does not take ownership (delete will not be called).
1110 // Just be explicit.
1112 py::cast(unownedModule, py::return_value_policy::take_ownership);
1113 unownedModule->handle = pyRef;
1114 liveModules[module.ptr] =
1115 std::make_pair(unownedModule->handle, unownedModule);
1116 return PyModuleRef(unownedModule, std::move(pyRef));
1119 PyModule *existing = it->second.second;
1120 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
1121 return PyModuleRef(existing, std::move(pyRef));
1124 py::object PyModule::createFromCapsule(py::object capsule) {
1125 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
1126 if (mlirModuleIsNull(rawModule))
1127 throw py::error_already_set();
1128 return forModule(rawModule).releaseObject();
1131 py::object PyModule::getCapsule() {
1132 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
1135 //------------------------------------------------------------------------------
1137 //------------------------------------------------------------------------------
1139 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
1140 : BaseContextObject(std::move(contextRef)), operation(operation) {}
1142 PyOperation::~PyOperation() {
1143 // If the operation has already been invalidated there is nothing to do.
1147 // Otherwise, invalidate the operation and remove it from live map when it is
1150 getContext()->clearOperation(*this);
1152 // And destroy it when it is detached, i.e. owned by Python, in which case
1153 // all nested operations must be invalidated at removed from the live map as
1159 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
1160 MlirOperation operation,
1161 py::object parentKeepAlive) {
1162 auto &liveOperations = contextRef->liveOperations;
1164 PyOperation *unownedOperation =
1165 new PyOperation(std::move(contextRef), operation);
1166 // Note that the default return value policy on cast is automatic_reference,
1167 // which does not take ownership (delete will not be called).
1168 // Just be explicit.
1170 py::cast(unownedOperation, py::return_value_policy::take_ownership);
1171 unownedOperation->handle = pyRef;
1172 if (parentKeepAlive) {
1173 unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
1175 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
1176 return PyOperationRef(unownedOperation, std::move(pyRef));
1179 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
1180 MlirOperation operation,
1181 py::object parentKeepAlive) {
1182 auto &liveOperations = contextRef->liveOperations;
1183 auto it = liveOperations.find(operation.ptr);
1184 if (it == liveOperations.end()) {
1186 return createInstance(std::move(contextRef), operation,
1187 std::move(parentKeepAlive));
1190 PyOperation *existing = it->second.second;
1191 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
1192 return PyOperationRef(existing, std::move(pyRef));
1195 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
1196 MlirOperation operation,
1197 py::object parentKeepAlive) {
1198 auto &liveOperations = contextRef->liveOperations;
1199 assert(liveOperations.count(operation.ptr) == 0 &&
1200 "cannot create detached operation that already exists
");
1201 (void)liveOperations;
1203 PyOperationRef created = createInstance(std::move(contextRef), operation,
1204 std::move(parentKeepAlive));
1205 created->attached = false;
1209 PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
1210 const std::string &sourceStr,
1211 const std::string &sourceName) {
1212 PyMlirContext::ErrorCapture errors(contextRef);
1214 mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
1215 toMlirStringRef(sourceName));
1216 if (mlirOperationIsNull(op))
1217 throw MLIRError("Unable to parse operation assembly
", errors.take());
1218 return PyOperation::createDetached(std::move(contextRef), op);
1221 void PyOperation::checkValid() const {
1223 throw std::runtime_error("the operation has been invalidated
");
1227 void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
1228 bool enableDebugInfo, bool prettyDebugInfo,
1229 bool printGenericOpForm, bool useLocalScope,
1230 bool assumeVerified, py::object fileObject,
1231 bool binary, bool skipRegions) {
1232 PyOperation &operation = getOperation();
1233 operation.checkValid();
1234 if (fileObject.is_none())
1235 fileObject = py::module::import("sys
").attr("stdout
");
1237 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1238 if (largeElementsLimit)
1239 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1240 if (enableDebugInfo)
1241 mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
1242 /*prettyForm=*/prettyDebugInfo);
1243 if (printGenericOpForm)
1244 mlirOpPrintingFlagsPrintGenericOpForm(flags);
1246 mlirOpPrintingFlagsUseLocalScope(flags);
1248 mlirOpPrintingFlagsAssumeVerified(flags);
1250 mlirOpPrintingFlagsSkipRegions(flags);
1252 PyFileAccumulator accum(fileObject, binary);
1253 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1254 accum.getUserData());
1255 mlirOpPrintingFlagsDestroy(flags);
1258 void PyOperationBase::print(PyAsmState &state, py::object fileObject,
1260 PyOperation &operation = getOperation();
1261 operation.checkValid();
1262 if (fileObject.is_none())
1263 fileObject = py::module::import("sys
").attr("stdout
");
1264 PyFileAccumulator accum(fileObject, binary);
1265 mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
1266 accum.getUserData());
1269 void PyOperationBase::writeBytecode(const py::object &fileObject,
1270 std::optional<int64_t> bytecodeVersion) {
1271 PyOperation &operation = getOperation();
1272 operation.checkValid();
1273 PyFileAccumulator accum(fileObject, /*binary=*/true);
1275 if (!bytecodeVersion.has_value())
1276 return mlirOperationWriteBytecode(operation, accum.getCallback(),
1277 accum.getUserData());
1279 MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
1280 mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion);
1281 MlirLogicalResult res = mlirOperationWriteBytecodeWithConfig(
1282 operation, config, accum.getCallback(), accum.getUserData());
1283 mlirBytecodeWriterConfigDestroy(config);
1284 if (mlirLogicalResultIsFailure(res))
1285 throw py::value_error((Twine("Unable to honor desired bytecode version
") +
1286 Twine(*bytecodeVersion))
1290 void PyOperationBase::walk(
1291 std::function<MlirWalkResult(MlirOperation)> callback,
1292 MlirWalkOrder walkOrder) {
1293 PyOperation &operation = getOperation();
1294 operation.checkValid();
1296 std::function<MlirWalkResult(MlirOperation)> callback;
1298 std::string exceptionWhat;
1299 py::object exceptionType;
1301 UserData userData{callback, false, {}, {}};
1302 MlirOperationWalkCallback walkCallback = [](MlirOperation op,
1304 UserData *calleeUserData = static_cast<UserData *>(userData);
1306 return (calleeUserData->callback)(op);
1307 } catch (py::error_already_set &e) {
1308 calleeUserData->gotException = true;
1309 calleeUserData->exceptionWhat = e.what();
1310 calleeUserData->exceptionType = e.type();
1311 return MlirWalkResult::MlirWalkResultInterrupt;
1314 mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
1315 if (userData.gotException) {
1316 std::string message("Exception raised in callback
: ");
1317 message.append(userData.exceptionWhat);
1318 throw std::runtime_error(message);
1322 py::object PyOperationBase::getAsm(bool binary,
1323 std::optional<int64_t> largeElementsLimit,
1324 bool enableDebugInfo, bool prettyDebugInfo,
1325 bool printGenericOpForm, bool useLocalScope,
1326 bool assumeVerified, bool skipRegions) {
1327 py::object fileObject;
1329 fileObject = py::module::import("io
").attr("BytesIO
")();
1331 fileObject = py::module::import("io
").attr("StringIO
")();
1333 print(/*largeElementsLimit=*/largeElementsLimit,
1334 /*enableDebugInfo=*/enableDebugInfo,
1335 /*prettyDebugInfo=*/prettyDebugInfo,
1336 /*printGenericOpForm=*/printGenericOpForm,
1337 /*useLocalScope=*/useLocalScope,
1338 /*assumeVerified=*/assumeVerified,
1339 /*fileObject=*/fileObject,
1341 /*skipRegions=*/skipRegions);
1343 return fileObject.attr("getvalue
")();
1346 void PyOperationBase::moveAfter(PyOperationBase &other) {
1347 PyOperation &operation = getOperation();
1348 PyOperation &otherOp = other.getOperation();
1349 operation.checkValid();
1350 otherOp.checkValid();
1351 mlirOperationMoveAfter(operation, otherOp);
1352 operation.parentKeepAlive = otherOp.parentKeepAlive;
1355 void PyOperationBase::moveBefore(PyOperationBase &other) {
1356 PyOperation &operation = getOperation();
1357 PyOperation &otherOp = other.getOperation();
1358 operation.checkValid();
1359 otherOp.checkValid();
1360 mlirOperationMoveBefore(operation, otherOp);
1361 operation.parentKeepAlive = otherOp.parentKeepAlive;
1364 bool PyOperationBase::verify() {
1365 PyOperation &op = getOperation();
1366 PyMlirContext::ErrorCapture errors(op.getContext());
1367 if (!mlirOperationVerify(op.get()))
1368 throw MLIRError("Verification failed
", errors.take());
1372 std::optional<PyOperationRef> PyOperation::getParentOperation() {
1375 throw py::value_error("Detached operations have no parent
");
1376 MlirOperation operation = mlirOperationGetParentOperation(get());
1377 if (mlirOperationIsNull(operation))
1379 return PyOperation::forOperation(getContext(), operation);
1382 PyBlock PyOperation::getBlock() {
1384 std::optional<PyOperationRef> parentOperation = getParentOperation();
1385 MlirBlock block = mlirOperationGetBlock(get());
1386 assert(!mlirBlockIsNull(block) && "Attached operation has null parent
");
1387 assert(parentOperation && "Operation has no parent
");
1388 return PyBlock{std::move(*parentOperation), block};
1391 py::object PyOperation::getCapsule() {
1393 return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
1396 py::object PyOperation::createFromCapsule(py::object capsule) {
1397 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1398 if (mlirOperationIsNull(rawOperation))
1399 throw py::error_already_set();
1400 MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1401 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1405 static void maybeInsertOperation(PyOperationRef &op,
1406 const py::object &maybeIp) {
1407 // InsertPoint active?
1408 if (!maybeIp.is(py::cast(false))) {
1409 PyInsertionPoint *ip;
1410 if (maybeIp.is_none()) {
1411 ip = PyThreadContextEntry::getDefaultInsertionPoint();
1413 ip = py::cast<PyInsertionPoint *>(maybeIp);
1416 ip->insert(*op.get());
1420 py::object PyOperation::create(const std::string &name,
1421 std::optional<std::vector<PyType *>> results,
1422 std::optional<std::vector<PyValue *>> operands,
1423 std::optional<py::dict> attributes,
1424 std::optional<std::vector<PyBlock *>> successors,
1425 int regions, DefaultingPyLocation location,
1426 const py::object &maybeIp, bool inferType) {
1427 llvm::SmallVector<MlirValue, 4> mlirOperands;
1428 llvm::SmallVector<MlirType, 4> mlirResults;
1429 llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1430 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
1432 // General parameter validation.
1434 throw py::value_error("number of regions must be
>= 0");
1436 // Unpack/validate operands.
1438 mlirOperands.reserve(operands->size());
1439 for (PyValue *operand : *operands) {
1441 throw py::value_error("operand value cannot be None
");
1442 mlirOperands.push_back(operand->get());
1446 // Unpack/validate results.
1448 mlirResults.reserve(results->size());
1449 for (PyType *result : *results) {
1450 // TODO: Verify result type originate from the same context.
1452 throw py::value_error("result type cannot be None
");
1453 mlirResults.push_back(*result);
1456 // Unpack/validate attributes.
1458 mlirAttributes.reserve(attributes->size());
1459 for (auto &it : *attributes) {
1462 key = it.first.cast<std::string>();
1463 } catch (py::cast_error &err) {
1464 std::string msg = "Invalid attribute
key (not a string
) when
"
1465 "attempting to create the operation
\"" +
1466 name + "\" (" + err.what() + ")";
1467 throw py::cast_error(msg);
1470 auto &attribute = it.second.cast<PyAttribute &>();
1471 // TODO: Verify attribute originates from the same context.
1472 mlirAttributes.emplace_back(std::move(key), attribute);
1473 } catch (py::reference_cast_error &) {
1474 // This exception seems thrown when the value is "None
".
1476 "Found an
invalid (`None`
?) attribute value
for the key
\"" + key +
1477 "\" when attempting to create the operation
\"" + name + "\"";
1478 throw py::cast_error(msg);
1479 } catch (py::cast_error &err) {
1480 std::string msg = "Invalid attribute value
for the key
\"" + key +
1481 "\" when attempting to create the operation
\"" +
1482 name + "\" (" + err.what() + ")";
1483 throw py::cast_error(msg);
1487 // Unpack/validate successors.
1489 mlirSuccessors.reserve(successors->size());
1490 for (auto *successor : *successors) {
1491 // TODO: Verify successor originate from the same context.
1493 throw py::value_error("successor block cannot be None
");
1494 mlirSuccessors.push_back(successor->get());
1498 // Apply unpacked/validated to the operation state. Beyond this
1499 // point, exceptions cannot be thrown or else the state will leak.
1500 MlirOperationState state =
1501 mlirOperationStateGet(toMlirStringRef(name), location);
1502 if (!mlirOperands.empty())
1503 mlirOperationStateAddOperands(&state, mlirOperands.size(),
1504 mlirOperands.data());
1505 state.enableResultTypeInference = inferType;
1506 if (!mlirResults.empty())
1507 mlirOperationStateAddResults(&state, mlirResults.size(),
1508 mlirResults.data());
1509 if (!mlirAttributes.empty()) {
1510 // Note that the attribute names directly reference bytes in
1511 // mlirAttributes, so that vector must not be changed from here
1513 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1514 mlirNamedAttributes.reserve(mlirAttributes.size());
1515 for (auto &it : mlirAttributes)
1516 mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1517 mlirIdentifierGet(mlirAttributeGetContext(it.second),
1518 toMlirStringRef(it.first)),
1520 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1521 mlirNamedAttributes.data());
1523 if (!mlirSuccessors.empty())
1524 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1525 mlirSuccessors.data());
1527 llvm::SmallVector<MlirRegion, 4> mlirRegions;
1528 mlirRegions.resize(regions);
1529 for (int i = 0; i < regions; ++i)
1530 mlirRegions[i] = mlirRegionCreate();
1531 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1532 mlirRegions.data());
1535 // Construct the operation.
1536 MlirOperation operation = mlirOperationCreate(&state);
1538 throw py::value_error("Operation creation failed
");
1539 PyOperationRef created =
1540 PyOperation::createDetached(location->getContext(), operation);
1541 maybeInsertOperation(created, maybeIp);
1543 return created.getObject();
1546 py::object PyOperation::clone(const py::object &maybeIp) {
1547 MlirOperation clonedOperation = mlirOperationClone(operation);
1548 PyOperationRef cloned =
1549 PyOperation::createDetached(getContext(), clonedOperation);
1550 maybeInsertOperation(cloned, maybeIp);
1552 return cloned->createOpView();
1555 py::object PyOperation::createOpView() {
1557 MlirIdentifier ident = mlirOperationGetName(get());
1558 MlirStringRef identStr = mlirIdentifierStr(ident);
1559 auto operationCls = PyGlobals::get().lookupOperationClass(
1560 StringRef(identStr.data, identStr.length));
1562 return PyOpView::constructDerived(*operationCls, *getRef().get());
1563 return py::cast(PyOpView(getRef().getObject()));
1566 void PyOperation::erase() {
1568 getContext()->clearOperationAndInside(*this);
1569 mlirOperationDestroy(operation);
1572 //------------------------------------------------------------------------------
1574 //------------------------------------------------------------------------------
1576 static void populateResultTypes(StringRef name, py::list resultTypeList,
1577 const py::object &resultSegmentSpecObj,
1578 std::vector<int32_t> &resultSegmentLengths,
1579 std::vector<PyType *> &resultTypes) {
1580 resultTypes.reserve(resultTypeList.size());
1581 if (resultSegmentSpecObj.is_none()) {
1582 // Non-variadic result unpacking.
1583 for (const auto &it : llvm::enumerate(resultTypeList)) {
1585 resultTypes.push_back(py::cast<PyType *>(it.value()));
1586 if (!resultTypes.back())
1587 throw py::cast_error();
1588 } catch (py::cast_error &err) {
1589 throw py::value_error((llvm::Twine("Result
") +
1590 llvm::Twine(it.index()) + " of operation
\"" +
1591 name + "\" must be a
Type (" + err.what() + ")")
1596 // Sized result unpacking.
1597 auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1598 if (resultSegmentSpec.size() != resultTypeList.size()) {
1599 throw py::value_error((llvm::Twine("Operation
\"") + name +
1601 llvm::Twine(resultSegmentSpec.size()) +
1602 " result segments but was provided
" +
1603 llvm::Twine(resultTypeList.size()))
1606 resultSegmentLengths.reserve(resultTypeList.size());
1607 for (const auto &it :
1608 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1609 int segmentSpec = std::get<1>(it.value());
1610 if (segmentSpec == 1 || segmentSpec == 0) {
1611 // Unpack unary element.
1613 auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1615 resultTypes.push_back(resultType);
1616 resultSegmentLengths.push_back(1);
1617 } else if (segmentSpec == 0) {
1618 // Allowed to be optional.
1619 resultSegmentLengths.push_back(0);
1621 throw py::cast_error("was None
and result is
not optional
");
1623 } catch (py::cast_error &err) {
1624 throw py::value_error((llvm::Twine("Result
") +
1625 llvm::Twine(it.index()) + " of operation
\"" +
1626 name + "\" must be a
Type (" + err.what() +
1630 } else if (segmentSpec == -1) {
1631 // Unpack sequence by appending.
1633 if (std::get<0>(it.value()).is_none()) {
1634 // Treat it as an empty list.
1635 resultSegmentLengths.push_back(0);
1638 auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1639 for (py::object segmentItem : segment) {
1640 resultTypes.push_back(py::cast<PyType *>(segmentItem));
1641 if (!resultTypes.back()) {
1642 throw py::cast_error("contained a None item
");
1645 resultSegmentLengths.push_back(segment.size());
1647 } catch (std::exception &err) {
1648 // NOTE: Sloppy to be using a catch-all here, but there are at least
1649 // three different unrelated exceptions that can be thrown in the
1650 // above "casts
". Just keep the scope above small and catch them all.
1651 throw py::value_error((llvm::Twine("Result
") +
1652 llvm::Twine(it.index()) + " of operation
\"" +
1653 name + "\" must be a Sequence of
Types (" +
1658 throw py::value_error("Unexpected segment spec
");
1664 py::object PyOpView::buildGeneric(
1665 const py::object &cls, std::optional<py::list> resultTypeList,
1666 py::list operandList, std::optional<py::dict> attributes,
1667 std::optional<std::vector<PyBlock *>> successors,
1668 std::optional<int> regions, DefaultingPyLocation location,
1669 const py::object &maybeIp) {
1670 PyMlirContextRef context = location->getContext();
1671 // Class level operation construction metadata.
1672 std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME
"));
1673 // Operand and result segment specs are either none, which does no
1674 // variadic unpacking, or a list of ints with segment sizes, where each
1675 // element is either a positive number (typically 1 for a scalar) or -1 to
1676 // indicate that it is derived from the length of the same-indexed operand
1677 // or result (implying that it is a list at that position).
1678 py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS
");
1679 py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS
");
1681 std::vector<int32_t> operandSegmentLengths;
1682 std::vector<int32_t> resultSegmentLengths;
1684 // Validate/determine region count.
1685 auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS
"));
1686 int opMinRegionCount = std::get<0>(opRegionSpec);
1687 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1689 regions = opMinRegionCount;
1691 if (*regions < opMinRegionCount) {
1692 throw py::value_error(
1693 (llvm::Twine("Operation
\"") + name + "\" requires a minimum of
" +
1694 llvm::Twine(opMinRegionCount) +
1695 " regions but was built with regions
=" + llvm::Twine(*regions))
1698 if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1699 throw py::value_error(
1700 (llvm::Twine("Operation
\"") + name + "\" requires a maximum of
" +
1701 llvm::Twine(opMinRegionCount) +
1702 " regions but was built with regions
=" + llvm::Twine(*regions))
1707 std::vector<PyType *> resultTypes;
1708 if (resultTypeList.has_value()) {
1709 populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
1710 resultSegmentLengths, resultTypes);
1714 std::vector<PyValue *> operands;
1715 operands.reserve(operands.size());
1716 if (operandSegmentSpecObj.is_none()) {
1717 // Non-sized operand unpacking.
1718 for (const auto &it : llvm::enumerate(operandList)) {
1720 operands.push_back(py::cast<PyValue *>(it.value()));
1721 if (!operands.back())
1722 throw py::cast_error();
1723 } catch (py::cast_error &err) {
1724 throw py::value_error((llvm::Twine("Operand
") +
1725 llvm::Twine(it.index()) + " of operation
\"" +
1726 name + "\" must be a
Value (" + err.what() + ")")
1731 // Sized operand unpacking.
1732 auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1733 if (operandSegmentSpec.size() != operandList.size()) {
1734 throw py::value_error((llvm::Twine("Operation
\"") + name +
1736 llvm::Twine(operandSegmentSpec.size()) +
1737 "operand segments but was provided
" +
1738 llvm::Twine(operandList.size()))
1741 operandSegmentLengths.reserve(operandList.size());
1742 for (const auto &it :
1743 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1744 int segmentSpec = std::get<1>(it.value());
1745 if (segmentSpec == 1 || segmentSpec == 0) {
1746 // Unpack unary element.
1748 auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1750 operands.push_back(operandValue);
1751 operandSegmentLengths.push_back(1);
1752 } else if (segmentSpec == 0) {
1753 // Allowed to be optional.
1754 operandSegmentLengths.push_back(0);
1756 throw py::cast_error("was None
and operand is
not optional
");
1758 } catch (py::cast_error &err) {
1759 throw py::value_error((llvm::Twine("Operand
") +
1760 llvm::Twine(it.index()) + " of operation
\"" +
1761 name + "\" must be a
Value (" + err.what() +
1765 } else if (segmentSpec == -1) {
1766 // Unpack sequence by appending.
1768 if (std::get<0>(it.value()).is_none()) {
1769 // Treat it as an empty list.
1770 operandSegmentLengths.push_back(0);
1773 auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1774 for (py::object segmentItem : segment) {
1775 operands.push_back(py::cast<PyValue *>(segmentItem));
1776 if (!operands.back()) {
1777 throw py::cast_error("contained a None item
");
1780 operandSegmentLengths.push_back(segment.size());
1782 } catch (std::exception &err) {
1783 // NOTE: Sloppy to be using a catch-all here, but there are at least
1784 // three different unrelated exceptions that can be thrown in the
1785 // above "casts
". Just keep the scope above small and catch them all.
1786 throw py::value_error((llvm::Twine("Operand
") +
1787 llvm::Twine(it.index()) + " of operation
\"" +
1788 name + "\" must be a Sequence of
Values (" +
1793 throw py::value_error("Unexpected segment spec
");
1798 // Merge operand/result segment lengths into attributes if needed.
1799 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1802 attributes = py::dict(*attributes);
1804 attributes = py::dict();
1806 if (attributes->contains("resultSegmentSizes
") ||
1807 attributes->contains("operandSegmentSizes
")) {
1808 throw py::value_error("Manually setting a
'resultSegmentSizes' or "
1809 "'operandSegmentSizes' attribute is unsupported
. "
1810 "Use Operation
.create
for such low
-level access
.");
1813 // Add resultSegmentSizes attribute.
1814 if (!resultSegmentLengths.empty()) {
1815 MlirAttribute segmentLengthAttr =
1816 mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(),
1817 resultSegmentLengths.data());
1818 (*attributes)["resultSegmentSizes
"] =
1819 PyAttribute(context, segmentLengthAttr);
1822 // Add operandSegmentSizes attribute.
1823 if (!operandSegmentLengths.empty()) {
1824 MlirAttribute segmentLengthAttr =
1825 mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(),
1826 operandSegmentLengths.data());
1827 (*attributes)["operandSegmentSizes
"] =
1828 PyAttribute(context, segmentLengthAttr);
1832 // Delegate to create.
1833 return PyOperation::create(name,
1834 /*results=*/std::move(resultTypes),
1835 /*operands=*/std::move(operands),
1836 /*attributes=*/std::move(attributes),
1837 /*successors=*/std::move(successors),
1838 /*regions=*/*regions, location, maybeIp,
1842 pybind11::object PyOpView::constructDerived(const pybind11::object &cls,
1843 const PyOperation &operation) {
1844 // TODO: pybind11 2.6 supports a more direct form.
1845 // Upgrade many years from now.
1846 // auto opViewType = py::type::of<PyOpView>();
1847 py::handle opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1848 py::object instance = cls.attr("__new__
")(cls);
1849 opViewType.attr("__init__
")(instance, operation);
1853 PyOpView::PyOpView(const py::object &operationObject)
1854 // Casting through the PyOperationBase base-class and then back to the
1855 // Operation lets us accept any PyOperationBase subclass.
1856 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1857 operationObject(operation.getRef().getObject()) {}
1859 //------------------------------------------------------------------------------
1860 // PyInsertionPoint.
1861 //------------------------------------------------------------------------------
1863 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1865 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1866 : refOperation(beforeOperationBase.getOperation().getRef()),
1867 block((*refOperation)->getBlock()) {}
1869 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1870 PyOperation &operation = operationBase.getOperation();
1871 if (operation.isAttached())
1872 throw py::value_error(
1873 "Attempt to insert operation that is already attached
");
1874 block.getParentOperation()->checkValid();
1875 MlirOperation beforeOp = {nullptr};
1877 // Insert before operation.
1878 (*refOperation)->checkValid();
1879 beforeOp = (*refOperation)->get();
1881 // Insert at end (before null) is only valid if the block does not
1882 // already end in a known terminator (violating this will cause assertion
1884 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1885 throw py::index_error("Cannot insert operation at the end of a block
"
1886 "that already has a terminator
. Did you mean to
"
1887 "use
'InsertionPoint.at_block_terminator(block)' "
1888 "versus
'InsertionPoint(block)'?");
1891 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1892 operation.setAttached();
1895 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1896 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1897 if (mlirOperationIsNull(firstOp)) {
1898 // Just insert at end.
1899 return PyInsertionPoint(block);
1902 // Insert before first op.
1903 PyOperationRef firstOpRef = PyOperation::forOperation(
1904 block.getParentOperation()->getContext(), firstOp);
1905 return PyInsertionPoint{block, std::move(firstOpRef)};
1908 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1909 MlirOperation terminator = mlirBlockGetTerminator(block.get());
1910 if (mlirOperationIsNull(terminator))
1911 throw py::value_error("Block has no terminator
");
1912 PyOperationRef terminatorOpRef = PyOperation::forOperation(
1913 block.getParentOperation()->getContext(), terminator);
1914 return PyInsertionPoint{block, std::move(terminatorOpRef)};
1917 py::object PyInsertionPoint::contextEnter() {
1918 return PyThreadContextEntry::pushInsertionPoint(*this);
1921 void PyInsertionPoint::contextExit(const pybind11::object &excType,
1922 const pybind11::object &excVal,
1923 const pybind11::object &excTb) {
1924 PyThreadContextEntry::popInsertionPoint(*this);
1927 //------------------------------------------------------------------------------
1929 //------------------------------------------------------------------------------
1931 bool PyAttribute::operator==(const PyAttribute &other) const {
1932 return mlirAttributeEqual(attr, other.attr);
1935 py::object PyAttribute::getCapsule() {
1936 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1939 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1940 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1941 if (mlirAttributeIsNull(rawAttr))
1942 throw py::error_already_set();
1944 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1947 //------------------------------------------------------------------------------
1948 // PyNamedAttribute.
1949 //------------------------------------------------------------------------------
1951 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1952 : ownedName(new std::string(std::move(ownedName))) {
1953 namedAttr = mlirNamedAttributeGet(
1954 mlirIdentifierGet(mlirAttributeGetContext(attr),
1955 toMlirStringRef(*this->ownedName)),
1959 //------------------------------------------------------------------------------
1961 //------------------------------------------------------------------------------
1963 bool PyType::operator==(const PyType &other) const {
1964 return mlirTypeEqual(type, other.type);
1967 py::object PyType::getCapsule() {
1968 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1971 PyType PyType::createFromCapsule(py::object capsule) {
1972 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1973 if (mlirTypeIsNull(rawType))
1974 throw py::error_already_set();
1975 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1979 //------------------------------------------------------------------------------
1981 //------------------------------------------------------------------------------
1983 py::object PyTypeID::getCapsule() {
1984 return py::reinterpret_steal<py::object>(mlirPythonTypeIDToCapsule(*this));
1987 PyTypeID PyTypeID::createFromCapsule(py::object capsule) {
1988 MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
1989 if (mlirTypeIDIsNull(mlirTypeID))
1990 throw py::error_already_set();
1991 return PyTypeID(mlirTypeID);
1993 bool PyTypeID::operator==(const PyTypeID &other) const {
1994 return mlirTypeIDEqual(typeID, other.typeID);
1997 //------------------------------------------------------------------------------
1998 // PyValue and subclasses.
1999 //------------------------------------------------------------------------------
2001 pybind11::object PyValue::getCapsule() {
2002 return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
2005 pybind11::object PyValue::maybeDownCast() {
2006 MlirType type = mlirValueGetType(get());
2007 MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
2008 assert(!mlirTypeIDIsNull(mlirTypeID) &&
2009 "mlirTypeID was expected to be non
-null
.");
2010 std::optional<pybind11::function> valueCaster =
2011 PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
2012 // py::return_value_policy::move means use std::move to move the return value
2013 // contents into a new instance that will be owned by Python.
2014 py::object thisObj = py::cast(this, py::return_value_policy::move);
2017 return valueCaster.value()(thisObj);
2020 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
2021 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
2022 if (mlirValueIsNull(value))
2023 throw py::error_already_set();
2024 MlirOperation owner;
2025 if (mlirValueIsAOpResult(value))
2026 owner = mlirOpResultGetOwner(value);
2027 if (mlirValueIsABlockArgument(value))
2028 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
2029 if (mlirOperationIsNull(owner))
2030 throw py::error_already_set();
2031 MlirContext ctx = mlirOperationGetContext(owner);
2032 PyOperationRef ownerRef =
2033 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
2034 return PyValue(ownerRef, value);
2037 //------------------------------------------------------------------------------
2039 //------------------------------------------------------------------------------
2041 PySymbolTable::PySymbolTable(PyOperationBase &operation)
2042 : operation(operation.getOperation().getRef()) {
2043 symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
2044 if (mlirSymbolTableIsNull(symbolTable)) {
2045 throw py::cast_error("Operation is
not a Symbol Table
.");
2049 py::object PySymbolTable::dunderGetItem(const std::string &name) {
2050 operation->checkValid();
2051 MlirOperation symbol = mlirSymbolTableLookup(
2052 symbolTable, mlirStringRefCreate(name.data(), name.length()));
2053 if (mlirOperationIsNull(symbol))
2054 throw py::key_error("Symbol
'" + name + "' not in the symbol table
.");
2056 return PyOperation::forOperation(operation->getContext(), symbol,
2057 operation.getObject())
2061 void PySymbolTable::erase(PyOperationBase &symbol) {
2062 operation->checkValid();
2063 symbol.getOperation().checkValid();
2064 mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
2065 // The operation is also erased, so we must invalidate it. There may be Python
2066 // references to this operation so we don't want to delete it from the list of
2067 // live operations here.
2068 symbol.getOperation().valid = false;
2071 void PySymbolTable::dunderDel(const std::string &name) {
2072 py::object operation = dunderGetItem(name);
2073 erase(py::cast<PyOperationBase &>(operation));
2076 MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) {
2077 operation->checkValid();
2078 symbol.getOperation().checkValid();
2079 MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
2080 symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
2081 if (mlirAttributeIsNull(symbolAttr))
2082 throw py::value_error("Expected operation to have a symbol name
.");
2083 return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get());
2086 MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
2087 // Op must already be a symbol.
2088 PyOperation &operation = symbol.getOperation();
2089 operation.checkValid();
2090 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
2091 MlirAttribute existingNameAttr =
2092 mlirOperationGetAttributeByName(operation.get(), attrName);
2093 if (mlirAttributeIsNull(existingNameAttr))
2094 throw py::value_error("Expected operation to have a symbol name
.");
2095 return existingNameAttr;
2098 void PySymbolTable::setSymbolName(PyOperationBase &symbol,
2099 const std::string &name) {
2100 // Op must already be a symbol.
2101 PyOperation &operation = symbol.getOperation();
2102 operation.checkValid();
2103 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
2104 MlirAttribute existingNameAttr =
2105 mlirOperationGetAttributeByName(operation.get(), attrName);
2106 if (mlirAttributeIsNull(existingNameAttr))
2107 throw py::value_error("Expected operation to have a symbol name
.");
2108 MlirAttribute newNameAttr =
2109 mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
2110 mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
2113 MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
2114 PyOperation &operation = symbol.getOperation();
2115 operation.checkValid();
2116 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
2117 MlirAttribute existingVisAttr =
2118 mlirOperationGetAttributeByName(operation.get(), attrName);
2119 if (mlirAttributeIsNull(existingVisAttr))
2120 throw py::value_error("Expected operation to have a symbol visibility
.");
2121 return existingVisAttr;
2124 void PySymbolTable::setVisibility(PyOperationBase &symbol,
2125 const std::string &visibility) {
2126 if (visibility != "public" && visibility != "private" &&
2127 visibility != "nested
")
2128 throw py::value_error(
2129 "Expected visibility to be
'public', 'private' or 'nested'");
2130 PyOperation &operation = symbol.getOperation();
2131 operation.checkValid();
2132 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
2133 MlirAttribute existingVisAttr =
2134 mlirOperationGetAttributeByName(operation.get(), attrName);
2135 if (mlirAttributeIsNull(existingVisAttr))
2136 throw py::value_error("Expected operation to have a symbol visibility
.");
2137 MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
2138 toMlirStringRef(visibility));
2139 mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
2142 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
2143 const std::string &newSymbol,
2144 PyOperationBase &from) {
2145 PyOperation &fromOperation = from.getOperation();
2146 fromOperation.checkValid();
2147 if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
2148 toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
2149 from.getOperation())))
2151 throw py::value_error("Symbol rename failed
");
2154 void PySymbolTable::walkSymbolTables(PyOperationBase &from,
2155 bool allSymUsesVisible,
2156 py::object callback) {
2157 PyOperation &fromOperation = from.getOperation();
2158 fromOperation.checkValid();
2160 PyMlirContextRef context;
2161 py::object callback;
2163 std::string exceptionWhat;
2164 py::object exceptionType;
2167 fromOperation.getContext(), std::move(callback), false, {}, {}};
2168 mlirSymbolTableWalkSymbolTables(
2169 fromOperation.get(), allSymUsesVisible,
2170 [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
2171 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
2173 PyOperation::forOperation(calleeUserData->context, foundOp);
2174 if (calleeUserData->gotException)
2177 calleeUserData->callback(pyFoundOp.getObject(), isVisible);
2178 } catch (py::error_already_set &e) {
2179 calleeUserData->gotException = true;
2180 calleeUserData->exceptionWhat = e.what();
2181 calleeUserData->exceptionType = e.type();
2184 static_cast<void *>(&userData));
2185 if (userData.gotException) {
2186 std::string message("Exception raised in callback
: ");
2187 message.append(userData.exceptionWhat);
2188 throw std::runtime_error(message);
2193 /// CRTP base class for Python MLIR values that subclass Value and should be
2194 /// castable from it. The value hierarchy is one level deep and is not supposed
2195 /// to accommodate other levels unless core MLIR changes.
2196 template <typename DerivedTy>
2197 class PyConcreteValue : public PyValue {
2199 // Derived classes must define statics for:
2200 // IsAFunctionTy isaFunction
2201 // const char *pyClassName
2202 // and redefine bindDerived.
2203 using ClassTy = py::class_<DerivedTy, PyValue>;
2204 using IsAFunctionTy = bool (*)(MlirValue);
2206 PyConcreteValue() = default;
2207 PyConcreteValue(PyOperationRef operationRef, MlirValue value)
2208 : PyValue(operationRef, value) {}
2209 PyConcreteValue(PyValue &orig)
2210 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
2212 /// Attempts to cast the original value to the derived type and throws on
2213 /// type mismatches.
2214 static MlirValue castFrom(PyValue &orig) {
2215 if (!DerivedTy::isaFunction(orig.get())) {
2216 auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
2217 throw py::value_error((Twine("Cannot cast value to
") +
2218 DerivedTy::pyClassName + " (from
" + origRepr +
2225 /// Binds the Python module objects to functions of this class.
2226 static void bind(py::module &m) {
2227 auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
2228 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value
"));
2231 [](PyValue &otherValue) -> bool {
2232 return DerivedTy::isaFunction(otherValue);
2234 py::arg("other_value
"));
2235 cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
2236 [](DerivedTy &self) { return self.maybeDownCast(); });
2237 DerivedTy::bindDerived(cls);
2240 /// Implemented by derived classes to add methods to the Python subclass.
2241 static void bindDerived(ClassTy &m) {}
2244 /// Python wrapper for MlirBlockArgument.
2245 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
2247 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
2248 static constexpr const char *pyClassName = "BlockArgument
";
2249 using PyConcreteValue::PyConcreteValue;
2251 static void bindDerived(ClassTy &c) {
2252 c.def_property_readonly("owner
", [](PyBlockArgument &self) {
2253 return PyBlock(self.getParentOperation(),
2254 mlirBlockArgumentGetOwner(self.get()));
2256 c.def_property_readonly("arg_number
", [](PyBlockArgument &self) {
2257 return mlirBlockArgumentGetArgNumber(self.get());
2261 [](PyBlockArgument &self, PyType type) {
2262 return mlirBlockArgumentSetType(self.get(), type);
2268 /// Python wrapper for MlirOpResult.
2269 class PyOpResult : public PyConcreteValue<PyOpResult> {
2271 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
2272 static constexpr const char *pyClassName = "OpResult
";
2273 using PyConcreteValue::PyConcreteValue;
2275 static void bindDerived(ClassTy &c) {
2276 c.def_property_readonly("owner
", [](PyOpResult &self) {
2278 mlirOperationEqual(self.getParentOperation()->get(),
2279 mlirOpResultGetOwner(self.get())) &&
2280 "expected the owner of the value in Python to match that in the IR
");
2281 return self.getParentOperation().getObject();
2283 c.def_property_readonly("result_number
", [](PyOpResult &self) {
2284 return mlirOpResultGetResultNumber(self.get());
2289 /// Returns the list of types of the values held by container.
2290 template <typename Container>
2291 static std::vector<MlirType> getValueTypes(Container &container,
2292 PyMlirContextRef &context) {
2293 std::vector<MlirType> result;
2294 result.reserve(container.size());
2295 for (int i = 0, e = container.size(); i < e; ++i) {
2296 result.push_back(mlirValueGetType(container.getElement(i).get()));
2301 /// A list of block arguments. Internally, these are stored as consecutive
2302 /// elements, random access is cheap. The argument list is associated with the
2303 /// operation that contains the block (detached blocks are not allowed in
2304 /// Python bindings) and extends its lifetime.
2305 class PyBlockArgumentList
2306 : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
2308 static constexpr const char *pyClassName = "BlockArgumentList
";
2309 using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
2311 PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
2312 intptr_t startIndex = 0, intptr_t length = -1,
2314 : Sliceable(startIndex,
2315 length == -1 ? mlirBlockGetNumArguments(block) : length,
2317 operation(std::move(operation)), block(block) {}
2319 static void bindDerived(ClassTy &c) {
2320 c.def_property_readonly("types
", [](PyBlockArgumentList &self) {
2321 return getValueTypes(self, self.operation->getContext());
2326 /// Give the parent CRTP class access to hook implementations below.
2327 friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
2329 /// Returns the number of arguments in the list.
2330 intptr_t getRawNumElements() {
2331 operation->checkValid();
2332 return mlirBlockGetNumArguments(block);
2335 /// Returns `pos`-the element in the list.
2336 PyBlockArgument getRawElement(intptr_t pos) {
2337 MlirValue argument = mlirBlockGetArgument(block, pos);
2338 return PyBlockArgument(operation, argument);
2341 /// Returns a sublist of this list.
2342 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
2344 return PyBlockArgumentList(operation, block, startIndex, length, step);
2347 PyOperationRef operation;
2351 /// A list of operation operands. Internally, these are stored as consecutive
2352 /// elements, random access is cheap. The (returned) operand list is associated
2353 /// with the operation whose operands these are, and thus extends the lifetime
2354 /// of this operation.
2355 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2357 static constexpr const char *pyClassName = "OpOperandList
";
2358 using SliceableT = Sliceable<PyOpOperandList, PyValue>;
2360 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2361 intptr_t length = -1, intptr_t step = 1)
2362 : Sliceable(startIndex,
2363 length == -1 ? mlirOperationGetNumOperands(operation->get())
2366 operation(operation) {}
2368 void dunderSetItem(intptr_t index, PyValue value) {
2369 index = wrapIndex(index);
2370 mlirOperationSetOperand(operation->get(), index, value.get());
2373 static void bindDerived(ClassTy &c) {
2374 c.def("__setitem__
", &PyOpOperandList::dunderSetItem);
2378 /// Give the parent CRTP class access to hook implementations below.
2379 friend class Sliceable<PyOpOperandList, PyValue>;
2381 intptr_t getRawNumElements() {
2382 operation->checkValid();
2383 return mlirOperationGetNumOperands(operation->get());
2386 PyValue getRawElement(intptr_t pos) {
2387 MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2388 MlirOperation owner;
2389 if (mlirValueIsAOpResult(operand))
2390 owner = mlirOpResultGetOwner(operand);
2391 else if (mlirValueIsABlockArgument(operand))
2392 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
2394 assert(false && "Value must be an block arg
or op result
.");
2395 PyOperationRef pyOwner =
2396 PyOperation::forOperation(operation->getContext(), owner);
2397 return PyValue(pyOwner, operand);
2400 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2401 return PyOpOperandList(operation, startIndex, length, step);
2404 PyOperationRef operation;
2407 /// A list of operation results. Internally, these are stored as consecutive
2408 /// elements, random access is cheap. The (returned) result list is associated
2409 /// with the operation whose results these are, and thus extends the lifetime of
2411 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
2413 static constexpr const char *pyClassName = "OpResultList
";
2414 using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
2416 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
2417 intptr_t length = -1, intptr_t step = 1)
2418 : Sliceable(startIndex,
2419 length == -1 ? mlirOperationGetNumResults(operation->get())
2422 operation(std::move(operation)) {}
2424 static void bindDerived(ClassTy &c) {
2425 c.def_property_readonly("types
", [](PyOpResultList &self) {
2426 return getValueTypes(self, self.operation->getContext());
2428 c.def_property_readonly("owner
", [](PyOpResultList &self) {
2429 return self.operation->createOpView();
2434 /// Give the parent CRTP class access to hook implementations below.
2435 friend class Sliceable<PyOpResultList, PyOpResult>;
2437 intptr_t getRawNumElements() {
2438 operation->checkValid();
2439 return mlirOperationGetNumResults(operation->get());
2442 PyOpResult getRawElement(intptr_t index) {
2443 PyValue value(operation, mlirOperationGetResult(operation->get(), index));
2444 return PyOpResult(value);
2447 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2448 return PyOpResultList(operation, startIndex, length, step);
2451 PyOperationRef operation;
2454 /// A list of operation successors. Internally, these are stored as consecutive
2455 /// elements, random access is cheap. The (returned) successor list is
2456 /// associated with the operation whose successors these are, and thus extends
2457 /// the lifetime of this operation.
2458 class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
2460 static constexpr const char *pyClassName = "OpSuccessors
";
2462 PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
2463 intptr_t length = -1, intptr_t step = 1)
2464 : Sliceable(startIndex,
2465 length == -1 ? mlirOperationGetNumSuccessors(operation->get())
2468 operation(operation) {}
2470 void dunderSetItem(intptr_t index, PyBlock block) {
2471 index = wrapIndex(index);
2472 mlirOperationSetSuccessor(operation->get(), index, block.get());
2475 static void bindDerived(ClassTy &c) {
2476 c.def("__setitem__
", &PyOpSuccessors::dunderSetItem);
2480 /// Give the parent CRTP class access to hook implementations below.
2481 friend class Sliceable<PyOpSuccessors, PyBlock>;
2483 intptr_t getRawNumElements() {
2484 operation->checkValid();
2485 return mlirOperationGetNumSuccessors(operation->get());
2488 PyBlock getRawElement(intptr_t pos) {
2489 MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
2490 return PyBlock(operation, block);
2493 PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2494 return PyOpSuccessors(operation, startIndex, length, step);
2497 PyOperationRef operation;
2500 /// A list of operation attributes. Can be indexed by name, producing
2501 /// attributes, or by index, producing named attributes.
2502 class PyOpAttributeMap {
2504 PyOpAttributeMap(PyOperationRef operation)
2505 : operation(std::move(operation)) {}
2507 MlirAttribute dunderGetItemNamed(const std::string &name) {
2508 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2509 toMlirStringRef(name));
2510 if (mlirAttributeIsNull(attr)) {
2511 throw py::key_error("attempt to access a non
-existent attribute
");
2516 PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2517 if (index < 0 || index >= dunderLen()) {
2518 throw py::index_error("attempt to access out of bounds attribute
");
2520 MlirNamedAttribute namedAttr =
2521 mlirOperationGetAttribute(operation->get(), index);
2522 return PyNamedAttribute(
2523 namedAttr.attribute,
2524 std::string(mlirIdentifierStr(namedAttr.name).data,
2525 mlirIdentifierStr(namedAttr.name).length));
2528 void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2529 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2533 void dunderDelItem(const std::string &name) {
2534 int removed = mlirOperationRemoveAttributeByName(operation->get(),
2535 toMlirStringRef(name));
2537 throw py::key_error("attempt to
delete a non
-existent attribute
");
2540 intptr_t dunderLen() {
2541 return mlirOperationGetNumAttributes(operation->get());
2544 bool dunderContains(const std::string &name) {
2545 return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
2546 operation->get(), toMlirStringRef(name)));
2549 static void bind(py::module &m) {
2550 py::class_<PyOpAttributeMap>(m, "OpAttributeMap
", py::module_local())
2551 .def("__contains__
", &PyOpAttributeMap::dunderContains)
2552 .def("__len__
", &PyOpAttributeMap::dunderLen)
2553 .def("__getitem__
", &PyOpAttributeMap::dunderGetItemNamed)
2554 .def("__getitem__
", &PyOpAttributeMap::dunderGetItemIndexed)
2555 .def("__setitem__
", &PyOpAttributeMap::dunderSetItem)
2556 .def("__delitem__
", &PyOpAttributeMap::dunderDelItem);
2560 PyOperationRef operation;
2565 //------------------------------------------------------------------------------
2566 // Populates the core exports of the 'ir' submodule.
2567 //------------------------------------------------------------------------------
2569 void mlir::python::populateIRCore(py::module &m) {
2570 //----------------------------------------------------------------------------
2572 //----------------------------------------------------------------------------
2573 py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity
", py::module_local())
2574 .value("ERROR
", MlirDiagnosticError)
2575 .value("WARNING
", MlirDiagnosticWarning)
2576 .value("NOTE
", MlirDiagnosticNote)
2577 .value("REMARK
", MlirDiagnosticRemark);
2579 py::enum_<MlirWalkOrder>(m, "WalkOrder
", py::module_local())
2580 .value("PRE_ORDER
", MlirWalkPreOrder)
2581 .value("POST_ORDER
", MlirWalkPostOrder);
2583 py::enum_<MlirWalkResult>(m, "WalkResult
", py::module_local())
2584 .value("ADVANCE
", MlirWalkResultAdvance)
2585 .value("INTERRUPT
", MlirWalkResultInterrupt)
2586 .value("SKIP
", MlirWalkResultSkip);
2588 //----------------------------------------------------------------------------
2589 // Mapping of Diagnostics.
2590 //----------------------------------------------------------------------------
2591 py::class_<PyDiagnostic>(m, "Diagnostic
", py::module_local())
2592 .def_property_readonly("severity
", &PyDiagnostic::getSeverity)
2593 .def_property_readonly("location
", &PyDiagnostic::getLocation)
2594 .def_property_readonly("message
", &PyDiagnostic::getMessage)
2595 .def_property_readonly("notes
", &PyDiagnostic::getNotes)
2596 .def("__str__
", [](PyDiagnostic &self) -> py::str {
2597 if (!self.isValid())
2598 return "<Invalid Diagnostic
>";
2599 return self.getMessage();
2602 py::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo
",
2604 .def(py::init<>([](PyDiagnostic diag) { return diag.getInfo(); }))
2605 .def_readonly("severity
", &PyDiagnostic::DiagnosticInfo::severity)
2606 .def_readonly("location
", &PyDiagnostic::DiagnosticInfo::location)
2607 .def_readonly("message
", &PyDiagnostic::DiagnosticInfo::message)
2608 .def_readonly("notes
", &PyDiagnostic::DiagnosticInfo::notes)
2610 [](PyDiagnostic::DiagnosticInfo &self) { return self.message; });
2612 py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler
", py::module_local())
2613 .def("detach
", &PyDiagnosticHandler::detach)
2614 .def_property_readonly("attached
", &PyDiagnosticHandler::isAttached)
2615 .def_property_readonly("had_error
", &PyDiagnosticHandler::getHadError)
2616 .def("__enter__
", &PyDiagnosticHandler::contextEnter)
2617 .def("__exit__
", &PyDiagnosticHandler::contextExit);
2619 //----------------------------------------------------------------------------
2620 // Mapping of MlirContext.
2621 // Note that this is exported as _BaseContext. The containing, Python level
2622 // __init__.py will subclass it with site-specific functionality and set a
2623 // "Context
" attribute on this module.
2624 //----------------------------------------------------------------------------
2625 py::class_<PyMlirContext>(m, "_BaseContext
", py::module_local())
2626 .def(py::init<>(&PyMlirContext::createNewContextForInit))
2627 .def_static("_get_live_count
", &PyMlirContext::getLiveCount)
2628 .def("_get_context_again
",
2629 [](PyMlirContext &self) {
2630 PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2631 return ref.releaseObject();
2633 .def("_get_live_operation_count
", &PyMlirContext::getLiveOperationCount)
2634 .def("_get_live_operation_objects
",
2635 &PyMlirContext::getLiveOperationObjects)
2636 .def("_clear_live_operations
", &PyMlirContext::clearLiveOperations)
2637 .def("_clear_live_operations_inside
",
2638 py::overload_cast<MlirOperation>(
2639 &PyMlirContext::clearOperationsInside))
2640 .def("_get_live_module_count
", &PyMlirContext::getLiveModuleCount)
2641 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2642 &PyMlirContext::getCapsule)
2643 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2644 .def("__enter__
", &PyMlirContext::contextEnter)
2645 .def("__exit__
", &PyMlirContext::contextExit)
2646 .def_property_readonly_static(
2648 [](py::object & /*class*/) {
2649 auto *context = PyThreadContextEntry::getDefaultContext();
2651 return py::none().cast<py::object>();
2652 return py::cast(context);
2654 "Gets the Context bound to the current thread
or raises ValueError
")
2655 .def_property_readonly(
2657 [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2658 "Gets a container
for accessing dialects by name
")
2659 .def_property_readonly(
2660 "d
", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2661 "Alias
for 'dialect'")
2663 "get_dialect_descriptor
",
2664 [=](PyMlirContext &self, std::string &name) {
2665 MlirDialect dialect = mlirContextGetOrLoadDialect(
2666 self.get(), {name.data(), name.size()});
2667 if (mlirDialectIsNull(dialect)) {
2668 throw py::value_error(
2669 (Twine("Dialect
'") + name + "' not found
").str());
2671 return PyDialectDescriptor(self.getRef(), dialect);
2673 py::arg("dialect_name
"),
2674 "Gets
or loads a dialect by name
, returning its descriptor object
")
2676 "allow_unregistered_dialects
",
2677 [](PyMlirContext &self) -> bool {
2678 return mlirContextGetAllowUnregisteredDialects(self.get());
2680 [](PyMlirContext &self, bool value) {
2681 mlirContextSetAllowUnregisteredDialects(self.get(), value);
2683 .def("attach_diagnostic_handler
", &PyMlirContext::attachDiagnosticHandler,
2684 py::arg("callback
"),
2685 "Attaches a diagnostic handler that will receive callbacks
")
2687 "enable_multithreading
",
2688 [](PyMlirContext &self, bool enable) {
2689 mlirContextEnableMultithreading(self.get(), enable);
2693 "is_registered_operation
",
2694 [](PyMlirContext &self, std::string &name) {
2695 return mlirContextIsRegisteredOperation(
2696 self.get(), MlirStringRef{name.data(), name.size()});
2698 py::arg("operation_name
"))
2700 "append_dialect_registry
",
2701 [](PyMlirContext &self, PyDialectRegistry ®istry) {
2702 mlirContextAppendDialectRegistry(self.get(), registry);
2704 py::arg("registry
"))
2705 .def_property("emit_error_diagnostics
", nullptr,
2706 &PyMlirContext::setEmitErrorDiagnostics,
2707 "Emit error diagnostics to diagnostic handlers
. By
default "
2708 "error diagnostics are captured
and reported through
"
2709 "MLIRError exceptions
.")
2710 .def("load_all_available_dialects
", [](PyMlirContext &self) {
2711 mlirContextLoadAllAvailableDialects(self.get());
2714 //----------------------------------------------------------------------------
2715 // Mapping of PyDialectDescriptor
2716 //----------------------------------------------------------------------------
2717 py::class_<PyDialectDescriptor>(m, "DialectDescriptor
", py::module_local())
2718 .def_property_readonly("namespace",
2719 [](PyDialectDescriptor &self) {
2721 mlirDialectGetNamespace(self.get());
2722 return py::str(ns.data, ns.length);
2724 .def("__repr__
", [](PyDialectDescriptor &self) {
2725 MlirStringRef ns = mlirDialectGetNamespace(self.get());
2726 std::string repr("<DialectDescriptor
");
2727 repr.append(ns.data, ns.length);
2732 //----------------------------------------------------------------------------
2733 // Mapping of PyDialects
2734 //----------------------------------------------------------------------------
2735 py::class_<PyDialects>(m, "Dialects
", py::module_local())
2737 [=](PyDialects &self, std::string keyName) {
2738 MlirDialect dialect =
2739 self.getDialectForKey(keyName, /*attrError=*/false);
2740 py::object descriptor =
2741 py::cast(PyDialectDescriptor{self.getContext(), dialect});
2742 return createCustomDialectWrapper(keyName, std::move(descriptor));
2744 .def("__getattr__
", [=](PyDialects &self, std::string attrName) {
2745 MlirDialect dialect =
2746 self.getDialectForKey(attrName, /*attrError=*/true);
2747 py::object descriptor =
2748 py::cast(PyDialectDescriptor{self.getContext(), dialect});
2749 return createCustomDialectWrapper(attrName, std::move(descriptor));
2752 //----------------------------------------------------------------------------
2753 // Mapping of PyDialect
2754 //----------------------------------------------------------------------------
2755 py::class_<PyDialect>(m, "Dialect
", py::module_local())
2756 .def(py::init<py::object>(), py::arg("descriptor
"))
2757 .def_property_readonly(
2758 "descriptor
", [](PyDialect &self) { return self.getDescriptor(); })
2759 .def("__repr__
", [](py::object self) {
2760 auto clazz = self.attr("__class__
");
2761 return py::str("<Dialect
") +
2762 self.attr("descriptor
").attr("namespace") + py::str(" (class ") +
2763 clazz.attr("__module__
") + py::str(".") +
2764 clazz.attr("__name__
") + py::str(")>");
2767 //----------------------------------------------------------------------------
2768 // Mapping of PyDialectRegistry
2769 //----------------------------------------------------------------------------
2770 py::class_<PyDialectRegistry>(m, "DialectRegistry
", py::module_local())
2771 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2772 &PyDialectRegistry::getCapsule)
2773 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule)
2776 //----------------------------------------------------------------------------
2777 // Mapping of Location
2778 //----------------------------------------------------------------------------
2779 py::class_<PyLocation>(m, "Location
", py::module_local())
2780 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2781 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2782 .def("__enter__
", &PyLocation::contextEnter)
2783 .def("__exit__
", &PyLocation::contextExit)
2785 [](PyLocation &self, PyLocation &other) -> bool {
2786 return mlirLocationEqual(self, other);
2788 .def("__eq__
", [](PyLocation &self, py::object other) { return false; })
2789 .def_property_readonly_static(
2791 [](py::object & /*class*/) {
2792 auto *loc = PyThreadContextEntry::getDefaultLocation();
2794 throw py::value_error("No current Location
");
2797 "Gets the Location bound to the current thread
or raises ValueError
")
2800 [](DefaultingPyMlirContext context) {
2801 return PyLocation(context->getRef(),
2802 mlirLocationUnknownGet(context->get()));
2804 py::arg("context
") = py::none(),
2805 "Gets a Location representing an unknown location
")
2808 [](PyLocation callee, const std::vector<PyLocation> &frames,
2809 DefaultingPyMlirContext context) {
2811 throw py::value_error("No caller frames provided
");
2812 MlirLocation caller = frames.back().get();
2813 for (const PyLocation &frame :
2814 llvm::reverse(llvm::ArrayRef(frames).drop_back()))
2815 caller = mlirLocationCallSiteGet(frame.get(), caller);
2816 return PyLocation(context->getRef(),
2817 mlirLocationCallSiteGet(callee.get(), caller));
2819 py::arg("callee
"), py::arg("frames
"), py::arg("context
") = py::none(),
2820 kContextGetCallSiteLocationDocstring)
2823 [](std::string filename, int line, int col,
2824 DefaultingPyMlirContext context) {
2827 mlirLocationFileLineColGet(
2828 context->get(), toMlirStringRef(filename), line, col));
2830 py::arg("filename
"), py::arg("line
"), py::arg("col
"),
2831 py::arg("context
") = py::none(), kContextGetFileLocationDocstring)
2834 [](const std::vector<PyLocation> &pyLocations,
2835 std::optional<PyAttribute> metadata,
2836 DefaultingPyMlirContext context) {
2837 llvm::SmallVector<MlirLocation, 4> locations;
2838 locations.reserve(pyLocations.size());
2839 for (auto &pyLocation : pyLocations)
2840 locations.push_back(pyLocation.get());
2841 MlirLocation location = mlirLocationFusedGet(
2842 context->get(), locations.size(), locations.data(),
2843 metadata ? metadata->get() : MlirAttribute{0});
2844 return PyLocation(context->getRef(), location);
2846 py::arg("locations
"), py::arg("metadata
") = py::none(),
2847 py::arg("context
") = py::none(), kContextGetFusedLocationDocstring)
2850 [](std::string name, std::optional<PyLocation> childLoc,
2851 DefaultingPyMlirContext context) {
2854 mlirLocationNameGet(
2855 context->get(), toMlirStringRef(name),
2856 childLoc ? childLoc->get()
2857 : mlirLocationUnknownGet(context->get())));
2859 py::arg("name
"), py::arg("childLoc
") = py::none(),
2860 py::arg("context
") = py::none(), kContextGetNameLocationDocString)
2863 [](PyAttribute &attribute, DefaultingPyMlirContext context) {
2864 return PyLocation(context->getRef(),
2865 mlirLocationFromAttribute(attribute));
2867 py::arg("attribute
"), py::arg("context
") = py::none(),
2868 "Gets a Location from a LocationAttr
")
2869 .def_property_readonly(
2871 [](PyLocation &self) { return self.getContext().getObject(); },
2872 "Context that owns the Location
")
2873 .def_property_readonly(
2875 [](PyLocation &self) { return mlirLocationGetAttribute(self); },
2876 "Get the underlying LocationAttr
")
2879 [](PyLocation &self, std::string message) {
2880 mlirEmitError(self, message.c_str());
2882 py::arg("message
"), "Emits an error at
this location
")
2883 .def("__repr__
", [](PyLocation &self) {
2884 PyPrintAccumulator printAccum;
2885 mlirLocationPrint(self, printAccum.getCallback(),
2886 printAccum.getUserData());
2887 return printAccum.join();
2890 //----------------------------------------------------------------------------
2891 // Mapping of Module
2892 //----------------------------------------------------------------------------
2893 py::class_<PyModule>(m, "Module
", py::module_local())
2894 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2895 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2898 [](const std::string &moduleAsm, DefaultingPyMlirContext context) {
2899 PyMlirContext::ErrorCapture errors(context->getRef());
2900 MlirModule module = mlirModuleCreateParse(
2901 context->get(), toMlirStringRef(moduleAsm));
2902 if (mlirModuleIsNull(module))
2903 throw MLIRError("Unable to parse module assembly
", errors.take());
2904 return PyModule::forModule(module).releaseObject();
2906 py::arg("asm"), py::arg("context
") = py::none(),
2907 kModuleParseDocstring)
2910 [](DefaultingPyLocation loc) {
2911 MlirModule module = mlirModuleCreateEmpty(loc);
2912 return PyModule::forModule(module).releaseObject();
2914 py::arg("loc
") = py::none(), "Creates an empty module
")
2915 .def_property_readonly(
2917 [](PyModule &self) { return self.getContext().getObject(); },
2918 "Context that created the Module
")
2919 .def_property_readonly(
2921 [](PyModule &self) {
2922 return PyOperation::forOperation(self.getContext(),
2923 mlirModuleGetOperation(self.get()),
2924 self.getRef().releaseObject())
2927 "Accesses the module as an operation
")
2928 .def_property_readonly(
2930 [](PyModule &self) {
2931 PyOperationRef moduleOp = PyOperation::forOperation(
2932 self.getContext(), mlirModuleGetOperation(self.get()),
2933 self.getRef().releaseObject());
2934 PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
2937 "Return the block
for this module
")
2940 [](PyModule &self) {
2941 mlirOperationDump(mlirModuleGetOperation(self.get()));
2946 [](py::object self) {
2947 // Defer to the operation's __str__.
2948 return self.attr("operation
").attr("__str__
")();
2950 kOperationStrDunderDocstring);
2952 //----------------------------------------------------------------------------
2953 // Mapping of Operation.
2954 //----------------------------------------------------------------------------
2955 py::class_<PyOperationBase>(m, "_OperationBase
", py::module_local())
2956 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2957 [](PyOperationBase &self) {
2958 return self.getOperation().getCapsule();
2961 [](PyOperationBase &self, PyOperationBase &other) {
2962 return &self.getOperation() == &other.getOperation();
2965 [](PyOperationBase &self, py::object other) { return false; })
2967 [](PyOperationBase &self) {
2968 return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2970 .def_property_readonly("attributes
",
2971 [](PyOperationBase &self) {
2972 return PyOpAttributeMap(
2973 self.getOperation().getRef());
2975 .def_property_readonly(
2977 [](PyOperationBase &self) {
2978 PyOperation &concreteOperation = self.getOperation();
2979 concreteOperation.checkValid();
2980 return concreteOperation.getContext().getObject();
2982 "Context that owns the Operation
")
2983 .def_property_readonly("name
",
2984 [](PyOperationBase &self) {
2985 auto &concreteOperation = self.getOperation();
2986 concreteOperation.checkValid();
2987 MlirOperation operation =
2988 concreteOperation.get();
2989 MlirStringRef name = mlirIdentifierStr(
2990 mlirOperationGetName(operation));
2991 return py::str(name.data, name.length);
2993 .def_property_readonly("operands
",
2994 [](PyOperationBase &self) {
2995 return PyOpOperandList(
2996 self.getOperation().getRef());
2998 .def_property_readonly("regions
",
2999 [](PyOperationBase &self) {
3000 return PyRegionList(
3001 self.getOperation().getRef());
3003 .def_property_readonly(
3005 [](PyOperationBase &self) {
3006 return PyOpResultList(self.getOperation().getRef());
3008 "Returns the list of Operation results
.")
3009 .def_property_readonly(
3011 [](PyOperationBase &self) {
3012 auto &operation = self.getOperation();
3013 auto numResults = mlirOperationGetNumResults(operation);
3014 if (numResults != 1) {
3015 auto name = mlirIdentifierStr(mlirOperationGetName(operation));
3016 throw py::value_error(
3017 (Twine("Cannot call
.result on operation
") +
3018 StringRef(name.data, name.length) + " which has
" +
3020 " results (it is only valid
for operations with a
"
3024 return PyOpResult(operation.getRef(),
3025 mlirOperationGetResult(operation, 0))
3028 "Shortcut to get an op result
if it has only
one (throws an error
"
3030 .def_property_readonly(
3032 [](PyOperationBase &self) {
3033 PyOperation &operation = self.getOperation();
3034 return PyLocation(operation.getContext(),
3035 mlirOperationGetLocation(operation.get()));
3037 "Returns the source location the operation was defined
or derived
"
3039 .def_property_readonly("parent
",
3040 [](PyOperationBase &self) -> py::object {
3042 self.getOperation().getParentOperation();
3044 return parent->getObject();
3049 [](PyOperationBase &self) {
3050 return self.getAsm(/*binary=*/false,
3051 /*largeElementsLimit=*/std::nullopt,
3052 /*enableDebugInfo=*/false,
3053 /*prettyDebugInfo=*/false,
3054 /*printGenericOpForm=*/false,
3055 /*useLocalScope=*/false,
3056 /*assumeVerified=*/false,
3057 /*skipRegions=*/false);
3059 "Returns the assembly form of the operation
.")
3061 py::overload_cast<PyAsmState &, pybind11::object, bool>(
3062 &PyOperationBase::print),
3063 py::arg("state
"), py::arg("file
") = py::none(),
3064 py::arg("binary
") = false, kOperationPrintStateDocstring)
3066 py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
3067 bool, py::object, bool, bool>(
3068 &PyOperationBase::print),
3069 // Careful: Lots of arguments must match up with print method.
3070 py::arg("large_elements_limit
") = py::none(),
3071 py::arg("enable_debug_info
") = false,
3072 py::arg("pretty_debug_info
") = false,
3073 py::arg("print_generic_op_form
") = false,
3074 py::arg("use_local_scope
") = false,
3075 py::arg("assume_verified
") = false, py::arg("file
") = py::none(),
3076 py::arg("binary
") = false, py::arg("skip_regions
") = false,
3077 kOperationPrintDocstring)
3078 .def("write_bytecode
", &PyOperationBase::writeBytecode, py::arg("file
"),
3079 py::arg("desired_version
") = py::none(),
3080 kOperationPrintBytecodeDocstring)
3081 .def("get_asm
", &PyOperationBase::getAsm,
3082 // Careful: Lots of arguments must match up with get_asm method.
3083 py::arg("binary
") = false,
3084 py::arg("large_elements_limit
") = py::none(),
3085 py::arg("enable_debug_info
") = false,
3086 py::arg("pretty_debug_info
") = false,
3087 py::arg("print_generic_op_form
") = false,
3088 py::arg("use_local_scope
") = false,
3089 py::arg("assume_verified
") = false, py::arg("skip_regions
") = false,
3090 kOperationGetAsmDocstring)
3091 .def("verify
", &PyOperationBase::verify,
3092 "Verify the operation
. Raises MLIRError
if verification fails
, and "
3093 "returns
true otherwise
.")
3094 .def("move_after
", &PyOperationBase::moveAfter, py::arg("other
"),
3095 "Puts self immediately after the other operation in its parent
"
3097 .def("move_before
", &PyOperationBase::moveBefore, py::arg("other
"),
3098 "Puts self immediately before the other operation in its parent
"
3102 [](PyOperationBase &self, py::object ip) {
3103 return self.getOperation().clone(ip);
3105 py::arg("ip
") = py::none())
3107 "detach_from_parent
",
3108 [](PyOperationBase &self) {
3109 PyOperation &operation = self.getOperation();
3110 operation.checkValid();
3111 if (!operation.isAttached())
3112 throw py::value_error("Detached operation has no parent
.");
3114 operation.detachFromParent();
3115 return operation.createOpView();
3117 "Detaches the operation from its parent block
.")
3118 .def("erase
", [](PyOperationBase &self) { self.getOperation().erase(); })
3119 .def("walk
", &PyOperationBase::walk, py::arg("callback
"),
3120 py::arg("walk_order
") = MlirWalkPostOrder);
3122 py::class_<PyOperation, PyOperationBase>(m, "Operation
", py::module_local())
3123 .def_static("create
", &PyOperation::create, py::arg("name
"),
3124 py::arg("results
") = py::none(),
3125 py::arg("operands
") = py::none(),
3126 py::arg("attributes
") = py::none(),
3127 py::arg("successors
") = py::none(), py::arg("regions
") = 0,
3128 py::arg("loc
") = py::none(), py::arg("ip
") = py::none(),
3129 py::arg("infer_type
") = false, kOperationCreateDocstring)
3132 [](const std::string &sourceStr, const std::string &sourceName,
3133 DefaultingPyMlirContext context) {
3134 return PyOperation::parse(context->getRef(), sourceStr, sourceName)
3137 py::arg("source
"), py::kw_only(), py::arg("source_name
") = "",
3138 py::arg("context
") = py::none(),
3139 "Parses an operation
. Supports both text assembly format
and binary
"
3141 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
3142 &PyOperation::getCapsule)
3143 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
3144 .def_property_readonly("operation
", [](py::object self) { return self; })
3145 .def_property_readonly("opview
", &PyOperation::createOpView)
3146 .def_property_readonly(
3148 [](PyOperationBase &self) {
3149 return PyOpSuccessors(self.getOperation().getRef());
3151 "Returns the list of Operation successors
.");
3154 py::class_<PyOpView, PyOperationBase>(m, "OpView
", py::module_local())
3155 .def(py::init<py::object>(), py::arg("operation
"))
3156 .def_property_readonly("operation
", &PyOpView::getOperationObject)
3157 .def_property_readonly("opview
", [](py::object self) { return self; })
3160 [](PyOpView &self) { return py::str(self.getOperationObject()); })
3161 .def_property_readonly(
3163 [](PyOperationBase &self) {
3164 return PyOpSuccessors(self.getOperation().getRef());
3166 "Returns the list of Operation successors
.");
3167 opViewClass.attr("_ODS_REGIONS
") = py::make_tuple(0, true);
3168 opViewClass.attr("_ODS_OPERAND_SEGMENTS
") = py::none();
3169 opViewClass.attr("_ODS_RESULT_SEGMENTS
") = py::none();
3170 opViewClass.attr("build_generic
") = classmethod(
3171 &PyOpView::buildGeneric, py::arg("cls
"), py::arg("results
") = py::none(),
3172 py::arg("operands
") = py::none(), py::arg("attributes
") = py::none(),
3173 py::arg("successors
") = py::none(), py::arg("regions
") = py::none(),
3174 py::arg("loc
") = py::none(), py::arg("ip
") = py::none(),
3175 "Builds a specific
, generated OpView based on
class level attributes
.");
3176 opViewClass.attr("parse
") = classmethod(
3177 [](const py::object &cls, const std::string &sourceStr,
3178 const std::string &sourceName, DefaultingPyMlirContext context) {
3179 PyOperationRef parsed =
3180 PyOperation::parse(context->getRef(), sourceStr, sourceName);
3182 // Check if the expected operation was parsed, and cast to to the
3183 // appropriate `OpView` subclass if successful.
3184 // NOTE: This accesses attributes that have been automatically added to
3185 // `OpView` subclasses, and is not intended to be used on `OpView`
3187 std::string clsOpName =
3188 py::cast<std::string>(cls.attr("OPERATION_NAME
"));
3189 MlirStringRef identifier =
3190 mlirIdentifierStr(mlirOperationGetName(*parsed.get()));
3191 std::string_view parsedOpName(identifier.data, identifier.length);
3192 if (clsOpName != parsedOpName)
3193 throw MLIRError(Twine("Expected a
'") + clsOpName + "' op
, got
: '" +
3194 parsedOpName + "'");
3195 return PyOpView::constructDerived(cls, *parsed.get());
3197 py::arg("cls
"), py::arg("source
"), py::kw_only(),
3198 py::arg("source_name
") = "", py::arg("context
") = py::none(),
3199 "Parses a specific
, generated OpView based on
class level attributes
");
3201 //----------------------------------------------------------------------------
3202 // Mapping of PyRegion.
3203 //----------------------------------------------------------------------------
3204 py::class_<PyRegion>(m, "Region
", py::module_local())
3205 .def_property_readonly(
3207 [](PyRegion &self) {
3208 return PyBlockList(self.getParentOperation(), self.get());
3210 "Returns a forward
-optimized sequence of blocks
.")
3211 .def_property_readonly(
3213 [](PyRegion &self) {
3214 return self.getParentOperation()->createOpView();
3216 "Returns the operation owning
this region
.")
3219 [](PyRegion &self) {
3221 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
3222 return PyBlockIterator(self.getParentOperation(), firstBlock);
3224 "Iterates over blocks in the region
.")
3226 [](PyRegion &self, PyRegion &other) {
3227 return self.get().ptr == other.get().ptr;
3229 .def("__eq__
", [](PyRegion &self, py::object &other) { return false; });
3231 //----------------------------------------------------------------------------
3232 // Mapping of PyBlock.
3233 //----------------------------------------------------------------------------
3234 py::class_<PyBlock>(m, "Block
", py::module_local())
3235 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule)
3236 .def_property_readonly(
3239 return self.getParentOperation()->createOpView();
3241 "Returns the owning operation of
this block
.")
3242 .def_property_readonly(
3245 MlirRegion region = mlirBlockGetParentRegion(self.get());
3246 return PyRegion(self.getParentOperation(), region);
3248 "Returns the owning region of
this block
.")
3249 .def_property_readonly(
3252 return PyBlockArgumentList(self.getParentOperation(), self.get());
3254 "Returns a list of block arguments
.")
3257 [](PyBlock &self, const PyType &type, const PyLocation &loc) {
3258 return mlirBlockAddArgument(self.get(), type, loc);
3260 "Append an argument of the specified type to the block
and returns
"
3261 "the newly added argument
.")
3264 [](PyBlock &self, unsigned index) {
3265 return mlirBlockEraseArgument(self.get(), index);
3267 "Erase the argument at
'index' and remove it from the argument list
.")
3268 .def_property_readonly(
3271 return PyOperationList(self.getParentOperation(), self.get());
3273 "Returns a forward
-optimized sequence of operations
.")
3276 [](PyRegion &parent, const py::list &pyArgTypes,
3277 const std::optional<py::sequence> &pyArgLocs) {
3278 parent.checkValid();
3279 MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3280 mlirRegionInsertOwnedBlock(parent, 0, block);
3281 return PyBlock(parent.getParentOperation(), block);
3283 py::arg("parent
"), py::arg("arg_types
") = py::list(),
3284 py::arg("arg_locs
") = std::nullopt,
3285 "Creates
and returns a
new Block at the beginning of the given
"
3286 "region (with given argument types
and locations
).")
3289 [](PyBlock &self, PyRegion ®ion) {
3290 MlirBlock b = self.get();
3291 if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
3293 mlirRegionAppendOwnedBlock(region.get(), b);
3295 "Append
this block to a region
, transferring ownership
if necessary
")
3298 [](PyBlock &self, const py::args &pyArgTypes,
3299 const std::optional<py::sequence> &pyArgLocs) {
3301 MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3302 MlirRegion region = mlirBlockGetParentRegion(self.get());
3303 mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
3304 return PyBlock(self.getParentOperation(), block);
3306 py::arg("arg_locs
") = std::nullopt,
3307 "Creates
and returns a
new Block before
this block
"
3308 "(with given argument types
and locations
).")
3311 [](PyBlock &self, const py::args &pyArgTypes,
3312 const std::optional<py::sequence> &pyArgLocs) {
3314 MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3315 MlirRegion region = mlirBlockGetParentRegion(self.get());
3316 mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
3317 return PyBlock(self.getParentOperation(), block);
3319 py::arg("arg_locs
") = std::nullopt,
3320 "Creates
and returns a
new Block after
this block
"
3321 "(with given argument types
and locations
).")
3326 MlirOperation firstOperation =
3327 mlirBlockGetFirstOperation(self.get());
3328 return PyOperationIterator(self.getParentOperation(),
3331 "Iterates over operations in the block
.")
3333 [](PyBlock &self, PyBlock &other) {
3334 return self.get().ptr == other.get().ptr;
3336 .def("__eq__
", [](PyBlock &self, py::object &other) { return false; })
3339 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3345 PyPrintAccumulator printAccum;
3346 mlirBlockPrint(self.get(), printAccum.getCallback(),
3347 printAccum.getUserData());
3348 return printAccum.join();
3350 "Returns the assembly form of the block
.")
3353 [](PyBlock &self, PyOperationBase &operation) {
3354 if (operation.getOperation().isAttached())
3355 operation.getOperation().detachFromParent();
3357 MlirOperation mlirOperation = operation.getOperation().get();
3358 mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
3359 operation.getOperation().setAttached(
3360 self.getParentOperation().getObject());
3362 py::arg("operation
"),
3363 "Appends an operation to
this block
. If the operation is currently
"
3364 "in another block
, it will be moved
.");
3366 //----------------------------------------------------------------------------
3367 // Mapping of PyInsertionPoint.
3368 //----------------------------------------------------------------------------
3370 py::class_<PyInsertionPoint>(m, "InsertionPoint
", py::module_local())
3371 .def(py::init<PyBlock &>(), py::arg("block
"),
3372 "Inserts after the last operation but still inside the block
.")
3373 .def("__enter__
", &PyInsertionPoint::contextEnter)
3374 .def("__exit__
", &PyInsertionPoint::contextExit)
3375 .def_property_readonly_static(
3377 [](py::object & /*class*/) {
3378 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
3380 throw py::value_error("No current InsertionPoint
");
3383 "Gets the InsertionPoint bound to the current thread
or raises
"
3384 "ValueError
if none has been set
")
3385 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation
"),
3386 "Inserts before a referenced operation
.")
3387 .def_static("at_block_begin
", &PyInsertionPoint::atBlockBegin,
3388 py::arg("block
"), "Inserts at the beginning of the block
.")
3389 .def_static("at_block_terminator
", &PyInsertionPoint::atBlockTerminator,
3390 py::arg("block
"), "Inserts before the block terminator
.")
3391 .def("insert
", &PyInsertionPoint::insert, py::arg("operation
"),
3392 "Inserts an operation
.")
3393 .def_property_readonly(
3394 "block
", [](PyInsertionPoint &self) { return self.getBlock(); },
3395 "Returns the block that
this InsertionPoint points to
.")
3396 .def_property_readonly(
3398 [](PyInsertionPoint &self) -> py::object {
3399 auto refOperation = self.getRefOperation();
3401 return refOperation->getObject();
3404 "The reference operation before which
new operations are
"
3405 "inserted
, or None
if the insertion point is at the end of
"
3408 //----------------------------------------------------------------------------
3409 // Mapping of PyAttribute.
3410 //----------------------------------------------------------------------------
3411 py::class_<PyAttribute>(m, "Attribute
", py::module_local())
3412 // Delegate to the PyAttribute copy constructor, which will also lifetime
3413 // extend the backing context which owns the MlirAttribute.
3414 .def(py::init<PyAttribute &>(), py::arg("cast_from_type
"),
3415 "Casts the passed attribute to the generic Attribute
")
3416 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
3417 &PyAttribute::getCapsule)
3418 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
3421 [](const std::string &attrSpec, DefaultingPyMlirContext context) {
3422 PyMlirContext::ErrorCapture errors(context->getRef());
3423 MlirAttribute attr = mlirAttributeParseGet(
3424 context->get(), toMlirStringRef(attrSpec));
3425 if (mlirAttributeIsNull(attr))
3426 throw MLIRError("Unable to parse attribute
", errors.take());
3429 py::arg("asm"), py::arg("context
") = py::none(),
3430 "Parses an attribute from an assembly form
. Raises an MLIRError on
"
3432 .def_property_readonly(
3434 [](PyAttribute &self) { return self.getContext().getObject(); },
3435 "Context that owns the Attribute
")
3436 .def_property_readonly(
3437 "type
", [](PyAttribute &self) { return mlirAttributeGetType(self); })
3440 [](PyAttribute &self, std::string name) {
3441 return PyNamedAttribute(self, std::move(name));
3443 py::keep_alive<0, 1>(), "Binds a name to the attribute
")
3445 [](PyAttribute &self, PyAttribute &other) { return self == other; })
3446 .def("__eq__
", [](PyAttribute &self, py::object &other) { return false; })
3448 [](PyAttribute &self) {
3449 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3452 "dump
", [](PyAttribute &self) { mlirAttributeDump(self); },
3456 [](PyAttribute &self) {
3457 PyPrintAccumulator printAccum;
3458 mlirAttributePrint(self, printAccum.getCallback(),
3459 printAccum.getUserData());
3460 return printAccum.join();
3462 "Returns the assembly form of the Attribute
.")
3464 [](PyAttribute &self) {
3465 // Generally, assembly formats are not printed for __repr__ because
3466 // this can cause exceptionally long debug output and exceptions.
3467 // However, attribute values are generally considered useful and
3468 // are printed. This may need to be re-evaluated if debug dumps end
3469 // up being excessive.
3470 PyPrintAccumulator printAccum;
3471 printAccum.parts.append("Attribute(");
3472 mlirAttributePrint(self, printAccum.getCallback(),
3473 printAccum.getUserData());
3474 printAccum.parts.append(")");
3475 return printAccum.join();
3477 .def_property_readonly(
3479 [](PyAttribute &self) -> MlirTypeID {
3480 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
3481 assert(!mlirTypeIDIsNull(mlirTypeID) &&
3482 "mlirTypeID was expected to be non
-null
.");
3485 .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) {
3486 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
3487 assert(!mlirTypeIDIsNull(mlirTypeID) &&
3488 "mlirTypeID was expected to be non
-null
.");
3489 std::optional<pybind11::function> typeCaster =
3490 PyGlobals::get().lookupTypeCaster(mlirTypeID,
3491 mlirAttributeGetDialect(self));
3493 return py::cast(self);
3494 return typeCaster.value()(self);
3497 //----------------------------------------------------------------------------
3498 // Mapping of PyNamedAttribute
3499 //----------------------------------------------------------------------------
3500 py::class_<PyNamedAttribute>(m, "NamedAttribute
", py::module_local())
3502 [](PyNamedAttribute &self) {
3503 PyPrintAccumulator printAccum;
3504 printAccum.parts.append("NamedAttribute(");
3505 printAccum.parts.append(
3506 py::str(mlirIdentifierStr(self.namedAttr.name).data,
3507 mlirIdentifierStr(self.namedAttr.name).length));
3508 printAccum.parts.append("=");
3509 mlirAttributePrint(self.namedAttr.attribute,
3510 printAccum.getCallback(),
3511 printAccum.getUserData());
3512 printAccum.parts.append(")");
3513 return printAccum.join();
3515 .def_property_readonly(
3517 [](PyNamedAttribute &self) {
3518 return py::str(mlirIdentifierStr(self.namedAttr.name).data,
3519 mlirIdentifierStr(self.namedAttr.name).length);
3521 "The name of the NamedAttribute binding
")
3522 .def_property_readonly(
3524 [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
3525 py::keep_alive<0, 1>(),
3526 "The underlying generic attribute of the NamedAttribute binding
");
3528 //----------------------------------------------------------------------------
3529 // Mapping of PyType.
3530 //----------------------------------------------------------------------------
3531 py::class_<PyType>(m, "Type
", py::module_local())
3532 // Delegate to the PyType copy constructor, which will also lifetime
3533 // extend the backing context which owns the MlirType.
3534 .def(py::init<PyType &>(), py::arg("cast_from_type
"),
3535 "Casts the passed type to the generic Type
")
3536 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
3537 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
3540 [](std::string typeSpec, DefaultingPyMlirContext context) {
3541 PyMlirContext::ErrorCapture errors(context->getRef());
3543 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
3544 if (mlirTypeIsNull(type))
3545 throw MLIRError("Unable to parse type
", errors.take());
3548 py::arg("asm"), py::arg("context
") = py::none(),
3549 kContextParseTypeDocstring)
3550 .def_property_readonly(
3551 "context
", [](PyType &self) { return self.getContext().getObject(); },
3552 "Context that owns the Type
")
3553 .def("__eq__
", [](PyType &self, PyType &other) { return self == other; })
3554 .def("__eq__
", [](PyType &self, py::object &other) { return false; })
3557 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3560 "dump
", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3564 PyPrintAccumulator printAccum;
3565 mlirTypePrint(self, printAccum.getCallback(),
3566 printAccum.getUserData());
3567 return printAccum.join();
3569 "Returns the assembly form of the type
.")
3572 // Generally, assembly formats are not printed for __repr__ because
3573 // this can cause exceptionally long debug output and exceptions.
3574 // However, types are an exception as they typically have compact
3575 // assembly forms and printing them is useful.
3576 PyPrintAccumulator printAccum;
3577 printAccum.parts.append("Type(");
3578 mlirTypePrint(self, printAccum.getCallback(),
3579 printAccum.getUserData());
3580 printAccum.parts.append(")");
3581 return printAccum.join();
3583 .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
3585 MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3586 assert(!mlirTypeIDIsNull(mlirTypeID) &&
3587 "mlirTypeID was expected to be non
-null
.");
3588 std::optional<pybind11::function> typeCaster =
3589 PyGlobals::get().lookupTypeCaster(mlirTypeID,
3590 mlirTypeGetDialect(self));
3592 return py::cast(self);
3593 return typeCaster.value()(self);
3595 .def_property_readonly("typeid", [](PyType &self) -> MlirTypeID {
3596 MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3597 if (!mlirTypeIDIsNull(mlirTypeID))
3600 pybind11::repr(pybind11::cast(self)).cast<std::string>();
3601 throw py::value_error(
3602 (origRepr + llvm::Twine(" has no
typeid.")).str());
3605 //----------------------------------------------------------------------------
3606 // Mapping of PyTypeID.
3607 //----------------------------------------------------------------------------
3608 py::class_<PyTypeID>(m, "TypeID
", py::module_local())
3609 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule)
3610 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule)
3611 // Note, this tests whether the underlying TypeIDs are the same,
3612 // not whether the wrapper MlirTypeIDs are the same, nor whether
3613 // the Python objects are the same (i.e., PyTypeID is a value type).
3615 [](PyTypeID &self, PyTypeID &other) { return self == other; })
3617 [](PyTypeID &self, const py::object &other) { return false; })
3618 // Note, this gives the hash value of the underlying TypeID, not the
3619 // hash value of the Python object, nor the hash value of the
3620 // MlirTypeID wrapper.
3621 .def("__hash__
", [](PyTypeID &self) {
3622 return static_cast<size_t>(mlirTypeIDHashValue(self));
3625 //----------------------------------------------------------------------------
3626 // Mapping of Value.
3627 //----------------------------------------------------------------------------
3628 py::class_<PyValue>(m, "Value
", py::module_local())
3629 .def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value
"))
3630 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
3631 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
3632 .def_property_readonly(
3634 [](PyValue &self) { return self.getParentOperation()->getContext(); },
3635 "Context in which the value lives
.")
3637 "dump
", [](PyValue &self) { mlirValueDump(self.get()); },
3639 .def_property_readonly(
3641 [](PyValue &self) -> py::object {
3642 MlirValue v = self.get();
3643 if (mlirValueIsAOpResult(v)) {
3645 mlirOperationEqual(self.getParentOperation()->get(),
3646 mlirOpResultGetOwner(self.get())) &&
3647 "expected the owner of the value in Python to match that in
"
3649 return self.getParentOperation().getObject();
3652 if (mlirValueIsABlockArgument(v)) {
3653 MlirBlock block = mlirBlockArgumentGetOwner(self.get());
3654 return py::cast(PyBlock(self.getParentOperation(), block));
3657 assert(false && "Value must be a block argument
or an op result
");
3660 .def_property_readonly("uses
",
3662 return PyOpOperandIterator(
3663 mlirValueGetFirstUse(self.get()));
3666 [](PyValue &self, PyValue &other) {
3667 return self.get().ptr == other.get().ptr;
3669 .def("__eq__
", [](PyValue &self, py::object other) { return false; })
3672 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3677 PyPrintAccumulator printAccum;
3678 printAccum.parts.append("Value(");
3679 mlirValuePrint(self.get(), printAccum.getCallback(),
3680 printAccum.getUserData());
3681 printAccum.parts.append(")");
3682 return printAccum.join();
3684 kValueDunderStrDocstring)
3687 [](PyValue &self, bool useLocalScope) {
3688 PyPrintAccumulator printAccum;
3689 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
3691 mlirOpPrintingFlagsUseLocalScope(flags);
3692 MlirAsmState valueState =
3693 mlirAsmStateCreateForValue(self.get(), flags);
3694 mlirValuePrintAsOperand(self.get(), valueState,
3695 printAccum.getCallback(),
3696 printAccum.getUserData());
3697 mlirOpPrintingFlagsDestroy(flags);
3698 mlirAsmStateDestroy(valueState);
3699 return printAccum.join();
3701 py::arg("use_local_scope
") = false)
3704 [](PyValue &self, std::reference_wrapper<PyAsmState> state) {
3705 PyPrintAccumulator printAccum;
3706 MlirAsmState valueState = state.get().get();
3707 mlirValuePrintAsOperand(self.get(), valueState,
3708 printAccum.getCallback(),
3709 printAccum.getUserData());
3710 return printAccum.join();
3712 py::arg("state
"), kGetNameAsOperand)
3713 .def_property_readonly(
3714 "type
", [](PyValue &self) { return mlirValueGetType(self.get()); })
3717 [](PyValue &self, const PyType &type) {
3718 return mlirValueSetType(self.get(), type);
3722 "replace_all_uses_with
",
3723 [](PyValue &self, PyValue &with) {
3724 mlirValueReplaceAllUsesOfWith(self.get(), with.get());
3726 kValueReplaceAllUsesWithDocstring)
3728 "replace_all_uses_except
",
3729 [](MlirValue self, MlirValue with, PyOperation &exception) {
3730 MlirOperation exceptedUser = exception.get();
3731 mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
3733 py::arg("with
"), py::arg("exceptions
"),
3734 kValueReplaceAllUsesExceptDocstring)
3736 "replace_all_uses_except
",
3737 [](MlirValue self, MlirValue with, py::list exceptions) {
3738 // Convert Python list to a SmallVector of MlirOperations
3739 llvm::SmallVector<MlirOperation> exceptionOps;
3740 for (py::handle exception : exceptions) {
3741 exceptionOps.push_back(exception.cast<PyOperation &>().get());
3744 mlirValueReplaceAllUsesExcept(
3745 self, with, static_cast<intptr_t>(exceptionOps.size()),
3746 exceptionOps.data());
3748 py::arg("with
"), py::arg("exceptions
"),
3749 kValueReplaceAllUsesExceptDocstring)
3750 .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
3751 [](PyValue &self) { return self.maybeDownCast(); });
3752 PyBlockArgument::bind(m);
3753 PyOpResult::bind(m);
3754 PyOpOperand::bind(m);
3756 py::class_<PyAsmState>(m, "AsmState
", py::module_local())
3757 .def(py::init<PyValue &, bool>(), py::arg("value
"),
3758 py::arg("use_local_scope
") = false)
3759 .def(py::init<PyOperationBase &, bool>(), py::arg("op
"),
3760 py::arg("use_local_scope
") = false);
3762 //----------------------------------------------------------------------------
3763 // Mapping of SymbolTable.
3764 //----------------------------------------------------------------------------
3765 py::class_<PySymbolTable>(m, "SymbolTable
", py::module_local())
3766 .def(py::init<PyOperationBase &>())
3767 .def("__getitem__
", &PySymbolTable::dunderGetItem)
3768 .def("insert
", &PySymbolTable::insert, py::arg("operation
"))
3769 .def("erase
", &PySymbolTable::erase, py::arg("operation
"))
3770 .def("__delitem__
", &PySymbolTable::dunderDel)
3771 .def("__contains__
",
3772 [](PySymbolTable &table, const std::string &name) {
3773 return !mlirOperationIsNull(mlirSymbolTableLookup(
3774 table, mlirStringRefCreate(name.data(), name.length())));
3777 .def_static("set_symbol_name
", &PySymbolTable::setSymbolName,
3778 py::arg("symbol
"), py::arg("name
"))
3779 .def_static("get_symbol_name
", &PySymbolTable::getSymbolName,
3781 .def_static("get_visibility
", &PySymbolTable::getVisibility,
3783 .def_static("set_visibility
", &PySymbolTable::setVisibility,
3784 py::arg("symbol
"), py::arg("visibility
"))
3785 .def_static("replace_all_symbol_uses
",
3786 &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol
"),
3787 py::arg("new_symbol
"), py::arg("from_op
"))
3788 .def_static("walk_symbol_tables
", &PySymbolTable::walkSymbolTables,
3789 py::arg("from_op
"), py::arg("all_sym_uses_visible
"),
3790 py::arg("callback
"));
3792 // Container bindings.
3793 PyBlockArgumentList::bind(m);
3794 PyBlockIterator::bind(m);
3795 PyBlockList::bind(m);
3796 PyOperationIterator::bind(m);
3797 PyOperationList::bind(m);
3798 PyOpAttributeMap::bind(m);
3799 PyOpOperandIterator::bind(m);
3800 PyOpOperandList::bind(m);
3801 PyOpResultList::bind(m);
3802 PyOpSuccessors::bind(m);
3803 PyRegionIterator::bind(m);
3804 PyRegionList::bind(m);
3807 PyGlobalDebugFlag::bind(m);
3809 // Attribute builder getter.
3810 PyAttrBuilderMap::bind(m);
3812 py::register_local_exception_translator([](std::exception_ptr p) {
3813 // We can't define exceptions with custom fields through pybind, so instead
3814 // the exception class is defined in python and imported here.
3817 std::rethrow_exception(p);
3818 } catch (const MLIRError &e) {
3819 py::object obj = py::module_::import(MAKE_MLIR_PYTHON_QUALNAME("ir
"))
3820 .attr("MLIRError
")(e.message, e.errorDiagnostics);
3821 PyErr_SetObject(PyExc_Exception, obj.ptr());