1 //===- quant.c - Test of Quant dialect C API ------------------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8 //===----------------------------------------------------------------------===//
10 // RUN: mlir-capi-quant-test 2>&1 | FileCheck %s
12 #include "mlir-c/Dialect/Quant.h"
13 #include "mlir-c/BuiltinTypes.h"
14 #include "mlir-c/IR.h"
21 // CHECK-LABEL: testTypeHierarchy
22 static void testTypeHierarchy(MlirContext ctx
) {
23 fprintf(stderr
, "testTypeHierarchy\n");
25 MlirType i8
= mlirIntegerTypeGet(ctx
, 8);
26 MlirType any
= mlirTypeParseGet(
27 ctx
, mlirStringRefCreateFromCString("!quant.any<i8<-8:7>:f32>"));
29 mlirTypeParseGet(ctx
, mlirStringRefCreateFromCString(
30 "!quant.uniform<i8<-8:7>:f32, 0.99872:127>"));
31 MlirType perAxis
= mlirTypeParseGet(
32 ctx
, mlirStringRefCreateFromCString(
33 "!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>"));
34 MlirType calibrated
= mlirTypeParseGet(
36 mlirStringRefCreateFromCString("!quant.calibrated<f32<-0.998:1.2321>>"));
38 // The parser itself is checked in C++ dialect tests.
39 assert(!mlirTypeIsNull(any
) && "couldn't parse AnyQuantizedType");
40 assert(!mlirTypeIsNull(uniform
) && "couldn't parse UniformQuantizedType");
41 assert(!mlirTypeIsNull(perAxis
) &&
42 "couldn't parse UniformQuantizedPerAxisType");
43 assert(!mlirTypeIsNull(calibrated
) &&
44 "couldn't parse CalibratedQuantizedType");
46 // CHECK: i8 isa QuantizedType: 0
47 fprintf(stderr
, "i8 isa QuantizedType: %d\n", mlirTypeIsAQuantizedType(i8
));
48 // CHECK: any isa QuantizedType: 1
49 fprintf(stderr
, "any isa QuantizedType: %d\n", mlirTypeIsAQuantizedType(any
));
50 // CHECK: uniform isa QuantizedType: 1
51 fprintf(stderr
, "uniform isa QuantizedType: %d\n",
52 mlirTypeIsAQuantizedType(uniform
));
53 // CHECK: perAxis isa QuantizedType: 1
54 fprintf(stderr
, "perAxis isa QuantizedType: %d\n",
55 mlirTypeIsAQuantizedType(perAxis
));
56 // CHECK: calibrated isa QuantizedType: 1
57 fprintf(stderr
, "calibrated isa QuantizedType: %d\n",
58 mlirTypeIsAQuantizedType(calibrated
));
60 // CHECK: any isa AnyQuantizedType: 1
61 fprintf(stderr
, "any isa AnyQuantizedType: %d\n",
62 mlirTypeIsAAnyQuantizedType(any
));
63 // CHECK: uniform isa UniformQuantizedType: 1
64 fprintf(stderr
, "uniform isa UniformQuantizedType: %d\n",
65 mlirTypeIsAUniformQuantizedType(uniform
));
66 // CHECK: perAxis isa UniformQuantizedPerAxisType: 1
67 fprintf(stderr
, "perAxis isa UniformQuantizedPerAxisType: %d\n",
68 mlirTypeIsAUniformQuantizedPerAxisType(perAxis
));
69 // CHECK: calibrated isa CalibratedQuantizedType: 1
70 fprintf(stderr
, "calibrated isa CalibratedQuantizedType: %d\n",
71 mlirTypeIsACalibratedQuantizedType(calibrated
));
73 // CHECK: perAxis isa UniformQuantizedType: 0
74 fprintf(stderr
, "perAxis isa UniformQuantizedType: %d\n",
75 mlirTypeIsAUniformQuantizedType(perAxis
));
76 // CHECK: uniform isa CalibratedQuantizedType: 0
77 fprintf(stderr
, "uniform isa CalibratedQuantizedType: %d\n",
78 mlirTypeIsACalibratedQuantizedType(uniform
));
79 fprintf(stderr
, "\n");
82 // CHECK-LABEL: testAnyQuantizedType
83 void testAnyQuantizedType(MlirContext ctx
) {
84 fprintf(stderr
, "testAnyQuantizedType\n");
86 MlirType anyParsed
= mlirTypeParseGet(
87 ctx
, mlirStringRefCreateFromCString("!quant.any<i8<-8:7>:f32>"));
89 MlirType i8
= mlirIntegerTypeGet(ctx
, 8);
90 MlirType f32
= mlirF32TypeGet(ctx
);
92 mlirAnyQuantizedTypeGet(mlirQuantizedTypeGetSignedFlag(), i8
, f32
, -8, 7);
95 fprintf(stderr
, "flags: %u\n", mlirQuantizedTypeGetFlags(any
));
97 fprintf(stderr
, "signed: %u\n", mlirQuantizedTypeIsSigned(any
));
98 // CHECK: storage type: i8
99 fprintf(stderr
, "storage type: ");
100 mlirTypeDump(mlirQuantizedTypeGetStorageType(any
));
101 fprintf(stderr
, "\n");
102 // CHECK: expressed type: f32
103 fprintf(stderr
, "expressed type: ");
104 mlirTypeDump(mlirQuantizedTypeGetExpressedType(any
));
105 fprintf(stderr
, "\n");
106 // CHECK: storage min: -8
107 fprintf(stderr
, "storage min: %" PRId64
"\n",
108 mlirQuantizedTypeGetStorageTypeMin(any
));
109 // CHECK: storage max: 7
110 fprintf(stderr
, "storage max: %" PRId64
"\n",
111 mlirQuantizedTypeGetStorageTypeMax(any
));
112 // CHECK: storage width: 8
113 fprintf(stderr
, "storage width: %u\n",
114 mlirQuantizedTypeGetStorageTypeIntegralWidth(any
));
115 // CHECK: quantized element type: !quant.any<i8<-8:7>:f32>
116 fprintf(stderr
, "quantized element type: ");
117 mlirTypeDump(mlirQuantizedTypeGetQuantizedElementType(any
));
118 fprintf(stderr
, "\n");
121 fprintf(stderr
, "equal: %d\n", mlirTypeEqual(anyParsed
, any
));
122 // CHECK: !quant.any<i8<-8:7>:f32>
124 fprintf(stderr
, "\n\n");
127 // CHECK-LABEL: testUniformType
128 void testUniformType(MlirContext ctx
) {
129 fprintf(stderr
, "testUniformType\n");
131 MlirType uniformParsed
=
132 mlirTypeParseGet(ctx
, mlirStringRefCreateFromCString(
133 "!quant.uniform<i8<-8:7>:f32, 0.99872:127>"));
135 MlirType i8
= mlirIntegerTypeGet(ctx
, 8);
136 MlirType f32
= mlirF32TypeGet(ctx
);
137 MlirType uniform
= mlirUniformQuantizedTypeGet(
138 mlirQuantizedTypeGetSignedFlag(), i8
, f32
, 0.99872, 127, -8, 7);
140 // CHECK: scale: 0.998720
141 fprintf(stderr
, "scale: %lf\n", mlirUniformQuantizedTypeGetScale(uniform
));
142 // CHECK: zero point: 127
143 fprintf(stderr
, "zero point: %" PRId64
"\n",
144 mlirUniformQuantizedTypeGetZeroPoint(uniform
));
145 // CHECK: fixed point: 0
146 fprintf(stderr
, "fixed point: %d\n",
147 mlirUniformQuantizedTypeIsFixedPoint(uniform
));
150 fprintf(stderr
, "equal: %d\n", mlirTypeEqual(uniform
, uniformParsed
));
151 // CHECK: !quant.uniform<i8<-8:7>:f32, 9.987200e-01:127>
152 mlirTypeDump(uniform
);
153 fprintf(stderr
, "\n\n");
156 // CHECK-LABEL: testUniformPerAxisType
157 void testUniformPerAxisType(MlirContext ctx
) {
158 fprintf(stderr
, "testUniformPerAxisType\n");
160 MlirType perAxisParsed
= mlirTypeParseGet(
161 ctx
, mlirStringRefCreateFromCString(
162 "!quant.uniform<i8:f32:1, {2.0e+2,0.99872:120}>"));
164 MlirType i8
= mlirIntegerTypeGet(ctx
, 8);
165 MlirType f32
= mlirF32TypeGet(ctx
);
166 double scales
[] = {200.0, 0.99872};
167 int64_t zeroPoints
[] = {0, 120};
168 MlirType perAxis
= mlirUniformQuantizedPerAxisTypeGet(
169 mlirQuantizedTypeGetSignedFlag(), i8
, f32
,
170 /*nDims=*/2, scales
, zeroPoints
,
171 /*quantizedDimension=*/1,
172 mlirQuantizedTypeGetDefaultMinimumForInteger(/*isSigned=*/true,
173 /*integralWidth=*/8),
174 mlirQuantizedTypeGetDefaultMaximumForInteger(/*isSigned=*/true,
175 /*integralWidth=*/8));
177 // CHECK: num dims: 2
178 fprintf(stderr
, "num dims: %" PRIdPTR
"\n",
179 mlirUniformQuantizedPerAxisTypeGetNumDims(perAxis
));
180 // CHECK: scale 0: 200.000000
181 fprintf(stderr
, "scale 0: %lf\n",
182 mlirUniformQuantizedPerAxisTypeGetScale(perAxis
, 0));
183 // CHECK: scale 1: 0.998720
184 fprintf(stderr
, "scale 1: %lf\n",
185 mlirUniformQuantizedPerAxisTypeGetScale(perAxis
, 1));
186 // CHECK: zero point 0: 0
187 fprintf(stderr
, "zero point 0: %" PRId64
"\n",
188 mlirUniformQuantizedPerAxisTypeGetZeroPoint(perAxis
, 0));
189 // CHECK: zero point 1: 120
190 fprintf(stderr
, "zero point 1: %" PRId64
"\n",
191 mlirUniformQuantizedPerAxisTypeGetZeroPoint(perAxis
, 1));
192 // CHECK: quantized dim: 1
193 fprintf(stderr
, "quantized dim: %" PRId32
"\n",
194 mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(perAxis
));
195 // CHECK: fixed point: 0
196 fprintf(stderr
, "fixed point: %d\n",
197 mlirUniformQuantizedPerAxisTypeIsFixedPoint(perAxis
));
200 fprintf(stderr
, "equal: %d\n", mlirTypeEqual(perAxis
, perAxisParsed
));
201 // CHECK: !quant.uniform<i8:f32:1, {2.000000e+02,9.987200e-01:120}>
202 mlirTypeDump(perAxis
);
203 fprintf(stderr
, "\n\n");
206 // CHECK-LABEL: testCalibratedType
207 void testCalibratedType(MlirContext ctx
) {
208 fprintf(stderr
, "testCalibratedType\n");
210 MlirType calibratedParsed
= mlirTypeParseGet(
212 mlirStringRefCreateFromCString("!quant.calibrated<f32<-0.998:1.2321>>"));
214 MlirType f32
= mlirF32TypeGet(ctx
);
215 MlirType calibrated
= mlirCalibratedQuantizedTypeGet(f32
, -0.998, 1.2321);
217 // CHECK: min: -0.998000
218 fprintf(stderr
, "min: %lf\n", mlirCalibratedQuantizedTypeGetMin(calibrated
));
219 // CHECK: max: 1.232100
220 fprintf(stderr
, "max: %lf\n", mlirCalibratedQuantizedTypeGetMax(calibrated
));
223 fprintf(stderr
, "equal: %d\n", mlirTypeEqual(calibrated
, calibratedParsed
));
224 // CHECK: !quant.calibrated<f32<-0.998:1.232100e+00>>
225 mlirTypeDump(calibrated
);
226 fprintf(stderr
, "\n\n");
230 MlirContext ctx
= mlirContextCreate();
231 mlirDialectHandleRegisterDialect(mlirGetDialectHandle__quant__(), ctx
);
232 testTypeHierarchy(ctx
);
233 testAnyQuantizedType(ctx
);
234 testUniformType(ctx
);
235 testUniformPerAxisType(ctx
);
236 testCalibratedType(ctx
);
237 mlirContextDestroy(ctx
);