2 * Bit sliced AES using NEON instructions
4 * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6 * This program is free software; you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License version 2 as
8 * published by the Free Software Foundation.
12 #include <crypto/aes.h>
13 #include <crypto/cbc.h>
14 #include <crypto/internal/simd.h>
15 #include <crypto/internal/skcipher.h>
16 #include <crypto/xts.h>
17 #include <linux/module.h>
19 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
20 MODULE_LICENSE("GPL v2");
22 MODULE_ALIAS_CRYPTO("ecb(aes)");
23 MODULE_ALIAS_CRYPTO("cbc(aes)");
24 MODULE_ALIAS_CRYPTO("ctr(aes)");
25 MODULE_ALIAS_CRYPTO("xts(aes)");
27 asmlinkage
void aesbs_convert_key(u8 out
[], u32
const rk
[], int rounds
);
29 asmlinkage
void aesbs_ecb_encrypt(u8 out
[], u8
const in
[], u8
const rk
[],
30 int rounds
, int blocks
);
31 asmlinkage
void aesbs_ecb_decrypt(u8 out
[], u8
const in
[], u8
const rk
[],
32 int rounds
, int blocks
);
34 asmlinkage
void aesbs_cbc_decrypt(u8 out
[], u8
const in
[], u8
const rk
[],
35 int rounds
, int blocks
, u8 iv
[]);
37 asmlinkage
void aesbs_ctr_encrypt(u8 out
[], u8
const in
[], u8
const rk
[],
38 int rounds
, int blocks
, u8 ctr
[], u8 final
[]);
40 asmlinkage
void aesbs_xts_encrypt(u8 out
[], u8
const in
[], u8
const rk
[],
41 int rounds
, int blocks
, u8 iv
[]);
42 asmlinkage
void aesbs_xts_decrypt(u8 out
[], u8
const in
[], u8
const rk
[],
43 int rounds
, int blocks
, u8 iv
[]);
45 asmlinkage
void __aes_arm_encrypt(const u32 rk
[], int rounds
, const u8 in
[],
50 u8 rk
[13 * (8 * AES_BLOCK_SIZE
) + 32] __aligned(AES_BLOCK_SIZE
);
53 struct aesbs_cbc_ctx
{
55 u32 enc
[AES_MAX_KEYLENGTH_U32
];
58 struct aesbs_xts_ctx
{
60 u32 twkey
[AES_MAX_KEYLENGTH_U32
];
63 static int aesbs_setkey(struct crypto_skcipher
*tfm
, const u8
*in_key
,
66 struct aesbs_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
67 struct crypto_aes_ctx rk
;
70 err
= crypto_aes_expand_key(&rk
, in_key
, key_len
);
74 ctx
->rounds
= 6 + key_len
/ 4;
77 aesbs_convert_key(ctx
->rk
, rk
.key_enc
, ctx
->rounds
);
83 static int __ecb_crypt(struct skcipher_request
*req
,
84 void (*fn
)(u8 out
[], u8
const in
[], u8
const rk
[],
85 int rounds
, int blocks
))
87 struct crypto_skcipher
*tfm
= crypto_skcipher_reqtfm(req
);
88 struct aesbs_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
89 struct skcipher_walk walk
;
92 err
= skcipher_walk_virt(&walk
, req
, true);
95 while (walk
.nbytes
>= AES_BLOCK_SIZE
) {
96 unsigned int blocks
= walk
.nbytes
/ AES_BLOCK_SIZE
;
98 if (walk
.nbytes
< walk
.total
)
99 blocks
= round_down(blocks
,
100 walk
.stride
/ AES_BLOCK_SIZE
);
102 fn(walk
.dst
.virt
.addr
, walk
.src
.virt
.addr
, ctx
->rk
,
103 ctx
->rounds
, blocks
);
104 err
= skcipher_walk_done(&walk
,
105 walk
.nbytes
- blocks
* AES_BLOCK_SIZE
);
112 static int ecb_encrypt(struct skcipher_request
*req
)
114 return __ecb_crypt(req
, aesbs_ecb_encrypt
);
117 static int ecb_decrypt(struct skcipher_request
*req
)
119 return __ecb_crypt(req
, aesbs_ecb_decrypt
);
122 static int aesbs_cbc_setkey(struct crypto_skcipher
*tfm
, const u8
*in_key
,
123 unsigned int key_len
)
125 struct aesbs_cbc_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
126 struct crypto_aes_ctx rk
;
129 err
= crypto_aes_expand_key(&rk
, in_key
, key_len
);
133 ctx
->key
.rounds
= 6 + key_len
/ 4;
135 memcpy(ctx
->enc
, rk
.key_enc
, sizeof(ctx
->enc
));
138 aesbs_convert_key(ctx
->key
.rk
, rk
.key_enc
, ctx
->key
.rounds
);
144 static void cbc_encrypt_one(struct crypto_skcipher
*tfm
, const u8
*src
, u8
*dst
)
146 struct aesbs_cbc_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
148 __aes_arm_encrypt(ctx
->enc
, ctx
->key
.rounds
, src
, dst
);
151 static int cbc_encrypt(struct skcipher_request
*req
)
153 return crypto_cbc_encrypt_walk(req
, cbc_encrypt_one
);
156 static int cbc_decrypt(struct skcipher_request
*req
)
158 struct crypto_skcipher
*tfm
= crypto_skcipher_reqtfm(req
);
159 struct aesbs_cbc_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
160 struct skcipher_walk walk
;
163 err
= skcipher_walk_virt(&walk
, req
, true);
166 while (walk
.nbytes
>= AES_BLOCK_SIZE
) {
167 unsigned int blocks
= walk
.nbytes
/ AES_BLOCK_SIZE
;
169 if (walk
.nbytes
< walk
.total
)
170 blocks
= round_down(blocks
,
171 walk
.stride
/ AES_BLOCK_SIZE
);
173 aesbs_cbc_decrypt(walk
.dst
.virt
.addr
, walk
.src
.virt
.addr
,
174 ctx
->key
.rk
, ctx
->key
.rounds
, blocks
,
176 err
= skcipher_walk_done(&walk
,
177 walk
.nbytes
- blocks
* AES_BLOCK_SIZE
);
184 static int ctr_encrypt(struct skcipher_request
*req
)
186 struct crypto_skcipher
*tfm
= crypto_skcipher_reqtfm(req
);
187 struct aesbs_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
188 struct skcipher_walk walk
;
189 u8 buf
[AES_BLOCK_SIZE
];
192 err
= skcipher_walk_virt(&walk
, req
, true);
195 while (walk
.nbytes
> 0) {
196 unsigned int blocks
= walk
.nbytes
/ AES_BLOCK_SIZE
;
197 u8
*final
= (walk
.total
% AES_BLOCK_SIZE
) ? buf
: NULL
;
199 if (walk
.nbytes
< walk
.total
) {
200 blocks
= round_down(blocks
,
201 walk
.stride
/ AES_BLOCK_SIZE
);
205 aesbs_ctr_encrypt(walk
.dst
.virt
.addr
, walk
.src
.virt
.addr
,
206 ctx
->rk
, ctx
->rounds
, blocks
, walk
.iv
, final
);
209 u8
*dst
= walk
.dst
.virt
.addr
+ blocks
* AES_BLOCK_SIZE
;
210 u8
*src
= walk
.src
.virt
.addr
+ blocks
* AES_BLOCK_SIZE
;
213 memcpy(dst
, src
, walk
.total
% AES_BLOCK_SIZE
);
214 crypto_xor(dst
, final
, walk
.total
% AES_BLOCK_SIZE
);
216 err
= skcipher_walk_done(&walk
, 0);
219 err
= skcipher_walk_done(&walk
,
220 walk
.nbytes
- blocks
* AES_BLOCK_SIZE
);
227 static int aesbs_xts_setkey(struct crypto_skcipher
*tfm
, const u8
*in_key
,
228 unsigned int key_len
)
230 struct aesbs_xts_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
231 struct crypto_aes_ctx rk
;
234 err
= xts_verify_key(tfm
, in_key
, key_len
);
239 err
= crypto_aes_expand_key(&rk
, in_key
+ key_len
, key_len
);
243 memcpy(ctx
->twkey
, rk
.key_enc
, sizeof(ctx
->twkey
));
245 return aesbs_setkey(tfm
, in_key
, key_len
);
248 static int __xts_crypt(struct skcipher_request
*req
,
249 void (*fn
)(u8 out
[], u8
const in
[], u8
const rk
[],
250 int rounds
, int blocks
, u8 iv
[]))
252 struct crypto_skcipher
*tfm
= crypto_skcipher_reqtfm(req
);
253 struct aesbs_xts_ctx
*ctx
= crypto_skcipher_ctx(tfm
);
254 struct skcipher_walk walk
;
257 err
= skcipher_walk_virt(&walk
, req
, true);
259 __aes_arm_encrypt(ctx
->twkey
, ctx
->key
.rounds
, walk
.iv
, walk
.iv
);
262 while (walk
.nbytes
>= AES_BLOCK_SIZE
) {
263 unsigned int blocks
= walk
.nbytes
/ AES_BLOCK_SIZE
;
265 if (walk
.nbytes
< walk
.total
)
266 blocks
= round_down(blocks
,
267 walk
.stride
/ AES_BLOCK_SIZE
);
269 fn(walk
.dst
.virt
.addr
, walk
.src
.virt
.addr
, ctx
->key
.rk
,
270 ctx
->key
.rounds
, blocks
, walk
.iv
);
271 err
= skcipher_walk_done(&walk
,
272 walk
.nbytes
- blocks
* AES_BLOCK_SIZE
);
279 static int xts_encrypt(struct skcipher_request
*req
)
281 return __xts_crypt(req
, aesbs_xts_encrypt
);
284 static int xts_decrypt(struct skcipher_request
*req
)
286 return __xts_crypt(req
, aesbs_xts_decrypt
);
289 static struct skcipher_alg aes_algs
[] = { {
290 .base
.cra_name
= "__ecb(aes)",
291 .base
.cra_driver_name
= "__ecb-aes-neonbs",
292 .base
.cra_priority
= 250,
293 .base
.cra_blocksize
= AES_BLOCK_SIZE
,
294 .base
.cra_ctxsize
= sizeof(struct aesbs_ctx
),
295 .base
.cra_module
= THIS_MODULE
,
296 .base
.cra_flags
= CRYPTO_ALG_INTERNAL
,
298 .min_keysize
= AES_MIN_KEY_SIZE
,
299 .max_keysize
= AES_MAX_KEY_SIZE
,
300 .walksize
= 8 * AES_BLOCK_SIZE
,
301 .setkey
= aesbs_setkey
,
302 .encrypt
= ecb_encrypt
,
303 .decrypt
= ecb_decrypt
,
305 .base
.cra_name
= "__cbc(aes)",
306 .base
.cra_driver_name
= "__cbc-aes-neonbs",
307 .base
.cra_priority
= 250,
308 .base
.cra_blocksize
= AES_BLOCK_SIZE
,
309 .base
.cra_ctxsize
= sizeof(struct aesbs_cbc_ctx
),
310 .base
.cra_module
= THIS_MODULE
,
311 .base
.cra_flags
= CRYPTO_ALG_INTERNAL
,
313 .min_keysize
= AES_MIN_KEY_SIZE
,
314 .max_keysize
= AES_MAX_KEY_SIZE
,
315 .walksize
= 8 * AES_BLOCK_SIZE
,
316 .ivsize
= AES_BLOCK_SIZE
,
317 .setkey
= aesbs_cbc_setkey
,
318 .encrypt
= cbc_encrypt
,
319 .decrypt
= cbc_decrypt
,
321 .base
.cra_name
= "__ctr(aes)",
322 .base
.cra_driver_name
= "__ctr-aes-neonbs",
323 .base
.cra_priority
= 250,
324 .base
.cra_blocksize
= 1,
325 .base
.cra_ctxsize
= sizeof(struct aesbs_ctx
),
326 .base
.cra_module
= THIS_MODULE
,
327 .base
.cra_flags
= CRYPTO_ALG_INTERNAL
,
329 .min_keysize
= AES_MIN_KEY_SIZE
,
330 .max_keysize
= AES_MAX_KEY_SIZE
,
331 .chunksize
= AES_BLOCK_SIZE
,
332 .walksize
= 8 * AES_BLOCK_SIZE
,
333 .ivsize
= AES_BLOCK_SIZE
,
334 .setkey
= aesbs_setkey
,
335 .encrypt
= ctr_encrypt
,
336 .decrypt
= ctr_encrypt
,
338 .base
.cra_name
= "__xts(aes)",
339 .base
.cra_driver_name
= "__xts-aes-neonbs",
340 .base
.cra_priority
= 250,
341 .base
.cra_blocksize
= AES_BLOCK_SIZE
,
342 .base
.cra_ctxsize
= sizeof(struct aesbs_xts_ctx
),
343 .base
.cra_module
= THIS_MODULE
,
344 .base
.cra_flags
= CRYPTO_ALG_INTERNAL
,
346 .min_keysize
= 2 * AES_MIN_KEY_SIZE
,
347 .max_keysize
= 2 * AES_MAX_KEY_SIZE
,
348 .walksize
= 8 * AES_BLOCK_SIZE
,
349 .ivsize
= AES_BLOCK_SIZE
,
350 .setkey
= aesbs_xts_setkey
,
351 .encrypt
= xts_encrypt
,
352 .decrypt
= xts_decrypt
,
355 static struct simd_skcipher_alg
*aes_simd_algs
[ARRAY_SIZE(aes_algs
)];
357 static void aes_exit(void)
361 for (i
= 0; i
< ARRAY_SIZE(aes_simd_algs
); i
++)
362 if (aes_simd_algs
[i
])
363 simd_skcipher_free(aes_simd_algs
[i
]);
365 crypto_unregister_skciphers(aes_algs
, ARRAY_SIZE(aes_algs
));
368 static int __init
aes_init(void)
370 struct simd_skcipher_alg
*simd
;
371 const char *basename
;
377 if (!(elf_hwcap
& HWCAP_NEON
))
380 err
= crypto_register_skciphers(aes_algs
, ARRAY_SIZE(aes_algs
));
384 for (i
= 0; i
< ARRAY_SIZE(aes_algs
); i
++) {
385 if (!(aes_algs
[i
].base
.cra_flags
& CRYPTO_ALG_INTERNAL
))
388 algname
= aes_algs
[i
].base
.cra_name
+ 2;
389 drvname
= aes_algs
[i
].base
.cra_driver_name
+ 2;
390 basename
= aes_algs
[i
].base
.cra_driver_name
;
391 simd
= simd_skcipher_create_compat(algname
, drvname
, basename
);
394 goto unregister_simds
;
396 aes_simd_algs
[i
] = simd
;
405 module_init(aes_init
);
406 module_exit(aes_exit
);