1 // RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s
2 // RUN: mlir-tblgen -gen-op-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
4 include "mlir/IR/OpBase.td"
5 include "mlir/Interfaces/InferTypeOpInterface.td"
7 def Test_Dialect : Dialect {
9 let usePropertiesForAttributes = 0;
11 class NS_Op<string mnemonic, list<Trait> traits> :
12 Op<Test_Dialect, mnemonic, traits>;
14 def OpA : NS_Op<"one_normal_result_op", []> {
15 let results = (outs I32:$result);
18 // CHECK-LABEL: void OpA::build
19 // CHECK: ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands
20 // CHECK: assert(resultTypes.size() == 1u && "mismatched number of return types");
21 // CHECK-NEXT: odsState.addTypes(resultTypes);
23 def OpB : NS_Op<"same_input_output_type_op", [SameOperandsAndResultType]> {
24 let arguments = (ins I32:$x);
25 let results = (outs I32:$y);
28 // CHECK-LABEL: OpB definitions
29 // CHECK: void OpB::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type y, ::mlir::Value x)
30 // CHECK: odsState.addTypes(y);
31 // CHECK: void OpB::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value x)
32 // CHECK: ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
33 // CHECK: if (::mlir::succeeded(OpB::inferReturnTypes(odsBuilder.getContext(),
34 // CHECK: odsState.location, odsState.operands,
35 // CHECK: odsState.attributes.getDictionary(odsState.getContext()),
36 // CHECK: odsState.regions, inferredReturnTypes)))
37 // CHECK: odsState.addTypes(inferredReturnTypes);
39 def OpC : NS_Op<"three_normal_result_op", []> {
40 let results = (outs I32:$x, /*unnamed*/I32, I32:$z);
43 // CHECK-LABEL: OpC definitions
44 // CHECK: void OpC::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type x, ::mlir::Type resultType1, ::mlir::Type z)
45 // CHECK-NEXT: odsState.addTypes(x)
46 // CHECK-NEXT: odsState.addTypes(resultType1)
47 // CHECK-NEXT: odsState.addTypes(z)
49 // CHECK: void OpC::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes) {
50 // CHECK-NEXT: assert(resultTypes.size() == 3u && "mismatched number of results");
51 // CHECK-NEXT: odsState.addTypes(resultTypes);
53 def IntegerTypeAttr : TypeAttrBase<"IntegerType", "Integer type attribute">;
54 def OpD : NS_Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> {
55 let arguments = (ins I32:$x, IntegerTypeAttr:$attr, F32Attr:$f32);
56 let results = (outs AnyTensor:$y);
59 // CHECK-LABEL: OpD definitions
60 // CHECK: void OpD::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes)
61 // CHECK: odsState.addTypes({::llvm::cast<::mlir::TypeAttr>(attr.getValue()).getValue()});
63 def OpE : NS_Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> {
64 let arguments = (ins I32:$x, F32Attr:$attr);
65 let results = (outs AnyTensor:$y);
68 // CHECK-LABEL: OpE definitions
69 // CHECK: void OpE::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes)
70 // CHECK: odsState.addTypes({::llvm::cast<::mlir::TypedAttr>(attr.getValue()).getType()});
72 def OpF : NS_Op<"one_variadic_result_op", []> {
73 let results = (outs Variadic<I32>:$x);
76 // CHECK-LABEL: void OpF::build
77 // CHECK-SAME: ::mlir::TypeRange x
79 // CHECK: odsState.addTypes(x);
81 def OpG : NS_Op<"one_normal_and_one_variadic_result_op", []> {
83 let results = (outs I32:$x, Variadic<I32>:$y);
86 // CHECK-LABEL: OpG definitions
88 // CHECK: void OpG::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type x, ::mlir::TypeRange y)
89 // CHECK-NEXT: odsState.addTypes(x);
90 // CHECK-NEXT: odsState.addTypes(y);
92 // CHECK: void OpG::build
93 // CHECK: ::mlir::TypeRange resultTypes
94 // CHECK: assert(resultTypes.size() >= 1u && "mismatched number of return types");
95 // CHECK-NEXT: odsState.addTypes(resultTypes);
97 def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]> {
98 let results = (outs Variadic<AnyTensor>:$output1, AnyTensor:$output2, Variadic<AnyTensor>:$output3);
101 // DECL-LABEL: ::mlir::Operation::result_range getOutput1
102 // DECL-NEXT: return getODSResults(0);
104 // DECL-LABEL: ::mlir::TypedValue<::mlir::TensorType> getOutput2
105 // DECL-NEXT: return ::llvm::cast<::mlir::TypedValue<::mlir::TensorType>>(*getODSResults(1).begin());
107 // CHECK-LABEL: OpI::build
108 // CHECK-NEXT: odsState.addTypes(output1);
109 // CHECK-NEXT: odsState.addTypes(output2);
110 // CHECK-NEXT: odsState.addTypes(output3);
112 // Test that if the only operand is variadic, we access the first value in the
113 // pack to set result type
115 def OpK : NS_Op<"only_input_is_variadic_with_same_value_type_op", [SameOperandsAndResultType]> {
116 let arguments = (ins Variadic<AnyTensor>:$input);
117 let results = (outs AnyTensor:$result);
120 // CHECK-LABEL: OpK::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes)
121 // CHECK: odsState.addTypes({operands[0].getType()});
123 // Test with inferred shapes and interleaved with operands/attributes.
125 def OpL1 : NS_Op<"op_with_all_types_constraint",
126 [AllTypesMatch<["a", "b"]>]> {
127 let arguments = (ins I32Attr:$attr1, AnyType:$a);
128 let results = (outs Res<AnyType, "output b", []>:$b);
131 // CHECK-LABEL: LogicalResult OpL1::inferReturnTypes
133 // CHECK: if (operands.size() <= 0)
134 // CHECK-NEXT: return ::mlir::failure();
135 // CHECK: ::mlir::Type odsInferredType0 = operands[0].getType();
136 // CHECK: inferredReturnTypes[0] = odsInferredType0;
138 def OpL2 : NS_Op<"op_with_all_types_constraint",
139 [AllTypesMatch<["c", "b"]>, AllTypesMatch<["a", "d"]>]> {
140 let arguments = (ins I32Attr:$attr1, AnyType:$a, AnyType:$a2, AnyType:$c);
141 let results = (outs Res<AnyType, "output b", []>:$b, AnyType:$d);
144 // CHECK-LABEL: LogicalResult OpL2::inferReturnTypes
146 // CHECK: if (operands.size() <= 2)
147 // CHECK-NEXT: return ::mlir::failure();
148 // CHECK-NOT: if (operands.size() <= 0)
149 // CHECK: ::mlir::Type odsInferredType0 = operands[2].getType();
150 // CHECK: ::mlir::Type odsInferredType1 = operands[0].getType();
151 // CHECK: inferredReturnTypes[0] = odsInferredType0;
152 // CHECK: inferredReturnTypes[1] = odsInferredType1;
154 def OpL3 : NS_Op<"op_with_all_types_constraint",
155 [AllTypesMatch<["a", "b"]>]> {
156 let arguments = (ins I32Attr:$a);
157 let results = (outs AnyType:$b);
160 // CHECK-LABEL: LogicalResult OpL3::inferReturnTypes
162 // CHECK: ::mlir::Type odsInferredType0 = odsInferredTypeAttr0.getType();
163 // CHECK: inferredReturnTypes[0] = odsInferredType0;
165 def OpL4 : NS_Op<"two_inference_edges", [
166 TypesMatchWith<"", "a", "b", "infer0($_self)">,
167 TypesMatchWith<"", "b", "c", "infer1($_self)">,
168 TypesMatchWith<"", "input", "a", "fromInput($_self)">]> {
169 let arguments = (ins I32:$input);
170 let results = (outs AnyType:$a, AnyType:$b, AnyType:$c);
173 // CHECK-LABEL: LogicalResult OpL4::inferReturnTypes
174 // CHECK: if (operands.size() <= 0)
175 // CHECK-NEXT: return ::mlir::failure();
176 // CHECK: odsInferredType0 = fromInput(operands[0].getType())
177 // CHECK: odsInferredType1 = infer0(odsInferredType0)
178 // CHECK: odsInferredType2 = infer1(odsInferredType1)
179 // CHECK: inferredReturnTypes[0] = odsInferredType0
180 // CHECK: inferredReturnTypes[1] = odsInferredType1
181 // CHECK: inferredReturnTypes[2] = odsInferredType2
183 def OpM : NS_Op<"mix_diff_size_variadic_and_normal_results_op", [AttrSizedResultSegments]> {
184 let results = (outs Variadic<AnyTensor>:$output1, AnyTensor:$output2, Optional<AnyTensor>:$output3);
187 // CHECK-LABEL: OpM::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange output1, ::mlir::Type output2, /*optional*/::mlir::Type output3)
188 // CHECK: odsState.addAttribute(getResultSegmentSizesAttrName(odsState.name), odsBuilder.getDenseI32ArrayAttr({static_cast<int32_t>(output1.size()), 1, (output3 ? 1 : 0)}));