[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / test / python / ir / operation.py
blobd5d72c98b66ad0ef9483983b19e0f173318b7796
1 # RUN: %PYTHON %s | FileCheck %s
3 import gc
4 import io
5 import itertools
6 from mlir.ir import *
7 from mlir.dialects.builtin import ModuleOp
8 from mlir.dialects import arith
9 from mlir.dialects._ods_common import _cext
12 def run(f):
13 print("\nTEST:", f.__name__)
14 f()
15 gc.collect()
16 assert Context._get_live_count() == 0
17 return f
20 def expect_index_error(callback):
21 try:
22 _ = callback()
23 raise RuntimeError("Expected IndexError")
24 except IndexError:
25 pass
28 # Verify iterator based traversal of the op/region/block hierarchy.
29 # CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
30 @run
31 def testTraverseOpRegionBlockIterators():
32 ctx = Context()
33 ctx.allow_unregistered_dialects = True
34 module = Module.parse(
35 r"""
36 func.func @f1(%arg0: i32) -> i32 {
37 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
38 return %1 : i32
40 """,
41 ctx,
43 op = module.operation
44 assert op.context is ctx
45 # Get the block using iterators off of the named collections.
46 regions = list(op.regions)
47 blocks = list(regions[0].blocks)
48 # CHECK: MODULE REGIONS=1 BLOCKS=1
49 print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
51 # Should verify.
52 # CHECK: .verify = True
53 print(f".verify = {module.operation.verify()}")
55 # Get the blocks from the default collection.
56 default_blocks = list(regions[0])
57 # They should compare equal regardless of how obtained.
58 assert default_blocks == blocks
60 # Should be able to get the operations from either the named collection
61 # or the block.
62 operations = list(blocks[0].operations)
63 default_operations = list(blocks[0])
64 assert default_operations == operations
66 def walk_operations(indent, op):
67 for i, region in enumerate(op.regions):
68 print(f"{indent}REGION {i}:")
69 for j, block in enumerate(region):
70 print(f"{indent} BLOCK {j}:")
71 for k, child_op in enumerate(block):
72 print(f"{indent} OP {k}: {child_op}")
73 walk_operations(indent + " ", child_op)
75 # CHECK: REGION 0:
76 # CHECK: BLOCK 0:
77 # CHECK: OP 0: func
78 # CHECK: REGION 0:
79 # CHECK: BLOCK 0:
80 # CHECK: OP 0: %0 = "custom.addi"
81 # CHECK: OP 1: func.return
82 walk_operations("", op)
84 # CHECK: Region iter: <mlir.{{.+}}.RegionIterator
85 # CHECK: Block iter: <mlir.{{.+}}.BlockIterator
86 # CHECK: Operation iter: <mlir.{{.+}}.OperationIterator
87 print(" Region iter:", iter(op.regions))
88 print(" Block iter:", iter(op.regions[0]))
89 print("Operation iter:", iter(op.regions[0].blocks[0]))
92 # Verify index based traversal of the op/region/block hierarchy.
93 # CHECK-LABEL: TEST: testTraverseOpRegionBlockIndices
94 @run
95 def testTraverseOpRegionBlockIndices():
96 ctx = Context()
97 ctx.allow_unregistered_dialects = True
98 module = Module.parse(
99 r"""
100 func.func @f1(%arg0: i32) -> i32 {
101 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
102 return %1 : i32
104 """,
105 ctx,
108 def walk_operations(indent, op):
109 for i in range(len(op.regions)):
110 region = op.regions[i]
111 print(f"{indent}REGION {i}:")
112 for j in range(len(region.blocks)):
113 block = region.blocks[j]
114 print(f"{indent} BLOCK {j}:")
115 for k in range(len(block.operations)):
116 child_op = block.operations[k]
117 print(f"{indent} OP {k}: {child_op}")
118 print(
119 f"{indent} OP {k}: parent {child_op.operation.parent.name}"
121 walk_operations(indent + " ", child_op)
123 # CHECK: REGION 0:
124 # CHECK: BLOCK 0:
125 # CHECK: OP 0: func
126 # CHECK: OP 0: parent builtin.module
127 # CHECK: REGION 0:
128 # CHECK: BLOCK 0:
129 # CHECK: OP 0: %0 = "custom.addi"
130 # CHECK: OP 0: parent func.func
131 # CHECK: OP 1: func.return
132 # CHECK: OP 1: parent func.func
133 walk_operations("", module.operation)
136 # CHECK-LABEL: TEST: testBlockAndRegionOwners
137 @run
138 def testBlockAndRegionOwners():
139 ctx = Context()
140 ctx.allow_unregistered_dialects = True
141 module = Module.parse(
142 r"""
143 builtin.module {
144 func.func @f() {
145 func.return
148 """,
149 ctx,
152 assert module.operation.regions[0].owner == module.operation
153 assert module.operation.regions[0].blocks[0].owner == module.operation
155 func = module.body.operations[0]
156 assert func.operation.regions[0].owner == func
157 assert func.operation.regions[0].blocks[0].owner == func
160 # CHECK-LABEL: TEST: testBlockArgumentList
161 @run
162 def testBlockArgumentList():
163 with Context() as ctx:
164 module = Module.parse(
165 r"""
166 func.func @f1(%arg0: i32, %arg1: f64, %arg2: index) {
167 return
169 """,
170 ctx,
172 func = module.body.operations[0]
173 entry_block = func.regions[0].blocks[0]
174 assert len(entry_block.arguments) == 3
175 # CHECK: Argument 0, type i32
176 # CHECK: Argument 1, type f64
177 # CHECK: Argument 2, type index
178 for arg in entry_block.arguments:
179 print(f"Argument {arg.arg_number}, type {arg.type}")
180 new_type = IntegerType.get_signless(8 * (arg.arg_number + 1))
181 arg.set_type(new_type)
183 # CHECK: Argument 0, type i8
184 # CHECK: Argument 1, type i16
185 # CHECK: Argument 2, type i24
186 for arg in entry_block.arguments:
187 print(f"Argument {arg.arg_number}, type {arg.type}")
189 # Check that slicing works for block argument lists.
190 # CHECK: Argument 1, type i16
191 # CHECK: Argument 2, type i24
192 for arg in entry_block.arguments[1:]:
193 print(f"Argument {arg.arg_number}, type {arg.type}")
195 # Check that we can concatenate slices of argument lists.
196 # CHECK: Length: 4
197 print("Length: ", len(entry_block.arguments[:2] + entry_block.arguments[1:]))
199 # CHECK: Type: i8
200 # CHECK: Type: i16
201 # CHECK: Type: i24
202 for t in entry_block.arguments.types:
203 print("Type: ", t)
205 # Check that slicing and type access compose.
206 # CHECK: Sliced type: i16
207 # CHECK: Sliced type: i24
208 for t in entry_block.arguments[1:].types:
209 print("Sliced type: ", t)
211 # Check that slice addition works as expected.
212 # CHECK: Argument 2, type i24
213 # CHECK: Argument 0, type i8
214 restructured = entry_block.arguments[-1:] + entry_block.arguments[:1]
215 for arg in restructured:
216 print(f"Argument {arg.arg_number}, type {arg.type}")
219 # CHECK-LABEL: TEST: testOperationOperands
220 @run
221 def testOperationOperands():
222 with Context() as ctx:
223 ctx.allow_unregistered_dialects = True
224 module = Module.parse(
225 r"""
226 func.func @f1(%arg0: i32) {
227 %0 = "test.producer"() : () -> i64
228 "test.consumer"(%arg0, %0) : (i32, i64) -> ()
229 return
230 }"""
232 func = module.body.operations[0]
233 entry_block = func.regions[0].blocks[0]
234 consumer = entry_block.operations[1]
235 assert len(consumer.operands) == 2
236 # CHECK: Operand 0, type i32
237 # CHECK: Operand 1, type i64
238 for i, operand in enumerate(consumer.operands):
239 print(f"Operand {i}, type {operand.type}")
242 # CHECK-LABEL: TEST: testOperationOperandsSlice
243 @run
244 def testOperationOperandsSlice():
245 with Context() as ctx:
246 ctx.allow_unregistered_dialects = True
247 module = Module.parse(
248 r"""
249 func.func @f1() {
250 %0 = "test.producer0"() : () -> i64
251 %1 = "test.producer1"() : () -> i64
252 %2 = "test.producer2"() : () -> i64
253 %3 = "test.producer3"() : () -> i64
254 %4 = "test.producer4"() : () -> i64
255 "test.consumer"(%0, %1, %2, %3, %4) : (i64, i64, i64, i64, i64) -> ()
256 return
257 }"""
259 func = module.body.operations[0]
260 entry_block = func.regions[0].blocks[0]
261 consumer = entry_block.operations[5]
262 assert len(consumer.operands) == 5
263 for left, right in zip(consumer.operands, consumer.operands[::-1][::-1]):
264 assert left == right
266 # CHECK: test.producer0
267 # CHECK: test.producer1
268 # CHECK: test.producer2
269 # CHECK: test.producer3
270 # CHECK: test.producer4
271 full_slice = consumer.operands[:]
272 for operand in full_slice:
273 print(operand)
275 # CHECK: test.producer0
276 # CHECK: test.producer1
277 first_two = consumer.operands[0:2]
278 for operand in first_two:
279 print(operand)
281 # CHECK: test.producer3
282 # CHECK: test.producer4
283 last_two = consumer.operands[3:]
284 for operand in last_two:
285 print(operand)
287 # CHECK: test.producer0
288 # CHECK: test.producer2
289 # CHECK: test.producer4
290 even = consumer.operands[::2]
291 for operand in even:
292 print(operand)
294 # CHECK: test.producer2
295 fourth = consumer.operands[::2][1::2]
296 for operand in fourth:
297 print(operand)
300 # CHECK-LABEL: TEST: testOperationOperandsSet
301 @run
302 def testOperationOperandsSet():
303 with Context() as ctx, Location.unknown(ctx):
304 ctx.allow_unregistered_dialects = True
305 module = Module.parse(
306 r"""
307 func.func @f1() {
308 %0 = "test.producer0"() : () -> i64
309 %1 = "test.producer1"() : () -> i64
310 %2 = "test.producer2"() : () -> i64
311 "test.consumer"(%0) : (i64) -> ()
312 return
313 }"""
315 func = module.body.operations[0]
316 entry_block = func.regions[0].blocks[0]
317 producer1 = entry_block.operations[1]
318 producer2 = entry_block.operations[2]
319 consumer = entry_block.operations[3]
320 assert len(consumer.operands) == 1
321 type = consumer.operands[0].type
323 # CHECK: test.producer1
324 consumer.operands[0] = producer1.result
325 print(consumer.operands[0])
327 # CHECK: test.producer2
328 consumer.operands[-1] = producer2.result
329 print(consumer.operands[0])
332 # CHECK-LABEL: TEST: testDetachedOperation
333 @run
334 def testDetachedOperation():
335 ctx = Context()
336 ctx.allow_unregistered_dialects = True
337 with Location.unknown(ctx):
338 i32 = IntegerType.get_signed(32)
339 op1 = Operation.create(
340 "custom.op1",
341 results=[i32, i32],
342 regions=1,
343 attributes={
344 "foo": StringAttr.get("foo_value"),
345 "bar": StringAttr.get("bar_value"),
348 # CHECK: %0:2 = "custom.op1"() ({
349 # CHECK: }) {bar = "bar_value", foo = "foo_value"} : () -> (si32, si32)
350 print(op1)
352 # TODO: Check successors once enough infra exists to do it properly.
355 # CHECK-LABEL: TEST: testOperationInsertionPoint
356 @run
357 def testOperationInsertionPoint():
358 ctx = Context()
359 ctx.allow_unregistered_dialects = True
360 module = Module.parse(
361 r"""
362 func.func @f1(%arg0: i32) -> i32 {
363 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
364 return %1 : i32
366 """,
367 ctx,
370 # Create test op.
371 with Location.unknown(ctx):
372 op1 = Operation.create("custom.op1")
373 op2 = Operation.create("custom.op2")
375 func = module.body.operations[0]
376 entry_block = func.regions[0].blocks[0]
377 ip = InsertionPoint.at_block_begin(entry_block)
378 ip.insert(op1)
379 ip.insert(op2)
380 # CHECK: func @f1
381 # CHECK: "custom.op1"()
382 # CHECK: "custom.op2"()
383 # CHECK: %0 = "custom.addi"
384 print(module)
386 # Trying to add a previously added op should raise.
387 try:
388 ip.insert(op1)
389 except ValueError:
390 pass
391 else:
392 assert False, "expected insert of attached op to raise"
395 # CHECK-LABEL: TEST: testOperationWithRegion
396 @run
397 def testOperationWithRegion():
398 ctx = Context()
399 ctx.allow_unregistered_dialects = True
400 with Location.unknown(ctx):
401 i32 = IntegerType.get_signed(32)
402 op1 = Operation.create("custom.op1", regions=1)
403 block = op1.regions[0].blocks.append(i32, i32)
404 # CHECK: "custom.op1"() ({
405 # CHECK: ^bb0(%arg0: si32, %arg1: si32):
406 # CHECK: "custom.terminator"() : () -> ()
407 # CHECK: }) : () -> ()
408 terminator = Operation.create("custom.terminator")
409 ip = InsertionPoint(block)
410 ip.insert(terminator)
411 print(op1)
413 # Now add the whole operation to another op.
414 # TODO: Verify lifetime hazard by nulling out the new owning module and
415 # accessing op1.
416 # TODO: Also verify accessing the terminator once both parents are nulled
417 # out.
418 module = Module.parse(
419 r"""
420 func.func @f1(%arg0: i32) -> i32 {
421 %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
422 return %1 : i32
426 func = module.body.operations[0]
427 entry_block = func.regions[0].blocks[0]
428 ip = InsertionPoint.at_block_begin(entry_block)
429 ip.insert(op1)
430 # CHECK: func @f1
431 # CHECK: "custom.op1"()
432 # CHECK: "custom.terminator"
433 # CHECK: %0 = "custom.addi"
434 print(module)
437 # CHECK-LABEL: TEST: testOperationResultList
438 @run
439 def testOperationResultList():
440 ctx = Context()
441 module = Module.parse(
442 r"""
443 func.func @f1() {
444 %0:3 = call @f2() : () -> (i32, f64, index)
445 call @f3() : () -> ()
446 return
448 func.func private @f2() -> (i32, f64, index)
449 func.func private @f3() -> ()
450 """,
451 ctx,
453 caller = module.body.operations[0]
454 call = caller.regions[0].blocks[0].operations[0]
455 assert len(call.results) == 3
456 # CHECK: Result 0, type i32
457 # CHECK: Result 1, type f64
458 # CHECK: Result 2, type index
459 for res in call.results:
460 print(f"Result {res.result_number}, type {res.type}")
462 # CHECK: Result type i32
463 # CHECK: Result type f64
464 # CHECK: Result type index
465 for t in call.results.types:
466 print(f"Result type {t}")
468 # Out of range
469 expect_index_error(lambda: call.results[3])
470 expect_index_error(lambda: call.results[-4])
472 no_results_call = caller.regions[0].blocks[0].operations[1]
473 assert len(no_results_call.results) == 0
474 assert no_results_call.results.owner == no_results_call
477 # CHECK-LABEL: TEST: testOperationResultListSlice
478 @run
479 def testOperationResultListSlice():
480 with Context() as ctx:
481 ctx.allow_unregistered_dialects = True
482 module = Module.parse(
483 r"""
484 func.func @f1() {
485 "some.op"() : () -> (i1, i2, i3, i4, i5)
486 return
490 func = module.body.operations[0]
491 entry_block = func.regions[0].blocks[0]
492 producer = entry_block.operations[0]
494 assert len(producer.results) == 5
495 for left, right in zip(producer.results, producer.results[::-1][::-1]):
496 assert left == right
497 assert left.result_number == right.result_number
499 # CHECK: Result 0, type i1
500 # CHECK: Result 1, type i2
501 # CHECK: Result 2, type i3
502 # CHECK: Result 3, type i4
503 # CHECK: Result 4, type i5
504 full_slice = producer.results[:]
505 for res in full_slice:
506 print(f"Result {res.result_number}, type {res.type}")
508 # CHECK: Result 1, type i2
509 # CHECK: Result 2, type i3
510 # CHECK: Result 3, type i4
511 middle = producer.results[1:4]
512 for res in middle:
513 print(f"Result {res.result_number}, type {res.type}")
515 # CHECK: Result 1, type i2
516 # CHECK: Result 3, type i4
517 odd = producer.results[1::2]
518 for res in odd:
519 print(f"Result {res.result_number}, type {res.type}")
521 # CHECK: Result 3, type i4
522 # CHECK: Result 1, type i2
523 inverted_middle = producer.results[-2:0:-2]
524 for res in inverted_middle:
525 print(f"Result {res.result_number}, type {res.type}")
528 # CHECK-LABEL: TEST: testOperationAttributes
529 @run
530 def testOperationAttributes():
531 ctx = Context()
532 ctx.allow_unregistered_dialects = True
533 module = Module.parse(
534 r"""
535 "some.op"() { some.attribute = 1 : i8,
536 other.attribute = 3.0,
537 dependent = "text" } : () -> ()
538 """,
539 ctx,
541 op = module.body.operations[0]
542 assert len(op.attributes) == 3
543 iattr = op.attributes["some.attribute"]
544 fattr = op.attributes["other.attribute"]
545 sattr = op.attributes["dependent"]
546 # CHECK: Attribute type i8, value 1
547 print(f"Attribute type {iattr.type}, value {iattr.value}")
548 # CHECK: Attribute type f64, value 3.0
549 print(f"Attribute type {fattr.type}, value {fattr.value}")
550 # CHECK: Attribute value text
551 print(f"Attribute value {sattr.value}")
552 # CHECK: Attribute value b'text'
553 print(f"Attribute value {sattr.value_bytes}")
555 # We don't know in which order the attributes are stored.
556 # CHECK-DAG: NamedAttribute(dependent="text")
557 # CHECK-DAG: NamedAttribute(other.attribute=3.000000e+00 : f64)
558 # CHECK-DAG: NamedAttribute(some.attribute=1 : i8)
559 for attr in op.attributes:
560 print(str(attr))
562 # Check that exceptions are raised as expected.
563 try:
564 op.attributes["does_not_exist"]
565 except KeyError:
566 pass
567 else:
568 assert False, "expected KeyError on accessing a non-existent attribute"
570 try:
571 op.attributes[42]
572 except IndexError:
573 pass
574 else:
575 assert False, "expected IndexError on accessing an out-of-bounds attribute"
578 # CHECK-LABEL: TEST: testOperationPrint
579 @run
580 def testOperationPrint():
581 ctx = Context()
582 module = Module.parse(
583 r"""
584 func.func @f1(%arg0: i32) -> i32 {
585 %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
586 return %arg0 : i32
588 """,
589 ctx,
592 # Test print to stdout.
593 # CHECK: return %arg0 : i32
594 module.operation.print()
596 # Test print to text file.
597 f = io.StringIO()
598 # CHECK: <class 'str'>
599 # CHECK: return %arg0 : i32
600 module.operation.print(file=f)
601 str_value = f.getvalue()
602 print(str_value.__class__)
603 print(f.getvalue())
605 # Test roundtrip to bytecode.
606 bytecode_stream = io.BytesIO()
607 module.operation.write_bytecode(bytecode_stream, desired_version=1)
608 bytecode = bytecode_stream.getvalue()
609 assert bytecode.startswith(b"ML\xefR"), "Expected bytecode to start with MLïR"
610 module_roundtrip = Module.parse(bytecode, ctx)
611 f = io.StringIO()
612 module_roundtrip.operation.print(file=f)
613 roundtrip_value = f.getvalue()
614 assert str_value == roundtrip_value, "Mismatch after roundtrip bytecode"
616 # Test print to binary file.
617 f = io.BytesIO()
618 # CHECK: <class 'bytes'>
619 # CHECK: return %arg0 : i32
620 module.operation.print(file=f, binary=True)
621 bytes_value = f.getvalue()
622 print(bytes_value.__class__)
623 print(bytes_value)
625 # Test print local_scope.
626 # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32> loc("nom")
627 module.operation.print(enable_debug_info=True, use_local_scope=True)
629 # Test printing using state.
630 state = AsmState(module.operation)
631 # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32>
632 module.operation.print(state)
634 # Test print with options.
635 # CHECK: value = dense_resource<__elided__> : tensor<4xi32>
636 # CHECK: "func.return"(%arg0) : (i32) -> () -:4:7
637 module.operation.print(
638 large_elements_limit=2,
639 enable_debug_info=True,
640 pretty_debug_info=True,
641 print_generic_op_form=True,
642 use_local_scope=True,
645 # Test print with skip_regions option
646 # CHECK: func.func @f1(%arg0: i32) -> i32
647 # CHECK-NOT: func.return
648 module.body.operations[0].print(
649 skip_regions=True,
653 # CHECK-LABEL: TEST: testKnownOpView
654 @run
655 def testKnownOpView():
656 with Context(), Location.unknown():
657 Context.current.allow_unregistered_dialects = True
658 module = Module.parse(
659 r"""
660 %1 = "custom.f32"() : () -> f32
661 %2 = "custom.f32"() : () -> f32
662 %3 = arith.addf %1, %2 : f32
663 %4 = arith.constant 0 : i32
666 print(module)
668 # addf should map to a known OpView class in the arithmetic dialect.
669 # We know the OpView for it defines an 'lhs' attribute.
670 addf = module.body.operations[2]
671 # CHECK: <mlir.dialects._arith_ops_gen.AddFOp object
672 print(repr(addf))
673 # CHECK: "custom.f32"()
674 print(addf.lhs)
676 # One of the custom ops should resolve to the default OpView.
677 custom = module.body.operations[0]
678 # CHECK: OpView object
679 print(repr(custom))
681 # Check again to make sure negative caching works.
682 custom = module.body.operations[0]
683 # CHECK: OpView object
684 print(repr(custom))
686 # constant should map to an extension OpView class in the arithmetic dialect.
687 constant = module.body.operations[3]
688 # CHECK: <mlir.dialects.arith.ConstantOp object
689 print(repr(constant))
690 # Checks that the arith extension is being registered successfully
691 # (literal_value is a property on the extension class but not on the default OpView).
692 # CHECK: literal value 0
693 print("literal value", constant.literal_value)
695 # Checks that "late" registration/replacement (i.e., post all module loading/initialization)
696 # is working correctly.
697 @_cext.register_operation(arith._Dialect, replace=True)
698 class ConstantOp(arith.ConstantOp):
699 def __init__(self, result, value, *, loc=None, ip=None):
700 if isinstance(value, int):
701 super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
702 elif isinstance(value, float):
703 super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
704 else:
705 super().__init__(value, loc=loc, ip=ip)
707 constant = module.body.operations[3]
708 # CHECK: <__main__.testKnownOpView.<locals>.ConstantOp object
709 print(repr(constant))
712 # CHECK-LABEL: TEST: testSingleResultProperty
713 @run
714 def testSingleResultProperty():
715 with Context(), Location.unknown():
716 Context.current.allow_unregistered_dialects = True
717 module = Module.parse(
718 r"""
719 "custom.no_result"() : () -> ()
720 %0:2 = "custom.two_result"() : () -> (f32, f32)
721 %1 = "custom.one_result"() : () -> f32
724 print(module)
726 try:
727 module.body.operations[0].result
728 except ValueError as e:
729 # CHECK: Cannot call .result on operation custom.no_result which has 0 results
730 print(e)
731 else:
732 assert False, "Expected exception"
734 try:
735 module.body.operations[1].result
736 except ValueError as e:
737 # CHECK: Cannot call .result on operation custom.two_result which has 2 results
738 print(e)
739 else:
740 assert False, "Expected exception"
742 # CHECK: %1 = "custom.one_result"() : () -> f32
743 print(module.body.operations[2])
746 def create_invalid_operation():
747 # This module has two region and is invalid verify that we fallback
748 # to the generic printer for safety.
749 op = Operation.create("builtin.module", regions=2)
750 op.regions[0].blocks.append()
751 return op
754 # CHECK-LABEL: TEST: testInvalidOperationStrSoftFails
755 @run
756 def testInvalidOperationStrSoftFails():
757 ctx = Context()
758 with Location.unknown(ctx):
759 invalid_op = create_invalid_operation()
760 # Verify that we fallback to the generic printer for safety.
761 # CHECK: "builtin.module"() ({
762 # CHECK: }) : () -> ()
763 print(invalid_op)
764 try:
765 invalid_op.verify()
766 except MLIRError as e:
767 # CHECK: Exception: <
768 # CHECK: Verification failed:
769 # CHECK: error: unknown: 'builtin.module' op requires one region
770 # CHECK: note: unknown: see current operation:
771 # CHECK: "builtin.module"() ({
772 # CHECK: ^bb0:
773 # CHECK: }, {
774 # CHECK: }) : () -> ()
775 # CHECK: >
776 print(f"Exception: <{e}>")
779 # CHECK-LABEL: TEST: testInvalidModuleStrSoftFails
780 @run
781 def testInvalidModuleStrSoftFails():
782 ctx = Context()
783 with Location.unknown(ctx):
784 module = Module.create()
785 with InsertionPoint(module.body):
786 invalid_op = create_invalid_operation()
787 # Verify that we fallback to the generic printer for safety.
788 # CHECK: "builtin.module"() ({
789 # CHECK: }) : () -> ()
790 print(module)
793 # CHECK-LABEL: TEST: testInvalidOperationGetAsmBinarySoftFails
794 @run
795 def testInvalidOperationGetAsmBinarySoftFails():
796 ctx = Context()
797 with Location.unknown(ctx):
798 invalid_op = create_invalid_operation()
799 # Verify that we fallback to the generic printer for safety.
800 # CHECK: b'"builtin.module"() ({\n^bb0:\n}, {\n}) : () -> ()\n'
801 print(invalid_op.get_asm(binary=True))
804 # CHECK-LABEL: TEST: testCreateWithInvalidAttributes
805 @run
806 def testCreateWithInvalidAttributes():
807 ctx = Context()
808 with Location.unknown(ctx):
809 try:
810 Operation.create(
811 "builtin.module", attributes={None: StringAttr.get("name")}
813 except Exception as e:
814 # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
815 print(e)
816 try:
817 Operation.create("builtin.module", attributes={42: StringAttr.get("name")})
818 except Exception as e:
819 # CHECK: Invalid attribute key (not a string) when attempting to create the operation "builtin.module"
820 print(e)
821 try:
822 Operation.create("builtin.module", attributes={"some_key": ctx})
823 except Exception as e:
824 # CHECK: Invalid attribute value for the key "some_key" when attempting to create the operation "builtin.module"
825 print(e)
826 try:
827 Operation.create("builtin.module", attributes={"some_key": None})
828 except Exception as e:
829 # CHECK: Found an invalid (`None`?) attribute value for the key "some_key" when attempting to create the operation "builtin.module"
830 print(e)
833 # CHECK-LABEL: TEST: testOperationName
834 @run
835 def testOperationName():
836 ctx = Context()
837 ctx.allow_unregistered_dialects = True
838 module = Module.parse(
839 r"""
840 %0 = "custom.op1"() : () -> f32
841 %1 = "custom.op2"() : () -> i32
842 %2 = "custom.op1"() : () -> f32
843 """,
844 ctx,
847 # CHECK: custom.op1
848 # CHECK: custom.op2
849 # CHECK: custom.op1
850 for op in module.body.operations:
851 print(op.operation.name)
854 # CHECK-LABEL: TEST: testCapsuleConversions
855 @run
856 def testCapsuleConversions():
857 ctx = Context()
858 ctx.allow_unregistered_dialects = True
859 with Location.unknown(ctx):
860 m = Operation.create("custom.op1").operation
861 m_capsule = m._CAPIPtr
862 assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
863 m2 = Operation._CAPICreate(m_capsule)
864 assert m2 is m
867 # CHECK-LABEL: TEST: testOperationErase
868 @run
869 def testOperationErase():
870 ctx = Context()
871 ctx.allow_unregistered_dialects = True
872 with Location.unknown(ctx):
873 m = Module.create()
874 with InsertionPoint(m.body):
875 op = Operation.create("custom.op1")
877 # CHECK: "custom.op1"
878 print(m)
880 op.operation.erase()
882 # CHECK-NOT: "custom.op1"
883 print(m)
885 # Ensure we can create another operation
886 Operation.create("custom.op2")
889 # CHECK-LABEL: TEST: testOperationClone
890 @run
891 def testOperationClone():
892 ctx = Context()
893 ctx.allow_unregistered_dialects = True
894 with Location.unknown(ctx):
895 m = Module.create()
896 with InsertionPoint(m.body):
897 op = Operation.create("custom.op1")
899 # CHECK: "custom.op1"
900 print(m)
902 clone = op.operation.clone()
903 op.operation.erase()
905 # CHECK: "custom.op1"
906 print(m)
909 # CHECK-LABEL: TEST: testOperationLoc
910 @run
911 def testOperationLoc():
912 ctx = Context()
913 ctx.allow_unregistered_dialects = True
914 with ctx:
915 loc = Location.name("loc")
916 op = Operation.create("custom.op", loc=loc)
917 assert op.location == loc
918 assert op.operation.location == loc
921 # CHECK-LABEL: TEST: testModuleMerge
922 @run
923 def testModuleMerge():
924 with Context():
925 m1 = Module.parse("func.func private @foo()")
926 m2 = Module.parse(
928 func.func private @bar()
929 func.func private @qux()
932 foo = m1.body.operations[0]
933 bar = m2.body.operations[0]
934 qux = m2.body.operations[1]
935 bar.move_before(foo)
936 qux.move_after(foo)
938 # CHECK: module
939 # CHECK: func private @bar
940 # CHECK: func private @foo
941 # CHECK: func private @qux
942 print(m1)
944 # CHECK: module {
945 # CHECK-NEXT: }
946 print(m2)
949 # CHECK-LABEL: TEST: testAppendMoveFromAnotherBlock
950 @run
951 def testAppendMoveFromAnotherBlock():
952 with Context():
953 m1 = Module.parse("func.func private @foo()")
954 m2 = Module.parse("func.func private @bar()")
955 func = m1.body.operations[0]
956 m2.body.append(func)
958 # CHECK: module
959 # CHECK: func private @bar
960 # CHECK: func private @foo
962 print(m2)
963 # CHECK: module {
964 # CHECK-NEXT: }
965 print(m1)
968 # CHECK-LABEL: TEST: testDetachFromParent
969 @run
970 def testDetachFromParent():
971 with Context():
972 m1 = Module.parse("func.func private @foo()")
973 func = m1.body.operations[0].detach_from_parent()
975 try:
976 func.detach_from_parent()
977 except ValueError as e:
978 if "has no parent" not in str(e):
979 raise
980 else:
981 assert False, "expected ValueError when detaching a detached operation"
983 print(m1)
984 # CHECK-NOT: func private @foo
987 # CHECK-LABEL: TEST: testOperationHash
988 @run
989 def testOperationHash():
990 ctx = Context()
991 ctx.allow_unregistered_dialects = True
992 with ctx, Location.unknown():
993 op = Operation.create("custom.op1")
994 assert hash(op) == hash(op.operation)
997 # CHECK-LABEL: TEST: testOperationParse
998 @run
999 def testOperationParse():
1000 with Context() as ctx:
1001 ctx.allow_unregistered_dialects = True
1003 # Generic operation parsing.
1004 m = Operation.parse("module {}")
1005 o = Operation.parse('"test.foo"() : () -> ()')
1006 assert isinstance(m, ModuleOp)
1007 assert type(o) is OpView
1009 # Parsing specific operation.
1010 m = ModuleOp.parse("module {}")
1011 assert isinstance(m, ModuleOp)
1012 try:
1013 ModuleOp.parse('"test.foo"() : () -> ()')
1014 except MLIRError as e:
1015 # CHECK: error: Expected a 'builtin.module' op, got: 'test.foo'
1016 print(f"error: {e}")
1017 else:
1018 assert False, "expected error"
1020 o = Operation.parse('"test.foo"() : () -> ()', source_name="my-source-string")
1021 # CHECK: op_with_source_name: "test.foo"() : () -> () loc("my-source-string":1:1)
1022 print(
1023 f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}"
1027 # CHECK-LABEL: TEST: testOpWalk
1028 @run
1029 def testOpWalk():
1030 ctx = Context()
1031 ctx.allow_unregistered_dialects = True
1032 module = Module.parse(
1033 r"""
1034 builtin.module {
1035 func.func @f() {
1036 func.return
1039 """,
1040 ctx,
1043 def callback(op):
1044 print(op.name)
1045 return WalkResult.ADVANCE
1047 # Test post-order walk (default).
1048 # CHECK-NEXT: Post-order
1049 # CHECK-NEXT: func.return
1050 # CHECK-NEXT: func.func
1051 # CHECK-NEXT: builtin.module
1052 print("Post-order")
1053 module.operation.walk(callback)
1055 # Test pre-order walk.
1056 # CHECK-NEXT: Pre-order
1057 # CHECK-NEXT: builtin.module
1058 # CHECK-NEXT: func.fun
1059 # CHECK-NEXT: func.return
1060 print("Pre-order")
1061 module.operation.walk(callback, WalkOrder.PRE_ORDER)
1063 # Test interrput.
1064 # CHECK-NEXT: Interrupt post-order
1065 # CHECK-NEXT: func.return
1066 print("Interrupt post-order")
1068 def callback(op):
1069 print(op.name)
1070 return WalkResult.INTERRUPT
1072 module.operation.walk(callback)
1074 # Test skip.
1075 # CHECK-NEXT: Skip pre-order
1076 # CHECK-NEXT: builtin.module
1077 print("Skip pre-order")
1079 def callback(op):
1080 print(op.name)
1081 return WalkResult.SKIP
1083 module.operation.walk(callback, WalkOrder.PRE_ORDER)
1085 # Test exception.
1086 # CHECK: Exception
1087 # CHECK-NEXT: func.return
1088 # CHECK-NEXT: Exception raised
1089 print("Exception")
1091 def callback(op):
1092 print(op.name)
1093 raise ValueError
1094 return WalkResult.ADVANCE
1096 try:
1097 module.operation.walk(callback)
1098 except RuntimeError:
1099 print("Exception raised")