1 // SPDX-License-Identifier: GPL-2.0-only
3 * Bit sliced AES using NEON instructions
5 * Copyright (C) 2016 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
10 #include <crypto/aes.h>
11 #include <crypto/ctr.h>
12 #include <crypto/internal/simd.h>
13 #include <crypto/internal/skcipher.h>
14 #include <crypto/scatterwalk.h>
15 #include <crypto/xts.h>
16 #include <linux/module.h>
18 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
19 MODULE_DESCRIPTION("Bit sliced AES using NEON instructions");
20 MODULE_LICENSE("GPL v2");
22 MODULE_ALIAS_CRYPTO("ecb(aes)");
23 MODULE_ALIAS_CRYPTO("cbc(aes)");
24 MODULE_ALIAS_CRYPTO("ctr(aes)");
25 MODULE_ALIAS_CRYPTO("xts(aes)");
27 asmlinkage
void aesbs_convert_key(u8 out
[], u32
const rk
[], int rounds
);
29 asmlinkage
void aesbs_ecb_encrypt(u8 out
[], u8
const in
[], u8
const rk
[],
30 int rounds
, int blocks
);
31 asmlinkage
void aesbs_ecb_decrypt(u8 out
[], u8
const in
[], u8
const rk
[],
32 int rounds
, int blocks
);
34 asmlinkage
void aesbs_cbc_decrypt(u8 out
[], u8
const in
[], u8
const rk
[],
35 int rounds
, int blocks
, u8 iv
[]);
37 asmlinkage
void aesbs_ctr_encrypt(u8 out
[], u8
const in
[], u8
const rk
[],
38 int rounds
, int blocks
, u8 iv
[]);
40 asmlinkage
void aesbs_xts_encrypt(u8 out
[], u8
const in
[], u8
const rk
[],
41 int rounds
, int blocks
, u8 iv
[]);
42 asmlinkage
void aesbs_xts_decrypt(u8 out
[], u8
const in
[], u8
const rk
[],
43 int rounds
, int blocks
, u8 iv
[]);
45 /* borrowed from aes-neon-blk.ko */
46 asmlinkage
void neon_aes_ecb_encrypt(u8 out
[], u8
const in
[], u32
const rk
[],
47 int rounds
, int blocks
);
48 asmlinkage
void neon_aes_cbc_encrypt(u8 out
[], u8
const in
[], u32
const rk
[],
49 int rounds
, int blocks
, u8 iv
[]);
50 asmlinkage
void neon_aes_ctr_encrypt(u8 out
[], u8
const in
[], u32
const rk
[],
51 int rounds
, int bytes
, u8 ctr
[]);
52 asmlinkage
void neon_aes_xts_encrypt(u8 out
[], u8
const in
[],
53 u32
const rk1
[], int rounds
, int bytes
,
54 u32
const rk2
[], u8 iv
[], int first
);
55 asmlinkage
void neon_aes_xts_decrypt(u8 out
[], u8
const in
[],
56 u32
const rk1
[], int rounds
, int bytes
,
57 u32
const rk2
[], u8 iv
[], int first
);
60 u8 rk
[13 * (8 * AES_BLOCK_SIZE
) + 32];
62 } __aligned(AES_BLOCK_SIZE
);
64 struct aesbs_cbc_ctr_ctx
{
66 u32 enc
[AES_MAX_KEYLENGTH_U32
];
69 struct aesbs_xts_ctx
{
71 u32 twkey
[AES_MAX_KEYLENGTH_U32
];
72 struct crypto_aes_ctx cts
;
75 static int aesbs_setkey(struct crypto_skcipher
*tfm
, const u8
*in_key
,
78 struct aesbs_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
79 struct crypto_aes_ctx rk
;
82 err
= aes_expandkey(&rk
, in_key
, key_len
);
86 ctx
->rounds
= 6 + key_len
/ 4;
89 aesbs_convert_key(ctx
->rk
, rk
.key_enc
, ctx
->rounds
);
95 static int __ecb_crypt(struct skcipher_request
*req
,
96 void (*fn
)(u8 out
[], u8
const in
[], u8
const rk
[],
97 int rounds
, int blocks
))
99 struct crypto_skcipher
*tfm
= crypto_skcipher_reqtfm(req
);
100 struct aesbs_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
101 struct skcipher_walk walk
;
104 err
= skcipher_walk_virt(&walk
, req
, false);
106 while (walk
.nbytes
>= AES_BLOCK_SIZE
) {
107 unsigned int blocks
= walk
.nbytes
/ AES_BLOCK_SIZE
;
109 if (walk
.nbytes
< walk
.total
)
110 blocks
= round_down(blocks
,
111 walk
.stride
/ AES_BLOCK_SIZE
);
114 fn(walk
.dst
.virt
.addr
, walk
.src
.virt
.addr
, ctx
->rk
,
115 ctx
->rounds
, blocks
);
117 err
= skcipher_walk_done(&walk
,
118 walk
.nbytes
- blocks
* AES_BLOCK_SIZE
);
124 static int ecb_encrypt(struct skcipher_request
*req
)
126 return __ecb_crypt(req
, aesbs_ecb_encrypt
);
129 static int ecb_decrypt(struct skcipher_request
*req
)
131 return __ecb_crypt(req
, aesbs_ecb_decrypt
);
134 static int aesbs_cbc_ctr_setkey(struct crypto_skcipher
*tfm
, const u8
*in_key
,
135 unsigned int key_len
)
137 struct aesbs_cbc_ctr_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
138 struct crypto_aes_ctx rk
;
141 err
= aes_expandkey(&rk
, in_key
, key_len
);
145 ctx
->key
.rounds
= 6 + key_len
/ 4;
147 memcpy(ctx
->enc
, rk
.key_enc
, sizeof(ctx
->enc
));
150 aesbs_convert_key(ctx
->key
.rk
, rk
.key_enc
, ctx
->key
.rounds
);
152 memzero_explicit(&rk
, sizeof(rk
));
157 static int cbc_encrypt(struct skcipher_request
*req
)
159 struct crypto_skcipher
*tfm
= crypto_skcipher_reqtfm(req
);
160 struct aesbs_cbc_ctr_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
161 struct skcipher_walk walk
;
164 err
= skcipher_walk_virt(&walk
, req
, false);
166 while (walk
.nbytes
>= AES_BLOCK_SIZE
) {
167 unsigned int blocks
= walk
.nbytes
/ AES_BLOCK_SIZE
;
169 /* fall back to the non-bitsliced NEON implementation */
171 neon_aes_cbc_encrypt(walk
.dst
.virt
.addr
, walk
.src
.virt
.addr
,
172 ctx
->enc
, ctx
->key
.rounds
, blocks
,
175 err
= skcipher_walk_done(&walk
, walk
.nbytes
% AES_BLOCK_SIZE
);
180 static int cbc_decrypt(struct skcipher_request
*req
)
182 struct crypto_skcipher
*tfm
= crypto_skcipher_reqtfm(req
);
183 struct aesbs_cbc_ctr_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
184 struct skcipher_walk walk
;
187 err
= skcipher_walk_virt(&walk
, req
, false);
189 while (walk
.nbytes
>= AES_BLOCK_SIZE
) {
190 unsigned int blocks
= walk
.nbytes
/ AES_BLOCK_SIZE
;
192 if (walk
.nbytes
< walk
.total
)
193 blocks
= round_down(blocks
,
194 walk
.stride
/ AES_BLOCK_SIZE
);
197 aesbs_cbc_decrypt(walk
.dst
.virt
.addr
, walk
.src
.virt
.addr
,
198 ctx
->key
.rk
, ctx
->key
.rounds
, blocks
,
201 err
= skcipher_walk_done(&walk
,
202 walk
.nbytes
- blocks
* AES_BLOCK_SIZE
);
208 static int ctr_encrypt(struct skcipher_request
*req
)
210 struct crypto_skcipher
*tfm
= crypto_skcipher_reqtfm(req
);
211 struct aesbs_cbc_ctr_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
212 struct skcipher_walk walk
;
215 err
= skcipher_walk_virt(&walk
, req
, false);
217 while (walk
.nbytes
> 0) {
218 int blocks
= (walk
.nbytes
/ AES_BLOCK_SIZE
) & ~7;
219 int nbytes
= walk
.nbytes
% (8 * AES_BLOCK_SIZE
);
220 const u8
*src
= walk
.src
.virt
.addr
;
221 u8
*dst
= walk
.dst
.virt
.addr
;
225 aesbs_ctr_encrypt(dst
, src
, ctx
->key
.rk
, ctx
->key
.rounds
,
227 dst
+= blocks
* AES_BLOCK_SIZE
;
228 src
+= blocks
* AES_BLOCK_SIZE
;
230 if (nbytes
&& walk
.nbytes
== walk
.total
) {
231 u8 buf
[AES_BLOCK_SIZE
];
234 if (unlikely(nbytes
< AES_BLOCK_SIZE
))
235 src
= dst
= memcpy(buf
+ sizeof(buf
) - nbytes
,
238 neon_aes_ctr_encrypt(dst
, src
, ctx
->enc
, ctx
->key
.rounds
,
241 if (unlikely(nbytes
< AES_BLOCK_SIZE
))
242 memcpy(d
, dst
, nbytes
);
247 err
= skcipher_walk_done(&walk
, nbytes
);
252 static int aesbs_xts_setkey(struct crypto_skcipher
*tfm
, const u8
*in_key
,
253 unsigned int key_len
)
255 struct aesbs_xts_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
256 struct crypto_aes_ctx rk
;
259 err
= xts_verify_key(tfm
, in_key
, key_len
);
264 err
= aes_expandkey(&ctx
->cts
, in_key
, key_len
);
268 err
= aes_expandkey(&rk
, in_key
+ key_len
, key_len
);
272 memcpy(ctx
->twkey
, rk
.key_enc
, sizeof(ctx
->twkey
));
274 return aesbs_setkey(tfm
, in_key
, key_len
);
277 static int __xts_crypt(struct skcipher_request
*req
, bool encrypt
,
278 void (*fn
)(u8 out
[], u8
const in
[], u8
const rk
[],
279 int rounds
, int blocks
, u8 iv
[]))
281 struct crypto_skcipher
*tfm
= crypto_skcipher_reqtfm(req
);
282 struct aesbs_xts_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
283 int tail
= req
->cryptlen
% (8 * AES_BLOCK_SIZE
);
284 struct scatterlist sg_src
[2], sg_dst
[2];
285 struct skcipher_request subreq
;
286 struct scatterlist
*src
, *dst
;
287 struct skcipher_walk walk
;
292 if (req
->cryptlen
< AES_BLOCK_SIZE
)
295 /* ensure that the cts tail is covered by a single step */
296 if (unlikely(tail
> 0 && tail
< AES_BLOCK_SIZE
)) {
297 int xts_blocks
= DIV_ROUND_UP(req
->cryptlen
,
300 skcipher_request_set_tfm(&subreq
, tfm
);
301 skcipher_request_set_callback(&subreq
,
302 skcipher_request_flags(req
),
304 skcipher_request_set_crypt(&subreq
, req
->src
, req
->dst
,
305 xts_blocks
* AES_BLOCK_SIZE
,
312 err
= skcipher_walk_virt(&walk
, req
, false);
316 while (walk
.nbytes
>= AES_BLOCK_SIZE
) {
317 int blocks
= (walk
.nbytes
/ AES_BLOCK_SIZE
) & ~7;
318 out
= walk
.dst
.virt
.addr
;
319 in
= walk
.src
.virt
.addr
;
320 nbytes
= walk
.nbytes
;
325 neon_aes_ecb_encrypt(walk
.iv
, walk
.iv
,
330 fn(out
, in
, ctx
->key
.rk
, ctx
->key
.rounds
, blocks
,
333 out
+= blocks
* AES_BLOCK_SIZE
;
334 in
+= blocks
* AES_BLOCK_SIZE
;
335 nbytes
-= blocks
* AES_BLOCK_SIZE
;
337 if (walk
.nbytes
== walk
.total
&& nbytes
> 0) {
339 neon_aes_xts_encrypt(out
, in
, ctx
->cts
.key_enc
,
340 ctx
->key
.rounds
, nbytes
,
341 ctx
->twkey
, walk
.iv
, first
);
343 neon_aes_xts_decrypt(out
, in
, ctx
->cts
.key_dec
,
344 ctx
->key
.rounds
, nbytes
,
345 ctx
->twkey
, walk
.iv
, first
);
349 err
= skcipher_walk_done(&walk
, nbytes
);
352 if (err
|| likely(!tail
))
355 /* handle ciphertext stealing */
356 dst
= src
= scatterwalk_ffwd(sg_src
, req
->src
, req
->cryptlen
);
357 if (req
->dst
!= req
->src
)
358 dst
= scatterwalk_ffwd(sg_dst
, req
->dst
, req
->cryptlen
);
360 skcipher_request_set_crypt(req
, src
, dst
, AES_BLOCK_SIZE
+ tail
,
363 err
= skcipher_walk_virt(&walk
, req
, false);
367 out
= walk
.dst
.virt
.addr
;
368 in
= walk
.src
.virt
.addr
;
369 nbytes
= walk
.nbytes
;
373 neon_aes_xts_encrypt(out
, in
, ctx
->cts
.key_enc
, ctx
->key
.rounds
,
374 nbytes
, ctx
->twkey
, walk
.iv
, first
);
376 neon_aes_xts_decrypt(out
, in
, ctx
->cts
.key_dec
, ctx
->key
.rounds
,
377 nbytes
, ctx
->twkey
, walk
.iv
, first
);
380 return skcipher_walk_done(&walk
, 0);
383 static int xts_encrypt(struct skcipher_request
*req
)
385 return __xts_crypt(req
, true, aesbs_xts_encrypt
);
388 static int xts_decrypt(struct skcipher_request
*req
)
390 return __xts_crypt(req
, false, aesbs_xts_decrypt
);
393 static struct skcipher_alg aes_algs
[] = { {
394 .base
.cra_name
= "ecb(aes)",
395 .base
.cra_driver_name
= "ecb-aes-neonbs",
396 .base
.cra_priority
= 250,
397 .base
.cra_blocksize
= AES_BLOCK_SIZE
,
398 .base
.cra_ctxsize
= sizeof(struct aesbs_ctx
),
399 .base
.cra_module
= THIS_MODULE
,
401 .min_keysize
= AES_MIN_KEY_SIZE
,
402 .max_keysize
= AES_MAX_KEY_SIZE
,
403 .walksize
= 8 * AES_BLOCK_SIZE
,
404 .setkey
= aesbs_setkey
,
405 .encrypt
= ecb_encrypt
,
406 .decrypt
= ecb_decrypt
,
408 .base
.cra_name
= "cbc(aes)",
409 .base
.cra_driver_name
= "cbc-aes-neonbs",
410 .base
.cra_priority
= 250,
411 .base
.cra_blocksize
= AES_BLOCK_SIZE
,
412 .base
.cra_ctxsize
= sizeof(struct aesbs_cbc_ctr_ctx
),
413 .base
.cra_module
= THIS_MODULE
,
415 .min_keysize
= AES_MIN_KEY_SIZE
,
416 .max_keysize
= AES_MAX_KEY_SIZE
,
417 .walksize
= 8 * AES_BLOCK_SIZE
,
418 .ivsize
= AES_BLOCK_SIZE
,
419 .setkey
= aesbs_cbc_ctr_setkey
,
420 .encrypt
= cbc_encrypt
,
421 .decrypt
= cbc_decrypt
,
423 .base
.cra_name
= "ctr(aes)",
424 .base
.cra_driver_name
= "ctr-aes-neonbs",
425 .base
.cra_priority
= 250,
426 .base
.cra_blocksize
= 1,
427 .base
.cra_ctxsize
= sizeof(struct aesbs_cbc_ctr_ctx
),
428 .base
.cra_module
= THIS_MODULE
,
430 .min_keysize
= AES_MIN_KEY_SIZE
,
431 .max_keysize
= AES_MAX_KEY_SIZE
,
432 .chunksize
= AES_BLOCK_SIZE
,
433 .walksize
= 8 * AES_BLOCK_SIZE
,
434 .ivsize
= AES_BLOCK_SIZE
,
435 .setkey
= aesbs_cbc_ctr_setkey
,
436 .encrypt
= ctr_encrypt
,
437 .decrypt
= ctr_encrypt
,
439 .base
.cra_name
= "xts(aes)",
440 .base
.cra_driver_name
= "xts-aes-neonbs",
441 .base
.cra_priority
= 250,
442 .base
.cra_blocksize
= AES_BLOCK_SIZE
,
443 .base
.cra_ctxsize
= sizeof(struct aesbs_xts_ctx
),
444 .base
.cra_module
= THIS_MODULE
,
446 .min_keysize
= 2 * AES_MIN_KEY_SIZE
,
447 .max_keysize
= 2 * AES_MAX_KEY_SIZE
,
448 .walksize
= 8 * AES_BLOCK_SIZE
,
449 .ivsize
= AES_BLOCK_SIZE
,
450 .setkey
= aesbs_xts_setkey
,
451 .encrypt
= xts_encrypt
,
452 .decrypt
= xts_decrypt
,
455 static void aes_exit(void)
457 crypto_unregister_skciphers(aes_algs
, ARRAY_SIZE(aes_algs
));
460 static int __init
aes_init(void)
462 if (!cpu_have_named_feature(ASIMD
))
465 return crypto_register_skciphers(aes_algs
, ARRAY_SIZE(aes_algs
));
468 module_init(aes_init
);
469 module_exit(aes_exit
);