2 * Copyright (C) 2024 Mikulas Patocka
4 * This file is part of Ajla.
6 * Ajla is free software: you can redistribute it and/or modify it under the
7 * terms of the GNU General Public License as published by the Free Software
8 * Foundation, either version 3 of the License, or (at your option) any later
11 * Ajla is distributed in the hope that it will be useful, but WITHOUT ANY
12 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
13 * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
15 * You should have received a copy of the GNU General Public License along with
16 * Ajla. If not, see <https://www.gnu.org/licenses/>.
25 #define MPINT_MAX_BITS 0x80000000UL
28 #if defined(MPINT_GMP) && __GNU_MP_VERSION+0 < 5
29 typedef unsigned long mp_bitcnt_t
;
32 #ifndef mpz_limbs_read
33 #define mpz_limbs_read(t) ((t)->_mp_d)
36 #ifndef mpz_limbs_write
37 #define mpz_limbs_write(t, idx) (((size_t)(t)->_mp_alloc < (idx) ? internal(file_line, "mpz_limbs_write: not enough entries: %"PRIuMAX" < %"PRIuMAX"", (uintmax_t)(t)->_mp_alloc, (uintmax_t)(idx)), 0 : 0), (t)->_mp_d)
40 #ifndef mpz_limbs_finish
41 #define mpz_limbs_finish(t, idx) ((t)->_mp_size = idx)
44 static attr_noinline
bool attr_fastcall attr_cold
mpint_size_ok_slow(const mpint_t
*t
, ajla_error_t
*err
)
46 size_t sz
= mpz_sizeinbase(t
, 2);
47 if (sz
>= MPINT_MAX_BITS
) {
48 if (sz
== MPINT_MAX_BITS
&& mpz_sgn(t
) < 0 &&
49 mpz_scan1(t
, 0) == MPINT_MAX_BITS
- 1)
51 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INT_TOO_LARGE
), err
, "integer too large");
57 static inline bool mpint_size_ok(const mpint_t attr_unused
*t
, ajla_error_t
*err
)
59 size_t size
= mpz_size(t
);
60 if (likely(size
<= (MPINT_MAX_BITS
- 1) / GMP_NUMB_BITS
))
62 return mpint_size_ok_slow(t
, err
);
65 bool attr_fastcall
mpint_add(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
)
68 if (unlikely(!mpint_size_ok(r
, err
))) {
74 bool attr_fastcall
mpint_subtract(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
)
77 if (unlikely(!mpint_size_ok(r
, err
))) {
83 static inline bool mpint_multiply_early_check(size_t size1
, size_t size2
, ajla_error_t
*err
)
85 if (unlikely(size1
+ size2
> 1 + (MPINT_MAX_BITS
+ GMP_NUMB_BITS
- 1) / GMP_NUMB_BITS
)) {
86 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INT_TOO_LARGE
), err
, "integer too large");
92 bool attr_fastcall
mpint_multiply(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
)
94 if (unlikely(!mpint_multiply_early_check(mpz_size(s1
), mpz_size(s2
), err
))) {
98 if (unlikely(!mpint_size_ok(r
, err
)))
103 bool attr_fastcall
mpint_divide(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
)
105 if (unlikely(!mpz_sgn(s2
))) {
106 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INVALID_OPERATION
), err
, "divide by zero");
109 mpz_tdiv_q(r
, s1
, s2
);
110 if (unlikely(!mpint_size_ok(r
, err
))) {
116 bool attr_fastcall
mpint_modulo(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
)
118 if (unlikely(!mpz_sgn(s2
))) {
119 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INVALID_OPERATION
), err
, "modulo by zero");
122 mpz_tdiv_r(r
, s1
, s2
);
126 bool attr_fastcall
mpint_power(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
)
129 mpz_init_set(&x1
, s1
);
130 mpz_init_set(&x2
, s2
);
131 if (unlikely(mpz_sgn(&x2
) < 0)) {
132 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INVALID_OPERATION
), err
, "power by negative number");
140 if (mpz_tstbit(&x2
, 0)) {
141 if (unlikely(!mpint_multiply(r
, &x1
, r
, err
)))
146 mpz_tdiv_q_2exp(&x2
, &x2
, 1);
147 if (unlikely(!mpint_multiply(&x1
, &x1
, &x1
, err
)))
155 bool attr_fastcall
mpint_and(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t attr_unused
*err
)
161 bool attr_fastcall
mpint_or(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t attr_unused
*err
)
167 bool attr_fastcall
mpint_xor(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t attr_unused
*err
)
173 bool attr_fastcall
mpint_shl(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
)
177 if (unlikely(!mpz_fits_ulong_p(s2
))) {
179 if (unlikely((mpz_sgn(s2
) < 0))) {
180 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INVALID_OPERATION
), err
, "shift left with negative count");
187 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INT_TOO_LARGE
), err
, "integer too large");
193 if (unlikely((mp_bitcnt_t
)sh
!= sh
))
195 size1
= mpz_size(s1
);
196 size2
= 1 + sh
/ GMP_NUMB_BITS
;
197 if (unlikely(!mpint_multiply_early_check(size1
, size2
, err
)))
199 mpz_mul_2exp(r
, s1
, sh
);
200 if (unlikely(!mpint_size_ok(r
, err
)))
205 bool attr_fastcall
mpint_shr(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
)
208 if (unlikely(!mpz_fits_ulong_p(s2
))) {
210 if (unlikely((mpz_sgn(s2
) < 0))) {
211 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INVALID_OPERATION
), err
, "shift right with negative count");
214 if (mpz_sgn(s1
) >= 0) {
224 if (unlikely((mp_bitcnt_t
)sh
!= sh
))
226 mpz_fdiv_q_2exp(r
, s1
, sh
);
230 static inline bool attr_fastcall
mpint_btx_(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
, void (*fn_bit
)(mpint_t
*, mp_bitcnt_t
), int mode
)
233 if (unlikely(!mpz_fits_ulong_p(s2
))) {
234 if (unlikely(mpz_sgn(s2
) < 0)) {
235 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INVALID_OPERATION
), err
, "bit %s with negative position", mode
== 0 ? "set" : mode
== 1 ? "reset" : "complement");
239 if (mode
== 0 && mpz_sgn(s1
) < 0) {
243 if (mode
== 1 && mpz_sgn(s1
) >= 0) {
247 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INT_TOO_LARGE
), err
, "integer too large");
251 if (unlikely(sh
!= (unsigned long)(mp_bitcnt_t
)sh
))
258 bool attr_fastcall
mpint_bts(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
)
260 return mpint_btx_(s1
, s2
, r
, err
, mpz_setbit
, 0);
263 bool attr_fastcall
mpint_btr(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
)
265 return mpint_btx_(s1
, s2
, r
, err
, mpz_clrbit
, 1);
268 bool attr_fastcall
mpint_btc(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
)
270 #if !defined(mpz_combit) || defined(UNUSUAL_ARITHMETICS)
271 ajla_flat_option_t o
;
272 if (unlikely(!mpint_bt(s1
, s2
, &o
, err
)))
275 return mpint_bts(s1
, s2
, r
, err
);
277 return mpint_btr(s1
, s2
, r
, err
);
279 return mpint_btx_(s1
, s2
, r
, err
, mpz_combit
, 2);
283 bool attr_fastcall
mpint_equal(const mpint_t
*s1
, const mpint_t
*s2
, ajla_flat_option_t
*r
, ajla_error_t attr_unused
*err
)
285 *r
= !mpz_cmp(s1
, s2
);
289 bool attr_fastcall
mpint_not_equal(const mpint_t
*s1
, const mpint_t
*s2
, ajla_flat_option_t
*r
, ajla_error_t attr_unused
*err
)
291 *r
= !!mpz_cmp(s1
, s2
);
295 bool attr_fastcall
mpint_less(const mpint_t
*s1
, const mpint_t
*s2
, ajla_flat_option_t
*r
, ajla_error_t attr_unused
*err
)
297 *r
= mpz_cmp(s1
, s2
) < 0;
301 bool attr_fastcall
mpint_less_equal(const mpint_t
*s1
, const mpint_t
*s2
, ajla_flat_option_t
*r
, ajla_error_t attr_unused
*err
)
303 *r
= mpz_cmp(s2
, s1
) >= 0;
307 bool attr_fastcall
mpint_bt(const mpint_t
*s1
, const mpint_t
*s2
, ajla_flat_option_t
*r
, ajla_error_t
*err
)
310 if (unlikely(!mpz_fits_ulong_p(s2
))) {
311 if (unlikely(mpz_sgn(s2
) < 0)) {
312 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INVALID_OPERATION
), err
, "bit test with negative position");
316 *r
= mpz_sgn(s1
) < 0;
320 if (unlikely(sh
!= (unsigned long)(mp_bitcnt_t
)sh
))
322 *r
= mpz_tstbit(s1
, sh
);
326 bool attr_fastcall
mpint_not(const mpint_t
*s
, mpint_t
*r
, ajla_error_t attr_unused
*err
)
332 bool attr_fastcall
mpint_neg(const mpint_t
*s
, mpint_t
*r
, ajla_error_t
*err
)
335 if (unlikely(!mpint_size_ok(r
, err
))) {
341 bool attr_fastcall
mpint_inc(const mpint_t
*s
, mpint_t
*r
, ajla_error_t
*err
)
344 if (unlikely(!mpint_size_ok(r
, err
))) {
350 bool attr_fastcall
mpint_dec(const mpint_t
*s
, mpint_t
*r
, ajla_error_t
*err
)
353 if (unlikely(!mpint_size_ok(r
, err
))) {
359 bool attr_fastcall
mpint_bsf(const mpint_t
*s
, mpint_t
*r
, ajla_error_t
*err
)
362 if (unlikely(!mpz_sgn(s
))) {
363 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INVALID_OPERATION
), err
, "bit scan forward with zero argument");
367 #ifndef UNUSUAL_ARITHMETICS
368 if (likely(b
== (unsigned long)b
))
372 mpz_import(r
, 1, 1, sizeof(b
), 0, 0, &b
);
376 bool attr_fastcall
mpint_bsr(const mpint_t
*s
, mpint_t
*r
, ajla_error_t
*err
)
379 if (unlikely(mpz_sgn(s
) <= 0)) {
380 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INVALID_OPERATION
), err
, "bit scan reverse with non-positive argument");
383 sz
= mpz_sizeinbase(s
, 2) - 1;
384 #ifndef UNUSUAL_ARITHMETICS
385 if (likely(sz
== (unsigned long)sz
))
389 mpz_import(r
, 1, 1, sizeof(sz
), 0, 0, &sz
);
393 bool attr_fastcall
mpint_popcnt(const mpint_t
*s
, mpint_t
*r
, ajla_error_t
*err
)
396 if (unlikely(mpz_sgn(s
) < 0)) {
397 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INVALID_OPERATION
), err
, "population count with negative argument");
401 #ifndef UNUSUAL_ARITHMETICS
402 if (likely(b
== (unsigned long)b
))
406 mpz_import(r
, 1, 1, sizeof(b
), 0, 0, &b
);
410 static inline bool mpint_raw_test_bit(const mp_limb_t
*ptr
, size_t bit
)
412 return (ptr
[bit
/ GMP_NUMB_BITS
] >> (bit
% GMP_NUMB_BITS
)) & 1;
415 #define mpint_conv_real(n, type, ntype, pack, unpack) \
416 bool attr_fastcall cat(mpint_init_from_,type)(mpint_t *t, type *valp, ajla_error_t *err)\
418 ntype norm, mult, val = unpack(*valp); \
422 bool neg = unlikely(val < (ntype)0); \
423 val = cat(mathfunc_,ntype)(fabs)(val); \
424 if (unlikely(!cat(isfinite_, ntype)(val))) { \
425 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_INFINITY), err, "attempting to convert infinity to integer");\
428 norm = cat(mathfunc_,ntype)(frexp)(val, &ex); \
429 if (unlikely(!mpint_alloc_mayfail(t, likely(ex >= 0) ? ex : 0, err)))\
431 if (unlikely(ex <= 0)) \
433 idx = (unsigned)(ex - 1) / GMP_NUMB_BITS; \
435 ptr = mpz_limbs_write(t, limbs); \
436 shift = ((unsigned)(ex - 1) % GMP_NUMB_BITS) + 1; \
437 mult = (ntype)((mp_limb_t)1 << (shift - 1)); \
441 mp_limb_t limb = (mp_limb_t)norm; \
442 norm -= (ntype)limb; \
447 memset(ptr, 0, idx * sizeof(mp_limb_t)); \
451 norm = norm * ((ntype)((mp_limb_t)1 << (GMP_NUMB_BITS - 1)) * 2.);\
453 mpz_limbs_finish(t, limbs); \
454 if (unlikely(neg)) { \
455 if (unlikely(!mpint_neg(t, t, err))) \
464 void attr_fastcall cat(mpint_export_to_,type)(const mpint_t *s, type *result)\
466 size_t sz, last_bit, idx, base_pos; \
467 const mp_limb_t *limbs; \
470 if (unlikely(!mpz_sgn(s))) \
473 limbs = mpz_limbs_read(s); \
475 sz = mpz_sizeinbase(s, 2); \
476 last_bit = sz < cat(bits_,ntype) ? 0 : sz - cat(bits_,ntype); \
478 idx = (sz - 1) / GMP_NUMB_BITS; \
479 base_pos = idx * GMP_NUMB_BITS; \
480 mult = cat(mathfunc_,ntype)(ldexp)(1., base_pos); \
484 mp_limb_t limb = limbs[idx]; \
485 mp_limb_t mask = (mp_limb_t)-1; \
487 mask <<= base_pos <= last_bit ? last_bit - base_pos : 0;\
489 l = (ntype)(limb & mask); \
494 if (base_pos <= last_bit) \
498 idx = (sz - 1) / GMP_NUMB_BITS; \
499 base_pos = idx * GMP_NUMB_BITS; \
501 mult = mult * (ntype)(1. / ((ntype)((mp_limb_t)1 << (GMP_NUMB_BITS - 1)) * 2.));\
504 if (last_bit >= 1 && mpint_raw_test_bit(limbs, last_bit - 1)) { \
505 if (mpint_raw_test_bit(limbs, last_bit) || mpn_scan1(limbs, 0) != last_bit - 1) {\
506 r += cat(mathfunc_,ntype)(ldexp)(1., last_bit); \
510 if (unlikely(mpz_sgn(s) < 0)) \
516 for_all_real(mpint_conv_real
, for_all_empty
)
517 #undef mpint_conv_real
520 bool mpint_export_to_blob(const mpint_t
*s
, uint8_t **blob
, size_t *blob_size
, ajla_error_t
*err
)
523 size_t sz
= (mpz_sizeinbase(s
, 2) + 7) / 8 + 1;
526 ptr
= mem_alloc_mayfail(uint8_t *, sz
, err
);
530 mpz_export(ptr
, &count
, -1, 1, 0, 0, s
);
532 internal(file_line
, "mpint_export_to_blob: mpz_export ran over the end of allocated memory: %"PRIuMAX
" > %"PRIuMAX
"", (uintmax_t)count
, (uintmax_t)sz
);
537 if (unlikely(mpz_sgn(s
) < 0)) {
540 for (i
= 0; i
< sz
; i
++) {
551 if (ptr
[sz
- 1] == 0xff && ptr
[sz
- 2] >= 0x80)
558 if (ptr
[sz
- 1] == 0x00 && ptr
[sz
- 2] < 0x80)
563 if (sz
== 1 && !ptr
[0])
574 static void *gmp_alloc(size_t size
)
576 return mem_alloc(void *, size
);
579 static void *gmp_realloc(void *ptr
, size_t attr_unused old_size
, size_t new_size
)
581 return mem_realloc(void *, ptr
, new_size
);
584 static void gmp_free(void *ptr
, size_t attr_unused size
)
590 void mpint_init(void)
593 mp_set_memory_functions(gmp_alloc
, gmp_realloc
, gmp_free
);
596 void mpint_done(void)