[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / test / python / integration / dialects / linalg / opsrun.py
blobf77900bc277736c3081bd8da405820ddd246a1fb
1 # RUN: %PYTHON %s 2>&1 | FileCheck %s
3 import ctypes
4 import sys
5 from mlir.ir import *
6 from mlir.dialects import builtin
7 from mlir.dialects import func
8 from mlir.dialects import linalg
9 from mlir.passmanager import *
10 from mlir.execution_engine import *
12 from mlir.dialects.linalg.opdsl.lang import *
15 # Log everything to stderr and flush so that we have a unified stream to match
16 # errors/info emitted by MLIR to stderr.
17 def log(*args):
18 print(*args, file=sys.stderr)
19 sys.stderr.flush()
22 elemwise_boiler = """
23 func.func @main() -> f32 attributes {llvm.emit_c_interface} {
24 %v0 = arith.constant 0.0 : f32
25 %v1 = arith.constant 1.0 : f32
26 %v2 = arith.constant 2.0 : f32
28 %lhs = memref.alloc() : memref<f32>
29 %rhs = memref.alloc() : memref<4x8xf32>
30 %O0 = memref.alloc() : memref<4x8xf32>
31 %O1 = memref.alloc() : memref<4x8xf32>
32 linalg.fill ins(%v1 : f32) outs(%lhs : memref<f32>)
33 linalg.fill ins(%v2 : f32) outs(%rhs : memref<4x8xf32>)
34 linalg.fill ins(%v0 : f32) outs(%O0 : memref<4x8xf32>)
35 linalg.fill ins(%v0 : f32) outs(%O1 : memref<4x8xf32>)
37 call @elemwise_exp_add_on_buffers(%lhs, %rhs, %O0) :
38 (memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
39 call @elemwise_log_mul_on_buffers(%lhs, %rhs, %O1) :
40 (memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
42 %c0 = arith.constant 0 : index
43 %res0 = memref.load %O0[%c0, %c0] : memref<4x8xf32>
44 %res1 = memref.load %O1[%c0, %c0] : memref<4x8xf32>
46 %0 = arith.addf %res0, %res1 : f32
48 // TODO: FFI-based solution to allow testing and printing with python code.
49 return %0 : f32
51 """
53 fill_boiler = """
54 func.func @main() -> i32 attributes {llvm.emit_c_interface} {
55 %O0 = memref.alloc() : memref<i32>
56 %O1 = memref.alloc() : memref<16xi32>
57 %O2 = memref.alloc() : memref<4x16xi32>
59 %val0 = arith.constant 1.0 : f32
60 %val1 = arith.constant 2.0 : f32
61 %val2 = arith.constant 3.0 : f32
63 call @fill_0d_on_buffers(%val0, %O0) : (f32, memref<i32>) -> ()
64 call @fill_1d_on_buffers(%val1, %O1) : (f32, memref<16xi32>) -> ()
65 call @fill_2d_on_buffers(%val2, %O2) : (f32, memref<4x16xi32>) -> ()
67 %c0 = arith.constant 0 : index
68 %res0 = memref.load %O0[] : memref<i32>
69 %c8 = arith.constant 8 : index
70 %res1 = memref.load %O1[%c8] : memref<16xi32>
71 %c2 = arith.constant 2 : index
72 %res2 = memref.load %O2[%c2, %c8] : memref<4x16xi32>
74 %0 = arith.addi %res0, %res1 : i32
75 %1 = arith.addi %0, %res2 : i32
77 // TODO: FFI-based solution to allow testing and printing with python code.
78 return %1 : i32
80 """
82 fill_rng_boiler = """
83 func.func @main() -> i32 attributes {llvm.emit_c_interface} {
84 %O = memref.alloc() : memref<4x16xi32>
85 %min = arith.constant -1000.0 : f64
86 %max = arith.constant 1000.0 : f64
87 %seed = arith.constant 42 : i32
89 call @fill_rng_on_buffers(%min, %max, %seed, %O) :
90 (f64, f64, i32, memref<4x16xi32>) -> ()
92 %c0 = arith.constant 0 : index
93 %0 = memref.load %O[%c0, %c0] : memref<4x16xi32>
95 // TODO: FFI-based solution to allow testing and printing with python code.
96 return %0 : i32
98 """
100 conv_boiler = """
101 func.func @main() -> i32 attributes {llvm.emit_c_interface} {
102 %v0 = arith.constant 0 : i32
103 %v1 = arith.constant 1.0 : f64
104 %v2 = arith.constant 2.0 : f64
106 %input = memref.alloc() : memref<1x4x16x1xf64>
107 %filter = memref.alloc() : memref<2x2x1xf64>
108 %output = memref.alloc() : memref<1x2x4x1xi32>
109 linalg.fill ins(%v1 : f64) outs(%input : memref<1x4x16x1xf64>)
110 linalg.fill ins(%v2 : f64) outs(%filter : memref<2x2x1xf64>)
111 linalg.fill ins(%v0 : i32) outs(%output : memref<1x2x4x1xi32>)
113 call @conv_on_buffers(%input, %filter, %output) :
114 (memref<1x4x16x1xf64>, memref<2x2x1xf64>, memref<1x2x4x1xi32>) -> ()
116 %c0 = arith.constant 0 : index
117 %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32>
119 // TODO: FFI-based solution to allow testing and printing with python code.
120 return %0 : i32
124 pooling_boiler = """
125 func.func @main() -> i32 attributes {llvm.emit_c_interface} {
126 %v0 = arith.constant 0 : i32
127 %v42 = arith.constant 42.0 : f64
128 %v77 = arith.constant 77.0 : f64
129 %v-13 = arith.constant -13.0 : f64
130 %v1 = arith.constant 1.0 : f64
132 %input = memref.alloc() : memref<1x4x16x1xf64>
133 %shape = memref.alloc() : memref<2x2xf64>
134 %output = memref.alloc() : memref<1x2x4x1xi32>
135 linalg.fill ins(%v1 : f64) outs(%input : memref<1x4x16x1xf64>)
136 linalg.fill ins(%v1 : f64) outs(%shape : memref<2x2xf64>)
137 linalg.fill ins(%v0 : i32) outs(%output : memref<1x2x4x1xi32>)
139 %c0 = arith.constant 0 : index
140 %c1 = arith.constant 1 : index
141 %c2 = arith.constant 2 : index
142 memref.store %v42, %input[%c0, %c0, %c0, %c0] : memref<1x4x16x1xf64>
143 memref.store %v77, %input[%c0, %c0, %c1, %c0] : memref<1x4x16x1xf64>
144 memref.store %v-13, %input[%c0, %c1, %c0, %c0] : memref<1x4x16x1xf64>
146 call @pooling_on_buffers(%input, %shape, %output) :
147 (memref<1x4x16x1xf64>, memref<2x2xf64>, memref<1x2x4x1xi32>) -> ()
149 %0 = memref.load %output[%c0, %c0, %c0, %c0] : memref<1x2x4x1xi32>
151 // TODO: FFI-based solution to allow testing and printing with python code.
152 return %0 : i32
157 def transform(module, boilerplate):
158 # TODO: Allow cloning functions from one module to another.
159 # Atm we have to resort to string concatenation.
160 ops = module.operation.regions[0].blocks[0].operations
161 mod = Module.parse("\n".join([str(op) for op in ops]) + boilerplate)
163 pm = PassManager("builtin.module")
164 pm.add("func.func(convert-linalg-to-loops)")
165 pm.add("func.func(lower-affine)")
166 pm.add("func.func(convert-math-to-llvm)")
167 pm.add("func.func(convert-scf-to-cf)")
168 pm.add("func.func(arith-expand)")
169 pm.add("func.func(memref-expand)")
170 pm.add("convert-vector-to-llvm")
171 pm.add("finalize-memref-to-llvm")
172 pm.add("convert-func-to-llvm")
173 pm.add("reconcile-unrealized-casts")
174 pm.run(mod.operation)
175 return mod
178 def test_elemwise_builtin():
179 with Context() as ctx, Location.unknown():
180 module = Module.create()
181 f32 = F32Type.get()
182 i8 = IntegerType.get_signless(8)
183 with InsertionPoint(module.body):
185 @func.FuncOp.from_py_func(
186 MemRefType.get((), f32),
187 MemRefType.get((4, 8), f32),
188 MemRefType.get((4, 8), f32),
190 def elemwise_exp_add_on_buffers(lhs, rhs, out):
191 linalg.elemwise_unary(lhs, outs=[out])
192 linalg.elemwise_binary(out, rhs, outs=[out])
194 @func.FuncOp.from_py_func(
195 MemRefType.get((), f32),
196 MemRefType.get((4, 8), f32),
197 MemRefType.get((4, 8), f32),
199 def elemwise_log_mul_on_buffers(lhs, rhs, out):
200 linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log)
201 linalg.elemwise_binary(out, rhs, outs=[out], fun=BinaryFn.mul)
203 execution_engine = ExecutionEngine(transform(module, elemwise_boiler))
205 # TODO: FFI-based solution to allow testing and printing with python code.
206 # Prepare arguments: one result f32.
207 # Arguments must be passed as pointers.
208 c_float_p = ctypes.c_float * 1
209 res = c_float_p(-1.0)
210 execution_engine.invoke("main", res)
212 log("RESULT: ", res[0])
213 # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
214 # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
215 # CHECK: RESULT: 4.71828
218 test_elemwise_builtin()
221 def test_elemwise_generic():
222 with Context() as ctx, Location.unknown():
223 module = Module.create()
224 f32 = F32Type.get()
225 i8 = IntegerType.get_signless(8)
226 with InsertionPoint(module.body):
228 @func.FuncOp.from_py_func(
229 MemRefType.get((), f32),
230 MemRefType.get((4, 8), f32),
231 MemRefType.get((4, 8), f32),
233 def elemwise_exp_add_on_buffers(lhs, rhs, out):
234 linalg.elemwise_unary(lhs, outs=[out], emit_generic=True)
235 linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True)
237 @func.FuncOp.from_py_func(
238 MemRefType.get((), f32),
239 MemRefType.get((4, 8), f32),
240 MemRefType.get((4, 8), f32),
242 def elemwise_log_mul_on_buffers(lhs, rhs, out):
243 linalg.elemwise_unary(
244 lhs, outs=[out], fun=UnaryFn.log, emit_generic=True
246 linalg.elemwise_binary(
247 out, rhs, outs=[out], fun=BinaryFn.mul, emit_generic=True
250 execution_engine = ExecutionEngine(transform(module, elemwise_boiler))
252 # TODO: FFI-based solution to allow testing and printing with python code.
253 # Prepare arguments: one result f32.
254 # Arguments must be passed as pointers.
255 c_float_p = ctypes.c_float * 1
256 res = c_float_p(-1.0)
257 execution_engine.invoke("main", res)
259 log("RESULT: ", res[0])
260 # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
261 # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
262 # CHECK: RESULT: 4.71828
265 test_elemwise_generic()
268 def test_fill_builtin():
269 with Context() as ctx, Location.unknown():
270 module = Module.create()
271 f32 = F32Type.get()
272 i32 = IntegerType.get_signless(32)
273 with InsertionPoint(module.body):
275 @func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
276 def fill_0d_on_buffers(value, out):
277 linalg.fill(value, outs=[out])
279 @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
280 def fill_1d_on_buffers(value, out):
281 linalg.fill(value, outs=[out])
283 @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
284 def fill_2d_on_buffers(value, out):
285 linalg.fill(value, outs=[out])
287 execution_engine = ExecutionEngine(transform(module, fill_boiler))
289 # TODO: FFI-based solution to allow testing and printing with python code.
290 # Prepare arguments: one result i32.
291 # Arguments must be passed as pointers.
292 c_int_p = ctypes.c_int * 1
293 res = c_int_p(-1)
294 execution_engine.invoke("main", res)
296 log("RESULT: ", res[0])
297 # CHECK: RESULT: 6
300 test_fill_builtin()
303 def test_fill_generic():
304 with Context() as ctx, Location.unknown():
305 module = Module.create()
306 f32 = F32Type.get()
307 i32 = IntegerType.get_signless(32)
308 with InsertionPoint(module.body):
310 @func.FuncOp.from_py_func(f32, MemRefType.get([], i32))
311 def fill_0d_on_buffers(value, out):
312 linalg.fill(value, outs=[out], emit_generic=True)
314 @func.FuncOp.from_py_func(f32, MemRefType.get([16], i32))
315 def fill_1d_on_buffers(value, out):
316 linalg.fill(value, outs=[out], emit_generic=True)
318 @func.FuncOp.from_py_func(f32, MemRefType.get([4, 16], i32))
319 def fill_2d_on_buffers(value, out):
320 linalg.fill(value, outs=[out], emit_generic=True)
322 execution_engine = ExecutionEngine(transform(module, fill_boiler))
324 # TODO: FFI-based solution to allow testing and printing with python code.
325 # Prepare arguments: one result i32.
326 # Arguments must be passed as pointers.
327 c_int_p = ctypes.c_int * 1
328 res = c_int_p(-1)
329 execution_engine.invoke("main", res)
331 log("RESULT: ", res[0])
332 # CHECK: RESULT: 6
335 test_fill_generic()
338 def test_fill_rng_builtin():
339 with Context() as ctx, Location.unknown():
340 module = Module.create()
341 f64 = F64Type.get()
342 i32 = IntegerType.get_signless(32)
343 with InsertionPoint(module.body):
345 @func.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
346 def fill_rng_on_buffers(min, max, seed, out):
347 linalg.fill_rng_2d(min, max, seed, outs=[out])
349 execution_engine = ExecutionEngine(transform(module, fill_rng_boiler))
351 # TODO: FFI-based solution to allow testing and printing with python code.
352 # Prepare arguments: one result i32.
353 # Arguments must be passed as pointers.
354 c_int_p = ctypes.c_int * 1
355 res = c_int_p(-1)
356 execution_engine.invoke("main", res)
358 log("RESULT: ", res[0])
359 # CHECK: RESULT: -480
362 test_fill_rng_builtin()
365 def test_fill_rng_generic():
366 with Context() as ctx, Location.unknown():
367 module = Module.create()
368 f64 = F64Type.get()
369 i32 = IntegerType.get_signless(32)
370 with InsertionPoint(module.body):
372 @func.FuncOp.from_py_func(f64, f64, i32, MemRefType.get((4, 16), i32))
373 def fill_rng_on_buffers(min, max, seed, out):
374 linalg.fill_rng_2d(min, max, seed, outs=[out], emit_generic=True)
376 execution_engine = ExecutionEngine(transform(module, fill_rng_boiler))
378 # TODO: FFI-based solution to allow testing and printing with python code.
379 # Prepare arguments: one result i32.
380 # Arguments must be passed as pointers.
381 c_int_p = ctypes.c_int * 1
382 res = c_int_p(-1)
383 execution_engine.invoke("main", res)
385 log("RESULT: ", res[0])
386 # CHECK: RESULT: -480
389 test_fill_rng_generic()
392 def test_max_pooling_builtin():
393 with Context() as ctx, Location.unknown():
394 module = Module.create()
395 f64 = F64Type.get()
396 i32 = IntegerType.get_signless(32)
397 with InsertionPoint(module.body):
399 @func.FuncOp.from_py_func(
400 MemRefType.get((1, 4, 16, 1), f64),
401 MemRefType.get((2, 2), f64),
402 MemRefType.get((1, 2, 4, 1), i32),
404 def pooling_on_buffers(input, shape, output):
405 linalg.pooling_nhwc_max(
406 input, shape, outs=[output], strides=[2, 4], dilations=[1, 2]
409 execution_engine = ExecutionEngine(transform(module, pooling_boiler))
411 # TODO: FFI-based solution to allow testing and printing with python code.
412 # Prepare arguments: one result i32.
413 # Arguments must be passed as pointers.
414 c_int_p = ctypes.c_int * 1
415 res = c_int_p(-1)
416 execution_engine.invoke("main", res)
418 log("RESULT: ", res[0])
419 # 77 is not selected due to the dilation 2 in the second dimension.
420 # CHECK: RESULT: 42
423 test_max_pooling_builtin()
426 def test_max_pooling_generic():
427 with Context() as ctx, Location.unknown():
428 module = Module.create()
429 f64 = F64Type.get()
430 i32 = IntegerType.get_signless(32)
431 with InsertionPoint(module.body):
433 @func.FuncOp.from_py_func(
434 MemRefType.get((1, 4, 16, 1), f64),
435 MemRefType.get((2, 2), f64),
436 MemRefType.get((1, 2, 4, 1), i32),
438 def pooling_on_buffers(input, shape, output):
439 linalg.pooling_nhwc_max(
440 input,
441 shape,
442 outs=[output],
443 strides=[2, 4],
444 dilations=[1, 2],
445 emit_generic=True,
448 execution_engine = ExecutionEngine(transform(module, pooling_boiler))
450 # TODO: FFI-based solution to allow testing and printing with python code.
451 # Prepare arguments: one result i32.
452 # Arguments must be passed as pointers.
453 c_int_p = ctypes.c_int * 1
454 res = c_int_p(-1)
455 execution_engine.invoke("main", res)
457 log("RESULT: ", res[0])
458 # 77 is not selected due to the dilation 2 in the second dimension.
459 # CHECK: RESULT: 42
462 test_max_pooling_generic()
465 def test_min_pooling_builtin():
466 with Context() as ctx, Location.unknown():
467 module = Module.create()
468 f64 = F64Type.get()
469 i32 = IntegerType.get_signless(32)
470 with InsertionPoint(module.body):
472 @func.FuncOp.from_py_func(
473 MemRefType.get((1, 4, 16, 1), f64),
474 MemRefType.get((2, 2), f64),
475 MemRefType.get((1, 2, 4, 1), i32),
477 # Set the strides and use the default dilations.
478 def pooling_on_buffers(input, shape, output):
479 linalg.pooling_nhwc_min(input, shape, outs=[output], strides=[2, 4])
481 execution_engine = ExecutionEngine(transform(module, pooling_boiler))
483 # TODO: FFI-based solution to allow testing and printing with python code.
484 # Prepare arguments: one result i32.
485 # Arguments must be passed as pointers.
486 c_int_p = ctypes.c_int * 1
487 res = c_int_p(-1)
488 execution_engine.invoke("main", res)
490 log("RESULT: ", res[0])
491 # CHECK: RESULT: -13
494 test_min_pooling_builtin()
497 def test_min_pooling_generic():
498 with Context() as ctx, Location.unknown():
499 module = Module.create()
500 f64 = F64Type.get()
501 i32 = IntegerType.get_signless(32)
502 with InsertionPoint(module.body):
504 @func.FuncOp.from_py_func(
505 MemRefType.get((1, 4, 16, 1), f64),
506 MemRefType.get((2, 2), f64),
507 MemRefType.get((1, 2, 4, 1), i32),
509 # Set the strides and use the default dilations.
510 def pooling_on_buffers(input, shape, output):
511 linalg.pooling_nhwc_min(
512 input, shape, outs=[output], strides=[2, 4], emit_generic=True
515 execution_engine = ExecutionEngine(transform(module, pooling_boiler))
517 # TODO: FFI-based solution to allow testing and printing with python code.
518 # Prepare arguments: one result i32.
519 # Arguments must be passed as pointers.
520 c_int_p = ctypes.c_int * 1
521 res = c_int_p(-1)
522 execution_engine.invoke("main", res)
524 log("RESULT: ", res[0])
525 # CHECK: RESULT: -13
528 test_min_pooling_generic()