2 * AArch64-specific checksum implementation using NEON
4 * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 * See https://llvm.org/LICENSE.txt for license information.
6 * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
9 #include "networking.h"
10 #include "../chksum_common.h"
13 #pragma GCC target("+simd")
19 static inline uint64_t
20 slurp_head64(const void **pptr
, uint32_t *nbytes
)
24 uint32_t off
= (uintptr_t) *pptr
% 8;
27 /* Get rid of bytes 0..off-1 */
28 const unsigned char *ptr64
= align_ptr(*pptr
, 8);
29 uint64_t mask
= ALL_ONES
<< (CHAR_BIT
* off
);
30 uint64_t val
= load64(ptr64
) & mask
;
31 /* Fold 64-bit sum to 33 bits */
33 sum
+= (uint32_t) val
;
41 static inline uint64_t
42 slurp_tail64(uint64_t sum
, const void *ptr
, uint32_t nbytes
)
45 if (likely(nbytes
!= 0))
47 /* Get rid of bytes 7..nbytes */
48 uint64_t mask
= ALL_ONES
>> (CHAR_BIT
* (8 - nbytes
));
49 Assert(__builtin_popcountl(mask
) / CHAR_BIT
== nbytes
);
50 uint64_t val
= load64(ptr
) & mask
;
52 sum
+= (uint32_t) val
;
60 __chksum_aarch64_simd(const void *ptr
, unsigned int nbytes
)
62 bool swap
= (uintptr_t) ptr
& 1;
65 if (unlikely(nbytes
< 50))
67 sum
= slurp_small(ptr
, nbytes
);
72 /* 8-byte align pointer */
74 sum
= slurp_head64(&ptr
, &nbytes
);
75 Assert(((uintptr_t) ptr
& 7) == 0);
77 const uint32_t *may_alias ptr32
= ptr
;
79 uint64x2_t vsum0
= { 0, 0 };
80 uint64x2_t vsum1
= { 0, 0 };
81 uint64x2_t vsum2
= { 0, 0 };
82 uint64x2_t vsum3
= { 0, 0 };
84 /* Sum groups of 64 bytes */
85 for (uint32_t i
= 0; i
< nbytes
/ 64; i
++)
87 uint32x4_t vtmp0
= vld1q_u32(ptr32
);
88 uint32x4_t vtmp1
= vld1q_u32(ptr32
+ 4);
89 uint32x4_t vtmp2
= vld1q_u32(ptr32
+ 8);
90 uint32x4_t vtmp3
= vld1q_u32(ptr32
+ 12);
91 vsum0
= vpadalq_u32(vsum0
, vtmp0
);
92 vsum1
= vpadalq_u32(vsum1
, vtmp1
);
93 vsum2
= vpadalq_u32(vsum2
, vtmp2
);
94 vsum3
= vpadalq_u32(vsum3
, vtmp3
);
99 /* Fold vsum2 and vsum3 into vsum0 and vsum1 */
100 vsum0
= vpadalq_u32(vsum0
, vreinterpretq_u32_u64(vsum2
));
101 vsum1
= vpadalq_u32(vsum1
, vreinterpretq_u32_u64(vsum3
));
103 /* Add any trailing group of 32 bytes */
106 uint32x4_t vtmp0
= vld1q_u32(ptr32
);
107 uint32x4_t vtmp1
= vld1q_u32(ptr32
+ 4);
108 vsum0
= vpadalq_u32(vsum0
, vtmp0
);
109 vsum1
= vpadalq_u32(vsum1
, vtmp1
);
115 /* Fold vsum1 into vsum0 */
116 vsum0
= vpadalq_u32(vsum0
, vreinterpretq_u32_u64(vsum1
));
118 /* Add any trailing group of 16 bytes */
121 uint32x4_t vtmp
= vld1q_u32(ptr32
);
122 vsum0
= vpadalq_u32(vsum0
, vtmp
);
128 /* Add any trailing group of 8 bytes */
131 uint32x2_t vtmp
= vld1_u32(ptr32
);
132 vsum0
= vaddw_u32(vsum0
, vtmp
);
138 uint64_t val
= vaddlvq_u32(vreinterpretq_u32_u64(vsum0
));
140 sum
+= (uint32_t) val
;
142 /* Handle any trailing 0..7 bytes */
143 sum
= slurp_tail64(sum
, ptr32
, nbytes
);
146 return fold_and_swap(sum
, swap
);