[LLVM] Fix Maintainers.md formatting (NFC)
[llvm-project.git] / mlir / lib / Bindings / Python / Pass.cpp
blob1d0e5ce2115a0a293e063232eccc085b10e53c20
1 //===- Pass.cpp - Pass Management -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "Pass.h"
11 #include "IRModule.h"
12 #include "mlir-c/Bindings/Python/Interop.h"
13 #include "mlir-c/Pass.h"
15 namespace py = pybind11;
16 using namespace py::literals;
17 using namespace mlir;
18 using namespace mlir::python;
20 namespace {
22 /// Owning Wrapper around a PassManager.
23 class PyPassManager {
24 public:
25 PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
26 PyPassManager(PyPassManager &&other) noexcept
27 : passManager(other.passManager) {
28 other.passManager.ptr = nullptr;
30 ~PyPassManager() {
31 if (!mlirPassManagerIsNull(passManager))
32 mlirPassManagerDestroy(passManager);
34 MlirPassManager get() { return passManager; }
36 void release() { passManager.ptr = nullptr; }
37 pybind11::object getCapsule() {
38 return py::reinterpret_steal<py::object>(
39 mlirPythonPassManagerToCapsule(get()));
42 static pybind11::object createFromCapsule(pybind11::object capsule) {
43 MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
44 if (mlirPassManagerIsNull(rawPm))
45 throw py::error_already_set();
46 return py::cast(PyPassManager(rawPm), py::return_value_policy::move);
49 private:
50 MlirPassManager passManager;
53 } // namespace
55 /// Create the `mlir.passmanager` here.
56 void mlir::python::populatePassManagerSubmodule(py::module &m) {
57 //----------------------------------------------------------------------------
58 // Mapping of the top-level PassManager
59 //----------------------------------------------------------------------------
60 py::class_<PyPassManager>(m, "PassManager", py::module_local())
61 .def(py::init<>([](const std::string &anchorOp,
62 DefaultingPyMlirContext context) {
63 MlirPassManager passManager = mlirPassManagerCreateOnOperation(
64 context->get(),
65 mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
66 return new PyPassManager(passManager);
67 }),
68 "anchor_op"_a = py::str("any"), "context"_a = py::none(),
69 "Create a new PassManager for the current (or provided) Context.")
70 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
71 &PyPassManager::getCapsule)
72 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
73 .def("_testing_release", &PyPassManager::release,
74 "Releases (leaks) the backing pass manager (testing)")
75 .def(
76 "enable_ir_printing",
77 [](PyPassManager &passManager, bool printBeforeAll,
78 bool printAfterAll, bool printModuleScope, bool printAfterChange,
79 bool printAfterFailure) {
80 mlirPassManagerEnableIRPrinting(
81 passManager.get(), printBeforeAll, printAfterAll,
82 printModuleScope, printAfterChange, printAfterFailure);
84 "print_before_all"_a = false, "print_after_all"_a = true,
85 "print_module_scope"_a = false, "print_after_change"_a = false,
86 "print_after_failure"_a = false,
87 "Enable IR printing, default as mlir-print-ir-after-all.")
88 .def(
89 "enable_verifier",
90 [](PyPassManager &passManager, bool enable) {
91 mlirPassManagerEnableVerifier(passManager.get(), enable);
93 "enable"_a, "Enable / disable verify-each.")
94 .def_static(
95 "parse",
96 [](const std::string &pipeline, DefaultingPyMlirContext context) {
97 MlirPassManager passManager = mlirPassManagerCreate(context->get());
98 PyPrintAccumulator errorMsg;
99 MlirLogicalResult status = mlirParsePassPipeline(
100 mlirPassManagerGetAsOpPassManager(passManager),
101 mlirStringRefCreate(pipeline.data(), pipeline.size()),
102 errorMsg.getCallback(), errorMsg.getUserData());
103 if (mlirLogicalResultIsFailure(status))
104 throw py::value_error(std::string(errorMsg.join()));
105 return new PyPassManager(passManager);
107 "pipeline"_a, "context"_a = py::none(),
108 "Parse a textual pass-pipeline and return a top-level PassManager "
109 "that can be applied on a Module. Throw a ValueError if the pipeline "
110 "can't be parsed")
111 .def(
112 "add",
113 [](PyPassManager &passManager, const std::string &pipeline) {
114 PyPrintAccumulator errorMsg;
115 MlirLogicalResult status = mlirOpPassManagerAddPipeline(
116 mlirPassManagerGetAsOpPassManager(passManager.get()),
117 mlirStringRefCreate(pipeline.data(), pipeline.size()),
118 errorMsg.getCallback(), errorMsg.getUserData());
119 if (mlirLogicalResultIsFailure(status))
120 throw py::value_error(std::string(errorMsg.join()));
122 "pipeline"_a,
123 "Add textual pipeline elements to the pass manager. Throws a "
124 "ValueError if the pipeline can't be parsed.")
125 .def(
126 "run",
127 [](PyPassManager &passManager, PyOperationBase &op,
128 bool invalidateOps) {
129 if (invalidateOps) {
130 op.getOperation().getContext()->clearOperationsInside(op);
132 // Actually run the pass manager.
133 PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
134 MlirLogicalResult status = mlirPassManagerRunOnOp(
135 passManager.get(), op.getOperation().get());
136 if (mlirLogicalResultIsFailure(status))
137 throw MLIRError("Failure while executing pass pipeline",
138 errors.take());
140 "operation"_a, "invalidate_ops"_a = true,
141 "Run the pass manager on the provided operation, raising an "
142 "MLIRError on failure.")
143 .def(
144 "__str__",
145 [](PyPassManager &self) {
146 MlirPassManager passManager = self.get();
147 PyPrintAccumulator printAccum;
148 mlirPrintPassPipeline(
149 mlirPassManagerGetAsOpPassManager(passManager),
150 printAccum.getCallback(), printAccum.getUserData());
151 return printAccum.join();
153 "Print the textual representation for this PassManager, suitable to "
154 "be passed to `parse` for round-tripping.");