1 # RUN: %PYTHON %s | FileCheck %s
4 from mlir
.dialects
import arith
5 from mlir
.dialects
import func
6 from mlir
.dialects
import memref
7 from mlir
.dialects
import scf
8 from mlir
.passmanager
import PassManager
11 def constructAndPrintInModule(f
):
12 print("\nTEST:", f
.__name
__)
13 with
Context(), Location
.unknown():
14 module
= Module
.create()
15 with
InsertionPoint(module
.body
):
21 # CHECK-LABEL: TEST: testSimpleLoop
22 @constructAndPrintInModule
24 index_type
= IndexType
.get()
26 @func.FuncOp
.from_py_func(index_type
, index_type
, index_type
)
27 def simple_loop(lb
, ub
, step
):
28 loop
= scf
.ForOp(lb
, ub
, step
, [lb
, lb
])
29 with
InsertionPoint(loop
.body
):
30 scf
.YieldOp(loop
.inner_iter_args
)
34 # CHECK: func @simple_loop(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
35 # CHECK: scf.for %{{.*}} = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
36 # CHECK: iter_args(%[[I1:.*]] = %[[ARG0]], %[[I2:.*]] = %[[ARG0]])
37 # CHECK: scf.yield %[[I1]], %[[I2]]
40 # CHECK-LABEL: TEST: testInductionVar
41 @constructAndPrintInModule
42 def testInductionVar():
43 index_type
= IndexType
.get()
45 @func.FuncOp
.from_py_func(index_type
, index_type
, index_type
)
46 def induction_var(lb
, ub
, step
):
47 loop
= scf
.ForOp(lb
, ub
, step
, [lb
])
48 with
InsertionPoint(loop
.body
):
49 scf
.YieldOp([loop
.induction_variable
])
53 # CHECK: func @induction_var(%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
54 # CHECK: scf.for %[[IV:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
55 # CHECK: scf.yield %[[IV]]
58 # CHECK-LABEL: TEST: testForSugar
59 @constructAndPrintInModule
61 index_type
= IndexType
.get()
62 memref_t
= MemRefType
.get([10], index_type
)
65 # CHECK: func.func @range_loop_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
66 # CHECK: scf.for %[[VAL_4:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
67 # CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_4]] : index
68 # CHECK: memref.store %[[VAL_5]], %[[VAL_3]]{{\[}}%[[VAL_4]]] : memref<10xindex>
72 @func.FuncOp
.from_py_func(index_type
, index_type
, index_type
, memref_t
)
73 def range_loop_1(lb
, ub
, step
, memref_v
):
74 for i
in range(lb
, ub
, step
):
75 add
= arith
.addi(i
, i
)
76 memref
.store(add
, memref_v
, [i
])
80 # CHECK: func.func @range_loop_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
81 # CHECK: %[[VAL_4:.*]] = arith.constant 10 : index
82 # CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
83 # CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_0]] to %[[VAL_4]] step %[[VAL_5]] {
84 # CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
85 # CHECK: memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex>
89 @func.FuncOp
.from_py_func(index_type
, index_type
, index_type
, memref_t
)
90 def range_loop_2(lb
, ub
, step
, memref_v
):
91 for i
in range(lb
, 10, 1):
92 add
= arith
.addi(i
, i
)
93 memref
.store(add
, memref_v
, [i
])
96 # CHECK: func.func @range_loop_3(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
97 # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
98 # CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
99 # CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_1]] step %[[VAL_5]] {
100 # CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
101 # CHECK: memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex>
105 @func.FuncOp
.from_py_func(index_type
, index_type
, index_type
, memref_t
)
106 def range_loop_3(lb
, ub
, step
, memref_v
):
107 for i
in range(0, ub
, 1):
108 add
= arith
.addi(i
, i
)
109 memref
.store(add
, memref_v
, [i
])
112 # CHECK: func.func @range_loop_4(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
113 # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
114 # CHECK: %[[VAL_5:.*]] = arith.constant 10 : index
115 # CHECK: scf.for %[[VAL_6:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_2]] {
116 # CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_6]], %[[VAL_6]] : index
117 # CHECK: memref.store %[[VAL_7]], %[[VAL_3]]{{\[}}%[[VAL_6]]] : memref<10xindex>
121 @func.FuncOp
.from_py_func(index_type
, index_type
, index_type
, memref_t
)
122 def range_loop_4(lb
, ub
, step
, memref_v
):
123 for i
in range(0, 10, step
):
124 add
= arith
.addi(i
, i
)
125 memref
.store(add
, memref_v
, [i
])
128 # CHECK: func.func @range_loop_5(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
129 # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
130 # CHECK: %[[VAL_5:.*]] = arith.constant 10 : index
131 # CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
132 # CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] {
133 # CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
134 # CHECK: memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex>
138 @func.FuncOp
.from_py_func(index_type
, index_type
, index_type
, memref_t
)
139 def range_loop_5(lb
, ub
, step
, memref_v
):
140 for i
in range(0, 10, 1):
141 add
= arith
.addi(i
, i
)
142 memref
.store(add
, memref_v
, [i
])
145 # CHECK: func.func @range_loop_6(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
146 # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
147 # CHECK: %[[VAL_5:.*]] = arith.constant 10 : index
148 # CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
149 # CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] {
150 # CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
151 # CHECK: memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex>
155 @func.FuncOp
.from_py_func(index_type
, index_type
, index_type
, memref_t
)
156 def range_loop_6(lb
, ub
, step
, memref_v
):
157 for i
in range(0, 10):
158 add
= arith
.addi(i
, i
)
159 memref
.store(add
, memref_v
, [i
])
162 # CHECK: func.func @range_loop_7(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
163 # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
164 # CHECK: %[[VAL_5:.*]] = arith.constant 10 : index
165 # CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
166 # CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_4]] to %[[VAL_5]] step %[[VAL_6]] {
167 # CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
168 # CHECK: memref.store %[[VAL_8]], %[[VAL_3]]{{\[}}%[[VAL_7]]] : memref<10xindex>
172 @func.FuncOp
.from_py_func(index_type
, index_type
, index_type
, memref_t
)
173 def range_loop_7(lb
, ub
, step
, memref_v
):
175 add
= arith
.addi(i
, i
)
176 memref
.store(add
, memref_v
, [i
])
179 # CHECK: func.func @loop_yield_1(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
180 # CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
181 # CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
182 # CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
183 # CHECK: %[[VAL_7:.*]] = arith.constant 100 : index
184 # CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
185 # CHECK: %[[VAL_10:.*]] = scf.for %[[IV:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[ITER:.*]] = %[[VAL_4]]) -> (index) {
186 # CHECK: %[[VAL_9:.*]] = arith.addi %[[ITER]], %[[IV]] : index
187 # CHECK: scf.yield %[[VAL_9]] : index
189 # CHECK: memref.store %[[VAL_10]], %[[VAL_3]]{{\[}}%[[VAL_5]]] : memref<10xindex>
192 @func.FuncOp
.from_py_func(index_type
, index_type
, index_type
, memref_t
)
193 def loop_yield_1(lb
, ub
, step
, memref_v
):
194 sum = arith
.ConstantOp
.create_index(0)
195 c0
= arith
.ConstantOp
.create_index(0)
196 for i
, loc_sum
, sum in scf
.for_(0, 100, 1, [sum]):
197 loc_sum
= arith
.addi(loc_sum
, i
)
198 scf
.yield_([loc_sum
])
199 memref
.store(sum, memref_v
, [c0
])
201 # CHECK: func.func @loop_yield_2(%[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: memref<10xindex>) {
202 # CHECK: %[[c0:.*]] = arith.constant 0 : index
203 # CHECK: %[[c2:.*]] = arith.constant 2 : index
204 # CHECK: %[[REF1:.*]] = arith.constant 0 : index
205 # CHECK: %[[REF2:.*]] = arith.constant 1 : index
206 # CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
207 # CHECK: %[[VAL_7:.*]] = arith.constant 100 : index
208 # CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
209 # CHECK: %[[RES:.*]] = scf.for %[[IV:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] iter_args(%[[ITER1:.*]] = %[[c0]], %[[ITER2:.*]] = %[[c2]]) -> (index, index) {
210 # CHECK: %[[VAL_9:.*]] = arith.addi %[[ITER1]], %[[IV]] : index
211 # CHECK: %[[VAL_10:.*]] = arith.addi %[[ITER2]], %[[IV]] : index
212 # CHECK: scf.yield %[[VAL_9]], %[[VAL_10]] : index, index
216 @func.FuncOp
.from_py_func(index_type
, index_type
, index_type
, memref_t
)
217 def loop_yield_2(lb
, ub
, step
, memref_v
):
218 sum1
= arith
.ConstantOp
.create_index(0)
219 sum2
= arith
.ConstantOp
.create_index(2)
220 c0
= arith
.ConstantOp
.create_index(0)
221 c1
= arith
.ConstantOp
.create_index(1)
222 for i
, [loc_sum1
, loc_sum2
], [sum1
, sum2
] in scf
.for_(0, 100, 1, [sum1
, sum2
]):
223 loc_sum1
= arith
.addi(loc_sum1
, i
)
224 loc_sum2
= arith
.addi(loc_sum2
, i
)
225 scf
.yield_([loc_sum1
, loc_sum2
])
226 memref
.store(sum1
, memref_v
, [c0
])
227 memref
.store(sum2
, memref_v
, [c1
])
230 @constructAndPrintInModule
231 def testOpsAsArguments():
232 index_type
= IndexType
.get()
233 callee
= func
.FuncOp("callee", ([], [index_type
, index_type
]), visibility
="private")
234 f
= func
.FuncOp("ops_as_arguments", ([], []))
235 with
InsertionPoint(f
.add_entry_block()):
236 lb
= arith
.ConstantOp
.create_index(0)
237 ub
= arith
.ConstantOp
.create_index(42)
238 step
= arith
.ConstantOp
.create_index(2)
239 iter_args
= func
.CallOp(callee
, [])
240 loop
= scf
.ForOp(lb
, ub
, step
, iter_args
)
241 with
InsertionPoint(loop
.body
):
242 scf
.YieldOp(loop
.inner_iter_args
)
246 # CHECK-LABEL: TEST: testOpsAsArguments
247 # CHECK: func private @callee() -> (index, index)
248 # CHECK: func @ops_as_arguments() {
249 # CHECK: %[[LB:.*]] = arith.constant 0
250 # CHECK: %[[UB:.*]] = arith.constant 42
251 # CHECK: %[[STEP:.*]] = arith.constant 2
252 # CHECK: %[[ARGS:.*]]:2 = call @callee()
253 # CHECK: scf.for %arg0 = %c0 to %c42 step %c2
254 # CHECK: iter_args(%{{.*}} = %[[ARGS]]#0, %{{.*}} = %[[ARGS]]#1)
255 # CHECK: scf.yield %{{.*}}, %{{.*}}
259 @constructAndPrintInModule
260 def testIfWithoutElse():
261 bool = IntegerType
.get_signless(1)
262 i32
= IntegerType
.get_signless(32)
264 @func.FuncOp
.from_py_func(bool)
266 if_op
= scf
.IfOp(cond
)
267 with
InsertionPoint(if_op
.then_block
):
268 one
= arith
.ConstantOp(i32
, 1)
269 add
= arith
.AddIOp(one
, one
)
274 # CHECK: func @simple_if(%[[ARG0:.*]]: i1)
275 # CHECK: scf.if %[[ARG0:.*]]
276 # CHECK: %[[ONE:.*]] = arith.constant 1
277 # CHECK: %[[ADD:.*]] = arith.addi %[[ONE]], %[[ONE]]
281 @constructAndPrintInModule
283 bool = IntegerType
.get_signless(1)
284 i32
= IntegerType
.get_signless(32)
286 @func.FuncOp
.from_py_func(bool, bool)
289 with
InsertionPoint(if_op
.then_block
) as ip
:
290 if_op
= scf
.IfOp(c
, ip
=ip
)
291 with
InsertionPoint(if_op
.then_block
):
292 one
= arith
.ConstantOp(i32
, 1)
293 add
= arith
.AddIOp(one
, one
)
299 # CHECK: func @nested_if(%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
300 # CHECK: scf.if %[[ARG0:.*]]
301 # CHECK: scf.if %[[ARG1:.*]]
302 # CHECK: %[[ONE:.*]] = arith.constant 1
303 # CHECK: %[[ADD:.*]] = arith.addi %[[ONE]], %[[ONE]]
307 @constructAndPrintInModule
308 def testIfWithElse():
309 bool = IntegerType
.get_signless(1)
310 i32
= IntegerType
.get_signless(32)
312 @func.FuncOp
.from_py_func(bool)
313 def simple_if_else(cond
):
314 if_op
= scf
.IfOp(cond
, [i32
, i32
], hasElse
=True)
315 with
InsertionPoint(if_op
.then_block
):
316 x_true
= arith
.ConstantOp(i32
, 0)
317 y_true
= arith
.ConstantOp(i32
, 1)
318 scf
.YieldOp([x_true
, y_true
])
319 with
InsertionPoint(if_op
.else_block
):
320 x_false
= arith
.ConstantOp(i32
, 2)
321 y_false
= arith
.ConstantOp(i32
, 3)
322 scf
.YieldOp([x_false
, y_false
])
323 add
= arith
.AddIOp(if_op
.results
[0], if_op
.results
[1])
327 # CHECK: func @simple_if_else(%[[ARG0:.*]]: i1)
328 # CHECK: %[[RET:.*]]:2 = scf.if %[[ARG0:.*]]
329 # CHECK: %[[ZERO:.*]] = arith.constant 0
330 # CHECK: %[[ONE:.*]] = arith.constant 1
331 # CHECK: scf.yield %[[ZERO]], %[[ONE]]
333 # CHECK: %[[TWO:.*]] = arith.constant 2
334 # CHECK: %[[THREE:.*]] = arith.constant 3
335 # CHECK: scf.yield %[[TWO]], %[[THREE]]
336 # CHECK: arith.addi %[[RET]]#0, %[[RET]]#1