1 //===- sparse_tensor.c - Test of sparse_tensor APIs -----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //===----------------------------------------------------------------------===//
10 // RUN: mlir-capi-sparse-tensor-test 2>&1 | FileCheck %s
12 #include "mlir-c/Dialect/SparseTensor.h"
13 #include "mlir-c/IR.h"
14 #include "mlir-c/RegisterEverything.h"
23 // CHECK-LABEL: testRoundtripEncoding()
24 static int testRoundtripEncoding(MlirContext ctx
) {
25 fprintf(stderr
, "testRoundtripEncoding()\n");
27 const char *originalAsm
=
28 "#sparse_tensor.encoding<{ "
29 "map = [s0](d0, d1) -> (s0 : dense, d0 : compressed, d1 : compressed), "
30 "posWidth = 32, crdWidth = 64, explicitVal = 1 : i64}>";
32 MlirAttribute originalAttr
=
33 mlirAttributeParseGet(ctx
, mlirStringRefCreateFromCString(originalAsm
));
35 fprintf(stderr
, "isa: %d\n",
36 mlirAttributeIsASparseTensorEncodingAttr(originalAttr
));
37 MlirAffineMap dimToLvl
=
38 mlirSparseTensorEncodingAttrGetDimToLvl(originalAttr
);
39 // CHECK: (d0, d1)[s0] -> (s0, d0, d1)
40 mlirAffineMapDump(dimToLvl
);
41 // CHECK: level_type: 65536
42 // CHECK: level_type: 262144
43 // CHECK: level_type: 262144
44 MlirAffineMap lvlToDim
=
45 mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr
);
46 int lvlRank
= mlirSparseTensorEncodingGetLvlRank(originalAttr
);
47 MlirSparseTensorLevelType
*lvlTypes
=
48 malloc(sizeof(MlirSparseTensorLevelType
) * lvlRank
);
49 for (int l
= 0; l
< lvlRank
; ++l
) {
50 lvlTypes
[l
] = mlirSparseTensorEncodingAttrGetLvlType(originalAttr
, l
);
51 fprintf(stderr
, "level_type: %" PRIu64
"\n", lvlTypes
[l
]);
53 // CHECK: posWidth: 32
54 int posWidth
= mlirSparseTensorEncodingAttrGetPosWidth(originalAttr
);
55 fprintf(stderr
, "posWidth: %d\n", posWidth
);
56 // CHECK: crdWidth: 64
57 int crdWidth
= mlirSparseTensorEncodingAttrGetCrdWidth(originalAttr
);
58 fprintf(stderr
, "crdWidth: %d\n", crdWidth
);
60 // CHECK: explicitVal: 1 : i64
61 MlirAttribute explicitVal
=
62 mlirSparseTensorEncodingAttrGetExplicitVal(originalAttr
);
63 fprintf(stderr
, "explicitVal: ");
64 mlirAttributeDump(explicitVal
);
65 // CHECK: implicitVal: <<NULL ATTRIBUTE>>
66 MlirAttribute implicitVal
=
67 mlirSparseTensorEncodingAttrGetImplicitVal(originalAttr
);
68 fprintf(stderr
, "implicitVal: ");
69 mlirAttributeDump(implicitVal
);
71 MlirAttribute newAttr
= mlirSparseTensorEncodingAttrGet(
72 ctx
, lvlRank
, lvlTypes
, dimToLvl
, lvlToDim
, posWidth
, crdWidth
,
73 explicitVal
, implicitVal
);
74 mlirAttributeDump(newAttr
); // For debugging filecheck output.
76 fprintf(stderr
, "equal: %d\n", mlirAttributeEqual(originalAttr
, newAttr
));
82 MlirContext ctx
= mlirContextCreate();
83 mlirDialectHandleRegisterDialect(mlirGetDialectHandle__sparse_tensor__(),
85 if (testRoundtripEncoding(ctx
))
88 mlirContextDestroy(ctx
);