1 # RUN: %PYTHON %s | FileCheck %s
7 from mlir
.dialects
.builtin
import ModuleOp
8 from mlir
.dialects
import arith
9 from mlir
.dialects
._ods
_common
import _cext
13 print("\nTEST:", f
.__name
__)
16 assert Context
._get
_live
_count
() == 0
20 def expect_index_error(callback
):
23 raise RuntimeError("Expected IndexError")
28 # Verify iterator based traversal of the op/region/block hierarchy.
29 # CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
31 def testTraverseOpRegionBlockIterators():
33 ctx
.allow_unregistered_dialects
= True
34 module
= Module
.parse(
36 func.func @f1(%arg0: i32) -> i32 {
37 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
44 assert op
.context
is ctx
45 # Get the block using iterators off of the named collections.
46 regions
= list(op
.regions
)
47 blocks
= list(regions
[0].blocks
)
48 # CHECK: MODULE REGIONS=1 BLOCKS=1
49 print(f
"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
52 # CHECK: .verify = True
53 print(f
".verify = {module.operation.verify()}")
55 # Get the blocks from the default collection.
56 default_blocks
= list(regions
[0])
57 # They should compare equal regardless of how obtained.
58 assert default_blocks
== blocks
60 # Should be able to get the operations from either the named collection
62 operations
= list(blocks
[0].operations
)
63 default_operations
= list(blocks
[0])
64 assert default_operations
== operations
66 def walk_operations(indent
, op
):
67 for i
, region
in enumerate(op
.regions
):
68 print(f
"{indent}REGION {i}:")
69 for j
, block
in enumerate(region
):
70 print(f
"{indent} BLOCK {j}:")
71 for k
, child_op
in enumerate(block
):
72 print(f
"{indent} OP {k}: {child_op}")
73 walk_operations(indent
+ " ", child_op
)
80 # CHECK: OP 0: %0 = "custom.addi"
81 # CHECK: OP 1: func.return
82 walk_operations("", op
)
84 # CHECK: Region iter: <mlir.{{.+}}.RegionIterator
85 # CHECK: Block iter: <mlir.{{.+}}.BlockIterator
86 # CHECK: Operation iter: <mlir.{{.+}}.OperationIterator
87 print(" Region iter:", iter(op
.regions
))
88 print(" Block iter:", iter(op
.regions
[0]))
89 print("Operation iter:", iter(op
.regions
[0].blocks
[0]))
92 # Verify index based traversal of the op/region/block hierarchy.
93 # CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices
95 def testTraverseOpRegionBlockIndices():
97 ctx
.allow_unregistered_dialects
= True
98 module
= Module
.parse(
100 func.func @f1(%arg0: i32) -> i32 {
101 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
108 def walk_operations(indent
, op
):
109 for i
in range(len(op
.regions
)):
110 region
= op
.regions
[i
]
111 print(f
"{indent}REGION {i}:")
112 for j
in range(len(region
.blocks
)):
113 block
= region
.blocks
[j
]
114 print(f
"{indent} BLOCK {j}:")
115 for k
in range(len(block
.operations
)):
116 child_op
= block
.operations
[k
]
117 print(f
"{indent} OP {k}: {child_op}")
119 f
"{indent} OP {k}: parent {child_op.operation.parent.name}"
121 walk_operations(indent
+ " ", child_op
)
126 # CHECK: OP 0: parent builtin.module
129 # CHECK: OP 0: %0 = "custom.addi"
130 # CHECK: OP 0: parent func.func
131 # CHECK: OP 1: func.return
132 # CHECK: OP 1: parent func.func
133 walk_operations("", module
.operation
)
136 # CHECK-LABEL: TEST: testBlockAndRegionOwners
138 def testBlockAndRegionOwners():
140 ctx
.allow_unregistered_dialects
= True
141 module
= Module
.parse(
152 assert module
.operation
.regions
[0].owner
== module
.operation
153 assert module
.operation
.regions
[0].blocks
[0].owner
== module
.operation
155 func
= module
.body
.operations
[0]
156 assert func
.operation
.regions
[0].owner
== func
157 assert func
.operation
.regions
[0].blocks
[0].owner
== func
160 # CHECK-LABEL: TEST: testBlockArgumentList
162 def testBlockArgumentList():
163 with
Context() as ctx
:
164 module
= Module
.parse(
166 func.func @f1(%arg0: i32, %arg1: f64, %arg2: index) {
172 func
= module
.body
.operations
[0]
173 entry_block
= func
.regions
[0].blocks
[0]
174 assert len(entry_block
.arguments
) == 3
175 # CHECK: Argument 0, type i32
176 # CHECK: Argument 1, type f64
177 # CHECK: Argument 2, type index
178 for arg
in entry_block
.arguments
:
179 print(f
"Argument {arg.arg_number}, type {arg.type}")
180 new_type
= IntegerType
.get_signless(8 * (arg
.arg_number
+ 1))
181 arg
.set_type(new_type
)
183 # CHECK: Argument 0, type i8
184 # CHECK: Argument 1, type i16
185 # CHECK: Argument 2, type i24
186 for arg
in entry_block
.arguments
:
187 print(f
"Argument {arg.arg_number}, type {arg.type}")
189 # Check that slicing works for block argument lists.
190 # CHECK: Argument 1, type i16
191 # CHECK: Argument 2, type i24
192 for arg
in entry_block
.arguments
[1:]:
193 print(f
"Argument {arg.arg_number}, type {arg.type}")
195 # Check that we can concatenate slices of argument lists.
197 print("Length: ", len(entry_block
.arguments
[:2] + entry_block
.arguments
[1:]))
202 for t
in entry_block
.arguments
.types
:
205 # Check that slicing and type access compose.
206 # CHECK: Sliced type: i16
207 # CHECK: Sliced type: i24
208 for t
in entry_block
.arguments
[1:].types
:
209 print("Sliced type: ", t
)
211 # Check that slice addition works as expected.
212 # CHECK: Argument 2, type i24
213 # CHECK: Argument 0, type i8
214 restructured
= entry_block
.arguments
[-1:] + entry_block
.arguments
[:1]
215 for arg
in restructured
:
216 print(f
"Argument {arg.arg_number}, type {arg.type}")
219 # CHECK-LABEL: TEST: testOperationOperands
221 def testOperationOperands():
222 with
Context() as ctx
:
223 ctx
.allow_unregistered_dialects
= True
224 module
= Module
.parse(
226 func.func @f1(%arg0: i32) {
227 %0 = "test.producer"() : () -> i64
228 "test.consumer"(%arg0, %0) : (i32, i64) -> ()
232 func
= module
.body
.operations
[0]
233 entry_block
= func
.regions
[0].blocks
[0]
234 consumer
= entry_block
.operations
[1]
235 assert len(consumer
.operands
) == 2
236 # CHECK: Operand 0, type i32
237 # CHECK: Operand 1, type i64
238 for i
, operand
in enumerate(consumer
.operands
):
239 print(f
"Operand {i}, type {operand.type}")
242 # CHECK-LABEL: TEST: testOperationOperandsSlice
244 def testOperationOperandsSlice():
245 with
Context() as ctx
:
246 ctx
.allow_unregistered_dialects
= True
247 module
= Module
.parse(
250 %0 = "test.producer0"() : () -> i64
251 %1 = "test.producer1"() : () -> i64
252 %2 = "test.producer2"() : () -> i64
253 %3 = "test.producer3"() : () -> i64
254 %4 = "test.producer4"() : () -> i64
255 "test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> ()
259 func
= module
.body
.operations
[0]
260 entry_block
= func
.regions
[0].blocks
[0]
261 consumer
= entry_block
.operations
[5]
262 assert len(consumer
.operands
) == 5
263 for left
, right
in zip(consumer
.operands
, consumer
.operands
[::-1][::-1]):
266 # CHECK: test.producer0
267 # CHECK: test.producer1
268 # CHECK: test.producer2
269 # CHECK: test.producer3
270 # CHECK: test.producer4
271 full_slice
= consumer
.operands
[:]
272 for operand
in full_slice
:
275 # CHECK: test.producer0
276 # CHECK: test.producer1
277 first_two
= consumer
.operands
[0:2]
278 for operand
in first_two
:
281 # CHECK: test.producer3
282 # CHECK: test.producer4
283 last_two
= consumer
.operands
[3:]
284 for operand
in last_two
:
287 # CHECK: test.producer0
288 # CHECK: test.producer2
289 # CHECK: test.producer4
290 even
= consumer
.operands
[::2]
294 # CHECK: test.producer2
295 fourth
= consumer
.operands
[::2][1::2]
296 for operand
in fourth
:
300 # CHECK-LABEL: TEST: testOperationOperandsSet
302 def testOperationOperandsSet():
303 with
Context() as ctx
, Location
.unknown(ctx
):
304 ctx
.allow_unregistered_dialects
= True
305 module
= Module
.parse(
308 %0 = "test.producer0"() : () -> i64
309 %1 = "test.producer1"() : () -> i64
310 %2 = "test.producer2"() : () -> i64
311 "test.consumer"(%0) : (i64) -> ()
315 func
= module
.body
.operations
[0]
316 entry_block
= func
.regions
[0].blocks
[0]
317 producer1
= entry_block
.operations
[1]
318 producer2
= entry_block
.operations
[2]
319 consumer
= entry_block
.operations
[3]
320 assert len(consumer
.operands
) == 1
321 type = consumer
.operands
[0].type
323 # CHECK: test.producer1
324 consumer
.operands
[0] = producer1
.result
325 print(consumer
.operands
[0])
327 # CHECK: test.producer2
328 consumer
.operands
[-1] = producer2
.result
329 print(consumer
.operands
[0])
332 # CHECK-LABEL: TEST: testDetachedOperation
334 def testDetachedOperation():
336 ctx
.allow_unregistered_dialects
= True
337 with Location
.unknown(ctx
):
338 i32
= IntegerType
.get_signed(32)
339 op1
= Operation
.create(
344 "foo": StringAttr
.get("foo_value"),
345 "bar": StringAttr
.get("bar_value"),
348 # CHECK: %0:2 = "custom.op1"() ({
349 # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
352 # TODO: Check successors once enough infra exists to do it properly.
355 # CHECK-LABEL: TEST: testOperationInsertionPoint
357 def testOperationInsertionPoint():
359 ctx
.allow_unregistered_dialects
= True
360 module
= Module
.parse(
362 func.func @f1(%arg0: i32) -> i32 {
363 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
371 with Location
.unknown(ctx
):
372 op1
= Operation
.create("custom.op1")
373 op2
= Operation
.create("custom.op2")
375 func
= module
.body
.operations
[0]
376 entry_block
= func
.regions
[0].blocks
[0]
377 ip
= InsertionPoint
.at_block_begin(entry_block
)
381 # CHECK: "custom.op1"()
382 # CHECK: "custom.op2"()
383 # CHECK: %0 = "custom.addi"
386 # Trying to add a previously added op should raise.
392 assert False, "expected insert of attached op to raise"
395 # CHECK-LABEL: TEST: testOperationWithRegion
397 def testOperationWithRegion():
399 ctx
.allow_unregistered_dialects
= True
400 with Location
.unknown(ctx
):
401 i32
= IntegerType
.get_signed(32)
402 op1
= Operation
.create("custom.op1", regions
=1)
403 block
= op1
.regions
[0].blocks
.append(i32
, i32
)
404 # CHECK: "custom.op1"() ({
405 # CHECK: ^bb0(%arg0: si32, %arg1: si32):
406 # CHECK: "custom.terminator"() : () -> ()
407 # CHECK: }) : () -> ()
408 terminator
= Operation
.create("custom.terminator")
409 ip
= InsertionPoint(block
)
410 ip
.insert(terminator
)
413 # Now add the whole operation to another op.
414 # TODO: Verify lifetime hazard by nulling out the new owning module and
416 # TODO: Also verify accessing the terminator once both parents are nulled
418 module
= Module
.parse(
420 func.func @f1(%arg0: i32) -> i32 {
421 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
426 func
= module
.body
.operations
[0]
427 entry_block
= func
.regions
[0].blocks
[0]
428 ip
= InsertionPoint
.at_block_begin(entry_block
)
431 # CHECK: "custom.op1"()
432 # CHECK: "custom.terminator"
433 # CHECK: %0 = "custom.addi"
437 # CHECK-LABEL: TEST: testOperationResultList
439 def testOperationResultList():
441 module
= Module
.parse(
444 %0:3 = call @f2() : () -> (i32, f64, index)
445 call @f3() : () -> ()
448 func.func private @f2() -> (i32, f64, index)
449 func.func private @f3() -> ()
453 caller
= module
.body
.operations
[0]
454 call
= caller
.regions
[0].blocks
[0].operations
[0]
455 assert len(call
.results
) == 3
456 # CHECK: Result 0, type i32
457 # CHECK: Result 1, type f64
458 # CHECK: Result 2, type index
459 for res
in call
.results
:
460 print(f
"Result {res.result_number}, type {res.type}")
462 # CHECK: Result type i32
463 # CHECK: Result type f64
464 # CHECK: Result type index
465 for t
in call
.results
.types
:
466 print(f
"Result type {t}")
469 expect_index_error(lambda: call
.results
[3])
470 expect_index_error(lambda: call
.results
[-4])
472 no_results_call
= caller
.regions
[0].blocks
[0].operations
[1]
473 assert len(no_results_call
.results
) == 0
474 assert no_results_call
.results
.owner
== no_results_call
477 # CHECK-LABEL: TEST: testOperationResultListSlice
479 def testOperationResultListSlice():
480 with
Context() as ctx
:
481 ctx
.allow_unregistered_dialects
= True
482 module
= Module
.parse(
485 "some.op"() : () -> (i1, i2, i3, i4, i5)
490 func
= module
.body
.operations
[0]
491 entry_block
= func
.regions
[0].blocks
[0]
492 producer
= entry_block
.operations
[0]
494 assert len(producer
.results
) == 5
495 for left
, right
in zip(producer
.results
, producer
.results
[::-1][::-1]):
497 assert left
.result_number
== right
.result_number
499 # CHECK: Result 0, type i1
500 # CHECK: Result 1, type i2
501 # CHECK: Result 2, type i3
502 # CHECK: Result 3, type i4
503 # CHECK: Result 4, type i5
504 full_slice
= producer
.results
[:]
505 for res
in full_slice
:
506 print(f
"Result {res.result_number}, type {res.type}")
508 # CHECK: Result 1, type i2
509 # CHECK: Result 2, type i3
510 # CHECK: Result 3, type i4
511 middle
= producer
.results
[1:4]
513 print(f
"Result {res.result_number}, type {res.type}")
515 # CHECK: Result 1, type i2
516 # CHECK: Result 3, type i4
517 odd
= producer
.results
[1::2]
519 print(f
"Result {res.result_number}, type {res.type}")
521 # CHECK: Result 3, type i4
522 # CHECK: Result 1, type i2
523 inverted_middle
= producer
.results
[-2:0:-2]
524 for res
in inverted_middle
:
525 print(f
"Result {res.result_number}, type {res.type}")
528 # CHECK-LABEL: TEST: testOperationAttributes
530 def testOperationAttributes():
532 ctx
.allow_unregistered_dialects
= True
533 module
= Module
.parse(
535 "some.op"() { some.attribute = 1 : i8,
536 other.attribute = 3.0,
537 dependent = "text" } : () -> ()
541 op
= module
.body
.operations
[0]
542 assert len(op
.attributes
) == 3
543 iattr
= op
.attributes
["some.attribute"]
544 fattr
= op
.attributes
["other.attribute"]
545 sattr
= op
.attributes
["dependent"]
546 # CHECK: Attribute type i8, value 1
547 print(f
"Attribute type {iattr.type}, value {iattr.value}")
548 # CHECK: Attribute type f64, value 3.0
549 print(f
"Attribute type {fattr.type}, value {fattr.value}")
550 # CHECK: Attribute value text
551 print(f
"Attribute value {sattr.value}")
552 # CHECK: Attribute value b'text'
553 print(f
"Attribute value {sattr.value_bytes}")
555 # We don't know in which order the attributes are stored.
556 # CHECK-DAG: NamedAttribute(dependent="text")
557 # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
558 # CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
559 for attr
in op
.attributes
:
562 # Check that exceptions are raised as expected.
564 op
.attributes
["does_not_exist"]
568 assert False, "expected KeyError on accessing a non-existent attribute"
575 assert False, "expected IndexError on accessing an out-of-bounds attribute"
578 # CHECK-LABEL: TEST: testOperationPrint
580 def testOperationPrint():
582 module
= Module
.parse(
584 func.func @f1(%arg0: i32) -> i32 {
585 %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
592 # Test print to stdout.
593 # CHECK: return %arg0 : i32
594 module
.operation
.print()
596 # Test print to text file.
598 # CHECK: <class 'str'>
599 # CHECK: return %arg0 : i32
600 module
.operation
.print(file=f
)
601 str_value
= f
.getvalue()
602 print(str_value
.__class
__)
605 # Test roundtrip to bytecode.
606 bytecode_stream
= io
.BytesIO()
607 module
.operation
.write_bytecode(bytecode_stream
, desired_version
=1)
608 bytecode
= bytecode_stream
.getvalue()
609 assert bytecode
.startswith(b
"ML\xefR"), "Expected bytecode to start with MLïR"
610 module_roundtrip
= Module
.parse(bytecode
, ctx
)
612 module_roundtrip
.operation
.print(file=f
)
613 roundtrip_value
= f
.getvalue()
614 assert str_value
== roundtrip_value
, "Mismatch after roundtrip bytecode"
616 # Test print to binary file.
618 # CHECK: <class 'bytes'>
619 # CHECK: return %arg0 : i32
620 module
.operation
.print(file=f
, binary
=True)
621 bytes_value
= f
.getvalue()
622 print(bytes_value
.__class
__)
625 # Test print local_scope.
626 # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
627 module
.operation
.print(enable_debug_info
=True, use_local_scope
=True)
629 # Test printing using state.
630 state
= AsmState(module
.operation
)
631 # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32>
632 module
.operation
.print(state
)
634 # Test print with options.
635 # CHECK: value = dense_resource<__elided__> : tensor<4xi32>
636 # CHECK: "func.return"(%arg0) : (i32) -> () -:4:7
637 module
.operation
.print(
638 large_elements_limit
=2,
639 enable_debug_info
=True,
640 pretty_debug_info
=True,
641 print_generic_op_form
=True,
642 use_local_scope
=True,
645 # Test print with skip_regions option
646 # CHECK: func.func @f1(%arg0: i32) -> i32
647 # CHECK-NOT: func.return
648 module
.body
.operations
[0].print(
653 # CHECK-LABEL: TEST: testKnownOpView
655 def testKnownOpView():
656 with
Context(), Location
.unknown():
657 Context
.current
.allow_unregistered_dialects
= True
658 module
= Module
.parse(
660 %1 = "custom.f32"() : () -> f32
661 %2 = "custom.f32"() : () -> f32
662 %3 = arith.addf %1, %2 : f32
663 %4 = arith.constant 0 : i32
668 # addf should map to a known OpView class in the arithmetic dialect.
669 # We know the OpView for it defines an 'lhs' attribute.
670 addf
= module
.body
.operations
[2]
671 # CHECK: <mlir.dialects._arith_ops_gen.AddFOp object
673 # CHECK: "custom.f32"()
676 # One of the custom ops should resolve to the default OpView.
677 custom
= module
.body
.operations
[0]
678 # CHECK: OpView object
681 # Check again to make sure negative caching works.
682 custom
= module
.body
.operations
[0]
683 # CHECK: OpView object
686 # constant should map to an extension OpView class in the arithmetic dialect.
687 constant
= module
.body
.operations
[3]
688 # CHECK: <mlir.dialects.arith.ConstantOp object
689 print(repr(constant
))
690 # Checks that the arith extension is being registered successfully
691 # (literal_value is a property on the extension class but not on the default OpView).
692 # CHECK: literal value 0
693 print("literal value", constant
.literal_value
)
695 # Checks that "late" registration/replacement (i.e., post all module loading/initialization)
696 # is working correctly.
697 @_cext.register_operation(arith
._Dialect
, replace
=True)
698 class ConstantOp(arith
.ConstantOp
):
699 def __init__(self
, result
, value
, *, loc
=None, ip
=None):
700 if isinstance(value
, int):
701 super().__init
__(IntegerAttr
.get(result
, value
), loc
=loc
, ip
=ip
)
702 elif isinstance(value
, float):
703 super().__init
__(FloatAttr
.get(result
, value
), loc
=loc
, ip
=ip
)
705 super().__init
__(value
, loc
=loc
, ip
=ip
)
707 constant
= module
.body
.operations
[3]
708 # CHECK: <__main__.testKnownOpView.<locals>.ConstantOp object
709 print(repr(constant
))
712 # CHECK-LABEL: TEST: testSingleResultProperty
714 def testSingleResultProperty():
715 with
Context(), Location
.unknown():
716 Context
.current
.allow_unregistered_dialects
= True
717 module
= Module
.parse(
719 "custom.no_result"() : () -> ()
720 %0:2 = "custom.two_result"() : () -> (f32, f32)
721 %1 = "custom.one_result"() : () -> f32
727 module
.body
.operations
[0].result
728 except ValueError as e
:
729 # CHECK: Cannot call .result on operation custom.no_result which has 0 results
732 assert False, "Expected exception"
735 module
.body
.operations
[1].result
736 except ValueError as e
:
737 # CHECK: Cannot call .result on operation custom.two_result which has 2 results
740 assert False, "Expected exception"
742 # CHECK: %1 = "custom.one_result"() : () -> f32
743 print(module
.body
.operations
[2])
746 def create_invalid_operation():
747 # This module has two region and is invalid verify that we fallback
748 # to the generic printer for safety.
749 op
= Operation
.create("builtin.module", regions
=2)
750 op
.regions
[0].blocks
.append()
754 # CHECK-LABEL: TEST: testInvalidOperationStrSoftFails
756 def testInvalidOperationStrSoftFails():
758 with Location
.unknown(ctx
):
759 invalid_op
= create_invalid_operation()
760 # Verify that we fallback to the generic printer for safety.
761 # CHECK: "builtin.module"() ({
762 # CHECK: }) : () -> ()
766 except MLIRError
as e
:
767 # CHECK: Exception: <
768 # CHECK: Verification failed:
769 # CHECK: error: unknown: 'builtin.module' op requires one region
770 # CHECK: note: unknown: see current operation:
771 # CHECK: "builtin.module"() ({
774 # CHECK: }) : () -> ()
776 print(f
"Exception: <{e}>")
779 # CHECK-LABEL: TEST: testInvalidModuleStrSoftFails
781 def testInvalidModuleStrSoftFails():
783 with Location
.unknown(ctx
):
784 module
= Module
.create()
785 with
InsertionPoint(module
.body
):
786 invalid_op
= create_invalid_operation()
787 # Verify that we fallback to the generic printer for safety.
788 # CHECK: "builtin.module"() ({
789 # CHECK: }) : () -> ()
793 # CHECK-LABEL: TEST: testInvalidOperationGetAsmBinarySoftFails
795 def testInvalidOperationGetAsmBinarySoftFails():
797 with Location
.unknown(ctx
):
798 invalid_op
= create_invalid_operation()
799 # Verify that we fallback to the generic printer for safety.
800 # CHECK: b'"builtin.module"() ({\n^bb0:\n}, {\n}) : () -> ()\n'
801 print(invalid_op
.get_asm(binary
=True))
804 # CHECK-LABEL: TEST: testCreateWithInvalidAttributes
806 def testCreateWithInvalidAttributes():
808 with Location
.unknown(ctx
):
811 "builtin.module", attributes
={None: StringAttr
.get("name")}
813 except Exception as e
:
814 # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
817 Operation
.create("builtin.module", attributes
={42: StringAttr
.get("name")})
818 except Exception as e
:
819 # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
822 Operation
.create("builtin.module", attributes
={"some_key": ctx
})
823 except Exception as e
:
824 # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module"
827 Operation
.create("builtin.module", attributes
={"some_key": None})
828 except Exception as e
:
829 # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module"
833 # CHECK-LABEL: TEST: testOperationName
835 def testOperationName():
837 ctx
.allow_unregistered_dialects
= True
838 module
= Module
.parse(
840 %0 = "custom.op1"() : () -> f32
841 %1 = "custom.op2"() : () -> i32
842 %2 = "custom.op1"() : () -> f32
850 for op
in module
.body
.operations
:
851 print(op
.operation
.name
)
854 # CHECK-LABEL: TEST: testCapsuleConversions
856 def testCapsuleConversions():
858 ctx
.allow_unregistered_dialects
= True
859 with Location
.unknown(ctx
):
860 m
= Operation
.create("custom.op1").operation
861 m_capsule
= m
._CAPIPtr
862 assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule
)
863 m2
= Operation
._CAPICreate
(m_capsule
)
867 # CHECK-LABEL: TEST: testOperationErase
869 def testOperationErase():
871 ctx
.allow_unregistered_dialects
= True
872 with Location
.unknown(ctx
):
874 with
InsertionPoint(m
.body
):
875 op
= Operation
.create("custom.op1")
877 # CHECK: "custom.op1"
882 # CHECK-NOT: "custom.op1"
885 # Ensure we can create another operation
886 Operation
.create("custom.op2")
889 # CHECK-LABEL: TEST: testOperationClone
891 def testOperationClone():
893 ctx
.allow_unregistered_dialects
= True
894 with Location
.unknown(ctx
):
896 with
InsertionPoint(m
.body
):
897 op
= Operation
.create("custom.op1")
899 # CHECK: "custom.op1"
902 clone
= op
.operation
.clone()
905 # CHECK: "custom.op1"
909 # CHECK-LABEL: TEST: testOperationLoc
911 def testOperationLoc():
913 ctx
.allow_unregistered_dialects
= True
915 loc
= Location
.name("loc")
916 op
= Operation
.create("custom.op", loc
=loc
)
917 assert op
.location
== loc
918 assert op
.operation
.location
== loc
921 # CHECK-LABEL: TEST: testModuleMerge
923 def testModuleMerge():
925 m1
= Module
.parse("func.func private @foo()")
928 func.func private @bar()
929 func.func private @qux()
932 foo
= m1
.body
.operations
[0]
933 bar
= m2
.body
.operations
[0]
934 qux
= m2
.body
.operations
[1]
939 # CHECK: func private @bar
940 # CHECK: func private @foo
941 # CHECK: func private @qux
949 # CHECK-LABEL: TEST: testAppendMoveFromAnotherBlock
951 def testAppendMoveFromAnotherBlock():
953 m1
= Module
.parse("func.func private @foo()")
954 m2
= Module
.parse("func.func private @bar()")
955 func
= m1
.body
.operations
[0]
959 # CHECK: func private @bar
960 # CHECK: func private @foo
968 # CHECK-LABEL: TEST: testDetachFromParent
970 def testDetachFromParent():
972 m1
= Module
.parse("func.func private @foo()")
973 func
= m1
.body
.operations
[0].detach_from_parent()
976 func
.detach_from_parent()
977 except ValueError as e
:
978 if "has no parent" not in str(e
):
981 assert False, "expected ValueError when detaching a detached operation"
984 # CHECK-NOT: func private @foo
987 # CHECK-LABEL: TEST: testOperationHash
989 def testOperationHash():
991 ctx
.allow_unregistered_dialects
= True
992 with ctx
, Location
.unknown():
993 op
= Operation
.create("custom.op1")
994 assert hash(op
) == hash(op
.operation
)
997 # CHECK-LABEL: TEST: testOperationParse
999 def testOperationParse():
1000 with
Context() as ctx
:
1001 ctx
.allow_unregistered_dialects
= True
1003 # Generic operation parsing.
1004 m
= Operation
.parse("module {}")
1005 o
= Operation
.parse('"test.foo"() : () -> ()')
1006 assert isinstance(m
, ModuleOp
)
1007 assert type(o
) is OpView
1009 # Parsing specific operation.
1010 m
= ModuleOp
.parse("module {}")
1011 assert isinstance(m
, ModuleOp
)
1013 ModuleOp
.parse('"test.foo"() : () -> ()')
1014 except MLIRError
as e
:
1015 # CHECK: error: Expected a 'builtin.module' op, got: 'test.foo'
1016 print(f
"error: {e}")
1018 assert False, "expected error"
1020 o
= Operation
.parse('"test.foo"() : () -> ()', source_name
="my-source-string")
1021 # CHECK: op_with_source_name: "test.foo"() : () -> () loc("my-source-string":1:1)
1023 f
"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
1027 # CHECK-LABEL: TEST: testOpWalk
1031 ctx
.allow_unregistered_dialects
= True
1032 module
= Module
.parse(
1045 return WalkResult
.ADVANCE
1047 # Test post-order walk (default).
1048 # CHECK-NEXT: Post-order
1049 # CHECK-NEXT: func.return
1050 # CHECK-NEXT: func.func
1051 # CHECK-NEXT: builtin.module
1053 module
.operation
.walk(callback
)
1055 # Test pre-order walk.
1056 # CHECK-NEXT: Pre-order
1057 # CHECK-NEXT: builtin.module
1058 # CHECK-NEXT: func.fun
1059 # CHECK-NEXT: func.return
1061 module
.operation
.walk(callback
, WalkOrder
.PRE_ORDER
)
1064 # CHECK-NEXT: Interrupt post-order
1065 # CHECK-NEXT: func.return
1066 print("Interrupt post-order")
1070 return WalkResult
.INTERRUPT
1072 module
.operation
.walk(callback
)
1075 # CHECK-NEXT: Skip pre-order
1076 # CHECK-NEXT: builtin.module
1077 print("Skip pre-order")
1081 return WalkResult
.SKIP
1083 module
.operation
.walk(callback
, WalkOrder
.PRE_ORDER
)
1087 # CHECK-NEXT: func.return
1088 # CHECK-NEXT: Exception raised
1094 return WalkResult
.ADVANCE
1097 module
.operation
.walk(callback
)
1098 except RuntimeError:
1099 print("Exception raised")