1 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
3 // This test covers the Integer Dot Product ops defined in the
4 // SPV_KHR_integer_dot_product extension.
6 //===----------------------------------------------------------------------===//
8 //===----------------------------------------------------------------------===//
10 // CHECK: @sdot_scalar_i32
11 func.func @sdot_scalar_i32(%a: i32, %b: i32) -> i32 {
12 // CHECK-NEXT: spirv.SDot
13 %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i32
17 // CHECK: @sdot_scalar_i64
18 func.func @sdot_scalar_i64(%a: i32, %b: i32) -> i64 {
19 // CHECK-NEXT: spirv.SDot
20 %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i64
24 // CHECK: @sdot_vector_4xi8
25 func.func @sdot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
26 // CHECK-NEXT: spirv.SDot
27 %r = spirv.SDot %a, %b : vector<4xi8> -> i32
31 // CHECK: @sdot_vector_4xi16
32 func.func @sdot_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>) -> i64 {
33 // CHECK-NEXT: spirv.SDot
34 %r = spirv.SDot %a, %b : vector<4xi16> -> i64
38 // CHECK: @sdot_vector_8xi8
39 func.func @sdot_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>) -> i64 {
40 // CHECK-NEXT: spirv.SDot
41 %r = spirv.SDot %a, %b : vector<8xi8> -> i64
47 // expected-note @+1 {{prior use here}}
48 func.func @sdot_scalar_bad_types(%a: i32, %b: i64) -> i32 {
49 // expected-error @+1 {{use of value '%b' expects different type than prior uses: 'i32' vs 'i64'}}
50 %r = spirv.SDot %a, %b : i32 -> i32
55 func.func @sdot_vector_4xi8_bad_attr(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
56 // expected-error @+1 {{op with invalid format attribute for vector operands of type 'vector<4xi8>'}}
57 %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : vector<4xi8> -> i32
63 func.func @sdot_scalar_bad_types(%a: i32, %b: i32) -> i16 {
64 // expected-error @+1 {{op result type has insufficient bit-width (16 bits) for the specified vector operand type (32 bits)}}
65 %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i16
71 func.func @sdot_scalar_bad_types(%a: i64, %b: i64) -> i64 {
72 // expected-error @+1 {{op with specified Packed Vector Format (PackedVectorFormat4x8Bit) requires integer vector operands to be 32-bits wide}}
73 %r = spirv.SDot %a, %b, <PackedVectorFormat4x8Bit> : i64 -> i64
79 //===----------------------------------------------------------------------===//
81 //===----------------------------------------------------------------------===//
83 // CHECK: @sudot_scalar_i32
84 func.func @sudot_scalar_i32(%a: i32, %b: i32) -> i32 {
85 // CHECK-NEXT: spirv.SUDot
86 %r = spirv.SUDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i32
90 // CHECK: @sudot_scalar_i64
91 func.func @sudot_scalar_i64(%a: i32, %b: i32) -> i64 {
92 // CHECK-NEXT: spirv.SUDot
93 %r = spirv.SUDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i64
97 // CHECK: @sudot_vector_4xi8
98 func.func @sudot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
99 // CHECK-NEXT: spirv.SUDot
100 %r = spirv.SUDot %a, %b : vector<4xi8> -> i32
104 // CHECK: @sudot_vector_4xi16
105 func.func @sudot_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>) -> i64 {
106 // CHECK-NEXT: spirv.SUDot
107 %r = spirv.SUDot %a, %b : vector<4xi16> -> i64
111 // CHECK: @sudot_vector_8xi8
112 func.func @sudot_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>) -> i64 {
113 // CHECK-NEXT: spirv.SUDot
114 %r = spirv.SUDot %a, %b : vector<8xi8> -> i64
120 //===----------------------------------------------------------------------===//
122 //===----------------------------------------------------------------------===//
124 // CHECK: @udot_scalar_i32
125 func.func @udot_scalar_i32(%a: i32, %b: i32) -> i32 {
126 // CHECK-NEXT: spirv.UDot
127 %r = spirv.UDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i32
131 // CHECK: @udot_scalar_i64
132 func.func @udot_scalar_i64(%a: i32, %b: i32) -> i64 {
133 // CHECK-NEXT: spirv.UDot
134 %r = spirv.UDot %a, %b, <PackedVectorFormat4x8Bit> : i32 -> i64
138 // CHECK: @udot_vector_4xi8
139 func.func @udot_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>) -> i32 {
140 // CHECK-NEXT: spirv.UDot
141 %r = spirv.UDot %a, %b : vector<4xi8> -> i32
147 //===----------------------------------------------------------------------===//
149 //===----------------------------------------------------------------------===//
151 // CHECK: @sdot_acc_sat_scalar_i32
152 func.func @sdot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc : i32) -> i32 {
153 // CHECK-NEXT: spirv.SDotAccSat
154 %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i32
158 // CHECK: @sdot_acc_sat_scalar_i64
159 func.func @sdot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc : i64) -> i64 {
160 // CHECK-NEXT: spirv.SDotAccSat
161 %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i64
165 // CHECK: @sdot_acc_sat_vector_4xi8
166 func.func @sdot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc : i32) -> i32 {
167 // CHECK-NEXT: spirv.SDotAccSat
168 %r = spirv.SDotAccSat %a, %b, %acc : vector<4xi8> -> i32
172 // CHECK: @sdot_acc_sat_vector_4xi16
173 func.func @sdot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc : i64) -> i64 {
174 // CHECK-NEXT: spirv.SDotAccSat
175 %r = spirv.SDotAccSat %a, %b, %acc : vector<4xi16> -> i64
179 // CHECK: @sdot_acc_sat_vector_8xi8
180 func.func @sdot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc : i64) -> i64 {
181 // CHECK-NEXT: spirv.SDotAccSat
182 %r = spirv.SDotAccSat %a, %b, %acc : vector<8xi8> -> i64
188 // expected-note @+1 {{prior use here}}
189 func.func @sdot_acc_sat_scalar_bad_types(%a: i32, %b: i64, %acc : i32) -> i32 {
190 // expected-error @+1 {{use of value '%b' expects different type than prior uses: 'i32' vs 'i64'}}
191 %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i32
197 func.func @sdot_acc_sat_scalar_bad_types(%a: i32, %b: i32, %acc : i16) -> i16 {
198 // expected-error @+1 {{op result type has insufficient bit-width (16 bits) for the specified vector operand type (32 bits)}}
199 %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i16
205 func.func @sdot_acc_sat_scalar_bad_types(%a: i64, %b: i64, %acc : i64) -> i64 {
206 // expected-error @+1 {{op with specified Packed Vector Format (PackedVectorFormat4x8Bit) requires integer vector operands to be 32-bits wide}}
207 %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i64 -> i64
213 // expected-note @+1 {{prior use here}}
214 func.func @sdot_acc_sat_scalar_bad_accumulator(%a: i32, %b: i32, %acc : i32) -> i64 {
215 // expected-error @+1 {{use of value '%acc' expects different type than prior uses: 'i64' vs 'i32'}}
216 %r = spirv.SDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i64
222 //===----------------------------------------------------------------------===//
224 //===----------------------------------------------------------------------===//
226 // CHECK: @sudot_acc_sat_scalar_i32
227 func.func @sudot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc : i32) -> i32 {
228 // CHECK-NEXT: spirv.SUDotAccSat
229 %r = spirv.SUDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i32
233 // CHECK: @sudot_acc_sat_scalar_i64
234 func.func @sudot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc : i64) -> i64 {
235 // CHECK-NEXT: spirv.SUDotAccSat
236 %r = spirv.SUDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i64
240 // CHECK: @sudot_acc_sat_vector_4xi8
241 func.func @sudot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc : i32) -> i32 {
242 // CHECK-NEXT: spirv.SUDotAccSat
243 %r = spirv.SUDotAccSat %a, %b, %acc : vector<4xi8> -> i32
247 // CHECK: @sudot_acc_sat_vector_4xi16
248 func.func @sudot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc : i64) -> i64 {
249 // CHECK-NEXT: spirv.SUDotAccSat
250 %r = spirv.SUDotAccSat %a, %b, %acc : vector<4xi16> -> i64
254 // CHECK: @sudot_acc_sat_vector_8xi8
255 func.func @sudot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc : i64) -> i64 {
256 // CHECK-NEXT: spirv.SUDotAccSat
257 %r = spirv.SUDotAccSat %a, %b, %acc : vector<8xi8> -> i64
263 //===----------------------------------------------------------------------===//
265 //===----------------------------------------------------------------------===//
267 // CHECK: @udot_acc_sat_scalar_i32
268 func.func @udot_acc_sat_scalar_i32(%a: i32, %b: i32, %acc : i32) -> i32 {
269 // CHECK-NEXT: spirv.UDotAccSat
270 %r = spirv.UDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i32
274 // CHECK: @udot_acc_sat_scalar_i64
275 func.func @udot_acc_sat_scalar_i64(%a: i32, %b: i32, %acc : i64) -> i64 {
276 // CHECK-NEXT: spirv.UDotAccSat
277 %r = spirv.UDotAccSat %a, %b, %acc, <PackedVectorFormat4x8Bit> : i32 -> i64
281 // CHECK: @udot_acc_sat_vector_4xi8
282 func.func @udot_acc_sat_vector_4xi8(%a: vector<4xi8>, %b: vector<4xi8>, %acc : i32) -> i32 {
283 // CHECK-NEXT: spirv.UDotAccSat
284 %r = spirv.UDotAccSat %a, %b, %acc : vector<4xi8> -> i32
288 // CHECK: @udot_acc_sat_vector_4xi16
289 func.func @udot_acc_sat_vector_4xi16(%a: vector<4xi16>, %b: vector<4xi16>, %acc : i64) -> i64 {
290 // CHECK-NEXT: spirv.UDotAccSat
291 %r = spirv.UDotAccSat %a, %b, %acc : vector<4xi16> -> i64
295 // CHECK: @udot_acc_sat_vector_8xi8
296 func.func @udot_acc_sat_vector_8xi8(%a: vector<8xi8>, %b: vector<8xi8>, %acc : i64) -> i64 {
297 // CHECK-NEXT: spirv.UDotAccSat
298 %r = spirv.UDotAccSat %a, %b, %acc : vector<8xi8> -> i64