1 //===- TransformInterpreter.cpp -------------------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // Pybind classes for the transform dialect interpreter.
11 //===----------------------------------------------------------------------===//
13 #include "mlir-c/Dialect/Transform/Interpreter.h"
14 #include "mlir-c/IR.h"
15 #include "mlir-c/Support.h"
16 #include "mlir/Bindings/Python/PybindAdaptors.h"
18 #include <pybind11/detail/common.h>
19 #include <pybind11/pybind11.h>
21 namespace py
= pybind11
;
24 struct PyMlirTransformOptions
{
25 PyMlirTransformOptions() { options
= mlirTransformOptionsCreate(); };
26 PyMlirTransformOptions(PyMlirTransformOptions
&&other
) {
27 options
= other
.options
;
28 other
.options
.ptr
= nullptr;
30 PyMlirTransformOptions(const PyMlirTransformOptions
&) = delete;
32 ~PyMlirTransformOptions() { mlirTransformOptionsDestroy(options
); }
34 MlirTransformOptions options
;
38 static void populateTransformInterpreterSubmodule(py::module
&m
) {
39 py::class_
<PyMlirTransformOptions
>(m
, "TransformOptions", py::module_local())
43 [](const PyMlirTransformOptions
&self
) {
44 return mlirTransformOptionsGetExpensiveChecksEnabled(self
.options
);
46 [](PyMlirTransformOptions
&self
, bool value
) {
47 mlirTransformOptionsEnableExpensiveChecks(self
.options
, value
);
50 "enforce_single_top_level_transform_op",
51 [](const PyMlirTransformOptions
&self
) {
52 return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
55 [](PyMlirTransformOptions
&self
, bool value
) {
56 mlirTransformOptionsEnforceSingleTopLevelTransformOp(self
.options
,
61 "apply_named_sequence",
62 [](MlirOperation payloadRoot
, MlirOperation transformRoot
,
63 MlirOperation transformModule
, const PyMlirTransformOptions
&options
) {
64 mlir::python::CollectDiagnosticsToStringScope
scope(
65 mlirOperationGetContext(transformRoot
));
67 // Calling back into Python to invalidate everything under the payload
68 // root. This is awkward, but we don't have access to PyMlirContext
69 // object here otherwise.
70 py::object obj
= py::cast(payloadRoot
);
71 obj
.attr("context").attr("_clear_live_operations_inside")(payloadRoot
);
73 MlirLogicalResult result
= mlirTransformApplyNamedSequence(
74 payloadRoot
, transformRoot
, transformModule
, options
.options
);
75 if (mlirLogicalResultIsSuccess(result
))
78 throw py::value_error(
79 "Failed to apply named transform sequence.\nDiagnostic message " +
82 py::arg("payload_root"), py::arg("transform_root"),
83 py::arg("transform_module"),
84 py::arg("transform_options") = PyMlirTransformOptions());
87 "copy_symbols_and_merge_into",
88 [](MlirOperation target
, MlirOperation other
) {
89 mlir::python::CollectDiagnosticsToStringScope
scope(
90 mlirOperationGetContext(target
));
92 MlirLogicalResult result
= mlirMergeSymbolsIntoFromClone(target
, other
);
93 if (mlirLogicalResultIsFailure(result
)) {
94 throw py::value_error(
95 "Failed to merge symbols.\nDiagnostic message " +
99 py::arg("target"), py::arg("other"));
102 PYBIND11_MODULE(_mlirTransformInterpreter
, m
) {
103 m
.doc() = "MLIR Transform dialect interpreter functionality.";
104 populateTransformInterpreterSubmodule(m
);