[mlir][Vector] Update VectorEmulateNarrowType.cpp (1/N) (#123526)
[llvm-project.git] / polly / lib / External / isl / imath / imath.c
blob26dc91f35e43d02cd0ea2b54f3826b0b7722d857
1 /*
2 Name: imath.c
3 Purpose: Arbitrary precision integer arithmetic routines.
4 Author: M. J. Fromberger
6 Copyright (C) 2002-2007 Michael J. Fromberger, All Rights Reserved.
8 Permission is hereby granted, free of charge, to any person obtaining a copy
9 of this software and associated documentation files (the "Software"), to deal
10 in the Software without restriction, including without limitation the rights
11 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 copies of the Software, and to permit persons to whom the Software is
13 furnished to do so, subject to the following conditions:
15 The above copyright notice and this permission notice shall be included in
16 all copies or substantial portions of the Software.
18 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 SOFTWARE.
27 #include "imath.h"
29 #include <assert.h>
30 #include <ctype.h>
31 #include <stdlib.h>
32 #include <string.h>
34 const mp_result MP_OK = 0; /* no error, all is well */
35 const mp_result MP_FALSE = 0; /* boolean false */
36 const mp_result MP_TRUE = -1; /* boolean true */
37 const mp_result MP_MEMORY = -2; /* out of memory */
38 const mp_result MP_RANGE = -3; /* argument out of range */
39 const mp_result MP_UNDEF = -4; /* result undefined */
40 const mp_result MP_TRUNC = -5; /* output truncated */
41 const mp_result MP_BADARG = -6; /* invalid null argument */
42 const mp_result MP_MINERR = -6;
44 const mp_sign MP_NEG = 1; /* value is strictly negative */
45 const mp_sign MP_ZPOS = 0; /* value is non-negative */
47 static const char *s_unknown_err = "unknown result code";
48 static const char *s_error_msg[] = {"error code 0", "boolean true",
49 "out of memory", "argument out of range",
50 "result undefined", "output truncated",
51 "invalid argument", NULL};
53 /* The ith entry of this table gives the value of log_i(2).
55 An integer value n requires ceil(log_i(n)) digits to be represented
56 in base i. Since it is easy to compute lg(n), by counting bits, we
57 can compute log_i(n) = lg(n) * log_i(2).
59 The use of this table eliminates a dependency upon linkage against
60 the standard math libraries.
62 If MP_MAX_RADIX is increased, this table should be expanded too.
64 static const double s_log2[] = {
65 0.000000000, 0.000000000, 1.000000000, 0.630929754, /* (D)(D) 2 3 */
66 0.500000000, 0.430676558, 0.386852807, 0.356207187, /* 4 5 6 7 */
67 0.333333333, 0.315464877, 0.301029996, 0.289064826, /* 8 9 10 11 */
68 0.278942946, 0.270238154, 0.262649535, 0.255958025, /* 12 13 14 15 */
69 0.250000000, 0.244650542, 0.239812467, 0.235408913, /* 16 17 18 19 */
70 0.231378213, 0.227670249, 0.224243824, 0.221064729, /* 20 21 22 23 */
71 0.218104292, 0.215338279, 0.212746054, 0.210309918, /* 24 25 26 27 */
72 0.208014598, 0.205846832, 0.203795047, 0.201849087, /* 28 29 30 31 */
73 0.200000000, 0.198239863, 0.196561632, 0.194959022, /* 32 33 34 35 */
74 0.193426404, /* 36 */
77 /* Return the number of digits needed to represent a static value */
78 #define MP_VALUE_DIGITS(V) \
79 ((sizeof(V) + (sizeof(mp_digit) - 1)) / sizeof(mp_digit))
81 /* Round precision P to nearest word boundary */
82 static inline mp_size s_round_prec(mp_size P) { return 2 * ((P + 1) / 2); }
84 /* Set array P of S digits to zero */
85 static inline void ZERO(mp_digit *P, mp_size S) {
86 mp_size i__ = S * sizeof(mp_digit);
87 mp_digit *p__ = P;
88 memset(p__, 0, i__);
91 /* Copy S digits from array P to array Q */
92 static inline void COPY(mp_digit *P, mp_digit *Q, mp_size S) {
93 mp_size i__ = S * sizeof(mp_digit);
94 mp_digit *p__ = P;
95 mp_digit *q__ = Q;
96 memcpy(q__, p__, i__);
99 /* Reverse N elements of unsigned char in A. */
100 static inline void REV(unsigned char *A, int N) {
101 unsigned char *u_ = A;
102 unsigned char *v_ = u_ + N - 1;
103 while (u_ < v_) {
104 unsigned char xch = *u_;
105 *u_++ = *v_;
106 *v_-- = xch;
110 /* Strip leading zeroes from z_ in-place. */
111 static inline void CLAMP(mp_int z_) {
112 mp_size uz_ = MP_USED(z_);
113 mp_digit *dz_ = MP_DIGITS(z_) + uz_ - 1;
114 while (uz_ > 1 && (*dz_-- == 0)) --uz_;
115 z_->used = uz_;
118 /* Select min/max. */
119 static inline int MIN(int A, int B) { return (B < A ? B : A); }
120 static inline mp_size MAX(mp_size A, mp_size B) { return (B > A ? B : A); }
122 /* Exchange lvalues A and B of type T, e.g.
123 SWAP(int, x, y) where x and y are variables of type int. */
124 #define SWAP(T, A, B) \
125 do { \
126 T t_ = (A); \
127 A = (B); \
128 B = t_; \
129 } while (0)
131 /* Declare a block of N temporary mpz_t values.
132 These values are initialized to zero.
133 You must add CLEANUP_TEMP() at the end of the function.
134 Use TEMP(i) to access a pointer to the ith value.
136 #define DECLARE_TEMP(N) \
137 struct { \
138 mpz_t value[(N)]; \
139 int len; \
140 mp_result err; \
141 } temp_ = { \
142 .len = (N), \
143 .err = MP_OK, \
144 }; \
145 do { \
146 for (int i = 0; i < temp_.len; i++) { \
147 mp_int_init(TEMP(i)); \
149 } while (0)
151 /* Clear all allocated temp values. */
152 #define CLEANUP_TEMP() \
153 CLEANUP: \
154 do { \
155 for (int i = 0; i < temp_.len; i++) { \
156 mp_int_clear(TEMP(i)); \
158 if (temp_.err != MP_OK) { \
159 return temp_.err; \
161 } while (0)
163 /* A pointer to the kth temp value. */
164 #define TEMP(K) (temp_.value + (K))
166 /* Evaluate E, an expression of type mp_result expected to return MP_OK. If
167 the value is not MP_OK, the error is cached and control resumes at the
168 cleanup handler, which returns it.
170 #define REQUIRE(E) \
171 do { \
172 temp_.err = (E); \
173 if (temp_.err != MP_OK) goto CLEANUP; \
174 } while (0)
176 /* Compare value to zero. */
177 static inline int CMPZ(mp_int Z) {
178 if (Z->used == 1 && Z->digits[0] == 0) return 0;
179 return (Z->sign == MP_NEG) ? -1 : 1;
182 static inline mp_word UPPER_HALF(mp_word W) { return (W >> MP_DIGIT_BIT); }
183 static inline mp_digit LOWER_HALF(mp_word W) { return (mp_digit)(W); }
185 /* Report whether the highest-order bit of W is 1. */
186 static inline bool HIGH_BIT_SET(mp_word W) {
187 return (W >> (MP_WORD_BIT - 1)) != 0;
190 /* Report whether adding W + V will carry out. */
191 static inline bool ADD_WILL_OVERFLOW(mp_word W, mp_word V) {
192 return ((MP_WORD_MAX - V) < W);
195 /* Default number of digits allocated to a new mp_int */
196 static mp_size default_precision = 8;
198 void mp_int_default_precision(mp_size size) {
199 assert(size > 0);
200 default_precision = size;
203 /* Minimum number of digits to invoke recursive multiply */
204 static mp_size multiply_threshold = 32;
206 void mp_int_multiply_threshold(mp_size thresh) {
207 assert(thresh >= sizeof(mp_word));
208 multiply_threshold = thresh;
211 /* Allocate a buffer of (at least) num digits, or return
212 NULL if that couldn't be done. */
213 static mp_digit *s_alloc(mp_size num);
215 /* Release a buffer of digits allocated by s_alloc(). */
216 static void s_free(void *ptr);
218 /* Insure that z has at least min digits allocated, resizing if
219 necessary. Returns true if successful, false if out of memory. */
220 static bool s_pad(mp_int z, mp_size min);
222 /* Ensure Z has at least N digits allocated. */
223 static inline mp_result GROW(mp_int Z, mp_size N) {
224 return s_pad(Z, N) ? MP_OK : MP_MEMORY;
227 /* Fill in a "fake" mp_int on the stack with a given value */
228 static void s_fake(mp_int z, mp_small value, mp_digit vbuf[]);
229 static void s_ufake(mp_int z, mp_usmall value, mp_digit vbuf[]);
231 /* Compare two runs of digits of given length, returns <0, 0, >0 */
232 static int s_cdig(mp_digit *da, mp_digit *db, mp_size len);
234 /* Pack the unsigned digits of v into array t */
235 static int s_uvpack(mp_usmall v, mp_digit t[]);
237 /* Compare magnitudes of a and b, returns <0, 0, >0 */
238 static int s_ucmp(mp_int a, mp_int b);
240 /* Compare magnitudes of a and v, returns <0, 0, >0 */
241 static int s_vcmp(mp_int a, mp_small v);
242 static int s_uvcmp(mp_int a, mp_usmall uv);
244 /* Unsigned magnitude addition; assumes dc is big enough.
245 Carry out is returned (no memory allocated). */
246 static mp_digit s_uadd(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
247 mp_size size_b);
249 /* Unsigned magnitude subtraction. Assumes dc is big enough. */
250 static void s_usub(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
251 mp_size size_b);
253 /* Unsigned recursive multiplication. Assumes dc is big enough. */
254 static int s_kmul(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
255 mp_size size_b);
257 /* Unsigned magnitude multiplication. Assumes dc is big enough. */
258 static void s_umul(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
259 mp_size size_b);
261 /* Unsigned recursive squaring. Assumes dc is big enough. */
262 static int s_ksqr(mp_digit *da, mp_digit *dc, mp_size size_a);
264 /* Unsigned magnitude squaring. Assumes dc is big enough. */
265 static void s_usqr(mp_digit *da, mp_digit *dc, mp_size size_a);
267 /* Single digit addition. Assumes a is big enough. */
268 static void s_dadd(mp_int a, mp_digit b);
270 /* Single digit multiplication. Assumes a is big enough. */
271 static void s_dmul(mp_int a, mp_digit b);
273 /* Single digit multiplication on buffers; assumes dc is big enough. */
274 static void s_dbmul(mp_digit *da, mp_digit b, mp_digit *dc, mp_size size_a);
276 /* Single digit division. Replaces a with the quotient,
277 returns the remainder. */
278 static mp_digit s_ddiv(mp_int a, mp_digit b);
280 /* Quick division by a power of 2, replaces z (no allocation) */
281 static void s_qdiv(mp_int z, mp_size p2);
283 /* Quick remainder by a power of 2, replaces z (no allocation) */
284 static void s_qmod(mp_int z, mp_size p2);
286 /* Quick multiplication by a power of 2, replaces z.
287 Allocates if necessary; returns false in case this fails. */
288 static int s_qmul(mp_int z, mp_size p2);
290 /* Quick subtraction from a power of 2, replaces z.
291 Allocates if necessary; returns false in case this fails. */
292 static int s_qsub(mp_int z, mp_size p2);
294 /* Return maximum k such that 2^k divides z. */
295 static int s_dp2k(mp_int z);
297 /* Return k >= 0 such that z = 2^k, or -1 if there is no such k. */
298 static int s_isp2(mp_int z);
300 /* Set z to 2^k. May allocate; returns false in case this fails. */
301 static int s_2expt(mp_int z, mp_small k);
303 /* Normalize a and b for division, returns normalization constant */
304 static int s_norm(mp_int a, mp_int b);
306 /* Compute constant mu for Barrett reduction, given modulus m, result
307 replaces z, m is untouched. */
308 static mp_result s_brmu(mp_int z, mp_int m);
310 /* Reduce a modulo m, using Barrett's algorithm. */
311 static int s_reduce(mp_int x, mp_int m, mp_int mu, mp_int q1, mp_int q2);
313 /* Modular exponentiation, using Barrett reduction */
314 static mp_result s_embar(mp_int a, mp_int b, mp_int m, mp_int mu, mp_int c);
316 /* Unsigned magnitude division. Assumes |a| > |b|. Allocates temporaries;
317 overwrites a with quotient, b with remainder. */
318 static mp_result s_udiv_knuth(mp_int a, mp_int b);
320 /* Compute the number of digits in radix r required to represent the given
321 value. Does not account for sign flags, terminators, etc. */
322 static int s_outlen(mp_int z, mp_size r);
324 /* Guess how many digits of precision will be needed to represent a radix r
325 value of the specified number of digits. Returns a value guaranteed to be
326 no smaller than the actual number required. */
327 static mp_size s_inlen(int len, mp_size r);
329 /* Convert a character to a digit value in radix r, or
330 -1 if out of range */
331 static int s_ch2val(char c, int r);
333 /* Convert a digit value to a character */
334 static char s_val2ch(int v, int caps);
336 /* Take 2's complement of a buffer in place */
337 static void s_2comp(unsigned char *buf, int len);
339 /* Convert a value to binary, ignoring sign. On input, *limpos is the bound on
340 how many bytes should be written to buf; on output, *limpos is set to the
341 number of bytes actually written. */
342 static mp_result s_tobin(mp_int z, unsigned char *buf, int *limpos, int pad);
344 /* Multiply X by Y into Z, ignoring signs. Requires that Z have enough storage
345 preallocated to hold the result. */
346 static inline void UMUL(mp_int X, mp_int Y, mp_int Z) {
347 mp_size ua_ = MP_USED(X);
348 mp_size ub_ = MP_USED(Y);
349 mp_size o_ = ua_ + ub_;
350 ZERO(MP_DIGITS(Z), o_);
351 (void)s_kmul(MP_DIGITS(X), MP_DIGITS(Y), MP_DIGITS(Z), ua_, ub_);
352 Z->used = o_;
353 CLAMP(Z);
356 /* Square X into Z. Requires that Z have enough storage to hold the result. */
357 static inline void USQR(mp_int X, mp_int Z) {
358 mp_size ua_ = MP_USED(X);
359 mp_size o_ = ua_ + ua_;
360 ZERO(MP_DIGITS(Z), o_);
361 (void)s_ksqr(MP_DIGITS(X), MP_DIGITS(Z), ua_);
362 Z->used = o_;
363 CLAMP(Z);
366 mp_result mp_int_init(mp_int z) {
367 if (z == NULL) return MP_BADARG;
369 z->single = 0;
370 z->digits = &(z->single);
371 z->alloc = 1;
372 z->used = 1;
373 z->sign = MP_ZPOS;
375 return MP_OK;
378 mp_int mp_int_alloc(void) {
379 mp_int out = malloc(sizeof(mpz_t));
381 if (out != NULL) mp_int_init(out);
383 return out;
386 mp_result mp_int_init_size(mp_int z, mp_size prec) {
387 assert(z != NULL);
389 if (prec == 0) {
390 prec = default_precision;
391 } else if (prec == 1) {
392 return mp_int_init(z);
393 } else {
394 prec = s_round_prec(prec);
397 z->digits = s_alloc(prec);
398 if (MP_DIGITS(z) == NULL) return MP_MEMORY;
400 z->digits[0] = 0;
401 z->used = 1;
402 z->alloc = prec;
403 z->sign = MP_ZPOS;
405 return MP_OK;
408 mp_result mp_int_init_copy(mp_int z, mp_int old) {
409 assert(z != NULL && old != NULL);
411 mp_size uold = MP_USED(old);
412 if (uold == 1) {
413 mp_int_init(z);
414 } else {
415 mp_size target = MAX(uold, default_precision);
416 mp_result res = mp_int_init_size(z, target);
417 if (res != MP_OK) return res;
420 z->used = uold;
421 z->sign = old->sign;
422 COPY(MP_DIGITS(old), MP_DIGITS(z), uold);
424 return MP_OK;
427 mp_result mp_int_init_value(mp_int z, mp_small value) {
428 mpz_t vtmp;
429 mp_digit vbuf[MP_VALUE_DIGITS(value)];
431 s_fake(&vtmp, value, vbuf);
432 return mp_int_init_copy(z, &vtmp);
435 mp_result mp_int_init_uvalue(mp_int z, mp_usmall uvalue) {
436 mpz_t vtmp;
437 mp_digit vbuf[MP_VALUE_DIGITS(uvalue)];
439 s_ufake(&vtmp, uvalue, vbuf);
440 return mp_int_init_copy(z, &vtmp);
443 mp_result mp_int_set_value(mp_int z, mp_small value) {
444 mpz_t vtmp;
445 mp_digit vbuf[MP_VALUE_DIGITS(value)];
447 s_fake(&vtmp, value, vbuf);
448 return mp_int_copy(&vtmp, z);
451 mp_result mp_int_set_uvalue(mp_int z, mp_usmall uvalue) {
452 mpz_t vtmp;
453 mp_digit vbuf[MP_VALUE_DIGITS(uvalue)];
455 s_ufake(&vtmp, uvalue, vbuf);
456 return mp_int_copy(&vtmp, z);
459 void mp_int_clear(mp_int z) {
460 if (z == NULL) return;
462 if (MP_DIGITS(z) != NULL) {
463 if (MP_DIGITS(z) != &(z->single)) s_free(MP_DIGITS(z));
465 z->digits = NULL;
469 void mp_int_free(mp_int z) {
470 assert(z != NULL);
472 mp_int_clear(z);
473 free(z); /* note: NOT s_free() */
476 mp_result mp_int_copy(mp_int a, mp_int c) {
477 assert(a != NULL && c != NULL);
479 if (a != c) {
480 mp_size ua = MP_USED(a);
481 mp_digit *da, *dc;
483 if (!s_pad(c, ua)) return MP_MEMORY;
485 da = MP_DIGITS(a);
486 dc = MP_DIGITS(c);
487 COPY(da, dc, ua);
489 c->used = ua;
490 c->sign = a->sign;
493 return MP_OK;
496 void mp_int_swap(mp_int a, mp_int c) {
497 if (a != c) {
498 mpz_t tmp = *a;
500 *a = *c;
501 *c = tmp;
503 if (MP_DIGITS(a) == &(c->single)) a->digits = &(a->single);
504 if (MP_DIGITS(c) == &(a->single)) c->digits = &(c->single);
508 void mp_int_zero(mp_int z) {
509 assert(z != NULL);
511 z->digits[0] = 0;
512 z->used = 1;
513 z->sign = MP_ZPOS;
516 mp_result mp_int_abs(mp_int a, mp_int c) {
517 assert(a != NULL && c != NULL);
519 mp_result res;
520 if ((res = mp_int_copy(a, c)) != MP_OK) return res;
522 c->sign = MP_ZPOS;
523 return MP_OK;
526 mp_result mp_int_neg(mp_int a, mp_int c) {
527 assert(a != NULL && c != NULL);
529 mp_result res;
530 if ((res = mp_int_copy(a, c)) != MP_OK) return res;
532 if (CMPZ(c) != 0) c->sign = 1 - MP_SIGN(a);
534 return MP_OK;
537 mp_result mp_int_add(mp_int a, mp_int b, mp_int c) {
538 assert(a != NULL && b != NULL && c != NULL);
540 mp_size ua = MP_USED(a);
541 mp_size ub = MP_USED(b);
542 mp_size max = MAX(ua, ub);
544 if (MP_SIGN(a) == MP_SIGN(b)) {
545 /* Same sign -- add magnitudes, preserve sign of addends */
546 if (!s_pad(c, max)) return MP_MEMORY;
548 mp_digit carry = s_uadd(MP_DIGITS(a), MP_DIGITS(b), MP_DIGITS(c), ua, ub);
549 mp_size uc = max;
551 if (carry) {
552 if (!s_pad(c, max + 1)) return MP_MEMORY;
554 c->digits[max] = carry;
555 ++uc;
558 c->used = uc;
559 c->sign = a->sign;
561 } else {
562 /* Different signs -- subtract magnitudes, preserve sign of greater */
563 int cmp = s_ucmp(a, b); /* magnitude comparison, sign ignored */
565 /* Set x to max(a, b), y to min(a, b) to simplify later code.
566 A special case yields zero for equal magnitudes.
568 mp_int x, y;
569 if (cmp == 0) {
570 mp_int_zero(c);
571 return MP_OK;
572 } else if (cmp < 0) {
573 x = b;
574 y = a;
575 } else {
576 x = a;
577 y = b;
580 if (!s_pad(c, MP_USED(x))) return MP_MEMORY;
582 /* Subtract smaller from larger */
583 s_usub(MP_DIGITS(x), MP_DIGITS(y), MP_DIGITS(c), MP_USED(x), MP_USED(y));
584 c->used = x->used;
585 CLAMP(c);
587 /* Give result the sign of the larger */
588 c->sign = x->sign;
591 return MP_OK;
594 mp_result mp_int_add_value(mp_int a, mp_small value, mp_int c) {
595 mpz_t vtmp;
596 mp_digit vbuf[MP_VALUE_DIGITS(value)];
598 s_fake(&vtmp, value, vbuf);
600 return mp_int_add(a, &vtmp, c);
603 mp_result mp_int_sub(mp_int a, mp_int b, mp_int c) {
604 assert(a != NULL && b != NULL && c != NULL);
606 mp_size ua = MP_USED(a);
607 mp_size ub = MP_USED(b);
608 mp_size max = MAX(ua, ub);
610 if (MP_SIGN(a) != MP_SIGN(b)) {
611 /* Different signs -- add magnitudes and keep sign of a */
612 if (!s_pad(c, max)) return MP_MEMORY;
614 mp_digit carry = s_uadd(MP_DIGITS(a), MP_DIGITS(b), MP_DIGITS(c), ua, ub);
615 mp_size uc = max;
617 if (carry) {
618 if (!s_pad(c, max + 1)) return MP_MEMORY;
620 c->digits[max] = carry;
621 ++uc;
624 c->used = uc;
625 c->sign = a->sign;
627 } else {
628 /* Same signs -- subtract magnitudes */
629 if (!s_pad(c, max)) return MP_MEMORY;
630 mp_int x, y;
631 mp_sign osign;
633 int cmp = s_ucmp(a, b);
634 if (cmp >= 0) {
635 x = a;
636 y = b;
637 osign = MP_ZPOS;
638 } else {
639 x = b;
640 y = a;
641 osign = MP_NEG;
644 if (MP_SIGN(a) == MP_NEG && cmp != 0) osign = 1 - osign;
646 s_usub(MP_DIGITS(x), MP_DIGITS(y), MP_DIGITS(c), MP_USED(x), MP_USED(y));
647 c->used = x->used;
648 CLAMP(c);
650 c->sign = osign;
653 return MP_OK;
656 mp_result mp_int_sub_value(mp_int a, mp_small value, mp_int c) {
657 mpz_t vtmp;
658 mp_digit vbuf[MP_VALUE_DIGITS(value)];
660 s_fake(&vtmp, value, vbuf);
662 return mp_int_sub(a, &vtmp, c);
665 mp_result mp_int_mul(mp_int a, mp_int b, mp_int c) {
666 assert(a != NULL && b != NULL && c != NULL);
668 /* If either input is zero, we can shortcut multiplication */
669 if (mp_int_compare_zero(a) == 0 || mp_int_compare_zero(b) == 0) {
670 mp_int_zero(c);
671 return MP_OK;
674 /* Output is positive if inputs have same sign, otherwise negative */
675 mp_sign osign = (MP_SIGN(a) == MP_SIGN(b)) ? MP_ZPOS : MP_NEG;
677 /* If the output is not identical to any of the inputs, we'll write the
678 results directly; otherwise, allocate a temporary space. */
679 mp_size ua = MP_USED(a);
680 mp_size ub = MP_USED(b);
681 mp_size osize = MAX(ua, ub);
682 osize = 4 * ((osize + 1) / 2);
684 mp_digit *out;
685 mp_size p = 0;
686 if (c == a || c == b) {
687 p = MAX(s_round_prec(osize), default_precision);
689 if ((out = s_alloc(p)) == NULL) return MP_MEMORY;
690 } else {
691 if (!s_pad(c, osize)) return MP_MEMORY;
693 out = MP_DIGITS(c);
695 ZERO(out, osize);
697 if (!s_kmul(MP_DIGITS(a), MP_DIGITS(b), out, ua, ub)) return MP_MEMORY;
699 /* If we allocated a new buffer, get rid of whatever memory c was already
700 using, and fix up its fields to reflect that.
702 if (out != MP_DIGITS(c)) {
703 if ((void *)MP_DIGITS(c) != (void *)c) s_free(MP_DIGITS(c));
704 c->digits = out;
705 c->alloc = p;
708 c->used = osize; /* might not be true, but we'll fix it ... */
709 CLAMP(c); /* ... right here */
710 c->sign = osign;
712 return MP_OK;
715 mp_result mp_int_mul_value(mp_int a, mp_small value, mp_int c) {
716 mpz_t vtmp;
717 mp_digit vbuf[MP_VALUE_DIGITS(value)];
719 s_fake(&vtmp, value, vbuf);
721 return mp_int_mul(a, &vtmp, c);
724 mp_result mp_int_mul_pow2(mp_int a, mp_small p2, mp_int c) {
725 assert(a != NULL && c != NULL && p2 >= 0);
727 mp_result res = mp_int_copy(a, c);
728 if (res != MP_OK) return res;
730 if (s_qmul(c, (mp_size)p2)) {
731 return MP_OK;
732 } else {
733 return MP_MEMORY;
737 mp_result mp_int_sqr(mp_int a, mp_int c) {
738 assert(a != NULL && c != NULL);
740 /* Get a temporary buffer big enough to hold the result */
741 mp_size osize = (mp_size)4 * ((MP_USED(a) + 1) / 2);
742 mp_size p = 0;
743 mp_digit *out;
744 if (a == c) {
745 p = s_round_prec(osize);
746 p = MAX(p, default_precision);
748 if ((out = s_alloc(p)) == NULL) return MP_MEMORY;
749 } else {
750 if (!s_pad(c, osize)) return MP_MEMORY;
752 out = MP_DIGITS(c);
754 ZERO(out, osize);
756 s_ksqr(MP_DIGITS(a), out, MP_USED(a));
758 /* Get rid of whatever memory c was already using, and fix up its fields to
759 reflect the new digit array it's using
761 if (out != MP_DIGITS(c)) {
762 if ((void *)MP_DIGITS(c) != (void *)c) s_free(MP_DIGITS(c));
763 c->digits = out;
764 c->alloc = p;
767 c->used = osize; /* might not be true, but we'll fix it ... */
768 CLAMP(c); /* ... right here */
769 c->sign = MP_ZPOS;
771 return MP_OK;
774 mp_result mp_int_div(mp_int a, mp_int b, mp_int q, mp_int r) {
775 assert(a != NULL && b != NULL && q != r);
777 int cmp;
778 mp_result res = MP_OK;
779 mp_int qout, rout;
780 mp_sign sa = MP_SIGN(a);
781 mp_sign sb = MP_SIGN(b);
782 if (CMPZ(b) == 0) {
783 return MP_UNDEF;
784 } else if ((cmp = s_ucmp(a, b)) < 0) {
785 /* If |a| < |b|, no division is required:
786 q = 0, r = a
788 if (r && (res = mp_int_copy(a, r)) != MP_OK) return res;
790 if (q) mp_int_zero(q);
792 return MP_OK;
793 } else if (cmp == 0) {
794 /* If |a| = |b|, no division is required:
795 q = 1 or -1, r = 0
797 if (r) mp_int_zero(r);
799 if (q) {
800 mp_int_zero(q);
801 q->digits[0] = 1;
803 if (sa != sb) q->sign = MP_NEG;
806 return MP_OK;
809 /* When |a| > |b|, real division is required. We need someplace to store
810 quotient and remainder, but q and r are allowed to be NULL or to overlap
811 with the inputs.
813 DECLARE_TEMP(2);
814 int lg;
815 if ((lg = s_isp2(b)) < 0) {
816 if (q && b != q) {
817 REQUIRE(mp_int_copy(a, q));
818 qout = q;
819 } else {
820 REQUIRE(mp_int_copy(a, TEMP(0)));
821 qout = TEMP(0);
824 if (r && a != r) {
825 REQUIRE(mp_int_copy(b, r));
826 rout = r;
827 } else {
828 REQUIRE(mp_int_copy(b, TEMP(1)));
829 rout = TEMP(1);
832 REQUIRE(s_udiv_knuth(qout, rout));
833 } else {
834 if (q) REQUIRE(mp_int_copy(a, q));
835 if (r) REQUIRE(mp_int_copy(a, r));
837 if (q) s_qdiv(q, (mp_size)lg);
838 qout = q;
839 if (r) s_qmod(r, (mp_size)lg);
840 rout = r;
843 /* Recompute signs for output */
844 if (rout) {
845 rout->sign = sa;
846 if (CMPZ(rout) == 0) rout->sign = MP_ZPOS;
848 if (qout) {
849 qout->sign = (sa == sb) ? MP_ZPOS : MP_NEG;
850 if (CMPZ(qout) == 0) qout->sign = MP_ZPOS;
853 if (q) REQUIRE(mp_int_copy(qout, q));
854 if (r) REQUIRE(mp_int_copy(rout, r));
855 CLEANUP_TEMP();
856 return res;
859 mp_result mp_int_mod(mp_int a, mp_int m, mp_int c) {
860 DECLARE_TEMP(1);
861 mp_int out = (m == c) ? TEMP(0) : c;
862 REQUIRE(mp_int_div(a, m, NULL, out));
863 if (CMPZ(out) < 0) {
864 REQUIRE(mp_int_add(out, m, c));
865 } else {
866 REQUIRE(mp_int_copy(out, c));
868 CLEANUP_TEMP();
869 return MP_OK;
872 mp_result mp_int_div_value(mp_int a, mp_small value, mp_int q, mp_small *r) {
873 mpz_t vtmp;
874 mp_digit vbuf[MP_VALUE_DIGITS(value)];
875 s_fake(&vtmp, value, vbuf);
877 DECLARE_TEMP(1);
878 REQUIRE(mp_int_div(a, &vtmp, q, TEMP(0)));
880 if (r) (void)mp_int_to_int(TEMP(0), r); /* can't fail */
882 CLEANUP_TEMP();
883 return MP_OK;
886 mp_result mp_int_div_pow2(mp_int a, mp_small p2, mp_int q, mp_int r) {
887 assert(a != NULL && p2 >= 0 && q != r);
889 mp_result res = MP_OK;
890 if (q != NULL && (res = mp_int_copy(a, q)) == MP_OK) {
891 s_qdiv(q, (mp_size)p2);
894 if (res == MP_OK && r != NULL && (res = mp_int_copy(a, r)) == MP_OK) {
895 s_qmod(r, (mp_size)p2);
898 return res;
901 mp_result mp_int_expt(mp_int a, mp_small b, mp_int c) {
902 assert(c != NULL);
903 if (b < 0) return MP_RANGE;
905 DECLARE_TEMP(1);
906 REQUIRE(mp_int_copy(a, TEMP(0)));
908 (void)mp_int_set_value(c, 1);
909 unsigned int v = labs(b);
910 while (v != 0) {
911 if (v & 1) {
912 REQUIRE(mp_int_mul(c, TEMP(0), c));
915 v >>= 1;
916 if (v == 0) break;
918 REQUIRE(mp_int_sqr(TEMP(0), TEMP(0)));
921 CLEANUP_TEMP();
922 return MP_OK;
925 mp_result mp_int_expt_value(mp_small a, mp_small b, mp_int c) {
926 assert(c != NULL);
927 if (b < 0) return MP_RANGE;
929 DECLARE_TEMP(1);
930 REQUIRE(mp_int_set_value(TEMP(0), a));
932 (void)mp_int_set_value(c, 1);
933 unsigned int v = labs(b);
934 while (v != 0) {
935 if (v & 1) {
936 REQUIRE(mp_int_mul(c, TEMP(0), c));
939 v >>= 1;
940 if (v == 0) break;
942 REQUIRE(mp_int_sqr(TEMP(0), TEMP(0)));
945 CLEANUP_TEMP();
946 return MP_OK;
949 mp_result mp_int_expt_full(mp_int a, mp_int b, mp_int c) {
950 assert(a != NULL && b != NULL && c != NULL);
951 if (MP_SIGN(b) == MP_NEG) return MP_RANGE;
953 DECLARE_TEMP(1);
954 REQUIRE(mp_int_copy(a, TEMP(0)));
956 (void)mp_int_set_value(c, 1);
957 for (unsigned ix = 0; ix < MP_USED(b); ++ix) {
958 mp_digit d = b->digits[ix];
960 for (unsigned jx = 0; jx < MP_DIGIT_BIT; ++jx) {
961 if (d & 1) {
962 REQUIRE(mp_int_mul(c, TEMP(0), c));
965 d >>= 1;
966 if (d == 0 && ix + 1 == MP_USED(b)) break;
967 REQUIRE(mp_int_sqr(TEMP(0), TEMP(0)));
971 CLEANUP_TEMP();
972 return MP_OK;
975 int mp_int_compare(mp_int a, mp_int b) {
976 assert(a != NULL && b != NULL);
978 mp_sign sa = MP_SIGN(a);
979 if (sa == MP_SIGN(b)) {
980 int cmp = s_ucmp(a, b);
982 /* If they're both zero or positive, the normal comparison applies; if both
983 negative, the sense is reversed. */
984 if (sa == MP_ZPOS) {
985 return cmp;
986 } else {
987 return -cmp;
989 } else if (sa == MP_ZPOS) {
990 return 1;
991 } else {
992 return -1;
996 int mp_int_compare_unsigned(mp_int a, mp_int b) {
997 assert(a != NULL && b != NULL);
999 return s_ucmp(a, b);
1002 int mp_int_compare_zero(mp_int z) {
1003 assert(z != NULL);
1005 if (MP_USED(z) == 1 && z->digits[0] == 0) {
1006 return 0;
1007 } else if (MP_SIGN(z) == MP_ZPOS) {
1008 return 1;
1009 } else {
1010 return -1;
1014 int mp_int_compare_value(mp_int z, mp_small value) {
1015 assert(z != NULL);
1017 mp_sign vsign = (value < 0) ? MP_NEG : MP_ZPOS;
1018 if (vsign == MP_SIGN(z)) {
1019 int cmp = s_vcmp(z, value);
1021 return (vsign == MP_ZPOS) ? cmp : -cmp;
1022 } else {
1023 return (value < 0) ? 1 : -1;
1027 int mp_int_compare_uvalue(mp_int z, mp_usmall uv) {
1028 assert(z != NULL);
1030 if (MP_SIGN(z) == MP_NEG) {
1031 return -1;
1032 } else {
1033 return s_uvcmp(z, uv);
1037 mp_result mp_int_exptmod(mp_int a, mp_int b, mp_int m, mp_int c) {
1038 assert(a != NULL && b != NULL && c != NULL && m != NULL);
1040 /* Zero moduli and negative exponents are not considered. */
1041 if (CMPZ(m) == 0) return MP_UNDEF;
1042 if (CMPZ(b) < 0) return MP_RANGE;
1044 mp_size um = MP_USED(m);
1045 DECLARE_TEMP(3);
1046 REQUIRE(GROW(TEMP(0), 2 * um));
1047 REQUIRE(GROW(TEMP(1), 2 * um));
1049 mp_int s;
1050 if (c == b || c == m) {
1051 REQUIRE(GROW(TEMP(2), 2 * um));
1052 s = TEMP(2);
1053 } else {
1054 s = c;
1057 REQUIRE(mp_int_mod(a, m, TEMP(0)));
1058 REQUIRE(s_brmu(TEMP(1), m));
1059 REQUIRE(s_embar(TEMP(0), b, m, TEMP(1), s));
1060 REQUIRE(mp_int_copy(s, c));
1062 CLEANUP_TEMP();
1063 return MP_OK;
1066 mp_result mp_int_exptmod_evalue(mp_int a, mp_small value, mp_int m, mp_int c) {
1067 mpz_t vtmp;
1068 mp_digit vbuf[MP_VALUE_DIGITS(value)];
1070 s_fake(&vtmp, value, vbuf);
1072 return mp_int_exptmod(a, &vtmp, m, c);
1075 mp_result mp_int_exptmod_bvalue(mp_small value, mp_int b, mp_int m, mp_int c) {
1076 mpz_t vtmp;
1077 mp_digit vbuf[MP_VALUE_DIGITS(value)];
1079 s_fake(&vtmp, value, vbuf);
1081 return mp_int_exptmod(&vtmp, b, m, c);
1084 mp_result mp_int_exptmod_known(mp_int a, mp_int b, mp_int m, mp_int mu,
1085 mp_int c) {
1086 assert(a && b && m && c);
1088 /* Zero moduli and negative exponents are not considered. */
1089 if (CMPZ(m) == 0) return MP_UNDEF;
1090 if (CMPZ(b) < 0) return MP_RANGE;
1092 DECLARE_TEMP(2);
1093 mp_size um = MP_USED(m);
1094 REQUIRE(GROW(TEMP(0), 2 * um));
1096 mp_int s;
1097 if (c == b || c == m) {
1098 REQUIRE(GROW(TEMP(1), 2 * um));
1099 s = TEMP(1);
1100 } else {
1101 s = c;
1104 REQUIRE(mp_int_mod(a, m, TEMP(0)));
1105 REQUIRE(s_embar(TEMP(0), b, m, mu, s));
1106 REQUIRE(mp_int_copy(s, c));
1108 CLEANUP_TEMP();
1109 return MP_OK;
1112 mp_result mp_int_redux_const(mp_int m, mp_int c) {
1113 assert(m != NULL && c != NULL && m != c);
1115 return s_brmu(c, m);
1118 mp_result mp_int_invmod(mp_int a, mp_int m, mp_int c) {
1119 assert(a != NULL && m != NULL && c != NULL);
1121 if (CMPZ(a) == 0 || CMPZ(m) <= 0) return MP_RANGE;
1123 DECLARE_TEMP(2);
1125 REQUIRE(mp_int_egcd(a, m, TEMP(0), TEMP(1), NULL));
1127 if (mp_int_compare_value(TEMP(0), 1) != 0) {
1128 REQUIRE(MP_UNDEF);
1131 /* It is first necessary to constrain the value to the proper range */
1132 REQUIRE(mp_int_mod(TEMP(1), m, TEMP(1)));
1134 /* Now, if 'a' was originally negative, the value we have is actually the
1135 magnitude of the negative representative; to get the positive value we
1136 have to subtract from the modulus. Otherwise, the value is okay as it
1137 stands.
1139 if (MP_SIGN(a) == MP_NEG) {
1140 REQUIRE(mp_int_sub(m, TEMP(1), c));
1141 } else {
1142 REQUIRE(mp_int_copy(TEMP(1), c));
1145 CLEANUP_TEMP();
1146 return MP_OK;
1149 /* Binary GCD algorithm due to Josef Stein, 1961 */
1150 mp_result mp_int_gcd(mp_int a, mp_int b, mp_int c) {
1151 assert(a != NULL && b != NULL && c != NULL);
1153 int ca = CMPZ(a);
1154 int cb = CMPZ(b);
1155 if (ca == 0 && cb == 0) {
1156 return MP_UNDEF;
1157 } else if (ca == 0) {
1158 return mp_int_abs(b, c);
1159 } else if (cb == 0) {
1160 return mp_int_abs(a, c);
1163 DECLARE_TEMP(3);
1164 REQUIRE(mp_int_copy(a, TEMP(0)));
1165 REQUIRE(mp_int_copy(b, TEMP(1)));
1167 TEMP(0)->sign = MP_ZPOS;
1168 TEMP(1)->sign = MP_ZPOS;
1170 int k = 0;
1171 { /* Divide out common factors of 2 from u and v */
1172 int div2_u = s_dp2k(TEMP(0));
1173 int div2_v = s_dp2k(TEMP(1));
1175 k = MIN(div2_u, div2_v);
1176 s_qdiv(TEMP(0), (mp_size)k);
1177 s_qdiv(TEMP(1), (mp_size)k);
1180 if (mp_int_is_odd(TEMP(0))) {
1181 REQUIRE(mp_int_neg(TEMP(1), TEMP(2)));
1182 } else {
1183 REQUIRE(mp_int_copy(TEMP(0), TEMP(2)));
1186 for (;;) {
1187 s_qdiv(TEMP(2), s_dp2k(TEMP(2)));
1189 if (CMPZ(TEMP(2)) > 0) {
1190 REQUIRE(mp_int_copy(TEMP(2), TEMP(0)));
1191 } else {
1192 REQUIRE(mp_int_neg(TEMP(2), TEMP(1)));
1195 REQUIRE(mp_int_sub(TEMP(0), TEMP(1), TEMP(2)));
1197 if (CMPZ(TEMP(2)) == 0) break;
1200 REQUIRE(mp_int_abs(TEMP(0), c));
1201 if (!s_qmul(c, (mp_size)k)) REQUIRE(MP_MEMORY);
1203 CLEANUP_TEMP();
1204 return MP_OK;
1207 /* This is the binary GCD algorithm again, but this time we keep track of the
1208 elementary matrix operations as we go, so we can get values x and y
1209 satisfying c = ax + by.
1211 mp_result mp_int_egcd(mp_int a, mp_int b, mp_int c, mp_int x, mp_int y) {
1212 assert(a != NULL && b != NULL && c != NULL && (x != NULL || y != NULL));
1214 mp_result res = MP_OK;
1215 int ca = CMPZ(a);
1216 int cb = CMPZ(b);
1217 if (ca == 0 && cb == 0) {
1218 return MP_UNDEF;
1219 } else if (ca == 0) {
1220 if ((res = mp_int_abs(b, c)) != MP_OK) return res;
1221 mp_int_zero(x);
1222 (void)mp_int_set_value(y, 1);
1223 return MP_OK;
1224 } else if (cb == 0) {
1225 if ((res = mp_int_abs(a, c)) != MP_OK) return res;
1226 (void)mp_int_set_value(x, 1);
1227 mp_int_zero(y);
1228 return MP_OK;
1231 /* Initialize temporaries:
1232 A:0, B:1, C:2, D:3, u:4, v:5, ou:6, ov:7 */
1233 DECLARE_TEMP(8);
1234 REQUIRE(mp_int_set_value(TEMP(0), 1));
1235 REQUIRE(mp_int_set_value(TEMP(3), 1));
1236 REQUIRE(mp_int_copy(a, TEMP(4)));
1237 REQUIRE(mp_int_copy(b, TEMP(5)));
1239 /* We will work with absolute values here */
1240 TEMP(4)->sign = MP_ZPOS;
1241 TEMP(5)->sign = MP_ZPOS;
1243 int k = 0;
1244 { /* Divide out common factors of 2 from u and v */
1245 int div2_u = s_dp2k(TEMP(4)), div2_v = s_dp2k(TEMP(5));
1247 k = MIN(div2_u, div2_v);
1248 s_qdiv(TEMP(4), k);
1249 s_qdiv(TEMP(5), k);
1252 REQUIRE(mp_int_copy(TEMP(4), TEMP(6)));
1253 REQUIRE(mp_int_copy(TEMP(5), TEMP(7)));
1255 for (;;) {
1256 while (mp_int_is_even(TEMP(4))) {
1257 s_qdiv(TEMP(4), 1);
1259 if (mp_int_is_odd(TEMP(0)) || mp_int_is_odd(TEMP(1))) {
1260 REQUIRE(mp_int_add(TEMP(0), TEMP(7), TEMP(0)));
1261 REQUIRE(mp_int_sub(TEMP(1), TEMP(6), TEMP(1)));
1264 s_qdiv(TEMP(0), 1);
1265 s_qdiv(TEMP(1), 1);
1268 while (mp_int_is_even(TEMP(5))) {
1269 s_qdiv(TEMP(5), 1);
1271 if (mp_int_is_odd(TEMP(2)) || mp_int_is_odd(TEMP(3))) {
1272 REQUIRE(mp_int_add(TEMP(2), TEMP(7), TEMP(2)));
1273 REQUIRE(mp_int_sub(TEMP(3), TEMP(6), TEMP(3)));
1276 s_qdiv(TEMP(2), 1);
1277 s_qdiv(TEMP(3), 1);
1280 if (mp_int_compare(TEMP(4), TEMP(5)) >= 0) {
1281 REQUIRE(mp_int_sub(TEMP(4), TEMP(5), TEMP(4)));
1282 REQUIRE(mp_int_sub(TEMP(0), TEMP(2), TEMP(0)));
1283 REQUIRE(mp_int_sub(TEMP(1), TEMP(3), TEMP(1)));
1284 } else {
1285 REQUIRE(mp_int_sub(TEMP(5), TEMP(4), TEMP(5)));
1286 REQUIRE(mp_int_sub(TEMP(2), TEMP(0), TEMP(2)));
1287 REQUIRE(mp_int_sub(TEMP(3), TEMP(1), TEMP(3)));
1290 if (CMPZ(TEMP(4)) == 0) {
1291 if (x) REQUIRE(mp_int_copy(TEMP(2), x));
1292 if (y) REQUIRE(mp_int_copy(TEMP(3), y));
1293 if (c) {
1294 if (!s_qmul(TEMP(5), k)) {
1295 REQUIRE(MP_MEMORY);
1297 REQUIRE(mp_int_copy(TEMP(5), c));
1300 break;
1304 CLEANUP_TEMP();
1305 return MP_OK;
1308 mp_result mp_int_lcm(mp_int a, mp_int b, mp_int c) {
1309 assert(a != NULL && b != NULL && c != NULL);
1311 /* Since a * b = gcd(a, b) * lcm(a, b), we can compute
1312 lcm(a, b) = (a / gcd(a, b)) * b.
1314 This formulation insures everything works even if the input
1315 variables share space.
1317 DECLARE_TEMP(1);
1318 REQUIRE(mp_int_gcd(a, b, TEMP(0)));
1319 REQUIRE(mp_int_div(a, TEMP(0), TEMP(0), NULL));
1320 REQUIRE(mp_int_mul(TEMP(0), b, TEMP(0)));
1321 REQUIRE(mp_int_copy(TEMP(0), c));
1323 CLEANUP_TEMP();
1324 return MP_OK;
1327 bool mp_int_divisible_value(mp_int a, mp_small v) {
1328 mp_small rem = 0;
1330 if (mp_int_div_value(a, v, NULL, &rem) != MP_OK) {
1331 return false;
1333 return rem == 0;
1336 int mp_int_is_pow2(mp_int z) {
1337 assert(z != NULL);
1339 return s_isp2(z);
1342 /* Implementation of Newton's root finding method, based loosely on a patch
1343 contributed by Hal Finkel <half@halssoftware.com>
1344 modified by M. J. Fromberger.
1346 mp_result mp_int_root(mp_int a, mp_small b, mp_int c) {
1347 assert(a != NULL && c != NULL && b > 0);
1349 if (b == 1) {
1350 return mp_int_copy(a, c);
1352 bool flips = false;
1353 if (MP_SIGN(a) == MP_NEG) {
1354 if (b % 2 == 0) {
1355 return MP_UNDEF; /* root does not exist for negative a with even b */
1356 } else {
1357 flips = true;
1361 DECLARE_TEMP(5);
1362 REQUIRE(mp_int_copy(a, TEMP(0)));
1363 REQUIRE(mp_int_copy(a, TEMP(1)));
1364 TEMP(0)->sign = MP_ZPOS;
1365 TEMP(1)->sign = MP_ZPOS;
1367 for (;;) {
1368 REQUIRE(mp_int_expt(TEMP(1), b, TEMP(2)));
1370 if (mp_int_compare_unsigned(TEMP(2), TEMP(0)) <= 0) break;
1372 REQUIRE(mp_int_sub(TEMP(2), TEMP(0), TEMP(2)));
1373 REQUIRE(mp_int_expt(TEMP(1), b - 1, TEMP(3)));
1374 REQUIRE(mp_int_mul_value(TEMP(3), b, TEMP(3)));
1375 REQUIRE(mp_int_div(TEMP(2), TEMP(3), TEMP(4), NULL));
1376 REQUIRE(mp_int_sub(TEMP(1), TEMP(4), TEMP(4)));
1378 if (mp_int_compare_unsigned(TEMP(1), TEMP(4)) == 0) {
1379 REQUIRE(mp_int_sub_value(TEMP(4), 1, TEMP(4)));
1381 REQUIRE(mp_int_copy(TEMP(4), TEMP(1)));
1384 REQUIRE(mp_int_copy(TEMP(1), c));
1386 /* If the original value of a was negative, flip the output sign. */
1387 if (flips) (void)mp_int_neg(c, c); /* cannot fail */
1389 CLEANUP_TEMP();
1390 return MP_OK;
1393 mp_result mp_int_to_int(mp_int z, mp_small *out) {
1394 assert(z != NULL);
1396 /* Make sure the value is representable as a small integer */
1397 mp_sign sz = MP_SIGN(z);
1398 if ((sz == MP_ZPOS && mp_int_compare_value(z, MP_SMALL_MAX) > 0) ||
1399 mp_int_compare_value(z, MP_SMALL_MIN) < 0) {
1400 return MP_RANGE;
1403 mp_usmall uz = MP_USED(z);
1404 mp_digit *dz = MP_DIGITS(z) + uz - 1;
1405 mp_small uv = 0;
1406 while (uz > 0) {
1407 uv <<= MP_DIGIT_BIT / 2;
1408 uv = (uv << (MP_DIGIT_BIT / 2)) | *dz--;
1409 --uz;
1412 if (out) *out = (mp_small)((sz == MP_NEG) ? -uv : uv);
1414 return MP_OK;
1417 mp_result mp_int_to_uint(mp_int z, mp_usmall *out) {
1418 assert(z != NULL);
1420 /* Make sure the value is representable as an unsigned small integer */
1421 mp_size sz = MP_SIGN(z);
1422 if (sz == MP_NEG || mp_int_compare_uvalue(z, MP_USMALL_MAX) > 0) {
1423 return MP_RANGE;
1426 mp_size uz = MP_USED(z);
1427 mp_digit *dz = MP_DIGITS(z) + uz - 1;
1428 mp_usmall uv = 0;
1430 while (uz > 0) {
1431 uv <<= MP_DIGIT_BIT / 2;
1432 uv = (uv << (MP_DIGIT_BIT / 2)) | *dz--;
1433 --uz;
1436 if (out) *out = uv;
1438 return MP_OK;
1441 mp_result mp_int_to_string(mp_int z, mp_size radix, char *str, int limit) {
1442 assert(z != NULL && str != NULL && limit >= 2);
1443 assert(radix >= MP_MIN_RADIX && radix <= MP_MAX_RADIX);
1445 int cmp = 0;
1446 if (CMPZ(z) == 0) {
1447 *str++ = s_val2ch(0, 1);
1448 } else {
1449 mp_result res;
1450 mpz_t tmp;
1451 char *h, *t;
1453 if ((res = mp_int_init_copy(&tmp, z)) != MP_OK) return res;
1455 if (MP_SIGN(z) == MP_NEG) {
1456 *str++ = '-';
1457 --limit;
1459 h = str;
1461 /* Generate digits in reverse order until finished or limit reached */
1462 for (/* */; limit > 0; --limit) {
1463 mp_digit d;
1465 if ((cmp = CMPZ(&tmp)) == 0) break;
1467 d = s_ddiv(&tmp, (mp_digit)radix);
1468 *str++ = s_val2ch(d, 1);
1470 t = str - 1;
1472 /* Put digits back in correct output order */
1473 while (h < t) {
1474 char tc = *h;
1475 *h++ = *t;
1476 *t-- = tc;
1479 mp_int_clear(&tmp);
1482 *str = '\0';
1483 if (cmp == 0) {
1484 return MP_OK;
1485 } else {
1486 return MP_TRUNC;
1490 mp_result mp_int_string_len(mp_int z, mp_size radix) {
1491 assert(z != NULL);
1492 assert(radix >= MP_MIN_RADIX && radix <= MP_MAX_RADIX);
1494 int len = s_outlen(z, radix) + 1; /* for terminator */
1496 /* Allow for sign marker on negatives */
1497 if (MP_SIGN(z) == MP_NEG) len += 1;
1499 return len;
1502 /* Read zero-terminated string into z */
1503 mp_result mp_int_read_string(mp_int z, mp_size radix, const char *str) {
1504 return mp_int_read_cstring(z, radix, str, NULL);
1507 mp_result mp_int_read_cstring(mp_int z, mp_size radix, const char *str,
1508 char **end) {
1509 assert(z != NULL && str != NULL);
1510 assert(radix >= MP_MIN_RADIX && radix <= MP_MAX_RADIX);
1512 /* Skip leading whitespace */
1513 while (isspace((unsigned char)*str)) ++str;
1515 /* Handle leading sign tag (+/-, positive default) */
1516 switch (*str) {
1517 case '-':
1518 z->sign = MP_NEG;
1519 ++str;
1520 break;
1521 case '+':
1522 ++str; /* fallthrough */
1523 default:
1524 z->sign = MP_ZPOS;
1525 break;
1528 /* Skip leading zeroes */
1529 int ch;
1530 while ((ch = s_ch2val(*str, radix)) == 0) ++str;
1532 /* Make sure there is enough space for the value */
1533 if (!s_pad(z, s_inlen(strlen(str), radix))) return MP_MEMORY;
1535 z->used = 1;
1536 z->digits[0] = 0;
1538 while (*str != '\0' && ((ch = s_ch2val(*str, radix)) >= 0)) {
1539 s_dmul(z, (mp_digit)radix);
1540 s_dadd(z, (mp_digit)ch);
1541 ++str;
1544 CLAMP(z);
1546 /* Override sign for zero, even if negative specified. */
1547 if (CMPZ(z) == 0) z->sign = MP_ZPOS;
1549 if (end != NULL) *end = (char *)str;
1551 /* Return a truncation error if the string has unprocessed characters
1552 remaining, so the caller can tell if the whole string was done */
1553 if (*str != '\0') {
1554 return MP_TRUNC;
1555 } else {
1556 return MP_OK;
1560 mp_result mp_int_count_bits(mp_int z) {
1561 assert(z != NULL);
1563 mp_size uz = MP_USED(z);
1564 if (uz == 1 && z->digits[0] == 0) return 1;
1566 --uz;
1567 mp_size nbits = uz * MP_DIGIT_BIT;
1568 mp_digit d = z->digits[uz];
1570 while (d != 0) {
1571 d >>= 1;
1572 ++nbits;
1575 return nbits;
1578 mp_result mp_int_to_binary(mp_int z, unsigned char *buf, int limit) {
1579 static const int PAD_FOR_2C = 1;
1581 assert(z != NULL && buf != NULL);
1583 int limpos = limit;
1584 mp_result res = s_tobin(z, buf, &limpos, PAD_FOR_2C);
1586 if (MP_SIGN(z) == MP_NEG) s_2comp(buf, limpos);
1588 return res;
1591 mp_result mp_int_read_binary(mp_int z, unsigned char *buf, int len) {
1592 assert(z != NULL && buf != NULL && len > 0);
1594 /* Figure out how many digits are needed to represent this value */
1595 mp_size need = ((len * CHAR_BIT) + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT;
1596 if (!s_pad(z, need)) return MP_MEMORY;
1598 mp_int_zero(z);
1600 /* If the high-order bit is set, take the 2's complement before reading the
1601 value (it will be restored afterward) */
1602 if (buf[0] >> (CHAR_BIT - 1)) {
1603 z->sign = MP_NEG;
1604 s_2comp(buf, len);
1607 mp_digit *dz = MP_DIGITS(z);
1608 unsigned char *tmp = buf;
1609 for (int i = len; i > 0; --i, ++tmp) {
1610 s_qmul(z, (mp_size)CHAR_BIT);
1611 *dz |= *tmp;
1614 /* Restore 2's complement if we took it before */
1615 if (MP_SIGN(z) == MP_NEG) s_2comp(buf, len);
1617 return MP_OK;
1620 mp_result mp_int_binary_len(mp_int z) {
1621 mp_result res = mp_int_count_bits(z);
1622 if (res <= 0) return res;
1624 int bytes = mp_int_unsigned_len(z);
1626 /* If the highest-order bit falls exactly on a byte boundary, we need to pad
1627 with an extra byte so that the sign will be read correctly when reading it
1628 back in. */
1629 if (bytes * CHAR_BIT == res) ++bytes;
1631 return bytes;
1634 mp_result mp_int_to_unsigned(mp_int z, unsigned char *buf, int limit) {
1635 static const int NO_PADDING = 0;
1637 assert(z != NULL && buf != NULL);
1639 return s_tobin(z, buf, &limit, NO_PADDING);
1642 mp_result mp_int_read_unsigned(mp_int z, unsigned char *buf, int len) {
1643 assert(z != NULL && buf != NULL && len > 0);
1645 /* Figure out how many digits are needed to represent this value */
1646 mp_size need = ((len * CHAR_BIT) + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT;
1647 if (!s_pad(z, need)) return MP_MEMORY;
1649 mp_int_zero(z);
1651 unsigned char *tmp = buf;
1652 for (int i = len; i > 0; --i, ++tmp) {
1653 (void)s_qmul(z, CHAR_BIT);
1654 *MP_DIGITS(z) |= *tmp;
1657 return MP_OK;
1660 mp_result mp_int_unsigned_len(mp_int z) {
1661 mp_result res = mp_int_count_bits(z);
1662 if (res <= 0) return res;
1664 int bytes = (res + (CHAR_BIT - 1)) / CHAR_BIT;
1665 return bytes;
1668 const char *mp_error_string(mp_result res) {
1669 if (res > 0) return s_unknown_err;
1671 res = -res;
1672 int ix;
1673 for (ix = 0; ix < res && s_error_msg[ix] != NULL; ++ix)
1676 if (s_error_msg[ix] != NULL) {
1677 return s_error_msg[ix];
1678 } else {
1679 return s_unknown_err;
1683 /*------------------------------------------------------------------------*/
1684 /* Private functions for internal use. These make assumptions. */
1686 #if DEBUG
1687 static const mp_digit fill = (mp_digit)0xdeadbeefabad1dea;
1688 #endif
1690 static mp_digit *s_alloc(mp_size num) {
1691 mp_digit *out = malloc(num * sizeof(mp_digit));
1692 assert(out != NULL);
1694 #if DEBUG
1695 for (mp_size ix = 0; ix < num; ++ix) out[ix] = fill;
1696 #endif
1697 return out;
1700 static mp_digit *s_realloc(mp_digit *old, mp_size osize, mp_size nsize) {
1701 #if DEBUG
1702 mp_digit *new = s_alloc(nsize);
1703 assert(new != NULL);
1705 for (mp_size ix = 0; ix < nsize; ++ix) new[ix] = fill;
1706 memcpy(new, old, osize * sizeof(mp_digit));
1707 #else
1708 mp_digit *new = realloc(old, nsize * sizeof(mp_digit));
1709 assert(new != NULL);
1710 #endif
1712 return new;
1715 static void s_free(void *ptr) { free(ptr); }
1717 static bool s_pad(mp_int z, mp_size min) {
1718 if (MP_ALLOC(z) < min) {
1719 mp_size nsize = s_round_prec(min);
1720 mp_digit *tmp;
1722 if (z->digits == &(z->single)) {
1723 if ((tmp = s_alloc(nsize)) == NULL) return false;
1724 tmp[0] = z->single;
1725 } else if ((tmp = s_realloc(MP_DIGITS(z), MP_ALLOC(z), nsize)) == NULL) {
1726 return false;
1729 z->digits = tmp;
1730 z->alloc = nsize;
1733 return true;
1736 /* Note: This will not work correctly when value == MP_SMALL_MIN */
1737 static void s_fake(mp_int z, mp_small value, mp_digit vbuf[]) {
1738 mp_usmall uv = (mp_usmall)(value < 0) ? -value : value;
1739 s_ufake(z, uv, vbuf);
1740 if (value < 0) z->sign = MP_NEG;
1743 static void s_ufake(mp_int z, mp_usmall value, mp_digit vbuf[]) {
1744 mp_size ndig = (mp_size)s_uvpack(value, vbuf);
1746 z->used = ndig;
1747 z->alloc = MP_VALUE_DIGITS(value);
1748 z->sign = MP_ZPOS;
1749 z->digits = vbuf;
1752 static int s_cdig(mp_digit *da, mp_digit *db, mp_size len) {
1753 mp_digit *dat = da + len - 1, *dbt = db + len - 1;
1755 for (/* */; len != 0; --len, --dat, --dbt) {
1756 if (*dat > *dbt) {
1757 return 1;
1758 } else if (*dat < *dbt) {
1759 return -1;
1763 return 0;
1766 static int s_uvpack(mp_usmall uv, mp_digit t[]) {
1767 int ndig = 0;
1769 if (uv == 0)
1770 t[ndig++] = 0;
1771 else {
1772 while (uv != 0) {
1773 t[ndig++] = (mp_digit)uv;
1774 uv >>= MP_DIGIT_BIT / 2;
1775 uv >>= MP_DIGIT_BIT / 2;
1779 return ndig;
1782 static int s_ucmp(mp_int a, mp_int b) {
1783 mp_size ua = MP_USED(a), ub = MP_USED(b);
1785 if (ua > ub) {
1786 return 1;
1787 } else if (ub > ua) {
1788 return -1;
1789 } else {
1790 return s_cdig(MP_DIGITS(a), MP_DIGITS(b), ua);
1794 static int s_vcmp(mp_int a, mp_small v) {
1795 mp_usmall uv = (v < 0) ? -(mp_usmall)v : (mp_usmall)v;
1796 return s_uvcmp(a, uv);
1799 static int s_uvcmp(mp_int a, mp_usmall uv) {
1800 mpz_t vtmp;
1801 mp_digit vdig[MP_VALUE_DIGITS(uv)];
1803 s_ufake(&vtmp, uv, vdig);
1804 return s_ucmp(a, &vtmp);
1807 static mp_digit s_uadd(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
1808 mp_size size_b) {
1809 mp_size pos;
1810 mp_word w = 0;
1812 /* Insure that da is the longer of the two to simplify later code */
1813 if (size_b > size_a) {
1814 SWAP(mp_digit *, da, db);
1815 SWAP(mp_size, size_a, size_b);
1818 /* Add corresponding digits until the shorter number runs out */
1819 for (pos = 0; pos < size_b; ++pos, ++da, ++db, ++dc) {
1820 w = w + (mp_word)*da + (mp_word)*db;
1821 *dc = LOWER_HALF(w);
1822 w = UPPER_HALF(w);
1825 /* Propagate carries as far as necessary */
1826 for (/* */; pos < size_a; ++pos, ++da, ++dc) {
1827 w = w + *da;
1829 *dc = LOWER_HALF(w);
1830 w = UPPER_HALF(w);
1833 /* Return carry out */
1834 return (mp_digit)w;
1837 static void s_usub(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
1838 mp_size size_b) {
1839 mp_size pos;
1840 mp_word w = 0;
1842 /* We assume that |a| >= |b| so this should definitely hold */
1843 assert(size_a >= size_b);
1845 /* Subtract corresponding digits and propagate borrow */
1846 for (pos = 0; pos < size_b; ++pos, ++da, ++db, ++dc) {
1847 w = ((mp_word)MP_DIGIT_MAX + 1 + /* MP_RADIX */
1848 (mp_word)*da) -
1849 w - (mp_word)*db;
1851 *dc = LOWER_HALF(w);
1852 w = (UPPER_HALF(w) == 0);
1855 /* Finish the subtraction for remaining upper digits of da */
1856 for (/* */; pos < size_a; ++pos, ++da, ++dc) {
1857 w = ((mp_word)MP_DIGIT_MAX + 1 + /* MP_RADIX */
1858 (mp_word)*da) -
1861 *dc = LOWER_HALF(w);
1862 w = (UPPER_HALF(w) == 0);
1865 /* If there is a borrow out at the end, it violates the precondition */
1866 assert(w == 0);
1869 static int s_kmul(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
1870 mp_size size_b) {
1871 mp_size bot_size;
1873 /* Make sure b is the smaller of the two input values */
1874 if (size_b > size_a) {
1875 SWAP(mp_digit *, da, db);
1876 SWAP(mp_size, size_a, size_b);
1879 /* Insure that the bottom is the larger half in an odd-length split; the code
1880 below relies on this being true.
1882 bot_size = (size_a + 1) / 2;
1884 /* If the values are big enough to bother with recursion, use the Karatsuba
1885 algorithm to compute the product; otherwise use the normal multiplication
1886 algorithm
1888 if (multiply_threshold && size_a >= multiply_threshold && size_b > bot_size) {
1889 mp_digit *t1, *t2, *t3, carry;
1891 mp_digit *a_top = da + bot_size;
1892 mp_digit *b_top = db + bot_size;
1894 mp_size at_size = size_a - bot_size;
1895 mp_size bt_size = size_b - bot_size;
1896 mp_size buf_size = 2 * bot_size;
1898 /* Do a single allocation for all three temporary buffers needed; each
1899 buffer must be big enough to hold the product of two bottom halves, and
1900 one buffer needs space for the completed product; twice the space is
1901 plenty.
1903 if ((t1 = s_alloc(4 * buf_size)) == NULL) return 0;
1904 t2 = t1 + buf_size;
1905 t3 = t2 + buf_size;
1906 ZERO(t1, 4 * buf_size);
1908 /* t1 and t2 are initially used as temporaries to compute the inner product
1909 (a1 + a0)(b1 + b0) = a1b1 + a1b0 + a0b1 + a0b0
1911 carry = s_uadd(da, a_top, t1, bot_size, at_size); /* t1 = a1 + a0 */
1912 t1[bot_size] = carry;
1914 carry = s_uadd(db, b_top, t2, bot_size, bt_size); /* t2 = b1 + b0 */
1915 t2[bot_size] = carry;
1917 (void)s_kmul(t1, t2, t3, bot_size + 1, bot_size + 1); /* t3 = t1 * t2 */
1919 /* Now we'll get t1 = a0b0 and t2 = a1b1, and subtract them out so that
1920 we're left with only the pieces we want: t3 = a1b0 + a0b1
1922 ZERO(t1, buf_size);
1923 ZERO(t2, buf_size);
1924 (void)s_kmul(da, db, t1, bot_size, bot_size); /* t1 = a0 * b0 */
1925 (void)s_kmul(a_top, b_top, t2, at_size, bt_size); /* t2 = a1 * b1 */
1927 /* Subtract out t1 and t2 to get the inner product */
1928 s_usub(t3, t1, t3, buf_size + 2, buf_size);
1929 s_usub(t3, t2, t3, buf_size + 2, buf_size);
1931 /* Assemble the output value */
1932 COPY(t1, dc, buf_size);
1933 carry = s_uadd(t3, dc + bot_size, dc + bot_size, buf_size + 1, buf_size);
1934 assert(carry == 0);
1936 carry =
1937 s_uadd(t2, dc + 2 * bot_size, dc + 2 * bot_size, buf_size, buf_size);
1938 assert(carry == 0);
1940 s_free(t1); /* note t2 and t3 are just internal pointers to t1 */
1941 } else {
1942 s_umul(da, db, dc, size_a, size_b);
1945 return 1;
1948 static void s_umul(mp_digit *da, mp_digit *db, mp_digit *dc, mp_size size_a,
1949 mp_size size_b) {
1950 mp_size a, b;
1951 mp_word w;
1953 for (a = 0; a < size_a; ++a, ++dc, ++da) {
1954 mp_digit *dct = dc;
1955 mp_digit *dbt = db;
1957 if (*da == 0) continue;
1959 w = 0;
1960 for (b = 0; b < size_b; ++b, ++dbt, ++dct) {
1961 w = (mp_word)*da * (mp_word)*dbt + w + (mp_word)*dct;
1963 *dct = LOWER_HALF(w);
1964 w = UPPER_HALF(w);
1967 *dct = (mp_digit)w;
1971 static int s_ksqr(mp_digit *da, mp_digit *dc, mp_size size_a) {
1972 if (multiply_threshold && size_a > multiply_threshold) {
1973 mp_size bot_size = (size_a + 1) / 2;
1974 mp_digit *a_top = da + bot_size;
1975 mp_digit *t1, *t2, *t3, carry;
1976 mp_size at_size = size_a - bot_size;
1977 mp_size buf_size = 2 * bot_size;
1979 if ((t1 = s_alloc(4 * buf_size)) == NULL) return 0;
1980 t2 = t1 + buf_size;
1981 t3 = t2 + buf_size;
1982 ZERO(t1, 4 * buf_size);
1984 (void)s_ksqr(da, t1, bot_size); /* t1 = a0 ^ 2 */
1985 (void)s_ksqr(a_top, t2, at_size); /* t2 = a1 ^ 2 */
1987 (void)s_kmul(da, a_top, t3, bot_size, at_size); /* t3 = a0 * a1 */
1989 /* Quick multiply t3 by 2, shifting left (can't overflow) */
1991 int i, top = bot_size + at_size;
1992 mp_word w, save = 0;
1994 for (i = 0; i < top; ++i) {
1995 w = t3[i];
1996 w = (w << 1) | save;
1997 t3[i] = LOWER_HALF(w);
1998 save = UPPER_HALF(w);
2000 t3[i] = LOWER_HALF(save);
2003 /* Assemble the output value */
2004 COPY(t1, dc, 2 * bot_size);
2005 carry = s_uadd(t3, dc + bot_size, dc + bot_size, buf_size + 1, buf_size);
2006 assert(carry == 0);
2008 carry =
2009 s_uadd(t2, dc + 2 * bot_size, dc + 2 * bot_size, buf_size, buf_size);
2010 assert(carry == 0);
2012 s_free(t1); /* note that t2 and t2 are internal pointers only */
2014 } else {
2015 s_usqr(da, dc, size_a);
2018 return 1;
2021 static void s_usqr(mp_digit *da, mp_digit *dc, mp_size size_a) {
2022 mp_size i, j;
2023 mp_word w;
2025 for (i = 0; i < size_a; ++i, dc += 2, ++da) {
2026 mp_digit *dct = dc, *dat = da;
2028 if (*da == 0) continue;
2030 /* Take care of the first digit, no rollover */
2031 w = (mp_word)*dat * (mp_word)*dat + (mp_word)*dct;
2032 *dct = LOWER_HALF(w);
2033 w = UPPER_HALF(w);
2034 ++dat;
2035 ++dct;
2037 for (j = i + 1; j < size_a; ++j, ++dat, ++dct) {
2038 mp_word t = (mp_word)*da * (mp_word)*dat;
2039 mp_word u = w + (mp_word)*dct, ov = 0;
2041 /* Check if doubling t will overflow a word */
2042 if (HIGH_BIT_SET(t)) ov = 1;
2044 w = t + t;
2046 /* Check if adding u to w will overflow a word */
2047 if (ADD_WILL_OVERFLOW(w, u)) ov = 1;
2049 w += u;
2051 *dct = LOWER_HALF(w);
2052 w = UPPER_HALF(w);
2053 if (ov) {
2054 w += MP_DIGIT_MAX; /* MP_RADIX */
2055 ++w;
2059 w = w + *dct;
2060 *dct = (mp_digit)w;
2061 while ((w = UPPER_HALF(w)) != 0) {
2062 ++dct;
2063 w = w + *dct;
2064 *dct = LOWER_HALF(w);
2067 assert(w == 0);
2071 static void s_dadd(mp_int a, mp_digit b) {
2072 mp_word w = 0;
2073 mp_digit *da = MP_DIGITS(a);
2074 mp_size ua = MP_USED(a);
2076 w = (mp_word)*da + b;
2077 *da++ = LOWER_HALF(w);
2078 w = UPPER_HALF(w);
2080 for (ua -= 1; ua > 0; --ua, ++da) {
2081 w = (mp_word)*da + w;
2083 *da = LOWER_HALF(w);
2084 w = UPPER_HALF(w);
2087 if (w) {
2088 *da = (mp_digit)w;
2089 a->used += 1;
2093 static void s_dmul(mp_int a, mp_digit b) {
2094 mp_word w = 0;
2095 mp_digit *da = MP_DIGITS(a);
2096 mp_size ua = MP_USED(a);
2098 while (ua > 0) {
2099 w = (mp_word)*da * b + w;
2100 *da++ = LOWER_HALF(w);
2101 w = UPPER_HALF(w);
2102 --ua;
2105 if (w) {
2106 *da = (mp_digit)w;
2107 a->used += 1;
2111 static void s_dbmul(mp_digit *da, mp_digit b, mp_digit *dc, mp_size size_a) {
2112 mp_word w = 0;
2114 while (size_a > 0) {
2115 w = (mp_word)*da++ * (mp_word)b + w;
2117 *dc++ = LOWER_HALF(w);
2118 w = UPPER_HALF(w);
2119 --size_a;
2122 if (w) *dc = LOWER_HALF(w);
2125 static mp_digit s_ddiv(mp_int a, mp_digit b) {
2126 mp_word w = 0, qdigit;
2127 mp_size ua = MP_USED(a);
2128 mp_digit *da = MP_DIGITS(a) + ua - 1;
2130 for (/* */; ua > 0; --ua, --da) {
2131 w = (w << MP_DIGIT_BIT) | *da;
2133 if (w >= b) {
2134 qdigit = w / b;
2135 w = w % b;
2136 } else {
2137 qdigit = 0;
2140 *da = (mp_digit)qdigit;
2143 CLAMP(a);
2144 return (mp_digit)w;
2147 static void s_qdiv(mp_int z, mp_size p2) {
2148 mp_size ndig = p2 / MP_DIGIT_BIT, nbits = p2 % MP_DIGIT_BIT;
2149 mp_size uz = MP_USED(z);
2151 if (ndig) {
2152 mp_size mark;
2153 mp_digit *to, *from;
2155 if (ndig >= uz) {
2156 mp_int_zero(z);
2157 return;
2160 to = MP_DIGITS(z);
2161 from = to + ndig;
2163 for (mark = ndig; mark < uz; ++mark) {
2164 *to++ = *from++;
2167 z->used = uz - ndig;
2170 if (nbits) {
2171 mp_digit d = 0, *dz, save;
2172 mp_size up = MP_DIGIT_BIT - nbits;
2174 uz = MP_USED(z);
2175 dz = MP_DIGITS(z) + uz - 1;
2177 for (/* */; uz > 0; --uz, --dz) {
2178 save = *dz;
2180 *dz = (*dz >> nbits) | (d << up);
2181 d = save;
2184 CLAMP(z);
2187 if (MP_USED(z) == 1 && z->digits[0] == 0) z->sign = MP_ZPOS;
2190 static void s_qmod(mp_int z, mp_size p2) {
2191 mp_size start = p2 / MP_DIGIT_BIT + 1, rest = p2 % MP_DIGIT_BIT;
2192 mp_size uz = MP_USED(z);
2193 mp_digit mask = (1u << rest) - 1;
2195 if (start <= uz) {
2196 z->used = start;
2197 z->digits[start - 1] &= mask;
2198 CLAMP(z);
2202 static int s_qmul(mp_int z, mp_size p2) {
2203 mp_size uz, need, rest, extra, i;
2204 mp_digit *from, *to, d;
2206 if (p2 == 0) return 1;
2208 uz = MP_USED(z);
2209 need = p2 / MP_DIGIT_BIT;
2210 rest = p2 % MP_DIGIT_BIT;
2212 /* Figure out if we need an extra digit at the top end; this occurs if the
2213 topmost `rest' bits of the high-order digit of z are not zero, meaning
2214 they will be shifted off the end if not preserved */
2215 extra = 0;
2216 if (rest != 0) {
2217 mp_digit *dz = MP_DIGITS(z) + uz - 1;
2219 if ((*dz >> (MP_DIGIT_BIT - rest)) != 0) extra = 1;
2222 if (!s_pad(z, uz + need + extra)) return 0;
2224 /* If we need to shift by whole digits, do that in one pass, then
2225 to back and shift by partial digits.
2227 if (need > 0) {
2228 from = MP_DIGITS(z) + uz - 1;
2229 to = from + need;
2231 for (i = 0; i < uz; ++i) *to-- = *from--;
2233 ZERO(MP_DIGITS(z), need);
2234 uz += need;
2237 if (rest) {
2238 d = 0;
2239 for (i = need, from = MP_DIGITS(z) + need; i < uz; ++i, ++from) {
2240 mp_digit save = *from;
2242 *from = (*from << rest) | (d >> (MP_DIGIT_BIT - rest));
2243 d = save;
2246 d >>= (MP_DIGIT_BIT - rest);
2247 if (d != 0) {
2248 *from = d;
2249 uz += extra;
2253 z->used = uz;
2254 CLAMP(z);
2256 return 1;
2259 /* Compute z = 2^p2 - |z|; requires that 2^p2 >= |z|
2260 The sign of the result is always zero/positive.
2262 static int s_qsub(mp_int z, mp_size p2) {
2263 mp_digit hi = (1u << (p2 % MP_DIGIT_BIT)), *zp;
2264 mp_size tdig = (p2 / MP_DIGIT_BIT), pos;
2265 mp_word w = 0;
2267 if (!s_pad(z, tdig + 1)) return 0;
2269 for (pos = 0, zp = MP_DIGITS(z); pos < tdig; ++pos, ++zp) {
2270 w = ((mp_word)MP_DIGIT_MAX + 1) - w - (mp_word)*zp;
2272 *zp = LOWER_HALF(w);
2273 w = UPPER_HALF(w) ? 0 : 1;
2276 w = ((mp_word)MP_DIGIT_MAX + 1 + hi) - w - (mp_word)*zp;
2277 *zp = LOWER_HALF(w);
2279 assert(UPPER_HALF(w) != 0); /* no borrow out should be possible */
2281 z->sign = MP_ZPOS;
2282 CLAMP(z);
2284 return 1;
2287 static int s_dp2k(mp_int z) {
2288 int k = 0;
2289 mp_digit *dp = MP_DIGITS(z), d;
2291 if (MP_USED(z) == 1 && *dp == 0) return 1;
2293 while (*dp == 0) {
2294 k += MP_DIGIT_BIT;
2295 ++dp;
2298 d = *dp;
2299 while ((d & 1) == 0) {
2300 d >>= 1;
2301 ++k;
2304 return k;
2307 static int s_isp2(mp_int z) {
2308 mp_size uz = MP_USED(z), k = 0;
2309 mp_digit *dz = MP_DIGITS(z), d;
2311 while (uz > 1) {
2312 if (*dz++ != 0) return -1;
2313 k += MP_DIGIT_BIT;
2314 --uz;
2317 d = *dz;
2318 while (d > 1) {
2319 if (d & 1) return -1;
2320 ++k;
2321 d >>= 1;
2324 return (int)k;
2327 static int s_2expt(mp_int z, mp_small k) {
2328 mp_size ndig, rest;
2329 mp_digit *dz;
2331 ndig = (k + MP_DIGIT_BIT) / MP_DIGIT_BIT;
2332 rest = k % MP_DIGIT_BIT;
2334 if (!s_pad(z, ndig)) return 0;
2336 dz = MP_DIGITS(z);
2337 ZERO(dz, ndig);
2338 *(dz + ndig - 1) = (1u << rest);
2339 z->used = ndig;
2341 return 1;
2344 static int s_norm(mp_int a, mp_int b) {
2345 mp_digit d = b->digits[MP_USED(b) - 1];
2346 int k = 0;
2348 while (d < (1u << (mp_digit)(MP_DIGIT_BIT - 1))) { /* d < (MP_RADIX / 2) */
2349 d <<= 1;
2350 ++k;
2353 /* These multiplications can't fail */
2354 if (k != 0) {
2355 (void)s_qmul(a, (mp_size)k);
2356 (void)s_qmul(b, (mp_size)k);
2359 return k;
2362 static mp_result s_brmu(mp_int z, mp_int m) {
2363 mp_size um = MP_USED(m) * 2;
2365 if (!s_pad(z, um)) return MP_MEMORY;
2367 s_2expt(z, MP_DIGIT_BIT * um);
2368 return mp_int_div(z, m, z, NULL);
2371 static int s_reduce(mp_int x, mp_int m, mp_int mu, mp_int q1, mp_int q2) {
2372 mp_size um = MP_USED(m), umb_p1, umb_m1;
2374 umb_p1 = (um + 1) * MP_DIGIT_BIT;
2375 umb_m1 = (um - 1) * MP_DIGIT_BIT;
2377 if (mp_int_copy(x, q1) != MP_OK) return 0;
2379 /* Compute q2 = floor((floor(x / b^(k-1)) * mu) / b^(k+1)) */
2380 s_qdiv(q1, umb_m1);
2381 UMUL(q1, mu, q2);
2382 s_qdiv(q2, umb_p1);
2384 /* Set x = x mod b^(k+1) */
2385 s_qmod(x, umb_p1);
2387 /* Now, q is a guess for the quotient a / m.
2388 Compute x - q * m mod b^(k+1), replacing x. This may be off
2389 by a factor of 2m, but no more than that.
2391 UMUL(q2, m, q1);
2392 s_qmod(q1, umb_p1);
2393 (void)mp_int_sub(x, q1, x); /* can't fail */
2395 /* The result may be < 0; if it is, add b^(k+1) to pin it in the proper
2396 range. */
2397 if ((CMPZ(x) < 0) && !s_qsub(x, umb_p1)) return 0;
2399 /* If x > m, we need to back it off until it is in range. This will be
2400 required at most twice. */
2401 if (mp_int_compare(x, m) >= 0) {
2402 (void)mp_int_sub(x, m, x);
2403 if (mp_int_compare(x, m) >= 0) {
2404 (void)mp_int_sub(x, m, x);
2408 /* At this point, x has been properly reduced. */
2409 return 1;
2412 /* Perform modular exponentiation using Barrett's method, where mu is the
2413 reduction constant for m. Assumes a < m, b > 0. */
2414 static mp_result s_embar(mp_int a, mp_int b, mp_int m, mp_int mu, mp_int c) {
2415 mp_digit umu = MP_USED(mu);
2416 mp_digit *db = MP_DIGITS(b);
2417 mp_digit *dbt = db + MP_USED(b) - 1;
2419 DECLARE_TEMP(3);
2420 REQUIRE(GROW(TEMP(0), 4 * umu));
2421 REQUIRE(GROW(TEMP(1), 4 * umu));
2422 REQUIRE(GROW(TEMP(2), 4 * umu));
2423 ZERO(TEMP(0)->digits, TEMP(0)->alloc);
2424 ZERO(TEMP(1)->digits, TEMP(1)->alloc);
2425 ZERO(TEMP(2)->digits, TEMP(2)->alloc);
2427 (void)mp_int_set_value(c, 1);
2429 /* Take care of low-order digits */
2430 while (db < dbt) {
2431 mp_digit d = *db;
2433 for (int i = MP_DIGIT_BIT; i > 0; --i, d >>= 1) {
2434 if (d & 1) {
2435 /* The use of a second temporary avoids allocation */
2436 UMUL(c, a, TEMP(0));
2437 if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2))) {
2438 REQUIRE(MP_MEMORY);
2440 mp_int_copy(TEMP(0), c);
2443 USQR(a, TEMP(0));
2444 assert(MP_SIGN(TEMP(0)) == MP_ZPOS);
2445 if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2))) {
2446 REQUIRE(MP_MEMORY);
2448 assert(MP_SIGN(TEMP(0)) == MP_ZPOS);
2449 mp_int_copy(TEMP(0), a);
2452 ++db;
2455 /* Take care of highest-order digit */
2456 mp_digit d = *dbt;
2457 for (;;) {
2458 if (d & 1) {
2459 UMUL(c, a, TEMP(0));
2460 if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2))) {
2461 REQUIRE(MP_MEMORY);
2463 mp_int_copy(TEMP(0), c);
2466 d >>= 1;
2467 if (!d) break;
2469 USQR(a, TEMP(0));
2470 if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2))) {
2471 REQUIRE(MP_MEMORY);
2473 (void)mp_int_copy(TEMP(0), a);
2476 CLEANUP_TEMP();
2477 return MP_OK;
2480 /* Division of nonnegative integers
2482 This function implements division algorithm for unsigned multi-precision
2483 integers. The algorithm is based on Algorithm D from Knuth's "The Art of
2484 Computer Programming", 3rd ed. 1998, pg 272-273.
2486 We diverge from Knuth's algorithm in that we do not perform the subtraction
2487 from the remainder until we have determined that we have the correct
2488 quotient digit. This makes our algorithm less efficient that Knuth because
2489 we might have to perform multiple multiplication and comparison steps before
2490 the subtraction. The advantage is that it is easy to implement and ensure
2491 correctness without worrying about underflow from the subtraction.
2493 inputs: u a n+m digit integer in base b (b is 2^MP_DIGIT_BIT)
2494 v a n digit integer in base b (b is 2^MP_DIGIT_BIT)
2495 n >= 1
2496 m >= 0
2497 outputs: u / v stored in u
2498 u % v stored in v
2500 static mp_result s_udiv_knuth(mp_int u, mp_int v) {
2501 /* Force signs to positive */
2502 u->sign = MP_ZPOS;
2503 v->sign = MP_ZPOS;
2505 /* Use simple division algorithm when v is only one digit long */
2506 if (MP_USED(v) == 1) {
2507 mp_digit d, rem;
2508 d = v->digits[0];
2509 rem = s_ddiv(u, d);
2510 mp_int_set_value(v, rem);
2511 return MP_OK;
2514 /* Algorithm D
2516 The n and m variables are defined as used by Knuth.
2517 u is an n digit number with digits u_{n-1}..u_0.
2518 v is an n+m digit number with digits from v_{m+n-1}..v_0.
2519 We require that n > 1 and m >= 0
2521 mp_size n = MP_USED(v);
2522 mp_size m = MP_USED(u) - n;
2523 assert(n > 1);
2524 /* assert(m >= 0) follows because m is unsigned. */
2526 /* D1: Normalize.
2527 The normalization step provides the necessary condition for Theorem B,
2528 which states that the quotient estimate for q_j, call it qhat
2530 qhat = u_{j+n}u_{j+n-1} / v_{n-1}
2532 is bounded by
2534 qhat - 2 <= q_j <= qhat.
2536 That is, qhat is always greater than the actual quotient digit q,
2537 and it is never more than two larger than the actual quotient digit.
2539 int k = s_norm(u, v);
2541 /* Extend size of u by one if needed.
2543 The algorithm begins with a value of u that has one more digit of input.
2544 The normalization step sets u_{m+n}..u_0 = 2^k * u_{m+n-1}..u_0. If the
2545 multiplication did not increase the number of digits of u, we need to add
2546 a leading zero here.
2548 if (k == 0 || MP_USED(u) != m + n + 1) {
2549 if (!s_pad(u, m + n + 1)) return MP_MEMORY;
2550 u->digits[m + n] = 0;
2551 u->used = m + n + 1;
2554 /* Add a leading 0 to v.
2556 The multiplication in step D4 multiplies qhat * 0v_{n-1}..v_0. We need to
2557 add the leading zero to v here to ensure that the multiplication will
2558 produce the full n+1 digit result.
2560 if (!s_pad(v, n + 1)) return MP_MEMORY;
2561 v->digits[n] = 0;
2563 /* Initialize temporary variables q and t.
2564 q allocates space for m+1 digits to store the quotient digits
2565 t allocates space for n+1 digits to hold the result of q_j*v
2567 DECLARE_TEMP(2);
2568 REQUIRE(GROW(TEMP(0), m + 1));
2569 REQUIRE(GROW(TEMP(1), n + 1));
2571 /* D2: Initialize j */
2572 int j = m;
2573 mpz_t r;
2574 r.digits = MP_DIGITS(u) + j; /* The contents of r are shared with u */
2575 r.used = n + 1;
2576 r.sign = MP_ZPOS;
2577 r.alloc = MP_ALLOC(u);
2578 ZERO(TEMP(1)->digits, TEMP(1)->alloc);
2580 /* Calculate the m+1 digits of the quotient result */
2581 for (; j >= 0; j--) {
2582 /* D3: Calculate q' */
2583 /* r->digits is aligned to position j of the number u */
2584 mp_word pfx, qhat;
2585 pfx = r.digits[n];
2586 pfx <<= MP_DIGIT_BIT / 2;
2587 pfx <<= MP_DIGIT_BIT / 2;
2588 pfx |= r.digits[n - 1]; /* pfx = u_{j+n}{j+n-1} */
2590 qhat = pfx / v->digits[n - 1];
2591 /* Check to see if qhat > b, and decrease qhat if so.
2592 Theorem B guarantess that qhat is at most 2 larger than the
2593 actual value, so it is possible that qhat is greater than
2594 the maximum value that will fit in a digit */
2595 if (qhat > MP_DIGIT_MAX) qhat = MP_DIGIT_MAX;
2597 /* D4,D5,D6: Multiply qhat * v and test for a correct value of q
2599 We proceed a bit different than the way described by Knuth. This way is
2600 simpler but less efficent. Instead of doing the multiply and subtract
2601 then checking for underflow, we first do the multiply of qhat * v and
2602 see if it is larger than the current remainder r. If it is larger, we
2603 decrease qhat by one and try again. We may need to decrease qhat one
2604 more time before we get a value that is smaller than r.
2606 This way is less efficent than Knuth because we do more multiplies, but
2607 we do not need to worry about underflow this way.
2609 /* t = qhat * v */
2610 s_dbmul(MP_DIGITS(v), (mp_digit)qhat, TEMP(1)->digits, n + 1);
2611 TEMP(1)->used = n + 1;
2612 CLAMP(TEMP(1));
2614 /* Clamp r for the comparison. Comparisons do not like leading zeros. */
2615 CLAMP(&r);
2616 if (s_ucmp(TEMP(1), &r) > 0) { /* would the remainder be negative? */
2617 qhat -= 1; /* try a smaller q */
2618 s_dbmul(MP_DIGITS(v), (mp_digit)qhat, TEMP(1)->digits, n + 1);
2619 TEMP(1)->used = n + 1;
2620 CLAMP(TEMP(1));
2621 if (s_ucmp(TEMP(1), &r) > 0) { /* would the remainder be negative? */
2622 assert(qhat > 0);
2623 qhat -= 1; /* try a smaller q */
2624 s_dbmul(MP_DIGITS(v), (mp_digit)qhat, TEMP(1)->digits, n + 1);
2625 TEMP(1)->used = n + 1;
2626 CLAMP(TEMP(1));
2628 assert(s_ucmp(TEMP(1), &r) <= 0 && "The mathematics failed us.");
2630 /* Unclamp r. The D algorithm expects r = u_{j+n}..u_j to always be n+1
2631 digits long. */
2632 r.used = n + 1;
2634 /* D4: Multiply and subtract
2636 Note: The multiply was completed above so we only need to subtract here.
2638 s_usub(r.digits, TEMP(1)->digits, r.digits, r.used, TEMP(1)->used);
2640 /* D5: Test remainder
2642 Note: Not needed because we always check that qhat is the correct value
2643 before performing the subtract. Value cast to mp_digit to prevent
2644 warning, qhat has been clamped to MP_DIGIT_MAX
2646 TEMP(0)->digits[j] = (mp_digit)qhat;
2648 /* D6: Add back
2649 Note: Not needed because we always check that qhat is the correct value
2650 before performing the subtract.
2653 /* D7: Loop on j */
2654 r.digits--;
2655 ZERO(TEMP(1)->digits, TEMP(1)->alloc);
2658 /* Get rid of leading zeros in q */
2659 TEMP(0)->used = m + 1;
2660 CLAMP(TEMP(0));
2662 /* Denormalize the remainder */
2663 CLAMP(u); /* use u here because the r.digits pointer is off-by-one */
2664 if (k != 0) s_qdiv(u, k);
2666 mp_int_copy(u, v); /* ok: 0 <= r < v */
2667 mp_int_copy(TEMP(0), u); /* ok: q <= u */
2669 CLEANUP_TEMP();
2670 return MP_OK;
2673 static int s_outlen(mp_int z, mp_size r) {
2674 assert(r >= MP_MIN_RADIX && r <= MP_MAX_RADIX);
2676 mp_result bits = mp_int_count_bits(z);
2677 double raw = (double)bits * s_log2[r];
2679 return (int)(raw + 0.999999);
2682 static mp_size s_inlen(int len, mp_size r) {
2683 double raw = (double)len / s_log2[r];
2684 mp_size bits = (mp_size)(raw + 0.5);
2686 return (mp_size)((bits + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT) + 1;
2689 static int s_ch2val(char c, int r) {
2690 int out;
2693 * In some locales, isalpha() accepts characters outside the range A-Z,
2694 * producing out<0 or out>=36. The "out >= r" check will always catch
2695 * out>=36. Though nothing explicitly catches out<0, our caller reacts the
2696 * same way to every negative return value.
2698 if (isdigit((unsigned char)c))
2699 out = c - '0';
2700 else if (r > 10 && isalpha((unsigned char)c))
2701 out = toupper((unsigned char)c) - 'A' + 10;
2702 else
2703 return -1;
2705 return (out >= r) ? -1 : out;
2708 static char s_val2ch(int v, int caps) {
2709 assert(v >= 0);
2711 if (v < 10) {
2712 return v + '0';
2713 } else {
2714 char out = (v - 10) + 'a';
2716 if (caps) {
2717 return toupper((unsigned char)out);
2718 } else {
2719 return out;
2724 static void s_2comp(unsigned char *buf, int len) {
2725 unsigned short s = 1;
2727 for (int i = len - 1; i >= 0; --i) {
2728 unsigned char c = ~buf[i];
2730 s = c + s;
2731 c = s & UCHAR_MAX;
2732 s >>= CHAR_BIT;
2734 buf[i] = c;
2737 /* last carry out is ignored */
2740 static mp_result s_tobin(mp_int z, unsigned char *buf, int *limpos, int pad) {
2741 int pos = 0, limit = *limpos;
2742 mp_size uz = MP_USED(z);
2743 mp_digit *dz = MP_DIGITS(z);
2745 while (uz > 0 && pos < limit) {
2746 mp_digit d = *dz++;
2747 int i;
2749 for (i = sizeof(mp_digit); i > 0 && pos < limit; --i) {
2750 buf[pos++] = (unsigned char)d;
2751 d >>= CHAR_BIT;
2753 /* Don't write leading zeroes */
2754 if (d == 0 && uz == 1) i = 0; /* exit loop without signaling truncation */
2757 /* Detect truncation (loop exited with pos >= limit) */
2758 if (i > 0) break;
2760 --uz;
2763 if (pad != 0 && (buf[pos - 1] >> (CHAR_BIT - 1))) {
2764 if (pos < limit) {
2765 buf[pos++] = 0;
2766 } else {
2767 uz = 1;
2771 /* Digits are in reverse order, fix that */
2772 REV(buf, pos);
2774 /* Return the number of bytes actually written */
2775 *limpos = pos;
2777 return (uz == 0) ? MP_OK : MP_TRUNC;
2780 /* Here there be dragons */