[gn build] Port fef54d0393fd
[llvm-project.git] / mlir / test / Dialect / SPIRV / IR / khr-cooperative-matrix-ops.mlir
blobd3e1dbc229ef99a11eceb6f09cf634300848ce30
1 // RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s
3 //===----------------------------------------------------------------------===//
4 // CooperativeMatrix (KHR) extension ops.
5 //===----------------------------------------------------------------------===//
7 // CHECK-LABEL: @cooperative_matrix_length
8 spirv.func @cooperative_matrix_length() -> i32 "None" {
9   // CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
10   %0 = spirv.KHR.CooperativeMatrixLength : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
11   spirv.ReturnValue %0 : i32
14 // -----
16 // CHECK-LABEL: @cooperative_matrix_load
17 spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
18   // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor> :
19   // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
20   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
21     !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
22   spirv.Return
25 // CHECK-LABEL: @cooperative_matrix_load_memoperand
26 spirv.func @cooperative_matrix_load_memoperand(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
27   // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <ColumnMajor>, <Volatile> :
28   // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
29   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile> :
30     !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
31   spirv.Return
34 // CHECK-LABEL: @cooperative_matrix_load_vector_ptr_type
35 spirv.func @cooperative_matrix_load_vector_ptr_type(%ptr : !spirv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32) "None" {
36   // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>, <Volatile> :
37   // CHECK-SAME:   !spirv.ptr<vector<4xi32>, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
38   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor>, <Volatile> :
39     !spirv.ptr<vector<4xi32>, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
40   spirv.Return
43 // CHECK-LABEL: @cooperative_matrix_load_function
44 spirv.func @cooperative_matrix_load_function(%ptr : !spirv.ptr<i32, Function>, %stride : i32) "None" {
45   // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor> :
46   // CHECK-SAME:   !spirv.ptr<i32, Function>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>
47   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
48     !spirv.ptr<i32, Function>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixAcc>
49   spirv.Return
52 // CHECK-LABEL: @cooperative_matrix_load_stride_i16
53 spirv.func @cooperative_matrix_load_stride_i16(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i16) "None" {
54   // CHECK:      {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor> :
55   // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer>, i16 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
56   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
57     !spirv.ptr<i32, StorageBuffer>, i16 -> !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>
58   spirv.Return
61 // CHECK-LABEL: @cooperative_matrix_store
62 spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
63                                      %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
64   // CHECK:      spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <RowMajor> :
65   // CHECK-SAME:   !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
66   spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor> :
67     !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
68   spirv.Return
71 // CHECK-LABEL: @cooperative_matrix_store_memoperand
72 spirv.func @cooperative_matrix_store_memoperand(%ptr : !spirv.ptr<i32, StorageBuffer>,
73                                                 %m : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
74                                                 %stride : i32) "None" {
75   // CHECK:       spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <ColumnMajor>, <Volatile> :
76   // CHECK-SAME:    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i32
77   spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <ColumnMajor>, <Volatile> :
78     !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i32
79   spirv.Return
82 // CHECK-LABEL: @cooperative_matrix_store_stride_i16
83 spirv.func @cooperative_matrix_store_stride_i16(%ptr : !spirv.ptr<i32, StorageBuffer>,
84                                                 %m : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
85                                                 %stride : i16) "None" {
86   // CHECK:       spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <ColumnMajor> :
87   // CHECK-SAME:    !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i16
88   spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <ColumnMajor> :
89     !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i16
90   spirv.Return
93 // -----
95 spirv.func @cooperative_matrix_load_bad_ptr(%ptr : !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, %stride : i32) "None" {
96   // expected-error @+1 {{Pointer must point to a scalar or vector type}}
97   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor> :
98     !spirv.ptr<!spirv.struct<(f32 [0])>, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
99   spirv.Return
102 // -----
104 spirv.func @cooperative_matrix_load_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
105   // expected-error @+1 {{expected ','}}
106   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride :
107     !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
108   spirv.Return
111 // -----
113 spirv.func @cooperative_matrix_load_bad_operad(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
114   // expected-error @+1 {{op not compatible with memory operand 'MakePointerAvailable'}}
115   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <MakePointerAvailable> :
116     !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
117   spirv.Return
120 // -----
122 spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
123   // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
124   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Aligned> :
125     !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
126   spirv.Return
129 // -----
131 spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
132   // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
133   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile|Aligned> :
134     !spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
135   spirv.Return
138 // -----
140 spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
141                                                   %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
142   // expected-error @+1 {{expected ','}}
143   spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride :
144     !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
145   spirv.Return
148 // -----
150 spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
151                                                   %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
152   // expected-error @+1 {{expected '<'}}
153   spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, :
154     !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
155   spirv.Return
158 // -----
160 spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr<i32, StorageBuffer>,
161                                                      %stride : i32) "None" {
162   // expected-error @+1 {{op operand #1 must be any SPIR-V cooperative matrix type}}
163   spirv.KHR.CooperativeMatrixStore %ptr, %stride, %stride, <RowMajor> :
164     !spirv.ptr<i32, StorageBuffer>, i32, i32
165   spirv.Return
168 // -----
170 spirv.func @cooperative_matrix_store_bad_operand(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
171                                                  %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
172   // expected-error @+1 {{op not compatible with memory operand 'MakePointerVisible'}}
173   spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <MakePointerVisible> :
174     !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
175   spirv.Return
178 // -----
180 spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
181                                      %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
182   // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
183   spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned> :
184     !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
185   spirv.Return
188 // -----
190 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
191                                       %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
192                                       %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
193   %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
194         !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
195         !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
196           !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
197   spirv.Return
200 spirv.func @cooperative_matrix_muladd_matrix_operands(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
201                                                       %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
202                                                       %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
203   %p = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, <AccSat> :
204         !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
205         !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
206           !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
207   %q = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, <ASigned | BSigned> :
208         !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
209         !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
210           !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
211   %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, <ASigned | BSigned | AccSat> :
212         !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
213         !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
214           !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
215   spirv.Return
218 spirv.func @cooperative_matrix_muladd_f32(%a : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>,
219                                           %b : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixB>,
220                                           %c : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixAcc>) "None" {
221   %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
222         !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>,
223         !spirv.coopmatrix<4x4xf32, Subgroup, MatrixB> ->
224           !spirv.coopmatrix<4x4xf32, Subgroup, MatrixAcc>
225   spirv.Return
228 spirv.func @cooperative_matrix_muladd_i8_i32(%a : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
229                                              %b : !spirv.coopmatrix<16x4xi8, Subgroup, MatrixB>,
230                                              %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
231   %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
232         !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
233         !spirv.coopmatrix<16x4xi8, Subgroup, MatrixB> ->
234           !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
235   spirv.Return
238 spirv.func @cooperative_matrix_muladd_i8_i16_i32(%a : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
239                                                  %b : !spirv.coopmatrix<16x4xi16, Subgroup, MatrixB>,
240                                                  %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
241   %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
242         !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
243         !spirv.coopmatrix<16x4xi16, Subgroup, MatrixB> ->
244           !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
245   spirv.Return
248 spirv.func @cooperative_matrix_muladd_workgroup(%a : !spirv.coopmatrix<4x4xf16, Workgroup, MatrixA>,
249                                                 %b : !spirv.coopmatrix<4x4xf16, Workgroup, MatrixB>,
250                                                 %c : !spirv.coopmatrix<4x4xf16, Workgroup, MatrixAcc>) "None" {
251   %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
252         !spirv.coopmatrix<4x4xf16, Workgroup, MatrixA>,
253         !spirv.coopmatrix<4x4xf16, Workgroup, MatrixB> ->
254           !spirv.coopmatrix<4x4xf16, Workgroup, MatrixAcc>
255   spirv.Return
258 // -----
260 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
261                                       %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
262                                       %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
263   // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op operand #0 must be of use 'MatrixA'}}
264   %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
265         !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
266         !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
267           !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
268   spirv.Return
271 // -----
273 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
274                                       %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>) "None" {
275   // expected-error @+1 {{expected ','}}
276   %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b :
277         !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
278         !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
279           !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
280   spirv.Return
283 // -----
285 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
286                                       %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>) "None" {
287   // expected-error @+1 {{expected SSA operand}}
288   %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, <ASigned> :
289         !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
290         !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
291           !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
292   spirv.Return
295 // -----
297 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
298                                       %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
299                                       %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
300   // expected-error @+1 {{expected '<'}}
301   %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, %c :
302         !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
303         !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
304           !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
305   spirv.Return
308 // -----
310 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
311                                       %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixA>,
312                                       %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
313   // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op operand #1 must be of use 'MatrixB'}}
314   %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
315         !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
316         !spirv.coopmatrix<16x4xi32, Subgroup, MatrixA> ->
317           !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
318   spirv.Return
321 // -----
323 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
324                                       %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
325                                       %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixB>) "None" {
326   // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op operand #2 must be of use 'MatrixAcc'}}
327   %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
328         !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
329         !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
330           !spirv.coopmatrix<8x4xi32, Subgroup, MatrixB>
331   spirv.Return
334 // -----
336 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
337                                       %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
338                                       %c : !spirv.coopmatrix<10x4xi32, Subgroup, MatrixAcc>) "None" {
339   // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op matrix size mismatch on dimension 'M'}}
340   %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
341         !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
342         !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
343           !spirv.coopmatrix<10x4xi32, Subgroup, MatrixAcc>
344   spirv.Return
347 // -----
349 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
350                                       %b : !spirv.coopmatrix<4x16xi32, Subgroup, MatrixB>,
351                                       %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
352   // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op matrix size mismatch on dimension 'N'}}
353   %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
354         !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
355         !spirv.coopmatrix<4x16xi32, Subgroup, MatrixB> ->
356           !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
357   spirv.Return
360 // -----
362 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
363                                       %b : !spirv.coopmatrix<12x4xi32, Subgroup, MatrixB>,
364                                       %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
365   // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op matrix size mismatch on dimension 'K'}}
366   %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
367         !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
368         !spirv.coopmatrix<12x4xi32, Subgroup, MatrixB> ->
369           !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>
370   spirv.Return
373 // -----
375 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
376                                       %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
377                                       %c : !spirv.coopmatrix<8x4xi32, Workgroup, MatrixAcc>) "None" {
378   // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op matrix scope mismatch}}
379   %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
380         !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
381         !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB> ->
382           !spirv.coopmatrix<8x4xi32, Workgroup, MatrixAcc>
383   spirv.Return
386 // -----
388 spirv.func @cooperative_matrix_muladd_matrix_operands(%a : !spirv.coopmatrix<8x16xf16, Subgroup, MatrixA>,
389                                                       %b : !spirv.coopmatrix<16x4xf16, Subgroup, MatrixB>,
390                                                       %c : !spirv.coopmatrix<8x4xf16, Subgroup, MatrixAcc>) "None" {
391   // expected-error @+1 {{'spirv.KHR.CooperativeMatrixMulAdd' op Matrix Operands require all matrix element types to be Integer Types}}
392   %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c, <AccSat> :
393         !spirv.coopmatrix<8x16xf16, Subgroup, MatrixA>,
394         !spirv.coopmatrix<16x4xf16, Subgroup, MatrixB> ->
395           !spirv.coopmatrix<8x4xf16, Subgroup, MatrixAcc>
396   spirv.Return
399 // -----
401 //===----------------------------------------------------------------------===//
402 // Standard ops that can be used CooperativeMatrix types
403 //===----------------------------------------------------------------------===//
405 !matA_i32 = !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
406 !matB_i32 = !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>
408 !matA_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
409 !matB_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixB>
411 // These tests are kept in the same order as the list of compatible ops in the
412 // SPV_KHR_cooperative_matrix extension spec.
414 // CHECK-LABEL: @snegate
415 spirv.func @snegate(%a: !matA_i32, %b: !matB_i32) "None" {
416   // CHECK:      spirv.SNegate {{%.*}} : !spirv.coopmatrix
417   // CHECK-NEXT: spirv.SNegate {{%.*}} : !spirv.coopmatrix
418   %p = spirv.SNegate %a : !matA_i32
419   %q = spirv.SNegate %b : !matB_i32
420   spirv.Return
423 // CHECK-LABEL: @fnegate
424 spirv.func @fnegate(%a: !matA_f32, %b: !matB_f32) "None" {
425   // CHECK:      spirv.FNegate {{%.*}} : !spirv.coopmatrix
426   // CHECK-NEXT: spirv.FNegate {{%.*}} : !spirv.coopmatrix
427   %p = spirv.FNegate %a : !matA_f32
428   %q = spirv.FNegate %b : !matB_f32
429   spirv.Return
432 // CHECK-LABEL: @iadd
433 spirv.func @iadd(%a: !matA_i32, %b: !matB_i32) "None" {
434   // CHECK:      spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
435   // CHECK-NEXT: spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
436   %p = spirv.IAdd %a, %a : !matA_i32
437   %q = spirv.IAdd %b, %b : !matB_i32
438   spirv.Return
441 // CHECK-LABEL: @fadd
442 spirv.func @fadd(%a: !matA_f32, %b: !matB_f32) "None" {
443   // CHECK:      spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
444   // CHECK-NEXT: spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
445   %p = spirv.FAdd %a, %a : !matA_f32
446   %q = spirv.FAdd %b, %b : !matB_f32
447   spirv.Return
450 // CHECK-LABEL: @isub
451 spirv.func @isub(%a: !matA_i32, %b: !matB_i32) "None" {
452   // CHECK:      spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix
453   // CHECK-NEXT: spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix
454   %p = spirv.ISub %a, %a : !matA_i32
455   %q = spirv.ISub %b, %b : !matB_i32
456   spirv.Return
459 // CHECK-LABEL: @fsub
460 spirv.func @fsub(%a: !matA_f32, %b: !matB_f32) "None" {
461   // CHECK:      spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
462   // CHECK-NEXT: spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
463   %p = spirv.FSub %a, %a : !matA_f32
464   %q = spirv.FSub %b, %b : !matB_f32
465   spirv.Return
468 // CHECK-LABEL: @fmul
469 spirv.func @fmul(%a: !matA_f32, %b: !matB_f32) "None" {
470   // CHECK:      spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
471   // CHECK-NEXT: spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
472   %p = spirv.FMul %a, %a : !matA_f32
473   %q = spirv.FMul %b, %b : !matB_f32
474   spirv.Return
477 // CHECK-LABEL: @imul
478 spirv.func @imul(%a: !matA_i32, %b: !matB_i32) "None" {
479   // CHECK:      spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
480   // CHECK-NEXT: spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
481   %p = spirv.IMul %a, %a : !matA_i32
482   %q = spirv.IMul %b, %b : !matB_i32
483   spirv.Return
486 // CHECK-LABEL: @fdiv
487 spirv.func @fdiv(%a: !matA_f32, %b: !matB_f32) "None" {
488   // CHECK:      spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
489   // CHECK-NEXT: spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
490   %p = spirv.FDiv %a, %a : !matA_f32
491   %q = spirv.FDiv %b, %b : !matB_f32
492   spirv.Return
495 // CHECK-LABEL: @sdiv
496 spirv.func @sdiv(%a: !matA_i32, %b: !matB_i32) "None" {
497   // CHECK:      spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
498   // CHECK-NEXT: spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
499   %p = spirv.SDiv %a, %a : !matA_i32
500   %q = spirv.SDiv %b, %b : !matB_i32
501   spirv.Return
504 // CHECK-LABEL: @udiv
505 spirv.func @udiv(%a: !matA_i32, %b: !matB_i32) "None" {
506   // CHECK:      spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
507   // CHECK-NEXT: spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
508   %p = spirv.UDiv %a, %a : !matA_i32
509   %q = spirv.UDiv %b, %b : !matB_i32
510   spirv.Return
513 // CHECK-LABEL: @matrix_times_scalar
514 spirv.func @matrix_times_scalar(%a: !matA_f32, %b: f32) "None" {
515   // CHECK: spirv.MatrixTimesScalar {{%.*}} : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, f32
516   %p = spirv.MatrixTimesScalar %a, %b : !matA_f32, f32
517   spirv.Return
520 // -----
522 // For binary arithmetic instructions with coop matrix operands, the types must
523 // match.
525 spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
526                  %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>) "None" {
527   // expected-error @+1 {{op requires the same type for all operands and results}}
528   %q = "spirv.IAdd"(%a, %b) :
529     (!spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>)
530     -> !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
531   spirv.Return
534 // -----
536 spirv.func @fadd(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>,
537                  %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>) "None" {
538   // expected-error @+1 {{op requires the same type for all operands and results}}
539   %q = "spirv.FAdd"(%a, %b) :
540     (!spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>)
541     -> !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>
542   spirv.Return
545 // -----
547 spirv.func @matrix_times_scalar(%a: !spirv.coopmatrix<2x2xf32, Workgroup, MatrixA>, %b: f16) "None" {
548   // expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
549   %p = spirv.MatrixTimesScalar %a, %b : !spirv.coopmatrix<2x2xf32, Workgroup, MatrixA>, f16
550   spirv.Return