2 * Armv7-A specific checksum implementation using NEON
4 * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 * See https://llvm.org/LICENSE.txt for license information.
6 * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
9 #include "networking.h"
10 #include "../chksum_common.h"
13 #pragma GCC target("+simd")
19 __chksum_arm_simd(const void *ptr
, unsigned int nbytes
)
21 bool swap
= (uintptr_t) ptr
& 1;
22 uint64x1_t vsum
= { 0 };
24 if (unlikely(nbytes
< 40))
26 uint64_t sum
= slurp_small(ptr
, nbytes
);
27 return fold_and_swap(sum
, false);
30 /* 8-byte align pointer */
31 /* Inline slurp_head-like code since we use NEON here */
33 uint32_t off
= (uintptr_t) ptr
& 7;
36 const uint64_t *may_alias ptr64
= align_ptr(ptr
, 8);
37 uint64x1_t vword64
= vld1_u64(ptr64
);
38 /* Get rid of bytes 0..off-1 */
39 uint64x1_t vmask
= vdup_n_u64(ALL_ONES
);
40 int64x1_t vshiftl
= vdup_n_s64(CHAR_BIT
* off
);
41 vmask
= vshl_u64(vmask
, vshiftl
);
42 vword64
= vand_u64(vword64
, vmask
);
43 uint32x2_t vtmp
= vreinterpret_u32_u64(vword64
);
45 vsum
= vpaddl_u32(vtmp
);
46 /* Update pointer and remaining size */
47 ptr
= (char *) ptr64
+ 8;
50 Assert(((uintptr_t) ptr
& 7) == 0);
52 /* Sum groups of 64 bytes */
53 uint64x2_t vsum0
= { 0, 0 };
54 uint64x2_t vsum1
= { 0, 0 };
55 uint64x2_t vsum2
= { 0, 0 };
56 uint64x2_t vsum3
= { 0, 0 };
57 const uint32_t *may_alias ptr32
= ptr
;
58 for (uint32_t i
= 0; i
< nbytes
/ 64; i
++)
60 uint32x4_t vtmp0
= vld1q_u32(ptr32
);
61 uint32x4_t vtmp1
= vld1q_u32(ptr32
+ 4);
62 uint32x4_t vtmp2
= vld1q_u32(ptr32
+ 8);
63 uint32x4_t vtmp3
= vld1q_u32(ptr32
+ 12);
64 vsum0
= vpadalq_u32(vsum0
, vtmp0
);
65 vsum1
= vpadalq_u32(vsum1
, vtmp1
);
66 vsum2
= vpadalq_u32(vsum2
, vtmp2
);
67 vsum3
= vpadalq_u32(vsum3
, vtmp3
);
72 /* Fold vsum1/vsum2/vsum3 into vsum0 */
73 vsum0
= vpadalq_u32(vsum0
, vreinterpretq_u32_u64(vsum2
));
74 vsum1
= vpadalq_u32(vsum1
, vreinterpretq_u32_u64(vsum3
));
75 vsum0
= vpadalq_u32(vsum0
, vreinterpretq_u32_u64(vsum1
));
77 /* Add any trailing 16-byte groups */
78 while (likely(nbytes
>= 16))
80 uint32x4_t vtmp0
= vld1q_u32(ptr32
);
81 vsum0
= vpadalq_u32(vsum0
, vtmp0
);
87 /* Fold vsum0 into vsum */
89 /* 4xu32 (4x32b) -> 2xu64 (2x33b) */
90 vsum0
= vpaddlq_u32(vreinterpretq_u32_u64(vsum0
));
91 /* 4xu32 (2x(1b+32b)) -> 2xu64 (2x(0b+32b)) */
92 vsum0
= vpaddlq_u32(vreinterpretq_u32_u64(vsum0
));
93 /* 4xu32 (4x32b) -> 2xu64 (2x33b) */
94 Assert((vgetq_lane_u64(vsum0
, 0) >> 32) == 0);
95 Assert((vgetq_lane_u64(vsum0
, 1) >> 32) == 0);
96 uint32x2_t vtmp
= vmovn_u64(vsum0
);
97 /* Add to accumulator */
98 vsum
= vpadal_u32(vsum
, vtmp
);
101 /* Add any trailing group of 8 bytes */
104 uint32x2_t vtmp
= vld1_u32(ptr32
);
105 /* Add to accumulator */
106 vsum
= vpadal_u32(vsum
, vtmp
);
112 /* Handle any trailing 1..7 bytes */
113 if (likely(nbytes
!= 0))
115 Assert(((uintptr_t) ptr32
& 7) == 0);
117 uint64x1_t vword64
= vld1_u64((const uint64_t *) ptr32
);
118 /* Get rid of bytes 7..nbytes */
119 uint64x1_t vmask
= vdup_n_u64(ALL_ONES
);
120 int64x1_t vshiftr
= vdup_n_s64(-CHAR_BIT
* (8 - nbytes
));
121 vmask
= vshl_u64(vmask
, vshiftr
);/* Shift right */
122 vword64
= vand_u64(vword64
, vmask
);
123 /* Fold 64-bit sum to 33 bits */
124 vword64
= vpaddl_u32(vreinterpret_u32_u64(vword64
));
125 /* Add to accumulator */
126 vsum
= vpadal_u32(vsum
, vreinterpret_u32_u64(vword64
));
129 /* Fold 64-bit vsum to 32 bits */
130 vsum
= vpaddl_u32(vreinterpret_u32_u64(vsum
));
131 vsum
= vpaddl_u32(vreinterpret_u32_u64(vsum
));
132 Assert(vget_lane_u32(vreinterpret_u32_u64(vsum
), 1) == 0);
134 /* Fold 32-bit vsum to 16 bits */
135 uint32x2_t vsum32
= vreinterpret_u32_u64(vsum
);
136 vsum32
= vpaddl_u16(vreinterpret_u16_u32(vsum32
));
137 vsum32
= vpaddl_u16(vreinterpret_u16_u32(vsum32
));
138 Assert(vget_lane_u16(vreinterpret_u16_u32(vsum32
), 1) == 0);
139 Assert(vget_lane_u16(vreinterpret_u16_u32(vsum32
), 2) == 0);
140 Assert(vget_lane_u16(vreinterpret_u16_u32(vsum32
), 3) == 0);
142 /* Convert to 16-bit scalar */
143 uint16_t sum
= vget_lane_u16(vreinterpret_u16_u32(vsum32
), 0);
145 if (unlikely(swap
))/* Odd base pointer is unexpected */