[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / test / Dialect / SPIRV / IR / integer-dot-product-ops.mlir
blobb04e5603019b41418c3cd1c6d225a81238dca72d
1 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
3 // This test covers the Integer Dot Product ops defined in the
4 // SPV_KHR_integer_dot_product extension.
6 //===----------------------------------------------------------------------===//
7 // spirv.SDot
8 //===----------------------------------------------------------------------===//
10 // CHECK: @sdot_scalar_i32
11 func.func @sdot_scalar_i32(%a: i32, %b: i32) -> i32 {
12   // CHECK-NEXT: spirv.SDot
13   %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i32
14   return %r : i32
17 // CHECK: @sdot_scalar_i64
18 func.func @sdot_scalar_i64(%a: i32, %b: i32) -> i64 {
19   // CHECK-NEXT: spirv.SDot
20   %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i64
21   return %r : i64
24 // CHECK: @sdot_vector_4xi8
25 func.func @sdot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
26   // CHECK-NEXT: spirv.SDot
27   %r = spirv.SDot %a, %b : vector<4xi8> -> i32
28   return %r : i32
31 // CHECK: @sdot_vector_4xi16
32 func.func @sdot_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>) -> i64 {
33   // CHECK-NEXT: spirv.SDot
34   %r = spirv.SDot %a, %b : vector<4xi16> -> i64
35   return %r : i64
38 // CHECK: @sdot_vector_8xi8
39 func.func @sdot_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>) -> i64 {
40   // CHECK-NEXT: spirv.SDot
41   %r = spirv.SDot %a, %b : vector<8xi8> -> i64
42   return %r : i64
45 // -----
47 // expected-note @+1 {{prior use here}}
48 func.func @sdot_scalar_bad_types(%a: i32, %b: i64) -> i32 {
49   // expected-error @+1 {{use of value '%b' expects different type than prior uses: 'i32' vs 'i64'}}
50   %r = spirv.SDot %a, %b : i32 -> i32
51   return %r : i32
53 // -----
55 func.func @sdot_vector_4xi8_bad_attr(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
56   // expected-error @+1 {{op with invalid format attribute for vector operands of type 'vector<4xi8>'}}
57   %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : vector<4xi8> -> i32
58   return %r : i32
61 // -----
63 func.func @sdot_scalar_bad_types(%a: i32, %b: i32) -> i16 {
64   // expected-error @+1 {{op result type has insufficient bit-width (16 bits) for the specified vector operand type (32 bits)}}
65   %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i16
66   return %r : i16
69 // -----
71 func.func @sdot_scalar_bad_types(%a: i64, %b: i64) -> i64 {
72   // expected-error @+1 {{op with specified Packed Vector Format (PackedVectorFormat4x8Bit) requires integer vector operands to be 32-bits wide}}
73   %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : i64 -> i64
74   return %r : i64
77 // -----
79 //===----------------------------------------------------------------------===//
80 // spirv.SUDot
81 //===----------------------------------------------------------------------===//
83 // CHECK: @sudot_scalar_i32
84 func.func @sudot_scalar_i32(%a: i32, %b: i32) -> i32 {
85   // CHECK-NEXT: spirv.SUDot
86   %r = spirv.SUDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i32
87   return %r : i32
90 // CHECK: @sudot_scalar_i64
91 func.func @sudot_scalar_i64(%a: i32, %b: i32) -> i64 {
92   // CHECK-NEXT: spirv.SUDot
93   %r = spirv.SUDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i64
94   return %r : i64
97 // CHECK: @sudot_vector_4xi8
98 func.func @sudot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
99   // CHECK-NEXT: spirv.SUDot
100   %r = spirv.SUDot %a, %b : vector<4xi8> -> i32
101   return %r : i32
104 // CHECK: @sudot_vector_4xi16
105 func.func @sudot_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>) -> i64 {
106   // CHECK-NEXT: spirv.SUDot
107   %r = spirv.SUDot %a, %b : vector<4xi16> -> i64
108   return %r : i64
111 // CHECK: @sudot_vector_8xi8
112 func.func @sudot_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>) -> i64 {
113   // CHECK-NEXT: spirv.SUDot
114   %r = spirv.SUDot %a, %b : vector<8xi8> -> i64
115   return %r : i64
118 // -----
120 //===----------------------------------------------------------------------===//
121 // spirv.UDot
122 //===----------------------------------------------------------------------===//
124 // CHECK: @udot_scalar_i32
125 func.func @udot_scalar_i32(%a: i32, %b: i32) -> i32 {
126   // CHECK-NEXT: spirv.UDot
127   %r = spirv.UDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i32
128   return %r : i32
131 // CHECK: @udot_scalar_i64
132 func.func @udot_scalar_i64(%a: i32, %b: i32) -> i64 {
133   // CHECK-NEXT: spirv.UDot
134   %r = spirv.UDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i64
135   return %r : i64
138 // CHECK: @udot_vector_4xi8
139 func.func @udot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
140   // CHECK-NEXT: spirv.UDot
141   %r = spirv.UDot %a, %b : vector<4xi8> -> i32
142   return %r : i32
145 // -----
147 //===----------------------------------------------------------------------===//
148 // spirv.SDotAccSat
149 //===----------------------------------------------------------------------===//
151 // CHECK: @sdot_acc_sat_scalar_i32
152 func.func @sdot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc : i32) -> i32 {
153   // CHECK-NEXT: spirv.SDotAccSat
154   %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i32
155   return %r : i32
158 // CHECK: @sdot_acc_sat_scalar_i64
159 func.func @sdot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc : i64) -> i64 {
160   // CHECK-NEXT: spirv.SDotAccSat
161   %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i64
162   return %r : i64
165 // CHECK: @sdot_acc_sat_vector_4xi8
166 func.func @sdot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc : i32) -> i32 {
167   // CHECK-NEXT: spirv.SDotAccSat
168   %r = spirv.SDotAccSat %a, %b, %acc : vector<4xi8> -> i32
169   return %r : i32
172 // CHECK: @sdot_acc_sat_vector_4xi16
173 func.func @sdot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc : i64) -> i64 {
174   // CHECK-NEXT: spirv.SDotAccSat
175   %r = spirv.SDotAccSat %a, %b, %acc : vector<4xi16> -> i64
176   return %r : i64
179 // CHECK: @sdot_acc_sat_vector_8xi8
180 func.func @sdot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc : i64) -> i64 {
181   // CHECK-NEXT: spirv.SDotAccSat
182   %r = spirv.SDotAccSat %a, %b, %acc : vector<8xi8> -> i64
183   return %r : i64
186 // -----
188 // expected-note @+1 {{prior use here}}
189 func.func @sdot_acc_sat_scalar_bad_types(%a: i32, %b: i64, %acc : i32) -> i32 {
190   // expected-error @+1 {{use of value '%b' expects different type than prior uses: 'i32' vs 'i64'}}
191   %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i32
192   return %r : i32
195 // -----
197 func.func @sdot_acc_sat_scalar_bad_types(%a: i32, %b: i32, %acc : i16) -> i16 {
198   // expected-error @+1 {{op result type has insufficient bit-width (16 bits) for the specified vector operand type (32 bits)}}
199   %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i16
200   return %r : i16
203 // -----
205 func.func @sdot_acc_sat_scalar_bad_types(%a: i64, %b: i64, %acc : i64) -> i64 {
206   // expected-error @+1 {{op with specified Packed Vector Format (PackedVectorFormat4x8Bit) requires integer vector operands to be 32-bits wide}}
207   %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i64 -> i64
208   return %r : i64
211 // -----
213 // expected-note @+1 {{prior use here}}
214 func.func @sdot_acc_sat_scalar_bad_accumulator(%a: i32, %b: i32, %acc : i32) -> i64 {
215   // expected-error @+1 {{use of value '%acc' expects different type than prior uses: 'i64' vs 'i32'}}
216   %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i64
217   return %r : i64
220 // -----
222 //===----------------------------------------------------------------------===//
223 // spirv.SUDotAccSat
224 //===----------------------------------------------------------------------===//
226 // CHECK: @sudot_acc_sat_scalar_i32
227 func.func @sudot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc : i32) -> i32 {
228   // CHECK-NEXT: spirv.SUDotAccSat
229   %r = spirv.SUDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i32
230   return %r : i32
233 // CHECK: @sudot_acc_sat_scalar_i64
234 func.func @sudot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc : i64) -> i64 {
235   // CHECK-NEXT: spirv.SUDotAccSat
236   %r = spirv.SUDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i64
237   return %r : i64
240 // CHECK: @sudot_acc_sat_vector_4xi8
241 func.func @sudot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc : i32) -> i32 {
242   // CHECK-NEXT: spirv.SUDotAccSat
243   %r = spirv.SUDotAccSat %a, %b, %acc : vector<4xi8> -> i32
244   return %r : i32
247 // CHECK: @sudot_acc_sat_vector_4xi16
248 func.func @sudot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc : i64) -> i64 {
249   // CHECK-NEXT: spirv.SUDotAccSat
250   %r = spirv.SUDotAccSat %a, %b, %acc : vector<4xi16> -> i64
251   return %r : i64
254 // CHECK: @sudot_acc_sat_vector_8xi8
255 func.func @sudot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc : i64) -> i64 {
256   // CHECK-NEXT: spirv.SUDotAccSat
257   %r = spirv.SUDotAccSat %a, %b, %acc : vector<8xi8> -> i64
258   return %r : i64
261 // -----
263 //===----------------------------------------------------------------------===//
264 // spirv.UDotAccSat
265 //===----------------------------------------------------------------------===//
267 // CHECK: @udot_acc_sat_scalar_i32
268 func.func @udot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc : i32) -> i32 {
269   // CHECK-NEXT: spirv.UDotAccSat
270   %r = spirv.UDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i32
271   return %r : i32
274 // CHECK: @udot_acc_sat_scalar_i64
275 func.func @udot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc : i64) -> i64 {
276   // CHECK-NEXT: spirv.UDotAccSat
277   %r = spirv.UDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i64
278   return %r : i64
281 // CHECK: @udot_acc_sat_vector_4xi8
282 func.func @udot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc : i32) -> i32 {
283   // CHECK-NEXT: spirv.UDotAccSat
284   %r = spirv.UDotAccSat %a, %b, %acc : vector<4xi8> -> i32
285   return %r : i32
288 // CHECK: @udot_acc_sat_vector_4xi16
289 func.func @udot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc : i64) -> i64 {
290   // CHECK-NEXT: spirv.UDotAccSat
291   %r = spirv.UDotAccSat %a, %b, %acc : vector<4xi16> -> i64
292   return %r : i64
295 // CHECK: @udot_acc_sat_vector_8xi8
296 func.func @udot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc : i64) -> i64 {
297   // CHECK-NEXT: spirv.UDotAccSat
298   %r = spirv.UDotAccSat %a, %b, %acc : vector<8xi8> -> i64
299   return %r : i64