1 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion))' -split-input-file | FileCheck %s
2 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal}))' -split-input-file | FileCheck %s --check-prefix=MAXIMAL
4 // Part I of fusion tests in mlir/test/Transforms/loop-fusion.mlir.
5 // Part III of fusion tests in mlir/test/Transforms/loop-fusion-3.mlir
6 // Part IV of fusion tests in mlir/test/Transforms/loop-fusion-4.mlir
10 // CHECK-LABEL: func @should_fuse_at_depth_above_loop_carried_dependence(%{{.*}}: memref<64x4xf32>, %{{.*}}: memref<64x4xf32>) {
11 func.func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4xf32>, %arg1: memref<64x4xf32>) {
12 %out = memref.alloc() : memref<64x4xf32>
13 %0 = arith.constant 0.0 : f32
14 affine.for %i0 = 0 to 64 {
15 affine.for %i1 = 0 to 4 {
16 affine.store %0, %out[%i0, %i1] : memref<64x4xf32>
19 affine.for %i2 = 0 to 4 {
20 affine.for %i3 = 0 to 4 {
21 affine.for %i4 = 0 to 16 {
22 %v = affine.load %arg1[16 * %i3 - %i4 + 15, %i2] : memref<64x4xf32>
23 "op0"(%v) : (f32) -> ()
25 affine.for %i5 = 0 to 4 {
26 affine.for %i6 = 0 to 16 {
27 %v = affine.load %arg0[16 * %i5 - %i6 + 15, %i3] : memref<64x4xf32>
28 "op1"(%v) : (f32) -> ()
30 affine.for %i7 = 0 to 16 {
31 %r = "op2"() : () -> (f32)
32 %v = affine.load %out[16 * %i5 + %i7, %i2] : memref<64x4xf32>
33 %s = arith.addf %v, %r : f32
34 affine.store %s, %out[16 * %i5 + %i7, %i2] : memref<64x4xf32>
40 // We can fuse source loop nest '%i0' into dst loop nest '%i2', but the
41 // depth at which we can insert the src loop nest slice into the dst loop
42 // lest must be decreased because of a loop carried dependence on loop '%i3'.
43 // As a result, the source loop nest is inserted at dst loop nest depth 1,
44 // just above the loop with the carried dependence. In addition, the source
45 // loop nest iteration bounds on its loop '%i1' are reduced to 1, so the
46 // memref size can be reduced to 128x1xf32.
48 // CHECK: memref.alloc() : memref<64x1xf32>
49 // CHECK: affine.for %{{.*}} = 0 to 4 {
50 // CHECK-NEXT: affine.for %{{.*}} = 0 to 64 {
51 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, 0] : memref<64x1xf32>
53 // CHECK-NEXT: affine.for %{{.*}} = 0 to 4 {
54 // CHECK-NEXT: affine.for %{{.*}} = 0 to 16 {
55 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}} * 16 - %{{.*}} + 15, %{{.*}}] : memref<64x4xf32>
56 // CHECK-NEXT: "op0"(%{{.*}}) : (f32) -> ()
58 // CHECK-NEXT: affine.for %{{.*}} = 0 to 4 {
59 // CHECK-NEXT: affine.for %{{.*}} = 0 to 16 {
60 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}} * 16 - %{{.*}} + 15, %{{.*}}] : memref<64x4xf32>
61 // CHECK-NEXT: "op1"(%{{.*}}) : (f32) -> ()
63 // CHECK-NEXT: affine.for %{{.*}} = 0 to 16 {
64 // CHECK-NEXT: %{{.*}} = "op2"() : () -> f32
65 // CHECK: affine.load %{{.*}}[%{{.*}} * 16 + %{{.*}}, 0] : memref<64x1xf32>
66 // CHECK-NEXT: arith.addf %{{.*}}, %{{.*}} : f32
67 // CHECK: affine.store %{{.*}}, %{{.*}}[%{{.*}} * 16 + %{{.*}}, 0] : memref<64x1xf32>
78 // CHECK-LABEL: func @should_fuse_only_two_loops_and_remove_producer() {
79 func.func @should_fuse_only_two_loops_and_remove_producer() {
80 %a = memref.alloc() : memref<10xf32>
81 %b = memref.alloc() : memref<10xf32>
83 %cf7 = arith.constant 7.0 : f32
85 affine.for %i0 = 0 to 10 {
86 affine.store %cf7, %a[%i0] : memref<10xf32>
88 affine.for %i1 = 0 to 10 {
89 %v0 = affine.load %a[%i1] : memref<10xf32>
90 affine.store %v0, %b[%i1] : memref<10xf32>
92 affine.for %i2 = 0 to 10 {
93 %v1 = affine.load %a[%i2] : memref<10xf32>
94 affine.store %v1, %b[%i2] : memref<10xf32>
97 // On the first visit to '%i2', the fusion algorithm can not fuse loop nest
98 // '%i0' into '%i2' because of the dependences '%i0' and '%i2' each have on
99 // '%i1'. Then, '%i0' is fused into '%i1' and no private memref is created for
100 // memref '%a' to be able to remove '%i0' and still preserve the depencence on
102 // TODO: Alternatively, we could fuse '%i0' into '%i1' with a private memref,
103 // the dependence between '%i0' and '%i1' on memref '%a' would no longer exist,
104 // and '%i0' could be fused into '%i2' as well. Note that this approach would
105 // duplicate the computation in loop nest '%i0' to loop nests '%i1' and '%i2',
106 // which would limit its profitability.
107 // CHECK: affine.for %{{.*}} = 0 to 10 {
108 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
109 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
110 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
112 // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
113 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
114 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
116 // CHECK-NEXT: return
122 // CHECK-LABEL: func @should_fuse_after_one_loop_interchange() {
123 func.func @should_fuse_after_one_loop_interchange() {
124 %a = memref.alloc() : memref<10xf32>
126 %cf0 = arith.constant 0.0 : f32
127 affine.for %i0 = 0 to 10 {
128 affine.store %cf0, %a[%i0] : memref<10xf32>
131 affine.for %i1 = 0 to 5 {
132 affine.for %i2 = 0 to 10 {
133 %v0 = affine.load %a[%i2] : memref<10xf32>
134 affine.store %v0, %a[%i2] : memref<10xf32>
138 // The dependence between the load and affine.store is carried on loop '%i1', and
139 // cannot be fused with loop '%i0' without violating this dependence.
140 // Once loops '%i1' and %i2' are interchanged, loop '%i0' can be fused
141 // at loop depth 1, because the loop carrying the dependence has been
142 // interchanged and is now at depth 2.
144 // CHECK: affine.for %{{.*}} = 0 to 10 {
145 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
146 // CHECK-NEXT: affine.for %{{.*}} = 0 to 5 {
147 // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
148 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
151 // CHECK-NEXT: return
156 // CHECK-LABEL: func @should_fuse_after_two_loop_interchanges() {
157 func.func @should_fuse_after_two_loop_interchanges() {
158 %a = memref.alloc() : memref<6x8xf32>
160 %cf0 = arith.constant 0.0 : f32
161 affine.for %i0 = 0 to 6 {
162 affine.for %i1 = 0 to 8 {
163 affine.store %cf0, %a[%i0, %i1] : memref<6x8xf32>
167 affine.for %i2 = 0 to 4 {
168 affine.for %i3 = 0 to 6 {
169 affine.for %i4 = 0 to 2 {
170 affine.for %i5 = 0 to 8 {
171 %v0 = affine.load %a[%i3, %i5] : memref<6x8xf32>
172 %v1 = arith.addf %v0, %v0 : f32
173 affine.store %v1, %a[%i3, %i5] : memref<6x8xf32>
179 // The dependence between the load and affine.store is carried on loops '%i2' and
180 // '%i4', and cannot be fused with loop '%i0' without violating this
182 // Once loop '%i2' is interchanged with loop '%i3', and again with loop
183 // '%i5', then loop '%i0' can be fused at loop depth 2, because the loop
184 // carrying the dependences have been interchanged with loops at depth > 2.
186 // CHECK: affine.for %{{.*}} = 0 to 6 {
187 // CHECK-NEXT: affine.for %{{.*}} = 0 to 8 {
188 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0, 0] : memref<1x1xf32>
189 // CHECK-NEXT: affine.for %{{.*}} = 0 to 4 {
190 // CHECK-NEXT: affine.for %{{.*}} = 0 to 2 {
191 // CHECK-NEXT: affine.load %{{.*}}[0, 0] : memref<1x1xf32>
192 // CHECK-NEXT: arith.addf %{{.*}}, %{{.*}} : f32
193 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0, 0] : memref<1x1xf32>
198 // CHECK-NEXT: return
204 func.func @should_fuse_live_out_writer(%arg0 : memref<10xf32>) -> memref<10xf32> {
205 %cst = arith.constant 0.000000e+00 : f32
206 affine.for %i0 = 0 to 10 {
207 affine.store %cst, %arg0[%i0] : memref<10xf32>
209 affine.for %i1 = 0 to 10 {
210 %1 = affine.load %arg0[%i1] : memref<10xf32>
211 affine.store %1, %arg0[%i1] : memref<10xf32>
213 return %arg0 : memref<10xf32>
215 // CHECK: %{{.*}} = arith.constant 0.000000e+00 : f32
216 // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
217 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
218 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
219 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
221 // CHECK-NEXT: return %{{.*}} : memref<10xf32>
226 // The fused slice has 16 iterations from along %i0.
228 // CHECK-DAG: [[$MAP_LB:#map[0-9]*]] = affine_map<(d0) -> (d0 * 16)>
229 // CHECK-DAG: [[$MAP_UB:#map[0-9]*]] = affine_map<(d0) -> (d0 * 16 + 16)>
231 // CHECK-LABEL: slice_tile
232 func.func @slice_tile(%arg0: memref<128x8xf32>, %arg1: memref<32x8xf32>, %0 : f32) -> memref<32x8xf32> {
233 affine.for %i0 = 0 to 32 {
234 affine.for %i1 = 0 to 8 {
235 affine.store %0, %arg1[%i0, %i1] : memref<32x8xf32>
238 affine.for %i = 0 to 2 {
239 affine.for %j = 0 to 8 {
240 affine.for %k = 0 to 8 {
241 affine.for %kk = 0 to 16 {
242 %v = affine.load %arg0[16 * %k + %kk, %j] : memref<128x8xf32>
243 %r = "foo"(%v) : (f32) -> f32
245 affine.for %ii = 0 to 16 {
246 %v = affine.load %arg1[16 * %i + %ii, %j] : memref<32x8xf32>
247 %s = arith.addf %v, %v : f32
248 affine.store %s, %arg1[16 * %i + %ii, %j] : memref<32x8xf32>
253 return %arg1 : memref<32x8xf32>
255 // CHECK: affine.for %{{.*}} = 0 to 2 {
256 // CHECK-NEXT: affine.for %{{.*}} = 0 to 8 {
257 // CHECK-NEXT: affine.for %{{.*}} = [[$MAP_LB]](%{{.*}}) to [[$MAP_UB]](%{{.*}}) {
258 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<32x8xf32>
260 // CHECK-NEXT: affine.for %{{.*}} = 0 to 8 {
261 // CHECK-NEXT: affine.for %{{.*}} = 0 to 16 {
262 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}} * 16 + %{{.*}}, %{{.*}}] : memref<128x8xf32>
263 // CHECK-NEXT: "foo"(%{{.*}}) : (f32) -> f32
265 // CHECK-NEXT: affine.for %{{.*}} = 0 to 16 {
266 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}} * 16 + %{{.*}}, %{{.*}}] : memref<32x8xf32>
267 // CHECK-NEXT: arith.addf %{{.*}}, %{{.*}} : f32
268 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}} * 16 + %{{.*}}, %{{.*}}] : memref<32x8xf32>
273 // CHECK-NEXT: return %{{.*}} : memref<32x8xf32>
278 // Test case which illustrates fix for b/126454413
279 func.func @test_add_slice_bounds() {
280 %a = memref.alloc() : memref<10xf32>
281 %b = memref.alloc() : memref<10xf32>
282 %cf7 = arith.constant 7.0 : f32
283 %c0 = arith.constant 0 : index
285 affine.for %i0 = 0 to 10 {
286 affine.for %i1 = 0 to 10 {
287 affine.for %i2 = 0 to 10 {
288 %a0 = affine.apply affine_map<(d0) -> (d0)> (%i0)
289 %a1 = affine.apply affine_map<(d0) -> (d0)> (%i0)
290 %a2 = affine.apply affine_map<(d0, d1) -> (d0 - d1)> (%a0, %a1)
291 affine.store %cf7, %a[%a2] : memref<10xf32>
295 affine.for %i3 = 0 to 10 {
296 affine.for %i4 = 0 to 10 {
297 affine.for %i5 = 0 to 10 {
298 %v0 = affine.load %a[%c0] : memref<10xf32>
303 // CHECK: affine.for %{{.*}} = 0 to 10 {
304 // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
305 // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
306 // CHECK-NEXT: affine.apply #map(%{{.*}})
307 // CHECK-NEXT: affine.apply #map(%{{.*}})
308 // CHECK-NEXT: affine.apply #map1(%{{.*}}, %{{.*}})
309 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
313 // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
314 // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
315 // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
316 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
325 func.func @should_fuse_init_loops_siblings_then_shared_producer(%arg0: memref<10x10xf32>, %arg1: memref<10x10xf32>) {
326 %0 = memref.alloc() : memref<10x10xf32>
327 %cst = arith.constant 0.000000e+00 : f32
328 %cst_0 = arith.constant 1.000000e+00 : f32
329 %cst_1 = arith.constant 7.000000e+00 : f32
330 affine.for %i0 = 0 to 10 {
331 affine.for %i1 = 0 to 10 {
332 affine.store %cst_1, %0[%i0, %i1] : memref<10x10xf32>
335 affine.for %i2 = 0 to 3 {
336 affine.for %i3 = 0 to 3 {
337 affine.store %cst, %arg0[%i2, %i3] : memref<10x10xf32>
340 affine.for %i4 = 0 to 3 {
341 affine.for %i5 = 0 to 3 {
342 %1 = affine.load %0[%i4, %i5] : memref<10x10xf32>
343 %2 = affine.load %arg0[%i4, %i5] : memref<10x10xf32>
344 %3 = arith.mulf %1, %2 : f32
345 affine.store %3, %arg0[%i4, %i5] : memref<10x10xf32>
348 affine.for %i6 = 0 to 3 {
349 affine.for %i7 = 0 to 3 {
350 affine.store %cst_0, %arg1[%i6, %i7] : memref<10x10xf32>
353 affine.for %i8 = 0 to 3 {
354 affine.for %i9 = 0 to 3 {
355 %4 = affine.load %0[%i8, %i9] : memref<10x10xf32>
356 %5 = affine.load %arg1[%i8, %i9] : memref<10x10xf32>
357 %6 = arith.addf %4, %5 : f32
358 affine.store %6, %arg1[%i8, %i9] : memref<10x10xf32>
362 // Pass 1: should fuse single-use producer loop nests into their unique user,
363 // so '%i2' will fuse into '%i4' and '%i6' will fuse into '%i8'.
364 // Pass 2: should fuse sibling loop nests which share no dependence edges,
365 // so should fuse '%i4' into '%i8'.
366 // Pass 3: should fuse single-use producer loop nest '%i0' into '%i8'. Note
367 // that loop nest '%i0' now has a single user after Pass 2 fused its
368 // two users together).
370 // CHECK: affine.for %{{.*}} = 0 to 3 {
371 // CHECK-NEXT: affine.for %{{.*}} = 0 to 3 {
372 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0, 0] : memref<1x1xf32>
373 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
374 // CHECK-NEXT: affine.load %{{.*}}[0, 0] : memref<1x1xf32>
375 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
376 // CHECK-NEXT: arith.mulf %{{.*}}, %{{.*}} : f32
377 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
378 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
379 // CHECK-NEXT: affine.load %{{.*}}[0, 0] : memref<1x1xf32>
380 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
381 // CHECK-NEXT: arith.addf %{{.*}}, %{{.*}} : f32
382 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
385 // CHECK-NEXT: return
392 func.func @two_matrix_vector_products() {
393 %in_matrix = memref.alloc() : memref<10x10xf32>
394 %in_vec0 = memref.alloc() : memref<10xf32>
395 %in_vec1 = memref.alloc() : memref<10xf32>
396 %out_vec0 = memref.alloc() : memref<10xf32>
397 %out_vec1 = memref.alloc() : memref<10xf32>
398 %cf7 = arith.constant 7.0 : f32
400 // Populate input matrix.
401 affine.for %i0 = 0 to 10 {
402 affine.for %i1 = 0 to 10 {
403 affine.store %cf7, %in_matrix[%i0, %i1] : memref<10x10xf32>
406 // out_vec0 = in_matrix x in_vec0
407 affine.for %i2 = 0 to 10 {
408 affine.for %i3 = 0 to 10 {
409 %v0 = affine.load %in_matrix[%i2, %i3] : memref<10x10xf32>
410 %v1 = affine.load %in_vec0[%i3] : memref<10xf32>
411 %v2 = arith.mulf %v0, %v1 : f32
412 %v3 = affine.load %out_vec0[%i3] : memref<10xf32>
413 %v4 = arith.addf %v2, %v3 : f32
414 affine.store %v4, %out_vec0[%i3] : memref<10xf32>
417 // out_vec1 = in_matrix x in_vec1
418 affine.for %i4 = 0 to 10 {
419 affine.for %i5 = 0 to 10 {
420 %v5 = affine.load %in_matrix[%i4, %i5] : memref<10x10xf32>
421 %v6 = affine.load %in_vec1[%i5] : memref<10xf32>
422 %v7 = arith.mulf %v5, %v6 : f32
423 %v8 = affine.load %out_vec1[%i5] : memref<10xf32>
424 %v9 = arith.addf %v7, %v8 : f32
425 affine.store %v9, %out_vec1[%i5] : memref<10xf32>
429 // CHECK: affine.for %{{.*}} = 0 to 10 {
430 // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
431 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, 0] : memref<10x1xf32>
433 // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
434 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}, 0] : memref<10x1xf32>
435 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
436 // CHECK-NEXT: arith.mulf %{{.*}}, %{{.*}} : f32
437 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
438 // CHECK-NEXT: arith.addf %{{.*}}, %{{.*}} : f32
439 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
441 // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
442 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}, 0] : memref<10x1xf32>
443 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
444 // CHECK-NEXT: arith.mulf %{{.*}}, %{{.*}} : f32
445 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
446 // CHECK-NEXT: arith.addf %{{.*}}, %{{.*}} : f32
447 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
450 // CHECK-NEXT: return
456 func.func @should_not_slice_past_slice_barrier() {
457 %0 = memref.alloc() : memref<100x16xf32>
458 affine.for %i0 = 0 to 100 {
459 affine.for %i1 = 0 to 16 {
460 %1 = "op1"() : () -> f32
461 affine.store %1, %0[%i0, %i1] : memref<100x16xf32>
462 } {slice_fusion_barrier = true}
464 affine.for %i2 = 0 to 100 {
465 affine.for %i3 = 0 to 16 {
466 %2 = affine.load %0[%i2, %i3] : memref<100x16xf32>
467 "op2"(%2) : (f32) -> ()
470 // The 'slice_fusion_barrier' attribute on '%i1' prevents slicing the
471 // iteration space of '%i1' and any enclosing loop nests.
472 // CHECK: affine.for %{{.*}} = 0 to 100 {
473 // CHECK-NEXT: affine.for %{{.*}} = 0 to 16 {
474 // CHECK-NEXT: %{{.*}} = "op1"() : () -> f32
475 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0, %{{.*}}] : memref<1x16xf32>
476 // CHECK-NEXT: } {slice_fusion_barrier = true}
477 // CHECK-NEXT: affine.for %{{.*}} = 0 to 16 {
478 // CHECK-NEXT: affine.load %{{.*}}[0, %{{.*}}] : memref<1x16xf32>
479 // CHECK-NEXT: "op2"(%{{.*}}) : (f32) -> ()
487 #map = affine_map<(d0, d1) -> (d0 * 16 + d1)>
488 func.func @fuse_across_dim_mismatch(%arg0: memref<4x4x16x1xf32>, %arg1: memref<144x9xf32>, %arg2: memref<9xf32>) {
489 %1 = memref.alloc() : memref<144x4xf32>
490 %2 = arith.constant 0.0 : f32
491 affine.for %i2 = 0 to 9 {
492 affine.for %i3 = 0 to 4 {
493 affine.for %i5 = 0 to 16 {
494 %7 = affine.apply #map(%i2, %i5)
495 affine.store %2, %1[%7, %i3] : memref<144x4xf32>
499 affine.for %i6 = 0 to 9 {
500 affine.for %i7 = 0 to 9 {
501 affine.for %i8 = 0 to 4 {
502 affine.for %i10 = 0 to 16 {
503 %10 = affine.apply #map(%i6, %i10)
504 %11 = affine.load %1[%10, %i8] : memref<144x4xf32>
511 // MAXIMAL: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 * 16 + d1)>
512 // MAXIMAL-LABEL: func @fuse_across_dim_mismatch
513 // MAXIMAL: memref.alloc() : memref<1x1xf32>
514 // MAXIMAL: affine.for %{{.*}} = 0 to 9 {
515 // MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 9 {
516 // MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 4 {
517 // MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 16 {
518 // MAXIMAL-NEXT: affine.apply #[[$MAP]](%{{.*}}, %{{.*}})
519 // MAXIMAL-NEXT: affine.store %{{.*}}, %{{.*}}[0, 0] : memref<1x1xf32>
520 // MAXIMAL-NEXT: affine.apply #[[$MAP]](%{{.*}}, %{{.*}})
521 // MAXIMAL-NEXT: affine.load %{{.*}}[0, 0] : memref<1x1xf32>
529 #map3 = affine_map<(d0, d1) -> ((d0 * 72 + d1) floordiv 2304)>
530 #map4 = affine_map<(d0, d1) -> (((d0 * 72 + d1) mod 2304) floordiv 1152)>
531 #map5 = affine_map<(d0, d1) -> (((((d0 * 72 + d1) mod 2304) mod 1152) floordiv 9) floordiv 8)>
532 #map6 = affine_map<(d0, d1) -> (((((d0 * 72 + d1) mod 2304) mod 1152) mod 9) floordiv 3)>
533 #map7 = affine_map<(d0, d1) -> (((((d0 * 72 + d1) mod 2304) mod 1152) mod 9) mod 3)>
534 #map10 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
535 #map11 = affine_map<(d0, d1) -> (d0 * 16 + d1)>
536 #map12 = affine_map<(d0, d1) -> (d0 * 16 - d1 + 15)>
537 func.func @fuse_across_varying_dims_complex(%arg0: f32) {
538 %c0 = arith.constant 0 : index
539 %0 = memref.alloc() : memref<2x2x3x3x16x1xf32>
540 %1 = memref.alloc() : memref<64x9xf32>
541 %2 = memref.alloc() : memref<144x4xf32>
542 affine.for %i0 = 0 to 64 {
543 affine.for %i1 = 0 to 9 {
544 %4 = affine.apply #map3(%i0, %i1)
545 %5 = affine.apply #map4(%i0, %i1)
546 %6 = affine.apply #map5(%i0, %i1)
547 %7 = affine.apply #map6(%i0, %i1)
548 %8 = affine.apply #map7(%i0, %i1)
549 %9 = affine.load %0[%4, %5, %7, %8, %6, %c0] : memref<2x2x3x3x16x1xf32>
550 affine.store %9, %1[%i0, %i1] : memref<64x9xf32>
553 affine.for %i2 = 0 to 9 {
554 affine.for %i3 = 0 to 4 {
555 affine.for %i4 = 0 to 16 {
556 %10 = affine.apply #map10(%i3, %i4)
557 %11 = affine.load %1[%10, %i2] : memref<64x9xf32>
559 affine.for %i5 = 0 to 16 {
560 %14 = affine.apply #map11(%i2, %i5)
561 affine.store %arg0, %2[%14, %i3] : memref<144x4xf32>
565 affine.for %i6 = 0 to 9 {
566 affine.for %i7 = 0 to 9 {
567 affine.for %i8 = 0 to 4 {
568 affine.for %i9 = 0 to 16 {
569 %15 = affine.apply #map12(%i8, %i9)
570 %16 = affine.load %1[%15, %i7] : memref<64x9xf32>
577 // MAXIMAL-DAG: [[$MAP0:#map[0-9]*]] = affine_map<(d0, d1) -> ((d0 * 72 + d1) floordiv 2304)>
578 // MAXIMAL-DAG: [[$MAP1:#map[0-9]*]] = affine_map<(d0, d1) -> (((d0 * 72 + d1) mod 2304) floordiv 1152)>
579 // MAXIMAL-DAG: [[$MAP2:#map[0-9]*]] = affine_map<(d0, d1) -> ((((d0 * 72 + d1) mod 1152) floordiv 9) floordiv 8)>
580 // MAXIMAL-DAG: [[$MAP3:#map[0-9]*]] = affine_map<(d0, d1) -> ((d1 mod 9) floordiv 3)>
581 // MAXIMAL-DAG: [[$MAP4:#map[0-9]*]] = affine_map<(d0, d1) -> (d1 mod 3)>
582 // MAXIMAL-DAG: [[$MAP7:#map[0-9]*]] = affine_map<(d0, d1) -> (d0 * 16 + d1)>
583 // MAXIMAL-DAG: [[$MAP8:#map[0-9]*]] = affine_map<(d0, d1) -> (d0 * 16 - d1 + 15)>
584 // MAXIMAL-LABEL: func @fuse_across_varying_dims_complex
585 // MAXIMAL-NEXT: memref.alloc() : memref<64x1xf32>
586 // MAXIMAL-NEXT: arith.constant 0 : index
587 // MAXIMAL-NEXT: memref.alloc() : memref<2x2x3x3x16x1xf32>
588 // MAXIMAL-NEXT: memref.alloc() : memref<144x4xf32>
589 // MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 9 {
590 // MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 4 {
591 // MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 16 {
592 // MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 64 {
593 // MAXIMAL-NEXT: affine.apply [[$MAP0]](%{{.*}}, %{{.*}})
594 // MAXIMAL-NEXT: affine.apply [[$MAP1]](%{{.*}}, %{{.*}})
595 // MAXIMAL-NEXT: affine.apply [[$MAP2]](%{{.*}}, %{{.*}})
596 // MAXIMAL-NEXT: affine.apply [[$MAP3]](%{{.*}}, %{{.*}})
597 // MAXIMAL-NEXT: affine.apply [[$MAP4]](%{{.*}}, %{{.*}})
598 // MAXIMAL-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<2x2x3x3x16x1xf32>
599 // MAXIMAL-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, 0] : memref<64x1xf32>
601 // MAXIMAL-NEXT: affine.apply [[$MAP7]](%{{.*}}, %{{.*}})
602 // MAXIMAL-NEXT: affine.load %{{.*}}[%{{.*}} * 16 + %{{.*}}, 0] : memref<64x1xf32>
603 // MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 9 {
604 // MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 4 {
605 // MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 16 {
606 // MAXIMAL-NEXT: affine.apply [[$MAP8]](%{{.*}}, %{{.*}})
607 // MAXIMAL-NEXT: affine.load %{{.*}}[%{{.*}} * 16 - %{{.*}} + 15, 0] : memref<64x1xf32>
612 // MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 16 {
613 // MAXIMAL-NEXT: affine.apply [[$MAP7]](%{{.*}}, %{{.*}})
614 // MAXIMAL-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<144x4xf32>
621 func.func @should_fuse_with_slice_union() {
622 %a = memref.alloc() : memref<100xf32>
623 %c0 = arith.constant 0 : index
624 %cf0 = arith.constant 0.0 : f32
626 affine.for %i0 = 0 to 100 {
627 affine.store %cf0, %a[%i0]: memref<100xf32>
630 affine.for %i1 = 10 to 20 {
631 %v0 = affine.load %a[%i1]: memref<100xf32>
632 affine.for %i2 = 15 to 25 {
633 %v1 = affine.load %a[%i2]: memref<100xf32>
636 // The union of two slice bounds (calculated between the store and each of
637 // the loads) is computed and used in the fusion cost calculation, index
638 // remapping, and private memref size. The result is that the temporary
639 // memref is reduced from 100xf32 to 15xf32 and properly indexed by
640 // the fused loops based on the union calculation.
641 // CHECK: affine.for %{{.*}} = 10 to 20 {
642 // CHECK-NEXT: affine.for %{{.*}} = 10 to 25 {
643 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}} - 10] : memref<15xf32>
645 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}} - 10] : memref<15xf32>
646 // CHECK-NEXT: affine.for %{{.*}} = 15 to 25 {
647 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}} - 10] : memref<15xf32>
650 // CHECK-NEXT: return
656 func.func @affine_add_mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32>, %arg2: memref<1024x1024xf32>, %arg3: memref<1024x1024xf32>) {
657 affine.for %i2 = 0 to 1024 {
658 affine.for %i3 = 0 to 1024 {
659 %0 = affine.load %arg3[%i2, %i3] : memref<1024x1024xf32>
660 %1 = affine.load %arg2[%i2, %i3] : memref<1024x1024xf32>
661 %2 = arith.addf %1, %0 : f32
662 affine.store %2, %arg2[%i2, %i3] : memref<1024x1024xf32>
665 affine.for %i4 = 0 to 1024 {
666 affine.for %i5 = 0 to 1024 {
667 affine.for %i6 = 0 to 1024 {
668 %3 = affine.load %arg1[%i6, %i5] : memref<1024x1024xf32>
669 %4 = affine.load %arg0[%i4, %i6] : memref<1024x1024xf32>
670 %5 = arith.mulf %4, %3 : f32
671 %6 = affine.load %arg2[%i4, %i5] : memref<1024x1024xf32>
672 %7 = arith.addf %6, %5 : f32
673 affine.store %7, %arg2[%i4, %i5] : memref<1024x1024xf32>
677 // Should fuse elementwise add loop at loop depth 2, above loop-carried
678 // dependence between load/store on '%arg2', carried on reduction loop %i6.
679 // CHECK: affine.for %{{.*}} = 0 to 1024 {
680 // CHECK-NEXT: affine.for %{{.*}} = 0 to 1024 {
681 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
682 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
683 // CHECK-NEXT: arith.addf %{{.*}}, %{{.*}} : f32
684 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
685 // CHECK-NEXT: affine.for %{{.*}} = 0 to 1024 {
686 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
687 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
688 // CHECK-NEXT: arith.mulf %{{.*}}, %{{.*}} : f32
689 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
690 // CHECK-NEXT: arith.addf %{{.*}}, %{{.*}} : f32
691 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
700 func.func @affine_2mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32>, %arg2: memref<1024x1024xf32>, %arg3: memref<1024x1024xf32>, %arg4: memref<1024x1024xf32>) {
701 %cst = arith.constant 0.000000e+00 : f32
702 affine.for %i0 = 0 to 1024 {
703 affine.for %i1 = 0 to 1024 {
704 affine.store %cst, %arg2[%i0, %i1] : memref<1024x1024xf32>
707 affine.for %i2 = 0 to 1024 {
708 affine.for %i3 = 0 to 1024 {
709 affine.store %cst, %arg4[%i2, %i3] : memref<1024x1024xf32>
712 affine.for %i4 = 0 to 1024 {
713 affine.for %i5 = 0 to 1024 {
714 affine.for %i6 = 0 to 1024 {
715 %0 = affine.load %arg1[%i6, %i5] : memref<1024x1024xf32>
716 %1 = affine.load %arg0[%i4, %i6] : memref<1024x1024xf32>
717 %2 = arith.mulf %1, %0 : f32
718 %3 = affine.load %arg2[%i4, %i5] : memref<1024x1024xf32>
719 %4 = arith.addf %3, %2 : f32
720 affine.store %4, %arg2[%i4, %i5] : memref<1024x1024xf32>
724 affine.for %i7 = 0 to 1024 {
725 affine.for %i8 = 0 to 1024 {
726 affine.for %i9 = 0 to 1024 {
727 %5 = affine.load %arg1[%i9, %i8] : memref<1024x1024xf32>
728 %6 = affine.load %arg0[%i7, %i9] : memref<1024x1024xf32>
729 %7 = arith.mulf %6, %5 : f32
730 %8 = affine.load %arg4[%i7, %i8] : memref<1024x1024xf32>
731 %9 = arith.addf %8, %7 : f32
732 affine.store %9, %arg4[%i7, %i8] : memref<1024x1024xf32>
737 // Should fuse MM initialization loops into their consumers, then fuse the
738 // two matmul loops together for input reuse on '%arg0/%arg1'.
740 // CHECK: affine.for %{{.*}} = 0 to 1024 {
741 // CHECK-NEXT: affine.for %{{.*}} = 0 to 1024 {
742 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
743 // CHECK-NEXT: affine.for %{{.*}} = 0 to 1024 {
744 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
745 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
746 // CHECK-NEXT: arith.mulf %{{.*}}, %{{.*}} : f32
747 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
748 // CHECK-NEXT: arith.addf %{{.*}}, %{{.*}} : f32
749 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
752 // CHECK-NEXT: affine.for %{{.*}} = 0 to 1024 {
753 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
754 // CHECK-NEXT: affine.for %{{.*}} = 0 to 1024 {
755 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
756 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
757 // CHECK-NEXT: arith.mulf %{{.*}}, %{{.*}} : f32
758 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
759 // CHECK-NEXT: arith.addf %{{.*}}, %{{.*}} : f32
760 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
770 func.func @affine_2_dependent_mm_fused(%arg0: memref<1024x1024xf32>, %arg1: memref<1024x1024xf32>, %arg2: memref<1024x1024xf32>, %arg3: memref<1024x1024xf32>, %arg4: memref<1024x1024xf32>) {
771 affine.for %i0 = 0 to 1024 {
772 affine.for %i1 = 0 to 1024 {
773 affine.for %i2 = 0 to 1024 {
774 %0 = affine.load %arg1[%i2, %i1] : memref<1024x1024xf32>
775 %1 = affine.load %arg0[%i0, %i2] : memref<1024x1024xf32>
776 %2 = arith.mulf %1, %0 : f32
777 %3 = affine.load %arg2[%i0, %i1] : memref<1024x1024xf32>
778 %4 = arith.addf %3, %2 : f32
779 affine.store %4, %arg2[%i0, %i1] : memref<1024x1024xf32>
783 affine.for %i3 = 0 to 1024 {
784 affine.for %i4 = 0 to 1024 {
785 affine.for %i5 = 0 to 1024 {
786 %5 = affine.load %arg3[%i5, %i4] : memref<1024x1024xf32>
787 %6 = affine.load %arg2[%i3, %i5] : memref<1024x1024xf32>
788 %7 = arith.mulf %6, %5 : f32
789 %8 = affine.load %arg4[%i3, %i4] : memref<1024x1024xf32>
790 %9 = arith.addf %8, %7 : f32
791 affine.store %9, %arg4[%i3, %i4] : memref<1024x1024xf32>
796 // CHECK: affine.for %{{.*}} = 0 to 1024 {
797 // CHECK-NEXT: affine.for %{{.*}} = 0 to 1024 {
798 // CHECK-NEXT: affine.for %{{.*}} = 0 to 1024 {
799 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
800 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
801 // CHECK-NEXT: arith.mulf %{{.*}}, %{{.*}} : f32
802 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
803 // CHECK-NEXT: arith.addf %{{.*}}, %{{.*}} : f32
804 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
807 // CHECK-NEXT: affine.for %{{.*}} = 0 to 1024 {
808 // CHECK-NEXT: affine.for %{{.*}} = 0 to 1024 {
809 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
810 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
811 // CHECK-NEXT: arith.mulf %{{.*}}, %{{.*}} : f32
812 // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
813 // CHECK-NEXT: arith.addf %{{.*}}, %{{.*}} : f32
814 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<1024x1024xf32>
823 // CHECK-LABEL: func @should_fuse_self_dependence_multi_store_producer() {
824 func.func @should_fuse_self_dependence_multi_store_producer() {
825 %m = memref.alloc() : memref<10xf32>
826 %local_m = memref.alloc() : memref<10xf32>
827 %cf7 = arith.constant 7.0 : f32
829 affine.for %i0 = 0 to 10 {
830 affine.store %cf7, %local_m[%i0] : memref<10xf32>
831 %v0 = affine.load %local_m[%i0] : memref<10xf32>
832 affine.store %v0, %m[%i0] : memref<10xf32>
834 affine.for %i1 = 0 to 10 {
835 %v1 = affine.load %m[%i1] : memref<10xf32>
837 // CHECK: affine.for %[[i0:.*]] = 0 to 10 {
838 // CHECK-NEXT: affine.store %{{.*}}, [[LOCAL_M:%.*]][%[[i0]]] : memref<10xf32>
839 // CHECK-NEXT: [[v0:%.*]] = affine.load [[LOCAL_M]][%[[i0]]] : memref<10xf32>
840 // CHECK-NEXT: affine.store [[v0]], %{{.*}}[0] : memref<1xf32>
841 // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
843 // CHECK-NEXT: return
849 // CHECK-LABEL: func @should_fuse_dead_multi_store_producer() {
850 func.func @should_fuse_dead_multi_store_producer() {
851 %m = memref.alloc() : memref<10xf32>
852 %dead_m = memref.alloc() : memref<10xf32>
853 %cf7 = arith.constant 7.0 : f32
855 affine.for %i0 = 0 to 10 {
856 affine.store %cf7, %dead_m[%i0] : memref<10xf32>
857 affine.store %cf7, %m[%i0] : memref<10xf32>
859 affine.for %i1 = 0 to 10 {
860 %v0 = affine.load %m[%i1] : memref<10xf32>
862 // CHECK: affine.for %[[i0:.*]] = 0 to 10 {
863 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%[[i0]]] : memref<10xf32>
864 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
865 // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
867 // CHECK-NEXT: return
873 // CHECK-LABEL: func @should_fuse_function_live_out_multi_store_producer
874 func.func @should_fuse_function_live_out_multi_store_producer(%live_in_out_m : memref<10xf32>) {
875 %m = memref.alloc() : memref<10xf32>
876 %cf7 = arith.constant 7.0 : f32
878 affine.for %i0 = 0 to 10 {
879 affine.store %cf7, %live_in_out_m[%i0] : memref<10xf32>
880 affine.store %cf7, %m[%i0] : memref<10xf32>
882 affine.for %i1 = 0 to 10 {
883 %v0 = affine.load %m[%i1] : memref<10xf32>
885 // CHECK: affine.for %[[i0:.*]] = 0 to 10 {
886 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%[[i0]]] : memref<10xf32>
887 // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
888 // CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
890 // CHECK-NEXT: return
894 // Add further tests in mlir/test/Transforms/loop-fusion-4.mlir