1 // RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s
3 //===----------------------------------------------------------------------===//
4 // CooperativeMatrix (KHR) extension ops.
5 //===----------------------------------------------------------------------===//
7 // CHECK-LABEL: @cooperative_matrix_length
8 spirv.func @cooperative_matrix_length() -> i32 "None" {
9 // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
10 %0 = spirv.KHR.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
11 spirv.ReturnValue %0 : i32
16 // CHECK-LABEL: @cooperative_matrix_load
17 spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
18 // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor> :
19 // CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
20 %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
21 !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
25 // CHECK-LABEL: @cooperative_matrix_load_memoperand
26 spirv.func @cooperative_matrix_load_memoperand(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
27 // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <ColumnMajor>, <Volatile> :
28 // CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
29 %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile> :
30 !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
34 // CHECK-LABEL: @cooperative_matrix_load_vector_ptr_type
35 spirv.func @cooperative_matrix_load_vector_ptr_type(%ptr : !spirv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32) "None" {
36 // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>, <Volatile> :
37 // CHECK-SAME: !spirv.ptr<vector<4xi32>, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
38 %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <Volatile> :
39 !spirv.ptr<vector<4xi32>, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
43 // CHECK-LABEL: @cooperative_matrix_load_function
44 spirv.func @cooperative_matrix_load_function(%ptr : !spirv.ptr<i32, Function>, %stride : i32) "None" {
45 // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor> :
46 // CHECK-SAME: !spirv.ptr<i32, Function>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>
47 %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
48 !spirv.ptr<i32, Function>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>
52 // CHECK-LABEL: @cooperative_matrix_load_stride_i16
53 spirv.func @cooperative_matrix_load_stride_i16(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i16) "None" {
54 // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor> :
55 // CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, i16 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
56 %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
57 !spirv.ptr<i32, StorageBuffer>, i16 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
61 // CHECK-LABEL: @cooperative_matrix_store
62 spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
63 %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
64 // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <RowMajor> :
65 // CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
66 spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor> :
67 !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
71 // CHECK-LABEL: @cooperative_matrix_store_memoperand
72 spirv.func @cooperative_matrix_store_memoperand(%ptr : !spirv.ptr<i32, StorageBuffer>,
73 %m : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
74 %stride : i32) "None" {
75 // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <ColumnMajor>, <Volatile> :
76 // CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i32
77 spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <ColumnMajor>, <Volatile> :
78 !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i32
82 // CHECK-LABEL: @cooperative_matrix_store_stride_i16
83 spirv.func @cooperative_matrix_store_stride_i16(%ptr : !spirv.ptr<i32, StorageBuffer>,
84 %m : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
85 %stride : i16) "None" {
86 // CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <ColumnMajor> :
87 // CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i16
88 spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <ColumnMajor> :
89 !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i16
95 spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32) "None" {
96 // expected-error @+1 {{Pointer must point to a scalar or vector type}}
97 %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor> :
98 !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
104 spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
105 // expected-error @+1 {{expected ','}}
106 %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride :
107 !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
113 spirv.func @cooperative_matrix_load_bad_operad(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
114 // expected-error @+1 {{op not compatible with memory operand 'MakePointerAvailable'}}
115 %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <MakePointerAvailable> :
116 !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
122 spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
123 // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
124 %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Aligned> :
125 !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
131 spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
132 // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
133 %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile|Aligned> :
134 !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
140 spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
141 %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
142 // expected-error @+1 {{expected ','}}
143 spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride :
144 !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
150 spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
151 %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
152 // expected-error @+1 {{expected '<'}}
153 spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, :
154 !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
160 spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr<i32, StorageBuffer>,
161 %stride : i32) "None" {
162 // expected-error @+1 {{op operand #1 must be any SPIR-V cooperative matrix type}}
163 spirv.KHR.CooperativeMatrixStore %ptr, %stride, %stride, <RowMajor> :
164 !spirv.ptr<i32, StorageBuffer>, i32, i32
170 spirv.func @cooperative_matrix_store_bad_operand(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
171 %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
172 // expected-error @+1 {{op not compatible with memory operand 'MakePointerVisible'}}
173 spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <MakePointerVisible> :
174 !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
180 spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
181 %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
182 // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
183 spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned> :
184 !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
190 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
191 %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
192 %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
193 %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
194 !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
195 !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
196 !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
200 spirv.func @cooperative_matrix_muladd_matrix_operands(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
201 %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
202 %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
203 %p = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, <AccSat> :
204 !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
205 !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
206 !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
207 %q = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, <ASigned | BSigned> :
208 !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
209 !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
210 !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
211 %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, <ASigned | BSigned | AccSat> :
212 !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
213 !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
214 !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
218 spirv.func @cooperative_matrix_muladd_f32(%a : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>,
219 %b : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixB>,
220 %c : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixAcc>) "None" {
221 %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
222 !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>,
223 !spirv.coopmatrix<4x4xf32, Subgroup, MatrixB> ->
224 !spirv.coopmatrix<4x4xf32, Subgroup, MatrixAcc>
228 spirv.func @cooperative_matrix_muladd_i8_i32(%a : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
229 %b : !spirv.coopmatrix<16x4xi8, Subgroup, MatrixB>,
230 %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
231 %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
232 !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
233 !spirv.coopmatrix<16x4xi8, Subgroup, MatrixB> ->
234 !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
238 spirv.func @cooperative_matrix_muladd_i8_i16_i32(%a : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
239 %b : !spirv.coopmatrix<16x4xi16, Subgroup, MatrixB>,
240 %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
241 %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
242 !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
243 !spirv.coopmatrix<16x4xi16, Subgroup, MatrixB> ->
244 !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
248 spirv.func @cooperative_matrix_muladd_workgroup(%a : !spirv.coopmatrix<4x4xf16, Workgroup, MatrixA>,
249 %b : !spirv.coopmatrix<4x4xf16, Workgroup, MatrixB>,
250 %c : !spirv.coopmatrix<4x4xf16, Workgroup, MatrixAcc>) "None" {
251 %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
252 !spirv.coopmatrix<4x4xf16, Workgroup, MatrixA>,
253 !spirv.coopmatrix<4x4xf16, Workgroup, MatrixB> ->
254 !spirv.coopmatrix<4x4xf16, Workgroup, MatrixAcc>
260 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
261 %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
262 %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
263 // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op operand #0 must be of use 'MatrixA'}}
264 %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
265 !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
266 !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
267 !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
273 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
274 %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>) "None" {
275 // expected-error @+1 {{expected ','}}
276 %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b :
277 !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
278 !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
279 !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
285 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
286 %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>) "None" {
287 // expected-error @+1 {{expected SSA operand}}
288 %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, <ASigned> :
289 !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
290 !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
291 !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
297 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
298 %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
299 %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
300 // expected-error @+1 {{expected '<'}}
301 %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, %c :
302 !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
303 !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
304 !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
310 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
311 %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixA>,
312 %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
313 // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op operand #1 must be of use 'MatrixB'}}
314 %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
315 !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
316 !spirv.coopmatrix<16x4xi32, Subgroup, MatrixA> ->
317 !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
323 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
324 %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
325 %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixB>) "None" {
326 // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op operand #2 must be of use 'MatrixAcc'}}
327 %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
328 !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
329 !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
330 !spirv.coopmatrix<8x4xi32, Subgroup, MatrixB>
336 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
337 %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
338 %c : !spirv.coopmatrix<10x4xi32, Subgroup, MatrixAcc>) "None" {
339 // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op matrix size mismatch on dimension 'M'}}
340 %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
341 !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
342 !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
343 !spirv.coopmatrix<10x4xi32, Subgroup, MatrixAcc>
349 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
350 %b : !spirv.coopmatrix<4x16xi32, Subgroup, MatrixB>,
351 %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
352 // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op matrix size mismatch on dimension 'N'}}
353 %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
354 !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
355 !spirv.coopmatrix<4x16xi32, Subgroup, MatrixB> ->
356 !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
362 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
363 %b : !spirv.coopmatrix<12x4xi32, Subgroup, MatrixB>,
364 %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
365 // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op matrix size mismatch on dimension 'K'}}
366 %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
367 !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
368 !spirv.coopmatrix<12x4xi32, Subgroup, MatrixB> ->
369 !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
375 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
376 %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
377 %c : !spirv.coopmatrix<8x4xi32, Workgroup, MatrixAcc>) "None" {
378 // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op matrix scope mismatch}}
379 %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
380 !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
381 !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
382 !spirv.coopmatrix<8x4xi32, Workgroup, MatrixAcc>
388 spirv.func @cooperative_matrix_muladd_matrix_operands(%a : !spirv.coopmatrix<8x16xf16, Subgroup, MatrixA>,
389 %b : !spirv.coopmatrix<16x4xf16, Subgroup, MatrixB>,
390 %c : !spirv.coopmatrix<8x4xf16, Subgroup, MatrixAcc>) "None" {
391 // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op Matrix Operands require all matrix element types to be Integer Types}}
392 %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, <AccSat> :
393 !spirv.coopmatrix<8x16xf16, Subgroup, MatrixA>,
394 !spirv.coopmatrix<16x4xf16, Subgroup, MatrixB> ->
395 !spirv.coopmatrix<8x4xf16, Subgroup, MatrixAcc>
401 //===----------------------------------------------------------------------===//
402 // Standard ops that can be used CooperativeMatrix types
403 //===----------------------------------------------------------------------===//
405 !matA_i32 = !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
406 !matB_i32 = !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>
408 !matA_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
409 !matB_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixB>
411 // These tests are kept in the same order as the list of compatible ops in the
412 // SPV_KHR_cooperative_matrix extension spec.
414 // CHECK-LABEL: @snegate
415 spirv.func @snegate(%a: !matA_i32, %b: !matB_i32) "None" {
416 // CHECK: spirv.SNegate {{%.*}} : !spirv.coopmatrix
417 // CHECK-NEXT: spirv.SNegate {{%.*}} : !spirv.coopmatrix
418 %p = spirv.SNegate %a : !matA_i32
419 %q = spirv.SNegate %b : !matB_i32
423 // CHECK-LABEL: @fnegate
424 spirv.func @fnegate(%a: !matA_f32, %b: !matB_f32) "None" {
425 // CHECK: spirv.FNegate {{%.*}} : !spirv.coopmatrix
426 // CHECK-NEXT: spirv.FNegate {{%.*}} : !spirv.coopmatrix
427 %p = spirv.FNegate %a : !matA_f32
428 %q = spirv.FNegate %b : !matB_f32
432 // CHECK-LABEL: @iadd
433 spirv.func @iadd(%a: !matA_i32, %b: !matB_i32) "None" {
434 // CHECK: spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
435 // CHECK-NEXT: spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
436 %p = spirv.IAdd %a, %a : !matA_i32
437 %q = spirv.IAdd %b, %b : !matB_i32
441 // CHECK-LABEL: @fadd
442 spirv.func @fadd(%a: !matA_f32, %b: !matB_f32) "None" {
443 // CHECK: spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
444 // CHECK-NEXT: spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
445 %p = spirv.FAdd %a, %a : !matA_f32
446 %q = spirv.FAdd %b, %b : !matB_f32
450 // CHECK-LABEL: @isub
451 spirv.func @isub(%a: !matA_i32, %b: !matB_i32) "None" {
452 // CHECK: spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix
453 // CHECK-NEXT: spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix
454 %p = spirv.ISub %a, %a : !matA_i32
455 %q = spirv.ISub %b, %b : !matB_i32
459 // CHECK-LABEL: @fsub
460 spirv.func @fsub(%a: !matA_f32, %b: !matB_f32) "None" {
461 // CHECK: spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
462 // CHECK-NEXT: spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
463 %p = spirv.FSub %a, %a : !matA_f32
464 %q = spirv.FSub %b, %b : !matB_f32
468 // CHECK-LABEL: @fmul
469 spirv.func @fmul(%a: !matA_f32, %b: !matB_f32) "None" {
470 // CHECK: spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
471 // CHECK-NEXT: spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
472 %p = spirv.FMul %a, %a : !matA_f32
473 %q = spirv.FMul %b, %b : !matB_f32
477 // CHECK-LABEL: @imul
478 spirv.func @imul(%a: !matA_i32, %b: !matB_i32) "None" {
479 // CHECK: spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
480 // CHECK-NEXT: spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
481 %p = spirv.IMul %a, %a : !matA_i32
482 %q = spirv.IMul %b, %b : !matB_i32
486 // CHECK-LABEL: @fdiv
487 spirv.func @fdiv(%a: !matA_f32, %b: !matB_f32) "None" {
488 // CHECK: spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
489 // CHECK-NEXT: spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
490 %p = spirv.FDiv %a, %a : !matA_f32
491 %q = spirv.FDiv %b, %b : !matB_f32
495 // CHECK-LABEL: @sdiv
496 spirv.func @sdiv(%a: !matA_i32, %b: !matB_i32) "None" {
497 // CHECK: spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
498 // CHECK-NEXT: spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
499 %p = spirv.SDiv %a, %a : !matA_i32
500 %q = spirv.SDiv %b, %b : !matB_i32
504 // CHECK-LABEL: @udiv
505 spirv.func @udiv(%a: !matA_i32, %b: !matB_i32) "None" {
506 // CHECK: spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
507 // CHECK-NEXT: spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
508 %p = spirv.UDiv %a, %a : !matA_i32
509 %q = spirv.UDiv %b, %b : !matB_i32
513 // CHECK-LABEL: @matrix_times_scalar
514 spirv.func @matrix_times_scalar(%a: !matA_f32, %b: f32) "None" {
515 // CHECK: spirv.MatrixTimesScalar {{%.*}} : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, f32
516 %p = spirv.MatrixTimesScalar %a, %b : !matA_f32, f32
522 // For binary arithmetic instructions with coop matrix operands, the types must
525 spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
526 %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>) "None" {
527 // expected-error @+1 {{op requires the same type for all operands and results}}
528 %q = "spirv.IAdd"(%a, %b) :
529 (!spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>)
530 -> !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
536 spirv.func @fadd(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>,
537 %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>) "None" {
538 // expected-error @+1 {{op requires the same type for all operands and results}}
539 %q = "spirv.FAdd"(%a, %b) :
540 (!spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>)
541 -> !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>
547 spirv.func @matrix_times_scalar(%a: !spirv.coopmatrix<2x2xf32, Workgroup, MatrixA>, %b: f16) "None" {
548 // expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
549 %p = spirv.MatrixTimesScalar %a, %b : !spirv.coopmatrix<2x2xf32, Workgroup, MatrixA>, f16