1 # RUN: %PYTHON %s | FileCheck %s
4 from mlir
.dialects
import sparse_tensor
as st
7 print("\nTEST:", f
.__name
__)
12 # CHECK-LABEL: TEST: testEncodingAttr1D
14 def testEncodingAttr1D():
15 with
Context() as ctx
:
16 parsed
= Attribute
.parse(
17 '#sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], '
18 'pointerBitWidth = 16, indexBitWidth = 32 }>')
21 casted
= st
.EncodingAttr(parsed
)
23 print(f
"equal: {casted == parsed}")
25 # CHECK: dim_level_types: [<DimLevelType.compressed: 1>]
26 print(f
"dim_level_types: {casted.dim_level_types}")
27 # CHECK: dim_ordering: None
28 # Note that for 1D, the ordering is None, which exercises several special
30 print(f
"dim_ordering: {casted.dim_ordering}")
31 # CHECK: pointer_bit_width: 16
32 print(f
"pointer_bit_width: {casted.pointer_bit_width}")
33 # CHECK: index_bit_width: 32
34 print(f
"index_bit_width: {casted.index_bit_width}")
36 created
= st
.EncodingAttr
.get(casted
.dim_level_types
, None, 16, 32)
38 # CHECK: created_equal: True
39 print(f
"created_equal: {created == casted}")
41 # Verify that the factory creates an instance of the proper type.
42 # CHECK: is_proper_instance: True
43 print(f
"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
44 # CHECK: created_pointer_bit_width: 16
45 print(f
"created_pointer_bit_width: {created.pointer_bit_width}")
48 # CHECK-LABEL: TEST: testEncodingAttr2D
50 def testEncodingAttr2D():
51 with
Context() as ctx
:
52 parsed
= Attribute
.parse(
53 '#sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], '
54 'dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, '
55 'pointerBitWidth = 16, indexBitWidth = 32 }>')
58 casted
= st
.EncodingAttr(parsed
)
60 print(f
"equal: {casted == parsed}")
62 # CHECK: dim_level_types: [<DimLevelType.dense: 0>, <DimLevelType.compressed: 1>]
63 print(f
"dim_level_types: {casted.dim_level_types}")
64 # CHECK: dim_ordering: (d0, d1) -> (d0, d1)
65 print(f
"dim_ordering: {casted.dim_ordering}")
66 # CHECK: pointer_bit_width: 16
67 print(f
"pointer_bit_width: {casted.pointer_bit_width}")
68 # CHECK: index_bit_width: 32
69 print(f
"index_bit_width: {casted.index_bit_width}")
71 created
= st
.EncodingAttr
.get(casted
.dim_level_types
, casted
.dim_ordering
,
74 # CHECK: created_equal: True
75 print(f
"created_equal: {created == casted}")
78 # CHECK-LABEL: TEST: testEncodingAttrOnTensor
80 def testEncodingAttrOnTensor():
81 with
Context() as ctx
, Location
.unknown():
82 encoding
= st
.EncodingAttr(Attribute
.parse(
83 '#sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], '
84 'pointerBitWidth = 16, indexBitWidth = 32 }>'))
85 tt
= RankedTensorType
.get((1024,), F32Type
.get(), encoding
=encoding
)
86 # CHECK: tensor<1024xf32, #sparse_tensor
88 # CHECK: #sparse_tensor.encoding
90 assert tt
.encoding
== encoding