[LLVM] Fix Maintainers.md formatting (NFC)
[llvm-project.git] / mlir / lib / Bindings / Python / DialectTransform.cpp
blob6b57e652aa9d8b0e18a66436501fe470c3206981
1 //===- DialectTransform.cpp - 'transform' dialect submodule ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
9 #include "mlir-c/Dialect/Transform.h"
10 #include "mlir-c/IR.h"
11 #include "mlir-c/Support.h"
12 #include "mlir/Bindings/Python/PybindAdaptors.h"
13 #include <pybind11/cast.h>
14 #include <pybind11/detail/common.h>
15 #include <pybind11/pybind11.h>
16 #include <pybind11/pytypes.h>
17 #include <string>
19 namespace py = pybind11;
20 using namespace mlir;
21 using namespace mlir::python;
22 using namespace mlir::python::adaptors;
24 void populateDialectTransformSubmodule(const pybind11::module &m) {
25 //===-------------------------------------------------------------------===//
26 // AnyOpType
27 //===-------------------------------------------------------------------===//
29 auto anyOpType =
30 mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType,
31 mlirTransformAnyOpTypeGetTypeID);
32 anyOpType.def_classmethod(
33 "get",
34 [](py::object cls, MlirContext ctx) {
35 return cls(mlirTransformAnyOpTypeGet(ctx));
37 "Get an instance of AnyOpType in the given context.", py::arg("cls"),
38 py::arg("context") = py::none());
40 //===-------------------------------------------------------------------===//
41 // AnyParamType
42 //===-------------------------------------------------------------------===//
44 auto anyParamType =
45 mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType,
46 mlirTransformAnyParamTypeGetTypeID);
47 anyParamType.def_classmethod(
48 "get",
49 [](py::object cls, MlirContext ctx) {
50 return cls(mlirTransformAnyParamTypeGet(ctx));
52 "Get an instance of AnyParamType in the given context.", py::arg("cls"),
53 py::arg("context") = py::none());
55 //===-------------------------------------------------------------------===//
56 // AnyValueType
57 //===-------------------------------------------------------------------===//
59 auto anyValueType =
60 mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType,
61 mlirTransformAnyValueTypeGetTypeID);
62 anyValueType.def_classmethod(
63 "get",
64 [](py::object cls, MlirContext ctx) {
65 return cls(mlirTransformAnyValueTypeGet(ctx));
67 "Get an instance of AnyValueType in the given context.", py::arg("cls"),
68 py::arg("context") = py::none());
70 //===-------------------------------------------------------------------===//
71 // OperationType
72 //===-------------------------------------------------------------------===//
74 auto operationType =
75 mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType,
76 mlirTransformOperationTypeGetTypeID);
77 operationType.def_classmethod(
78 "get",
79 [](py::object cls, const std::string &operationName, MlirContext ctx) {
80 MlirStringRef cOperationName =
81 mlirStringRefCreate(operationName.data(), operationName.size());
82 return cls(mlirTransformOperationTypeGet(ctx, cOperationName));
84 "Get an instance of OperationType for the given kind in the given "
85 "context",
86 py::arg("cls"), py::arg("operation_name"),
87 py::arg("context") = py::none());
88 operationType.def_property_readonly(
89 "operation_name",
90 [](MlirType type) {
91 MlirStringRef operationName =
92 mlirTransformOperationTypeGetOperationName(type);
93 return py::str(operationName.data, operationName.length);
95 "Get the name of the payload operation accepted by the handle.");
97 //===-------------------------------------------------------------------===//
98 // ParamType
99 //===-------------------------------------------------------------------===//
101 auto paramType =
102 mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType,
103 mlirTransformParamTypeGetTypeID);
104 paramType.def_classmethod(
105 "get",
106 [](py::object cls, MlirType type, MlirContext ctx) {
107 return cls(mlirTransformParamTypeGet(ctx, type));
109 "Get an instance of ParamType for the given type in the given context.",
110 py::arg("cls"), py::arg("type"), py::arg("context") = py::none());
111 paramType.def_property_readonly(
112 "type",
113 [](MlirType type) {
114 MlirType paramType = mlirTransformParamTypeGetType(type);
115 return paramType;
117 "Get the type this ParamType is associated with.");
120 PYBIND11_MODULE(_mlirDialectsTransform, m) {
121 m.doc() = "MLIR Transform dialect.";
122 populateDialectTransformSubmodule(m);