1 // RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-consecutive-insert-extract-slice -canonicalize -mlir-print-local-scope %s | FileCheck %s
3 func.func @extract_slice_same_rank(
4 %src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<8x16x32x?xf32> {
5 %0 = tensor.extract_slice %src[0, 1, 2, %offset0] [128, 128, 128, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x128x128x?xf32>
6 %1 = tensor.extract_slice %0[7, 8, 9, %offset1] [8, 16, 32, %size1] [1, 1, 1, 1] : tensor<128x128x128x?xf32> to tensor<8x16x32x?xf32>
7 return %1: tensor<8x16x32x?xf32>
10 // CHECK-LABEL: func.func @extract_slice_same_rank
11 // CHECK-SAME: (%[[SOURCE:.+]]: tensor<?x?x?x?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index)
12 // CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET1]], %[[OFFSET0]]]
13 // CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]][7, 9, 11, %[[OFFSET]]] [8, 16, 32, %[[SIZE1]]] [1, 1, 1, 1]
14 // CHECK: return %[[EXTRACT]] : tensor<8x16x32x?xf32>
18 func.func @extract_slice_rank_reducing_consumer(
19 %src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<16x?xf32> {
20 %0 = tensor.extract_slice %src[0, 1, 2, %offset0] [128, 128, 128, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x128x128x?xf32>
21 %1 = tensor.extract_slice %0[7, 8, 9, %offset1] [1, 16, 1, %size1] [1, 1, 1, 1] : tensor<128x128x128x?xf32> to tensor<16x?xf32>
22 return %1: tensor<16x?xf32>
25 // CHECK-LABEL: func.func @extract_slice_rank_reducing_consumer
26 // CHECK: tensor.extract_slice %{{.+}}[7, 9, 11, %{{.+}}] [1, 16, 1, %{{.+}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<16x?xf32>
30 func.func @extract_slice_rank_reducing_producer(
31 %src: tensor<?x?x?x?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index) -> tensor<8x?xf32> {
32 %0 = tensor.extract_slice %src[0, 1, 2, %offset0] [1, 128, 1, %size0] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<128x?xf32>
33 %1 = tensor.extract_slice %0[7, %offset1] [8, %size1] [1, 1] : tensor<128x?xf32> to tensor<8x?xf32>
34 return %1: tensor<8x?xf32>
37 // CHECK-LABEL: func.func @extract_slice_rank_reducing_producer
38 // CHECK-SAME: (%[[SRC:.+]]: tensor<?x?x?x?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index)
39 // CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[OFFSET1]], %[[OFFSET0]]]
40 // CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SRC]][0, 8, 2, %[[OFFSET]]] [1, 8, 1, %[[SIZE1]]] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<8x?xf32>
41 // CHECK: return %[[EXTRACT]] : tensor<8x?xf32>
45 func.func @extract_slice_non_one_stride(
46 %src: tensor<?xf32>, %offset0: index, %offset1: index, %size0: index, %size1: index, %stride0: index, %stride1: index) -> tensor<?xf32> {
47 %0 = tensor.extract_slice %src[%offset0] [%size0] [%stride0] : tensor<?xf32> to tensor<?xf32>
48 %1 = tensor.extract_slice %0[%offset1] [%size1] [%stride1] : tensor<?xf32> to tensor<?xf32>
49 return %1: tensor<?xf32>
52 // CHECK-LABEL: func.func @extract_slice_non_one_stride
53 // CHECK-SAME: (%[[SRC:.+]]: tensor<?xf32>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index, %{{.+}}: index, %[[SIZE1:.+]]: index, %[[STRIDE0:.+]]: index, %[[STRIDE1:.+]]: index)
54 // CHECK: %[[OFFSET:.+]] = affine.apply affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>()[%[[OFFSET1]], %[[STRIDE0]], %[[OFFSET0]]]
55 // CHECK: %[[STRIDE:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%[[STRIDE1]], %[[STRIDE0]]]
56 // CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SRC]][%[[OFFSET]]] [%[[SIZE1]]] [%[[STRIDE]]] : tensor<?xf32> to tensor<?xf32>
57 // CHECK: return %[[EXTRACT]] : tensor<?xf32>
61 func.func @insert_slice_rank_reducing(
62 %dst: tensor<128x128x128x128xf32>, %mid: tensor<1x16x1xf32>, %src: tensor<16xf32>, %offset: index) -> tensor<128x128x128x128xf32> {
63 %0 = tensor.insert_slice %src into %mid[0, 0, 0] [1, 16, 1] [1, 1, 1] : tensor<16xf32> into tensor<1x16x1xf32>
64 %1 = tensor.insert_slice %0 into %dst[6, 7, 8, %offset] [1, 1, 16, 1] [1, 1, 1, 1] : tensor<1x16x1xf32> into tensor<128x128x128x128xf32>
65 return %1: tensor<128x128x128x128xf32>
68 // CHECK-LABEL: func.func @insert_slice_rank_reducing
69 // CHECK-SAME: (%[[DST:.+]]: tensor<128x128x128x128xf32>, %{{.+}}: tensor<1x16x1xf32>, %[[SRC:.+]]: tensor<16xf32>, %[[IDX:.+]]: index)
70 // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SRC]] into %[[DST]][6, 7, 8, %[[IDX]]] [1, 1, 16, 1] [1, 1, 1, 1]
71 // CHECK: return %[[INSERT]]
75 func.func @insert_slice_rank_reducing_dynamic_shape(
76 %dst: tensor<128x128x128x128xf32>, %mid: tensor<1x?x1xf32>, %src: tensor<?xf32>, %offset: index, %size: index) -> tensor<128x128x128x128xf32> {
77 %0 = tensor.insert_slice %src into %mid[0, 0, 0] [1, %size, 1] [1, 1, 1] : tensor<?xf32> into tensor<1x?x1xf32>
78 %1 = tensor.insert_slice %0 into %dst[6, 7, 8, %offset] [1, 1, %size, 1] [1, 1, 1, 1] : tensor<1x?x1xf32> into tensor<128x128x128x128xf32>
79 return %1: tensor<128x128x128x128xf32>
82 // CHECK-LABEL: func.func @insert_slice_rank_reducing_dynamic_shape
83 // CHECK-COUNT-2: tensor.insert_slice
87 // CHECK-LABEL: func.func @parallel_insert_slice
88 // CHECK-NOT: tensor.insert_slice
89 // CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[0, %{{.*}}] [1, 1] [1, 1] : tensor<f32> into tensor<1x2xf32>
90 func.func @parallel_insert_slice(%t0: tensor<1x2xf32>, %t1: tensor<f32>, %t2: tensor<1x1xf32>) -> tensor<1x2xf32> {
91 %c1 = arith.constant 1 : index
92 %c2 = arith.constant 2 : index
93 %r = scf.forall (%arg2, %arg3) in (%c1, %c2) shared_outs(%arg4 = %t0) -> (tensor<1x2xf32>) {
94 %inserted_slice = tensor.insert_slice %t1 into %t2[0, 0] [1, 1] [1, 1] : tensor<f32> into tensor<1x1xf32>
95 scf.forall.in_parallel {
96 tensor.parallel_insert_slice %inserted_slice into %arg4[%arg2, %arg3] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<1x2xf32>
99 return %r : tensor<1x2xf32>