1 // SPDX-License-Identifier: GPL-2.0-only
3 * Bit sliced AES using NEON instructions
5 * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
10 #include <crypto/aes.h>
11 #include <crypto/ctr.h>
12 #include <crypto/internal/simd.h>
13 #include <crypto/internal/skcipher.h>
14 #include <crypto/scatterwalk.h>
15 #include <crypto/xts.h>
16 #include <linux/module.h>
17 #include "aes-cipher.h"
19 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
20 MODULE_DESCRIPTION("Bit sliced AES using NEON instructions");
21 MODULE_LICENSE("GPL v2");
23 MODULE_ALIAS_CRYPTO("ecb(aes)");
24 MODULE_ALIAS_CRYPTO("cbc(aes)");
25 MODULE_ALIAS_CRYPTO("ctr(aes)");
26 MODULE_ALIAS_CRYPTO("xts(aes)");
28 asmlinkage
void aesbs_convert_key(u8 out
[], u32
const rk
[], int rounds
);
30 asmlinkage
void aesbs_ecb_encrypt(u8 out
[], u8
const in
[], u8
const rk
[],
31 int rounds
, int blocks
);
32 asmlinkage
void aesbs_ecb_decrypt(u8 out
[], u8
const in
[], u8
const rk
[],
33 int rounds
, int blocks
);
35 asmlinkage
void aesbs_cbc_decrypt(u8 out
[], u8
const in
[], u8
const rk
[],
36 int rounds
, int blocks
, u8 iv
[]);
38 asmlinkage
void aesbs_ctr_encrypt(u8 out
[], u8
const in
[], u8
const rk
[],
39 int rounds
, int blocks
, u8 ctr
[]);
41 asmlinkage
void aesbs_xts_encrypt(u8 out
[], u8
const in
[], u8
const rk
[],
42 int rounds
, int blocks
, u8 iv
[], int);
43 asmlinkage
void aesbs_xts_decrypt(u8 out
[], u8
const in
[], u8
const rk
[],
44 int rounds
, int blocks
, u8 iv
[], int);
48 u8 rk
[13 * (8 * AES_BLOCK_SIZE
) + 32] __aligned(AES_BLOCK_SIZE
);
51 struct aesbs_cbc_ctx
{
53 struct crypto_aes_ctx fallback
;
56 struct aesbs_xts_ctx
{
58 struct crypto_aes_ctx fallback
;
59 struct crypto_aes_ctx tweak_key
;
62 struct aesbs_ctr_ctx
{
63 struct aesbs_ctx key
; /* must be first member */
64 struct crypto_aes_ctx fallback
;
67 static int aesbs_setkey(struct crypto_skcipher
*tfm
, const u8
*in_key
,
70 struct aesbs_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
71 struct crypto_aes_ctx rk
;
74 err
= aes_expandkey(&rk
, in_key
, key_len
);
78 ctx
->rounds
= 6 + key_len
/ 4;
81 aesbs_convert_key(ctx
->rk
, rk
.key_enc
, ctx
->rounds
);
87 static int __ecb_crypt(struct skcipher_request
*req
,
88 void (*fn
)(u8 out
[], u8
const in
[], u8
const rk
[],
89 int rounds
, int blocks
))
91 struct crypto_skcipher
*tfm
= crypto_skcipher_reqtfm(req
);
92 struct aesbs_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
93 struct skcipher_walk walk
;
96 err
= skcipher_walk_virt(&walk
, req
, false);
98 while (walk
.nbytes
>= AES_BLOCK_SIZE
) {
99 unsigned int blocks
= walk
.nbytes
/ AES_BLOCK_SIZE
;
101 if (walk
.nbytes
< walk
.total
)
102 blocks
= round_down(blocks
,
103 walk
.stride
/ AES_BLOCK_SIZE
);
106 fn(walk
.dst
.virt
.addr
, walk
.src
.virt
.addr
, ctx
->rk
,
107 ctx
->rounds
, blocks
);
109 err
= skcipher_walk_done(&walk
,
110 walk
.nbytes
- blocks
* AES_BLOCK_SIZE
);
116 static int ecb_encrypt(struct skcipher_request
*req
)
118 return __ecb_crypt(req
, aesbs_ecb_encrypt
);
121 static int ecb_decrypt(struct skcipher_request
*req
)
123 return __ecb_crypt(req
, aesbs_ecb_decrypt
);
126 static int aesbs_cbc_setkey(struct crypto_skcipher
*tfm
, const u8
*in_key
,
127 unsigned int key_len
)
129 struct aesbs_cbc_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
132 err
= aes_expandkey(&ctx
->fallback
, in_key
, key_len
);
136 ctx
->key
.rounds
= 6 + key_len
/ 4;
139 aesbs_convert_key(ctx
->key
.rk
, ctx
->fallback
.key_enc
, ctx
->key
.rounds
);
145 static int cbc_encrypt(struct skcipher_request
*req
)
147 struct crypto_skcipher
*tfm
= crypto_skcipher_reqtfm(req
);
148 const struct aesbs_cbc_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
149 struct skcipher_walk walk
;
153 err
= skcipher_walk_virt(&walk
, req
, false);
155 while ((nbytes
= walk
.nbytes
) >= AES_BLOCK_SIZE
) {
156 const u8
*src
= walk
.src
.virt
.addr
;
157 u8
*dst
= walk
.dst
.virt
.addr
;
161 crypto_xor_cpy(dst
, src
, prev
, AES_BLOCK_SIZE
);
162 __aes_arm_encrypt(ctx
->fallback
.key_enc
,
163 ctx
->key
.rounds
, dst
, dst
);
165 src
+= AES_BLOCK_SIZE
;
166 dst
+= AES_BLOCK_SIZE
;
167 nbytes
-= AES_BLOCK_SIZE
;
168 } while (nbytes
>= AES_BLOCK_SIZE
);
169 memcpy(walk
.iv
, prev
, AES_BLOCK_SIZE
);
170 err
= skcipher_walk_done(&walk
, nbytes
);
175 static int cbc_decrypt(struct skcipher_request
*req
)
177 struct crypto_skcipher
*tfm
= crypto_skcipher_reqtfm(req
);
178 struct aesbs_cbc_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
179 struct skcipher_walk walk
;
182 err
= skcipher_walk_virt(&walk
, req
, false);
184 while (walk
.nbytes
>= AES_BLOCK_SIZE
) {
185 unsigned int blocks
= walk
.nbytes
/ AES_BLOCK_SIZE
;
187 if (walk
.nbytes
< walk
.total
)
188 blocks
= round_down(blocks
,
189 walk
.stride
/ AES_BLOCK_SIZE
);
192 aesbs_cbc_decrypt(walk
.dst
.virt
.addr
, walk
.src
.virt
.addr
,
193 ctx
->key
.rk
, ctx
->key
.rounds
, blocks
,
196 err
= skcipher_walk_done(&walk
,
197 walk
.nbytes
- blocks
* AES_BLOCK_SIZE
);
203 static int aesbs_ctr_setkey_sync(struct crypto_skcipher
*tfm
, const u8
*in_key
,
204 unsigned int key_len
)
206 struct aesbs_ctr_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
209 err
= aes_expandkey(&ctx
->fallback
, in_key
, key_len
);
213 ctx
->key
.rounds
= 6 + key_len
/ 4;
216 aesbs_convert_key(ctx
->key
.rk
, ctx
->fallback
.key_enc
, ctx
->key
.rounds
);
222 static int ctr_encrypt(struct skcipher_request
*req
)
224 struct crypto_skcipher
*tfm
= crypto_skcipher_reqtfm(req
);
225 struct aesbs_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
226 struct skcipher_walk walk
;
227 u8 buf
[AES_BLOCK_SIZE
];
230 err
= skcipher_walk_virt(&walk
, req
, false);
232 while (walk
.nbytes
> 0) {
233 const u8
*src
= walk
.src
.virt
.addr
;
234 u8
*dst
= walk
.dst
.virt
.addr
;
235 int bytes
= walk
.nbytes
;
237 if (unlikely(bytes
< AES_BLOCK_SIZE
))
238 src
= dst
= memcpy(buf
+ sizeof(buf
) - bytes
,
240 else if (walk
.nbytes
< walk
.total
)
241 bytes
&= ~(8 * AES_BLOCK_SIZE
- 1);
244 aesbs_ctr_encrypt(dst
, src
, ctx
->rk
, ctx
->rounds
, bytes
, walk
.iv
);
247 if (unlikely(bytes
< AES_BLOCK_SIZE
))
248 memcpy(walk
.dst
.virt
.addr
,
249 buf
+ sizeof(buf
) - bytes
, bytes
);
251 err
= skcipher_walk_done(&walk
, walk
.nbytes
- bytes
);
257 static void ctr_encrypt_one(struct crypto_skcipher
*tfm
, const u8
*src
, u8
*dst
)
259 struct aesbs_ctr_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
261 __aes_arm_encrypt(ctx
->fallback
.key_enc
, ctx
->key
.rounds
, src
, dst
);
264 static int ctr_encrypt_sync(struct skcipher_request
*req
)
266 if (!crypto_simd_usable())
267 return crypto_ctr_encrypt_walk(req
, ctr_encrypt_one
);
269 return ctr_encrypt(req
);
272 static int aesbs_xts_setkey(struct crypto_skcipher
*tfm
, const u8
*in_key
,
273 unsigned int key_len
)
275 struct aesbs_xts_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
278 err
= xts_verify_key(tfm
, in_key
, key_len
);
283 err
= aes_expandkey(&ctx
->fallback
, in_key
, key_len
);
286 err
= aes_expandkey(&ctx
->tweak_key
, in_key
+ key_len
, key_len
);
290 return aesbs_setkey(tfm
, in_key
, key_len
);
293 static int __xts_crypt(struct skcipher_request
*req
, bool encrypt
,
294 void (*fn
)(u8 out
[], u8
const in
[], u8
const rk
[],
295 int rounds
, int blocks
, u8 iv
[], int))
297 struct crypto_skcipher
*tfm
= crypto_skcipher_reqtfm(req
);
298 struct aesbs_xts_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
299 const int rounds
= ctx
->key
.rounds
;
300 int tail
= req
->cryptlen
% AES_BLOCK_SIZE
;
301 struct skcipher_request subreq
;
302 u8 buf
[2 * AES_BLOCK_SIZE
];
303 struct skcipher_walk walk
;
306 if (req
->cryptlen
< AES_BLOCK_SIZE
)
309 if (unlikely(tail
)) {
310 skcipher_request_set_tfm(&subreq
, tfm
);
311 skcipher_request_set_callback(&subreq
,
312 skcipher_request_flags(req
),
314 skcipher_request_set_crypt(&subreq
, req
->src
, req
->dst
,
315 req
->cryptlen
- tail
, req
->iv
);
319 err
= skcipher_walk_virt(&walk
, req
, true);
323 __aes_arm_encrypt(ctx
->tweak_key
.key_enc
, rounds
, walk
.iv
, walk
.iv
);
325 while (walk
.nbytes
>= AES_BLOCK_SIZE
) {
326 unsigned int blocks
= walk
.nbytes
/ AES_BLOCK_SIZE
;
327 int reorder_last_tweak
= !encrypt
&& tail
> 0;
329 if (walk
.nbytes
< walk
.total
) {
330 blocks
= round_down(blocks
,
331 walk
.stride
/ AES_BLOCK_SIZE
);
332 reorder_last_tweak
= 0;
336 fn(walk
.dst
.virt
.addr
, walk
.src
.virt
.addr
, ctx
->key
.rk
,
337 rounds
, blocks
, walk
.iv
, reorder_last_tweak
);
339 err
= skcipher_walk_done(&walk
,
340 walk
.nbytes
- blocks
* AES_BLOCK_SIZE
);
343 if (err
|| likely(!tail
))
346 /* handle ciphertext stealing */
347 scatterwalk_map_and_copy(buf
, req
->dst
, req
->cryptlen
- AES_BLOCK_SIZE
,
349 memcpy(buf
+ AES_BLOCK_SIZE
, buf
, tail
);
350 scatterwalk_map_and_copy(buf
, req
->src
, req
->cryptlen
, tail
, 0);
352 crypto_xor(buf
, req
->iv
, AES_BLOCK_SIZE
);
355 __aes_arm_encrypt(ctx
->fallback
.key_enc
, rounds
, buf
, buf
);
357 __aes_arm_decrypt(ctx
->fallback
.key_dec
, rounds
, buf
, buf
);
359 crypto_xor(buf
, req
->iv
, AES_BLOCK_SIZE
);
361 scatterwalk_map_and_copy(buf
, req
->dst
, req
->cryptlen
- AES_BLOCK_SIZE
,
362 AES_BLOCK_SIZE
+ tail
, 1);
366 static int xts_encrypt(struct skcipher_request
*req
)
368 return __xts_crypt(req
, true, aesbs_xts_encrypt
);
371 static int xts_decrypt(struct skcipher_request
*req
)
373 return __xts_crypt(req
, false, aesbs_xts_decrypt
);
376 static struct skcipher_alg aes_algs
[] = { {
377 .base
.cra_name
= "__ecb(aes)",
378 .base
.cra_driver_name
= "__ecb-aes-neonbs",
379 .base
.cra_priority
= 250,
380 .base
.cra_blocksize
= AES_BLOCK_SIZE
,
381 .base
.cra_ctxsize
= sizeof(struct aesbs_ctx
),
382 .base
.cra_module
= THIS_MODULE
,
383 .base
.cra_flags
= CRYPTO_ALG_INTERNAL
,
385 .min_keysize
= AES_MIN_KEY_SIZE
,
386 .max_keysize
= AES_MAX_KEY_SIZE
,
387 .walksize
= 8 * AES_BLOCK_SIZE
,
388 .setkey
= aesbs_setkey
,
389 .encrypt
= ecb_encrypt
,
390 .decrypt
= ecb_decrypt
,
392 .base
.cra_name
= "__cbc(aes)",
393 .base
.cra_driver_name
= "__cbc-aes-neonbs",
394 .base
.cra_priority
= 250,
395 .base
.cra_blocksize
= AES_BLOCK_SIZE
,
396 .base
.cra_ctxsize
= sizeof(struct aesbs_cbc_ctx
),
397 .base
.cra_module
= THIS_MODULE
,
398 .base
.cra_flags
= CRYPTO_ALG_INTERNAL
,
400 .min_keysize
= AES_MIN_KEY_SIZE
,
401 .max_keysize
= AES_MAX_KEY_SIZE
,
402 .walksize
= 8 * AES_BLOCK_SIZE
,
403 .ivsize
= AES_BLOCK_SIZE
,
404 .setkey
= aesbs_cbc_setkey
,
405 .encrypt
= cbc_encrypt
,
406 .decrypt
= cbc_decrypt
,
408 .base
.cra_name
= "__ctr(aes)",
409 .base
.cra_driver_name
= "__ctr-aes-neonbs",
410 .base
.cra_priority
= 250,
411 .base
.cra_blocksize
= 1,
412 .base
.cra_ctxsize
= sizeof(struct aesbs_ctx
),
413 .base
.cra_module
= THIS_MODULE
,
414 .base
.cra_flags
= CRYPTO_ALG_INTERNAL
,
416 .min_keysize
= AES_MIN_KEY_SIZE
,
417 .max_keysize
= AES_MAX_KEY_SIZE
,
418 .chunksize
= AES_BLOCK_SIZE
,
419 .walksize
= 8 * AES_BLOCK_SIZE
,
420 .ivsize
= AES_BLOCK_SIZE
,
421 .setkey
= aesbs_setkey
,
422 .encrypt
= ctr_encrypt
,
423 .decrypt
= ctr_encrypt
,
425 .base
.cra_name
= "ctr(aes)",
426 .base
.cra_driver_name
= "ctr-aes-neonbs-sync",
427 .base
.cra_priority
= 250 - 1,
428 .base
.cra_blocksize
= 1,
429 .base
.cra_ctxsize
= sizeof(struct aesbs_ctr_ctx
),
430 .base
.cra_module
= THIS_MODULE
,
432 .min_keysize
= AES_MIN_KEY_SIZE
,
433 .max_keysize
= AES_MAX_KEY_SIZE
,
434 .chunksize
= AES_BLOCK_SIZE
,
435 .walksize
= 8 * AES_BLOCK_SIZE
,
436 .ivsize
= AES_BLOCK_SIZE
,
437 .setkey
= aesbs_ctr_setkey_sync
,
438 .encrypt
= ctr_encrypt_sync
,
439 .decrypt
= ctr_encrypt_sync
,
441 .base
.cra_name
= "__xts(aes)",
442 .base
.cra_driver_name
= "__xts-aes-neonbs",
443 .base
.cra_priority
= 250,
444 .base
.cra_blocksize
= AES_BLOCK_SIZE
,
445 .base
.cra_ctxsize
= sizeof(struct aesbs_xts_ctx
),
446 .base
.cra_module
= THIS_MODULE
,
447 .base
.cra_flags
= CRYPTO_ALG_INTERNAL
,
449 .min_keysize
= 2 * AES_MIN_KEY_SIZE
,
450 .max_keysize
= 2 * AES_MAX_KEY_SIZE
,
451 .walksize
= 8 * AES_BLOCK_SIZE
,
452 .ivsize
= AES_BLOCK_SIZE
,
453 .setkey
= aesbs_xts_setkey
,
454 .encrypt
= xts_encrypt
,
455 .decrypt
= xts_decrypt
,
458 static struct simd_skcipher_alg
*aes_simd_algs
[ARRAY_SIZE(aes_algs
)];
460 static void aes_exit(void)
464 for (i
= 0; i
< ARRAY_SIZE(aes_simd_algs
); i
++)
465 if (aes_simd_algs
[i
])
466 simd_skcipher_free(aes_simd_algs
[i
]);
468 crypto_unregister_skciphers(aes_algs
, ARRAY_SIZE(aes_algs
));
471 static int __init
aes_init(void)
473 struct simd_skcipher_alg
*simd
;
474 const char *basename
;
480 if (!(elf_hwcap
& HWCAP_NEON
))
483 err
= crypto_register_skciphers(aes_algs
, ARRAY_SIZE(aes_algs
));
487 for (i
= 0; i
< ARRAY_SIZE(aes_algs
); i
++) {
488 if (!(aes_algs
[i
].base
.cra_flags
& CRYPTO_ALG_INTERNAL
))
491 algname
= aes_algs
[i
].base
.cra_name
+ 2;
492 drvname
= aes_algs
[i
].base
.cra_driver_name
+ 2;
493 basename
= aes_algs
[i
].base
.cra_driver_name
;
494 simd
= simd_skcipher_create_compat(aes_algs
+ i
, algname
, drvname
, basename
);
497 goto unregister_simds
;
499 aes_simd_algs
[i
] = simd
;
508 late_initcall(aes_init
);
509 module_exit(aes_exit
);