[mlir][py] Enable loading only specified dialects during creation. (#121421)
[llvm-project.git] / mlir / test / python / execution_engine.py
blob6d3a8db8c24be97b69d8c1f8064156fe689de870
1 # RUN: env MLIR_RUNNER_UTILS=%mlir_runner_utils MLIR_C_RUNNER_UTILS=%mlir_c_runner_utils %PYTHON %s 2>&1 | FileCheck %s
2 # REQUIRES: host-supports-jit
3 import gc, sys, os, tempfile
4 from mlir.ir import *
5 from mlir.passmanager import *
6 from mlir.execution_engine import *
7 from mlir.runtime import *
8 from ml_dtypes import bfloat16, float8_e5m2
10 MLIR_RUNNER_UTILS = os.getenv(
11 "MLIR_RUNNER_UTILS", "../../../../lib/libmlir_runner_utils.so"
13 MLIR_C_RUNNER_UTILS = os.getenv(
14 "MLIR_C_RUNNER_UTILS", "../../../../lib/libmlir_c_runner_utils.so"
17 # Log everything to stderr and flush so that we have a unified stream to match
18 # errors/info emitted by MLIR to stderr.
19 def log(*args):
20 print(*args, file=sys.stderr)
21 sys.stderr.flush()
24 def run(f):
25 log("\nTEST:", f.__name__)
26 f()
27 gc.collect()
28 assert Context._get_live_count() == 0
31 # Verify capsule interop.
32 # CHECK-LABEL: TEST: testCapsule
33 def testCapsule():
34 with Context():
35 module = Module.parse(
36 r"""
37 llvm.func @none() {
38 llvm.return
40 """
42 execution_engine = ExecutionEngine(module)
43 execution_engine_capsule = execution_engine._CAPIPtr
44 # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr
45 log(repr(execution_engine_capsule))
46 execution_engine._testing_release()
47 execution_engine1 = ExecutionEngine._CAPICreate(execution_engine_capsule)
48 # CHECK: _mlirExecutionEngine.ExecutionEngine
49 log(repr(execution_engine1))
52 run(testCapsule)
55 # Test invalid ExecutionEngine creation
56 # CHECK-LABEL: TEST: testInvalidModule
57 def testInvalidModule():
58 with Context():
59 # Builtin function
60 module = Module.parse(
61 r"""
62 func.func @foo() { return }
63 """
65 # CHECK: Got RuntimeError: Failure while creating the ExecutionEngine.
66 try:
67 execution_engine = ExecutionEngine(module)
68 except RuntimeError as e:
69 log("Got RuntimeError: ", e)
72 run(testInvalidModule)
75 def lowerToLLVM(module):
76 pm = PassManager.parse(
77 "builtin.module(convert-complex-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,convert-arith-to-llvm,convert-cf-to-llvm,reconcile-unrealized-casts)"
79 pm.run(module.operation)
80 return module
83 # Test simple ExecutionEngine execution
84 # CHECK-LABEL: TEST: testInvokeVoid
85 def testInvokeVoid():
86 with Context():
87 module = Module.parse(
88 r"""
89 func.func @void() attributes { llvm.emit_c_interface } {
90 return
92 """
94 execution_engine = ExecutionEngine(lowerToLLVM(module))
95 # Nothing to check other than no exception thrown here.
96 execution_engine.invoke("void")
99 run(testInvokeVoid)
102 # Test argument passing and result with a simple float addition.
103 # CHECK-LABEL: TEST: testInvokeFloatAdd
104 def testInvokeFloatAdd():
105 with Context():
106 module = Module.parse(
107 r"""
108 func.func @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } {
109 %add = arith.addf %arg0, %arg1 : f32
110 return %add : f32
114 execution_engine = ExecutionEngine(lowerToLLVM(module))
115 # Prepare arguments: two input floats and one result.
116 # Arguments must be passed as pointers.
117 c_float_p = ctypes.c_float * 1
118 arg0 = c_float_p(42.0)
119 arg1 = c_float_p(2.0)
120 res = c_float_p(-1.0)
121 execution_engine.invoke("add", arg0, arg1, res)
122 # CHECK: 42.0 + 2.0 = 44.0
123 log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0]))
126 run(testInvokeFloatAdd)
129 # Test callback
130 # CHECK-LABEL: TEST: testBasicCallback
131 def testBasicCallback():
132 # Define a callback function that takes a float and an integer and returns a float.
133 @ctypes.CFUNCTYPE(ctypes.c_float, ctypes.c_float, ctypes.c_int)
134 def callback(a, b):
135 return a / 2 + b / 2
137 with Context():
138 # The module just forwards to a runtime function known as "some_callback_into_python".
139 module = Module.parse(
140 r"""
141 func.func @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } {
142 %resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32)
143 return %resf : f32
145 func.func private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface }
148 execution_engine = ExecutionEngine(lowerToLLVM(module))
149 execution_engine.register_runtime("some_callback_into_python", callback)
151 # Prepare arguments: two input floats and one result.
152 # Arguments must be passed as pointers.
153 c_float_p = ctypes.c_float * 1
154 c_int_p = ctypes.c_int * 1
155 arg0 = c_float_p(42.0)
156 arg1 = c_int_p(2)
157 res = c_float_p(-1.0)
158 execution_engine.invoke("add", arg0, arg1, res)
159 # CHECK: 42.0 + 2 = 44.0
160 log("{0} + {1} = {2}".format(arg0[0], arg1[0], res[0] * 2))
163 run(testBasicCallback)
166 # Test callback with an unranked memref
167 # CHECK-LABEL: TEST: testUnrankedMemRefCallback
168 def testUnrankedMemRefCallback():
169 # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
170 @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
171 def callback(a):
172 arr = unranked_memref_to_numpy(a, np.float32)
173 log("Inside callback: ")
174 log(arr)
176 with Context():
177 # The module just forwards to a runtime function known as "some_callback_into_python".
178 module = Module.parse(
179 r"""
180 func.func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } {
181 call @some_callback_into_python(%arg0) : (memref<*xf32>) -> ()
182 return
184 func.func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface }
187 execution_engine = ExecutionEngine(lowerToLLVM(module))
188 execution_engine.register_runtime("some_callback_into_python", callback)
189 inp_arr = np.array([[1.0, 2.0], [3.0, 4.0]], np.float32)
190 # CHECK: Inside callback:
191 # CHECK{LITERAL}: [[1. 2.]
192 # CHECK{LITERAL}: [3. 4.]]
193 execution_engine.invoke(
194 "callback_memref",
195 ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(inp_arr))),
197 inp_arr_1 = np.array([5, 6, 7], dtype=np.float32)
198 strided_arr = np.lib.stride_tricks.as_strided(
199 inp_arr_1, strides=(4, 0), shape=(3, 4)
201 # CHECK: Inside callback:
202 # CHECK{LITERAL}: [[5. 5. 5. 5.]
203 # CHECK{LITERAL}: [6. 6. 6. 6.]
204 # CHECK{LITERAL}: [7. 7. 7. 7.]]
205 execution_engine.invoke(
206 "callback_memref",
207 ctypes.pointer(ctypes.pointer(get_unranked_memref_descriptor(strided_arr))),
211 run(testUnrankedMemRefCallback)
214 # Test callback with a ranked memref.
215 # CHECK-LABEL: TEST: testRankedMemRefCallback
216 def testRankedMemRefCallback():
217 # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
218 @ctypes.CFUNCTYPE(
219 None,
220 ctypes.POINTER(
221 make_nd_memref_descriptor(2, np.ctypeslib.as_ctypes_type(np.float32))
224 def callback(a):
225 arr = ranked_memref_to_numpy(a)
226 log("Inside Callback: ")
227 log(arr)
229 with Context():
230 # The module just forwards to a runtime function known as "some_callback_into_python".
231 module = Module.parse(
232 r"""
233 func.func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } {
234 call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> ()
235 return
237 func.func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface }
240 execution_engine = ExecutionEngine(lowerToLLVM(module))
241 execution_engine.register_runtime("some_callback_into_python", callback)
242 inp_arr = np.array([[1.0, 5.0], [6.0, 7.0]], np.float32)
243 # CHECK: Inside Callback:
244 # CHECK{LITERAL}: [[1. 5.]
245 # CHECK{LITERAL}: [6. 7.]]
246 execution_engine.invoke(
247 "callback_memref",
248 ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))),
252 run(testRankedMemRefCallback)
255 # Test callback with a ranked memref with non-zero offset.
256 # CHECK-LABEL: TEST: testRankedMemRefWithOffsetCallback
257 def testRankedMemRefWithOffsetCallback():
258 # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
259 @ctypes.CFUNCTYPE(
260 None,
261 ctypes.POINTER(
262 make_nd_memref_descriptor(1, np.ctypeslib.as_ctypes_type(np.float32))
265 def callback(a):
266 arr = ranked_memref_to_numpy(a)
267 log("Inside Callback: ")
268 log(arr)
270 with Context():
271 # The module takes a subview of the argument memref and calls the callback with it
272 module = Module.parse(
273 r"""
274 func.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} {
275 %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref<f32>, index, index, index
276 %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref<f32> to memref<2xf32, strided<[1], offset: 3>>
277 %cast = memref.cast %reinterpret_cast : memref<2xf32, strided<[1], offset: 3>> to memref<?xf32, strided<[?], offset: ?>>
278 call @some_callback_into_python(%cast) : (memref<?xf32, strided<[?], offset: ?>>) -> ()
279 return
281 func.func private @some_callback_into_python(memref<?xf32, strided<[?], offset: ?>>) attributes {llvm.emit_c_interface}
284 execution_engine = ExecutionEngine(lowerToLLVM(module))
285 execution_engine.register_runtime("some_callback_into_python", callback)
286 inp_arr = np.array([0, 0, 0, 1, 2], np.float32)
287 # CHECK: Inside Callback:
288 # CHECK{LITERAL}: [1. 2.]
289 execution_engine.invoke(
290 "callback_memref",
291 ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))),
295 run(testRankedMemRefWithOffsetCallback)
298 # Test callback with an unranked memref with non-zero offset
299 # CHECK-LABEL: TEST: testUnrankedMemRefWithOffsetCallback
300 def testUnrankedMemRefWithOffsetCallback():
301 # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
302 @ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
303 def callback(a):
304 arr = unranked_memref_to_numpy(a, np.float32)
305 log("Inside callback: ")
306 log(arr)
308 with Context():
309 # The module takes a subview of the argument memref, casts it to an unranked memref and
310 # calls the callback with it.
311 module = Module.parse(
312 r"""
313 func.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} {
314 %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref<f32>, index, index, index
315 %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref<f32> to memref<2xf32, strided<[1], offset: 3>>
316 %cast = memref.cast %reinterpret_cast : memref<2xf32, strided<[1], offset: 3>> to memref<*xf32>
317 call @some_callback_into_python(%cast) : (memref<*xf32>) -> ()
318 return
320 func.func private @some_callback_into_python(memref<*xf32>) attributes {llvm.emit_c_interface}
323 execution_engine = ExecutionEngine(lowerToLLVM(module))
324 execution_engine.register_runtime("some_callback_into_python", callback)
325 inp_arr = np.array([1, 2, 3, 4, 5], np.float32)
326 # CHECK: Inside callback:
327 # CHECK{LITERAL}: [4. 5.]
328 execution_engine.invoke(
329 "callback_memref",
330 ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))),
333 run(testUnrankedMemRefWithOffsetCallback)
336 # Test addition of two memrefs.
337 # CHECK-LABEL: TEST: testMemrefAdd
338 def testMemrefAdd():
339 with Context():
340 module = Module.parse(
342 module {
343 func.func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } {
344 %0 = arith.constant 0 : index
345 %1 = memref.load %arg0[%0] : memref<1xf32>
346 %2 = memref.load %arg1[] : memref<f32>
347 %3 = arith.addf %1, %2 : f32
348 memref.store %3, %arg2[%0] : memref<1xf32>
349 return
351 } """
353 arg1 = np.array([32.5]).astype(np.float32)
354 arg2 = np.array(6).astype(np.float32)
355 res = np.array([0]).astype(np.float32)
357 arg1_memref_ptr = ctypes.pointer(
358 ctypes.pointer(get_ranked_memref_descriptor(arg1))
360 arg2_memref_ptr = ctypes.pointer(
361 ctypes.pointer(get_ranked_memref_descriptor(arg2))
363 res_memref_ptr = ctypes.pointer(
364 ctypes.pointer(get_ranked_memref_descriptor(res))
367 execution_engine = ExecutionEngine(lowerToLLVM(module))
368 execution_engine.invoke(
369 "main", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
371 # CHECK: [32.5] + 6.0 = [38.5]
372 log("{0} + {1} = {2}".format(arg1, arg2, res))
375 run(testMemrefAdd)
378 # Test addition of two f16 memrefs
379 # CHECK-LABEL: TEST: testF16MemrefAdd
380 def testF16MemrefAdd():
381 with Context():
382 module = Module.parse(
384 module {
385 func.func @main(%arg0: memref<1xf16>,
386 %arg1: memref<1xf16>,
387 %arg2: memref<1xf16>) attributes { llvm.emit_c_interface } {
388 %0 = arith.constant 0 : index
389 %1 = memref.load %arg0[%0] : memref<1xf16>
390 %2 = memref.load %arg1[%0] : memref<1xf16>
391 %3 = arith.addf %1, %2 : f16
392 memref.store %3, %arg2[%0] : memref<1xf16>
393 return
395 } """
398 arg1 = np.array([11.0]).astype(np.float16)
399 arg2 = np.array([22.0]).astype(np.float16)
400 arg3 = np.array([0.0]).astype(np.float16)
402 arg1_memref_ptr = ctypes.pointer(
403 ctypes.pointer(get_ranked_memref_descriptor(arg1))
405 arg2_memref_ptr = ctypes.pointer(
406 ctypes.pointer(get_ranked_memref_descriptor(arg2))
408 arg3_memref_ptr = ctypes.pointer(
409 ctypes.pointer(get_ranked_memref_descriptor(arg3))
412 execution_engine = ExecutionEngine(lowerToLLVM(module))
413 execution_engine.invoke(
414 "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr
416 # CHECK: [11.] + [22.] = [33.]
417 log("{0} + {1} = {2}".format(arg1, arg2, arg3))
419 # test to-numpy utility
420 # CHECK: [33.]
421 npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
422 log(npout)
425 run(testF16MemrefAdd)
428 # Test addition of two complex memrefs
429 # CHECK-LABEL: TEST: testComplexMemrefAdd
430 def testComplexMemrefAdd():
431 with Context():
432 module = Module.parse(
434 module {
435 func.func @main(%arg0: memref<1xcomplex<f64>>,
436 %arg1: memref<1xcomplex<f64>>,
437 %arg2: memref<1xcomplex<f64>>) attributes { llvm.emit_c_interface } {
438 %0 = arith.constant 0 : index
439 %1 = memref.load %arg0[%0] : memref<1xcomplex<f64>>
440 %2 = memref.load %arg1[%0] : memref<1xcomplex<f64>>
441 %3 = complex.add %1, %2 : complex<f64>
442 memref.store %3, %arg2[%0] : memref<1xcomplex<f64>>
443 return
445 } """
448 arg1 = np.array([1.0 + 2.0j]).astype(np.complex128)
449 arg2 = np.array([3.0 + 4.0j]).astype(np.complex128)
450 arg3 = np.array([0.0 + 0.0j]).astype(np.complex128)
452 arg1_memref_ptr = ctypes.pointer(
453 ctypes.pointer(get_ranked_memref_descriptor(arg1))
455 arg2_memref_ptr = ctypes.pointer(
456 ctypes.pointer(get_ranked_memref_descriptor(arg2))
458 arg3_memref_ptr = ctypes.pointer(
459 ctypes.pointer(get_ranked_memref_descriptor(arg3))
462 execution_engine = ExecutionEngine(lowerToLLVM(module))
463 execution_engine.invoke(
464 "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr
466 # CHECK: [1.+2.j] + [3.+4.j] = [4.+6.j]
467 log("{0} + {1} = {2}".format(arg1, arg2, arg3))
469 # test to-numpy utility
470 # CHECK: [4.+6.j]
471 npout = ranked_memref_to_numpy(arg3_memref_ptr[0])
472 log(npout)
475 run(testComplexMemrefAdd)
478 # Test addition of two complex unranked memrefs
479 # CHECK-LABEL: TEST: testComplexUnrankedMemrefAdd
480 def testComplexUnrankedMemrefAdd():
481 with Context():
482 module = Module.parse(
484 module {
485 func.func @main(%arg0: memref<*xcomplex<f32>>,
486 %arg1: memref<*xcomplex<f32>>,
487 %arg2: memref<*xcomplex<f32>>) attributes { llvm.emit_c_interface } {
488 %A = memref.cast %arg0 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
489 %B = memref.cast %arg1 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
490 %C = memref.cast %arg2 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
491 %0 = arith.constant 0 : index
492 %1 = memref.load %A[%0] : memref<1xcomplex<f32>>
493 %2 = memref.load %B[%0] : memref<1xcomplex<f32>>
494 %3 = complex.add %1, %2 : complex<f32>
495 memref.store %3, %C[%0] : memref<1xcomplex<f32>>
496 return
498 } """
501 arg1 = np.array([5.0 + 6.0j]).astype(np.complex64)
502 arg2 = np.array([7.0 + 8.0j]).astype(np.complex64)
503 arg3 = np.array([0.0 + 0.0j]).astype(np.complex64)
505 arg1_memref_ptr = ctypes.pointer(
506 ctypes.pointer(get_unranked_memref_descriptor(arg1))
508 arg2_memref_ptr = ctypes.pointer(
509 ctypes.pointer(get_unranked_memref_descriptor(arg2))
511 arg3_memref_ptr = ctypes.pointer(
512 ctypes.pointer(get_unranked_memref_descriptor(arg3))
515 execution_engine = ExecutionEngine(lowerToLLVM(module))
516 execution_engine.invoke(
517 "main", arg1_memref_ptr, arg2_memref_ptr, arg3_memref_ptr
519 # CHECK: [5.+6.j] + [7.+8.j] = [12.+14.j]
520 log("{0} + {1} = {2}".format(arg1, arg2, arg3))
522 # test to-numpy utility
523 # CHECK: [12.+14.j]
524 npout = unranked_memref_to_numpy(arg3_memref_ptr[0], np.dtype(np.complex64))
525 log(npout)
528 run(testComplexUnrankedMemrefAdd)
531 # Test bf16 memrefs
532 # CHECK-LABEL: TEST: testBF16Memref
533 def testBF16Memref():
534 with Context():
535 module = Module.parse(
537 module {
538 func.func @main(%arg0: memref<1xbf16>,
539 %arg1: memref<1xbf16>) attributes { llvm.emit_c_interface } {
540 %0 = arith.constant 0 : index
541 %1 = memref.load %arg0[%0] : memref<1xbf16>
542 memref.store %1, %arg1[%0] : memref<1xbf16>
543 return
545 } """
548 arg1 = np.array([0.5]).astype(bfloat16)
549 arg2 = np.array([0.0]).astype(bfloat16)
551 arg1_memref_ptr = ctypes.pointer(
552 ctypes.pointer(get_ranked_memref_descriptor(arg1))
554 arg2_memref_ptr = ctypes.pointer(
555 ctypes.pointer(get_ranked_memref_descriptor(arg2))
558 execution_engine = ExecutionEngine(lowerToLLVM(module))
559 execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
561 # test to-numpy utility
562 # CHECK: [0.5]
563 npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
564 log(npout)
567 run(testBF16Memref)
570 # Test f8E5M2 memrefs
571 # CHECK-LABEL: TEST: testF8E5M2Memref
572 def testF8E5M2Memref():
573 with Context():
574 module = Module.parse(
576 module {
577 func.func @main(%arg0: memref<1xf8E5M2>,
578 %arg1: memref<1xf8E5M2>) attributes { llvm.emit_c_interface } {
579 %0 = arith.constant 0 : index
580 %1 = memref.load %arg0[%0] : memref<1xf8E5M2>
581 memref.store %1, %arg1[%0] : memref<1xf8E5M2>
582 return
584 } """
587 arg1 = np.array([0.5]).astype(float8_e5m2)
588 arg2 = np.array([0.0]).astype(float8_e5m2)
590 arg1_memref_ptr = ctypes.pointer(
591 ctypes.pointer(get_ranked_memref_descriptor(arg1))
593 arg2_memref_ptr = ctypes.pointer(
594 ctypes.pointer(get_ranked_memref_descriptor(arg2))
597 execution_engine = ExecutionEngine(lowerToLLVM(module))
598 execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
600 # test to-numpy utility
601 # CHECK: [0.5]
602 npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
603 log(npout)
606 run(testF8E5M2Memref)
609 # Test addition of two 2d_memref
610 # CHECK-LABEL: TEST: testDynamicMemrefAdd2D
611 def testDynamicMemrefAdd2D():
612 with Context():
613 module = Module.parse(
615 module {
616 func.func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref<?x?xf32>, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} {
617 %c0 = arith.constant 0 : index
618 %c2 = arith.constant 2 : index
619 %c1 = arith.constant 1 : index
620 cf.br ^bb1(%c0 : index)
621 ^bb1(%0: index): // 2 preds: ^bb0, ^bb5
622 %1 = arith.cmpi slt, %0, %c2 : index
623 cf.cond_br %1, ^bb2, ^bb6
624 ^bb2: // pred: ^bb1
625 %c0_0 = arith.constant 0 : index
626 %c2_1 = arith.constant 2 : index
627 %c1_2 = arith.constant 1 : index
628 cf.br ^bb3(%c0_0 : index)
629 ^bb3(%2: index): // 2 preds: ^bb2, ^bb4
630 %3 = arith.cmpi slt, %2, %c2_1 : index
631 cf.cond_br %3, ^bb4, ^bb5
632 ^bb4: // pred: ^bb3
633 %4 = memref.load %arg0[%0, %2] : memref<2x2xf32>
634 %5 = memref.load %arg1[%0, %2] : memref<?x?xf32>
635 %6 = arith.addf %4, %5 : f32
636 memref.store %6, %arg2[%0, %2] : memref<2x2xf32>
637 %7 = arith.addi %2, %c1_2 : index
638 cf.br ^bb3(%7 : index)
639 ^bb5: // pred: ^bb3
640 %8 = arith.addi %0, %c1 : index
641 cf.br ^bb1(%8 : index)
642 ^bb6: // pred: ^bb1
643 return
648 arg1 = np.random.randn(2, 2).astype(np.float32)
649 arg2 = np.random.randn(2, 2).astype(np.float32)
650 res = np.random.randn(2, 2).astype(np.float32)
652 arg1_memref_ptr = ctypes.pointer(
653 ctypes.pointer(get_ranked_memref_descriptor(arg1))
655 arg2_memref_ptr = ctypes.pointer(
656 ctypes.pointer(get_ranked_memref_descriptor(arg2))
658 res_memref_ptr = ctypes.pointer(
659 ctypes.pointer(get_ranked_memref_descriptor(res))
662 execution_engine = ExecutionEngine(lowerToLLVM(module))
663 execution_engine.invoke(
664 "memref_add_2d", arg1_memref_ptr, arg2_memref_ptr, res_memref_ptr
666 # CHECK: True
667 log(np.allclose(arg1 + arg2, res))
670 run(testDynamicMemrefAdd2D)
673 # Test loading of shared libraries.
674 # CHECK-LABEL: TEST: testSharedLibLoad
675 def testSharedLibLoad():
676 with Context():
677 module = Module.parse(
679 module {
680 func.func @main(%arg0: memref<1xf32>) attributes { llvm.emit_c_interface } {
681 %c0 = arith.constant 0 : index
682 %cst42 = arith.constant 42.0 : f32
683 memref.store %cst42, %arg0[%c0] : memref<1xf32>
684 %u_memref = memref.cast %arg0 : memref<1xf32> to memref<*xf32>
685 call @printMemrefF32(%u_memref) : (memref<*xf32>) -> ()
686 return
688 func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface }
689 } """
691 arg0 = np.array([0.0]).astype(np.float32)
693 arg0_memref_ptr = ctypes.pointer(
694 ctypes.pointer(get_ranked_memref_descriptor(arg0))
697 if sys.platform == "win32":
698 shared_libs = [
699 "../../../../bin/mlir_runner_utils.dll",
700 "../../../../bin/mlir_c_runner_utils.dll",
702 elif sys.platform == "darwin":
703 shared_libs = [
704 "../../../../lib/libmlir_runner_utils.dylib",
705 "../../../../lib/libmlir_c_runner_utils.dylib",
707 else:
708 shared_libs = [
709 MLIR_RUNNER_UTILS,
710 MLIR_C_RUNNER_UTILS,
713 execution_engine = ExecutionEngine(
714 lowerToLLVM(module), opt_level=3, shared_libs=shared_libs
716 execution_engine.invoke("main", arg0_memref_ptr)
717 # CHECK: Unranked Memref
718 # CHECK-NEXT: [42]
721 run(testSharedLibLoad)
724 # Test that nano time clock is available.
725 # CHECK-LABEL: TEST: testNanoTime
726 def testNanoTime():
727 with Context():
728 module = Module.parse(
730 module {
731 func.func @main() attributes { llvm.emit_c_interface } {
732 %now = call @nanoTime() : () -> i64
733 %memref = memref.alloca() : memref<1xi64>
734 %c0 = arith.constant 0 : index
735 memref.store %now, %memref[%c0] : memref<1xi64>
736 %u_memref = memref.cast %memref : memref<1xi64> to memref<*xi64>
737 call @printMemrefI64(%u_memref) : (memref<*xi64>) -> ()
738 return
740 func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface }
741 func.func private @printMemrefI64(memref<*xi64>) attributes { llvm.emit_c_interface }
742 }"""
745 if sys.platform == "win32":
746 shared_libs = [
747 "../../../../bin/mlir_runner_utils.dll",
748 "../../../../bin/mlir_c_runner_utils.dll",
750 else:
751 shared_libs = [
752 MLIR_RUNNER_UTILS,
753 MLIR_C_RUNNER_UTILS,
756 execution_engine = ExecutionEngine(
757 lowerToLLVM(module), opt_level=3, shared_libs=shared_libs
759 execution_engine.invoke("main")
760 # CHECK: Unranked Memref
761 # CHECK: [{{.*}}]
764 run(testNanoTime)
767 # Test that nano time clock is available.
768 # CHECK-LABEL: TEST: testDumpToObjectFile
769 def testDumpToObjectFile():
770 fd, object_path = tempfile.mkstemp(suffix=".o")
772 try:
773 with Context():
774 module = Module.parse(
776 module {
777 func.func @main() attributes { llvm.emit_c_interface } {
778 return
780 }"""
783 execution_engine = ExecutionEngine(lowerToLLVM(module), opt_level=3)
785 # CHECK: Object file exists: True
786 print(f"Object file exists: {os.path.exists(object_path)}")
787 # CHECK: Object file is empty: True
788 print(f"Object file is empty: {os.path.getsize(object_path) == 0}")
790 execution_engine.dump_to_object_file(object_path)
792 # CHECK: Object file exists: True
793 print(f"Object file exists: {os.path.exists(object_path)}")
794 # CHECK: Object file is empty: False
795 print(f"Object file is empty: {os.path.getsize(object_path) == 0}")
797 finally:
798 os.close(fd)
799 os.remove(object_path)
802 run(testDumpToObjectFile)