1 //===- Pass.cpp - Pass Management -----------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
12 #include "mlir-c/Pass.h"
13 #include "mlir/Bindings/Python/Nanobind.h"
14 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
16 namespace nb
= nanobind
;
17 using namespace nb::literals
;
19 using namespace mlir::python
;
23 /// Owning Wrapper around a PassManager.
26 PyPassManager(MlirPassManager passManager
) : passManager(passManager
) {}
27 PyPassManager(PyPassManager
&&other
) noexcept
28 : passManager(other
.passManager
) {
29 other
.passManager
.ptr
= nullptr;
32 if (!mlirPassManagerIsNull(passManager
))
33 mlirPassManagerDestroy(passManager
);
35 MlirPassManager
get() { return passManager
; }
37 void release() { passManager
.ptr
= nullptr; }
38 nb::object
getCapsule() {
39 return nb::steal
<nb::object
>(mlirPythonPassManagerToCapsule(get()));
42 static nb::object
createFromCapsule(nb::object capsule
) {
43 MlirPassManager rawPm
= mlirPythonCapsuleToPassManager(capsule
.ptr());
44 if (mlirPassManagerIsNull(rawPm
))
45 throw nb::python_error();
46 return nb::cast(PyPassManager(rawPm
), nb::rv_policy::move
);
50 MlirPassManager passManager
;
55 /// Create the `mlir.passmanager` here.
56 void mlir::python::populatePassManagerSubmodule(nb::module_
&m
) {
57 //----------------------------------------------------------------------------
58 // Mapping of the top-level PassManager
59 //----------------------------------------------------------------------------
60 nb::class_
<PyPassManager
>(m
, "PassManager")
63 [](PyPassManager
&self
, const std::string
&anchorOp
,
64 DefaultingPyMlirContext context
) {
65 MlirPassManager passManager
= mlirPassManagerCreateOnOperation(
67 mlirStringRefCreate(anchorOp
.data(), anchorOp
.size()));
68 new (&self
) PyPassManager(passManager
);
70 "anchor_op"_a
= nb::str("any"), "context"_a
.none() = nb::none(),
71 "Create a new PassManager for the current (or provided) Context.")
72 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR
, &PyPassManager::getCapsule
)
73 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR
, &PyPassManager::createFromCapsule
)
74 .def("_testing_release", &PyPassManager::release
,
75 "Releases (leaks) the backing pass manager (testing)")
78 [](PyPassManager
&passManager
, bool printBeforeAll
,
79 bool printAfterAll
, bool printModuleScope
, bool printAfterChange
,
80 bool printAfterFailure
, std::optional
<int64_t> largeElementsLimit
,
81 bool enableDebugInfo
, bool printGenericOpForm
,
82 std::optional
<std::string
> optionalTreePrintingPath
) {
83 MlirOpPrintingFlags flags
= mlirOpPrintingFlagsCreate();
84 if (largeElementsLimit
)
85 mlirOpPrintingFlagsElideLargeElementsAttrs(flags
,
88 mlirOpPrintingFlagsEnableDebugInfo(flags
, /*enable=*/true,
89 /*prettyForm=*/false);
90 if (printGenericOpForm
)
91 mlirOpPrintingFlagsPrintGenericOpForm(flags
);
92 std::string treePrintingPath
= "";
93 if (optionalTreePrintingPath
.has_value())
94 treePrintingPath
= optionalTreePrintingPath
.value();
95 mlirPassManagerEnableIRPrinting(
96 passManager
.get(), printBeforeAll
, printAfterAll
,
97 printModuleScope
, printAfterChange
, printAfterFailure
, flags
,
98 mlirStringRefCreate(treePrintingPath
.data(),
99 treePrintingPath
.size()));
100 mlirOpPrintingFlagsDestroy(flags
);
102 "print_before_all"_a
= false, "print_after_all"_a
= true,
103 "print_module_scope"_a
= false, "print_after_change"_a
= false,
104 "print_after_failure"_a
= false,
105 "large_elements_limit"_a
.none() = nb::none(),
106 "enable_debug_info"_a
= false, "print_generic_op_form"_a
= false,
107 "tree_printing_dir_path"_a
.none() = nb::none(),
108 "Enable IR printing, default as mlir-print-ir-after-all.")
111 [](PyPassManager
&passManager
, bool enable
) {
112 mlirPassManagerEnableVerifier(passManager
.get(), enable
);
114 "enable"_a
, "Enable / disable verify-each.")
117 [](const std::string
&pipeline
, DefaultingPyMlirContext context
) {
118 MlirPassManager passManager
= mlirPassManagerCreate(context
->get());
119 PyPrintAccumulator errorMsg
;
120 MlirLogicalResult status
= mlirParsePassPipeline(
121 mlirPassManagerGetAsOpPassManager(passManager
),
122 mlirStringRefCreate(pipeline
.data(), pipeline
.size()),
123 errorMsg
.getCallback(), errorMsg
.getUserData());
124 if (mlirLogicalResultIsFailure(status
))
125 throw nb::value_error(errorMsg
.join().c_str());
126 return new PyPassManager(passManager
);
128 "pipeline"_a
, "context"_a
.none() = nb::none(),
129 "Parse a textual pass-pipeline and return a top-level PassManager "
130 "that can be applied on a Module. Throw a ValueError if the pipeline "
134 [](PyPassManager
&passManager
, const std::string
&pipeline
) {
135 PyPrintAccumulator errorMsg
;
136 MlirLogicalResult status
= mlirOpPassManagerAddPipeline(
137 mlirPassManagerGetAsOpPassManager(passManager
.get()),
138 mlirStringRefCreate(pipeline
.data(), pipeline
.size()),
139 errorMsg
.getCallback(), errorMsg
.getUserData());
140 if (mlirLogicalResultIsFailure(status
))
141 throw nb::value_error(errorMsg
.join().c_str());
144 "Add textual pipeline elements to the pass manager. Throws a "
145 "ValueError if the pipeline can't be parsed.")
148 [](PyPassManager
&passManager
, PyOperationBase
&op
,
149 bool invalidateOps
) {
151 op
.getOperation().getContext()->clearOperationsInside(op
);
153 // Actually run the pass manager.
154 PyMlirContext::ErrorCapture
errors(op
.getOperation().getContext());
155 MlirLogicalResult status
= mlirPassManagerRunOnOp(
156 passManager
.get(), op
.getOperation().get());
157 if (mlirLogicalResultIsFailure(status
))
158 throw MLIRError("Failure while executing pass pipeline",
161 "operation"_a
, "invalidate_ops"_a
= true,
162 "Run the pass manager on the provided operation, raising an "
163 "MLIRError on failure.")
166 [](PyPassManager
&self
) {
167 MlirPassManager passManager
= self
.get();
168 PyPrintAccumulator printAccum
;
169 mlirPrintPassPipeline(
170 mlirPassManagerGetAsOpPassManager(passManager
),
171 printAccum
.getCallback(), printAccum
.getUserData());
172 return printAccum
.join();
174 "Print the textual representation for this PassManager, suitable to "
175 "be passed to `parse` for round-tripping.");