[VPlan] Run recipe removal and simplification after optimizeForVFAndUF. (#125926)
[llvm-project.git] / mlir / test / python / dialects / scf.py
blobde61f4613868f7b9233803ea00292baf235dff32
1 # RUN: %PYTHON %s | FileCheck %s
3 from mlir.ir import *
4 from mlir.dialects import arith
5 from mlir.dialects import func
6 from mlir.dialects import memref
7 from mlir.dialects import scf
8 from mlir.passmanager import PassManager
11 def constructAndPrintInModule(f):
12 print("\nTEST:", f.__name__)
13 with Context(), Location.unknown():
14 module = Module.create()
15 with InsertionPoint(module.body):
16 f()
17 print(module)
18 return f
21 # CHECK-LABEL: TEST: testSimpleLoop
22 @constructAndPrintInModule
23 def testSimpleLoop():
24 index_type = IndexType.get()
26 @func.FuncOp.from_py_func(index_type, index_type, index_type)
27 def simple_loop(lb, ub, step):
28 loop = scf.ForOp(lb, ub, step, [lb, lb])
29 with InsertionPoint(loop.body):
30 scf.YieldOp(loop.inner_iter_args)
31 return
34 # CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
35 # CHECK: scf.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
36 # CHECK: iter_args(%[[I1:.*]] = %[[ARG0]], %[[I2:.*]] = %[[ARG0]])
37 # CHECK: scf.yield %[[I1]], %[[I2]]
40 # CHECK-LABEL: TEST: testInductionVar
41 @constructAndPrintInModule
42 def testInductionVar():
43 index_type = IndexType.get()
45 @func.FuncOp.from_py_func(index_type, index_type, index_type)
46 def induction_var(lb, ub, step):
47 loop = scf.ForOp(lb, ub, step, [lb])
48 with InsertionPoint(loop.body):
49 scf.YieldOp([loop.induction_variable])
50 return
53 # CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
54 # CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
55 # CHECK: scf.yield %[[IV]]
58 # CHECK-LABEL: TEST: testForSugar
59 @constructAndPrintInModule
60 def testForSugar():
61 index_type = IndexType.get()
62 memref_t = MemRefType.get([10], index_type)
63 range = scf.for_
65 # CHECK: func.func @range_loop_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
66 # CHECK: scf.for %[[VAL_4:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
67 # CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_4]] : index
68 # CHECK: memref.store %[[VAL_5]], %[[VAL_3]]{{\[}}%[[VAL_4]]] : memref<10xindex>
69 # CHECK: }
70 # CHECK: return
71 # CHECK: }
72 @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
73 def range_loop_1(lb, ub, step, memref_v):
74 for i in range(lb, ub, step):
75 add = arith.addi(i, i)
76 memref.store(add, memref_v, [i])
78 scf.yield_([])
80 # CHECK: func.func @range_loop_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
81 # CHECK: %[[VAL_4:.*]] = arith.constant 10 : index
82 # CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
83 # CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_4]] step %[[VAL_5]] {
84 # CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
85 # CHECK: memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex>
86 # CHECK: }
87 # CHECK: return
88 # CHECK: }
89 @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
90 def range_loop_2(lb, ub, step, memref_v):
91 for i in range(lb, 10, 1):
92 add = arith.addi(i, i)
93 memref.store(add, memref_v, [i])
94 scf.yield_([])
96 # CHECK: func.func @range_loop_3(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
97 # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
98 # CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
99 # CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_1]] step %[[VAL_5]] {
100 # CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
101 # CHECK: memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex>
102 # CHECK: }
103 # CHECK: return
104 # CHECK: }
105 @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
106 def range_loop_3(lb, ub, step, memref_v):
107 for i in range(0, ub, 1):
108 add = arith.addi(i, i)
109 memref.store(add, memref_v, [i])
110 scf.yield_([])
112 # CHECK: func.func @range_loop_4(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
113 # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
114 # CHECK: %[[VAL_5:.*]] = arith.constant 10 : index
115 # CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_2]] {
116 # CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
117 # CHECK: memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex>
118 # CHECK: }
119 # CHECK: return
120 # CHECK: }
121 @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
122 def range_loop_4(lb, ub, step, memref_v):
123 for i in range(0, 10, step):
124 add = arith.addi(i, i)
125 memref.store(add, memref_v, [i])
126 scf.yield_([])
128 # CHECK: func.func @range_loop_5(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
129 # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
130 # CHECK: %[[VAL_5:.*]] = arith.constant 10 : index
131 # CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
132 # CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] {
133 # CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
134 # CHECK: memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex>
135 # CHECK: }
136 # CHECK: return
137 # CHECK: }
138 @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
139 def range_loop_5(lb, ub, step, memref_v):
140 for i in range(0, 10, 1):
141 add = arith.addi(i, i)
142 memref.store(add, memref_v, [i])
143 scf.yield_([])
145 # CHECK: func.func @range_loop_6(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
146 # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
147 # CHECK: %[[VAL_5:.*]] = arith.constant 10 : index
148 # CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
149 # CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] {
150 # CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
151 # CHECK: memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex>
152 # CHECK: }
153 # CHECK: return
154 # CHECK: }
155 @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
156 def range_loop_6(lb, ub, step, memref_v):
157 for i in range(0, 10):
158 add = arith.addi(i, i)
159 memref.store(add, memref_v, [i])
160 scf.yield_([])
162 # CHECK: func.func @range_loop_7(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
163 # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
164 # CHECK: %[[VAL_5:.*]] = arith.constant 10 : index
165 # CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
166 # CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] {
167 # CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
168 # CHECK: memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex>
169 # CHECK: }
170 # CHECK: return
171 # CHECK: }
172 @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
173 def range_loop_7(lb, ub, step, memref_v):
174 for i in range(10):
175 add = arith.addi(i, i)
176 memref.store(add, memref_v, [i])
177 scf.yield_([])
179 # CHECK: func.func @loop_yield_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
180 # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
181 # CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
182 # CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
183 # CHECK: %[[VAL_7:.*]] = arith.constant 100 : index
184 # CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
185 # CHECK: %[[VAL_10:.*]] = scf.for %[[IV:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[ITER:.*]] = %[[VAL_4]]) -> (index) {
186 # CHECK: %[[VAL_9:.*]] = arith.addi %[[ITER]], %[[IV]] : index
187 # CHECK: scf.yield %[[VAL_9]] : index
188 # CHECK: }
189 # CHECK: memref.store %[[VAL_10]], %[[VAL_3]]{{\[}}%[[VAL_5]]] : memref<10xindex>
190 # CHECK: return
191 # CHECK: }
192 @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
193 def loop_yield_1(lb, ub, step, memref_v):
194 sum = arith.ConstantOp.create_index(0)
195 c0 = arith.ConstantOp.create_index(0)
196 for i, loc_sum, sum in scf.for_(0, 100, 1, [sum]):
197 loc_sum = arith.addi(loc_sum, i)
198 scf.yield_([loc_sum])
199 memref.store(sum, memref_v, [c0])
201 # CHECK: func.func @loop_yield_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
202 # CHECK: %[[c0:.*]] = arith.constant 0 : index
203 # CHECK: %[[c2:.*]] = arith.constant 2 : index
204 # CHECK: %[[REF1:.*]] = arith.constant 0 : index
205 # CHECK: %[[REF2:.*]] = arith.constant 1 : index
206 # CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
207 # CHECK: %[[VAL_7:.*]] = arith.constant 100 : index
208 # CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
209 # CHECK: %[[RES:.*]] = scf.for %[[IV:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[ITER1:.*]] = %[[c0]], %[[ITER2:.*]] = %[[c2]]) -> (index, index) {
210 # CHECK: %[[VAL_9:.*]] = arith.addi %[[ITER1]], %[[IV]] : index
211 # CHECK: %[[VAL_10:.*]] = arith.addi %[[ITER2]], %[[IV]] : index
212 # CHECK: scf.yield %[[VAL_9]], %[[VAL_10]] : index, index
213 # CHECK: }
214 # CHECK: return
215 # CHECK: }
216 @func.FuncOp.from_py_func(index_type, index_type, index_type, memref_t)
217 def loop_yield_2(lb, ub, step, memref_v):
218 sum1 = arith.ConstantOp.create_index(0)
219 sum2 = arith.ConstantOp.create_index(2)
220 c0 = arith.ConstantOp.create_index(0)
221 c1 = arith.ConstantOp.create_index(1)
222 for i, [loc_sum1, loc_sum2], [sum1, sum2] in scf.for_(0, 100, 1, [sum1, sum2]):
223 loc_sum1 = arith.addi(loc_sum1, i)
224 loc_sum2 = arith.addi(loc_sum2, i)
225 scf.yield_([loc_sum1, loc_sum2])
226 memref.store(sum1, memref_v, [c0])
227 memref.store(sum2, memref_v, [c1])
230 @constructAndPrintInModule
231 def testOpsAsArguments():
232 index_type = IndexType.get()
233 callee = func.FuncOp("callee", ([], [index_type, index_type]), visibility="private")
234 f = func.FuncOp("ops_as_arguments", ([], []))
235 with InsertionPoint(f.add_entry_block()):
236 lb = arith.ConstantOp.create_index(0)
237 ub = arith.ConstantOp.create_index(42)
238 step = arith.ConstantOp.create_index(2)
239 iter_args = func.CallOp(callee, [])
240 loop = scf.ForOp(lb, ub, step, iter_args)
241 with InsertionPoint(loop.body):
242 scf.YieldOp(loop.inner_iter_args)
243 func.ReturnOp([])
246 # CHECK-LABEL: TEST: testOpsAsArguments
247 # CHECK: func private @callee() -> (index, index)
248 # CHECK: func @ops_as_arguments() {
249 # CHECK: %[[LB:.*]] = arith.constant 0
250 # CHECK: %[[UB:.*]] = arith.constant 42
251 # CHECK: %[[STEP:.*]] = arith.constant 2
252 # CHECK: %[[ARGS:.*]]:2 = call @callee()
253 # CHECK: scf.for %arg0 = %c0 to %c42 step %c2
254 # CHECK: iter_args(%{{.*}} = %[[ARGS]]#0, %{{.*}} = %[[ARGS]]#1)
255 # CHECK: scf.yield %{{.*}}, %{{.*}}
256 # CHECK: return
259 @constructAndPrintInModule
260 def testIfWithoutElse():
261 bool = IntegerType.get_signless(1)
262 i32 = IntegerType.get_signless(32)
264 @func.FuncOp.from_py_func(bool)
265 def simple_if(cond):
266 if_op = scf.IfOp(cond)
267 with InsertionPoint(if_op.then_block):
268 one = arith.ConstantOp(i32, 1)
269 add = arith.AddIOp(one, one)
270 scf.YieldOp([])
271 return
274 # CHECK: func @simple_if(%[[ARG0:.*]]: i1)
275 # CHECK: scf.if %[[ARG0:.*]]
276 # CHECK: %[[ONE:.*]] = arith.constant 1
277 # CHECK: %[[ADD:.*]] = arith.addi %[[ONE]], %[[ONE]]
278 # CHECK: return
281 @constructAndPrintInModule
282 def testNestedIf():
283 bool = IntegerType.get_signless(1)
284 i32 = IntegerType.get_signless(32)
286 @func.FuncOp.from_py_func(bool, bool)
287 def nested_if(b, c):
288 if_op = scf.IfOp(b)
289 with InsertionPoint(if_op.then_block) as ip:
290 if_op = scf.IfOp(c, ip=ip)
291 with InsertionPoint(if_op.then_block):
292 one = arith.ConstantOp(i32, 1)
293 add = arith.AddIOp(one, one)
294 scf.YieldOp([])
295 scf.YieldOp([])
296 return
299 # CHECK: func @nested_if(%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
300 # CHECK: scf.if %[[ARG0:.*]]
301 # CHECK: scf.if %[[ARG1:.*]]
302 # CHECK: %[[ONE:.*]] = arith.constant 1
303 # CHECK: %[[ADD:.*]] = arith.addi %[[ONE]], %[[ONE]]
304 # CHECK: return
307 @constructAndPrintInModule
308 def testIfWithElse():
309 bool = IntegerType.get_signless(1)
310 i32 = IntegerType.get_signless(32)
312 @func.FuncOp.from_py_func(bool)
313 def simple_if_else(cond):
314 if_op = scf.IfOp(cond, [i32, i32], hasElse=True)
315 with InsertionPoint(if_op.then_block):
316 x_true = arith.ConstantOp(i32, 0)
317 y_true = arith.ConstantOp(i32, 1)
318 scf.YieldOp([x_true, y_true])
319 with InsertionPoint(if_op.else_block):
320 x_false = arith.ConstantOp(i32, 2)
321 y_false = arith.ConstantOp(i32, 3)
322 scf.YieldOp([x_false, y_false])
323 add = arith.AddIOp(if_op.results[0], if_op.results[1])
324 return
327 # CHECK: func @simple_if_else(%[[ARG0:.*]]: i1)
328 # CHECK: %[[RET:.*]]:2 = scf.if %[[ARG0:.*]]
329 # CHECK: %[[ZERO:.*]] = arith.constant 0
330 # CHECK: %[[ONE:.*]] = arith.constant 1
331 # CHECK: scf.yield %[[ZERO]], %[[ONE]]
332 # CHECK: } else {
333 # CHECK: %[[TWO:.*]] = arith.constant 2
334 # CHECK: %[[THREE:.*]] = arith.constant 3
335 # CHECK: scf.yield %[[TWO]], %[[THREE]]
336 # CHECK: arith.addi %[[RET]]#0, %[[RET]]#1
337 # CHECK: return