2 * Copyright (C) 2024 Mikulas Patocka
4 * This file is part of Ajla.
6 * Ajla is free software: you can redistribute it and/or modify it under the
7 * terms of the GNU General Public License as published by the Free Software
8 * Foundation, either version 3 of the License, or (at your option) any later
11 * Ajla is distributed in the hope that it will be useful, but WITHOUT ANY
12 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
13 * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
15 * You should have received a copy of the GNU General Public License along with
16 * Ajla. If not, see <https://www.gnu.org/licenses/>.
26 #if !defined(MPINT_GMP)
28 #elif defined(HAVE_GMP_H)
30 #elif defined(HAVE_GMP_GMP_H)
34 typedef MP_INT mpint_t
;
36 #define mp_direct long
37 #define mp_udirect unsigned long
39 static inline unsigned long mpint_estimate_bits(const mpint_t
*t
)
41 return (unsigned long)mpz_size(t
) * GMP_NUMB_BITS
;
44 static inline bool mpint_alloc_mayfail(mpint_t
*t
, unsigned long bits
, ajla_error_t attr_unused
*err
)
50 static inline bool mpint_alloc_copy_mayfail(mpint_t
*t
, const mpint_t
*src
, ajla_error_t attr_unused
*err
)
56 static inline void mpint_free(mpint_t
*t
)
61 static inline bool mpint_negative(const mpint_t
*t
)
63 return mpz_sgn(t
) < 0;
66 #define mpint_conv_int(n, type, utype, sz, bits) \
67 static inline bool cat(mpint_set_from_,type)(mpint_t *t, type val, bool uns, ajla_error_t attr_unused *err)\
69 if (sizeof(type) <= sizeof(mp_direct)) { \
71 mpz_set_si(t, (mp_direct)val); \
73 mpz_set_ui(t, (mp_udirect)(cat(u,type))val); \
75 bool sign = val < 0 && !uns; \
79 mpz_import(t, 1, 1, sizeof(type), 0, 0, &val); \
86 static inline bool cat(mpint_init_from_,type)(mpint_t *t, type val, ajla_error_t *err)\
88 if (sizeof(type) <= sizeof(mp_direct)) { \
89 mpz_init_set_si(t, (mp_direct)val); \
91 if (unlikely(!mpint_alloc_mayfail(t, sizeof(type) * 8, err)))\
93 if (unlikely(!cat(mpint_set_from_,type)(t, val, false, err))) {\
101 static attr_always_inline bool cat(mpint_export_to_,type)(const mpint_t *t, type *result, ajla_error_t *err)\
103 if (mpz_fits_slong_p(t)) { \
106 if (unlikely(l != (type)l)) \
110 } else if (sizeof(type) > sizeof(long)) { \
112 size_t bit = mpz_sizeinbase(t, 2); \
113 if (bit > 8 * sizeof(type)) \
115 (void)mpz_export(&ui, NULL, 1, sizeof(utype), 0, 0, t); \
116 if (mpz_sgn(t) >= 0) { \
124 *result = (type)ui; \
129 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_DOESNT_FIT), err, "integer too large for the target type");\
133 static attr_always_inline bool cat(mpint_export_to_,utype)(const mpint_t *t, utype *result, ajla_error_t *err)\
135 if (mpz_fits_ulong_p(t)) { \
138 if (unlikely(l != (utype)l)) \
140 *result = (utype)l; \
142 } else if (sizeof(utype) > sizeof(unsigned long)) { \
144 if (unlikely(mpz_sgn(t) < 0)) \
146 bit = mpz_sizeinbase(t, 2); \
147 if (bit > 8 * sizeof(utype)) \
149 (void)mpz_export(result, NULL, 1, sizeof(utype), 0, 0, t);\
154 fatal_mayfail(error_ajla(EC_SYNC, AJLA_ERROR_DOESNT_FIT), err, "integer too large for the target type");\
157 for_all_int(mpint_conv_int
, for_all_empty
)
158 #undef mpint_conv_int
161 static inline bool mpint_import_from_code(mpint_t
*m
, const code_t
*code
, ip_t n_words
, ajla_error_t
*err
)
163 if (unlikely(n_words
>= size_t_limit
+ uzero
)) {
164 fatal_mayfail(error_ajla(EC_SYNC
, AJLA_ERROR_INT_TOO_LARGE
), err
, "integer too large");
167 if (likely(n_words
!= 0) && unlikely((code
[!CODE_ENDIAN
? n_words
- 1 : 0] & sign_bit(code_t
)) != 0)) {
169 code_t
*copy
= mem_alloc_array_mayfail(mem_alloc_mayfail
, code_t
*, 0, 0, n_words
, sizeof(code_t
), err
);
172 for (i
= 0; i
< n_words
; i
++)
174 mpz_import(m
, n_words
, !CODE_ENDIAN
? -1 : 1, sizeof(code_t
), 0, 0, copy
);
178 mpz_import(m
, n_words
, !CODE_ENDIAN
? -1 : 1, sizeof(code_t
), 0, 0, code
);
183 #define mpint_import_from_variable(m, type, var) \
185 if (!is_unsigned(type) && unlikely((var) < (type)zero)) { \
186 type var2 = -(var); \
187 mpz_import((m), 1, 1, sizeof(var), 0, 0, &var2); \
190 mpz_import((m), 1, 1, sizeof(var), 0, 0, &(var)); \
194 #define mpint_export_to_variable(m, type, var, success) \
199 bit = mpz_sizeinbase(m, 2); \
200 if (unlikely(bit > 8 * sizeof(type))) { \
204 mpz_export(&(var), NULL, 1, sizeof(type), 0, 0, (m)); \
205 if (likely(mpz_sgn(m) >= 0)) { \
206 if (unlikely((var) < (type)zero)) \
209 if (is_unsigned(type)) \
211 if (likely((var) != sign_bit(type))) \
213 if (unlikely((var) >= (type)zero)) \
219 bool attr_fastcall
mpint_add(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
);
220 bool attr_fastcall
mpint_subtract(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
);
221 bool attr_fastcall
mpint_multiply(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
);
222 bool attr_fastcall
mpint_divide(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
);
223 bool attr_fastcall
mpint_modulo(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
);
224 bool attr_fastcall
mpint_power(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
);
225 bool attr_fastcall
mpint_and(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
);
226 bool attr_fastcall
mpint_or(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
);
227 bool attr_fastcall
mpint_xor(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
);
228 bool attr_fastcall
mpint_shl(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
);
229 bool attr_fastcall
mpint_shr(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
);
230 bool attr_fastcall
mpint_bts(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
);
231 bool attr_fastcall
mpint_btr(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
);
232 bool attr_fastcall
mpint_btc(const mpint_t
*s1
, const mpint_t
*s2
, mpint_t
*r
, ajla_error_t
*err
);
234 bool attr_fastcall
mpint_equal(const mpint_t
*s1
, const mpint_t
*s2
, ajla_flat_option_t
*r
, ajla_error_t
*err
);
235 bool attr_fastcall
mpint_not_equal(const mpint_t
*s1
, const mpint_t
*s2
, ajla_flat_option_t
*r
, ajla_error_t
*err
);
236 bool attr_fastcall
mpint_less(const mpint_t
*s1
, const mpint_t
*s2
, ajla_flat_option_t
*r
, ajla_error_t
*err
);
237 bool attr_fastcall
mpint_less_equal(const mpint_t
*s1
, const mpint_t
*s2
, ajla_flat_option_t
*r
, ajla_error_t
*err
);
238 bool attr_fastcall
mpint_bt(const mpint_t
*s1
, const mpint_t
*s2
, ajla_flat_option_t
*r
, ajla_error_t
*err
);
240 bool attr_fastcall
mpint_not(const mpint_t
*s
, mpint_t
*r
, ajla_error_t
*err
);
241 bool attr_fastcall
mpint_neg(const mpint_t
*s
, mpint_t
*r
, ajla_error_t
*err
);
242 bool attr_fastcall
mpint_inc(const mpint_t
*s
, mpint_t
*r
, ajla_error_t
*err
);
243 bool attr_fastcall
mpint_dec(const mpint_t
*s
, mpint_t
*r
, ajla_error_t
*err
);
244 bool attr_fastcall
mpint_bsf(const mpint_t
*s
, mpint_t
*r
, ajla_error_t
*err
);
245 bool attr_fastcall
mpint_bsr(const mpint_t
*s
, mpint_t
*r
, ajla_error_t
*err
);
246 bool attr_fastcall
mpint_popcnt(const mpint_t
*s
, mpint_t
*r
, ajla_error_t
*err
);
248 #define mpint_divide_alt1 mpint_divide
249 #define mpint_modulo_alt1 mpint_modulo
250 #define mpint_popcnt_alt1 mpint_popcnt
252 #define mpint_conv_real(n, type, ntype, pack, unpack) \
253 bool attr_fastcall cat(mpint_init_from_,type)(mpint_t *t, type *valp, ajla_error_t *err);\
254 void attr_fastcall cat(mpint_export_to_,type)(const mpint_t *t, type *result);
255 for_all_real(mpint_conv_real
, for_all_empty
)
256 #undef mpint_conv_real
258 bool mpint_export_to_blob(const mpint_t
*s
, uint8_t **blob
, size_t *blob_size
, ajla_error_t
*err
);
260 static inline void str_add_mpint(char **s
, size_t *l
, const mpint_t
*mp
, uint16_t base_n
)
262 mpint_t mod
, base
, num
;
264 int16_t base_m
= (int16_t)base_n
;
265 ajla_flat_option_t neg
, bo
;
270 mpint_init_from_int8_t(&mod
, 0, NULL
);
271 mpint_init_from_int16_t(&base
, base_m
, NULL
);
272 mpint_alloc_mayfail(&num
, 0, NULL
);
273 mpint_add(mp
, &mod
, &num
, NULL
);
274 mpint_less(&num
, &mod
, &neg
, NULL
);
276 mpint_neg(&num
, &num
, NULL
);
280 mpint_modulo(&num
, &base
, &mod
, NULL
);
281 digit
= 0; /* avoid warning */
282 mpint_export_to_int8_t(&mod
, &digit
, NULL
);
283 str_add_char(s
, l
, digit
<= 9 ? '0' + (char)digit
: 'a' - 10 + (char)digit
);
284 mpint_less(&num
, &base
, &bo
, NULL
);
285 if (!bo
) mpint_divide(&num
, &base
, &num
, NULL
);
289 str_add_char(s
, l
, '-');
295 for (i
= 0; i
< (*l
- pos
) / 2; i
++) {
296 char c
= (*s
)[pos
+ i
];
297 (*s
)[pos
+ i
] = (*s
)[*l
- 1 - i
];
298 (*s
)[*l
- 1 - i
] = c
;
303 void mpint_init(void);
304 void mpint_done(void);