1 //===- DialectTransform.cpp - 'transform' dialect submodule ---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir-c/Dialect/Transform.h"
10 #include "mlir-c/IR.h"
11 #include "mlir-c/Support.h"
12 #include "mlir/Bindings/Python/PybindAdaptors.h"
13 #include <pybind11/cast.h>
14 #include <pybind11/detail/common.h>
15 #include <pybind11/pybind11.h>
16 #include <pybind11/pytypes.h>
19 namespace py
= pybind11
;
21 using namespace mlir::python
;
22 using namespace mlir::python::adaptors
;
24 void populateDialectTransformSubmodule(const pybind11::module
&m
) {
25 //===-------------------------------------------------------------------===//
27 //===-------------------------------------------------------------------===//
30 mlir_type_subclass(m
, "AnyOpType", mlirTypeIsATransformAnyOpType
,
31 mlirTransformAnyOpTypeGetTypeID
);
32 anyOpType
.def_classmethod(
34 [](py::object cls
, MlirContext ctx
) {
35 return cls(mlirTransformAnyOpTypeGet(ctx
));
37 "Get an instance of AnyOpType in the given context.", py::arg("cls"),
38 py::arg("context") = py::none());
40 //===-------------------------------------------------------------------===//
42 //===-------------------------------------------------------------------===//
45 mlir_type_subclass(m
, "AnyParamType", mlirTypeIsATransformAnyParamType
,
46 mlirTransformAnyParamTypeGetTypeID
);
47 anyParamType
.def_classmethod(
49 [](py::object cls
, MlirContext ctx
) {
50 return cls(mlirTransformAnyParamTypeGet(ctx
));
52 "Get an instance of AnyParamType in the given context.", py::arg("cls"),
53 py::arg("context") = py::none());
55 //===-------------------------------------------------------------------===//
57 //===-------------------------------------------------------------------===//
60 mlir_type_subclass(m
, "AnyValueType", mlirTypeIsATransformAnyValueType
,
61 mlirTransformAnyValueTypeGetTypeID
);
62 anyValueType
.def_classmethod(
64 [](py::object cls
, MlirContext ctx
) {
65 return cls(mlirTransformAnyValueTypeGet(ctx
));
67 "Get an instance of AnyValueType in the given context.", py::arg("cls"),
68 py::arg("context") = py::none());
70 //===-------------------------------------------------------------------===//
72 //===-------------------------------------------------------------------===//
75 mlir_type_subclass(m
, "OperationType", mlirTypeIsATransformOperationType
,
76 mlirTransformOperationTypeGetTypeID
);
77 operationType
.def_classmethod(
79 [](py::object cls
, const std::string
&operationName
, MlirContext ctx
) {
80 MlirStringRef cOperationName
=
81 mlirStringRefCreate(operationName
.data(), operationName
.size());
82 return cls(mlirTransformOperationTypeGet(ctx
, cOperationName
));
84 "Get an instance of OperationType for the given kind in the given "
86 py::arg("cls"), py::arg("operation_name"),
87 py::arg("context") = py::none());
88 operationType
.def_property_readonly(
91 MlirStringRef operationName
=
92 mlirTransformOperationTypeGetOperationName(type
);
93 return py::str(operationName
.data
, operationName
.length
);
95 "Get the name of the payload operation accepted by the handle.");
97 //===-------------------------------------------------------------------===//
99 //===-------------------------------------------------------------------===//
102 mlir_type_subclass(m
, "ParamType", mlirTypeIsATransformParamType
,
103 mlirTransformParamTypeGetTypeID
);
104 paramType
.def_classmethod(
106 [](py::object cls
, MlirType type
, MlirContext ctx
) {
107 return cls(mlirTransformParamTypeGet(ctx
, type
));
109 "Get an instance of ParamType for the given type in the given context.",
110 py::arg("cls"), py::arg("type"), py::arg("context") = py::none());
111 paramType
.def_property_readonly(
114 MlirType paramType
= mlirTransformParamTypeGetType(type
);
117 "Get the type this ParamType is associated with.");
120 PYBIND11_MODULE(_mlirDialectsTransform
, m
) {
121 m
.doc() = "MLIR Transform dialect.";
122 populateDialectTransformSubmodule(m
);