1 # RUN: %PYTHON %s | FileCheck %s
9 print("\nTEST:", f
.__name
__)
12 assert Context
._get
_live
_count
() == 0
16 # CHECK-LABEL: TEST: testParsePrint
19 with
Context() as ctx
:
20 t
= Attribute
.parse('"hello"')
21 assert t
.context
is ctx
26 # CHECK: StringAttr("hello")
30 # CHECK-LABEL: TEST: testParseError
35 t
= Attribute
.parse("BAD_ATTR_DOES_NOT_EXIST")
36 except MLIRError
as e
:
37 # CHECK: testParseError: <
38 # CHECK: Unable to parse attribute:
39 # CHECK: error: "BAD_ATTR_DOES_NOT_EXIST":1:1: expected attribute value
41 print(f
"testParseError: <{e}>")
43 print("Exception not produced")
46 # CHECK-LABEL: TEST: testAttrEq
50 a1
= Attribute
.parse('"attr1"')
51 a2
= Attribute
.parse('"attr2"')
52 a3
= Attribute
.parse('"attr1"')
53 # CHECK: a1 == a1: True
54 print("a1 == a1:", a1
== a1
)
55 # CHECK: a1 == a2: False
56 print("a1 == a2:", a1
== a2
)
57 # CHECK: a1 == a3: True
58 print("a1 == a3:", a1
== a3
)
59 # CHECK: a1 is None: False
60 print("a1 is None:", a1
is None)
63 # CHECK-LABEL: TEST: testAttrHash
67 a1
= Attribute
.parse('"attr1"')
68 a2
= Attribute
.parse('"attr2"')
69 a3
= Attribute
.parse('"attr1"')
70 # CHECK: hash(a1) == hash(a3): True
71 print("hash(a1) == hash(a3):", a1
.__hash__() == a3
.__hash__())
78 print("len(s): ", len(s
))
81 # CHECK-LABEL: TEST: testAttrCast
85 a1
= Attribute
.parse('"attr1"')
87 # CHECK: a1 == a2: True
88 print("a1 == a2:", a1
== a2
)
91 # CHECK-LABEL: TEST: testAttrIsInstance
93 def testAttrIsInstance():
95 a1
= Attribute
.parse("42")
96 a2
= Attribute
.parse("[42]")
97 assert IntegerAttr
.isinstance(a1
)
98 assert not IntegerAttr
.isinstance(a2
)
99 assert not ArrayAttr
.isinstance(a1
)
100 assert ArrayAttr
.isinstance(a2
)
103 # CHECK-LABEL: TEST: testAttrEqDoesNotRaise
105 def testAttrEqDoesNotRaise():
107 a1
= Attribute
.parse('"attr1"')
110 print(a1
== not_an_attr
)
114 print(a1
is not None)
117 # CHECK-LABEL: TEST: testAttrCapsule
119 def testAttrCapsule():
120 with
Context() as ctx
:
121 a1
= Attribute
.parse('"attr1"')
122 # CHECK: mlir.ir.Attribute._CAPIPtr
123 attr_capsule
= a1
._CAPIPtr
125 a2
= Attribute
._CAPICreate
(attr_capsule
)
127 assert a2
.context
is ctx
130 # CHECK-LABEL: TEST: testStandardAttrCasts
132 def testStandardAttrCasts():
134 a1
= Attribute
.parse('"attr1"')
135 astr
= StringAttr(a1
)
136 aself
= StringAttr(astr
)
137 # CHECK: StringAttr("attr1")
140 tillegal
= StringAttr(Attribute
.parse("1.0"))
141 except ValueError as e
:
142 # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64))
143 print("ValueError:", e
)
145 print("Exception not produced")
148 # CHECK-LABEL: TEST: testAffineMapAttr
150 def testAffineMapAttr():
151 with
Context() as ctx
:
152 d0
= AffineDimExpr
.get(0)
153 d1
= AffineDimExpr
.get(1)
154 c2
= AffineConstantExpr
.get(2)
155 map0
= AffineMap
.get(2, 3, [])
157 # CHECK: affine_map<(d0, d1)[s0, s1, s2] -> ()>
158 attr_built
= AffineMapAttr
.get(map0
)
159 print(str(attr_built
))
160 assert attr_built
.value
== map0
161 attr_parsed
= Attribute
.parse(str(attr_built
))
162 assert attr_built
== attr_parsed
165 # CHECK-LABEL: TEST: testIntegerSetAttr
167 def testIntegerSetAttr():
168 with
Context() as ctx
:
169 d0
= AffineDimExpr
.get(0)
170 d1
= AffineDimExpr
.get(1)
171 s0
= AffineSymbolExpr
.get(0)
172 c42
= AffineConstantExpr
.get(42)
173 set0
= IntegerSet
.get(2, 1, [d0
- d1
, s0
- c42
], [True, False])
175 # CHECK: affine_set<(d0, d1)[s0] : (d0 - d1 == 0, s0 - 42 >= 0)>
176 attr_built
= IntegerSetAttr
.get(set0
)
177 print(str(attr_built
))
179 attr_parsed
= Attribute
.parse(str(attr_built
))
180 assert attr_built
== attr_parsed
183 # CHECK-LABEL: TEST: testFloatAttr
186 with
Context(), Location
.unknown():
187 fattr
= FloatAttr(Attribute
.parse("42.0 : f32"))
188 # CHECK: fattr value: 42.0
189 print("fattr value:", fattr
.value
)
190 # CHECK: fattr float: 42.0 <class 'float'>
191 print("fattr float:", float(fattr
), type(float(fattr
)))
193 # Test factory methods.
194 # CHECK: default_get: 4.200000e+01 : f32
195 print("default_get:", FloatAttr
.get(F32Type
.get(), 42.0))
196 # CHECK: f32_get: 4.200000e+01 : f32
197 print("f32_get:", FloatAttr
.get_f32(42.0))
198 # CHECK: f64_get: 4.200000e+01 : f64
199 print("f64_get:", FloatAttr
.get_f64(42.0))
201 fattr_invalid
= FloatAttr
.get(IntegerType
.get_signless(32), 42)
202 except MLIRError
as e
:
203 # CHECK: Invalid attribute:
204 # CHECK: error: unknown: expected floating point type
207 print("Exception not produced")
210 # CHECK-LABEL: TEST: testIntegerAttr
212 def testIntegerAttr():
213 with
Context() as ctx
:
214 i_attr
= IntegerAttr(Attribute
.parse("42"))
215 # CHECK: i_attr value: 42
216 print("i_attr value:", i_attr
.value
)
217 # CHECK: i_attr type: i64
218 print("i_attr type:", i_attr
.type)
219 # CHECK: i_attr int: 42 <class 'int'>
220 print("i_attr int:", int(i_attr
), type(int(i_attr
)))
221 si_attr
= IntegerAttr(Attribute
.parse("-1 : si8"))
222 # CHECK: si_attr value: -1
223 print("si_attr value:", si_attr
.value
)
224 ui_attr
= IntegerAttr(Attribute
.parse("255 : ui8"))
225 # CHECK: i_attr int: -1 <class 'int'>
226 print("si_attr int:", int(si_attr
), type(int(si_attr
)))
227 # CHECK: ui_attr value: 255
228 print("ui_attr value:", ui_attr
.value
)
229 # CHECK: i_attr int: 255 <class 'int'>
230 print("ui_attr int:", int(ui_attr
), type(int(ui_attr
)))
231 idx_attr
= IntegerAttr(Attribute
.parse("-1 : index"))
232 # CHECK: idx_attr value: -1
233 print("idx_attr value:", idx_attr
.value
)
234 # CHECK: idx_attr int: -1 <class 'int'>
235 print("idx_attr int:", int(idx_attr
), type(int(idx_attr
)))
237 # Test factory methods.
238 # CHECK: default_get: 42 : i32
239 print("default_get:", IntegerAttr
.get(IntegerType
.get_signless(32), 42))
242 # CHECK-LABEL: TEST: testBoolAttr
245 with
Context() as ctx
:
246 battr
= BoolAttr(Attribute
.parse("true"))
247 # CHECK: iattr value: True
248 print("iattr value:", battr
.value
)
249 # CHECK: iattr bool: True <class 'bool'>
250 print("iattr bool:", bool(battr
), type(bool(battr
)))
252 # Test factory methods.
253 # CHECK: default_get: true
254 print("default_get:", BoolAttr
.get(True))
257 # CHECK-LABEL: TEST: testFlatSymbolRefAttr
259 def testFlatSymbolRefAttr():
260 with
Context() as ctx
:
261 sattr
= Attribute
.parse("@symbol")
262 # CHECK: symattr value: symbol
263 print("symattr value:", sattr
.value
)
265 # Test factory methods.
266 # CHECK: default_get: @foobar
267 print("default_get:", FlatSymbolRefAttr
.get("foobar"))
270 # CHECK-LABEL: TEST: testSymbolRefAttr
272 def testSymbolRefAttr():
273 with
Context() as ctx
:
274 sattr
= Attribute
.parse("@symbol1::@symbol2")
275 # CHECK: symattr value: ['symbol1', 'symbol2']
276 print("symattr value:", sattr
.value
)
278 # CHECK: default_get: @symbol1::@symbol2
279 print("default_get:", SymbolRefAttr
.get(["symbol1", "symbol2"]))
281 # CHECK: default_get: @"@symbol1"::@"@symbol2"
282 print("default_get:", SymbolRefAttr
.get(["@symbol1", "@symbol2"]))
285 # CHECK-LABEL: TEST: testOpaqueAttr
287 def testOpaqueAttr():
288 with
Context() as ctx
:
289 ctx
.allow_unregistered_dialects
= True
290 oattr
= OpaqueAttr(Attribute
.parse("#pytest_dummy.dummyattr<>"))
291 # CHECK: oattr value: pytest_dummy
292 print("oattr value:", oattr
.dialect_namespace
)
293 # CHECK: oattr value: b'dummyattr<>'
294 print("oattr value:", oattr
.data
)
296 # Test factory methods.
297 # CHECK: default_get: #foobar<123>
300 OpaqueAttr
.get("foobar", bytes("123", "utf-8"), NoneType
.get()),
304 # CHECK-LABEL: TEST: testStringAttr
306 def testStringAttr():
307 with
Context() as ctx
:
308 sattr
= StringAttr(Attribute
.parse('"stringattr"'))
309 # CHECK: sattr value: stringattr
310 print("sattr value:", sattr
.value
)
311 # CHECK: sattr value: b'stringattr'
312 print("sattr value:", sattr
.value_bytes
)
314 # Test factory methods.
315 # CHECK: default_get: "foobar"
316 print("default_get:", StringAttr
.get("foobar"))
317 # CHECK: typed_get: "12345" : i32
318 print("typed_get:", StringAttr
.get_typed(IntegerType
.get_signless(32), "12345"))
321 # CHECK-LABEL: TEST: testNamedAttr
325 a
= Attribute
.parse('"stringattr"')
326 named
= a
.get_named("foobar") # Note: under the small object threshold
327 # CHECK: attr: "stringattr"
328 print("attr:", named
.attr
)
329 # CHECK: name: foobar
330 print("name:", named
.name
)
331 # CHECK: named: NamedAttribute(foobar="stringattr")
332 print("named:", named
)
335 # CHECK-LABEL: TEST: testDenseIntAttr
337 def testDenseIntAttr():
339 raw
= Attribute
.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>")
340 # CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]>
343 a
= DenseIntElementsAttr(raw
)
348 print(value
, end
=" ")
352 print(ShapedType(a
.type).element_type
)
354 raw
= Attribute
.parse("dense<[true,false,true,false]> : vector<4xi1>")
355 # CHECK: attr: dense<[true, false, true, false]>
358 a
= DenseIntElementsAttr(raw
)
363 print(value
, end
=" ")
367 print(ShapedType(a
.type).element_type
)
371 def testDenseArrayGetItem():
372 def print_item(attr_asm
):
373 attr
= Attribute
.parse(attr_asm
)
374 print(f
"{len(attr)}: {attr[0]}, {attr[1]}")
377 # CHECK: 2: False, True
378 print_item("array<i1: false, true>")
380 print_item("array<i8: 2, 3>")
382 print_item("array<i16: 4, 5>")
384 print_item("array<i32: 6, 7>")
386 print_item("array<i64: 8, 9>")
387 # CHECK: 2: 1.{{0+}}, 2.{{0+}}
388 print_item("array<f32: 1.0, 2.0>")
389 # CHECK: 2: 3.{{0+}}, 4.{{0+}}
390 print_item("array<f64: 3.0, 4.0>")
396 # CHECK: myboolarray: array<i1: true>
397 print("myboolarray:", DenseBoolArrayAttr
.get([MyBool()]))
400 # CHECK-LABEL: TEST: testDenseArrayAttrConstruction
402 def testDenseArrayAttrConstruction():
403 with
Context(), Location
.unknown():
405 def create_and_print(cls
, x
):
408 print(f
"input: {x} ({type(x)}), result: {darr}")
409 except Exception as ex
:
410 print(f
"input: {x} ({type(x)}), error: {ex}")
412 # CHECK: input: [4, 2] (<class 'list'>),
413 # CHECK-SAME: result: array<i8: 4, 2>
414 create_and_print(DenseI8ArrayAttr
, [4, 2])
416 # CHECK: input: [4, 2.0] (<class 'list'>),
417 # CHECK-SAME: error: get(): incompatible function arguments
418 create_and_print(DenseI8ArrayAttr
, [4, 2.0])
420 # CHECK: input: [40000, 2] (<class 'list'>),
421 # CHECK-SAME: error: get(): incompatible function arguments
422 create_and_print(DenseI8ArrayAttr
, [40000, 2])
424 # CHECK: input: range(0, 4) (<class 'range'>),
425 # CHECK-SAME: result: array<i8: 0, 1, 2, 3>
426 create_and_print(DenseI8ArrayAttr
, range(4))
428 # CHECK: input: [IntegerAttr(4 : i64), IntegerAttr(2 : i64)] (<class 'list'>),
429 # CHECK-SAME: result: array<i8: 4, 2>
430 create_and_print(DenseI8ArrayAttr
, [Attribute
.parse(f
"{x}") for x
in [4, 2]])
432 # CHECK: input: [IntegerAttr(4000 : i64), IntegerAttr(2 : i64)] (<class 'list'>),
433 # CHECK-SAME: error: get(): incompatible function arguments
434 create_and_print(DenseI8ArrayAttr
, [Attribute
.parse(f
"{x}") for x
in [4000, 2]])
436 # CHECK: input: [IntegerAttr(4 : i64), FloatAttr(2.000000e+00 : f64)] (<class 'list'>),
437 # CHECK-SAME: error: get(): incompatible function arguments
438 create_and_print(DenseI8ArrayAttr
, [Attribute
.parse(f
"{x}") for x
in [4, 2.0]])
440 # CHECK: input: [IntegerAttr(4 : i8), IntegerAttr(2 : ui16)] (<class 'list'>),
441 # CHECK-SAME: result: array<i8: 4, 2>
443 DenseI8ArrayAttr
, [Attribute
.parse(s
) for s
in ["4 : i8", "2 : ui16"]]
446 # CHECK: input: [FloatAttr(4.000000e+00 : f64), FloatAttr(2.000000e+00 : f64)] (<class 'list'>)
447 # CHECK-SAME: result: array<f32: 4.000000e+00, 2.000000e+00>
449 DenseF32ArrayAttr
, [Attribute
.parse(f
"{x}") for x
in [4.0, 2.0]]
452 # CHECK: [BoolAttr(true), BoolAttr(false)] (<class 'list'>),
453 # CHECK-SAME: result: array<i1: true, false>
455 DenseBoolArrayAttr
, [Attribute
.parse(f
"{x}") for x
in ["true", "false"]]
459 # CHECK-LABEL: TEST: testDenseIntAttrGetItem
461 def testDenseIntAttrGetItem():
462 def print_item(attr_asm
):
463 attr
= Attribute
.parse(attr_asm
)
464 dtype
= ShapedType(attr
.type).element_type
467 print(f
"{dtype}:", item
)
468 except TypeError as e
:
469 print(f
"{dtype}:", e
)
473 print_item("dense<true> : tensor<i1>")
475 print_item("dense<123> : tensor<i8>")
477 print_item("dense<123> : tensor<i16>")
479 print_item("dense<123> : tensor<i32>")
481 print_item("dense<123> : tensor<i64>")
483 print_item("dense<123> : tensor<ui8>")
485 print_item("dense<123> : tensor<ui16>")
487 print_item("dense<123> : tensor<ui32>")
489 print_item("dense<123> : tensor<ui64>")
491 print_item("dense<-123> : tensor<si8>")
493 print_item("dense<-123> : tensor<si16>")
495 print_item("dense<-123> : tensor<si32>")
497 print_item("dense<-123> : tensor<si64>")
499 # CHECK: i7: Unsupported integer type
500 print_item("dense<123> : tensor<i7>")
503 # CHECK-LABEL: TEST: testDenseFPAttr
505 def testDenseFPAttr():
507 raw
= Attribute
.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>")
508 # CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
512 a
= DenseFPElementsAttr(raw
)
515 # CHECK: 0.0 1.0 2.0 3.0
517 print(value
, end
=" ")
521 print(ShapedType(a
.type).element_type
)
524 # CHECK-LABEL: TEST: testDictAttr
529 "stringattr": StringAttr
.get("string"),
530 "integerattr": IntegerAttr
.get(IntegerType
.get_signless(32), 42),
533 a
= DictAttr
.get(dict_attr
)
535 # CHECK: attr: {integerattr = 42 : i32, stringattr = "string"}
540 # CHECK: integerattr: IntegerAttr(42 : i32)
541 print("integerattr:", repr(a
["integerattr"]))
543 # CHECK: stringattr: StringAttr("string")
544 print("stringattr:", repr(a
["stringattr"]))
547 print("stringattr" in a
)
550 print("not_in_dict" in a
)
552 # Check that exceptions are raised as expected.
554 _
= a
["does_not_exist"]
558 assert False, "Exception not produced"
565 assert False, "expected IndexError on accessing an out-of-bounds attribute"
568 print("empty: ", DictAttr
.get())
571 # CHECK-LABEL: TEST: testTypeAttr
575 raw
= Attribute
.parse("vector<4xf32>")
576 # CHECK: attr: vector<4xf32>
578 type_attr
= TypeAttr(raw
)
580 print(ShapedType(type_attr
.value
).element_type
)
583 # CHECK-LABEL: TEST: testArrayAttr
587 arr
= Attribute
.parse("[42, true, vector<4xf32>]")
588 # CHECK: arr: [42, true, vector<4xf32>]
590 # CHECK: - IntegerAttr(42 : i64)
591 # CHECK: - BoolAttr(true)
592 # CHECK: - TypeAttr(vector<4xf32>)
594 print("- ", repr(attr
))
597 intAttr
= Attribute
.parse("42")
598 vecAttr
= Attribute
.parse("vector<4xf32>")
599 boolAttr
= BoolAttr
.get(True)
600 raw
= ArrayAttr
.get([vecAttr
, boolAttr
, intAttr
])
601 # CHECK: attr: [vector<4xf32>, true, 42]
602 print("raw attr:", raw
)
603 # CHECK: - TypeAttr(vector<4xf32>)
604 # CHECK: - BoolAttr(true
605 # CHECK: - IntegerAttr(42 : i64)
608 print("- ", repr(attr
))
609 # CHECK: attr[0]: TypeAttr(vector<4xf32>)
610 print("attr[0]:", repr(arr
[0]))
611 # CHECK: attr[1]: BoolAttr(true)
612 print("attr[1]:", repr(arr
[1]))
613 # CHECK: attr[2]: IntegerAttr(42 : i64)
614 print("attr[2]:", repr(arr
[2]))
616 print("attr[3]:", arr
[3])
617 except IndexError as e
:
618 # CHECK: Error: ArrayAttribute index out of range
622 ArrayAttr
.get([None])
623 except RuntimeError as e
:
624 # CHECK: Error: Invalid attribute (None?) when attempting to create an ArrayAttribute
628 except RuntimeError as e
:
629 # CHECK: Error: Invalid attribute when attempting to create an ArrayAttribute
633 array
= ArrayAttr
.get([StringAttr
.get("a"), StringAttr
.get("b")])
634 array
= array
+ [StringAttr
.get("c")]
635 # CHECK: concat: ["a", "b", "c"]
636 print("concat: ", array
)
639 # CHECK-LABEL: TEST: testStridedLayoutAttr
641 def testStridedLayoutAttr():
643 attr
= StridedLayoutAttr
.get(42, [5, 7, 13])
644 # CHECK: strided<[5, 7, 13], offset: 42>
649 print(len(attr
.strides
))
651 print(attr
.strides
[0])
653 print(attr
.strides
[1])
655 print(attr
.strides
[2])
657 attr
= StridedLayoutAttr
.get_fully_dynamic(3)
658 dynamic
= ShapedType
.get_dynamic_stride_or_offset()
659 # CHECK: strided<[?, ?, ?], offset: ?>
661 # CHECK: offset is dynamic: True
662 print(f
"offset is dynamic: {attr.offset == dynamic}")
664 print(f
"rank: {len(attr.strides)}")
665 # CHECK: strides are dynamic: [True, True, True]
666 print(f
"strides are dynamic: {[s == dynamic for s in attr.strides]}")
669 # CHECK-LABEL: TEST: testConcreteTypesRoundTrip
671 def testConcreteTypesRoundTrip():
672 with
Context(), Location
.unknown():
674 def print_item(attr
):
675 print(repr(attr
.type))
677 # CHECK: F32Type(f32)
678 print_item(Attribute
.parse("42.0 : f32"))
679 # CHECK: F32Type(f32)
680 print_item(FloatAttr
.get_f32(42.0))
681 # CHECK: IntegerType(i64)
682 print_item(IntegerAttr
.get(IntegerType
.get_signless(64), 42))
684 def print_container_item(attr_asm
):
685 attr
= DenseElementsAttr(Attribute
.parse(attr_asm
))
686 print(repr(attr
.type))
687 print(repr(attr
.type.element_type
))
689 # CHECK: RankedTensorType(tensor<i16>)
690 # CHECK: IntegerType(i16)
691 print_container_item("dense<123> : tensor<i16>")
693 # CHECK: RankedTensorType(tensor<f64>)
694 # CHECK: F64Type(f64)
695 print_container_item("dense<1.0> : tensor<f64>")
697 raw
= Attribute
.parse("vector<4xf32>")
698 # CHECK: attr: vector<4xf32>
700 type_attr
= TypeAttr(raw
)
702 # CHECK: VectorType(vector<4xf32>)
703 print(repr(type_attr
.value
))
704 # CHECK: F32Type(f32)
705 print(repr(type_attr
.value
.element_type
))
708 # CHECK-LABEL: TEST: testConcreteAttributesRoundTrip
710 def testConcreteAttributesRoundTrip():
711 with
Context(), Location
.unknown():
712 # CHECK: FloatAttr(4.200000e+01 : f32)
713 print(repr(Attribute
.parse("42.0 : f32")))
715 assert IntegerAttr
.static_typeid
is not None