1 // SPDX-License-Identifier: GPL-2.0-only
3 * Bit sliced AES using NEON instructions
5 * Copyright (C) 2016 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
10 #include <crypto/aes.h>
11 #include <crypto/ctr.h>
12 #include <crypto/internal/simd.h>
13 #include <crypto/internal/skcipher.h>
14 #include <crypto/scatterwalk.h>
15 #include <crypto/xts.h>
16 #include <linux/module.h>
18 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
19 MODULE_LICENSE("GPL v2");
21 MODULE_ALIAS_CRYPTO("ecb(aes)");
22 MODULE_ALIAS_CRYPTO("cbc(aes)");
23 MODULE_ALIAS_CRYPTO("ctr(aes)");
24 MODULE_ALIAS_CRYPTO("xts(aes)");
26 asmlinkage
void aesbs_convert_key(u8 out
[], u32
const rk
[], int rounds
);
28 asmlinkage
void aesbs_ecb_encrypt(u8 out
[], u8
const in
[], u8
const rk
[],
29 int rounds
, int blocks
);
30 asmlinkage
void aesbs_ecb_decrypt(u8 out
[], u8
const in
[], u8
const rk
[],
31 int rounds
, int blocks
);
33 asmlinkage
void aesbs_cbc_decrypt(u8 out
[], u8
const in
[], u8
const rk
[],
34 int rounds
, int blocks
, u8 iv
[]);
36 asmlinkage
void aesbs_ctr_encrypt(u8 out
[], u8
const in
[], u8
const rk
[],
37 int rounds
, int blocks
, u8 iv
[], u8 final
[]);
39 asmlinkage
void aesbs_xts_encrypt(u8 out
[], u8
const in
[], u8
const rk
[],
40 int rounds
, int blocks
, u8 iv
[]);
41 asmlinkage
void aesbs_xts_decrypt(u8 out
[], u8
const in
[], u8
const rk
[],
42 int rounds
, int blocks
, u8 iv
[]);
44 /* borrowed from aes-neon-blk.ko */
45 asmlinkage
void neon_aes_ecb_encrypt(u8 out
[], u8
const in
[], u32
const rk
[],
46 int rounds
, int blocks
);
47 asmlinkage
void neon_aes_cbc_encrypt(u8 out
[], u8
const in
[], u32
const rk
[],
48 int rounds
, int blocks
, u8 iv
[]);
49 asmlinkage
void neon_aes_xts_encrypt(u8 out
[], u8
const in
[],
50 u32
const rk1
[], int rounds
, int bytes
,
51 u32
const rk2
[], u8 iv
[], int first
);
52 asmlinkage
void neon_aes_xts_decrypt(u8 out
[], u8
const in
[],
53 u32
const rk1
[], int rounds
, int bytes
,
54 u32
const rk2
[], u8 iv
[], int first
);
57 u8 rk
[13 * (8 * AES_BLOCK_SIZE
) + 32];
59 } __aligned(AES_BLOCK_SIZE
);
61 struct aesbs_cbc_ctx
{
63 u32 enc
[AES_MAX_KEYLENGTH_U32
];
66 struct aesbs_ctr_ctx
{
67 struct aesbs_ctx key
; /* must be first member */
68 struct crypto_aes_ctx fallback
;
71 struct aesbs_xts_ctx
{
73 u32 twkey
[AES_MAX_KEYLENGTH_U32
];
74 struct crypto_aes_ctx cts
;
77 static int aesbs_setkey(struct crypto_skcipher
*tfm
, const u8
*in_key
,
80 struct aesbs_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
81 struct crypto_aes_ctx rk
;
84 err
= aes_expandkey(&rk
, in_key
, key_len
);
88 ctx
->rounds
= 6 + key_len
/ 4;
91 aesbs_convert_key(ctx
->rk
, rk
.key_enc
, ctx
->rounds
);
97 static int __ecb_crypt(struct skcipher_request
*req
,
98 void (*fn
)(u8 out
[], u8
const in
[], u8
const rk
[],
99 int rounds
, int blocks
))
101 struct crypto_skcipher
*tfm
= crypto_skcipher_reqtfm(req
);
102 struct aesbs_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
103 struct skcipher_walk walk
;
106 err
= skcipher_walk_virt(&walk
, req
, false);
108 while (walk
.nbytes
>= AES_BLOCK_SIZE
) {
109 unsigned int blocks
= walk
.nbytes
/ AES_BLOCK_SIZE
;
111 if (walk
.nbytes
< walk
.total
)
112 blocks
= round_down(blocks
,
113 walk
.stride
/ AES_BLOCK_SIZE
);
116 fn(walk
.dst
.virt
.addr
, walk
.src
.virt
.addr
, ctx
->rk
,
117 ctx
->rounds
, blocks
);
119 err
= skcipher_walk_done(&walk
,
120 walk
.nbytes
- blocks
* AES_BLOCK_SIZE
);
126 static int ecb_encrypt(struct skcipher_request
*req
)
128 return __ecb_crypt(req
, aesbs_ecb_encrypt
);
131 static int ecb_decrypt(struct skcipher_request
*req
)
133 return __ecb_crypt(req
, aesbs_ecb_decrypt
);
136 static int aesbs_cbc_setkey(struct crypto_skcipher
*tfm
, const u8
*in_key
,
137 unsigned int key_len
)
139 struct aesbs_cbc_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
140 struct crypto_aes_ctx rk
;
143 err
= aes_expandkey(&rk
, in_key
, key_len
);
147 ctx
->key
.rounds
= 6 + key_len
/ 4;
149 memcpy(ctx
->enc
, rk
.key_enc
, sizeof(ctx
->enc
));
152 aesbs_convert_key(ctx
->key
.rk
, rk
.key_enc
, ctx
->key
.rounds
);
154 memzero_explicit(&rk
, sizeof(rk
));
159 static int cbc_encrypt(struct skcipher_request
*req
)
161 struct crypto_skcipher
*tfm
= crypto_skcipher_reqtfm(req
);
162 struct aesbs_cbc_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
163 struct skcipher_walk walk
;
166 err
= skcipher_walk_virt(&walk
, req
, false);
168 while (walk
.nbytes
>= AES_BLOCK_SIZE
) {
169 unsigned int blocks
= walk
.nbytes
/ AES_BLOCK_SIZE
;
171 /* fall back to the non-bitsliced NEON implementation */
173 neon_aes_cbc_encrypt(walk
.dst
.virt
.addr
, walk
.src
.virt
.addr
,
174 ctx
->enc
, ctx
->key
.rounds
, blocks
,
177 err
= skcipher_walk_done(&walk
, walk
.nbytes
% AES_BLOCK_SIZE
);
182 static int cbc_decrypt(struct skcipher_request
*req
)
184 struct crypto_skcipher
*tfm
= crypto_skcipher_reqtfm(req
);
185 struct aesbs_cbc_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
186 struct skcipher_walk walk
;
189 err
= skcipher_walk_virt(&walk
, req
, false);
191 while (walk
.nbytes
>= AES_BLOCK_SIZE
) {
192 unsigned int blocks
= walk
.nbytes
/ AES_BLOCK_SIZE
;
194 if (walk
.nbytes
< walk
.total
)
195 blocks
= round_down(blocks
,
196 walk
.stride
/ AES_BLOCK_SIZE
);
199 aesbs_cbc_decrypt(walk
.dst
.virt
.addr
, walk
.src
.virt
.addr
,
200 ctx
->key
.rk
, ctx
->key
.rounds
, blocks
,
203 err
= skcipher_walk_done(&walk
,
204 walk
.nbytes
- blocks
* AES_BLOCK_SIZE
);
210 static int aesbs_ctr_setkey_sync(struct crypto_skcipher
*tfm
, const u8
*in_key
,
211 unsigned int key_len
)
213 struct aesbs_ctr_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
216 err
= aes_expandkey(&ctx
->fallback
, in_key
, key_len
);
220 ctx
->key
.rounds
= 6 + key_len
/ 4;
223 aesbs_convert_key(ctx
->key
.rk
, ctx
->fallback
.key_enc
, ctx
->key
.rounds
);
229 static int ctr_encrypt(struct skcipher_request
*req
)
231 struct crypto_skcipher
*tfm
= crypto_skcipher_reqtfm(req
);
232 struct aesbs_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
233 struct skcipher_walk walk
;
234 u8 buf
[AES_BLOCK_SIZE
];
237 err
= skcipher_walk_virt(&walk
, req
, false);
239 while (walk
.nbytes
> 0) {
240 unsigned int blocks
= walk
.nbytes
/ AES_BLOCK_SIZE
;
241 u8
*final
= (walk
.total
% AES_BLOCK_SIZE
) ? buf
: NULL
;
243 if (walk
.nbytes
< walk
.total
) {
244 blocks
= round_down(blocks
,
245 walk
.stride
/ AES_BLOCK_SIZE
);
250 aesbs_ctr_encrypt(walk
.dst
.virt
.addr
, walk
.src
.virt
.addr
,
251 ctx
->rk
, ctx
->rounds
, blocks
, walk
.iv
, final
);
255 u8
*dst
= walk
.dst
.virt
.addr
+ blocks
* AES_BLOCK_SIZE
;
256 u8
*src
= walk
.src
.virt
.addr
+ blocks
* AES_BLOCK_SIZE
;
258 crypto_xor_cpy(dst
, src
, final
,
259 walk
.total
% AES_BLOCK_SIZE
);
261 err
= skcipher_walk_done(&walk
, 0);
264 err
= skcipher_walk_done(&walk
,
265 walk
.nbytes
- blocks
* AES_BLOCK_SIZE
);
270 static int aesbs_xts_setkey(struct crypto_skcipher
*tfm
, const u8
*in_key
,
271 unsigned int key_len
)
273 struct aesbs_xts_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
274 struct crypto_aes_ctx rk
;
277 err
= xts_verify_key(tfm
, in_key
, key_len
);
282 err
= aes_expandkey(&ctx
->cts
, in_key
, key_len
);
286 err
= aes_expandkey(&rk
, in_key
+ key_len
, key_len
);
290 memcpy(ctx
->twkey
, rk
.key_enc
, sizeof(ctx
->twkey
));
292 return aesbs_setkey(tfm
, in_key
, key_len
);
295 static void ctr_encrypt_one(struct crypto_skcipher
*tfm
, const u8
*src
, u8
*dst
)
297 struct aesbs_ctr_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
301 * Temporarily disable interrupts to avoid races where
302 * cachelines are evicted when the CPU is interrupted
303 * to do something else.
305 local_irq_save(flags
);
306 aes_encrypt(&ctx
->fallback
, dst
, src
);
307 local_irq_restore(flags
);
310 static int ctr_encrypt_sync(struct skcipher_request
*req
)
312 if (!crypto_simd_usable())
313 return crypto_ctr_encrypt_walk(req
, ctr_encrypt_one
);
315 return ctr_encrypt(req
);
318 static int __xts_crypt(struct skcipher_request
*req
, bool encrypt
,
319 void (*fn
)(u8 out
[], u8
const in
[], u8
const rk
[],
320 int rounds
, int blocks
, u8 iv
[]))
322 struct crypto_skcipher
*tfm
= crypto_skcipher_reqtfm(req
);
323 struct aesbs_xts_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
324 int tail
= req
->cryptlen
% (8 * AES_BLOCK_SIZE
);
325 struct scatterlist sg_src
[2], sg_dst
[2];
326 struct skcipher_request subreq
;
327 struct scatterlist
*src
, *dst
;
328 struct skcipher_walk walk
;
333 if (req
->cryptlen
< AES_BLOCK_SIZE
)
336 /* ensure that the cts tail is covered by a single step */
337 if (unlikely(tail
> 0 && tail
< AES_BLOCK_SIZE
)) {
338 int xts_blocks
= DIV_ROUND_UP(req
->cryptlen
,
341 skcipher_request_set_tfm(&subreq
, tfm
);
342 skcipher_request_set_callback(&subreq
,
343 skcipher_request_flags(req
),
345 skcipher_request_set_crypt(&subreq
, req
->src
, req
->dst
,
346 xts_blocks
* AES_BLOCK_SIZE
,
353 err
= skcipher_walk_virt(&walk
, req
, false);
357 while (walk
.nbytes
>= AES_BLOCK_SIZE
) {
358 unsigned int blocks
= walk
.nbytes
/ AES_BLOCK_SIZE
;
360 if (walk
.nbytes
< walk
.total
|| walk
.nbytes
% AES_BLOCK_SIZE
)
361 blocks
= round_down(blocks
,
362 walk
.stride
/ AES_BLOCK_SIZE
);
364 out
= walk
.dst
.virt
.addr
;
365 in
= walk
.src
.virt
.addr
;
366 nbytes
= walk
.nbytes
;
369 if (likely(blocks
> 6)) { /* plain NEON is faster otherwise */
371 neon_aes_ecb_encrypt(walk
.iv
, walk
.iv
,
376 fn(out
, in
, ctx
->key
.rk
, ctx
->key
.rounds
, blocks
,
379 out
+= blocks
* AES_BLOCK_SIZE
;
380 in
+= blocks
* AES_BLOCK_SIZE
;
381 nbytes
-= blocks
* AES_BLOCK_SIZE
;
384 if (walk
.nbytes
== walk
.total
&& nbytes
> 0)
388 err
= skcipher_walk_done(&walk
, nbytes
);
391 if (err
|| likely(!tail
))
394 /* handle ciphertext stealing */
395 dst
= src
= scatterwalk_ffwd(sg_src
, req
->src
, req
->cryptlen
);
396 if (req
->dst
!= req
->src
)
397 dst
= scatterwalk_ffwd(sg_dst
, req
->dst
, req
->cryptlen
);
399 skcipher_request_set_crypt(req
, src
, dst
, AES_BLOCK_SIZE
+ tail
,
402 err
= skcipher_walk_virt(&walk
, req
, false);
406 out
= walk
.dst
.virt
.addr
;
407 in
= walk
.src
.virt
.addr
;
408 nbytes
= walk
.nbytes
;
413 neon_aes_xts_encrypt(out
, in
, ctx
->cts
.key_enc
, ctx
->key
.rounds
,
414 nbytes
, ctx
->twkey
, walk
.iv
, first
?: 2);
416 neon_aes_xts_decrypt(out
, in
, ctx
->cts
.key_dec
, ctx
->key
.rounds
,
417 nbytes
, ctx
->twkey
, walk
.iv
, first
?: 2);
420 return skcipher_walk_done(&walk
, 0);
423 static int xts_encrypt(struct skcipher_request
*req
)
425 return __xts_crypt(req
, true, aesbs_xts_encrypt
);
428 static int xts_decrypt(struct skcipher_request
*req
)
430 return __xts_crypt(req
, false, aesbs_xts_decrypt
);
433 static struct skcipher_alg aes_algs
[] = { {
434 .base
.cra_name
= "__ecb(aes)",
435 .base
.cra_driver_name
= "__ecb-aes-neonbs",
436 .base
.cra_priority
= 250,
437 .base
.cra_blocksize
= AES_BLOCK_SIZE
,
438 .base
.cra_ctxsize
= sizeof(struct aesbs_ctx
),
439 .base
.cra_module
= THIS_MODULE
,
440 .base
.cra_flags
= CRYPTO_ALG_INTERNAL
,
442 .min_keysize
= AES_MIN_KEY_SIZE
,
443 .max_keysize
= AES_MAX_KEY_SIZE
,
444 .walksize
= 8 * AES_BLOCK_SIZE
,
445 .setkey
= aesbs_setkey
,
446 .encrypt
= ecb_encrypt
,
447 .decrypt
= ecb_decrypt
,
449 .base
.cra_name
= "__cbc(aes)",
450 .base
.cra_driver_name
= "__cbc-aes-neonbs",
451 .base
.cra_priority
= 250,
452 .base
.cra_blocksize
= AES_BLOCK_SIZE
,
453 .base
.cra_ctxsize
= sizeof(struct aesbs_cbc_ctx
),
454 .base
.cra_module
= THIS_MODULE
,
455 .base
.cra_flags
= CRYPTO_ALG_INTERNAL
,
457 .min_keysize
= AES_MIN_KEY_SIZE
,
458 .max_keysize
= AES_MAX_KEY_SIZE
,
459 .walksize
= 8 * AES_BLOCK_SIZE
,
460 .ivsize
= AES_BLOCK_SIZE
,
461 .setkey
= aesbs_cbc_setkey
,
462 .encrypt
= cbc_encrypt
,
463 .decrypt
= cbc_decrypt
,
465 .base
.cra_name
= "__ctr(aes)",
466 .base
.cra_driver_name
= "__ctr-aes-neonbs",
467 .base
.cra_priority
= 250,
468 .base
.cra_blocksize
= 1,
469 .base
.cra_ctxsize
= sizeof(struct aesbs_ctx
),
470 .base
.cra_module
= THIS_MODULE
,
471 .base
.cra_flags
= CRYPTO_ALG_INTERNAL
,
473 .min_keysize
= AES_MIN_KEY_SIZE
,
474 .max_keysize
= AES_MAX_KEY_SIZE
,
475 .chunksize
= AES_BLOCK_SIZE
,
476 .walksize
= 8 * AES_BLOCK_SIZE
,
477 .ivsize
= AES_BLOCK_SIZE
,
478 .setkey
= aesbs_setkey
,
479 .encrypt
= ctr_encrypt
,
480 .decrypt
= ctr_encrypt
,
482 .base
.cra_name
= "ctr(aes)",
483 .base
.cra_driver_name
= "ctr-aes-neonbs",
484 .base
.cra_priority
= 250 - 1,
485 .base
.cra_blocksize
= 1,
486 .base
.cra_ctxsize
= sizeof(struct aesbs_ctr_ctx
),
487 .base
.cra_module
= THIS_MODULE
,
489 .min_keysize
= AES_MIN_KEY_SIZE
,
490 .max_keysize
= AES_MAX_KEY_SIZE
,
491 .chunksize
= AES_BLOCK_SIZE
,
492 .walksize
= 8 * AES_BLOCK_SIZE
,
493 .ivsize
= AES_BLOCK_SIZE
,
494 .setkey
= aesbs_ctr_setkey_sync
,
495 .encrypt
= ctr_encrypt_sync
,
496 .decrypt
= ctr_encrypt_sync
,
498 .base
.cra_name
= "__xts(aes)",
499 .base
.cra_driver_name
= "__xts-aes-neonbs",
500 .base
.cra_priority
= 250,
501 .base
.cra_blocksize
= AES_BLOCK_SIZE
,
502 .base
.cra_ctxsize
= sizeof(struct aesbs_xts_ctx
),
503 .base
.cra_module
= THIS_MODULE
,
504 .base
.cra_flags
= CRYPTO_ALG_INTERNAL
,
506 .min_keysize
= 2 * AES_MIN_KEY_SIZE
,
507 .max_keysize
= 2 * AES_MAX_KEY_SIZE
,
508 .walksize
= 8 * AES_BLOCK_SIZE
,
509 .ivsize
= AES_BLOCK_SIZE
,
510 .setkey
= aesbs_xts_setkey
,
511 .encrypt
= xts_encrypt
,
512 .decrypt
= xts_decrypt
,
515 static struct simd_skcipher_alg
*aes_simd_algs
[ARRAY_SIZE(aes_algs
)];
517 static void aes_exit(void)
521 for (i
= 0; i
< ARRAY_SIZE(aes_simd_algs
); i
++)
522 if (aes_simd_algs
[i
])
523 simd_skcipher_free(aes_simd_algs
[i
]);
525 crypto_unregister_skciphers(aes_algs
, ARRAY_SIZE(aes_algs
));
528 static int __init
aes_init(void)
530 struct simd_skcipher_alg
*simd
;
531 const char *basename
;
537 if (!cpu_have_named_feature(ASIMD
))
540 err
= crypto_register_skciphers(aes_algs
, ARRAY_SIZE(aes_algs
));
544 for (i
= 0; i
< ARRAY_SIZE(aes_algs
); i
++) {
545 if (!(aes_algs
[i
].base
.cra_flags
& CRYPTO_ALG_INTERNAL
))
548 algname
= aes_algs
[i
].base
.cra_name
+ 2;
549 drvname
= aes_algs
[i
].base
.cra_driver_name
+ 2;
550 basename
= aes_algs
[i
].base
.cra_driver_name
;
551 simd
= simd_skcipher_create_compat(algname
, drvname
, basename
);
554 goto unregister_simds
;
556 aes_simd_algs
[i
] = simd
;
565 module_init(aes_init
);
566 module_exit(aes_exit
);