[MLIR][LLVM] Fold extract of extract (#125980)
[llvm-project.git] / mlir / lib / Bindings / Python / IRCore.cpp
blob2e4b6d1ce35c1b693a34bd13e8fe2c6f31415a12
1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include <filesystem>
10 #include <optional>
11 #include <utility>
13 #include "Globals.h"
14 #include "IRModule.h"
15 #include "NanobindUtils.h"
16 #include "mlir-c/BuiltinAttributes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/Diagnostics.h"
19 #include "mlir-c/IR.h"
20 #include "mlir-c/Support.h"
21 #include "mlir/Bindings/Python/Nanobind.h"
22 #include "mlir/Bindings/Python/NanobindAdaptors.h"
23 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/SmallVector.h"
27 namespace nb = nanobind;
28 using namespace nb::literals;
29 using namespace mlir;
30 using namespace mlir::python;
32 using llvm::SmallVector;
33 using llvm::StringRef;
34 using llvm::Twine;
36 //------------------------------------------------------------------------------
37 // Docstrings (trivial, non-duplicated docstrings are included inline).
38 //------------------------------------------------------------------------------
40 static const char kContextParseTypeDocstring[] =
41 R"(Parses the assembly form of a type.
43 Returns a Type object or raises an MLIRError if the type cannot be parsed.
45 See also: https://mlir.llvm.org/docs/LangRef/#type-system
46 )";
48 static const char kContextGetCallSiteLocationDocstring[] =
49 R"(Gets a Location representing a caller and callsite)";
51 static const char kContextGetFileLocationDocstring[] =
52 R"(Gets a Location representing a file, line and column)";
54 static const char kContextGetFileRangeDocstring[] =
55 R"(Gets a Location representing a file, line and column range)";
57 static const char kContextGetFusedLocationDocstring[] =
58 R"(Gets a Location representing a fused location with optional metadata)";
60 static const char kContextGetNameLocationDocString[] =
61 R"(Gets a Location representing a named location with optional child location)";
63 static const char kModuleParseDocstring[] =
64 R"(Parses a module's assembly format from a string.
66 Returns a new MlirModule or raises an MLIRError if the parsing fails.
68 See also: https://mlir.llvm.org/docs/LangRef/
69 )";
71 static const char kOperationCreateDocstring[] =
72 R"(Creates a new operation.
74 Args:
75 name: Operation name (e.g. "dialect.operation").
76 results: Sequence of Type representing op result types.
77 attributes: Dict of str:Attribute.
78 successors: List of Block for the operation's successors.
79 regions: Number of regions to create.
80 location: A Location object (defaults to resolve from context manager).
81 ip: An InsertionPoint (defaults to resolve from context manager or set to
82 False to disable insertion, even with an insertion point set in the
83 context manager).
84 infer_type: Whether to infer result types.
85 Returns:
86 A new "detached" Operation object. Detached operations can be added
87 to blocks, which causes them to become "attached."
88 )";
90 static const char kOperationPrintDocstring[] =
91 R"(Prints the assembly form of the operation to a file like object.
93 Args:
94 file: The file like object to write to. Defaults to sys.stdout.
95 binary: Whether to write bytes (True) or str (False). Defaults to False.
96 large_elements_limit: Whether to elide elements attributes above this
97 number of elements. Defaults to None (no limit).
98 enable_debug_info: Whether to print debug/location information. Defaults
99 to False.
100 pretty_debug_info: Whether to format debug information for easier reading
101 by a human (warning: the result is unparseable).
102 print_generic_op_form: Whether to print the generic assembly forms of all
103 ops. Defaults to False.
104 use_local_Scope: Whether to print in a way that is more optimized for
105 multi-threaded access but may not be consistent with how the overall
106 module prints.
107 assume_verified: By default, if not printing generic form, the verifier
108 will be run and if it fails, generic form will be printed with a comment
109 about failed verification. While a reasonable default for interactive use,
110 for systematic use, it is often better for the caller to verify explicitly
111 and report failures in a more robust fashion. Set this to True if doing this
112 in order to avoid running a redundant verification. If the IR is actually
113 invalid, behavior is undefined.
114 skip_regions: Whether to skip printing regions. Defaults to False.
117 static const char kOperationPrintStateDocstring[] =
118 R"(Prints the assembly form of the operation to a file like object.
120 Args:
121 file: The file like object to write to. Defaults to sys.stdout.
122 binary: Whether to write bytes (True) or str (False). Defaults to False.
123 state: AsmState capturing the operation numbering and flags.
126 static const char kOperationGetAsmDocstring[] =
127 R"(Gets the assembly form of the operation with all options available.
129 Args:
130 binary: Whether to return a bytes (True) or str (False) object. Defaults to
131 False.
132 ... others ...: See the print() method for common keyword arguments for
133 configuring the printout.
134 Returns:
135 Either a bytes or str object, depending on the setting of the 'binary'
136 argument.
139 static const char kOperationPrintBytecodeDocstring[] =
140 R"(Write the bytecode form of the operation to a file like object.
142 Args:
143 file: The file like object to write to.
144 desired_version: The version of bytecode to emit.
145 Returns:
146 The bytecode writer status.
149 static const char kOperationStrDunderDocstring[] =
150 R"(Gets the assembly form of the operation with default options.
152 If more advanced control over the assembly formatting or I/O options is needed,
153 use the dedicated print or get_asm method, which supports keyword arguments to
154 customize behavior.
157 static const char kDumpDocstring[] =
158 R"(Dumps a debug representation of the object to stderr.)";
160 static const char kAppendBlockDocstring[] =
161 R"(Appends a new block, with argument types as positional args.
163 Returns:
164 The created block.
167 static const char kValueDunderStrDocstring[] =
168 R"(Returns the string form of the value.
170 If the value is a block argument, this is the assembly form of its type and the
171 position in the argument list. If the value is an operation result, this is
172 equivalent to printing the operation that produced it.
175 static const char kGetNameAsOperand[] =
176 R"(Returns the string form of value as an operand (i.e., the ValueID).
179 static const char kValueReplaceAllUsesWithDocstring[] =
180 R"(Replace all uses of value with the new value, updating anything in
181 the IR that uses 'self' to use the other value instead.
184 static const char kValueReplaceAllUsesExceptDocstring[] =
185 R"("Replace all uses of this value with the 'with' value, except for those
186 in 'exceptions'. 'exceptions' can be either a single operation or a list of
187 operations.
190 //------------------------------------------------------------------------------
191 // Utilities.
192 //------------------------------------------------------------------------------
194 /// Helper for creating an @classmethod.
195 template <class Func, typename... Args>
196 nb::object classmethod(Func f, Args... args) {
197 nb::object cf = nb::cpp_function(f, args...);
198 return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr())));
201 static nb::object
202 createCustomDialectWrapper(const std::string &dialectNamespace,
203 nb::object dialectDescriptor) {
204 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
205 if (!dialectClass) {
206 // Use the base class.
207 return nb::cast(PyDialect(std::move(dialectDescriptor)));
210 // Create the custom implementation.
211 return (*dialectClass)(std::move(dialectDescriptor));
214 static MlirStringRef toMlirStringRef(const std::string &s) {
215 return mlirStringRefCreate(s.data(), s.size());
218 static MlirStringRef toMlirStringRef(std::string_view s) {
219 return mlirStringRefCreate(s.data(), s.size());
222 static MlirStringRef toMlirStringRef(const nb::bytes &s) {
223 return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
226 /// Create a block, using the current location context if no locations are
227 /// specified.
228 static MlirBlock createBlock(const nb::sequence &pyArgTypes,
229 const std::optional<nb::sequence> &pyArgLocs) {
230 SmallVector<MlirType> argTypes;
231 argTypes.reserve(nb::len(pyArgTypes));
232 for (const auto &pyType : pyArgTypes)
233 argTypes.push_back(nb::cast<PyType &>(pyType));
235 SmallVector<MlirLocation> argLocs;
236 if (pyArgLocs) {
237 argLocs.reserve(nb::len(*pyArgLocs));
238 for (const auto &pyLoc : *pyArgLocs)
239 argLocs.push_back(nb::cast<PyLocation &>(pyLoc));
240 } else if (!argTypes.empty()) {
241 argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve());
244 if (argTypes.size() != argLocs.size())
245 throw nb::value_error(("Expected " + Twine(argTypes.size()) +
246 " locations, got: " + Twine(argLocs.size()))
247 .str()
248 .c_str());
249 return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
252 /// Wrapper for the global LLVM debugging flag.
253 struct PyGlobalDebugFlag {
254 static void set(nb::object &o, bool enable) {
255 nb::ft_lock_guard lock(mutex);
256 mlirEnableGlobalDebug(enable);
259 static bool get(const nb::object &) {
260 nb::ft_lock_guard lock(mutex);
261 return mlirIsGlobalDebugEnabled();
264 static void bind(nb::module_ &m) {
265 // Debug flags.
266 nb::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
267 .def_prop_rw_static("flag", &PyGlobalDebugFlag::get,
268 &PyGlobalDebugFlag::set, "LLVM-wide debug flag")
269 .def_static(
270 "set_types",
271 [](const std::string &type) {
272 nb::ft_lock_guard lock(mutex);
273 mlirSetGlobalDebugType(type.c_str());
275 "types"_a, "Sets specific debug types to be produced by LLVM")
276 .def_static("set_types", [](const std::vector<std::string> &types) {
277 std::vector<const char *> pointers;
278 pointers.reserve(types.size());
279 for (const std::string &str : types)
280 pointers.push_back(str.c_str());
281 nb::ft_lock_guard lock(mutex);
282 mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
286 private:
287 static nb::ft_mutex mutex;
290 nb::ft_mutex PyGlobalDebugFlag::mutex;
292 struct PyAttrBuilderMap {
293 static bool dunderContains(const std::string &attributeKind) {
294 return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
296 static nb::callable dunderGetItemNamed(const std::string &attributeKind) {
297 auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind);
298 if (!builder)
299 throw nb::key_error(attributeKind.c_str());
300 return *builder;
302 static void dunderSetItemNamed(const std::string &attributeKind,
303 nb::callable func, bool replace) {
304 PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
305 replace);
308 static void bind(nb::module_ &m) {
309 nb::class_<PyAttrBuilderMap>(m, "AttrBuilder")
310 .def_static("contains", &PyAttrBuilderMap::dunderContains)
311 .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed)
312 .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
313 "attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
314 "Register an attribute builder for building MLIR "
315 "attributes from python values.");
319 //------------------------------------------------------------------------------
320 // PyBlock
321 //------------------------------------------------------------------------------
323 nb::object PyBlock::getCapsule() {
324 return nb::steal<nb::object>(mlirPythonBlockToCapsule(get()));
327 //------------------------------------------------------------------------------
328 // Collections.
329 //------------------------------------------------------------------------------
331 namespace {
333 class PyRegionIterator {
334 public:
335 PyRegionIterator(PyOperationRef operation)
336 : operation(std::move(operation)) {}
338 PyRegionIterator &dunderIter() { return *this; }
340 PyRegion dunderNext() {
341 operation->checkValid();
342 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
343 throw nb::stop_iteration();
345 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
346 return PyRegion(operation, region);
349 static void bind(nb::module_ &m) {
350 nb::class_<PyRegionIterator>(m, "RegionIterator")
351 .def("__iter__", &PyRegionIterator::dunderIter)
352 .def("__next__", &PyRegionIterator::dunderNext);
355 private:
356 PyOperationRef operation;
357 int nextIndex = 0;
360 /// Regions of an op are fixed length and indexed numerically so are represented
361 /// with a sequence-like container.
362 class PyRegionList {
363 public:
364 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
366 PyRegionIterator dunderIter() {
367 operation->checkValid();
368 return PyRegionIterator(operation);
371 intptr_t dunderLen() {
372 operation->checkValid();
373 return mlirOperationGetNumRegions(operation->get());
376 PyRegion dunderGetItem(intptr_t index) {
377 // dunderLen checks validity.
378 if (index < 0 || index >= dunderLen()) {
379 throw nb::index_error("attempt to access out of bounds region");
381 MlirRegion region = mlirOperationGetRegion(operation->get(), index);
382 return PyRegion(operation, region);
385 static void bind(nb::module_ &m) {
386 nb::class_<PyRegionList>(m, "RegionSequence")
387 .def("__len__", &PyRegionList::dunderLen)
388 .def("__iter__", &PyRegionList::dunderIter)
389 .def("__getitem__", &PyRegionList::dunderGetItem);
392 private:
393 PyOperationRef operation;
396 class PyBlockIterator {
397 public:
398 PyBlockIterator(PyOperationRef operation, MlirBlock next)
399 : operation(std::move(operation)), next(next) {}
401 PyBlockIterator &dunderIter() { return *this; }
403 PyBlock dunderNext() {
404 operation->checkValid();
405 if (mlirBlockIsNull(next)) {
406 throw nb::stop_iteration();
409 PyBlock returnBlock(operation, next);
410 next = mlirBlockGetNextInRegion(next);
411 return returnBlock;
414 static void bind(nb::module_ &m) {
415 nb::class_<PyBlockIterator>(m, "BlockIterator")
416 .def("__iter__", &PyBlockIterator::dunderIter)
417 .def("__next__", &PyBlockIterator::dunderNext);
420 private:
421 PyOperationRef operation;
422 MlirBlock next;
425 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
426 /// we present them as a more full-featured list-like container but optimize
427 /// it for forward iteration. Blocks are always owned by a region.
428 class PyBlockList {
429 public:
430 PyBlockList(PyOperationRef operation, MlirRegion region)
431 : operation(std::move(operation)), region(region) {}
433 PyBlockIterator dunderIter() {
434 operation->checkValid();
435 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
438 intptr_t dunderLen() {
439 operation->checkValid();
440 intptr_t count = 0;
441 MlirBlock block = mlirRegionGetFirstBlock(region);
442 while (!mlirBlockIsNull(block)) {
443 count += 1;
444 block = mlirBlockGetNextInRegion(block);
446 return count;
449 PyBlock dunderGetItem(intptr_t index) {
450 operation->checkValid();
451 if (index < 0) {
452 throw nb::index_error("attempt to access out of bounds block");
454 MlirBlock block = mlirRegionGetFirstBlock(region);
455 while (!mlirBlockIsNull(block)) {
456 if (index == 0) {
457 return PyBlock(operation, block);
459 block = mlirBlockGetNextInRegion(block);
460 index -= 1;
462 throw nb::index_error("attempt to access out of bounds block");
465 PyBlock appendBlock(const nb::args &pyArgTypes,
466 const std::optional<nb::sequence> &pyArgLocs) {
467 operation->checkValid();
468 MlirBlock block =
469 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
470 mlirRegionAppendOwnedBlock(region, block);
471 return PyBlock(operation, block);
474 static void bind(nb::module_ &m) {
475 nb::class_<PyBlockList>(m, "BlockList")
476 .def("__getitem__", &PyBlockList::dunderGetItem)
477 .def("__iter__", &PyBlockList::dunderIter)
478 .def("__len__", &PyBlockList::dunderLen)
479 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring,
480 nb::arg("args"), nb::kw_only(),
481 nb::arg("arg_locs") = std::nullopt);
484 private:
485 PyOperationRef operation;
486 MlirRegion region;
489 class PyOperationIterator {
490 public:
491 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
492 : parentOperation(std::move(parentOperation)), next(next) {}
494 PyOperationIterator &dunderIter() { return *this; }
496 nb::object dunderNext() {
497 parentOperation->checkValid();
498 if (mlirOperationIsNull(next)) {
499 throw nb::stop_iteration();
502 PyOperationRef returnOperation =
503 PyOperation::forOperation(parentOperation->getContext(), next);
504 next = mlirOperationGetNextInBlock(next);
505 return returnOperation->createOpView();
508 static void bind(nb::module_ &m) {
509 nb::class_<PyOperationIterator>(m, "OperationIterator")
510 .def("__iter__", &PyOperationIterator::dunderIter)
511 .def("__next__", &PyOperationIterator::dunderNext);
514 private:
515 PyOperationRef parentOperation;
516 MlirOperation next;
519 /// Operations are exposed by the C-API as a forward-only linked list. In
520 /// Python, we present them as a more full-featured list-like container but
521 /// optimize it for forward iteration. Iterable operations are always owned
522 /// by a block.
523 class PyOperationList {
524 public:
525 PyOperationList(PyOperationRef parentOperation, MlirBlock block)
526 : parentOperation(std::move(parentOperation)), block(block) {}
528 PyOperationIterator dunderIter() {
529 parentOperation->checkValid();
530 return PyOperationIterator(parentOperation,
531 mlirBlockGetFirstOperation(block));
534 intptr_t dunderLen() {
535 parentOperation->checkValid();
536 intptr_t count = 0;
537 MlirOperation childOp = mlirBlockGetFirstOperation(block);
538 while (!mlirOperationIsNull(childOp)) {
539 count += 1;
540 childOp = mlirOperationGetNextInBlock(childOp);
542 return count;
545 nb::object dunderGetItem(intptr_t index) {
546 parentOperation->checkValid();
547 if (index < 0) {
548 throw nb::index_error("attempt to access out of bounds operation");
550 MlirOperation childOp = mlirBlockGetFirstOperation(block);
551 while (!mlirOperationIsNull(childOp)) {
552 if (index == 0) {
553 return PyOperation::forOperation(parentOperation->getContext(), childOp)
554 ->createOpView();
556 childOp = mlirOperationGetNextInBlock(childOp);
557 index -= 1;
559 throw nb::index_error("attempt to access out of bounds operation");
562 static void bind(nb::module_ &m) {
563 nb::class_<PyOperationList>(m, "OperationList")
564 .def("__getitem__", &PyOperationList::dunderGetItem)
565 .def("__iter__", &PyOperationList::dunderIter)
566 .def("__len__", &PyOperationList::dunderLen);
569 private:
570 PyOperationRef parentOperation;
571 MlirBlock block;
574 class PyOpOperand {
575 public:
576 PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
578 nb::object getOwner() {
579 MlirOperation owner = mlirOpOperandGetOwner(opOperand);
580 PyMlirContextRef context =
581 PyMlirContext::forContext(mlirOperationGetContext(owner));
582 return PyOperation::forOperation(context, owner)->createOpView();
585 size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); }
587 static void bind(nb::module_ &m) {
588 nb::class_<PyOpOperand>(m, "OpOperand")
589 .def_prop_ro("owner", &PyOpOperand::getOwner)
590 .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber);
593 private:
594 MlirOpOperand opOperand;
597 class PyOpOperandIterator {
598 public:
599 PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {}
601 PyOpOperandIterator &dunderIter() { return *this; }
603 PyOpOperand dunderNext() {
604 if (mlirOpOperandIsNull(opOperand))
605 throw nb::stop_iteration();
607 PyOpOperand returnOpOperand(opOperand);
608 opOperand = mlirOpOperandGetNextUse(opOperand);
609 return returnOpOperand;
612 static void bind(nb::module_ &m) {
613 nb::class_<PyOpOperandIterator>(m, "OpOperandIterator")
614 .def("__iter__", &PyOpOperandIterator::dunderIter)
615 .def("__next__", &PyOpOperandIterator::dunderNext);
618 private:
619 MlirOpOperand opOperand;
622 } // namespace
624 //------------------------------------------------------------------------------
625 // PyMlirContext
626 //------------------------------------------------------------------------------
628 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
629 nb::gil_scoped_acquire acquire;
630 nb::ft_lock_guard lock(live_contexts_mutex);
631 auto &liveContexts = getLiveContexts();
632 liveContexts[context.ptr] = this;
635 PyMlirContext::~PyMlirContext() {
636 // Note that the only public way to construct an instance is via the
637 // forContext method, which always puts the associated handle into
638 // liveContexts.
639 nb::gil_scoped_acquire acquire;
641 nb::ft_lock_guard lock(live_contexts_mutex);
642 getLiveContexts().erase(context.ptr);
644 mlirContextDestroy(context);
647 nb::object PyMlirContext::getCapsule() {
648 return nb::steal<nb::object>(mlirPythonContextToCapsule(get()));
651 nb::object PyMlirContext::createFromCapsule(nb::object capsule) {
652 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
653 if (mlirContextIsNull(rawContext))
654 throw nb::python_error();
655 return forContext(rawContext).releaseObject();
658 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
659 nb::gil_scoped_acquire acquire;
660 nb::ft_lock_guard lock(live_contexts_mutex);
661 auto &liveContexts = getLiveContexts();
662 auto it = liveContexts.find(context.ptr);
663 if (it == liveContexts.end()) {
664 // Create.
665 PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
666 nb::object pyRef = nb::cast(unownedContextWrapper);
667 assert(pyRef && "cast to nb::object failed");
668 liveContexts[context.ptr] = unownedContextWrapper;
669 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
671 // Use existing.
672 nb::object pyRef = nb::cast(it->second);
673 return PyMlirContextRef(it->second, std::move(pyRef));
676 nb::ft_mutex PyMlirContext::live_contexts_mutex;
678 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
679 static LiveContextMap liveContexts;
680 return liveContexts;
683 size_t PyMlirContext::getLiveCount() {
684 nb::ft_lock_guard lock(live_contexts_mutex);
685 return getLiveContexts().size();
688 size_t PyMlirContext::getLiveOperationCount() {
689 nb::ft_lock_guard lock(liveOperationsMutex);
690 return liveOperations.size();
693 std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
694 std::vector<PyOperation *> liveObjects;
695 nb::ft_lock_guard lock(liveOperationsMutex);
696 for (auto &entry : liveOperations)
697 liveObjects.push_back(entry.second.second);
698 return liveObjects;
701 size_t PyMlirContext::clearLiveOperations() {
703 LiveOperationMap operations;
705 nb::ft_lock_guard lock(liveOperationsMutex);
706 std::swap(operations, liveOperations);
708 for (auto &op : operations)
709 op.second.second->setInvalid();
710 size_t numInvalidated = operations.size();
711 return numInvalidated;
714 void PyMlirContext::clearOperation(MlirOperation op) {
715 PyOperation *py_op;
717 nb::ft_lock_guard lock(liveOperationsMutex);
718 auto it = liveOperations.find(op.ptr);
719 if (it == liveOperations.end()) {
720 return;
722 py_op = it->second.second;
723 liveOperations.erase(it);
725 py_op->setInvalid();
728 void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
729 typedef struct {
730 PyOperation &rootOp;
731 bool rootSeen;
732 } callBackData;
733 callBackData data{op.getOperation(), false};
734 // Mark all ops below the op that the passmanager will be rooted
735 // at (but not op itself - note the preorder) as invalid.
736 MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
737 void *userData) {
738 callBackData *data = static_cast<callBackData *>(userData);
739 if (LLVM_LIKELY(data->rootSeen))
740 data->rootOp.getOperation().getContext()->clearOperation(op);
741 else
742 data->rootSeen = true;
743 return MlirWalkResult::MlirWalkResultAdvance;
745 mlirOperationWalk(op.getOperation(), invalidatingCallback,
746 static_cast<void *>(&data), MlirWalkPreOrder);
748 void PyMlirContext::clearOperationsInside(MlirOperation op) {
749 PyOperationRef opRef = PyOperation::forOperation(getRef(), op);
750 clearOperationsInside(opRef->getOperation());
753 void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
754 MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
755 void *userData) {
756 PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
757 contextRef->clearOperation(op);
758 return MlirWalkResult::MlirWalkResultAdvance;
760 mlirOperationWalk(op.getOperation(), invalidatingCallback,
761 &op.getOperation().getContext(), MlirWalkPreOrder);
764 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
766 nb::object PyMlirContext::contextEnter(nb::object context) {
767 return PyThreadContextEntry::pushContext(context);
770 void PyMlirContext::contextExit(const nb::object &excType,
771 const nb::object &excVal,
772 const nb::object &excTb) {
773 PyThreadContextEntry::popContext(*this);
776 nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) {
777 // Note that ownership is transferred to the delete callback below by way of
778 // an explicit inc_ref (borrow).
779 PyDiagnosticHandler *pyHandler =
780 new PyDiagnosticHandler(get(), std::move(callback));
781 nb::object pyHandlerObject =
782 nb::cast(pyHandler, nb::rv_policy::take_ownership);
783 pyHandlerObject.inc_ref();
785 // In these C callbacks, the userData is a PyDiagnosticHandler* that is
786 // guaranteed to be known to pybind.
787 auto handlerCallback =
788 +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
789 PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
790 nb::object pyDiagnosticObject =
791 nb::cast(pyDiagnostic, nb::rv_policy::take_ownership);
793 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
794 bool result = false;
796 // Since this can be called from arbitrary C++ contexts, always get the
797 // gil.
798 nb::gil_scoped_acquire gil;
799 try {
800 result = nb::cast<bool>(pyHandler->callback(pyDiagnostic));
801 } catch (std::exception &e) {
802 fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
803 e.what());
804 pyHandler->hadError = true;
808 pyDiagnostic->invalidate();
809 return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure();
811 auto deleteCallback = +[](void *userData) {
812 auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
813 assert(pyHandler->registeredID && "handler is not registered");
814 pyHandler->registeredID.reset();
816 // Decrement reference, balancing the inc_ref() above.
817 nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference);
818 pyHandlerObject.dec_ref();
821 pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
822 get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
823 return pyHandlerObject;
826 MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
827 void *userData) {
828 auto *self = static_cast<ErrorCapture *>(userData);
829 // Check if the context requested we emit errors instead of capturing them.
830 if (self->ctx->emitErrorDiagnostics)
831 return mlirLogicalResultFailure();
833 if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError)
834 return mlirLogicalResultFailure();
836 self->errors.emplace_back(PyDiagnostic(diag).getInfo());
837 return mlirLogicalResultSuccess();
840 PyMlirContext &DefaultingPyMlirContext::resolve() {
841 PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
842 if (!context) {
843 throw std::runtime_error(
844 "An MLIR function requires a Context but none was provided in the call "
845 "or from the surrounding environment. Either pass to the function with "
846 "a 'context=' argument or establish a default using 'with Context():'");
848 return *context;
851 //------------------------------------------------------------------------------
852 // PyThreadContextEntry management
853 //------------------------------------------------------------------------------
855 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
856 static thread_local std::vector<PyThreadContextEntry> stack;
857 return stack;
860 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
861 auto &stack = getStack();
862 if (stack.empty())
863 return nullptr;
864 return &stack.back();
867 void PyThreadContextEntry::push(FrameKind frameKind, nb::object context,
868 nb::object insertionPoint,
869 nb::object location) {
870 auto &stack = getStack();
871 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
872 std::move(location));
873 // If the new stack has more than one entry and the context of the new top
874 // entry matches the previous, copy the insertionPoint and location from the
875 // previous entry if missing from the new top entry.
876 if (stack.size() > 1) {
877 auto &prev = *(stack.rbegin() + 1);
878 auto &current = stack.back();
879 if (current.context.is(prev.context)) {
880 // Default non-context objects from the previous entry.
881 if (!current.insertionPoint)
882 current.insertionPoint = prev.insertionPoint;
883 if (!current.location)
884 current.location = prev.location;
889 PyMlirContext *PyThreadContextEntry::getContext() {
890 if (!context)
891 return nullptr;
892 return nb::cast<PyMlirContext *>(context);
895 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
896 if (!insertionPoint)
897 return nullptr;
898 return nb::cast<PyInsertionPoint *>(insertionPoint);
901 PyLocation *PyThreadContextEntry::getLocation() {
902 if (!location)
903 return nullptr;
904 return nb::cast<PyLocation *>(location);
907 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
908 auto *tos = getTopOfStack();
909 return tos ? tos->getContext() : nullptr;
912 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
913 auto *tos = getTopOfStack();
914 return tos ? tos->getInsertionPoint() : nullptr;
917 PyLocation *PyThreadContextEntry::getDefaultLocation() {
918 auto *tos = getTopOfStack();
919 return tos ? tos->getLocation() : nullptr;
922 nb::object PyThreadContextEntry::pushContext(nb::object context) {
923 push(FrameKind::Context, /*context=*/context,
924 /*insertionPoint=*/nb::object(),
925 /*location=*/nb::object());
926 return context;
929 void PyThreadContextEntry::popContext(PyMlirContext &context) {
930 auto &stack = getStack();
931 if (stack.empty())
932 throw std::runtime_error("Unbalanced Context enter/exit");
933 auto &tos = stack.back();
934 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
935 throw std::runtime_error("Unbalanced Context enter/exit");
936 stack.pop_back();
939 nb::object
940 PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) {
941 PyInsertionPoint &insertionPoint =
942 nb::cast<PyInsertionPoint &>(insertionPointObj);
943 nb::object contextObj =
944 insertionPoint.getBlock().getParentOperation()->getContext().getObject();
945 push(FrameKind::InsertionPoint,
946 /*context=*/contextObj,
947 /*insertionPoint=*/insertionPointObj,
948 /*location=*/nb::object());
949 return insertionPointObj;
952 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
953 auto &stack = getStack();
954 if (stack.empty())
955 throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
956 auto &tos = stack.back();
957 if (tos.frameKind != FrameKind::InsertionPoint &&
958 tos.getInsertionPoint() != &insertionPoint)
959 throw std::runtime_error("Unbalanced InsertionPoint enter/exit");
960 stack.pop_back();
963 nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) {
964 PyLocation &location = nb::cast<PyLocation &>(locationObj);
965 nb::object contextObj = location.getContext().getObject();
966 push(FrameKind::Location, /*context=*/contextObj,
967 /*insertionPoint=*/nb::object(),
968 /*location=*/locationObj);
969 return locationObj;
972 void PyThreadContextEntry::popLocation(PyLocation &location) {
973 auto &stack = getStack();
974 if (stack.empty())
975 throw std::runtime_error("Unbalanced Location enter/exit");
976 auto &tos = stack.back();
977 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
978 throw std::runtime_error("Unbalanced Location enter/exit");
979 stack.pop_back();
982 //------------------------------------------------------------------------------
983 // PyDiagnostic*
984 //------------------------------------------------------------------------------
986 void PyDiagnostic::invalidate() {
987 valid = false;
988 if (materializedNotes) {
989 for (nb::handle noteObject : *materializedNotes) {
990 PyDiagnostic *note = nb::cast<PyDiagnostic *>(noteObject);
991 note->invalidate();
996 PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context,
997 nb::object callback)
998 : context(context), callback(std::move(callback)) {}
1000 PyDiagnosticHandler::~PyDiagnosticHandler() = default;
1002 void PyDiagnosticHandler::detach() {
1003 if (!registeredID)
1004 return;
1005 MlirDiagnosticHandlerID localID = *registeredID;
1006 mlirContextDetachDiagnosticHandler(context, localID);
1007 assert(!registeredID && "should have unregistered");
1008 // Not strictly necessary but keeps stale pointers from being around to cause
1009 // issues.
1010 context = {nullptr};
1013 void PyDiagnostic::checkValid() {
1014 if (!valid) {
1015 throw std::invalid_argument(
1016 "Diagnostic is invalid (used outside of callback)");
1020 MlirDiagnosticSeverity PyDiagnostic::getSeverity() {
1021 checkValid();
1022 return mlirDiagnosticGetSeverity(diagnostic);
1025 PyLocation PyDiagnostic::getLocation() {
1026 checkValid();
1027 MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
1028 MlirContext context = mlirLocationGetContext(loc);
1029 return PyLocation(PyMlirContext::forContext(context), loc);
1032 nb::str PyDiagnostic::getMessage() {
1033 checkValid();
1034 nb::object fileObject = nb::module_::import_("io").attr("StringIO")();
1035 PyFileAccumulator accum(fileObject, /*binary=*/false);
1036 mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
1037 return nb::cast<nb::str>(fileObject.attr("getvalue")());
1040 nb::tuple PyDiagnostic::getNotes() {
1041 checkValid();
1042 if (materializedNotes)
1043 return *materializedNotes;
1044 intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
1045 nb::tuple notes = nb::steal<nb::tuple>(PyTuple_New(numNotes));
1046 for (intptr_t i = 0; i < numNotes; ++i) {
1047 MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
1048 nb::object diagnostic = nb::cast(PyDiagnostic(noteDiag));
1049 PyTuple_SET_ITEM(notes.ptr(), i, diagnostic.release().ptr());
1051 materializedNotes = std::move(notes);
1053 return *materializedNotes;
1056 PyDiagnostic::DiagnosticInfo PyDiagnostic::getInfo() {
1057 std::vector<DiagnosticInfo> notes;
1058 for (nb::handle n : getNotes())
1059 notes.emplace_back(nb::cast<PyDiagnostic>(n).getInfo());
1060 return {getSeverity(), getLocation(), nb::cast<std::string>(getMessage()),
1061 std::move(notes)};
1064 //------------------------------------------------------------------------------
1065 // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
1066 //------------------------------------------------------------------------------
1068 MlirDialect PyDialects::getDialectForKey(const std::string &key,
1069 bool attrError) {
1070 MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
1071 {key.data(), key.size()});
1072 if (mlirDialectIsNull(dialect)) {
1073 std::string msg = (Twine("Dialect '") + key + "' not found").str();
1074 if (attrError)
1075 throw nb::attribute_error(msg.c_str());
1076 throw nb::index_error(msg.c_str());
1078 return dialect;
1081 nb::object PyDialectRegistry::getCapsule() {
1082 return nb::steal<nb::object>(mlirPythonDialectRegistryToCapsule(*this));
1085 PyDialectRegistry PyDialectRegistry::createFromCapsule(nb::object capsule) {
1086 MlirDialectRegistry rawRegistry =
1087 mlirPythonCapsuleToDialectRegistry(capsule.ptr());
1088 if (mlirDialectRegistryIsNull(rawRegistry))
1089 throw nb::python_error();
1090 return PyDialectRegistry(rawRegistry);
1093 //------------------------------------------------------------------------------
1094 // PyLocation
1095 //------------------------------------------------------------------------------
1097 nb::object PyLocation::getCapsule() {
1098 return nb::steal<nb::object>(mlirPythonLocationToCapsule(*this));
1101 PyLocation PyLocation::createFromCapsule(nb::object capsule) {
1102 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
1103 if (mlirLocationIsNull(rawLoc))
1104 throw nb::python_error();
1105 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
1106 rawLoc);
1109 nb::object PyLocation::contextEnter(nb::object locationObj) {
1110 return PyThreadContextEntry::pushLocation(locationObj);
1113 void PyLocation::contextExit(const nb::object &excType,
1114 const nb::object &excVal,
1115 const nb::object &excTb) {
1116 PyThreadContextEntry::popLocation(*this);
1119 PyLocation &DefaultingPyLocation::resolve() {
1120 auto *location = PyThreadContextEntry::getDefaultLocation();
1121 if (!location) {
1122 throw std::runtime_error(
1123 "An MLIR function requires a Location but none was provided in the "
1124 "call or from the surrounding environment. Either pass to the function "
1125 "with a 'loc=' argument or establish a default using 'with loc:'");
1127 return *location;
1130 //------------------------------------------------------------------------------
1131 // PyModule
1132 //------------------------------------------------------------------------------
1134 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
1135 : BaseContextObject(std::move(contextRef)), module(module) {}
1137 PyModule::~PyModule() {
1138 nb::gil_scoped_acquire acquire;
1139 auto &liveModules = getContext()->liveModules;
1140 assert(liveModules.count(module.ptr) == 1 &&
1141 "destroying module not in live map");
1142 liveModules.erase(module.ptr);
1143 mlirModuleDestroy(module);
1146 PyModuleRef PyModule::forModule(MlirModule module) {
1147 MlirContext context = mlirModuleGetContext(module);
1148 PyMlirContextRef contextRef = PyMlirContext::forContext(context);
1150 nb::gil_scoped_acquire acquire;
1151 auto &liveModules = contextRef->liveModules;
1152 auto it = liveModules.find(module.ptr);
1153 if (it == liveModules.end()) {
1154 // Create.
1155 PyModule *unownedModule = new PyModule(std::move(contextRef), module);
1156 // Note that the default return value policy on cast is automatic_reference,
1157 // which does not take ownership (delete will not be called).
1158 // Just be explicit.
1159 nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
1160 unownedModule->handle = pyRef;
1161 liveModules[module.ptr] =
1162 std::make_pair(unownedModule->handle, unownedModule);
1163 return PyModuleRef(unownedModule, std::move(pyRef));
1165 // Use existing.
1166 PyModule *existing = it->second.second;
1167 nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1168 return PyModuleRef(existing, std::move(pyRef));
1171 nb::object PyModule::createFromCapsule(nb::object capsule) {
1172 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
1173 if (mlirModuleIsNull(rawModule))
1174 throw nb::python_error();
1175 return forModule(rawModule).releaseObject();
1178 nb::object PyModule::getCapsule() {
1179 return nb::steal<nb::object>(mlirPythonModuleToCapsule(get()));
1182 //------------------------------------------------------------------------------
1183 // PyOperation
1184 //------------------------------------------------------------------------------
1186 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
1187 : BaseContextObject(std::move(contextRef)), operation(operation) {}
1189 PyOperation::~PyOperation() {
1190 // If the operation has already been invalidated there is nothing to do.
1191 if (!valid)
1192 return;
1194 // Otherwise, invalidate the operation and remove it from live map when it is
1195 // attached.
1196 if (isAttached()) {
1197 getContext()->clearOperation(*this);
1198 } else {
1199 // And destroy it when it is detached, i.e. owned by Python, in which case
1200 // all nested operations must be invalidated at removed from the live map as
1201 // well.
1202 erase();
1206 namespace {
1208 // Constructs a new object of type T in-place on the Python heap, returning a
1209 // PyObjectRef to it, loosely analogous to std::make_shared<T>().
1210 template <typename T, class... Args>
1211 PyObjectRef<T> makeObjectRef(Args &&...args) {
1212 nb::handle type = nb::type<T>();
1213 nb::object instance = nb::inst_alloc(type);
1214 T *ptr = nb::inst_ptr<T>(instance);
1215 new (ptr) T(std::forward<Args>(args)...);
1216 nb::inst_mark_ready(instance);
1217 return PyObjectRef<T>(ptr, std::move(instance));
1220 } // namespace
1222 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
1223 MlirOperation operation,
1224 nb::object parentKeepAlive) {
1225 // Create.
1226 PyOperationRef unownedOperation =
1227 makeObjectRef<PyOperation>(std::move(contextRef), operation);
1228 unownedOperation->handle = unownedOperation.getObject();
1229 if (parentKeepAlive) {
1230 unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
1232 return unownedOperation;
1235 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
1236 MlirOperation operation,
1237 nb::object parentKeepAlive) {
1238 nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
1239 auto &liveOperations = contextRef->liveOperations;
1240 auto it = liveOperations.find(operation.ptr);
1241 if (it == liveOperations.end()) {
1242 // Create.
1243 PyOperationRef result = createInstance(std::move(contextRef), operation,
1244 std::move(parentKeepAlive));
1245 liveOperations[operation.ptr] =
1246 std::make_pair(result.getObject(), result.get());
1247 return result;
1249 // Use existing.
1250 PyOperation *existing = it->second.second;
1251 nb::object pyRef = nb::borrow<nb::object>(it->second.first);
1252 return PyOperationRef(existing, std::move(pyRef));
1255 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
1256 MlirOperation operation,
1257 nb::object parentKeepAlive) {
1258 nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
1259 auto &liveOperations = contextRef->liveOperations;
1260 assert(liveOperations.count(operation.ptr) == 0 &&
1261 "cannot create detached operation that already exists");
1262 (void)liveOperations;
1263 PyOperationRef created = createInstance(std::move(contextRef), operation,
1264 std::move(parentKeepAlive));
1265 liveOperations[operation.ptr] =
1266 std::make_pair(created.getObject(), created.get());
1267 created->attached = false;
1268 return created;
1271 PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
1272 const std::string &sourceStr,
1273 const std::string &sourceName) {
1274 PyMlirContext::ErrorCapture errors(contextRef);
1275 MlirOperation op =
1276 mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
1277 toMlirStringRef(sourceName));
1278 if (mlirOperationIsNull(op))
1279 throw MLIRError("Unable to parse operation assembly", errors.take());
1280 return PyOperation::createDetached(std::move(contextRef), op);
1283 void PyOperation::checkValid() const {
1284 if (!valid) {
1285 throw std::runtime_error("the operation has been invalidated");
1289 void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
1290 bool enableDebugInfo, bool prettyDebugInfo,
1291 bool printGenericOpForm, bool useLocalScope,
1292 bool assumeVerified, nb::object fileObject,
1293 bool binary, bool skipRegions) {
1294 PyOperation &operation = getOperation();
1295 operation.checkValid();
1296 if (fileObject.is_none())
1297 fileObject = nb::module_::import_("sys").attr("stdout");
1299 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1300 if (largeElementsLimit) {
1301 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1302 mlirOpPrintingFlagsElideLargeResourceString(flags, *largeElementsLimit);
1304 if (enableDebugInfo)
1305 mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
1306 /*prettyForm=*/prettyDebugInfo);
1307 if (printGenericOpForm)
1308 mlirOpPrintingFlagsPrintGenericOpForm(flags);
1309 if (useLocalScope)
1310 mlirOpPrintingFlagsUseLocalScope(flags);
1311 if (assumeVerified)
1312 mlirOpPrintingFlagsAssumeVerified(flags);
1313 if (skipRegions)
1314 mlirOpPrintingFlagsSkipRegions(flags);
1316 PyFileAccumulator accum(fileObject, binary);
1317 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1318 accum.getUserData());
1319 mlirOpPrintingFlagsDestroy(flags);
1322 void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
1323 bool binary) {
1324 PyOperation &operation = getOperation();
1325 operation.checkValid();
1326 if (fileObject.is_none())
1327 fileObject = nb::module_::import_("sys").attr("stdout");
1328 PyFileAccumulator accum(fileObject, binary);
1329 mlirOperationPrintWithState(operation, state.get(), accum.getCallback(),
1330 accum.getUserData());
1333 void PyOperationBase::writeBytecode(const nb::object &fileObject,
1334 std::optional<int64_t> bytecodeVersion) {
1335 PyOperation &operation = getOperation();
1336 operation.checkValid();
1337 PyFileAccumulator accum(fileObject, /*binary=*/true);
1339 if (!bytecodeVersion.has_value())
1340 return mlirOperationWriteBytecode(operation, accum.getCallback(),
1341 accum.getUserData());
1343 MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate();
1344 mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion);
1345 MlirLogicalResult res = mlirOperationWriteBytecodeWithConfig(
1346 operation, config, accum.getCallback(), accum.getUserData());
1347 mlirBytecodeWriterConfigDestroy(config);
1348 if (mlirLogicalResultIsFailure(res))
1349 throw nb::value_error((Twine("Unable to honor desired bytecode version ") +
1350 Twine(*bytecodeVersion))
1351 .str()
1352 .c_str());
1355 void PyOperationBase::walk(
1356 std::function<MlirWalkResult(MlirOperation)> callback,
1357 MlirWalkOrder walkOrder) {
1358 PyOperation &operation = getOperation();
1359 operation.checkValid();
1360 struct UserData {
1361 std::function<MlirWalkResult(MlirOperation)> callback;
1362 bool gotException;
1363 std::string exceptionWhat;
1364 nb::object exceptionType;
1366 UserData userData{callback, false, {}, {}};
1367 MlirOperationWalkCallback walkCallback = [](MlirOperation op,
1368 void *userData) {
1369 UserData *calleeUserData = static_cast<UserData *>(userData);
1370 try {
1371 return (calleeUserData->callback)(op);
1372 } catch (nb::python_error &e) {
1373 calleeUserData->gotException = true;
1374 calleeUserData->exceptionWhat = std::string(e.what());
1375 calleeUserData->exceptionType = nb::borrow(e.type());
1376 return MlirWalkResult::MlirWalkResultInterrupt;
1379 mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
1380 if (userData.gotException) {
1381 std::string message("Exception raised in callback: ");
1382 message.append(userData.exceptionWhat);
1383 throw std::runtime_error(message);
1387 nb::object PyOperationBase::getAsm(bool binary,
1388 std::optional<int64_t> largeElementsLimit,
1389 bool enableDebugInfo, bool prettyDebugInfo,
1390 bool printGenericOpForm, bool useLocalScope,
1391 bool assumeVerified, bool skipRegions) {
1392 nb::object fileObject;
1393 if (binary) {
1394 fileObject = nb::module_::import_("io").attr("BytesIO")();
1395 } else {
1396 fileObject = nb::module_::import_("io").attr("StringIO")();
1398 print(/*largeElementsLimit=*/largeElementsLimit,
1399 /*enableDebugInfo=*/enableDebugInfo,
1400 /*prettyDebugInfo=*/prettyDebugInfo,
1401 /*printGenericOpForm=*/printGenericOpForm,
1402 /*useLocalScope=*/useLocalScope,
1403 /*assumeVerified=*/assumeVerified,
1404 /*fileObject=*/fileObject,
1405 /*binary=*/binary,
1406 /*skipRegions=*/skipRegions);
1408 return fileObject.attr("getvalue")();
1411 void PyOperationBase::moveAfter(PyOperationBase &other) {
1412 PyOperation &operation = getOperation();
1413 PyOperation &otherOp = other.getOperation();
1414 operation.checkValid();
1415 otherOp.checkValid();
1416 mlirOperationMoveAfter(operation, otherOp);
1417 operation.parentKeepAlive = otherOp.parentKeepAlive;
1420 void PyOperationBase::moveBefore(PyOperationBase &other) {
1421 PyOperation &operation = getOperation();
1422 PyOperation &otherOp = other.getOperation();
1423 operation.checkValid();
1424 otherOp.checkValid();
1425 mlirOperationMoveBefore(operation, otherOp);
1426 operation.parentKeepAlive = otherOp.parentKeepAlive;
1429 bool PyOperationBase::verify() {
1430 PyOperation &op = getOperation();
1431 PyMlirContext::ErrorCapture errors(op.getContext());
1432 if (!mlirOperationVerify(op.get()))
1433 throw MLIRError("Verification failed", errors.take());
1434 return true;
1437 std::optional<PyOperationRef> PyOperation::getParentOperation() {
1438 checkValid();
1439 if (!isAttached())
1440 throw nb::value_error("Detached operations have no parent");
1441 MlirOperation operation = mlirOperationGetParentOperation(get());
1442 if (mlirOperationIsNull(operation))
1443 return {};
1444 return PyOperation::forOperation(getContext(), operation);
1447 PyBlock PyOperation::getBlock() {
1448 checkValid();
1449 std::optional<PyOperationRef> parentOperation = getParentOperation();
1450 MlirBlock block = mlirOperationGetBlock(get());
1451 assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1452 assert(parentOperation && "Operation has no parent");
1453 return PyBlock{std::move(*parentOperation), block};
1456 nb::object PyOperation::getCapsule() {
1457 checkValid();
1458 return nb::steal<nb::object>(mlirPythonOperationToCapsule(get()));
1461 nb::object PyOperation::createFromCapsule(nb::object capsule) {
1462 MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1463 if (mlirOperationIsNull(rawOperation))
1464 throw nb::python_error();
1465 MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1466 return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1467 .releaseObject();
1470 static void maybeInsertOperation(PyOperationRef &op,
1471 const nb::object &maybeIp) {
1472 // InsertPoint active?
1473 if (!maybeIp.is(nb::cast(false))) {
1474 PyInsertionPoint *ip;
1475 if (maybeIp.is_none()) {
1476 ip = PyThreadContextEntry::getDefaultInsertionPoint();
1477 } else {
1478 ip = nb::cast<PyInsertionPoint *>(maybeIp);
1480 if (ip)
1481 ip->insert(*op.get());
1485 nb::object PyOperation::create(std::string_view name,
1486 std::optional<std::vector<PyType *>> results,
1487 llvm::ArrayRef<MlirValue> operands,
1488 std::optional<nb::dict> attributes,
1489 std::optional<std::vector<PyBlock *>> successors,
1490 int regions, DefaultingPyLocation location,
1491 const nb::object &maybeIp, bool inferType) {
1492 llvm::SmallVector<MlirType, 4> mlirResults;
1493 llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1494 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
1496 // General parameter validation.
1497 if (regions < 0)
1498 throw nb::value_error("number of regions must be >= 0");
1500 // Unpack/validate results.
1501 if (results) {
1502 mlirResults.reserve(results->size());
1503 for (PyType *result : *results) {
1504 // TODO: Verify result type originate from the same context.
1505 if (!result)
1506 throw nb::value_error("result type cannot be None");
1507 mlirResults.push_back(*result);
1510 // Unpack/validate attributes.
1511 if (attributes) {
1512 mlirAttributes.reserve(attributes->size());
1513 for (std::pair<nb::handle, nb::handle> it : *attributes) {
1514 std::string key;
1515 try {
1516 key = nb::cast<std::string>(it.first);
1517 } catch (nb::cast_error &err) {
1518 std::string msg = "Invalid attribute key (not a string) when "
1519 "attempting to create the operation \"" +
1520 std::string(name) + "\" (" + err.what() + ")";
1521 throw nb::type_error(msg.c_str());
1523 try {
1524 auto &attribute = nb::cast<PyAttribute &>(it.second);
1525 // TODO: Verify attribute originates from the same context.
1526 mlirAttributes.emplace_back(std::move(key), attribute);
1527 } catch (nb::cast_error &err) {
1528 std::string msg = "Invalid attribute value for the key \"" + key +
1529 "\" when attempting to create the operation \"" +
1530 std::string(name) + "\" (" + err.what() + ")";
1531 throw nb::type_error(msg.c_str());
1532 } catch (std::runtime_error &) {
1533 // This exception seems thrown when the value is "None".
1534 std::string msg =
1535 "Found an invalid (`None`?) attribute value for the key \"" + key +
1536 "\" when attempting to create the operation \"" +
1537 std::string(name) + "\"";
1538 throw std::runtime_error(msg);
1542 // Unpack/validate successors.
1543 if (successors) {
1544 mlirSuccessors.reserve(successors->size());
1545 for (auto *successor : *successors) {
1546 // TODO: Verify successor originate from the same context.
1547 if (!successor)
1548 throw nb::value_error("successor block cannot be None");
1549 mlirSuccessors.push_back(successor->get());
1553 // Apply unpacked/validated to the operation state. Beyond this
1554 // point, exceptions cannot be thrown or else the state will leak.
1555 MlirOperationState state =
1556 mlirOperationStateGet(toMlirStringRef(name), location);
1557 if (!operands.empty())
1558 mlirOperationStateAddOperands(&state, operands.size(), operands.data());
1559 state.enableResultTypeInference = inferType;
1560 if (!mlirResults.empty())
1561 mlirOperationStateAddResults(&state, mlirResults.size(),
1562 mlirResults.data());
1563 if (!mlirAttributes.empty()) {
1564 // Note that the attribute names directly reference bytes in
1565 // mlirAttributes, so that vector must not be changed from here
1566 // on.
1567 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1568 mlirNamedAttributes.reserve(mlirAttributes.size());
1569 for (auto &it : mlirAttributes)
1570 mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1571 mlirIdentifierGet(mlirAttributeGetContext(it.second),
1572 toMlirStringRef(it.first)),
1573 it.second));
1574 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1575 mlirNamedAttributes.data());
1577 if (!mlirSuccessors.empty())
1578 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1579 mlirSuccessors.data());
1580 if (regions) {
1581 llvm::SmallVector<MlirRegion, 4> mlirRegions;
1582 mlirRegions.resize(regions);
1583 for (int i = 0; i < regions; ++i)
1584 mlirRegions[i] = mlirRegionCreate();
1585 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1586 mlirRegions.data());
1589 // Construct the operation.
1590 MlirOperation operation = mlirOperationCreate(&state);
1591 if (!operation.ptr)
1592 throw nb::value_error("Operation creation failed");
1593 PyOperationRef created =
1594 PyOperation::createDetached(location->getContext(), operation);
1595 maybeInsertOperation(created, maybeIp);
1597 return created.getObject();
1600 nb::object PyOperation::clone(const nb::object &maybeIp) {
1601 MlirOperation clonedOperation = mlirOperationClone(operation);
1602 PyOperationRef cloned =
1603 PyOperation::createDetached(getContext(), clonedOperation);
1604 maybeInsertOperation(cloned, maybeIp);
1606 return cloned->createOpView();
1609 nb::object PyOperation::createOpView() {
1610 checkValid();
1611 MlirIdentifier ident = mlirOperationGetName(get());
1612 MlirStringRef identStr = mlirIdentifierStr(ident);
1613 auto operationCls = PyGlobals::get().lookupOperationClass(
1614 StringRef(identStr.data, identStr.length));
1615 if (operationCls)
1616 return PyOpView::constructDerived(*operationCls, getRef().getObject());
1617 return nb::cast(PyOpView(getRef().getObject()));
1620 void PyOperation::erase() {
1621 checkValid();
1622 getContext()->clearOperationAndInside(*this);
1623 mlirOperationDestroy(operation);
1626 namespace {
1627 /// CRTP base class for Python MLIR values that subclass Value and should be
1628 /// castable from it. The value hierarchy is one level deep and is not supposed
1629 /// to accommodate other levels unless core MLIR changes.
1630 template <typename DerivedTy>
1631 class PyConcreteValue : public PyValue {
1632 public:
1633 // Derived classes must define statics for:
1634 // IsAFunctionTy isaFunction
1635 // const char *pyClassName
1636 // and redefine bindDerived.
1637 using ClassTy = nb::class_<DerivedTy, PyValue>;
1638 using IsAFunctionTy = bool (*)(MlirValue);
1640 PyConcreteValue() = default;
1641 PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1642 : PyValue(operationRef, value) {}
1643 PyConcreteValue(PyValue &orig)
1644 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1646 /// Attempts to cast the original value to the derived type and throws on
1647 /// type mismatches.
1648 static MlirValue castFrom(PyValue &orig) {
1649 if (!DerivedTy::isaFunction(orig.get())) {
1650 auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
1651 throw nb::value_error((Twine("Cannot cast value to ") +
1652 DerivedTy::pyClassName + " (from " + origRepr +
1653 ")")
1654 .str()
1655 .c_str());
1657 return orig.get();
1660 /// Binds the Python module objects to functions of this class.
1661 static void bind(nb::module_ &m) {
1662 auto cls = ClassTy(m, DerivedTy::pyClassName);
1663 cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
1664 cls.def_static(
1665 "isinstance",
1666 [](PyValue &otherValue) -> bool {
1667 return DerivedTy::isaFunction(otherValue);
1669 nb::arg("other_value"));
1670 cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
1671 [](DerivedTy &self) { return self.maybeDownCast(); });
1672 DerivedTy::bindDerived(cls);
1675 /// Implemented by derived classes to add methods to the Python subclass.
1676 static void bindDerived(ClassTy &m) {}
1679 } // namespace
1681 /// Python wrapper for MlirOpResult.
1682 class PyOpResult : public PyConcreteValue<PyOpResult> {
1683 public:
1684 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1685 static constexpr const char *pyClassName = "OpResult";
1686 using PyConcreteValue::PyConcreteValue;
1688 static void bindDerived(ClassTy &c) {
1689 c.def_prop_ro("owner", [](PyOpResult &self) {
1690 assert(
1691 mlirOperationEqual(self.getParentOperation()->get(),
1692 mlirOpResultGetOwner(self.get())) &&
1693 "expected the owner of the value in Python to match that in the IR");
1694 return self.getParentOperation().getObject();
1696 c.def_prop_ro("result_number", [](PyOpResult &self) {
1697 return mlirOpResultGetResultNumber(self.get());
1702 /// Returns the list of types of the values held by container.
1703 template <typename Container>
1704 static std::vector<MlirType> getValueTypes(Container &container,
1705 PyMlirContextRef &context) {
1706 std::vector<MlirType> result;
1707 result.reserve(container.size());
1708 for (int i = 0, e = container.size(); i < e; ++i) {
1709 result.push_back(mlirValueGetType(container.getElement(i).get()));
1711 return result;
1714 /// A list of operation results. Internally, these are stored as consecutive
1715 /// elements, random access is cheap. The (returned) result list is associated
1716 /// with the operation whose results these are, and thus extends the lifetime of
1717 /// this operation.
1718 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1719 public:
1720 static constexpr const char *pyClassName = "OpResultList";
1721 using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
1723 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1724 intptr_t length = -1, intptr_t step = 1)
1725 : Sliceable(startIndex,
1726 length == -1 ? mlirOperationGetNumResults(operation->get())
1727 : length,
1728 step),
1729 operation(std::move(operation)) {}
1731 static void bindDerived(ClassTy &c) {
1732 c.def_prop_ro("types", [](PyOpResultList &self) {
1733 return getValueTypes(self, self.operation->getContext());
1735 c.def_prop_ro("owner", [](PyOpResultList &self) {
1736 return self.operation->createOpView();
1740 PyOperationRef &getOperation() { return operation; }
1742 private:
1743 /// Give the parent CRTP class access to hook implementations below.
1744 friend class Sliceable<PyOpResultList, PyOpResult>;
1746 intptr_t getRawNumElements() {
1747 operation->checkValid();
1748 return mlirOperationGetNumResults(operation->get());
1751 PyOpResult getRawElement(intptr_t index) {
1752 PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1753 return PyOpResult(value);
1756 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1757 return PyOpResultList(operation, startIndex, length, step);
1760 PyOperationRef operation;
1763 //------------------------------------------------------------------------------
1764 // PyOpView
1765 //------------------------------------------------------------------------------
1767 static void populateResultTypes(StringRef name, nb::list resultTypeList,
1768 const nb::object &resultSegmentSpecObj,
1769 std::vector<int32_t> &resultSegmentLengths,
1770 std::vector<PyType *> &resultTypes) {
1771 resultTypes.reserve(resultTypeList.size());
1772 if (resultSegmentSpecObj.is_none()) {
1773 // Non-variadic result unpacking.
1774 for (const auto &it : llvm::enumerate(resultTypeList)) {
1775 try {
1776 resultTypes.push_back(nb::cast<PyType *>(it.value()));
1777 if (!resultTypes.back())
1778 throw nb::cast_error();
1779 } catch (nb::cast_error &err) {
1780 throw nb::value_error((llvm::Twine("Result ") +
1781 llvm::Twine(it.index()) + " of operation \"" +
1782 name + "\" must be a Type (" + err.what() + ")")
1783 .str()
1784 .c_str());
1787 } else {
1788 // Sized result unpacking.
1789 auto resultSegmentSpec = nb::cast<std::vector<int>>(resultSegmentSpecObj);
1790 if (resultSegmentSpec.size() != resultTypeList.size()) {
1791 throw nb::value_error((llvm::Twine("Operation \"") + name +
1792 "\" requires " +
1793 llvm::Twine(resultSegmentSpec.size()) +
1794 " result segments but was provided " +
1795 llvm::Twine(resultTypeList.size()))
1796 .str()
1797 .c_str());
1799 resultSegmentLengths.reserve(resultTypeList.size());
1800 for (const auto &it :
1801 llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1802 int segmentSpec = std::get<1>(it.value());
1803 if (segmentSpec == 1 || segmentSpec == 0) {
1804 // Unpack unary element.
1805 try {
1806 auto *resultType = nb::cast<PyType *>(std::get<0>(it.value()));
1807 if (resultType) {
1808 resultTypes.push_back(resultType);
1809 resultSegmentLengths.push_back(1);
1810 } else if (segmentSpec == 0) {
1811 // Allowed to be optional.
1812 resultSegmentLengths.push_back(0);
1813 } else {
1814 throw nb::value_error(
1815 (llvm::Twine("Result ") + llvm::Twine(it.index()) +
1816 " of operation \"" + name +
1817 "\" must be a Type (was None and result is not optional)")
1818 .str()
1819 .c_str());
1821 } catch (nb::cast_error &err) {
1822 throw nb::value_error((llvm::Twine("Result ") +
1823 llvm::Twine(it.index()) + " of operation \"" +
1824 name + "\" must be a Type (" + err.what() +
1825 ")")
1826 .str()
1827 .c_str());
1829 } else if (segmentSpec == -1) {
1830 // Unpack sequence by appending.
1831 try {
1832 if (std::get<0>(it.value()).is_none()) {
1833 // Treat it as an empty list.
1834 resultSegmentLengths.push_back(0);
1835 } else {
1836 // Unpack the list.
1837 auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
1838 for (nb::handle segmentItem : segment) {
1839 resultTypes.push_back(nb::cast<PyType *>(segmentItem));
1840 if (!resultTypes.back()) {
1841 throw nb::type_error("contained a None item");
1844 resultSegmentLengths.push_back(nb::len(segment));
1846 } catch (std::exception &err) {
1847 // NOTE: Sloppy to be using a catch-all here, but there are at least
1848 // three different unrelated exceptions that can be thrown in the
1849 // above "casts". Just keep the scope above small and catch them all.
1850 throw nb::value_error((llvm::Twine("Result ") +
1851 llvm::Twine(it.index()) + " of operation \"" +
1852 name + "\" must be a Sequence of Types (" +
1853 err.what() + ")")
1854 .str()
1855 .c_str());
1857 } else {
1858 throw nb::value_error("Unexpected segment spec");
1864 static MlirValue getUniqueResult(MlirOperation operation) {
1865 auto numResults = mlirOperationGetNumResults(operation);
1866 if (numResults != 1) {
1867 auto name = mlirIdentifierStr(mlirOperationGetName(operation));
1868 throw nb::value_error((Twine("Cannot call .result on operation ") +
1869 StringRef(name.data, name.length) + " which has " +
1870 Twine(numResults) +
1871 " results (it is only valid for operations with a "
1872 "single result)")
1873 .str()
1874 .c_str());
1876 return mlirOperationGetResult(operation, 0);
1879 static MlirValue getOpResultOrValue(nb::handle operand) {
1880 if (operand.is_none()) {
1881 throw nb::value_error("contained a None item");
1883 PyOperationBase *op;
1884 if (nb::try_cast<PyOperationBase *>(operand, op)) {
1885 return getUniqueResult(op->getOperation());
1887 PyOpResultList *opResultList;
1888 if (nb::try_cast<PyOpResultList *>(operand, opResultList)) {
1889 return getUniqueResult(opResultList->getOperation()->get());
1891 PyValue *value;
1892 if (nb::try_cast<PyValue *>(operand, value)) {
1893 return value->get();
1895 throw nb::value_error("is not a Value");
1898 nb::object PyOpView::buildGeneric(
1899 std::string_view name, std::tuple<int, bool> opRegionSpec,
1900 nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
1901 std::optional<nb::list> resultTypeList, nb::list operandList,
1902 std::optional<nb::dict> attributes,
1903 std::optional<std::vector<PyBlock *>> successors,
1904 std::optional<int> regions, DefaultingPyLocation location,
1905 const nb::object &maybeIp) {
1906 PyMlirContextRef context = location->getContext();
1908 // Class level operation construction metadata.
1909 // Operand and result segment specs are either none, which does no
1910 // variadic unpacking, or a list of ints with segment sizes, where each
1911 // element is either a positive number (typically 1 for a scalar) or -1 to
1912 // indicate that it is derived from the length of the same-indexed operand
1913 // or result (implying that it is a list at that position).
1914 std::vector<int32_t> operandSegmentLengths;
1915 std::vector<int32_t> resultSegmentLengths;
1917 // Validate/determine region count.
1918 int opMinRegionCount = std::get<0>(opRegionSpec);
1919 bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1920 if (!regions) {
1921 regions = opMinRegionCount;
1923 if (*regions < opMinRegionCount) {
1924 throw nb::value_error(
1925 (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1926 llvm::Twine(opMinRegionCount) +
1927 " regions but was built with regions=" + llvm::Twine(*regions))
1928 .str()
1929 .c_str());
1931 if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1932 throw nb::value_error(
1933 (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1934 llvm::Twine(opMinRegionCount) +
1935 " regions but was built with regions=" + llvm::Twine(*regions))
1936 .str()
1937 .c_str());
1940 // Unpack results.
1941 std::vector<PyType *> resultTypes;
1942 if (resultTypeList.has_value()) {
1943 populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
1944 resultSegmentLengths, resultTypes);
1947 // Unpack operands.
1948 llvm::SmallVector<MlirValue, 4> operands;
1949 operands.reserve(operands.size());
1950 if (operandSegmentSpecObj.is_none()) {
1951 // Non-sized operand unpacking.
1952 for (const auto &it : llvm::enumerate(operandList)) {
1953 try {
1954 operands.push_back(getOpResultOrValue(it.value()));
1955 } catch (nb::builtin_exception &err) {
1956 throw nb::value_error((llvm::Twine("Operand ") +
1957 llvm::Twine(it.index()) + " of operation \"" +
1958 name + "\" must be a Value (" + err.what() + ")")
1959 .str()
1960 .c_str());
1963 } else {
1964 // Sized operand unpacking.
1965 auto operandSegmentSpec = nb::cast<std::vector<int>>(operandSegmentSpecObj);
1966 if (operandSegmentSpec.size() != operandList.size()) {
1967 throw nb::value_error((llvm::Twine("Operation \"") + name +
1968 "\" requires " +
1969 llvm::Twine(operandSegmentSpec.size()) +
1970 "operand segments but was provided " +
1971 llvm::Twine(operandList.size()))
1972 .str()
1973 .c_str());
1975 operandSegmentLengths.reserve(operandList.size());
1976 for (const auto &it :
1977 llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1978 int segmentSpec = std::get<1>(it.value());
1979 if (segmentSpec == 1 || segmentSpec == 0) {
1980 // Unpack unary element.
1981 auto &operand = std::get<0>(it.value());
1982 if (!operand.is_none()) {
1983 try {
1985 operands.push_back(getOpResultOrValue(operand));
1986 } catch (nb::builtin_exception &err) {
1987 throw nb::value_error((llvm::Twine("Operand ") +
1988 llvm::Twine(it.index()) +
1989 " of operation \"" + name +
1990 "\" must be a Value (" + err.what() + ")")
1991 .str()
1992 .c_str());
1995 operandSegmentLengths.push_back(1);
1996 } else if (segmentSpec == 0) {
1997 // Allowed to be optional.
1998 operandSegmentLengths.push_back(0);
1999 } else {
2000 throw nb::value_error(
2001 (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
2002 " of operation \"" + name +
2003 "\" must be a Value (was None and operand is not optional)")
2004 .str()
2005 .c_str());
2007 } else if (segmentSpec == -1) {
2008 // Unpack sequence by appending.
2009 try {
2010 if (std::get<0>(it.value()).is_none()) {
2011 // Treat it as an empty list.
2012 operandSegmentLengths.push_back(0);
2013 } else {
2014 // Unpack the list.
2015 auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
2016 for (nb::handle segmentItem : segment) {
2017 operands.push_back(getOpResultOrValue(segmentItem));
2019 operandSegmentLengths.push_back(nb::len(segment));
2021 } catch (std::exception &err) {
2022 // NOTE: Sloppy to be using a catch-all here, but there are at least
2023 // three different unrelated exceptions that can be thrown in the
2024 // above "casts". Just keep the scope above small and catch them all.
2025 throw nb::value_error((llvm::Twine("Operand ") +
2026 llvm::Twine(it.index()) + " of operation \"" +
2027 name + "\" must be a Sequence of Values (" +
2028 err.what() + ")")
2029 .str()
2030 .c_str());
2032 } else {
2033 throw nb::value_error("Unexpected segment spec");
2038 // Merge operand/result segment lengths into attributes if needed.
2039 if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
2040 // Dup.
2041 if (attributes) {
2042 attributes = nb::dict(*attributes);
2043 } else {
2044 attributes = nb::dict();
2046 if (attributes->contains("resultSegmentSizes") ||
2047 attributes->contains("operandSegmentSizes")) {
2048 throw nb::value_error("Manually setting a 'resultSegmentSizes' or "
2049 "'operandSegmentSizes' attribute is unsupported. "
2050 "Use Operation.create for such low-level access.");
2053 // Add resultSegmentSizes attribute.
2054 if (!resultSegmentLengths.empty()) {
2055 MlirAttribute segmentLengthAttr =
2056 mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(),
2057 resultSegmentLengths.data());
2058 (*attributes)["resultSegmentSizes"] =
2059 PyAttribute(context, segmentLengthAttr);
2062 // Add operandSegmentSizes attribute.
2063 if (!operandSegmentLengths.empty()) {
2064 MlirAttribute segmentLengthAttr =
2065 mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(),
2066 operandSegmentLengths.data());
2067 (*attributes)["operandSegmentSizes"] =
2068 PyAttribute(context, segmentLengthAttr);
2072 // Delegate to create.
2073 return PyOperation::create(name,
2074 /*results=*/std::move(resultTypes),
2075 /*operands=*/std::move(operands),
2076 /*attributes=*/std::move(attributes),
2077 /*successors=*/std::move(successors),
2078 /*regions=*/*regions, location, maybeIp,
2079 !resultTypeList);
2082 nb::object PyOpView::constructDerived(const nb::object &cls,
2083 const nb::object &operation) {
2084 nb::handle opViewType = nb::type<PyOpView>();
2085 nb::object instance = cls.attr("__new__")(cls);
2086 opViewType.attr("__init__")(instance, operation);
2087 return instance;
2090 PyOpView::PyOpView(const nb::object &operationObject)
2091 // Casting through the PyOperationBase base-class and then back to the
2092 // Operation lets us accept any PyOperationBase subclass.
2093 : operation(nb::cast<PyOperationBase &>(operationObject).getOperation()),
2094 operationObject(operation.getRef().getObject()) {}
2096 //------------------------------------------------------------------------------
2097 // PyInsertionPoint.
2098 //------------------------------------------------------------------------------
2100 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
2102 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
2103 : refOperation(beforeOperationBase.getOperation().getRef()),
2104 block((*refOperation)->getBlock()) {}
2106 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
2107 PyOperation &operation = operationBase.getOperation();
2108 if (operation.isAttached())
2109 throw nb::value_error(
2110 "Attempt to insert operation that is already attached");
2111 block.getParentOperation()->checkValid();
2112 MlirOperation beforeOp = {nullptr};
2113 if (refOperation) {
2114 // Insert before operation.
2115 (*refOperation)->checkValid();
2116 beforeOp = (*refOperation)->get();
2117 } else {
2118 // Insert at end (before null) is only valid if the block does not
2119 // already end in a known terminator (violating this will cause assertion
2120 // failures later).
2121 if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
2122 throw nb::index_error("Cannot insert operation at the end of a block "
2123 "that already has a terminator. Did you mean to "
2124 "use 'InsertionPoint.at_block_terminator(block)' "
2125 "versus 'InsertionPoint(block)'?");
2128 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
2129 operation.setAttached();
2132 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
2133 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
2134 if (mlirOperationIsNull(firstOp)) {
2135 // Just insert at end.
2136 return PyInsertionPoint(block);
2139 // Insert before first op.
2140 PyOperationRef firstOpRef = PyOperation::forOperation(
2141 block.getParentOperation()->getContext(), firstOp);
2142 return PyInsertionPoint{block, std::move(firstOpRef)};
2145 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
2146 MlirOperation terminator = mlirBlockGetTerminator(block.get());
2147 if (mlirOperationIsNull(terminator))
2148 throw nb::value_error("Block has no terminator");
2149 PyOperationRef terminatorOpRef = PyOperation::forOperation(
2150 block.getParentOperation()->getContext(), terminator);
2151 return PyInsertionPoint{block, std::move(terminatorOpRef)};
2154 nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
2155 return PyThreadContextEntry::pushInsertionPoint(insertPoint);
2158 void PyInsertionPoint::contextExit(const nb::object &excType,
2159 const nb::object &excVal,
2160 const nb::object &excTb) {
2161 PyThreadContextEntry::popInsertionPoint(*this);
2164 //------------------------------------------------------------------------------
2165 // PyAttribute.
2166 //------------------------------------------------------------------------------
2168 bool PyAttribute::operator==(const PyAttribute &other) const {
2169 return mlirAttributeEqual(attr, other.attr);
2172 nb::object PyAttribute::getCapsule() {
2173 return nb::steal<nb::object>(mlirPythonAttributeToCapsule(*this));
2176 PyAttribute PyAttribute::createFromCapsule(nb::object capsule) {
2177 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
2178 if (mlirAttributeIsNull(rawAttr))
2179 throw nb::python_error();
2180 return PyAttribute(
2181 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
2184 //------------------------------------------------------------------------------
2185 // PyNamedAttribute.
2186 //------------------------------------------------------------------------------
2188 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
2189 : ownedName(new std::string(std::move(ownedName))) {
2190 namedAttr = mlirNamedAttributeGet(
2191 mlirIdentifierGet(mlirAttributeGetContext(attr),
2192 toMlirStringRef(*this->ownedName)),
2193 attr);
2196 //------------------------------------------------------------------------------
2197 // PyType.
2198 //------------------------------------------------------------------------------
2200 bool PyType::operator==(const PyType &other) const {
2201 return mlirTypeEqual(type, other.type);
2204 nb::object PyType::getCapsule() {
2205 return nb::steal<nb::object>(mlirPythonTypeToCapsule(*this));
2208 PyType PyType::createFromCapsule(nb::object capsule) {
2209 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
2210 if (mlirTypeIsNull(rawType))
2211 throw nb::python_error();
2212 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
2213 rawType);
2216 //------------------------------------------------------------------------------
2217 // PyTypeID.
2218 //------------------------------------------------------------------------------
2220 nb::object PyTypeID::getCapsule() {
2221 return nb::steal<nb::object>(mlirPythonTypeIDToCapsule(*this));
2224 PyTypeID PyTypeID::createFromCapsule(nb::object capsule) {
2225 MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr());
2226 if (mlirTypeIDIsNull(mlirTypeID))
2227 throw nb::python_error();
2228 return PyTypeID(mlirTypeID);
2230 bool PyTypeID::operator==(const PyTypeID &other) const {
2231 return mlirTypeIDEqual(typeID, other.typeID);
2234 //------------------------------------------------------------------------------
2235 // PyValue and subclasses.
2236 //------------------------------------------------------------------------------
2238 nb::object PyValue::getCapsule() {
2239 return nb::steal<nb::object>(mlirPythonValueToCapsule(get()));
2242 nb::object PyValue::maybeDownCast() {
2243 MlirType type = mlirValueGetType(get());
2244 MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
2245 assert(!mlirTypeIDIsNull(mlirTypeID) &&
2246 "mlirTypeID was expected to be non-null.");
2247 std::optional<nb::callable> valueCaster =
2248 PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
2249 // nb::rv_policy::move means use std::move to move the return value
2250 // contents into a new instance that will be owned by Python.
2251 nb::object thisObj = nb::cast(this, nb::rv_policy::move);
2252 if (!valueCaster)
2253 return thisObj;
2254 return valueCaster.value()(thisObj);
2257 PyValue PyValue::createFromCapsule(nb::object capsule) {
2258 MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
2259 if (mlirValueIsNull(value))
2260 throw nb::python_error();
2261 MlirOperation owner;
2262 if (mlirValueIsAOpResult(value))
2263 owner = mlirOpResultGetOwner(value);
2264 if (mlirValueIsABlockArgument(value))
2265 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
2266 if (mlirOperationIsNull(owner))
2267 throw nb::python_error();
2268 MlirContext ctx = mlirOperationGetContext(owner);
2269 PyOperationRef ownerRef =
2270 PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
2271 return PyValue(ownerRef, value);
2274 //------------------------------------------------------------------------------
2275 // PySymbolTable.
2276 //------------------------------------------------------------------------------
2278 PySymbolTable::PySymbolTable(PyOperationBase &operation)
2279 : operation(operation.getOperation().getRef()) {
2280 symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
2281 if (mlirSymbolTableIsNull(symbolTable)) {
2282 throw nb::type_error("Operation is not a Symbol Table.");
2286 nb::object PySymbolTable::dunderGetItem(const std::string &name) {
2287 operation->checkValid();
2288 MlirOperation symbol = mlirSymbolTableLookup(
2289 symbolTable, mlirStringRefCreate(name.data(), name.length()));
2290 if (mlirOperationIsNull(symbol))
2291 throw nb::key_error(
2292 ("Symbol '" + name + "' not in the symbol table.").c_str());
2294 return PyOperation::forOperation(operation->getContext(), symbol,
2295 operation.getObject())
2296 ->createOpView();
2299 void PySymbolTable::erase(PyOperationBase &symbol) {
2300 operation->checkValid();
2301 symbol.getOperation().checkValid();
2302 mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
2303 // The operation is also erased, so we must invalidate it. There may be Python
2304 // references to this operation so we don't want to delete it from the list of
2305 // live operations here.
2306 symbol.getOperation().valid = false;
2309 void PySymbolTable::dunderDel(const std::string &name) {
2310 nb::object operation = dunderGetItem(name);
2311 erase(nb::cast<PyOperationBase &>(operation));
2314 MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) {
2315 operation->checkValid();
2316 symbol.getOperation().checkValid();
2317 MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
2318 symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
2319 if (mlirAttributeIsNull(symbolAttr))
2320 throw nb::value_error("Expected operation to have a symbol name.");
2321 return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get());
2324 MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
2325 // Op must already be a symbol.
2326 PyOperation &operation = symbol.getOperation();
2327 operation.checkValid();
2328 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
2329 MlirAttribute existingNameAttr =
2330 mlirOperationGetAttributeByName(operation.get(), attrName);
2331 if (mlirAttributeIsNull(existingNameAttr))
2332 throw nb::value_error("Expected operation to have a symbol name.");
2333 return existingNameAttr;
2336 void PySymbolTable::setSymbolName(PyOperationBase &symbol,
2337 const std::string &name) {
2338 // Op must already be a symbol.
2339 PyOperation &operation = symbol.getOperation();
2340 operation.checkValid();
2341 MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
2342 MlirAttribute existingNameAttr =
2343 mlirOperationGetAttributeByName(operation.get(), attrName);
2344 if (mlirAttributeIsNull(existingNameAttr))
2345 throw nb::value_error("Expected operation to have a symbol name.");
2346 MlirAttribute newNameAttr =
2347 mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
2348 mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
2351 MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
2352 PyOperation &operation = symbol.getOperation();
2353 operation.checkValid();
2354 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
2355 MlirAttribute existingVisAttr =
2356 mlirOperationGetAttributeByName(operation.get(), attrName);
2357 if (mlirAttributeIsNull(existingVisAttr))
2358 throw nb::value_error("Expected operation to have a symbol visibility.");
2359 return existingVisAttr;
2362 void PySymbolTable::setVisibility(PyOperationBase &symbol,
2363 const std::string &visibility) {
2364 if (visibility != "public" && visibility != "private" &&
2365 visibility != "nested")
2366 throw nb::value_error(
2367 "Expected visibility to be 'public', 'private' or 'nested'");
2368 PyOperation &operation = symbol.getOperation();
2369 operation.checkValid();
2370 MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
2371 MlirAttribute existingVisAttr =
2372 mlirOperationGetAttributeByName(operation.get(), attrName);
2373 if (mlirAttributeIsNull(existingVisAttr))
2374 throw nb::value_error("Expected operation to have a symbol visibility.");
2375 MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
2376 toMlirStringRef(visibility));
2377 mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
2380 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
2381 const std::string &newSymbol,
2382 PyOperationBase &from) {
2383 PyOperation &fromOperation = from.getOperation();
2384 fromOperation.checkValid();
2385 if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
2386 toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
2387 from.getOperation())))
2389 throw nb::value_error("Symbol rename failed");
2392 void PySymbolTable::walkSymbolTables(PyOperationBase &from,
2393 bool allSymUsesVisible,
2394 nb::object callback) {
2395 PyOperation &fromOperation = from.getOperation();
2396 fromOperation.checkValid();
2397 struct UserData {
2398 PyMlirContextRef context;
2399 nb::object callback;
2400 bool gotException;
2401 std::string exceptionWhat;
2402 nb::object exceptionType;
2404 UserData userData{
2405 fromOperation.getContext(), std::move(callback), false, {}, {}};
2406 mlirSymbolTableWalkSymbolTables(
2407 fromOperation.get(), allSymUsesVisible,
2408 [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
2409 UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
2410 auto pyFoundOp =
2411 PyOperation::forOperation(calleeUserData->context, foundOp);
2412 if (calleeUserData->gotException)
2413 return;
2414 try {
2415 calleeUserData->callback(pyFoundOp.getObject(), isVisible);
2416 } catch (nb::python_error &e) {
2417 calleeUserData->gotException = true;
2418 calleeUserData->exceptionWhat = e.what();
2419 calleeUserData->exceptionType = nb::borrow(e.type());
2422 static_cast<void *>(&userData));
2423 if (userData.gotException) {
2424 std::string message("Exception raised in callback: ");
2425 message.append(userData.exceptionWhat);
2426 throw std::runtime_error(message);
2430 namespace {
2432 /// Python wrapper for MlirBlockArgument.
2433 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
2434 public:
2435 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
2436 static constexpr const char *pyClassName = "BlockArgument";
2437 using PyConcreteValue::PyConcreteValue;
2439 static void bindDerived(ClassTy &c) {
2440 c.def_prop_ro("owner", [](PyBlockArgument &self) {
2441 return PyBlock(self.getParentOperation(),
2442 mlirBlockArgumentGetOwner(self.get()));
2444 c.def_prop_ro("arg_number", [](PyBlockArgument &self) {
2445 return mlirBlockArgumentGetArgNumber(self.get());
2447 c.def(
2448 "set_type",
2449 [](PyBlockArgument &self, PyType type) {
2450 return mlirBlockArgumentSetType(self.get(), type);
2452 nb::arg("type"));
2456 /// A list of block arguments. Internally, these are stored as consecutive
2457 /// elements, random access is cheap. The argument list is associated with the
2458 /// operation that contains the block (detached blocks are not allowed in
2459 /// Python bindings) and extends its lifetime.
2460 class PyBlockArgumentList
2461 : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
2462 public:
2463 static constexpr const char *pyClassName = "BlockArgumentList";
2464 using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
2466 PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
2467 intptr_t startIndex = 0, intptr_t length = -1,
2468 intptr_t step = 1)
2469 : Sliceable(startIndex,
2470 length == -1 ? mlirBlockGetNumArguments(block) : length,
2471 step),
2472 operation(std::move(operation)), block(block) {}
2474 static void bindDerived(ClassTy &c) {
2475 c.def_prop_ro("types", [](PyBlockArgumentList &self) {
2476 return getValueTypes(self, self.operation->getContext());
2480 private:
2481 /// Give the parent CRTP class access to hook implementations below.
2482 friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
2484 /// Returns the number of arguments in the list.
2485 intptr_t getRawNumElements() {
2486 operation->checkValid();
2487 return mlirBlockGetNumArguments(block);
2490 /// Returns `pos`-the element in the list.
2491 PyBlockArgument getRawElement(intptr_t pos) {
2492 MlirValue argument = mlirBlockGetArgument(block, pos);
2493 return PyBlockArgument(operation, argument);
2496 /// Returns a sublist of this list.
2497 PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
2498 intptr_t step) {
2499 return PyBlockArgumentList(operation, block, startIndex, length, step);
2502 PyOperationRef operation;
2503 MlirBlock block;
2506 /// A list of operation operands. Internally, these are stored as consecutive
2507 /// elements, random access is cheap. The (returned) operand list is associated
2508 /// with the operation whose operands these are, and thus extends the lifetime
2509 /// of this operation.
2510 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2511 public:
2512 static constexpr const char *pyClassName = "OpOperandList";
2513 using SliceableT = Sliceable<PyOpOperandList, PyValue>;
2515 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2516 intptr_t length = -1, intptr_t step = 1)
2517 : Sliceable(startIndex,
2518 length == -1 ? mlirOperationGetNumOperands(operation->get())
2519 : length,
2520 step),
2521 operation(operation) {}
2523 void dunderSetItem(intptr_t index, PyValue value) {
2524 index = wrapIndex(index);
2525 mlirOperationSetOperand(operation->get(), index, value.get());
2528 static void bindDerived(ClassTy &c) {
2529 c.def("__setitem__", &PyOpOperandList::dunderSetItem);
2532 private:
2533 /// Give the parent CRTP class access to hook implementations below.
2534 friend class Sliceable<PyOpOperandList, PyValue>;
2536 intptr_t getRawNumElements() {
2537 operation->checkValid();
2538 return mlirOperationGetNumOperands(operation->get());
2541 PyValue getRawElement(intptr_t pos) {
2542 MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2543 MlirOperation owner;
2544 if (mlirValueIsAOpResult(operand))
2545 owner = mlirOpResultGetOwner(operand);
2546 else if (mlirValueIsABlockArgument(operand))
2547 owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
2548 else
2549 assert(false && "Value must be an block arg or op result.");
2550 PyOperationRef pyOwner =
2551 PyOperation::forOperation(operation->getContext(), owner);
2552 return PyValue(pyOwner, operand);
2555 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2556 return PyOpOperandList(operation, startIndex, length, step);
2559 PyOperationRef operation;
2562 /// A list of operation successors. Internally, these are stored as consecutive
2563 /// elements, random access is cheap. The (returned) successor list is
2564 /// associated with the operation whose successors these are, and thus extends
2565 /// the lifetime of this operation.
2566 class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
2567 public:
2568 static constexpr const char *pyClassName = "OpSuccessors";
2570 PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0,
2571 intptr_t length = -1, intptr_t step = 1)
2572 : Sliceable(startIndex,
2573 length == -1 ? mlirOperationGetNumSuccessors(operation->get())
2574 : length,
2575 step),
2576 operation(operation) {}
2578 void dunderSetItem(intptr_t index, PyBlock block) {
2579 index = wrapIndex(index);
2580 mlirOperationSetSuccessor(operation->get(), index, block.get());
2583 static void bindDerived(ClassTy &c) {
2584 c.def("__setitem__", &PyOpSuccessors::dunderSetItem);
2587 private:
2588 /// Give the parent CRTP class access to hook implementations below.
2589 friend class Sliceable<PyOpSuccessors, PyBlock>;
2591 intptr_t getRawNumElements() {
2592 operation->checkValid();
2593 return mlirOperationGetNumSuccessors(operation->get());
2596 PyBlock getRawElement(intptr_t pos) {
2597 MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos);
2598 return PyBlock(operation, block);
2601 PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2602 return PyOpSuccessors(operation, startIndex, length, step);
2605 PyOperationRef operation;
2608 /// A list of operation attributes. Can be indexed by name, producing
2609 /// attributes, or by index, producing named attributes.
2610 class PyOpAttributeMap {
2611 public:
2612 PyOpAttributeMap(PyOperationRef operation)
2613 : operation(std::move(operation)) {}
2615 MlirAttribute dunderGetItemNamed(const std::string &name) {
2616 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2617 toMlirStringRef(name));
2618 if (mlirAttributeIsNull(attr)) {
2619 throw nb::key_error("attempt to access a non-existent attribute");
2621 return attr;
2624 PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2625 if (index < 0 || index >= dunderLen()) {
2626 throw nb::index_error("attempt to access out of bounds attribute");
2628 MlirNamedAttribute namedAttr =
2629 mlirOperationGetAttribute(operation->get(), index);
2630 return PyNamedAttribute(
2631 namedAttr.attribute,
2632 std::string(mlirIdentifierStr(namedAttr.name).data,
2633 mlirIdentifierStr(namedAttr.name).length));
2636 void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2637 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2638 attr);
2641 void dunderDelItem(const std::string &name) {
2642 int removed = mlirOperationRemoveAttributeByName(operation->get(),
2643 toMlirStringRef(name));
2644 if (!removed)
2645 throw nb::key_error("attempt to delete a non-existent attribute");
2648 intptr_t dunderLen() {
2649 return mlirOperationGetNumAttributes(operation->get());
2652 bool dunderContains(const std::string &name) {
2653 return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
2654 operation->get(), toMlirStringRef(name)));
2657 static void bind(nb::module_ &m) {
2658 nb::class_<PyOpAttributeMap>(m, "OpAttributeMap")
2659 .def("__contains__", &PyOpAttributeMap::dunderContains)
2660 .def("__len__", &PyOpAttributeMap::dunderLen)
2661 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2662 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2663 .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2664 .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2667 private:
2668 PyOperationRef operation;
2671 } // namespace
2673 //------------------------------------------------------------------------------
2674 // Populates the core exports of the 'ir' submodule.
2675 //------------------------------------------------------------------------------
2677 void mlir::python::populateIRCore(nb::module_ &m) {
2678 // disable leak warnings which tend to be false positives.
2679 nb::set_leak_warnings(false);
2680 //----------------------------------------------------------------------------
2681 // Enums.
2682 //----------------------------------------------------------------------------
2683 nb::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity")
2684 .value("ERROR", MlirDiagnosticError)
2685 .value("WARNING", MlirDiagnosticWarning)
2686 .value("NOTE", MlirDiagnosticNote)
2687 .value("REMARK", MlirDiagnosticRemark);
2689 nb::enum_<MlirWalkOrder>(m, "WalkOrder")
2690 .value("PRE_ORDER", MlirWalkPreOrder)
2691 .value("POST_ORDER", MlirWalkPostOrder);
2693 nb::enum_<MlirWalkResult>(m, "WalkResult")
2694 .value("ADVANCE", MlirWalkResultAdvance)
2695 .value("INTERRUPT", MlirWalkResultInterrupt)
2696 .value("SKIP", MlirWalkResultSkip);
2698 //----------------------------------------------------------------------------
2699 // Mapping of Diagnostics.
2700 //----------------------------------------------------------------------------
2701 nb::class_<PyDiagnostic>(m, "Diagnostic")
2702 .def_prop_ro("severity", &PyDiagnostic::getSeverity)
2703 .def_prop_ro("location", &PyDiagnostic::getLocation)
2704 .def_prop_ro("message", &PyDiagnostic::getMessage)
2705 .def_prop_ro("notes", &PyDiagnostic::getNotes)
2706 .def("__str__", [](PyDiagnostic &self) -> nb::str {
2707 if (!self.isValid())
2708 return nb::str("<Invalid Diagnostic>");
2709 return self.getMessage();
2712 nb::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo")
2713 .def("__init__",
2714 [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) {
2715 new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo());
2717 .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity)
2718 .def_ro("location", &PyDiagnostic::DiagnosticInfo::location)
2719 .def_ro("message", &PyDiagnostic::DiagnosticInfo::message)
2720 .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes)
2721 .def("__str__",
2722 [](PyDiagnostic::DiagnosticInfo &self) { return self.message; });
2724 nb::class_<PyDiagnosticHandler>(m, "DiagnosticHandler")
2725 .def("detach", &PyDiagnosticHandler::detach)
2726 .def_prop_ro("attached", &PyDiagnosticHandler::isAttached)
2727 .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError)
2728 .def("__enter__", &PyDiagnosticHandler::contextEnter)
2729 .def("__exit__", &PyDiagnosticHandler::contextExit,
2730 nb::arg("exc_type").none(), nb::arg("exc_value").none(),
2731 nb::arg("traceback").none());
2733 //----------------------------------------------------------------------------
2734 // Mapping of MlirContext.
2735 // Note that this is exported as _BaseContext. The containing, Python level
2736 // __init__.py will subclass it with site-specific functionality and set a
2737 // "Context" attribute on this module.
2738 //----------------------------------------------------------------------------
2739 nb::class_<PyMlirContext>(m, "_BaseContext")
2740 .def("__init__",
2741 [](PyMlirContext &self) {
2742 MlirContext context = mlirContextCreateWithThreading(false);
2743 new (&self) PyMlirContext(context);
2745 .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2746 .def("_get_context_again",
2747 [](PyMlirContext &self) {
2748 PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2749 return ref.releaseObject();
2751 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2752 .def("_get_live_operation_objects",
2753 &PyMlirContext::getLiveOperationObjects)
2754 .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
2755 .def("_clear_live_operations_inside",
2756 nb::overload_cast<MlirOperation>(
2757 &PyMlirContext::clearOperationsInside))
2758 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2759 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
2760 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2761 .def("__enter__", &PyMlirContext::contextEnter)
2762 .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
2763 nb::arg("exc_value").none(), nb::arg("traceback").none())
2764 .def_prop_ro_static(
2765 "current",
2766 [](nb::object & /*class*/) {
2767 auto *context = PyThreadContextEntry::getDefaultContext();
2768 if (!context)
2769 return nb::none();
2770 return nb::cast(context);
2772 "Gets the Context bound to the current thread or raises ValueError")
2773 .def_prop_ro(
2774 "dialects",
2775 [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2776 "Gets a container for accessing dialects by name")
2777 .def_prop_ro(
2778 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2779 "Alias for 'dialect'")
2780 .def(
2781 "get_dialect_descriptor",
2782 [=](PyMlirContext &self, std::string &name) {
2783 MlirDialect dialect = mlirContextGetOrLoadDialect(
2784 self.get(), {name.data(), name.size()});
2785 if (mlirDialectIsNull(dialect)) {
2786 throw nb::value_error(
2787 (Twine("Dialect '") + name + "' not found").str().c_str());
2789 return PyDialectDescriptor(self.getRef(), dialect);
2791 nb::arg("dialect_name"),
2792 "Gets or loads a dialect by name, returning its descriptor object")
2793 .def_prop_rw(
2794 "allow_unregistered_dialects",
2795 [](PyMlirContext &self) -> bool {
2796 return mlirContextGetAllowUnregisteredDialects(self.get());
2798 [](PyMlirContext &self, bool value) {
2799 mlirContextSetAllowUnregisteredDialects(self.get(), value);
2801 .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
2802 nb::arg("callback"),
2803 "Attaches a diagnostic handler that will receive callbacks")
2804 .def(
2805 "enable_multithreading",
2806 [](PyMlirContext &self, bool enable) {
2807 mlirContextEnableMultithreading(self.get(), enable);
2809 nb::arg("enable"))
2810 .def(
2811 "is_registered_operation",
2812 [](PyMlirContext &self, std::string &name) {
2813 return mlirContextIsRegisteredOperation(
2814 self.get(), MlirStringRef{name.data(), name.size()});
2816 nb::arg("operation_name"))
2817 .def(
2818 "append_dialect_registry",
2819 [](PyMlirContext &self, PyDialectRegistry &registry) {
2820 mlirContextAppendDialectRegistry(self.get(), registry);
2822 nb::arg("registry"))
2823 .def_prop_rw("emit_error_diagnostics", nullptr,
2824 &PyMlirContext::setEmitErrorDiagnostics,
2825 "Emit error diagnostics to diagnostic handlers. By default "
2826 "error diagnostics are captured and reported through "
2827 "MLIRError exceptions.")
2828 .def("load_all_available_dialects", [](PyMlirContext &self) {
2829 mlirContextLoadAllAvailableDialects(self.get());
2832 //----------------------------------------------------------------------------
2833 // Mapping of PyDialectDescriptor
2834 //----------------------------------------------------------------------------
2835 nb::class_<PyDialectDescriptor>(m, "DialectDescriptor")
2836 .def_prop_ro("namespace",
2837 [](PyDialectDescriptor &self) {
2838 MlirStringRef ns = mlirDialectGetNamespace(self.get());
2839 return nb::str(ns.data, ns.length);
2841 .def("__repr__", [](PyDialectDescriptor &self) {
2842 MlirStringRef ns = mlirDialectGetNamespace(self.get());
2843 std::string repr("<DialectDescriptor ");
2844 repr.append(ns.data, ns.length);
2845 repr.append(">");
2846 return repr;
2849 //----------------------------------------------------------------------------
2850 // Mapping of PyDialects
2851 //----------------------------------------------------------------------------
2852 nb::class_<PyDialects>(m, "Dialects")
2853 .def("__getitem__",
2854 [=](PyDialects &self, std::string keyName) {
2855 MlirDialect dialect =
2856 self.getDialectForKey(keyName, /*attrError=*/false);
2857 nb::object descriptor =
2858 nb::cast(PyDialectDescriptor{self.getContext(), dialect});
2859 return createCustomDialectWrapper(keyName, std::move(descriptor));
2861 .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2862 MlirDialect dialect =
2863 self.getDialectForKey(attrName, /*attrError=*/true);
2864 nb::object descriptor =
2865 nb::cast(PyDialectDescriptor{self.getContext(), dialect});
2866 return createCustomDialectWrapper(attrName, std::move(descriptor));
2869 //----------------------------------------------------------------------------
2870 // Mapping of PyDialect
2871 //----------------------------------------------------------------------------
2872 nb::class_<PyDialect>(m, "Dialect")
2873 .def(nb::init<nb::object>(), nb::arg("descriptor"))
2874 .def_prop_ro("descriptor",
2875 [](PyDialect &self) { return self.getDescriptor(); })
2876 .def("__repr__", [](nb::object self) {
2877 auto clazz = self.attr("__class__");
2878 return nb::str("<Dialect ") +
2879 self.attr("descriptor").attr("namespace") + nb::str(" (class ") +
2880 clazz.attr("__module__") + nb::str(".") +
2881 clazz.attr("__name__") + nb::str(")>");
2884 //----------------------------------------------------------------------------
2885 // Mapping of PyDialectRegistry
2886 //----------------------------------------------------------------------------
2887 nb::class_<PyDialectRegistry>(m, "DialectRegistry")
2888 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule)
2889 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule)
2890 .def(nb::init<>());
2892 //----------------------------------------------------------------------------
2893 // Mapping of Location
2894 //----------------------------------------------------------------------------
2895 nb::class_<PyLocation>(m, "Location")
2896 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2897 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2898 .def("__enter__", &PyLocation::contextEnter)
2899 .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(),
2900 nb::arg("exc_value").none(), nb::arg("traceback").none())
2901 .def("__eq__",
2902 [](PyLocation &self, PyLocation &other) -> bool {
2903 return mlirLocationEqual(self, other);
2905 .def("__eq__", [](PyLocation &self, nb::object other) { return false; })
2906 .def_prop_ro_static(
2907 "current",
2908 [](nb::object & /*class*/) {
2909 auto *loc = PyThreadContextEntry::getDefaultLocation();
2910 if (!loc)
2911 throw nb::value_error("No current Location");
2912 return loc;
2914 "Gets the Location bound to the current thread or raises ValueError")
2915 .def_static(
2916 "unknown",
2917 [](DefaultingPyMlirContext context) {
2918 return PyLocation(context->getRef(),
2919 mlirLocationUnknownGet(context->get()));
2921 nb::arg("context").none() = nb::none(),
2922 "Gets a Location representing an unknown location")
2923 .def_static(
2924 "callsite",
2925 [](PyLocation callee, const std::vector<PyLocation> &frames,
2926 DefaultingPyMlirContext context) {
2927 if (frames.empty())
2928 throw nb::value_error("No caller frames provided");
2929 MlirLocation caller = frames.back().get();
2930 for (const PyLocation &frame :
2931 llvm::reverse(llvm::ArrayRef(frames).drop_back()))
2932 caller = mlirLocationCallSiteGet(frame.get(), caller);
2933 return PyLocation(context->getRef(),
2934 mlirLocationCallSiteGet(callee.get(), caller));
2936 nb::arg("callee"), nb::arg("frames"),
2937 nb::arg("context").none() = nb::none(),
2938 kContextGetCallSiteLocationDocstring)
2939 .def_static(
2940 "file",
2941 [](std::string filename, int line, int col,
2942 DefaultingPyMlirContext context) {
2943 return PyLocation(
2944 context->getRef(),
2945 mlirLocationFileLineColGet(
2946 context->get(), toMlirStringRef(filename), line, col));
2948 nb::arg("filename"), nb::arg("line"), nb::arg("col"),
2949 nb::arg("context").none() = nb::none(),
2950 kContextGetFileLocationDocstring)
2951 .def_static(
2952 "file",
2953 [](std::string filename, int startLine, int startCol, int endLine,
2954 int endCol, DefaultingPyMlirContext context) {
2955 return PyLocation(context->getRef(),
2956 mlirLocationFileLineColRangeGet(
2957 context->get(), toMlirStringRef(filename),
2958 startLine, startCol, endLine, endCol));
2960 nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"),
2961 nb::arg("end_line"), nb::arg("end_col"),
2962 nb::arg("context").none() = nb::none(), kContextGetFileRangeDocstring)
2963 .def_static(
2964 "fused",
2965 [](const std::vector<PyLocation> &pyLocations,
2966 std::optional<PyAttribute> metadata,
2967 DefaultingPyMlirContext context) {
2968 llvm::SmallVector<MlirLocation, 4> locations;
2969 locations.reserve(pyLocations.size());
2970 for (auto &pyLocation : pyLocations)
2971 locations.push_back(pyLocation.get());
2972 MlirLocation location = mlirLocationFusedGet(
2973 context->get(), locations.size(), locations.data(),
2974 metadata ? metadata->get() : MlirAttribute{0});
2975 return PyLocation(context->getRef(), location);
2977 nb::arg("locations"), nb::arg("metadata").none() = nb::none(),
2978 nb::arg("context").none() = nb::none(),
2979 kContextGetFusedLocationDocstring)
2980 .def_static(
2981 "name",
2982 [](std::string name, std::optional<PyLocation> childLoc,
2983 DefaultingPyMlirContext context) {
2984 return PyLocation(
2985 context->getRef(),
2986 mlirLocationNameGet(
2987 context->get(), toMlirStringRef(name),
2988 childLoc ? childLoc->get()
2989 : mlirLocationUnknownGet(context->get())));
2991 nb::arg("name"), nb::arg("childLoc").none() = nb::none(),
2992 nb::arg("context").none() = nb::none(),
2993 kContextGetNameLocationDocString)
2994 .def_static(
2995 "from_attr",
2996 [](PyAttribute &attribute, DefaultingPyMlirContext context) {
2997 return PyLocation(context->getRef(),
2998 mlirLocationFromAttribute(attribute));
3000 nb::arg("attribute"), nb::arg("context").none() = nb::none(),
3001 "Gets a Location from a LocationAttr")
3002 .def_prop_ro(
3003 "context",
3004 [](PyLocation &self) { return self.getContext().getObject(); },
3005 "Context that owns the Location")
3006 .def_prop_ro(
3007 "attr",
3008 [](PyLocation &self) { return mlirLocationGetAttribute(self); },
3009 "Get the underlying LocationAttr")
3010 .def(
3011 "emit_error",
3012 [](PyLocation &self, std::string message) {
3013 mlirEmitError(self, message.c_str());
3015 nb::arg("message"), "Emits an error at this location")
3016 .def("__repr__", [](PyLocation &self) {
3017 PyPrintAccumulator printAccum;
3018 mlirLocationPrint(self, printAccum.getCallback(),
3019 printAccum.getUserData());
3020 return printAccum.join();
3023 //----------------------------------------------------------------------------
3024 // Mapping of Module
3025 //----------------------------------------------------------------------------
3026 nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
3027 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
3028 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
3029 .def_static(
3030 "parse",
3031 [](const std::string &moduleAsm, DefaultingPyMlirContext context) {
3032 PyMlirContext::ErrorCapture errors(context->getRef());
3033 MlirModule module = mlirModuleCreateParse(
3034 context->get(), toMlirStringRef(moduleAsm));
3035 if (mlirModuleIsNull(module))
3036 throw MLIRError("Unable to parse module assembly", errors.take());
3037 return PyModule::forModule(module).releaseObject();
3039 nb::arg("asm"), nb::arg("context").none() = nb::none(),
3040 kModuleParseDocstring)
3041 .def_static(
3042 "parse",
3043 [](nb::bytes moduleAsm, DefaultingPyMlirContext context) {
3044 PyMlirContext::ErrorCapture errors(context->getRef());
3045 MlirModule module = mlirModuleCreateParse(
3046 context->get(), toMlirStringRef(moduleAsm));
3047 if (mlirModuleIsNull(module))
3048 throw MLIRError("Unable to parse module assembly", errors.take());
3049 return PyModule::forModule(module).releaseObject();
3051 nb::arg("asm"), nb::arg("context").none() = nb::none(),
3052 kModuleParseDocstring)
3053 .def_static(
3054 "parse",
3055 [](const std::filesystem::path &path,
3056 DefaultingPyMlirContext context) {
3057 PyMlirContext::ErrorCapture errors(context->getRef());
3058 MlirModule module = mlirModuleCreateParseFromFile(
3059 context->get(), toMlirStringRef(path.string()));
3060 if (mlirModuleIsNull(module))
3061 throw MLIRError("Unable to parse module assembly", errors.take());
3062 return PyModule::forModule(module).releaseObject();
3064 nb::arg("asm"), nb::arg("context").none() = nb::none(),
3065 kModuleParseDocstring)
3066 .def_static(
3067 "create",
3068 [](DefaultingPyLocation loc) {
3069 MlirModule module = mlirModuleCreateEmpty(loc);
3070 return PyModule::forModule(module).releaseObject();
3072 nb::arg("loc").none() = nb::none(), "Creates an empty module")
3073 .def_prop_ro(
3074 "context",
3075 [](PyModule &self) { return self.getContext().getObject(); },
3076 "Context that created the Module")
3077 .def_prop_ro(
3078 "operation",
3079 [](PyModule &self) {
3080 return PyOperation::forOperation(self.getContext(),
3081 mlirModuleGetOperation(self.get()),
3082 self.getRef().releaseObject())
3083 .releaseObject();
3085 "Accesses the module as an operation")
3086 .def_prop_ro(
3087 "body",
3088 [](PyModule &self) {
3089 PyOperationRef moduleOp = PyOperation::forOperation(
3090 self.getContext(), mlirModuleGetOperation(self.get()),
3091 self.getRef().releaseObject());
3092 PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
3093 return returnBlock;
3095 "Return the block for this module")
3096 .def(
3097 "dump",
3098 [](PyModule &self) {
3099 mlirOperationDump(mlirModuleGetOperation(self.get()));
3101 kDumpDocstring)
3102 .def(
3103 "__str__",
3104 [](nb::object self) {
3105 // Defer to the operation's __str__.
3106 return self.attr("operation").attr("__str__")();
3108 kOperationStrDunderDocstring);
3110 //----------------------------------------------------------------------------
3111 // Mapping of Operation.
3112 //----------------------------------------------------------------------------
3113 nb::class_<PyOperationBase>(m, "_OperationBase")
3114 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
3115 [](PyOperationBase &self) {
3116 return self.getOperation().getCapsule();
3118 .def("__eq__",
3119 [](PyOperationBase &self, PyOperationBase &other) {
3120 return &self.getOperation() == &other.getOperation();
3122 .def("__eq__",
3123 [](PyOperationBase &self, nb::object other) { return false; })
3124 .def("__hash__",
3125 [](PyOperationBase &self) {
3126 return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
3128 .def_prop_ro("attributes",
3129 [](PyOperationBase &self) {
3130 return PyOpAttributeMap(self.getOperation().getRef());
3132 .def_prop_ro(
3133 "context",
3134 [](PyOperationBase &self) {
3135 PyOperation &concreteOperation = self.getOperation();
3136 concreteOperation.checkValid();
3137 return concreteOperation.getContext().getObject();
3139 "Context that owns the Operation")
3140 .def_prop_ro("name",
3141 [](PyOperationBase &self) {
3142 auto &concreteOperation = self.getOperation();
3143 concreteOperation.checkValid();
3144 MlirOperation operation = concreteOperation.get();
3145 MlirStringRef name =
3146 mlirIdentifierStr(mlirOperationGetName(operation));
3147 return nb::str(name.data, name.length);
3149 .def_prop_ro("operands",
3150 [](PyOperationBase &self) {
3151 return PyOpOperandList(self.getOperation().getRef());
3153 .def_prop_ro("regions",
3154 [](PyOperationBase &self) {
3155 return PyRegionList(self.getOperation().getRef());
3157 .def_prop_ro(
3158 "results",
3159 [](PyOperationBase &self) {
3160 return PyOpResultList(self.getOperation().getRef());
3162 "Returns the list of Operation results.")
3163 .def_prop_ro(
3164 "result",
3165 [](PyOperationBase &self) {
3166 auto &operation = self.getOperation();
3167 return PyOpResult(operation.getRef(), getUniqueResult(operation))
3168 .maybeDownCast();
3170 "Shortcut to get an op result if it has only one (throws an error "
3171 "otherwise).")
3172 .def_prop_ro(
3173 "location",
3174 [](PyOperationBase &self) {
3175 PyOperation &operation = self.getOperation();
3176 return PyLocation(operation.getContext(),
3177 mlirOperationGetLocation(operation.get()));
3179 "Returns the source location the operation was defined or derived "
3180 "from.")
3181 .def_prop_ro("parent",
3182 [](PyOperationBase &self) -> nb::object {
3183 auto parent = self.getOperation().getParentOperation();
3184 if (parent)
3185 return parent->getObject();
3186 return nb::none();
3188 .def(
3189 "__str__",
3190 [](PyOperationBase &self) {
3191 return self.getAsm(/*binary=*/false,
3192 /*largeElementsLimit=*/std::nullopt,
3193 /*enableDebugInfo=*/false,
3194 /*prettyDebugInfo=*/false,
3195 /*printGenericOpForm=*/false,
3196 /*useLocalScope=*/false,
3197 /*assumeVerified=*/false,
3198 /*skipRegions=*/false);
3200 "Returns the assembly form of the operation.")
3201 .def("print",
3202 nb::overload_cast<PyAsmState &, nb::object, bool>(
3203 &PyOperationBase::print),
3204 nb::arg("state"), nb::arg("file").none() = nb::none(),
3205 nb::arg("binary") = false, kOperationPrintStateDocstring)
3206 .def("print",
3207 nb::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
3208 bool, nb::object, bool, bool>(
3209 &PyOperationBase::print),
3210 // Careful: Lots of arguments must match up with print method.
3211 nb::arg("large_elements_limit").none() = nb::none(),
3212 nb::arg("enable_debug_info") = false,
3213 nb::arg("pretty_debug_info") = false,
3214 nb::arg("print_generic_op_form") = false,
3215 nb::arg("use_local_scope") = false,
3216 nb::arg("assume_verified") = false,
3217 nb::arg("file").none() = nb::none(), nb::arg("binary") = false,
3218 nb::arg("skip_regions") = false, kOperationPrintDocstring)
3219 .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"),
3220 nb::arg("desired_version").none() = nb::none(),
3221 kOperationPrintBytecodeDocstring)
3222 .def("get_asm", &PyOperationBase::getAsm,
3223 // Careful: Lots of arguments must match up with get_asm method.
3224 nb::arg("binary") = false,
3225 nb::arg("large_elements_limit").none() = nb::none(),
3226 nb::arg("enable_debug_info") = false,
3227 nb::arg("pretty_debug_info") = false,
3228 nb::arg("print_generic_op_form") = false,
3229 nb::arg("use_local_scope") = false,
3230 nb::arg("assume_verified") = false, nb::arg("skip_regions") = false,
3231 kOperationGetAsmDocstring)
3232 .def("verify", &PyOperationBase::verify,
3233 "Verify the operation. Raises MLIRError if verification fails, and "
3234 "returns true otherwise.")
3235 .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"),
3236 "Puts self immediately after the other operation in its parent "
3237 "block.")
3238 .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"),
3239 "Puts self immediately before the other operation in its parent "
3240 "block.")
3241 .def(
3242 "clone",
3243 [](PyOperationBase &self, nb::object ip) {
3244 return self.getOperation().clone(ip);
3246 nb::arg("ip").none() = nb::none())
3247 .def(
3248 "detach_from_parent",
3249 [](PyOperationBase &self) {
3250 PyOperation &operation = self.getOperation();
3251 operation.checkValid();
3252 if (!operation.isAttached())
3253 throw nb::value_error("Detached operation has no parent.");
3255 operation.detachFromParent();
3256 return operation.createOpView();
3258 "Detaches the operation from its parent block.")
3259 .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
3260 .def("walk", &PyOperationBase::walk, nb::arg("callback"),
3261 nb::arg("walk_order") = MlirWalkPostOrder);
3263 nb::class_<PyOperation, PyOperationBase>(m, "Operation")
3264 .def_static(
3265 "create",
3266 [](std::string_view name,
3267 std::optional<std::vector<PyType *>> results,
3268 std::optional<std::vector<PyValue *>> operands,
3269 std::optional<nb::dict> attributes,
3270 std::optional<std::vector<PyBlock *>> successors, int regions,
3271 DefaultingPyLocation location, const nb::object &maybeIp,
3272 bool inferType) {
3273 // Unpack/validate operands.
3274 llvm::SmallVector<MlirValue, 4> mlirOperands;
3275 if (operands) {
3276 mlirOperands.reserve(operands->size());
3277 for (PyValue *operand : *operands) {
3278 if (!operand)
3279 throw nb::value_error("operand value cannot be None");
3280 mlirOperands.push_back(operand->get());
3284 return PyOperation::create(name, results, mlirOperands, attributes,
3285 successors, regions, location, maybeIp,
3286 inferType);
3288 nb::arg("name"), nb::arg("results").none() = nb::none(),
3289 nb::arg("operands").none() = nb::none(),
3290 nb::arg("attributes").none() = nb::none(),
3291 nb::arg("successors").none() = nb::none(), nb::arg("regions") = 0,
3292 nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(),
3293 nb::arg("infer_type") = false, kOperationCreateDocstring)
3294 .def_static(
3295 "parse",
3296 [](const std::string &sourceStr, const std::string &sourceName,
3297 DefaultingPyMlirContext context) {
3298 return PyOperation::parse(context->getRef(), sourceStr, sourceName)
3299 ->createOpView();
3301 nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "",
3302 nb::arg("context").none() = nb::none(),
3303 "Parses an operation. Supports both text assembly format and binary "
3304 "bytecode format.")
3305 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule)
3306 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
3307 .def_prop_ro("operation", [](nb::object self) { return self; })
3308 .def_prop_ro("opview", &PyOperation::createOpView)
3309 .def_prop_ro(
3310 "successors",
3311 [](PyOperationBase &self) {
3312 return PyOpSuccessors(self.getOperation().getRef());
3314 "Returns the list of Operation successors.");
3316 auto opViewClass =
3317 nb::class_<PyOpView, PyOperationBase>(m, "OpView")
3318 .def(nb::init<nb::object>(), nb::arg("operation"))
3319 .def(
3320 "__init__",
3321 [](PyOpView *self, std::string_view name,
3322 std::tuple<int, bool> opRegionSpec,
3323 nb::object operandSegmentSpecObj,
3324 nb::object resultSegmentSpecObj,
3325 std::optional<nb::list> resultTypeList, nb::list operandList,
3326 std::optional<nb::dict> attributes,
3327 std::optional<std::vector<PyBlock *>> successors,
3328 std::optional<int> regions, DefaultingPyLocation location,
3329 const nb::object &maybeIp) {
3330 new (self) PyOpView(PyOpView::buildGeneric(
3331 name, opRegionSpec, operandSegmentSpecObj,
3332 resultSegmentSpecObj, resultTypeList, operandList,
3333 attributes, successors, regions, location, maybeIp));
3335 nb::arg("name"), nb::arg("opRegionSpec"),
3336 nb::arg("operandSegmentSpecObj").none() = nb::none(),
3337 nb::arg("resultSegmentSpecObj").none() = nb::none(),
3338 nb::arg("results").none() = nb::none(),
3339 nb::arg("operands").none() = nb::none(),
3340 nb::arg("attributes").none() = nb::none(),
3341 nb::arg("successors").none() = nb::none(),
3342 nb::arg("regions").none() = nb::none(),
3343 nb::arg("loc").none() = nb::none(),
3344 nb::arg("ip").none() = nb::none())
3346 .def_prop_ro("operation", &PyOpView::getOperationObject)
3347 .def_prop_ro("opview", [](nb::object self) { return self; })
3348 .def(
3349 "__str__",
3350 [](PyOpView &self) { return nb::str(self.getOperationObject()); })
3351 .def_prop_ro(
3352 "successors",
3353 [](PyOperationBase &self) {
3354 return PyOpSuccessors(self.getOperation().getRef());
3356 "Returns the list of Operation successors.");
3357 opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
3358 opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
3359 opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
3360 // It is faster to pass the operation_name, ods_regions, and
3361 // ods_operand_segments/ods_result_segments as arguments to the constructor,
3362 // rather than to access them as attributes.
3363 opViewClass.attr("build_generic") = classmethod(
3364 [](nb::handle cls, std::optional<nb::list> resultTypeList,
3365 nb::list operandList, std::optional<nb::dict> attributes,
3366 std::optional<std::vector<PyBlock *>> successors,
3367 std::optional<int> regions, DefaultingPyLocation location,
3368 const nb::object &maybeIp) {
3369 std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
3370 std::tuple<int, bool> opRegionSpec =
3371 nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
3372 nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
3373 nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
3374 return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
3375 resultSegmentSpec, resultTypeList,
3376 operandList, attributes, successors,
3377 regions, location, maybeIp);
3379 nb::arg("cls"), nb::arg("results").none() = nb::none(),
3380 nb::arg("operands").none() = nb::none(),
3381 nb::arg("attributes").none() = nb::none(),
3382 nb::arg("successors").none() = nb::none(),
3383 nb::arg("regions").none() = nb::none(),
3384 nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(),
3385 "Builds a specific, generated OpView based on class level attributes.");
3386 opViewClass.attr("parse") = classmethod(
3387 [](const nb::object &cls, const std::string &sourceStr,
3388 const std::string &sourceName, DefaultingPyMlirContext context) {
3389 PyOperationRef parsed =
3390 PyOperation::parse(context->getRef(), sourceStr, sourceName);
3392 // Check if the expected operation was parsed, and cast to to the
3393 // appropriate `OpView` subclass if successful.
3394 // NOTE: This accesses attributes that have been automatically added to
3395 // `OpView` subclasses, and is not intended to be used on `OpView`
3396 // directly.
3397 std::string clsOpName =
3398 nb::cast<std::string>(cls.attr("OPERATION_NAME"));
3399 MlirStringRef identifier =
3400 mlirIdentifierStr(mlirOperationGetName(*parsed.get()));
3401 std::string_view parsedOpName(identifier.data, identifier.length);
3402 if (clsOpName != parsedOpName)
3403 throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
3404 parsedOpName + "'");
3405 return PyOpView::constructDerived(cls, parsed.getObject());
3407 nb::arg("cls"), nb::arg("source"), nb::kw_only(),
3408 nb::arg("source_name") = "", nb::arg("context").none() = nb::none(),
3409 "Parses a specific, generated OpView based on class level attributes");
3411 //----------------------------------------------------------------------------
3412 // Mapping of PyRegion.
3413 //----------------------------------------------------------------------------
3414 nb::class_<PyRegion>(m, "Region")
3415 .def_prop_ro(
3416 "blocks",
3417 [](PyRegion &self) {
3418 return PyBlockList(self.getParentOperation(), self.get());
3420 "Returns a forward-optimized sequence of blocks.")
3421 .def_prop_ro(
3422 "owner",
3423 [](PyRegion &self) {
3424 return self.getParentOperation()->createOpView();
3426 "Returns the operation owning this region.")
3427 .def(
3428 "__iter__",
3429 [](PyRegion &self) {
3430 self.checkValid();
3431 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
3432 return PyBlockIterator(self.getParentOperation(), firstBlock);
3434 "Iterates over blocks in the region.")
3435 .def("__eq__",
3436 [](PyRegion &self, PyRegion &other) {
3437 return self.get().ptr == other.get().ptr;
3439 .def("__eq__", [](PyRegion &self, nb::object &other) { return false; });
3441 //----------------------------------------------------------------------------
3442 // Mapping of PyBlock.
3443 //----------------------------------------------------------------------------
3444 nb::class_<PyBlock>(m, "Block")
3445 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule)
3446 .def_prop_ro(
3447 "owner",
3448 [](PyBlock &self) {
3449 return self.getParentOperation()->createOpView();
3451 "Returns the owning operation of this block.")
3452 .def_prop_ro(
3453 "region",
3454 [](PyBlock &self) {
3455 MlirRegion region = mlirBlockGetParentRegion(self.get());
3456 return PyRegion(self.getParentOperation(), region);
3458 "Returns the owning region of this block.")
3459 .def_prop_ro(
3460 "arguments",
3461 [](PyBlock &self) {
3462 return PyBlockArgumentList(self.getParentOperation(), self.get());
3464 "Returns a list of block arguments.")
3465 .def(
3466 "add_argument",
3467 [](PyBlock &self, const PyType &type, const PyLocation &loc) {
3468 return mlirBlockAddArgument(self.get(), type, loc);
3470 "Append an argument of the specified type to the block and returns "
3471 "the newly added argument.")
3472 .def(
3473 "erase_argument",
3474 [](PyBlock &self, unsigned index) {
3475 return mlirBlockEraseArgument(self.get(), index);
3477 "Erase the argument at 'index' and remove it from the argument list.")
3478 .def_prop_ro(
3479 "operations",
3480 [](PyBlock &self) {
3481 return PyOperationList(self.getParentOperation(), self.get());
3483 "Returns a forward-optimized sequence of operations.")
3484 .def_static(
3485 "create_at_start",
3486 [](PyRegion &parent, const nb::sequence &pyArgTypes,
3487 const std::optional<nb::sequence> &pyArgLocs) {
3488 parent.checkValid();
3489 MlirBlock block = createBlock(pyArgTypes, pyArgLocs);
3490 mlirRegionInsertOwnedBlock(parent, 0, block);
3491 return PyBlock(parent.getParentOperation(), block);
3493 nb::arg("parent"), nb::arg("arg_types") = nb::list(),
3494 nb::arg("arg_locs") = std::nullopt,
3495 "Creates and returns a new Block at the beginning of the given "
3496 "region (with given argument types and locations).")
3497 .def(
3498 "append_to",
3499 [](PyBlock &self, PyRegion &region) {
3500 MlirBlock b = self.get();
3501 if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
3502 mlirBlockDetach(b);
3503 mlirRegionAppendOwnedBlock(region.get(), b);
3505 "Append this block to a region, transferring ownership if necessary")
3506 .def(
3507 "create_before",
3508 [](PyBlock &self, const nb::args &pyArgTypes,
3509 const std::optional<nb::sequence> &pyArgLocs) {
3510 self.checkValid();
3511 MlirBlock block =
3512 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
3513 MlirRegion region = mlirBlockGetParentRegion(self.get());
3514 mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
3515 return PyBlock(self.getParentOperation(), block);
3517 nb::arg("arg_types"), nb::kw_only(),
3518 nb::arg("arg_locs") = std::nullopt,
3519 "Creates and returns a new Block before this block "
3520 "(with given argument types and locations).")
3521 .def(
3522 "create_after",
3523 [](PyBlock &self, const nb::args &pyArgTypes,
3524 const std::optional<nb::sequence> &pyArgLocs) {
3525 self.checkValid();
3526 MlirBlock block =
3527 createBlock(nb::cast<nb::sequence>(pyArgTypes), pyArgLocs);
3528 MlirRegion region = mlirBlockGetParentRegion(self.get());
3529 mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
3530 return PyBlock(self.getParentOperation(), block);
3532 nb::arg("arg_types"), nb::kw_only(),
3533 nb::arg("arg_locs") = std::nullopt,
3534 "Creates and returns a new Block after this block "
3535 "(with given argument types and locations).")
3536 .def(
3537 "__iter__",
3538 [](PyBlock &self) {
3539 self.checkValid();
3540 MlirOperation firstOperation =
3541 mlirBlockGetFirstOperation(self.get());
3542 return PyOperationIterator(self.getParentOperation(),
3543 firstOperation);
3545 "Iterates over operations in the block.")
3546 .def("__eq__",
3547 [](PyBlock &self, PyBlock &other) {
3548 return self.get().ptr == other.get().ptr;
3550 .def("__eq__", [](PyBlock &self, nb::object &other) { return false; })
3551 .def("__hash__",
3552 [](PyBlock &self) {
3553 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3555 .def(
3556 "__str__",
3557 [](PyBlock &self) {
3558 self.checkValid();
3559 PyPrintAccumulator printAccum;
3560 mlirBlockPrint(self.get(), printAccum.getCallback(),
3561 printAccum.getUserData());
3562 return printAccum.join();
3564 "Returns the assembly form of the block.")
3565 .def(
3566 "append",
3567 [](PyBlock &self, PyOperationBase &operation) {
3568 if (operation.getOperation().isAttached())
3569 operation.getOperation().detachFromParent();
3571 MlirOperation mlirOperation = operation.getOperation().get();
3572 mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
3573 operation.getOperation().setAttached(
3574 self.getParentOperation().getObject());
3576 nb::arg("operation"),
3577 "Appends an operation to this block. If the operation is currently "
3578 "in another block, it will be moved.");
3580 //----------------------------------------------------------------------------
3581 // Mapping of PyInsertionPoint.
3582 //----------------------------------------------------------------------------
3584 nb::class_<PyInsertionPoint>(m, "InsertionPoint")
3585 .def(nb::init<PyBlock &>(), nb::arg("block"),
3586 "Inserts after the last operation but still inside the block.")
3587 .def("__enter__", &PyInsertionPoint::contextEnter)
3588 .def("__exit__", &PyInsertionPoint::contextExit,
3589 nb::arg("exc_type").none(), nb::arg("exc_value").none(),
3590 nb::arg("traceback").none())
3591 .def_prop_ro_static(
3592 "current",
3593 [](nb::object & /*class*/) {
3594 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
3595 if (!ip)
3596 throw nb::value_error("No current InsertionPoint");
3597 return ip;
3599 "Gets the InsertionPoint bound to the current thread or raises "
3600 "ValueError if none has been set")
3601 .def(nb::init<PyOperationBase &>(), nb::arg("beforeOperation"),
3602 "Inserts before a referenced operation.")
3603 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
3604 nb::arg("block"), "Inserts at the beginning of the block.")
3605 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
3606 nb::arg("block"), "Inserts before the block terminator.")
3607 .def("insert", &PyInsertionPoint::insert, nb::arg("operation"),
3608 "Inserts an operation.")
3609 .def_prop_ro(
3610 "block", [](PyInsertionPoint &self) { return self.getBlock(); },
3611 "Returns the block that this InsertionPoint points to.")
3612 .def_prop_ro(
3613 "ref_operation",
3614 [](PyInsertionPoint &self) -> nb::object {
3615 auto refOperation = self.getRefOperation();
3616 if (refOperation)
3617 return refOperation->getObject();
3618 return nb::none();
3620 "The reference operation before which new operations are "
3621 "inserted, or None if the insertion point is at the end of "
3622 "the block");
3624 //----------------------------------------------------------------------------
3625 // Mapping of PyAttribute.
3626 //----------------------------------------------------------------------------
3627 nb::class_<PyAttribute>(m, "Attribute")
3628 // Delegate to the PyAttribute copy constructor, which will also lifetime
3629 // extend the backing context which owns the MlirAttribute.
3630 .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"),
3631 "Casts the passed attribute to the generic Attribute")
3632 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule)
3633 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
3634 .def_static(
3635 "parse",
3636 [](const std::string &attrSpec, DefaultingPyMlirContext context) {
3637 PyMlirContext::ErrorCapture errors(context->getRef());
3638 MlirAttribute attr = mlirAttributeParseGet(
3639 context->get(), toMlirStringRef(attrSpec));
3640 if (mlirAttributeIsNull(attr))
3641 throw MLIRError("Unable to parse attribute", errors.take());
3642 return attr;
3644 nb::arg("asm"), nb::arg("context").none() = nb::none(),
3645 "Parses an attribute from an assembly form. Raises an MLIRError on "
3646 "failure.")
3647 .def_prop_ro(
3648 "context",
3649 [](PyAttribute &self) { return self.getContext().getObject(); },
3650 "Context that owns the Attribute")
3651 .def_prop_ro("type",
3652 [](PyAttribute &self) { return mlirAttributeGetType(self); })
3653 .def(
3654 "get_named",
3655 [](PyAttribute &self, std::string name) {
3656 return PyNamedAttribute(self, std::move(name));
3658 nb::keep_alive<0, 1>(), "Binds a name to the attribute")
3659 .def("__eq__",
3660 [](PyAttribute &self, PyAttribute &other) { return self == other; })
3661 .def("__eq__", [](PyAttribute &self, nb::object &other) { return false; })
3662 .def("__hash__",
3663 [](PyAttribute &self) {
3664 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3666 .def(
3667 "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
3668 kDumpDocstring)
3669 .def(
3670 "__str__",
3671 [](PyAttribute &self) {
3672 PyPrintAccumulator printAccum;
3673 mlirAttributePrint(self, printAccum.getCallback(),
3674 printAccum.getUserData());
3675 return printAccum.join();
3677 "Returns the assembly form of the Attribute.")
3678 .def("__repr__",
3679 [](PyAttribute &self) {
3680 // Generally, assembly formats are not printed for __repr__ because
3681 // this can cause exceptionally long debug output and exceptions.
3682 // However, attribute values are generally considered useful and
3683 // are printed. This may need to be re-evaluated if debug dumps end
3684 // up being excessive.
3685 PyPrintAccumulator printAccum;
3686 printAccum.parts.append("Attribute(");
3687 mlirAttributePrint(self, printAccum.getCallback(),
3688 printAccum.getUserData());
3689 printAccum.parts.append(")");
3690 return printAccum.join();
3692 .def_prop_ro("typeid",
3693 [](PyAttribute &self) -> MlirTypeID {
3694 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
3695 assert(!mlirTypeIDIsNull(mlirTypeID) &&
3696 "mlirTypeID was expected to be non-null.");
3697 return mlirTypeID;
3699 .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) {
3700 MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self);
3701 assert(!mlirTypeIDIsNull(mlirTypeID) &&
3702 "mlirTypeID was expected to be non-null.");
3703 std::optional<nb::callable> typeCaster =
3704 PyGlobals::get().lookupTypeCaster(mlirTypeID,
3705 mlirAttributeGetDialect(self));
3706 if (!typeCaster)
3707 return nb::cast(self);
3708 return typeCaster.value()(self);
3711 //----------------------------------------------------------------------------
3712 // Mapping of PyNamedAttribute
3713 //----------------------------------------------------------------------------
3714 nb::class_<PyNamedAttribute>(m, "NamedAttribute")
3715 .def("__repr__",
3716 [](PyNamedAttribute &self) {
3717 PyPrintAccumulator printAccum;
3718 printAccum.parts.append("NamedAttribute(");
3719 printAccum.parts.append(
3720 nb::str(mlirIdentifierStr(self.namedAttr.name).data,
3721 mlirIdentifierStr(self.namedAttr.name).length));
3722 printAccum.parts.append("=");
3723 mlirAttributePrint(self.namedAttr.attribute,
3724 printAccum.getCallback(),
3725 printAccum.getUserData());
3726 printAccum.parts.append(")");
3727 return printAccum.join();
3729 .def_prop_ro(
3730 "name",
3731 [](PyNamedAttribute &self) {
3732 return nb::str(mlirIdentifierStr(self.namedAttr.name).data,
3733 mlirIdentifierStr(self.namedAttr.name).length);
3735 "The name of the NamedAttribute binding")
3736 .def_prop_ro(
3737 "attr",
3738 [](PyNamedAttribute &self) { return self.namedAttr.attribute; },
3739 nb::keep_alive<0, 1>(),
3740 "The underlying generic attribute of the NamedAttribute binding");
3742 //----------------------------------------------------------------------------
3743 // Mapping of PyType.
3744 //----------------------------------------------------------------------------
3745 nb::class_<PyType>(m, "Type")
3746 // Delegate to the PyType copy constructor, which will also lifetime
3747 // extend the backing context which owns the MlirType.
3748 .def(nb::init<PyType &>(), nb::arg("cast_from_type"),
3749 "Casts the passed type to the generic Type")
3750 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
3751 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
3752 .def_static(
3753 "parse",
3754 [](std::string typeSpec, DefaultingPyMlirContext context) {
3755 PyMlirContext::ErrorCapture errors(context->getRef());
3756 MlirType type =
3757 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
3758 if (mlirTypeIsNull(type))
3759 throw MLIRError("Unable to parse type", errors.take());
3760 return type;
3762 nb::arg("asm"), nb::arg("context").none() = nb::none(),
3763 kContextParseTypeDocstring)
3764 .def_prop_ro(
3765 "context", [](PyType &self) { return self.getContext().getObject(); },
3766 "Context that owns the Type")
3767 .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
3768 .def(
3769 "__eq__", [](PyType &self, nb::object &other) { return false; },
3770 nb::arg("other").none())
3771 .def("__hash__",
3772 [](PyType &self) {
3773 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3775 .def(
3776 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3777 .def(
3778 "__str__",
3779 [](PyType &self) {
3780 PyPrintAccumulator printAccum;
3781 mlirTypePrint(self, printAccum.getCallback(),
3782 printAccum.getUserData());
3783 return printAccum.join();
3785 "Returns the assembly form of the type.")
3786 .def("__repr__",
3787 [](PyType &self) {
3788 // Generally, assembly formats are not printed for __repr__ because
3789 // this can cause exceptionally long debug output and exceptions.
3790 // However, types are an exception as they typically have compact
3791 // assembly forms and printing them is useful.
3792 PyPrintAccumulator printAccum;
3793 printAccum.parts.append("Type(");
3794 mlirTypePrint(self, printAccum.getCallback(),
3795 printAccum.getUserData());
3796 printAccum.parts.append(")");
3797 return printAccum.join();
3799 .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
3800 [](PyType &self) {
3801 MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3802 assert(!mlirTypeIDIsNull(mlirTypeID) &&
3803 "mlirTypeID was expected to be non-null.");
3804 std::optional<nb::callable> typeCaster =
3805 PyGlobals::get().lookupTypeCaster(mlirTypeID,
3806 mlirTypeGetDialect(self));
3807 if (!typeCaster)
3808 return nb::cast(self);
3809 return typeCaster.value()(self);
3811 .def_prop_ro("typeid", [](PyType &self) -> MlirTypeID {
3812 MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
3813 if (!mlirTypeIDIsNull(mlirTypeID))
3814 return mlirTypeID;
3815 auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(self)));
3816 throw nb::value_error(
3817 (origRepr + llvm::Twine(" has no typeid.")).str().c_str());
3820 //----------------------------------------------------------------------------
3821 // Mapping of PyTypeID.
3822 //----------------------------------------------------------------------------
3823 nb::class_<PyTypeID>(m, "TypeID")
3824 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule)
3825 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule)
3826 // Note, this tests whether the underlying TypeIDs are the same,
3827 // not whether the wrapper MlirTypeIDs are the same, nor whether
3828 // the Python objects are the same (i.e., PyTypeID is a value type).
3829 .def("__eq__",
3830 [](PyTypeID &self, PyTypeID &other) { return self == other; })
3831 .def("__eq__",
3832 [](PyTypeID &self, const nb::object &other) { return false; })
3833 // Note, this gives the hash value of the underlying TypeID, not the
3834 // hash value of the Python object, nor the hash value of the
3835 // MlirTypeID wrapper.
3836 .def("__hash__", [](PyTypeID &self) {
3837 return static_cast<size_t>(mlirTypeIDHashValue(self));
3840 //----------------------------------------------------------------------------
3841 // Mapping of Value.
3842 //----------------------------------------------------------------------------
3843 nb::class_<PyValue>(m, "Value")
3844 .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"))
3845 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
3846 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
3847 .def_prop_ro(
3848 "context",
3849 [](PyValue &self) { return self.getParentOperation()->getContext(); },
3850 "Context in which the value lives.")
3851 .def(
3852 "dump", [](PyValue &self) { mlirValueDump(self.get()); },
3853 kDumpDocstring)
3854 .def_prop_ro(
3855 "owner",
3856 [](PyValue &self) -> nb::object {
3857 MlirValue v = self.get();
3858 if (mlirValueIsAOpResult(v)) {
3859 assert(
3860 mlirOperationEqual(self.getParentOperation()->get(),
3861 mlirOpResultGetOwner(self.get())) &&
3862 "expected the owner of the value in Python to match that in "
3863 "the IR");
3864 return self.getParentOperation().getObject();
3867 if (mlirValueIsABlockArgument(v)) {
3868 MlirBlock block = mlirBlockArgumentGetOwner(self.get());
3869 return nb::cast(PyBlock(self.getParentOperation(), block));
3872 assert(false && "Value must be a block argument or an op result");
3873 return nb::none();
3875 .def_prop_ro("uses",
3876 [](PyValue &self) {
3877 return PyOpOperandIterator(
3878 mlirValueGetFirstUse(self.get()));
3880 .def("__eq__",
3881 [](PyValue &self, PyValue &other) {
3882 return self.get().ptr == other.get().ptr;
3884 .def("__eq__", [](PyValue &self, nb::object other) { return false; })
3885 .def("__hash__",
3886 [](PyValue &self) {
3887 return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3889 .def(
3890 "__str__",
3891 [](PyValue &self) {
3892 PyPrintAccumulator printAccum;
3893 printAccum.parts.append("Value(");
3894 mlirValuePrint(self.get(), printAccum.getCallback(),
3895 printAccum.getUserData());
3896 printAccum.parts.append(")");
3897 return printAccum.join();
3899 kValueDunderStrDocstring)
3900 .def(
3901 "get_name",
3902 [](PyValue &self, bool useLocalScope) {
3903 PyPrintAccumulator printAccum;
3904 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
3905 if (useLocalScope)
3906 mlirOpPrintingFlagsUseLocalScope(flags);
3907 MlirAsmState valueState =
3908 mlirAsmStateCreateForValue(self.get(), flags);
3909 mlirValuePrintAsOperand(self.get(), valueState,
3910 printAccum.getCallback(),
3911 printAccum.getUserData());
3912 mlirOpPrintingFlagsDestroy(flags);
3913 mlirAsmStateDestroy(valueState);
3914 return printAccum.join();
3916 nb::arg("use_local_scope") = false)
3917 .def(
3918 "get_name",
3919 [](PyValue &self, PyAsmState &state) {
3920 PyPrintAccumulator printAccum;
3921 MlirAsmState valueState = state.get();
3922 mlirValuePrintAsOperand(self.get(), valueState,
3923 printAccum.getCallback(),
3924 printAccum.getUserData());
3925 return printAccum.join();
3927 nb::arg("state"), kGetNameAsOperand)
3928 .def_prop_ro("type",
3929 [](PyValue &self) { return mlirValueGetType(self.get()); })
3930 .def(
3931 "set_type",
3932 [](PyValue &self, const PyType &type) {
3933 return mlirValueSetType(self.get(), type);
3935 nb::arg("type"))
3936 .def(
3937 "replace_all_uses_with",
3938 [](PyValue &self, PyValue &with) {
3939 mlirValueReplaceAllUsesOfWith(self.get(), with.get());
3941 kValueReplaceAllUsesWithDocstring)
3942 .def(
3943 "replace_all_uses_except",
3944 [](MlirValue self, MlirValue with, PyOperation &exception) {
3945 MlirOperation exceptedUser = exception.get();
3946 mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
3948 nb::arg("with"), nb::arg("exceptions"),
3949 kValueReplaceAllUsesExceptDocstring)
3950 .def(
3951 "replace_all_uses_except",
3952 [](MlirValue self, MlirValue with, nb::list exceptions) {
3953 // Convert Python list to a SmallVector of MlirOperations
3954 llvm::SmallVector<MlirOperation> exceptionOps;
3955 for (nb::handle exception : exceptions) {
3956 exceptionOps.push_back(nb::cast<PyOperation &>(exception).get());
3959 mlirValueReplaceAllUsesExcept(
3960 self, with, static_cast<intptr_t>(exceptionOps.size()),
3961 exceptionOps.data());
3963 nb::arg("with"), nb::arg("exceptions"),
3964 kValueReplaceAllUsesExceptDocstring)
3965 .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
3966 [](PyValue &self) { return self.maybeDownCast(); });
3967 PyBlockArgument::bind(m);
3968 PyOpResult::bind(m);
3969 PyOpOperand::bind(m);
3971 nb::class_<PyAsmState>(m, "AsmState")
3972 .def(nb::init<PyValue &, bool>(), nb::arg("value"),
3973 nb::arg("use_local_scope") = false)
3974 .def(nb::init<PyOperationBase &, bool>(), nb::arg("op"),
3975 nb::arg("use_local_scope") = false);
3977 //----------------------------------------------------------------------------
3978 // Mapping of SymbolTable.
3979 //----------------------------------------------------------------------------
3980 nb::class_<PySymbolTable>(m, "SymbolTable")
3981 .def(nb::init<PyOperationBase &>())
3982 .def("__getitem__", &PySymbolTable::dunderGetItem)
3983 .def("insert", &PySymbolTable::insert, nb::arg("operation"))
3984 .def("erase", &PySymbolTable::erase, nb::arg("operation"))
3985 .def("__delitem__", &PySymbolTable::dunderDel)
3986 .def("__contains__",
3987 [](PySymbolTable &table, const std::string &name) {
3988 return !mlirOperationIsNull(mlirSymbolTableLookup(
3989 table, mlirStringRefCreate(name.data(), name.length())));
3991 // Static helpers.
3992 .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
3993 nb::arg("symbol"), nb::arg("name"))
3994 .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
3995 nb::arg("symbol"))
3996 .def_static("get_visibility", &PySymbolTable::getVisibility,
3997 nb::arg("symbol"))
3998 .def_static("set_visibility", &PySymbolTable::setVisibility,
3999 nb::arg("symbol"), nb::arg("visibility"))
4000 .def_static("replace_all_symbol_uses",
4001 &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"),
4002 nb::arg("new_symbol"), nb::arg("from_op"))
4003 .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
4004 nb::arg("from_op"), nb::arg("all_sym_uses_visible"),
4005 nb::arg("callback"));
4007 // Container bindings.
4008 PyBlockArgumentList::bind(m);
4009 PyBlockIterator::bind(m);
4010 PyBlockList::bind(m);
4011 PyOperationIterator::bind(m);
4012 PyOperationList::bind(m);
4013 PyOpAttributeMap::bind(m);
4014 PyOpOperandIterator::bind(m);
4015 PyOpOperandList::bind(m);
4016 PyOpResultList::bind(m);
4017 PyOpSuccessors::bind(m);
4018 PyRegionIterator::bind(m);
4019 PyRegionList::bind(m);
4021 // Debug bindings.
4022 PyGlobalDebugFlag::bind(m);
4024 // Attribute builder getter.
4025 PyAttrBuilderMap::bind(m);
4027 nb::register_exception_translator([](const std::exception_ptr &p,
4028 void *payload) {
4029 // We can't define exceptions with custom fields through pybind, so instead
4030 // the exception class is defined in python and imported here.
4031 try {
4032 if (p)
4033 std::rethrow_exception(p);
4034 } catch (const MLIRError &e) {
4035 nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
4036 .attr("MLIRError")(e.message, e.errorDiagnostics);
4037 PyErr_SetObject(PyExc_Exception, obj.ptr());