1 //===- DialectTransform.cpp - 'transform' dialect submodule ---------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
11 #include "mlir-c/Dialect/Transform.h"
12 #include "mlir-c/IR.h"
13 #include "mlir-c/Support.h"
14 #include "mlir/Bindings/Python/NanobindAdaptors.h"
15 #include "mlir/Bindings/Python/Nanobind.h"
17 namespace nb
= nanobind
;
19 using namespace mlir::python
;
20 using namespace mlir::python::nanobind_adaptors
;
22 void populateDialectTransformSubmodule(const nb::module_
&m
) {
23 //===-------------------------------------------------------------------===//
25 //===-------------------------------------------------------------------===//
28 mlir_type_subclass(m
, "AnyOpType", mlirTypeIsATransformAnyOpType
,
29 mlirTransformAnyOpTypeGetTypeID
);
30 anyOpType
.def_classmethod(
32 [](nb::object cls
, MlirContext ctx
) {
33 return cls(mlirTransformAnyOpTypeGet(ctx
));
35 "Get an instance of AnyOpType in the given context.", nb::arg("cls"),
36 nb::arg("context").none() = nb::none());
38 //===-------------------------------------------------------------------===//
40 //===-------------------------------------------------------------------===//
43 mlir_type_subclass(m
, "AnyParamType", mlirTypeIsATransformAnyParamType
,
44 mlirTransformAnyParamTypeGetTypeID
);
45 anyParamType
.def_classmethod(
47 [](nb::object cls
, MlirContext ctx
) {
48 return cls(mlirTransformAnyParamTypeGet(ctx
));
50 "Get an instance of AnyParamType in the given context.", nb::arg("cls"),
51 nb::arg("context").none() = nb::none());
53 //===-------------------------------------------------------------------===//
55 //===-------------------------------------------------------------------===//
58 mlir_type_subclass(m
, "AnyValueType", mlirTypeIsATransformAnyValueType
,
59 mlirTransformAnyValueTypeGetTypeID
);
60 anyValueType
.def_classmethod(
62 [](nb::object cls
, MlirContext ctx
) {
63 return cls(mlirTransformAnyValueTypeGet(ctx
));
65 "Get an instance of AnyValueType in the given context.", nb::arg("cls"),
66 nb::arg("context").none() = nb::none());
68 //===-------------------------------------------------------------------===//
70 //===-------------------------------------------------------------------===//
73 mlir_type_subclass(m
, "OperationType", mlirTypeIsATransformOperationType
,
74 mlirTransformOperationTypeGetTypeID
);
75 operationType
.def_classmethod(
77 [](nb::object cls
, const std::string
&operationName
, MlirContext ctx
) {
78 MlirStringRef cOperationName
=
79 mlirStringRefCreate(operationName
.data(), operationName
.size());
80 return cls(mlirTransformOperationTypeGet(ctx
, cOperationName
));
82 "Get an instance of OperationType for the given kind in the given "
84 nb::arg("cls"), nb::arg("operation_name"),
85 nb::arg("context").none() = nb::none());
86 operationType
.def_property_readonly(
89 MlirStringRef operationName
=
90 mlirTransformOperationTypeGetOperationName(type
);
91 return nb::str(operationName
.data
, operationName
.length
);
93 "Get the name of the payload operation accepted by the handle.");
95 //===-------------------------------------------------------------------===//
97 //===-------------------------------------------------------------------===//
100 mlir_type_subclass(m
, "ParamType", mlirTypeIsATransformParamType
,
101 mlirTransformParamTypeGetTypeID
);
102 paramType
.def_classmethod(
104 [](nb::object cls
, MlirType type
, MlirContext ctx
) {
105 return cls(mlirTransformParamTypeGet(ctx
, type
));
107 "Get an instance of ParamType for the given type in the given context.",
108 nb::arg("cls"), nb::arg("type"), nb::arg("context").none() = nb::none());
109 paramType
.def_property_readonly(
112 MlirType paramType
= mlirTransformParamTypeGetType(type
);
115 "Get the type this ParamType is associated with.");
118 NB_MODULE(_mlirDialectsTransform
, m
) {
119 m
.doc() = "MLIR Transform dialect.";
120 populateDialectTransformSubmodule(m
);