[AArch64] Add cost model for @experimental.vector.match (#118512)
[llvm-project.git] / mlir / test / Dialect / SPIRV / IR / composite-ops.mlir
blob3fc8dfb2767d1ea51221b25798cfe562b67c5c44
1 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
3 //===----------------------------------------------------------------------===//
4 // spirv.CompositeConstruct
5 //===----------------------------------------------------------------------===//
7 // CHECK-LABEL: func @composite_construct_vector
8 func.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> {
9   // CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (f32, f32, f32) -> vector<3xf32>
10   %0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xf32>
11   return %0: vector<3xf32>
14 // CHECK-LABEL: func @composite_construct_struct
15 func.func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spirv.array<4xf32>, %arg2 : !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)> {
16   // CHECK: spirv.CompositeConstruct
17   %0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)>
18   return %0: !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)>
21 // CHECK-LABEL: func @composite_construct_mixed_scalar_vector
22 func.func @composite_construct_mixed_scalar_vector(%arg0: f32, %arg1: f32, %arg2 : vector<2xf32>) -> vector<4xf32> {
23   // CHECK: spirv.CompositeConstruct %{{.+}}, %{{.+}}, %{{.+}} : (f32, vector<2xf32>, f32) -> vector<4xf32>
24   %0 = spirv.CompositeConstruct %arg0, %arg2, %arg1 : (f32, vector<2xf32>, f32) -> vector<4xf32>
25   return %0: vector<4xf32>
28 // CHECK-LABEL: func @composite_construct_coopmatrix_khr
29 func.func @composite_construct_coopmatrix_khr(%arg0 : f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> {
30   // CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
31   %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
32   return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
35 // -----
37 func.func @composite_construct_invalid_result_type(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> {
38   // expected-error @+1 {{has incorrect number of operands: expected 3, but provided 2}}
39   %0 = spirv.CompositeConstruct %arg0, %arg2 : (f32, f32) -> vector<3xf32>
40   return %0: vector<3xf32>
43 // -----
45 func.func @composite_construct_invalid_operand_type(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xi32> {
46   // expected-error @+1 {{operand type mismatch: expected operand type 'i32', but provided 'f32'}}
47   %0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xi32>
48   return %0: vector<3xi32>
51 // -----
53 func.func @composite_construct_khr_coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) ->
54   !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> {
55   // expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}}
56   %0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
57   return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
60 // -----
62 func.func @composite_construct_khr_coopmatrix_incorrect_element_type(%arg0 : i32) ->
63   !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB> {
64   // expected-error @+1 {{operand type mismatch: expected operand type 'f32', but provided 'i32'}}
65   %0 = spirv.CompositeConstruct %arg0 : (i32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB>
66   return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB>
69 // -----
71 func.func @composite_construct_array(%arg0: f32) -> !spirv.array<4xf32> {
72   // expected-error @+1 {{expected to return a vector or cooperative matrix when the number of constituents is less than what the result needs}}
73   %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.array<4xf32>
74   return %0: !spirv.array<4xf32>
77 // -----
79 func.func @composite_construct_vector_wrong_element_type(%arg0: f32, %arg1: f32, %arg2 : vector<2xi32>) -> vector<4xf32> {
80   // expected-error @+1 {{operand element type mismatch: expected to be 'f32', but provided 'i32'}}
81   %0 = spirv.CompositeConstruct %arg0, %arg2, %arg1 : (f32, vector<2xi32>, f32) -> vector<4xf32>
82   return %0: vector<4xf32>
85 // -----
87 func.func @composite_construct_vector_wrong_count(%arg0: f32, %arg1: f32, %arg2 : vector<2xf32>) -> vector<4xf32> {
88   // expected-error @+1 {{op has incorrect number of operands: expected 4, but provided 3}}
89   %0 = spirv.CompositeConstruct %arg0, %arg2 : (f32, vector<2xf32>) -> vector<4xf32>
90   return %0: vector<4xf32>
93 // -----
95 //===----------------------------------------------------------------------===//
96 // spirv.CompositeExtractOp
97 //===----------------------------------------------------------------------===//
99 func.func @composite_extract_array(%arg0: !spirv.array<4xf32>) -> f32 {
100   // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[1 : i32] : !spirv.array<4 x f32>
101   %0 = spirv.CompositeExtract %arg0[1 : i32] : !spirv.array<4xf32>
102   return %0: f32
105 // -----
107 func.func @composite_extract_struct(%arg0 : !spirv.struct<(f32, !spirv.array<4xf32>)>) -> f32 {
108   // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[1 : i32, 2 : i32] : !spirv.struct<(f32, !spirv.array<4 x f32>)>
109   %0 = spirv.CompositeExtract %arg0[1 : i32, 2 : i32] : !spirv.struct<(f32, !spirv.array<4xf32>)>
110   return %0 : f32
113 // -----
115 func.func @composite_extract_vector(%arg0 : vector<4xf32>) -> f32 {
116   // CHECK: {{%.*}} = spirv.CompositeExtract {{%.*}}[1 : i32] : vector<4xf32>
117   %0 = spirv.CompositeExtract %arg0[1 : i32] : vector<4xf32>
118   return %0 : f32
121 // -----
123 func.func @composite_extract_no_ssa_operand() -> () {
124   // expected-error @+1 {{expected SSA operand}}
125   %0 = spirv.CompositeExtract [4 : i32, 1 : i32] : !spirv.array<4x!spirv.array<4xf32>>
126   return
129 // -----
131 func.func @composite_extract_invalid_index_type_1() -> () {
132   %0 = spirv.Constant 10 : i32
133   %1 = spirv.Variable : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32>>, Function>
134   %2 = spirv.Load "Function" %1 ["Volatile"] : !spirv.array<4x!spirv.array<4xf32>>
135   // expected-error @+1 {{expected attribute value}}
136   %3 = spirv.CompositeExtract %2[%0] : !spirv.array<4x!spirv.array<4xf32>>
137   return
140 // -----
142 func.func @composite_extract_invalid_index_type_2(%arg0 : !spirv.array<4x!spirv.array<4xf32>>) -> () {
143   // expected-error @+1 {{attribute 'indices' failed to satisfy constraint: 32-bit integer array attribute}}
144   %0 = spirv.CompositeExtract %arg0[1] : !spirv.array<4x!spirv.array<4xf32>>
145   return
148 // -----
150 func.func @composite_extract_invalid_index_identifier(%arg0 : !spirv.array<4x!spirv.array<4xf32>>) -> () {
151   // expected-error @+1 {{expected attribute value}}
152   %0 = spirv.CompositeExtract %arg0 ]1 : i32) : !spirv.array<4x!spirv.array<4xf32>>
153   return
156 // -----
158 func.func @composite_extract_2D_array_out_of_bounds_access_1(%arg0: !spirv.array<4x!spirv.array<4xf32>>) -> () {
159   // expected-error @+1 {{index 4 out of bounds for '!spirv.array<4 x !spirv.array<4 x f32>>'}}
160   %0 = spirv.CompositeExtract %arg0[4 : i32, 1 : i32] : !spirv.array<4x!spirv.array<4xf32>>
161   return
164 // -----
166 func.func @composite_extract_2D_array_out_of_bounds_access_2(%arg0: !spirv.array<4x!spirv.array<4xf32>>
167 ) -> () {
168   // expected-error @+1 {{index 4 out of bounds for '!spirv.array<4 x f32>'}}
169   %0 = spirv.CompositeExtract %arg0[1 : i32, 4 : i32] : !spirv.array<4x!spirv.array<4xf32>>
170   return
173 // -----
175 func.func @composite_extract_struct_element_out_of_bounds_access(%arg0 : !spirv.struct<(f32, !spirv.array<4xf32>)>) -> () {
176   // expected-error @+1 {{index 2 out of bounds for '!spirv.struct<(f32, !spirv.array<4 x f32>)>'}}
177   %0 = spirv.CompositeExtract %arg0[2 : i32, 0 : i32] : !spirv.struct<(f32, !spirv.array<4xf32>)>
178   return
181 // -----
183 func.func @composite_extract_vector_out_of_bounds_access(%arg0: vector<4xf32>) -> () {
184   // expected-error @+1 {{index 4 out of bounds for 'vector<4xf32>'}}
185   %0 = spirv.CompositeExtract %arg0[4 : i32] : vector<4xf32>
186   return
189 // -----
191 func.func @composite_extract_invalid_types_1(%arg0: !spirv.array<4x!spirv.array<4xf32>>) -> () {
192   // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 3}}
193   %0 = spirv.CompositeExtract %arg0[1 : i32, 2 : i32, 3 : i32] : !spirv.array<4x!spirv.array<4xf32>>
194   return
197 // -----
199 func.func @composite_extract_invalid_types_2(%arg0: f32) -> () {
200   // expected-error @+1 {{cannot extract from non-composite type 'f32' with index 1}}
201   %0 = spirv.CompositeExtract %arg0[1 : i32] : f32
202   return
205 // -----
207 func.func @composite_extract_invalid_extracted_type(%arg0: !spirv.array<4x!spirv.array<4xf32>>) -> () {
208   // expected-error @+1 {{expected at least one index for spirv.CompositeExtract}}
209   %0 = spirv.CompositeExtract %arg0[] : !spirv.array<4x!spirv.array<4xf32>>
210   return
213 // -----
215 func.func @composite_extract_result_type_mismatch(%arg0: !spirv.array<4xf32>) -> i32 {
216   // expected-error @+1 {{invalid result type: expected 'f32' but provided 'i32'}}
217   %0 = "spirv.CompositeExtract"(%arg0) {indices = [2: i32]} : (!spirv.array<4xf32>) -> (i32)
218   return %0: i32
221 // -----
223 //===----------------------------------------------------------------------===//
224 // spirv.CompositeInsert
225 //===----------------------------------------------------------------------===//
227 func.func @composite_insert_array(%arg0: !spirv.array<4xf32>, %arg1: f32) -> !spirv.array<4xf32> {
228   // CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32] : f32 into !spirv.array<4 x f32>
229   %0 = spirv.CompositeInsert %arg1, %arg0[1 : i32] : f32 into !spirv.array<4xf32>
230   return %0: !spirv.array<4xf32>
233 // -----
235 func.func @composite_insert_struct(%arg0: !spirv.struct<(!spirv.array<4xf32>, f32)>, %arg1: !spirv.array<4xf32>) -> !spirv.struct<(!spirv.array<4xf32>, f32)> {
236   // CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[0 : i32] : !spirv.array<4 x f32> into !spirv.struct<(!spirv.array<4 x f32>, f32)>
237   %0 = spirv.CompositeInsert %arg1, %arg0[0 : i32] : !spirv.array<4xf32> into !spirv.struct<(!spirv.array<4xf32>, f32)>
238   return %0: !spirv.struct<(!spirv.array<4xf32>, f32)>
241 // -----
243 func.func @composite_insert_no_indices(%arg0: !spirv.array<4xf32>, %arg1: f32) -> !spirv.array<4xf32> {
244   // expected-error @+1 {{expected at least one index}}
245   %0 = spirv.CompositeInsert %arg1, %arg0[] : f32 into !spirv.array<4xf32>
246   return %0: !spirv.array<4xf32>
249 // -----
251 func.func @composite_insert_out_of_bounds(%arg0: !spirv.array<4xf32>, %arg1: f32) -> !spirv.array<4xf32> {
252   // expected-error @+1 {{index 4 out of bounds}}
253   %0 = spirv.CompositeInsert %arg1, %arg0[4 : i32] : f32 into !spirv.array<4xf32>
254   return %0: !spirv.array<4xf32>
257 // -----
259 func.func @composite_insert_invalid_object_type(%arg0: !spirv.array<4xf32>, %arg1: f64) -> !spirv.array<4xf32> {
260   // expected-error @+1 {{object operand type should be 'f32', but found 'f64'}}
261   %0 = spirv.CompositeInsert %arg1, %arg0[3 : i32] : f64 into !spirv.array<4xf32>
262   return %0: !spirv.array<4xf32>
265 // -----
267 func.func @composite_insert_invalid_result_type(%arg0: !spirv.array<4xf32>, %arg1 : f32) -> !spirv.array<4xf64> {
268   // expected-error @+1 {{result type should be the same as the composite type, but found '!spirv.array<4 x f32>' vs '!spirv.array<4 x f64>'}}
269   %0 = "spirv.CompositeInsert"(%arg1, %arg0) {indices = [0: i32]} : (f32, !spirv.array<4xf32>) -> !spirv.array<4xf64>
270   return %0: !spirv.array<4xf64>
273 // -----
275 //===----------------------------------------------------------------------===//
276 // spirv.VectorExtractDynamic
277 //===----------------------------------------------------------------------===//
279 func.func @vector_dynamic_extract(%vec: vector<4xf32>, %id : i32) -> f32 {
280   // CHECK: spirv.VectorExtractDynamic %{{.*}}[%{{.*}}] : vector<4xf32>, i32
281   %0 = spirv.VectorExtractDynamic %vec[%id] : vector<4xf32>, i32
282   return %0 : f32
285 //===----------------------------------------------------------------------===//
286 // spirv.VectorInsertDynamic
287 //===----------------------------------------------------------------------===//
289 func.func @vector_dynamic_insert(%val: f32, %vec: vector<4xf32>, %id : i32) -> vector<4xf32> {
290   // CHECK: spirv.VectorInsertDynamic %{{.*}}, %{{.*}}[%{{.*}}] : vector<4xf32>, i32
291   %0 = spirv.VectorInsertDynamic %val, %vec[%id] : vector<4xf32>, i32
292   return %0 : vector<4xf32>
295 // -----
297 //===----------------------------------------------------------------------===//
298 // spirv.VectorShuffle
299 //===----------------------------------------------------------------------===//
301 func.func @vector_shuffle(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> {
302   // CHECK: %{{.+}} = spirv.VectorShuffle [1 : i32, 3 : i32, -1 : i32] %{{.+}}, %arg1 : vector<4xf32>, vector<2xf32> -> vector<3xf32>
303   %0 = spirv.VectorShuffle [1: i32, 3: i32, 0xffffffff: i32] %vector1, %vector2 : vector<4xf32>, vector<2xf32> -> vector<3xf32>
304   return %0: vector<3xf32>
307 // -----
309 func.func @vector_shuffle_extra_selector(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> {
310   // expected-error @+1 {{result type element count (3) mismatch with the number of component selectors (4)}}
311   %0 = spirv.VectorShuffle [1: i32, 3: i32, 5: i32, 2: i32] %vector1, %vector2 : vector<4xf32>, vector<2xf32> -> vector<3xf32>
312   return %0: vector<3xf32>
315 // -----
317 func.func @vector_shuffle_extra_selector(%vector1: vector<4xf32>, %vector2: vector<2xf32>) -> vector<3xf32> {
318   // expected-error @+1 {{component selector 7 out of range: expected to be in [0, 6) or 0xffffffff}}
319   %0 = spirv.VectorShuffle [1: i32, 7: i32, 5: i32] %vector1, %vector2 : vector<4xf32>, vector<2xf32> -> vector<3xf32>
320   return %0: vector<3xf32>