1 // RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s
3 #map = affine_map<(d0) -> (d0)>
5 func.func @fuse_tileable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
6 %c4 = arith.constant 4 : index
7 %c64 = arith.constant 64 : index
8 %c0 = arith.constant 0 : index
9 %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) {
10 %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32>
11 %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) {
12 ^bb0(%in: f32, %in_16: f32, %out: f32):
13 %13 = arith.mulf %in, %in_16 : f32
14 %14 = arith.addf %out, %13 : f32
15 linalg.yield %14 : f32
17 %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
18 scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32>
20 %in_operand_2 = tensor.empty() : tensor<64xf32>
21 %out_operand_3 = tensor.empty() : tensor<64xf32>
22 %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3 : tensor<64xf32>) -> tensor<64xf32>
23 return %2 : tensor<64xf32>
27 module attributes {transform.with_named_sequence} {
28 transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
29 %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
30 : (!transform.any_op) -> !transform.any_op
31 %a, %b = transform.test.fuse_consumer %yield
32 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
36 // CHECK: func.func @fuse_tileable_consumer_scf_for(
37 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32>
38 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32>
39 // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>)
40 // CHECK: %[[C0:.*]] = arith.constant 0 : index
41 // CHECK: %0 = tensor.empty() : tensor<64xf32>
42 // CHECK: %[[FINAL_RESULT:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]]
43 // CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %0)
45 // CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
46 // CHECK: %[[MAT_OUT:.*]] = linalg.generic
47 // CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
48 // CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
49 // CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
50 // CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
51 // CHECK: %[[ELEM_OUT:.*]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
52 // CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
53 // CHECK-SAME: outs(%[[SLICE_OUT]] :
54 // CHECK: %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1]
55 // CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM]] :
57 // CHECK: return %[[FINAL_RESULT]]#2 :
62 func.func @fuse_tileable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
63 %c4 = arith.constant 4 : index
64 %c64 = arith.constant 64 : index
65 %c0 = arith.constant 0 : index
66 %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) {
67 %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
68 %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
69 %3 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
70 scf.forall.in_parallel {
71 tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
72 tensor.parallel_insert_slice %extracted_slice_1 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
75 %in_operand_2 = tensor.empty() : tensor<64x64xf32>
76 %out_operand_3 = tensor.empty() : tensor<64x64xf32>
77 %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1#1, %in_operand_2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%out_operand_3 : tensor<64x64xf32>) -> tensor<64x64xf32>
78 return %2 : tensor<64x64xf32>
82 module attributes {transform.with_named_sequence} {
83 transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
84 %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
85 : (!transform.any_op) -> !transform.any_op
86 %first_slice_op, %second_slice_op = transform.split_handle %slice_ops
88 -> (!transform.any_op, !transform.any_op)
89 %a, %b = transform.test.fuse_consumer %first_slice_op
90 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
94 // CHECK: func.func @fuse_tileable_consumer_scf_forall(
95 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
96 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
97 // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>)
98 // CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32>
99 // CHECK: %[[FINAL_RESULT:.*]]:3 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
100 // CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %[[OUT_INIT]])
102 // CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
103 // CHECK: %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
104 // CHECK: %[[MAT_OUT:.*]] = linalg.matmul
105 // CHECK-SAME: outs(%[[MAT_OUT_SLICE]] :
106 // CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
107 // CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
108 // CHECK: %[[ELEM_OUT:.*]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
109 // CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
110 // CHECK-SAME: outs(%[[SLICE_OUT]] :
111 // CHECK: scf.forall.in_parallel {
112 // CHECK: tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
113 // CHECK: tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
114 // CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
117 // CHECK: return %[[FINAL_RESULT]]#2 :
121 #map = affine_map<(d0) -> (d0)>
123 func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
124 %c4 = arith.constant 4 : index
125 %c64 = arith.constant 64 : index
126 %c0 = arith.constant 0 : index
127 %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) {
128 %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32>
129 %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) {
130 ^bb0(%in: f32, %in_16: f32, %out: f32):
131 %13 = arith.mulf %in, %in_16 : f32
132 %14 = arith.addf %out, %13 : f32
133 linalg.yield %14 : f32
135 %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
136 scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32>
138 %in_operand_2 = tensor.empty() : tensor<64xf32>
139 %out_operand_3 = tensor.empty() : tensor<64xf32>
140 %out_operand_4 = tensor.empty() : tensor<64xf32>
141 %2:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3, %out_operand_4 : tensor<64xf32>, tensor<64xf32>) {
142 ^bb0(%in: f32, %in_16: f32, %out_0: f32, %out_1: f32):
143 %13 = arith.mulf %in, %in_16 : f32
144 %14 = arith.subf %out_0, %13 : f32
145 %15 = arith.addf %out_1, %in : f32
146 linalg.yield %14, %15 : f32, f32
147 } -> (tensor<64xf32>, tensor<64xf32>)
148 return %2#1 : tensor<64xf32>
152 module attributes {transform.with_named_sequence} {
153 transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
154 %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
155 : (!transform.any_op) -> !transform.any_op
156 %a, %b = transform.test.fuse_consumer %yield
157 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
161 // CHECK: func.func @fuse_tileable_consumer_scf_for_multi_yielding_consumer(
162 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32>
163 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32>
164 // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>)
165 // CHECK: %[[C0:.*]] = arith.constant 0 : index
166 // CHECK: %0 = tensor.empty() : tensor<64xf32>
167 // CHECK: %[[FINAL_RESULT:.*]]:4 = scf.for %[[IV:.*]] = %[[C0]]
168 // CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %0, %[[ELEM_OUT_ARG_1:.*]] = %0)
170 // CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
171 // CHECK: %[[MAT_OUT:.*]] = linalg.generic
172 // CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>)
173 // CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1]
174 // CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1]
175 // CHECK: %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1]
176 // CHECK: %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1]
177 // CHECK: %[[ELEM_OUT:.*]]:2 = linalg.generic
178 // CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
179 // CHECK-SAME: outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] :
180 // CHECK: %[[INSERT_ELEM_0:.*]] = tensor.insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV]]] [32] [1]
181 // CHECK: %[[INSERT_ELEM_1:.*]] = tensor.insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV]]] [32] [1]
182 // CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM_0]], %[[INSERT_ELEM_1]] :
184 // CHECK: return %[[FINAL_RESULT]]#3 :
188 #map = affine_map<(d0, d1) -> (d0, d1)>
190 func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>, %arg3: tensor<64x32xf32>) -> (tensor<64x64xf32>, tensor<2048xf32>) {
191 %c4 = arith.constant 4 : index
192 %c64 = arith.constant 64 : index
193 %c0 = arith.constant 0 : index
194 %0:2 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %arg3, %arg7 = %arg2) -> (tensor<64x32xf32>, tensor<64x64xf32>) {
195 %extracted_slice = tensor.extract_slice %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
196 %extracted_slice_0 = tensor.extract_slice %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
197 %6 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32>
198 scf.forall.in_parallel {
199 tensor.parallel_insert_slice %6 into %arg7[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
200 tensor.parallel_insert_slice %extracted_slice_0 into %arg6[%arg4, %arg5] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
203 %1 = tensor.empty() : tensor<64x64xf32>
204 %2 = tensor.empty() : tensor<64x64xf32>
205 %3 = tensor.empty() : tensor<64x64xf32>
206 %4:2 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%0#1, %1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%2, %3 : tensor<64x64xf32>, tensor<64x64xf32>) {
207 ^bb0(%in: f32, %in_0: f32, %out: f32, %out_1: f32):
208 %6 = arith.mulf %in, %in_0 : f32
209 %7 = arith.subf %out, %6 : f32
210 %8 = arith.addf %out_1, %in : f32
211 linalg.yield %7, %8 : f32, f32
212 } -> (tensor<64x64xf32>, tensor<64x64xf32>)
213 %5 = tensor.empty() : tensor<2048xf32>
214 %unpack = tensor.unpack %0#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %5 : tensor<64x32xf32> -> tensor<2048xf32>
215 return %4#1, %unpack : tensor<64x64xf32>, tensor<2048xf32>
219 module attributes {transform.with_named_sequence} {
220 transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
221 %slice_ops = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
222 : (!transform.any_op) -> !transform.any_op
223 %first_slice_op, %second_slice_op = transform.split_handle %slice_ops
224 : (!transform.any_op)
225 -> (!transform.any_op, !transform.any_op)
226 %a, %b = transform.test.fuse_consumer %first_slice_op
227 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
231 // CHECK: func.func @fuse_tileable_consumer_scf_forall_multi_yielding_consumer(
232 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
233 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
234 // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x64xf32>
235 // CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
236 // CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<64x64xf32>
237 // CHECK: %[[FINAL_RESULT:.*]]:4 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
238 // CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG3]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG_0:.*]] = %[[OUT_INIT]], %[[ELEM_OUT_ARG_1:.*]] = %[[OUT_INIT]])
240 // CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
241 // CHECK: %[[SECOND_ARG_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
242 // CHECK: %[[MAT_OUT:.*]] = linalg.matmul
243 // CHECK-SAME: outs(%[[MAT_OUT_SLICE]] :
244 // CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %[[OUT_INIT]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
245 // CHECK: %[[SLICE_OUT_0:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
246 // CHECK: %[[SLICE_OUT_1:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
247 // CHECK: %[[ELEM_OUT:.*]]:2 = linalg.generic
248 // CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] :
249 // CHECK-SAME: outs(%[[SLICE_OUT_0]], %[[SLICE_OUT_1]] :
250 // CHECK: scf.forall.in_parallel {
251 // CHECK: tensor.parallel_insert_slice %[[MAT_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
252 // CHECK: tensor.parallel_insert_slice %[[SECOND_ARG_SLICE]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
253 // CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]]#0 into %[[ELEM_OUT_ARG_0]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
254 // CHECK: tensor.parallel_insert_slice %[[ELEM_OUT]]#1 into %[[ELEM_OUT_ARG_1]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
257 // CHECK: %[[UNPACK:.*]] = tensor.unpack %[[FINAL_RESULT]]#0 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %{{.*}} : tensor<64x32xf32> -> tensor<2048xf32>
258 // CHECK: return %[[FINAL_RESULT]]#3, %[[UNPACK]] :
262 #map = affine_map<(d0, d1) -> (d0, d1)>
264 func.func @fuse_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2048xf32> {
265 %c4 = arith.constant 4 : index
266 %c64 = arith.constant 64 : index
267 %c0 = arith.constant 0 : index
268 %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
269 %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
270 %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
271 ^bb0(%in: f32, %in_16: f32, %out: f32):
272 %13 = arith.mulf %in, %in_16 : f32
273 %14 = arith.addf %out, %13 : f32
274 linalg.yield %14 : f32
275 } -> tensor<32x32xf32>
276 scf.forall.in_parallel {
277 tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
280 %output = tensor.empty() : tensor<2048xf32>
281 %unpack = tensor.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2048xf32>
282 return %unpack : tensor<2048xf32>
286 module attributes {transform.with_named_sequence} {
287 transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
288 %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
289 : (!transform.any_op) -> !transform.any_op
290 %a, %b = transform.test.fuse_consumer %slice_op
291 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
295 // CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
296 // CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2048)>
297 // CHECK: func.func @fuse_unpack_consumer_into_scf_forall(
298 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
299 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
300 // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
301 // CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<2048xf32>
302 // CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) = (0, 0) to (64, 32) step (32, 32)
303 // CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]])
305 // CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
306 // CHECK: %[[GENERIC_OUT:.*]] = linalg.generic
307 // CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
308 // CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]])
309 // CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]])
310 // CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
311 // CHECK: %[[TILED_UNPACK_OUT:.*]] = tensor.unpack %[[GENERIC_OUT]]
312 // CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
313 // CHECK-SAME: into %[[TILED_UNPACK_DEST]]
314 // CHECK: scf.forall.in_parallel {
315 // CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
316 // CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
319 // CHECK: return %[[FINAL_RESULT]]#1 :
323 #map = affine_map<(d0, d1) -> (d0, d1)>
325 func.func @fuse_unaligned_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2047xf32> {
326 %c4 = arith.constant 4 : index
327 %c64 = arith.constant 64 : index
328 %c0 = arith.constant 0 : index
329 %1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
330 %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
331 %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
332 ^bb0(%in: f32, %in_16: f32, %out: f32):
333 %13 = arith.mulf %in, %in_16 : f32
334 %14 = arith.addf %out, %13 : f32
335 linalg.yield %14 : f32
336 } -> tensor<32x32xf32>
337 scf.forall.in_parallel {
338 tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
341 %output = tensor.empty() : tensor<2047xf32>
342 %unpack = tensor.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2047xf32>
343 return %unpack : tensor<2047xf32>
347 module attributes {transform.with_named_sequence} {
348 transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
349 %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
350 : (!transform.any_op) -> !transform.any_op
351 %a, %b = transform.test.fuse_consumer %slice_op
352 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
356 // CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
357 // CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2047)>
358 // CHECK: func.func @fuse_unaligned_unpack_consumer_into_scf_forall(
359 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
360 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
361 // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
362 // CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<2047xf32>
363 // CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) = (0, 0) to (64, 32) step (32, 32)
364 // CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]])
366 // CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
367 // CHECK: %[[GENERIC_OUT:.*]] = linalg.generic
368 // CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
369 // CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]])
370 // CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]])
371 // CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
372 // CHECK: %[[TILED_UNPACK_OUT:.*]] = tensor.unpack %[[GENERIC_OUT]]
373 // CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
374 // CHECK-SAME: into %[[TILED_UNPACK_DEST]]
375 // CHECK: scf.forall.in_parallel {
376 // CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
377 // CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
380 // CHECK: return %[[FINAL_RESULT]]#1 :
384 #map = affine_map<(d0, d1) -> (d0, d1)>
386 func.func @fuse_pack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {
387 %c4 = arith.constant 4 : index
388 %c64 = arith.constant 64 : index
389 %c0 = arith.constant 0 : index
390 %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
391 %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
392 %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
393 ^bb0(%in: f32, %in_16: f32, %out: f32):
394 %13 = arith.mulf %in, %in_16 : f32
395 %14 = arith.addf %out, %13 : f32
396 linalg.yield %14 : f32
397 } -> tensor<32x32xf32>
398 scf.forall.in_parallel {
399 tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
402 %output = tensor.empty() : tensor<4x32x16xf32>
403 %pack = tensor.pack %1 inner_dims_pos = [0] inner_tiles = [16] into %output : tensor<64x32xf32> -> tensor<4x32x16xf32>
404 return %pack : tensor<4x32x16xf32>
408 module attributes {transform.with_named_sequence} {
409 transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
410 %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
411 : (!transform.any_op) -> !transform.any_op
412 %a, %b = transform.test.fuse_consumer %slice_op
413 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
417 // CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
418 // CHECK: func.func @fuse_pack_consumer_into_scf_forall(
419 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
420 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
421 // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
422 // CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<4x32x16xf32>
423 // CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
424 // CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
426 // CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
427 // CHECK: %[[GENERIC_OUT:.*]] = linalg.generic
428 // CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
429 // CHECK: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV1]])
430 // CHECK: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1]
431 // CHECK: %[[TILED_PACK_OUT:.*]] = tensor.pack %[[GENERIC_OUT]]
432 // CHECK-SAME: inner_dims_pos = [0] inner_tiles = [16]
433 // CHECK-SAME: into %[[TILED_PACK_DEST]]
434 // CHECK: scf.forall.in_parallel {
435 // CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
436 // CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1]
441 func.func @fuse_add_consumer_into_nested_scf_for(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
442 %c0 = arith.constant 0 : index
443 %c64 = arith.constant 64 : index
444 %c256 = arith.constant 256 : index
445 %cst = arith.constant 0.000000e+00 : f32
446 %dest0 = tensor.empty() : tensor<256x256xf32>
447 %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
448 %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest1) -> (tensor<256x256xf32>) {
449 %2 = scf.for %arg5 = %c0 to %c256 step %c64 iter_args(%arg6 = %arg4) -> (tensor<256x256xf32>) {
450 %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg5] [64, 64] [1, 1] : tensor<256x256xf32> to tensor<64x64xf32>
451 %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 512] [1, 1] : tensor<256x512xf32> to tensor<64x512xf32>
452 %extracted_slice_3 = tensor.extract_slice %arg1[0, %arg5] [512, 64] [1, 1] : tensor<512x256xf32> to tensor<512x64xf32>
453 %3 = linalg.matmul ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_1 : tensor<64x64xf32>) -> tensor<64x64xf32>
454 %insert_slice = tensor.insert_slice %3 into %arg6[%arg3, %arg5] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<256x256xf32>
455 scf.yield %insert_slice : tensor<256x256xf32>
457 scf.yield %2 : tensor<256x256xf32>
459 %4 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
460 return %4 : tensor<256x256xf32>
464 module attributes {transform.with_named_sequence} {
465 transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
466 %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
467 : (!transform.any_op) -> !transform.any_op
468 %a, %b = transform.test.fuse_consumer %slice_op
469 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
473 // CHECK: func.func @fuse_add_consumer_into_nested_scf_for(
474 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
475 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
476 // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
477 // CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
478 // CHECK: %[[dest1:.*]] = linalg.fill
479 // CHECK-SAME: outs(%[[dest0]] :
480 // CHECK: %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV1:.*]] = %[[C0]]
481 // CHECK-SAME: iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[dest1]], %[[SECOND_OUT_ARG1:.*]] = %[[dest0]])
483 // CHECK: %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV2:.*]] = %[[C0]]
484 // CHECK-SAME: iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[SECOND_OUT_ARG1]])
486 // CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
487 // CHECK: %[[INPUT_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 512] [1, 1]
488 // CHECK: %[[WEIGHT_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[IV2]]] [512, 64] [1, 1]
489 // CHECK: %[[TILED_MAT_OUT:.*]] = linalg.matmul
490 // CHECK-SAME: outs(%[[MAT_OUT_SLICE]] :
491 // CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
492 // CHECK: %[[ADD_OPERAND2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
493 // CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
494 // CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add
495 // CHECK-SAME: ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE]] :
496 // CHECK-SAME: outs(%[[ADD_OUT_SLICE]] :
497 // CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
498 // CHECK: scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] :
500 // CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
502 // CHECK: return %[[LOOP_RESULT1]]#1 :
506 // This test case checks fusion of consumer even if the producer has multiple uses.
507 // The multiple uses of the producer essentially means that besides the consumer
508 // op in concern, the only other uses of the producer are allowed in :-
510 // 2. tensor.parallel_insert_slice
514 func.func @fuse_consumer_for_multi_use_producer(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
515 %c0 = arith.constant 0 : index
516 %c64 = arith.constant 64 : index
517 %c256 = arith.constant 256 : index
518 %cst = arith.constant 0.000000e+00 : f32
519 %0 = tensor.empty() : tensor<256x256xf32>
520 %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
521 %2:2 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %1, %arg5 = %arg2) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
522 %3 = scf.for %arg6 = %c0 to %c256 step %c64 iter_args(%arg7 = %arg4) -> (tensor<256x256xf32>) {
523 %extracted_slice = tensor.extract_slice %arg7[%arg3, %arg6] [64, 64] [1, 1] : tensor<256x256xf32> to tensor<64x64xf32>
524 %extracted_slice_0 = tensor.extract_slice %arg0[%arg3, 0] [64, 512] [1, 1] : tensor<256x512xf32> to tensor<64x512xf32>
525 %extracted_slice_1 = tensor.extract_slice %arg1[0, %arg6] [512, 64] [1, 1] : tensor<512x256xf32> to tensor<512x64xf32>
526 %5 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_1 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice : tensor<64x64xf32>) -> tensor<64x64xf32>
527 %inserted_slice = tensor.insert_slice %5 into %arg7[%arg3, %arg6] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<256x256xf32>
528 scf.yield %inserted_slice : tensor<256x256xf32>
530 %4 = linalg.add ins(%3, %arg5 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
531 scf.yield %3, %4 : tensor<256x256xf32>, tensor<256x256xf32>
533 return %2#0, %2#1 : tensor<256x256xf32>, tensor<256x256xf32>
536 module attributes {transform.with_named_sequence} {
537 transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
538 %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
539 %consumer, %fused_consumer = transform.test.fuse_consumer %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
544 // CHECK: func.func @fuse_consumer_for_multi_use_producer(
545 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
546 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
547 // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
548 // CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
549 // CHECK: %[[dest1:.*]] = linalg.fill
550 // CHECK-SAME: outs(%[[dest0]] :
551 // CHECK: %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV1:.*]] = %[[C0]]
552 // CHECK-SAME: iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[dest1]], %[[SECOND_OUT_ARG1:.*]] = %[[ARG2]])
554 // CHECK: %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV2:.*]] = %[[C0]]
555 // CHECK-SAME: iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[dest0]])
557 // CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
558 // CHECK: %[[INPUT_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 512] [1, 1]
559 // CHECK: %[[WEIGHT_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[IV2]]] [512, 64] [1, 1]
560 // CHECK: %[[TILED_MAT_OUT:.*]] = linalg.matmul
561 // CHECK-SAME: outs(%[[MAT_OUT_SLICE]] :
562 // CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
563 // CHECK: %[[ADD_OPERAND2_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG1]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
564 // CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
565 // CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add
566 // CHECK-SAME: ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE]] :
567 // CHECK-SAME: outs(%[[ADD_OUT_SLICE]] :
568 // CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
569 // CHECK: scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] :
571 // CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
573 // CHECK: return %[[LOOP_RESULT1]]#0, %[[LOOP_RESULT1]]#1 :
578 func.func @fuse_add_multiple_tilable_consumers(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
579 %c0 = arith.constant 0 : index
580 %c64 = arith.constant 64 : index
581 %c256 = arith.constant 256 : index
582 %cst = arith.constant 0.000000e+00 : f32
583 %dest0 = tensor.empty() : tensor<256x256xf32>
584 %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) {
585 %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
586 %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
587 %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
588 %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32>
589 %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
590 scf.yield %insert_slice : tensor<256x256xf32>
592 %4 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
593 %5 = linalg.exp ins(%1 : tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
594 return %4, %5 : tensor<256x256xf32>, tensor<256x256xf32>
598 module attributes {transform.with_named_sequence} {
599 transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
600 %slice_op = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
601 : (!transform.any_op) -> !transform.any_op
602 %a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 2
603 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
607 // CHECK: func.func @fuse_add_multiple_tilable_consumers(
608 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x256xf32>
609 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<256x256xf32>
610 // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
611 // CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
612 // CHECK: %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]]
613 // CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
615 // CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
616 // CHECK: %[[ADD_INS0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 256] [1, 1]
617 // CHECK: %[[ADD_INS1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[IV1]], 0] [64, 256] [1, 1]
618 // CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add
619 // CHECK-SAME: ins(%[[ADD_INS0_SLICE]], %[[ADD_INS1_SLICE]] :
620 // CHECK-SAME: outs(%[[ADD_OUT_SLICE]] :
621 // CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
622 // CHECK: %[[EXP_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
623 // CHECK: %[[TILED_EXP_OUT:.*]] = linalg.exp
624 // CHECK-SAME: ins(%[[TILED_ADD_OUT]] :
625 // CHECK-SAME: outs(%[[EXP_OUT_SLICE]] :
626 // CHECK: %[[MUL_INS2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], 0] [64, 256] [1, 1]
627 // CHECK: %[[MUL_OUT_SLICE:.*]] = tensor.extract_slice %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
628 // CHECK: %[[TILED_MUL_OUT:.*]] = linalg.mul
629 // CHECK-SAME: ins(%[[TILED_ADD_OUT]], %[[MUL_INS2_SLICE]] :
630 // CHECK-SAME: outs(%[[MUL_OUT_SLICE]] :
631 // CHECK: %[[INSERT_EXP:.*]] = tensor.insert_slice %[[TILED_EXP_OUT]] into %[[SECOND_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
632 // CHECK: %[[INSERT_MUL:.*]] = tensor.insert_slice %[[TILED_MUL_OUT]] into %[[THIRD_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
633 // CHECK: scf.yield %[[INSERT_ADD]], %[[INSERT_EXP]], %[[INSERT_MUL]] :
635 // CHECK: return %[[LOOP_RESULT]]#2, %[[LOOP_RESULT]]#1 :
640 func.func @no_fuse_only_dps_consumer(%arg0: tensor<256x256xf32>, %arg1: tensor<256x256xf32>, %arg2: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<258x258xf32>) {
641 %c0 = arith.constant 0 : index
642 %c64 = arith.constant 64 : index
643 %c256 = arith.constant 256 : index
644 %cst = arith.constant 0.000000e+00 : f32
645 %dest0 = tensor.empty() : tensor<256x256xf32>
646 %1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest0) -> (tensor<256x256xf32>) {
647 %extracted_slice_1 = tensor.extract_slice %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
648 %extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
649 %extracted_slice_3 = tensor.extract_slice %arg1[%arg3, 0] [64, 256] [1, 1] : tensor<256x256xf32> to tensor<64x256xf32>
650 %3 = linalg.add ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x256xf32>, tensor<64x256xf32>) outs(%extracted_slice_1 : tensor<64x256xf32>) -> tensor<64x256xf32>
651 %insert_slice = tensor.insert_slice %3 into %arg4[%arg3, 0] [64, 256] [1, 1] : tensor<64x256xf32> into tensor<256x256xf32>
652 scf.yield %insert_slice : tensor<256x256xf32>
654 %dest1 = tensor.empty() : tensor<258x258xf32>
655 %4 = tensor.insert_slice %1 into %dest1[0, 0] [256, 256] [1, 1] : tensor<256x256xf32> into tensor<258x258xf32>
656 %5 = linalg.mul ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
657 return %5, %4 : tensor<256x256xf32>, tensor<258x258xf32>
661 module attributes {transform.with_named_sequence} {
662 transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
663 %slice_ops = transform.structured.match ops{["tensor.insert_slice"]} in %arg1
664 : (!transform.any_op) -> !transform.any_op
665 %slice_op, %other_slice = transform.split_handle %slice_ops : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
666 %a, %b = transform.test.fuse_consumer %slice_op num_consumer_to_fuse = 1
667 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
671 // CHECK: func.func @no_fuse_only_dps_consumer(
672 // CHECK: %[[LOOP_RESULT:.*]]:2 = scf.for {{.*}} {
677 // CHECK: %[[RES_SLICE:.+]] = tensor.insert_slice
678 // CHECK: return %[[LOOP_RESULT]]#1, %[[RES_SLICE]]