1 // SPDX-License-Identifier: GPL-2.0 OR MIT
3 * Copyright (C) 2016-2017 INRIA and Microsoft Corporation.
4 * Copyright (C) 2018-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
6 * This is a machine-generated formally verified implementation of Curve25519
7 * ECDH from: <https://github.com/mitls/hacl-star>. Though originally machine
8 * generated, it has been tweaked to be suitable for use in the kernel. It is
9 * optimized for 64-bit machines that can efficiently work with 128-bit
13 #include <asm/unaligned.h>
14 #include <crypto/curve25519.h>
15 #include <linux/string.h>
17 typedef __uint128_t u128
;
19 static __always_inline u64
u64_eq_mask(u64 a
, u64 b
)
22 u64 minus_x
= ~x
+ (u64
)1U;
23 u64 x_or_minus_x
= x
| minus_x
;
24 u64 xnx
= x_or_minus_x
>> (u32
)63U;
25 u64 c
= xnx
- (u64
)1U;
29 static __always_inline u64
u64_gte_mask(u64 a
, u64 b
)
35 u64 x_sub_y_xor_y
= x_sub_y
^ y
;
36 u64 q
= x_xor_y
| x_sub_y_xor_y
;
38 u64 x_xor_q_
= x_xor_q
>> (u32
)63U;
39 u64 c
= x_xor_q_
- (u64
)1U;
43 static __always_inline
void modulo_carry_top(u64
*b
)
47 u64 b4_
= b4
& 0x7ffffffffffffLLU
;
48 u64 b0_
= b0
+ 19 * (b4
>> 51);
53 static __always_inline
void fproduct_copy_from_wide_(u64
*output
, u128
*input
)
57 output
[0] = ((u64
)(xi
));
61 output
[1] = ((u64
)(xi
));
65 output
[2] = ((u64
)(xi
));
69 output
[3] = ((u64
)(xi
));
73 output
[4] = ((u64
)(xi
));
77 static __always_inline
void
78 fproduct_sum_scalar_multiplication_(u128
*output
, u64
*input
, u64 s
)
80 output
[0] += (u128
)input
[0] * s
;
81 output
[1] += (u128
)input
[1] * s
;
82 output
[2] += (u128
)input
[2] * s
;
83 output
[3] += (u128
)input
[3] * s
;
84 output
[4] += (u128
)input
[4] * s
;
87 static __always_inline
void fproduct_carry_wide_(u128
*tmp
)
92 u128 tctrp1
= tmp
[ctr
+ 1];
93 u64 r0
= ((u64
)(tctr
)) & 0x7ffffffffffffLLU
;
94 u128 c
= ((tctr
) >> (51));
95 tmp
[ctr
] = ((u128
)(r0
));
96 tmp
[ctr
+ 1] = ((tctrp1
) + (c
));
100 u128 tctr
= tmp
[ctr
];
101 u128 tctrp1
= tmp
[ctr
+ 1];
102 u64 r0
= ((u64
)(tctr
)) & 0x7ffffffffffffLLU
;
103 u128 c
= ((tctr
) >> (51));
104 tmp
[ctr
] = ((u128
)(r0
));
105 tmp
[ctr
+ 1] = ((tctrp1
) + (c
));
110 u128 tctr
= tmp
[ctr
];
111 u128 tctrp1
= tmp
[ctr
+ 1];
112 u64 r0
= ((u64
)(tctr
)) & 0x7ffffffffffffLLU
;
113 u128 c
= ((tctr
) >> (51));
114 tmp
[ctr
] = ((u128
)(r0
));
115 tmp
[ctr
+ 1] = ((tctrp1
) + (c
));
119 u128 tctr
= tmp
[ctr
];
120 u128 tctrp1
= tmp
[ctr
+ 1];
121 u64 r0
= ((u64
)(tctr
)) & 0x7ffffffffffffLLU
;
122 u128 c
= ((tctr
) >> (51));
123 tmp
[ctr
] = ((u128
)(r0
));
124 tmp
[ctr
+ 1] = ((tctrp1
) + (c
));
128 static __always_inline
void fmul_shift_reduce(u64
*output
)
134 u64 z
= output
[ctr
- 1];
139 u64 z
= output
[ctr
- 1];
144 u64 z
= output
[ctr
- 1];
149 u64 z
= output
[ctr
- 1];
157 static __always_inline
void fmul_mul_shift_reduce_(u128
*output
, u64
*input
,
163 u64 input2i
= input21
[0];
164 fproduct_sum_scalar_multiplication_(output
, input
, input2i
);
165 fmul_shift_reduce(input
);
168 u64 input2i
= input21
[1];
169 fproduct_sum_scalar_multiplication_(output
, input
, input2i
);
170 fmul_shift_reduce(input
);
173 u64 input2i
= input21
[2];
174 fproduct_sum_scalar_multiplication_(output
, input
, input2i
);
175 fmul_shift_reduce(input
);
178 u64 input2i
= input21
[3];
179 fproduct_sum_scalar_multiplication_(output
, input
, input2i
);
180 fmul_shift_reduce(input
);
183 input2i
= input21
[i
];
184 fproduct_sum_scalar_multiplication_(output
, input
, input2i
);
187 static __always_inline
void fmul_fmul(u64
*output
, u64
*input
, u64
*input21
)
189 u64 tmp
[5] = { input
[0], input
[1], input
[2], input
[3], input
[4] };
200 fmul_mul_shift_reduce_(t
, tmp
, input21
);
201 fproduct_carry_wide_(t
);
204 b4_
= ((b4
) & (((u128
)(0x7ffffffffffffLLU
))));
205 b0_
= ((b0
) + (((u128
)(19) * (((u64
)(((b4
) >> (51))))))));
208 fproduct_copy_from_wide_(output
, t
);
211 i0_
= i0
& 0x7ffffffffffffLLU
;
212 i1_
= i1
+ (i0
>> 51);
218 static __always_inline
void fsquare_fsquare__(u128
*tmp
, u64
*output
)
227 u64 d2
= r2
* 2 * 19;
230 u128 s0
= ((((((u128
)(r0
) * (r0
))) + (((u128
)(d4
) * (r1
))))) +
231 (((u128
)(d2
) * (r3
))));
232 u128 s1
= ((((((u128
)(d0
) * (r1
))) + (((u128
)(d4
) * (r2
))))) +
233 (((u128
)(r3
* 19) * (r3
))));
234 u128 s2
= ((((((u128
)(d0
) * (r2
))) + (((u128
)(r1
) * (r1
))))) +
235 (((u128
)(d4
) * (r3
))));
236 u128 s3
= ((((((u128
)(d0
) * (r3
))) + (((u128
)(d1
) * (r2
))))) +
237 (((u128
)(r4
) * (d419
))));
238 u128 s4
= ((((((u128
)(d0
) * (r4
))) + (((u128
)(d1
) * (r3
))))) +
239 (((u128
)(r2
) * (r2
))));
247 static __always_inline
void fsquare_fsquare_(u128
*tmp
, u64
*output
)
257 fsquare_fsquare__(tmp
, output
);
258 fproduct_carry_wide_(tmp
);
261 b4_
= ((b4
) & (((u128
)(0x7ffffffffffffLLU
))));
262 b0_
= ((b0
) + (((u128
)(19) * (((u64
)(((b4
) >> (51))))))));
265 fproduct_copy_from_wide_(output
, tmp
);
268 i0_
= i0
& 0x7ffffffffffffLLU
;
269 i1_
= i1
+ (i0
>> 51);
274 static __always_inline
void fsquare_fsquare_times_(u64
*output
, u128
*tmp
,
278 fsquare_fsquare_(tmp
, output
);
279 for (i
= 1; i
< count1
; ++i
)
280 fsquare_fsquare_(tmp
, output
);
283 static __always_inline
void fsquare_fsquare_times(u64
*output
, u64
*input
,
287 memcpy(output
, input
, 5 * sizeof(*input
));
288 fsquare_fsquare_times_(output
, t
, count1
);
291 static __always_inline
void fsquare_fsquare_times_inplace(u64
*output
,
295 fsquare_fsquare_times_(output
, t
, count1
);
298 static __always_inline
void crecip_crecip(u64
*out
, u64
*z
)
311 fsquare_fsquare_times(a0
, z
, 1);
312 fsquare_fsquare_times(t00
, a0
, 2);
313 fmul_fmul(b0
, t00
, z
);
314 fmul_fmul(a0
, b0
, a0
);
315 fsquare_fsquare_times(t00
, a0
, 1);
316 fmul_fmul(b0
, t00
, b0
);
317 fsquare_fsquare_times(t00
, b0
, 5);
321 fmul_fmul(b1
, t01
, b1
);
322 fsquare_fsquare_times(t01
, b1
, 10);
323 fmul_fmul(c0
, t01
, b1
);
324 fsquare_fsquare_times(t01
, c0
, 20);
325 fmul_fmul(t01
, t01
, c0
);
326 fsquare_fsquare_times_inplace(t01
, 10);
327 fmul_fmul(b1
, t01
, b1
);
328 fsquare_fsquare_times(t01
, b1
, 50);
334 fsquare_fsquare_times(t0
, c
, 100);
335 fmul_fmul(t0
, t0
, c
);
336 fsquare_fsquare_times_inplace(t0
, 50);
337 fmul_fmul(t0
, t0
, b
);
338 fsquare_fsquare_times_inplace(t0
, 5);
339 fmul_fmul(out
, t0
, a
);
342 static __always_inline
void fsum(u64
*a
, u64
*b
)
351 static __always_inline
void fdifference(u64
*a
, u64
*b
)
359 memcpy(tmp
, b
, 5 * sizeof(*b
));
365 tmp
[0] = b0
+ 0x3fffffffffff68LLU
;
366 tmp
[1] = b1
+ 0x3ffffffffffff8LLU
;
367 tmp
[2] = b2
+ 0x3ffffffffffff8LLU
;
368 tmp
[3] = b3
+ 0x3ffffffffffff8LLU
;
369 tmp
[4] = b4
+ 0x3ffffffffffff8LLU
;
397 static __always_inline
void fscalar(u64
*output
, u64
*b
, u64 s
)
406 tmp
[0] = ((u128
)(xi
) * (s
));
410 tmp
[1] = ((u128
)(xi
) * (s
));
414 tmp
[2] = ((u128
)(xi
) * (s
));
418 tmp
[3] = ((u128
)(xi
) * (s
));
422 tmp
[4] = ((u128
)(xi
) * (s
));
424 fproduct_carry_wide_(tmp
);
427 b4_
= ((b4
) & (((u128
)(0x7ffffffffffffLLU
))));
428 b0_
= ((b0
) + (((u128
)(19) * (((u64
)(((b4
) >> (51))))))));
431 fproduct_copy_from_wide_(output
, tmp
);
434 static __always_inline
void fmul(u64
*output
, u64
*a
, u64
*b
)
436 fmul_fmul(output
, a
, b
);
439 static __always_inline
void crecip(u64
*output
, u64
*input
)
441 crecip_crecip(output
, input
);
444 static __always_inline
void point_swap_conditional_step(u64
*a
, u64
*b
,
450 u64 x
= swap1
& (ai
^ bi
);
457 static __always_inline
void point_swap_conditional5(u64
*a
, u64
*b
, u64 swap1
)
459 point_swap_conditional_step(a
, b
, swap1
, 5);
460 point_swap_conditional_step(a
, b
, swap1
, 4);
461 point_swap_conditional_step(a
, b
, swap1
, 3);
462 point_swap_conditional_step(a
, b
, swap1
, 2);
463 point_swap_conditional_step(a
, b
, swap1
, 1);
466 static __always_inline
void point_swap_conditional(u64
*a
, u64
*b
, u64 iswap
)
468 u64 swap1
= 0 - iswap
;
469 point_swap_conditional5(a
, b
, swap1
);
470 point_swap_conditional5(a
+ 5, b
+ 5, swap1
);
473 static __always_inline
void point_copy(u64
*output
, u64
*input
)
475 memcpy(output
, input
, 5 * sizeof(*input
));
476 memcpy(output
+ 5, input
+ 5, 5 * sizeof(*input
));
479 static __always_inline
void addanddouble_fmonty(u64
*pp
, u64
*ppq
, u64
*p
,
490 u64
*zprime
= pq
+ 5;
493 u64
*origxprime0
= buf
+ 5;
499 memcpy(origx
, x
, 5 * sizeof(*x
));
501 fdifference(z
, origx
);
502 memcpy(origxprime0
, xprime
, 5 * sizeof(*xprime
));
503 fsum(xprime
, zprime
);
504 fdifference(zprime
, origxprime0
);
505 fmul(xxprime0
, xprime
, z
);
506 fmul(zzprime0
, x
, zprime
);
507 origxprime
= buf
+ 5;
519 memcpy(origxprime
, xxprime
, 5 * sizeof(*xxprime
));
520 fsum(xxprime
, zzprime
);
521 fdifference(zzprime
, origxprime
);
522 fsquare_fsquare_times(x3
, xxprime
, 1);
523 fsquare_fsquare_times(zzzprime
, zzprime
, 1);
524 fmul(z3
, zzzprime
, qx
);
525 fsquare_fsquare_times(xx0
, x
, 1);
526 fsquare_fsquare_times(zz0
, z
, 1);
538 fscalar(zzz
, zz
, scalar
);
545 static __always_inline
void
546 ladder_smallloop_cmult_small_loop_step(u64
*nq
, u64
*nqpq
, u64
*nq2
, u64
*nqpq2
,
549 u64 bit0
= (u64
)(byt
>> 7);
551 point_swap_conditional(nq
, nqpq
, bit0
);
552 addanddouble_fmonty(nq2
, nqpq2
, nq
, nqpq
, q
);
553 bit
= (u64
)(byt
>> 7);
554 point_swap_conditional(nq2
, nqpq2
, bit
);
557 static __always_inline
void
558 ladder_smallloop_cmult_small_loop_double_step(u64
*nq
, u64
*nqpq
, u64
*nq2
,
559 u64
*nqpq2
, u64
*q
, u8 byt
)
562 ladder_smallloop_cmult_small_loop_step(nq
, nqpq
, nq2
, nqpq2
, q
, byt
);
564 ladder_smallloop_cmult_small_loop_step(nq2
, nqpq2
, nq
, nqpq
, q
, byt1
);
567 static __always_inline
void
568 ladder_smallloop_cmult_small_loop(u64
*nq
, u64
*nqpq
, u64
*nq2
, u64
*nqpq2
,
569 u64
*q
, u8 byt
, u32 i
)
572 ladder_smallloop_cmult_small_loop_double_step(nq
, nqpq
, nq2
,
578 static __always_inline
void ladder_bigloop_cmult_big_loop(u8
*n1
, u64
*nq
,
585 ladder_smallloop_cmult_small_loop(nq
, nqpq
, nq2
, nqpq2
, q
,
590 static void ladder_cmult(u64
*result
, u8
*n1
, u64
*q
)
592 u64 point_buf
[40] = { 0 };
594 u64
*nqpq
= point_buf
+ 10;
595 u64
*nq2
= point_buf
+ 20;
596 u64
*nqpq2
= point_buf
+ 30;
599 ladder_bigloop_cmult_big_loop(n1
, nq
, nqpq
, nq2
, nqpq2
, q
, 32);
600 point_copy(result
, nq
);
603 static __always_inline
void format_fexpand(u64
*output
, const u8
*input
)
605 const u8
*x00
= input
+ 6;
606 const u8
*x01
= input
+ 12;
607 const u8
*x02
= input
+ 19;
608 const u8
*x0
= input
+ 24;
609 u64 i0
, i1
, i2
, i3
, i4
, output0
, output1
, output2
, output3
, output4
;
610 i0
= get_unaligned_le64(input
);
611 i1
= get_unaligned_le64(x00
);
612 i2
= get_unaligned_le64(x01
);
613 i3
= get_unaligned_le64(x02
);
614 i4
= get_unaligned_le64(x0
);
615 output0
= i0
& 0x7ffffffffffffLLU
;
616 output1
= i1
>> 3 & 0x7ffffffffffffLLU
;
617 output2
= i2
>> 6 & 0x7ffffffffffffLLU
;
618 output3
= i3
>> 1 & 0x7ffffffffffffLLU
;
619 output4
= i4
>> 12 & 0x7ffffffffffffLLU
;
627 static __always_inline
void format_fcontract_first_carry_pass(u64
*input
)
634 u64 t1_
= t1
+ (t0
>> 51);
635 u64 t0_
= t0
& 0x7ffffffffffffLLU
;
636 u64 t2_
= t2
+ (t1_
>> 51);
637 u64 t1__
= t1_
& 0x7ffffffffffffLLU
;
638 u64 t3_
= t3
+ (t2_
>> 51);
639 u64 t2__
= t2_
& 0x7ffffffffffffLLU
;
640 u64 t4_
= t4
+ (t3_
>> 51);
641 u64 t3__
= t3_
& 0x7ffffffffffffLLU
;
649 static __always_inline
void format_fcontract_first_carry_full(u64
*input
)
651 format_fcontract_first_carry_pass(input
);
652 modulo_carry_top(input
);
655 static __always_inline
void format_fcontract_second_carry_pass(u64
*input
)
662 u64 t1_
= t1
+ (t0
>> 51);
663 u64 t0_
= t0
& 0x7ffffffffffffLLU
;
664 u64 t2_
= t2
+ (t1_
>> 51);
665 u64 t1__
= t1_
& 0x7ffffffffffffLLU
;
666 u64 t3_
= t3
+ (t2_
>> 51);
667 u64 t2__
= t2_
& 0x7ffffffffffffLLU
;
668 u64 t4_
= t4
+ (t3_
>> 51);
669 u64 t3__
= t3_
& 0x7ffffffffffffLLU
;
677 static __always_inline
void format_fcontract_second_carry_full(u64
*input
)
683 format_fcontract_second_carry_pass(input
);
684 modulo_carry_top(input
);
687 i0_
= i0
& 0x7ffffffffffffLLU
;
688 i1_
= i1
+ (i0
>> 51);
693 static __always_inline
void format_fcontract_trim(u64
*input
)
700 u64 mask0
= u64_gte_mask(a0
, 0x7ffffffffffedLLU
);
701 u64 mask1
= u64_eq_mask(a1
, 0x7ffffffffffffLLU
);
702 u64 mask2
= u64_eq_mask(a2
, 0x7ffffffffffffLLU
);
703 u64 mask3
= u64_eq_mask(a3
, 0x7ffffffffffffLLU
);
704 u64 mask4
= u64_eq_mask(a4
, 0x7ffffffffffffLLU
);
705 u64 mask
= (((mask0
& mask1
) & mask2
) & mask3
) & mask4
;
706 u64 a0_
= a0
- (0x7ffffffffffedLLU
& mask
);
707 u64 a1_
= a1
- (0x7ffffffffffffLLU
& mask
);
708 u64 a2_
= a2
- (0x7ffffffffffffLLU
& mask
);
709 u64 a3_
= a3
- (0x7ffffffffffffLLU
& mask
);
710 u64 a4_
= a4
- (0x7ffffffffffffLLU
& mask
);
718 static __always_inline
void format_fcontract_store(u8
*output
, u64
*input
)
725 u64 o0
= t1
<< 51 | t0
;
726 u64 o1
= t2
<< 38 | t1
>> 13;
727 u64 o2
= t3
<< 25 | t2
>> 26;
728 u64 o3
= t4
<< 12 | t3
>> 39;
731 u8
*b2
= output
+ 16;
732 u8
*b3
= output
+ 24;
733 put_unaligned_le64(o0
, b0
);
734 put_unaligned_le64(o1
, b1
);
735 put_unaligned_le64(o2
, b2
);
736 put_unaligned_le64(o3
, b3
);
739 static __always_inline
void format_fcontract(u8
*output
, u64
*input
)
741 format_fcontract_first_carry_full(input
);
742 format_fcontract_second_carry_full(input
);
743 format_fcontract_trim(input
);
744 format_fcontract_store(output
, input
);
747 static __always_inline
void format_scalar_of_point(u8
*scalar
, u64
*point
)
751 u64 buf
[10] __aligned(32) = { 0 };
756 format_fcontract(scalar
, sc
);
759 void curve25519_generic(u8 mypublic
[CURVE25519_KEY_SIZE
],
760 const u8 secret
[CURVE25519_KEY_SIZE
],
761 const u8 basepoint
[CURVE25519_KEY_SIZE
])
763 u64 buf0
[10] __aligned(32) = { 0 };
767 format_fexpand(x0
, basepoint
);
771 u8 e
[32] __aligned(32) = { 0 };
773 memcpy(e
, secret
, 32);
774 curve25519_clamp_secret(e
);
781 ladder_cmult(nq
, scalar
, q
);
782 format_scalar_of_point(mypublic
, nq
);
783 memzero_explicit(buf
, sizeof(buf
));
785 memzero_explicit(e
, sizeof(e
));
787 memzero_explicit(buf0
, sizeof(buf0
));