Merge tag 'trace-printf-v6.13' of git://git.kernel.org/pub/scm/linux/kernel/git/trace...
[drm/drm-misc.git] / arch / arm64 / lib / csum.c
blob2432683e48a61f2186a329db893bced8157a3cf1
1 // SPDX-License-Identifier: GPL-2.0-only
2 // Copyright (C) 2019-2020 Arm Ltd.
4 #include <linux/compiler.h>
5 #include <linux/kasan-checks.h>
6 #include <linux/kernel.h>
8 #include <net/checksum.h>
10 /* Looks dumb, but generates nice-ish code */
11 static u64 accumulate(u64 sum, u64 data)
13 __uint128_t tmp = (__uint128_t)sum + data;
14 return tmp + (tmp >> 64);
18 * We over-read the buffer and this makes KASAN unhappy. Instead, disable
19 * instrumentation and call kasan explicitly.
21 unsigned int __no_sanitize_address do_csum(const unsigned char *buff, int len)
23 unsigned int offset, shift, sum;
24 const u64 *ptr;
25 u64 data, sum64 = 0;
27 if (unlikely(len <= 0))
28 return 0;
30 offset = (unsigned long)buff & 7;
32 * This is to all intents and purposes safe, since rounding down cannot
33 * result in a different page or cache line being accessed, and @buff
34 * should absolutely not be pointing to anything read-sensitive. We do,
35 * however, have to be careful not to piss off KASAN, which means using
36 * unchecked reads to accommodate the head and tail, for which we'll
37 * compensate with an explicit check up-front.
39 kasan_check_read(buff, len);
40 ptr = (u64 *)(buff - offset);
41 len = len + offset - 8;
44 * Head: zero out any excess leading bytes. Shifting back by the same
45 * amount should be at least as fast as any other way of handling the
46 * odd/even alignment, and means we can ignore it until the very end.
48 shift = offset * 8;
49 data = *ptr++;
50 #ifdef __LITTLE_ENDIAN
51 data = (data >> shift) << shift;
52 #else
53 data = (data << shift) >> shift;
54 #endif
57 * Body: straightforward aligned loads from here on (the paired loads
58 * underlying the quadword type still only need dword alignment). The
59 * main loop strictly excludes the tail, so the second loop will always
60 * run at least once.
62 while (unlikely(len > 64)) {
63 __uint128_t tmp1, tmp2, tmp3, tmp4;
65 tmp1 = *(__uint128_t *)ptr;
66 tmp2 = *(__uint128_t *)(ptr + 2);
67 tmp3 = *(__uint128_t *)(ptr + 4);
68 tmp4 = *(__uint128_t *)(ptr + 6);
70 len -= 64;
71 ptr += 8;
73 /* This is the "don't dump the carry flag into a GPR" idiom */
74 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
75 tmp2 += (tmp2 >> 64) | (tmp2 << 64);
76 tmp3 += (tmp3 >> 64) | (tmp3 << 64);
77 tmp4 += (tmp4 >> 64) | (tmp4 << 64);
78 tmp1 = ((tmp1 >> 64) << 64) | (tmp2 >> 64);
79 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
80 tmp3 = ((tmp3 >> 64) << 64) | (tmp4 >> 64);
81 tmp3 += (tmp3 >> 64) | (tmp3 << 64);
82 tmp1 = ((tmp1 >> 64) << 64) | (tmp3 >> 64);
83 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
84 tmp1 = ((tmp1 >> 64) << 64) | sum64;
85 tmp1 += (tmp1 >> 64) | (tmp1 << 64);
86 sum64 = tmp1 >> 64;
88 while (len > 8) {
89 __uint128_t tmp;
91 sum64 = accumulate(sum64, data);
92 tmp = *(__uint128_t *)ptr;
94 len -= 16;
95 ptr += 2;
97 #ifdef __LITTLE_ENDIAN
98 data = tmp >> 64;
99 sum64 = accumulate(sum64, tmp);
100 #else
101 data = tmp;
102 sum64 = accumulate(sum64, tmp >> 64);
103 #endif
105 if (len > 0) {
106 sum64 = accumulate(sum64, data);
107 data = *ptr;
108 len -= 8;
111 * Tail: zero any over-read bytes similarly to the head, again
112 * preserving odd/even alignment.
114 shift = len * -8;
115 #ifdef __LITTLE_ENDIAN
116 data = (data << shift) >> shift;
117 #else
118 data = (data >> shift) << shift;
119 #endif
120 sum64 = accumulate(sum64, data);
122 /* Finally, folding */
123 sum64 += (sum64 >> 32) | (sum64 << 32);
124 sum = sum64 >> 32;
125 sum += (sum >> 16) | (sum << 16);
126 if (offset & 1)
127 return (u16)swab32(sum);
129 return sum >> 16;
132 __sum16 csum_ipv6_magic(const struct in6_addr *saddr,
133 const struct in6_addr *daddr,
134 __u32 len, __u8 proto, __wsum csum)
136 __uint128_t src, dst;
137 u64 sum = (__force u64)csum;
139 src = *(const __uint128_t *)saddr->s6_addr;
140 dst = *(const __uint128_t *)daddr->s6_addr;
142 sum += (__force u32)htonl(len);
143 #ifdef __LITTLE_ENDIAN
144 sum += (u32)proto << 24;
145 #else
146 sum += proto;
147 #endif
148 src += (src >> 64) | (src << 64);
149 dst += (dst >> 64) | (dst << 64);
151 sum = accumulate(sum, src >> 64);
152 sum = accumulate(sum, dst >> 64);
154 sum += ((sum >> 32) | (sum << 32));
155 return csum_fold((__force __wsum)(sum >> 32));
157 EXPORT_SYMBOL(csum_ipv6_magic);