[MLIR][LLVM] Fold extract of extract (#125980)
[llvm-project.git] / mlir / lib / Bindings / Python / Pass.cpp
blob858c3bd5745feebb074812e8536fcb8ffd8570a1
1 //===- Pass.cpp - Pass Management -----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "Pass.h"
11 #include "IRModule.h"
12 #include "mlir-c/Pass.h"
13 #include "mlir/Bindings/Python/Nanobind.h"
14 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
16 namespace nb = nanobind;
17 using namespace nb::literals;
18 using namespace mlir;
19 using namespace mlir::python;
21 namespace {
23 /// Owning Wrapper around a PassManager.
24 class PyPassManager {
25 public:
26 PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
27 PyPassManager(PyPassManager &&other) noexcept
28 : passManager(other.passManager) {
29 other.passManager.ptr = nullptr;
31 ~PyPassManager() {
32 if (!mlirPassManagerIsNull(passManager))
33 mlirPassManagerDestroy(passManager);
35 MlirPassManager get() { return passManager; }
37 void release() { passManager.ptr = nullptr; }
38 nb::object getCapsule() {
39 return nb::steal<nb::object>(mlirPythonPassManagerToCapsule(get()));
42 static nb::object createFromCapsule(nb::object capsule) {
43 MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
44 if (mlirPassManagerIsNull(rawPm))
45 throw nb::python_error();
46 return nb::cast(PyPassManager(rawPm), nb::rv_policy::move);
49 private:
50 MlirPassManager passManager;
53 } // namespace
55 /// Create the `mlir.passmanager` here.
56 void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
57 //----------------------------------------------------------------------------
58 // Mapping of the top-level PassManager
59 //----------------------------------------------------------------------------
60 nb::class_<PyPassManager>(m, "PassManager")
61 .def(
62 "__init__",
63 [](PyPassManager &self, const std::string &anchorOp,
64 DefaultingPyMlirContext context) {
65 MlirPassManager passManager = mlirPassManagerCreateOnOperation(
66 context->get(),
67 mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
68 new (&self) PyPassManager(passManager);
70 "anchor_op"_a = nb::str("any"), "context"_a.none() = nb::none(),
71 "Create a new PassManager for the current (or provided) Context.")
72 .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule)
73 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
74 .def("_testing_release", &PyPassManager::release,
75 "Releases (leaks) the backing pass manager (testing)")
76 .def(
77 "enable_ir_printing",
78 [](PyPassManager &passManager, bool printBeforeAll,
79 bool printAfterAll, bool printModuleScope, bool printAfterChange,
80 bool printAfterFailure, std::optional<int64_t> largeElementsLimit,
81 bool enableDebugInfo, bool printGenericOpForm,
82 std::optional<std::string> optionalTreePrintingPath) {
83 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
84 if (largeElementsLimit)
85 mlirOpPrintingFlagsElideLargeElementsAttrs(flags,
86 *largeElementsLimit);
87 if (enableDebugInfo)
88 mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
89 /*prettyForm=*/false);
90 if (printGenericOpForm)
91 mlirOpPrintingFlagsPrintGenericOpForm(flags);
92 std::string treePrintingPath = "";
93 if (optionalTreePrintingPath.has_value())
94 treePrintingPath = optionalTreePrintingPath.value();
95 mlirPassManagerEnableIRPrinting(
96 passManager.get(), printBeforeAll, printAfterAll,
97 printModuleScope, printAfterChange, printAfterFailure, flags,
98 mlirStringRefCreate(treePrintingPath.data(),
99 treePrintingPath.size()));
100 mlirOpPrintingFlagsDestroy(flags);
102 "print_before_all"_a = false, "print_after_all"_a = true,
103 "print_module_scope"_a = false, "print_after_change"_a = false,
104 "print_after_failure"_a = false,
105 "large_elements_limit"_a.none() = nb::none(),
106 "enable_debug_info"_a = false, "print_generic_op_form"_a = false,
107 "tree_printing_dir_path"_a.none() = nb::none(),
108 "Enable IR printing, default as mlir-print-ir-after-all.")
109 .def(
110 "enable_verifier",
111 [](PyPassManager &passManager, bool enable) {
112 mlirPassManagerEnableVerifier(passManager.get(), enable);
114 "enable"_a, "Enable / disable verify-each.")
115 .def_static(
116 "parse",
117 [](const std::string &pipeline, DefaultingPyMlirContext context) {
118 MlirPassManager passManager = mlirPassManagerCreate(context->get());
119 PyPrintAccumulator errorMsg;
120 MlirLogicalResult status = mlirParsePassPipeline(
121 mlirPassManagerGetAsOpPassManager(passManager),
122 mlirStringRefCreate(pipeline.data(), pipeline.size()),
123 errorMsg.getCallback(), errorMsg.getUserData());
124 if (mlirLogicalResultIsFailure(status))
125 throw nb::value_error(errorMsg.join().c_str());
126 return new PyPassManager(passManager);
128 "pipeline"_a, "context"_a.none() = nb::none(),
129 "Parse a textual pass-pipeline and return a top-level PassManager "
130 "that can be applied on a Module. Throw a ValueError if the pipeline "
131 "can't be parsed")
132 .def(
133 "add",
134 [](PyPassManager &passManager, const std::string &pipeline) {
135 PyPrintAccumulator errorMsg;
136 MlirLogicalResult status = mlirOpPassManagerAddPipeline(
137 mlirPassManagerGetAsOpPassManager(passManager.get()),
138 mlirStringRefCreate(pipeline.data(), pipeline.size()),
139 errorMsg.getCallback(), errorMsg.getUserData());
140 if (mlirLogicalResultIsFailure(status))
141 throw nb::value_error(errorMsg.join().c_str());
143 "pipeline"_a,
144 "Add textual pipeline elements to the pass manager. Throws a "
145 "ValueError if the pipeline can't be parsed.")
146 .def(
147 "run",
148 [](PyPassManager &passManager, PyOperationBase &op,
149 bool invalidateOps) {
150 if (invalidateOps) {
151 op.getOperation().getContext()->clearOperationsInside(op);
153 // Actually run the pass manager.
154 PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
155 MlirLogicalResult status = mlirPassManagerRunOnOp(
156 passManager.get(), op.getOperation().get());
157 if (mlirLogicalResultIsFailure(status))
158 throw MLIRError("Failure while executing pass pipeline",
159 errors.take());
161 "operation"_a, "invalidate_ops"_a = true,
162 "Run the pass manager on the provided operation, raising an "
163 "MLIRError on failure.")
164 .def(
165 "__str__",
166 [](PyPassManager &self) {
167 MlirPassManager passManager = self.get();
168 PyPrintAccumulator printAccum;
169 mlirPrintPassPipeline(
170 mlirPassManagerGetAsOpPassManager(passManager),
171 printAccum.getCallback(), printAccum.getUserData());
172 return printAccum.join();
174 "Print the textual representation for this PassManager, suitable to "
175 "be passed to `parse` for round-tripping.");