don't merge variables with different "must_be_flat" to the same slot
[ajla.git] / mpint.c
blobb4ce729557f1fac5c70b5de6314c6046ac149a79
1 /*
2 * Copyright (C) 2024 Mikulas Patocka
4 * This file is part of Ajla.
6 * Ajla is free software: you can redistribute it and/or modify it under the
7 * terms of the GNU General Public License as published by the Free Software
8 * Foundation, either version 3 of the License, or (at your option) any later
9 * version.
11 * Ajla is distributed in the hope that it will be useful, but WITHOUT ANY
12 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
13 * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
15 * You should have received a copy of the GNU General Public License along with
16 * Ajla. If not, see <https://www.gnu.org/licenses/>.
19 #include "ajla.h"
21 #include "mem_al.h"
23 #include "mpint.h"
25 #define MPINT_MAX_BITS 0x80000000UL
28 #if defined(MPINT_GMP) && __GNU_MP_VERSION+0 < 5
29 typedef unsigned long mp_bitcnt_t;
30 #endif
32 #ifndef mpz_limbs_read
33 #define mpz_limbs_read(t) ((t)->_mp_d)
34 #endif
36 #ifndef mpz_limbs_write
37 #define mpz_limbs_write(t, idx) (((size_t)(t)->_mp_alloc < (idx) ? internal(file_line, "mpz_limbs_write: not enough entries: %"PRIuMAX" < %"PRIuMAX"", (uintmax_t)(t)->_mp_alloc, (uintmax_t)(idx)), 0 : 0), (t)->_mp_d)
38 #endif
40 #ifndef mpz_limbs_finish
41 #define mpz_limbs_finish(t, idx) ((t)->_mp_size = idx)
42 #endif
44 static attr_noinline bool attr_fastcall attr_cold mpint_size_ok_slow(const mpint_t *t, ajla_error_t *err)
46 size_t sz = mpz_sizeinbase(t, 2);
47 if (sz >= MPINT_MAX_BITS) {
48 if (sz == MPINT_MAX_BITS && mpz_sgn(t) < 0 &&
49 mpz_scan1(t, 0) == MPINT_MAX_BITS - 1)
50 return true;
51 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INT_TOO_LARGE), err, "integer too large");
52 return false;
54 return true;
57 static inline bool mpint_size_ok(const mpint_t attr_unused *t, ajla_error_t *err)
59 size_t size = mpz_size(t);
60 if (likely(size <= (MPINT_MAX_BITS - 1) / GMP_NUMB_BITS))
61 return true;
62 return mpint_size_ok_slow(t, err);
65 bool attr_fastcall mpint_add(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err)
67 mpz_add(r, s1, s2);
68 if (unlikely(!mpint_size_ok(r, err))) {
69 return false;
71 return true;
74 bool attr_fastcall mpint_subtract(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err)
76 mpz_sub(r, s1, s2);
77 if (unlikely(!mpint_size_ok(r, err))) {
78 return false;
80 return true;
83 static inline bool mpint_multiply_early_check(size_t size1, size_t size2, ajla_error_t *err)
85 if (unlikely(size1 + size2 > 1 + (MPINT_MAX_BITS + GMP_NUMB_BITS - 1) / GMP_NUMB_BITS)) {
86 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INT_TOO_LARGE), err, "integer too large");
87 return false;
89 return true;
92 bool attr_fastcall mpint_multiply(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err)
94 if (unlikely(!mpint_multiply_early_check(mpz_size(s1), mpz_size(s2), err))) {
95 return false;
97 mpz_mul(r, s1, s2);
98 if (unlikely(!mpint_size_ok(r, err)))
99 return false;
100 return true;
103 bool attr_fastcall mpint_divide(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err)
105 if (unlikely(!mpz_sgn(s2))) {
106 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INVALID_OPERATION), err, "divide by zero");
107 return false;
109 mpz_tdiv_q(r, s1, s2);
110 if (unlikely(!mpint_size_ok(r, err))) {
111 return false;
113 return true;
116 bool attr_fastcall mpint_modulo(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err)
118 if (unlikely(!mpz_sgn(s2))) {
119 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INVALID_OPERATION), err, "modulo by zero");
120 return false;
122 mpz_tdiv_r(r, s1, s2);
123 return true;
126 bool attr_fastcall mpint_power(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err)
128 mpint_t x1, x2;
129 mpz_init_set(&x1, s1);
130 mpz_init_set(&x2, s2);
131 if (unlikely(mpz_sgn(&x2) < 0)) {
132 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INVALID_OPERATION), err, "power by negative number");
133 ret_err:
134 mpz_clear(&x1);
135 mpz_clear(&x2);
136 return false;
138 mpz_set_ui(r, 1);
139 while (1) {
140 if (mpz_tstbit(&x2, 0)) {
141 if (unlikely(!mpint_multiply(r, &x1, r, err)))
142 goto ret_err;
144 if (!mpz_sgn(&x2))
145 break;
146 mpz_tdiv_q_2exp(&x2, &x2, 1);
147 if (unlikely(!mpint_multiply(&x1, &x1, &x1, err)))
148 goto ret_err;
150 mpz_clear(&x1);
151 mpz_clear(&x2);
152 return true;
155 bool attr_fastcall mpint_and(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t attr_unused *err)
157 mpz_and(r, s1, s2);
158 return true;
161 bool attr_fastcall mpint_or(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t attr_unused *err)
163 mpz_ior(r, s1, s2);
164 return true;
167 bool attr_fastcall mpint_xor(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t attr_unused *err)
169 mpz_xor(r, s1, s2);
170 return true;
173 bool attr_fastcall mpint_shl(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err)
175 unsigned long sh;
176 size_t size1, size2;
177 if (unlikely(!mpz_fits_ulong_p(s2))) {
178 overflow:
179 if (unlikely((mpz_sgn(s2) < 0))) {
180 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INVALID_OPERATION), err, "shift left with negative count");
181 return false;
182 } else {
183 if (!mpz_sgn(s1)) {
184 mpz_set_ui(r, 0);
185 return true;
186 } else {
187 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INT_TOO_LARGE), err, "integer too large");
188 return false;
192 sh = mpz_get_ui(s2);
193 if (unlikely((mp_bitcnt_t)sh != sh))
194 goto overflow;
195 size1 = mpz_size(s1);
196 size2 = 1 + sh / GMP_NUMB_BITS;
197 if (unlikely(!mpint_multiply_early_check(size1, size2, err)))
198 return false;
199 mpz_mul_2exp(r, s1, sh);
200 if (unlikely(!mpint_size_ok(r, err)))
201 return false;
202 return true;
205 bool attr_fastcall mpint_shr(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err)
207 unsigned long sh;
208 if (unlikely(!mpz_fits_ulong_p(s2))) {
209 overflow:
210 if (unlikely((mpz_sgn(s2) < 0))) {
211 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INVALID_OPERATION), err, "shift right with negative count");
212 return false;
213 } else {
214 if (mpz_sgn(s1) >= 0) {
215 mpz_set_ui(r, 0);
216 return true;
217 } else {
218 mpz_set_si(r, -1);
219 return true;
223 sh = mpz_get_ui(s2);
224 if (unlikely((mp_bitcnt_t)sh != sh))
225 goto overflow;
226 mpz_fdiv_q_2exp(r, s1, sh);
227 return true;
230 static inline bool attr_fastcall mpint_btx_(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err, void (*fn_bit)(mpint_t *, mp_bitcnt_t), int mode)
232 unsigned long sh;
233 if (unlikely(!mpz_fits_ulong_p(s2))) {
234 if (unlikely(mpz_sgn(s2) < 0)) {
235 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INVALID_OPERATION), err, "bit %s with negative position", mode == 0 ? "set" : mode == 1 ? "reset" : "complement");
236 return false;
238 overflow:
239 if (mode == 0 && mpz_sgn(s1) < 0) {
240 mpz_set(r, s1);
241 return true;
243 if (mode == 1 && mpz_sgn(s1) >= 0) {
244 mpz_set(r, s1);
245 return true;
247 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INT_TOO_LARGE), err, "integer too large");
248 return false;
250 sh = mpz_get_ui(s2);
251 if (unlikely(sh != (unsigned long)(mp_bitcnt_t)sh))
252 goto overflow;
253 mpz_set(r, s1);
254 fn_bit(r, sh);
255 return true;
258 bool attr_fastcall mpint_bts(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err)
260 return mpint_btx_(s1, s2, r, err, mpz_setbit, 0);
263 bool attr_fastcall mpint_btr(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err)
265 return mpint_btx_(s1, s2, r, err, mpz_clrbit, 1);
268 bool attr_fastcall mpint_btc(const mpint_t *s1, const mpint_t *s2, mpint_t *r, ajla_error_t *err)
270 #if !defined(mpz_combit) || defined(UNUSUAL_ARITHMETICS)
271 ajla_flat_option_t o;
272 if (unlikely(!mpint_bt(s1, s2, &o, err)))
273 return false;
274 if (!o)
275 return mpint_bts(s1, s2, r, err);
276 else
277 return mpint_btr(s1, s2, r, err);
278 #else
279 return mpint_btx_(s1, s2, r, err, mpz_combit, 2);
280 #endif
283 bool attr_fastcall mpint_equal(const mpint_t *s1, const mpint_t *s2, ajla_flat_option_t *r, ajla_error_t attr_unused *err)
285 *r = !mpz_cmp(s1, s2);
286 return true;
289 bool attr_fastcall mpint_not_equal(const mpint_t *s1, const mpint_t *s2, ajla_flat_option_t *r, ajla_error_t attr_unused *err)
291 *r = !!mpz_cmp(s1, s2);
292 return true;
295 bool attr_fastcall mpint_less(const mpint_t *s1, const mpint_t *s2, ajla_flat_option_t *r, ajla_error_t attr_unused *err)
297 *r = mpz_cmp(s1, s2) < 0;
298 return true;
301 bool attr_fastcall mpint_less_equal(const mpint_t *s1, const mpint_t *s2, ajla_flat_option_t *r, ajla_error_t attr_unused *err)
303 *r = mpz_cmp(s2, s1) >= 0;
304 return true;
307 bool attr_fastcall mpint_bt(const mpint_t *s1, const mpint_t *s2, ajla_flat_option_t *r, ajla_error_t *err)
309 unsigned long sh;
310 if (unlikely(!mpz_fits_ulong_p(s2))) {
311 if (unlikely(mpz_sgn(s2) < 0)) {
312 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INVALID_OPERATION), err, "bit test with negative position");
313 return false;
315 overflow:
316 *r = mpz_sgn(s1) < 0;
317 return true;
319 sh = mpz_get_ui(s2);
320 if (unlikely(sh != (unsigned long)(mp_bitcnt_t)sh))
321 goto overflow;
322 *r = mpz_tstbit(s1, sh);
323 return true;
326 bool attr_fastcall mpint_not(const mpint_t *s, mpint_t *r, ajla_error_t attr_unused *err)
328 mpz_com(r, s);
329 return true;
332 bool attr_fastcall mpint_neg(const mpint_t *s, mpint_t *r, ajla_error_t *err)
334 mpz_neg(r, s);
335 if (unlikely(!mpint_size_ok(r, err))) {
336 return false;
338 return true;
341 bool attr_fastcall mpint_inc(const mpint_t *s, mpint_t *r, ajla_error_t *err)
343 mpz_add_ui(r, s, 1);
344 if (unlikely(!mpint_size_ok(r, err))) {
345 return false;
347 return true;
350 bool attr_fastcall mpint_dec(const mpint_t *s, mpint_t *r, ajla_error_t *err)
352 mpz_sub_ui(r, s, 1);
353 if (unlikely(!mpint_size_ok(r, err))) {
354 return false;
356 return true;
359 bool attr_fastcall mpint_bsf(const mpint_t *s, mpint_t *r, ajla_error_t *err)
361 mp_bitcnt_t b;
362 if (unlikely(!mpz_sgn(s))) {
363 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INVALID_OPERATION), err, "bit scan forward with zero argument");
364 return false;
366 b = mpz_scan1(s, 0);
367 #ifndef UNUSUAL_ARITHMETICS
368 if (likely(b == (unsigned long)b))
369 mpz_set_ui(r, b);
370 else
371 #endif
372 mpz_import(r, 1, 1, sizeof(b), 0, 0, &b);
373 return true;
376 bool attr_fastcall mpint_bsr(const mpint_t *s, mpint_t *r, ajla_error_t *err)
378 size_t sz;
379 if (unlikely(mpz_sgn(s) <= 0)) {
380 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INVALID_OPERATION), err, "bit scan reverse with non-positive argument");
381 return false;
383 sz = mpz_sizeinbase(s, 2) - 1;
384 #ifndef UNUSUAL_ARITHMETICS
385 if (likely(sz == (unsigned long)sz))
386 mpz_set_ui(r, sz);
387 else
388 #endif
389 mpz_import(r, 1, 1, sizeof(sz), 0, 0, &sz);
390 return true;
393 bool attr_fastcall mpint_popcnt(const mpint_t *s, mpint_t *r, ajla_error_t *err)
395 mp_bitcnt_t b;
396 if (unlikely(mpz_sgn(s) < 0)) {
397 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INVALID_OPERATION), err, "population count with negative argument");
398 return false;
400 b = mpz_popcount(s);
401 #ifndef UNUSUAL_ARITHMETICS
402 if (likely(b == (unsigned long)b))
403 mpz_set_ui(r, b);
404 else
405 #endif
406 mpz_import(r, 1, 1, sizeof(b), 0, 0, &b);
407 return true;
410 static inline bool mpint_raw_test_bit(const mp_limb_t *ptr, size_t bit)
412 return (ptr[bit / GMP_NUMB_BITS] >> (bit % GMP_NUMB_BITS)) & 1;
415 #define mpint_conv_real(n, type, ntype, pack, unpack) \
416 bool attr_fastcall cat(mpint_init_from_,type)(mpint_t *t, type *valp, ajla_error_t *err)\
418 ntype norm, mult, val = unpack(*valp); \
419 int ex, shift; \
420 size_t limbs, idx; \
421 mp_limb_t *ptr; \
422 bool neg = unlikely(val < (ntype)0); \
423 val = cat(mathfunc_,ntype)(fabs)(val); \
424 if (unlikely(!cat(isfinite_, ntype)(val))) { \
425 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INFINITY), err, "attempting to convert infinity to integer");\
426 return false; \
428 norm = cat(mathfunc_,ntype)(frexp)(val, &ex); \
429 if (unlikely(!mpint_alloc_mayfail(t, likely(ex >= 0) ? ex : 0, err)))\
430 return false; \
431 if (unlikely(ex <= 0)) \
432 goto skip; \
433 idx = (unsigned)(ex - 1) / GMP_NUMB_BITS; \
434 limbs = idx + 1; \
435 ptr = mpz_limbs_write(t, limbs); \
436 shift = ((unsigned)(ex - 1) % GMP_NUMB_BITS) + 1; \
437 mult = (ntype)((mp_limb_t)1 << (shift - 1)); \
438 mult += mult; \
439 norm *= mult; \
440 while (1) { \
441 mp_limb_t limb = (mp_limb_t)norm; \
442 norm -= (ntype)limb; \
443 ptr[idx] = limb; \
444 if (!idx) \
445 break; \
446 if (!norm) { \
447 memset(ptr, 0, idx * sizeof(mp_limb_t)); \
448 break; \
450 idx--; \
451 norm = norm * ((ntype)((mp_limb_t)1 << (GMP_NUMB_BITS - 1)) * 2.);\
453 mpz_limbs_finish(t, limbs); \
454 if (unlikely(neg)) { \
455 if (unlikely(!mpint_neg(t, t, err))) \
456 goto fail_free; \
458 skip: \
459 return true; \
460 fail_free: \
461 mpint_free(t); \
462 return false; \
464 void attr_fastcall cat(mpint_export_to_,type)(const mpint_t *s, type *result)\
466 size_t sz, last_bit, idx, base_pos; \
467 const mp_limb_t *limbs; \
468 ntype mult, r = 0; \
470 if (unlikely(!mpz_sgn(s))) \
471 goto skip; \
473 limbs = mpz_limbs_read(s); \
475 sz = mpz_sizeinbase(s, 2); \
476 last_bit = sz < cat(bits_,ntype) ? 0 : sz - cat(bits_,ntype); \
478 idx = (sz - 1) / GMP_NUMB_BITS; \
479 base_pos = idx * GMP_NUMB_BITS; \
480 mult = cat(mathfunc_,ntype)(ldexp)(1., base_pos); \
482 while (1) { \
483 ntype l; \
484 mp_limb_t limb = limbs[idx]; \
485 mp_limb_t mask = (mp_limb_t)-1; \
487 mask <<= base_pos <= last_bit ? last_bit - base_pos : 0;\
489 l = (ntype)(limb & mask); \
491 if (l) \
492 r += mult * l; \
494 if (base_pos <= last_bit) \
495 break; \
497 sz = base_pos; \
498 idx = (sz - 1) / GMP_NUMB_BITS; \
499 base_pos = idx * GMP_NUMB_BITS; \
501 mult = mult * (ntype)(1. / ((ntype)((mp_limb_t)1 << (GMP_NUMB_BITS - 1)) * 2.));\
504 if (last_bit >= 1 && mpint_raw_test_bit(limbs, last_bit - 1)) { \
505 if (mpint_raw_test_bit(limbs, last_bit) || mpn_scan1(limbs, 0) != last_bit - 1) {\
506 r += cat(mathfunc_,ntype)(ldexp)(1., last_bit); \
510 if (unlikely(mpz_sgn(s) < 0)) \
511 r = -r; \
513 skip: \
514 *result = pack(r); \
516 for_all_real(mpint_conv_real, for_all_empty)
517 #undef mpint_conv_real
520 bool mpint_export_to_blob(const mpint_t *s, uint8_t **blob, size_t *blob_size, ajla_error_t *err)
522 uint8_t *ptr;
523 size_t sz = (mpz_sizeinbase(s, 2) + 7) / 8 + 1;
524 size_t count;
526 ptr = mem_alloc_mayfail(uint8_t *, sz, err);
527 if (unlikely(!ptr))
528 return false;
530 mpz_export(ptr, &count, -1, 1, 0, 0, s);
531 if (count > sz)
532 internal(file_line, "mpint_export_to_blob: mpz_export ran over the end of allocated memory: %"PRIuMAX" > %"PRIuMAX"", (uintmax_t)count, (uintmax_t)sz);
533 while (count < sz) {
534 ptr[count++] = 0;
537 if (unlikely(mpz_sgn(s) < 0)) {
538 size_t i;
539 bool do_not = false;
540 for (i = 0; i < sz; i++) {
541 if (!do_not) {
542 if (ptr[i] != 0) {
543 ptr[i] = -ptr[i];
544 do_not = true;
546 } else {
547 ptr[i] = ~ptr[i];
550 while (sz >= 2) {
551 if (ptr[sz - 1] == 0xff && ptr[sz - 2] >= 0x80)
552 sz--;
553 else
554 break;
556 } else {
557 while (sz >= 2) {
558 if (ptr[sz - 1] == 0x00 && ptr[sz - 2] < 0x80)
559 sz--;
560 else
561 break;
563 if (sz == 1 && !ptr[0])
564 sz = 0;
567 *blob = ptr;
568 *blob_size = sz;
570 return true;
574 static void *gmp_alloc(size_t size)
576 return mem_alloc(void *, size);
579 static void *gmp_realloc(void *ptr, size_t attr_unused old_size, size_t new_size)
581 return mem_realloc(void *, ptr, new_size);
584 static void gmp_free(void *ptr, size_t attr_unused size)
586 mem_free(ptr);
590 void mpint_init(void)
592 if (!dll)
593 mp_set_memory_functions(gmp_alloc, gmp_realloc, gmp_free);
596 void mpint_done(void)