1 //===- sparse_tensor.c - Test of sparse_tensor APIs -----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //===----------------------------------------------------------------------===//
10 // RUN: mlir-capi-sparse-tensor-test 2>&1 | FileCheck %s
12 #include "mlir-c/Dialect/SparseTensor.h"
13 #include "mlir-c/IR.h"
14 #include "mlir-c/RegisterEverything.h"
22 // CHECK-LABEL: testRoundtripEncoding()
23 static int testRoundtripEncoding(MlirContext ctx
) {
24 fprintf(stderr
, "testRoundtripEncoding()\n");
26 const char *originalAsm
=
27 "#sparse_tensor.encoding<{ "
28 "map = [s0](d0, d1) -> (s0 : dense, d0 : compressed, d1 : compressed), "
29 "posWidth = 32, crdWidth = 64 }>";
31 MlirAttribute originalAttr
=
32 mlirAttributeParseGet(ctx
, mlirStringRefCreateFromCString(originalAsm
));
34 fprintf(stderr
, "isa: %d\n",
35 mlirAttributeIsASparseTensorEncodingAttr(originalAttr
));
36 MlirAffineMap dimToLvl
=
37 mlirSparseTensorEncodingAttrGetDimToLvl(originalAttr
);
38 // CHECK: (d0, d1)[s0] -> (s0, d0, d1)
39 mlirAffineMapDump(dimToLvl
);
40 // CHECK: level_type: 4
41 // CHECK: level_type: 8
42 // CHECK: level_type: 8
43 MlirAffineMap lvlToDim
=
44 mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr
);
45 int lvlRank
= mlirSparseTensorEncodingGetLvlRank(originalAttr
);
46 enum MlirSparseTensorDimLevelType
*lvlTypes
=
47 malloc(sizeof(enum MlirSparseTensorDimLevelType
) * lvlRank
);
48 for (int l
= 0; l
< lvlRank
; ++l
) {
49 lvlTypes
[l
] = mlirSparseTensorEncodingAttrGetLvlType(originalAttr
, l
);
50 fprintf(stderr
, "level_type: %d\n", lvlTypes
[l
]);
52 // CHECK: posWidth: 32
53 int posWidth
= mlirSparseTensorEncodingAttrGetPosWidth(originalAttr
);
54 fprintf(stderr
, "posWidth: %d\n", posWidth
);
55 // CHECK: crdWidth: 64
56 int crdWidth
= mlirSparseTensorEncodingAttrGetCrdWidth(originalAttr
);
57 fprintf(stderr
, "crdWidth: %d\n", crdWidth
);
58 MlirAttribute newAttr
= mlirSparseTensorEncodingAttrGet(
59 ctx
, lvlRank
, lvlTypes
, dimToLvl
, lvlToDim
, posWidth
, crdWidth
);
60 mlirAttributeDump(newAttr
); // For debugging filecheck output.
62 fprintf(stderr
, "equal: %d\n", mlirAttributeEqual(originalAttr
, newAttr
));
68 MlirContext ctx
= mlirContextCreate();
69 mlirDialectHandleRegisterDialect(mlirGetDialectHandle__sparse_tensor__(),
71 if (testRoundtripEncoding(ctx
))
74 mlirContextDestroy(ctx
);