[RISCV] Fix the code alignment for GroupFloatVectors. NFC
[llvm-project.git] / mlir / test / python / dialects / sparse_tensor / dialect.py
blob581f5eab250cf3c5fdeba1db1efbcd7177d02797
1 # RUN: %PYTHON %s | FileCheck %s
3 from mlir.ir import *
4 from mlir.dialects import sparse_tensor as st
6 def run(f):
7 print("\nTEST:", f.__name__)
8 f()
9 return f
12 # CHECK-LABEL: TEST: testEncodingAttr1D
13 @run
14 def testEncodingAttr1D():
15 with Context() as ctx:
16 parsed = Attribute.parse(
17 '#sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], '
18 'pointerBitWidth = 16, indexBitWidth = 32 }>')
19 print(parsed)
21 casted = st.EncodingAttr(parsed)
22 # CHECK: equal: True
23 print(f"equal: {casted == parsed}")
25 # CHECK: dim_level_types: [<DimLevelType.compressed: 1>]
26 print(f"dim_level_types: {casted.dim_level_types}")
27 # CHECK: dim_ordering: None
28 # Note that for 1D, the ordering is None, which exercises several special
29 # cases.
30 print(f"dim_ordering: {casted.dim_ordering}")
31 # CHECK: pointer_bit_width: 16
32 print(f"pointer_bit_width: {casted.pointer_bit_width}")
33 # CHECK: index_bit_width: 32
34 print(f"index_bit_width: {casted.index_bit_width}")
36 created = st.EncodingAttr.get(casted.dim_level_types, None, 16, 32)
37 print(created)
38 # CHECK: created_equal: True
39 print(f"created_equal: {created == casted}")
41 # Verify that the factory creates an instance of the proper type.
42 # CHECK: is_proper_instance: True
43 print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
44 # CHECK: created_pointer_bit_width: 16
45 print(f"created_pointer_bit_width: {created.pointer_bit_width}")
48 # CHECK-LABEL: TEST: testEncodingAttr2D
49 @run
50 def testEncodingAttr2D():
51 with Context() as ctx:
52 parsed = Attribute.parse(
53 '#sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], '
54 'dimOrdering = affine_map<(d0, d1) -> (d0, d1)>, '
55 'pointerBitWidth = 16, indexBitWidth = 32 }>')
56 print(parsed)
58 casted = st.EncodingAttr(parsed)
59 # CHECK: equal: True
60 print(f"equal: {casted == parsed}")
62 # CHECK: dim_level_types: [<DimLevelType.dense: 0>, <DimLevelType.compressed: 1>]
63 print(f"dim_level_types: {casted.dim_level_types}")
64 # CHECK: dim_ordering: (d0, d1) -> (d0, d1)
65 print(f"dim_ordering: {casted.dim_ordering}")
66 # CHECK: pointer_bit_width: 16
67 print(f"pointer_bit_width: {casted.pointer_bit_width}")
68 # CHECK: index_bit_width: 32
69 print(f"index_bit_width: {casted.index_bit_width}")
71 created = st.EncodingAttr.get(casted.dim_level_types, casted.dim_ordering,
72 16, 32)
73 print(created)
74 # CHECK: created_equal: True
75 print(f"created_equal: {created == casted}")
78 # CHECK-LABEL: TEST: testEncodingAttrOnTensor
79 @run
80 def testEncodingAttrOnTensor():
81 with Context() as ctx, Location.unknown():
82 encoding = st.EncodingAttr(Attribute.parse(
83 '#sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], '
84 'pointerBitWidth = 16, indexBitWidth = 32 }>'))
85 tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding)
86 # CHECK: tensor<1024xf32, #sparse_tensor
87 print(tt)
88 # CHECK: #sparse_tensor.encoding
89 print(tt.encoding)
90 assert tt.encoding == encoding