1 //===- sparse_tensor.c - Test of sparse_tensor APIs -----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //===----------------------------------------------------------------------===//
10 // RUN: mlir-capi-sparse-tensor-test 2>&1 | FileCheck %s
12 #include "mlir-c/Dialect/SparseTensor.h"
13 #include "mlir-c/IR.h"
14 #include "mlir-c/Registration.h"
22 // CHECK-LABEL: testRoundtripEncoding()
23 static int testRoundtripEncoding(MlirContext ctx
) {
24 fprintf(stderr
, "testRoundtripEncoding()\n");
26 const char *originalAsm
=
27 "#sparse_tensor.encoding<{ "
28 "dimLevelType = [ \"dense\", \"compressed\", \"singleton\"], "
29 "dimOrdering = affine_map<(d0, d1, d2) -> (d0, d1, d2)>, "
30 "pointerBitWidth = 32, indexBitWidth = 64 }>";
32 MlirAttribute originalAttr
=
33 mlirAttributeParseGet(ctx
, mlirStringRefCreateFromCString(originalAsm
));
35 fprintf(stderr
, "isa: %d\n",
36 mlirAttributeIsASparseTensorEncodingAttr(originalAttr
));
37 MlirAffineMap dimOrdering
=
38 mlirSparseTensorEncodingAttrGetDimOrdering(originalAttr
);
39 // CHECK: (d0, d1, d2) -> (d0, d1, d2)
40 mlirAffineMapDump(dimOrdering
);
41 // CHECK: level_type: 0
42 // CHECK: level_type: 1
43 // CHECK: level_type: 2
44 int numLevelTypes
= mlirSparseTensorEncodingGetNumDimLevelTypes(originalAttr
);
45 enum MlirSparseTensorDimLevelType
*levelTypes
=
46 malloc(sizeof(enum MlirSparseTensorDimLevelType
) * numLevelTypes
);
47 for (int i
= 0; i
< numLevelTypes
; ++i
) {
49 mlirSparseTensorEncodingAttrGetDimLevelType(originalAttr
, i
);
50 fprintf(stderr
, "level_type: %d\n", levelTypes
[i
]);
54 mlirSparseTensorEncodingAttrGetPointerBitWidth(originalAttr
);
55 fprintf(stderr
, "pointer: %d\n", pointerBitWidth
);
58 mlirSparseTensorEncodingAttrGetIndexBitWidth(originalAttr
);
59 fprintf(stderr
, "index: %d\n", indexBitWidth
);
61 MlirAttribute newAttr
= mlirSparseTensorEncodingAttrGet(
62 ctx
, numLevelTypes
, levelTypes
, dimOrdering
, pointerBitWidth
,
64 mlirAttributeDump(newAttr
); // For debugging filecheck output.
66 fprintf(stderr
, "equal: %d\n", mlirAttributeEqual(originalAttr
, newAttr
));
73 MlirContext ctx
= mlirContextCreate();
74 mlirDialectHandleRegisterDialect(mlirGetDialectHandle__sparse_tensor__(),
76 if (testRoundtripEncoding(ctx
))
79 mlirContextDestroy(ctx
);