2 * Copyright (C) 2024 Mikulas Patocka
4 * This file is part of Ajla.
6 * Ajla is free software: you can redistribute it and/or modify it under the
7 * terms of the GNU General Public License as published by the Free Software
8 * Foundation, either version 3 of the License, or (at your option) any later
11 * Ajla is distributed in the hope that it will be useful, but WITHOUT ANY
12 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
13 * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
15 * You should have received a copy of the GNU General Public License along with
16 * Ajla. If not, see <https://www.gnu.org/licenses/>.
25 #define MPINT_MAX_BITS 0x80000000UL
28 #if defined(MPINT_GMP) && __GNU_MP_VERSION+0 < 5
29 typedef unsigned long mp_bitcnt_t
;
32 #ifndef mpz_limbs_read
33 #define mpz_limbs_read(t) ((t)->_mp_d)
36 #ifndef mpz_limbs_write
37 #define mpz_limbs_write(t, idx) (((size_t)(t)->_mp_alloc < (idx) ? internal(file_line, "mpz_limbs_write: not enough entries: %"PRIuMAX" < %"PRIuMAX"", (uintmax_t)(t)->_mp_alloc, (uintmax_t)(idx)), 0 : 0), (t)->_mp_d)
40 #ifndef mpz_limbs_finish
41 #define mpz_limbs_finish(t, idx) ((t)->_mp_size = idx)
44 static attr_noinline
bool attr_fastcall attr_cold
mpint_size_ok_slow(const mpint_t
*t
, ajla_error_t
*err
)
46 size_t sz
= mpz_sizeinbase(t
, 2);
47 if (sz
>= MPINT_MAX_BITS
) {
48 if (sz
== MPINT_MAX_BITS
&& mpz_sgn(t
) < 0 &&
49 mpz_scan1(t
, 0) == MPINT_MAX_BITS
- 1)
51 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INT_TOO_LARGE
), err
, "integer too large");
57 static inline bool mpint_size_ok(const mpint_t attr_unused
*t
, ajla_error_t
*err
)
59 size_t size
= mpz_size(t
);
60 if (likely(size
<= (MPINT_MAX_BITS
- 1) / GMP_NUMB_BITS
))
62 return mpint_size_ok_slow(t
, err
);
65 bool attr_fastcall
mpint_add(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
)
68 if (unlikely(!mpint_size_ok(r
, err
))) {
74 bool attr_fastcall
mpint_subtract(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
)
77 if (unlikely(!mpint_size_ok(r
, err
))) {
83 static inline bool mpint_multiply_early_check(size_t size1
, size_t size2
, ajla_error_t
*err
)
85 if (unlikely(size1
+ size2
> 1 + (MPINT_MAX_BITS
+ GMP_NUMB_BITS
- 1) / GMP_NUMB_BITS
)) {
86 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INT_TOO_LARGE
), err
, "integer too large");
92 bool attr_fastcall
mpint_multiply(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
)
94 if (unlikely(!mpint_multiply_early_check(mpz_size(s1
), mpz_size(s2
), err
))) {
98 if (unlikely(!mpint_size_ok(r
, err
)))
103 bool attr_fastcall
mpint_divide(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
)
105 if (unlikely(!mpz_sgn(s2
))) {
106 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INVALID_OPERATION
), err
, "divide by zero");
109 mpz_tdiv_q(r
, s1
, s2
);
110 if (unlikely(!mpint_size_ok(r
, err
))) {
116 bool attr_fastcall
mpint_modulo(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
)
118 if (unlikely(!mpz_sgn(s2
))) {
119 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INVALID_OPERATION
), err
, "modulo by zero");
122 mpz_tdiv_r(r
, s1
, s2
);
126 bool attr_fastcall
mpint_power(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
)
129 mpz_init_set(&x1
, s1
);
130 mpz_init_set(&x2
, s2
);
131 if (unlikely(mpz_sgn(&x2
) < 0)) {
132 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INVALID_OPERATION
), err
, "power by negative number");
140 if (mpz_tstbit(&x2
, 0)) {
141 if (unlikely(!mpint_multiply(r
, &x1
, r
, err
)))
146 mpz_tdiv_q_2exp(&x2
, &x2
, 1);
147 if (unlikely(!mpint_multiply(&x1
, &x1
, &x1
, err
)))
155 bool attr_fastcall
mpint_and(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t attr_unused
*err
)
161 bool attr_fastcall
mpint_or(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t attr_unused
*err
)
167 bool attr_fastcall
mpint_xor(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t attr_unused
*err
)
173 bool attr_fastcall
mpint_shl(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
)
177 if (unlikely(!mpz_fits_ulong_p(s2
))) {
179 if (unlikely((mpz_sgn(s2
) < 0))) {
180 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INVALID_OPERATION
), err
, "shift left with negative count");
187 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INT_TOO_LARGE
), err
, "integer too large");
193 if (unlikely((mp_bitcnt_t
)sh
!= sh
))
195 size1
= mpz_size(s1
);
196 size2
= 1 + sh
/ GMP_NUMB_BITS
;
197 if (unlikely(!mpint_multiply_early_check(size1
, size2
, err
)))
199 mpz_mul_2exp(r
, s1
, sh
);
200 if (unlikely(!mpint_size_ok(r
, err
)))
205 bool attr_fastcall
mpint_shr(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
)
208 if (unlikely(!mpz_fits_ulong_p(s2
))) {
210 if (unlikely((mpz_sgn(s2
) < 0))) {
211 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INVALID_OPERATION
), err
, "shift right with negative count");
214 if (mpz_sgn(s1
) >= 0) {
224 if (unlikely((mp_bitcnt_t
)sh
!= sh
))
226 mpz_fdiv_q_2exp(r
, s1
, sh
);
230 static inline bool attr_fastcall
mpint_btx_(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
, void (*fn_bit
)(mpint_t
*, mp_bitcnt_t
), int mode
)
233 if (unlikely(!mpz_fits_ulong_p(s2
))) {
234 if (unlikely(mpz_sgn(s2
) < 0)) {
235 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INVALID_OPERATION
), err
, "bit %s with negative position", mode
== 0 ? "set" : mode
== 1 ? "reset" : "complement");
239 if (mode
== 0 && mpz_sgn(s1
) < 0) {
243 if (mode
== 1 && mpz_sgn(s1
) >= 0) {
247 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INT_TOO_LARGE
), err
, "integer too large");
251 if (unlikely(sh
!= (unsigned long)(mp_bitcnt_t
)sh
))
258 bool attr_fastcall
mpint_bts(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
)
260 return mpint_btx_(s1
, s2
, r
, err
, mpz_setbit
, 0);
263 bool attr_fastcall
mpint_btr(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
)
265 return mpint_btx_(s1
, s2
, r
, err
, mpz_clrbit
, 1);
268 bool attr_fastcall
mpint_btc(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
)
270 #if !defined(mpz_combit) || defined(UNUSUAL_ARITHMETICS)
271 ajla_flat_option_t o
;
272 if (unlikely(!mpint_bt(s1
, s2
, &o
, err
)))
275 return mpint_bts(s1
, s2
, r
, err
);
277 return mpint_btr(s1
, s2
, r
, err
);
279 return mpint_btx_(s1
, s2
, r
, err
, mpz_combit
, 2);
283 bool attr_fastcall
mpint_equal(const mpint_t
*s1
, const mpint_t
*s2
, ajla_flat_option_t
*r
, ajla_error_t attr_unused
*err
)
285 *r
= !mpz_cmp(s1
, s2
);
289 bool attr_fastcall
mpint_not_equal(const mpint_t
*s1
, const mpint_t
*s2
, ajla_flat_option_t
*r
, ajla_error_t attr_unused
*err
)
291 *r
= !!mpz_cmp(s1
, s2
);
295 bool attr_fastcall
mpint_less(const mpint_t
*s1
, const mpint_t
*s2
, ajla_flat_option_t
*r
, ajla_error_t attr_unused
*err
)
297 *r
= mpz_cmp(s1
, s2
) < 0;
301 bool attr_fastcall
mpint_less_equal(const mpint_t
*s1
, const mpint_t
*s2
, ajla_flat_option_t
*r
, ajla_error_t attr_unused
*err
)
303 *r
= mpz_cmp(s1
, s2
) <= 0;
307 bool attr_fastcall
mpint_greater(const mpint_t
*s1
, const mpint_t
*s2
, ajla_flat_option_t
*r
, ajla_error_t attr_unused
*err
)
309 *r
= mpz_cmp(s1
, s2
) > 0;
313 bool attr_fastcall
mpint_greater_equal(const mpint_t
*s1
, const mpint_t
*s2
, ajla_flat_option_t
*r
, ajla_error_t attr_unused
*err
)
315 *r
= mpz_cmp(s1
, s2
) >= 0;
319 bool attr_fastcall
mpint_bt(const mpint_t
*s1
, const mpint_t
*s2
, ajla_flat_option_t
*r
, ajla_error_t
*err
)
322 if (unlikely(!mpz_fits_ulong_p(s2
))) {
323 if (unlikely(mpz_sgn(s2
) < 0)) {
324 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INVALID_OPERATION
), err
, "bit test with negative position");
328 *r
= mpz_sgn(s1
) < 0;
332 if (unlikely(sh
!= (unsigned long)(mp_bitcnt_t
)sh
))
334 *r
= mpz_tstbit(s1
, sh
);
338 bool attr_fastcall
mpint_not(const mpint_t
*s
, mpint_t
*r
, ajla_error_t attr_unused
*err
)
344 bool attr_fastcall
mpint_neg(const mpint_t
*s
, mpint_t
*r
, ajla_error_t
*err
)
347 if (unlikely(!mpint_size_ok(r
, err
))) {
353 bool attr_fastcall
mpint_inc(const mpint_t
*s
, mpint_t
*r
, ajla_error_t
*err
)
356 if (unlikely(!mpint_size_ok(r
, err
))) {
362 bool attr_fastcall
mpint_dec(const mpint_t
*s
, mpint_t
*r
, ajla_error_t
*err
)
365 if (unlikely(!mpint_size_ok(r
, err
))) {
371 bool attr_fastcall
mpint_bsf(const mpint_t
*s
, mpint_t
*r
, ajla_error_t
*err
)
374 if (unlikely(!mpz_sgn(s
))) {
375 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INVALID_OPERATION
), err
, "bit scan forward with zero argument");
379 #ifndef UNUSUAL_ARITHMETICS
380 if (likely(b
== (unsigned long)b
))
384 mpz_import(r
, 1, 1, sizeof(b
), 0, 0, &b
);
388 bool attr_fastcall
mpint_bsr(const mpint_t
*s
, mpint_t
*r
, ajla_error_t
*err
)
391 if (unlikely(mpz_sgn(s
) <= 0)) {
392 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INVALID_OPERATION
), err
, "bit scan reverse with non-positive argument");
395 sz
= mpz_sizeinbase(s
, 2) - 1;
396 #ifndef UNUSUAL_ARITHMETICS
397 if (likely(sz
== (unsigned long)sz
))
401 mpz_import(r
, 1, 1, sizeof(sz
), 0, 0, &sz
);
405 bool attr_fastcall
mpint_popcnt(const mpint_t
*s
, mpint_t
*r
, ajla_error_t
*err
)
408 if (unlikely(mpz_sgn(s
) < 0)) {
409 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INVALID_OPERATION
), err
, "population count with negative argument");
413 #ifndef UNUSUAL_ARITHMETICS
414 if (likely(b
== (unsigned long)b
))
418 mpz_import(r
, 1, 1, sizeof(b
), 0, 0, &b
);
422 static inline bool mpint_raw_test_bit(const mp_limb_t
*ptr
, size_t bit
)
424 return (ptr
[bit
/ GMP_NUMB_BITS
] >> (bit
% GMP_NUMB_BITS
)) & 1;
427 #define mpint_conv_real(n, type, ntype, pack, unpack) \
428 bool attr_fastcall cat(mpint_init_from_,type)(mpint_t *t, type *valp, ajla_error_t *err)\
430 ntype norm, mult, val = unpack(*valp); \
434 bool neg = unlikely(val < (ntype)0); \
435 val = cat(mathfunc_,ntype)(fabs)(val); \
436 if (unlikely(!cat(isfinite_, ntype)(val))) { \
437 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INFINITY), err, "attempting to convert infinity to integer");\
440 norm = cat(mathfunc_,ntype)(frexp)(val, &ex); \
441 if (unlikely(!mpint_alloc_mayfail(t, likely(ex >= 0) ? ex : 0, err)))\
443 if (unlikely(ex <= 0)) \
445 idx = (unsigned)(ex - 1) / GMP_NUMB_BITS; \
447 ptr = mpz_limbs_write(t, limbs); \
448 shift = ((unsigned)(ex - 1) % GMP_NUMB_BITS) + 1; \
449 mult = (ntype)((mp_limb_t)1 << (shift - 1)); \
453 mp_limb_t limb = (mp_limb_t)norm; \
454 norm -= (ntype)limb; \
459 memset(ptr, 0, idx * sizeof(mp_limb_t)); \
463 norm = norm * ((ntype)((mp_limb_t)1 << (GMP_NUMB_BITS - 1)) * 2.);\
465 mpz_limbs_finish(t, limbs); \
466 if (unlikely(neg)) { \
467 if (unlikely(!mpint_neg(t, t, err))) \
476 void attr_fastcall cat(mpint_export_to_,type)(const mpint_t *s, type *result)\
478 size_t sz, last_bit, idx, base_pos; \
479 const mp_limb_t *limbs; \
482 if (unlikely(!mpz_sgn(s))) \
485 limbs = mpz_limbs_read(s); \
487 sz = mpz_sizeinbase(s, 2); \
488 last_bit = sz < cat(bits_,ntype) ? 0 : sz - cat(bits_,ntype); \
490 idx = (sz - 1) / GMP_NUMB_BITS; \
491 base_pos = idx * GMP_NUMB_BITS; \
492 mult = cat(mathfunc_,ntype)(ldexp)(1., base_pos); \
496 mp_limb_t limb = limbs[idx]; \
497 mp_limb_t mask = (mp_limb_t)-1; \
499 mask <<= base_pos <= last_bit ? last_bit - base_pos : 0;\
501 l = (ntype)(limb & mask); \
506 if (base_pos <= last_bit) \
510 idx = (sz - 1) / GMP_NUMB_BITS; \
511 base_pos = idx * GMP_NUMB_BITS; \
513 mult = mult * (ntype)(1. / ((ntype)((mp_limb_t)1 << (GMP_NUMB_BITS - 1)) * 2.));\
516 if (last_bit >= 1 && mpint_raw_test_bit(limbs, last_bit - 1)) { \
517 if (mpint_raw_test_bit(limbs, last_bit) || mpn_scan1(limbs, 0) != last_bit - 1) {\
518 r += cat(mathfunc_,ntype)(ldexp)(1., last_bit); \
522 if (unlikely(mpz_sgn(s) < 0)) \
528 for_all_real(mpint_conv_real
, for_all_empty
)
529 #undef mpint_conv_real
532 bool mpint_export_to_blob(const mpint_t
*s
, uint8_t **blob
, size_t *blob_size
, ajla_error_t
*err
)
535 size_t sz
= (mpz_sizeinbase(s
, 2) + 7) / 8 + 1;
538 ptr
= mem_alloc_mayfail(uint8_t *, sz
, err
);
542 mpz_export(ptr
, &count
, -1, 1, 0, 0, s
);
544 internal(file_line
, "mpint_export_to_blob: mpz_export ran over the end of allocated memory: %"PRIuMAX
" > %"PRIuMAX
"", (uintmax_t)count
, (uintmax_t)sz
);
549 if (unlikely(mpz_sgn(s
) < 0)) {
552 for (i
= 0; i
< sz
; i
++) {
563 if (ptr
[sz
- 1] == 0xff && ptr
[sz
- 2] >= 0x80)
570 if (ptr
[sz
- 1] == 0x00 && ptr
[sz
- 2] < 0x80)
575 if (sz
== 1 && !ptr
[0])
586 static void *gmp_alloc(size_t size
)
588 return mem_alloc(void *, size
);
591 static void *gmp_realloc(void *ptr
, size_t attr_unused old_size
, size_t new_size
)
593 return mem_realloc(void *, ptr
, new_size
);
596 static void gmp_free(void *ptr
, size_t attr_unused size
)
602 void mpint_init(void)
605 mp_set_memory_functions(gmp_alloc
, gmp_realloc
, gmp_free
);
608 void mpint_done(void)