1 // RUN: mlir-opt %s -split-input-file -canonicalize="test-convergence" | FileCheck %s
4 // CHECK-LABEL: expand_shape_identity_fold
6 func.func @expand_shape_identity_fold(%arg0 : tensor<5xf32>) -> tensor<5xf32> {
7 %0 = tensor.expand_shape %arg0 [[0]] output_shape [5] : tensor<5xf32> into tensor<5xf32>
8 return %0 : tensor<5xf32>
13 // CHECK-LABEL: expand_shape_rank0_identity_fold
15 func.func @expand_shape_rank0_identity_fold(%arg0 : tensor<f32>) -> tensor<f32> {
16 %0 = tensor.expand_shape %arg0 [] output_shape [] : tensor<f32> into tensor<f32>
17 return %0 : tensor<f32>
22 // CHECK-LABEL: collapse_shape_identity_fold
24 func.func @collapse_shape_identity_fold(%arg0 : tensor<5x4xf32>) -> tensor<5x4xf32> {
25 %0 = tensor.collapse_shape %arg0 [[0], [1]] : tensor<5x4xf32> into tensor<5x4xf32>
26 return %0 : tensor<5x4xf32>
31 // CHECK-LABEL: collapse_shape_rank0_identity_fold
33 func.func @collapse_shape_rank0_identity_fold(%arg0 : tensor<f32>) -> tensor<f32> {
34 %0 = tensor.collapse_shape %arg0 [] : tensor<f32> into tensor<f32>
35 return %0 : tensor<f32>
40 // CHECK-LABEL: @tensor_bitcast_chain_ok
41 // CHECK-SAME: %[[IN:.*]]: tensor<2xi32>
42 func.func @tensor_bitcast_chain_ok(%input: tensor<2xi32>) -> tensor<2xf32> {
43 // CHECK-NEXT: %[[RES:.*]] = tensor.bitcast %[[IN]] : tensor<2xi32> to tensor<2xf32>
44 %0 = tensor.bitcast %input : tensor<2xi32> to tensor<2xui32>
45 %1 = tensor.bitcast %0 : tensor<2xui32> to tensor<2xf32>
46 // CHECK-NEXT: return %[[RES]]
47 return %1 : tensor<2xf32>
52 // CHECK-LABEL: @tensor_bitcast_chain_nop
53 // CHECK-SAME: %[[IN:.*]]: tensor<4xi32>
54 func.func @tensor_bitcast_chain_nop(%input: tensor<4xi32>) -> tensor<4xi32> {
55 %0 = tensor.bitcast %input : tensor<4xi32> to tensor<4xui32>
56 %1 = tensor.bitcast %0 : tensor<4xui32> to tensor<4xi32>
57 // CHECK-NEXT: return %[[IN]]
58 return %1 : tensor<4xi32>
63 // Checks that NOP casts are removed.
64 // CHECK-LABEL: cast_values
65 func.func @cast_values(%arg0: tensor<*xi32>) -> tensor<2xi32> {
67 %0 = tensor.cast %arg0 : tensor<*xi32> to tensor<*xi32>
68 // CHECK-NEXT: %[[RET:.*]] = tensor.cast %arg0 : tensor<*xi32> to tensor<2xi32>
69 %2 = tensor.cast %0 : tensor<*xi32> to tensor<2xi32>
71 %4 = tensor.cast %2 : tensor<2xi32> to tensor<2xi32>
72 // CHECK-NEXT: return %[[RET]] : tensor<2xi32>
73 return %4 : tensor<2xi32>
78 // CHECK-LABEL: @tensor.cast_chain_ok
79 // CHECK-SAME: %[[IN:.*]]: tensor<*xi32>
80 func.func @tensor.cast_chain_ok(%input: tensor<*xi32>) -> tensor<4x8xi32> {
81 // CHECK-NEXT: %[[RES:.*]] = tensor.cast %[[IN]] : tensor<*xi32> to tensor<4x8xi32>
82 %0 = tensor.cast %input : tensor<*xi32> to tensor<4x?xi32>
83 %1 = tensor.cast %0 : tensor<4x?xi32> to tensor<4x8xi32>
84 // CHECK-NEXT: return %[[RES]]
85 return %1 : tensor<4x8xi32>
90 // CHECK-LABEL: @tensor.cast_chain_regain
91 // CHECK-SAME: %[[IN:.*]]: tensor<4xi32>
92 func.func @tensor.cast_chain_regain(%input: tensor<4xi32>) -> tensor<4xi32> {
93 %0 = tensor.cast %input : tensor<4xi32> to tensor<?xi32>
94 %1 = tensor.cast %0 : tensor<?xi32> to tensor<4xi32>
95 // CHECK-NEXT: return %[[IN]]
96 return %1 : tensor<4xi32>
101 // CHECK-LABEL: @tensor.cast_chain_keep
102 // CHECK-SAME: %[[IN:.*]]: tensor<?x?xi32>
103 func.func @tensor.cast_chain_keep(%input: tensor<?x?xi32>) -> tensor<?x8xi32> {
104 // CHECK-NEXT: %[[C1:.*]] = tensor.cast %[[IN]]
105 %0 = tensor.cast %input : tensor<?x?xi32> to tensor<4x?xi32>
106 // CHECK-NEXT: %[[C2:.*]] = tensor.cast %[[C1]]
107 %1 = tensor.cast %0 : tensor<4x?xi32> to tensor<?x8xi32>
108 // CHECK-NEXT: return %[[C2]]
109 return %1 : tensor<?x8xi32>
114 // CHECK-LABEL: @tensor.cast_chain_invalid
115 // CHECK-SAME: %[[IN:.*]]: tensor<4x8xi32>
116 func.func @tensor.cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> {
117 // CHECK-NEXT: %[[C1:.*]] = tensor.cast %[[IN]]
118 %0 = tensor.cast %input : tensor<4x8xi32> to tensor<?x?xi32>
119 // CHECK-NEXT: %[[C2:.*]] = tensor.cast %[[C1]]
120 %1 = tensor.cast %0 : tensor<?x?xi32> to tensor<8x4xi32>
121 // CHECK-NEXT: return %[[C2]]
122 return %1 : tensor<8x4xi32>
127 // CHECK-LABEL: fold_concat
128 // CHECK-SAME: %[[ARG0:.*]]: tensor<1x2x?xi32>
129 func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1x2x?xi32>) {
130 %0 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x3xi32>
131 // CHECK-NEXT: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<1x2x?xi32> to tensor<1x2x3xi32>
132 %1 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x?xi32>
133 // CHECK-NEXT: return %[[CAST]], %[[ARG0]] : tensor<1x2x3xi32>, tensor<1x2x?xi32>
134 return %0, %1 : tensor<1x2x3xi32>, tensor<1x2x?xi32>
139 // CHECK-LABEL: func @fold_extract
140 func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
141 %const_0 = arith.constant 0 : index
142 %const_1 = arith.constant 1 : index
143 %const_3 = arith.constant 3 : index
144 // CHECK-DAG: [[C64:%.+]] = arith.constant 64 : i32
145 // CHECK-DAG: [[C0:%.+]] = arith.constant 0.{{0*}}e+00 : f16
146 // CHECK-DAG: [[CM2:%.+]] = arith.constant -2.{{0*}}e+00 : f16
148 // Fold an extract into a splat.
149 // CHECK-DAG: [[C4:%.+]] = arith.constant 4.{{0*}}e+00 : f32
150 %0 = arith.constant dense<4.0> : tensor<4xf32>
151 %ext_1 = tensor.extract %0[%arg0] : tensor<4xf32>
153 // Fold an extract into a sparse with a sparse index.
154 %1 = arith.constant sparse<[[0, 0, 0], [1, 1, 1]], [-5.0, -2.0]> : tensor<4x4x4xf16>
155 %ext_2 = tensor.extract %1[%const_1, %const_1, %const_1] : tensor<4x4x4xf16>
157 // Fold an extract into a sparse with a non sparse index.
158 %2 = arith.constant sparse<[[1, 1, 1]], [-2.0]> : tensor<2x2x2xf16>
159 %ext_3 = tensor.extract %2[%const_0, %const_0, %const_0] : tensor<2x2x2xf16>
161 // Fold an extract into a dense tensor.
162 %3 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
163 %ext_4 = tensor.extract %3[%const_1, %const_0, %const_3] : tensor<2x1x4xi32>
165 // Fold an extract into a complex constant.
166 // CHECK-DAG: [[C5:%.+]] = complex.constant [1.200000e+00 : f32, 2.300000e+00 : f32] : complex<f32>
167 %4 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
168 %ext_5 = tensor.extract %4[] : tensor<complex<f32>>
170 // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]]
171 return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex<f32>
176 // Ensure extract dense resource elements not crash.
178 // CHECK-LABEL: func @extract_dense_resource_nofold
179 func.func @extract_dense_resource_nofold() -> i64 {
180 // CHECK: %[[EXT:.+]] = tensor.extract
181 // CHECK-NEXT: return %[[EXT]]
182 %c0 = arith.constant 0 : index
183 %cst = arith.constant dense_resource<__elided__> : tensor<1xi64>
184 %extracted = tensor.extract %cst[%c0] : tensor<1xi64>
185 return %extracted : i64
190 // CHECK-LABEL: func @fold_insert
191 func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
192 // Fold an insert into a splat.
193 // CHECK-DAG: %[[C4:.+]] = arith.constant dense<4.{{0*}}e+00> : tensor<4xf32>
194 %0 = arith.constant dense<4.0> : tensor<4xf32>
195 %1 = arith.constant 4.0 : f32
196 %ins_1 = tensor.insert %1 into %0[%arg0] : tensor<4xf32>
197 // CHECK-NEXT: return %[[C4]]
198 return %ins_1 : tensor<4xf32>
203 // CHECK-LABEL: func @extract_from_tensor.cast
204 // CHECK-SAME: %[[TENSOR:.*]]: tensor<9xf32>
205 func.func @extract_from_tensor.cast(%tensor: tensor<9xf32>) -> f32 {
206 // CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
207 %c0 = arith.constant 0 : index
208 // CHECK-NOT: tensor.cast
209 %casted = tensor.cast %tensor : tensor<9xf32> to tensor<?xf32>
210 // CHECK-NEXT: tensor.extract %[[TENSOR]][%[[C0]]]
211 %result = tensor.extract %casted[%c0] : tensor<?xf32>
217 // CHECK-LABEL: func @extract_from_tensor.from_elements
218 func.func @extract_from_tensor.from_elements(%element : index) -> index {
219 // CHECK-SAME: ([[ARG:%.*]]: index)
220 %c0 = arith.constant 0 : index
221 %tensor = tensor.from_elements %element : tensor<1xindex>
222 %extracted_element = tensor.extract %tensor[%c0] : tensor<1xindex>
223 // CHECK: [[ARG]] : index
224 return %extracted_element : index
229 // CHECK-LABEL: func @extract_from_tensor.from_elements_0d
230 func.func @extract_from_tensor.from_elements_0d(%element : index) -> index {
231 // CHECK-SAME: ([[ARG:%.*]]: index)
232 %c0 = arith.constant 0 : index
233 %tensor = tensor.from_elements %element : tensor<index>
234 %extracted_element = tensor.extract %tensor[] : tensor<index>
235 // CHECK: [[ARG]] : index
236 return %extracted_element : index
241 // CHECK-LABEL: func @extract_from_tensor.from_elements_3d
242 func.func @extract_from_tensor.from_elements_3d()
243 -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
244 %f0 = arith.constant 0.0 : f32
245 %f1 = arith.constant 1.0 : f32
246 %f2 = arith.constant 2.0 : f32
247 %f3 = arith.constant 3.0 : f32
248 %f4 = arith.constant 4.0 : f32
249 %f5 = arith.constant 5.0 : f32
250 %f6 = arith.constant 6.0 : f32
251 %f7 = arith.constant 7.0 : f32
252 %f8 = arith.constant 8.0 : f32
253 %f9 = arith.constant 9.0 : f32
254 %f10 = arith.constant 10.0 : f32
255 %f11 = arith.constant 11.0 : f32
257 %tensor = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11
259 %c0 = arith.constant 0 : index
260 %c1 = arith.constant 1 : index
261 %c2 = arith.constant 2 : index
263 %r0 = tensor.extract %tensor[%c0, %c0, %c0] : tensor<3x2x2xf32>
264 %r1 = tensor.extract %tensor[%c0, %c0, %c1] : tensor<3x2x2xf32>
265 %r2 = tensor.extract %tensor[%c0, %c1, %c0] : tensor<3x2x2xf32>
266 %r3 = tensor.extract %tensor[%c0, %c1, %c1] : tensor<3x2x2xf32>
267 %r4 = tensor.extract %tensor[%c1, %c0, %c0] : tensor<3x2x2xf32>
268 %r5 = tensor.extract %tensor[%c1, %c0, %c1] : tensor<3x2x2xf32>
269 %r6 = tensor.extract %tensor[%c1, %c1, %c0] : tensor<3x2x2xf32>
270 %r7 = tensor.extract %tensor[%c1, %c1, %c1] : tensor<3x2x2xf32>
271 %r8 = tensor.extract %tensor[%c2, %c0, %c0] : tensor<3x2x2xf32>
272 %r9 = tensor.extract %tensor[%c2, %c0, %c1] : tensor<3x2x2xf32>
273 %r10 = tensor.extract %tensor[%c2, %c1, %c0] : tensor<3x2x2xf32>
274 %r11 = tensor.extract %tensor[%c2, %c1, %c1] : tensor<3x2x2xf32>
275 return %r0,%r1,%r2,%r3,%r4,%r5,%r6,%r7,%r8,%r9,%r10,%r11
276 : f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
278 // CHECK-DAG: %[[F0:.*]] = arith.constant 0.0
279 // CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00
280 // CHECK-DAG: %[[F2:.*]] = arith.constant 2.0
281 // CHECK-DAG: %[[F3:.*]] = arith.constant 3.0
282 // CHECK-DAG: %[[F4:.*]] = arith.constant 4.0
283 // CHECK-DAG: %[[F5:.*]] = arith.constant 5.0
284 // CHECK-DAG: %[[F6:.*]] = arith.constant 6.0
285 // CHECK-DAG: %[[F7:.*]] = arith.constant 7.0
286 // CHECK-DAG: %[[F8:.*]] = arith.constant 8.0
287 // CHECK-DAG: %[[F9:.*]] = arith.constant 9.0
288 // CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01
289 // CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01
291 // CHECK: return %[[F0]], %[[F1]], %[[F2]], %[[F3]], %[[F4]], %[[F5]],
292 // CHECK-SAME: %[[F6]], %[[F7]], %[[F8]], %[[F9]], %[[F10]], %[[F11]]
296 // CHECK-LABEL: func @extract_from_tensor.from_elements_variable_3d
297 // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: f32
298 // CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: f32
299 // CHECK-SAME: %[[ARG_2:[a-zA-Z0-9_]+]]: f32
300 // CHECK-SAME: %[[ARG_3:[a-zA-Z0-9_]+]]: f32
301 // CHECK-SAME: %[[ARG_4:[a-zA-Z0-9_]+]]: f32
302 // CHECK-SAME: %[[ARG_5:[a-zA-Z0-9_]+]]: f32
303 // CHECK-SAME: %[[ARG_6:[a-zA-Z0-9_]+]]: f32
304 // CHECK-SAME: %[[ARG_7:[a-zA-Z0-9_]+]]: f32
305 // CHECK-SAME: %[[ARG_8:[a-zA-Z0-9_]+]]: f32
306 // CHECK-SAME: %[[ARG_9:[a-zA-Z0-9_]+]]: f32
307 // CHECK-SAME: %[[ARG_10:[a-zA-Z0-9_]+]]: f32
308 // CHECK-SAME: %[[ARG_11:[a-zA-Z0-9_]+]]: f32
309 func.func @extract_from_tensor.from_elements_variable_3d(
310 %f0: f32, %f1: f32, %f2: f32, %f3: f32, %f4: f32, %f5: f32,
311 %f6: f32, %f7: f32, %f8: f32, %f9: f32, %f10: f32, %f11: f32)
312 -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
314 %tensor = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11
316 %c0 = arith.constant 0 : index
317 %c1 = arith.constant 1 : index
318 %c2 = arith.constant 2 : index
320 %r0 = tensor.extract %tensor[%c0, %c0, %c0] : tensor<3x2x2xf32>
321 %r1 = tensor.extract %tensor[%c0, %c0, %c1] : tensor<3x2x2xf32>
322 %r2 = tensor.extract %tensor[%c0, %c1, %c0] : tensor<3x2x2xf32>
323 %r3 = tensor.extract %tensor[%c0, %c1, %c1] : tensor<3x2x2xf32>
324 %r4 = tensor.extract %tensor[%c1, %c0, %c0] : tensor<3x2x2xf32>
325 %r5 = tensor.extract %tensor[%c1, %c0, %c1] : tensor<3x2x2xf32>
326 %r6 = tensor.extract %tensor[%c1, %c1, %c0] : tensor<3x2x2xf32>
327 %r7 = tensor.extract %tensor[%c1, %c1, %c1] : tensor<3x2x2xf32>
328 %r8 = tensor.extract %tensor[%c2, %c0, %c0] : tensor<3x2x2xf32>
329 %r9 = tensor.extract %tensor[%c2, %c0, %c1] : tensor<3x2x2xf32>
330 %r10 = tensor.extract %tensor[%c2, %c1, %c0] : tensor<3x2x2xf32>
331 %r11 = tensor.extract %tensor[%c2, %c1, %c1] : tensor<3x2x2xf32>
332 return %r0,%r1,%r2,%r3,%r4,%r5,%r6,%r7,%r8,%r9,%r10,%r11
333 : f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
335 // CHECK: return %[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]],
336 // CHECK-SAME: %[[ARG_6]], %[[ARG_7]], %[[ARG_8]], %[[ARG_9]], %[[ARG_10]], %[[ARG_11]]
340 // CHECK-LABEL: func.func @extract_from_elements_complex_i() -> tensor<3xcomplex<i32>> {
341 // CHECK-NEXT: %cst = arith.constant dense<[(1,2), (3,2), (1,2)]> : tensor<3xcomplex<i32>>
342 // CHECK-NEXT: return %cst : tensor<3xcomplex<i32>>
343 func.func @extract_from_elements_complex_i() -> tensor<3xcomplex<i32>> {
344 %c1 = arith.constant dense<(1, 2)> : tensor<complex<i32>>
345 %complex1 = tensor.extract %c1[] : tensor<complex<i32>>
346 %c2 = arith.constant dense<(3, 2)> : tensor<complex<i32>>
347 %complex2 = tensor.extract %c2[] : tensor<complex<i32>>
348 %tensor = tensor.from_elements %complex1, %complex2, %complex1 : tensor<3xcomplex<i32>>
349 return %tensor : tensor<3xcomplex<i32>>
354 // CHECK-LABEL: func.func @extract_from_elements_complex_f() -> tensor<3xcomplex<f32>> {
355 // CHECK-NEXT: %cst = arith.constant dense<[(1.200000e+00,2.300000e+00), (3.200000e+00,2.100000e+00), (1.200000e+00,2.300000e+00)]> : tensor<3xcomplex<f32>>
356 // CHECK-NEXT: return %cst : tensor<3xcomplex<f32>>
357 func.func @extract_from_elements_complex_f() -> tensor<3xcomplex<f32>> {
358 %c1 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
359 %complex1 = tensor.extract %c1[] : tensor<complex<f32>>
360 %c2 = arith.constant dense<(3.2, 2.1)> : tensor<complex<f32>>
361 %complex2 = tensor.extract %c2[] : tensor<complex<f32>>
362 %tensor = tensor.from_elements %complex1, %complex2, %complex1 : tensor<3xcomplex<f32>>
363 return %tensor : tensor<3xcomplex<f32>>
368 // Ensure the optimization doesn't segfault from bad constants
369 // CHECK-LABEL: func @extract_negative_from_tensor.from_elements
370 func.func @extract_negative_from_tensor.from_elements(%element : index) -> index {
371 // CHECK-SAME: ([[ARG:%.*]]: index)
372 %c-1 = arith.constant -1 : index
373 %tensor = tensor.from_elements %element : tensor<1xindex>
374 %extracted_element = tensor.extract %tensor[%c-1] : tensor<1xindex>
375 // CHECK: tensor.from_elements
376 // CHECK: %[[RESULT:.*]] = tensor.extract
377 // CHECK: return %[[RESULT]]
378 return %extracted_element : index
383 // Ensure the optimization doesn't segfault from bad constants
384 // CHECK-LABEL: func @extract_oob_from_tensor.from_elements
385 func.func @extract_oob_from_tensor.from_elements(%element : index) -> index {
386 // CHECK-SAME: ([[ARG:%.*]]: index)
387 %c1 = arith.constant 1 : index
388 %tensor = tensor.from_elements %element : tensor<1xindex>
389 %extracted_element = tensor.extract %tensor[%c1] : tensor<1xindex>
390 // CHECK: tensor.from_elements
391 // CHECK: %[[RESULT:.*]] = tensor.extract
392 // CHECK: return %[[RESULT]]
393 return %extracted_element : index
398 // Ensure the optimization doesn't segfault from bad constants
399 // CHECK-LABEL: func @extract_oob_from_tensor.from_elements
400 func.func @extract_oob_from_tensor.from_elements(%element : index) -> index {
401 // CHECK-SAME: ([[ARG:%.*]]: index)
402 %c2 = arith.constant 2 : index
403 %tensor = tensor.from_elements %element : tensor<1xindex>
404 %extracted_element = tensor.extract %tensor[%c2] : tensor<1xindex>
405 // CHECK: tensor.from_elements
406 // CHECK: %[[RESULT:.*]] = tensor.extract
407 // CHECK: return %[[RESULT]]
408 return %extracted_element : index
413 // CHECK-LABEL: func @extract_from_tensor.generate
414 // CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
415 func.func @extract_from_tensor.generate(%idx: index, %tensor: tensor<*xf32>) -> index {
416 %size = tensor.rank %tensor : tensor<*xf32>
417 // CHECK-NEXT: %[[RES:.*]] = tensor.dim %[[TENSOR]], %[[IDX]]
418 %0 = tensor.generate %size {
420 %1 = tensor.dim %tensor, %arg0 : tensor<*xf32>
421 tensor.yield %1 : index
423 %1 = tensor.extract %0[%idx] : tensor<?xindex>
424 // CHECK-NEXT: return %[[RES]]
430 // CHECK-LABEL: func @extract_from_tensor.generate_2d
431 // CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
432 func.func @extract_from_tensor.generate_2d(%idx0: index, %idx1: index, %tensor: tensor<*xf32>) -> index {
433 %size = tensor.rank %tensor : tensor<*xf32>
434 // CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[TENSOR]], %[[IDX0]]
435 // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[TENSOR]], %[[IDX1]]
436 // CHECK-NEXT: %[[RES:.*]] = arith.addi %[[DIM0]], %[[DIM1]]
437 %0 = tensor.generate %size, %size {
438 ^bb0(%arg0: index, %arg1: index):
439 %1 = tensor.dim %tensor, %arg0 : tensor<*xf32>
440 %2 = tensor.dim %tensor, %arg1 : tensor<*xf32>
441 %3 = arith.addi %1, %2 : index
442 tensor.yield %3 : index
443 } : tensor<?x?xindex>
444 %4 = tensor.extract %0[%idx0, %idx1] : tensor<?x?xindex>
445 // CHECK-NEXT: return %[[RES]]
451 // CHECK-LABEL: func @extract_from_tensor.generate_sideeffects
452 // CHECK-SAME: %[[IDX:.*]]: index
453 func.func @extract_from_tensor.generate_sideeffects(%idx: index, %tensor: tensor<*xf32>, %mem: memref<?xindex>) -> index {
454 %size = tensor.rank %tensor : tensor<*xf32>
455 // CHECK: %[[DTENSOR:.*]] = tensor.generate
456 %0 = tensor.generate %size {
458 %1 = tensor.dim %tensor, %arg0 : tensor<*xf32>
459 memref.store %1, %mem[%arg0] : memref<?xindex>
460 tensor.yield %1 : index
462 // CHECK: %[[RES:.*]] = tensor.extract %[[DTENSOR]][%[[IDX]]]
463 %1 = tensor.extract %0[%idx] : tensor<?xindex>
464 // CHECK-NEXT: return %[[RES]]
470 // CHECK-LABEL: @static_tensor.generate
471 // CHECK-SAME: %[[SIZE1:.*]]: index, %[[SIZE4:.*]]: index)
472 func.func @static_tensor.generate(%size1: index, %size4: index) -> tensor<3x?x?x7x?xindex> {
473 %c5 = arith.constant 5 : index
474 // CHECK: tensor.generate %[[SIZE1]], %[[SIZE4]]
475 %0 = tensor.generate %size1, %c5, %size4 {
476 ^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index):
477 %1 = arith.constant 32 : index
478 tensor.yield %1 : index
479 // CHECK: : tensor<3x?x5x7x?xindex>
480 } : tensor<3x?x?x7x?xindex>
481 // CHECK: tensor.cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex>
482 return %0 : tensor<3x?x?x7x?xindex>
487 // CHECK-LABEL: @from_elements.constant
488 func.func @from_elements.constant() -> tensor<3xindex> {
489 // CHECK: %[[CST:.*]] = arith.constant dense<[1, 2, 1]> : tensor<3xindex>
490 // CHECK: return %[[CST]]
491 %c1 = arith.constant 1 : index
492 %c2 = arith.constant 2 : index
493 %tensor = tensor.from_elements %c1, %c2, %c1 : tensor<3xindex>
494 return %tensor : tensor<3xindex>
499 func.func @slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
500 %arg2 : index) -> tensor<?x?x?xf32>
502 %c0 = arith.constant 0 : index
503 %c1 = arith.constant 1 : index
504 %c4 = arith.constant 4 : index
505 %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
506 return %0 : tensor<?x?x?xf32>
508 // CHECK-LABEL: func @slice_canonicalize
509 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
510 // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
511 // CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
512 // CHECK-SAME: : tensor<?x?x?xf32> to tensor<4x1x?xf32>
513 // CHECK: %[[RESULT:.+]] = tensor.cast %[[SLICE]]
514 // CHECK: return %[[RESULT]]
518 func.func @rank_reducing_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
519 %arg2 : index) -> tensor<?x?xf32>
521 %c0 = arith.constant 0 : index
522 %c1 = arith.constant 1 : index
523 %c4 = arith.constant 4 : index
524 %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?xf32>
525 return %0 : tensor<?x?xf32>
527 // CHECK-LABEL: func @rank_reducing_slice_canonicalize
528 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>
529 // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
530 // CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
531 // CHECK-SAME: : tensor<?x?x?xf32> to tensor<4x?xf32>
532 // CHECK: %[[RESULT:.+]] = tensor.cast %[[SLICE]]
533 // CHECK: return %[[RESULT]]
537 // CHECK-LABEL: func @trivial_slice
538 // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
539 // CHECK-NOT: tensor.extract_slice
540 // CHECK: return %[[ARG0]] : tensor<4x6x16x32xi8>
541 func.func @trivial_slice(%arg0 : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
542 %0 = tensor.extract_slice %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<4x6x16x32xi8>
543 return %0 : tensor<4x6x16x32xi8>
548 // CHECK-LABEL: func @trivial_insert_slice
549 // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
550 // CHECK-NOT: tensor.extract_slice
551 // CHECK: return %[[ARG0]] : tensor<4x6x16x32xi8>
552 func.func @trivial_insert_slice(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
553 %0 = tensor.insert_slice %arg0 into %arg1[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> into tensor<4x6x16x32xi8>
554 return %0 : tensor<4x6x16x32xi8>
559 // CHECK-LABEL: func @empty_insert_slice
560 // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<0x2xi8>
561 // CHECK-SAME: %[[ARG1:.[a-z0-9A-Z_]+]]: tensor<3x3xi8>
562 // CHECK-NOT: tensor.extract_slice
563 // CHECK: return %[[ARG1]] : tensor<3x3xi8>
564 func.func @empty_insert_slice(%arg0 : tensor<0x2xi8>, %arg1 : tensor<3x3xi8>) -> tensor<3x3xi8> {
565 %0 = tensor.insert_slice %arg0 into %arg1[0, 0] [0, 2] [1, 1] : tensor<0x2xi8> into tensor<3x3xi8>
566 return %0 : tensor<3x3xi8>
571 // CHECK-LABEL: func @rank_reducing_tensor_of_cast
572 // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
573 // CHECK: %[[S:.+]] = tensor.extract_slice %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8>
574 // Tensor cast is moved after slice and then gets canonicalized away.
575 // CHECK-NOT: tensor.cast
576 // CHECK: return %[[S]] : tensor<16x32xi8>
577 func.func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<16x32xi8> {
578 %0 = tensor.cast %arg : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
579 %1 = tensor.extract_slice %0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<?x?x16x32xi8> to tensor<16x32xi8>
580 return %1 : tensor<16x32xi8>
585 // CHECK-LABEL: func @rank_reducing_insert_slice_of_cast
586 // CHECK-SAME: %[[A:.[a-z0-9A-Z_]+]]: tensor<16x32xi8>
587 // CHECK-SAME: %[[B:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
588 // CHECK: %[[S:.+]] = tensor.insert_slice %[[A]] into %[[B]][0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<16x32xi8> into tensor<4x6x16x32xi8>
589 // Tensor cast is folded away.
590 // CHECK-NOT: tensor.cast
591 // CHECK: return %[[S]] : tensor<4x6x16x32xi8>
592 func.func @rank_reducing_insert_slice_of_cast(%a : tensor<16x32xi8>, %b : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
593 %c0 = arith.constant 0: index
594 %cast = tensor.cast %a : tensor<16x32xi8> to tensor<?x32xi8>
595 %sz = tensor.dim %cast, %c0: tensor<?x32xi8>
596 %res = tensor.insert_slice %cast into %b[0, 1, 0, 0] [1, 1, %sz, 32] [1, 1, 1, 1] : tensor<?x32xi8> into tensor<4x6x16x32xi8>
597 return %res : tensor<4x6x16x32xi8>
602 func.func @insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
603 %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
605 %c0 = arith.constant 0 : index
606 %c1 = arith.constant 1 : index
607 %c4 = arith.constant 4 : index
608 %0 = tensor.insert_slice %arg0 into %arg3[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
609 return %0 : tensor<?x?x?xf32>
611 // CHECK-LABEL: func @insert_slice_canonicalize
612 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
613 // CHECK: %[[CAST:.+]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<4x1x?xf32>
614 // CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[CAST]]
615 // CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
616 // CHECK-SAME: : tensor<4x1x?xf32> into tensor<?x?x?xf32>
617 // CHECK: return %[[RESULT]]
621 // Do not insert a cast for the following example. The new source type wouldn't be "more static" than the old one.
622 func.func @insert_slice_canonicalize_encoding(%arg0 : tensor<2x2xf32, "foo">,
623 %arg1 : tensor<4x4xf32, "foo">) -> tensor<4x4xf32, "foo">
625 %0 = tensor.insert_slice %arg0 into %arg1[0, 0] [2, 2] [1, 1] : tensor<2x2xf32, "foo"> into tensor<4x4xf32, "foo">
626 return %0 : tensor<4x4xf32, "foo">
628 // CHECK-LABEL: func @insert_slice_canonicalize_encoding
629 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x2xf32, "foo">
630 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<4x4xf32, "foo">
631 // CHECK-NOT: tensor.cast
632 // CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]] into %[[ARG1]]
633 // CHECK-SAME: [0, 0] [2, 2] [1, 1]
634 // CHECK-SAME: : tensor<2x2xf32, "foo"> into tensor<4x4xf32, "foo">
635 // CHECK: return %[[RESULT]]
639 func.func @slice_to_insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
640 %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
642 %c0 = arith.constant 0 : index
643 %c1 = arith.constant 1 : index
644 %c4 = arith.constant 4 : index
645 %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
646 %1 = tensor.insert_slice %0 into %arg3[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
647 return %1 : tensor<?x?x?xf32>
649 // CHECK-LABEL: func @slice_to_insert_slice_canonicalize
650 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
651 // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
652 // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
653 // CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}} [1, 1, 1]
654 // CHECK-SAME: : tensor<?x?x?xf32> to tensor<4x1x?xf32>
655 // CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[SLICE]]
656 // CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
657 // CHECK-SAME: : tensor<4x1x?xf32> into tensor<?x?x?xf32>
658 // CHECK: return %[[RESULT]]
662 func.func @rank_reducing_insert_slice_canonicalize(%arg0 : tensor<?x?xf32>, %arg1 : index,
663 %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
665 %c0 = arith.constant 0 : index
666 %c1 = arith.constant 1 : index
667 %c4 = arith.constant 4 : index
668 %0 = tensor.insert_slice %arg0 into %arg3[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?xf32> into tensor<?x?x?xf32>
669 return %0 : tensor<?x?x?xf32>
671 // CHECK-LABEL: func @rank_reducing_insert_slice_canonicalize
672 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
673 // CHECK: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<?x?xf32> to tensor<4x?xf32>
674 // CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[CAST]]
675 // CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
676 // CHECK-SAME: : tensor<4x?xf32> into tensor<?x?x?xf32>
677 // CHECK: return %[[RESULT]]
681 func.func @rank_reducing_slice_to_insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
682 %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
684 %c0 = arith.constant 0 : index
685 %c1 = arith.constant 1 : index
686 %c4 = arith.constant 4 : index
687 %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?xf32>
688 %1 = tensor.insert_slice %0 into %arg3[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?xf32> into tensor<?x?x?xf32>
689 return %1 : tensor<?x?x?xf32>
691 // CHECK-LABEL: func @rank_reducing_slice_to_insert_slice_canonicalize
692 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
693 // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
694 // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
695 // CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
696 // CHECK-SAME: : tensor<?x?x?xf32> to tensor<4x?xf32>
697 // CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[SLICE]] into %[[ARG3]]
698 // CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
699 // CHECK-SAME: : tensor<4x?xf32> into tensor<?x?x?xf32>
700 // CHECK: return %[[RESULT]]
704 func.func @insert_slice_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i32>,
705 %arg2 : index, %arg3 : index) -> tensor<?x?xi32> {
706 %c0 = arith.constant 0 : index
707 %c1 = arith.constant 1 : index
708 %c2 = arith.constant 2 : index
709 %c8 = arith.constant 8 : index
710 %0 = tensor.dim %arg0, %c1 : tensor<2x?xi32>
711 %1 = tensor.extract %arg1[] : tensor<i32>
712 %2 = tensor.generate %arg2, %c8 {
713 ^bb0(%arg4: index, %arg5: index):
714 tensor.yield %1 : i32
716 %3 = tensor.insert_slice %arg0 into %2[0, %arg3] [2, %0] [1, 1] : tensor<2x?xi32> into tensor<?x?xi32>
717 return %3 : tensor<?x?xi32>
719 // CHECK-LABEL: func @insert_slice_propagate_dest_cast
720 // CHECK: %[[UPDATED:.+]] = tensor.insert_slice %{{.+}} into %{{.+}}[0, %{{.+}}] [2, %{{.+}}] [1, 1]
721 // CHECK-SAME: tensor<2x?xi32> into tensor<?x8xi32>
722 // CHECK: %[[CAST:.+]] = tensor.cast %[[UPDATED]]
723 // CHECK: return %[[CAST]]
727 func.func @insert_slice_output_dest_canonicalize(%arg0 : tensor<2x3xi32>, %arg1 : tensor<i32>) -> tensor<3x9xi32> {
728 %c9 = arith.constant 9 : index
729 %c3 = arith.constant 3 : index
730 %2 = tensor.extract %arg1[] : tensor<i32>
731 %4 = tensor.generate %c3, %c9 {
732 ^bb0(%arg2: index, %arg3: index):
733 tensor.yield %2 : i32
735 %5 = tensor.insert_slice %arg0 into %4[0, 1] [2, 3] [1, 1] : tensor<2x3xi32> into tensor<?x?xi32>
736 %6 = tensor.cast %5 : tensor<?x?xi32> to tensor<3x9xi32>
737 return %6 : tensor<3x9xi32>
739 // CHECK-LABEL: func @insert_slice_output_dest_canonicalize
740 // CHECK-SAME: %[[ARG0:[a-zA-z0-9_]+]]: tensor<2x3xi32>
741 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<i32>
742 // CHECK: %[[PAD:.+]] = tensor.extract %[[ARG1]]
743 // CHECK: %[[GENERATE:.+]] = tensor.generate
744 // CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]] into %[[GENERATE]]
745 // CHECK: return %[[RESULT]]
749 // Test case: Folding of tensor.dim(tensor.generate %idx) -> %idx
750 // CHECK-LABEL: func @dim_of_tensor.generate(
751 // CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
752 // CHECK-NOT: tensor.dim
753 // CHECK: return %[[IDX1]] : index
754 func.func @dim_of_tensor.generate(%arg0: index, %arg1: index) -> index {
755 %c3 = arith.constant 3 : index
756 %0 = tensor.generate %arg0, %arg1 {
757 ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
758 tensor.yield %c3 : index
759 } : tensor<2x?x4x?x5xindex>
760 %1 = tensor.dim %0, %c3 : tensor<2x?x4x?x5xindex>
766 // Test case: Folding tensor.dim(tensor.cast %0, %idx) -> tensor.dim %0, %idx
767 // CHECK-LABEL: func @fold_dim_of_tensor.cast
768 // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>
769 // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
770 // CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
771 // CHECK: %[[T0:.+]] = tensor.dim %[[ARG0]], %[[C1]]
772 // CHECK-NEXT: return %[[C4]], %[[T0]]
773 func.func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
774 %c0 = arith.constant 0 : index
775 %c1 = arith.constant 1 : index
776 %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
777 %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
778 %2 = tensor.dim %0, %c1 : tensor<?x?xf32>
779 return %1, %2: index, index
784 // CHECK-LABEL: func @insert_slice_cast
785 func.func @insert_slice_cast(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
786 // CHECK-SAME: %[[ARG0:.*]]: tensor<1x?xf32>
787 %0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x?xf32>
788 // CHECK: %[[RES:.*]] = tensor.insert_slice %[[ARG0]]
789 // CHECK-SAME: [{{.*}}, {{.*}}] [1, {{.*}}] [{{.*}}, {{.*}}]
790 // CHECK-SAME: : tensor<1x?xf32> into tensor<?x?xf32>
791 %1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, %arg5] [%arg6, %arg7] : tensor<?x?xf32> into tensor<?x?xf32>
792 // CHECK: return %[[RES]] : tensor<?x?xf32>
793 return %1 : tensor<?x?xf32>
798 // CHECK-LABEL: func @insert_slice_cast_no_fold
799 func.func @insert_slice_cast_no_fold(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
800 %0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x5xf32>
801 // CHECK: %[[CAST:.*]] = tensor.cast
802 // CHECK: %[[RES:.*]] = tensor.insert_slice %[[CAST]]
803 // CHECK-SAME: [{{.*}}, {{.*}}] [{{.*}}, 5] [{{.*}}, {{.*}}]
804 // CHECK-SAME: : tensor<?x5xf32> into tensor<?x?xf32>
805 %1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, 5] [%arg6, %arg7] : tensor<?x5xf32> into tensor<?x?xf32>
806 // CHECK: return %[[RES]] : tensor<?x?xf32>
807 return %1 : tensor<?x?xf32>
812 // CHECK-LABEL: func @insert_tensor_cast_on_insert_slice_src(
813 // CHECK-SAME: %[[arg0:.*]]: tensor<?x5x?xf32>, %[[arg1:.*]]: tensor<?x?x?xf32>
814 // CHECK: %[[cast:.*]] = tensor.cast %[[arg0]] : tensor<?x5x?xf32> to tensor<64x5x64xf32>
815 // CHECK: %[[r:.*]] = tensor.insert_slice %[[cast]] into %[[arg1]][0, 1, 2] [64, 5, 64] [1, 1, 1] : tensor<64x5x64xf32> into tensor<?x?x?xf32>
816 // CHECK: return %[[r]]
817 func.func @insert_tensor_cast_on_insert_slice_src(
818 %arg0 : tensor<?x5x?xf32>, %arg1 : tensor<?x?x?xf32>, %sz0: index, %sz2: index) -> tensor<?x?x?xf32> {
819 %c64 = arith.constant 64: index
820 %r = tensor.insert_slice %arg0 into %arg1[0, 1, 2] [%c64, 5, %c64] [1, 1, 1]
821 : tensor<?x5x?xf32> into tensor<?x?x?xf32>
822 return %r : tensor<?x?x?xf32>
827 // CHECK-LABEL: func @fold_extract_insert
828 // CHECK-SAME: %{{.+}}: tensor<?x?x?xf32>, %[[SLICE:.+]]: tensor<4x?x8xf32>
829 func.func @fold_extract_insert(%input : tensor<?x?x?xf32>, %slice: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<4x?x8xf32>) {
830 %c0 = arith.constant 0: index
831 %c1 = arith.constant 1: index
832 %0 = tensor.insert_slice %slice into %input[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32>
833 %1 = tensor.extract_slice %0[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<?x?x?xf32> to tensor<4x?x8xf32>
834 // CHECK: return %[[SLICE]]
835 return %1 : tensor<4x?x8xf32>
840 // CHECK-LABEL: func @fold_gather_constant_splat
841 // CHECK-NOT: tensor.gather
842 // CHECK: arith.constant dense<1.000000e-01> : tensor<1x2x1x1x1xf32>
843 func.func @fold_gather_constant_splat(%indices : tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32> {
844 %cst = arith.constant dense<1.000000e-01> : tensor<4x4x4xf32>
845 %0 = tensor.gather %cst[%indices] gather_dims([0, 1, 2]) :
846 (tensor<4x4x4xf32>, tensor<1x2x 3xindex>) -> tensor<1x2x 1x1x1xf32>
847 return %0 : tensor<1x2x 1x1x1xf32>
852 // CHECK-LABEL: func @fold_reshape_constant_splat
853 // CHECK-NOT: tensor.reshape
854 // CHECK: arith.constant dense<1.000000e-01> : tensor<4xf32>
855 func.func @fold_reshape_constant_splat(%shape : tensor<1xi32>) -> tensor<4xf32> {
856 %cst = arith.constant dense<1.000000e-01> : tensor<4x1xf32>
857 %0 = tensor.reshape %cst(%shape)
858 : (tensor<4x1xf32>, tensor<1xi32>) -> tensor<4xf32>
859 return %0 : tensor<4xf32>
864 // CHECK-LABEL: func @fold_reshape_chain
865 // CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<*xf32>
866 // CHECK-SAME: %[[SHAPE_0:[a-zA-Z0-9_]+]]: tensor<?xindex>
867 // CHECK-SAME: %[[SHAPE_1:[a-zA-Z0-9_]+]]: tensor<?xindex>
868 // CHECK-SAME: %[[SHAPE_2:[a-zA-Z0-9_]+]]: tensor<?xindex>
869 // CHECK: %[[RESULT:.*]] = tensor.reshape %[[INPUT]](%[[SHAPE_2]])
870 // CHECK: return %[[RESULT]]
871 func.func @fold_reshape_chain(%input: tensor<*xf32>, %shape_0: tensor<?xindex>, %shape_1: tensor<?xindex>, %shape_2: tensor<?xindex>) -> tensor<*xf32> {
872 %0 = tensor.reshape %input(%shape_0) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
873 %1 = tensor.reshape %0(%shape_1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
874 %2 = tensor.reshape %1(%shape_2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
875 return %2 : tensor<*xf32>
880 // CHECK-LABEL: func @fold_reshape_1d
881 // CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<?xf32>
882 // CHECK-SAME: %[[SHAPE:[a-zA-Z0-9_]+]]: tensor<1xindex>
883 // CHECK: return %[[INPUT]]
884 func.func @fold_reshape_1d(%input: tensor<?xf32>, %shape: tensor<1xindex>) -> tensor<?xf32> {
885 %0 = tensor.reshape %input(%shape) : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
886 return %0 : tensor<?xf32>
891 // CHECK-LABEL: func @fold_extract_constant_splat
892 // CHECK-NOT: tensor.extract_slice
893 // CHECK: arith.constant dense<42> : tensor<4x4xi32>
894 func.func @fold_extract_constant_splat() -> (tensor<4x4xi32>) {
895 %cst = arith.constant dense<42> : tensor<1024x1024xi32>
896 %1 = tensor.extract_slice %cst[0,0] [4,4] [1, 1] : tensor<1024x1024xi32> to tensor<4x4xi32>
897 return %1 : tensor<4x4xi32>
902 // CHECK-LABEL: func @fold_pack_constant_splat
903 // CHECK-NOT: tensor.pack
904 // CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
905 func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
906 %cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
907 %0 = tensor.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
908 inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
909 return %0 : tensor<8x16x8x32xf32>
914 // CHECK-LABEL: func @fold_padding_value_pack_constant_splat
915 // CHECK-NOT: tensor.pack
916 // CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
917 func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
918 %pad = arith.constant 1.000000e-01 : f32
919 %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
920 %0 = tensor.pack %cst
921 padding_value(%pad : f32)
922 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
923 inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
924 return %0 : tensor<8x16x8x32xf32>
930 // CHECK-LABEL: func @nofold_padding_value_pack_constant_splat
931 // CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32>
932 // CHECK: tensor.pack
933 func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
934 %pad = arith.constant 0.0 : f32
935 %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
936 %0 = tensor.pack %cst
937 padding_value(%pad : f32)
938 outer_dims_perm = [1, 0]
939 inner_dims_pos = [0, 1]
940 inner_tiles = [8, 32]
941 into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
942 return %0 : tensor<8x16x8x32xf32>
947 func.func @fold_padding_value_pack(%arg0: tensor<1200x500000xf32>) -> tensor<31250x1200x16x1xf32> {
948 %cst = arith.constant 0.000000e+00 : f32
949 %0 = tensor.empty() : tensor<31250x1200x16x1xf32>
950 %pack = tensor.pack %arg0
951 padding_value(%cst : f32)
952 outer_dims_perm = [1, 0]
953 inner_dims_pos = [1, 0]
954 inner_tiles = [16, 1]
955 into %0 : tensor<1200x500000xf32> -> tensor<31250x1200x16x1xf32>
956 return %pack : tensor<31250x1200x16x1xf32>
958 // CHECK-LABEL: func @fold_padding_value_pack
959 // CHECK-NOT: padding_value
963 func.func @infer_src_shape_pack(%src: tensor<?x?x?x?xf32>, %dest: tensor<10x20x30x40x16xf32>) -> tensor<10x20x30x40x16xf32> {
964 %cst = arith.constant 0.000000e+00 : f32
965 %pack = tensor.pack %src
966 padding_value(%cst : f32)
967 outer_dims_perm = [2, 1, 3, 0]
970 into %dest : tensor<?x?x?x?xf32> -> tensor<10x20x30x40x16xf32>
971 return %pack : tensor<10x20x30x40x16xf32>
973 // CHECK-LABEL: func.func @infer_src_shape_pack
974 // CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
975 // CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
976 // CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<40x20x?x30xf32>
977 // CHECK: %[[PACK:.+]] = tensor.pack %[[CAST_SRC]] {{.+}} into %[[DEST]]
978 // CHECK: return %[[PACK]]
982 func.func @infer_dest_shape_pack(%src: tensor<30x20x?x10xf32>, %dest: tensor<?x?x?x?x16xf32>) -> tensor<?x?x?x?x16xf32> {
983 %cst = arith.constant 0.000000e+00 : f32
984 %pack = tensor.pack %src
985 padding_value(%cst : f32)
986 outer_dims_perm = [2, 1, 3, 0]
989 into %dest : tensor<30x20x?x10xf32> -> tensor<?x?x?x?x16xf32>
990 return %pack : tensor<?x?x?x?x16xf32>
992 // CHECK-LABEL: func.func @infer_dest_shape_pack
993 // CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
994 // CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
995 // CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<?x20x10x30x16xf32>
996 // CHECK: %[[PACK:.+]] = tensor.pack %[[SRC]] {{.+}} into %[[CAST_DEST]]
997 // CHECK: %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<?x20x10x30x16xf32> to tensor<?x?x?x?x16xf32>
998 // CHECK: return %[[CAST_PACK]]
1002 func.func @no_infer_pack_shape(%arg0: tensor<?x32x100xf32>, %arg1: index) -> tensor<32x7x?x16x1xf32> {
1003 %cst = arith.constant 0.000000e+00 : f32
1004 %0 = tensor.empty(%arg1) : tensor<32x7x?x16x1xf32>
1005 %pack = tensor.pack %arg0 padding_value(%cst : f32) outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [16, 1] into %0 : tensor<?x32x100xf32> -> tensor<32x7x?x16x1xf32>
1006 return %pack : tensor<32x7x?x16x1xf32>
1008 // CHECK-LABEL: func.func @no_infer_pack_shape
1009 // CHECK-NOT: tensor.cast
1013 func.func @fold_padding_value_pack_negative1(%arg0: tensor<1200x499999xf32>) -> tensor<31250x1200x16x1xf32> {
1014 %cst = arith.constant 0.000000e+00 : f32
1015 %0 = tensor.empty() : tensor<31250x1200x16x1xf32>
1016 %pack = tensor.pack %arg0
1017 padding_value(%cst : f32)
1018 outer_dims_perm = [1, 0]
1019 inner_dims_pos = [1, 0]
1020 inner_tiles = [16, 1]
1021 into %0 : tensor<1200x499999xf32> -> tensor<31250x1200x16x1xf32>
1022 return %pack : tensor<31250x1200x16x1xf32>
1024 // CHECK-LABEL: func @fold_padding_value_pack_negative1
1025 // CHECK: tensor.pack
1026 // CHECK-SAME: padding_value
1030 func.func @fold_padding_value_pack_negative2(%arg0: tensor<1200x?xf32>, %arg1: tensor<?x1200x16x1xf32>) -> tensor<?x1200x16x1xf32> {
1031 %cst = arith.constant 0.000000e+00 : f32
1032 %pack = tensor.pack %arg0
1033 padding_value(%cst : f32)
1034 outer_dims_perm = [1, 0]
1035 inner_dims_pos = [1, 0]
1036 inner_tiles = [16, 1]
1037 into %arg1 : tensor<1200x?xf32> -> tensor<?x1200x16x1xf32>
1038 return %pack : tensor<?x1200x16x1xf32>
1040 // CHECK-LABEL: func @fold_padding_value_pack_negative2
1041 // CHECK: tensor.pack
1042 // CHECK-SAME: padding_value
1046 func.func @fold_padding_value_pack_negative3(%arg0: tensor<1200x500000xf32>, %arg1: tensor<?x1200x?x1xf32>, %tile : index) -> tensor<?x1200x?x1xf32> {
1047 %cst = arith.constant 0.000000e+00 : f32
1048 %pack = tensor.pack %arg0
1049 padding_value(%cst : f32)
1050 outer_dims_perm = [1, 0]
1051 inner_dims_pos = [1, 0]
1052 inner_tiles = [%tile, 1]
1053 into %arg1 : tensor<1200x500000xf32> -> tensor<?x1200x?x1xf32>
1054 return %pack : tensor<?x1200x?x1xf32>
1056 // CHECK-LABEL: func @fold_padding_value_pack_negative3
1057 // CHECK: tensor.pack
1058 // CHECK-SAME: padding_value
1062 // CHECK-LABEL: func @fold_unpack_constant_splat
1063 // CHECK-NOT: tensor.unpack
1064 // CHECK: arith.constant dense<1.000000e-01> : tensor<128x256xf32>
1065 func.func @fold_unpack_constant_splat(%dest : tensor<128x256xf32>) -> tensor<128x256xf32> {
1066 %cst = arith.constant dense<1.000000e-01> : tensor<16x8x8x32xf32>
1067 %0 = tensor.unpack %cst inner_dims_pos = [0, 1]
1068 inner_tiles = [8, 32] into %dest : tensor<16x8x8x32xf32> -> tensor<128x256xf32>
1069 return %0 : tensor<128x256xf32>
1074 func.func @infer_dest_shape_unpack(%src: tensor<10x20x30x40x16xf32>, %dest: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
1075 %unpack = tensor.unpack %src
1076 outer_dims_perm = [2, 1, 3, 0]
1077 inner_dims_pos = [2]
1079 into %dest : tensor<10x20x30x40x16xf32> -> tensor<?x?x?x?xf32>
1080 return %unpack : tensor<?x?x?x?xf32>
1082 // CHECK-LABEL: func.func @infer_dest_shape_unpack
1083 // CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
1084 // CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
1085 // CHECK: %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?xf32> to tensor<40x20x?x30xf32>
1086 // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[SRC]] {{.+}} into %[[CAST_DEST]]
1087 // CHECK: %[[CAST_UNPACK:.+]] = tensor.cast %[[UNPACK]] : tensor<40x20x?x30xf32> to tensor<?x?x?x?xf32>
1088 // CHECK: return %[[CAST_UNPACK]]
1092 func.func @infer_src_shape_unpack(%src: tensor<?x?x?x?x16xf32>, %dest: tensor<30x20x?x10xf32>) -> tensor<30x20x?x10xf32> {
1093 %unpack = tensor.unpack %src
1094 outer_dims_perm = [2, 1, 3, 0]
1095 inner_dims_pos = [2]
1097 into %dest : tensor<?x?x?x?x16xf32> -> tensor<30x20x?x10xf32>
1098 return %unpack : tensor<30x20x?x10xf32>
1100 // CHECK-LABEL: func.func @infer_src_shape_unpack
1101 // CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
1102 // CHECK-SAME: %[[DEST:[0-9a-zA-Z]+]]
1103 // CHECK: %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?x16xf32> to tensor<?x20x10x30x16xf32>
1104 // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[CAST_SRC]]
1105 // CHECK: return %[[UNPACK]]
1109 func.func @no_infer_unpack_shape(%arg1: tensor<32x7x?x16x1xf32>, %arg2: index) -> tensor<?x32x100xf32> {
1110 %cst = arith.constant 0.000000e+00 : f32
1111 %0 = tensor.empty(%arg2) : tensor<?x32x100xf32>
1112 %unpack = tensor.unpack %arg1 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [16, 1] into %0 : tensor<32x7x?x16x1xf32> -> tensor<?x32x100xf32>
1113 return %unpack : tensor<?x32x100xf32>
1115 // CHECK-LABEL: func.func @no_infer_unpack_shape
1116 // CHECK-NOT: tensor.cast
1121 // CHECK-LABEL: func @fold_overlapping_insert
1122 // CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
1123 func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {
1124 %c0 = arith.constant 0: index
1125 %c1 = arith.constant 1: index
1126 %0 = tensor.insert_slice %slice1 into %input[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32>
1127 // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE2]] into %[[INPUT]]
1128 %1 = tensor.insert_slice %slice2 into %0[0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32>
1129 // CHECK: return %[[INSERT]]
1130 return %1 : tensor<?x?x?xf32>
1135 func.func @compose_expand_of_expand(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index)
1136 -> tensor<?x6x4x?x5xf32> {
1137 %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
1138 : tensor<?x?xf32> into tensor<?x4x?xf32>
1139 %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4]] output_shape [%arg3, 6, 4, %arg4, 5] : tensor<?x4x?xf32> into tensor<?x6x4x?x5xf32>
1140 return %1 : tensor<?x6x4x?x5xf32>
1142 // CHECK-LABEL: compose_expand_of_expand
1143 // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] output_shape [%arg3, 6, 4, %arg4, 5]
1144 // CHECK-NOT: tensor.expand_shape
1148 func.func @compose_expand_of_expand_of_zero_dim(%arg0 : tensor<f32>)
1149 -> tensor<1x1x1xf32> {
1150 %0 = tensor.expand_shape %arg0 [] output_shape [1] : tensor<f32> into tensor<1xf32>
1151 %1 = tensor.expand_shape %0 [[0, 1, 2]] output_shape [1, 1, 1]
1152 : tensor<1xf32> into tensor<1x1x1xf32>
1153 return %1 : tensor<1x1x1xf32>
1155 // CHECK-LABEL: compose_expand_of_expand_of_zero_dim
1156 // CHECK: tensor.expand_shape %{{.*}} [] output_shape [1, 1, 1]
1157 // CHECK-SAME: tensor<f32> into tensor<1x1x1xf32>
1161 // CHECK-LABEL: func.func @collapse_of_cast(
1162 // CHECK-SAME: %[[IN:.*]]: tensor<8x12x32xf32>) -> tensor<?x32xf32> {
1163 // CHECK-NEXT: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[IN]] {{\[}}[0, 1], [2]] : tensor<8x12x32xf32> into tensor<96x32xf32>
1164 // CHECK-NEXT: %[[CAST:.*]] = tensor.cast %[[COLLAPSE]] : tensor<96x32xf32> to tensor<?x32xf32>
1165 // CHECK-NEXT: return %[[CAST]] : tensor<?x32xf32>
1166 func.func @collapse_of_cast(%t: tensor<8x12x32xf32>) -> tensor<?x32xf32> {
1167 %0 = tensor.cast %t : tensor<8x12x32xf32> to tensor<?x?x?xf32>
1168 %1 = tensor.collapse_shape %0 [[0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
1169 %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<?x32xf32>
1170 return %2 : tensor<?x32xf32>
1175 func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> {
1176 %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [3, 4, 4]
1177 : tensor<12x4xf32> into tensor<3x4x4xf32>
1178 %1 = tensor.collapse_shape %0 [[0, 1], [2]]
1179 : tensor<3x4x4xf32> into tensor<12x4xf32>
1180 return %1 : tensor<12x4xf32>
1182 // CHECK-LABEL: @fold_collapse_of_expand
1183 // CHECK-NOT: tensor.{{.*}}_shape
1187 func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index)
1188 -> tensor<?x?xf32> {
1189 %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
1190 : tensor<?x?xf32> into tensor<?x4x?xf32>
1191 %1 = tensor.collapse_shape %0 [[0, 1], [2]]
1192 : tensor<?x4x?xf32> into tensor<?x?xf32>
1193 return %1 : tensor<?x?xf32>
1195 // CHECK-LABEL: @fold_collapse_of_expand_dynamic
1196 // CHECK-NOT: tensor.{{.*}}_shape
1200 func.func @fold_collapse_of_expand_fully_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
1201 -> tensor<?x?xf32> {
1202 %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
1203 : tensor<?x?xf32> into tensor<?x?x?xf32>
1204 %1 = tensor.collapse_shape %0 [[0, 1], [2]]
1205 : tensor<?x?x?xf32> into tensor<?x?xf32>
1206 return %1 : tensor<?x?xf32>
1208 // CHECK-LABEL: @fold_collapse_of_expand_fully_dynamic
1209 // CHECK-NOT: tensor.{{.*}}_shape
1213 func.func @no_fold_parallel_collapse_of_expand_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index)
1214 -> tensor<?x?x?xf32> {
1215 %0 = tensor.expand_shape %arg0 [[0, 1], [2], [3]] output_shape [%arg1, %arg2, %arg3, %arg4]
1216 : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
1217 %1 = tensor.collapse_shape %0 [[0], [1], [2, 3]]
1218 : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
1219 return %1 : tensor<?x?x?xf32>
1221 // CHECK-LABEL: @no_fold_parallel_collapse_of_expand_dynamic
1222 // CHECK: tensor.expand_shape
1223 // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape
1224 // CHECK: return %[[COLLAPSE]]
1228 func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf32> {
1229 %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
1230 : tensor<3x4x4xf32> into tensor<12x4xf32>
1231 %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [3, 4, 4]
1232 : tensor<12x4xf32> into tensor<3x4x4xf32>
1233 return %1 : tensor<3x4x4xf32>
1235 // CHECK-LABEL: @fold_expand_of_collapse
1236 // CHECK-NOT: tensor.{{.*}}_shape
1240 func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
1241 -> tensor<?x4x?xf32> {
1242 %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
1243 : tensor<?x4x?xf32> into tensor<?x?xf32>
1244 %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
1245 : tensor<?x?xf32> into tensor<?x4x?xf32>
1246 return %1 : tensor<?x4x?xf32>
1248 // CHECK-LABEL: @fold_expand_of_collapse_dynamic
1249 // CHECK-NOT: tensor.{{.*}}_shape
1253 func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
1254 -> tensor<?x?x?xf32> {
1255 %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
1256 : tensor<?x?x?xf32> into tensor<?x?xf32>
1257 %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
1258 : tensor<?x?xf32> into tensor<?x?x?xf32>
1259 return %1 : tensor<?x?x?xf32>
1261 // CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
1262 // CHECK: tensor.collapse_shape
1263 // CHECK: %[[EXPAND:.+]] = tensor.expand_shape
1264 // CHECK: return %[[EXPAND]]
1268 func.func @compose_expand_of_collapse_last_two_dims(%arg0: tensor<?x64x1xf32>) -> tensor<?x384xf32> {
1269 %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2]] : tensor<?x64x1xf32> into tensor<?xf32>
1270 %c0 = arith.constant 0 : index
1271 %dim = tensor.dim %collapsed, %c0 : tensor<?xf32>
1272 %c384= arith.constant 384 : index
1273 %div = arith.divui %dim, %c384 : index
1274 %expanded = tensor.expand_shape %collapsed [[0, 1]] output_shape [%div, 384] : tensor<?xf32> into tensor<?x384xf32>
1275 return %expanded : tensor<?x384xf32>
1277 // CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 64)>
1278 // CHECK-LABEL: @compose_expand_of_collapse_last_two_dims
1279 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x64x1xf32>
1280 // CHECK: %[[CONSTANT0:.+]] = arith.constant 0 : index
1281 // CHECK: %[[CONSTANT384:.+]] = arith.constant 384 : index
1282 // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]] : tensor<?x64x1xf32> into tensor<?xf32>
1283 // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[CONSTANT0]] : tensor<?x64x1xf32>
1284 // CHECK: %[[AFFAPPLY:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
1285 // CHECK: %[[DIVUI:.+]] = arith.divui %[[AFFAPPLY]], %[[CONSTANT384]] : index
1286 // CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0, 1]] output_shape [%[[DIVUI]], 384] : tensor<?xf32> into tensor<?x384xf32>
1287 // CHECK: return %[[RESULT]]
1291 func.func @compose_expand_of_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>)
1292 -> tensor<24x5x42x8xf32> {
1293 %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]]
1294 : tensor<2x3x4x5x6x7x8xf32> into tensor<40320xf32>
1295 %1 = tensor.expand_shape %0 [[0, 1, 2, 3]] output_shape [24, 5, 42, 8]
1296 : tensor<40320xf32> into tensor<24x5x42x8xf32>
1297 return %1 : tensor<24x5x42x8xf32>
1299 // CHECK: func @compose_expand_of_collapse
1300 // CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8xf32>
1301 // CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
1302 // CHECK-SAME: [0, 1, 2], [3], [4, 5], [6]
1303 // CHECK: return %[[RESULT]]
1307 func.func @compose_expand_of_collapse_7D(%arg0 : tensor<24x5x42x8xf32>)
1308 -> tensor<2x3x4x5x6x7x8xf32> {
1309 %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3]]
1310 : tensor<24x5x42x8xf32> into tensor<40320xf32>
1311 %1 = tensor.expand_shape %0 [[0, 1, 2, 3, 4, 5, 6]] output_shape [2, 3, 4, 5, 6, 7, 8]
1312 : tensor<40320xf32> into tensor<2x3x4x5x6x7x8xf32>
1313 return %1 : tensor<2x3x4x5x6x7x8xf32>
1315 // CHECK: func @compose_expand_of_collapse_7D
1316 // CHECK-SAME: %[[ARG0:.+]]: tensor<24x5x42x8xf32>
1317 // CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[ARG0]]
1318 // CHECK-SAME: [0, 1, 2], [3], [4, 5], [6]
1319 // CHECK: return %[[RESULT]]
1323 func.func @compose_collapse_of_expand(%arg : tensor<?x?x?xi64>, %arg1: index, %arg2: index, %arg3: index)
1324 -> tensor<?x?xi64> {
1325 %0 = tensor.expand_shape %arg [[0], [1], [2, 3]] output_shape [%arg1, %arg2, %arg3, 1]
1326 : tensor<?x?x?xi64> into tensor<?x?x?x1xi64>
1327 %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
1328 : tensor<?x?x?x1xi64> into tensor<?x?xi64>
1329 return %1 : tensor<?x?xi64>
1331 // CHECK-LABEL: func @compose_collapse_of_expand
1332 // CHECK: (%[[ARG:.*]]: tensor<?x?x?xi64>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
1333 // CHECK-NEXT: tensor.collapse_shape %[[ARG]]
1334 // CHECK-SAME: [0, 1], [2]
1335 // CHECK-SAME: : tensor<?x?x?xi64> into tensor<?x?xi64>
1339 func.func @compose_collapse_of_expand_1D(%arg0 : tensor<2048xf32>)
1340 -> tensor<4x512xf32> {
1341 %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3]] output_shape [1, 4, 1, 512]
1342 : tensor<2048xf32> into tensor<1x4x1x512xf32>
1343 %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]]
1344 : tensor<1x4x1x512xf32> into tensor<4x512xf32>
1345 return %1 : tensor<4x512xf32>
1347 // CHECK: func @compose_collapse_of_expand_1D
1348 // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] output_shape [4, 512]
1349 // CHECK-SAME: tensor<2048xf32> into tensor<4x512xf32>
1353 func.func @compose_expand_of_collapse_0_rank_to_expand(%arg0 : tensor<1x1x1xf32>)
1354 -> tensor<1x1x1x1xf32> {
1355 %0 = tensor.collapse_shape %arg0 []
1356 : tensor<1x1x1xf32> into tensor<f32>
1357 %1 = tensor.expand_shape %0 [] output_shape [1, 1, 1, 1]
1358 : tensor<f32> into tensor<1x1x1x1xf32>
1359 return %1 : tensor<1x1x1x1xf32>
1361 // CHECK: func @compose_expand_of_collapse_0_rank_to_expand
1362 // CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x1xf32>
1363 // CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[ARG0]]
1364 // CHECK-SAME: {{\[}}[0], [1], [2, 3]] output_shape [1, 1, 1, 1]
1365 // CHECK: return %[[RESULT]]
1369 func.func @compose_expand_of_collapse_0_rank_to_collapse(%arg0 : tensor<1x1x1x1xf32>)
1370 -> tensor<1x1x1xf32> {
1371 %0 = tensor.collapse_shape %arg0 []
1372 : tensor<1x1x1x1xf32> into tensor<f32>
1373 %1 = tensor.expand_shape %0 [] output_shape [1, 1, 1]
1374 : tensor<f32> into tensor<1x1x1xf32>
1375 return %1 : tensor<1x1x1xf32>
1377 // CHECK: func @compose_expand_of_collapse_0_rank_to_collapse
1378 // CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x1x1xf32>
1379 // CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
1380 // CHECK-SAME: [0], [1], [2, 3]
1381 // CHECK: return %[[RESULT]]
1385 // CHECK-LABEL: func @zero_rank_reshape_multi
1386 func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
1387 // CHECK: return %arg0
1388 %0 = tensor.expand_shape %arg0 [] output_shape [1] : tensor<f32> into tensor<1xf32>
1389 %1 = tensor.expand_shape %0 [[0, 1]] output_shape [1, 1] : tensor<1xf32> into tensor<1x1xf32>
1390 %2 = tensor.collapse_shape %1 [] : tensor<1x1xf32> into tensor<f32>
1391 return %2 : tensor<f32>
1396 func.func @compose_collapse_of_collapse(%arg0 : tensor<?x?x?x?x?xf32>)
1397 -> tensor<?x?xf32> {
1398 %0 = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]]
1399 : tensor<?x?x?x?x?xf32> into tensor<?x?x?xf32>
1400 %1 = tensor.collapse_shape %0 [[0, 1], [2]]
1401 : tensor<?x?x?xf32> into tensor<?x?xf32>
1402 return %1 : tensor<?x?xf32>
1404 // CHECK-LABEL: func @compose_collapse_of_collapse
1405 // CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
1406 // CHECK-NOT: tensor.collapse_shape
1410 func.func @compose_collapse_of_collapse_zero_dim(%arg0 : tensor<1x1x1xf32>)
1412 %0 = tensor.collapse_shape %arg0 [[0, 1, 2]]
1413 : tensor<1x1x1xf32> into tensor<1xf32>
1414 %1 = tensor.collapse_shape %0 [] : tensor<1xf32> into tensor<f32>
1415 return %1 : tensor<f32>
1417 // CHECK-LABEL: func @compose_collapse_of_collapse_zero_dim
1418 // CHECK: tensor.collapse_shape %{{.*}} []
1419 // CHECK-SAME: tensor<1x1x1xf32> into tensor<f32>
1423 func.func @fold_collapse_of_expand_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> {
1424 %0 = tensor.expand_shape %arg0 [[0, 1, 2], [3]] output_shape [1, 4, 1, 512]
1425 : tensor<4x512xf32> into tensor<1x4x1x512xf32>
1426 %1 = tensor.collapse_shape %0 [[0, 1, 2, 3]]
1427 : tensor<1x4x1x512xf32> into tensor<2048xf32>
1428 return %1 : tensor<2048xf32>
1430 // CHECK: func @fold_collapse_of_expand_1D
1431 // CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1]]
1432 // CHECK-SAME: tensor<4x512xf32> into tensor<2048xf32>
1436 func.func @fold_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x1xf32>)
1437 -> tensor<4x512x1x1xf32> {
1438 %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3], [4], [5]] output_shape [1, 4, 1, 512, 1, 1] : tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32>
1439 %1 = tensor.collapse_shape %0 [[0, 1, 2], [3], [4], [5]]
1440 : tensor<1x4x1x512x1x1xf32> into tensor<4x512x1x1xf32>
1441 return %1 : tensor<4x512x1x1xf32>
1443 // CHECK: func @fold_collapse_of_expand_unit_dims
1444 // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]] output_shape [4, 512, 1, 1]
1445 // CHECK-SAME: tensor<2048x1x1xf32> into tensor<4x512x1x1xf32>
1449 func.func @compose_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x2048xf32>)
1450 -> tensor<4x512x1x512x4xf32> {
1451 %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]] output_shape [1, 4, 1, 512, 1, 1, 512, 1, 4] : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32>
1452 %1 = tensor.collapse_shape %0 [[0, 1, 2], [3, 4], [5], [6, 7], [8]]
1453 : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32>
1454 return %1 : tensor<4x512x1x512x4xf32>
1456 // CHECK: func @compose_collapse_of_expand_unit_dims
1457 // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]] output_shape [4, 512, 1, 512, 4]
1458 // CHECK-SAME: tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32>
1462 func.func @compose_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>)
1463 -> tensor<2x1xf32> {
1464 %0 = tensor.expand_shape %arg0 [[0, 1, 2]] output_shape [2, 1, 1]
1465 : tensor<2xf32> into tensor<2x1x1xf32>
1466 %1 = tensor.collapse_shape %0 [[0], [1, 2]]
1467 : tensor<2x1x1xf32> into tensor<2x1xf32>
1468 return %1 : tensor<2x1xf32>
1470 // CHECK: func @compose_collapse_of_expand_trailing_unit_dims
1471 // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] output_shape [2, 1]
1472 // CHECK-SAME: tensor<2xf32> into tensor<2x1xf32>
1476 func.func @compose_collapse_of_collapse_unit_dims_dynamic(
1477 %arg0 : tensor<?x1x?x1x1x?x?x1x1xf32>) -> tensor<?x?x?x?xf32> {
1478 %0 = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4], [5], [6, 7, 8]]
1479 : tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x1x1x?x?xf32>
1480 %1 = tensor.collapse_shape %0 [[0], [1], [2, 3, 4], [5]]
1481 : tensor<?x?x1x1x?x?xf32> into tensor<?x?x?x?xf32>
1482 return %1 : tensor<?x?x?x?xf32>
1484 // CHECK: func @compose_collapse_of_collapse_unit_dims_dynamic
1485 // CHECK: tensor.collapse_shape
1486 // CHECK-SAME: [0], [1, 2], [3, 4, 5], [6, 7, 8]
1487 // CHECK-SAME: tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x?x?xf32>
1491 func.func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>)
1492 -> tensor<2x1xf32> {
1493 %0 = tensor.expand_shape %arg0 [[0, 1, 2]] output_shape [2, 1, 1] : tensor<2xf32> into tensor<2x1x1xf32>
1494 %1 = tensor.collapse_shape %0 [[0], [1, 2]]
1495 : tensor<2x1x1xf32> into tensor<2x1xf32>
1496 return %1 : tensor<2x1xf32>
1498 // CHECK: func @fold_collapse_of_expand_trailing_unit_dims
1499 // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] output_shape [2, 1]
1500 // CHECK-SAME: tensor<2xf32> into tensor<2x1xf32>
1504 func.func @fold_collapse_of_collapse_trailing_unit_dims_dynamic(
1505 %arg0: tensor<1x1x?x1x1x1xf32>) -> tensor<?xf32> {
1506 %0 = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4], [5]]
1507 : tensor<1x1x?x1x1x1xf32> into tensor<?x1x1x1xf32>
1508 %1 = tensor.collapse_shape %0 [[0, 1, 2, 3]]
1509 : tensor<?x1x1x1xf32> into tensor<?xf32>
1510 return %1 : tensor<?xf32>
1512 // CHECK: func @fold_collapse_of_collapse_trailing_unit_dims_dynamic
1513 // CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4, 5]]
1514 // CHECK-SAME: tensor<1x1x?x1x1x1xf32> into tensor<?xf32>
1518 func.func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>)
1519 -> tensor<12x42xf32> {
1520 %0 = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [12, 42, 1, 1, 1] : tensor<12x42x1x1xf32> into tensor<12x42x1x1x1xf32>
1521 %1 = tensor.collapse_shape %0 [[0], [1, 2, 3, 4]]
1522 : tensor<12x42x1x1x1xf32> into tensor<12x42xf32>
1523 return %1 : tensor<12x42xf32>
1525 // CHECK: func @fold_collapse_of_expand_trailing_unit_dims
1526 // CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0], [1, 2, 3]]
1527 // CHECK-SAME: tensor<12x42x1x1xf32> into tensor<12x42xf32>
1531 func.func @fold_collapse_of_expand_unit_dims_in_middle(%arg0 : tensor<?x?x?xf32>, %sz0: index, %sz1: index, %sz2: index)
1532 -> tensor<?x?xf32> {
1533 %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [%sz0, %sz1, 1, %sz2]
1534 : tensor<?x?x?xf32> into tensor<?x?x1x?xf32>
1535 %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]]
1536 : tensor<?x?x1x?xf32> into tensor<?x?xf32>
1537 return %1 : tensor<?x?xf32>
1539 // CHECK-LABEL: func @fold_collapse_of_expand_unit_dims_in_middle
1540 // CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>
1541 // CHECK: tensor.collapse_shape %[[ARG]] {{\[}}[0], [1, 2]]
1542 // CHECK-SAME: tensor<?x?x?xf32> into tensor<?x?xf32>
1546 func.func @no_fold_collapse_of_expand_incompatible(%arg0 : tensor<4x6x8xf32>)
1547 -> tensor<2x6x16xf32> {
1548 %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]] output_shape [2, 2, 3, 2, 8]
1549 : tensor<4x6x8xf32> into tensor<2x2x3x2x8xf32>
1550 %1 = tensor.collapse_shape %0 [[0], [1, 2], [3, 4]]
1551 : tensor<2x2x3x2x8xf32> into tensor<2x6x16xf32>
1552 return %1 : tensor<2x6x16xf32>
1554 // CHECK-LABEL: func @no_fold_collapse_of_expand_incompatible
1555 // CHECK: tensor.expand_shape
1556 // CHECK: tensor.collapse_shape
1560 func.func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>)
1561 -> tensor<12x1xf32> {
1562 %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [3, 2, 2, 1]
1563 : tensor<3x2x2xf32> into tensor<3x2x2x1xf32>
1564 %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]]
1565 : tensor<3x2x2x1xf32> into tensor<12x1xf32>
1566 return %1 : tensor<12x1xf32>
1568 // CHECK: func @no_fold_collapse_of_expand_empty_expr
1569 // CHECK-SAME: %[[ARG0:.+]]: tensor<3x2x2xf32>
1570 // CHECK: %[[RARG0:.+]] = tensor.expand_shape %[[ARG0]]
1571 // CHECK-SAME: {{\[}}[0], [1], [2, 3]] output_shape [3, 2, 2, 1]
1572 // CHECK: %[[RES:.+]] = tensor.collapse_shape %[[RARG0]]
1573 // CHECK-SAME: [0, 1, 2], [3]
1574 // CHECK: return %[[RES:.+]] : tensor<12x1xf32>
1578 func.func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> {
1579 %c0 = arith.constant dense<42> : tensor<2x8xi32>
1580 %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2]
1581 : tensor<2x8xi32> into tensor<2x4x2xi32>
1582 return %0 : tensor<2x4x2xi32>
1584 // CHECK-LABEL: @reshape_splat_constant_int32
1585 // CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xi32>
1586 // CHECK-NOT: tensor.expand_shape
1587 // CHECK: return %[[CST]]
1589 func.func @expand_shape_splat(%arg : f32) -> tensor<2x2x2xf32> {
1590 %c0 = tensor.splat %arg : tensor<2x4xf32>
1591 %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 2, 2]
1592 : tensor<2x4xf32> into tensor<2x2x2xf32>
1593 return %0 : tensor<2x2x2xf32>
1595 // CHECK-LABEL: @expand_shape_splat
1596 // CHECK-SAME: %[[ARG0:.+]]: f32
1597 // CHECK: %[[CST:.*]] = tensor.splat %[[ARG0:.+]] : tensor<2x2x2xf32>
1598 // CHECK-NOT: tensor.expand_shape
1599 // CHECK: return %[[CST]]
1603 // CHECK-LABEL: @expand_shape_splat_dynamic_no_fold
1604 // CHECK-SAME: (%[[F:.+]]: f32, %[[M:.+]]: index, %[[SZ0:.+]]: index)
1605 func.func @expand_shape_splat_dynamic_no_fold(%arg: f32, %m: index, %sz0: index) -> tensor<2x2x?xf32> {
1606 // CHECK: %[[SPLAT:.+]] = tensor.splat %[[F]][%[[M]]] : tensor<2x?xf32>
1607 // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[SPLAT]]
1608 %c0 = tensor.splat %arg[%m] : tensor<2x?xf32>
1609 %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 2, %sz0] : tensor<2x?xf32> into tensor<2x2x?xf32>
1610 return %0 : tensor<2x2x?xf32>
1615 func.func @collapse_shape_splat(%arg : f32) -> tensor<2x4xf32> {
1616 %c0 = tensor.splat %arg : tensor<2x2x2xf32>
1617 %0 = tensor.collapse_shape %c0 [[0], [1, 2]]
1618 : tensor<2x2x2xf32> into tensor<2x4xf32>
1619 return %0 : tensor<2x4xf32>
1621 // CHECK-LABEL: @collapse_shape_splat
1622 // CHECK-SAME: %[[ARG0:.+]]: f32
1623 // CHECK: %[[CST:.*]] = tensor.splat %[[ARG0:.+]] : tensor<2x4xf32>
1624 // CHECK-NOT: tensor.collapse_shape
1625 // CHECK: return %[[CST]]
1629 // CHECK-LABEL: @collapse_shape_splat_dynamic_no_fold
1630 // CHECK-SAME: %[[F:.+]]: f32
1631 // CHECK-SAME: %[[M:.+]]: index
1632 func.func @collapse_shape_splat_dynamic_no_fold(%f: f32, %m: index) -> tensor<2x?xf32> {
1633 // CHECK: %[[SPLAT:.+]] = tensor.splat %[[F]][%[[M]]]
1634 // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[SPLAT]]
1635 %c0 = tensor.splat %f[%m] : tensor<2x2x?xf32>
1636 %0 = tensor.collapse_shape %c0 [[0], [1, 2]] : tensor<2x2x?xf32> into tensor<2x?xf32>
1637 return %0 : tensor<2x?xf32>
1642 func.func @reshape_splat_constant_int16() -> tensor<2x4x2xi16> {
1643 %c0 = arith.constant dense<42> : tensor<2x8xi16>
1644 %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2]
1645 : tensor<2x8xi16> into tensor<2x4x2xi16>
1646 return %0 : tensor<2x4x2xi16>
1648 // CHECK-LABEL: @reshape_splat_constant_int16
1649 // CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xi16>
1650 // CHECK-NOT: tensor.expand_shape
1651 // CHECK: return %[[CST]]
1655 func.func @reshape_splat_constant_float32() -> tensor<2x4x2xf32> {
1656 %c0 = arith.constant dense<42.0> : tensor<2x8xf32>
1657 %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2]
1658 : tensor<2x8xf32> into tensor<2x4x2xf32>
1659 return %0 : tensor<2x4x2xf32>
1661 // CHECK-LABEL: @reshape_splat_constant_float32
1662 // CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xf32>
1663 // CHECK-NOT: tensor.expand_shape
1664 // CHECK: return %[[CST]]
1668 func.func @reshape_splat_constant_float64() -> tensor<2x4x2xf64> {
1669 %c0 = arith.constant dense<42.0> : tensor<2x8xf64>
1670 %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2]
1671 : tensor<2x8xf64> into tensor<2x4x2xf64>
1672 return %0 : tensor<2x4x2xf64>
1674 // CHECK-LABEL: @reshape_splat_constant_float64
1675 // CHECK: %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xf64>
1676 // CHECK-NOT: tensor.expand_shape
1677 // CHECK: return %[[CST]]
1681 // CHECK-LABEL: func @fold_rank
1682 func.func @fold_rank() -> (index) {
1683 %const_0 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]>
1686 // Fold a ank into a constant
1687 // CHECK-NEXT: [[C3:%.+]] = arith.constant 3 : index
1688 %rank_0 = tensor.rank %const_0 : tensor<2x1x4xi32>
1690 // CHECK-NEXT: return [[C3]]
1691 return %rank_0 : index
1696 // CHECK-LABEL: func @pad_same_static_shape(
1697 // CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>
1698 // CHECK-NOT: tensor.pad
1699 // CHECK: return %[[ARG0]]
1700 func.func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
1701 -> tensor<5x6xf32> {
1702 %cst = arith.constant 0.000000e+00 : f32
1703 %0 = tensor.pad %arg0 low[%a, 0] high[0, %a] {
1704 ^bb0(%arg1: index, %arg2: index):
1705 tensor.yield %cst : f32
1706 } : tensor<5x6xf32> to tensor<5x6xf32>
1707 return %0 : tensor<5x6xf32>
1712 // CHECK-LABEL: func @pad_fold_static(
1713 // CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
1714 // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
1715 // CHECK-NOT: arith.constant 4 : index
1716 // CHECK: %[[PADDED:.*]] = tensor.pad %[[INPUT]]
1717 // CHECK-SAME: low[0, 4, 1, 1] high[0, 4, 1, 1] {
1718 // CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
1719 // CHECK: tensor.yield %[[CST]] : f32
1720 // CHECK: } : tensor<?x64x?x?xf32> to tensor<?x72x?x?xf32>
1721 // CHECK: tensor.cast
1722 func.func @pad_fold_static(%arg0: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
1723 %c0 = arith.constant 0 : index
1724 %cst = arith.constant 0.000000e+00 : f32
1725 %padding = arith.constant 4 : index
1726 %padded = tensor.pad %arg0 low[0, %padding, 1, 1] high[0, %padding, 1, 1] {
1727 ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
1728 tensor.yield %cst: f32
1729 } : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
1730 return %padded : tensor<?x?x?x?xf32>
1735 // CHECK-LABEL: func @pad_nofold_same_static_shape(
1736 // CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>
1737 // CHECK: %[[PAD:.*]] = tensor.pad
1738 // CHECK: return %[[PAD]]
1739 func.func @pad_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
1740 -> tensor<5x6xf32> {
1741 %cst = arith.constant 0.000000e+00 : f32
1742 %0 = tensor.pad %arg0 nofold low[%a, 0] high[0, %a] {
1743 ^bb0(%arg1: index, %arg2: index):
1744 tensor.yield %cst : f32
1745 } : tensor<5x6xf32> to tensor<5x6xf32>
1746 return %0 : tensor<5x6xf32>
1751 // CHECK-LABEL: func @pad_after_cast_different_shape(
1752 // CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
1753 // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
1754 // CHECK: %[[PADDED:.*]] = tensor.pad %[[INPUT]]
1755 // CHECK-SAME: low[0, 0, 1, 1] high[0, 0, 1, 1] {
1756 // CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
1757 // CHECK: tensor.yield %[[CST]] : f32
1758 // CHECK: } : tensor<?x64x?x?xf32> to tensor<?x64x?x?xf32>
1759 // CHECK: %[[DYNAMIC:.*]] = tensor.cast %[[PADDED:.*]] :
1760 // CHECK-SAME: tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
1761 // CHECK: return %[[DYNAMIC]] : tensor<?x?x?x?xf32>
1763 func.func @pad_after_cast_different_shape(%arg0: tensor<?x64x?x?xf32>)
1764 -> tensor<?x?x?x?xf32> {
1765 %cst = arith.constant 0.000000e+00 : f32
1766 %dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
1767 %padded = tensor.pad %dynamic low[0, 0, 1, 1] high[0, 0, 1, 1] {
1768 ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
1769 tensor.yield %cst: f32
1770 } : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
1771 return %padded: tensor<?x?x?x?xf32>
1776 // CHECK-LABEL: func @pad_after_cast_same_shape(
1777 // CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>,
1778 // CHECK-SAME: %[[PADDING:.*]]: index) -> tensor<?x?x?x?xf32> {
1779 // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
1780 // CHECK: %[[PADDED:.*]] = tensor.pad %[[INPUT]]
1781 // CHECK-SAME: low[0, %[[PADDING]], 1, 1] high[0, %[[PADDING]], 1, 1] {
1782 // CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
1783 // CHECK: tensor.yield %[[CST]] : f32
1784 // CHECK: } : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
1785 // CHECK: return %[[PADDED:.*]] : tensor<?x?x?x?xf32>
1787 func.func @pad_after_cast_same_shape(%arg0: tensor<?x64x?x?xf32>, %padding : index)
1788 -> tensor<?x?x?x?xf32> {
1789 %cst = arith.constant 0.000000e+00 : f32
1790 %dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
1791 %padded = tensor.pad %dynamic low[0, %padding, 1, 1] high[0, %padding, 1, 1] {
1792 ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
1793 tensor.yield %cst: f32
1794 } : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
1795 return %padded: tensor<?x?x?x?xf32>
1800 // CHECK-LABEL: func @pad_of_cast(
1801 // CHECK-NOT: tensor.cast
1802 // CHECK: tensor.pad
1803 // CHECK: tensor<8x?xf32> to tensor<8x32xf32>
1804 func.func @pad_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> {
1805 %c0 = arith.constant 0 : index
1806 %cst = arith.constant 0.000000e+00 : f32
1807 %0 = tensor.cast %t : tensor<8x?xf32> to tensor<?x?xf32>
1808 %1 = tensor.pad %0 low[%c0, %c0] high[%c0, %s] {
1809 ^bb0(%arg9: index, %arg10: index):
1810 tensor.yield %cst : f32
1811 } : tensor<?x?xf32> to tensor<8x32xf32>
1812 return %1 : tensor<8x32xf32>
1817 // CHECK-LABEL: @cast_of_pad_more_static
1818 func.func @cast_of_pad_more_static(%arg0: tensor<?x?xf32>, %padding: index) -> tensor<32x32xf32> {
1819 %cst = arith.constant 0.000000e+00 : f32
1820 // CHECK: %[[PAD:.*]] = tensor.pad
1821 // CHECK: tensor<?x?xf32> to tensor<32x32xf32>
1822 %padded = tensor.pad %arg0 low[%padding, %padding] high[0, 0] {
1823 ^bb0(%arg1: index, %arg2: index):
1824 tensor.yield %cst : f32
1825 } : tensor<?x?xf32> to tensor<?x?xf32>
1826 // CHECK-NOT: tensor.cast
1827 %casted = tensor.cast %padded : tensor<?x?xf32> to tensor<32x32xf32>
1828 // CHECK: return %[[PAD]]
1829 return %casted : tensor<32x32xf32>
1834 // CHECK-LABEL: @cast_of_pad_less_static
1835 func.func @cast_of_pad_less_static(%arg0: tensor<32x?x?xf32>, %padding: index) -> tensor<?x32x32xf32> {
1836 %cst = arith.constant 0.000000e+00 : f32
1837 // CHECK: tensor.pad
1838 %padded = tensor.pad %arg0 low[%padding, %padding, %padding] high[0, 0, 0] {
1839 ^bb0(%arg1: index, %arg2: index, %arg3: index):
1840 tensor.yield %cst : f32
1841 } : tensor<32x?x?xf32> to tensor<32x?x?xf32>
1842 // CHECK: %[[CAST:.*]] = tensor.cast
1843 %casted = tensor.cast %padded : tensor<32x?x?xf32> to tensor<?x32x32xf32>
1844 // CHECK: return %[[CAST]]
1845 return %casted : tensor<?x32x32xf32>
1850 func.func @pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
1851 %c0 = arith.constant 0 : index
1852 %cst = arith.constant 0.0 : f32
1853 %0 = tensor.cast %arg0 : tensor<4x4xf32> to tensor<?x?xf32>
1854 %1 = tensor.pad %0 low[%c0, %c0] high[%c0, %c0] {
1855 ^bb0(%arg1: index, %arg2: index):
1856 tensor.yield %cst : f32
1857 } : tensor<?x?xf32> to tensor<4x4xf32>
1858 return %1 : tensor<4x4xf32>
1860 // CHECK-LABEL: @pad_cast
1861 // CHECK-SAME: %[[ARG0:.+]]: tensor<4x4xf32>
1862 // CHECK: return %[[ARG0]]
1866 // CHECK-LABEL: func @fold_pad_source_cast(
1867 // CHECK-SAME: %[[ARG0:.*]]: tensor<4x?xf32>
1868 // CHECK-NOT: tensor.cast
1869 // CHECK: %[[RESULT:.*]] = tensor.pad %[[ARG0]]
1870 func.func @fold_pad_source_cast(%arg0: tensor<4x?xf32>) -> tensor<4x4xf32> {
1871 %cst = arith.constant 0.0 : f32
1872 %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
1873 %1 = tensor.pad %0 low[0, 0] high[0, 1] {
1874 ^bb0(%arg1: index, %arg2: index):
1875 tensor.yield %cst : f32
1876 } : tensor<?x?xf32> to tensor<4x4xf32>
1877 return %1 : tensor<4x4xf32>
1882 // CHECK-LABEL: func @pad_static_zero_cast(
1883 // CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?xf32>
1884 // CHECK-NOT: tensor.pad
1885 // CHECK: %[[RESULT:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
1886 // CHECK: return %[[RESULT]]
1887 func.func @pad_static_zero_cast(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
1888 %c0 = arith.constant 0 : index
1889 %0 = tensor.pad %arg0 low[0, %c0, 0] high[0, 0, %c0] {
1890 ^bb0(%arg1: index, %arg2: index, %arg3: index):
1891 tensor.yield %pad_value : f32
1892 } : tensor<?x?x?xf32> to tensor<2x3x4xf32>
1894 return %0 : tensor<2x3x4xf32>
1899 // CHECK-LABEL: func @pad_nofold_static_zero(
1900 // CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?xf32>
1901 // CHECK: %[[PAD:.*]] = tensor.pad
1902 // CHECK: return %[[PAD]]
1903 func.func @pad_nofold_static_zero(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
1904 %c0 = arith.constant 0 : index
1905 %0 = tensor.pad %arg0 nofold low[0, %c0, 0] high[0, 0, %c0] {
1906 ^bb0(%arg1: index, %arg2: index, %arg3: index):
1907 tensor.yield %pad_value : f32
1908 } : tensor<?x?x?xf32> to tensor<2x3x4xf32>
1910 return %0 : tensor<2x3x4xf32>
1915 // CHECK-LABEL: func @fold_orthogonal_pad_chains(
1916 // CHECK-SAME: %[[ARG0:.*]]: tensor<64x64xf32>,
1917 // CHECK-SAME: %[[SZ0:.*]]: index, %[[SZ1:.*]]: index, %[[PW0:.*]]: index, %[[PW1:.*]]: index
1918 func.func @fold_orthogonal_pad_chains(%arg0: tensor<64x64xf32>,
1919 %sz0 : index, %sz1 : index,
1920 %pw0 : index, %pw1 : index) -> tensor<8x4xf32> {
1921 // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
1922 // CHECK-SAME: [16, 4] [%[[SZ0]], %[[SZ1]]]
1923 // CHECK: %[[PAD:.*]] = tensor.pad %[[T0]] nofold
1924 // CHECK-SAME: high[%[[PW0]], %[[PW1]]]
1925 // CHECK: return %[[PAD]]
1926 %pad_value = arith.constant 0.0 : f32
1927 %0 = tensor.extract_slice %arg0[16, 0] [%sz0, 64] [1, 1] : tensor<64x64xf32> to tensor<?x64xf32>
1928 %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] {
1929 ^bb0(%arg1: index, %arg2: index):
1930 tensor.yield %pad_value : f32
1931 } : tensor<?x64xf32> to tensor<8x64xf32>
1932 %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32>
1933 %3 = tensor.pad %2 nofold low[0, 0] high[0, %pw1] {
1934 ^bb0(%arg1: index, %arg2: index):
1935 tensor.yield %pad_value : f32
1936 } : tensor<8x?xf32> to tensor<8x4xf32>
1937 func.return %3 : tensor<8x4xf32>
1942 // CHECK-LABEL: func @dont_fold_pad_chains(
1943 // CHECK-SAME: %[[ARG0:.*]]: tensor<64x64xf32>,
1944 // CHECK-SAME: %[[SZ0:.*]]: index, %[[SZ1:.*]]: index, %[[PW0:.*]]: index, %[[PW1:.*]]: index
1945 func.func @dont_fold_pad_chains(%arg0: tensor<64x64xf32>,
1946 %sz0 : index, %sz1 : index,
1947 %pw0 : index, %pw1 : index) -> (tensor<8x4xf32>, tensor<4x64xf32>, tensor<8x4xf32>, tensor<6x4xf32>) {
1948 // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
1949 // CHECK: %[[T1:.*]] = tensor.pad %[[T0]]
1950 %pad_value = arith.constant 0.0 : f32
1951 %0 = tensor.extract_slice %arg0[16, 0] [%sz0, 64] [1, 1] : tensor<64x64xf32> to tensor<?x64xf32>
1952 %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] {
1953 ^bb0(%arg1: index, %arg2: index):
1954 tensor.yield %pad_value : f32
1955 } : tensor<?x64xf32> to tensor<8x64xf32>
1957 // Don't fold if the padding values are different.
1958 // CHECK: %[[T2:.*]] = tensor.extract_slice %[[T1]]
1959 // CHECK-SAME: [0, 4] [8, %[[SZ1]]]
1960 // CHECK: %[[PAD0:.*]] = tensor.pad %[[T2]]
1961 %different_value = arith.constant 1.0 : f32
1962 %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32>
1963 %3 = tensor.pad %2 nofold low[0, 0] high[0, %pw1] {
1964 ^bb0(%arg1: index, %arg2: index):
1965 tensor.yield %different_value : f32
1966 } : tensor<8x?xf32> to tensor<8x4xf32>
1968 // Don't fold if the pad ops have common padding dimensions.
1969 // CHECK: %[[T3:.*]] = tensor.extract_slice %[[T1]]
1970 // CHECK-SAME: [4, 0] [%[[SZ1]], 64]
1971 // CHECK: %[[PAD1:.*]] = tensor.pad %[[T3]]
1972 %4 = tensor.extract_slice %1[4, 0] [%sz1, 64] [1, 1] : tensor<8x64xf32> to tensor<?x64xf32>
1973 %5 = tensor.pad %4 nofold low[0, 0] high[%pw1, 0] {
1974 ^bb0(%arg1: index, %arg2: index):
1975 tensor.yield %pad_value : f32
1976 } : tensor<?x64xf32> to tensor<4x64xf32>
1978 // Don't fold if padded source tensor dimension is accessed at an offset.
1979 // CHECK: %[[T4:.*]] = tensor.extract_slice %[[T1]]
1980 // CHECK-SAME: [%[[SZ0]], 4] [8, %[[SZ1]]
1981 // CHECK: %[[PAD2:.*]] = tensor.pad %[[T4]]
1982 %6 = tensor.extract_slice %1[%sz0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32>
1983 %7 = tensor.pad %6 nofold low[0, 0] high[0, %pw1] {
1984 ^bb0(%arg1: index, %arg2: index):
1985 tensor.yield %pad_value : f32
1986 } : tensor<8x?xf32> to tensor<8x4xf32>
1988 // Don't fold if a padded source tensor dimension is sliced.
1989 // CHECK: %[[T5:.*]] = tensor.extract_slice %[[T1]]
1990 // CHECK-SAME: [0, 4] [6, %[[SZ1]]
1991 // CHECK: %[[PAD3:.*]] = tensor.pad %[[T5]]
1992 %8 = tensor.extract_slice %1[0, 4] [6, %sz1] [1, 1] : tensor<8x64xf32> to tensor<6x?xf32>
1993 %9 = tensor.pad %8 nofold low[0, 0] high[0, %pw1] {
1994 ^bb0(%arg1: index, %arg2: index):
1995 tensor.yield %pad_value : f32
1996 } : tensor<6x?xf32> to tensor<6x4xf32>
1998 // CHECK: return %[[PAD0]], %[[PAD1]], %[[PAD2]], %[[PAD3]]
1999 func.return %3, %5, %7, %9 : tensor<8x4xf32>, tensor<4x64xf32>, tensor<8x4xf32>, tensor<6x4xf32>
2004 // CHECK-LABEL: func @merge_constant_padding
2005 // CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<2x3xf32>
2006 // CHECK-SAME: %[[PADVAL:[A-Za-z0-9]+]]: f32
2007 // CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] low[1, 3] high[4, 2]
2008 // CHECK: tensor.yield %[[PADVAL]]
2009 // CHECK: return %[[PAD]]
2010 func.func @merge_constant_padding(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> {
2011 %pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
2012 ^bb0(%b0: index, %b1 : index):
2013 tensor.yield %pad_value : f32
2014 } : tensor<2x3xf32> to tensor<4x4xf32>
2015 %pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
2016 ^bb0(%b2: index, %b3 : index):
2017 tensor.yield %pad_value : f32
2018 } : tensor<4x4xf32> to tensor<7x8xf32>
2019 return %pad1 : tensor<7x8xf32>
2024 // CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 + 1)>
2025 // CHECK-LABEL: func @merge_constant_padding_dynamic
2026 // CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<?x?xf32>
2027 // CHECK-SAME: %[[IDX:[A-Za-z0-9]+]]: index
2028 // CHECK-SAME: %[[PADVAL:[A-Za-z0-9]+]]: f32
2029 // CHECK: %[[HIGH:.+]] = affine.apply #[[$MAP]]()[%[[IDX]]]
2030 // CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]] low[%[[IDX]], 3] high[%[[HIGH]], 2]
2031 // CHECK: tensor.yield %[[PADVAL]]
2032 // CHECK: return %[[PAD]]
2033 func.func @merge_constant_padding_dynamic(%arg0: tensor<?x?xf32>, %idx: index, %pad_value: f32) -> tensor<?x?xf32> {
2034 %pad0 = tensor.pad %arg0 low[%idx, 1] high[1, 0] {
2035 ^bb0(%b0: index, %b1 : index):
2036 tensor.yield %pad_value : f32
2037 } : tensor<?x?xf32> to tensor<?x?xf32>
2038 %pad1 = tensor.pad %pad0 low[0, 2] high[%idx, 2] {
2039 ^bb0(%b2: index, %b3 : index):
2040 tensor.yield %pad_value : f32
2041 } : tensor<?x?xf32> to tensor<?x?xf32>
2042 return %pad1 : tensor<?x?xf32>
2047 // Verify that folding does not happen if it would drop a nofold attribute
2048 // CHECK-LABEL: func @dont_merge_constant_padding_nofold
2049 // CHECK: tensor.pad {{.*}} nofold
2050 // CHECK: tensor.pad
2051 func.func @dont_merge_constant_padding_nofold(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> {
2052 %pad0 = tensor.pad %arg0 nofold low[1, 1] high[1, 0] {
2053 ^bb0(%b0: index, %b1 : index):
2054 tensor.yield %pad_value : f32
2055 } : tensor<2x3xf32> to tensor<4x4xf32>
2056 %pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
2057 ^bb0(%b2: index, %b3 : index):
2058 tensor.yield %pad_value : f32
2059 } : tensor<4x4xf32> to tensor<7x8xf32>
2060 return %pad1 : tensor<7x8xf32>
2065 // Verify that folding does not happen if it would drop a nofold attribute
2066 // CHECK-LABEL: func @dont_merge_constant_padding_different_vals
2067 // CHECK: tensor.pad
2068 // CHECK: tensor.pad
2069 func.func @dont_merge_constant_padding_different_vals(
2070 %arg0: tensor<2x3xf32>,
2072 %pad_value1: f32) -> tensor<7x8xf32> {
2073 %pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
2074 ^bb0(%b0: index, %b1 : index):
2075 tensor.yield %pad_value0 : f32
2076 } : tensor<2x3xf32> to tensor<4x4xf32>
2077 %pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
2078 ^bb0(%b2: index, %b3 : index):
2079 tensor.yield %pad_value1 : f32
2080 } : tensor<4x4xf32> to tensor<7x8xf32>
2081 return %pad1 : tensor<7x8xf32>
2086 // CHECK-LABEL: func @fold_collapse_shape_from_elements
2087 func.func @fold_collapse_shape_from_elements(%arg0: i32) -> tensor<i32> {
2088 // CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor<i32>
2089 // CHECK: return %[[FROM]] : tensor<i32>
2090 %0 = tensor.from_elements %arg0 : tensor<1xi32>
2091 %1 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor<i32>
2092 return %1 : tensor<i32>
2097 // CHECK-LABEL: func @fold_expand_shape_from_elements
2098 func.func @fold_expand_shape_from_elements(%arg0: i32) -> tensor<1xi32> {
2099 // CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor<1xi32>
2100 // CHECK: return %[[FROM]] : tensor<1xi32>
2101 %0 = tensor.from_elements %arg0 : tensor<i32>
2102 %1 = tensor.expand_shape %0 [] output_shape [1] : tensor<i32> into tensor<1xi32>
2103 return %1 : tensor<1xi32>
2108 // CHECK-LABEL: func @propagate_index_cast
2109 func.func @propagate_index_cast(%arg0: tensor<1xi32>) -> index {
2110 // CHECK: %[[IDX:.+]] = arith.constant 0
2111 // CHECK: %[[EXT:.+]] = tensor.extract %arg0[%[[IDX]]] : tensor<1xi32>
2112 // CHECK: %[[CAST:.+]] = arith.index_cast %[[EXT]]
2113 // CHECK: return %[[CAST]] : index
2114 %c0 = arith.constant 0 : index
2115 %0 = arith.index_cast %arg0 : tensor<1xi32> to tensor<1xindex>
2116 %1 = tensor.extract %0[%c0] : tensor<1xindex>
2122 // CHECK-LABEL: func @splat_fold
2123 func.func @splat_fold() -> tensor<4xf32> {
2124 %c = arith.constant 1.0 : f32
2125 %t = tensor.splat %c : tensor<4xf32>
2126 return %t : tensor<4xf32>
2128 // CHECK-NEXT: [[T:%.*]] = arith.constant dense<1.000000e+00> : tensor<4xf32>
2129 // CHECK-NEXT: return [[T]] : tensor<4xf32>
2134 // CHECK-LABEL: func @splat_dynamic_no_fold
2135 // CHECK-SAME: %[[M:.+]]: index
2136 func.func @splat_dynamic_no_fold(%m: index) -> tensor<4x?xf32> {
2137 // CHECK: %[[F:.+]] = arith.constant
2138 %f = arith.constant 1.0 : f32
2140 // CHECK: tensor.splat %[[F]][%[[M]]] : tensor<4x?xf32>
2141 %t = tensor.splat %f[%m] : tensor<4x?xf32>
2142 return %t : tensor<4x?xf32>
2147 // CHECK-LABEL: func @cast_extract_slice
2148 func.func @cast_extract_slice(%arg0 : tensor<128x512xf32>, %s : index, %o : index)
2149 -> tensor<16x512xf32> {
2150 // CHECK: %[[E:.*]] = tensor.extract_slice %{{.*}}[%{{.*}}, 0] [16, 512] [1, 1] : tensor<128x512xf32> to tensor<16x512xf32>
2151 %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] : tensor<128x512xf32> to tensor<?x512xf32>
2152 %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
2153 // CHECK: return %[[E]] : tensor<16x512xf32>
2154 return %1 : tensor<16x512xf32>
2159 // CHECK-LABEL: func @cast_extract_slice_rank_reduce
2160 func.func @cast_extract_slice_rank_reduce(%arg0 : tensor<128x512xf32>, %s : index, %o : index)
2162 // CHECK: %[[E:.*]] = tensor.extract_slice %{{.*}}[%{{.*}}, 0] [16, 1] [1, 1] : tensor<128x512xf32> to tensor<16xf32>
2163 %0 = tensor.extract_slice %arg0[%o, 0] [%s, 1] [1, 1] : tensor<128x512xf32> to tensor<?xf32>
2164 %1 = tensor.cast %0 : tensor<?xf32> to tensor<16xf32>
2165 // CHECK: return %[[E]] : tensor<16xf32>
2166 return %1 : tensor<16xf32>
2171 // CHECK-LABEL: func.func @canonicalize_parallel_insert_slice_indices(
2172 // CHECK-SAME: %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>,
2173 // CHECK-SAME: %[[arg1:[0-9a-z]*]]: tensor<?x?xf32>,
2174 // CHECK-SAME: %[[num_threads:[0-9a-z]*]]: index
2175 func.func @canonicalize_parallel_insert_slice_indices(
2176 %arg0 : tensor<1x5xf32>, %arg1: tensor<?x?xf32>,
2177 %num_threads : index) -> tensor<?x?xf32>
2179 %cst = arith.constant 4.200000e+01 : f32
2180 %c0 = arith.constant 0 : index
2181 %c1 = arith.constant 1 : index
2183 // CHECK-NOT: tensor.cast
2184 // CHECK: scf.forall (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) shared_outs(%[[o:.*]] = %[[arg1]]) -> (tensor<?x?xf32>) {
2185 // CHECK-NEXT: scf.forall.in_parallel {
2186 // CHECK-NEXT: tensor.parallel_insert_slice %[[arg0]] into %[[o]][%[[tidx]], 0] [1, 5] [1, 1]
2187 %2 = scf.forall (%tidx) in (%num_threads) shared_outs(%o = %arg1) -> (tensor<?x?xf32>) {
2188 %3 = tensor.cast %arg0 : tensor<1x5xf32> to tensor<?x5xf32>
2189 scf.forall.in_parallel {
2190 tensor.parallel_insert_slice %3 into %o[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor<?x5xf32> into tensor<?x?xf32>
2193 return %2 : tensor<?x?xf32>
2198 // CHECK-LABEL: func.func @fold_insert_slice_after_extract_slice
2199 // CHECK-SAME: (%[[INPUT:.+]]: tensor<1x2x2x4xf32>)
2200 func.func @fold_insert_slice_after_extract_slice(%input: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> {
2201 %c0 = arith.constant 0 : index
2202 %0 = tensor.extract_slice %input[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32>
2203 %1 = tensor.insert_slice %0 into %input[%c0, 0, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32>
2204 // CHECK: return %[[INPUT]]
2205 return %1: tensor<1x2x2x4xf32>
2210 // CHECK-LABEL: func.func @dont_fold_mismatched_source_dst
2211 func.func @dont_fold_mismatched_source_dst(%input0: tensor<1x2x2x4xf32>, %input1: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> {
2212 %c0 = arith.constant 0 : index
2213 // CHECK: tensor.extract_slice
2214 %0 = tensor.extract_slice %input0[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32>
2215 // CHECK: tensor.insert_slice
2216 %1 = tensor.insert_slice %0 into %input1[%c0, 0, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32>
2217 return %1: tensor<1x2x2x4xf32>
2222 // CHECK-LABEL: func.func @dont_fold_mismatched_parameters
2223 func.func @dont_fold_mismatched_parameters(%input: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> {
2224 %c0 = arith.constant 0 : index
2225 // CHECK: tensor.extract_slice
2226 %0 = tensor.extract_slice %input[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32>
2227 // CHECK: tensor.insert_slice
2228 %1 = tensor.insert_slice %0 into %input[%c0, 1, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32>
2229 return %1: tensor<1x2x2x4xf32>
2234 func.func @empty_canonicalize() -> (tensor<4x5x?xf32>) {
2235 %c6 = arith.constant 6 : index
2236 %0 = tensor.empty(%c6) : tensor<4x5x?xf32>
2237 return %0 : tensor<4x5x?xf32>
2239 // CHECK: func @empty_canonicalize
2240 // CHECK: %[[T0:.+]] = tensor.empty() : tensor<4x5x6xf32>
2241 // CHECK: %[[T1:.+]] = tensor.cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32>
2242 // CHECK: return %[[T1]]
2246 func.func @fold_empty_tensor_with_cast(%arg0 : index) -> tensor<1x12xf32> {
2247 %0 = tensor.empty(%arg0) : tensor<?x12xf32>
2248 %1 = tensor.cast %0 : tensor<?x12xf32> to tensor<1x12xf32>
2249 return %1 : tensor<1x12xf32>
2251 // CHECK: func @fold_empty_tensor_with_cast(%[[ARG0:.+]]: index)
2252 // CHECK: %[[T0:.+]] = tensor.empty() : tensor<1x12xf32>
2253 // CHECK: return %[[T0]] : tensor<1x12xf32>
2257 func.func private @some_use(%i : index, %j : index)
2259 // CHECK-LABEL: func @empty_tensor_canonicalize
2260 // CHECK-SAME: %[[I:.*]]: index
2261 func.func @empty_tensor_canonicalize(%i : index) {
2262 %c0 = arith.constant 0 : index
2263 %c1 = arith.constant 1 : index
2265 // CHECK-NOT: tensor.empty
2266 %0 = tensor.empty(%i) : tensor<?x42xf32>
2268 // CHECK-NOT: tensor.dim
2269 %1 = tensor.dim %0, %c0: tensor<?x42xf32>
2270 %2 = tensor.dim %0, %c1: tensor<?x42xf32>
2272 // CHECK: %[[c42:.*]] = arith.constant 42 : index
2273 // CHECK: call @some_use(%[[I]], %[[c42]])
2274 call @some_use(%1, %2) : (index, index) -> ()
2281 // CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 floordiv 40)>
2282 // CHECK-LABEL: func @dim_of_expand_shape(
2283 // CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
2284 // CHECK: %[[c1:.*]] = arith.constant 1 : index
2285 // CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c1]] : tensor<?x?xf32>
2286 // CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim]]]
2287 // CHECK: return %[[apply]]
2288 func.func @dim_of_expand_shape(%t: tensor<?x?xf32>, %sz0: index, %sz1: index) -> index {
2289 %c2 = arith.constant 2 : index
2290 %0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]] output_shape [%sz0, 1, %sz1, 5, 1, 8]
2291 : tensor<?x?xf32> into tensor<?x1x?x5x1x8xf32>
2292 %1 = tensor.dim %0, %c2 : tensor<?x1x?x5x1x8xf32>
2298 // CHECK: #[[$map:.*]] = affine_map<()[s0, s1, s2] -> (((s0 * s1) * s2) * 7)>
2299 // CHECK-LABEL: func @dim_of_collapse_shape(
2300 // CHECK-SAME: %[[t:.*]]: tensor<?x?x?x7x?xf32>
2301 // CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
2302 // CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index
2303 // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
2304 // CHECK-DAG: %[[dim1:.*]] = tensor.dim %[[t]], %[[c1]]
2305 // CHECK-DAG: %[[dim2:.*]] = tensor.dim %[[t]], %[[c2]]
2306 // CHECK-DAG: %[[dim4:.*]] = tensor.dim %[[t]], %[[c4]]
2307 // CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim1]], %[[dim2]], %[[dim4]]]
2308 // CHECK: return %[[apply]]
2309 func.func @dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
2310 %c1 = arith.constant 1 : index
2311 %0 = tensor.collapse_shape %t [[0], [1, 2, 3, 4]]
2312 : tensor<?x?x?x7x?xf32> into tensor<?x?xf32>
2313 %1 = tensor.dim %0, %c1 : tensor<?x?xf32>
2319 // CHECK-LABEL: func @collapse_expand_fold_to_cast(
2320 // CHECK-SAME: %[[t:.*]]: tensor<?xf32>
2321 // CHECK: return %[[t]]
2322 func.func @collapse_expand_fold_to_cast(%t: tensor<?xf32>, %sz0: index) -> (tensor<?xf32>)
2324 %0 = tensor.expand_shape %t [[0, 1]] output_shape [1, %sz0] : tensor<?xf32> into tensor<1x?xf32>
2325 %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
2326 return %1 : tensor<?xf32>
2331 // Chain: NC -> NCnc -> NCnc -> NC
2332 // CHECK: func.func @unpack_pack(
2333 // CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>)
2334 // CHECK: return %[[T]] : tensor<128x128xf32>
2335 func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> {
2336 %tensor_empty = tensor.empty() : tensor<16x16x8x8xf32>
2337 %packed = tensor.pack %t inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
2338 %tensor_empty1 = tensor.empty() : tensor<128x128xf32>
2339 %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
2340 return %unpacked : tensor<128x128xf32>
2345 // Chain: NC -> NCcn -> NCnc -> NC
2346 // CHECK: func.func @unpack_pack(
2347 // CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>)
2348 // CHECK-NOT: return %[[T]] : tensor<128x128xf32>
2349 func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> {
2350 %tensor_empty = tensor.empty() : tensor<16x16x8x8xf32>
2351 %packed = tensor.pack %t inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
2352 %tensor_empty1 = tensor.empty() : tensor<128x128xf32>
2353 %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<16x16x8x8xf32> -> tensor
2355 return %unpacked : tensor<128x128xf32>
2360 // Chain: NC -> CNcn -> NCnc -> NC
2361 // CHECK: func.func @unpack_pack(
2362 // CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>)
2363 // CHECK-NOT: return %[[T]] : tensor<128x128xf32>
2364 func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> {
2365 %tensor_empty = tensor.empty() : tensor<16x16x8x8xf32>
2366 %packed = tensor.pack %t outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
2367 %tensor_empty1 = tensor.empty() : tensor<128x128xf32>
2368 %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<16x16x8x8xf32> -> tensor
2370 return %unpacked : tensor<128x128xf32>
2375 // Chain: NC -> NCnc -> NCnc -> NC
2376 // CHECK: func.func @unpack_pack(
2377 // CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>,
2378 // CHECK: return %[[T]] : tensor<128x128xf32>
2379 func.func @unpack_pack(%t: tensor<128x128xf32>, %tile1: index, %tile2: index) -> tensor<128x128xf32> {
2380 %tensor_empty = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
2381 %packed = tensor.pack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
2382 %tensor_empty1 = tensor.empty() : tensor<128x128xf32>
2383 %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<16x16x?x?xf32> -> tensor
2385 return %unpacked : tensor<128x128xf32>
2390 // CHECK: func.func @unpack_pack_with_padding_no_canonicalization(
2391 // CHECK: tensor.pack
2392 // CHECK: tensor.unpack
2393 func.func @unpack_pack_with_padding_no_canonicalization(%t: tensor<256x512xbf16>) -> tensor<224x512xbf16> {
2394 %tensor_empty = tensor.empty() : tensor<4x16x64x32xbf16>
2395 %tensor_empty1 = tensor.empty() : tensor<224x512xbf16>
2396 %packed = tensor.pack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty : tensor<256x512xbf16> -> tensor<4x16x64x32xbf16>
2397 %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty1 : tensor<4x16x64x32xbf16> -> tensor<224x512xbf16>
2398 return %unpacked : tensor<224x512xbf16>
2403 // Chain NCnc -> NC -> NC -> NCnc
2404 // CHECK: func.func @pack_unpack(
2405 // CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
2406 // CHECK: return %[[T]] : tensor<16x16x?x?xf32>
2407 func.func @pack_unpack(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
2408 %tensor_empty = tensor.empty() : tensor<128x128xf32>
2409 %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
2410 %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
2411 %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
2412 return %packed : tensor<16x16x?x?xf32>
2417 // Chain NCnc -> NC -> NC -> NCnc
2418 // CHECK: func.func @pack_unpack(
2419 // CHECK-SAME: %[[T:.+]]: tensor<16x16x8x8xf32>
2420 // CHECK: return %[[T]] : tensor<16x16x8x8xf32>
2421 func.func @pack_unpack(%t: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> {
2422 %tensor_empty = tensor.empty() : tensor<128x128xf32>
2423 %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
2424 %tensor_empty1 = tensor.empty() : tensor<16x16x8x8xf32>
2425 %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
2426 return %packed : tensor<16x16x8x8xf32>
2431 // CHECK: func.func @pack_unpack_same_tiles(
2432 // CHECK-SAME: %[[T:.+]]: tensor<?x?x?x?xf32>,
2433 // CHECK: return %[[T]] : tensor<?x?x?x?xf32>
2434 func.func @pack_unpack_same_tiles(%t: tensor<?x?x?x?xf32>, %dim1: index, %dim2: index, %dim3: index, %dim4: index, %dim5: index, %dim6: index,
2435 %tile1: index, %tile2: index) -> tensor<?x?x?x?xf32> {
2436 %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
2437 %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
2438 %tensor_empty1 = tensor.empty(%dim3, %dim4, %dim5, %dim6) : tensor<?x?x?x?xf32>
2439 %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
2440 return %packed : tensor<?x?x?x?xf32>
2445 // CHECK: func.func @pack_unpack_different_tiles(
2446 // CHECK-SAME: %[[T:.+]]: tensor<?x?x?x?xf32>,
2447 // CHECK-NOT: return %[[T]] : tensor<?x?x?x?xf32>
2448 func.func @pack_unpack_different_tiles(%t: tensor<?x?x?x?xf32>, %dim1: index, %dim2: index, %dim3: index, %dim4: index, %dim5: index, %dim6: index,
2449 %tile1: index, %tile2: index) -> tensor<?x?x?x?xf32> {
2450 %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
2451 %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
2452 %tensor_empty1 = tensor.empty(%dim3, %dim4, %dim5, %dim6) : tensor<?x?x?x?xf32>
2453 %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile2, %tile1] into %tensor_empty1 : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
2454 return %packed : tensor<?x?x?x?xf32>
2459 // CHECK: func.func @pack_unpack_dynamic_with_padding(
2460 // CHECK-SAME: %[[T:.+]]: tensor<?x?x?x?xf32>,
2461 // CHECK-NOT: return %[[T]] : tensor<?x?x?x?xf32>
2462 func.func @pack_unpack_dynamic_with_padding(%t: tensor<?x?x?x?xf32>, %dim1: index, %dim2: index, %dim3: index, %dim4: index, %dim5: index, %dim6: index,
2463 %tile1: index, %tile2: index, %pad: f32) -> tensor<?x?x?x?xf32> {
2464 %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
2465 %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
2466 %tensor_empty1 = tensor.empty(%dim3, %dim4, %dim5, %dim6) : tensor<?x?x?x?xf32>
2467 %packed = tensor.pack %unpacked padding_value(%pad: f32) inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
2468 return %packed : tensor<?x?x?x?xf32>
2473 // CHECK: func.func @pack_outer_dims_unpack_no_outer_dims(
2474 // CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
2475 // CHECK: return %[[T]] : tensor<16x16x?x?xf32>
2476 func.func @pack_outer_dims_unpack_no_outer_dims(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
2477 %tensor_empty = tensor.empty() : tensor<128x128xf32>
2478 %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
2479 %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
2480 %packed = tensor.pack %unpacked outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
2481 return %packed : tensor<16x16x?x?xf32>
2486 // CHECK: func.func @pack_no_outer_dims_unpack_outer_dims(
2487 // CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
2488 // CHECK: return %[[T]] : tensor<16x16x?x?xf32>
2489 func.func @pack_no_outer_dims_unpack_outer_dims(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
2490 %tensor_empty = tensor.empty() : tensor<128x128xf32>
2491 %unpacked = tensor.unpack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
2492 %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
2493 %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
2494 return %packed : tensor<16x16x?x?xf32>
2499 // CHECK: func.func @invalid_empty_negative_size
2500 // CHECK: %[[IDX:.*]] = index.constant
2501 // CHECK: %[[T:.*]] = tensor.empty(%[[IDX]]) : tensor<4x5x?xf32>
2502 func.func @invalid_empty_negative_size() -> (tensor<4x5x?xf32>) {
2503 %c1 = arith.constant 1 : index
2504 %cn2 = arith.constant 2 : index
2505 %0 = index.sub %c1, %cn2
2506 %1 = tensor.empty(%0) : tensor<4x5x?xf32>
2507 return %1 : tensor<4x5x?xf32>
2512 // Fold DstStyleOp -> tensor.unpack operations.
2513 func.func @fold_dst_style_ops_into_unpack(%arg0 : tensor<?x?x16x64xf32>, %init : tensor<?x?xf32>) -> tensor<?x?xf32> {
2514 %cst = arith.constant 0.0 : f32
2515 %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
2516 %unpack = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %fill : tensor<?x?x16x64xf32> -> tensor<?x?xf32>
2517 return %unpack : tensor<?x?xf32>
2519 // CHECK-LABEL: func @fold_dst_style_ops_into_unpack
2520 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x16x64xf32>
2521 // CHECK-SAME: %[[INIT:.+]]: tensor<?x?xf32>
2522 // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
2523 // CHECK-SAME: into %[[INIT]]
2524 // CHECK: return %[[UNPACK]]
2528 // The IR in this test case in invalid. This test tests that the canonicalizer
2531 // CHECK-LABEL: func @invalid_slice_ops(
2532 // CHECK: %[[c:.*]] = arith.constant -5 : index
2533 // CHECK: tensor.extract_slice {{.*}}%[[c]]
2534 // CHECK: tensor.insert_slice {{.*}}%[[c]]
2535 func.func @invalid_slice_ops(%t: tensor<?xf32>, %t2: tensor<?xf32>) -> tensor<?xf32> {
2536 %c = arith.constant -5 : index
2537 %0 = tensor.extract_slice %t[0][%c][1] : tensor<?xf32> to tensor<?xf32>
2538 %1 = tensor.insert_slice %0 into %t2[2][%c][1] : tensor<?xf32> into tensor<?xf32>
2539 return %1 : tensor<?xf32>
2544 // CHECK-LABEL: func @generate_negative_size_verifies(
2545 // CHECK: %[[c:.*]] = arith.constant -8 : index
2546 // CHECK: tensor.generate %[[c]]
2547 // CHECK: : tensor<?x8xi32>
2548 func.func @generate_negative_size_verifies() -> tensor<?x8xi32> {
2549 %cst = arith.constant 0 : i32
2550 %c0 = arith.constant 0 : index
2551 %size = affine.max affine_map<(d0) -> (d0 mod 64 - 8)>(%c0)
2552 %tensor = tensor.generate %size {
2553 ^bb0(%arg0: index, %arg1: index):
2554 tensor.yield %cst : i32
2556 return %tensor : tensor<?x8xi32>
2561 func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> tensor<10x20x4x4xf32> {
2562 %dim1 = arith.constant 40 : index
2563 %dim2 = arith.constant 80 : index
2564 %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
2565 %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty : tensor<10x20x4x4xf32> -> tensor<?x?xf32>
2566 %cast = tensor.cast %unpacked : tensor<?x?xf32> to tensor<40x80xf32>
2567 %tensor_empty1 = tensor.empty() : tensor<10x20x4x4xf32>
2568 %packed = tensor.pack %cast inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty1 : tensor<40x80xf32> -> tensor<10x20x4x4xf32>
2569 return %packed : tensor<10x20x4x4xf32>
2571 // CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
2572 // CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
2573 // CHECK: return %[[SRC]]
2577 // Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
2578 // CHECK-LABEL: func @dim_of_reshape(
2579 // CHECK-SAME: %[[MEM:[0-9a-z]+]]: tensor<*xf32>,
2580 // CHECK-SAME: %[[SHP:[0-9a-z]+]]: tensor<?xindex>
2581 // CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
2582 // CHECK-NEXT: %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]]
2583 // CHECK-NOT: tensor.store
2584 // CHECK-NOT: tensor.dim
2585 // CHECK-NOT: tensor.reshape
2586 // CHECK: return %[[DIM]] : index
2587 func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>)
2589 %c3 = arith.constant 3 : index
2590 %0 = tensor.reshape %arg0(%arg1)
2591 : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
2592 // Update the shape to test that the load ends up in the right place.
2593 tensor.insert %c3 into %arg1[%c3] : tensor<?xindex>
2594 %1 = tensor.dim %0, %c3 : tensor<*xf32>
2600 // Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
2601 // CHECK-LABEL: func @dim_of_reshape_i32(
2602 // CHECK: tensor.extract
2603 // CHECK-NEXT: %[[CAST:.*]] = arith.index_cast
2604 // CHECK-NOT: tensor.dim
2605 // CHECK-NOT: tensor.reshape
2606 // CHECK: return %[[CAST]] : index
2607 func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)
2609 %c3 = arith.constant 3 : index
2610 %0 = tensor.reshape %arg0(%arg1)
2611 : (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
2612 %1 = tensor.dim %0, %c3 : tensor<*xf32>
2618 // Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
2619 // CHECK-LABEL: func @dim_of_reshape_for(
2621 // CHECK-NEXT: tensor.extract
2622 // CHECK-NOT: tensor.dim
2623 // CHECK-NOT: tensor.reshape
2624 func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> index {
2625 %c0 = arith.constant 0 : index
2626 %c1 = arith.constant 1 : index
2627 %c4 = arith.constant 4 : index
2629 %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
2631 %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
2632 %2 = tensor.dim %0, %arg2 : tensor<*xf32>
2633 %3 = arith.muli %arg3, %2 : index
2634 scf.yield %3 : index
2641 // Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
2642 // CHECK-LABEL: func @dim_of_reshape_undominated(
2643 // CHECK: arith.muli
2644 // CHECK-NEXT: tensor.extract
2645 // CHECK-NOT: tensor.dim
2646 // CHECK-NOT: tensor.reshape
2647 func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: index) -> index {
2648 %c4 = arith.constant 4 : index
2649 %reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
2650 %0 = arith.muli %arg2, %c4 : index
2651 %dim = tensor.dim %reshape, %0 : tensor<*xf32>
2657 // CHECK-LABEL: @reshape_fold_2d
2658 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
2659 func.func @reshape_fold_2d(%arg0 : tensor<?x?xi32>) -> tensor<?x?xi32> {
2660 %c0 = arith.constant 0 : index
2661 %c1 = arith.constant 1 : index
2662 %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
2663 %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
2664 %ds = tensor.from_elements %d0, %d1 : tensor<2xindex>
2665 %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
2666 // CHECK: return %[[ARG0]]
2667 return %reshape : tensor<?x?xi32>
2672 // CHECK-LABEL: @reshape_nofold_2d
2673 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
2674 func.func @reshape_nofold_2d(%arg0 : tensor<?x?xi32>) -> tensor<?x?xi32> {
2675 %c0 = arith.constant 0 : index
2676 %c1 = arith.constant 1 : index
2677 %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
2678 %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
2679 %ds = tensor.from_elements %d1, %d0 : tensor<2xindex>
2680 // CHECK: tensor.reshape
2681 %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
2682 return %reshape : tensor<?x?xi32>
2687 // CHECK-LABEL: @reshape_nofold_2d_ins
2688 func.func @reshape_nofold_2d_ins(%arg0 : tensor<?x?xi32>, %arg1: index, %arg2: index) -> tensor<?x?xi32> {
2689 %ds = tensor.from_elements %arg1, %arg2 : tensor<2xindex>
2690 // CHECK: tensor.reshape
2691 %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
2692 return %reshape : tensor<?x?xi32>
2697 // CHECK-LABEL: @reshape_fold_3d_cst
2698 // CHECK-SAME: %[[ARG0:.+]]: tensor<5x?x?xi32>
2699 func.func @reshape_fold_3d_cst(%arg0 : tensor<5x?x?xi32>) -> tensor<5x?x?xi32> {
2700 %c1 = arith.constant 1 : index
2701 %c2 = arith.constant 2 : index
2702 %d0 = arith.constant 5 : index
2703 %d1 = tensor.dim %arg0, %c1 : tensor<5x?x?xi32>
2704 %d2 = tensor.dim %arg0, %c2 : tensor<5x?x?xi32>
2705 %ds = tensor.from_elements %d0, %d1, %d2 : tensor<3xindex>
2706 %reshape = tensor.reshape %arg0(%ds) : (tensor<5x?x?xi32>, tensor<3xindex>) -> tensor<5x?x?xi32>
2707 // CHECK: return %[[ARG0]]
2708 return %reshape : tensor<5x?x?xi32>
2713 // Test case: This test fails to fold because the index of tensor.dim is out_of_bounds
2714 // CHECK-LABEL: func @dim_out_of_bounds(
2715 // CHECK: %[[IDX:.*]] = index.constant 28
2716 // CHECK-NEXT: bufferization.alloc_tensor
2717 // CHECK-NEXT: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[IDX]]
2718 // CHECK-NEXT: memref.alloc
2719 // CHECK-NEXT: memref.cast
2720 // CHECK-NEXT: affine.vector_load %{{.*}}[{{.*}}, {{.*}}, symbol(%[[DIM]])]
2721 // CHECK-NEXT: return
2722 func.func @dim_out_of_bounds() -> vector<7xi32> {
2723 %c1 = arith.constant 1 : index
2724 %idx28 = index.constant 28
2725 %c29 = arith.constant 29 : index
2726 %3 = bufferization.alloc_tensor(%c29) : tensor<?xi16>
2727 %dim = tensor.dim %3, %idx28 : tensor<?xi16>
2728 %alloc_21 = memref.alloc(%c29) : memref<?x26x2xi32>
2729 %16 = affine.vector_load %alloc_21[%c1, %c1, %dim] : memref<?x26x2xi32>, vector<7xi32>
2730 return %16 : vector<7xi32>
2735 // CHECK-LABEL: func.func @fold_cast_multiple_results(
2736 // CHECK-SAME: %[[ARG1:.*]]: tensor<2x2xf32>,
2737 // CHECK-SAME: %[[ARG2:.*]]: tensor<2x2xf32>) -> index {
2738 // CHECK: %[[RES:.*]]:2 = test.destination_style_op ins(%[[ARG1]] : tensor<2x2xf32>)
2739 // CHECK-SAME: outs(%[[ARG2]] : tensor<2x2xf32>) -> tensor<2x2xf32>, index
2740 // CHECK: return %[[RES]]#1 : index
2741 func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> index {
2742 %cast = tensor.cast %arg0 : tensor<2x2xf32> to tensor<?x2xf32>
2743 %cast_0 = tensor.cast %arg1 : tensor<2x2xf32> to tensor<?x2xf32>
2744 %0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
2749 // CHECK-LABEL: func.func @fold_cast_pack_dynamic_tile_size
2750 // CHECK-SAME: %[[DEST:.*]]: tensor<1x1x8x1xi32>,
2751 // CHECK-SAME: %[[SRC:.*]]: tensor<7x?xi32>,
2752 // CHECK-SAME: %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> {
2753 // CHECK: %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32)
2754 // CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]]
2755 // CHECK-SAME: some_attr
2756 // CHECK-SAME: : tensor<7x?xi32> -> tensor<1x1x8x1xi32>
2757 // CHECK: return %[[PACK]] : tensor<1x1x8x1xi32>
2758 func.func @fold_cast_pack_dynamic_tile_size(
2759 %dest: tensor<1x1x8x1xi32>,
2760 %src: tensor<7x?xi32>,
2761 %pad: i32) -> tensor<1x1x8x1xi32> {
2763 %cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
2764 %c8 = arith.constant 8 : index
2765 %pack = tensor.pack %src padding_value(%pad : i32)
2766 inner_dims_pos = [0, 1]
2767 inner_tiles = [%c8, 1]
2768 into %cast {some_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
2769 %res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32>
2770 return %res : tensor<1x1x8x1xi32>
2775 // CHECK-LABEL: func.func @pack_dont_drop_attributes(
2776 // CHECK: tensor.pack {{.*}} {test_attr}
2777 func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> {
2778 %c32_i64 = arith.constant 32 : i64
2779 %cst = arith.constant 0.000000e+00 : f16
2780 %pack = tensor.pack %arg0 padding_value(%cst : f16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %arg1 {test_attr} : tensor<?x?x?xf16> -> tensor<128x?x100x16x1xf16>
2781 return %pack : tensor<128x?x100x16x1xf16>
2786 func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>)
2787 -> tensor<10x1x10xf32> {
2788 %c1 = arith.constant 1 : index
2789 %c10 = arith.constant 10 : index
2790 %0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
2791 %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10]
2792 : tensor<?x?xf32> into tensor<?x?x?xf32>
2793 %2 = tensor.cast %1 : tensor<?x?x?xf32> to tensor<10x1x10xf32>
2794 return %2 : tensor<10x1x10xf32>
2796 // CHECK-LABEL: func.func @fold_expand_of_cast
2797 // CHECK: %[[RES:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10]
2798 // CHECK: return %[[RES]]
2802 func.func @sink_expand_of_cast(%arg0 : tensor<?x10xf32>)
2803 -> tensor<?x?x?xf32> {
2804 %c1 = arith.constant 1 : index
2805 %c10 = arith.constant 10 : index
2806 %0 = tensor.cast %arg0 : tensor<?x10xf32> to tensor<?x?xf32>
2807 %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10]
2808 : tensor<?x?xf32> into tensor<?x?x?xf32>
2809 return %1 : tensor<?x?x?xf32>
2811 // CHECK-LABEL: func.func @sink_expand_of_cast
2812 // CHECK-DAG: %[[C10:.*]] = arith.constant 10
2813 // CHECK-DAG: %[[C1:.*]] = arith.constant 1
2814 // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]]
2815 // CHECK-SAME: output_shape [%[[C10]], %[[C1]], 10]
2816 // CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]]
2817 // CHECK: return %[[RES]]
2821 func.func @partial_sink_expand_of_cast(%arg0 : tensor<10x10xf32>, %arg1 : index, %arg2 : index)
2822 -> tensor<?x?x?xf32> {
2823 %c10 = arith.constant 10 : index
2824 %0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
2825 %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %c10]
2826 : tensor<?x?xf32> into tensor<?x?x?xf32>
2827 return %1 : tensor<?x?x?xf32>
2829 // CHECK-LABEL: func.func @partial_sink_expand_of_cast
2830 // CHECK: %[[CAST:.+]] = tensor.cast
2831 // CHECK-SAME: tensor<10x10xf32> to tensor<?x10xf32>
2832 // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]]
2833 // CHECK-SAME: output_shape [%{{.*}}, %{{.*}}, 10]
2834 // CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]]
2835 // CHECK-SAME: tensor<?x?x10xf32> to tensor<?x?x?xf32>
2836 // CHECK: return %[[RES]]