libtommath: Fix possible integer overflow CVE-2023-36328
[heimdal.git] / lib / hcrypto / libtommath / bn_s_mp_exptmod.c
blobc3bfa95e834460d00f89e557fbdd32e7f7150815
1 #include "tommath_private.h"
2 #ifdef BN_S_MP_EXPTMOD_C
3 /* LibTomMath, multiple-precision integer library -- Tom St Denis */
4 /* SPDX-License-Identifier: Unlicense */
6 #ifdef MP_LOW_MEM
7 # define TAB_SIZE 32
8 # define MAX_WINSIZE 5
9 #else
10 # define TAB_SIZE 256
11 # define MAX_WINSIZE 0
12 #endif
14 mp_err s_mp_exptmod(const mp_int *G, const mp_int *X, const mp_int *P, mp_int *Y, int redmode)
16 mp_int M[TAB_SIZE], res, mu;
17 mp_digit buf;
18 mp_err err;
19 int bitbuf, bitcpy, bitcnt, mode, digidx, x, y, winsize;
20 mp_err(*redux)(mp_int *x, const mp_int *m, const mp_int *mu);
22 /* find window size */
23 x = mp_count_bits(X);
24 if (x <= 7) {
25 winsize = 2;
26 } else if (x <= 36) {
27 winsize = 3;
28 } else if (x <= 140) {
29 winsize = 4;
30 } else if (x <= 450) {
31 winsize = 5;
32 } else if (x <= 1303) {
33 winsize = 6;
34 } else if (x <= 3529) {
35 winsize = 7;
36 } else {
37 winsize = 8;
40 winsize = MAX_WINSIZE ? MP_MIN(MAX_WINSIZE, winsize) : winsize;
42 /* init M array */
43 /* init first cell */
44 if ((err = mp_init(&M[1])) != MP_OKAY) {
45 return err;
48 /* now init the second half of the array */
49 for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
50 if ((err = mp_init(&M[x])) != MP_OKAY) {
51 for (y = 1<<(winsize-1); y < x; y++) {
52 mp_clear(&M[y]);
54 mp_clear(&M[1]);
55 return err;
59 /* create mu, used for Barrett reduction */
60 if ((err = mp_init(&mu)) != MP_OKAY) goto LBL_M;
62 if (redmode == 0) {
63 if ((err = mp_reduce_setup(&mu, P)) != MP_OKAY) goto LBL_MU;
64 redux = mp_reduce;
65 } else {
66 if ((err = mp_reduce_2k_setup_l(P, &mu)) != MP_OKAY) goto LBL_MU;
67 redux = mp_reduce_2k_l;
70 /* create M table
72 * The M table contains powers of the base,
73 * e.g. M[x] = G**x mod P
75 * The first half of the table is not
76 * computed though accept for M[0] and M[1]
78 if ((err = mp_mod(G, P, &M[1])) != MP_OKAY) goto LBL_MU;
80 /* compute the value at M[1<<(winsize-1)] by squaring
81 * M[1] (winsize-1) times
83 if ((err = mp_copy(&M[1], &M[(size_t)1 << (winsize - 1)])) != MP_OKAY) goto LBL_MU;
85 for (x = 0; x < (winsize - 1); x++) {
86 /* square it */
87 if ((err = mp_sqr(&M[(size_t)1 << (winsize - 1)],
88 &M[(size_t)1 << (winsize - 1)])) != MP_OKAY) goto LBL_MU;
90 /* reduce modulo P */
91 if ((err = redux(&M[(size_t)1 << (winsize - 1)], P, &mu)) != MP_OKAY) goto LBL_MU;
94 /* create upper table, that is M[x] = M[x-1] * M[1] (mod P)
95 * for x = (2**(winsize - 1) + 1) to (2**winsize - 1)
97 for (x = (1 << (winsize - 1)) + 1; x < (1 << winsize); x++) {
98 if ((err = mp_mul(&M[x - 1], &M[1], &M[x])) != MP_OKAY) goto LBL_MU;
99 if ((err = redux(&M[x], P, &mu)) != MP_OKAY) goto LBL_MU;
102 /* setup result */
103 if ((err = mp_init(&res)) != MP_OKAY) goto LBL_MU;
104 mp_set(&res, 1uL);
106 /* set initial mode and bit cnt */
107 mode = 0;
108 bitcnt = 1;
109 buf = 0;
110 digidx = X->used - 1;
111 bitcpy = 0;
112 bitbuf = 0;
114 for (;;) {
115 /* grab next digit as required */
116 if (--bitcnt == 0) {
117 /* if digidx == -1 we are out of digits */
118 if (digidx == -1) {
119 break;
121 /* read next digit and reset the bitcnt */
122 buf = X->dp[digidx--];
123 bitcnt = (int)MP_DIGIT_BIT;
126 /* grab the next msb from the exponent */
127 y = (buf >> (mp_digit)(MP_DIGIT_BIT - 1)) & 1uL;
128 buf <<= (mp_digit)1;
130 /* if the bit is zero and mode == 0 then we ignore it
131 * These represent the leading zero bits before the first 1 bit
132 * in the exponent. Technically this opt is not required but it
133 * does lower the # of trivial squaring/reductions used
135 if ((mode == 0) && (y == 0)) {
136 continue;
139 /* if the bit is zero and mode == 1 then we square */
140 if ((mode == 1) && (y == 0)) {
141 if ((err = mp_sqr(&res, &res)) != MP_OKAY) goto LBL_RES;
142 if ((err = redux(&res, P, &mu)) != MP_OKAY) goto LBL_RES;
143 continue;
146 /* else we add it to the window */
147 bitbuf |= (y << (winsize - ++bitcpy));
148 mode = 2;
150 if (bitcpy == winsize) {
151 /* ok window is filled so square as required and multiply */
152 /* square first */
153 for (x = 0; x < winsize; x++) {
154 if ((err = mp_sqr(&res, &res)) != MP_OKAY) goto LBL_RES;
155 if ((err = redux(&res, P, &mu)) != MP_OKAY) goto LBL_RES;
158 /* then multiply */
159 if ((err = mp_mul(&res, &M[bitbuf], &res)) != MP_OKAY) goto LBL_RES;
160 if ((err = redux(&res, P, &mu)) != MP_OKAY) goto LBL_RES;
162 /* empty window and reset */
163 bitcpy = 0;
164 bitbuf = 0;
165 mode = 1;
169 /* if bits remain then square/multiply */
170 if ((mode == 2) && (bitcpy > 0)) {
171 /* square then multiply if the bit is set */
172 for (x = 0; x < bitcpy; x++) {
173 if ((err = mp_sqr(&res, &res)) != MP_OKAY) goto LBL_RES;
174 if ((err = redux(&res, P, &mu)) != MP_OKAY) goto LBL_RES;
176 bitbuf <<= 1;
177 if ((bitbuf & (1 << winsize)) != 0) {
178 /* then multiply */
179 if ((err = mp_mul(&res, &M[1], &res)) != MP_OKAY) goto LBL_RES;
180 if ((err = redux(&res, P, &mu)) != MP_OKAY) goto LBL_RES;
185 mp_exch(&res, Y);
186 err = MP_OKAY;
187 LBL_RES:
188 mp_clear(&res);
189 LBL_MU:
190 mp_clear(&mu);
191 LBL_M:
192 mp_clear(&M[1]);
193 for (x = 1<<(winsize-1); x < (1 << winsize); x++) {
194 mp_clear(&M[x]);
196 return err;
198 #endif