1 //===- Pass.cpp - Pass Management -----------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
12 #include "mlir-c/Bindings/Python/Interop.h"
13 #include "mlir-c/Pass.h"
15 namespace py
= pybind11
;
16 using namespace py::literals
;
18 using namespace mlir::python
;
22 /// Owning Wrapper around a PassManager.
25 PyPassManager(MlirPassManager passManager
) : passManager(passManager
) {}
26 PyPassManager(PyPassManager
&&other
) noexcept
27 : passManager(other
.passManager
) {
28 other
.passManager
.ptr
= nullptr;
31 if (!mlirPassManagerIsNull(passManager
))
32 mlirPassManagerDestroy(passManager
);
34 MlirPassManager
get() { return passManager
; }
36 void release() { passManager
.ptr
= nullptr; }
37 pybind11::object
getCapsule() {
38 return py::reinterpret_steal
<py::object
>(
39 mlirPythonPassManagerToCapsule(get()));
42 static pybind11::object
createFromCapsule(pybind11::object capsule
) {
43 MlirPassManager rawPm
= mlirPythonCapsuleToPassManager(capsule
.ptr());
44 if (mlirPassManagerIsNull(rawPm
))
45 throw py::error_already_set();
46 return py::cast(PyPassManager(rawPm
), py::return_value_policy::move
);
50 MlirPassManager passManager
;
55 /// Create the `mlir.passmanager` here.
56 void mlir::python::populatePassManagerSubmodule(py::module
&m
) {
57 //----------------------------------------------------------------------------
58 // Mapping of the top-level PassManager
59 //----------------------------------------------------------------------------
60 py::class_
<PyPassManager
>(m
, "PassManager", py::module_local())
61 .def(py::init
<>([](const std::string
&anchorOp
,
62 DefaultingPyMlirContext context
) {
63 MlirPassManager passManager
= mlirPassManagerCreateOnOperation(
65 mlirStringRefCreate(anchorOp
.data(), anchorOp
.size()));
66 return new PyPassManager(passManager
);
68 "anchor_op"_a
= py::str("any"), "context"_a
= py::none(),
69 "Create a new PassManager for the current (or provided) Context.")
70 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR
,
71 &PyPassManager::getCapsule
)
72 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR
, &PyPassManager::createFromCapsule
)
73 .def("_testing_release", &PyPassManager::release
,
74 "Releases (leaks) the backing pass manager (testing)")
77 [](PyPassManager
&passManager
, bool printBeforeAll
,
78 bool printAfterAll
, bool printModuleScope
, bool printAfterChange
,
79 bool printAfterFailure
) {
80 mlirPassManagerEnableIRPrinting(
81 passManager
.get(), printBeforeAll
, printAfterAll
,
82 printModuleScope
, printAfterChange
, printAfterFailure
);
84 "print_before_all"_a
= false, "print_after_all"_a
= true,
85 "print_module_scope"_a
= false, "print_after_change"_a
= false,
86 "print_after_failure"_a
= false,
87 "Enable IR printing, default as mlir-print-ir-after-all.")
90 [](PyPassManager
&passManager
, bool enable
) {
91 mlirPassManagerEnableVerifier(passManager
.get(), enable
);
93 "enable"_a
, "Enable / disable verify-each.")
96 [](const std::string
&pipeline
, DefaultingPyMlirContext context
) {
97 MlirPassManager passManager
= mlirPassManagerCreate(context
->get());
98 PyPrintAccumulator errorMsg
;
99 MlirLogicalResult status
= mlirParsePassPipeline(
100 mlirPassManagerGetAsOpPassManager(passManager
),
101 mlirStringRefCreate(pipeline
.data(), pipeline
.size()),
102 errorMsg
.getCallback(), errorMsg
.getUserData());
103 if (mlirLogicalResultIsFailure(status
))
104 throw py::value_error(std::string(errorMsg
.join()));
105 return new PyPassManager(passManager
);
107 "pipeline"_a
, "context"_a
= py::none(),
108 "Parse a textual pass-pipeline and return a top-level PassManager "
109 "that can be applied on a Module. Throw a ValueError if the pipeline "
113 [](PyPassManager
&passManager
, const std::string
&pipeline
) {
114 PyPrintAccumulator errorMsg
;
115 MlirLogicalResult status
= mlirOpPassManagerAddPipeline(
116 mlirPassManagerGetAsOpPassManager(passManager
.get()),
117 mlirStringRefCreate(pipeline
.data(), pipeline
.size()),
118 errorMsg
.getCallback(), errorMsg
.getUserData());
119 if (mlirLogicalResultIsFailure(status
))
120 throw py::value_error(std::string(errorMsg
.join()));
123 "Add textual pipeline elements to the pass manager. Throws a "
124 "ValueError if the pipeline can't be parsed.")
127 [](PyPassManager
&passManager
, PyOperationBase
&op
,
128 bool invalidateOps
) {
130 op
.getOperation().getContext()->clearOperationsInside(op
);
132 // Actually run the pass manager.
133 PyMlirContext::ErrorCapture
errors(op
.getOperation().getContext());
134 MlirLogicalResult status
= mlirPassManagerRunOnOp(
135 passManager
.get(), op
.getOperation().get());
136 if (mlirLogicalResultIsFailure(status
))
137 throw MLIRError("Failure while executing pass pipeline",
140 "operation"_a
, "invalidate_ops"_a
= true,
141 "Run the pass manager on the provided operation, raising an "
142 "MLIRError on failure.")
145 [](PyPassManager
&self
) {
146 MlirPassManager passManager
= self
.get();
147 PyPrintAccumulator printAccum
;
148 mlirPrintPassPipeline(
149 mlirPassManagerGetAsOpPassManager(passManager
),
150 printAccum
.getCallback(), printAccum
.getUserData());
151 return printAccum
.join();
153 "Print the textual representation for this PassManager, suitable to "
154 "be passed to `parse` for round-tripping.");