1 //===- IRModules.h - IR Submodules of pybind module -----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //===----------------------------------------------------------------------===//
10 #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H
11 #define MLIR_BINDINGS_PYTHON_IRMODULES_H
18 #include "PybindUtils.h"
20 #include "mlir-c/AffineExpr.h"
21 #include "mlir-c/AffineMap.h"
22 #include "mlir-c/Diagnostics.h"
23 #include "mlir-c/IR.h"
24 #include "mlir-c/IntegerSet.h"
25 #include "mlir-c/Transforms.h"
26 #include "mlir/Bindings/Python/PybindAdaptors.h"
27 #include "llvm/ADT/DenseMap.h"
34 class PyDiagnosticHandler
;
35 class PyInsertionPoint
;
37 class DefaultingPyLocation
;
39 class DefaultingPyMlirContext
;
42 class PyOperationBase
;
47 /// Template for a reference to a concrete type which captures a python
48 /// reference to its underlying python object.
52 PyObjectRef(T
*referrent
, pybind11::object object
)
53 : referrent(referrent
), object(std::move(object
)) {
54 assert(this->referrent
&&
55 "cannot construct PyObjectRef with null referrent");
56 assert(this->object
&& "cannot construct PyObjectRef with null object");
58 PyObjectRef(PyObjectRef
&&other
) noexcept
59 : referrent(other
.referrent
), object(std::move(other
.object
)) {
60 other
.referrent
= nullptr;
61 assert(!other
.object
);
63 PyObjectRef(const PyObjectRef
&other
)
64 : referrent(other
.referrent
), object(other
.object
/* copies */) {}
65 ~PyObjectRef() = default;
70 return object
.ref_count();
73 /// Releases the object held by this instance, returning it.
74 /// This is the proper thing to return from a function that wants to return
75 /// the reference. Note that this does not work from initializers.
76 pybind11::object
releaseObject() {
77 assert(referrent
&& object
);
79 auto stolen
= std::move(object
);
83 T
*get() { return referrent
; }
85 assert(referrent
&& object
);
88 pybind11::object
getObject() {
89 assert(referrent
&& object
);
92 operator bool() const { return referrent
&& object
; }
96 pybind11::object object
;
99 /// Tracks an entry in the thread context stack. New entries are pushed onto
100 /// here for each with block that activates a new InsertionPoint, Context or
103 /// Pushing either a Location or InsertionPoint also pushes its associated
104 /// Context. Pushing a Context will not modify the Location or InsertionPoint
105 /// unless if they are from a different context, in which case, they are
107 class PyThreadContextEntry
{
109 enum class FrameKind
{
115 PyThreadContextEntry(FrameKind frameKind
, pybind11::object context
,
116 pybind11::object insertionPoint
,
117 pybind11::object location
)
118 : context(std::move(context
)), insertionPoint(std::move(insertionPoint
)),
119 location(std::move(location
)), frameKind(frameKind
) {}
121 /// Gets the top of stack context and return nullptr if not defined.
122 static PyMlirContext
*getDefaultContext();
124 /// Gets the top of stack insertion point and return nullptr if not defined.
125 static PyInsertionPoint
*getDefaultInsertionPoint();
127 /// Gets the top of stack location and returns nullptr if not defined.
128 static PyLocation
*getDefaultLocation();
130 PyMlirContext
*getContext();
131 PyInsertionPoint
*getInsertionPoint();
132 PyLocation
*getLocation();
133 FrameKind
getFrameKind() { return frameKind
; }
135 /// Stack management.
136 static PyThreadContextEntry
*getTopOfStack();
137 static pybind11::object
pushContext(PyMlirContext
&context
);
138 static void popContext(PyMlirContext
&context
);
139 static pybind11::object
pushInsertionPoint(PyInsertionPoint
&insertionPoint
);
140 static void popInsertionPoint(PyInsertionPoint
&insertionPoint
);
141 static pybind11::object
pushLocation(PyLocation
&location
);
142 static void popLocation(PyLocation
&location
);
144 /// Gets the thread local stack.
145 static std::vector
<PyThreadContextEntry
> &getStack();
148 static void push(FrameKind frameKind
, pybind11::object context
,
149 pybind11::object insertionPoint
, pybind11::object location
);
151 /// An object reference to the PyContext.
152 pybind11::object context
;
153 /// An object reference to the current insertion point.
154 pybind11::object insertionPoint
;
155 /// An object reference to the current location.
156 pybind11::object location
;
157 // The kind of push that was performed.
161 /// Wrapper around MlirContext.
162 using PyMlirContextRef
= PyObjectRef
<PyMlirContext
>;
163 class PyMlirContext
{
165 PyMlirContext() = delete;
166 PyMlirContext(const PyMlirContext
&) = delete;
167 PyMlirContext(PyMlirContext
&&) = delete;
169 /// For the case of a python __init__ (py::init) method, pybind11 is quite
170 /// strict about needing to return a pointer that is not yet associated to
171 /// an py::object. Since the forContext() method acts like a pool, possibly
172 /// returning a recycled context, it does not satisfy this need. The usual
173 /// way in python to accomplish such a thing is to override __new__, but
174 /// that is also not supported by pybind11. Instead, we use this entry
175 /// point which always constructs a fresh context (which cannot alias an
176 /// existing one because it is fresh).
177 static PyMlirContext
*createNewContextForInit();
179 /// Returns a context reference for the singleton PyMlirContext wrapper for
180 /// the given context.
181 static PyMlirContextRef
forContext(MlirContext context
);
184 /// Accesses the underlying MlirContext.
185 MlirContext
get() { return context
; }
187 /// Gets a strong reference to this context, which will ensure it is kept
188 /// alive for the life of the reference.
189 PyMlirContextRef
getRef() {
190 return PyMlirContextRef(this, pybind11::cast(this));
193 /// Gets a capsule wrapping the void* within the MlirContext.
194 pybind11::object
getCapsule();
196 /// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
197 /// Note that PyMlirContext instances are uniqued, so the returned object
198 /// may be a pre-existing object. Ownership of the underlying MlirContext
199 /// is taken by calling this function.
200 static pybind11::object
createFromCapsule(pybind11::object capsule
);
202 /// Gets the count of live context objects. Used for testing.
203 static size_t getLiveCount();
205 /// Get a list of Python objects which are still in the live context map.
206 std::vector
<PyOperation
*> getLiveOperationObjects();
208 /// Gets the count of live operations associated with this context.
209 /// Used for testing.
210 size_t getLiveOperationCount();
212 /// Clears the live operations map, returning the number of entries which were
213 /// invalidated. To be used as a safety mechanism so that API end-users can't
214 /// corrupt by holding references they shouldn't have accessed in the first
216 size_t clearLiveOperations();
218 /// Removes an operation from the live operations map and sets it invalid.
219 /// This is useful for when some non-bindings code destroys the operation and
220 /// the bindings need to made aware. For example, in the case when pass
223 /// Note that this does *NOT* clear the nested operations.
224 void clearOperation(MlirOperation op
);
226 /// Clears all operations nested inside the given op using
227 /// `clearOperation(MlirOperation)`.
228 void clearOperationsInside(PyOperationBase
&op
);
229 void clearOperationsInside(MlirOperation op
);
231 /// Clears the operaiton _and_ all operations inside using
232 /// `clearOperation(MlirOperation)`.
233 void clearOperationAndInside(PyOperationBase
&op
);
235 /// Gets the count of live modules associated with this context.
236 /// Used for testing.
237 size_t getLiveModuleCount();
239 /// Enter and exit the context manager.
240 pybind11::object
contextEnter();
241 void contextExit(const pybind11::object
&excType
,
242 const pybind11::object
&excVal
,
243 const pybind11::object
&excTb
);
245 /// Attaches a Python callback as a diagnostic handler, returning a
246 /// registration object (internally a PyDiagnosticHandler).
247 pybind11::object
attachDiagnosticHandler(pybind11::object callback
);
249 /// Controls whether error diagnostics should be propagated to diagnostic
250 /// handlers, instead of being captured by `ErrorCapture`.
251 void setEmitErrorDiagnostics(bool value
) { emitErrorDiagnostics
= value
; }
255 PyMlirContext(MlirContext context
);
257 // Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
258 // preserving the relationship that an MlirContext maps to a single
259 // PyMlirContext wrapper. This could be replaced in the future with an
260 // extension mechanism on the MlirContext for stashing user pointers.
261 // Note that this holds a handle, which does not imply ownership.
262 // Mappings will be removed when the context is destructed.
263 using LiveContextMap
= llvm::DenseMap
<void *, PyMlirContext
*>;
264 static LiveContextMap
&getLiveContexts();
266 // Interns all live modules associated with this context. Modules tracked
267 // in this map are valid. When a module is invalidated, it is removed
268 // from this map, and while it still exists as an instance, any
269 // attempt to access it will raise an error.
270 using LiveModuleMap
=
271 llvm::DenseMap
<const void *, std::pair
<pybind11::handle
, PyModule
*>>;
272 LiveModuleMap liveModules
;
274 // Interns all live operations associated with this context. Operations
275 // tracked in this map are valid. When an operation is invalidated, it is
276 // removed from this map, and while it still exists as an instance, any
277 // attempt to access it will raise an error.
278 using LiveOperationMap
=
279 llvm::DenseMap
<void *, std::pair
<pybind11::handle
, PyOperation
*>>;
280 LiveOperationMap liveOperations
;
282 bool emitErrorDiagnostics
= false;
285 friend class PyModule
;
286 friend class PyOperation
;
289 /// Used in function arguments when None should resolve to the current context
290 /// manager set instance.
291 class DefaultingPyMlirContext
292 : public Defaulting
<DefaultingPyMlirContext
, PyMlirContext
> {
294 using Defaulting::Defaulting
;
295 static constexpr const char kTypeDescription
[] = "mlir.ir.Context";
296 static PyMlirContext
&resolve();
299 /// Base class for all objects that directly or indirectly depend on an
300 /// MlirContext. The lifetime of the context will extend at least to the
301 /// lifetime of these instances.
302 /// Immutable objects that depend on a context extend this directly.
303 class BaseContextObject
{
305 BaseContextObject(PyMlirContextRef ref
) : contextRef(std::move(ref
)) {
306 assert(this->contextRef
&&
307 "context object constructed with null context ref");
310 /// Accesses the context reference.
311 PyMlirContextRef
&getContext() { return contextRef
; }
314 PyMlirContextRef contextRef
;
317 /// Wrapper around an MlirLocation.
318 class PyLocation
: public BaseContextObject
{
320 PyLocation(PyMlirContextRef contextRef
, MlirLocation loc
)
321 : BaseContextObject(std::move(contextRef
)), loc(loc
) {}
323 operator MlirLocation() const { return loc
; }
324 MlirLocation
get() const { return loc
; }
326 /// Enter and exit the context manager.
327 pybind11::object
contextEnter();
328 void contextExit(const pybind11::object
&excType
,
329 const pybind11::object
&excVal
,
330 const pybind11::object
&excTb
);
332 /// Gets a capsule wrapping the void* within the MlirLocation.
333 pybind11::object
getCapsule();
335 /// Creates a PyLocation from the MlirLocation wrapped by a capsule.
336 /// Note that PyLocation instances are uniqued, so the returned object
337 /// may be a pre-existing object. Ownership of the underlying MlirLocation
338 /// is taken by calling this function.
339 static PyLocation
createFromCapsule(pybind11::object capsule
);
345 /// Python class mirroring the C MlirDiagnostic struct. Note that these structs
346 /// are only valid for the duration of a diagnostic callback and attempting
347 /// to access them outside of that will raise an exception. This applies to
348 /// nested diagnostics (in the notes) as well.
351 PyDiagnostic(MlirDiagnostic diagnostic
) : diagnostic(diagnostic
) {}
353 bool isValid() { return valid
; }
354 MlirDiagnosticSeverity
getSeverity();
355 PyLocation
getLocation();
356 pybind11::str
getMessage();
357 pybind11::tuple
getNotes();
359 /// Materialized diagnostic information. This is safe to access outside the
360 /// diagnostic callback.
361 struct DiagnosticInfo
{
362 MlirDiagnosticSeverity severity
;
365 std::vector
<DiagnosticInfo
> notes
;
367 DiagnosticInfo
getInfo();
370 MlirDiagnostic diagnostic
;
373 /// If notes have been materialized from the diagnostic, then this will
374 /// be populated with the corresponding objects (all castable to
376 std::optional
<pybind11::tuple
> materializedNotes
;
380 /// Represents a diagnostic handler attached to the context. The handler's
381 /// callback will be invoked with PyDiagnostic instances until the detach()
382 /// method is called or the context is destroyed. A diagnostic handler can be
383 /// the subject of a `with` block, which will detach it when the block exits.
385 /// Since diagnostic handlers can call back into Python code which can do
386 /// unsafe things (i.e. recursively emitting diagnostics, raising exceptions,
387 /// etc), this is generally not deemed to be a great user-level API. Users
388 /// should generally use some form of DiagnosticCollector. If the handler raises
389 /// any exceptions, they will just be emitted to stderr and dropped.
391 /// The unique usage of this class means that its lifetime management is
392 /// different from most other parts of the API. Instances are always created
393 /// in an attached state and can transition to a detached state by either:
394 /// a) The context being destroyed and unregistering all handlers.
395 /// b) An explicit call to detach().
396 /// The object may remain live from a Python perspective for an arbitrary time
397 /// after detachment, but there is nothing the user can do with it (since there
398 /// is no way to attach an existing handler object).
399 class PyDiagnosticHandler
{
401 PyDiagnosticHandler(MlirContext context
, pybind11::object callback
);
402 ~PyDiagnosticHandler();
404 bool isAttached() { return registeredID
.has_value(); }
405 bool getHadError() { return hadError
; }
407 /// Detaches the handler. Does nothing if not attached.
410 pybind11::object
contextEnter() { return pybind11::cast(this); }
411 void contextExit(const pybind11::object
&excType
,
412 const pybind11::object
&excVal
,
413 const pybind11::object
&excTb
) {
419 pybind11::object callback
;
420 std::optional
<MlirDiagnosticHandlerID
> registeredID
;
421 bool hadError
= false;
422 friend class PyMlirContext
;
425 /// RAII object that captures any error diagnostics emitted to the provided
427 struct PyMlirContext::ErrorCapture
{
428 ErrorCapture(PyMlirContextRef ctx
)
429 : ctx(ctx
), handlerID(mlirContextAttachDiagnosticHandler(
430 ctx
->get(), handler
, /*userData=*/this,
431 /*deleteUserData=*/nullptr)) {}
433 mlirContextDetachDiagnosticHandler(ctx
->get(), handlerID
);
434 assert(errors
.empty() && "unhandled captured errors");
437 std::vector
<PyDiagnostic::DiagnosticInfo
> take() {
438 return std::move(errors
);
442 PyMlirContextRef ctx
;
443 MlirDiagnosticHandlerID handlerID
;
444 std::vector
<PyDiagnostic::DiagnosticInfo
> errors
;
446 static MlirLogicalResult
handler(MlirDiagnostic diag
, void *userData
);
449 /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in
450 /// order to differentiate it from the `Dialect` base class which is extended by
451 /// plugins which extend dialect functionality through extension python code.
452 /// This should be seen as the "low-level" object and `Dialect` as the
453 /// high-level, user facing object.
454 class PyDialectDescriptor
: public BaseContextObject
{
456 PyDialectDescriptor(PyMlirContextRef contextRef
, MlirDialect dialect
)
457 : BaseContextObject(std::move(contextRef
)), dialect(dialect
) {}
459 MlirDialect
get() { return dialect
; }
465 /// User-level object for accessing dialects with dotted syntax such as:
467 class PyDialects
: public BaseContextObject
{
469 PyDialects(PyMlirContextRef contextRef
)
470 : BaseContextObject(std::move(contextRef
)) {}
472 MlirDialect
getDialectForKey(const std::string
&key
, bool attrError
);
475 /// User-level dialect object. For dialects that have a registered extension,
476 /// this will be the base class of the extension dialect type. For un-extended,
477 /// objects of this type will be returned directly.
480 PyDialect(pybind11::object descriptor
) : descriptor(std::move(descriptor
)) {}
482 pybind11::object
getDescriptor() { return descriptor
; }
485 pybind11::object descriptor
;
488 /// Wrapper around an MlirDialectRegistry.
489 /// Upon construction, the Python wrapper takes ownership of the
490 /// underlying MlirDialectRegistry.
491 class PyDialectRegistry
{
493 PyDialectRegistry() : registry(mlirDialectRegistryCreate()) {}
494 PyDialectRegistry(MlirDialectRegistry registry
) : registry(registry
) {}
495 ~PyDialectRegistry() {
496 if (!mlirDialectRegistryIsNull(registry
))
497 mlirDialectRegistryDestroy(registry
);
499 PyDialectRegistry(PyDialectRegistry
&) = delete;
500 PyDialectRegistry(PyDialectRegistry
&&other
) noexcept
501 : registry(other
.registry
) {
502 other
.registry
= {nullptr};
505 operator MlirDialectRegistry() const { return registry
; }
506 MlirDialectRegistry
get() const { return registry
; }
508 pybind11::object
getCapsule();
509 static PyDialectRegistry
createFromCapsule(pybind11::object capsule
);
512 MlirDialectRegistry registry
;
515 /// Used in function arguments when None should resolve to the current context
516 /// manager set instance.
517 class DefaultingPyLocation
518 : public Defaulting
<DefaultingPyLocation
, PyLocation
> {
520 using Defaulting::Defaulting
;
521 static constexpr const char kTypeDescription
[] = "mlir.ir.Location";
522 static PyLocation
&resolve();
524 operator MlirLocation() const { return *get(); }
527 /// Wrapper around MlirModule.
528 /// This is the top-level, user-owned object that contains regions/ops/blocks.
530 using PyModuleRef
= PyObjectRef
<PyModule
>;
531 class PyModule
: public BaseContextObject
{
533 /// Returns a PyModule reference for the given MlirModule. This may return
534 /// a pre-existing or new object.
535 static PyModuleRef
forModule(MlirModule module
);
536 PyModule(PyModule
&) = delete;
537 PyModule(PyMlirContext
&&) = delete;
540 /// Gets the backing MlirModule.
541 MlirModule
get() { return module
; }
543 /// Gets a strong reference to this module.
544 PyModuleRef
getRef() {
545 return PyModuleRef(this,
546 pybind11::reinterpret_borrow
<pybind11::object
>(handle
));
549 /// Gets a capsule wrapping the void* within the MlirModule.
550 /// Note that the module does not (yet) provide a corresponding factory for
551 /// constructing from a capsule as that would require uniquing PyModule
552 /// instances, which is not currently done.
553 pybind11::object
getCapsule();
555 /// Creates a PyModule from the MlirModule wrapped by a capsule.
556 /// Note that PyModule instances are uniqued, so the returned object
557 /// may be a pre-existing object. Ownership of the underlying MlirModule
558 /// is taken by calling this function.
559 static pybind11::object
createFromCapsule(pybind11::object capsule
);
562 PyModule(PyMlirContextRef contextRef
, MlirModule module
);
564 pybind11::handle handle
;
569 /// Base class for PyOperation and PyOpView which exposes the primary, user
570 /// visible methods for manipulating it.
571 class PyOperationBase
{
573 virtual ~PyOperationBase() = default;
574 /// Implements the bound 'print' method and helps with others.
575 void print(std::optional
<int64_t> largeElementsLimit
, bool enableDebugInfo
,
576 bool prettyDebugInfo
, bool printGenericOpForm
, bool useLocalScope
,
577 bool assumeVerified
, py::object fileObject
, bool binary
,
579 void print(PyAsmState
&state
, py::object fileObject
, bool binary
);
581 pybind11::object
getAsm(bool binary
,
582 std::optional
<int64_t> largeElementsLimit
,
583 bool enableDebugInfo
, bool prettyDebugInfo
,
584 bool printGenericOpForm
, bool useLocalScope
,
585 bool assumeVerified
, bool skipRegions
);
587 // Implement the bound 'writeBytecode' method.
588 void writeBytecode(const pybind11::object
&fileObject
,
589 std::optional
<int64_t> bytecodeVersion
);
591 // Implement the walk method.
592 void walk(std::function
<MlirWalkResult(MlirOperation
)> callback
,
593 MlirWalkOrder walkOrder
);
595 /// Moves the operation before or after the other operation.
596 void moveAfter(PyOperationBase
&other
);
597 void moveBefore(PyOperationBase
&other
);
599 /// Verify the operation. Throws `MLIRError` if verification fails, and
600 /// returns `true` otherwise.
603 /// Each must provide access to the raw Operation.
604 virtual PyOperation
&getOperation() = 0;
607 /// Wrapper around PyOperation.
608 /// Operations exist in either an attached (dependent) or detached (top-level)
609 /// state. In the detached state (as on creation), an operation is owned by
610 /// the creator and its lifetime extends either until its reference count
611 /// drops to zero or it is attached to a parent, at which point its lifetime
612 /// is bounded by its top-level parent reference.
614 using PyOperationRef
= PyObjectRef
<PyOperation
>;
615 class PyOperation
: public PyOperationBase
, public BaseContextObject
{
617 ~PyOperation() override
;
618 PyOperation
&getOperation() override
{ return *this; }
620 /// Returns a PyOperation for the given MlirOperation, optionally associating
621 /// it with a parentKeepAlive.
622 static PyOperationRef
623 forOperation(PyMlirContextRef contextRef
, MlirOperation operation
,
624 pybind11::object parentKeepAlive
= pybind11::object());
626 /// Creates a detached operation. The operation must not be associated with
627 /// any existing live operation.
628 static PyOperationRef
629 createDetached(PyMlirContextRef contextRef
, MlirOperation operation
,
630 pybind11::object parentKeepAlive
= pybind11::object());
632 /// Parses a source string (either text assembly or bytecode), creating a
633 /// detached operation.
634 static PyOperationRef
parse(PyMlirContextRef contextRef
,
635 const std::string
&sourceStr
,
636 const std::string
&sourceName
);
638 /// Detaches the operation from its parent block and updates its state
640 void detachFromParent() {
641 mlirOperationRemoveFromParent(getOperation());
643 parentKeepAlive
= pybind11::object();
646 /// Gets the backing operation.
647 operator MlirOperation() const { return get(); }
648 MlirOperation
get() const {
653 PyOperationRef
getRef() {
654 return PyOperationRef(
655 this, pybind11::reinterpret_borrow
<pybind11::object
>(handle
));
658 bool isAttached() { return attached
; }
659 void setAttached(const pybind11::object
&parent
= pybind11::object()) {
660 assert(!attached
&& "operation already attached");
664 assert(attached
&& "operation already detached");
667 void checkValid() const;
669 /// Gets the owning block or raises an exception if the operation has no
673 /// Gets the parent operation or raises an exception if the operation has
675 std::optional
<PyOperationRef
> getParentOperation();
677 /// Gets a capsule wrapping the void* within the MlirOperation.
678 pybind11::object
getCapsule();
680 /// Creates a PyOperation from the MlirOperation wrapped by a capsule.
681 /// Ownership of the underlying MlirOperation is taken by calling this
683 static pybind11::object
createFromCapsule(pybind11::object capsule
);
685 /// Creates an operation. See corresponding python docstring.
686 static pybind11::object
687 create(const std::string
&name
, std::optional
<std::vector
<PyType
*>> results
,
688 std::optional
<std::vector
<PyValue
*>> operands
,
689 std::optional
<pybind11::dict
> attributes
,
690 std::optional
<std::vector
<PyBlock
*>> successors
, int regions
,
691 DefaultingPyLocation location
, const pybind11::object
&ip
,
694 /// Creates an OpView suitable for this operation.
695 pybind11::object
createOpView();
697 /// Erases the underlying MlirOperation, removes its pointer from the
698 /// parent context's live operations map, and sets the valid bit false.
701 /// Invalidate the operation.
702 void setInvalid() { valid
= false; }
704 /// Clones this operation.
705 pybind11::object
clone(const pybind11::object
&ip
);
708 PyOperation(PyMlirContextRef contextRef
, MlirOperation operation
);
709 static PyOperationRef
createInstance(PyMlirContextRef contextRef
,
710 MlirOperation operation
,
711 pybind11::object parentKeepAlive
);
713 MlirOperation operation
;
714 pybind11::handle handle
;
715 // Keeps the parent alive, regardless of whether it is an Operation or
717 // TODO: As implemented, this facility is only sufficient for modeling the
718 // trivial module parent back-reference. Generalize this to also account for
719 // transitions from detached to attached and address TODOs in the
720 // ir_operation.py regarding testing corresponding lifetime guarantees.
721 pybind11::object parentKeepAlive
;
722 bool attached
= true;
725 friend class PyOperationBase
;
726 friend class PySymbolTable
;
729 /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for
730 /// providing more instance-specific accessors and serve as the base class for
731 /// custom ODS-style operation classes. Since this class is subclass on the
732 /// python side, it must present an __init__ method that operates in pure
734 class PyOpView
: public PyOperationBase
{
736 PyOpView(const pybind11::object
&operationObject
);
737 PyOperation
&getOperation() override
{ return operation
; }
739 pybind11::object
getOperationObject() { return operationObject
; }
741 static pybind11::object
buildGeneric(
742 const pybind11::object
&cls
, std::optional
<pybind11::list
> resultTypeList
,
743 pybind11::list operandList
, std::optional
<pybind11::dict
> attributes
,
744 std::optional
<std::vector
<PyBlock
*>> successors
,
745 std::optional
<int> regions
, DefaultingPyLocation location
,
746 const pybind11::object
&maybeIp
);
748 /// Construct an instance of a class deriving from OpView, bypassing its
749 /// `__init__` method. The derived class will typically define a constructor
750 /// that provides a convenient builder, but we need to side-step this when
751 /// constructing an `OpView` for an already-built operation.
753 /// The caller is responsible for verifying that `operation` is a valid
754 /// operation to construct `cls` with.
755 static pybind11::object
constructDerived(const pybind11::object
&cls
,
756 const PyOperation
&operation
);
759 PyOperation
&operation
; // For efficient, cast-free access from C++
760 pybind11::object operationObject
; // Holds the reference.
763 /// Wrapper around an MlirRegion.
764 /// Regions are managed completely by their containing operation. Unlike the
765 /// C++ API, the python API does not support detached regions.
768 PyRegion(PyOperationRef parentOperation
, MlirRegion region
)
769 : parentOperation(std::move(parentOperation
)), region(region
) {
770 assert(!mlirRegionIsNull(region
) && "python region cannot be null");
772 operator MlirRegion() const { return region
; }
774 MlirRegion
get() { return region
; }
775 PyOperationRef
&getParentOperation() { return parentOperation
; }
777 void checkValid() { return parentOperation
->checkValid(); }
780 PyOperationRef parentOperation
;
784 /// Wrapper around an MlirAsmState.
787 PyAsmState(MlirValue value
, bool useLocalScope
) {
788 flags
= mlirOpPrintingFlagsCreate();
789 // The OpPrintingFlags are not exposed Python side, create locally and
790 // associate lifetime with the state.
792 mlirOpPrintingFlagsUseLocalScope(flags
);
793 state
= mlirAsmStateCreateForValue(value
, flags
);
796 PyAsmState(PyOperationBase
&operation
, bool useLocalScope
) {
797 flags
= mlirOpPrintingFlagsCreate();
798 // The OpPrintingFlags are not exposed Python side, create locally and
799 // associate lifetime with the state.
801 mlirOpPrintingFlagsUseLocalScope(flags
);
803 mlirAsmStateCreateForOperation(operation
.getOperation().get(), flags
);
805 ~PyAsmState() { mlirOpPrintingFlagsDestroy(flags
); }
806 // Delete copy constructors.
807 PyAsmState(PyAsmState
&other
) = delete;
808 PyAsmState(const PyAsmState
&other
) = delete;
810 MlirAsmState
get() { return state
; }
814 MlirOpPrintingFlags flags
;
817 /// Wrapper around an MlirBlock.
818 /// Blocks are managed completely by their containing operation. Unlike the
819 /// C++ API, the python API does not support detached blocks.
822 PyBlock(PyOperationRef parentOperation
, MlirBlock block
)
823 : parentOperation(std::move(parentOperation
)), block(block
) {
824 assert(!mlirBlockIsNull(block
) && "python block cannot be null");
827 MlirBlock
get() { return block
; }
828 PyOperationRef
&getParentOperation() { return parentOperation
; }
830 void checkValid() { return parentOperation
->checkValid(); }
832 /// Gets a capsule wrapping the void* within the MlirBlock.
833 pybind11::object
getCapsule();
836 PyOperationRef parentOperation
;
840 /// An insertion point maintains a pointer to a Block and a reference operation.
841 /// Calls to insert() will insert a new operation before the
842 /// reference operation. If the reference operation is null, then appends to
843 /// the end of the block.
844 class PyInsertionPoint
{
846 /// Creates an insertion point positioned after the last operation in the
847 /// block, but still inside the block.
848 PyInsertionPoint(PyBlock
&block
);
849 /// Creates an insertion point positioned before a reference operation.
850 PyInsertionPoint(PyOperationBase
&beforeOperationBase
);
852 /// Shortcut to create an insertion point at the beginning of the block.
853 static PyInsertionPoint
atBlockBegin(PyBlock
&block
);
854 /// Shortcut to create an insertion point before the block terminator.
855 static PyInsertionPoint
atBlockTerminator(PyBlock
&block
);
857 /// Inserts an operation.
858 void insert(PyOperationBase
&operationBase
);
860 /// Enter and exit the context manager.
861 pybind11::object
contextEnter();
862 void contextExit(const pybind11::object
&excType
,
863 const pybind11::object
&excVal
,
864 const pybind11::object
&excTb
);
866 PyBlock
&getBlock() { return block
; }
867 std::optional
<PyOperationRef
> &getRefOperation() { return refOperation
; }
870 // Trampoline constructor that avoids null initializing members while
871 // looking up parents.
872 PyInsertionPoint(PyBlock block
, std::optional
<PyOperationRef
> refOperation
)
873 : refOperation(std::move(refOperation
)), block(std::move(block
)) {}
875 std::optional
<PyOperationRef
> refOperation
;
878 /// Wrapper around the generic MlirType.
879 /// The lifetime of a type is bound by the PyContext that created it.
880 class PyType
: public BaseContextObject
{
882 PyType(PyMlirContextRef contextRef
, MlirType type
)
883 : BaseContextObject(std::move(contextRef
)), type(type
) {}
884 bool operator==(const PyType
&other
) const;
885 operator MlirType() const { return type
; }
886 MlirType
get() const { return type
; }
888 /// Gets a capsule wrapping the void* within the MlirType.
889 pybind11::object
getCapsule();
891 /// Creates a PyType from the MlirType wrapped by a capsule.
892 /// Note that PyType instances are uniqued, so the returned object
893 /// may be a pre-existing object. Ownership of the underlying MlirType
894 /// is taken by calling this function.
895 static PyType
createFromCapsule(pybind11::object capsule
);
901 /// A TypeID provides an efficient and unique identifier for a specific C++
902 /// type. This allows for a C++ type to be compared, hashed, and stored in an
903 /// opaque context. This class wraps around the generic MlirTypeID.
906 PyTypeID(MlirTypeID typeID
) : typeID(typeID
) {}
907 // Note, this tests whether the underlying TypeIDs are the same,
908 // not whether the wrapper MlirTypeIDs are the same, nor whether
909 // the PyTypeID objects are the same (i.e., PyTypeID is a value type).
910 bool operator==(const PyTypeID
&other
) const;
911 operator MlirTypeID() const { return typeID
; }
912 MlirTypeID
get() { return typeID
; }
914 /// Gets a capsule wrapping the void* within the MlirTypeID.
915 pybind11::object
getCapsule();
917 /// Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
918 static PyTypeID
createFromCapsule(pybind11::object capsule
);
924 /// CRTP base classes for Python types that subclass Type and should be
925 /// castable from it (i.e. via something like IntegerType(t)).
926 /// By default, type class hierarchies are one level deep (i.e. a
927 /// concrete type class extends PyType); however, intermediate python-visible
928 /// base classes can be modeled by specifying a BaseTy.
929 template <typename DerivedTy
, typename BaseTy
= PyType
>
930 class PyConcreteType
: public BaseTy
{
932 // Derived classes must define statics for:
933 // IsAFunctionTy isaFunction
934 // const char *pyClassName
935 using ClassTy
= pybind11::class_
<DerivedTy
, BaseTy
>;
936 using IsAFunctionTy
= bool (*)(MlirType
);
937 using GetTypeIDFunctionTy
= MlirTypeID (*)();
938 static constexpr GetTypeIDFunctionTy getTypeIdFunction
= nullptr;
940 PyConcreteType() = default;
941 PyConcreteType(PyMlirContextRef contextRef
, MlirType t
)
942 : BaseTy(std::move(contextRef
), t
) {}
943 PyConcreteType(PyType
&orig
)
944 : PyConcreteType(orig
.getContext(), castFrom(orig
)) {}
946 static MlirType
castFrom(PyType
&orig
) {
947 if (!DerivedTy::isaFunction(orig
)) {
948 auto origRepr
= pybind11::repr(pybind11::cast(orig
)).cast
<std::string
>();
949 throw py::value_error((llvm::Twine("Cannot cast type to ") +
950 DerivedTy::pyClassName
+ " (from " + origRepr
+
957 static void bind(pybind11::module
&m
) {
958 auto cls
= ClassTy(m
, DerivedTy::pyClassName
, pybind11::module_local());
959 cls
.def(pybind11::init
<PyType
&>(), pybind11::keep_alive
<0, 1>(),
960 pybind11::arg("cast_from_type"));
963 [](PyType
&otherType
) -> bool {
964 return DerivedTy::isaFunction(otherType
);
966 pybind11::arg("other"));
967 cls
.def_property_readonly_static(
968 "static_typeid", [](py::object
& /*class*/) -> MlirTypeID
{
969 if (DerivedTy::getTypeIdFunction
)
970 return DerivedTy::getTypeIdFunction();
971 throw py::attribute_error(
972 (DerivedTy::pyClassName
+ llvm::Twine(" has no typeid.")).str());
974 cls
.def_property_readonly("typeid", [](PyType
&self
) {
975 return py::cast(self
).attr("typeid").cast
<MlirTypeID
>();
977 cls
.def("__repr__", [](DerivedTy
&self
) {
978 PyPrintAccumulator printAccum
;
979 printAccum
.parts
.append(DerivedTy::pyClassName
);
980 printAccum
.parts
.append("(");
981 mlirTypePrint(self
, printAccum
.getCallback(), printAccum
.getUserData());
982 printAccum
.parts
.append(")");
983 return printAccum
.join();
986 if (DerivedTy::getTypeIdFunction
) {
987 PyGlobals::get().registerTypeCaster(
988 DerivedTy::getTypeIdFunction(),
989 pybind11::cpp_function(
990 [](PyType pyType
) -> DerivedTy
{ return pyType
; }));
993 DerivedTy::bindDerived(cls
);
996 /// Implemented by derived classes to add methods to the Python subclass.
997 static void bindDerived(ClassTy
&m
) {}
1000 /// Wrapper around the generic MlirAttribute.
1001 /// The lifetime of a type is bound by the PyContext that created it.
1002 class PyAttribute
: public BaseContextObject
{
1004 PyAttribute(PyMlirContextRef contextRef
, MlirAttribute attr
)
1005 : BaseContextObject(std::move(contextRef
)), attr(attr
) {}
1006 bool operator==(const PyAttribute
&other
) const;
1007 operator MlirAttribute() const { return attr
; }
1008 MlirAttribute
get() const { return attr
; }
1010 /// Gets a capsule wrapping the void* within the MlirAttribute.
1011 pybind11::object
getCapsule();
1013 /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
1014 /// Note that PyAttribute instances are uniqued, so the returned object
1015 /// may be a pre-existing object. Ownership of the underlying MlirAttribute
1016 /// is taken by calling this function.
1017 static PyAttribute
createFromCapsule(pybind11::object capsule
);
1023 /// Represents a Python MlirNamedAttr, carrying an optional owned name.
1024 /// TODO: Refactor this and the C-API to be based on an Identifier owned
1025 /// by the context so as to avoid ownership issues here.
1026 class PyNamedAttribute
{
1028 /// Constructs a PyNamedAttr that retains an owned name. This should be
1029 /// used in any code that originates an MlirNamedAttribute from a python
1031 /// The lifetime of the PyNamedAttr must extend to the lifetime of the
1032 /// passed attribute.
1033 PyNamedAttribute(MlirAttribute attr
, std::string ownedName
);
1035 MlirNamedAttribute namedAttr
;
1038 // Since the MlirNamedAttr contains an internal pointer to the actual
1039 // memory of the owned string, it must be heap allocated to remain valid.
1040 // Otherwise, strings that fit within the small object optimization threshold
1041 // will have their memory address change as the containing object is moved,
1042 // resulting in an invalid aliased pointer.
1043 std::unique_ptr
<std::string
> ownedName
;
1046 /// CRTP base classes for Python attributes that subclass Attribute and should
1047 /// be castable from it (i.e. via something like StringAttr(attr)).
1048 /// By default, attribute class hierarchies are one level deep (i.e. a
1049 /// concrete attribute class extends PyAttribute); however, intermediate
1050 /// python-visible base classes can be modeled by specifying a BaseTy.
1051 template <typename DerivedTy
, typename BaseTy
= PyAttribute
>
1052 class PyConcreteAttribute
: public BaseTy
{
1054 // Derived classes must define statics for:
1055 // IsAFunctionTy isaFunction
1056 // const char *pyClassName
1057 using ClassTy
= pybind11::class_
<DerivedTy
, BaseTy
>;
1058 using IsAFunctionTy
= bool (*)(MlirAttribute
);
1059 using GetTypeIDFunctionTy
= MlirTypeID (*)();
1060 static constexpr GetTypeIDFunctionTy getTypeIdFunction
= nullptr;
1062 PyConcreteAttribute() = default;
1063 PyConcreteAttribute(PyMlirContextRef contextRef
, MlirAttribute attr
)
1064 : BaseTy(std::move(contextRef
), attr
) {}
1065 PyConcreteAttribute(PyAttribute
&orig
)
1066 : PyConcreteAttribute(orig
.getContext(), castFrom(orig
)) {}
1068 static MlirAttribute
castFrom(PyAttribute
&orig
) {
1069 if (!DerivedTy::isaFunction(orig
)) {
1070 auto origRepr
= pybind11::repr(pybind11::cast(orig
)).cast
<std::string
>();
1071 throw py::value_error((llvm::Twine("Cannot cast attribute to ") +
1072 DerivedTy::pyClassName
+ " (from " + origRepr
+
1079 static void bind(pybind11::module
&m
) {
1080 auto cls
= ClassTy(m
, DerivedTy::pyClassName
, pybind11::buffer_protocol(),
1081 pybind11::module_local());
1082 cls
.def(pybind11::init
<PyAttribute
&>(), pybind11::keep_alive
<0, 1>(),
1083 pybind11::arg("cast_from_attr"));
1086 [](PyAttribute
&otherAttr
) -> bool {
1087 return DerivedTy::isaFunction(otherAttr
);
1089 pybind11::arg("other"));
1090 cls
.def_property_readonly(
1091 "type", [](PyAttribute
&attr
) { return mlirAttributeGetType(attr
); });
1092 cls
.def_property_readonly_static(
1093 "static_typeid", [](py::object
& /*class*/) -> MlirTypeID
{
1094 if (DerivedTy::getTypeIdFunction
)
1095 return DerivedTy::getTypeIdFunction();
1096 throw py::attribute_error(
1097 (DerivedTy::pyClassName
+ llvm::Twine(" has no typeid.")).str());
1099 cls
.def_property_readonly("typeid", [](PyAttribute
&self
) {
1100 return py::cast(self
).attr("typeid").cast
<MlirTypeID
>();
1102 cls
.def("__repr__", [](DerivedTy
&self
) {
1103 PyPrintAccumulator printAccum
;
1104 printAccum
.parts
.append(DerivedTy::pyClassName
);
1105 printAccum
.parts
.append("(");
1106 mlirAttributePrint(self
, printAccum
.getCallback(),
1107 printAccum
.getUserData());
1108 printAccum
.parts
.append(")");
1109 return printAccum
.join();
1112 if (DerivedTy::getTypeIdFunction
) {
1113 PyGlobals::get().registerTypeCaster(
1114 DerivedTy::getTypeIdFunction(),
1115 pybind11::cpp_function([](PyAttribute pyAttribute
) -> DerivedTy
{
1120 DerivedTy::bindDerived(cls
);
1123 /// Implemented by derived classes to add methods to the Python subclass.
1124 static void bindDerived(ClassTy
&m
) {}
1127 /// Wrapper around the generic MlirValue.
1128 /// Values are managed completely by the operation that resulted in their
1129 /// definition. For op result value, this is the operation that defines the
1130 /// value. For block argument values, this is the operation that contains the
1131 /// block to which the value is an argument (blocks cannot be detached in Python
1132 /// bindings so such operation always exists).
1135 // The virtual here is "load bearing" in that it enables RTTI
1136 // for PyConcreteValue CRTP classes that support maybeDownCast.
1137 // See PyValue::maybeDownCast.
1138 virtual ~PyValue() = default;
1139 PyValue(PyOperationRef parentOperation
, MlirValue value
)
1140 : parentOperation(std::move(parentOperation
)), value(value
) {}
1141 operator MlirValue() const { return value
; }
1143 MlirValue
get() { return value
; }
1144 PyOperationRef
&getParentOperation() { return parentOperation
; }
1146 void checkValid() { return parentOperation
->checkValid(); }
1148 /// Gets a capsule wrapping the void* within the MlirValue.
1149 pybind11::object
getCapsule();
1151 pybind11::object
maybeDownCast();
1153 /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
1154 /// the underlying MlirValue is still tied to the owning operation.
1155 static PyValue
createFromCapsule(pybind11::object capsule
);
1158 PyOperationRef parentOperation
;
1162 /// Wrapper around MlirAffineExpr. Affine expressions are owned by the context.
1163 class PyAffineExpr
: public BaseContextObject
{
1165 PyAffineExpr(PyMlirContextRef contextRef
, MlirAffineExpr affineExpr
)
1166 : BaseContextObject(std::move(contextRef
)), affineExpr(affineExpr
) {}
1167 bool operator==(const PyAffineExpr
&other
) const;
1168 operator MlirAffineExpr() const { return affineExpr
; }
1169 MlirAffineExpr
get() const { return affineExpr
; }
1171 /// Gets a capsule wrapping the void* within the MlirAffineExpr.
1172 pybind11::object
getCapsule();
1174 /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule.
1175 /// Note that PyAffineExpr instances are uniqued, so the returned object
1176 /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr
1177 /// is taken by calling this function.
1178 static PyAffineExpr
createFromCapsule(pybind11::object capsule
);
1180 PyAffineExpr
add(const PyAffineExpr
&other
) const;
1181 PyAffineExpr
mul(const PyAffineExpr
&other
) const;
1182 PyAffineExpr
floorDiv(const PyAffineExpr
&other
) const;
1183 PyAffineExpr
ceilDiv(const PyAffineExpr
&other
) const;
1184 PyAffineExpr
mod(const PyAffineExpr
&other
) const;
1187 MlirAffineExpr affineExpr
;
1190 class PyAffineMap
: public BaseContextObject
{
1192 PyAffineMap(PyMlirContextRef contextRef
, MlirAffineMap affineMap
)
1193 : BaseContextObject(std::move(contextRef
)), affineMap(affineMap
) {}
1194 bool operator==(const PyAffineMap
&other
) const;
1195 operator MlirAffineMap() const { return affineMap
; }
1196 MlirAffineMap
get() const { return affineMap
; }
1198 /// Gets a capsule wrapping the void* within the MlirAffineMap.
1199 pybind11::object
getCapsule();
1201 /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule.
1202 /// Note that PyAffineMap instances are uniqued, so the returned object
1203 /// may be a pre-existing object. Ownership of the underlying MlirAffineMap
1204 /// is taken by calling this function.
1205 static PyAffineMap
createFromCapsule(pybind11::object capsule
);
1208 MlirAffineMap affineMap
;
1211 class PyIntegerSet
: public BaseContextObject
{
1213 PyIntegerSet(PyMlirContextRef contextRef
, MlirIntegerSet integerSet
)
1214 : BaseContextObject(std::move(contextRef
)), integerSet(integerSet
) {}
1215 bool operator==(const PyIntegerSet
&other
) const;
1216 operator MlirIntegerSet() const { return integerSet
; }
1217 MlirIntegerSet
get() const { return integerSet
; }
1219 /// Gets a capsule wrapping the void* within the MlirIntegerSet.
1220 pybind11::object
getCapsule();
1222 /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule.
1223 /// Note that PyIntegerSet instances may be uniqued, so the returned object
1224 /// may be a pre-existing object. Integer sets are owned by the context.
1225 static PyIntegerSet
createFromCapsule(pybind11::object capsule
);
1228 MlirIntegerSet integerSet
;
1231 /// Bindings for MLIR symbol tables.
1232 class PySymbolTable
{
1234 /// Constructs a symbol table for the given operation.
1235 explicit PySymbolTable(PyOperationBase
&operation
);
1237 /// Destroys the symbol table.
1238 ~PySymbolTable() { mlirSymbolTableDestroy(symbolTable
); }
1240 /// Returns the symbol (opview) with the given name, throws if there is no
1241 /// such symbol in the table.
1242 pybind11::object
dunderGetItem(const std::string
&name
);
1244 /// Removes the given operation from the symbol table and erases it.
1245 void erase(PyOperationBase
&symbol
);
1247 /// Removes the operation with the given name from the symbol table and erases
1248 /// it, throws if there is no such symbol in the table.
1249 void dunderDel(const std::string
&name
);
1251 /// Inserts the given operation into the symbol table. The operation must have
1252 /// the symbol trait.
1253 MlirAttribute
insert(PyOperationBase
&symbol
);
1255 /// Gets and sets the name of a symbol op.
1256 static MlirAttribute
getSymbolName(PyOperationBase
&symbol
);
1257 static void setSymbolName(PyOperationBase
&symbol
, const std::string
&name
);
1259 /// Gets and sets the visibility of a symbol op.
1260 static MlirAttribute
getVisibility(PyOperationBase
&symbol
);
1261 static void setVisibility(PyOperationBase
&symbol
,
1262 const std::string
&visibility
);
1264 /// Replaces all symbol uses within an operation. See the API
1265 /// mlirSymbolTableReplaceAllSymbolUses for all caveats.
1266 static void replaceAllSymbolUses(const std::string
&oldSymbol
,
1267 const std::string
&newSymbol
,
1268 PyOperationBase
&from
);
1270 /// Walks all symbol tables under and including 'from'.
1271 static void walkSymbolTables(PyOperationBase
&from
, bool allSymUsesVisible
,
1272 pybind11::object callback
);
1274 /// Casts the bindings class into the C API structure.
1275 operator MlirSymbolTable() { return symbolTable
; }
1278 PyOperationRef operation
;
1279 MlirSymbolTable symbolTable
;
1282 /// Custom exception that allows access to error diagnostic information. This is
1283 /// converted to the `ir.MLIRError` python exception when thrown.
1285 MLIRError(llvm::Twine message
,
1286 std::vector
<PyDiagnostic::DiagnosticInfo
> &&errorDiagnostics
= {})
1287 : message(message
.str()), errorDiagnostics(std::move(errorDiagnostics
)) {}
1288 std::string message
;
1289 std::vector
<PyDiagnostic::DiagnosticInfo
> errorDiagnostics
;
1292 void populateIRAffine(pybind11::module
&m
);
1293 void populateIRAttributes(pybind11::module
&m
);
1294 void populateIRCore(pybind11::module
&m
);
1295 void populateIRInterfaces(pybind11::module
&m
);
1296 void populateIRTypes(pybind11::module
&m
);
1298 } // namespace python
1301 namespace pybind11
{
1305 struct type_caster
<mlir::python::DefaultingPyMlirContext
>
1306 : MlirDefaultingCaster
<mlir::python::DefaultingPyMlirContext
> {};
1308 struct type_caster
<mlir::python::DefaultingPyLocation
>
1309 : MlirDefaultingCaster
<mlir::python::DefaultingPyLocation
> {};
1311 } // namespace detail
1312 } // namespace pybind11
1314 #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H