1 // SPDX-License-Identifier: GPL-2.0 OR MIT
3 * Copyright (C) 2016-2017 INRIA and Microsoft Corporation.
4 * Copyright (C) 2018-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
6 * This is a machine-generated formally verified implementation of Curve25519
7 * ECDH from: <https://github.com/mitls/hacl-star>. Though originally machine
8 * generated, it has been tweaked to be suitable for use in the kernel. It is
9 * optimized for 64-bit machines that can efficiently work with 128-bit
13 #include <linux/unaligned.h>
14 #include <crypto/curve25519.h>
15 #include <linux/string.h>
17 static __always_inline u64
u64_eq_mask(u64 a
, u64 b
)
20 u64 minus_x
= ~x
+ (u64
)1U;
21 u64 x_or_minus_x
= x
| minus_x
;
22 u64 xnx
= x_or_minus_x
>> (u32
)63U;
23 u64 c
= xnx
- (u64
)1U;
27 static __always_inline u64
u64_gte_mask(u64 a
, u64 b
)
33 u64 x_sub_y_xor_y
= x_sub_y
^ y
;
34 u64 q
= x_xor_y
| x_sub_y_xor_y
;
36 u64 x_xor_q_
= x_xor_q
>> (u32
)63U;
37 u64 c
= x_xor_q_
- (u64
)1U;
41 static __always_inline
void modulo_carry_top(u64
*b
)
45 u64 b4_
= b4
& 0x7ffffffffffffLLU
;
46 u64 b0_
= b0
+ 19 * (b4
>> 51);
51 static __always_inline
void fproduct_copy_from_wide_(u64
*output
, u128
*input
)
55 output
[0] = ((u64
)(xi
));
59 output
[1] = ((u64
)(xi
));
63 output
[2] = ((u64
)(xi
));
67 output
[3] = ((u64
)(xi
));
71 output
[4] = ((u64
)(xi
));
75 static __always_inline
void
76 fproduct_sum_scalar_multiplication_(u128
*output
, u64
*input
, u64 s
)
78 output
[0] += (u128
)input
[0] * s
;
79 output
[1] += (u128
)input
[1] * s
;
80 output
[2] += (u128
)input
[2] * s
;
81 output
[3] += (u128
)input
[3] * s
;
82 output
[4] += (u128
)input
[4] * s
;
85 static __always_inline
void fproduct_carry_wide_(u128
*tmp
)
90 u128 tctrp1
= tmp
[ctr
+ 1];
91 u64 r0
= ((u64
)(tctr
)) & 0x7ffffffffffffLLU
;
92 u128 c
= ((tctr
) >> (51));
93 tmp
[ctr
] = ((u128
)(r0
));
94 tmp
[ctr
+ 1] = ((tctrp1
) + (c
));
99 u128 tctrp1
= tmp
[ctr
+ 1];
100 u64 r0
= ((u64
)(tctr
)) & 0x7ffffffffffffLLU
;
101 u128 c
= ((tctr
) >> (51));
102 tmp
[ctr
] = ((u128
)(r0
));
103 tmp
[ctr
+ 1] = ((tctrp1
) + (c
));
108 u128 tctr
= tmp
[ctr
];
109 u128 tctrp1
= tmp
[ctr
+ 1];
110 u64 r0
= ((u64
)(tctr
)) & 0x7ffffffffffffLLU
;
111 u128 c
= ((tctr
) >> (51));
112 tmp
[ctr
] = ((u128
)(r0
));
113 tmp
[ctr
+ 1] = ((tctrp1
) + (c
));
117 u128 tctr
= tmp
[ctr
];
118 u128 tctrp1
= tmp
[ctr
+ 1];
119 u64 r0
= ((u64
)(tctr
)) & 0x7ffffffffffffLLU
;
120 u128 c
= ((tctr
) >> (51));
121 tmp
[ctr
] = ((u128
)(r0
));
122 tmp
[ctr
+ 1] = ((tctrp1
) + (c
));
126 static __always_inline
void fmul_shift_reduce(u64
*output
)
132 u64 z
= output
[ctr
- 1];
137 u64 z
= output
[ctr
- 1];
142 u64 z
= output
[ctr
- 1];
147 u64 z
= output
[ctr
- 1];
155 static __always_inline
void fmul_mul_shift_reduce_(u128
*output
, u64
*input
,
161 u64 input2i
= input21
[0];
162 fproduct_sum_scalar_multiplication_(output
, input
, input2i
);
163 fmul_shift_reduce(input
);
166 u64 input2i
= input21
[1];
167 fproduct_sum_scalar_multiplication_(output
, input
, input2i
);
168 fmul_shift_reduce(input
);
171 u64 input2i
= input21
[2];
172 fproduct_sum_scalar_multiplication_(output
, input
, input2i
);
173 fmul_shift_reduce(input
);
176 u64 input2i
= input21
[3];
177 fproduct_sum_scalar_multiplication_(output
, input
, input2i
);
178 fmul_shift_reduce(input
);
181 input2i
= input21
[i
];
182 fproduct_sum_scalar_multiplication_(output
, input
, input2i
);
185 static __always_inline
void fmul_fmul(u64
*output
, u64
*input
, u64
*input21
)
187 u64 tmp
[5] = { input
[0], input
[1], input
[2], input
[3], input
[4] };
198 fmul_mul_shift_reduce_(t
, tmp
, input21
);
199 fproduct_carry_wide_(t
);
202 b4_
= ((b4
) & (((u128
)(0x7ffffffffffffLLU
))));
203 b0_
= ((b0
) + (((u128
)(19) * (((u64
)(((b4
) >> (51))))))));
206 fproduct_copy_from_wide_(output
, t
);
209 i0_
= i0
& 0x7ffffffffffffLLU
;
210 i1_
= i1
+ (i0
>> 51);
216 static __always_inline
void fsquare_fsquare__(u128
*tmp
, u64
*output
)
225 u64 d2
= r2
* 2 * 19;
228 u128 s0
= ((((((u128
)(r0
) * (r0
))) + (((u128
)(d4
) * (r1
))))) +
229 (((u128
)(d2
) * (r3
))));
230 u128 s1
= ((((((u128
)(d0
) * (r1
))) + (((u128
)(d4
) * (r2
))))) +
231 (((u128
)(r3
* 19) * (r3
))));
232 u128 s2
= ((((((u128
)(d0
) * (r2
))) + (((u128
)(r1
) * (r1
))))) +
233 (((u128
)(d4
) * (r3
))));
234 u128 s3
= ((((((u128
)(d0
) * (r3
))) + (((u128
)(d1
) * (r2
))))) +
235 (((u128
)(r4
) * (d419
))));
236 u128 s4
= ((((((u128
)(d0
) * (r4
))) + (((u128
)(d1
) * (r3
))))) +
237 (((u128
)(r2
) * (r2
))));
245 static __always_inline
void fsquare_fsquare_(u128
*tmp
, u64
*output
)
255 fsquare_fsquare__(tmp
, output
);
256 fproduct_carry_wide_(tmp
);
259 b4_
= ((b4
) & (((u128
)(0x7ffffffffffffLLU
))));
260 b0_
= ((b0
) + (((u128
)(19) * (((u64
)(((b4
) >> (51))))))));
263 fproduct_copy_from_wide_(output
, tmp
);
266 i0_
= i0
& 0x7ffffffffffffLLU
;
267 i1_
= i1
+ (i0
>> 51);
272 static __always_inline
void fsquare_fsquare_times_(u64
*output
, u128
*tmp
,
276 fsquare_fsquare_(tmp
, output
);
277 for (i
= 1; i
< count1
; ++i
)
278 fsquare_fsquare_(tmp
, output
);
281 static __always_inline
void fsquare_fsquare_times(u64
*output
, u64
*input
,
285 memcpy(output
, input
, 5 * sizeof(*input
));
286 fsquare_fsquare_times_(output
, t
, count1
);
289 static __always_inline
void fsquare_fsquare_times_inplace(u64
*output
,
293 fsquare_fsquare_times_(output
, t
, count1
);
296 static __always_inline
void crecip_crecip(u64
*out
, u64
*z
)
309 fsquare_fsquare_times(a0
, z
, 1);
310 fsquare_fsquare_times(t00
, a0
, 2);
311 fmul_fmul(b0
, t00
, z
);
312 fmul_fmul(a0
, b0
, a0
);
313 fsquare_fsquare_times(t00
, a0
, 1);
314 fmul_fmul(b0
, t00
, b0
);
315 fsquare_fsquare_times(t00
, b0
, 5);
319 fmul_fmul(b1
, t01
, b1
);
320 fsquare_fsquare_times(t01
, b1
, 10);
321 fmul_fmul(c0
, t01
, b1
);
322 fsquare_fsquare_times(t01
, c0
, 20);
323 fmul_fmul(t01
, t01
, c0
);
324 fsquare_fsquare_times_inplace(t01
, 10);
325 fmul_fmul(b1
, t01
, b1
);
326 fsquare_fsquare_times(t01
, b1
, 50);
332 fsquare_fsquare_times(t0
, c
, 100);
333 fmul_fmul(t0
, t0
, c
);
334 fsquare_fsquare_times_inplace(t0
, 50);
335 fmul_fmul(t0
, t0
, b
);
336 fsquare_fsquare_times_inplace(t0
, 5);
337 fmul_fmul(out
, t0
, a
);
340 static __always_inline
void fsum(u64
*a
, u64
*b
)
349 static __always_inline
void fdifference(u64
*a
, u64
*b
)
357 memcpy(tmp
, b
, 5 * sizeof(*b
));
363 tmp
[0] = b0
+ 0x3fffffffffff68LLU
;
364 tmp
[1] = b1
+ 0x3ffffffffffff8LLU
;
365 tmp
[2] = b2
+ 0x3ffffffffffff8LLU
;
366 tmp
[3] = b3
+ 0x3ffffffffffff8LLU
;
367 tmp
[4] = b4
+ 0x3ffffffffffff8LLU
;
395 static __always_inline
void fscalar(u64
*output
, u64
*b
, u64 s
)
404 tmp
[0] = ((u128
)(xi
) * (s
));
408 tmp
[1] = ((u128
)(xi
) * (s
));
412 tmp
[2] = ((u128
)(xi
) * (s
));
416 tmp
[3] = ((u128
)(xi
) * (s
));
420 tmp
[4] = ((u128
)(xi
) * (s
));
422 fproduct_carry_wide_(tmp
);
425 b4_
= ((b4
) & (((u128
)(0x7ffffffffffffLLU
))));
426 b0_
= ((b0
) + (((u128
)(19) * (((u64
)(((b4
) >> (51))))))));
429 fproduct_copy_from_wide_(output
, tmp
);
432 static __always_inline
void fmul(u64
*output
, u64
*a
, u64
*b
)
434 fmul_fmul(output
, a
, b
);
437 static __always_inline
void crecip(u64
*output
, u64
*input
)
439 crecip_crecip(output
, input
);
442 static __always_inline
void point_swap_conditional_step(u64
*a
, u64
*b
,
448 u64 x
= swap1
& (ai
^ bi
);
455 static __always_inline
void point_swap_conditional5(u64
*a
, u64
*b
, u64 swap1
)
457 point_swap_conditional_step(a
, b
, swap1
, 5);
458 point_swap_conditional_step(a
, b
, swap1
, 4);
459 point_swap_conditional_step(a
, b
, swap1
, 3);
460 point_swap_conditional_step(a
, b
, swap1
, 2);
461 point_swap_conditional_step(a
, b
, swap1
, 1);
464 static __always_inline
void point_swap_conditional(u64
*a
, u64
*b
, u64 iswap
)
466 u64 swap1
= 0 - iswap
;
467 point_swap_conditional5(a
, b
, swap1
);
468 point_swap_conditional5(a
+ 5, b
+ 5, swap1
);
471 static __always_inline
void point_copy(u64
*output
, u64
*input
)
473 memcpy(output
, input
, 5 * sizeof(*input
));
474 memcpy(output
+ 5, input
+ 5, 5 * sizeof(*input
));
477 static __always_inline
void addanddouble_fmonty(u64
*pp
, u64
*ppq
, u64
*p
,
488 u64
*zprime
= pq
+ 5;
491 u64
*origxprime0
= buf
+ 5;
497 memcpy(origx
, x
, 5 * sizeof(*x
));
499 fdifference(z
, origx
);
500 memcpy(origxprime0
, xprime
, 5 * sizeof(*xprime
));
501 fsum(xprime
, zprime
);
502 fdifference(zprime
, origxprime0
);
503 fmul(xxprime0
, xprime
, z
);
504 fmul(zzprime0
, x
, zprime
);
505 origxprime
= buf
+ 5;
517 memcpy(origxprime
, xxprime
, 5 * sizeof(*xxprime
));
518 fsum(xxprime
, zzprime
);
519 fdifference(zzprime
, origxprime
);
520 fsquare_fsquare_times(x3
, xxprime
, 1);
521 fsquare_fsquare_times(zzzprime
, zzprime
, 1);
522 fmul(z3
, zzzprime
, qx
);
523 fsquare_fsquare_times(xx0
, x
, 1);
524 fsquare_fsquare_times(zz0
, z
, 1);
536 fscalar(zzz
, zz
, scalar
);
543 static __always_inline
void
544 ladder_smallloop_cmult_small_loop_step(u64
*nq
, u64
*nqpq
, u64
*nq2
, u64
*nqpq2
,
547 u64 bit0
= (u64
)(byt
>> 7);
549 point_swap_conditional(nq
, nqpq
, bit0
);
550 addanddouble_fmonty(nq2
, nqpq2
, nq
, nqpq
, q
);
551 bit
= (u64
)(byt
>> 7);
552 point_swap_conditional(nq2
, nqpq2
, bit
);
555 static __always_inline
void
556 ladder_smallloop_cmult_small_loop_double_step(u64
*nq
, u64
*nqpq
, u64
*nq2
,
557 u64
*nqpq2
, u64
*q
, u8 byt
)
560 ladder_smallloop_cmult_small_loop_step(nq
, nqpq
, nq2
, nqpq2
, q
, byt
);
562 ladder_smallloop_cmult_small_loop_step(nq2
, nqpq2
, nq
, nqpq
, q
, byt1
);
565 static __always_inline
void
566 ladder_smallloop_cmult_small_loop(u64
*nq
, u64
*nqpq
, u64
*nq2
, u64
*nqpq2
,
567 u64
*q
, u8 byt
, u32 i
)
570 ladder_smallloop_cmult_small_loop_double_step(nq
, nqpq
, nq2
,
576 static __always_inline
void ladder_bigloop_cmult_big_loop(u8
*n1
, u64
*nq
,
583 ladder_smallloop_cmult_small_loop(nq
, nqpq
, nq2
, nqpq2
, q
,
588 static void ladder_cmult(u64
*result
, u8
*n1
, u64
*q
)
590 u64 point_buf
[40] = { 0 };
592 u64
*nqpq
= point_buf
+ 10;
593 u64
*nq2
= point_buf
+ 20;
594 u64
*nqpq2
= point_buf
+ 30;
597 ladder_bigloop_cmult_big_loop(n1
, nq
, nqpq
, nq2
, nqpq2
, q
, 32);
598 point_copy(result
, nq
);
601 static __always_inline
void format_fexpand(u64
*output
, const u8
*input
)
603 const u8
*x00
= input
+ 6;
604 const u8
*x01
= input
+ 12;
605 const u8
*x02
= input
+ 19;
606 const u8
*x0
= input
+ 24;
607 u64 i0
, i1
, i2
, i3
, i4
, output0
, output1
, output2
, output3
, output4
;
608 i0
= get_unaligned_le64(input
);
609 i1
= get_unaligned_le64(x00
);
610 i2
= get_unaligned_le64(x01
);
611 i3
= get_unaligned_le64(x02
);
612 i4
= get_unaligned_le64(x0
);
613 output0
= i0
& 0x7ffffffffffffLLU
;
614 output1
= i1
>> 3 & 0x7ffffffffffffLLU
;
615 output2
= i2
>> 6 & 0x7ffffffffffffLLU
;
616 output3
= i3
>> 1 & 0x7ffffffffffffLLU
;
617 output4
= i4
>> 12 & 0x7ffffffffffffLLU
;
625 static __always_inline
void format_fcontract_first_carry_pass(u64
*input
)
632 u64 t1_
= t1
+ (t0
>> 51);
633 u64 t0_
= t0
& 0x7ffffffffffffLLU
;
634 u64 t2_
= t2
+ (t1_
>> 51);
635 u64 t1__
= t1_
& 0x7ffffffffffffLLU
;
636 u64 t3_
= t3
+ (t2_
>> 51);
637 u64 t2__
= t2_
& 0x7ffffffffffffLLU
;
638 u64 t4_
= t4
+ (t3_
>> 51);
639 u64 t3__
= t3_
& 0x7ffffffffffffLLU
;
647 static __always_inline
void format_fcontract_first_carry_full(u64
*input
)
649 format_fcontract_first_carry_pass(input
);
650 modulo_carry_top(input
);
653 static __always_inline
void format_fcontract_second_carry_pass(u64
*input
)
660 u64 t1_
= t1
+ (t0
>> 51);
661 u64 t0_
= t0
& 0x7ffffffffffffLLU
;
662 u64 t2_
= t2
+ (t1_
>> 51);
663 u64 t1__
= t1_
& 0x7ffffffffffffLLU
;
664 u64 t3_
= t3
+ (t2_
>> 51);
665 u64 t2__
= t2_
& 0x7ffffffffffffLLU
;
666 u64 t4_
= t4
+ (t3_
>> 51);
667 u64 t3__
= t3_
& 0x7ffffffffffffLLU
;
675 static __always_inline
void format_fcontract_second_carry_full(u64
*input
)
681 format_fcontract_second_carry_pass(input
);
682 modulo_carry_top(input
);
685 i0_
= i0
& 0x7ffffffffffffLLU
;
686 i1_
= i1
+ (i0
>> 51);
691 static __always_inline
void format_fcontract_trim(u64
*input
)
698 u64 mask0
= u64_gte_mask(a0
, 0x7ffffffffffedLLU
);
699 u64 mask1
= u64_eq_mask(a1
, 0x7ffffffffffffLLU
);
700 u64 mask2
= u64_eq_mask(a2
, 0x7ffffffffffffLLU
);
701 u64 mask3
= u64_eq_mask(a3
, 0x7ffffffffffffLLU
);
702 u64 mask4
= u64_eq_mask(a4
, 0x7ffffffffffffLLU
);
703 u64 mask
= (((mask0
& mask1
) & mask2
) & mask3
) & mask4
;
704 u64 a0_
= a0
- (0x7ffffffffffedLLU
& mask
);
705 u64 a1_
= a1
- (0x7ffffffffffffLLU
& mask
);
706 u64 a2_
= a2
- (0x7ffffffffffffLLU
& mask
);
707 u64 a3_
= a3
- (0x7ffffffffffffLLU
& mask
);
708 u64 a4_
= a4
- (0x7ffffffffffffLLU
& mask
);
716 static __always_inline
void format_fcontract_store(u8
*output
, u64
*input
)
723 u64 o0
= t1
<< 51 | t0
;
724 u64 o1
= t2
<< 38 | t1
>> 13;
725 u64 o2
= t3
<< 25 | t2
>> 26;
726 u64 o3
= t4
<< 12 | t3
>> 39;
729 u8
*b2
= output
+ 16;
730 u8
*b3
= output
+ 24;
731 put_unaligned_le64(o0
, b0
);
732 put_unaligned_le64(o1
, b1
);
733 put_unaligned_le64(o2
, b2
);
734 put_unaligned_le64(o3
, b3
);
737 static __always_inline
void format_fcontract(u8
*output
, u64
*input
)
739 format_fcontract_first_carry_full(input
);
740 format_fcontract_second_carry_full(input
);
741 format_fcontract_trim(input
);
742 format_fcontract_store(output
, input
);
745 static __always_inline
void format_scalar_of_point(u8
*scalar
, u64
*point
)
749 u64 buf
[10] __aligned(32) = { 0 };
754 format_fcontract(scalar
, sc
);
757 void curve25519_generic(u8 mypublic
[CURVE25519_KEY_SIZE
],
758 const u8 secret
[CURVE25519_KEY_SIZE
],
759 const u8 basepoint
[CURVE25519_KEY_SIZE
])
761 u64 buf0
[10] __aligned(32) = { 0 };
765 format_fexpand(x0
, basepoint
);
769 u8 e
[32] __aligned(32) = { 0 };
771 memcpy(e
, secret
, 32);
772 curve25519_clamp_secret(e
);
779 ladder_cmult(nq
, scalar
, q
);
780 format_scalar_of_point(mypublic
, nq
);
781 memzero_explicit(buf
, sizeof(buf
));
783 memzero_explicit(e
, sizeof(e
));
785 memzero_explicit(buf0
, sizeof(buf0
));