1 #include "blake3_impl.h"
7 #ifdef __ARM_BIG_ENDIAN
8 #error "This implementation only supports little-endian ARM."
9 // It might be that all we need for big-endian support here is to get the loads
10 // and stores right, but step zero would be finding a way to test it in CI.
13 INLINE uint32x4_t
loadu_128(const uint8_t src
[16]) {
14 // vld1q_u32 has alignment requirements. Don't use it.
20 INLINE
void storeu_128(uint32x4_t src
, uint8_t dest
[16]) {
21 // vst1q_u32 has alignment requirements. Don't use it.
22 memcpy(dest
, &src
, 16);
25 INLINE uint32x4_t
add_128(uint32x4_t a
, uint32x4_t b
) {
26 return vaddq_u32(a
, b
);
29 INLINE uint32x4_t
xor_128(uint32x4_t a
, uint32x4_t b
) {
30 return veorq_u32(a
, b
);
33 INLINE uint32x4_t
set1_128(uint32_t x
) { return vld1q_dup_u32(&x
); }
35 INLINE uint32x4_t
set4(uint32_t a
, uint32_t b
, uint32_t c
, uint32_t d
) {
36 uint32_t array
[4] = {a
, b
, c
, d
};
37 return vld1q_u32(array
);
40 INLINE uint32x4_t
rot16_128(uint32x4_t x
) {
41 return vorrq_u32(vshrq_n_u32(x
, 16), vshlq_n_u32(x
, 32 - 16));
44 INLINE uint32x4_t
rot12_128(uint32x4_t x
) {
45 return vorrq_u32(vshrq_n_u32(x
, 12), vshlq_n_u32(x
, 32 - 12));
48 INLINE uint32x4_t
rot8_128(uint32x4_t x
) {
49 return vorrq_u32(vshrq_n_u32(x
, 8), vshlq_n_u32(x
, 32 - 8));
52 INLINE uint32x4_t
rot7_128(uint32x4_t x
) {
53 return vorrq_u32(vshrq_n_u32(x
, 7), vshlq_n_u32(x
, 32 - 7));
56 // TODO: compress_neon
61 * ----------------------------------------------------------------------------
63 * ----------------------------------------------------------------------------
66 INLINE
void round_fn4(uint32x4_t v
[16], uint32x4_t m
[16], size_t r
) {
67 v
[0] = add_128(v
[0], m
[(size_t)MSG_SCHEDULE
[r
][0]]);
68 v
[1] = add_128(v
[1], m
[(size_t)MSG_SCHEDULE
[r
][2]]);
69 v
[2] = add_128(v
[2], m
[(size_t)MSG_SCHEDULE
[r
][4]]);
70 v
[3] = add_128(v
[3], m
[(size_t)MSG_SCHEDULE
[r
][6]]);
71 v
[0] = add_128(v
[0], v
[4]);
72 v
[1] = add_128(v
[1], v
[5]);
73 v
[2] = add_128(v
[2], v
[6]);
74 v
[3] = add_128(v
[3], v
[7]);
75 v
[12] = xor_128(v
[12], v
[0]);
76 v
[13] = xor_128(v
[13], v
[1]);
77 v
[14] = xor_128(v
[14], v
[2]);
78 v
[15] = xor_128(v
[15], v
[3]);
79 v
[12] = rot16_128(v
[12]);
80 v
[13] = rot16_128(v
[13]);
81 v
[14] = rot16_128(v
[14]);
82 v
[15] = rot16_128(v
[15]);
83 v
[8] = add_128(v
[8], v
[12]);
84 v
[9] = add_128(v
[9], v
[13]);
85 v
[10] = add_128(v
[10], v
[14]);
86 v
[11] = add_128(v
[11], v
[15]);
87 v
[4] = xor_128(v
[4], v
[8]);
88 v
[5] = xor_128(v
[5], v
[9]);
89 v
[6] = xor_128(v
[6], v
[10]);
90 v
[7] = xor_128(v
[7], v
[11]);
91 v
[4] = rot12_128(v
[4]);
92 v
[5] = rot12_128(v
[5]);
93 v
[6] = rot12_128(v
[6]);
94 v
[7] = rot12_128(v
[7]);
95 v
[0] = add_128(v
[0], m
[(size_t)MSG_SCHEDULE
[r
][1]]);
96 v
[1] = add_128(v
[1], m
[(size_t)MSG_SCHEDULE
[r
][3]]);
97 v
[2] = add_128(v
[2], m
[(size_t)MSG_SCHEDULE
[r
][5]]);
98 v
[3] = add_128(v
[3], m
[(size_t)MSG_SCHEDULE
[r
][7]]);
99 v
[0] = add_128(v
[0], v
[4]);
100 v
[1] = add_128(v
[1], v
[5]);
101 v
[2] = add_128(v
[2], v
[6]);
102 v
[3] = add_128(v
[3], v
[7]);
103 v
[12] = xor_128(v
[12], v
[0]);
104 v
[13] = xor_128(v
[13], v
[1]);
105 v
[14] = xor_128(v
[14], v
[2]);
106 v
[15] = xor_128(v
[15], v
[3]);
107 v
[12] = rot8_128(v
[12]);
108 v
[13] = rot8_128(v
[13]);
109 v
[14] = rot8_128(v
[14]);
110 v
[15] = rot8_128(v
[15]);
111 v
[8] = add_128(v
[8], v
[12]);
112 v
[9] = add_128(v
[9], v
[13]);
113 v
[10] = add_128(v
[10], v
[14]);
114 v
[11] = add_128(v
[11], v
[15]);
115 v
[4] = xor_128(v
[4], v
[8]);
116 v
[5] = xor_128(v
[5], v
[9]);
117 v
[6] = xor_128(v
[6], v
[10]);
118 v
[7] = xor_128(v
[7], v
[11]);
119 v
[4] = rot7_128(v
[4]);
120 v
[5] = rot7_128(v
[5]);
121 v
[6] = rot7_128(v
[6]);
122 v
[7] = rot7_128(v
[7]);
124 v
[0] = add_128(v
[0], m
[(size_t)MSG_SCHEDULE
[r
][8]]);
125 v
[1] = add_128(v
[1], m
[(size_t)MSG_SCHEDULE
[r
][10]]);
126 v
[2] = add_128(v
[2], m
[(size_t)MSG_SCHEDULE
[r
][12]]);
127 v
[3] = add_128(v
[3], m
[(size_t)MSG_SCHEDULE
[r
][14]]);
128 v
[0] = add_128(v
[0], v
[5]);
129 v
[1] = add_128(v
[1], v
[6]);
130 v
[2] = add_128(v
[2], v
[7]);
131 v
[3] = add_128(v
[3], v
[4]);
132 v
[15] = xor_128(v
[15], v
[0]);
133 v
[12] = xor_128(v
[12], v
[1]);
134 v
[13] = xor_128(v
[13], v
[2]);
135 v
[14] = xor_128(v
[14], v
[3]);
136 v
[15] = rot16_128(v
[15]);
137 v
[12] = rot16_128(v
[12]);
138 v
[13] = rot16_128(v
[13]);
139 v
[14] = rot16_128(v
[14]);
140 v
[10] = add_128(v
[10], v
[15]);
141 v
[11] = add_128(v
[11], v
[12]);
142 v
[8] = add_128(v
[8], v
[13]);
143 v
[9] = add_128(v
[9], v
[14]);
144 v
[5] = xor_128(v
[5], v
[10]);
145 v
[6] = xor_128(v
[6], v
[11]);
146 v
[7] = xor_128(v
[7], v
[8]);
147 v
[4] = xor_128(v
[4], v
[9]);
148 v
[5] = rot12_128(v
[5]);
149 v
[6] = rot12_128(v
[6]);
150 v
[7] = rot12_128(v
[7]);
151 v
[4] = rot12_128(v
[4]);
152 v
[0] = add_128(v
[0], m
[(size_t)MSG_SCHEDULE
[r
][9]]);
153 v
[1] = add_128(v
[1], m
[(size_t)MSG_SCHEDULE
[r
][11]]);
154 v
[2] = add_128(v
[2], m
[(size_t)MSG_SCHEDULE
[r
][13]]);
155 v
[3] = add_128(v
[3], m
[(size_t)MSG_SCHEDULE
[r
][15]]);
156 v
[0] = add_128(v
[0], v
[5]);
157 v
[1] = add_128(v
[1], v
[6]);
158 v
[2] = add_128(v
[2], v
[7]);
159 v
[3] = add_128(v
[3], v
[4]);
160 v
[15] = xor_128(v
[15], v
[0]);
161 v
[12] = xor_128(v
[12], v
[1]);
162 v
[13] = xor_128(v
[13], v
[2]);
163 v
[14] = xor_128(v
[14], v
[3]);
164 v
[15] = rot8_128(v
[15]);
165 v
[12] = rot8_128(v
[12]);
166 v
[13] = rot8_128(v
[13]);
167 v
[14] = rot8_128(v
[14]);
168 v
[10] = add_128(v
[10], v
[15]);
169 v
[11] = add_128(v
[11], v
[12]);
170 v
[8] = add_128(v
[8], v
[13]);
171 v
[9] = add_128(v
[9], v
[14]);
172 v
[5] = xor_128(v
[5], v
[10]);
173 v
[6] = xor_128(v
[6], v
[11]);
174 v
[7] = xor_128(v
[7], v
[8]);
175 v
[4] = xor_128(v
[4], v
[9]);
176 v
[5] = rot7_128(v
[5]);
177 v
[6] = rot7_128(v
[6]);
178 v
[7] = rot7_128(v
[7]);
179 v
[4] = rot7_128(v
[4]);
182 INLINE
void transpose_vecs_128(uint32x4_t vecs
[4]) {
183 // Individually transpose the four 2x2 sub-matrices in each corner.
184 uint32x4x2_t rows01
= vtrnq_u32(vecs
[0], vecs
[1]);
185 uint32x4x2_t rows23
= vtrnq_u32(vecs
[2], vecs
[3]);
187 // Swap the top-right and bottom-left 2x2s (which just got transposed).
189 vcombine_u32(vget_low_u32(rows01
.val
[0]), vget_low_u32(rows23
.val
[0]));
191 vcombine_u32(vget_low_u32(rows01
.val
[1]), vget_low_u32(rows23
.val
[1]));
193 vcombine_u32(vget_high_u32(rows01
.val
[0]), vget_high_u32(rows23
.val
[0]));
195 vcombine_u32(vget_high_u32(rows01
.val
[1]), vget_high_u32(rows23
.val
[1]));
198 INLINE
void transpose_msg_vecs4(const uint8_t *const *inputs
,
199 size_t block_offset
, uint32x4_t out
[16]) {
200 out
[0] = loadu_128(&inputs
[0][block_offset
+ 0 * sizeof(uint32x4_t
)]);
201 out
[1] = loadu_128(&inputs
[1][block_offset
+ 0 * sizeof(uint32x4_t
)]);
202 out
[2] = loadu_128(&inputs
[2][block_offset
+ 0 * sizeof(uint32x4_t
)]);
203 out
[3] = loadu_128(&inputs
[3][block_offset
+ 0 * sizeof(uint32x4_t
)]);
204 out
[4] = loadu_128(&inputs
[0][block_offset
+ 1 * sizeof(uint32x4_t
)]);
205 out
[5] = loadu_128(&inputs
[1][block_offset
+ 1 * sizeof(uint32x4_t
)]);
206 out
[6] = loadu_128(&inputs
[2][block_offset
+ 1 * sizeof(uint32x4_t
)]);
207 out
[7] = loadu_128(&inputs
[3][block_offset
+ 1 * sizeof(uint32x4_t
)]);
208 out
[8] = loadu_128(&inputs
[0][block_offset
+ 2 * sizeof(uint32x4_t
)]);
209 out
[9] = loadu_128(&inputs
[1][block_offset
+ 2 * sizeof(uint32x4_t
)]);
210 out
[10] = loadu_128(&inputs
[2][block_offset
+ 2 * sizeof(uint32x4_t
)]);
211 out
[11] = loadu_128(&inputs
[3][block_offset
+ 2 * sizeof(uint32x4_t
)]);
212 out
[12] = loadu_128(&inputs
[0][block_offset
+ 3 * sizeof(uint32x4_t
)]);
213 out
[13] = loadu_128(&inputs
[1][block_offset
+ 3 * sizeof(uint32x4_t
)]);
214 out
[14] = loadu_128(&inputs
[2][block_offset
+ 3 * sizeof(uint32x4_t
)]);
215 out
[15] = loadu_128(&inputs
[3][block_offset
+ 3 * sizeof(uint32x4_t
)]);
216 transpose_vecs_128(&out
[0]);
217 transpose_vecs_128(&out
[4]);
218 transpose_vecs_128(&out
[8]);
219 transpose_vecs_128(&out
[12]);
222 INLINE
void load_counters4(uint64_t counter
, bool increment_counter
,
223 uint32x4_t
*out_low
, uint32x4_t
*out_high
) {
224 uint64_t mask
= (increment_counter
? ~0 : 0);
226 counter_low(counter
+ (mask
& 0)), counter_low(counter
+ (mask
& 1)),
227 counter_low(counter
+ (mask
& 2)), counter_low(counter
+ (mask
& 3)));
229 counter_high(counter
+ (mask
& 0)), counter_high(counter
+ (mask
& 1)),
230 counter_high(counter
+ (mask
& 2)), counter_high(counter
+ (mask
& 3)));
234 void blake3_hash4_neon(const uint8_t *const *inputs
, size_t blocks
,
235 const uint32_t key
[8], uint64_t counter
,
236 bool increment_counter
, uint8_t flags
,
237 uint8_t flags_start
, uint8_t flags_end
, uint8_t *out
) {
238 uint32x4_t h_vecs
[8] = {
239 set1_128(key
[0]), set1_128(key
[1]), set1_128(key
[2]), set1_128(key
[3]),
240 set1_128(key
[4]), set1_128(key
[5]), set1_128(key
[6]), set1_128(key
[7]),
242 uint32x4_t counter_low_vec
, counter_high_vec
;
243 load_counters4(counter
, increment_counter
, &counter_low_vec
,
245 uint8_t block_flags
= flags
| flags_start
;
247 for (size_t block
= 0; block
< blocks
; block
++) {
248 if (block
+ 1 == blocks
) {
249 block_flags
|= flags_end
;
251 uint32x4_t block_len_vec
= set1_128(BLAKE3_BLOCK_LEN
);
252 uint32x4_t block_flags_vec
= set1_128(block_flags
);
253 uint32x4_t msg_vecs
[16];
254 transpose_msg_vecs4(inputs
, block
* BLAKE3_BLOCK_LEN
, msg_vecs
);
257 h_vecs
[0], h_vecs
[1], h_vecs
[2], h_vecs
[3],
258 h_vecs
[4], h_vecs
[5], h_vecs
[6], h_vecs
[7],
259 set1_128(IV
[0]), set1_128(IV
[1]), set1_128(IV
[2]), set1_128(IV
[3]),
260 counter_low_vec
, counter_high_vec
, block_len_vec
, block_flags_vec
,
262 round_fn4(v
, msg_vecs
, 0);
263 round_fn4(v
, msg_vecs
, 1);
264 round_fn4(v
, msg_vecs
, 2);
265 round_fn4(v
, msg_vecs
, 3);
266 round_fn4(v
, msg_vecs
, 4);
267 round_fn4(v
, msg_vecs
, 5);
268 round_fn4(v
, msg_vecs
, 6);
269 h_vecs
[0] = xor_128(v
[0], v
[8]);
270 h_vecs
[1] = xor_128(v
[1], v
[9]);
271 h_vecs
[2] = xor_128(v
[2], v
[10]);
272 h_vecs
[3] = xor_128(v
[3], v
[11]);
273 h_vecs
[4] = xor_128(v
[4], v
[12]);
274 h_vecs
[5] = xor_128(v
[5], v
[13]);
275 h_vecs
[6] = xor_128(v
[6], v
[14]);
276 h_vecs
[7] = xor_128(v
[7], v
[15]);
281 transpose_vecs_128(&h_vecs
[0]);
282 transpose_vecs_128(&h_vecs
[4]);
283 // The first four vecs now contain the first half of each output, and the
284 // second four vecs contain the second half of each output.
285 storeu_128(h_vecs
[0], &out
[0 * sizeof(uint32x4_t
)]);
286 storeu_128(h_vecs
[4], &out
[1 * sizeof(uint32x4_t
)]);
287 storeu_128(h_vecs
[1], &out
[2 * sizeof(uint32x4_t
)]);
288 storeu_128(h_vecs
[5], &out
[3 * sizeof(uint32x4_t
)]);
289 storeu_128(h_vecs
[2], &out
[4 * sizeof(uint32x4_t
)]);
290 storeu_128(h_vecs
[6], &out
[5 * sizeof(uint32x4_t
)]);
291 storeu_128(h_vecs
[3], &out
[6 * sizeof(uint32x4_t
)]);
292 storeu_128(h_vecs
[7], &out
[7 * sizeof(uint32x4_t
)]);
296 * ----------------------------------------------------------------------------
298 * ----------------------------------------------------------------------------
301 void blake3_compress_in_place_portable(uint32_t cv
[8],
302 const uint8_t block
[BLAKE3_BLOCK_LEN
],
303 uint8_t block_len
, uint64_t counter
,
306 INLINE
void hash_one_neon(const uint8_t *input
, size_t blocks
,
307 const uint32_t key
[8], uint64_t counter
,
308 uint8_t flags
, uint8_t flags_start
, uint8_t flags_end
,
309 uint8_t out
[BLAKE3_OUT_LEN
]) {
311 memcpy(cv
, key
, BLAKE3_KEY_LEN
);
312 uint8_t block_flags
= flags
| flags_start
;
315 block_flags
|= flags_end
;
317 // TODO: Implement compress_neon. However note that according to
318 // https://github.com/BLAKE2/BLAKE2/commit/7965d3e6e1b4193438b8d3a656787587d2579227,
319 // compress_neon might not be any faster than compress_portable.
320 blake3_compress_in_place_portable(cv
, input
, BLAKE3_BLOCK_LEN
, counter
,
322 input
= &input
[BLAKE3_BLOCK_LEN
];
326 memcpy(out
, cv
, BLAKE3_OUT_LEN
);
329 void blake3_hash_many_neon(const uint8_t *const *inputs
, size_t num_inputs
,
330 size_t blocks
, const uint32_t key
[8],
331 uint64_t counter
, bool increment_counter
,
332 uint8_t flags
, uint8_t flags_start
,
333 uint8_t flags_end
, uint8_t *out
) {
334 while (num_inputs
>= 4) {
335 blake3_hash4_neon(inputs
, blocks
, key
, counter
, increment_counter
, flags
,
336 flags_start
, flags_end
, out
);
337 if (increment_counter
) {
342 out
= &out
[4 * BLAKE3_OUT_LEN
];
344 while (num_inputs
> 0) {
345 hash_one_neon(inputs
[0], blocks
, key
, counter
, flags
, flags_start
,
347 if (increment_counter
) {
352 out
= &out
[BLAKE3_OUT_LEN
];
356 #endif // BLAKE3_USE_NEON