[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / test / python / ir / symbol_table.py
blob577721ab2111f553fadcdd828629feac7ca0ba32
1 # RUN: %PYTHON %s | FileCheck %s
3 import gc
4 import io
5 import itertools
6 from mlir.ir import *
9 def run(f):
10 print("\nTEST:", f.__name__)
11 f()
12 gc.collect()
13 assert Context._get_live_count() == 0
14 return f
17 # CHECK-LABEL: TEST: testSymbolTableInsert
18 @run
19 def testSymbolTableInsert():
20 with Context() as ctx:
21 ctx.allow_unregistered_dialects = True
22 m1 = Module.parse(
23 """
24 func.func private @foo()
25 func.func private @bar()"""
27 m2 = Module.parse(
28 """
29 func.func private @qux()
30 func.func private @foo()
31 "foo.bar"() : () -> ()"""
34 symbol_table = SymbolTable(m1.operation)
36 # CHECK: func private @foo
37 # CHECK: func private @bar
38 assert "foo" in symbol_table
39 print(symbol_table["foo"])
40 assert "bar" in symbol_table
41 bar = symbol_table["bar"]
42 print(symbol_table["bar"])
44 assert "qux" not in symbol_table
46 del symbol_table["bar"]
47 try:
48 symbol_table.erase(symbol_table["bar"])
49 except KeyError:
50 pass
51 else:
52 assert False, "expected KeyError"
54 # CHECK: module
55 # CHECK: func private @foo()
56 print(m1)
57 assert "bar" not in symbol_table
59 try:
60 print(bar)
61 except RuntimeError as e:
62 if "the operation has been invalidated" not in str(e):
63 raise
64 else:
65 assert False, "expected RuntimeError due to invalidated operation"
67 qux = m2.body.operations[0]
68 m1.body.append(qux)
69 symbol_table.insert(qux)
70 assert "qux" in symbol_table
72 # Check that insertion actually renames this symbol in the symbol table.
73 foo2 = m2.body.operations[0]
74 m1.body.append(foo2)
75 updated_name = symbol_table.insert(foo2)
76 assert foo2.name.value != "foo"
77 assert foo2.name == updated_name
78 assert isinstance(updated_name, StringAttr)
80 # CHECK: module
81 # CHECK: func private @foo()
82 # CHECK: func private @qux()
83 # CHECK: func private @foo{{.*}}
84 print(m1)
86 try:
87 symbol_table.insert(m2.body.operations[0])
88 except ValueError as e:
89 if "Expected operation to have a symbol name" not in str(e):
90 raise
91 else:
92 assert False, "exepcted ValueError when adding a non-symbol"
95 # CHECK-LABEL: testSymbolTableRAUW
96 @run
97 def testSymbolTableRAUW():
98 with Context() as ctx:
99 m = Module.parse(
101 func.func private @foo() {
102 call @bar() : () -> ()
103 return
105 func.func private @bar()
108 foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2]
110 # Do renaming just within `foo`.
111 SymbolTable.set_symbol_name(bar, "bam")
112 SymbolTable.replace_all_symbol_uses("bar", "bam", foo)
113 # CHECK: call @bam()
114 # CHECK: func private @bam
115 print(m)
116 # CHECK: Foo symbol: StringAttr("foo")
117 # CHECK: Bar symbol: StringAttr("bam")
118 print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}")
119 print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}")
121 # Do renaming within the module.
122 SymbolTable.set_symbol_name(bar, "baz")
123 SymbolTable.replace_all_symbol_uses("bam", "baz", m.operation)
124 # CHECK: call @baz()
125 # CHECK: func private @baz
126 print(m)
127 # CHECK: Foo symbol: StringAttr("foo")
128 # CHECK: Bar symbol: StringAttr("baz")
129 print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}")
130 print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}")
133 # CHECK-LABEL: testSymbolTableVisibility
134 @run
135 def testSymbolTableVisibility():
136 with Context() as ctx:
137 m = Module.parse(
139 func.func private @foo() {
140 return
144 foo = m.operation.regions[0].blocks[0].operations[0]
145 # CHECK: Existing visibility: StringAttr("private")
146 print(f"Existing visibility: {repr(SymbolTable.get_visibility(foo))}")
147 SymbolTable.set_visibility(foo, "public")
148 # CHECK: func public @foo
149 print(m)
152 # CHECK: testWalkSymbolTables
153 @run
154 def testWalkSymbolTables():
155 with Context() as ctx:
156 m = Module.parse(
158 module @outer {
159 module @inner{
165 def callback(symbol_table_op, uses_visible):
166 print(f"SYMBOL TABLE: {uses_visible}: {symbol_table_op}")
168 # CHECK: SYMBOL TABLE: True: module @inner
169 # CHECK: SYMBOL TABLE: True: module @outer
170 SymbolTable.walk_symbol_tables(m.operation, True, callback)
172 # Make sure exceptions in the callback are handled.
173 def error_callback(symbol_table_op, uses_visible):
174 assert False, "Raised from python"
176 try:
177 SymbolTable.walk_symbol_tables(m.operation, True, error_callback)
178 except RuntimeError as e:
179 # CHECK: GOT EXCEPTION: Exception raised in callback: AssertionError: Raised from python
180 print(f"GOT EXCEPTION: {e}")