1 # RUN: env MLIR_RUNNER_UTILS=%mlir_runner_utils MLIR_C_RUNNER_UTILS=%mlir_c_runner_utils %PYTHON %s 2>&1 | FileCheck %s
2 # REQUIRES: host-supports-jit
3 import gc
, sys
, os
, tempfile
5 from mlir
.passmanager
import *
6 from mlir
.execution_engine
import *
7 from mlir
.runtime
import *
8 from ml_dtypes
import bfloat16
, float8_e5m2
10 MLIR_RUNNER_UTILS
= os
.getenv(
11 "MLIR_RUNNER_UTILS", "../../../../lib/libmlir_runner_utils.so"
13 MLIR_C_RUNNER_UTILS
= os
.getenv(
14 "MLIR_C_RUNNER_UTILS", "../../../../lib/libmlir_c_runner_utils.so"
17 # Log everything to stderr and flush so that we have a unified stream to match
18 # errors/info emitted by MLIR to stderr.
20 print(*args
, file=sys
.stderr
)
25 log("\nTEST:", f
.__name
__)
28 assert Context
._get
_live
_count
() == 0
31 # Verify capsule interop.
32 # CHECK-LABEL: TEST: testCapsule
35 module
= Module
.parse(
42 execution_engine
= ExecutionEngine(module
)
43 execution_engine_capsule
= execution_engine
._CAPIPtr
44 # CHECK: mlir.execution_engine.ExecutionEngine._CAPIPtr
45 log(repr(execution_engine_capsule
))
46 execution_engine
._testing
_release
()
47 execution_engine1
= ExecutionEngine
._CAPICreate
(execution_engine_capsule
)
48 # CHECK: _mlirExecutionEngine.ExecutionEngine
49 log(repr(execution_engine1
))
55 # Test invalid ExecutionEngine creation
56 # CHECK-LABEL: TEST: testInvalidModule
57 def testInvalidModule():
60 module
= Module
.parse(
62 func.func @foo() { return }
65 # CHECK: Got RuntimeError: Failure while creating the ExecutionEngine.
67 execution_engine
= ExecutionEngine(module
)
68 except RuntimeError as e
:
69 log("Got RuntimeError: ", e
)
72 run(testInvalidModule
)
75 def lowerToLLVM(module
):
76 pm
= PassManager
.parse(
77 "builtin.module(convert-complex-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,convert-arith-to-llvm,convert-cf-to-llvm,reconcile-unrealized-casts)"
79 pm
.run(module
.operation
)
83 # Test simple ExecutionEngine execution
84 # CHECK-LABEL: TEST: testInvokeVoid
87 module
= Module
.parse(
89 func.func @void() attributes { llvm.emit_c_interface } {
94 execution_engine
= ExecutionEngine(lowerToLLVM(module
))
95 # Nothing to check other than no exception thrown here.
96 execution_engine
.invoke("void")
102 # Test argument passing and result with a simple float addition.
103 # CHECK-LABEL: TEST: testInvokeFloatAdd
104 def testInvokeFloatAdd():
106 module
= Module
.parse(
108 func.func @add(%arg0: f32, %arg1: f32) -> f32 attributes { llvm.emit_c_interface } {
109 %add = arith.addf %arg0, %arg1 : f32
114 execution_engine
= ExecutionEngine(lowerToLLVM(module
))
115 # Prepare arguments: two input floats and one result.
116 # Arguments must be passed as pointers.
117 c_float_p
= ctypes
.c_float
* 1
118 arg0
= c_float_p(42.0)
119 arg1
= c_float_p(2.0)
120 res
= c_float_p(-1.0)
121 execution_engine
.invoke("add", arg0
, arg1
, res
)
122 # CHECK: 42.0 + 2.0 = 44.0
123 log("{0} + {1} = {2}".format(arg0
[0], arg1
[0], res
[0]))
126 run(testInvokeFloatAdd
)
130 # CHECK-LABEL: TEST: testBasicCallback
131 def testBasicCallback():
132 # Define a callback function that takes a float and an integer and returns a float.
133 @ctypes.CFUNCTYPE(ctypes
.c_float
, ctypes
.c_float
, ctypes
.c_int
)
138 # The module just forwards to a runtime function known as "some_callback_into_python".
139 module
= Module
.parse(
141 func.func @add(%arg0: f32, %arg1: i32) -> f32 attributes { llvm.emit_c_interface } {
142 %resf = call @some_callback_into_python(%arg0, %arg1) : (f32, i32) -> (f32)
145 func.func private @some_callback_into_python(f32, i32) -> f32 attributes { llvm.emit_c_interface }
148 execution_engine
= ExecutionEngine(lowerToLLVM(module
))
149 execution_engine
.register_runtime("some_callback_into_python", callback
)
151 # Prepare arguments: two input floats and one result.
152 # Arguments must be passed as pointers.
153 c_float_p
= ctypes
.c_float
* 1
154 c_int_p
= ctypes
.c_int
* 1
155 arg0
= c_float_p(42.0)
157 res
= c_float_p(-1.0)
158 execution_engine
.invoke("add", arg0
, arg1
, res
)
159 # CHECK: 42.0 + 2 = 44.0
160 log("{0} + {1} = {2}".format(arg0
[0], arg1
[0], res
[0] * 2))
163 run(testBasicCallback
)
166 # Test callback with an unranked memref
167 # CHECK-LABEL: TEST: testUnrankedMemRefCallback
168 def testUnrankedMemRefCallback():
169 # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
170 @ctypes.CFUNCTYPE(None, ctypes
.POINTER(UnrankedMemRefDescriptor
))
172 arr
= unranked_memref_to_numpy(a
, np
.float32
)
173 log("Inside callback: ")
177 # The module just forwards to a runtime function known as "some_callback_into_python".
178 module
= Module
.parse(
180 func.func @callback_memref(%arg0: memref<*xf32>) attributes { llvm.emit_c_interface } {
181 call @some_callback_into_python(%arg0) : (memref<*xf32>) -> ()
184 func.func private @some_callback_into_python(memref<*xf32>) -> () attributes { llvm.emit_c_interface }
187 execution_engine
= ExecutionEngine(lowerToLLVM(module
))
188 execution_engine
.register_runtime("some_callback_into_python", callback
)
189 inp_arr
= np
.array([[1.0, 2.0], [3.0, 4.0]], np
.float32
)
190 # CHECK: Inside callback:
191 # CHECK{LITERAL}: [[1. 2.]
192 # CHECK{LITERAL}: [3. 4.]]
193 execution_engine
.invoke(
195 ctypes
.pointer(ctypes
.pointer(get_unranked_memref_descriptor(inp_arr
))),
197 inp_arr_1
= np
.array([5, 6, 7], dtype
=np
.float32
)
198 strided_arr
= np
.lib
.stride_tricks
.as_strided(
199 inp_arr_1
, strides
=(4, 0), shape
=(3, 4)
201 # CHECK: Inside callback:
202 # CHECK{LITERAL}: [[5. 5. 5. 5.]
203 # CHECK{LITERAL}: [6. 6. 6. 6.]
204 # CHECK{LITERAL}: [7. 7. 7. 7.]]
205 execution_engine
.invoke(
207 ctypes
.pointer(ctypes
.pointer(get_unranked_memref_descriptor(strided_arr
))),
211 run(testUnrankedMemRefCallback
)
214 # Test callback with a ranked memref.
215 # CHECK-LABEL: TEST: testRankedMemRefCallback
216 def testRankedMemRefCallback():
217 # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
221 make_nd_memref_descriptor(2, np
.ctypeslib
.as_ctypes_type(np
.float32
))
225 arr
= ranked_memref_to_numpy(a
)
226 log("Inside Callback: ")
230 # The module just forwards to a runtime function known as "some_callback_into_python".
231 module
= Module
.parse(
233 func.func @callback_memref(%arg0: memref<2x2xf32>) attributes { llvm.emit_c_interface } {
234 call @some_callback_into_python(%arg0) : (memref<2x2xf32>) -> ()
237 func.func private @some_callback_into_python(memref<2x2xf32>) -> () attributes { llvm.emit_c_interface }
240 execution_engine
= ExecutionEngine(lowerToLLVM(module
))
241 execution_engine
.register_runtime("some_callback_into_python", callback
)
242 inp_arr
= np
.array([[1.0, 5.0], [6.0, 7.0]], np
.float32
)
243 # CHECK: Inside Callback:
244 # CHECK{LITERAL}: [[1. 5.]
245 # CHECK{LITERAL}: [6. 7.]]
246 execution_engine
.invoke(
248 ctypes
.pointer(ctypes
.pointer(get_ranked_memref_descriptor(inp_arr
))),
252 run(testRankedMemRefCallback
)
255 # Test callback with a ranked memref with non-zero offset.
256 # CHECK-LABEL: TEST: testRankedMemRefWithOffsetCallback
257 def testRankedMemRefWithOffsetCallback():
258 # Define a callback function that takes a ranked memref, converts it to a numpy array and prints it.
262 make_nd_memref_descriptor(1, np
.ctypeslib
.as_ctypes_type(np
.float32
))
266 arr
= ranked_memref_to_numpy(a
)
267 log("Inside Callback: ")
271 # The module takes a subview of the argument memref and calls the callback with it
272 module
= Module
.parse(
274 func.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} {
275 %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref<f32>, index, index, index
276 %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref<f32> to memref<2xf32, strided<[1], offset: 3>>
277 %cast = memref.cast %reinterpret_cast : memref<2xf32, strided<[1], offset: 3>> to memref<?xf32, strided<[?], offset: ?>>
278 call @some_callback_into_python(%cast) : (memref<?xf32, strided<[?], offset: ?>>) -> ()
281 func.func private @some_callback_into_python(memref<?xf32, strided<[?], offset: ?>>) attributes {llvm.emit_c_interface}
284 execution_engine
= ExecutionEngine(lowerToLLVM(module
))
285 execution_engine
.register_runtime("some_callback_into_python", callback
)
286 inp_arr
= np
.array([0, 0, 0, 1, 2], np
.float32
)
287 # CHECK: Inside Callback:
288 # CHECK{LITERAL}: [1. 2.]
289 execution_engine
.invoke(
291 ctypes
.pointer(ctypes
.pointer(get_ranked_memref_descriptor(inp_arr
))),
295 run(testRankedMemRefWithOffsetCallback
)
298 # Test callback with an unranked memref with non-zero offset
299 # CHECK-LABEL: TEST: testUnrankedMemRefWithOffsetCallback
300 def testUnrankedMemRefWithOffsetCallback():
301 # Define a callback function that takes an unranked memref, converts it to a numpy array and prints it.
302 @ctypes.CFUNCTYPE(None, ctypes
.POINTER(UnrankedMemRefDescriptor
))
304 arr
= unranked_memref_to_numpy(a
, np
.float32
)
305 log("Inside callback: ")
309 # The module takes a subview of the argument memref, casts it to an unranked memref and
310 # calls the callback with it.
311 module
= Module
.parse(
313 func.func @callback_memref(%arg0: memref<5xf32>) attributes {llvm.emit_c_interface} {
314 %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg0 : memref<5xf32> -> memref<f32>, index, index, index
315 %reinterpret_cast = memref.reinterpret_cast %base_buffer to offset: [3], sizes: [2], strides: [1] : memref<f32> to memref<2xf32, strided<[1], offset: 3>>
316 %cast = memref.cast %reinterpret_cast : memref<2xf32, strided<[1], offset: 3>> to memref<*xf32>
317 call @some_callback_into_python(%cast) : (memref<*xf32>) -> ()
320 func.func private @some_callback_into_python(memref<*xf32>) attributes {llvm.emit_c_interface}
323 execution_engine
= ExecutionEngine(lowerToLLVM(module
))
324 execution_engine
.register_runtime("some_callback_into_python", callback
)
325 inp_arr
= np
.array([1, 2, 3, 4, 5], np
.float32
)
326 # CHECK: Inside callback:
327 # CHECK{LITERAL}: [4. 5.]
328 execution_engine
.invoke(
330 ctypes
.pointer(ctypes
.pointer(get_ranked_memref_descriptor(inp_arr
))),
333 run(testUnrankedMemRefWithOffsetCallback
)
336 # Test addition of two memrefs.
337 # CHECK-LABEL: TEST: testMemrefAdd
340 module
= Module
.parse(
343 func.func @main(%arg0: memref<1xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) attributes { llvm.emit_c_interface } {
344 %0 = arith.constant 0 : index
345 %1 = memref.load %arg0[%0] : memref<1xf32>
346 %2 = memref.load %arg1[] : memref<f32>
347 %3 = arith.addf %1, %2 : f32
348 memref.store %3, %arg2[%0] : memref<1xf32>
353 arg1
= np
.array([32.5]).astype(np
.float32
)
354 arg2
= np
.array(6).astype(np
.float32
)
355 res
= np
.array([0]).astype(np
.float32
)
357 arg1_memref_ptr
= ctypes
.pointer(
358 ctypes
.pointer(get_ranked_memref_descriptor(arg1
))
360 arg2_memref_ptr
= ctypes
.pointer(
361 ctypes
.pointer(get_ranked_memref_descriptor(arg2
))
363 res_memref_ptr
= ctypes
.pointer(
364 ctypes
.pointer(get_ranked_memref_descriptor(res
))
367 execution_engine
= ExecutionEngine(lowerToLLVM(module
))
368 execution_engine
.invoke(
369 "main", arg1_memref_ptr
, arg2_memref_ptr
, res_memref_ptr
371 # CHECK: [32.5] + 6.0 = [38.5]
372 log("{0} + {1} = {2}".format(arg1
, arg2
, res
))
378 # Test addition of two f16 memrefs
379 # CHECK-LABEL: TEST: testF16MemrefAdd
380 def testF16MemrefAdd():
382 module
= Module
.parse(
385 func.func @main(%arg0: memref<1xf16>,
386 %arg1: memref<1xf16>,
387 %arg2: memref<1xf16>) attributes { llvm.emit_c_interface } {
388 %0 = arith.constant 0 : index
389 %1 = memref.load %arg0[%0] : memref<1xf16>
390 %2 = memref.load %arg1[%0] : memref<1xf16>
391 %3 = arith.addf %1, %2 : f16
392 memref.store %3, %arg2[%0] : memref<1xf16>
398 arg1
= np
.array([11.0]).astype(np
.float16
)
399 arg2
= np
.array([22.0]).astype(np
.float16
)
400 arg3
= np
.array([0.0]).astype(np
.float16
)
402 arg1_memref_ptr
= ctypes
.pointer(
403 ctypes
.pointer(get_ranked_memref_descriptor(arg1
))
405 arg2_memref_ptr
= ctypes
.pointer(
406 ctypes
.pointer(get_ranked_memref_descriptor(arg2
))
408 arg3_memref_ptr
= ctypes
.pointer(
409 ctypes
.pointer(get_ranked_memref_descriptor(arg3
))
412 execution_engine
= ExecutionEngine(lowerToLLVM(module
))
413 execution_engine
.invoke(
414 "main", arg1_memref_ptr
, arg2_memref_ptr
, arg3_memref_ptr
416 # CHECK: [11.] + [22.] = [33.]
417 log("{0} + {1} = {2}".format(arg1
, arg2
, arg3
))
419 # test to-numpy utility
421 npout
= ranked_memref_to_numpy(arg3_memref_ptr
[0])
425 run(testF16MemrefAdd
)
428 # Test addition of two complex memrefs
429 # CHECK-LABEL: TEST: testComplexMemrefAdd
430 def testComplexMemrefAdd():
432 module
= Module
.parse(
435 func.func @main(%arg0: memref<1xcomplex<f64>>,
436 %arg1: memref<1xcomplex<f64>>,
437 %arg2: memref<1xcomplex<f64>>) attributes { llvm.emit_c_interface } {
438 %0 = arith.constant 0 : index
439 %1 = memref.load %arg0[%0] : memref<1xcomplex<f64>>
440 %2 = memref.load %arg1[%0] : memref<1xcomplex<f64>>
441 %3 = complex.add %1, %2 : complex<f64>
442 memref.store %3, %arg2[%0] : memref<1xcomplex<f64>>
448 arg1
= np
.array([1.0 + 2.0j
]).astype(np
.complex128
)
449 arg2
= np
.array([3.0 + 4.0j
]).astype(np
.complex128
)
450 arg3
= np
.array([0.0 + 0.0j
]).astype(np
.complex128
)
452 arg1_memref_ptr
= ctypes
.pointer(
453 ctypes
.pointer(get_ranked_memref_descriptor(arg1
))
455 arg2_memref_ptr
= ctypes
.pointer(
456 ctypes
.pointer(get_ranked_memref_descriptor(arg2
))
458 arg3_memref_ptr
= ctypes
.pointer(
459 ctypes
.pointer(get_ranked_memref_descriptor(arg3
))
462 execution_engine
= ExecutionEngine(lowerToLLVM(module
))
463 execution_engine
.invoke(
464 "main", arg1_memref_ptr
, arg2_memref_ptr
, arg3_memref_ptr
466 # CHECK: [1.+2.j] + [3.+4.j] = [4.+6.j]
467 log("{0} + {1} = {2}".format(arg1
, arg2
, arg3
))
469 # test to-numpy utility
471 npout
= ranked_memref_to_numpy(arg3_memref_ptr
[0])
475 run(testComplexMemrefAdd
)
478 # Test addition of two complex unranked memrefs
479 # CHECK-LABEL: TEST: testComplexUnrankedMemrefAdd
480 def testComplexUnrankedMemrefAdd():
482 module
= Module
.parse(
485 func.func @main(%arg0: memref<*xcomplex<f32>>,
486 %arg1: memref<*xcomplex<f32>>,
487 %arg2: memref<*xcomplex<f32>>) attributes { llvm.emit_c_interface } {
488 %A = memref.cast %arg0 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
489 %B = memref.cast %arg1 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
490 %C = memref.cast %arg2 : memref<*xcomplex<f32>> to memref<1xcomplex<f32>>
491 %0 = arith.constant 0 : index
492 %1 = memref.load %A[%0] : memref<1xcomplex<f32>>
493 %2 = memref.load %B[%0] : memref<1xcomplex<f32>>
494 %3 = complex.add %1, %2 : complex<f32>
495 memref.store %3, %C[%0] : memref<1xcomplex<f32>>
501 arg1
= np
.array([5.0 + 6.0j
]).astype(np
.complex64
)
502 arg2
= np
.array([7.0 + 8.0j
]).astype(np
.complex64
)
503 arg3
= np
.array([0.0 + 0.0j
]).astype(np
.complex64
)
505 arg1_memref_ptr
= ctypes
.pointer(
506 ctypes
.pointer(get_unranked_memref_descriptor(arg1
))
508 arg2_memref_ptr
= ctypes
.pointer(
509 ctypes
.pointer(get_unranked_memref_descriptor(arg2
))
511 arg3_memref_ptr
= ctypes
.pointer(
512 ctypes
.pointer(get_unranked_memref_descriptor(arg3
))
515 execution_engine
= ExecutionEngine(lowerToLLVM(module
))
516 execution_engine
.invoke(
517 "main", arg1_memref_ptr
, arg2_memref_ptr
, arg3_memref_ptr
519 # CHECK: [5.+6.j] + [7.+8.j] = [12.+14.j]
520 log("{0} + {1} = {2}".format(arg1
, arg2
, arg3
))
522 # test to-numpy utility
524 npout
= unranked_memref_to_numpy(arg3_memref_ptr
[0], np
.dtype(np
.complex64
))
528 run(testComplexUnrankedMemrefAdd
)
532 # CHECK-LABEL: TEST: testBF16Memref
533 def testBF16Memref():
535 module
= Module
.parse(
538 func.func @main(%arg0: memref<1xbf16>,
539 %arg1: memref<1xbf16>) attributes { llvm.emit_c_interface } {
540 %0 = arith.constant 0 : index
541 %1 = memref.load %arg0[%0] : memref<1xbf16>
542 memref.store %1, %arg1[%0] : memref<1xbf16>
548 arg1
= np
.array([0.5]).astype(bfloat16
)
549 arg2
= np
.array([0.0]).astype(bfloat16
)
551 arg1_memref_ptr
= ctypes
.pointer(
552 ctypes
.pointer(get_ranked_memref_descriptor(arg1
))
554 arg2_memref_ptr
= ctypes
.pointer(
555 ctypes
.pointer(get_ranked_memref_descriptor(arg2
))
558 execution_engine
= ExecutionEngine(lowerToLLVM(module
))
559 execution_engine
.invoke("main", arg1_memref_ptr
, arg2_memref_ptr
)
561 # test to-numpy utility
563 npout
= ranked_memref_to_numpy(arg2_memref_ptr
[0])
570 # Test f8E5M2 memrefs
571 # CHECK-LABEL: TEST: testF8E5M2Memref
572 def testF8E5M2Memref():
574 module
= Module
.parse(
577 func.func @main(%arg0: memref<1xf8E5M2>,
578 %arg1: memref<1xf8E5M2>) attributes { llvm.emit_c_interface } {
579 %0 = arith.constant 0 : index
580 %1 = memref.load %arg0[%0] : memref<1xf8E5M2>
581 memref.store %1, %arg1[%0] : memref<1xf8E5M2>
587 arg1
= np
.array([0.5]).astype(float8_e5m2
)
588 arg2
= np
.array([0.0]).astype(float8_e5m2
)
590 arg1_memref_ptr
= ctypes
.pointer(
591 ctypes
.pointer(get_ranked_memref_descriptor(arg1
))
593 arg2_memref_ptr
= ctypes
.pointer(
594 ctypes
.pointer(get_ranked_memref_descriptor(arg2
))
597 execution_engine
= ExecutionEngine(lowerToLLVM(module
))
598 execution_engine
.invoke("main", arg1_memref_ptr
, arg2_memref_ptr
)
600 # test to-numpy utility
602 npout
= ranked_memref_to_numpy(arg2_memref_ptr
[0])
606 run(testF8E5M2Memref
)
609 # Test addition of two 2d_memref
610 # CHECK-LABEL: TEST: testDynamicMemrefAdd2D
611 def testDynamicMemrefAdd2D():
613 module
= Module
.parse(
616 func.func @memref_add_2d(%arg0: memref<2x2xf32>, %arg1: memref<?x?xf32>, %arg2: memref<2x2xf32>) attributes {llvm.emit_c_interface} {
617 %c0 = arith.constant 0 : index
618 %c2 = arith.constant 2 : index
619 %c1 = arith.constant 1 : index
620 cf.br ^bb1(%c0 : index)
621 ^bb1(%0: index): // 2 preds: ^bb0, ^bb5
622 %1 = arith.cmpi slt, %0, %c2 : index
623 cf.cond_br %1, ^bb2, ^bb6
625 %c0_0 = arith.constant 0 : index
626 %c2_1 = arith.constant 2 : index
627 %c1_2 = arith.constant 1 : index
628 cf.br ^bb3(%c0_0 : index)
629 ^bb3(%2: index): // 2 preds: ^bb2, ^bb4
630 %3 = arith.cmpi slt, %2, %c2_1 : index
631 cf.cond_br %3, ^bb4, ^bb5
633 %4 = memref.load %arg0[%0, %2] : memref<2x2xf32>
634 %5 = memref.load %arg1[%0, %2] : memref<?x?xf32>
635 %6 = arith.addf %4, %5 : f32
636 memref.store %6, %arg2[%0, %2] : memref<2x2xf32>
637 %7 = arith.addi %2, %c1_2 : index
638 cf.br ^bb3(%7 : index)
640 %8 = arith.addi %0, %c1 : index
641 cf.br ^bb1(%8 : index)
648 arg1
= np
.random
.randn(2, 2).astype(np
.float32
)
649 arg2
= np
.random
.randn(2, 2).astype(np
.float32
)
650 res
= np
.random
.randn(2, 2).astype(np
.float32
)
652 arg1_memref_ptr
= ctypes
.pointer(
653 ctypes
.pointer(get_ranked_memref_descriptor(arg1
))
655 arg2_memref_ptr
= ctypes
.pointer(
656 ctypes
.pointer(get_ranked_memref_descriptor(arg2
))
658 res_memref_ptr
= ctypes
.pointer(
659 ctypes
.pointer(get_ranked_memref_descriptor(res
))
662 execution_engine
= ExecutionEngine(lowerToLLVM(module
))
663 execution_engine
.invoke(
664 "memref_add_2d", arg1_memref_ptr
, arg2_memref_ptr
, res_memref_ptr
667 log(np
.allclose(arg1
+ arg2
, res
))
670 run(testDynamicMemrefAdd2D
)
673 # Test loading of shared libraries.
674 # CHECK-LABEL: TEST: testSharedLibLoad
675 def testSharedLibLoad():
677 module
= Module
.parse(
680 func.func @main(%arg0: memref<1xf32>) attributes { llvm.emit_c_interface } {
681 %c0 = arith.constant 0 : index
682 %cst42 = arith.constant 42.0 : f32
683 memref.store %cst42, %arg0[%c0] : memref<1xf32>
684 %u_memref = memref.cast %arg0 : memref<1xf32> to memref<*xf32>
685 call @printMemrefF32(%u_memref) : (memref<*xf32>) -> ()
688 func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interface }
691 arg0
= np
.array([0.0]).astype(np
.float32
)
693 arg0_memref_ptr
= ctypes
.pointer(
694 ctypes
.pointer(get_ranked_memref_descriptor(arg0
))
697 if sys
.platform
== "win32":
699 "../../../../bin/mlir_runner_utils.dll",
700 "../../../../bin/mlir_c_runner_utils.dll",
702 elif sys
.platform
== "darwin":
704 "../../../../lib/libmlir_runner_utils.dylib",
705 "../../../../lib/libmlir_c_runner_utils.dylib",
713 execution_engine
= ExecutionEngine(
714 lowerToLLVM(module
), opt_level
=3, shared_libs
=shared_libs
716 execution_engine
.invoke("main", arg0_memref_ptr
)
717 # CHECK: Unranked Memref
721 run(testSharedLibLoad
)
724 # Test that nano time clock is available.
725 # CHECK-LABEL: TEST: testNanoTime
728 module
= Module
.parse(
731 func.func @main() attributes { llvm.emit_c_interface } {
732 %now = call @nanoTime() : () -> i64
733 %memref = memref.alloca() : memref<1xi64>
734 %c0 = arith.constant 0 : index
735 memref.store %now, %memref[%c0] : memref<1xi64>
736 %u_memref = memref.cast %memref : memref<1xi64> to memref<*xi64>
737 call @printMemrefI64(%u_memref) : (memref<*xi64>) -> ()
740 func.func private @nanoTime() -> i64 attributes { llvm.emit_c_interface }
741 func.func private @printMemrefI64(memref<*xi64>) attributes { llvm.emit_c_interface }
745 if sys
.platform
== "win32":
747 "../../../../bin/mlir_runner_utils.dll",
748 "../../../../bin/mlir_c_runner_utils.dll",
756 execution_engine
= ExecutionEngine(
757 lowerToLLVM(module
), opt_level
=3, shared_libs
=shared_libs
759 execution_engine
.invoke("main")
760 # CHECK: Unranked Memref
767 # Test that nano time clock is available.
768 # CHECK-LABEL: TEST: testDumpToObjectFile
769 def testDumpToObjectFile():
770 fd
, object_path
= tempfile
.mkstemp(suffix
=".o")
774 module
= Module
.parse(
777 func.func @main() attributes { llvm.emit_c_interface } {
783 execution_engine
= ExecutionEngine(lowerToLLVM(module
), opt_level
=3)
785 # CHECK: Object file exists: True
786 print(f
"Object file exists: {os.path.exists(object_path)}")
787 # CHECK: Object file is empty: True
788 print(f
"Object file is empty: {os.path.getsize(object_path) == 0}")
790 execution_engine
.dump_to_object_file(object_path
)
792 # CHECK: Object file exists: True
793 print(f
"Object file exists: {os.path.exists(object_path)}")
794 # CHECK: Object file is empty: False
795 print(f
"Object file is empty: {os.path.getsize(object_path) == 0}")
799 os
.remove(object_path
)
802 run(testDumpToObjectFile
)