[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / test / Dialect / Tensor / canonicalize.mlir
blob0b54c207dea84ea3b9393c043a5297054879445a
1 // RUN: mlir-opt %s -split-input-file -canonicalize="test-convergence" | FileCheck %s
4 // CHECK-LABEL: expand_shape_identity_fold
5 // CHECK-NEXT: return
6 func.func @expand_shape_identity_fold(%arg0 : tensor<5xf32>) -> tensor<5xf32> {
7   %0 = tensor.expand_shape %arg0 [[0]] output_shape [5] : tensor<5xf32> into tensor<5xf32>
8   return %0 : tensor<5xf32>
11 // -----
13 // CHECK-LABEL: expand_shape_rank0_identity_fold
14 // CHECK-NEXT: return
15 func.func @expand_shape_rank0_identity_fold(%arg0 : tensor<f32>) -> tensor<f32> {
16   %0 = tensor.expand_shape %arg0 [] output_shape [] : tensor<f32> into tensor<f32>
17   return %0 : tensor<f32>
20 // -----
22 // CHECK-LABEL: collapse_shape_identity_fold
23 // CHECK-NEXT: return
24 func.func @collapse_shape_identity_fold(%arg0 : tensor<5x4xf32>) -> tensor<5x4xf32> {
25   %0 = tensor.collapse_shape %arg0 [[0], [1]] : tensor<5x4xf32> into tensor<5x4xf32>
26   return %0 : tensor<5x4xf32>
29 // -----
31 // CHECK-LABEL: collapse_shape_rank0_identity_fold
32 // CHECK-NEXT: return
33 func.func @collapse_shape_rank0_identity_fold(%arg0 : tensor<f32>) -> tensor<f32> {
34   %0 = tensor.collapse_shape %arg0 [] : tensor<f32> into tensor<f32>
35   return %0 : tensor<f32>
38 // -----
40 // CHECK-LABEL: @tensor_bitcast_chain_ok
41 // CHECK-SAME: %[[IN:.*]]: tensor<2xi32>
42 func.func @tensor_bitcast_chain_ok(%input: tensor<2xi32>) -> tensor<2xf32> {
43   // CHECK-NEXT: %[[RES:.*]] = tensor.bitcast %[[IN]] : tensor<2xi32> to tensor<2xf32>
44   %0 = tensor.bitcast %input : tensor<2xi32> to tensor<2xui32>
45   %1 = tensor.bitcast %0 : tensor<2xui32> to tensor<2xf32>
46   // CHECK-NEXT: return %[[RES]]
47   return %1 : tensor<2xf32>
50 // -----
52 // CHECK-LABEL: @tensor_bitcast_chain_nop
53 // CHECK-SAME: %[[IN:.*]]: tensor<4xi32>
54 func.func @tensor_bitcast_chain_nop(%input: tensor<4xi32>) -> tensor<4xi32> {
55   %0 = tensor.bitcast %input : tensor<4xi32> to tensor<4xui32>
56   %1 = tensor.bitcast %0 : tensor<4xui32> to tensor<4xi32>
57   // CHECK-NEXT: return %[[IN]]
58   return %1 : tensor<4xi32>
61 // -----
63 // Checks that NOP casts are removed.
64 // CHECK-LABEL: cast_values
65 func.func @cast_values(%arg0: tensor<*xi32>) -> tensor<2xi32> {
66   // NOP cast
67   %0 = tensor.cast %arg0 : tensor<*xi32> to tensor<*xi32>
68   // CHECK-NEXT: %[[RET:.*]] = tensor.cast %arg0 : tensor<*xi32> to tensor<2xi32>
69   %2 = tensor.cast %0 : tensor<*xi32> to tensor<2xi32>
70   // NOP cast
71   %4 = tensor.cast %2 : tensor<2xi32> to tensor<2xi32>
72   // CHECK-NEXT: return %[[RET]] : tensor<2xi32>
73   return %4 : tensor<2xi32>
76 // -----
78 // CHECK-LABEL: @tensor.cast_chain_ok
79 // CHECK-SAME: %[[IN:.*]]: tensor<*xi32>
80 func.func @tensor.cast_chain_ok(%input: tensor<*xi32>) -> tensor<4x8xi32> {
81   // CHECK-NEXT: %[[RES:.*]] = tensor.cast %[[IN]] : tensor<*xi32> to tensor<4x8xi32>
82   %0 = tensor.cast %input : tensor<*xi32> to tensor<4x?xi32>
83   %1 = tensor.cast %0 : tensor<4x?xi32> to tensor<4x8xi32>
84   // CHECK-NEXT: return %[[RES]]
85   return %1 : tensor<4x8xi32>
88 // -----
90 // CHECK-LABEL: @tensor.cast_chain_regain
91 // CHECK-SAME: %[[IN:.*]]: tensor<4xi32>
92 func.func @tensor.cast_chain_regain(%input: tensor<4xi32>) -> tensor<4xi32> {
93   %0 = tensor.cast %input : tensor<4xi32> to tensor<?xi32>
94   %1 = tensor.cast %0 : tensor<?xi32> to tensor<4xi32>
95   // CHECK-NEXT: return %[[IN]]
96   return %1 : tensor<4xi32>
99 // -----
101 // CHECK-LABEL: @tensor.cast_chain_keep
102 // CHECK-SAME: %[[IN:.*]]: tensor<?x?xi32>
103 func.func @tensor.cast_chain_keep(%input: tensor<?x?xi32>) -> tensor<?x8xi32> {
104   // CHECK-NEXT: %[[C1:.*]] = tensor.cast %[[IN]]
105   %0 = tensor.cast %input : tensor<?x?xi32> to tensor<4x?xi32>
106   // CHECK-NEXT: %[[C2:.*]] = tensor.cast %[[C1]]
107   %1 = tensor.cast %0 : tensor<4x?xi32> to tensor<?x8xi32>
108   // CHECK-NEXT: return %[[C2]]
109   return %1 : tensor<?x8xi32>
112 // -----
114 // CHECK-LABEL: @tensor.cast_chain_invalid
115 // CHECK-SAME: %[[IN:.*]]: tensor<4x8xi32>
116 func.func @tensor.cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> {
117   // CHECK-NEXT: %[[C1:.*]] = tensor.cast %[[IN]]
118   %0 = tensor.cast %input : tensor<4x8xi32> to tensor<?x?xi32>
119   // CHECK-NEXT: %[[C2:.*]] = tensor.cast %[[C1]]
120   %1 = tensor.cast %0 : tensor<?x?xi32> to tensor<8x4xi32>
121   // CHECK-NEXT: return %[[C2]]
122   return %1 : tensor<8x4xi32>
125 // -----
127 // CHECK-LABEL: fold_concat
128 // CHECK-SAME: %[[ARG0:.*]]: tensor<1x2x?xi32>
129 func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1x2x?xi32>) {
130   %0 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x3xi32>
131   // CHECK-NEXT: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<1x2x?xi32> to tensor<1x2x3xi32>
132   %1 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x?xi32>
133   // CHECK-NEXT: return %[[CAST]], %[[ARG0]] : tensor<1x2x3xi32>, tensor<1x2x?xi32>
134   return %0, %1 : tensor<1x2x3xi32>, tensor<1x2x?xi32>
137 // -----
139 // CHECK-LABEL: func @fold_extract
140 func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
141   %const_0 = arith.constant 0 : index
142   %const_1 = arith.constant 1 : index
143   %const_3 = arith.constant 3 : index
144   // CHECK-DAG: [[C64:%.+]] = arith.constant 64 : i32
145   // CHECK-DAG: [[C0:%.+]] = arith.constant 0.{{0*}}e+00 : f16
146   // CHECK-DAG: [[CM2:%.+]] = arith.constant -2.{{0*}}e+00 : f16
148   // Fold an extract into a splat.
149   // CHECK-DAG: [[C4:%.+]] = arith.constant 4.{{0*}}e+00 : f32
150   %0 = arith.constant dense<4.0> : tensor<4xf32>
151   %ext_1 = tensor.extract %0[%arg0] : tensor<4xf32>
153   // Fold an extract into a sparse with a sparse index.
154   %1 = arith.constant sparse<[[0, 0, 0], [1, 1, 1]],  [-5.0, -2.0]> : tensor<4x4x4xf16>
155   %ext_2 = tensor.extract %1[%const_1, %const_1, %const_1] : tensor<4x4x4xf16>
157   // Fold an extract into a sparse with a non sparse index.
158   %2 = arith.constant sparse<[[1, 1, 1]],  [-2.0]> : tensor<2x2x2xf16>
159   %ext_3 = tensor.extract %2[%const_0, %const_0, %const_0] : tensor<2x2x2xf16>
161   // Fold an extract into a dense tensor.
162   %3 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32>
163   %ext_4 = tensor.extract %3[%const_1, %const_0, %const_3] : tensor<2x1x4xi32>
165   // Fold an extract into a complex constant.
166   // CHECK-DAG: [[C5:%.+]] = complex.constant [1.200000e+00 : f32, 2.300000e+00 : f32] : complex<f32>
167   %4 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
168   %ext_5 = tensor.extract %4[] : tensor<complex<f32>>
170   // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]]
171   return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex<f32>
174 // -----
176 // Ensure extract dense resource elements not crash.
178 // CHECK-LABEL: func @extract_dense_resource_nofold
179 func.func @extract_dense_resource_nofold() -> i64 {
180   // CHECK:      %[[EXT:.+]] = tensor.extract
181   // CHECK-NEXT:   return %[[EXT]]
182   %c0 = arith.constant 0 : index
183   %cst = arith.constant dense_resource<__elided__> : tensor<1xi64>
184   %extracted = tensor.extract %cst[%c0] : tensor<1xi64>
185   return %extracted : i64
188 // -----
190 // CHECK-LABEL: func @fold_insert
191 func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
192   // Fold an insert into a splat.
193   // CHECK-DAG: %[[C4:.+]] = arith.constant dense<4.{{0*}}e+00> : tensor<4xf32>
194   %0 = arith.constant dense<4.0> : tensor<4xf32>
195   %1 = arith.constant 4.0 : f32
196   %ins_1 = tensor.insert %1 into %0[%arg0] : tensor<4xf32>
197   // CHECK-NEXT: return %[[C4]]
198   return %ins_1 : tensor<4xf32>
201 // -----
203 // CHECK-LABEL: func @extract_from_tensor.cast
204 // CHECK-SAME: %[[TENSOR:.*]]: tensor<9xf32>
205 func.func @extract_from_tensor.cast(%tensor: tensor<9xf32>) -> f32 {
206   // CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
207   %c0 = arith.constant 0 : index
208   // CHECK-NOT: tensor.cast
209   %casted = tensor.cast %tensor : tensor<9xf32> to tensor<?xf32>
210   // CHECK-NEXT: tensor.extract %[[TENSOR]][%[[C0]]]
211   %result = tensor.extract %casted[%c0] : tensor<?xf32>
212   return %result : f32
215 // -----
217 // CHECK-LABEL: func @extract_from_tensor.from_elements
218 func.func @extract_from_tensor.from_elements(%element : index) -> index {
219   // CHECK-SAME: ([[ARG:%.*]]: index)
220   %c0 = arith.constant 0 : index
221   %tensor = tensor.from_elements %element : tensor<1xindex>
222   %extracted_element = tensor.extract %tensor[%c0] : tensor<1xindex>
223   // CHECK: [[ARG]] : index
224   return %extracted_element : index
227 // -----
229 // CHECK-LABEL: func @extract_from_tensor.from_elements_0d
230 func.func @extract_from_tensor.from_elements_0d(%element : index) -> index {
231   // CHECK-SAME: ([[ARG:%.*]]: index)
232   %c0 = arith.constant 0 : index
233   %tensor = tensor.from_elements %element : tensor<index>
234   %extracted_element = tensor.extract %tensor[] : tensor<index>
235   // CHECK: [[ARG]] : index
236   return %extracted_element : index
239 // -----
241 // CHECK-LABEL: func @extract_from_tensor.from_elements_3d
242 func.func @extract_from_tensor.from_elements_3d()
243     -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
244   %f0 = arith.constant 0.0 : f32
245   %f1 = arith.constant 1.0 : f32
246   %f2 = arith.constant 2.0 : f32
247   %f3 = arith.constant 3.0 : f32
248   %f4 = arith.constant 4.0 : f32
249   %f5 = arith.constant 5.0 : f32
250   %f6 = arith.constant 6.0 : f32
251   %f7 = arith.constant 7.0 : f32
252   %f8 = arith.constant 8.0 : f32
253   %f9 = arith.constant 9.0 : f32
254   %f10 = arith.constant 10.0 : f32
255   %f11 = arith.constant 11.0 : f32
257   %tensor = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11
258          : tensor<3x2x2xf32>
259   %c0 = arith.constant 0 : index
260   %c1 = arith.constant 1 : index
261   %c2 = arith.constant 2 : index
263   %r0 = tensor.extract %tensor[%c0, %c0, %c0] : tensor<3x2x2xf32>
264   %r1 = tensor.extract %tensor[%c0, %c0, %c1] : tensor<3x2x2xf32>
265   %r2 = tensor.extract %tensor[%c0, %c1, %c0] : tensor<3x2x2xf32>
266   %r3 = tensor.extract %tensor[%c0, %c1, %c1] : tensor<3x2x2xf32>
267   %r4 = tensor.extract %tensor[%c1, %c0, %c0] : tensor<3x2x2xf32>
268   %r5 = tensor.extract %tensor[%c1, %c0, %c1] : tensor<3x2x2xf32>
269   %r6 = tensor.extract %tensor[%c1, %c1, %c0] : tensor<3x2x2xf32>
270   %r7 = tensor.extract %tensor[%c1, %c1, %c1] : tensor<3x2x2xf32>
271   %r8 = tensor.extract %tensor[%c2, %c0, %c0] : tensor<3x2x2xf32>
272   %r9 = tensor.extract %tensor[%c2, %c0, %c1] : tensor<3x2x2xf32>
273   %r10 = tensor.extract %tensor[%c2, %c1, %c0] : tensor<3x2x2xf32>
274   %r11 = tensor.extract %tensor[%c2, %c1, %c1] : tensor<3x2x2xf32>
275   return %r0,%r1,%r2,%r3,%r4,%r5,%r6,%r7,%r8,%r9,%r10,%r11
276          : f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
278 // CHECK-DAG: %[[F0:.*]] = arith.constant 0.0
279 // CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00
280 // CHECK-DAG: %[[F2:.*]] = arith.constant 2.0
281 // CHECK-DAG: %[[F3:.*]] = arith.constant 3.0
282 // CHECK-DAG: %[[F4:.*]] = arith.constant 4.0
283 // CHECK-DAG: %[[F5:.*]] = arith.constant 5.0
284 // CHECK-DAG: %[[F6:.*]] = arith.constant 6.0
285 // CHECK-DAG: %[[F7:.*]] = arith.constant 7.0
286 // CHECK-DAG: %[[F8:.*]] = arith.constant 8.0
287 // CHECK-DAG: %[[F9:.*]] = arith.constant 9.0
288 // CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01
289 // CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01
291 // CHECK: return %[[F0]], %[[F1]], %[[F2]], %[[F3]], %[[F4]], %[[F5]],
292 // CHECK-SAME:   %[[F6]], %[[F7]], %[[F8]], %[[F9]], %[[F10]], %[[F11]]
294 // -----
296 // CHECK-LABEL: func @extract_from_tensor.from_elements_variable_3d
297 // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: f32
298 // CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: f32
299 // CHECK-SAME: %[[ARG_2:[a-zA-Z0-9_]+]]: f32
300 // CHECK-SAME: %[[ARG_3:[a-zA-Z0-9_]+]]: f32
301 // CHECK-SAME: %[[ARG_4:[a-zA-Z0-9_]+]]: f32
302 // CHECK-SAME: %[[ARG_5:[a-zA-Z0-9_]+]]: f32
303 // CHECK-SAME: %[[ARG_6:[a-zA-Z0-9_]+]]: f32
304 // CHECK-SAME: %[[ARG_7:[a-zA-Z0-9_]+]]: f32
305 // CHECK-SAME: %[[ARG_8:[a-zA-Z0-9_]+]]: f32
306 // CHECK-SAME: %[[ARG_9:[a-zA-Z0-9_]+]]: f32
307 // CHECK-SAME: %[[ARG_10:[a-zA-Z0-9_]+]]: f32
308 // CHECK-SAME: %[[ARG_11:[a-zA-Z0-9_]+]]: f32
309 func.func @extract_from_tensor.from_elements_variable_3d(
310     %f0: f32, %f1: f32, %f2: f32, %f3: f32, %f4: f32, %f5: f32,
311     %f6: f32, %f7: f32, %f8: f32, %f9: f32, %f10: f32, %f11: f32)
312     -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
314   %tensor = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11
315          : tensor<3x2x2xf32>
316   %c0 = arith.constant 0 : index
317   %c1 = arith.constant 1 : index
318   %c2 = arith.constant 2 : index
320   %r0 = tensor.extract %tensor[%c0, %c0, %c0] : tensor<3x2x2xf32>
321   %r1 = tensor.extract %tensor[%c0, %c0, %c1] : tensor<3x2x2xf32>
322   %r2 = tensor.extract %tensor[%c0, %c1, %c0] : tensor<3x2x2xf32>
323   %r3 = tensor.extract %tensor[%c0, %c1, %c1] : tensor<3x2x2xf32>
324   %r4 = tensor.extract %tensor[%c1, %c0, %c0] : tensor<3x2x2xf32>
325   %r5 = tensor.extract %tensor[%c1, %c0, %c1] : tensor<3x2x2xf32>
326   %r6 = tensor.extract %tensor[%c1, %c1, %c0] : tensor<3x2x2xf32>
327   %r7 = tensor.extract %tensor[%c1, %c1, %c1] : tensor<3x2x2xf32>
328   %r8 = tensor.extract %tensor[%c2, %c0, %c0] : tensor<3x2x2xf32>
329   %r9 = tensor.extract %tensor[%c2, %c0, %c1] : tensor<3x2x2xf32>
330   %r10 = tensor.extract %tensor[%c2, %c1, %c0] : tensor<3x2x2xf32>
331   %r11 = tensor.extract %tensor[%c2, %c1, %c1] : tensor<3x2x2xf32>
332   return %r0,%r1,%r2,%r3,%r4,%r5,%r6,%r7,%r8,%r9,%r10,%r11
333          : f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
335 // CHECK: return %[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]],
336 // CHECK-SAME: %[[ARG_6]], %[[ARG_7]], %[[ARG_8]], %[[ARG_9]], %[[ARG_10]], %[[ARG_11]]
338 // -----
340 // CHECK-LABEL: func.func @extract_from_elements_complex_i() -> tensor<3xcomplex<i32>> {
341 // CHECK-NEXT:  %cst = arith.constant dense<[(1,2), (3,2), (1,2)]> : tensor<3xcomplex<i32>>
342 // CHECK-NEXT:  return %cst : tensor<3xcomplex<i32>>
343 func.func @extract_from_elements_complex_i() -> tensor<3xcomplex<i32>> {
344   %c1 = arith.constant dense<(1, 2)> : tensor<complex<i32>>
345   %complex1 = tensor.extract %c1[] : tensor<complex<i32>>
346   %c2 = arith.constant dense<(3, 2)> : tensor<complex<i32>>
347   %complex2 = tensor.extract %c2[] : tensor<complex<i32>>
348   %tensor = tensor.from_elements %complex1, %complex2, %complex1 : tensor<3xcomplex<i32>>
349   return %tensor : tensor<3xcomplex<i32>>
352 // -----
354 // CHECK-LABEL:  func.func @extract_from_elements_complex_f() -> tensor<3xcomplex<f32>> {
355 // CHECK-NEXT:   %cst = arith.constant dense<[(1.200000e+00,2.300000e+00), (3.200000e+00,2.100000e+00), (1.200000e+00,2.300000e+00)]> : tensor<3xcomplex<f32>>
356 // CHECK-NEXT:   return %cst : tensor<3xcomplex<f32>>
357 func.func @extract_from_elements_complex_f() -> tensor<3xcomplex<f32>> {
358   %c1 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
359   %complex1 = tensor.extract %c1[] : tensor<complex<f32>>
360   %c2 = arith.constant dense<(3.2, 2.1)> : tensor<complex<f32>>
361   %complex2 = tensor.extract %c2[] : tensor<complex<f32>>
362   %tensor = tensor.from_elements %complex1, %complex2, %complex1 : tensor<3xcomplex<f32>>
363   return %tensor : tensor<3xcomplex<f32>>
366 // -----
368 // Ensure the optimization doesn't segfault from bad constants
369 // CHECK-LABEL: func @extract_negative_from_tensor.from_elements
370 func.func @extract_negative_from_tensor.from_elements(%element : index) -> index {
371   // CHECK-SAME: ([[ARG:%.*]]: index)
372   %c-1 = arith.constant -1 : index
373   %tensor = tensor.from_elements %element : tensor<1xindex>
374   %extracted_element = tensor.extract %tensor[%c-1] : tensor<1xindex>
375   // CHECK: tensor.from_elements
376   // CHECK: %[[RESULT:.*]] = tensor.extract
377   // CHECK: return %[[RESULT]]
378   return %extracted_element : index
381 // -----
383 // Ensure the optimization doesn't segfault from bad constants
384 // CHECK-LABEL: func @extract_oob_from_tensor.from_elements
385 func.func @extract_oob_from_tensor.from_elements(%element : index) -> index {
386   // CHECK-SAME: ([[ARG:%.*]]: index)
387   %c1 = arith.constant 1 : index
388   %tensor = tensor.from_elements %element : tensor<1xindex>
389   %extracted_element = tensor.extract %tensor[%c1] : tensor<1xindex>
390   // CHECK: tensor.from_elements
391   // CHECK: %[[RESULT:.*]] = tensor.extract
392   // CHECK: return %[[RESULT]]
393   return %extracted_element : index
396 // -----
398 // Ensure the optimization doesn't segfault from bad constants
399 // CHECK-LABEL: func @extract_oob_from_tensor.from_elements
400 func.func @extract_oob_from_tensor.from_elements(%element : index) -> index {
401   // CHECK-SAME: ([[ARG:%.*]]: index)
402   %c2 = arith.constant 2 : index
403   %tensor = tensor.from_elements %element : tensor<1xindex>
404   %extracted_element = tensor.extract %tensor[%c2] : tensor<1xindex>
405   // CHECK: tensor.from_elements
406   // CHECK: %[[RESULT:.*]] = tensor.extract
407   // CHECK: return %[[RESULT]]
408   return %extracted_element : index
411 // -----
413 // CHECK-LABEL: func @extract_from_tensor.generate
414 // CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
415 func.func @extract_from_tensor.generate(%idx: index, %tensor: tensor<*xf32>) -> index {
416   %size = tensor.rank %tensor : tensor<*xf32>
417   // CHECK-NEXT: %[[RES:.*]] = tensor.dim %[[TENSOR]], %[[IDX]]
418   %0 = tensor.generate %size {
419     ^bb0(%arg0: index):
420     %1 = tensor.dim %tensor, %arg0 : tensor<*xf32>
421     tensor.yield %1 : index
422   } : tensor<?xindex>
423   %1 = tensor.extract %0[%idx] : tensor<?xindex>
424   // CHECK-NEXT: return %[[RES]]
425   return %1 : index
428 // -----
430 // CHECK-LABEL: func @extract_from_tensor.generate_2d
431 // CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32>
432 func.func @extract_from_tensor.generate_2d(%idx0: index, %idx1: index, %tensor: tensor<*xf32>) -> index {
433   %size = tensor.rank %tensor : tensor<*xf32>
434   // CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[TENSOR]], %[[IDX0]]
435   // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[TENSOR]], %[[IDX1]]
436   // CHECK-NEXT: %[[RES:.*]] = arith.addi %[[DIM0]], %[[DIM1]]
437   %0 = tensor.generate %size, %size {
438     ^bb0(%arg0: index, %arg1: index):
439     %1 = tensor.dim %tensor, %arg0 : tensor<*xf32>
440     %2 = tensor.dim %tensor, %arg1 : tensor<*xf32>
441     %3 = arith.addi %1, %2 : index
442     tensor.yield %3 : index
443   } : tensor<?x?xindex>
444   %4 = tensor.extract %0[%idx0, %idx1] : tensor<?x?xindex>
445   // CHECK-NEXT: return %[[RES]]
446   return %4 : index
449 // -----
451 // CHECK-LABEL: func @extract_from_tensor.generate_sideeffects
452 // CHECK-SAME: %[[IDX:.*]]: index
453 func.func @extract_from_tensor.generate_sideeffects(%idx: index, %tensor: tensor<*xf32>, %mem: memref<?xindex>) -> index {
454   %size = tensor.rank %tensor : tensor<*xf32>
455   // CHECK: %[[DTENSOR:.*]] = tensor.generate
456   %0 = tensor.generate %size {
457     ^bb0(%arg0: index):
458     %1 = tensor.dim %tensor, %arg0 : tensor<*xf32>
459     memref.store %1, %mem[%arg0] : memref<?xindex>
460     tensor.yield %1 : index
461   } : tensor<?xindex>
462   // CHECK: %[[RES:.*]] = tensor.extract %[[DTENSOR]][%[[IDX]]]
463   %1 = tensor.extract %0[%idx] : tensor<?xindex>
464   // CHECK-NEXT: return %[[RES]]
465   return %1 : index
468 // -----
470 // CHECK-LABEL: @static_tensor.generate
471 // CHECK-SAME: %[[SIZE1:.*]]: index, %[[SIZE4:.*]]: index)
472 func.func @static_tensor.generate(%size1: index, %size4: index) -> tensor<3x?x?x7x?xindex> {
473   %c5 = arith.constant 5 : index
474   // CHECK: tensor.generate %[[SIZE1]], %[[SIZE4]]
475   %0 = tensor.generate %size1, %c5, %size4 {
476     ^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index):
477     %1 = arith.constant 32 : index
478     tensor.yield %1 : index
479   // CHECK: : tensor<3x?x5x7x?xindex>
480   } : tensor<3x?x?x7x?xindex>
481   // CHECK: tensor.cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex>
482   return %0 : tensor<3x?x?x7x?xindex>
485 // -----
487 // CHECK-LABEL: @from_elements.constant
488 func.func @from_elements.constant() -> tensor<3xindex> {
489   // CHECK: %[[CST:.*]] = arith.constant dense<[1, 2, 1]> : tensor<3xindex>
490   // CHECK: return %[[CST]]
491   %c1 = arith.constant 1 : index
492   %c2 = arith.constant 2 : index
493   %tensor = tensor.from_elements %c1, %c2, %c1 : tensor<3xindex>
494   return %tensor : tensor<3xindex>
497 // -----
499 func.func @slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
500     %arg2 : index) -> tensor<?x?x?xf32>
502   %c0 = arith.constant 0 : index
503   %c1 = arith.constant 1 : index
504   %c4 = arith.constant 4 : index
505   %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
506   return %0 : tensor<?x?x?xf32>
508 // CHECK-LABEL: func @slice_canonicalize
509 //  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?xf32>
510 //       CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
511 //  CHECK-SAME:      [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
512 //  CHECK-SAME:      : tensor<?x?x?xf32> to tensor<4x1x?xf32>
513 //       CHECK:   %[[RESULT:.+]] = tensor.cast %[[SLICE]]
514 //       CHECK:   return %[[RESULT]]
516 // -----
518 func.func @rank_reducing_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
519     %arg2 : index) -> tensor<?x?xf32>
521   %c0 = arith.constant 0 : index
522   %c1 = arith.constant 1 : index
523   %c4 = arith.constant 4 : index
524   %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?xf32>
525   return %0 : tensor<?x?xf32>
527 // CHECK-LABEL: func @rank_reducing_slice_canonicalize
528 //  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?xf32>
529 //       CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
530 //  CHECK-SAME:      [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
531 //  CHECK-SAME:      : tensor<?x?x?xf32> to tensor<4x?xf32>
532 //       CHECK:   %[[RESULT:.+]] = tensor.cast %[[SLICE]]
533 //       CHECK:   return %[[RESULT]]
535 // -----
537 // CHECK-LABEL: func @trivial_slice
538 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
539 //   CHECK-NOT:   tensor.extract_slice
540 //       CHECK:   return %[[ARG0]] :  tensor<4x6x16x32xi8>
541 func.func @trivial_slice(%arg0 : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
542   %0 = tensor.extract_slice %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<4x6x16x32xi8>
543   return %0 : tensor<4x6x16x32xi8>
546 // -----
548 // CHECK-LABEL: func @trivial_insert_slice
549 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
550 //   CHECK-NOT:   tensor.extract_slice
551 //       CHECK:   return %[[ARG0]] :  tensor<4x6x16x32xi8>
552 func.func @trivial_insert_slice(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
553   %0 = tensor.insert_slice %arg0 into %arg1[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> into tensor<4x6x16x32xi8>
554   return %0 : tensor<4x6x16x32xi8>
557 // -----
559 // CHECK-LABEL: func @empty_insert_slice
560 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<0x2xi8>
561 //  CHECK-SAME:   %[[ARG1:.[a-z0-9A-Z_]+]]: tensor<3x3xi8>
562 //   CHECK-NOT:   tensor.extract_slice
563 //       CHECK:   return %[[ARG1]] :  tensor<3x3xi8>
564 func.func @empty_insert_slice(%arg0 : tensor<0x2xi8>, %arg1 : tensor<3x3xi8>) -> tensor<3x3xi8> {
565   %0 = tensor.insert_slice %arg0 into %arg1[0, 0] [0, 2] [1, 1] : tensor<0x2xi8> into tensor<3x3xi8>
566   return %0 : tensor<3x3xi8>
569 // -----
571 // CHECK-LABEL: func @rank_reducing_tensor_of_cast
572 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
573 //       CHECK:   %[[S:.+]] = tensor.extract_slice %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8>
574 // Tensor cast is moved after slice and then gets canonicalized away.
575 //   CHECK-NOT:   tensor.cast
576 //       CHECK:   return %[[S]] : tensor<16x32xi8>
577 func.func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<16x32xi8> {
578   %0 = tensor.cast %arg : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
579   %1 = tensor.extract_slice %0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<?x?x16x32xi8> to tensor<16x32xi8>
580   return %1 : tensor<16x32xi8>
583 // -----
585 // CHECK-LABEL: func @rank_reducing_insert_slice_of_cast
586 //  CHECK-SAME:   %[[A:.[a-z0-9A-Z_]+]]: tensor<16x32xi8>
587 //  CHECK-SAME:   %[[B:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
588 //       CHECK:   %[[S:.+]] = tensor.insert_slice %[[A]] into %[[B]][0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<16x32xi8> into tensor<4x6x16x32xi8>
589 // Tensor cast is folded away.
590 //   CHECK-NOT:   tensor.cast
591 //       CHECK:   return %[[S]] : tensor<4x6x16x32xi8>
592 func.func @rank_reducing_insert_slice_of_cast(%a : tensor<16x32xi8>, %b : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
593   %c0 = arith.constant 0: index
594   %cast = tensor.cast %a : tensor<16x32xi8> to tensor<?x32xi8>
595   %sz = tensor.dim %cast, %c0: tensor<?x32xi8>
596   %res = tensor.insert_slice %cast into %b[0, 1, 0, 0] [1, 1, %sz, 32] [1, 1, 1, 1] : tensor<?x32xi8> into tensor<4x6x16x32xi8>
597   return %res : tensor<4x6x16x32xi8>
600 // -----
602 func.func @insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
603     %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
605   %c0 = arith.constant 0 : index
606   %c1 = arith.constant 1 : index
607   %c4 = arith.constant 4 : index
608   %0 = tensor.insert_slice %arg0 into %arg3[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
609   return %0 : tensor<?x?x?xf32>
611 // CHECK-LABEL: func @insert_slice_canonicalize
612 //  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
613 //       CHECK:   %[[CAST:.+]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<4x1x?xf32>
614 //       CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[CAST]]
615 //  CHECK-SAME:      [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
616 //  CHECK-SAME:      : tensor<4x1x?xf32> into tensor<?x?x?xf32>
617 //       CHECK:   return %[[RESULT]]
619 // -----
621 // Do not insert a cast for the following example. The new source type wouldn't be "more static" than the old one.
622 func.func @insert_slice_canonicalize_encoding(%arg0 : tensor<2x2xf32, "foo">,
623                                               %arg1 : tensor<4x4xf32, "foo">) -> tensor<4x4xf32, "foo">
625   %0 = tensor.insert_slice %arg0 into %arg1[0, 0] [2, 2] [1, 1] : tensor<2x2xf32, "foo"> into tensor<4x4xf32, "foo">
626   return %0 : tensor<4x4xf32, "foo">
628 // CHECK-LABEL: func @insert_slice_canonicalize_encoding
629 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x2xf32, "foo">
630 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: tensor<4x4xf32, "foo">
631 //       CHECK-NOT: tensor.cast
632 //       CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]] into %[[ARG1]]
633 //  CHECK-SAME:      [0, 0] [2, 2] [1, 1]
634 //  CHECK-SAME:      : tensor<2x2xf32, "foo"> into tensor<4x4xf32, "foo">
635 //       CHECK:   return %[[RESULT]]
637 // -----
639 func.func @slice_to_insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
640     %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
642   %c0 = arith.constant 0 : index
643   %c1 = arith.constant 1 : index
644   %c4 = arith.constant 4 : index
645   %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
646   %1 = tensor.insert_slice %0 into %arg3[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
647   return %1 : tensor<?x?x?xf32>
649 // CHECK-LABEL: func @slice_to_insert_slice_canonicalize
650 //  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
651 //  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
652 //       CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
653 //  CHECK-SAME:      [0, %{{.+}}, 1] [4, 1, %{{.+}} [1, 1, 1]
654 //  CHECK-SAME:      : tensor<?x?x?xf32> to tensor<4x1x?xf32>
655 //       CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[SLICE]]
656 //  CHECK-SAME:      [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
657 //  CHECK-SAME:      : tensor<4x1x?xf32> into tensor<?x?x?xf32>
658 //       CHECK:   return %[[RESULT]]
660 // -----
662 func.func @rank_reducing_insert_slice_canonicalize(%arg0 : tensor<?x?xf32>, %arg1 : index,
663     %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
665   %c0 = arith.constant 0 : index
666   %c1 = arith.constant 1 : index
667   %c4 = arith.constant 4 : index
668   %0 = tensor.insert_slice %arg0 into %arg3[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?xf32> into tensor<?x?x?xf32>
669   return %0 : tensor<?x?x?xf32>
671 // CHECK-LABEL: func @rank_reducing_insert_slice_canonicalize
672 //  CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?xf32>
673 //       CHECK:   %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<?x?xf32> to tensor<4x?xf32>
674 //       CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[CAST]]
675 //  CHECK-SAME:      [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
676 //  CHECK-SAME:      : tensor<4x?xf32> into tensor<?x?x?xf32>
677 //       CHECK:   return %[[RESULT]]
679 // -----
681 func.func @rank_reducing_slice_to_insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
682     %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
684   %c0 = arith.constant 0 : index
685   %c1 = arith.constant 1 : index
686   %c4 = arith.constant 4 : index
687   %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?xf32>
688   %1 = tensor.insert_slice %0 into %arg3[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?xf32> into tensor<?x?x?xf32>
689   return %1 : tensor<?x?x?xf32>
691 // CHECK-LABEL: func @rank_reducing_slice_to_insert_slice_canonicalize
692 //  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
693 //  CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
694 //       CHECK:   %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
695 //  CHECK-SAME:     [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
696 //  CHECK-SAME:     : tensor<?x?x?xf32> to tensor<4x?xf32>
697 //       CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[SLICE]] into %[[ARG3]]
698 //  CHECK-SAME:      [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
699 //  CHECK-SAME:      : tensor<4x?xf32> into tensor<?x?x?xf32>
700 //       CHECK:   return %[[RESULT]]
702 // -----
704 func.func @insert_slice_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i32>,
705     %arg2 : index, %arg3 : index) -> tensor<?x?xi32> {
706   %c0 = arith.constant 0 : index
707   %c1 = arith.constant 1 : index
708   %c2 = arith.constant 2 : index
709   %c8 = arith.constant 8 : index
710   %0 = tensor.dim %arg0, %c1 : tensor<2x?xi32>
711   %1 = tensor.extract %arg1[] : tensor<i32>
712   %2 = tensor.generate %arg2, %c8 {
713   ^bb0(%arg4: index, %arg5: index):
714     tensor.yield %1 : i32
715   } : tensor<?x?xi32>
716   %3 = tensor.insert_slice %arg0 into %2[0, %arg3] [2, %0] [1, 1] : tensor<2x?xi32> into tensor<?x?xi32>
717   return %3 : tensor<?x?xi32>
719 // CHECK-LABEL: func @insert_slice_propagate_dest_cast
720 //       CHECK:   %[[UPDATED:.+]] = tensor.insert_slice %{{.+}} into %{{.+}}[0, %{{.+}}] [2, %{{.+}}] [1, 1]
721 //  CHECK-SAME:     tensor<2x?xi32> into tensor<?x8xi32>
722 //       CHECK:   %[[CAST:.+]] = tensor.cast %[[UPDATED]]
723 //       CHECK:   return %[[CAST]]
725 // -----
727 func.func @insert_slice_output_dest_canonicalize(%arg0 : tensor<2x3xi32>, %arg1 : tensor<i32>) -> tensor<3x9xi32> {
728   %c9 = arith.constant 9 : index
729   %c3 = arith.constant 3 : index
730   %2 = tensor.extract %arg1[] : tensor<i32>
731   %4 = tensor.generate %c3, %c9 {
732   ^bb0(%arg2: index, %arg3: index):
733     tensor.yield %2 : i32
734   } : tensor<?x?xi32>
735   %5 = tensor.insert_slice %arg0 into %4[0, 1] [2, 3] [1, 1] : tensor<2x3xi32> into tensor<?x?xi32>
736   %6 = tensor.cast %5 : tensor<?x?xi32> to tensor<3x9xi32>
737   return %6 : tensor<3x9xi32>
739 // CHECK-LABEL: func @insert_slice_output_dest_canonicalize
740 //  CHECK-SAME:   %[[ARG0:[a-zA-z0-9_]+]]: tensor<2x3xi32>
741 //  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<i32>
742 //       CHECK:   %[[PAD:.+]] = tensor.extract %[[ARG1]]
743 //       CHECK:   %[[GENERATE:.+]] = tensor.generate
744 //       CHECK:   %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]] into %[[GENERATE]]
745 //       CHECK:   return %[[RESULT]]
747 // -----
749 // Test case: Folding of tensor.dim(tensor.generate %idx) -> %idx
750 // CHECK-LABEL: func @dim_of_tensor.generate(
751 //  CHECK-SAME:     %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
752 //   CHECK-NOT:   tensor.dim
753 //       CHECK:   return %[[IDX1]] : index
754 func.func @dim_of_tensor.generate(%arg0: index, %arg1: index) -> index {
755   %c3 = arith.constant 3 : index
756   %0 = tensor.generate %arg0, %arg1 {
757   ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
758     tensor.yield %c3 : index
759   } : tensor<2x?x4x?x5xindex>
760   %1 = tensor.dim %0, %c3 : tensor<2x?x4x?x5xindex>
761   return %1 : index
764 // -----
766 // Test case: Folding tensor.dim(tensor.cast %0, %idx) -> tensor.dim %0, %idx
767 // CHECK-LABEL: func @fold_dim_of_tensor.cast
768 //  CHECK-SAME:   %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>
769 //   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
770 //   CHECK-DAG:   %[[C4:.+]] = arith.constant 4 : index
771 //       CHECK:   %[[T0:.+]] = tensor.dim %[[ARG0]], %[[C1]]
772 //  CHECK-NEXT:   return %[[C4]], %[[T0]]
773 func.func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
774   %c0 = arith.constant 0 : index
775   %c1 = arith.constant 1 : index
776   %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
777   %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
778   %2 = tensor.dim %0, %c1 : tensor<?x?xf32>
779   return %1, %2: index, index
782 // -----
784 // CHECK-LABEL: func @insert_slice_cast
785 func.func @insert_slice_cast(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
786   // CHECK-SAME: %[[ARG0:.*]]: tensor<1x?xf32>
787   %0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x?xf32>
788   // CHECK: %[[RES:.*]] = tensor.insert_slice %[[ARG0]]
789   // CHECK-SAME: [{{.*}}, {{.*}}] [1, {{.*}}] [{{.*}}, {{.*}}]
790   // CHECK-SAME: : tensor<1x?xf32> into tensor<?x?xf32>
791   %1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, %arg5] [%arg6, %arg7] : tensor<?x?xf32> into tensor<?x?xf32>
792   // CHECK: return %[[RES]] : tensor<?x?xf32>
793   return %1 : tensor<?x?xf32>
796 // -----
798 // CHECK-LABEL: func @insert_slice_cast_no_fold
799 func.func @insert_slice_cast_no_fold(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor<?x?xf32> {
800   %0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor<?x5xf32>
801   // CHECK: %[[CAST:.*]] = tensor.cast
802   // CHECK: %[[RES:.*]] = tensor.insert_slice %[[CAST]]
803   // CHECK-SAME: [{{.*}}, {{.*}}] [{{.*}}, 5] [{{.*}}, {{.*}}]
804   // CHECK-SAME: : tensor<?x5xf32> into tensor<?x?xf32>
805   %1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, 5] [%arg6, %arg7] : tensor<?x5xf32> into tensor<?x?xf32>
806   // CHECK: return %[[RES]] : tensor<?x?xf32>
807   return %1 : tensor<?x?xf32>
810 // -----
812 // CHECK-LABEL: func @insert_tensor_cast_on_insert_slice_src(
813 // CHECK-SAME:      %[[arg0:.*]]: tensor<?x5x?xf32>, %[[arg1:.*]]: tensor<?x?x?xf32>
814 //      CHECK:    %[[cast:.*]] = tensor.cast %[[arg0]] : tensor<?x5x?xf32> to tensor<64x5x64xf32>
815 //      CHECK:    %[[r:.*]] =  tensor.insert_slice %[[cast]] into %[[arg1]][0, 1, 2] [64, 5, 64] [1, 1, 1] : tensor<64x5x64xf32> into tensor<?x?x?xf32>
816 //      CHECK:    return %[[r]]
817 func.func @insert_tensor_cast_on_insert_slice_src(
818     %arg0 : tensor<?x5x?xf32>,  %arg1 : tensor<?x?x?xf32>, %sz0: index, %sz2: index) -> tensor<?x?x?xf32> {
819   %c64 = arith.constant 64: index
820   %r = tensor.insert_slice %arg0 into %arg1[0, 1, 2] [%c64, 5, %c64] [1, 1, 1]
821     : tensor<?x5x?xf32> into tensor<?x?x?xf32>
822   return %r : tensor<?x?x?xf32>
825 // -----
827 // CHECK-LABEL: func @fold_extract_insert
828 //  CHECK-SAME: %{{.+}}: tensor<?x?x?xf32>, %[[SLICE:.+]]: tensor<4x?x8xf32>
829 func.func @fold_extract_insert(%input : tensor<?x?x?xf32>, %slice: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<4x?x8xf32>) {
830   %c0 = arith.constant 0: index
831   %c1 = arith.constant 1: index
832   %0 = tensor.insert_slice %slice into %input[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32>
833   %1 = tensor.extract_slice %0[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<?x?x?xf32> to tensor<4x?x8xf32>
834   // CHECK: return %[[SLICE]]
835   return %1 : tensor<4x?x8xf32>
838 // -----
840 // CHECK-LABEL: func @fold_gather_constant_splat
841 //   CHECK-NOT: tensor.gather
842 //       CHECK: arith.constant dense<1.000000e-01> : tensor<1x2x1x1x1xf32>
843 func.func @fold_gather_constant_splat(%indices : tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32> {
844   %cst = arith.constant dense<1.000000e-01> : tensor<4x4x4xf32>
845   %0 = tensor.gather %cst[%indices] gather_dims([0, 1, 2]) :
846     (tensor<4x4x4xf32>, tensor<1x2x 3xindex>) -> tensor<1x2x 1x1x1xf32>
847   return %0 : tensor<1x2x 1x1x1xf32>
850 // -----
852 // CHECK-LABEL: func @fold_reshape_constant_splat
853 //   CHECK-NOT: tensor.reshape
854 //       CHECK: arith.constant dense<1.000000e-01> : tensor<4xf32>
855 func.func @fold_reshape_constant_splat(%shape : tensor<1xi32>) -> tensor<4xf32> {
856   %cst = arith.constant dense<1.000000e-01> : tensor<4x1xf32>
857   %0 = tensor.reshape %cst(%shape)
858              : (tensor<4x1xf32>, tensor<1xi32>) -> tensor<4xf32>
859   return %0 : tensor<4xf32>
862 // -----
864 // CHECK-LABEL: func @fold_reshape_chain
865 //  CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<*xf32>
866 //  CHECK-SAME: %[[SHAPE_0:[a-zA-Z0-9_]+]]: tensor<?xindex>
867 //  CHECK-SAME: %[[SHAPE_1:[a-zA-Z0-9_]+]]: tensor<?xindex>
868 //  CHECK-SAME: %[[SHAPE_2:[a-zA-Z0-9_]+]]: tensor<?xindex>
869 //       CHECK: %[[RESULT:.*]] = tensor.reshape %[[INPUT]](%[[SHAPE_2]])
870 //       CHECK: return %[[RESULT]]
871 func.func @fold_reshape_chain(%input: tensor<*xf32>, %shape_0: tensor<?xindex>, %shape_1: tensor<?xindex>, %shape_2: tensor<?xindex>) -> tensor<*xf32> {
872   %0 = tensor.reshape %input(%shape_0) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
873   %1 = tensor.reshape %0(%shape_1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
874   %2 = tensor.reshape %1(%shape_2) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
875   return %2 : tensor<*xf32>
878 // -----
880 // CHECK-LABEL: func @fold_reshape_1d
881 //  CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]+]]: tensor<?xf32>
882 //  CHECK-SAME: %[[SHAPE:[a-zA-Z0-9_]+]]: tensor<1xindex>
883 //       CHECK: return %[[INPUT]]
884 func.func @fold_reshape_1d(%input: tensor<?xf32>, %shape: tensor<1xindex>) -> tensor<?xf32> {
885   %0 = tensor.reshape %input(%shape) : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
886   return %0 : tensor<?xf32>
889 // -----
891 // CHECK-LABEL: func @fold_extract_constant_splat
892 //   CHECK-NOT: tensor.extract_slice
893 //       CHECK: arith.constant dense<42> : tensor<4x4xi32>
894 func.func @fold_extract_constant_splat() -> (tensor<4x4xi32>) {
895   %cst = arith.constant dense<42> : tensor<1024x1024xi32>
896   %1 = tensor.extract_slice %cst[0,0] [4,4] [1, 1] : tensor<1024x1024xi32> to tensor<4x4xi32>
897   return %1 : tensor<4x4xi32>
900 // -----
902 // CHECK-LABEL: func @fold_pack_constant_splat
903 //   CHECK-NOT: tensor.pack
904 //       CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
905 func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
906   %cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
907   %0 = tensor.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
908     inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
909   return %0 : tensor<8x16x8x32xf32>
912 // -----
914 // CHECK-LABEL: func @fold_padding_value_pack_constant_splat
915 //   CHECK-NOT: tensor.pack
916 //       CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
917 func.func @fold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
918   %pad = arith.constant 1.000000e-01 : f32
919   %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
920   %0 = tensor.pack %cst
921     padding_value(%pad : f32)
922     outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
923     inner_tiles = [8, 32] into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
924   return %0 : tensor<8x16x8x32xf32>
928 // -----
930 // CHECK-LABEL: func @nofold_padding_value_pack_constant_splat
931 //       CHECK: arith.constant dense<1.000000e-01> : tensor<63x127xf32>
932 //       CHECK: tensor.pack
933 func.func @nofold_padding_value_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
934   %pad = arith.constant 0.0 : f32
935   %cst = arith.constant dense<1.000000e-01> : tensor<63x127xf32>
936   %0 = tensor.pack %cst
937     padding_value(%pad : f32)
938     outer_dims_perm = [1, 0]
939     inner_dims_pos = [0, 1]
940     inner_tiles = [8, 32]
941     into %dest : tensor<63x127xf32> -> tensor<8x16x8x32xf32>
942   return %0 : tensor<8x16x8x32xf32>
945 // -----
947 func.func @fold_padding_value_pack(%arg0: tensor<1200x500000xf32>) -> tensor<31250x1200x16x1xf32> {
948   %cst = arith.constant 0.000000e+00 : f32
949   %0 = tensor.empty() : tensor<31250x1200x16x1xf32>
950   %pack = tensor.pack %arg0
951     padding_value(%cst : f32)
952     outer_dims_perm = [1, 0]
953     inner_dims_pos = [1, 0]
954     inner_tiles = [16, 1]
955     into %0 : tensor<1200x500000xf32> -> tensor<31250x1200x16x1xf32>
956   return %pack : tensor<31250x1200x16x1xf32>
958 // CHECK-LABEL: func @fold_padding_value_pack
959 // CHECK-NOT:     padding_value
961 // -----
963 func.func @infer_src_shape_pack(%src: tensor<?x?x?x?xf32>, %dest: tensor<10x20x30x40x16xf32>) -> tensor<10x20x30x40x16xf32> {
964   %cst = arith.constant 0.000000e+00 : f32
965    %pack = tensor.pack %src
966     padding_value(%cst : f32)
967     outer_dims_perm = [2, 1, 3, 0]
968     inner_dims_pos = [2]
969     inner_tiles = [16]
970     into %dest : tensor<?x?x?x?xf32> -> tensor<10x20x30x40x16xf32>
971   return %pack : tensor<10x20x30x40x16xf32>
973 // CHECK-LABEL: func.func @infer_src_shape_pack
974 // CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
975 // CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
976 // CHECK:         %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?xf32> to tensor<40x20x?x30xf32>
977 // CHECK:         %[[PACK:.+]] = tensor.pack %[[CAST_SRC]] {{.+}} into %[[DEST]]
978 // CHECK:         return %[[PACK]]
980 // -----
982 func.func @infer_dest_shape_pack(%src: tensor<30x20x?x10xf32>, %dest: tensor<?x?x?x?x16xf32>) -> tensor<?x?x?x?x16xf32> {
983   %cst = arith.constant 0.000000e+00 : f32
984    %pack = tensor.pack %src
985     padding_value(%cst : f32)
986     outer_dims_perm = [2, 1, 3, 0]
987     inner_dims_pos = [2]
988     inner_tiles = [16]
989     into %dest : tensor<30x20x?x10xf32> -> tensor<?x?x?x?x16xf32>
990   return %pack : tensor<?x?x?x?x16xf32>
992 // CHECK-LABEL: func.func @infer_dest_shape_pack
993 // CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
994 // CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
995 // CHECK:         %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?x16xf32> to tensor<?x20x10x30x16xf32>
996 // CHECK:         %[[PACK:.+]] = tensor.pack %[[SRC]] {{.+}} into %[[CAST_DEST]]
997 // CHECK:         %[[CAST_PACK:.+]] = tensor.cast %[[PACK]] : tensor<?x20x10x30x16xf32> to tensor<?x?x?x?x16xf32>
998 // CHECK:         return %[[CAST_PACK]]
1000 // -----
1002 func.func @no_infer_pack_shape(%arg0: tensor<?x32x100xf32>, %arg1: index) -> tensor<32x7x?x16x1xf32> {
1003   %cst = arith.constant 0.000000e+00 : f32
1004   %0 = tensor.empty(%arg1) : tensor<32x7x?x16x1xf32>
1005   %pack = tensor.pack %arg0 padding_value(%cst : f32) outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [16, 1] into %0 : tensor<?x32x100xf32> -> tensor<32x7x?x16x1xf32>
1006   return %pack : tensor<32x7x?x16x1xf32>
1008 // CHECK-LABEL: func.func @no_infer_pack_shape
1009 // CHECK-NOT:     tensor.cast
1011 // -----
1013 func.func @fold_padding_value_pack_negative1(%arg0: tensor<1200x499999xf32>) -> tensor<31250x1200x16x1xf32> {
1014   %cst = arith.constant 0.000000e+00 : f32
1015   %0 = tensor.empty() : tensor<31250x1200x16x1xf32>
1016   %pack = tensor.pack %arg0
1017     padding_value(%cst : f32)
1018     outer_dims_perm = [1, 0]
1019     inner_dims_pos = [1, 0]
1020     inner_tiles = [16, 1]
1021     into %0 : tensor<1200x499999xf32> -> tensor<31250x1200x16x1xf32>
1022   return %pack : tensor<31250x1200x16x1xf32>
1024 // CHECK-LABEL: func @fold_padding_value_pack_negative1
1025 // CHECK:         tensor.pack
1026 // CHECK-SAME:      padding_value
1028 // -----
1030 func.func @fold_padding_value_pack_negative2(%arg0: tensor<1200x?xf32>, %arg1: tensor<?x1200x16x1xf32>) -> tensor<?x1200x16x1xf32> {
1031   %cst = arith.constant 0.000000e+00 : f32
1032   %pack = tensor.pack %arg0
1033     padding_value(%cst : f32)
1034     outer_dims_perm = [1, 0]
1035     inner_dims_pos = [1, 0]
1036     inner_tiles = [16, 1]
1037     into %arg1 : tensor<1200x?xf32> -> tensor<?x1200x16x1xf32>
1038   return %pack : tensor<?x1200x16x1xf32>
1040 // CHECK-LABEL: func @fold_padding_value_pack_negative2
1041 // CHECK:         tensor.pack
1042 // CHECK-SAME:      padding_value
1044 // -----
1046 func.func @fold_padding_value_pack_negative3(%arg0: tensor<1200x500000xf32>, %arg1: tensor<?x1200x?x1xf32>, %tile : index) -> tensor<?x1200x?x1xf32> {
1047   %cst = arith.constant 0.000000e+00 : f32
1048   %pack = tensor.pack %arg0
1049     padding_value(%cst : f32)
1050     outer_dims_perm = [1, 0]
1051     inner_dims_pos = [1, 0]
1052     inner_tiles = [%tile, 1]
1053     into %arg1 : tensor<1200x500000xf32> -> tensor<?x1200x?x1xf32>
1054   return %pack : tensor<?x1200x?x1xf32>
1056 // CHECK-LABEL: func @fold_padding_value_pack_negative3
1057 // CHECK:         tensor.pack
1058 // CHECK-SAME:      padding_value
1060 // -----
1062 // CHECK-LABEL: func @fold_unpack_constant_splat
1063 //   CHECK-NOT: tensor.unpack
1064 //       CHECK: arith.constant dense<1.000000e-01> : tensor<128x256xf32>
1065 func.func @fold_unpack_constant_splat(%dest : tensor<128x256xf32>) -> tensor<128x256xf32> {
1066   %cst = arith.constant dense<1.000000e-01> : tensor<16x8x8x32xf32>
1067   %0 = tensor.unpack %cst inner_dims_pos = [0, 1]
1068     inner_tiles = [8, 32] into %dest : tensor<16x8x8x32xf32> -> tensor<128x256xf32>
1069   return %0 : tensor<128x256xf32>
1072 // -----
1074 func.func @infer_dest_shape_unpack(%src: tensor<10x20x30x40x16xf32>, %dest: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
1075   %unpack = tensor.unpack %src
1076     outer_dims_perm = [2, 1, 3, 0]
1077     inner_dims_pos = [2]
1078     inner_tiles = [16]
1079     into %dest : tensor<10x20x30x40x16xf32> -> tensor<?x?x?x?xf32>
1080   return %unpack : tensor<?x?x?x?xf32>
1082 // CHECK-LABEL: func.func @infer_dest_shape_unpack
1083 // CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
1084 // CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
1085 // CHECK:         %[[CAST_DEST:.+]] = tensor.cast %[[DEST]] : tensor<?x?x?x?xf32> to tensor<40x20x?x30xf32>
1086 // CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[SRC]] {{.+}} into %[[CAST_DEST]]
1087 // CHECK:         %[[CAST_UNPACK:.+]] = tensor.cast %[[UNPACK]] : tensor<40x20x?x30xf32> to tensor<?x?x?x?xf32>
1088 // CHECK:         return %[[CAST_UNPACK]]
1090 // -----
1092 func.func @infer_src_shape_unpack(%src: tensor<?x?x?x?x16xf32>, %dest: tensor<30x20x?x10xf32>) -> tensor<30x20x?x10xf32> {
1093   %unpack = tensor.unpack %src
1094     outer_dims_perm = [2, 1, 3, 0]
1095     inner_dims_pos = [2]
1096     inner_tiles = [16]
1097     into %dest : tensor<?x?x?x?x16xf32> -> tensor<30x20x?x10xf32>
1098   return %unpack : tensor<30x20x?x10xf32>
1100 // CHECK-LABEL: func.func @infer_src_shape_unpack
1101 // CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
1102 // CHECK-SAME:    %[[DEST:[0-9a-zA-Z]+]]
1103 // CHECK:         %[[CAST_SRC:.+]] = tensor.cast %[[SRC]] : tensor<?x?x?x?x16xf32> to tensor<?x20x10x30x16xf32>
1104 // CHECK:         %[[UNPACK:.+]] = tensor.unpack %[[CAST_SRC]]
1105 // CHECK:         return %[[UNPACK]]
1107 // -----
1109 func.func @no_infer_unpack_shape(%arg1: tensor<32x7x?x16x1xf32>, %arg2: index) -> tensor<?x32x100xf32> {
1110   %cst = arith.constant 0.000000e+00 : f32
1111   %0 = tensor.empty(%arg2) : tensor<?x32x100xf32>
1112   %unpack = tensor.unpack %arg1 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [16, 1] into %0 : tensor<32x7x?x16x1xf32> -> tensor<?x32x100xf32>
1113   return %unpack : tensor<?x32x100xf32>
1115 // CHECK-LABEL: func.func @no_infer_unpack_shape
1116 // CHECK-NOT:     tensor.cast
1118 // -----
1121 // CHECK-LABEL: func @fold_overlapping_insert
1122 //  CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
1123 func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {
1124   %c0 = arith.constant 0: index
1125   %c1 = arith.constant 1: index
1126   %0 = tensor.insert_slice %slice1 into %input[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32>
1127   // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE2]] into %[[INPUT]]
1128   %1 = tensor.insert_slice %slice2 into %0[0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32>
1129   // CHECK: return %[[INSERT]]
1130   return %1 : tensor<?x?x?xf32>
1133 // -----
1135 func.func @compose_expand_of_expand(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index)
1136     -> tensor<?x6x4x?x5xf32> {
1137   %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
1138       : tensor<?x?xf32> into tensor<?x4x?xf32>
1139   %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4]] output_shape [%arg3, 6, 4, %arg4, 5] : tensor<?x4x?xf32> into tensor<?x6x4x?x5xf32>
1140   return %1 : tensor<?x6x4x?x5xf32>
1142 // CHECK-LABEL: compose_expand_of_expand
1143 //       CHECK:   tensor.expand_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]] output_shape [%arg3, 6, 4, %arg4, 5]
1144 //   CHECK-NOT:   tensor.expand_shape
1146 // -----
1148 func.func @compose_expand_of_expand_of_zero_dim(%arg0 : tensor<f32>)
1149     -> tensor<1x1x1xf32> {
1150   %0 = tensor.expand_shape %arg0 [] output_shape [1] : tensor<f32> into tensor<1xf32>
1151   %1 = tensor.expand_shape %0 [[0, 1, 2]] output_shape [1, 1, 1]
1152       : tensor<1xf32> into tensor<1x1x1xf32>
1153   return %1 : tensor<1x1x1xf32>
1155 // CHECK-LABEL: compose_expand_of_expand_of_zero_dim
1156 //       CHECK:   tensor.expand_shape %{{.*}} [] output_shape [1, 1, 1]
1157 //  CHECK-SAME:     tensor<f32> into tensor<1x1x1xf32>
1159 // -----
1161 // CHECK-LABEL: func.func @collapse_of_cast(
1162 // CHECK-SAME:         %[[IN:.*]]: tensor<8x12x32xf32>) -> tensor<?x32xf32> {
1163 // CHECK-NEXT:    %[[COLLAPSE:.*]] = tensor.collapse_shape %[[IN]] {{\[}}[0, 1], [2]] : tensor<8x12x32xf32> into tensor<96x32xf32>
1164 // CHECK-NEXT:    %[[CAST:.*]] = tensor.cast %[[COLLAPSE]] : tensor<96x32xf32> to tensor<?x32xf32>
1165 // CHECK-NEXT:    return %[[CAST]] : tensor<?x32xf32>
1166 func.func @collapse_of_cast(%t: tensor<8x12x32xf32>) -> tensor<?x32xf32> {
1167   %0 = tensor.cast %t : tensor<8x12x32xf32> to tensor<?x?x?xf32>
1168   %1 = tensor.collapse_shape %0 [[0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
1169   %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<?x32xf32>
1170   return %2 : tensor<?x32xf32>
1173 // -----
1175 func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> {
1176   %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [3, 4, 4]
1177       : tensor<12x4xf32> into tensor<3x4x4xf32>
1178   %1 = tensor.collapse_shape %0 [[0, 1], [2]]
1179       : tensor<3x4x4xf32> into tensor<12x4xf32>
1180   return %1 : tensor<12x4xf32>
1182 // CHECK-LABEL: @fold_collapse_of_expand
1183 //   CHECK-NOT:   tensor.{{.*}}_shape
1185 // -----
1187 func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index)
1188     -> tensor<?x?xf32> {
1189   %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
1190       : tensor<?x?xf32> into tensor<?x4x?xf32>
1191   %1 = tensor.collapse_shape %0 [[0, 1], [2]]
1192       : tensor<?x4x?xf32> into tensor<?x?xf32>
1193   return %1 : tensor<?x?xf32>
1195 // CHECK-LABEL: @fold_collapse_of_expand_dynamic
1196 //   CHECK-NOT:   tensor.{{.*}}_shape
1198 // -----
1200 func.func @fold_collapse_of_expand_fully_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
1201     -> tensor<?x?xf32> {
1202   %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
1203       : tensor<?x?xf32> into tensor<?x?x?xf32>
1204   %1 = tensor.collapse_shape %0 [[0, 1], [2]]
1205       : tensor<?x?x?xf32> into tensor<?x?xf32>
1206   return %1 : tensor<?x?xf32>
1208 // CHECK-LABEL: @fold_collapse_of_expand_fully_dynamic
1209 //   CHECK-NOT:   tensor.{{.*}}_shape
1211 // -----
1213 func.func @no_fold_parallel_collapse_of_expand_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index)
1214     -> tensor<?x?x?xf32> {
1215   %0 = tensor.expand_shape %arg0 [[0, 1], [2], [3]] output_shape [%arg1, %arg2, %arg3, %arg4]
1216       : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
1217   %1 = tensor.collapse_shape %0 [[0], [1], [2, 3]]
1218       : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
1219   return %1 : tensor<?x?x?xf32>
1221 // CHECK-LABEL: @no_fold_parallel_collapse_of_expand_dynamic
1222 //       CHECK:   tensor.expand_shape
1223 //       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape
1224 //       CHECK:   return %[[COLLAPSE]]
1226 // -----
1228 func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf32> {
1229   %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
1230       : tensor<3x4x4xf32> into tensor<12x4xf32>
1231   %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [3, 4, 4]
1232       : tensor<12x4xf32> into tensor<3x4x4xf32>
1233   return %1 : tensor<3x4x4xf32>
1235 // CHECK-LABEL: @fold_expand_of_collapse
1236 //   CHECK-NOT:   tensor.{{.*}}_shape
1238 // -----
1240 func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
1241     -> tensor<?x4x?xf32> {
1242   %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
1243       : tensor<?x4x?xf32> into tensor<?x?xf32>
1244   %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
1245       : tensor<?x?xf32> into tensor<?x4x?xf32>
1246   return %1 : tensor<?x4x?xf32>
1248 // CHECK-LABEL: @fold_expand_of_collapse_dynamic
1249 //   CHECK-NOT:   tensor.{{.*}}_shape
1251 // -----
1253 func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
1254     -> tensor<?x?x?xf32> {
1255   %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
1256       : tensor<?x?x?xf32> into tensor<?x?xf32>
1257   %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
1258       : tensor<?x?xf32> into tensor<?x?x?xf32>
1259   return %1 : tensor<?x?x?xf32>
1261 // CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
1262 //       CHECK:   tensor.collapse_shape
1263 //       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape
1264 //       CHECK:   return %[[EXPAND]]
1266 // -----
1268 func.func @compose_expand_of_collapse_last_two_dims(%arg0: tensor<?x64x1xf32>) -> tensor<?x384xf32> {
1269   %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2]] : tensor<?x64x1xf32> into tensor<?xf32>
1270   %c0 = arith.constant 0 : index
1271   %dim = tensor.dim %collapsed, %c0 : tensor<?xf32>
1272   %c384= arith.constant 384 : index
1273   %div = arith.divui %dim, %c384 : index
1274   %expanded = tensor.expand_shape %collapsed [[0, 1]] output_shape [%div, 384] : tensor<?xf32> into tensor<?x384xf32>
1275   return %expanded : tensor<?x384xf32>
1277 //       CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 64)>
1278 // CHECK-LABEL: @compose_expand_of_collapse_last_two_dims
1279 //  CHECK-SAME: %[[ARG0:.+]]: tensor<?x64x1xf32>
1280 //       CHECK: %[[CONSTANT0:.+]] = arith.constant 0 : index
1281 //       CHECK: %[[CONSTANT384:.+]] = arith.constant 384 : index
1282 //       CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]] : tensor<?x64x1xf32> into tensor<?xf32>
1283 //       CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[CONSTANT0]] : tensor<?x64x1xf32>
1284 //       CHECK: %[[AFFAPPLY:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
1285 //       CHECK: %[[DIVUI:.+]] = arith.divui %[[AFFAPPLY]], %[[CONSTANT384]] : index
1286 //       CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0, 1]] output_shape [%[[DIVUI]], 384] : tensor<?xf32> into tensor<?x384xf32>
1287 //       CHECK: return %[[RESULT]]
1289 // -----
1291 func.func @compose_expand_of_collapse(%arg0 : tensor<2x3x4x5x6x7x8xf32>)
1292     -> tensor<24x5x42x8xf32> {
1293   %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3, 4, 5, 6]]
1294       : tensor<2x3x4x5x6x7x8xf32> into tensor<40320xf32>
1295   %1 = tensor.expand_shape %0 [[0, 1, 2, 3]] output_shape [24, 5, 42, 8]
1296       : tensor<40320xf32> into tensor<24x5x42x8xf32>
1297   return %1 : tensor<24x5x42x8xf32>
1299 //      CHECK: func @compose_expand_of_collapse
1300 // CHECK-SAME:   %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8xf32>
1301 //      CHECK:   %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
1302 // CHECK-SAME:     [0, 1, 2], [3], [4, 5], [6]
1303 //      CHECK:   return %[[RESULT]]
1305 // -----
1307 func.func @compose_expand_of_collapse_7D(%arg0 : tensor<24x5x42x8xf32>)
1308     -> tensor<2x3x4x5x6x7x8xf32> {
1309   %0 = tensor.collapse_shape %arg0 [[0, 1, 2, 3]]
1310       : tensor<24x5x42x8xf32> into tensor<40320xf32>
1311   %1 = tensor.expand_shape %0 [[0, 1, 2, 3, 4, 5, 6]] output_shape [2, 3, 4, 5, 6, 7, 8]
1312       : tensor<40320xf32> into tensor<2x3x4x5x6x7x8xf32>
1313   return %1 : tensor<2x3x4x5x6x7x8xf32>
1315 //      CHECK: func @compose_expand_of_collapse_7D
1316 // CHECK-SAME:   %[[ARG0:.+]]: tensor<24x5x42x8xf32>
1317 //      CHECK:   %[[RESULT:.+]] = tensor.expand_shape %[[ARG0]]
1318 // CHECK-SAME:     [0, 1, 2], [3], [4, 5], [6]
1319 //      CHECK:   return %[[RESULT]]
1321 // -----
1323 func.func @compose_collapse_of_expand(%arg : tensor<?x?x?xi64>, %arg1: index, %arg2: index, %arg3: index)
1324     -> tensor<?x?xi64> {
1325   %0 = tensor.expand_shape %arg [[0], [1], [2, 3]] output_shape [%arg1, %arg2, %arg3, 1]
1326     : tensor<?x?x?xi64> into tensor<?x?x?x1xi64>
1327   %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
1328     : tensor<?x?x?x1xi64> into tensor<?x?xi64>
1329   return %1 : tensor<?x?xi64>
1331 // CHECK-LABEL: func @compose_collapse_of_expand
1332 //       CHECK:   (%[[ARG:.*]]: tensor<?x?x?xi64>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
1333 //  CHECK-NEXT: tensor.collapse_shape %[[ARG]]
1334 //  CHECK-SAME:   [0, 1], [2]
1335 //  CHECK-SAME:   : tensor<?x?x?xi64> into tensor<?x?xi64>
1337 // -----
1339 func.func @compose_collapse_of_expand_1D(%arg0 : tensor<2048xf32>)
1340     -> tensor<4x512xf32> {
1341   %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3]] output_shape [1, 4, 1, 512]
1342     : tensor<2048xf32> into tensor<1x4x1x512xf32>
1343   %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]]
1344     : tensor<1x4x1x512xf32> into tensor<4x512xf32>
1345   return %1 : tensor<4x512xf32>
1347 //       CHECK: func @compose_collapse_of_expand_1D
1348 //       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] output_shape [4, 512]
1349 //  CHECK-SAME:   tensor<2048xf32> into tensor<4x512xf32>
1351 // -----
1353 func.func @compose_expand_of_collapse_0_rank_to_expand(%arg0 : tensor<1x1x1xf32>)
1354     -> tensor<1x1x1x1xf32> {
1355   %0 = tensor.collapse_shape %arg0 []
1356       : tensor<1x1x1xf32> into tensor<f32>
1357   %1 = tensor.expand_shape %0 [] output_shape [1, 1, 1, 1]
1358       : tensor<f32> into tensor<1x1x1x1xf32>
1359   return %1 : tensor<1x1x1x1xf32>
1361 //      CHECK: func @compose_expand_of_collapse_0_rank_to_expand
1362 // CHECK-SAME:   %[[ARG0:.+]]: tensor<1x1x1xf32>
1363 //      CHECK:   %[[RESULT:.+]] = tensor.expand_shape %[[ARG0]]
1364 // CHECK-SAME:     {{\[}}[0], [1], [2, 3]] output_shape [1, 1, 1, 1]
1365 //      CHECK:   return %[[RESULT]]
1367 // -----
1369 func.func @compose_expand_of_collapse_0_rank_to_collapse(%arg0 : tensor<1x1x1x1xf32>)
1370     -> tensor<1x1x1xf32> {
1371   %0 = tensor.collapse_shape %arg0 []
1372       : tensor<1x1x1x1xf32> into tensor<f32>
1373   %1 = tensor.expand_shape %0 [] output_shape [1, 1, 1]
1374       : tensor<f32> into tensor<1x1x1xf32>
1375   return %1 : tensor<1x1x1xf32>
1377 //      CHECK: func @compose_expand_of_collapse_0_rank_to_collapse
1378 // CHECK-SAME:   %[[ARG0:.+]]: tensor<1x1x1x1xf32>
1379 //      CHECK:   %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
1380 // CHECK-SAME:     [0], [1], [2, 3]
1381 //      CHECK:   return %[[RESULT]]
1383 // -----
1385 // CHECK-LABEL: func @zero_rank_reshape_multi
1386 func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
1387   // CHECK: return %arg0
1388   %0 = tensor.expand_shape %arg0 [] output_shape [1] : tensor<f32> into tensor<1xf32>
1389   %1 = tensor.expand_shape %0 [[0, 1]] output_shape [1, 1] : tensor<1xf32> into tensor<1x1xf32>
1390   %2 = tensor.collapse_shape %1 [] : tensor<1x1xf32> into tensor<f32>
1391   return %2 : tensor<f32>
1394 // -----
1396 func.func @compose_collapse_of_collapse(%arg0 : tensor<?x?x?x?x?xf32>)
1397     -> tensor<?x?xf32> {
1398   %0 = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]]
1399       : tensor<?x?x?x?x?xf32> into tensor<?x?x?xf32>
1400   %1 = tensor.collapse_shape %0 [[0, 1], [2]]
1401       : tensor<?x?x?xf32> into tensor<?x?xf32>
1402   return %1 : tensor<?x?xf32>
1404 // CHECK-LABEL: func @compose_collapse_of_collapse
1405 //       CHECK:   tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2], [3, 4]]
1406 //   CHECK-NOT:   tensor.collapse_shape
1408 // -----
1410 func.func @compose_collapse_of_collapse_zero_dim(%arg0 : tensor<1x1x1xf32>)
1411     -> tensor<f32> {
1412   %0 = tensor.collapse_shape %arg0 [[0, 1, 2]]
1413       : tensor<1x1x1xf32> into tensor<1xf32>
1414   %1 = tensor.collapse_shape %0 [] : tensor<1xf32> into tensor<f32>
1415   return %1 : tensor<f32>
1417 // CHECK-LABEL: func @compose_collapse_of_collapse_zero_dim
1418 //       CHECK:   tensor.collapse_shape %{{.*}} []
1419 //  CHECK-SAME:     tensor<1x1x1xf32> into tensor<f32>
1421 // -----
1423 func.func @fold_collapse_of_expand_1D(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32> {
1424   %0 = tensor.expand_shape %arg0 [[0, 1, 2], [3]] output_shape [1, 4, 1, 512]
1425     : tensor<4x512xf32> into tensor<1x4x1x512xf32>
1426   %1 = tensor.collapse_shape %0 [[0, 1, 2, 3]]
1427     : tensor<1x4x1x512xf32> into tensor<2048xf32>
1428   return %1 : tensor<2048xf32>
1430 //       CHECK: func @fold_collapse_of_expand_1D
1431 //       CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1]]
1432 //  CHECK-SAME:   tensor<4x512xf32> into tensor<2048xf32>
1434 // -----
1436 func.func @fold_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x1xf32>)
1437     -> tensor<4x512x1x1xf32> {
1438   %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3], [4], [5]] output_shape [1, 4, 1, 512, 1, 1] : tensor<2048x1x1xf32> into tensor<1x4x1x512x1x1xf32>
1439   %1 = tensor.collapse_shape %0 [[0, 1, 2], [3], [4], [5]]
1440     : tensor<1x4x1x512x1x1xf32> into tensor<4x512x1x1xf32>
1441   return %1 : tensor<4x512x1x1xf32>
1443 //       CHECK: func @fold_collapse_of_expand_unit_dims
1444 //       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3]] output_shape [4, 512, 1, 1]
1445 //  CHECK-SAME:   tensor<2048x1x1xf32> into tensor<4x512x1x1xf32>
1447 // -----
1449 func.func @compose_collapse_of_expand_unit_dims(%arg0 : tensor<2048x1x2048xf32>)
1450     -> tensor<4x512x1x512x4xf32> {
1451   %0 = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4], [5], [6, 7, 8]] output_shape [1, 4, 1, 512, 1, 1, 512, 1, 4] : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32>
1452   %1 = tensor.collapse_shape %0 [[0, 1, 2], [3, 4], [5], [6, 7], [8]]
1453     : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32>
1454   return %1 : tensor<4x512x1x512x4xf32>
1456 //       CHECK: func @compose_collapse_of_expand_unit_dims
1457 //       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2], [3, 4]] output_shape [4, 512, 1, 512, 4]
1458 //  CHECK-SAME:   tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32>
1460 // -----
1462 func.func @compose_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>)
1463     -> tensor<2x1xf32> {
1464   %0 = tensor.expand_shape %arg0 [[0, 1, 2]] output_shape [2, 1, 1]
1465       : tensor<2xf32> into tensor<2x1x1xf32>
1466   %1 = tensor.collapse_shape %0 [[0], [1, 2]]
1467       : tensor<2x1x1xf32> into tensor<2x1xf32>
1468   return %1 : tensor<2x1xf32>
1470 //       CHECK: func @compose_collapse_of_expand_trailing_unit_dims
1471 //       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] output_shape [2, 1]
1472 //  CHECK-SAME:   tensor<2xf32> into tensor<2x1xf32>
1474 // -----
1476 func.func @compose_collapse_of_collapse_unit_dims_dynamic(
1477     %arg0 : tensor<?x1x?x1x1x?x?x1x1xf32>) -> tensor<?x?x?x?xf32> {
1478   %0 = tensor.collapse_shape %arg0 [[0], [1, 2], [3], [4], [5], [6, 7, 8]]
1479     : tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x1x1x?x?xf32>
1480   %1 = tensor.collapse_shape %0 [[0], [1], [2, 3, 4], [5]]
1481     : tensor<?x?x1x1x?x?xf32> into tensor<?x?x?x?xf32>
1482   return %1 : tensor<?x?x?x?xf32>
1484 //       CHECK: func @compose_collapse_of_collapse_unit_dims_dynamic
1485 //       CHECK: tensor.collapse_shape
1486 //  CHECK-SAME:   [0], [1, 2], [3, 4, 5], [6, 7, 8]
1487 //  CHECK-SAME:   tensor<?x1x?x1x1x?x?x1x1xf32> into tensor<?x?x?x?xf32>
1489 // -----
1491 func.func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<2xf32>)
1492     -> tensor<2x1xf32> {
1493   %0 = tensor.expand_shape %arg0 [[0, 1, 2]] output_shape [2, 1, 1] : tensor<2xf32> into tensor<2x1x1xf32>
1494   %1 = tensor.collapse_shape %0 [[0], [1, 2]]
1495       : tensor<2x1x1xf32> into tensor<2x1xf32>
1496   return %1 : tensor<2x1xf32>
1498 //       CHECK: func @fold_collapse_of_expand_trailing_unit_dims
1499 //       CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1]] output_shape [2, 1]
1500 //  CHECK-SAME:   tensor<2xf32> into tensor<2x1xf32>
1502 // -----
1504 func.func @fold_collapse_of_collapse_trailing_unit_dims_dynamic(
1505     %arg0: tensor<1x1x?x1x1x1xf32>) -> tensor<?xf32> {
1506   %0 = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4], [5]]
1507       : tensor<1x1x?x1x1x1xf32> into tensor<?x1x1x1xf32>
1508   %1 = tensor.collapse_shape %0 [[0, 1, 2, 3]]
1509       : tensor<?x1x1x1xf32> into tensor<?xf32>
1510   return %1 : tensor<?xf32>
1512 //       CHECK: func @fold_collapse_of_collapse_trailing_unit_dims_dynamic
1513 //       CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4, 5]]
1514 //  CHECK-SAME:   tensor<1x1x?x1x1x1xf32> into tensor<?xf32>
1516 // -----
1518 func.func @fold_collapse_of_expand_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>)
1519     -> tensor<12x42xf32> {
1520   %0 = tensor.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [12, 42, 1, 1, 1] : tensor<12x42x1x1xf32> into tensor<12x42x1x1x1xf32>
1521   %1 = tensor.collapse_shape %0 [[0], [1, 2, 3, 4]]
1522       : tensor<12x42x1x1x1xf32> into tensor<12x42xf32>
1523   return %1 : tensor<12x42xf32>
1525 //       CHECK: func @fold_collapse_of_expand_trailing_unit_dims
1526 //       CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0], [1, 2, 3]]
1527 //  CHECK-SAME:   tensor<12x42x1x1xf32> into tensor<12x42xf32>
1529 // -----
1531 func.func @fold_collapse_of_expand_unit_dims_in_middle(%arg0 : tensor<?x?x?xf32>, %sz0: index, %sz1: index, %sz2: index)
1532     -> tensor<?x?xf32> {
1533   %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [%sz0, %sz1, 1, %sz2]
1534       : tensor<?x?x?xf32> into tensor<?x?x1x?xf32>
1535   %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]]
1536       : tensor<?x?x1x?xf32> into tensor<?x?xf32>
1537   return %1 : tensor<?x?xf32>
1539 // CHECK-LABEL: func @fold_collapse_of_expand_unit_dims_in_middle
1540 //  CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>
1541 //       CHECK: tensor.collapse_shape %[[ARG]] {{\[}}[0], [1, 2]]
1542 //  CHECK-SAME:   tensor<?x?x?xf32> into tensor<?x?xf32>
1544 // -----
1546 func.func @no_fold_collapse_of_expand_incompatible(%arg0 : tensor<4x6x8xf32>)
1547     -> tensor<2x6x16xf32> {
1548   %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3], [4]] output_shape [2, 2, 3, 2, 8]
1549       : tensor<4x6x8xf32> into tensor<2x2x3x2x8xf32>
1550   %1 = tensor.collapse_shape %0 [[0], [1, 2], [3, 4]]
1551       : tensor<2x2x3x2x8xf32> into tensor<2x6x16xf32>
1552   return %1 : tensor<2x6x16xf32>
1554 // CHECK-LABEL: func @no_fold_collapse_of_expand_incompatible
1555 //       CHECK:   tensor.expand_shape
1556 //       CHECK:   tensor.collapse_shape
1558 // -----
1560 func.func @no_fold_collapse_of_expand_empty_expr(%arg0: tensor<3x2x2xf32>)
1561     -> tensor<12x1xf32> {
1562   %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] output_shape [3, 2, 2, 1]
1563       : tensor<3x2x2xf32> into tensor<3x2x2x1xf32>
1564   %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]]
1565       : tensor<3x2x2x1xf32> into tensor<12x1xf32>
1566   return %1 : tensor<12x1xf32>
1568 //      CHECK: func @no_fold_collapse_of_expand_empty_expr
1569 // CHECK-SAME:    %[[ARG0:.+]]: tensor<3x2x2xf32>
1570 //      CHECK:    %[[RARG0:.+]] = tensor.expand_shape %[[ARG0]]
1571 // CHECK-SAME:      {{\[}}[0], [1], [2, 3]] output_shape [3, 2, 2, 1]
1572 //      CHECK:    %[[RES:.+]] = tensor.collapse_shape %[[RARG0]]
1573 // CHECK-SAME:      [0, 1, 2], [3]
1574 //      CHECK:    return %[[RES:.+]] : tensor<12x1xf32>
1576 // -----
1578 func.func @reshape_splat_constant_int32() -> tensor<2x4x2xi32> {
1579   %c0 = arith.constant dense<42> : tensor<2x8xi32>
1580   %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2]
1581       : tensor<2x8xi32> into tensor<2x4x2xi32>
1582   return %0 : tensor<2x4x2xi32>
1584 // CHECK-LABEL: @reshape_splat_constant_int32
1585 //       CHECK:   %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xi32>
1586 //   CHECK-NOT:   tensor.expand_shape
1587 //       CHECK:   return %[[CST]]
1588 // -----
1589 func.func @expand_shape_splat(%arg : f32) -> tensor<2x2x2xf32> {
1590   %c0 = tensor.splat %arg : tensor<2x4xf32>
1591   %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 2, 2]
1592       : tensor<2x4xf32> into tensor<2x2x2xf32>
1593   return %0 : tensor<2x2x2xf32>
1595 // CHECK-LABEL: @expand_shape_splat
1596 // CHECK-SAME:    %[[ARG0:.+]]: f32
1597 //       CHECK:   %[[CST:.*]] = tensor.splat %[[ARG0:.+]] : tensor<2x2x2xf32>
1598 //   CHECK-NOT:   tensor.expand_shape
1599 //       CHECK:   return %[[CST]]
1601 // -----
1603 // CHECK-LABEL: @expand_shape_splat_dynamic_no_fold
1604 // CHECK-SAME: (%[[F:.+]]: f32, %[[M:.+]]: index, %[[SZ0:.+]]: index)
1605 func.func @expand_shape_splat_dynamic_no_fold(%arg: f32, %m: index, %sz0: index) -> tensor<2x2x?xf32> {
1606   // CHECK: %[[SPLAT:.+]] = tensor.splat %[[F]][%[[M]]] : tensor<2x?xf32>
1607   // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[SPLAT]]
1608   %c0 = tensor.splat %arg[%m] : tensor<2x?xf32>
1609   %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 2, %sz0] : tensor<2x?xf32> into tensor<2x2x?xf32>
1610   return %0 : tensor<2x2x?xf32>
1613 // -----
1615 func.func @collapse_shape_splat(%arg : f32) -> tensor<2x4xf32> {
1616   %c0 = tensor.splat %arg : tensor<2x2x2xf32>
1617   %0 = tensor.collapse_shape %c0 [[0], [1, 2]]
1618       : tensor<2x2x2xf32> into tensor<2x4xf32>
1619   return %0 : tensor<2x4xf32>
1621 // CHECK-LABEL: @collapse_shape_splat
1622 // CHECK-SAME:    %[[ARG0:.+]]: f32
1623 //       CHECK:   %[[CST:.*]] = tensor.splat %[[ARG0:.+]] : tensor<2x4xf32>
1624 //   CHECK-NOT:   tensor.collapse_shape
1625 //       CHECK:   return %[[CST]]
1627 // -----
1629 // CHECK-LABEL: @collapse_shape_splat_dynamic_no_fold
1630 // CHECK-SAME: %[[F:.+]]: f32
1631 // CHECK-SAME: %[[M:.+]]: index
1632 func.func @collapse_shape_splat_dynamic_no_fold(%f: f32, %m: index) -> tensor<2x?xf32> {
1633   // CHECK: %[[SPLAT:.+]] = tensor.splat %[[F]][%[[M]]]
1634   // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[SPLAT]]
1635   %c0 = tensor.splat %f[%m] : tensor<2x2x?xf32>
1636   %0 = tensor.collapse_shape %c0 [[0], [1, 2]] : tensor<2x2x?xf32> into tensor<2x?xf32>
1637   return %0 : tensor<2x?xf32>
1640 // -----
1642 func.func @reshape_splat_constant_int16() -> tensor<2x4x2xi16> {
1643   %c0 = arith.constant dense<42> : tensor<2x8xi16>
1644   %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2]
1645       : tensor<2x8xi16> into tensor<2x4x2xi16>
1646   return %0 : tensor<2x4x2xi16>
1648 // CHECK-LABEL: @reshape_splat_constant_int16
1649 //       CHECK:   %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xi16>
1650 //   CHECK-NOT:   tensor.expand_shape
1651 //       CHECK:   return %[[CST]]
1653 // -----
1655 func.func @reshape_splat_constant_float32() -> tensor<2x4x2xf32> {
1656   %c0 = arith.constant dense<42.0> : tensor<2x8xf32>
1657   %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2]
1658       : tensor<2x8xf32> into tensor<2x4x2xf32>
1659   return %0 : tensor<2x4x2xf32>
1661 // CHECK-LABEL: @reshape_splat_constant_float32
1662 //       CHECK:   %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xf32>
1663 //   CHECK-NOT:   tensor.expand_shape
1664 //       CHECK:   return %[[CST]]
1666 // -----
1668 func.func @reshape_splat_constant_float64() -> tensor<2x4x2xf64> {
1669   %c0 = arith.constant dense<42.0> : tensor<2x8xf64>
1670   %0 = tensor.expand_shape %c0 [[0], [1, 2]] output_shape [2, 4, 2]
1671       : tensor<2x8xf64> into tensor<2x4x2xf64>
1672   return %0 : tensor<2x4x2xf64>
1674 // CHECK-LABEL: @reshape_splat_constant_float64
1675 //       CHECK:   %[[CST:.*]] = arith.constant dense<{{.*}}> : tensor<2x4x2xf64>
1676 //   CHECK-NOT:   tensor.expand_shape
1677 //       CHECK:   return %[[CST]]
1679 // -----
1681 // CHECK-LABEL: func @fold_rank
1682 func.func @fold_rank() -> (index) {
1683   %const_0 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]>
1684     : tensor<2x1x4xi32>
1686   // Fold a ank into a constant
1687   // CHECK-NEXT: [[C3:%.+]] = arith.constant 3 : index
1688   %rank_0 = tensor.rank %const_0 : tensor<2x1x4xi32>
1690   // CHECK-NEXT: return [[C3]]
1691   return %rank_0 : index
1694 // -----
1696 // CHECK-LABEL: func @pad_same_static_shape(
1697 //  CHECK-SAME:   %[[ARG0:.*]]: tensor<5x6xf32>
1698 //   CHECK-NOT:   tensor.pad
1699 //       CHECK:   return %[[ARG0]]
1700 func.func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
1701     -> tensor<5x6xf32> {
1702   %cst = arith.constant 0.000000e+00 : f32
1703   %0 = tensor.pad %arg0 low[%a, 0] high[0, %a] {
1704         ^bb0(%arg1: index, %arg2: index):
1705           tensor.yield %cst : f32
1706   } : tensor<5x6xf32> to tensor<5x6xf32>
1707   return %0 : tensor<5x6xf32>
1710 // -----
1712 // CHECK-LABEL:   func @pad_fold_static(
1713 // CHECK-SAME:      %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
1714 // CHECK:           %[[CST:.*]] = arith.constant 0.000000e+00 : f32
1715 // CHECK-NOT:       arith.constant 4 : index
1716 // CHECK:           %[[PADDED:.*]] = tensor.pad %[[INPUT]]
1717 // CHECK-SAME:        low[0, 4, 1, 1] high[0, 4, 1, 1]  {
1718 // CHECK:           ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
1719 // CHECK:             tensor.yield %[[CST]] : f32
1720 // CHECK:           } : tensor<?x64x?x?xf32> to tensor<?x72x?x?xf32>
1721 // CHECK:           tensor.cast
1722 func.func @pad_fold_static(%arg0: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
1723   %c0 = arith.constant 0 : index
1724   %cst = arith.constant 0.000000e+00 : f32
1725   %padding = arith.constant 4 : index
1726   %padded = tensor.pad %arg0 low[0, %padding, 1, 1] high[0, %padding, 1, 1]  {
1727     ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
1728     tensor.yield %cst: f32
1729   } : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
1730   return %padded : tensor<?x?x?x?xf32>
1733 // -----
1735 // CHECK-LABEL: func @pad_nofold_same_static_shape(
1736 //  CHECK-SAME:   %[[ARG0:.*]]: tensor<5x6xf32>
1737 //       CHECK:   %[[PAD:.*]] = tensor.pad
1738 //       CHECK:   return %[[PAD]]
1739 func.func @pad_nofold_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
1740     -> tensor<5x6xf32> {
1741   %cst = arith.constant 0.000000e+00 : f32
1742   %0 = tensor.pad %arg0 nofold low[%a, 0] high[0, %a] {
1743         ^bb0(%arg1: index, %arg2: index):
1744           tensor.yield %cst : f32
1745   } : tensor<5x6xf32> to tensor<5x6xf32>
1746   return %0 : tensor<5x6xf32>
1749 // -----
1751 // CHECK-LABEL:   func @pad_after_cast_different_shape(
1752 // CHECK-SAME:      %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
1753 // CHECK:           %[[CST:.*]] = arith.constant 0.000000e+00 : f32
1754 // CHECK:           %[[PADDED:.*]] = tensor.pad %[[INPUT]]
1755 // CHECK-SAME:        low[0, 0, 1, 1] high[0, 0, 1, 1]  {
1756 // CHECK:           ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
1757 // CHECK:             tensor.yield %[[CST]] : f32
1758 // CHECK:           } : tensor<?x64x?x?xf32> to tensor<?x64x?x?xf32>
1759 // CHECK:           %[[DYNAMIC:.*]] = tensor.cast %[[PADDED:.*]] :
1760 // CHECK-SAME:         tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
1761 // CHECK:           return %[[DYNAMIC]] : tensor<?x?x?x?xf32>
1762 // CHECK:         }
1763 func.func @pad_after_cast_different_shape(%arg0: tensor<?x64x?x?xf32>)
1764     -> tensor<?x?x?x?xf32> {
1765   %cst = arith.constant 0.000000e+00 : f32
1766   %dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
1767   %padded = tensor.pad %dynamic low[0, 0, 1, 1] high[0, 0, 1, 1]  {
1768     ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
1769     tensor.yield %cst: f32
1770   } : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
1771   return %padded: tensor<?x?x?x?xf32>
1774 // -----
1776 // CHECK-LABEL:   func @pad_after_cast_same_shape(
1777 // CHECK-SAME:      %[[INPUT:.*]]: tensor<?x64x?x?xf32>,
1778 // CHECK-SAME:      %[[PADDING:.*]]: index) -> tensor<?x?x?x?xf32> {
1779 // CHECK:           %[[CST:.*]] = arith.constant 0.000000e+00 : f32
1780 // CHECK:           %[[PADDED:.*]] = tensor.pad %[[INPUT]]
1781 // CHECK-SAME:        low[0, %[[PADDING]], 1, 1] high[0, %[[PADDING]], 1, 1]  {
1782 // CHECK:           ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
1783 // CHECK:             tensor.yield %[[CST]] : f32
1784 // CHECK:           } : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
1785 // CHECK:           return %[[PADDED:.*]] : tensor<?x?x?x?xf32>
1786 // CHECK:         }
1787 func.func @pad_after_cast_same_shape(%arg0: tensor<?x64x?x?xf32>, %padding : index)
1788     -> tensor<?x?x?x?xf32> {
1789   %cst = arith.constant 0.000000e+00 : f32
1790   %dynamic = tensor.cast %arg0 : tensor<?x64x?x?xf32> to tensor<?x?x?x?xf32>
1791   %padded = tensor.pad %dynamic low[0, %padding, 1, 1] high[0, %padding, 1, 1]  {
1792     ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
1793     tensor.yield %cst: f32
1794   } : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
1795   return %padded: tensor<?x?x?x?xf32>
1798 // -----
1800 // CHECK-LABEL: func @pad_of_cast(
1801 // CHECK-NOT:     tensor.cast
1802 // CHECK:         tensor.pad
1803 // CHECK:         tensor<8x?xf32> to tensor<8x32xf32>
1804 func.func @pad_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> {
1805   %c0 = arith.constant 0 : index
1806   %cst = arith.constant 0.000000e+00 : f32
1807   %0 = tensor.cast %t : tensor<8x?xf32> to tensor<?x?xf32>
1808   %1 = tensor.pad %0 low[%c0, %c0] high[%c0, %s]  {
1809   ^bb0(%arg9: index, %arg10: index):
1810     tensor.yield %cst : f32
1811   } : tensor<?x?xf32> to tensor<8x32xf32>
1812   return %1 : tensor<8x32xf32>
1815 // -----
1817 // CHECK-LABEL: @cast_of_pad_more_static
1818 func.func @cast_of_pad_more_static(%arg0: tensor<?x?xf32>, %padding: index) -> tensor<32x32xf32> {
1819   %cst = arith.constant 0.000000e+00 : f32
1820   // CHECK: %[[PAD:.*]] = tensor.pad
1821   // CHECK: tensor<?x?xf32> to tensor<32x32xf32>
1822   %padded = tensor.pad %arg0 low[%padding, %padding] high[0, 0] {
1823   ^bb0(%arg1: index, %arg2: index):
1824     tensor.yield %cst : f32
1825   } : tensor<?x?xf32> to tensor<?x?xf32>
1826   // CHECK-NOT: tensor.cast
1827   %casted = tensor.cast %padded : tensor<?x?xf32> to tensor<32x32xf32>
1828   // CHECK: return %[[PAD]]
1829   return %casted : tensor<32x32xf32>
1832 // -----
1834 // CHECK-LABEL: @cast_of_pad_less_static
1835 func.func @cast_of_pad_less_static(%arg0: tensor<32x?x?xf32>, %padding: index) -> tensor<?x32x32xf32> {
1836   %cst = arith.constant 0.000000e+00 : f32
1837   // CHECK: tensor.pad
1838   %padded = tensor.pad %arg0 low[%padding, %padding, %padding] high[0, 0, 0] {
1839   ^bb0(%arg1: index, %arg2: index, %arg3: index):
1840     tensor.yield %cst : f32
1841   } : tensor<32x?x?xf32> to tensor<32x?x?xf32>
1842   // CHECK: %[[CAST:.*]] = tensor.cast
1843   %casted = tensor.cast %padded : tensor<32x?x?xf32> to tensor<?x32x32xf32>
1844   // CHECK: return %[[CAST]]
1845   return %casted : tensor<?x32x32xf32>
1848 // -----
1850 func.func @pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
1851   %c0 = arith.constant 0 : index
1852   %cst = arith.constant 0.0 : f32
1853   %0 = tensor.cast %arg0 : tensor<4x4xf32> to tensor<?x?xf32>
1854   %1 = tensor.pad %0 low[%c0, %c0] high[%c0, %c0]  {
1855     ^bb0(%arg1: index, %arg2: index):
1856       tensor.yield %cst : f32
1857   } : tensor<?x?xf32> to tensor<4x4xf32>
1858   return %1 : tensor<4x4xf32>
1860 // CHECK-LABEL: @pad_cast
1861 // CHECK-SAME: %[[ARG0:.+]]: tensor<4x4xf32>
1862 // CHECK: return %[[ARG0]]
1864 // -----
1866 // CHECK-LABEL: func @fold_pad_source_cast(
1867 //  CHECK-SAME:                  %[[ARG0:.*]]: tensor<4x?xf32>
1868 //   CHECK-NOT:   tensor.cast
1869 //       CHECK:   %[[RESULT:.*]] = tensor.pad %[[ARG0]]
1870 func.func @fold_pad_source_cast(%arg0: tensor<4x?xf32>) -> tensor<4x4xf32> {
1871   %cst = arith.constant 0.0 : f32
1872   %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32>
1873   %1 = tensor.pad %0 low[0, 0] high[0, 1]  {
1874     ^bb0(%arg1: index, %arg2: index):
1875       tensor.yield %cst : f32
1876   } : tensor<?x?xf32> to tensor<4x4xf32>
1877   return %1 : tensor<4x4xf32>
1880 // -----
1882 // CHECK-LABEL: func @pad_static_zero_cast(
1883 //  CHECK-SAME:                  %[[ARG0:.*]]: tensor<?x?x?xf32>
1884 //   CHECK-NOT:   tensor.pad
1885 //       CHECK:   %[[RESULT:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
1886 //       CHECK:   return %[[RESULT]]
1887 func.func @pad_static_zero_cast(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
1888   %c0 = arith.constant 0 : index
1889   %0 = tensor.pad %arg0 low[0, %c0, 0] high[0, 0, %c0] {
1890     ^bb0(%arg1: index, %arg2: index, %arg3: index):
1891       tensor.yield %pad_value : f32
1892     } : tensor<?x?x?xf32> to tensor<2x3x4xf32>
1894   return %0 : tensor<2x3x4xf32>
1897 // -----
1899 // CHECK-LABEL: func @pad_nofold_static_zero(
1900 //  CHECK-SAME:                  %[[ARG0:.*]]: tensor<?x?x?xf32>
1901 //       CHECK:   %[[PAD:.*]] = tensor.pad
1902 //       CHECK:   return %[[PAD]]
1903 func.func @pad_nofold_static_zero(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
1904   %c0 = arith.constant 0 : index
1905   %0 = tensor.pad %arg0 nofold low[0, %c0, 0] high[0, 0, %c0] {
1906     ^bb0(%arg1: index, %arg2: index, %arg3: index):
1907       tensor.yield %pad_value : f32
1908     } : tensor<?x?x?xf32> to tensor<2x3x4xf32>
1910   return %0 : tensor<2x3x4xf32>
1913 // -----
1915 // CHECK-LABEL: func @fold_orthogonal_pad_chains(
1916 //  CHECK-SAME:   %[[ARG0:.*]]: tensor<64x64xf32>,
1917 //  CHECK-SAME:   %[[SZ0:.*]]: index, %[[SZ1:.*]]: index, %[[PW0:.*]]: index, %[[PW1:.*]]: index
1918 func.func @fold_orthogonal_pad_chains(%arg0: tensor<64x64xf32>,
1919                                       %sz0 : index, %sz1 : index,
1920                                       %pw0 : index, %pw1 : index) -> tensor<8x4xf32> {
1921   //       CHECK:   %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
1922   //  CHECK-SAME:                     [16, 4] [%[[SZ0]], %[[SZ1]]]
1923   //       CHECK:   %[[PAD:.*]] = tensor.pad %[[T0]] nofold
1924   //  CHECK-SAME:                     high[%[[PW0]], %[[PW1]]]
1925   //       CHECK:   return %[[PAD]]
1926   %pad_value = arith.constant 0.0 : f32
1927   %0 = tensor.extract_slice %arg0[16, 0] [%sz0, 64] [1, 1] : tensor<64x64xf32> to tensor<?x64xf32>
1928   %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] {
1929     ^bb0(%arg1: index, %arg2: index):
1930       tensor.yield %pad_value : f32
1931     } : tensor<?x64xf32> to tensor<8x64xf32>
1932   %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32>
1933   %3 = tensor.pad %2 nofold low[0, 0] high[0, %pw1] {
1934     ^bb0(%arg1: index, %arg2: index):
1935       tensor.yield %pad_value : f32
1936     } : tensor<8x?xf32> to tensor<8x4xf32>
1937   func.return %3 : tensor<8x4xf32>
1940 // -----
1942 // CHECK-LABEL: func @dont_fold_pad_chains(
1943 //  CHECK-SAME:   %[[ARG0:.*]]: tensor<64x64xf32>,
1944 //  CHECK-SAME:   %[[SZ0:.*]]: index, %[[SZ1:.*]]: index, %[[PW0:.*]]: index, %[[PW1:.*]]: index
1945 func.func @dont_fold_pad_chains(%arg0: tensor<64x64xf32>,
1946                                 %sz0 : index, %sz1 : index,
1947                                 %pw0 : index, %pw1 : index) -> (tensor<8x4xf32>, tensor<4x64xf32>, tensor<8x4xf32>, tensor<6x4xf32>) {
1948   //       CHECK:   %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
1949   //       CHECK:   %[[T1:.*]] = tensor.pad %[[T0]]
1950   %pad_value = arith.constant 0.0 : f32
1951   %0 = tensor.extract_slice %arg0[16, 0] [%sz0, 64] [1, 1] : tensor<64x64xf32> to tensor<?x64xf32>
1952   %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] {
1953     ^bb0(%arg1: index, %arg2: index):
1954       tensor.yield %pad_value : f32
1955     } : tensor<?x64xf32> to tensor<8x64xf32>
1957   // Don't fold if the padding values are different.
1958   //       CHECK:   %[[T2:.*]] = tensor.extract_slice %[[T1]]
1959   //  CHECK-SAME:                     [0, 4] [8, %[[SZ1]]]
1960   //       CHECK:   %[[PAD0:.*]] = tensor.pad %[[T2]]
1961   %different_value = arith.constant 1.0 : f32
1962   %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32>
1963   %3 = tensor.pad %2 nofold low[0, 0] high[0, %pw1] {
1964     ^bb0(%arg1: index, %arg2: index):
1965       tensor.yield %different_value : f32
1966     } : tensor<8x?xf32> to tensor<8x4xf32>
1968   // Don't fold if the pad ops have common padding dimensions.
1969   //       CHECK:   %[[T3:.*]] = tensor.extract_slice %[[T1]]
1970   //  CHECK-SAME:                     [4, 0] [%[[SZ1]], 64]
1971   //       CHECK:   %[[PAD1:.*]] = tensor.pad %[[T3]]
1972   %4 = tensor.extract_slice %1[4, 0] [%sz1, 64] [1, 1] : tensor<8x64xf32> to tensor<?x64xf32>
1973   %5 = tensor.pad %4 nofold low[0, 0] high[%pw1, 0] {
1974     ^bb0(%arg1: index, %arg2: index):
1975       tensor.yield %pad_value : f32
1976     } : tensor<?x64xf32> to tensor<4x64xf32>
1978   // Don't fold if padded source tensor dimension is accessed at an offset.
1979   //       CHECK:   %[[T4:.*]] = tensor.extract_slice %[[T1]]
1980   //  CHECK-SAME:                     [%[[SZ0]], 4] [8, %[[SZ1]]
1981   //       CHECK:   %[[PAD2:.*]] = tensor.pad %[[T4]]
1982   %6 = tensor.extract_slice %1[%sz0, 4] [8, %sz1] [1, 1] : tensor<8x64xf32> to tensor<8x?xf32>
1983   %7 = tensor.pad %6 nofold low[0, 0] high[0, %pw1] {
1984     ^bb0(%arg1: index, %arg2: index):
1985       tensor.yield %pad_value : f32
1986     } : tensor<8x?xf32> to tensor<8x4xf32>
1988   // Don't fold if a padded source tensor dimension is sliced.
1989   //       CHECK:   %[[T5:.*]] = tensor.extract_slice %[[T1]]
1990   //  CHECK-SAME:                     [0, 4] [6, %[[SZ1]]
1991   //       CHECK:   %[[PAD3:.*]] = tensor.pad %[[T5]]
1992   %8 = tensor.extract_slice %1[0, 4] [6, %sz1] [1, 1] : tensor<8x64xf32> to tensor<6x?xf32>
1993   %9 = tensor.pad %8 nofold low[0, 0] high[0, %pw1] {
1994     ^bb0(%arg1: index, %arg2: index):
1995       tensor.yield %pad_value : f32
1996     } : tensor<6x?xf32> to tensor<6x4xf32>
1998   //       CHECK:   return %[[PAD0]], %[[PAD1]], %[[PAD2]], %[[PAD3]]
1999   func.return %3, %5, %7, %9 : tensor<8x4xf32>, tensor<4x64xf32>, tensor<8x4xf32>, tensor<6x4xf32>
2002 // -----
2004 // CHECK-LABEL: func @merge_constant_padding
2005 //  CHECK-SAME:   %[[ARG0:[A-Za-z0-9]+]]: tensor<2x3xf32>
2006 //  CHECK-SAME:   %[[PADVAL:[A-Za-z0-9]+]]: f32
2007 //       CHECK:   %[[PAD:.+]] = tensor.pad %[[ARG0]] low[1, 3] high[4, 2]
2008 //       CHECK:     tensor.yield %[[PADVAL]]
2009 //       CHECK:   return %[[PAD]]
2010 func.func @merge_constant_padding(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> {
2011   %pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
2012     ^bb0(%b0: index, %b1 : index):
2013       tensor.yield %pad_value : f32
2014     } : tensor<2x3xf32> to tensor<4x4xf32>
2015   %pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
2016     ^bb0(%b2: index, %b3 : index):
2017       tensor.yield %pad_value : f32
2018     } : tensor<4x4xf32> to tensor<7x8xf32>
2019   return %pad1 : tensor<7x8xf32>
2022 // -----
2024 //       CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 + 1)>
2025 // CHECK-LABEL: func @merge_constant_padding_dynamic
2026 //  CHECK-SAME:   %[[ARG0:[A-Za-z0-9]+]]: tensor<?x?xf32>
2027 //  CHECK-SAME:   %[[IDX:[A-Za-z0-9]+]]: index
2028 //  CHECK-SAME:   %[[PADVAL:[A-Za-z0-9]+]]: f32
2029 //       CHECK:   %[[HIGH:.+]] = affine.apply #[[$MAP]]()[%[[IDX]]]
2030 //       CHECK:   %[[PAD:.+]] = tensor.pad %[[ARG0]] low[%[[IDX]], 3] high[%[[HIGH]], 2]
2031 //       CHECK:     tensor.yield %[[PADVAL]]
2032 //       CHECK:   return %[[PAD]]
2033 func.func @merge_constant_padding_dynamic(%arg0: tensor<?x?xf32>, %idx: index, %pad_value: f32) -> tensor<?x?xf32> {
2034   %pad0 = tensor.pad %arg0 low[%idx, 1] high[1, 0] {
2035     ^bb0(%b0: index, %b1 : index):
2036       tensor.yield %pad_value : f32
2037     } : tensor<?x?xf32> to tensor<?x?xf32>
2038   %pad1 = tensor.pad %pad0 low[0, 2] high[%idx, 2] {
2039     ^bb0(%b2: index, %b3 : index):
2040       tensor.yield %pad_value : f32
2041     } : tensor<?x?xf32> to tensor<?x?xf32>
2042   return %pad1 : tensor<?x?xf32>
2045 // -----
2047 // Verify that folding does not happen if it would drop a nofold attribute
2048 // CHECK-LABEL: func @dont_merge_constant_padding_nofold
2049 //       CHECK:   tensor.pad {{.*}} nofold
2050 //       CHECK:   tensor.pad
2051 func.func @dont_merge_constant_padding_nofold(%arg0: tensor<2x3xf32>, %pad_value: f32) -> tensor<7x8xf32> {
2052   %pad0 = tensor.pad %arg0 nofold low[1, 1] high[1, 0] {
2053     ^bb0(%b0: index, %b1 : index):
2054       tensor.yield %pad_value : f32
2055     } : tensor<2x3xf32> to tensor<4x4xf32>
2056   %pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
2057     ^bb0(%b2: index, %b3 : index):
2058       tensor.yield %pad_value : f32
2059     } : tensor<4x4xf32> to tensor<7x8xf32>
2060   return %pad1 : tensor<7x8xf32>
2063 // -----
2065 // Verify that folding does not happen if it would drop a nofold attribute
2066 // CHECK-LABEL: func @dont_merge_constant_padding_different_vals
2067 //       CHECK:   tensor.pad
2068 //       CHECK:   tensor.pad
2069 func.func @dont_merge_constant_padding_different_vals(
2070     %arg0: tensor<2x3xf32>,
2071     %pad_value0: f32,
2072     %pad_value1: f32) -> tensor<7x8xf32> {
2073   %pad0 = tensor.pad %arg0 low[1, 1] high[1, 0] {
2074     ^bb0(%b0: index, %b1 : index):
2075       tensor.yield %pad_value0 : f32
2076     } : tensor<2x3xf32> to tensor<4x4xf32>
2077   %pad1 = tensor.pad %pad0 low[0, 2] high[3, 2] {
2078     ^bb0(%b2: index, %b3 : index):
2079       tensor.yield %pad_value1 : f32
2080     } : tensor<4x4xf32> to tensor<7x8xf32>
2081   return %pad1 : tensor<7x8xf32>
2084 // -----
2086 // CHECK-LABEL: func @fold_collapse_shape_from_elements
2087 func.func @fold_collapse_shape_from_elements(%arg0: i32) -> tensor<i32> {
2088   // CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor<i32>
2089   // CHECK: return %[[FROM]] : tensor<i32>
2090   %0 = tensor.from_elements %arg0 : tensor<1xi32>
2091   %1 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor<i32>
2092   return %1 : tensor<i32>
2095 // -----
2097 // CHECK-LABEL: func @fold_expand_shape_from_elements
2098 func.func @fold_expand_shape_from_elements(%arg0: i32) -> tensor<1xi32> {
2099   // CHECK: %[[FROM:.+]] = tensor.from_elements %arg0 : tensor<1xi32>
2100   // CHECK: return %[[FROM]] : tensor<1xi32>
2101   %0 = tensor.from_elements %arg0 : tensor<i32>
2102   %1 = tensor.expand_shape %0 [] output_shape [1] : tensor<i32> into tensor<1xi32>
2103   return %1 : tensor<1xi32>
2106 // -----
2108 // CHECK-LABEL: func @propagate_index_cast
2109 func.func @propagate_index_cast(%arg0: tensor<1xi32>) -> index {
2110   // CHECK: %[[IDX:.+]] = arith.constant 0
2111   // CHECK: %[[EXT:.+]] = tensor.extract %arg0[%[[IDX]]] : tensor<1xi32>
2112   // CHECK: %[[CAST:.+]] = arith.index_cast %[[EXT]]
2113   // CHECK: return %[[CAST]] : index
2114   %c0 = arith.constant 0 : index
2115   %0 = arith.index_cast %arg0 : tensor<1xi32> to tensor<1xindex>
2116   %1 = tensor.extract %0[%c0] : tensor<1xindex>
2117   return %1 : index
2120 // -----
2122 // CHECK-LABEL: func @splat_fold
2123 func.func @splat_fold() -> tensor<4xf32> {
2124   %c = arith.constant 1.0 : f32
2125   %t = tensor.splat %c : tensor<4xf32>
2126   return %t : tensor<4xf32>
2128   // CHECK-NEXT: [[T:%.*]] = arith.constant dense<1.000000e+00> : tensor<4xf32>
2129   // CHECK-NEXT: return [[T]] : tensor<4xf32>
2132 // -----
2134 // CHECK-LABEL: func @splat_dynamic_no_fold
2135 // CHECK-SAME: %[[M:.+]]: index
2136 func.func @splat_dynamic_no_fold(%m: index) -> tensor<4x?xf32> {
2137   // CHECK: %[[F:.+]] = arith.constant
2138   %f = arith.constant 1.0 : f32
2140   // CHECK: tensor.splat %[[F]][%[[M]]] : tensor<4x?xf32>
2141   %t = tensor.splat %f[%m] : tensor<4x?xf32>
2142   return %t : tensor<4x?xf32>
2145 // -----
2147 // CHECK-LABEL: func @cast_extract_slice
2148 func.func @cast_extract_slice(%arg0 : tensor<128x512xf32>, %s : index, %o : index)
2149     -> tensor<16x512xf32> {
2150 // CHECK: %[[E:.*]] = tensor.extract_slice %{{.*}}[%{{.*}}, 0] [16, 512] [1, 1] : tensor<128x512xf32> to tensor<16x512xf32>
2151   %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] : tensor<128x512xf32> to tensor<?x512xf32>
2152   %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
2153 // CHECK: return %[[E]] : tensor<16x512xf32>
2154   return %1 : tensor<16x512xf32>
2157 // -----
2159 // CHECK-LABEL: func @cast_extract_slice_rank_reduce
2160 func.func @cast_extract_slice_rank_reduce(%arg0 : tensor<128x512xf32>, %s : index, %o : index)
2161     -> tensor<16xf32> {
2162 // CHECK: %[[E:.*]]  = tensor.extract_slice %{{.*}}[%{{.*}}, 0] [16, 1] [1, 1] : tensor<128x512xf32> to tensor<16xf32>
2163   %0 = tensor.extract_slice %arg0[%o, 0] [%s, 1] [1, 1] : tensor<128x512xf32> to tensor<?xf32>
2164   %1 = tensor.cast %0 : tensor<?xf32> to tensor<16xf32>
2165 // CHECK: return %[[E]] : tensor<16xf32>
2166   return %1 : tensor<16xf32>
2169 // -----
2171 // CHECK-LABEL: func.func @canonicalize_parallel_insert_slice_indices(
2172 //  CHECK-SAME:     %[[arg0:[0-9a-z]*]]: tensor<1x5xf32>,
2173 //  CHECK-SAME:     %[[arg1:[0-9a-z]*]]: tensor<?x?xf32>,
2174 //  CHECK-SAME:     %[[num_threads:[0-9a-z]*]]: index
2175 func.func @canonicalize_parallel_insert_slice_indices(
2176     %arg0 : tensor<1x5xf32>, %arg1: tensor<?x?xf32>,
2177     %num_threads : index) -> tensor<?x?xf32>
2179   %cst = arith.constant 4.200000e+01 : f32
2180   %c0 = arith.constant 0 : index
2181   %c1 = arith.constant 1 : index
2183   //  CHECK-NOT: tensor.cast
2184   //      CHECK: scf.forall (%[[tidx:[0-9a-z]*]]) in (%[[num_threads]]) shared_outs(%[[o:.*]] = %[[arg1]]) -> (tensor<?x?xf32>) {
2185   // CHECK-NEXT:   scf.forall.in_parallel {
2186   // CHECK-NEXT:     tensor.parallel_insert_slice %[[arg0]] into %[[o]][%[[tidx]], 0] [1, 5] [1, 1]
2187   %2 = scf.forall (%tidx) in (%num_threads) shared_outs(%o = %arg1) -> (tensor<?x?xf32>) {
2188     %3 = tensor.cast %arg0 : tensor<1x5xf32> to tensor<?x5xf32>
2189     scf.forall.in_parallel {
2190       tensor.parallel_insert_slice %3 into %o[%tidx, %c0] [%c1, 5] [%c1, %c1] : tensor<?x5xf32> into tensor<?x?xf32>
2191     }
2192   }
2193   return %2 : tensor<?x?xf32>
2196 // -----
2198 // CHECK-LABEL: func.func @fold_insert_slice_after_extract_slice
2199 //  CHECK-SAME: (%[[INPUT:.+]]: tensor<1x2x2x4xf32>)
2200 func.func @fold_insert_slice_after_extract_slice(%input: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> {
2201   %c0 = arith.constant 0 : index
2202   %0 = tensor.extract_slice %input[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32>
2203   %1 = tensor.insert_slice %0 into %input[%c0, 0, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32>
2204   // CHECK: return %[[INPUT]]
2205   return %1: tensor<1x2x2x4xf32>
2208 // -----
2210 // CHECK-LABEL: func.func @dont_fold_mismatched_source_dst
2211 func.func @dont_fold_mismatched_source_dst(%input0: tensor<1x2x2x4xf32>, %input1: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> {
2212   %c0 = arith.constant 0 : index
2213   // CHECK: tensor.extract_slice
2214   %0 = tensor.extract_slice %input0[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32>
2215   // CHECK: tensor.insert_slice
2216   %1 = tensor.insert_slice %0 into %input1[%c0, 0, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32>
2217   return %1: tensor<1x2x2x4xf32>
2220 // -----
2222 // CHECK-LABEL: func.func @dont_fold_mismatched_parameters
2223 func.func @dont_fold_mismatched_parameters(%input: tensor<1x2x2x4xf32>) -> tensor<1x2x2x4xf32> {
2224   %c0 = arith.constant 0 : index
2225   // CHECK: tensor.extract_slice
2226   %0 = tensor.extract_slice %input[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x2x4xf32> to tensor<1x2x4xf32>
2227   // CHECK: tensor.insert_slice
2228   %1 = tensor.insert_slice %0 into %input[%c0, 1, %c0, 0] [1, 1, 2, 4] [1, 1, 1, 1] : tensor<1x2x4xf32> into tensor<1x2x2x4xf32>
2229   return %1: tensor<1x2x2x4xf32>
2232 // -----
2234 func.func @empty_canonicalize() -> (tensor<4x5x?xf32>) {
2235   %c6 = arith.constant 6 : index
2236   %0 = tensor.empty(%c6) : tensor<4x5x?xf32>
2237   return %0 : tensor<4x5x?xf32>
2239 // CHECK: func @empty_canonicalize
2240 // CHECK:   %[[T0:.+]] = tensor.empty() : tensor<4x5x6xf32>
2241 // CHECK:   %[[T1:.+]] = tensor.cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32>
2242 // CHECK:   return %[[T1]]
2244 // -----
2246 func.func @fold_empty_tensor_with_cast(%arg0 : index) -> tensor<1x12xf32> {
2247   %0 = tensor.empty(%arg0) : tensor<?x12xf32>
2248   %1 = tensor.cast %0 : tensor<?x12xf32> to tensor<1x12xf32>
2249   return %1 : tensor<1x12xf32>
2251 //      CHECK: func @fold_empty_tensor_with_cast(%[[ARG0:.+]]: index)
2252 //      CHECK:   %[[T0:.+]] = tensor.empty() : tensor<1x12xf32>
2253 //      CHECK:   return %[[T0]] : tensor<1x12xf32>
2255 // -----
2257 func.func private @some_use(%i : index, %j : index)
2259 // CHECK-LABEL: func @empty_tensor_canonicalize
2260 //  CHECK-SAME:   %[[I:.*]]: index
2261 func.func @empty_tensor_canonicalize(%i : index) {
2262   %c0 = arith.constant 0 : index
2263   %c1 = arith.constant 1 : index
2265   // CHECK-NOT: tensor.empty
2266   %0 = tensor.empty(%i) : tensor<?x42xf32>
2268   // CHECK-NOT: tensor.dim
2269   %1 = tensor.dim %0, %c0: tensor<?x42xf32>
2270   %2 = tensor.dim %0, %c1: tensor<?x42xf32>
2272   // CHECK: %[[c42:.*]] = arith.constant 42 : index
2273   // CHECK: call @some_use(%[[I]], %[[c42]])
2274   call @some_use(%1, %2) : (index, index) -> ()
2276   return
2279 // -----
2281 //       CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 floordiv 40)>
2282 // CHECK-LABEL: func @dim_of_expand_shape(
2283 //  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>
2284 //       CHECK:   %[[c1:.*]] = arith.constant 1 : index
2285 //       CHECK:   %[[dim:.*]] = tensor.dim %[[t]], %[[c1]] : tensor<?x?xf32>
2286 //       CHECK:   %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim]]]
2287 //       CHECK:   return %[[apply]]
2288 func.func @dim_of_expand_shape(%t: tensor<?x?xf32>, %sz0: index, %sz1: index) -> index {
2289   %c2 = arith.constant 2 : index
2290   %0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]] output_shape [%sz0, 1, %sz1, 5, 1, 8]
2291       : tensor<?x?xf32> into tensor<?x1x?x5x1x8xf32>
2292   %1 = tensor.dim %0, %c2 : tensor<?x1x?x5x1x8xf32>
2293   return %1 : index
2296 // -----
2298 //       CHECK: #[[$map:.*]] = affine_map<()[s0, s1, s2] -> (((s0 * s1) * s2) * 7)>
2299 // CHECK-LABEL: func @dim_of_collapse_shape(
2300 //  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?x7x?xf32>
2301 //   CHECK-DAG:   %[[c1:.*]] = arith.constant 1 : index
2302 //   CHECK-DAG:   %[[c2:.*]] = arith.constant 2 : index
2303 //   CHECK-DAG:   %[[c4:.*]] = arith.constant 4 : index
2304 //   CHECK-DAG:   %[[dim1:.*]] = tensor.dim %[[t]], %[[c1]]
2305 //   CHECK-DAG:   %[[dim2:.*]] = tensor.dim %[[t]], %[[c2]]
2306 //   CHECK-DAG:   %[[dim4:.*]] = tensor.dim %[[t]], %[[c4]]
2307 //       CHECK:   %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim1]], %[[dim2]], %[[dim4]]]
2308 //       CHECK:   return %[[apply]]
2309 func.func @dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
2310   %c1 = arith.constant 1 : index
2311   %0 = tensor.collapse_shape %t [[0], [1, 2, 3, 4]]
2312       : tensor<?x?x?x7x?xf32> into tensor<?x?xf32>
2313   %1 = tensor.dim %0, %c1 : tensor<?x?xf32>
2314   return %1 : index
2317 // -----
2319 // CHECK-LABEL: func @collapse_expand_fold_to_cast(
2320 //  CHECK-SAME:     %[[t:.*]]: tensor<?xf32>
2321 //       CHECK:   return %[[t]]
2322 func.func @collapse_expand_fold_to_cast(%t: tensor<?xf32>, %sz0: index) -> (tensor<?xf32>)
2324   %0 = tensor.expand_shape %t [[0, 1]] output_shape [1, %sz0] : tensor<?xf32> into tensor<1x?xf32>
2325   %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x?xf32> into tensor<?xf32>
2326   return %1 : tensor<?xf32>
2329 // -----
2331 // Chain: NC -> NCnc -> NCnc -> NC
2332 // CHECK: func.func @unpack_pack(
2333 // CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>)
2334 // CHECK: return %[[T]] : tensor<128x128xf32>
2335 func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> {
2336   %tensor_empty = tensor.empty() : tensor<16x16x8x8xf32>
2337   %packed = tensor.pack %t inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
2338   %tensor_empty1 = tensor.empty() : tensor<128x128xf32>
2339   %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
2340   return %unpacked : tensor<128x128xf32>
2343 // -----
2345 // Chain: NC -> NCcn -> NCnc -> NC
2346 // CHECK: func.func @unpack_pack(
2347 // CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>)
2348 // CHECK-NOT: return %[[T]] : tensor<128x128xf32>
2349 func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> {
2350   %tensor_empty = tensor.empty() : tensor<16x16x8x8xf32>
2351   %packed = tensor.pack %t inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
2352   %tensor_empty1 = tensor.empty() : tensor<128x128xf32>
2353   %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<16x16x8x8xf32> -> tensor
2354 <128x128xf32>
2355   return %unpacked : tensor<128x128xf32>
2358 // -----
2360 // Chain: NC -> CNcn -> NCnc -> NC
2361 // CHECK: func.func @unpack_pack(
2362 // CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>)
2363 // CHECK-NOT: return %[[T]] : tensor<128x128xf32>
2364 func.func @unpack_pack(%t: tensor<128x128xf32>) -> tensor<128x128xf32> {
2365   %tensor_empty = tensor.empty() : tensor<16x16x8x8xf32>
2366   %packed = tensor.pack %t outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [8, 8] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
2367   %tensor_empty1 = tensor.empty() : tensor<128x128xf32>
2368   %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<16x16x8x8xf32> -> tensor
2369 <128x128xf32>
2370   return %unpacked : tensor<128x128xf32>
2373 // -----
2375 // Chain: NC -> NCnc -> NCnc -> NC
2376 // CHECK: func.func @unpack_pack(
2377 // CHECK-SAME: %[[T:.+]]: tensor<128x128xf32>,
2378 // CHECK: return %[[T]] : tensor<128x128xf32>
2379 func.func @unpack_pack(%t: tensor<128x128xf32>, %tile1: index, %tile2: index) -> tensor<128x128xf32> {
2380   %tensor_empty = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
2381   %packed = tensor.pack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
2382   %tensor_empty1 = tensor.empty() : tensor<128x128xf32>
2383   %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<16x16x?x?xf32> -> tensor
2384 <128x128xf32>
2385   return %unpacked : tensor<128x128xf32>
2388 // -----
2390 // CHECK: func.func @unpack_pack_with_padding_no_canonicalization(
2391 // CHECK:         tensor.pack
2392 // CHECK:         tensor.unpack
2393 func.func @unpack_pack_with_padding_no_canonicalization(%t: tensor<256x512xbf16>) -> tensor<224x512xbf16> {
2394   %tensor_empty = tensor.empty() : tensor<4x16x64x32xbf16>
2395   %tensor_empty1 = tensor.empty() : tensor<224x512xbf16>
2396   %packed = tensor.pack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty : tensor<256x512xbf16> -> tensor<4x16x64x32xbf16>
2397   %unpacked = tensor.unpack %packed inner_dims_pos = [0, 1] inner_tiles = [64, 32] into %tensor_empty1 : tensor<4x16x64x32xbf16> -> tensor<224x512xbf16>
2398   return %unpacked : tensor<224x512xbf16>
2401 // -----
2403 // Chain NCnc -> NC -> NC -> NCnc
2404 // CHECK: func.func @pack_unpack(
2405 // CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
2406 // CHECK: return %[[T]] : tensor<16x16x?x?xf32>
2407 func.func @pack_unpack(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
2408   %tensor_empty = tensor.empty() : tensor<128x128xf32>
2409   %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
2410   %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
2411   %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
2412   return %packed : tensor<16x16x?x?xf32>
2415 // -----
2417 // Chain NCnc -> NC -> NC -> NCnc
2418 // CHECK: func.func @pack_unpack(
2419 // CHECK-SAME: %[[T:.+]]: tensor<16x16x8x8xf32>
2420 // CHECK: return %[[T]] : tensor<16x16x8x8xf32>
2421 func.func @pack_unpack(%t: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> {
2422   %tensor_empty = tensor.empty() : tensor<128x128xf32>
2423   %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
2424   %tensor_empty1 = tensor.empty() : tensor<16x16x8x8xf32>
2425   %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
2426   return %packed : tensor<16x16x8x8xf32>
2429 // -----
2431 // CHECK: func.func @pack_unpack_same_tiles(
2432 // CHECK-SAME:  %[[T:.+]]: tensor<?x?x?x?xf32>,
2433 // CHECK: return %[[T]] : tensor<?x?x?x?xf32>
2434 func.func @pack_unpack_same_tiles(%t: tensor<?x?x?x?xf32>, %dim1: index, %dim2: index, %dim3: index, %dim4: index, %dim5: index, %dim6: index,
2435                        %tile1: index, %tile2: index) -> tensor<?x?x?x?xf32> {
2436   %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
2437   %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
2438   %tensor_empty1 = tensor.empty(%dim3, %dim4, %dim5, %dim6) : tensor<?x?x?x?xf32>
2439   %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
2440   return %packed : tensor<?x?x?x?xf32>
2443 // -----
2445 // CHECK: func.func @pack_unpack_different_tiles(
2446 // CHECK-SAME:  %[[T:.+]]: tensor<?x?x?x?xf32>,
2447 // CHECK-NOT: return %[[T]] : tensor<?x?x?x?xf32>
2448 func.func @pack_unpack_different_tiles(%t: tensor<?x?x?x?xf32>, %dim1: index, %dim2: index, %dim3: index, %dim4: index, %dim5: index, %dim6: index,
2449                        %tile1: index, %tile2: index) -> tensor<?x?x?x?xf32> {
2450   %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
2451   %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
2452   %tensor_empty1 = tensor.empty(%dim3, %dim4, %dim5, %dim6) : tensor<?x?x?x?xf32>
2453   %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile2, %tile1] into %tensor_empty1 : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
2454   return %packed : tensor<?x?x?x?xf32>
2457 // -----
2459 // CHECK: func.func @pack_unpack_dynamic_with_padding(
2460 // CHECK-SAME:  %[[T:.+]]: tensor<?x?x?x?xf32>,
2461 // CHECK-NOT: return %[[T]] : tensor<?x?x?x?xf32>
2462 func.func @pack_unpack_dynamic_with_padding(%t: tensor<?x?x?x?xf32>, %dim1: index, %dim2: index, %dim3: index, %dim4: index, %dim5: index, %dim6: index,
2463                        %tile1: index, %tile2: index, %pad: f32) -> tensor<?x?x?x?xf32> {
2464   %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
2465   %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
2466   %tensor_empty1 = tensor.empty(%dim3, %dim4, %dim5, %dim6) : tensor<?x?x?x?xf32>
2467   %packed = tensor.pack %unpacked padding_value(%pad: f32) inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<?x?xf32> -> tensor<?x?x?x?xf32>
2468   return %packed : tensor<?x?x?x?xf32>
2471 // -----
2473 // CHECK: func.func @pack_outer_dims_unpack_no_outer_dims(
2474 // CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
2475 // CHECK: return %[[T]] : tensor<16x16x?x?xf32>
2476 func.func @pack_outer_dims_unpack_no_outer_dims(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
2477   %tensor_empty = tensor.empty() : tensor<128x128xf32>
2478   %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
2479   %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
2480   %packed = tensor.pack %unpacked outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
2481   return %packed : tensor<16x16x?x?xf32>
2484 // -----
2486 // CHECK: func.func @pack_no_outer_dims_unpack_outer_dims(
2487 // CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
2488 // CHECK: return %[[T]] : tensor<16x16x?x?xf32>
2489 func.func @pack_no_outer_dims_unpack_outer_dims(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
2490   %tensor_empty = tensor.empty() : tensor<128x128xf32>
2491   %unpacked = tensor.unpack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
2492   %tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
2493   %packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
2494   return %packed : tensor<16x16x?x?xf32>
2497 // -----
2499 // CHECK: func.func @invalid_empty_negative_size
2500 // CHECK: %[[IDX:.*]] = index.constant
2501 // CHECK: %[[T:.*]] = tensor.empty(%[[IDX]]) : tensor<4x5x?xf32>
2502 func.func @invalid_empty_negative_size() -> (tensor<4x5x?xf32>) {
2503   %c1 = arith.constant 1 : index
2504   %cn2 = arith.constant 2 : index
2505   %0 = index.sub %c1, %cn2
2506   %1 = tensor.empty(%0) : tensor<4x5x?xf32>
2507   return %1 : tensor<4x5x?xf32>
2510 // -----
2512 // Fold DstStyleOp -> tensor.unpack operations.
2513 func.func @fold_dst_style_ops_into_unpack(%arg0 : tensor<?x?x16x64xf32>, %init : tensor<?x?xf32>) -> tensor<?x?xf32> {
2514   %cst = arith.constant 0.0 : f32
2515   %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
2516   %unpack = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %fill : tensor<?x?x16x64xf32> -> tensor<?x?xf32>
2517   return %unpack : tensor<?x?xf32>
2519 // CHECK-LABEL: func @fold_dst_style_ops_into_unpack
2520 //  CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?x16x64xf32>
2521 //  CHECK-SAME:     %[[INIT:.+]]: tensor<?x?xf32>
2522 //       CHECK:   %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
2523 //  CHECK-SAME:       into %[[INIT]]
2524 //       CHECK:   return %[[UNPACK]]
2526 // -----
2528 // The IR in this test case in invalid. This test tests that the canonicalizer
2529 // does not crash.
2531 // CHECK-LABEL: func @invalid_slice_ops(
2532 //       CHECK:   %[[c:.*]] = arith.constant -5 : index
2533 //       CHECK:   tensor.extract_slice {{.*}}%[[c]]
2534 //       CHECK:   tensor.insert_slice {{.*}}%[[c]]
2535 func.func @invalid_slice_ops(%t: tensor<?xf32>, %t2: tensor<?xf32>) -> tensor<?xf32> {
2536   %c = arith.constant -5 : index
2537   %0 = tensor.extract_slice %t[0][%c][1] : tensor<?xf32> to tensor<?xf32>
2538   %1 = tensor.insert_slice %0 into %t2[2][%c][1] : tensor<?xf32> into tensor<?xf32>
2539   return %1 : tensor<?xf32>
2542 // -----
2544 // CHECK-LABEL: func @generate_negative_size_verifies(
2545 //       CHECK:   %[[c:.*]] = arith.constant -8 : index
2546 //       CHECK:   tensor.generate %[[c]]
2547 //       CHECK:   : tensor<?x8xi32>
2548 func.func @generate_negative_size_verifies() -> tensor<?x8xi32> {
2549   %cst = arith.constant 0 : i32
2550   %c0 = arith.constant 0 : index
2551   %size = affine.max affine_map<(d0) -> (d0 mod 64 - 8)>(%c0)
2552   %tensor = tensor.generate %size {
2553   ^bb0(%arg0: index, %arg1: index):
2554     tensor.yield %cst : i32
2555   } : tensor<?x8xi32>
2556   return %tensor : tensor<?x8xi32>
2559 // -----
2561 func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> tensor<10x20x4x4xf32> {
2562   %dim1 = arith.constant 40 : index
2563   %dim2 = arith.constant 80 : index
2564   %tensor_empty = tensor.empty(%dim1, %dim2) : tensor<?x?xf32>
2565   %unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty : tensor<10x20x4x4xf32> -> tensor<?x?xf32>
2566   %cast = tensor.cast %unpacked : tensor<?x?xf32> to tensor<40x80xf32>
2567   %tensor_empty1 = tensor.empty() : tensor<10x20x4x4xf32>
2568   %packed = tensor.pack %cast inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %tensor_empty1 : tensor<40x80xf32> -> tensor<10x20x4x4xf32>
2569   return %packed : tensor<10x20x4x4xf32>
2571 // CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
2572 // CHECK-SAME:    %[[SRC:[0-9a-zA-Z]+]]
2573 // CHECK:         return %[[SRC]]
2575 // -----
2577 // Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
2578 // CHECK-LABEL: func @dim_of_reshape(
2579 //  CHECK-SAME:     %[[MEM:[0-9a-z]+]]: tensor<*xf32>,
2580 //  CHECK-SAME:     %[[SHP:[0-9a-z]+]]: tensor<?xindex>
2581 //  CHECK-NEXT:   %[[IDX:.*]] = arith.constant 3
2582 //  CHECK-NEXT:   %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]]
2583 //   CHECK-NOT:   tensor.store
2584 //   CHECK-NOT:   tensor.dim
2585 //   CHECK-NOT: tensor.reshape
2586 //       CHECK:   return %[[DIM]] : index
2587 func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>)
2588     -> index {
2589   %c3 = arith.constant 3 : index
2590   %0 = tensor.reshape %arg0(%arg1)
2591       : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
2592   // Update the shape to test that the load ends up in the right place.
2593   tensor.insert %c3 into %arg1[%c3] : tensor<?xindex>
2594   %1 = tensor.dim %0, %c3 : tensor<*xf32>
2595   return %1 : index
2598 // -----
2600 // Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
2601 // CHECK-LABEL: func @dim_of_reshape_i32(
2602 //       CHECK:  tensor.extract
2603 //  CHECK-NEXT:  %[[CAST:.*]] = arith.index_cast
2604 //   CHECK-NOT:  tensor.dim
2605 //   CHECK-NOT:  tensor.reshape
2606 //       CHECK:  return %[[CAST]] : index
2607 func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)
2608     -> index {
2609     %c3 = arith.constant 3 : index
2610     %0 = tensor.reshape %arg0(%arg1)
2611         : (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
2612     %1 = tensor.dim %0, %c3 : tensor<*xf32>
2613     return %1 : index
2616 // -----
2618 // Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
2619 // CHECK-LABEL: func @dim_of_reshape_for(
2620 //       CHECK: scf.for
2621 //  CHECK-NEXT: tensor.extract
2622 //   CHECK-NOT: tensor.dim
2623 //   CHECK-NOT: tensor.reshape
2624 func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> index {
2625     %c0 = arith.constant 0 : index
2626     %c1 = arith.constant 1 : index
2627     %c4 = arith.constant 4 : index
2629     %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
2631     %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
2632       %2 = tensor.dim %0, %arg2 : tensor<*xf32>
2633       %3 = arith.muli %arg3, %2 : index
2634       scf.yield %3 : index
2635     }
2636     return %1 : index
2639 // -----
2641 // Test case: tensor.dim(tensor.reshape %v %shp, %idx) is folded into tensor.extract %shp[%idx]
2642 // CHECK-LABEL: func @dim_of_reshape_undominated(
2643 //       CHECK: arith.muli
2644 //  CHECK-NEXT: tensor.extract
2645 //   CHECK-NOT: tensor.dim
2646 //   CHECK-NOT: tensor.reshape
2647 func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: index) -> index {
2648     %c4 = arith.constant 4 : index
2649     %reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
2650     %0 = arith.muli %arg2, %c4 : index
2651     %dim = tensor.dim %reshape, %0 : tensor<*xf32>
2652     return %dim : index
2653   }
2655 // -----
2657 // CHECK-LABEL: @reshape_fold_2d
2658 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
2659 func.func @reshape_fold_2d(%arg0 : tensor<?x?xi32>) -> tensor<?x?xi32> {
2660   %c0 = arith.constant 0 : index
2661   %c1 = arith.constant 1 : index
2662   %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
2663   %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
2664   %ds = tensor.from_elements %d0, %d1 : tensor<2xindex>
2665   %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
2666   // CHECK: return %[[ARG0]]
2667   return %reshape : tensor<?x?xi32>
2670 // -----
2672 // CHECK-LABEL: @reshape_nofold_2d
2673 // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
2674 func.func @reshape_nofold_2d(%arg0 : tensor<?x?xi32>) -> tensor<?x?xi32> {
2675   %c0 = arith.constant 0 : index
2676   %c1 = arith.constant 1 : index
2677   %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
2678   %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
2679   %ds = tensor.from_elements %d1, %d0 : tensor<2xindex>
2680   // CHECK: tensor.reshape
2681   %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
2682   return %reshape : tensor<?x?xi32>
2685 // -----
2687 // CHECK-LABEL: @reshape_nofold_2d_ins
2688 func.func @reshape_nofold_2d_ins(%arg0 : tensor<?x?xi32>, %arg1: index, %arg2: index) -> tensor<?x?xi32> {
2689   %ds = tensor.from_elements %arg1, %arg2 : tensor<2xindex>
2690   // CHECK: tensor.reshape
2691   %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
2692   return %reshape : tensor<?x?xi32>
2695 // -----
2697 // CHECK-LABEL: @reshape_fold_3d_cst
2698 // CHECK-SAME: %[[ARG0:.+]]: tensor<5x?x?xi32>
2699 func.func @reshape_fold_3d_cst(%arg0 : tensor<5x?x?xi32>) -> tensor<5x?x?xi32> {
2700   %c1 = arith.constant 1 : index
2701   %c2 = arith.constant 2 : index
2702   %d0 = arith.constant 5 : index
2703   %d1 = tensor.dim %arg0, %c1 : tensor<5x?x?xi32>
2704   %d2 = tensor.dim %arg0, %c2 : tensor<5x?x?xi32>
2705   %ds = tensor.from_elements %d0, %d1, %d2 : tensor<3xindex>
2706   %reshape = tensor.reshape %arg0(%ds) : (tensor<5x?x?xi32>, tensor<3xindex>) -> tensor<5x?x?xi32>
2707   // CHECK: return %[[ARG0]]
2708   return %reshape : tensor<5x?x?xi32>
2711 // -----
2713 // Test case: This test fails to fold because the index of tensor.dim is out_of_bounds
2714 // CHECK-LABEL: func @dim_out_of_bounds(
2715 //       CHECK: %[[IDX:.*]] = index.constant 28
2716 //  CHECK-NEXT: bufferization.alloc_tensor
2717 //  CHECK-NEXT: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[IDX]]
2718 //  CHECK-NEXT: memref.alloc
2719 //  CHECK-NEXT: memref.cast
2720 //  CHECK-NEXT: affine.vector_load %{{.*}}[{{.*}}, {{.*}}, symbol(%[[DIM]])]
2721 //  CHECK-NEXT: return
2722 func.func @dim_out_of_bounds() -> vector<7xi32> {
2723     %c1 = arith.constant 1 : index
2724     %idx28 = index.constant 28
2725     %c29 = arith.constant 29 : index
2726     %3 = bufferization.alloc_tensor(%c29) : tensor<?xi16>
2727     %dim = tensor.dim %3, %idx28 : tensor<?xi16>
2728     %alloc_21 = memref.alloc(%c29) : memref<?x26x2xi32>
2729     %16 = affine.vector_load %alloc_21[%c1, %c1, %dim] : memref<?x26x2xi32>, vector<7xi32>
2730     return %16 : vector<7xi32>
2733 // -----
2735 // CHECK-LABEL:   func.func @fold_cast_multiple_results(
2736 // CHECK-SAME:         %[[ARG1:.*]]: tensor<2x2xf32>,
2737 // CHECK-SAME:         %[[ARG2:.*]]: tensor<2x2xf32>) -> index {
2738 // CHECK:           %[[RES:.*]]:2 = test.destination_style_op ins(%[[ARG1]] : tensor<2x2xf32>)
2739 // CHECK-SAME:      outs(%[[ARG2]] : tensor<2x2xf32>) -> tensor<2x2xf32>, index
2740 // CHECK:           return %[[RES]]#1 : index
2741 func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> index {
2742   %cast = tensor.cast %arg0 : tensor<2x2xf32> to tensor<?x2xf32>
2743   %cast_0 = tensor.cast %arg1 : tensor<2x2xf32> to tensor<?x2xf32>
2744   %0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
2745   return %0#1 : index
2747 // -----
2749 // CHECK-LABEL:   func.func @fold_cast_pack_dynamic_tile_size
2750 // CHECK-SAME:      %[[DEST:.*]]: tensor<1x1x8x1xi32>,
2751 // CHECK-SAME:      %[[SRC:.*]]: tensor<7x?xi32>,
2752 // CHECK-SAME:      %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> {
2753 // CHECK:           %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32)
2754 // CHECK-SAME:        inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]]
2755 // CHECK-SAME:        some_attr
2756 // CHECK-SAME:        : tensor<7x?xi32> -> tensor<1x1x8x1xi32>
2757 // CHECK:           return %[[PACK]] : tensor<1x1x8x1xi32>
2758 func.func @fold_cast_pack_dynamic_tile_size(
2759   %dest: tensor<1x1x8x1xi32>,
2760   %src: tensor<7x?xi32>,
2761   %pad: i32) -> tensor<1x1x8x1xi32> {
2763     %cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
2764     %c8 = arith.constant 8 : index
2765     %pack = tensor.pack %src padding_value(%pad : i32)
2766       inner_dims_pos = [0, 1]
2767       inner_tiles = [%c8, 1]
2768       into %cast {some_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
2769     %res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32>
2770     return %res : tensor<1x1x8x1xi32>
2773 // -----
2775 // CHECK-LABEL:   func.func @pack_dont_drop_attributes(
2776 // CHECK: tensor.pack {{.*}}  {test_attr}
2777 func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> {
2778   %c32_i64 = arith.constant 32 : i64
2779   %cst = arith.constant 0.000000e+00 : f16
2780   %pack = tensor.pack %arg0 padding_value(%cst : f16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %arg1 {test_attr} : tensor<?x?x?xf16> -> tensor<128x?x100x16x1xf16>
2781   return %pack : tensor<128x?x100x16x1xf16>
2784 // -----
2786 func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>)
2787     -> tensor<10x1x10xf32> {
2788   %c1 = arith.constant 1 : index 
2789   %c10 = arith.constant 10 : index 
2790   %0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
2791   %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10]
2792       : tensor<?x?xf32> into tensor<?x?x?xf32>
2793   %2 = tensor.cast %1 : tensor<?x?x?xf32> to tensor<10x1x10xf32>
2794   return %2 : tensor<10x1x10xf32>
2796 // CHECK-LABEL:  func.func @fold_expand_of_cast
2797 //       CHECK:   %[[RES:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10]
2798 //       CHECK:   return %[[RES]]
2800 // -----
2802 func.func @sink_expand_of_cast(%arg0 : tensor<?x10xf32>)
2803     -> tensor<?x?x?xf32> {
2804   %c1 = arith.constant 1 : index
2805   %c10 = arith.constant 10 : index
2806   %0 = tensor.cast %arg0 : tensor<?x10xf32> to tensor<?x?xf32>
2807   %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10]
2808       : tensor<?x?xf32> into tensor<?x?x?xf32>
2809   return %1 : tensor<?x?x?xf32>
2811 // CHECK-LABEL:  func.func @sink_expand_of_cast
2812 //   CHECK-DAG:   %[[C10:.*]] = arith.constant 10
2813 //   CHECK-DAG:   %[[C1:.*]] = arith.constant 1
2814 //       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] 
2815 //  CHECK-SAME:     output_shape [%[[C10]], %[[C1]], 10]
2816 //       CHECK:   %[[RES:.+]] = tensor.cast %[[EXPAND]]
2817 //       CHECK:   return %[[RES]]
2819 // -----
2821 func.func @partial_sink_expand_of_cast(%arg0 : tensor<10x10xf32>, %arg1 : index, %arg2 : index)
2822     -> tensor<?x?x?xf32> {
2823   %c10 = arith.constant 10 : index
2824   %0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
2825   %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %c10]
2826       : tensor<?x?xf32> into tensor<?x?x?xf32>
2827   return %1 : tensor<?x?x?xf32>
2829 // CHECK-LABEL:  func.func @partial_sink_expand_of_cast
2830 //       CHECK:   %[[CAST:.+]] = tensor.cast
2831 //  CHECK-SAME:     tensor<10x10xf32> to tensor<?x10xf32>
2832 //       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] 
2833 //  CHECK-SAME:     output_shape [%{{.*}}, %{{.*}}, 10]
2834 //       CHECK:   %[[RES:.+]] = tensor.cast %[[EXPAND]]
2835 //  CHECK-SAME:     tensor<?x?x10xf32> to tensor<?x?x?xf32>
2836 //       CHECK:   return %[[RES]]