codegen: fix a bug introduced by 9e43b00888e59da6e5b0def6f3c49a823094f9d4
[ajla.git] / mpint.h
blob92998460b56bfeb5ae28638bde185a4e120ba5a2
1 /*
2 * Copyright (C) 2024 Mikulas Patocka
4 * This file is part of Ajla.
6 * Ajla is free software: you can redistribute it and/or modify it under the
7 * terms of the GNU General Public License as published by the Free Software
8 * Foundation, either version 3 of the License, or (at your option) any later
9 * version.
11 * Ajla is distributed in the hope that it will be useful, but WITHOUT ANY
12 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
13 * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
15 * You should have received a copy of the GNU General Public License along with
16 * Ajla. If not, see <https://www.gnu.org/licenses/>.
19 #ifndef AJLA_MPINT_H
20 #define AJLA_MPINT_H
22 #include "str.h"
23 #include "code-op.h"
26 #if !defined(MPINT_GMP)
27 #include "mini-gmp.h"
28 #elif defined(HAVE_GMP_H)
29 #include <gmp.h>
30 #elif defined(HAVE_GMP_GMP_H)
31 #include <gmp/gmp.h>
32 #endif
34 typedef MP_INT mpint_t;
36 #define mp_direct long
37 #define mp_udirect unsigned long
39 static inline unsigned long mpint_estimate_bits(const mpint_t *t)
41 return (unsigned long)mpz_size(t) * GMP_NUMB_BITS;
44 static inline bool mpint_alloc_mayfail(mpint_t *t, unsigned long bits, ajla_error_t attr_unused *err)
46 mpz_init2(t, bits);
47 return true;
50 static inline bool mpint_alloc_copy_mayfail(mpint_t *t, const mpint_t *src, ajla_error_t attr_unused *err)
52 mpz_init_set(t, src);
53 return true;
56 static inline void mpint_free(mpint_t *t)
58 mpz_clear(t);
61 static inline bool mpint_negative(const mpint_t *t)
63 return mpz_sgn(t) < 0;
66 #define mpint_conv_int(n, type, utype, sz, bits) \
67 static inline bool cat(mpint_set_from_,type)(mpint_t *t, type val, bool uns, ajla_error_t attr_unused *err)\
68 { \
69 if (sizeof(type) <= sizeof(mp_direct)) { \
70 if (!uns) \
71 mpz_set_si(t, (mp_direct)val); \
72 else \
73 mpz_set_ui(t, (mp_udirect)(cat(u,type))val); \
74 } else { \
75 bool sign = val < 0 && !uns; \
76 if (unlikely(sign)) \
77 val = -(utype)val; \
79 mpz_import(t, 1, 1, sizeof(type), 0, 0, &val); \
80 if (unlikely(sign)) \
81 mpz_neg(t, t); \
82 } \
83 return true; \
84 } \
86 static inline bool cat(mpint_init_from_,type)(mpint_t *t, type val, ajla_error_t *err)\
87 { \
88 if (sizeof(type) <= sizeof(mp_direct)) { \
89 mpz_init_set_si(t, (mp_direct)val); \
90 } else { \
91 if (unlikely(!mpint_alloc_mayfail(t, sizeof(type) * 8, err)))\
92 return false; \
93 if (unlikely(!cat(mpint_set_from_,type)(t, val, false, err))) {\
94 mpz_clear(t); \
95 return false; \
96 } \
97 } \
98 return true; \
99 } \
101 static attr_always_inline bool cat(mpint_export_to_,type)(const mpint_t *t, type *result, ajla_error_t *err)\
103 if (mpz_fits_slong_p(t)) { \
104 long l; \
105 l = mpz_get_si(t); \
106 if (unlikely(l != (type)l)) \
107 goto doesnt_fit; \
108 *result = (type)l; \
109 return true; \
110 } else if (sizeof(type) > sizeof(long)) { \
111 utype ui; \
112 size_t bit = mpz_sizeinbase(t, 2); \
113 if (bit > 8 * sizeof(type)) \
114 goto doesnt_fit; \
115 (void)mpz_export(&ui, NULL, 1, sizeof(utype), 0, 0, t); \
116 if (mpz_sgn(t) >= 0) { \
117 if ((type)ui < 0) \
118 goto doesnt_fit; \
119 } else { \
120 ui = -ui; \
121 if ((type)ui >= 0) \
122 goto doesnt_fit; \
124 *result = (type)ui; \
125 return true; \
128 doesnt_fit: \
129 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_DOESNT_FIT), err, "integer too large for the target type");\
130 return false; \
133 static attr_always_inline bool cat(mpint_export_to_,utype)(const mpint_t *t, utype *result, ajla_error_t *err)\
135 if (mpz_fits_ulong_p(t)) { \
136 unsigned long l; \
137 l = mpz_get_ui(t); \
138 if (unlikely(l != (utype)l)) \
139 goto doesnt_fit; \
140 *result = (utype)l; \
141 return true; \
142 } else if (sizeof(utype) > sizeof(unsigned long)) { \
143 size_t bit; \
144 if (unlikely(mpz_sgn(t) < 0)) \
145 goto doesnt_fit; \
146 bit = mpz_sizeinbase(t, 2); \
147 if (bit > 8 * sizeof(utype)) \
148 goto doesnt_fit; \
149 (void)mpz_export(result, NULL, 1, sizeof(utype), 0, 0, t);\
150 return true; \
153 doesnt_fit: \
154 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_DOESNT_FIT), err, "integer too large for the target type");\
155 return false; \
157 for_all_int(mpint_conv_int, for_all_empty)
158 #undef mpint_conv_int
161 static inline bool mpint_import_from_code(mpint_t *m, const code_t *code, ip_t n_words, ajla_error_t *err)
163 if (unlikely(n_words >= size_t_limit + uzero)) {
164 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INT_TOO_LARGE), err, "integer too large");
165 return false;
167 if (likely(n_words != 0) && unlikely((code[!CODE_ENDIAN ? n_words - 1 : 0] & sign_bit(code_t)) != 0)) {
168 size_t i;
169 code_t *copy = mem_alloc_array_mayfail(mem_alloc_mayfail, code_t *, 0, 0, n_words, sizeof(code_t), err);
170 if (unlikely(!copy))
171 return false;
172 for (i = 0; i < n_words; i++)
173 copy[i] = ~code[i];
174 mpz_import(m, n_words, !CODE_ENDIAN ? -1 : 1, sizeof(code_t), 0, 0, copy);
175 mem_free(copy);
176 mpz_com(m, m);
177 } else {
178 mpz_import(m, n_words, !CODE_ENDIAN ? -1 : 1, sizeof(code_t), 0, 0, code);
180 return true;
183 #define mpint_import_from_variable(m, type, var) \
184 do { \
185 if (!is_unsigned(type) && unlikely((var) < (type)zero)) { \
186 type var2 = -(var); \
187 mpz_import((m), 1, 1, sizeof(var), 0, 0, &var2); \
188 mpz_neg((m), (m)); \
189 } else { \
190 mpz_import((m), 1, 1, sizeof(var), 0, 0, &(var)); \
192 } while (0)
194 #define mpint_export_to_variable(m, type, var, success) \
195 do { \
196 size_t bit; \
197 success = true; \
198 (var) = 0; \
199 bit = mpz_sizeinbase(m, 2); \
200 if (unlikely(bit > 8 * sizeof(type))) { \
201 success = false; \
202 break; \
204 mpz_export(&(var), NULL, 1, sizeof(type), 0, 0, (m)); \
205 if (likely(mpz_sgn(m) >= 0)) { \
206 if (unlikely((var) < (type)zero)) \
207 success = false; \
208 } else { \
209 if (is_unsigned(type)) \
210 success = false; \
211 if (likely((var) != sign_bit(type))) \
212 (var) = -(var); \
213 if (unlikely((var) >= (type)zero)) \
214 success = false; \
216 } while (0)
219 bool attr_fastcall mpint_add(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err);
220 bool attr_fastcall mpint_subtract(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err);
221 bool attr_fastcall mpint_multiply(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err);
222 bool attr_fastcall mpint_divide(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err);
223 bool attr_fastcall mpint_modulo(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err);
224 bool attr_fastcall mpint_power(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err);
225 bool attr_fastcall mpint_and(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err);
226 bool attr_fastcall mpint_or(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err);
227 bool attr_fastcall mpint_xor(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err);
228 bool attr_fastcall mpint_shl(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err);
229 bool attr_fastcall mpint_shr(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err);
230 bool attr_fastcall mpint_bts(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err);
231 bool attr_fastcall mpint_btr(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err);
232 bool attr_fastcall mpint_btc(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err);
234 bool attr_fastcall mpint_equal(const mpint_t *s1, const mpint_t *s2, ajla_flat_option_t *r, ajla_error_t *err);
235 bool attr_fastcall mpint_not_equal(const mpint_t *s1, const mpint_t *s2, ajla_flat_option_t *r, ajla_error_t *err);
236 bool attr_fastcall mpint_less(const mpint_t *s1, const mpint_t *s2, ajla_flat_option_t *r, ajla_error_t *err);
237 bool attr_fastcall mpint_less_equal(const mpint_t *s1, const mpint_t *s2, ajla_flat_option_t *r, ajla_error_t *err);
238 bool attr_fastcall mpint_bt(const mpint_t *s1, const mpint_t *s2, ajla_flat_option_t *r, ajla_error_t *err);
240 bool attr_fastcall mpint_not(const mpint_t *s, mpint_t *r, ajla_error_t *err);
241 bool attr_fastcall mpint_neg(const mpint_t *s, mpint_t *r, ajla_error_t *err);
242 bool attr_fastcall mpint_inc(const mpint_t *s, mpint_t *r, ajla_error_t *err);
243 bool attr_fastcall mpint_dec(const mpint_t *s, mpint_t *r, ajla_error_t *err);
244 bool attr_fastcall mpint_bsf(const mpint_t *s, mpint_t *r, ajla_error_t *err);
245 bool attr_fastcall mpint_bsr(const mpint_t *s, mpint_t *r, ajla_error_t *err);
246 bool attr_fastcall mpint_popcnt(const mpint_t *s, mpint_t *r, ajla_error_t *err);
248 #define mpint_divide_alt1 mpint_divide
249 #define mpint_modulo_alt1 mpint_modulo
250 #define mpint_popcnt_alt1 mpint_popcnt
252 #define mpint_conv_real(n, type, ntype, pack, unpack) \
253 bool attr_fastcall cat(mpint_init_from_,type)(mpint_t *t, type *valp, ajla_error_t *err);\
254 void attr_fastcall cat(mpint_export_to_,type)(const mpint_t *t, type *result);
255 for_all_real(mpint_conv_real, for_all_empty)
256 #undef mpint_conv_real
258 bool mpint_export_to_blob(const mpint_t *s, uint8_t **blob, size_t *blob_size, ajla_error_t *err);
260 static inline void str_add_mpint(char **s, size_t *l, const mpint_t *mp, uint16_t base_n)
262 mpint_t mod, base, num;
263 int8_t digit;
264 int16_t base_m = (int16_t)base_n;
265 ajla_flat_option_t neg, bo;
266 size_t pos, i;
268 pos = *l;
270 mpint_init_from_int8_t(&mod, 0, NULL);
271 mpint_init_from_int16_t(&base, base_m, NULL);
272 mpint_alloc_mayfail(&num, 0, NULL);
273 mpint_add(mp, &mod, &num, NULL);
274 mpint_less(&num, &mod, &neg, NULL);
275 if (neg) {
276 mpint_neg(&num, &num, NULL);
279 do {
280 mpint_modulo(&num, &base, &mod, NULL);
281 digit = 0; /* avoid warning */
282 mpint_export_to_int8_t(&mod, &digit, NULL);
283 str_add_char(s, l, digit <= 9 ? '0' + (char)digit : 'a' - 10 + (char)digit);
284 mpint_less(&num, &base, &bo, NULL);
285 if (!bo) mpint_divide(&num, &base, &num, NULL);
286 } while (!bo);
288 if (neg)
289 str_add_char(s, l, '-');
291 mpint_free(&mod);
292 mpint_free(&base);
293 mpint_free(&num);
295 for (i = 0; i < (*l - pos) / 2; i++) {
296 char c = (*s)[pos + i];
297 (*s)[pos + i] = (*s)[*l - 1 - i];
298 (*s)[*l - 1 - i] = c;
303 void mpint_init(void);
304 void mpint_done(void);
306 #endif