[flang] Fix crash in HLFIR generation (#118399)
[llvm-project.git] / mlir / test / Transforms / loop-invariant-subset-hoisting.mlir
blob3a78287a0dcad2fc8f8974bb6e1e905be42f138d
1 // RUN: mlir-opt %s  -split-input-file -loop-invariant-subset-hoisting | FileCheck %s
3 // CHECK-LABEL: func @hoist_matching_extract_insert(
4 //  CHECK-SAME:     %[[arg:.*]]: tensor<?xf32>
5 func.func @hoist_matching_extract_insert(%arg: tensor<?xf32>) -> tensor<?xf32> {
6   %lb = "test.foo"() : () -> (index)
7   %ub = "test.foo"() : () -> (index)
8   %step = "test.foo"() : () -> (index)
10   %c0 = arith.constant 0 : index
11   %c1 = arith.constant 1 : index
12   %add = arith.addi %c0, %c1 : index
13   %sub = arith.subi %add, %c1 : index
15   // CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]]
16   // CHECK: %[[for:.*]]:2 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]])
17   %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
18     // CHECK: tensor.extract_slice %[[t]][9] [5] [1]
19     %standalone = tensor.extract_slice %t[9][5][1] : tensor<?xf32> to tensor<5xf32>
20     "test.foo"(%standalone) : (tensor<5xf32>) -> ()
22     %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
23     // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]])
24     %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
25     // Obfuscate the IR by inserting at offset %sub instead of 0; both of them
26     // have the same value.
27     %3 = tensor.insert_slice %2 into %t[%sub][5][1] : tensor<5xf32> into tensor<?xf32>
28     // CHECK: scf.yield %[[t]], %[[foo]]
29     scf.yield %3 : tensor<?xf32>
30   }
31   // CHECK: %[[insert:.*]] = tensor.insert_slice %[[for]]#1 into %[[for]]#0
33   // CHECK: return %[[insert]]
34   return %0 : tensor<?xf32>
37 // -----
39 func.func @subset_of_subset(%arg: tensor<?xf32>) -> tensor<?xf32> {
40   %lb = "test.foo"() : () -> (index)
41   %ub = "test.foo"() : () -> (index)
42   %step = "test.foo"() : () -> (index)
44   // CHECK: %[[extract1:.*]] = tensor.extract_slice %[[arg]]
45   // CHECK: %[[extract2:.*]] = tensor.extract_slice %[[extract1]]
46   // CHECK: %[[for:.*]]:3 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted1:.*]] = %[[extract1]], %[[hoisted2:.*]] = %[[extract2]])
47   %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
48     %extract1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
49     %extract2 = tensor.extract_slice %extract1[1][2][1] : tensor<5xf32> to tensor<2xf32>
51     // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted2]])
52     %2 = "test.foo"(%extract2) : (tensor<2xf32>) -> (tensor<2xf32>)
54     %insert1 = tensor.insert_slice %2 into %extract1[1][2][1] : tensor<2xf32> into tensor<5xf32>
55     %insert2 = tensor.insert_slice %insert1 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
57     // CHECK: scf.yield %[[t]], %[[hoisted1]], %[[foo]]
58     scf.yield %insert2 : tensor<?xf32>
59   }
60   // CHECK: %[[insert2:.*]] = tensor.insert_slice %[[for]]#2 into %[[for]]#1[1] [2] [1]
61   // CHECK: %[[insert1:.*]] = tensor.insert_slice %[[insert2]] into %[[for]]#0[0] [5] [1]
63   // CHECK: return %[[insert1]]
64   return %0 : tensor<?xf32>
67 // -----
69 // CHECK-LABEL: func @hoist_matching_chain(
70 //  CHECK-SAME:     %[[arg:.*]]: tensor<?xf32>
71 func.func @hoist_matching_chain(%arg: tensor<?xf32>) -> tensor<?xf32> {
72   %lb = "test.foo"() : () -> (index)
73   %ub = "test.foo"() : () -> (index)
74   %step = "test.foo"() : () -> (index)
75   %sz = "test.foo"() : () -> (index)
77   // CHECK: %[[extract2:.*]] = tensor.extract_slice %[[arg]][%{{.*}}] [5] [1]
78   // CHECK: %[[extract1:.*]] = tensor.extract_slice %[[arg]][0] [%{{.*}}] [1]
79   // CHECK: %[[for:.*]]:3 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted2:.*]] = %[[extract2]], %[[hoisted1:.*]] = %[[extract1]])
80   %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
81     %1 = tensor.extract_slice %t[0][%sz][1] : tensor<?xf32> to tensor<?xf32>
82     %2 = tensor.extract_slice %t[%sz][5][1] : tensor<?xf32> to tensor<5xf32>
83     // CHECK-DAG: %[[foo1:.*]] = "test.foo"(%[[hoisted1]])
84     // CHECK-DAG: %[[foo2:.*]] = "test.foo"(%[[hoisted2]])
85     %foo1 = "test.foo"(%1) : (tensor<?xf32>) -> (tensor<?xf32>)
86     %foo2 = "test.foo"(%2) : (tensor<5xf32>) -> (tensor<5xf32>)
87     %5 = tensor.insert_slice %foo2 into %t[%sz][5][1] : tensor<5xf32> into tensor<?xf32>
88     %6 = tensor.insert_slice %foo1 into %5[0][%sz][1] : tensor<?xf32> into tensor<?xf32>
89     // CHECK: scf.yield %[[t]], %[[foo2]], %[[foo1]]
90     scf.yield %6 : tensor<?xf32>
91   }
92   // CHECK: %[[insert2:.*]] = tensor.insert_slice %[[for]]#2 into %[[for]]#0[0] [%{{.*}}] [1]
93   // CHECK: %[[insert1:.*]] = tensor.insert_slice %[[for]]#1 into %[[insert2]][%{{.*}}] [5] [1]
95   // CHECK: return %[[insert1]]
96   return %0 : tensor<?xf32>
99 // -----
101 // CHECK-LABEL: func @do_not_hoist_overlapping_subsets(
102 func.func @do_not_hoist_overlapping_subsets(%arg: tensor<?xf32>) -> tensor<?xf32> {
103   %lb = "test.foo"() : () -> (index)
104   %ub = "test.foo"() : () -> (index)
105   %step = "test.foo"() : () -> (index)
106   %sz1 = "test.foo"() : () -> (index)
107   %sz2 = "test.foo"() : () -> (index)
109   // CHECK: scf.for
110   %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
111     // These two slices are potentially overlapping. Do not hoist.
112     // CHECK: tensor.extract_slice
113     // CHECK: tensor.extract_slice
114     %1 = tensor.extract_slice %t[0][%sz1][1] : tensor<?xf32> to tensor<?xf32>
115     %2 = tensor.extract_slice %t[10][%sz2][1] : tensor<?xf32> to tensor<?xf32>
116     // CHECK: "test.foo"
117     // CHECK: "test.foo"
118     %foo1 = "test.foo"(%1) : (tensor<?xf32>) -> (tensor<?xf32>)
119     %foo2 = "test.foo"(%2) : (tensor<?xf32>) -> (tensor<?xf32>)
120     // CHECK: tensor.insert_slice
121     // CHECK: tensor.insert_slice
122     %5 = tensor.insert_slice %foo2 into %t[0][%sz1][1] : tensor<?xf32> into tensor<?xf32>
123     %6 = tensor.insert_slice %foo1 into %5[10][%sz2][1] : tensor<?xf32> into tensor<?xf32>
124     // CHECK: scf.yield
125     scf.yield %6 : tensor<?xf32>
126   }
128   return %0 : tensor<?xf32>
131 // -----
133 // CHECK-LABEL: func @multiple_yields(
134 //  CHECK-SAME:     %[[arg:.*]]: tensor<?xf32>
135 func.func @multiple_yields(%arg: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
136   %lb = "test.foo"() : () -> (index)
137   %ub = "test.foo"() : () -> (index)
138   %step = "test.foo"() : () -> (index)
140   // CHECK: %[[extract1:.*]] = tensor.extract_slice
141   // CHECK: %[[extract2:.*]] = tensor.extract_slice
142   // CHECK: scf.for {{.*}} iter_args(%{{.*}} = %[[arg]], %{{.*}} = %[[arg]], %{{.*}} = %[[extract1]], %{{.*}} = %[[extract2]])
143   %0:2 = scf.for %iv = %lb to %ub step %step iter_args(%t1 = %arg, %t2 = %arg)
144       -> (tensor<?xf32>, tensor<?xf32>) {
145     %1 = tensor.extract_slice %t1[0][5][1] : tensor<?xf32> to tensor<5xf32>
146     %2 = tensor.extract_slice %t2[5][5][1] : tensor<?xf32> to tensor<5xf32>
147     // CHECK: "test.foo"
148     // CHECK: "test.foo"
149     %foo1 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
150     %foo2 = "test.foo"(%2) : (tensor<5xf32>) -> (tensor<5xf32>)
151     %5 = tensor.insert_slice %foo2 into %t1[0][5][1] : tensor<5xf32> into tensor<?xf32>
152     %6 = tensor.insert_slice %foo1 into %t2[5][5][1] : tensor<5xf32> into tensor<?xf32>
153     // CHECK: scf.yield
154     scf.yield %5, %6 : tensor<?xf32>, tensor<?xf32>
155   }
156   // CHECK: tensor.insert_slice
157   // CHECK: tensor.insert_slice
159   return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>
162 // -----
164 // CHECK-LABEL: func @do_not_hoist_swapping_yields(
165 func.func @do_not_hoist_swapping_yields(%arg: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
166   %lb = "test.foo"() : () -> (index)
167   %ub = "test.foo"() : () -> (index)
168   %step = "test.foo"() : () -> (index)
170   // CHECK: scf.for
171   %0:2 = scf.for %iv = %lb to %ub step %step iter_args(%t1 = %arg, %t2 = %arg)
172       -> (tensor<?xf32>, tensor<?xf32>) {
173     // CHECK: tensor.extract_slice
174     // CHECK: tensor.extract_slice
175     %1 = tensor.extract_slice %t1[0][5][1] : tensor<?xf32> to tensor<5xf32>
176     %2 = tensor.extract_slice %t2[5][5][1] : tensor<?xf32> to tensor<5xf32>
177     // CHECK: "test.foo"
178     // CHECK: "test.foo"
179     %foo1 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
180     %foo2 = "test.foo"(%2) : (tensor<5xf32>) -> (tensor<5xf32>)
181     // CHECK: tensor.insert_slice
182     // CHECK: tensor.insert_slice
183     %5 = tensor.insert_slice %foo2 into %t1[0][5][1] : tensor<5xf32> into tensor<?xf32>
184     %6 = tensor.insert_slice %foo1 into %t2[5][5][1] : tensor<5xf32> into tensor<?xf32>
185     // Swapping yields: do not hoist.
186     // CHECK: scf.yield
187     scf.yield %6, %5 : tensor<?xf32>, tensor<?xf32>
188   }
190   return %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>
193 // -----
195 // CHECK-LABEL: func @non_subset_op(
196 func.func @non_subset_op(%arg: tensor<?xf32>) -> tensor<?xf32> {
197   %lb = "test.foo"() : () -> (index)
198   %ub = "test.foo"() : () -> (index)
199   %step = "test.foo"() : () -> (index)
201   // CHECK: scf.for
202   %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
203     // If any value along the use-def chain from the region iter_arg to the
204     // terminator is used by a non-subset op, no subset op along that chain can
205     // be hoisted. That is because it is unknown which parts of the value are
206     // accessed by the non-subset op.
207     // CHECK: "test.non_subset_op"
208     "test.non_subset_op"(%t) : (tensor<?xf32>) -> ()
209     // CHECK: tensor.extract_slice
210     %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
211     // CHECK: "test.foo"
212     %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
213     // CHECK: tensor.insert_slice
214     %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
215     // CHECK: scf.yield
216     scf.yield %3 : tensor<?xf32>
217   }
219   return %0 : tensor<?xf32>
222 // -----
224 // CHECK-LABEL: func @non_loop_invariant_subset_op(
225 func.func @non_loop_invariant_subset_op(%arg: tensor<?xf32>) -> tensor<?xf32> {
226   %lb = "test.foo"() : () -> (index)
227   %ub = "test.foo"() : () -> (index)
228   %step = "test.foo"() : () -> (index)
230   // CHECK: scf.for
231   %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
232     // Subset ops that are not loop-invariant cannot be hoisted.
233     // CHECK: tensor.extract_slice
234     %1 = tensor.extract_slice %t[%iv][5][1] : tensor<?xf32> to tensor<5xf32>
235     // CHECK: "test.foo"
236     %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
237     // CHECK: tensor.insert_slice
238     %3 = tensor.insert_slice %2 into %t[%iv][5][1] : tensor<5xf32> into tensor<?xf32>
239     // CHECK: scf.yield
240     scf.yield %3 : tensor<?xf32>
241   }
243   return %0 : tensor<?xf32>
246 // -----
248 // CHECK-LABEL: func @nested_hoisting(
249 //  CHECK-SAME:     %[[arg:.*]]: tensor<?xf32>
250 func.func @nested_hoisting(%arg: tensor<?xf32>) -> tensor<?xf32> {
251   %lb = "test.foo"() : () -> (index)
252   %ub = "test.foo"() : () -> (index)
253   %step = "test.foo"() : () -> (index)
255   // CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]][0] [5] [1]
256   // CHECK: %[[extract2:.*]] = tensor.extract_slice %[[arg]][5] [5] [1]
257   // CHECK: %[[for:.*]]:3 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]], %[[hoisted2:.*]] = %[[extract2]])
258   %0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
259     %1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
260     // CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]])
261     %2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
262     %3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
263     // CHECK: %[[for2:.*]]:2 = {{.*}} iter_args(%[[t2:.*]] = %[[t]], %[[hoisted2_nested:.*]] = %[[hoisted2]])
264     %4 = scf.for %iv2 = %lb to %ub step %step iter_args(%t2 = %3) -> (tensor<?xf32>) {
265       %5 = tensor.extract_slice %t2[5][5][1] : tensor<?xf32> to tensor<5xf32>
266       // CHECK: %[[foo2:.*]] = "test.foo"(%[[hoisted2_nested]])
267       %6 = "test.foo"(%5) : (tensor<5xf32>) -> (tensor<5xf32>)
268       %7 = tensor.insert_slice %6 into %t2[5][5][1] : tensor<5xf32> into tensor<?xf32>
269       // CHECK: scf.yield %[[t2]], %[[foo2]]
270       scf.yield %7 : tensor<?xf32>
271     }
272     // CHECK: scf.yield %[[for2]]#0, %[[foo]], %[[for2]]#1
273     scf.yield %4 : tensor<?xf32>
274   }
275   // CHECK: %[[insert:.*]] = tensor.insert_slice %[[for]]#2 into %[[for]]#0[5] [5] [1]
276   // CHECK: %[[insert2:.*]] = tensor.insert_slice %[[for]]#1 into %[[insert]][0] [5] [1]
277   // CHECK: return %[[insert2]]
278   return %0 : tensor<?xf32>
281 // -----
283 // CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor
284 func.func @hoist_vector_transfer_pairs_tensor(
285     %tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>, %tensor2: tensor<?x?xf32>,
286     %tensor3: tensor<?x?xf32>, %tensor4: tensor<?x?xf32>, %tensor5: tensor<?x?xf32>,
287     %val: index, %lb : index, %ub : index, %step: index) ->
288     (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
289      tensor<?x?xf32>, tensor<?x?xf32>) {
290   %c0 = arith.constant 0 : index
291   %cst = arith.constant 0.0 : f32
293 // CHECK: vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<1xf32>
294 // CHECK: scf.for {{.*}} iter_args({{.*}}) ->
295 // CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>) {
296 // CHECK:   vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<2xf32>
297 // CHECK:   scf.for {{.*}} iter_args({{.*}}) ->
298 // CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<2xf32>, vector<1xf32>) {
299 // CHECK:     vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<4xf32>
300 // CHECK:     "test.some_crippling_use"(%{{.*}}) : (tensor<?x?xf32>) -> ()
301 // CHECK:     vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<5xf32>
302 // CHECK:     "test.some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
303 // CHECK:     "test.some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
304 // CHECK:     "test.some_use"(%{{.*}}) : (tensor<?x?xf32>) -> vector<3xf32>
305 // CHECK:     "test.some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
306 // CHECK:     "test.some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32>
307 // CHECK:     vector.transfer_write %{{.*}} : vector<3xf32>, tensor<?x?xf32>
308 // CHECK:     vector.transfer_write %{{.*}} : vector<4xf32>, tensor<?x?xf32>
309 // CHECK:     vector.transfer_write %{{.*}} : vector<5xf32>, tensor<?x?xf32>
310 // CHECK:     "test.some_crippling_use"(%{{.*}}) : (tensor<?x?xf32>) -> ()
311 // CHECK:     scf.yield {{.*}} :
312 // CHECK-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<2xf32>, vector<1xf32>
313 // CHECK:   }
314 // CHECK:   vector.transfer_write %{{.*}} : vector<2xf32>, tensor<?x?xf32>
315 // CHECK:   scf.yield {{.*}} :
316 // CHECK-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>
317 // CHECK: }
318 // CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, tensor<?x?xf32>
319   %0:6 = scf.for %i = %lb to %ub step %step
320   iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2,
321             %arg3 = %tensor3,  %arg4 = %tensor4, %arg5 = %tensor5)
322   -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
323      tensor<?x?xf32>, tensor<?x?xf32>)  {
324     %1:6 = scf.for %j = %lb to %ub step %step
325     iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2,
326               %arg9 = %arg3,  %arg10 = %arg4, %arg11 = %arg5)
327     -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
328        tensor<?x?xf32>, tensor<?x?xf32>)  {
329       %r0 = vector.transfer_read %arg7[%c0, %c0], %cst: tensor<?x?xf32>, vector<1xf32>
330       %r1 = vector.transfer_read %arg6[%i, %i], %cst: tensor<?x?xf32>, vector<2xf32>
331       %r3 = vector.transfer_read %arg9[%c0, %c0], %cst: tensor<?x?xf32>, vector<4xf32>
332       "test.some_crippling_use"(%arg10) : (tensor<?x?xf32>) -> ()
333       %r4 = vector.transfer_read %arg10[%c0, %c0], %cst: tensor<?x?xf32>, vector<5xf32>
334       %r5 = vector.transfer_read %arg11[%c0, %c0], %cst: tensor<?x?xf32>, vector<6xf32>
335       "test.some_crippling_use"(%arg11) : (tensor<?x?xf32>) -> ()
336       %u0 = "test.some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
337       %u1 = "test.some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
338       %u2 = "test.some_use"(%arg8) : (tensor<?x?xf32>) -> vector<3xf32>
339       %u3 = "test.some_use"(%r3) : (vector<4xf32>) -> vector<4xf32>
340       %u4 = "test.some_use"(%r4) : (vector<5xf32>) -> vector<5xf32>
341       %u5 = "test.some_use"(%r5) : (vector<6xf32>) -> vector<6xf32>
342       %w1 = vector.transfer_write %u0, %arg7[%c0, %c0] : vector<1xf32>, tensor<?x?xf32>
343       %w0 = vector.transfer_write %u1, %arg6[%i, %i] : vector<2xf32>, tensor<?x?xf32>
344       %w2 = vector.transfer_write %u2, %arg8[%c0, %c0] : vector<3xf32>, tensor<?x?xf32>
345       %w3 = vector.transfer_write %u3, %arg9[%c0, %c0] : vector<4xf32>, tensor<?x?xf32>
346       %w4 = vector.transfer_write %u4, %arg10[%c0, %c0] : vector<5xf32>, tensor<?x?xf32>
347       %w5 = vector.transfer_write %u5, %arg11[%c0, %c0] : vector<6xf32>, tensor<?x?xf32>
348       "test.some_crippling_use"(%w3) : (tensor<?x?xf32>) -> ()
349       scf.yield %w0, %w1, %w2, %w3, %w4, %w5 :
350         tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
351         tensor<?x?xf32>, tensor<?x?xf32>
352       }
353       scf.yield %1#0,  %1#1, %1#2, %1#3, %1#4, %1#5 :
354         tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
355         tensor<?x?xf32>, tensor<?x?xf32>
356   }
357   return %0#0,  %0#1, %0#2, %0#3, %0#4,  %0#5 :
358         tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
359         tensor<?x?xf32>, tensor<?x?xf32>
362 // -----
364 // CHECK-LABEL: func @hoist_vector_transfer_pairs_disjoint_tensor(
365 //  CHECK-SAME:   %[[TENSOR0:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
366 //  CHECK-SAME:   %[[TENSOR1:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
367 //  CHECK-SAME:   %[[TENSOR2:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
368 //  CHECK-SAME:   %[[TENSOR3:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
369 func.func @hoist_vector_transfer_pairs_disjoint_tensor(
370     %tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>,
371     %tensor2: tensor<?x?xf32>, %tensor3: tensor<?x?xf32>,
372     %val: index, %lb : index, %ub : index, %step: index,
373     %random_index : index) ->
374     (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
375   %c0 = arith.constant 0 : index
376   %c1 = arith.constant 1 : index
377   %c3 = arith.constant 3 : index
378   %cst = arith.constant 0.0 : f32
380 // CHECK: vector.transfer_read %[[TENSOR2]]{{.*}} : tensor<?x?xf32>, vector<3xf32>
381 // CHECK: vector.transfer_read %[[TENSOR2]]{{.*}} : tensor<?x?xf32>, vector<3xf32>
382 // CHECK: vector.transfer_read %[[TENSOR3]]{{.*}} : tensor<?x?xf32>, vector<4xf32>
383 // CHECK: vector.transfer_read %[[TENSOR3]]{{.*}} : tensor<?x?xf32>, vector<4xf32>
384 // CHECK: %[[R:.*]]:8 = scf.for {{.*}} iter_args({{.*}}) ->
385 // CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
386 // CHECK:   scf.for {{.*}} iter_args({{.*}}) ->
387 // CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
388 // CHECK:     vector.transfer_read %[[TENSOR1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
389 // CHECK:     vector.transfer_read %[[TENSOR1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
390 // CHECK:     "test.some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
391 // CHECK:     "test.some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
392 // CHECK:     "test.some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32>
393 // CHECK:     "test.some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32>
394 // CHECK:     "test.some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
395 // CHECK:     "test.some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
396 // CHECK:     "test.some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
397 // CHECK:     "test.some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
398 // CHECK:     vector.transfer_write %{{.*}}, %{{.*}}{{.*}} : vector<2xf32>, tensor<?x?xf32>
399 // CHECK:     vector.transfer_write %{{.*}}, %{{.*}}{{.*}} : vector<2xf32>, tensor<?x?xf32>
400 // CHECK:     scf.yield {{.*}} :
401 // CHECK-SAME: tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>
402 // CHECK:   }
403 // CHECK:   scf.yield {{.*}} :
404 // CHECK-SAME: tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>
405 // CHECK: }
406 // CHECK: %[[TENSOR4:.*]] = vector.transfer_write %[[R]]#7, %[[R]]#3{{.*}} : vector<4xf32>, tensor<?x?xf32>
407 // CHECK:                   vector.transfer_write %[[R]]#6, %[[TENSOR4]]{{.*}} : vector<4xf32>, tensor<?x?xf32>
408 // CHECK: %[[TENSOR5:.*]] = vector.transfer_write %[[R]]#5, %[[R]]#2{{.*}} : vector<3xf32>, tensor<?x?xf32>
409 // CHECK:                   vector.transfer_write %[[R]]#4, %[[TENSOR5]]{{.*}} : vector<3xf32>, tensor<?x?xf32>
410   %0:4 = scf.for %i = %lb to %ub step %step
411   iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2,
412             %arg3 = %tensor3)
413   -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
414     %1:4 = scf.for %j = %lb to %ub step %step
415     iter_args(%arg4 = %arg0, %arg5 = %arg1, %arg6 = %arg2,
416               %arg7 = %arg3)
417     -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
418       %r00 = vector.transfer_read %arg5[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32>
419       %r01 = vector.transfer_read %arg5[%c0, %c1], %cst: tensor<?x?xf32>, vector<2xf32>
420       %r20 = vector.transfer_read %arg6[%c0, %c0], %cst: tensor<?x?xf32>, vector<3xf32>
421       %r21 = vector.transfer_read %arg6[%c0, %c3], %cst: tensor<?x?xf32>, vector<3xf32>
422       %r30 = vector.transfer_read %arg7[%c0, %random_index], %cst: tensor<?x?xf32>, vector<4xf32>
423       %r31 = vector.transfer_read %arg7[%c1, %random_index], %cst: tensor<?x?xf32>, vector<4xf32>
424       %r10 = vector.transfer_read %arg4[%i, %i], %cst: tensor<?x?xf32>, vector<2xf32>
425       %r11 = vector.transfer_read %arg4[%random_index, %random_index], %cst: tensor<?x?xf32>, vector<2xf32>
426       %u00 = "test.some_use"(%r00) : (vector<2xf32>) -> vector<2xf32>
427       %u01 = "test.some_use"(%r01) : (vector<2xf32>) -> vector<2xf32>
428       %u20 = "test.some_use"(%r20) : (vector<3xf32>) -> vector<3xf32>
429       %u21 = "test.some_use"(%r21) : (vector<3xf32>) -> vector<3xf32>
430       %u30 = "test.some_use"(%r30) : (vector<4xf32>) -> vector<4xf32>
431       %u31 = "test.some_use"(%r31) : (vector<4xf32>) -> vector<4xf32>
432       %u10 = "test.some_use"(%r10) : (vector<2xf32>) -> vector<2xf32>
433       %u11 = "test.some_use"(%r11) : (vector<2xf32>) -> vector<2xf32>
434       %w10 = vector.transfer_write %u00, %arg5[%c0, %c0] : vector<2xf32>, tensor<?x?xf32>
435       %w11 = vector.transfer_write %u01, %w10[%c0, %c1] : vector<2xf32>, tensor<?x?xf32>
436       %w20 = vector.transfer_write %u20, %arg6[%c0, %c0] : vector<3xf32>, tensor<?x?xf32>
437       %w21 = vector.transfer_write %u21, %w20[%c0, %c3] : vector<3xf32>, tensor<?x?xf32>
438       %w30 = vector.transfer_write %u30, %arg7[%c0, %random_index] : vector<4xf32>, tensor<?x?xf32>
439       %w31 = vector.transfer_write %u31, %w30[%c1, %random_index] : vector<4xf32>, tensor<?x?xf32>
440       %w00 = vector.transfer_write %u10, %arg4[%i, %i] : vector<2xf32>, tensor<?x?xf32>
441       %w01 = vector.transfer_write %u11, %w00[%random_index, %random_index] : vector<2xf32>, tensor<?x?xf32>
442       scf.yield %w01, %w11, %w21, %w31 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
443     }
444     scf.yield %1#0,  %1#1, %1#2, %1#3 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
445   }
446   return %0#0,  %0#1, %0#2, %0#3 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
449 // -----
451 // CHECK-LABEL: func @hoist_vector_transfer_pairs_tensor_and_slices
452 //  CHECK-SAME:   %[[TENSOR0:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
453 //  CHECK-SAME:   %[[TENSOR1:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
454 //  CHECK-SAME:   %[[TENSOR2:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
455 //  CHECK-SAME:   %[[TENSOR3:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
456 //  CHECK-SAME:   %[[TENSOR4:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
457 //  CHECK-SAME:   %[[TENSOR5:[a-zA-Z0-9]*]]: tensor<?x?xf32>
458 func.func @hoist_vector_transfer_pairs_tensor_and_slices(
459     %tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>, %tensor2: tensor<?x?xf32>,
460     %tensor3: tensor<?x?xf32>, %tensor4: tensor<?x?xf32>, %tensor5: tensor<?x?xf32>,
461     %val: index, %lb : index, %ub : index, %step: index) ->
462     (
463       tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>//, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
464     ) {
465   %c0 = arith.constant 0 : index
466   %cst = arith.constant 0.0 : f32
468   //      CHECK: scf.for %[[I:.*]] = {{.*}} iter_args(
469   // CHECK-SAME:   %[[TENSOR0_ARG:[0-9a-zA-Z]+]] = %[[TENSOR0]],
470   // CHECK-SAME:   %[[TENSOR1_ARG:[0-9a-zA-Z]+]] = %[[TENSOR1]],
471   // CHECK-SAME:   %[[TENSOR2_ARG:[0-9a-zA-Z]+]] = %[[TENSOR2]]
472   // CHECK-SAME: ) ->
473   // CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
474   %0:3 = scf.for %i = %lb to %ub step %step
475   iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2)
476     -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)  {
478     // Hoisted
479     // CHECK:   %[[ST0:.*]] = tensor.extract_slice %[[TENSOR0_ARG]][%[[I]], %[[I]]]{{.*}}: tensor<?x?xf32> to tensor<?x?xf32>
480     // CHECK:   %[[V0:.*]] = vector.transfer_read %[[ST0]]{{.*}} : tensor<?x?xf32>, vector<1xf32>
482     //      CHECK:   %[[R:.*]]:5 = scf.for %[[J:.*]] = {{.*}} iter_args(
483     // CHECK-SAME:   %[[TENSOR0_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR0_ARG]]
484     // CHECK-SAME:   %[[TENSOR1_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR1_ARG]]
485     // CHECK-SAME:   %[[TENSOR2_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR2_ARG]]
486     // CHECK-SAME:   %[[ST0_ARG_L2:[0-9a-zA-Z]+]] = %[[ST0]]
487     // CHECK-SAME:   %[[V0_ARG_L2:[0-9a-zA-Z]+]] = %[[V0]]
488     // CHECK-SAME: ) ->
489     // CHECK-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>)
490     %1:3 = scf.for %j = %lb to %ub step %step
491     iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2)
492     -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)  {
493       // Hoists.
494       %st0 = tensor.extract_slice %arg6[%i, %i][%step, %step][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
495       %r0 = vector.transfer_read %st0[%c0, %c0], %cst: tensor<?x?xf32>, vector<1xf32>
497       // CHECK:     %[[ST1:.*]] = tensor.extract_slice %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor<?x?xf32> to tensor<?x?xf32>
498       // CHECK:     %[[V1:.*]] = vector.transfer_read %[[ST1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
499       // Does not hoist (slice depends on %j)
500       %st1 = tensor.extract_slice %arg7[%j, %c0][%step, %step][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
501       %r1 = vector.transfer_read %st1[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32>
503       // CHECK:     %[[ST2:.*]] = tensor.extract_slice %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor<?x?xf32> to tensor<?x?xf32>
504       // CHECK:     %[[V2:.*]] = vector.transfer_read %[[ST2]]{{.*}} : tensor<?x?xf32>, vector<3xf32>
505       // Does not hoist, 2 slice %arg8.
506       %st2 = tensor.extract_slice %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
507       %r2 = vector.transfer_read %st2[%c0, %c0], %cst: tensor<?x?xf32>, vector<3xf32>
509       // CHECK:     %[[U0:.*]] = "test.some_use"(%[[V0_ARG_L2]]) : (vector<1xf32>) -> vector<1xf32>
510       // CHECK:     %[[U1:.*]] = "test.some_use"(%[[V1]]) : (vector<2xf32>) -> vector<2xf32>
511       // CHECK:     %[[U2:.*]] = "test.some_use"(%[[V2]]) : (vector<3xf32>) -> vector<3xf32>
512       %u0 = "test.some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
513       %u1 = "test.some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
514       %u2 = "test.some_use"(%r2) : (vector<3xf32>) -> vector<3xf32>
516       // Hoists
517       %w0 = vector.transfer_write %u0, %st0[%c0, %c0] : vector<1xf32>, tensor<?x?xf32>
519       // CHECK-DAG:     %[[STI1:.*]] = vector.transfer_write %[[U1]], %{{.*}} : vector<2xf32>, tensor<?x?xf32>
520       // Does not hoist (associated slice depends on %j).
521       %w1 = vector.transfer_write %u1, %st1[%i, %i] : vector<2xf32>, tensor<?x?xf32>
523       // CHECK-DAG:     %[[STI2:.*]] = vector.transfer_write %[[U2]], %{{.*}} : vector<3xf32>, tensor<?x?xf32>
524       // Does not hoist, 2 slice / insert_slice for %arg8.
525       %w2 = vector.transfer_write %u2, %st2[%c0, %c0] : vector<3xf32>, tensor<?x?xf32>
527       // Hoists.
528       %sti0 = tensor.insert_slice %w0 into %arg6[%i, %i][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
530       // CHECK-DAG:     tensor.insert_slice %[[STI1]] into %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor<?x?xf32> into tensor<?x?xf32>
531       // Does not hoist (depends on %j).
532       %sti1 = tensor.insert_slice %w1 into %arg7[%j, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
534       // CHECK-DAG:     tensor.insert_slice %[[STI2]] into %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor<?x?xf32> into tensor<?x?xf32>
535       // Does not hoist, 2 slice / insert_slice for %arg8.
536       %sti2 = tensor.insert_slice %w2 into %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
537       // Extract with a different stride to make sure we cannot fold this extract with the above insert.
538       %st22 = tensor.extract_slice %sti2[%i, %c0][%step, %step][2, 1] : tensor<?x?xf32> to tensor<?x?xf32>
539       %sti22 = tensor.insert_slice %st22 into %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
541       // CHECK:     scf.yield {{.*}} : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>
542       // CHECK:   }
543       scf.yield %sti0, %sti1, %sti22:
544         tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
545     }
547     // Hoisted
548     // CHECK:   %[[STI0:.*]] = vector.transfer_write %[[R]]#4, %[[R]]#3{{.*}} : vector<1xf32>, tensor<?x?xf32>
549     // CHECK:   tensor.insert_slice %[[STI0]] into %[[R]]#0[%[[I]], %[[I]]]{{.*}} : tensor<?x?xf32> into tensor<?x?xf32>
551     // CHECK:   scf.yield {{.*}} : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
552     scf.yield %1#0, %1#1, %1#2 :
553       tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
555     // CHECK: }
556   }
557   return %0#0, %0#1, %0#2 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
560 // -----
562 // CHECK-LABEL: func @hoist_vector_transfer_write_pairs_disjoint_tensor(
563 //  CHECK-SAME:   %[[T:.*]]: tensor<?x?xf32>,
564 //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
565 //   CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
566 //   CHECK-DAG:   %[[R0:.*]] = vector.transfer_read %[[T]][%[[C0]], %[[C0]]], %{{.*}} : tensor<?x?xf32>, vector<2xf32>
567 //   CHECK-DAG:   %[[R1:.*]] = vector.transfer_read %[[T]][%[[C0]], %[[C3]]], %{{.*}} : tensor<?x?xf32>, vector<2xf32>
568 //       CHECK:   %[[F:.*]]:3 = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[TL:.*]] = %[[T]], %[[R2:.*]] = %[[R0]], %[[R3:.*]] = %[[R1]]) -> (tensor<?x?xf32>, vector<2xf32>, vector<2xf32>) {
569 //       CHECK:     %[[R4:.*]] = "test.some_use"(%[[R2]]) : (vector<2xf32>) -> vector<2xf32>
570 //       CHECK:     %[[R5:.*]] = "test.some_use"(%[[R3]]) : (vector<2xf32>) -> vector<2xf32>
571 //       CHECK:     scf.yield %[[TL]], %[[R4]], %[[R5]] : tensor<?x?xf32>, vector<2xf32>, vector<2xf32>
572 //       CHECK:   }
573 //       CHECK:   %[[W0:.*]] = vector.transfer_write %[[F]]#2, %[[F]]#0[%[[C0]], %[[C3]]] : vector<2xf32>, tensor<?x?xf32>
574 //       CHECK:   %[[W1:.*]] = vector.transfer_write %[[F]]#1, %[[W0]][%[[C0]], %[[C0]]] : vector<2xf32>, tensor<?x?xf32>
575 //       CHECK:  return %[[W1]] : tensor<?x?xf32>
576 func.func @hoist_vector_transfer_write_pairs_disjoint_tensor(
577     %tensor: tensor<?x?xf32>,
578     %val: index, %lb : index, %ub : index, %step: index) ->
579     (tensor<?x?xf32>) {
580   %c0 = arith.constant 0 : index
581   %c1 = arith.constant 1 : index
582   %c3 = arith.constant 3 : index
583   %cst = arith.constant 0.0 : f32
584   %1 = scf.for %j = %lb to %ub step %step iter_args(%arg5 = %tensor)
585     -> (tensor<?x?xf32>) {
586     %r00 = vector.transfer_read %arg5[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32>
587     %u00 = "test.some_use"(%r00) : (vector<2xf32>) -> vector<2xf32>
588     %w10 = vector.transfer_write %u00, %arg5[%c0, %c0] : vector<2xf32>, tensor<?x?xf32>
590     // Hoist by properly bypassing the disjoint write %w10.
591     %r01 = vector.transfer_read %w10[%c0, %c3], %cst: tensor<?x?xf32>, vector<2xf32>
592     %u01 = "test.some_use"(%r01) : (vector<2xf32>) -> vector<2xf32>
593     %w11 = vector.transfer_write %u01, %w10[%c0, %c3] : vector<2xf32>, tensor<?x?xf32>
594     scf.yield %w11 : tensor<?x?xf32>
595   }
596   return %1 : tensor<?x?xf32>