codegen: add a 'size' argument to ALU_WRITES_FLAGS
[ajla.git] / mpint.c
blobaec41bbe57cec4a2cee4c5fb12898426e19ea2c0
1 /*
2 * Copyright (C) 2024 Mikulas Patocka
4 * This file is part of Ajla.
6 * Ajla is free software: you can redistribute it and/or modify it under the
7 * terms of the GNU General Public License as published by the Free Software
8 * Foundation, either version 3 of the License, or (at your option) any later
9 * version.
11 * Ajla is distributed in the hope that it will be useful, but WITHOUT ANY
12 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
13 * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
15 * You should have received a copy of the GNU General Public License along with
16 * Ajla. If not, see <https://www.gnu.org/licenses/>.
19 #include "ajla.h"
21 #include "mem_al.h"
23 #include "mpint.h"
25 #define MPINT_MAX_BITS 0x80000000UL
28 #if defined(MPINT_GMP) && __GNU_MP_VERSION+0 < 5
29 typedef unsigned long mp_bitcnt_t;
30 #endif
32 #ifndef mpz_limbs_read
33 #define mpz_limbs_read(t) ((t)->_mp_d)
34 #endif
36 #ifndef mpz_limbs_write
37 #define mpz_limbs_write(t, idx) (((size_t)(t)->_mp_alloc < (idx) ? internal(file_line, "mpz_limbs_write: not enough entries: %"PRIuMAX" < %"PRIuMAX"", (uintmax_t)(t)->_mp_alloc, (uintmax_t)(idx)), 0 : 0), (t)->_mp_d)
38 #endif
40 #ifndef mpz_limbs_finish
41 #define mpz_limbs_finish(t, idx) ((t)->_mp_size = idx)
42 #endif
44 static attr_noinline bool attr_fastcall attr_cold mpint_size_ok_slow(const mpint_t *t, ajla_error_t *err)
46 size_t sz = mpz_sizeinbase(t, 2);
47 if (sz >= MPINT_MAX_BITS) {
48 if (sz == MPINT_MAX_BITS && mpz_sgn(t) < 0 &&
49 mpz_scan1(t, 0) == MPINT_MAX_BITS - 1)
50 return true;
51 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INT_TOO_LARGE), err, "integer too large");
52 return false;
54 return true;
57 static inline bool mpint_size_ok(const mpint_t attr_unused *t, ajla_error_t *err)
59 size_t size = mpz_size(t);
60 if (likely(size <= (MPINT_MAX_BITS - 1) / GMP_NUMB_BITS))
61 return true;
62 return mpint_size_ok_slow(t, err);
65 bool attr_fastcall mpint_add(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err)
67 mpz_add(r, s1, s2);
68 if (unlikely(!mpint_size_ok(r, err))) {
69 return false;
71 return true;
74 bool attr_fastcall mpint_subtract(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err)
76 mpz_sub(r, s1, s2);
77 if (unlikely(!mpint_size_ok(r, err))) {
78 return false;
80 return true;
83 static inline bool mpint_multiply_early_check(size_t size1, size_t size2, ajla_error_t *err)
85 if (unlikely(size1 + size2 > 1 + (MPINT_MAX_BITS + GMP_NUMB_BITS - 1) / GMP_NUMB_BITS)) {
86 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INT_TOO_LARGE), err, "integer too large");
87 return false;
89 return true;
92 bool attr_fastcall mpint_multiply(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err)
94 if (unlikely(!mpint_multiply_early_check(mpz_size(s1), mpz_size(s2), err))) {
95 return false;
97 mpz_mul(r, s1, s2);
98 if (unlikely(!mpint_size_ok(r, err)))
99 return false;
100 return true;
103 bool attr_fastcall mpint_divide(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err)
105 if (unlikely(!mpz_sgn(s2))) {
106 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INVALID_OPERATION), err, "divide by zero");
107 return false;
109 mpz_tdiv_q(r, s1, s2);
110 if (unlikely(!mpint_size_ok(r, err))) {
111 return false;
113 return true;
116 bool attr_fastcall mpint_modulo(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err)
118 if (unlikely(!mpz_sgn(s2))) {
119 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INVALID_OPERATION), err, "modulo by zero");
120 return false;
122 mpz_tdiv_r(r, s1, s2);
123 return true;
126 bool attr_fastcall mpint_power(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err)
128 mpint_t x1, x2;
129 mpz_init_set(&x1, s1);
130 mpz_init_set(&x2, s2);
131 if (unlikely(mpz_sgn(&x2) < 0)) {
132 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INVALID_OPERATION), err, "power by negative number");
133 ret_err:
134 mpz_clear(&x1);
135 mpz_clear(&x2);
136 return false;
138 mpz_set_ui(r, 1);
139 while (1) {
140 if (mpz_tstbit(&x2, 0)) {
141 if (unlikely(!mpint_multiply(r, &x1, r, err)))
142 goto ret_err;
144 if (!mpz_sgn(&x2))
145 break;
146 mpz_tdiv_q_2exp(&x2, &x2, 1);
147 if (unlikely(!mpint_multiply(&x1, &x1, &x1, err)))
148 goto ret_err;
150 mpz_clear(&x1);
151 mpz_clear(&x2);
152 return true;
155 bool attr_fastcall mpint_and(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t attr_unused *err)
157 mpz_and(r, s1, s2);
158 return true;
161 bool attr_fastcall mpint_or(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t attr_unused *err)
163 mpz_ior(r, s1, s2);
164 return true;
167 bool attr_fastcall mpint_xor(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t attr_unused *err)
169 mpz_xor(r, s1, s2);
170 return true;
173 bool attr_fastcall mpint_shl(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err)
175 unsigned long sh;
176 size_t size1, size2;
177 if (unlikely(!mpz_fits_ulong_p(s2))) {
178 overflow:
179 if (unlikely((mpz_sgn(s2) < 0))) {
180 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INVALID_OPERATION), err, "shift left with negative count");
181 return false;
182 } else {
183 if (!mpz_sgn(s1)) {
184 mpz_set_ui(r, 0);
185 return true;
186 } else {
187 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INT_TOO_LARGE), err, "integer too large");
188 return false;
192 sh = mpz_get_ui(s2);
193 if (unlikely((mp_bitcnt_t)sh != sh))
194 goto overflow;
195 size1 = mpz_size(s1);
196 size2 = 1 + sh / GMP_NUMB_BITS;
197 if (unlikely(!mpint_multiply_early_check(size1, size2, err)))
198 return false;
199 mpz_mul_2exp(r, s1, sh);
200 if (unlikely(!mpint_size_ok(r, err)))
201 return false;
202 return true;
205 bool attr_fastcall mpint_shr(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err)
207 unsigned long sh;
208 if (unlikely(!mpz_fits_ulong_p(s2))) {
209 overflow:
210 if (unlikely((mpz_sgn(s2) < 0))) {
211 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INVALID_OPERATION), err, "shift right with negative count");
212 return false;
213 } else {
214 if (mpz_sgn(s1) >= 0) {
215 mpz_set_ui(r, 0);
216 return true;
217 } else {
218 mpz_set_si(r, -1);
219 return true;
223 sh = mpz_get_ui(s2);
224 if (unlikely((mp_bitcnt_t)sh != sh))
225 goto overflow;
226 mpz_fdiv_q_2exp(r, s1, sh);
227 return true;
230 static inline bool attr_fastcall mpint_btx_(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err, void (*fn_bit)(mpint_t *, mp_bitcnt_t), int mode)
232 unsigned long sh;
233 if (unlikely(!mpz_fits_ulong_p(s2))) {
234 if (unlikely(mpz_sgn(s2) < 0)) {
235 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INVALID_OPERATION), err, "bit %s with negative position", mode == 0 ? "set" : mode == 1 ? "reset" : "complement");
236 return false;
238 overflow:
239 if (mode == 0 && mpz_sgn(s1) < 0) {
240 mpz_set(r, s1);
241 return true;
243 if (mode == 1 && mpz_sgn(s1) >= 0) {
244 mpz_set(r, s1);
245 return true;
247 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INT_TOO_LARGE), err, "integer too large");
248 return false;
250 sh = mpz_get_ui(s2);
251 if (unlikely(sh != (unsigned long)(mp_bitcnt_t)sh))
252 goto overflow;
253 mpz_set(r, s1);
254 fn_bit(r, sh);
255 return true;
258 bool attr_fastcall mpint_bts(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err)
260 return mpint_btx_(s1, s2, r, err, mpz_setbit, 0);
263 bool attr_fastcall mpint_btr(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err)
265 return mpint_btx_(s1, s2, r, err, mpz_clrbit, 1);
268 bool attr_fastcall mpint_btc(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err)
270 #if !defined(mpz_combit) || defined(UNUSUAL_ARITHMETICS)
271 ajla_flat_option_t o;
272 if (unlikely(!mpint_bt(s1, s2, &o, err)))
273 return false;
274 if (!o)
275 return mpint_bts(s1, s2, r, err);
276 else
277 return mpint_btr(s1, s2, r, err);
278 #else
279 return mpint_btx_(s1, s2, r, err, mpz_combit, 2);
280 #endif
283 bool attr_fastcall mpint_equal(const mpint_t *s1, const mpint_t *s2, ajla_flat_option_t *r, ajla_error_t attr_unused *err)
285 *r = !mpz_cmp(s1, s2);
286 return true;
289 bool attr_fastcall mpint_not_equal(const mpint_t *s1, const mpint_t *s2, ajla_flat_option_t *r, ajla_error_t attr_unused *err)
291 *r = !!mpz_cmp(s1, s2);
292 return true;
295 bool attr_fastcall mpint_less(const mpint_t *s1, const mpint_t *s2, ajla_flat_option_t *r, ajla_error_t attr_unused *err)
297 *r = mpz_cmp(s1, s2) < 0;
298 return true;
301 bool attr_fastcall mpint_less_equal(const mpint_t *s1, const mpint_t *s2, ajla_flat_option_t *r, ajla_error_t attr_unused *err)
303 *r = mpz_cmp(s1, s2) <= 0;
304 return true;
307 bool attr_fastcall mpint_greater(const mpint_t *s1, const mpint_t *s2, ajla_flat_option_t *r, ajla_error_t attr_unused *err)
309 *r = mpz_cmp(s1, s2) > 0;
310 return true;
313 bool attr_fastcall mpint_greater_equal(const mpint_t *s1, const mpint_t *s2, ajla_flat_option_t *r, ajla_error_t attr_unused *err)
315 *r = mpz_cmp(s1, s2) >= 0;
316 return true;
319 bool attr_fastcall mpint_bt(const mpint_t *s1, const mpint_t *s2, ajla_flat_option_t *r, ajla_error_t *err)
321 unsigned long sh;
322 if (unlikely(!mpz_fits_ulong_p(s2))) {
323 if (unlikely(mpz_sgn(s2) < 0)) {
324 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INVALID_OPERATION), err, "bit test with negative position");
325 return false;
327 overflow:
328 *r = mpz_sgn(s1) < 0;
329 return true;
331 sh = mpz_get_ui(s2);
332 if (unlikely(sh != (unsigned long)(mp_bitcnt_t)sh))
333 goto overflow;
334 *r = mpz_tstbit(s1, sh);
335 return true;
338 bool attr_fastcall mpint_not(const mpint_t *s, mpint_t *r, ajla_error_t attr_unused *err)
340 mpz_com(r, s);
341 return true;
344 bool attr_fastcall mpint_neg(const mpint_t *s, mpint_t *r, ajla_error_t *err)
346 mpz_neg(r, s);
347 if (unlikely(!mpint_size_ok(r, err))) {
348 return false;
350 return true;
353 bool attr_fastcall mpint_inc(const mpint_t *s, mpint_t *r, ajla_error_t *err)
355 mpz_add_ui(r, s, 1);
356 if (unlikely(!mpint_size_ok(r, err))) {
357 return false;
359 return true;
362 bool attr_fastcall mpint_dec(const mpint_t *s, mpint_t *r, ajla_error_t *err)
364 mpz_sub_ui(r, s, 1);
365 if (unlikely(!mpint_size_ok(r, err))) {
366 return false;
368 return true;
371 bool attr_fastcall mpint_bsf(const mpint_t *s, mpint_t *r, ajla_error_t *err)
373 mp_bitcnt_t b;
374 if (unlikely(!mpz_sgn(s))) {
375 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INVALID_OPERATION), err, "bit scan forward with zero argument");
376 return false;
378 b = mpz_scan1(s, 0);
379 #ifndef UNUSUAL_ARITHMETICS
380 if (likely(b == (unsigned long)b))
381 mpz_set_ui(r, b);
382 else
383 #endif
384 mpz_import(r, 1, 1, sizeof(b), 0, 0, &b);
385 return true;
388 bool attr_fastcall mpint_bsr(const mpint_t *s, mpint_t *r, ajla_error_t *err)
390 size_t sz;
391 if (unlikely(mpz_sgn(s) <= 0)) {
392 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INVALID_OPERATION), err, "bit scan reverse with non-positive argument");
393 return false;
395 sz = mpz_sizeinbase(s, 2) - 1;
396 #ifndef UNUSUAL_ARITHMETICS
397 if (likely(sz == (unsigned long)sz))
398 mpz_set_ui(r, sz);
399 else
400 #endif
401 mpz_import(r, 1, 1, sizeof(sz), 0, 0, &sz);
402 return true;
405 bool attr_fastcall mpint_popcnt(const mpint_t *s, mpint_t *r, ajla_error_t *err)
407 mp_bitcnt_t b;
408 if (unlikely(mpz_sgn(s) < 0)) {
409 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INVALID_OPERATION), err, "population count with negative argument");
410 return false;
412 b = mpz_popcount(s);
413 #ifndef UNUSUAL_ARITHMETICS
414 if (likely(b == (unsigned long)b))
415 mpz_set_ui(r, b);
416 else
417 #endif
418 mpz_import(r, 1, 1, sizeof(b), 0, 0, &b);
419 return true;
422 static inline bool mpint_raw_test_bit(const mp_limb_t *ptr, size_t bit)
424 return (ptr[bit / GMP_NUMB_BITS] >> (bit % GMP_NUMB_BITS)) & 1;
427 #define mpint_conv_real(n, type, ntype, pack, unpack) \
428 bool attr_fastcall cat(mpint_init_from_,type)(mpint_t *t, type *valp, ajla_error_t *err)\
430 ntype norm, mult, val = unpack(*valp); \
431 int ex, shift; \
432 size_t limbs, idx; \
433 mp_limb_t *ptr; \
434 bool neg = unlikely(val < (ntype)0); \
435 val = cat(mathfunc_,ntype)(fabs)(val); \
436 if (unlikely(!cat(isfinite_, ntype)(val))) { \
437 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INFINITY), err, "attempting to convert infinity to integer");\
438 return false; \
440 norm = cat(mathfunc_,ntype)(frexp)(val, &ex); \
441 if (unlikely(!mpint_alloc_mayfail(t, likely(ex >= 0) ? ex : 0, err)))\
442 return false; \
443 if (unlikely(ex <= 0)) \
444 goto skip; \
445 idx = (unsigned)(ex - 1) / GMP_NUMB_BITS; \
446 limbs = idx + 1; \
447 ptr = mpz_limbs_write(t, limbs); \
448 shift = ((unsigned)(ex - 1) % GMP_NUMB_BITS) + 1; \
449 mult = (ntype)((mp_limb_t)1 << (shift - 1)); \
450 mult += mult; \
451 norm *= mult; \
452 while (1) { \
453 mp_limb_t limb = (mp_limb_t)norm; \
454 norm -= (ntype)limb; \
455 ptr[idx] = limb; \
456 if (!idx) \
457 break; \
458 if (!norm) { \
459 memset(ptr, 0, idx * sizeof(mp_limb_t)); \
460 break; \
462 idx--; \
463 norm = norm * ((ntype)((mp_limb_t)1 << (GMP_NUMB_BITS - 1)) * 2.);\
465 mpz_limbs_finish(t, limbs); \
466 if (unlikely(neg)) { \
467 if (unlikely(!mpint_neg(t, t, err))) \
468 goto fail_free; \
470 skip: \
471 return true; \
472 fail_free: \
473 mpint_free(t); \
474 return false; \
476 void attr_fastcall cat(mpint_export_to_,type)(const mpint_t *s, type *result)\
478 size_t sz, last_bit, idx, base_pos; \
479 const mp_limb_t *limbs; \
480 ntype mult, r = 0; \
482 if (unlikely(!mpz_sgn(s))) \
483 goto skip; \
485 limbs = mpz_limbs_read(s); \
487 sz = mpz_sizeinbase(s, 2); \
488 last_bit = sz < cat(bits_,ntype) ? 0 : sz - cat(bits_,ntype); \
490 idx = (sz - 1) / GMP_NUMB_BITS; \
491 base_pos = idx * GMP_NUMB_BITS; \
492 mult = cat(mathfunc_,ntype)(ldexp)(1., base_pos); \
494 while (1) { \
495 ntype l; \
496 mp_limb_t limb = limbs[idx]; \
497 mp_limb_t mask = (mp_limb_t)-1; \
499 mask <<= base_pos <= last_bit ? last_bit - base_pos : 0;\
501 l = (ntype)(limb & mask); \
503 if (l) \
504 r += mult * l; \
506 if (base_pos <= last_bit) \
507 break; \
509 sz = base_pos; \
510 idx = (sz - 1) / GMP_NUMB_BITS; \
511 base_pos = idx * GMP_NUMB_BITS; \
513 mult = mult * (ntype)(1. / ((ntype)((mp_limb_t)1 << (GMP_NUMB_BITS - 1)) * 2.));\
516 if (last_bit >= 1 && mpint_raw_test_bit(limbs, last_bit - 1)) { \
517 if (mpint_raw_test_bit(limbs, last_bit) || mpn_scan1(limbs, 0) != last_bit - 1) {\
518 r += cat(mathfunc_,ntype)(ldexp)(1., last_bit); \
522 if (unlikely(mpz_sgn(s) < 0)) \
523 r = -r; \
525 skip: \
526 *result = pack(r); \
528 for_all_real(mpint_conv_real, for_all_empty)
529 #undef mpint_conv_real
532 bool mpint_export_to_blob(const mpint_t *s, uint8_t **blob, size_t *blob_size, ajla_error_t *err)
534 uint8_t *ptr;
535 size_t sz = (mpz_sizeinbase(s, 2) + 7) / 8 + 1;
536 size_t count;
538 ptr = mem_alloc_mayfail(uint8_t *, sz, err);
539 if (unlikely(!ptr))
540 return false;
542 mpz_export(ptr, &count, -1, 1, 0, 0, s);
543 if (count > sz)
544 internal(file_line, "mpint_export_to_blob: mpz_export ran over the end of allocated memory: %"PRIuMAX" > %"PRIuMAX"", (uintmax_t)count, (uintmax_t)sz);
545 while (count < sz) {
546 ptr[count++] = 0;
549 if (unlikely(mpz_sgn(s) < 0)) {
550 size_t i;
551 bool do_not = false;
552 for (i = 0; i < sz; i++) {
553 if (!do_not) {
554 if (ptr[i] != 0) {
555 ptr[i] = -ptr[i];
556 do_not = true;
558 } else {
559 ptr[i] = ~ptr[i];
562 while (sz >= 2) {
563 if (ptr[sz - 1] == 0xff && ptr[sz - 2] >= 0x80)
564 sz--;
565 else
566 break;
568 } else {
569 while (sz >= 2) {
570 if (ptr[sz - 1] == 0x00 && ptr[sz - 2] < 0x80)
571 sz--;
572 else
573 break;
575 if (sz == 1 && !ptr[0])
576 sz = 0;
579 *blob = ptr;
580 *blob_size = sz;
582 return true;
586 static void *gmp_alloc(size_t size)
588 return mem_alloc(void *, size);
591 static void *gmp_realloc(void *ptr, size_t attr_unused old_size, size_t new_size)
593 return mem_realloc(void *, ptr, new_size);
596 static void gmp_free(void *ptr, size_t attr_unused size)
598 mem_free(ptr);
602 void mpint_init(void)
604 if (!dll)
605 mp_set_memory_functions(gmp_alloc, gmp_realloc, gmp_free);
608 void mpint_done(void)