1 # RUN: %PYTHON %s | FileCheck %s --enable-var-scope=false
5 from mlir
.dialects
import func
9 print("\nTEST:", f
.__name
__)
12 assert Context
._get
_live
_count
() == 0
16 # CHECK-LABEL: TEST: testCapsuleConversions
18 def testCapsuleConversions():
20 ctx
.allow_unregistered_dialects
= True
21 with Location
.unknown(ctx
):
22 i32
= IntegerType
.get_signless(32)
23 value
= Operation
.create("custom.op1", results
=[i32
]).result
24 value_capsule
= value
._CAPIPtr
25 assert '"mlir.ir.Value._CAPIPtr"' in repr(value_capsule
)
26 value2
= Value
._CAPICreate
(value_capsule
)
27 assert value2
== value
30 # CHECK-LABEL: TEST: testOpResultOwner
32 def testOpResultOwner():
34 ctx
.allow_unregistered_dialects
= True
35 with Location
.unknown(ctx
):
36 i32
= IntegerType
.get_signless(32)
37 op
= Operation
.create("custom.op1", results
=[i32
])
38 assert op
.result
.owner
== op
41 # CHECK-LABEL: TEST: testBlockArgOwner
43 def testBlockArgOwner():
45 ctx
.allow_unregistered_dialects
= True
46 module
= Module
.parse(
48 func.func @foo(%arg0: f32) {
53 func
= module
.body
.operations
[0]
54 block
= func
.regions
[0].blocks
[0]
55 assert block
.arguments
[0].owner
== block
58 # CHECK-LABEL: TEST: testValueIsInstance
60 def testValueIsInstance():
62 ctx
.allow_unregistered_dialects
= True
63 module
= Module
.parse(
65 func.func @foo(%arg0: f32) {
66 %0 = "some_dialect.some_op"() : () -> f64
71 func
= module
.body
.operations
[0]
72 assert BlockArgument
.isinstance(func
.regions
[0].blocks
[0].arguments
[0])
73 assert not OpResult
.isinstance(func
.regions
[0].blocks
[0].arguments
[0])
75 op
= func
.regions
[0].blocks
[0].operations
[0]
76 assert not BlockArgument
.isinstance(op
.results
[0])
77 assert OpResult
.isinstance(op
.results
[0])
80 # CHECK-LABEL: TEST: testValueHash
84 ctx
.allow_unregistered_dialects
= True
85 module
= Module
.parse(
87 func.func @foo(%arg0: f32) -> f32 {
88 %0 = "some_dialect.some_op"(%arg0) : (f32) -> f32
94 [func
] = module
.body
.operations
95 block
= func
.entry_block
96 op
, ret
= block
.operations
97 assert hash(block
.arguments
[0]) == hash(op
.operands
[0])
98 assert hash(op
.result
) == hash(ret
.operands
[0])
101 # CHECK-LABEL: TEST: testValueUses
105 ctx
.allow_unregistered_dialects
= True
106 with Location
.unknown(ctx
):
107 i32
= IntegerType
.get_signless(32)
108 module
= Module
.create()
109 with
InsertionPoint(module
.body
):
110 value
= Operation
.create("custom.op1", results
=[i32
]).results
[0]
111 op1
= Operation
.create("custom.op2", operands
=[value
])
112 op2
= Operation
.create("custom.op2", operands
=[value
])
114 # CHECK: Use owner: "custom.op2"
115 # CHECK: Use operand_number: 0
116 # CHECK: Use owner: "custom.op2"
117 # CHECK: Use operand_number: 0
118 for use
in value
.uses
:
119 assert use
.owner
in [op1
, op2
]
120 print(f
"Use owner: {use.owner}")
121 print(f
"Use operand_number: {use.operand_number}")
124 # CHECK-LABEL: TEST: testValueReplaceAllUsesWith
126 def testValueReplaceAllUsesWith():
128 ctx
.allow_unregistered_dialects
= True
129 with Location
.unknown(ctx
):
130 i32
= IntegerType
.get_signless(32)
131 module
= Module
.create()
132 with
InsertionPoint(module
.body
):
133 value
= Operation
.create("custom.op1", results
=[i32
]).results
[0]
134 op1
= Operation
.create("custom.op2", operands
=[value
])
135 op2
= Operation
.create("custom.op2", operands
=[value
])
136 value2
= Operation
.create("custom.op3", results
=[i32
]).results
[0]
137 value
.replace_all_uses_with(value2
)
139 assert len(list(value
.uses
)) == 0
141 # CHECK: Use owner: "custom.op2"
142 # CHECK: Use operand_number: 0
143 # CHECK: Use owner: "custom.op2"
144 # CHECK: Use operand_number: 0
145 for use
in value2
.uses
:
146 assert use
.owner
in [op1
, op2
]
147 print(f
"Use owner: {use.owner}")
148 print(f
"Use operand_number: {use.operand_number}")
151 # CHECK-LABEL: TEST: testValueReplaceAllUsesWithExcept
153 def testValueReplaceAllUsesWithExcept():
155 ctx
.allow_unregistered_dialects
= True
156 with Location
.unknown(ctx
):
157 i32
= IntegerType
.get_signless(32)
158 module
= Module
.create()
159 with
InsertionPoint(module
.body
):
160 value
= Operation
.create("custom.op1", results
=[i32
]).results
[0]
161 op1
= Operation
.create("custom.op1", operands
=[value
])
162 op2
= Operation
.create("custom.op2", operands
=[value
])
163 value2
= Operation
.create("custom.op3", results
=[i32
]).results
[0]
164 value
.replace_all_uses_except(value2
, op1
)
166 assert len(list(value
.uses
)) == 1
168 # CHECK: Use owner: "custom.op2"
169 # CHECK: Use operand_number: 0
170 for use
in value2
.uses
:
171 assert use
.owner
in [op2
]
172 print(f
"Use owner: {use.owner}")
173 print(f
"Use operand_number: {use.operand_number}")
175 # CHECK: Use owner: "custom.op1"
176 # CHECK: Use operand_number: 0
177 for use
in value
.uses
:
178 assert use
.owner
in [op1
]
179 print(f
"Use owner: {use.owner}")
180 print(f
"Use operand_number: {use.operand_number}")
183 # CHECK-LABEL: TEST: testValueReplaceAllUsesWithMultipleExceptions
185 def testValueReplaceAllUsesWithMultipleExceptions():
187 ctx
.allow_unregistered_dialects
= True
188 with Location
.unknown(ctx
):
189 i32
= IntegerType
.get_signless(32)
190 module
= Module
.create()
191 with
InsertionPoint(module
.body
):
192 value
= Operation
.create("custom.op1", results
=[i32
]).results
[0]
193 op1
= Operation
.create("custom.op1", operands
=[value
])
194 op2
= Operation
.create("custom.op2", operands
=[value
])
195 op3
= Operation
.create("custom.op3", operands
=[value
])
196 value2
= Operation
.create("custom.op4", results
=[i32
]).results
[0]
198 # Replace all uses of `value` with `value2`, except for `op1` and `op2`.
199 value
.replace_all_uses_except(value2
, [op1
, op2
])
201 # After replacement, only `op3` should use `value2`, while `op1` and `op2` should still use `value`.
202 assert len(list(value
.uses
)) == 2
203 assert len(list(value2
.uses
)) == 1
205 # CHECK: Use owner: "custom.op3"
206 # CHECK: Use operand_number: 0
207 for use
in value2
.uses
:
208 assert use
.owner
in [op3
]
209 print(f
"Use owner: {use.owner}")
210 print(f
"Use operand_number: {use.operand_number}")
212 # CHECK: Use owner: "custom.op2"
213 # CHECK: Use operand_number: 0
214 # CHECK: Use owner: "custom.op1"
215 # CHECK: Use operand_number: 0
216 for use
in value
.uses
:
217 assert use
.owner
in [op1
, op2
]
218 print(f
"Use owner: {use.owner}")
219 print(f
"Use operand_number: {use.operand_number}")
222 # CHECK-LABEL: TEST: testValuePrintAsOperand
224 def testValuePrintAsOperand():
226 ctx
.allow_unregistered_dialects
= True
227 with Location
.unknown(ctx
):
228 i32
= IntegerType
.get_signless(32)
229 module
= Module
.create()
230 with
InsertionPoint(module
.body
):
231 value
= Operation
.create("custom.op1", results
=[i32
]).results
[0]
232 # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32)
235 value2
= Operation
.create("custom.op2", results
=[i32
]).results
[0]
236 # CHECK: Value(%[[VAL2:.*]] = "custom.op2"() : () -> i32)
239 topFn
= func
.FuncOp("test", ([i32
, i32
], []))
240 entry_block
= Block
.create_at_start(topFn
.operation
.regions
[0], [i32
, i32
])
242 with
InsertionPoint(entry_block
):
243 value3
= Operation
.create("custom.op3", results
=[i32
]).results
[0]
244 # CHECK: Value(%[[VAL3:.*]] = "custom.op3"() : () -> i32)
246 value4
= Operation
.create("custom.op4", results
=[i32
]).results
[0]
247 # CHECK: Value(%[[VAL4:.*]] = "custom.op4"() : () -> i32)
252 print(value
.get_name())
254 print(value2
.get_name())
256 print(value3
.get_name())
258 print(value4
.get_name())
260 print("With AsmState")
261 # CHECK-LABEL: With AsmState
262 state
= AsmState(topFn
.operation
, use_local_scope
=True)
264 print(value3
.get_name(state
=state
))
266 print(value4
.get_name(state
=state
))
268 print("With use_local_scope")
269 # CHECK-LABEL: With use_local_scope
271 print(value3
.get_name(use_local_scope
=True))
273 print(value4
.get_name(use_local_scope
=True))
275 # CHECK: %[[ARG0:.*]]
276 print(entry_block
.arguments
[0].get_name())
277 # CHECK: %[[ARG1:.*]]
278 print(entry_block
.arguments
[1].get_name())
281 # CHECK: %[[VAL1]] = "custom.op1"() : () -> i32
282 # CHECK: %[[VAL2]] = "custom.op2"() : () -> i32
283 # CHECK: func.func @test(%[[ARG0]]: i32, %[[ARG1]]: i32) {
284 # CHECK: %[[VAL3]] = "custom.op3"() : () -> i32
285 # CHECK: %[[VAL4]] = "custom.op4"() : () -> i32
291 value2
.owner
.detach_from_parent()
293 print(value2
.get_name())
296 # CHECK-LABEL: TEST: testValueSetType
298 def testValueSetType():
300 ctx
.allow_unregistered_dialects
= True
301 with Location
.unknown(ctx
):
302 i32
= IntegerType
.get_signless(32)
303 i64
= IntegerType
.get_signless(64)
304 module
= Module
.create()
305 with
InsertionPoint(module
.body
):
306 value
= Operation
.create("custom.op1", results
=[i32
]).results
[0]
307 # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32)
311 # CHECK: Value(%[[VAL1]] = "custom.op1"() : () -> i64)
314 # CHECK: %[[VAL1]] = "custom.op1"() : () -> i64
318 # CHECK-LABEL: TEST: testValueCasters
320 def testValueCasters():
321 class NOPResult(OpResult
):
322 def __init__(self
, v
):
326 return super().__str
__().replace(Value
.__name
__, NOPResult
.__name
__)
328 class NOPValue(Value
):
329 def __init__(self
, v
):
333 return super().__str
__().replace(Value
.__name
__, NOPValue
.__name
__)
335 class NOPBlockArg(BlockArgument
):
336 def __init__(self
, v
):
340 return super().__str
__().replace(Value
.__name
__, NOPBlockArg
.__name
__)
342 @register_value_caster(IntegerType
.static_typeid
)
343 def cast_int(v
) -> Value
:
344 print("in caster", v
.__class
__.__name
__)
345 if isinstance(v
, OpResult
):
347 if isinstance(v
, BlockArgument
):
348 return NOPBlockArg(v
)
349 elif isinstance(v
, Value
):
353 ctx
.allow_unregistered_dialects
= True
354 with Location
.unknown(ctx
):
355 i32
= IntegerType
.get_signless(32)
356 module
= Module
.create()
357 with
InsertionPoint(module
.body
):
358 values
= Operation
.create("custom.op1", results
=[i32
, i32
]).results
359 # CHECK: in caster OpResult
360 # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
361 print("result", values
[0].result_number
, values
[0])
362 # CHECK: in caster OpResult
363 # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
364 print("result", values
[1].result_number
, values
[1])
366 # CHECK: results slice 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
367 print("results slice", values
[:1][0].result_number
, values
[:1][0])
369 value0
, value1
= values
370 # CHECK: in caster OpResult
371 # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
372 print("result", value0
.result_number
, values
[0])
373 # CHECK: in caster OpResult
374 # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
375 print("result", value1
.result_number
, values
[1])
377 op1
= Operation
.create("custom.op2", operands
=[value0
, value1
])
378 # CHECK: "custom.op2"(%0#0, %0#1) : (i32, i32) -> ()
381 # CHECK: in caster Value
382 # CHECK: operand 0 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
383 print("operand 0", op1
.operands
[0])
384 # CHECK: in caster Value
385 # CHECK: operand 1 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
386 print("operand 1", op1
.operands
[1])
388 # CHECK: in caster BlockArgument
389 # CHECK: in caster BlockArgument
390 @func.FuncOp
.from_py_func(i32
, i32
)
391 def reduction(arg0
, arg1
):
392 # CHECK: as func arg 0 NOPBlockArg
393 print("as func arg", arg0
.arg_number
, arg0
.__class__
.__name
__)
394 # CHECK: as func arg 1 NOPBlockArg
395 print("as func arg", arg1
.arg_number
, arg1
.__class__
.__name
__)
397 # CHECK: args slice 0 NOPBlockArg(<block argument> of type 'i32' at index: 0)
400 reduction
.func_op
.arguments
[:1][0].arg_number
,
401 reduction
.func_op
.arguments
[:1][0],
406 @register_value_caster(IntegerType
.static_typeid
)
407 def dont_cast_int_shouldnt_register(v
):
410 except RuntimeError as e
:
411 # CHECK: Value caster is already registered: {{.*}}cast_int
414 @register_value_caster(IntegerType
.static_typeid
, replace
=True)
415 def dont_cast_int(v
) -> OpResult
:
416 assert isinstance(v
, OpResult
)
417 print("don't cast", v
.result_number
, v
)
420 with Location
.unknown(ctx
):
421 i32
= IntegerType
.get_signless(32)
422 module
= Module
.create()
423 with
InsertionPoint(module
.body
):
424 # CHECK: don't cast 0 Value(%0 = "custom.op1"() : () -> i32)
425 new_value
= Operation
.create("custom.op1", results
=[i32
]).result
426 # CHECK: result 0 Value(%0 = "custom.op1"() : () -> i32)
427 print("result", new_value
.result_number
, new_value
)
429 # CHECK: don't cast 0 Value(%1 = "custom.op2"() : () -> i32)
430 new_value
= Operation
.create("custom.op2", results
=[i32
]).results
[0]
431 # CHECK: result 0 Value(%1 = "custom.op2"() : () -> i32)
432 print("result", new_value
.result_number
, new_value
)