[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / test / Dialect / traits.mlir
blob4d583435adeeee6cee00b07f2346f78e01249c24
1 // RUN: mlir-opt %s -split-input-file -verify-diagnostics
3 // Verify that ops with broadcastable trait verifies operand and result type
4 // combinations and emits an error for invalid combinations.
6 func.func @broadcast_scalar_scalar_scalar(tensor<i32>, tensor<i32>) -> tensor<i32> {
7 ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
8   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
9   return %0 : tensor<i32>
12 // -----
14 func.func @broadcast_tensor_scalar_tensor(tensor<4xi32>, tensor<i32>) -> tensor<4xi32> {
15 ^bb0(%arg0: tensor<4xi32>, %arg1: tensor<i32>):
16   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4xi32>, tensor<i32>) -> tensor<4xi32>
17   return %0 : tensor<4xi32>
20 // -----
22 // Check only one dimension has size 1
23 func.func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x2xi32> {
24 ^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<3x1xi32>):
25   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x2xi32>
26   return %0 : tensor<4x3x2xi32>
29 // -----
31 // Check multiple dimensions have size 1
32 func.func @broadcast_tensor_tensor_tensor(tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32> {
33 ^bb0(%arg0: tensor<8x1x6x1xi32>, %arg1: tensor<7x1x5xi32>):
34   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x5xi32>
35   return %0 : tensor<8x7x6x5xi32>
38 // -----
40 // Check leading unknown dimension
41 func.func @broadcast_tensor_tensor_tensor(tensor<?x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<?x7x6x5xi32> {
42 ^bb0(%arg0: tensor<?x1x6x1xi32>, %arg1: tensor<7x1x5xi32>):
43   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<?x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<?x7x6x5xi32>
44   return %0 : tensor<?x7x6x5xi32>
47 // -----
49 // Check unknown dimension in the middle
50 func.func @broadcast_tensor_tensor_tensor(tensor<8x1x?x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x?x5xi32> {
51 ^bb0(%arg0: tensor<8x1x?x1xi32>, %arg1: tensor<7x1x5xi32>):
52   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<8x1x?x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x?x5xi32>
53   return %0 : tensor<8x7x?x5xi32>
56 // -----
58 // Check incompatible vector and tensor result type
59 func.func @broadcast_scalar_vector_vector(tensor<4xf32>, tensor<4xf32>) -> vector<4xf32> {
60 ^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>):
61   // expected-error @+1 {{op result #0 must be tensor of any type values, but got 'vector<4xf32>'}}
62   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> vector<4xf32>
63   return %0 : vector<4xf32>
66 // -----
68 // Check incompatible operand types with known dimension
69 func.func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<3x3xi32>) -> tensor<4x3x2xi32> {
70 ^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<3x3xi32>):
71   // expected-error @+1 {{operands don't have broadcast-compatible shapes}}
72   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4x3x2xi32>, tensor<3x3xi32>) -> tensor<4x3x2xi32>
73   return %0 : tensor<4x3x2xi32>
76 // -----
78 // Check incompatible result type with known dimension
79 func.func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x3xi32> {
80 ^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<3x1xi32>):
81   // expected-error @+1 {{op result type '4x3x3' not broadcast compatible with broadcasted operands's shapes '4x3x2'}}
82   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x3xi32>
83   return %0 : tensor<4x3x3xi32>
86 // -----
88 // Check incompatible result type with known dimension
89 func.func @broadcast_tensor_tensor_tensor(tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x1xi32> {
90 ^bb0(%arg0: tensor<8x1x6x1xi32>, %arg1: tensor<7x1x5xi32>):
91   // expected-error @+1 {{op result type '8x7x6x1' not broadcast compatible with broadcasted operands's shapes '8x7x6x5'}}
92   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x1xi32>
93   return %0 : tensor<8x7x6x1xi32>
96 // -----
98 func.func @broadcast_tensor_tensor_tensor(tensor<2xi32>, tensor<2xi32>) -> tensor<*xi32> {
99 ^bb0(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>):
100   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<*xi32>
101   return %0 : tensor<*xi32>
104 // -----
106 func.func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<?xi32>) -> tensor<4x3x2xi32> {
107 ^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<?xi32>):
108   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4x3x2xi32>, tensor<?xi32>) -> tensor<4x3x2xi32>
109   return %0 : tensor<4x3x2xi32>
112 // -----
114 // It is alright to have an implicit dynamic-to-static cast in a dimension size
115 // as long as the runtime result size is consistent with the result tensor's
116 // static dimension.
117 func.func @broadcast_tensor_tensor_tensor(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<2xi32> {
118   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<2xi32>
119   return %0 : tensor<2xi32>
122 // -----
124 func.func @broadcast_tensor_tensor_tensor(%arg0: tensor<?x6x1xi32>, %arg1: tensor<*xi32>) -> tensor<?x6x?xi32> {
125   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<?x6x1xi32>, tensor<*xi32>) -> tensor<?x6x?xi32>
126   return %0 : tensor<?x6x?xi32>
129 // -----
131 // Unranked operands but ranked result
132 func.func @broadcast_tensor_tensor_tensor(tensor<*xi32>, tensor<*xi32>) -> tensor<2xi32> {
133 ^bb0(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>):
134   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<2xi32>
135   return %0 : tensor<2xi32>
138 // -----
140 // Unranked operand and compatible ranked result
141 func.func @broadcast_tensor_tensor_tensor(tensor<3x2xi32>, tensor<*xi32>) -> tensor<4x3x2xi32> {
142 ^bb0(%arg0: tensor<3x2xi32>, %arg1: tensor<*xi32>):
143   %0 = "test.broadcastable"(%arg0, %arg0, %arg1) : (tensor<3x2xi32>, tensor<3x2xi32>, tensor<*xi32>) -> tensor<4x3x2xi32>
144   return %0 : tensor<4x3x2xi32>
147 // -----
149 func.func @broadcast_tensor_tensor_tensor(tensor<3x2xi32>, tensor<*xi32>) -> tensor<2xi32> {
150 ^bb0(%arg0: tensor<3x2xi32>, %arg1: tensor<*xi32>):
151   // expected-error @+1 {{op result type '2' not broadcast compatible with broadcasted operands's shapes '3x2'}}
152   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<3x2xi32>, tensor<*xi32>) -> tensor<2xi32>
153   return %0 : tensor<2xi32>
156 // -----
158 // Correct use of broadcast semantics for input dimensions
159 func.func @broadcast_tensor_tensor_tensor(%arg0: tensor<?x1x6x1xi32>, %arg1: tensor<7x1x5xi32>) -> tensor<?x7x6x5xi32> {
160   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<?x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<?x7x6x5xi32>
161   return %0 : tensor<?x7x6x5xi32>
164 // -----
166 // Incorrect attempt to use broadcast semantics for result
167 func.func @broadcast_tensor_tensor_tensor(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<5xi32> {
168   // expected-error @+1 {{op result type '5' not broadcast compatible with broadcasted operands's shapes '1'}}
169   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<5xi32>
170   return %0 : tensor<5xi32>
173 // -----
175 func.func @broadcastDifferentResultType(tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> {
176 ^bb0(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>):
177   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
178   return %0 : tensor<4xi1>