[X86] Better handling of impossibly large stack frames (#124217)
[llvm-project.git] / mlir / test / Interfaces / TilingInterface / tile-and-fuse-consumer.mlir
bloba2871b30698c527d4080de310f02084d2230c928
1 // RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s
3 #map = affine_map<(d0) -> (d0)>
4 module {
5   func.func @fuse_tileable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
6     %c4 = arith.constant 4 : index
7     %c64 = arith.constant 64 : index
8     %c0 = arith.constant 0 : index
9     %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) {
10       %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32>
11       %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) {
12         ^bb0(%in: f32, %in_16: f32, %out: f32):
13           %13 = arith.mulf %in, %in_16 : f32
14           %14 = arith.addf %out, %13 : f32
15           linalg.yield %14 : f32
16         } -> tensor<32xf32>
17       %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
18       scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32>
19     }
20     %in_operand_2 = tensor.empty() : tensor<64xf32>
21     %out_operand_3 = tensor.empty() : tensor<64xf32>
22     %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3 : tensor<64xf32>) -> tensor<64xf32>
23     return %2 : tensor<64xf32>
24   }
27 module attributes {transform.with_named_sequence} {
28   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
29     %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
30       : (!transform.any_op) -> !transform.any_op
31     %a, %b = transform.test.fuse_consumer %yield
32       : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
33     transform.yield
34   }
36 //      CHECK: func.func @fuse_tileable_consumer_scf_for(
37 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32>
38 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32>
39 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>)
40 //      CHECK:   %[[C0:.*]] = arith.constant 0 : index
41 //      CHECK:   %0 = tensor.empty() : tensor<64xf32>
42 //      CHECK:   %[[FINAL_RESULT:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]]
43 // CHECK-SAME:      iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %0)
44 // CHECK-SAME:   {
45 //      CHECK:      %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
46 //      CHECK:      %[[MAT_OUT:.*]] = linalg.generic
47 // CHECK-SAME:              outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
48 //      CHECK:      %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
49 //      CHECK:      %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
50 //      CHECK:      %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
51 //      CHECK:      %[[ELEM_OUT:.*]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
52 // CHECK-SAME:              ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
53 // CHECK-SAME:              outs(%[[SLICE_OUT]] :
54 //      CHECK:      %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
55 //      CHECK:      scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM]] :
56 //      CHECK:   }
57 //      CHECK:   return %[[FINAL_RESULT]]#2 :
59 // -----
61 module {
62   func.func @fuse_tileable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
63     %c4 = arith.constant 4 : index
64     %c64 = arith.constant 64 : index
65     %c0 = arith.constant 0 : index
66     %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) {
67       %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
68       %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
69       %3 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
70       scf.forall.in_parallel {
71          tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
72          tensor.parallel_insert_slice %extracted_slice_1 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
73       }
74     }
75     %in_operand_2 = tensor.empty() : tensor<64x64xf32>
76     %out_operand_3 = tensor.empty() : tensor<64x64xf32>
77     %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1#1, %in_operand_2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%out_operand_3 : tensor<64x64xf32>) -> tensor<64x64xf32>
78     return %2 : tensor<64x64xf32>
79   }
82 module attributes {transform.with_named_sequence} {
83   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
84     %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
85       : (!transform.any_op) -> !transform.any_op
86     %first_slice_op, %second_slice_op = transform.split_handle %slice_ops
87         : (!transform.any_op)
88         -> (!transform.any_op, !transform.any_op)
89     %a, %b = transform.test.fuse_consumer %first_slice_op
90       : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
91     transform.yield
92   }
94 //      CHECK: func.func @fuse_tileable_consumer_scf_forall(
95 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
96 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
97 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>)
98 //      CHECK:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32>
99 //      CHECK:   %[[FINAL_RESULT:.*]]:3 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
100 // CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %[[OUT_INIT]])
101 // CHECK-SAME:   {
102 //      CHECK:      %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
103 //      CHECK:      %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
104 //      CHECK:      %[[MAT_OUT:.*]] = linalg.matmul
105 // CHECK-SAME:              outs(%[[MAT_OUT_SLICE]] :
106 //      CHECK:      %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
107 //      CHECK:      %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
108 //      CHECK:      %[[ELEM_OUT:.*]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
109 // CHECK-SAME:              ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
110 // CHECK-SAME:              outs(%[[SLICE_OUT]] :
111 //      CHECK:      scf.forall.in_parallel {
112 //      CHECK:          tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
113 //      CHECK:          tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
114 //      CHECK:          tensor.parallel_insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
115 //      CHECK:       }
116 //      CHECK:   }
117 //      CHECK:   return %[[FINAL_RESULT]]#2 :
119 // -----
121 #map = affine_map<(d0) -> (d0)>
122 module {
123   func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
124     %c4 = arith.constant 4 : index
125     %c64 = arith.constant 64 : index
126     %c0 = arith.constant 0 : index
127     %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) {
128       %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32>
129       %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) {
130         ^bb0(%in: f32, %in_16: f32, %out: f32):
131           %13 = arith.mulf %in, %in_16 : f32
132           %14 = arith.addf %out, %13 : f32
133           linalg.yield %14 : f32
134         } -> tensor<32xf32>
135       %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
136       scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32>
137     }
138     %in_operand_2 = tensor.empty() : tensor<64xf32>
139     %out_operand_3 = tensor.empty() : tensor<64xf32>
140     %out_operand_4 = tensor.empty() : tensor<64xf32>
141     %2:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3, %out_operand_4 : tensor<64xf32>, tensor<64xf32>) {
142       ^bb0(%in: f32, %in_16: f32, %out_0: f32, %out_1: f32):
143           %13 = arith.mulf %in, %in_16 : f32
144           %14 = arith.subf %out_0, %13 : f32
145           %15 = arith.addf %out_1, %in : f32
146           linalg.yield %14, %15 : f32, f32
147     } -> (tensor<64xf32>, tensor<64xf32>)
148     return %2#1 : tensor<64xf32>
149   }
152 module attributes {transform.with_named_sequence} {
153   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
154     %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
155       : (!transform.any_op) -> !transform.any_op
156     %a, %b = transform.test.fuse_consumer %yield
157       : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
158     transform.yield
159   }
161 //      CHECK: func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(
162 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32>
163 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32>
164 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>)
165 //      CHECK:   %[[C0:.*]] = arith.constant 0 : index
166 //      CHECK:   %0 = tensor.empty() : tensor<64xf32>
167 //      CHECK:   %[[FINAL_RESULT:.*]]:4 = scf.for %[[IV:.*]] = %[[C0]]
168 // CHECK-SAME:      iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %0, %[[ELEM_OUT_ARG_1:.*]] = %0)
169 // CHECK-SAME:   {
170 //      CHECK:      %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
171 //      CHECK:      %[[MAT_OUT:.*]] = linalg.generic
172 // CHECK-SAME:              outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
173 //      CHECK:      %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
174 //      CHECK:      %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
175 //      CHECK:      %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1]
176 //      CHECK:      %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1]
177 //      CHECK:      %[[ELEM_OUT:.*]]:2 = linalg.generic
178 // CHECK-SAME:              ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
179 // CHECK-SAME:              outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] :
180 //      CHECK:      %[[INSERT_ELEM_0:.*]] = tensor.insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1]
181 //      CHECK:      %[[INSERT_ELEM_1:.*]] = tensor.insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1]
182 //      CHECK:      scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM_0]], %[[INSERT_ELEM_1]] :
183 //      CHECK:   }
184 //      CHECK:   return %[[FINAL_RESULT]]#3 :
186 // -----
188 #map = affine_map<(d0, d1) -> (d0, d1)>
189 module {
190     func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) {
191       %c4 = arith.constant 4 : index
192       %c64 = arith.constant 64 : index
193       %c0 = arith.constant 0 : index
194       %0:2 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %arg3, %arg7 = %arg2) -> (tensor<64x32xf32>, tensor<64x64xf32>) {
195         %extracted_slice = tensor.extract_slice %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
196         %extracted_slice_0 = tensor.extract_slice %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
197         %6 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
198         scf.forall.in_parallel {
199           tensor.parallel_insert_slice %6 into %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
200           tensor.parallel_insert_slice %extracted_slice_0 into %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
201         }
202       }
203       %1 = tensor.empty() : tensor<64x64xf32>
204       %2 = tensor.empty() : tensor<64x64xf32>
205       %3 = tensor.empty() : tensor<64x64xf32>
206       %4:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%0#1, %1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%2, %3 : tensor<64x64xf32>, tensor<64x64xf32>) {
207       ^bb0(%in: f32, %in_0: f32, %out: f32, %out_1: f32):
208         %6 = arith.mulf %in, %in_0 : f32
209         %7 = arith.subf %out, %6 : f32
210         %8 = arith.addf %out_1, %in : f32
211         linalg.yield %7, %8 : f32, f32
212       } -> (tensor<64x64xf32>, tensor<64x64xf32>)
213       %5 = tensor.empty() : tensor<2048xf32>
214       %unpack = tensor.unpack %0#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %5 : tensor<64x32xf32> -> tensor<2048xf32>
215       return %4#1, %unpack : tensor<64x64xf32>, tensor<2048xf32>
216     }
219 module attributes {transform.with_named_sequence} {
220   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
221     %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
222       : (!transform.any_op) -> !transform.any_op
223     %first_slice_op, %second_slice_op = transform.split_handle %slice_ops
224         : (!transform.any_op)
225         -> (!transform.any_op, !transform.any_op)
226     %a, %b = transform.test.fuse_consumer %first_slice_op
227       : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
228     transform.yield
229   }
231 //      CHECK: func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(
232 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
233 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
234 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>
235 // CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
236 //      CHECK:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32>
237 //      CHECK:   %[[FINAL_RESULT:.*]]:4 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
238 // CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG3]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %[[OUT_INIT]], %[[ELEM_OUT_ARG_1:.*]] = %[[OUT_INIT]])
239 // CHECK-SAME:   {
240 //      CHECK:      %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
241 //      CHECK:      %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
242 //      CHECK:      %[[MAT_OUT:.*]] = linalg.matmul
243 // CHECK-SAME:              outs(%[[MAT_OUT_SLICE]] :
244 //      CHECK:      %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
245 //      CHECK:      %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
246 //      CHECK:      %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
247 //      CHECK:      %[[ELEM_OUT:.*]]:2 = linalg.generic
248 // CHECK-SAME:              ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
249 // CHECK-SAME:              outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] :
250 //      CHECK:      scf.forall.in_parallel {
251 //      CHECK:          tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
252 //      CHECK:          tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
253 //      CHECK:          tensor.parallel_insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
254 //      CHECK:          tensor.parallel_insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
255 //      CHECK:       }
256 //      CHECK:   }
257 //      CHECK:   %[[UNPACK:.*]] = tensor.unpack %[[FINAL_RESULT]]#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %{{.*}} : tensor<64x32xf32> -> tensor<2048xf32>
258 //      CHECK:   return %[[FINAL_RESULT]]#3, %[[UNPACK]] :
260 // -----
262 #map = affine_map<(d0, d1) -> (d0, d1)>
263 module {
264     func.func @fuse_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2048xf32> {
265         %c4 = arith.constant 4 : index
266         %c64 = arith.constant 64 : index
267         %c0 = arith.constant 0 : index
268         %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
269             %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
270             %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
271                 ^bb0(%in: f32, %in_16: f32, %out: f32):
272                 %13 = arith.mulf %in, %in_16 : f32
273                 %14 = arith.addf %out, %13 : f32
274                 linalg.yield %14 : f32
275             } -> tensor<32x32xf32>
276             scf.forall.in_parallel {
277                 tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
278             }
279         }
280         %output = tensor.empty() : tensor<2048xf32>
281         %unpack = tensor.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2048xf32>
282         return %unpack : tensor<2048xf32>
283     }
285   
286 module attributes {transform.with_named_sequence} {
287     transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
288         %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
289         : (!transform.any_op) -> !transform.any_op
290         %a, %b = transform.test.fuse_consumer %slice_op
291         : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
292         transform.yield
293     }
295 //  CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
296 //  CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2048)>
297 //      CHECK: func.func @fuse_unpack_consumer_into_scf_forall(
298 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
299 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
300 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
301 //      CHECK:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<2048xf32>
302 //      CHECK:   %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) = (0, 0) to (64, 32) step (32, 32)
303 // CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]])
304 // CHECK-SAME:   {
305 //      CHECK:      %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
306 //      CHECK:      %[[GENERIC_OUT:.*]] = linalg.generic
307 // CHECK-SAME:              outs(%[[GENERIC_OUT_SLICE]] :
308 //  CHECK-DAG:      %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]])
309 //  CHECK-DAG:      %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]])
310 //      CHECK:      %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
311 //      CHECK:      %[[TILED_UNPACK_OUT:.*]] = tensor.unpack %[[GENERIC_OUT]]
312 // CHECK-SAME:                              outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
313 // CHECK-SAME:                              into %[[TILED_UNPACK_DEST]]
314 //      CHECK:      scf.forall.in_parallel {
315 //      CHECK:          tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
316 //      CHECK:          tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
317 //      CHECK:       }
318 //      CHECK:   }
319 //      CHECK:   return %[[FINAL_RESULT]]#1 :
321 // -----
323 #map = affine_map<(d0, d1) -> (d0, d1)>
324 module {
325     func.func @fuse_unaligned_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2047xf32> {
326         %c4 = arith.constant 4 : index
327         %c64 = arith.constant 64 : index
328         %c0 = arith.constant 0 : index
329         %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
330             %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
331             %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
332                 ^bb0(%in: f32, %in_16: f32, %out: f32):
333                 %13 = arith.mulf %in, %in_16 : f32
334                 %14 = arith.addf %out, %13 : f32
335                 linalg.yield %14 : f32
336             } -> tensor<32x32xf32>
337             scf.forall.in_parallel {
338                 tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
339             }
340         }
341         %output = tensor.empty() : tensor<2047xf32>
342         %unpack = tensor.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2047xf32>
343         return %unpack : tensor<2047xf32>
344     }
346   
347 module attributes {transform.with_named_sequence} {
348     transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
349         %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
350         : (!transform.any_op) -> !transform.any_op
351         %a, %b = transform.test.fuse_consumer %slice_op
352         : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
353         transform.yield
354     }
356 //  CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
357 //  CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2047)>
358 //      CHECK: func.func @fuse_unaligned_unpack_consumer_into_scf_forall(
359 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
360 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
361 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
362 //      CHECK:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<2047xf32>
363 //      CHECK:   %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) = (0, 0) to (64, 32) step (32, 32)
364 // CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]])
365 // CHECK-SAME:   {
366 //      CHECK:      %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
367 //      CHECK:      %[[GENERIC_OUT:.*]] = linalg.generic
368 // CHECK-SAME:              outs(%[[GENERIC_OUT_SLICE]] :
369 //  CHECK-DAG:      %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]])
370 //  CHECK-DAG:      %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]])
371 //      CHECK:      %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
372 //      CHECK:      %[[TILED_UNPACK_OUT:.*]] = tensor.unpack %[[GENERIC_OUT]]
373 // CHECK-SAME:                              outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
374 // CHECK-SAME:                              into %[[TILED_UNPACK_DEST]]
375 //      CHECK:      scf.forall.in_parallel {
376 //      CHECK:          tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
377 //      CHECK:          tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
378 //      CHECK:       }
379 //      CHECK:   }
380 //      CHECK:   return %[[FINAL_RESULT]]#1 :
382 // -----
384 #map = affine_map<(d0, d1) -> (d0, d1)>
385 module {
386     func.func @fuse_pack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {
387         %c4 = arith.constant 4 : index
388         %c64 = arith.constant 64 : index
389         %c0 = arith.constant 0 : index
390         %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
391             %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
392             %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
393                 ^bb0(%in: f32, %in_16: f32, %out: f32):
394                 %13 = arith.mulf %in, %in_16 : f32
395                 %14 = arith.addf %out, %13 : f32
396                 linalg.yield %14 : f32
397             } -> tensor<32x32xf32>
398             scf.forall.in_parallel {
399                 tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
400             }
401         }
402         %output = tensor.empty() : tensor<4x32x16xf32>
403         %pack = tensor.pack %1 inner_dims_pos = [0] inner_tiles = [16] into %output : tensor<64x32xf32> -> tensor<4x32x16xf32>
404         return %pack : tensor<4x32x16xf32>
405     }
407   
408 module attributes {transform.with_named_sequence} {
409     transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
410         %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
411         : (!transform.any_op) -> !transform.any_op
412         %a, %b = transform.test.fuse_consumer %slice_op
413         : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
414         transform.yield
415     }
417 //      CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
418 //      CHECK: func.func @fuse_pack_consumer_into_scf_forall(
419 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
420 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
421 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
422 //      CHECK:   %[[OUT_INIT:.*]] = tensor.empty() : tensor<4x32x16xf32>
423 //      CHECK:   %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
424 // CHECK-SAME:      shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
425 // CHECK-SAME:   {
426 //      CHECK:      %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
427 //      CHECK:      %[[GENERIC_OUT:.*]] = linalg.generic
428 // CHECK-SAME:              outs(%[[GENERIC_OUT_SLICE]] :
429 //      CHECK:      %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV1]])
430 //      CHECK:      %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1]
431 //      CHECK:      %[[TILED_PACK_OUT:.*]] = tensor.pack %[[GENERIC_OUT]]
432 // CHECK-SAME:                              inner_dims_pos = [0] inner_tiles = [16]
433 // CHECK-SAME:                              into %[[TILED_PACK_DEST]]
434 //      CHECK:      scf.forall.in_parallel {
435 //      CHECK:          tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
436 //      CHECK:          tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]],  %[[IV2]], 0] [2, 32, 16] [1, 1, 1]
438 // -----
440 module {
441   func.func @fuse_add_consumer_into_nested_scf_for(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
442     %c0 = arith.constant 0 : index
443     %c64 = arith.constant 64 : index
444     %c256 = arith.constant 256 : index
445     %cst = arith.constant 0.000000e+00 : f32
446     %dest0 = tensor.empty() : tensor<256x256xf32>
447     %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
448     %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest1) -> (tensor<256x256xf32>) {
449       %2 = scf.for %arg5 = %c0 to %c256 step %c64 iter_args(%arg6 = %arg4) -> (tensor<256x256xf32>) {
450         %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg5] [64, 64] [1, 1] : tensor<256x256xf32> to tensor<64x64xf32>
451         %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 512] [1, 1] : tensor<256x512xf32> to tensor<64x512xf32>
452         %extracted_slice_3 = tensor.extract_slice %arg1[0, %arg5] [512, 64] [1, 1] : tensor<512x256xf32> to tensor<512x64xf32>
453         %3 = linalg.matmul ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_1 : tensor<64x64xf32>) -> tensor<64x64xf32>
454         %insert_slice = tensor.insert_slice %3 into %arg6[%arg3, %arg5] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<256x256xf32>
455         scf.yield %insert_slice : tensor<256x256xf32>
456       }
457       scf.yield %2 : tensor<256x256xf32>
458     }
459     %4 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
460     return %4 : tensor<256x256xf32>
461   }
464 module attributes {transform.with_named_sequence} {
465   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
466     %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
467       : (!transform.any_op) -> !transform.any_op
468     %a, %b = transform.test.fuse_consumer %slice_op
469       : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
470     transform.yield
471   }
473 //      CHECK: func.func @fuse_add_consumer_into_nested_scf_for(
474 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
475 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
476 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
477 //      CHECK:   %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
478 //      CHECK:   %[[dest1:.*]] = linalg.fill
479 // CHECK-SAME:          outs(%[[dest0]] :
480 //      CHECK:   %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV1:.*]] = %[[C0]]
481 // CHECK-SAME:       iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[dest1]], %[[SECOND_OUT_ARG1:.*]] = %[[dest0]])
482 // CHECK-SAME:   {
483 //      CHECK:       %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV2:.*]] = %[[C0]]
484 // CHECK-SAME:         iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[SECOND_OUT_ARG1]])
485 // CHECK-SAME:         {
486 //      CHECK:            %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
487 //      CHECK:            %[[INPUT_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 512] [1, 1]
488 //      CHECK:            %[[WEIGHT_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[IV2]]] [512, 64] [1, 1]
489 //      CHECK:            %[[TILED_MAT_OUT:.*]] = linalg.matmul
490 // CHECK-SAME:                  outs(%[[MAT_OUT_SLICE]] :
491 //      CHECK:            %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
492 //      CHECK:            %[[ADD_OPERAND2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
493 //      CHECK:            %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
494 //      CHECK:            %[[TILED_ADD_OUT:.*]] = linalg.add
495 // CHECK-SAME:              ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE]] :
496 // CHECK-SAME:              outs(%[[ADD_OUT_SLICE]] :
497 //      CHECK:            %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
498 //      CHECK:            scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] :
499 //      CHECK:         }
500 //      CHECK:         scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
501 //      CHECK:   }
502 //      CHECK:   return %[[LOOP_RESULT1]]#1 :
504 // -----
506 // This test case checks fusion of consumer even if the producer has multiple uses.
507 // The multiple uses of the producer essentially means that besides the consumer
508 // op in concern, the only other uses of the producer are allowed in :-
509 // 1. scf.yield
510 // 2. tensor.parallel_insert_slice
512 module {
513   module {
514     func.func @fuse_consumer_for_multi_use_producer(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
515       %c0 = arith.constant 0 : index
516       %c64 = arith.constant 64 : index
517       %c256 = arith.constant 256 : index
518       %cst = arith.constant 0.000000e+00 : f32
519       %0 = tensor.empty() : tensor<256x256xf32>
520       %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
521       %2:2 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %1, %arg5 = %arg2) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
522         %3 = scf.for %arg6 = %c0 to %c256 step %c64 iter_args(%arg7 = %arg4) -> (tensor<256x256xf32>) {
523           %extracted_slice = tensor.extract_slice %arg7[%arg3, %arg6] [64, 64] [1, 1] : tensor<256x256xf32> to tensor<64x64xf32>
524           %extracted_slice_0 = tensor.extract_slice %arg0[%arg3, 0] [64, 512] [1, 1] : tensor<256x512xf32> to tensor<64x512xf32>
525           %extracted_slice_1 = tensor.extract_slice %arg1[0, %arg6] [512, 64] [1, 1] : tensor<512x256xf32> to tensor<512x64xf32>
526           %5 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_1 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice : tensor<64x64xf32>) -> tensor<64x64xf32>
527           %inserted_slice = tensor.insert_slice %5 into %arg7[%arg3, %arg6] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<256x256xf32>
528           scf.yield %inserted_slice : tensor<256x256xf32>
529         }
530         %4 = linalg.add ins(%3, %arg5 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
531         scf.yield %3, %4 : tensor<256x256xf32>, tensor<256x256xf32>
532       }
533       return %2#0, %2#1 : tensor<256x256xf32>, tensor<256x256xf32>
534     }
535   }
536   module attributes {transform.with_named_sequence} {
537     transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
538       %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
539       %consumer, %fused_consumer = transform.test.fuse_consumer %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
540       transform.yield
541     }
542   }
544 //      CHECK: func.func @fuse_consumer_for_multi_use_producer(
545 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
546 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
547 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
548 //      CHECK:   %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
549 //      CHECK:   %[[dest1:.*]] = linalg.fill
550 // CHECK-SAME:          outs(%[[dest0]] :
551 //      CHECK:   %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV1:.*]] = %[[C0]]
552 // CHECK-SAME:       iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[dest1]], %[[SECOND_OUT_ARG1:.*]] = %[[ARG2]])
553 // CHECK-SAME:   {
554 //      CHECK:       %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV2:.*]] = %[[C0]]
555 // CHECK-SAME:         iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[dest0]])
556 // CHECK-SAME:         {
557 //      CHECK:            %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
558 //      CHECK:            %[[INPUT_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 512] [1, 1]
559 //      CHECK:            %[[WEIGHT_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[IV2]]] [512, 64] [1, 1]
560 //      CHECK:            %[[TILED_MAT_OUT:.*]] = linalg.matmul
561 // CHECK-SAME:                  outs(%[[MAT_OUT_SLICE]] :
562 //      CHECK:            %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
563 //      CHECK:            %[[ADD_OPERAND2_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG1]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
564 //      CHECK:            %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
565 //      CHECK:            %[[TILED_ADD_OUT:.*]] = linalg.add
566 // CHECK-SAME:              ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE]] :
567 // CHECK-SAME:              outs(%[[ADD_OUT_SLICE]] :
568 //      CHECK:            %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
569 //      CHECK:            scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] :
570 //      CHECK:         }
571 //      CHECK:         scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
572 //      CHECK:   }
573 //      CHECK:   return %[[LOOP_RESULT1]]#0, %[[LOOP_RESULT1]]#1 :
575 // -----
577 module {
578   func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
579     %c0 = arith.constant 0 : index
580     %c64 = arith.constant 64 : index
581     %c256 = arith.constant 256 : index
582     %cst = arith.constant 0.000000e+00 : f32
583     %dest0 = tensor.empty() : tensor<256x256xf32>
584     %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) {
585         %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
586         %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
587         %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
588         %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32>
589         %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
590         scf.yield %insert_slice : tensor<256x256xf32>
591     }
592     %4 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
593     %5 = linalg.exp ins(%1 : tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
594     return %4, %5 : tensor<256x256xf32>, tensor<256x256xf32>
595   }
598 module attributes {transform.with_named_sequence} {
599   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
600     %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
601       : (!transform.any_op) -> !transform.any_op
602     %a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 2
603       : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
604     transform.yield
605   }
607 //      CHECK: func.func @fuse_add_multiple_tilable_consumers(
608 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x256xf32>
609 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<256x256xf32>
610 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
611 //      CHECK:   %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
612 //      CHECK:   %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]]
613 // CHECK-SAME:       iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]]) 
614 // CHECK-SAME:   {
615 //      CHECK:          %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
616 //      CHECK:          %[[ADD_INS0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 256] [1, 1]
617 //      CHECK:          %[[ADD_INS1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[IV1]], 0] [64, 256] [1, 1]
618 //      CHECK:          %[[TILED_ADD_OUT:.*]] = linalg.add
619 // CHECK-SAME:                ins(%[[ADD_INS0_SLICE]], %[[ADD_INS1_SLICE]] :
620 // CHECK-SAME:                outs(%[[ADD_OUT_SLICE]] :
621 //      CHECK:          %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
622 //      CHECK:          %[[EXP_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
623 //      CHECK:          %[[TILED_EXP_OUT:.*]] = linalg.exp
624 // CHECK-SAME:                ins(%[[TILED_ADD_OUT]] :
625 // CHECK-SAME:                outs(%[[EXP_OUT_SLICE]] :
626 //      CHECK:          %[[MUL_INS2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], 0] [64, 256] [1, 1]
627 //      CHECK:          %[[MUL_OUT_SLICE:.*]] = tensor.extract_slice %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
628 //      CHECK:          %[[TILED_MUL_OUT:.*]] = linalg.mul
629 // CHECK-SAME:                ins(%[[TILED_ADD_OUT]], %[[MUL_INS2_SLICE]] :
630 // CHECK-SAME:                outs(%[[MUL_OUT_SLICE]] :
631 //      CHECK:          %[[INSERT_EXP:.*]] = tensor.insert_slice %[[TILED_EXP_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
632 //      CHECK:          %[[INSERT_MUL:.*]] = tensor.insert_slice %[[TILED_MUL_OUT]] into %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
633 //      CHECK:          scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] :
634 //      CHECK:   }
635 //      CHECK:   return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 :
637 // -----
639 module {
640   func.func @no_fuse_only_dps_consumer(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<258x258xf32>) {
641     %c0 = arith.constant 0 : index
642     %c64 = arith.constant 64 : index
643     %c256 = arith.constant 256 : index
644     %cst = arith.constant 0.000000e+00 : f32
645     %dest0 = tensor.empty() : tensor<256x256xf32>
646     %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) {
647         %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
648         %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
649         %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
650         %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32>
651         %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
652         scf.yield %insert_slice : tensor<256x256xf32>
653     }
654     %dest1 = tensor.empty() : tensor<258x258xf32>
655     %4 = tensor.insert_slice %1 into %dest1[0, 0] [256, 256] [1, 1] : tensor<256x256xf32> into tensor<258x258xf32>
656     %5 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
657     return %5, %4 : tensor<256x256xf32>, tensor<258x258xf32>
658   }
661 module attributes {transform.with_named_sequence} {
662   transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
663     %slice_ops = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
664       : (!transform.any_op) -> !transform.any_op
665     %slice_op, %other_slice = transform.split_handle %slice_ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
666     %a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 1
667       : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
668     transform.yield
669   }
671 //      CHECK: func.func @no_fuse_only_dps_consumer(
672 //      CHECK:   %[[LOOP_RESULT:.*]]:2 = scf.for {{.*}} {
673 //      CHECK:     linalg.add
674 //      CHECK:     linalg.mul
675 //      CHECK:     scf.yield
676 //      CHECK:   }
677 //      CHECK:   %[[RES_SLICE:.+]] = tensor.insert_slice
678 //      CHECK:   return %[[LOOP_RESULT]]#1, %[[RES_SLICE]]