1 // SPDX-License-Identifier: GPL-2.0-only
3 * Bit sliced AES using NEON instructions
5 * Copyright (C) 2016 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
10 #include <crypto/aes.h>
11 #include <crypto/ctr.h>
12 #include <crypto/internal/simd.h>
13 #include <crypto/internal/skcipher.h>
14 #include <crypto/scatterwalk.h>
15 #include <crypto/xts.h>
16 #include <linux/module.h>
18 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
19 MODULE_LICENSE("GPL v2");
21 MODULE_ALIAS_CRYPTO("ecb(aes)");
22 MODULE_ALIAS_CRYPTO("cbc(aes)");
23 MODULE_ALIAS_CRYPTO("ctr(aes)");
24 MODULE_ALIAS_CRYPTO("xts(aes)");
26 asmlinkage
void aesbs_convert_key(u8 out
[], u32
const rk
[], int rounds
);
28 asmlinkage
void aesbs_ecb_encrypt(u8 out
[], u8
const in
[], u8
const rk
[],
29 int rounds
, int blocks
);
30 asmlinkage
void aesbs_ecb_decrypt(u8 out
[], u8
const in
[], u8
const rk
[],
31 int rounds
, int blocks
);
33 asmlinkage
void aesbs_cbc_decrypt(u8 out
[], u8
const in
[], u8
const rk
[],
34 int rounds
, int blocks
, u8 iv
[]);
36 asmlinkage
void aesbs_ctr_encrypt(u8 out
[], u8
const in
[], u8
const rk
[],
37 int rounds
, int blocks
, u8 iv
[], u8 final
[]);
39 asmlinkage
void aesbs_xts_encrypt(u8 out
[], u8
const in
[], u8
const rk
[],
40 int rounds
, int blocks
, u8 iv
[]);
41 asmlinkage
void aesbs_xts_decrypt(u8 out
[], u8
const in
[], u8
const rk
[],
42 int rounds
, int blocks
, u8 iv
[]);
44 /* borrowed from aes-neon-blk.ko */
45 asmlinkage
void neon_aes_ecb_encrypt(u8 out
[], u8
const in
[], u32
const rk
[],
46 int rounds
, int blocks
);
47 asmlinkage
void neon_aes_cbc_encrypt(u8 out
[], u8
const in
[], u32
const rk
[],
48 int rounds
, int blocks
, u8 iv
[]);
49 asmlinkage
void neon_aes_xts_encrypt(u8 out
[], u8
const in
[],
50 u32
const rk1
[], int rounds
, int bytes
,
51 u32
const rk2
[], u8 iv
[], int first
);
52 asmlinkage
void neon_aes_xts_decrypt(u8 out
[], u8
const in
[],
53 u32
const rk1
[], int rounds
, int bytes
,
54 u32
const rk2
[], u8 iv
[], int first
);
57 u8 rk
[13 * (8 * AES_BLOCK_SIZE
) + 32];
59 } __aligned(AES_BLOCK_SIZE
);
61 struct aesbs_cbc_ctx
{
63 u32 enc
[AES_MAX_KEYLENGTH_U32
];
66 struct aesbs_ctr_ctx
{
67 struct aesbs_ctx key
; /* must be first member */
68 struct crypto_aes_ctx fallback
;
71 struct aesbs_xts_ctx
{
73 u32 twkey
[AES_MAX_KEYLENGTH_U32
];
74 struct crypto_aes_ctx cts
;
77 static int aesbs_setkey(struct crypto_skcipher
*tfm
, const u8
*in_key
,
80 struct aesbs_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
81 struct crypto_aes_ctx rk
;
84 err
= aes_expandkey(&rk
, in_key
, key_len
);
88 ctx
->rounds
= 6 + key_len
/ 4;
91 aesbs_convert_key(ctx
->rk
, rk
.key_enc
, ctx
->rounds
);
97 static int __ecb_crypt(struct skcipher_request
*req
,
98 void (*fn
)(u8 out
[], u8
const in
[], u8
const rk
[],
99 int rounds
, int blocks
))
101 struct crypto_skcipher
*tfm
= crypto_skcipher_reqtfm(req
);
102 struct aesbs_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
103 struct skcipher_walk walk
;
106 err
= skcipher_walk_virt(&walk
, req
, false);
108 while (walk
.nbytes
>= AES_BLOCK_SIZE
) {
109 unsigned int blocks
= walk
.nbytes
/ AES_BLOCK_SIZE
;
111 if (walk
.nbytes
< walk
.total
)
112 blocks
= round_down(blocks
,
113 walk
.stride
/ AES_BLOCK_SIZE
);
116 fn(walk
.dst
.virt
.addr
, walk
.src
.virt
.addr
, ctx
->rk
,
117 ctx
->rounds
, blocks
);
119 err
= skcipher_walk_done(&walk
,
120 walk
.nbytes
- blocks
* AES_BLOCK_SIZE
);
126 static int ecb_encrypt(struct skcipher_request
*req
)
128 return __ecb_crypt(req
, aesbs_ecb_encrypt
);
131 static int ecb_decrypt(struct skcipher_request
*req
)
133 return __ecb_crypt(req
, aesbs_ecb_decrypt
);
136 static int aesbs_cbc_setkey(struct crypto_skcipher
*tfm
, const u8
*in_key
,
137 unsigned int key_len
)
139 struct aesbs_cbc_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
140 struct crypto_aes_ctx rk
;
143 err
= aes_expandkey(&rk
, in_key
, key_len
);
147 ctx
->key
.rounds
= 6 + key_len
/ 4;
149 memcpy(ctx
->enc
, rk
.key_enc
, sizeof(ctx
->enc
));
152 aesbs_convert_key(ctx
->key
.rk
, rk
.key_enc
, ctx
->key
.rounds
);
158 static int cbc_encrypt(struct skcipher_request
*req
)
160 struct crypto_skcipher
*tfm
= crypto_skcipher_reqtfm(req
);
161 struct aesbs_cbc_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
162 struct skcipher_walk walk
;
165 err
= skcipher_walk_virt(&walk
, req
, false);
167 while (walk
.nbytes
>= AES_BLOCK_SIZE
) {
168 unsigned int blocks
= walk
.nbytes
/ AES_BLOCK_SIZE
;
170 /* fall back to the non-bitsliced NEON implementation */
172 neon_aes_cbc_encrypt(walk
.dst
.virt
.addr
, walk
.src
.virt
.addr
,
173 ctx
->enc
, ctx
->key
.rounds
, blocks
,
176 err
= skcipher_walk_done(&walk
, walk
.nbytes
% AES_BLOCK_SIZE
);
181 static int cbc_decrypt(struct skcipher_request
*req
)
183 struct crypto_skcipher
*tfm
= crypto_skcipher_reqtfm(req
);
184 struct aesbs_cbc_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
185 struct skcipher_walk walk
;
188 err
= skcipher_walk_virt(&walk
, req
, false);
190 while (walk
.nbytes
>= AES_BLOCK_SIZE
) {
191 unsigned int blocks
= walk
.nbytes
/ AES_BLOCK_SIZE
;
193 if (walk
.nbytes
< walk
.total
)
194 blocks
= round_down(blocks
,
195 walk
.stride
/ AES_BLOCK_SIZE
);
198 aesbs_cbc_decrypt(walk
.dst
.virt
.addr
, walk
.src
.virt
.addr
,
199 ctx
->key
.rk
, ctx
->key
.rounds
, blocks
,
202 err
= skcipher_walk_done(&walk
,
203 walk
.nbytes
- blocks
* AES_BLOCK_SIZE
);
209 static int aesbs_ctr_setkey_sync(struct crypto_skcipher
*tfm
, const u8
*in_key
,
210 unsigned int key_len
)
212 struct aesbs_ctr_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
215 err
= aes_expandkey(&ctx
->fallback
, in_key
, key_len
);
219 ctx
->key
.rounds
= 6 + key_len
/ 4;
222 aesbs_convert_key(ctx
->key
.rk
, ctx
->fallback
.key_enc
, ctx
->key
.rounds
);
228 static int ctr_encrypt(struct skcipher_request
*req
)
230 struct crypto_skcipher
*tfm
= crypto_skcipher_reqtfm(req
);
231 struct aesbs_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
232 struct skcipher_walk walk
;
233 u8 buf
[AES_BLOCK_SIZE
];
236 err
= skcipher_walk_virt(&walk
, req
, false);
238 while (walk
.nbytes
> 0) {
239 unsigned int blocks
= walk
.nbytes
/ AES_BLOCK_SIZE
;
240 u8
*final
= (walk
.total
% AES_BLOCK_SIZE
) ? buf
: NULL
;
242 if (walk
.nbytes
< walk
.total
) {
243 blocks
= round_down(blocks
,
244 walk
.stride
/ AES_BLOCK_SIZE
);
249 aesbs_ctr_encrypt(walk
.dst
.virt
.addr
, walk
.src
.virt
.addr
,
250 ctx
->rk
, ctx
->rounds
, blocks
, walk
.iv
, final
);
254 u8
*dst
= walk
.dst
.virt
.addr
+ blocks
* AES_BLOCK_SIZE
;
255 u8
*src
= walk
.src
.virt
.addr
+ blocks
* AES_BLOCK_SIZE
;
257 crypto_xor_cpy(dst
, src
, final
,
258 walk
.total
% AES_BLOCK_SIZE
);
260 err
= skcipher_walk_done(&walk
, 0);
263 err
= skcipher_walk_done(&walk
,
264 walk
.nbytes
- blocks
* AES_BLOCK_SIZE
);
269 static int aesbs_xts_setkey(struct crypto_skcipher
*tfm
, const u8
*in_key
,
270 unsigned int key_len
)
272 struct aesbs_xts_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
273 struct crypto_aes_ctx rk
;
276 err
= xts_verify_key(tfm
, in_key
, key_len
);
281 err
= aes_expandkey(&ctx
->cts
, in_key
, key_len
);
285 err
= aes_expandkey(&rk
, in_key
+ key_len
, key_len
);
289 memcpy(ctx
->twkey
, rk
.key_enc
, sizeof(ctx
->twkey
));
291 return aesbs_setkey(tfm
, in_key
, key_len
);
294 static void ctr_encrypt_one(struct crypto_skcipher
*tfm
, const u8
*src
, u8
*dst
)
296 struct aesbs_ctr_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
300 * Temporarily disable interrupts to avoid races where
301 * cachelines are evicted when the CPU is interrupted
302 * to do something else.
304 local_irq_save(flags
);
305 aes_encrypt(&ctx
->fallback
, dst
, src
);
306 local_irq_restore(flags
);
309 static int ctr_encrypt_sync(struct skcipher_request
*req
)
311 if (!crypto_simd_usable())
312 return crypto_ctr_encrypt_walk(req
, ctr_encrypt_one
);
314 return ctr_encrypt(req
);
317 static int __xts_crypt(struct skcipher_request
*req
, bool encrypt
,
318 void (*fn
)(u8 out
[], u8
const in
[], u8
const rk
[],
319 int rounds
, int blocks
, u8 iv
[]))
321 struct crypto_skcipher
*tfm
= crypto_skcipher_reqtfm(req
);
322 struct aesbs_xts_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
323 int tail
= req
->cryptlen
% (8 * AES_BLOCK_SIZE
);
324 struct scatterlist sg_src
[2], sg_dst
[2];
325 struct skcipher_request subreq
;
326 struct scatterlist
*src
, *dst
;
327 struct skcipher_walk walk
;
332 if (req
->cryptlen
< AES_BLOCK_SIZE
)
335 /* ensure that the cts tail is covered by a single step */
336 if (unlikely(tail
> 0 && tail
< AES_BLOCK_SIZE
)) {
337 int xts_blocks
= DIV_ROUND_UP(req
->cryptlen
,
340 skcipher_request_set_tfm(&subreq
, tfm
);
341 skcipher_request_set_callback(&subreq
,
342 skcipher_request_flags(req
),
344 skcipher_request_set_crypt(&subreq
, req
->src
, req
->dst
,
345 xts_blocks
* AES_BLOCK_SIZE
,
352 err
= skcipher_walk_virt(&walk
, req
, false);
356 while (walk
.nbytes
>= AES_BLOCK_SIZE
) {
357 unsigned int blocks
= walk
.nbytes
/ AES_BLOCK_SIZE
;
359 if (walk
.nbytes
< walk
.total
|| walk
.nbytes
% AES_BLOCK_SIZE
)
360 blocks
= round_down(blocks
,
361 walk
.stride
/ AES_BLOCK_SIZE
);
363 out
= walk
.dst
.virt
.addr
;
364 in
= walk
.src
.virt
.addr
;
365 nbytes
= walk
.nbytes
;
368 if (likely(blocks
> 6)) { /* plain NEON is faster otherwise */
370 neon_aes_ecb_encrypt(walk
.iv
, walk
.iv
,
375 fn(out
, in
, ctx
->key
.rk
, ctx
->key
.rounds
, blocks
,
378 out
+= blocks
* AES_BLOCK_SIZE
;
379 in
+= blocks
* AES_BLOCK_SIZE
;
380 nbytes
-= blocks
* AES_BLOCK_SIZE
;
383 if (walk
.nbytes
== walk
.total
&& nbytes
> 0)
387 err
= skcipher_walk_done(&walk
, nbytes
);
390 if (err
|| likely(!tail
))
393 /* handle ciphertext stealing */
394 dst
= src
= scatterwalk_ffwd(sg_src
, req
->src
, req
->cryptlen
);
395 if (req
->dst
!= req
->src
)
396 dst
= scatterwalk_ffwd(sg_dst
, req
->dst
, req
->cryptlen
);
398 skcipher_request_set_crypt(req
, src
, dst
, AES_BLOCK_SIZE
+ tail
,
401 err
= skcipher_walk_virt(&walk
, req
, false);
405 out
= walk
.dst
.virt
.addr
;
406 in
= walk
.src
.virt
.addr
;
407 nbytes
= walk
.nbytes
;
412 neon_aes_xts_encrypt(out
, in
, ctx
->cts
.key_enc
, ctx
->key
.rounds
,
413 nbytes
, ctx
->twkey
, walk
.iv
, first
?: 2);
415 neon_aes_xts_decrypt(out
, in
, ctx
->cts
.key_dec
, ctx
->key
.rounds
,
416 nbytes
, ctx
->twkey
, walk
.iv
, first
?: 2);
419 return skcipher_walk_done(&walk
, 0);
422 static int xts_encrypt(struct skcipher_request
*req
)
424 return __xts_crypt(req
, true, aesbs_xts_encrypt
);
427 static int xts_decrypt(struct skcipher_request
*req
)
429 return __xts_crypt(req
, false, aesbs_xts_decrypt
);
432 static struct skcipher_alg aes_algs
[] = { {
433 .base
.cra_name
= "__ecb(aes)",
434 .base
.cra_driver_name
= "__ecb-aes-neonbs",
435 .base
.cra_priority
= 250,
436 .base
.cra_blocksize
= AES_BLOCK_SIZE
,
437 .base
.cra_ctxsize
= sizeof(struct aesbs_ctx
),
438 .base
.cra_module
= THIS_MODULE
,
439 .base
.cra_flags
= CRYPTO_ALG_INTERNAL
,
441 .min_keysize
= AES_MIN_KEY_SIZE
,
442 .max_keysize
= AES_MAX_KEY_SIZE
,
443 .walksize
= 8 * AES_BLOCK_SIZE
,
444 .setkey
= aesbs_setkey
,
445 .encrypt
= ecb_encrypt
,
446 .decrypt
= ecb_decrypt
,
448 .base
.cra_name
= "__cbc(aes)",
449 .base
.cra_driver_name
= "__cbc-aes-neonbs",
450 .base
.cra_priority
= 250,
451 .base
.cra_blocksize
= AES_BLOCK_SIZE
,
452 .base
.cra_ctxsize
= sizeof(struct aesbs_cbc_ctx
),
453 .base
.cra_module
= THIS_MODULE
,
454 .base
.cra_flags
= CRYPTO_ALG_INTERNAL
,
456 .min_keysize
= AES_MIN_KEY_SIZE
,
457 .max_keysize
= AES_MAX_KEY_SIZE
,
458 .walksize
= 8 * AES_BLOCK_SIZE
,
459 .ivsize
= AES_BLOCK_SIZE
,
460 .setkey
= aesbs_cbc_setkey
,
461 .encrypt
= cbc_encrypt
,
462 .decrypt
= cbc_decrypt
,
464 .base
.cra_name
= "__ctr(aes)",
465 .base
.cra_driver_name
= "__ctr-aes-neonbs",
466 .base
.cra_priority
= 250,
467 .base
.cra_blocksize
= 1,
468 .base
.cra_ctxsize
= sizeof(struct aesbs_ctx
),
469 .base
.cra_module
= THIS_MODULE
,
470 .base
.cra_flags
= CRYPTO_ALG_INTERNAL
,
472 .min_keysize
= AES_MIN_KEY_SIZE
,
473 .max_keysize
= AES_MAX_KEY_SIZE
,
474 .chunksize
= AES_BLOCK_SIZE
,
475 .walksize
= 8 * AES_BLOCK_SIZE
,
476 .ivsize
= AES_BLOCK_SIZE
,
477 .setkey
= aesbs_setkey
,
478 .encrypt
= ctr_encrypt
,
479 .decrypt
= ctr_encrypt
,
481 .base
.cra_name
= "ctr(aes)",
482 .base
.cra_driver_name
= "ctr-aes-neonbs",
483 .base
.cra_priority
= 250 - 1,
484 .base
.cra_blocksize
= 1,
485 .base
.cra_ctxsize
= sizeof(struct aesbs_ctr_ctx
),
486 .base
.cra_module
= THIS_MODULE
,
488 .min_keysize
= AES_MIN_KEY_SIZE
,
489 .max_keysize
= AES_MAX_KEY_SIZE
,
490 .chunksize
= AES_BLOCK_SIZE
,
491 .walksize
= 8 * AES_BLOCK_SIZE
,
492 .ivsize
= AES_BLOCK_SIZE
,
493 .setkey
= aesbs_ctr_setkey_sync
,
494 .encrypt
= ctr_encrypt_sync
,
495 .decrypt
= ctr_encrypt_sync
,
497 .base
.cra_name
= "__xts(aes)",
498 .base
.cra_driver_name
= "__xts-aes-neonbs",
499 .base
.cra_priority
= 250,
500 .base
.cra_blocksize
= AES_BLOCK_SIZE
,
501 .base
.cra_ctxsize
= sizeof(struct aesbs_xts_ctx
),
502 .base
.cra_module
= THIS_MODULE
,
503 .base
.cra_flags
= CRYPTO_ALG_INTERNAL
,
505 .min_keysize
= 2 * AES_MIN_KEY_SIZE
,
506 .max_keysize
= 2 * AES_MAX_KEY_SIZE
,
507 .walksize
= 8 * AES_BLOCK_SIZE
,
508 .ivsize
= AES_BLOCK_SIZE
,
509 .setkey
= aesbs_xts_setkey
,
510 .encrypt
= xts_encrypt
,
511 .decrypt
= xts_decrypt
,
514 static struct simd_skcipher_alg
*aes_simd_algs
[ARRAY_SIZE(aes_algs
)];
516 static void aes_exit(void)
520 for (i
= 0; i
< ARRAY_SIZE(aes_simd_algs
); i
++)
521 if (aes_simd_algs
[i
])
522 simd_skcipher_free(aes_simd_algs
[i
]);
524 crypto_unregister_skciphers(aes_algs
, ARRAY_SIZE(aes_algs
));
527 static int __init
aes_init(void)
529 struct simd_skcipher_alg
*simd
;
530 const char *basename
;
536 if (!cpu_have_named_feature(ASIMD
))
539 err
= crypto_register_skciphers(aes_algs
, ARRAY_SIZE(aes_algs
));
543 for (i
= 0; i
< ARRAY_SIZE(aes_algs
); i
++) {
544 if (!(aes_algs
[i
].base
.cra_flags
& CRYPTO_ALG_INTERNAL
))
547 algname
= aes_algs
[i
].base
.cra_name
+ 2;
548 drvname
= aes_algs
[i
].base
.cra_driver_name
+ 2;
549 basename
= aes_algs
[i
].base
.cra_driver_name
;
550 simd
= simd_skcipher_create_compat(algname
, drvname
, basename
);
553 goto unregister_simds
;
555 aes_simd_algs
[i
] = simd
;
564 module_init(aes_init
);
565 module_exit(aes_exit
);