[clang-tidy][NFC]remove deps of clang in clang tidy test (#116588)
[llvm-project.git] / mlir / test / python / ir / value.py
blob9a8146bd9350bc7f0c33b429e7f1d31fe39d7724
1 # RUN: %PYTHON %s | FileCheck %s --enable-var-scope=false
3 import gc
4 from mlir.ir import *
5 from mlir.dialects import func
8 def run(f):
9 print("\nTEST:", f.__name__)
10 f()
11 gc.collect()
12 assert Context._get_live_count() == 0
13 return f
16 # CHECK-LABEL: TEST: testCapsuleConversions
17 @run
18 def testCapsuleConversions():
19 ctx = Context()
20 ctx.allow_unregistered_dialects = True
21 with Location.unknown(ctx):
22 i32 = IntegerType.get_signless(32)
23 value = Operation.create("custom.op1", results=[i32]).result
24 value_capsule = value._CAPIPtr
25 assert '"mlir.ir.Value._CAPIPtr"' in repr(value_capsule)
26 value2 = Value._CAPICreate(value_capsule)
27 assert value2 == value
30 # CHECK-LABEL: TEST: testOpResultOwner
31 @run
32 def testOpResultOwner():
33 ctx = Context()
34 ctx.allow_unregistered_dialects = True
35 with Location.unknown(ctx):
36 i32 = IntegerType.get_signless(32)
37 op = Operation.create("custom.op1", results=[i32])
38 assert op.result.owner == op
41 # CHECK-LABEL: TEST: testBlockArgOwner
42 @run
43 def testBlockArgOwner():
44 ctx = Context()
45 ctx.allow_unregistered_dialects = True
46 module = Module.parse(
47 r"""
48 func.func @foo(%arg0: f32) {
49 return
50 }""",
51 ctx,
53 func = module.body.operations[0]
54 block = func.regions[0].blocks[0]
55 assert block.arguments[0].owner == block
58 # CHECK-LABEL: TEST: testValueIsInstance
59 @run
60 def testValueIsInstance():
61 ctx = Context()
62 ctx.allow_unregistered_dialects = True
63 module = Module.parse(
64 r"""
65 func.func @foo(%arg0: f32) {
66 %0 = "some_dialect.some_op"() : () -> f64
67 return
68 }""",
69 ctx,
71 func = module.body.operations[0]
72 assert BlockArgument.isinstance(func.regions[0].blocks[0].arguments[0])
73 assert not OpResult.isinstance(func.regions[0].blocks[0].arguments[0])
75 op = func.regions[0].blocks[0].operations[0]
76 assert not BlockArgument.isinstance(op.results[0])
77 assert OpResult.isinstance(op.results[0])
80 # CHECK-LABEL: TEST: testValueHash
81 @run
82 def testValueHash():
83 ctx = Context()
84 ctx.allow_unregistered_dialects = True
85 module = Module.parse(
86 r"""
87 func.func @foo(%arg0: f32) -> f32 {
88 %0 = "some_dialect.some_op"(%arg0) : (f32) -> f32
89 return %0 : f32
90 }""",
91 ctx,
94 [func] = module.body.operations
95 block = func.entry_block
96 op, ret = block.operations
97 assert hash(block.arguments[0]) == hash(op.operands[0])
98 assert hash(op.result) == hash(ret.operands[0])
101 # CHECK-LABEL: TEST: testValueUses
102 @run
103 def testValueUses():
104 ctx = Context()
105 ctx.allow_unregistered_dialects = True
106 with Location.unknown(ctx):
107 i32 = IntegerType.get_signless(32)
108 module = Module.create()
109 with InsertionPoint(module.body):
110 value = Operation.create("custom.op1", results=[i32]).results[0]
111 op1 = Operation.create("custom.op2", operands=[value])
112 op2 = Operation.create("custom.op2", operands=[value])
114 # CHECK: Use owner: "custom.op2"
115 # CHECK: Use operand_number: 0
116 # CHECK: Use owner: "custom.op2"
117 # CHECK: Use operand_number: 0
118 for use in value.uses:
119 assert use.owner in [op1, op2]
120 print(f"Use owner: {use.owner}")
121 print(f"Use operand_number: {use.operand_number}")
124 # CHECK-LABEL: TEST: testValueReplaceAllUsesWith
125 @run
126 def testValueReplaceAllUsesWith():
127 ctx = Context()
128 ctx.allow_unregistered_dialects = True
129 with Location.unknown(ctx):
130 i32 = IntegerType.get_signless(32)
131 module = Module.create()
132 with InsertionPoint(module.body):
133 value = Operation.create("custom.op1", results=[i32]).results[0]
134 op1 = Operation.create("custom.op2", operands=[value])
135 op2 = Operation.create("custom.op2", operands=[value])
136 value2 = Operation.create("custom.op3", results=[i32]).results[0]
137 value.replace_all_uses_with(value2)
139 assert len(list(value.uses)) == 0
141 # CHECK: Use owner: "custom.op2"
142 # CHECK: Use operand_number: 0
143 # CHECK: Use owner: "custom.op2"
144 # CHECK: Use operand_number: 0
145 for use in value2.uses:
146 assert use.owner in [op1, op2]
147 print(f"Use owner: {use.owner}")
148 print(f"Use operand_number: {use.operand_number}")
151 # CHECK-LABEL: TEST: testValueReplaceAllUsesWithExcept
152 @run
153 def testValueReplaceAllUsesWithExcept():
154 ctx = Context()
155 ctx.allow_unregistered_dialects = True
156 with Location.unknown(ctx):
157 i32 = IntegerType.get_signless(32)
158 module = Module.create()
159 with InsertionPoint(module.body):
160 value = Operation.create("custom.op1", results=[i32]).results[0]
161 op1 = Operation.create("custom.op1", operands=[value])
162 op2 = Operation.create("custom.op2", operands=[value])
163 value2 = Operation.create("custom.op3", results=[i32]).results[0]
164 value.replace_all_uses_except(value2, op1)
166 assert len(list(value.uses)) == 1
168 # CHECK: Use owner: "custom.op2"
169 # CHECK: Use operand_number: 0
170 for use in value2.uses:
171 assert use.owner in [op2]
172 print(f"Use owner: {use.owner}")
173 print(f"Use operand_number: {use.operand_number}")
175 # CHECK: Use owner: "custom.op1"
176 # CHECK: Use operand_number: 0
177 for use in value.uses:
178 assert use.owner in [op1]
179 print(f"Use owner: {use.owner}")
180 print(f"Use operand_number: {use.operand_number}")
183 # CHECK-LABEL: TEST: testValueReplaceAllUsesWithMultipleExceptions
184 @run
185 def testValueReplaceAllUsesWithMultipleExceptions():
186 ctx = Context()
187 ctx.allow_unregistered_dialects = True
188 with Location.unknown(ctx):
189 i32 = IntegerType.get_signless(32)
190 module = Module.create()
191 with InsertionPoint(module.body):
192 value = Operation.create("custom.op1", results=[i32]).results[0]
193 op1 = Operation.create("custom.op1", operands=[value])
194 op2 = Operation.create("custom.op2", operands=[value])
195 op3 = Operation.create("custom.op3", operands=[value])
196 value2 = Operation.create("custom.op4", results=[i32]).results[0]
198 # Replace all uses of `value` with `value2`, except for `op1` and `op2`.
199 value.replace_all_uses_except(value2, [op1, op2])
201 # After replacement, only `op3` should use `value2`, while `op1` and `op2` should still use `value`.
202 assert len(list(value.uses)) == 2
203 assert len(list(value2.uses)) == 1
205 # CHECK: Use owner: "custom.op3"
206 # CHECK: Use operand_number: 0
207 for use in value2.uses:
208 assert use.owner in [op3]
209 print(f"Use owner: {use.owner}")
210 print(f"Use operand_number: {use.operand_number}")
212 # CHECK: Use owner: "custom.op2"
213 # CHECK: Use operand_number: 0
214 # CHECK: Use owner: "custom.op1"
215 # CHECK: Use operand_number: 0
216 for use in value.uses:
217 assert use.owner in [op1, op2]
218 print(f"Use owner: {use.owner}")
219 print(f"Use operand_number: {use.operand_number}")
222 # CHECK-LABEL: TEST: testValuePrintAsOperand
223 @run
224 def testValuePrintAsOperand():
225 ctx = Context()
226 ctx.allow_unregistered_dialects = True
227 with Location.unknown(ctx):
228 i32 = IntegerType.get_signless(32)
229 module = Module.create()
230 with InsertionPoint(module.body):
231 value = Operation.create("custom.op1", results=[i32]).results[0]
232 # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32)
233 print(value)
235 value2 = Operation.create("custom.op2", results=[i32]).results[0]
236 # CHECK: Value(%[[VAL2:.*]] = "custom.op2"() : () -> i32)
237 print(value2)
239 topFn = func.FuncOp("test", ([i32, i32], []))
240 entry_block = Block.create_at_start(topFn.operation.regions[0], [i32, i32])
242 with InsertionPoint(entry_block):
243 value3 = Operation.create("custom.op3", results=[i32]).results[0]
244 # CHECK: Value(%[[VAL3:.*]] = "custom.op3"() : () -> i32)
245 print(value3)
246 value4 = Operation.create("custom.op4", results=[i32]).results[0]
247 # CHECK: Value(%[[VAL4:.*]] = "custom.op4"() : () -> i32)
248 print(value4)
249 func.ReturnOp([])
251 # CHECK: %[[VAL1]]
252 print(value.get_name())
253 # CHECK: %[[VAL2]]
254 print(value2.get_name())
255 # CHECK: %[[VAL3]]
256 print(value3.get_name())
257 # CHECK: %[[VAL4]]
258 print(value4.get_name())
260 print("With AsmState")
261 # CHECK-LABEL: With AsmState
262 state = AsmState(topFn.operation, use_local_scope=True)
263 # CHECK: %0
264 print(value3.get_name(state=state))
265 # CHECK: %1
266 print(value4.get_name(state=state))
268 print("With use_local_scope")
269 # CHECK-LABEL: With use_local_scope
270 # CHECK: %0
271 print(value3.get_name(use_local_scope=True))
272 # CHECK: %1
273 print(value4.get_name(use_local_scope=True))
275 # CHECK: %[[ARG0:.*]]
276 print(entry_block.arguments[0].get_name())
277 # CHECK: %[[ARG1:.*]]
278 print(entry_block.arguments[1].get_name())
280 # CHECK: module {
281 # CHECK: %[[VAL1]] = "custom.op1"() : () -> i32
282 # CHECK: %[[VAL2]] = "custom.op2"() : () -> i32
283 # CHECK: func.func @test(%[[ARG0]]: i32, %[[ARG1]]: i32) {
284 # CHECK: %[[VAL3]] = "custom.op3"() : () -> i32
285 # CHECK: %[[VAL4]] = "custom.op4"() : () -> i32
286 # CHECK: return
287 # CHECK: }
288 # CHECK: }
289 print(module)
291 value2.owner.detach_from_parent()
292 # CHECK: %0
293 print(value2.get_name())
296 # CHECK-LABEL: TEST: testValueSetType
297 @run
298 def testValueSetType():
299 ctx = Context()
300 ctx.allow_unregistered_dialects = True
301 with Location.unknown(ctx):
302 i32 = IntegerType.get_signless(32)
303 i64 = IntegerType.get_signless(64)
304 module = Module.create()
305 with InsertionPoint(module.body):
306 value = Operation.create("custom.op1", results=[i32]).results[0]
307 # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32)
308 print(value)
310 value.set_type(i64)
311 # CHECK: Value(%[[VAL1]] = "custom.op1"() : () -> i64)
312 print(value)
314 # CHECK: %[[VAL1]] = "custom.op1"() : () -> i64
315 print(value.owner)
318 # CHECK-LABEL: TEST: testValueCasters
319 @run
320 def testValueCasters():
321 class NOPResult(OpResult):
322 def __init__(self, v):
323 super().__init__(v)
325 def __str__(self):
326 return super().__str__().replace(Value.__name__, NOPResult.__name__)
328 class NOPValue(Value):
329 def __init__(self, v):
330 super().__init__(v)
332 def __str__(self):
333 return super().__str__().replace(Value.__name__, NOPValue.__name__)
335 class NOPBlockArg(BlockArgument):
336 def __init__(self, v):
337 super().__init__(v)
339 def __str__(self):
340 return super().__str__().replace(Value.__name__, NOPBlockArg.__name__)
342 @register_value_caster(IntegerType.static_typeid)
343 def cast_int(v) -> Value:
344 print("in caster", v.__class__.__name__)
345 if isinstance(v, OpResult):
346 return NOPResult(v)
347 if isinstance(v, BlockArgument):
348 return NOPBlockArg(v)
349 elif isinstance(v, Value):
350 return NOPValue(v)
352 ctx = Context()
353 ctx.allow_unregistered_dialects = True
354 with Location.unknown(ctx):
355 i32 = IntegerType.get_signless(32)
356 module = Module.create()
357 with InsertionPoint(module.body):
358 values = Operation.create("custom.op1", results=[i32, i32]).results
359 # CHECK: in caster OpResult
360 # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
361 print("result", values[0].result_number, values[0])
362 # CHECK: in caster OpResult
363 # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
364 print("result", values[1].result_number, values[1])
366 # CHECK: results slice 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
367 print("results slice", values[:1][0].result_number, values[:1][0])
369 value0, value1 = values
370 # CHECK: in caster OpResult
371 # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
372 print("result", value0.result_number, values[0])
373 # CHECK: in caster OpResult
374 # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
375 print("result", value1.result_number, values[1])
377 op1 = Operation.create("custom.op2", operands=[value0, value1])
378 # CHECK: "custom.op2"(%0#0, %0#1) : (i32, i32) -> ()
379 print(op1)
381 # CHECK: in caster Value
382 # CHECK: operand 0 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
383 print("operand 0", op1.operands[0])
384 # CHECK: in caster Value
385 # CHECK: operand 1 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
386 print("operand 1", op1.operands[1])
388 # CHECK: in caster BlockArgument
389 # CHECK: in caster BlockArgument
390 @func.FuncOp.from_py_func(i32, i32)
391 def reduction(arg0, arg1):
392 # CHECK: as func arg 0 NOPBlockArg
393 print("as func arg", arg0.arg_number, arg0.__class__.__name__)
394 # CHECK: as func arg 1 NOPBlockArg
395 print("as func arg", arg1.arg_number, arg1.__class__.__name__)
397 # CHECK: args slice 0 NOPBlockArg(<block argument> of type 'i32' at index: 0)
398 print(
399 "args slice",
400 reduction.func_op.arguments[:1][0].arg_number,
401 reduction.func_op.arguments[:1][0],
404 try:
406 @register_value_caster(IntegerType.static_typeid)
407 def dont_cast_int_shouldnt_register(v):
410 except RuntimeError as e:
411 # CHECK: Value caster is already registered: {{.*}}cast_int
412 print(e)
414 @register_value_caster(IntegerType.static_typeid, replace=True)
415 def dont_cast_int(v) -> OpResult:
416 assert isinstance(v, OpResult)
417 print("don't cast", v.result_number, v)
418 return v
420 with Location.unknown(ctx):
421 i32 = IntegerType.get_signless(32)
422 module = Module.create()
423 with InsertionPoint(module.body):
424 # CHECK: don't cast 0 Value(%0 = "custom.op1"() : () -> i32)
425 new_value = Operation.create("custom.op1", results=[i32]).result
426 # CHECK: result 0 Value(%0 = "custom.op1"() : () -> i32)
427 print("result", new_value.result_number, new_value)
429 # CHECK: don't cast 0 Value(%1 = "custom.op2"() : () -> i32)
430 new_value = Operation.create("custom.op2", results=[i32]).results[0]
431 # CHECK: result 0 Value(%1 = "custom.op2"() : () -> i32)
432 print("result", new_value.result_number, new_value)