[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / test / Dialect / Vector / vector-dropleadunitdim-transforms.mlir
blob9526d610e490e7457b273d9f1217df70661c02cf
1 // RUN: mlir-opt %s -test-vector-to-vector-lowering -split-input-file| FileCheck %s
3 // CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
4 // CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
5 // CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
7 // CHECK-LABEL: cast_away_contraction_leading_one_dims
8 //  CHECK-NEXT:   %[[R0:.+]] =  vector.extract %{{.*}}[0] : vector<16x8xf32> from vector<1x16x8xf32>
9 //  CHECK-NEXT:   %[[R1:.+]] =  vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32>
10 //  CHECK-NEXT:   %[[R2:.+]] =  vector.extract %{{.*}}[0] : vector<16x16xf32> from vector<1x16x16xf32>
11 //  CHECK-NEXT:   %[[R3:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
12 //  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
13 //  CHECK-SAME:   %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
14 //  CHECK-NEXT:   %[[R4:.+]] = vector.broadcast %[[R3]] : vector<16x16xf32> to vector<1x16x16xf32>
15 //  CHECK-NEXT:  return %[[R4]] : vector<1x16x16xf32>
17 #contraction_accesses0 = [
18   affine_map<(l, i, j, k) -> (l, i, k)>,
19   affine_map<(l, i, j, k) -> (l, k, j)>,
20   affine_map<(l, i, j, k) -> (l, i, j)>
22 #contraction_trait0 = {
23   indexing_maps = #contraction_accesses0,
24   iterator_types = ["parallel", "parallel", "parallel", "reduction"]
27 func.func @cast_away_contraction_leading_one_dims(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
28   %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2  : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
29   return %0: vector<1x16x16xf32>
32 // -----
33 // CHECK: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
34 // CHECK: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
35 // CHECK: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
37 // CHECK-LABEL:   func.func @cast_away_contraction_leading_one_dim_under_const_mask
38 // CHECK:           %[[MASK:.*]] = vector.constant_mask [15, 15, 8] : vector<16x16x8xi1>
39 // CHECK:           %[[R0:.*]] = vector.extract %{{.*}}[0] : vector<16x8xf32> from vector<1x16x8xf32>
40 // CHECK:           %[[R1:.*]] = vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32>
41 // CHECK:           %[[R2:.*]] = vector.extract %{{.*}}[0] : vector<16x16xf32> from vector<1x16x16xf32>
42 // CHECK:           %[[CONTRACT:.*]] = vector.mask %[[MASK]] {
43 // CHECK-SAME:        vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
44 // CHECK-SAME:          %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
45 // CHECK-SAME:      } : vector<16x16x8xi1> -> vector<16x16xf32>
46 // CHECK:           %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
47 // CHECK:           return %[[RES]] : vector<1x16x16xf32>
49 #contraction_accesses0 = [
50   affine_map<(l, i, j, k) -> (l, i, k)>,
51   affine_map<(l, i, j, k) -> (l, k, j)>,
52   affine_map<(l, i, j, k) -> (l, i, j)>
54 #contraction_trait0 = {
55   indexing_maps = #contraction_accesses0,
56   iterator_types = ["parallel", "parallel", "parallel", "reduction"]
59 func.func @cast_away_contraction_leading_one_dim_under_const_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
60   %mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
61   %0 = vector.mask %mask {
62     vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
63   } : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
64   return %0 : vector<1x16x16xf32>
67 // -----
68 // CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
69 // CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
70 // CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
72 // CHECK-LABEL:   func.func @cast_away_contraction_leading_one_dim_under_mask
73 // CHECK:           %[[R0:.*]] = vector.extract %{{.*}} : vector<16x8xf32> from vector<1x16x8xf32>
74 // CHECK:           %[[R1:.*]] = vector.extract %{{.*}} : vector<8x16xf32> from vector<1x8x16xf32>
75 // CHECK:           %[[R2:.*]] = vector.extract %{{.*}} : vector<16x16xf32> from vector<1x16x16xf32>
76 // CHECK:           %[[M:.*]] = vector.extract %{{.*}} : vector<16x16x8xi1> from vector<1x16x16x8xi1>
77 // CHECK:           %[[CONTRACT:.*]] = vector.mask %[[M]] {
78 // CHECK-SAME:      vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
79 // CHECK-SAME:          %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
80 // CHECK-SAME:      } : vector<16x16x8xi1> -> vector<16x16xf32>
81 // CHECK-NEXT:      %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
82 // CHECK-NEXT:      return %[[RES]] : vector<1x16x16xf32>
84 #contraction_accesses0 = [
85   affine_map<(l, i, j, k) -> (l, i, k)>,
86   affine_map<(l, i, j, k) -> (l, k, j)>,
87   affine_map<(l, i, j, k) -> (l, i, j)>
89 #contraction_trait0 = {
90   indexing_maps = #contraction_accesses0,
91   iterator_types = ["parallel", "parallel", "parallel", "reduction"]
94 func.func @cast_away_contraction_leading_one_dim_under_mask(
95   %arg0: vector<1x16x8xf32>,
96   %arg1: vector<1x8x16xf32>,
97   %arg2: vector<1x16x16xf32>,
98   %mask: vector<1x16x16x8xi1>) -> vector<1x16x16xf32> {
99   %0 = vector.mask %mask {
100     vector.contract #contraction_trait0 %arg0, %arg1, %arg2  : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
101   } : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
102   return %0: vector<1x16x16xf32>
105 // -----
107 // CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1) -> (d1)>
108 // CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1) -> (d1, d0)>
109 // CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1) -> (d0)>
111 // CHECK-LABEL: cast_away_contraction_leading_one_dims_transposeneeded
112 //  CHECK-NEXT:   %[[R0:.+]] =  vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32>
113 //  CHECK-NEXT:   %[[R1:.+]] =  vector.extract %{{.*}}[0, 0] : vector<8xf32> from vector<1x1x8xf32>
114 //  CHECK-NEXT:   %[[R2:.+]] =  vector.extract %{{.*}}[0, 0] : vector<16xf32> from vector<1x1x16xf32>
115 //  CHECK-NEXT:   %[[R3:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
116 //  CHECK-SAME:   iterator_types = ["parallel", "reduction"], kind = #vector.kind<mul>}
117 //  CHECK-SAME:   %[[R1]], %[[R0]], %[[R2]] : vector<8xf32>, vector<8x16xf32> into vector<16xf32>
118 //  CHECK-NEXT:   %[[R4:.+]] = vector.broadcast %[[R3]] : vector<16xf32> to vector<1x16xf32>
119 //  CHECK-NEXT:   %[[R5:.+]] = vector.broadcast %[[R4]] : vector<1x16xf32> to vector<1x1x16xf32>
120 //  CHECK-NEXT:  return %[[R5]] : vector<1x1x16xf32>
122 #contraction_accesses1 = [
123   affine_map<(l, i, j, k) -> (i, l, k)>,
124   affine_map<(l, i, j, k) -> (l, k, j)>,
125   affine_map<(l, i, j, k) -> (l, i, j)>
127 #contraction_trait1 = {
128   indexing_maps = #contraction_accesses1,
129   iterator_types = ["parallel", "parallel", "parallel", "reduction"],
130   kind = #vector.kind<mul>
133 func.func @cast_away_contraction_leading_one_dims_transposeneeded(%arg0: vector<1x1x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x1x16xf32>) -> vector<1x1x16xf32> {
134   %0 = vector.contract #contraction_trait1 %arg0, %arg1, %arg2  : vector<1x1x8xf32>, vector<1x8x16xf32> into vector<1x1x16xf32>
135   return %0: vector<1x1x16xf32>
138 // -----
139 // CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
140 // CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
141 // CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
143 // CHECK-LABEL: cast_away_contraction_leading_one_dims_transposeneeded2
144 //  CHECK-NEXT:   %[[R0:.+]] =  vector.transpose %{{.*}}[1, 0, 2] : vector<8x1x16xf32> to vector<1x8x16xf32>
145 //  CHECK-NEXT:   %[[R1:.+]] =  vector.extract %[[R0]][0] : vector<8x16xf32> from vector<1x8x16xf32>
146 //  CHECK-NEXT:   %[[R2:.+]] =  vector.transpose %{{.*}}[2, 0, 1] : vector<2x8x1xf32> to vector<1x2x8xf32>
147 //  CHECK-NEXT:   %[[R3:.+]] =  vector.extract %[[R2]][0] : vector<2x8xf32> from vector<1x2x8xf32>
148 //  CHECK-NEXT:   %[[R4:.+]] =  vector.extract %{{.*}}[0] : vector<2x16xf32> from vector<1x2x16xf32>
149 //  CHECK-NEXT:   %[[R5:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
150 //  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
151 //  CHECK-SAME:   %[[R1]], %[[R3]], %[[R4]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32>
152 //  CHECK-NEXT:   %[[R6:.+]] = vector.broadcast %[[R5]] : vector<2x16xf32> to vector<1x2x16xf32>
153 //  CHECK-NEXT:  return %[[R6]] : vector<1x2x16xf32>
155 #contraction_accesses2 = [
156   affine_map<(l, i, j, k) -> (k, l, j)>,
157   affine_map<(l, i, j, k) -> (i, k, l)>,
158   affine_map<(l, i, j, k) -> (l, i, j)>
160 #contraction_trait2 = {
161   indexing_maps = #contraction_accesses2,
162   iterator_types = ["parallel", "parallel", "parallel", "reduction"]
166 func.func @cast_away_contraction_leading_one_dims_transposeneeded2(%arg0: vector<8x1x16xf32>, %arg1: vector<2x8x1xf32>, %arg2: vector<1x2x16xf32>) -> vector<1x2x16xf32> {
167   %0 = vector.contract #contraction_trait2 %arg0, %arg1, %arg2  : vector<8x1x16xf32>, vector<2x8x1xf32> into vector<1x2x16xf32>
168   return %0: vector<1x2x16xf32>
171 // -----
172 // CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
173 // CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
174 // CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
177 // CHECK-LABEL: cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4
178 //  CHECK-NEXT:   %[[R0:.+]] =  vector.extract %{{.*}}[0] : vector<8x1x16xf32> from vector<1x8x1x16xf32>
179 //  CHECK-NEXT:   %[[R1:.+]] =  vector.extract %{{.*}}[0] : vector<2x8x1xf32> from vector<1x2x8x1xf32>
180 //  CHECK-NEXT:   %[[R2:.+]] =  vector.transpose %[[R0]], [1, 0, 2] : vector<8x1x16xf32> to vector<1x8x16xf32>
181 //  CHECK-NEXT:   %[[R3:.+]] =  vector.extract %[[R2]][0] : vector<8x16xf32> from vector<1x8x16xf32>
182 //  CHECK-NEXT:   %[[R4:.+]] =  vector.transpose %[[R1]], [2, 0, 1] : vector<2x8x1xf32> to vector<1x2x8xf32>
183 //  CHECK-NEXT:   %[[R5:.+]] =  vector.extract %[[R4]][0] : vector<2x8xf32> from vector<1x2x8xf32>
184 //  CHECK-NEXT:   %[[R6:.+]] =  vector.extract %{{.*}}[0, 0] : vector<2x16xf32> from vector<1x1x2x16xf32>
185 //  CHECK-NEXT:   %[[R7:.+]] =  vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
186 //  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
187 //  CHECK-SAME:   %[[R3]], %[[R5]], %[[R6]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32>
188 //  CHECK-NEXT:   %[[R8:.+]] =  vector.broadcast %[[R7]] : vector<2x16xf32> to vector<1x2x16xf32>
189 //  CHECK-NEXT:   %[[R9:.+]] =  vector.broadcast %[[R8]] : vector<1x2x16xf32> to vector<1x1x2x16xf32>
190 //  CHECK-NEXT:  return %[[R9]] : vector<1x1x2x16xf32>
192 #contraction_accesses2 = [
193   affine_map<(m, l, i, j, k) -> (m, k, l, j)>,
194   affine_map<(m, l, i, j, k) -> (m, i, k, l)>,
195   affine_map<(m, l, i, j, k) -> (m, l, i, j)>
197 #contraction_trait2 = {
198   indexing_maps = #contraction_accesses2,
199   iterator_types = ["parallel","parallel", "parallel", "parallel", "reduction"]
203 func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4(%arg0: vector<1x8x1x16xf32>, %arg1: vector<1x2x8x1xf32>, %arg2: vector<1x1x2x16xf32>) -> vector<1x1x2x16xf32> {
204   %0 = vector.contract #contraction_trait2 %arg0, %arg1, %arg2  : vector<1x8x1x16xf32>, vector<1x2x8x1xf32> into vector<1x1x2x16xf32>
205   return %0: vector<1x1x2x16xf32>
208 // -----
209 // CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
210 // CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
211 // CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
213 // CHECK-LABEL: cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctranspose
214 //  CHECK-NEXT:   %[[R0:.+]] =  vector.transpose %{{.*}}, [2, 0, 1, 3] : vector<1x8x1x16xf32> to vector<1x1x8x16xf32>
215 //  CHECK-NEXT:   %[[R1:.+]] =  vector.transpose %{{.*}}, [3, 0, 1, 2] : vector<1x2x8x1xf32> to vector<1x1x2x8xf32>
216 //  CHECK-NEXT:   %[[R2:.+]] =  vector.extract %[[R0]][0, 0] : vector<8x16xf32> from vector<1x1x8x16xf32>
217 //  CHECK-NEXT:   %[[R3:.+]] =  vector.extract %[[R1]][0, 0] : vector<2x8xf32> from vector<1x1x2x8xf32>
218 //  CHECK-NEXT:   %[[R4:.+]] =  vector.extract %{{.*}}[0, 0] : vector<2x16xf32> from vector<1x1x2x16xf32>
219 //  CHECK-NEXT:   %[[R5:.+]] =  vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]],
220 //  CHECK-SAME:   iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
221 //  CHECK-SAME:   %[[R2]], %[[R3]], %[[R4]] : vector<8x16xf32>, vector<2x8xf32> into vector<2x16xf32>
222 //  CHECK-NEXT:   %[[R6:.+]] =  vector.broadcast %[[R5]] : vector<2x16xf32> to vector<1x2x16xf32>
223 //  CHECK-NEXT:   %[[R7:.+]] =  vector.broadcast %[[R6]] : vector<1x2x16xf32> to vector<1x1x2x16xf32>
224 //  CHECK-NEXT:  return %[[R7]] : vector<1x1x2x16xf32>
226 #contraction_accesses3 = [
227   affine_map<(m, l, i, j, k) -> (m, k, l, j)>,
228   affine_map<(m, l, i, j, k) -> (m, i, k, l)>,
229   affine_map<(m, l, i, j, k) -> (l, m, i, j)>
231 #contraction_trait3 = {
232   indexing_maps = #contraction_accesses3,
233   iterator_types = ["parallel","parallel", "parallel", "parallel", "reduction"]
236 func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctranspose(%arg0: vector<1x8x1x16xf32>, %arg1: vector<1x2x8x1xf32>, %arg2: vector<1x1x2x16xf32>) -> vector<1x1x2x16xf32> {
237   %0 = vector.contract #contraction_trait3 %arg0, %arg1, %arg2  : vector<1x8x1x16xf32>, vector<1x2x8x1xf32> into vector<1x1x2x16xf32>
238   return %0: vector<1x1x2x16xf32>
241 // -----
243 // CHECK-LABEL:   func.func @cast_away_contraction_does_not_transpose_leading_unit_dims
244 // CHECK-NOT:  vector.transpose
245 // CHECK:           vector.contract
246 func.func @cast_away_contraction_does_not_transpose_leading_unit_dims(%lhs: vector<1x1x8xi32>,
247                           %rhs: vector<1x8x8xi32>,
248                           %acc: vector<1x8xi32>) -> vector<1x8xi32> {
249   %result = vector.contract {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<1x1x8xi32>, vector<1x8x8xi32> into vector<1x8xi32>
250   return %result : vector<1x8xi32>
253 // -----
254 // CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
255 func.func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> {
256   // CHECK:     %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8x8xf16> from vector<1x8x8xf16>
257   // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x8xf16> to vector<1x8xf16>
258   %0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x8xf16> to vector<1x1x8xf16>
259   // CHECK:     %[[RET:.+]] = vector.broadcast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16>
260   // CHECK: return %[[RET]]
261   return %0: vector<1x1x8xf16>
264 // CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims_scalable
265 func.func @cast_away_extract_strided_slice_leading_one_dims_scalable(%arg0: vector<1x8x[8]xf16>) -> vector<1x1x[8]xf16> {
266   // CHECK:     %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8x[8]xf16> from vector<1x8x[8]xf16>
267   // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x[8]xf16> to vector<1x[8]xf16>
268   %0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x[8]xf16> to vector<1x1x[8]xf16>
269   // CHECK:     %[[RET:.+]] = vector.broadcast %[[EXTRACT]] : vector<1x[8]xf16> to vector<1x1x[8]xf16>
270   // CHECK: return %[[RET]]
271   return %0: vector<1x1x[8]xf16>
274 // CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims
275 func.func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %arg1: vector<1x8x8xf16>) -> vector<1x8x8xf16> {
276   // CHECK:    %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8xf16> from vector<1x8xf16>
277   // CHECK:    %[[DST:.+]] = vector.extract %{{.*}}[0] : vector<8x8xf16> from vector<1x8x8xf16>
278   // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<8xf16> into vector<8x8xf16>
279   %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<1x8x8xf16>
280   // CHECK:    %[[RET:.+]] = vector.broadcast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16>
281   // CHECK: return %[[RET]]
282   return %0: vector<1x8x8xf16>
285 // CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_scalable
286 func.func @cast_away_insert_strided_slice_leading_one_dims_scalable(%arg0: vector<1x[8]xf16>, %arg1: vector<1x8x[8]xf16>) -> vector<1x8x[8]xf16> {
287   // CHECK:    %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<[8]xf16> from vector<1x[8]xf16>
288   // CHECK:    %[[DST:.+]] = vector.extract %{{.*}}[0] : vector<8x[8]xf16> from vector<1x8x[8]xf16>
289   // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<[8]xf16> into vector<8x[8]xf16>
290   %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x[8]xf16> into vector<1x8x[8]xf16>
291   // CHECK:    %[[RET:.+]] = vector.broadcast %[[INSERT]] : vector<8x[8]xf16> to vector<1x8x[8]xf16>
292   // CHECK: return %[[RET]]
293   return %0: vector<1x8x[8]xf16>
296 // CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element
297 //  CHECK-SAME: %[[ARG0:.+]]: vector<1x1xf16>, %{{.+}}: vector<1x1x1xf16>
298 func.func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: vector<1x1xf16>, %arg1: vector<1x1x1xf16>) -> vector<1x1x1xf16> {
299   // CHECK: %[[EXT:.+]] = vector.extract %{{.*}}[0] : vector<1xf16> from vector<1x1xf16>
300   // CHECK: %[[B:.+]] = vector.broadcast %[[EXT]] : vector<1xf16> to vector<1x1x1xf16>
301   %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x1xf16> into vector<1x1x1xf16>
302   // CHECK: return %[[B]]
303   return %0: vector<1x1x1xf16>
306 // CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element_scalable
307 //  CHECK-SAME: %[[ARG0:.+]]: vector<1x[1]xf16>, %{{.+}}: vector<1x1x[1]xf16>
308 func.func @cast_away_insert_strided_slice_leading_one_dims_one_element_scalable(%arg0: vector<1x[1]xf16>, %arg1: vector<1x1x[1]xf16>) -> vector<1x1x[1]xf16> {
309   // CHECK: %[[EXT:.+]] = vector.extract %{{.*}}[0] : vector<[1]xf16> from vector<1x[1]xf16>
310   // CHECK: %[[B:.+]] = vector.broadcast %[[EXT]] : vector<[1]xf16> to vector<1x1x[1]xf16>
311   %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x[1]xf16> into vector<1x1x[1]xf16>
312   // CHECK: return %[[B]]
313   return %0: vector<1x1x[1]xf16>
316 // CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims
317 func.func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>) -> vector<1x4xf16> {
318   // CHECK: %[[C0:.+]] = arith.constant 0 : index
319   %c0 = arith.constant 0 : index
320   // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
321   %f0 = arith.constant 0. : f16
322   // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16>
323   // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16>
324   %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16>
325   // CHECK: return %[[CAST]]
326   return %0: vector<1x4xf16>
329 // CHECK-LABEL: func @cast_away_masked_transfer_read_leading_one_dims
330 func.func @cast_away_masked_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xi1>) -> vector<1x4xf16> {
331   // CHECK: %[[C0:.+]] = arith.constant 0 : index
332   %c0 = arith.constant 0 : index
333   // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
334   %f0 = arith.constant 0. : f16
335   // CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
336   // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16>
337   // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16>
338   %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16>
339   // CHECK: return %[[CAST]]
340   return %0: vector<1x4xf16>
343 // CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims_one_element
344 func.func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> {
345   %c0 = arith.constant 0 : index
346   %f0 = arith.constant 0. : f16
347   // CHECK: vector.broadcast %{{.+}} : vector<1xf16> to vector<1x1xf16>
348   %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x1x1xf16>, vector<1x1xf16>
349   return %0: vector<1x1xf16>
352 // -----
354 // CHECK:       #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
355 // CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_read
356 func.func @cast_away_nontrivial_map_masked_transfer_read(%arg0: memref<1x4x8xf16>, %arg1: vector<1x4x1xi1>) -> vector<1x1x4xf16> {
357   // CHECK: %[[C0:.+]] = arith.constant 0 : index
358   %c0 = arith.constant 0 : index
359   // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16
360   %f0 = arith.constant 0. : f16
361   // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1>
362   // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[F0]], %[[MASK_CAST]] {in_bounds = [true]
363   // CHECK-SAME: permutation_map = #[[$MAP]]} : memref<1x4x8xf16>, vector<4xf16>
364   // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x1x4xf16>
365   %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %f0, %arg1 {in_bounds = [true, true, true],
366                             permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : memref<1x4x8xf16>, vector<1x1x4xf16>
367   // CHECK: return %[[CAST]]
368   return %0: vector<1x1x4xf16>
371 // -----
373 // CHECK-LABEL: func @not_insert_cast_fo4_transfer_read_under_mask
374 // CHECK:      %[[MASK:.+]] = vector.constant_mask
375 // CHECK:      %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
376 // CHECK:      %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
377 // CHECK-SAME:   vector.transfer_read {{.*}} : memref<1x1x4xf16>, vector<1x4xf16> }
378 // CHECK:      return %[[RET]] : vector<1x4xf16>
379 func.func @not_insert_cast_fo4_transfer_read_under_mask(%arg0: memref<1x1x4xf16>) -> vector<1x4xf16> {
380   %c0 = arith.constant 0 : index
381   %f0 = arith.constant 0. : f16
382   %mask = vector.constant_mask [1, 3] : vector<1x4xi1>
383   %ret = vector.mask %mask {
384     vector.transfer_read %arg0[%c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x4xf16>, vector<1x4xf16>
385   } : vector<1x4xi1> -> vector<1x4xf16>
386   return %ret: vector<1x4xf16>
389 // -----
391 // CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims
392 func.func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) {
393   // CHECK: %[[C0:.+]] = arith.constant 0 : index
394   %c0 = arith.constant 0 : index
395   // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xf16> from vector<1x4xf16>
396   // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16>
398   vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16>
399   return
402 // CHECK-LABEL: func @cast_away_masked_transfer_write_leading_one_dims
403 func.func @cast_away_masked_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>, %arg2: vector<1x4xi1>) {
404   // CHECK: %[[C0:.+]] = arith.constant 0 : index
405   %c0 = arith.constant 0 : index
406   // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xf16> from vector<1x4xf16>
407   // CHECK: %[[MASK_CAST:.+]] = vector.extract %{{.*}}[0] : vector<4xi1> from vector<1x4xi1>
408   // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16>
410   vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0], %arg2 {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16>
411   return
414 // CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element
415 func.func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) {
416   %c0 = arith.constant 0 : index
417   // CHECK: vector.extract %{{.+}}[0] : vector<1xf16> from vector<1x1xf16>
418   vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x1xf16>, memref<1x1x1x1xf16>
419   return
422 // -----
424 // CHECK-LABEL: func @not_insert_cast_for_transfer_write_under_mask
425 // CHECK:      %[[MASK:.+]] = vector.constant_mask
426 // CHECK:      %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
427 // CHECK:      vector.mask %[[CASTED_MASK]] {
428 // CHECK-SAME:   vector.transfer_write {{.*}} : vector<1x4xf16>, memref<1x1x4xf16> }
429 // CHECK:      return
430 func.func @not_insert_cast_for_transfer_write_under_mask(%arg0: memref<1x1x4xf16>, %arg1: vector<1x4xf16>) {
431   %c0 = arith.constant 0 : index
432   %mask = vector.constant_mask [1, 3] : vector<1x4xi1>
433   vector.mask %mask {
434     vector.transfer_write %arg1, %arg0[%c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x1x4xf16>
435   } : vector<1x4xi1>
436   return
439 // -----
441 // CHECK:       #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d1)>
442 // CHECK-LABEL: func @cast_away_nontrivial_map_masked_transfer_write
443 func.func @cast_away_nontrivial_map_masked_transfer_write(%arg0: memref<1x4x8xf16>, %arg1: vector<1x1x4xf16>, %arg2: vector<1x4x1xi1>) {
444   // CHECK: %[[C0:.+]] = arith.constant 0 : index
445   %c0 = arith.constant 0 : index
446   // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0, 0] : vector<4xf16> from vector<1x1x4xf16>
447   // CHECK: %[[MASK_CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4x1xi1> to vector<4xi1>
448   // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %[[MASK_CAST]] {in_bounds = [true]
449   // CHECK-SAME: permutation_map = #[[$MAP]]} : vector<4xf16>, memref<1x4x8xf16>
451   vector.transfer_write %arg1, %arg0[%c0, %c0, %c0], %arg2 {in_bounds = [true, true, true],
452                         permutation_map = affine_map<(d0, d1, d2) -> (d0, d2, d1)>} : vector<1x1x4xf16>, memref<1x4x8xf16>
453   return
456 // -----
458 // CHECK-LABEL: func @cast_away_elementwise_leading_one_dims
459 func.func @cast_away_elementwise_leading_one_dims(
460   %arg0: vector<1x1x8xf32>, %arg1: f32, %arg2: vector<1x4xf32>,
461   %arg3: vector<1x4xf32>, %arg4: i1) ->
462   (vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>) {
463   // CHECK:  vector.extract %{{.*}}[0, 0] : vector<8xf32> from vector<1x1x8xf32>
464   // CHECK:  vector.extract %{{.*}}[0, 0] : vector<8xf32> from vector<1x1x8xf32>
465   // CHECK:  arith.addf %{{.*}}, %{{.*}} : vector<8xf32>
466   // CHECK:  vector.broadcast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
467   %0 = arith.addf %arg0, %arg0 : vector<1x1x8xf32>
468   // CHECK:  vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
469   // CHECK:  vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
470   // CHECK:  arith.cmpf ogt, %{{.*}}, %{{.*}} : vector<4xf32>
471   // CHECK:  vector.broadcast %{{.*}} : vector<4xi1> to vector<1x4xi1>
472   %1 = arith.cmpf ogt, %arg2, %arg3 : vector<1x4xf32>
473   // CHECK:  vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
474   // CHECK:  vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
475   // CHECK:  select %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi1>, vector<4xf32>
476   // CHECK:  vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32>
477   %2 = arith.select %1, %arg3, %arg2 : vector<1x4xi1>, vector<1x4xf32>
478   // CHECK:  vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
479   // CHECK:  vector.extract %{{.*}}[0] : vector<4xf32> from vector<1x4xf32>
480   // CHECK:  select %arg4, %12, %{{.*}} : vector<4xf32>
481   // CHECK:  vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32>
482   %3 = arith.select %arg4, %arg3, %arg2 : vector<1x4xf32>
483   return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>
486 // -----
488 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_scalar
489 //  CHECK-SAME: (%[[S:.+]]: f32, %[[V:.+]]: vector<1x1x4xf32>)
490 //       CHECK:   %[[EXTRACT:.+]] = vector.extract %[[V]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
491 //       CHECK:   %[[INSERT:.+]] = vector.insert %[[S]], %[[EXTRACT]] [0] : f32 into vector<4xf32>
492 //       CHECK:   %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<4xf32> to vector<1x1x4xf32>
493 //       CHECK:   return %[[BCAST]]
494 func.func @cast_away_insert_leading_one_dims_scalar(%s: f32, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> {
495   %0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x1x4xf32>
496   return %0: vector<1x1x4xf32>
499 // -----
501 // CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_scalar_scalable(
502 // CHECK-SAME:    %[[S:.*]]: f32,
503 // CHECK-SAME:    %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
504 func.func @cast_away_insert_leading_one_dims_scalar_scalable(%s: f32, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
505 // CHECK:           %[[EXTRACT:.*]] = vector.extract %[[V]][0, 0] : vector<[4]xf32> from vector<1x1x[4]xf32>
506 // CHECK:           %[[INSERT:.*]] = vector.insert %[[S]], %[[EXTRACT]] [0] : f32 into vector<[4]xf32>
507 // CHECK:           %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<[4]xf32> to vector<1x1x[4]xf32>
508 // CHECK:           return %[[BCAST]] : vector<1x1x[4]xf32>
509   %0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x1x[4]xf32>
510   return %0: vector<1x1x[4]xf32>
513 // -----
515 // CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_scalar_skip_scalable_dim(
516 // CHECK-SAME:    %[[S:.*]]: f32,
517 // CHECK-SAME:    %[[V:.*]]: vector<1x[1]x4xf32>) -> vector<1x[1]x4xf32> {
518 func.func @cast_away_insert_leading_one_dims_scalar_skip_scalable_dim(%s: f32, %v: vector<1x[1]x4xf32>) -> vector<1x[1]x4xf32> {
519 // CHECK:           %[[EXTRACT:.*]] = vector.extract %[[V]][0] : vector<[1]x4xf32> from vector<1x[1]x4xf32>
520 // CHECK:           %[[INSERT:.*]] = vector.insert %[[S]], %[[EXTRACT]] [0, 0] : f32 into vector<[1]x4xf32>
521 // CHECK:           %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<[1]x4xf32> to vector<1x[1]x4xf32>
522 // CHECK:           return %[[BCAST]] : vector<1x[1]x4xf32>
523   %0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x[1]x4xf32>
524   return %0: vector<1x[1]x4xf32>
527 // -----
529 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank1
530 //  CHECK-SAME: (%[[S:.+]]: vector<4xf32>, %[[V:.+]]: vector<1x1x4xf32>)
531 //       CHECK:   %[[BCAST:.+]] = vector.broadcast %[[S]] : vector<4xf32> to vector<1x1x4xf32>
532 //       CHECK:   return %[[BCAST]]
533 func.func @cast_away_insert_leading_one_dims_rank1(%s: vector<4xf32>, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> {
534   %0 = vector.insert %s, %v [0, 0] : vector<4xf32> into vector<1x1x4xf32>
535   return %0: vector<1x1x4xf32>
538 // -----
540 // CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_rank1_scalable(
541 // CHECK-SAME:    %[[S:.*]]: vector<[4]xf32>,
542 // CHECK-SAME:    %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
543 // CHECK:           %[[BCAST:.*]] = vector.broadcast %[[S]] : vector<[4]xf32> to vector<1x1x[4]xf32>
544 // CHECK:           return %[[BCAST]] : vector<1x1x[4]xf32>
545 func.func @cast_away_insert_leading_one_dims_rank1_scalable(%s: vector<[4]xf32>, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
546   %0 = vector.insert %s, %v [0, 0] : vector<[4]xf32> into vector<1x1x[4]xf32>
547   return %0: vector<1x1x[4]xf32>
550 // -----
552 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2
553 //  CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x1x4xf32>)
554 //       CHECK:   %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<4xf32> from vector<1x4xf32>
555 //       CHECK:   %[[BCAST:.+]] = vector.broadcast %[[EXTRACT]] : vector<4xf32> to vector<1x1x4xf32>
556 //       CHECK:   return %[[BCAST]]
557 func.func @cast_away_insert_leading_one_dims_rank2(%s: vector<1x4xf32>, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> {
558   %0 = vector.insert %s, %v [0] : vector<1x4xf32> into vector<1x1x4xf32>
559   return %0: vector<1x1x4xf32>
562 // -----
564 // CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_rank2_scalable(
565 // CHECK-SAME:    %[[S:.*]]: vector<1x[4]xf32>,
566 // CHECK-SAME:    %[[V:.*]]: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
567 // CHECK:           %[[EXTRACT:.*]] = vector.extract %[[S]][0] : vector<[4]xf32> from vector<1x[4]xf32>
568 // CHECK:           %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : vector<[4]xf32> to vector<1x1x[4]xf32>
569 // CHECK:           return %[[BCAST]] : vector<1x1x[4]xf32>
570 func.func @cast_away_insert_leading_one_dims_rank2_scalable(%s: vector<1x[4]xf32>, %v: vector<1x1x[4]xf32>) -> vector<1x1x[4]xf32> {
571   %0 = vector.insert %s, %v [0] : vector<1x[4]xf32> into vector<1x1x[4]xf32>
572   return %0: vector<1x1x[4]xf32>
575 // -----
577 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2_one_dest
578 //  CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x2x1x4xf32>)
579 //       CHECK:   %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<4xf32> from vector<1x4xf32>
580 //       CHECK:   %[[EXTRACTV:.+]] = vector.extract %[[V]][0] : vector<2x1x4xf32> from vector<1x2x1x4xf32>
581 //       CHECK:   %[[INSERT:.+]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [1, 0] : vector<4xf32> into vector<2x1x4xf32>
582 //       CHECK:   %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<2x1x4xf32> to vector<1x2x1x4xf32>
583 //       CHECK:   return %[[BCAST]]
584 func.func @cast_away_insert_leading_one_dims_rank2_one_dest(%s: vector<1x4xf32>, %v: vector<1x2x1x4xf32>) -> vector<1x2x1x4xf32> {
585   %0 = vector.insert %s, %v [0, 1] : vector<1x4xf32> into vector<1x2x1x4xf32>
586   return %0: vector<1x2x1x4xf32>
589 // -----
591 // CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable(
592 // CHECK-SAME:      %[[S:.*]]: vector<1x[4]xf32>,
593 // CHECK-SAME:      %[[V:.*]]: vector<1x2x1x[4]xf32>) -> vector<1x2x1x[4]xf32> {
594 // CHECK:           %[[EXTRACTS:.*]] = vector.extract %[[S]][0] : vector<[4]xf32> from vector<1x[4]xf32>
595 // CHECK:           %[[EXTRACTV:.*]] = vector.extract %[[V]][0] : vector<2x1x[4]xf32> from vector<1x2x1x[4]xf32>
596 // CHECK:           %[[INSERT:.*]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [1, 0] : vector<[4]xf32> into vector<2x1x[4]xf32>
597 // CHECK:           %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<2x1x[4]xf32> to vector<1x2x1x[4]xf32>
598 // CHECK:           return %[[BCAST]] : vector<1x2x1x[4]xf32>
599 func.func @cast_away_insert_leading_one_dims_rank2_one_dest_scalable(%s: vector<1x[4]xf32>, %v: vector<1x2x1x[4]xf32>) -> vector<1x2x1x[4]xf32> {
600   %0 = vector.insert %s, %v [0, 1] : vector<1x[4]xf32> into vector<1x2x1x[4]xf32>
601   return %0: vector<1x2x1x[4]xf32>
604 // -----
606 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_non_one_dest
607 //  CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<8x1x4xf32>)
608 //       CHECK:   %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<4xf32> from vector<1x4xf32>
609 //       CHECK:   %[[INSERT:.+]] = vector.insert %[[EXTRACT]], %[[V]] [5, 0] : vector<4xf32> into vector<8x1x4xf32>
610 //       CHECK:   return %[[INSERT]]
611 func.func @cast_away_insert_leading_one_dims_non_one_dest(%s: vector<1x4xf32>, %v: vector<8x1x4xf32>) -> vector<8x1x4xf32> {
612   %0 = vector.insert %s, %v [5] : vector<1x4xf32> into vector<8x1x4xf32>
613   return %0: vector<8x1x4xf32>
616 // -----
618 // CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable(
619 // CHECK-SAME:      %[[S:.*]]: vector<1x[4]xf32>,
620 // CHECK-SAME:      %[[V:.*]]: vector<8x1x[4]xf32>) -> vector<8x1x[4]xf32> {
621 // CHECK:           %[[EXTRACT:.*]] = vector.extract %[[S]][0] : vector<[4]xf32> from vector<1x[4]xf32>
622 // CHECK:           %[[INSERT:.*]] = vector.insert %[[EXTRACT]], %[[V]] [5, 0] : vector<[4]xf32> into vector<8x1x[4]xf32>
623 // CHECK:           return %[[INSERT]] : vector<8x1x[4]xf32>
624 func.func @cast_away_insert_leading_one_dims_non_one_dest_scalable(%s: vector<1x[4]xf32>, %v: vector<8x1x[4]xf32>) -> vector<8x1x[4]xf32> {
625   %0 = vector.insert %s, %v [5] : vector<1x[4]xf32> into vector<8x1x[4]xf32>
626   return %0: vector<8x1x[4]xf32>
629 // -----
631 // CHECK-LABEL: func @cast_away_insert_leading_one_dims_one_two_dest
632 //  CHECK-SAME: (%[[S:.+]]: vector<1x8xi1>, %[[V:.+]]: vector<1x1x8x1x8xi1>)
633 //       CHECK:   %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<8xi1> from vector<1x8xi1>
634 //       CHECK:   %[[EXTRACTV:.+]] = vector.extract %[[V]][0, 0] : vector<8x1x8xi1> from vector<1x1x8x1x8xi1>
635 //       CHECK:   %[[INSERT:.+]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [7, 0] : vector<8xi1> into vector<8x1x8xi1>
636 //       CHECK:   %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<8x1x8xi1> to vector<1x1x8x1x8xi1>
637 //       CHECK:   return %[[BCAST]]
638 func.func @cast_away_insert_leading_one_dims_one_two_dest(%s: vector<1x8xi1>, %v: vector<1x1x8x1x8xi1>) -> vector<1x1x8x1x8xi1> {
639   %0 = vector.insert %s, %v [0, 0, 7] : vector<1x8xi1> into vector<1x1x8x1x8xi1>
640   return %0: vector<1x1x8x1x8xi1>
643 // -----
645 // CHECK-LABEL:   func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(
646 // CHECK-SAME:      %[[S:.*]]: vector<1x[8]xi1>,
647 // CHECK-SAME:      %[[V:.*]]: vector<1x1x8x1x[8]xi1>) -> vector<1x1x8x1x[8]xi1> {
648 // CHECK:           %[[EXTRACTS:.*]] = vector.extract %[[S]][0] : vector<[8]xi1> from vector<1x[8]xi1>
649 // CHECK:           %[[EXTRACTV:.*]] = vector.extract %[[V]][0, 0] : vector<8x1x[8]xi1> from vector<1x1x8x1x[8]xi1>
650 // CHECK:           %[[INSERT:.*]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [7, 0] : vector<[8]xi1> into vector<8x1x[8]xi1>
651 // CHECK:           %[[BCAST:.*]] = vector.broadcast %[[INSERT]] : vector<8x1x[8]xi1> to vector<1x1x8x1x[8]xi1>
652 // CHECK:           return %[[BCAST]] : vector<1x1x8x1x[8]xi1>
653 func.func @cast_away_insert_leading_one_dims_one_two_dest_scalable(%s: vector<1x[8]xi1>, %v: vector<1x1x8x1x[8]xi1>) -> vector<1x1x8x1x[8]xi1> {
654   %0 = vector.insert %s, %v [0, 0, 7] : vector<1x[8]xi1> into vector<1x1x8x1x[8]xi1>
655   return %0: vector<1x1x8x1x[8]xi1>
658 // -----
660 // CHECK-LABEL:   func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
661 // CHECK:           %[[MASK:.*]] = vector.constant_mask [6, 1, 1] : vector<8x2x1xi1>
662 // CHECK:           %[[BCAST:.*]] = vector.broadcast %[[MASK]] : vector<8x2x1xi1> to vector<1x1x8x2x1xi1>
663 // CHECK:           return %[[BCAST]] : vector<1x1x8x2x1xi1>
664 func.func @cast_away_constant_mask() -> vector<1x1x8x2x1xi1> {
665   %0 = vector.constant_mask [1, 1, 6, 1, 1] : vector<1x1x8x2x1xi1>
666   return %0: vector<1x1x8x2x1xi1>
669 // -----
671 // CHECK-LABEL:   func.func @drop_unit_dims_scalar_cond_select(
672 // CHECK:           arith.select {{.*}} : vector<16xi1>
673 func.func @drop_unit_dims_scalar_cond_select(%cond: i1, %arg0: vector<1x16xi1>, %arg1: vector<1x16xi1>) -> vector<1x16xi1> {
674   %sel = arith.select %cond, %arg0, %arg1 : vector<1x16xi1>
675   return %sel : vector<1x16xi1>