1 // RUN: mlir-opt -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer fusion-maximal}))' %s | FileCheck %s
3 // Test fusion of affine nests inside other region-holding ops (scf.for in the
6 // CHECK-LABEL: func @fusion_inner_simple
7 func.func @fusion_inner_simple(%A : memref<10xf32>) {
8 %cst = arith.constant 0.0 : f32
10 affine.for %i = 0 to 100 {
11 %B = memref.alloc() : memref<10xf32>
12 %C = memref.alloc() : memref<10xf32>
14 affine.for %j = 0 to 10 {
15 %v = affine.load %A[%j] : memref<10xf32>
16 affine.store %v, %B[%j] : memref<10xf32>
19 affine.for %j = 0 to 10 {
20 %v = affine.load %B[%j] : memref<10xf32>
21 affine.store %v, %C[%j] : memref<10xf32>
25 // CHECK: affine.for %{{.*}} = 0 to 100
26 // CHECK-NEXT: memref.alloc
27 // CHECK-NEXT: memref.alloc
28 // CHECK-NEXT: affine.for %{{.*}} = 0 to 10
29 // CHECK-NOT: affine.for
34 // CHECK-LABEL: func @fusion_inner_simple_scf
35 func.func @fusion_inner_simple_scf(%A : memref<10xf32>) {
36 %c0 = arith.constant 0 : index
37 %c1 = arith.constant 1 : index
38 %c100 = arith.constant 100 : index
39 %cst = arith.constant 0.0 : f32
41 scf.for %i = %c0 to %c100 step %c1 {
42 %B = memref.alloc() : memref<10xf32>
43 %C = memref.alloc() : memref<10xf32>
45 affine.for %j = 0 to 10 {
46 %v = affine.load %A[%j] : memref<10xf32>
47 affine.store %v, %B[%j] : memref<10xf32>
50 affine.for %j = 0 to 10 {
51 %v = affine.load %B[%j] : memref<10xf32>
52 affine.store %v, %C[%j] : memref<10xf32>
56 // CHECK-NEXT: memref.alloc
57 // CHECK-NEXT: memref.alloc
58 // CHECK-NEXT: affine.for %{{.*}} = 0 to 10
59 // CHECK-NOT: affine.for
63 // CHECK-LABEL: func @fusion_inner_multiple_nests
64 func.func @fusion_inner_multiple_nests() {
65 %alloc_5 = memref.alloc() {alignment = 64 : i64} : memref<4x4xi8>
66 %alloc_10 = memref.alloc() : memref<8x4xi32>
67 affine.for %arg8 = 0 to 4 {
68 %alloc_14 = memref.alloc() : memref<4xi8>
69 %alloc_15 = memref.alloc() : memref<8x4xi8>
70 affine.for %arg9 = 0 to 4 {
71 %0 = affine.load %alloc_5[%arg9, %arg8] : memref<4x4xi8>
72 affine.store %0, %alloc_14[%arg9] : memref<4xi8>
74 %alloc_16 = memref.alloc() : memref<4xi8>
75 affine.for %arg9 = 0 to 4 {
76 %0 = affine.load %alloc_14[%arg9] : memref<4xi8>
77 affine.store %0, %alloc_16[%arg9] : memref<4xi8>
79 affine.for %arg9 = 0 to 2 {
80 %0 = affine.load %alloc_15[%arg9 * 4, 0] : memref<8x4xi8>
81 %1 = affine.load %alloc_16[0] : memref<4xi8>
82 %2 = affine.load %alloc_10[%arg9 * 4, %arg8] : memref<8x4xi32>
83 %3 = arith.muli %0, %1 : i8
84 %4 = arith.extsi %3 : i8 to i32
85 %5 = arith.addi %4, %2 : i32
86 affine.store %5, %alloc_10[%arg9 * 4 + 3, %arg8] : memref<8x4xi32>
88 memref.dealloc %alloc_16 : memref<4xi8>
90 // CHECK: affine.for %{{.*}} = 0 to 4 {
91 // Everything inside fused into two nests (the second will be DCE'd).
92 // CHECK-NEXT: memref.alloc() : memref<4xi8>
93 // CHECK-NEXT: memref.alloc() : memref<1xi8>
94 // CHECK-NEXT: memref.alloc() : memref<1xi8>
95 // CHECK-NEXT: memref.alloc() : memref<8x4xi8>
96 // CHECK-NEXT: memref.alloc() : memref<4xi8>
97 // CHECK-NEXT: affine.for %{{.*}} = 0 to 2 {
99 // CHECK: affine.for %{{.*}} = 0 to 4 {
101 // CHECK-NEXT: memref.dealloc
103 // CHECK-NEXT: return
107 // CHECK-LABEL: func @fusion_inside_scf_while
108 func.func @fusion_inside_scf_while(%A : memref<10xf32>) {
109 %c0 = arith.constant 0 : index
110 %c1 = arith.constant 1 : index
111 %c100 = arith.constant 100 : index
112 %cst = arith.constant 0.0 : f32
114 %0 = scf.while (%arg3 = %cst) : (f32) -> (f32) {
115 %1 = arith.cmpf ult, %arg3, %cst : f32
116 scf.condition(%1) %arg3 : f32
120 %B = memref.alloc() : memref<10xf32>
121 %C = memref.alloc() : memref<10xf32>
123 affine.for %j = 0 to 10 {
124 %v = affine.load %A[%j] : memref<10xf32>
125 affine.store %v, %B[%j] : memref<10xf32>
128 affine.for %j = 0 to 10 {
129 %v = affine.load %B[%j] : memref<10xf32>
130 affine.store %v, %C[%j] : memref<10xf32>
132 %1 = arith.mulf %arg5, %cst : f32
136 // CHECK: affine.for %{{.*}} = 0 to 10
137 // CHECK-NOT: affine.for
143 memref.global "private" constant @__constant_10x2xf32 : memref<10x2xf32> = dense<0.000000e+00>
145 // CHECK-LABEL: func @fusion_inner_long
146 func.func @fusion_inner_long(%arg0: memref<10x2xf32>, %arg1: memref<10x10xf32>, %arg2: memref<10x2xf32>, %s: index) {
147 %c0 = arith.constant 0 : index
148 %cst_0 = arith.constant 1.000000e-03 : f32
149 %c9 = arith.constant 9 : index
150 %c10_i32 = arith.constant 10 : i32
151 %c1_i32 = arith.constant 1 : i32
152 %c100_i32 = arith.constant 100 : i32
153 %c0_i32 = arith.constant 0 : i32
154 %0 = memref.get_global @__constant_10x2xf32 : memref<10x2xf32>
155 %1 = scf.for %arg3 = %c0_i32 to %c100_i32 step %c1_i32 iter_args(%arg4 = %arg0) -> (memref<10x2xf32>) : i32 {
156 %alloc = memref.alloc() {alignment = 64 : i64} : memref<10xi32>
157 affine.for %arg5 = 0 to 10 {
158 %3 = arith.index_cast %arg5 : index to i32
159 affine.store %3, %alloc[%arg5] : memref<10xi32>
161 %2 = scf.for %arg5 = %c0_i32 to %c10_i32 step %c1_i32 iter_args(%arg6 = %0) -> (memref<10x2xf32>) : i32 {
162 %alloc_5 = memref.alloc() : memref<2xf32>
163 affine.for %arg7 = 0 to 2 {
164 %16 = affine.load %arg4[%s, %arg7] : memref<10x2xf32>
165 affine.store %16, %alloc_5[%arg7] : memref<2xf32>
167 %alloc_6 = memref.alloc() {alignment = 64 : i64} : memref<1x2xf32>
168 affine.for %arg7 = 0 to 2 {
169 %16 = affine.load %alloc_5[%arg7] : memref<2xf32>
170 affine.store %16, %alloc_6[0, %arg7] : memref<1x2xf32>
172 %alloc_7 = memref.alloc() {alignment = 64 : i64} : memref<10x2xf32>
173 affine.for %arg7 = 0 to 10 {
174 affine.for %arg8 = 0 to 2 {
175 %16 = affine.load %alloc_6[0, %arg8] : memref<1x2xf32>
176 affine.store %16, %alloc_7[%arg7, %arg8] : memref<10x2xf32>
179 %alloc_8 = memref.alloc() {alignment = 64 : i64} : memref<10x2xf32>
180 affine.for %arg7 = 0 to 10 {
181 affine.for %arg8 = 0 to 2 {
182 %16 = affine.load %alloc_7[%arg7, %arg8] : memref<10x2xf32>
183 %17 = affine.load %arg4[%arg7, %arg8] : memref<10x2xf32>
184 %18 = arith.subf %16, %17 : f32
185 affine.store %18, %alloc_8[%arg7, %arg8] : memref<10x2xf32>
188 scf.yield %alloc_8 : memref<10x2xf32>
191 // CHECK: affine.for %{{.*}} = 0 to 10
192 // CHECK-NEXT: affine.for %{{.*}} = 0 to 2
193 // CHECK-NOT: affine.for
196 %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<10x2xf32>
197 affine.for %arg5 = 0 to 10 {
198 affine.for %arg6 = 0 to 2 {
199 affine.store %cst_0, %alloc_2[%arg5, %arg6] : memref<10x2xf32>
202 %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<10x2xf32>
203 affine.for %arg5 = 0 to 10 {
204 affine.for %arg6 = 0 to 2 {
205 %3 = affine.load %alloc_2[%arg5, %arg6] : memref<10x2xf32>
206 %4 = affine.load %2[%arg5, %arg6] : memref<10x2xf32>
207 %5 = arith.mulf %3, %4 : f32
208 affine.store %5, %alloc_3[%arg5, %arg6] : memref<10x2xf32>
211 scf.yield %alloc_3 : memref<10x2xf32>
212 // The nests above will be fused as well.
213 // CHECK: affine.for %{{.*}} = 0 to 10
214 // CHECK-NEXT: affine.for %{{.*}} = 0 to 2
215 // CHECK-NOT: affine.for
218 affine.for %arg3 = 0 to 10 {
219 affine.for %arg4 = 0 to 2 {
220 %2 = affine.load %1[%arg3, %arg4] : memref<10x2xf32>
221 affine.store %2, %arg2[%arg3, %arg4] : memref<10x2xf32>