1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /* Asymmetric algorithms supported by virtio crypto device
4 * Authors: zhenwei pi <pizhenwei@bytedance.com>
5 * lei he <helei.sig11@bytedance.com>
7 * Copyright 2022 Bytedance CO., LTD.
10 #include <crypto/engine.h>
11 #include <crypto/internal/akcipher.h>
12 #include <crypto/internal/rsa.h>
13 #include <crypto/scatterwalk.h>
14 #include <linux/err.h>
15 #include <linux/kernel.h>
16 #include <linux/mpi.h>
17 #include <linux/scatterlist.h>
18 #include <linux/slab.h>
19 #include <linux/string.h>
20 #include <uapi/linux/virtio_crypto.h>
21 #include "virtio_crypto_common.h"
23 struct virtio_crypto_rsa_ctx
{
27 struct virtio_crypto_akcipher_ctx
{
28 struct virtio_crypto
*vcrypto
;
29 struct crypto_akcipher
*tfm
;
33 struct virtio_crypto_rsa_ctx rsa_ctx
;
37 struct virtio_crypto_akcipher_request
{
38 struct virtio_crypto_request base
;
39 struct virtio_crypto_akcipher_ctx
*akcipher_ctx
;
40 struct akcipher_request
*akcipher_req
;
46 struct virtio_crypto_akcipher_algo
{
49 unsigned int active_devs
;
50 struct akcipher_engine_alg algo
;
53 static DEFINE_MUTEX(algs_lock
);
55 static void virtio_crypto_akcipher_finalize_req(
56 struct virtio_crypto_akcipher_request
*vc_akcipher_req
,
57 struct akcipher_request
*req
, int err
)
59 kfree(vc_akcipher_req
->src_buf
);
60 kfree(vc_akcipher_req
->dst_buf
);
61 vc_akcipher_req
->src_buf
= NULL
;
62 vc_akcipher_req
->dst_buf
= NULL
;
63 virtcrypto_clear_request(&vc_akcipher_req
->base
);
65 crypto_finalize_akcipher_request(vc_akcipher_req
->base
.dataq
->engine
, req
, err
);
68 static void virtio_crypto_dataq_akcipher_callback(struct virtio_crypto_request
*vc_req
, int len
)
70 struct virtio_crypto_akcipher_request
*vc_akcipher_req
=
71 container_of(vc_req
, struct virtio_crypto_akcipher_request
, base
);
72 struct akcipher_request
*akcipher_req
;
75 switch (vc_req
->status
) {
76 case VIRTIO_CRYPTO_OK
:
79 case VIRTIO_CRYPTO_INVSESS
:
80 case VIRTIO_CRYPTO_ERR
:
83 case VIRTIO_CRYPTO_BADMSG
:
87 case VIRTIO_CRYPTO_KEY_REJECTED
:
88 error
= -EKEYREJECTED
;
96 akcipher_req
= vc_akcipher_req
->akcipher_req
;
97 if (vc_akcipher_req
->opcode
!= VIRTIO_CRYPTO_AKCIPHER_VERIFY
) {
98 /* actuall length maybe less than dst buffer */
99 akcipher_req
->dst_len
= len
- sizeof(vc_req
->status
);
100 sg_copy_from_buffer(akcipher_req
->dst
, sg_nents(akcipher_req
->dst
),
101 vc_akcipher_req
->dst_buf
, akcipher_req
->dst_len
);
103 virtio_crypto_akcipher_finalize_req(vc_akcipher_req
, akcipher_req
, error
);
106 static int virtio_crypto_alg_akcipher_init_session(struct virtio_crypto_akcipher_ctx
*ctx
,
107 struct virtio_crypto_ctrl_header
*header
,
108 struct virtio_crypto_akcipher_session_para
*para
,
109 const uint8_t *key
, unsigned int keylen
)
111 struct scatterlist outhdr_sg
, key_sg
, inhdr_sg
, *sgs
[3];
112 struct virtio_crypto
*vcrypto
= ctx
->vcrypto
;
115 unsigned int num_out
= 0, num_in
= 0;
116 struct virtio_crypto_op_ctrl_req
*ctrl
;
117 struct virtio_crypto_session_input
*input
;
118 struct virtio_crypto_ctrl_request
*vc_ctrl_req
;
120 pkey
= kmemdup(key
, keylen
, GFP_KERNEL
);
124 vc_ctrl_req
= kzalloc(sizeof(*vc_ctrl_req
), GFP_KERNEL
);
130 ctrl
= &vc_ctrl_req
->ctrl
;
131 memcpy(&ctrl
->header
, header
, sizeof(ctrl
->header
));
132 memcpy(&ctrl
->u
.akcipher_create_session
.para
, para
, sizeof(*para
));
133 input
= &vc_ctrl_req
->input
;
134 input
->status
= cpu_to_le32(VIRTIO_CRYPTO_ERR
);
136 sg_init_one(&outhdr_sg
, ctrl
, sizeof(*ctrl
));
137 sgs
[num_out
++] = &outhdr_sg
;
139 sg_init_one(&key_sg
, pkey
, keylen
);
140 sgs
[num_out
++] = &key_sg
;
142 sg_init_one(&inhdr_sg
, input
, sizeof(*input
));
143 sgs
[num_out
+ num_in
++] = &inhdr_sg
;
145 err
= virtio_crypto_ctrl_vq_request(vcrypto
, sgs
, num_out
, num_in
, vc_ctrl_req
);
149 if (le32_to_cpu(input
->status
) != VIRTIO_CRYPTO_OK
) {
150 pr_err("virtio_crypto: Create session failed status: %u\n",
151 le32_to_cpu(input
->status
));
156 ctx
->session_id
= le64_to_cpu(input
->session_id
);
157 ctx
->session_valid
= true;
162 kfree_sensitive(pkey
);
167 static int virtio_crypto_alg_akcipher_close_session(struct virtio_crypto_akcipher_ctx
*ctx
)
169 struct scatterlist outhdr_sg
, inhdr_sg
, *sgs
[2];
170 struct virtio_crypto_destroy_session_req
*destroy_session
;
171 struct virtio_crypto
*vcrypto
= ctx
->vcrypto
;
172 unsigned int num_out
= 0, num_in
= 0;
174 struct virtio_crypto_op_ctrl_req
*ctrl
;
175 struct virtio_crypto_inhdr
*ctrl_status
;
176 struct virtio_crypto_ctrl_request
*vc_ctrl_req
;
178 if (!ctx
->session_valid
)
181 vc_ctrl_req
= kzalloc(sizeof(*vc_ctrl_req
), GFP_KERNEL
);
185 ctrl_status
= &vc_ctrl_req
->ctrl_status
;
186 ctrl_status
->status
= VIRTIO_CRYPTO_ERR
;
187 ctrl
= &vc_ctrl_req
->ctrl
;
188 ctrl
->header
.opcode
= cpu_to_le32(VIRTIO_CRYPTO_AKCIPHER_DESTROY_SESSION
);
189 ctrl
->header
.queue_id
= 0;
191 destroy_session
= &ctrl
->u
.destroy_session
;
192 destroy_session
->session_id
= cpu_to_le64(ctx
->session_id
);
194 sg_init_one(&outhdr_sg
, ctrl
, sizeof(*ctrl
));
195 sgs
[num_out
++] = &outhdr_sg
;
197 sg_init_one(&inhdr_sg
, &ctrl_status
->status
, sizeof(ctrl_status
->status
));
198 sgs
[num_out
+ num_in
++] = &inhdr_sg
;
200 err
= virtio_crypto_ctrl_vq_request(vcrypto
, sgs
, num_out
, num_in
, vc_ctrl_req
);
204 if (ctrl_status
->status
!= VIRTIO_CRYPTO_OK
) {
205 pr_err("virtio_crypto: Close session failed status: %u, session_id: 0x%llx\n",
206 ctrl_status
->status
, destroy_session
->session_id
);
212 ctx
->session_valid
= false;
220 static int __virtio_crypto_akcipher_do_req(struct virtio_crypto_akcipher_request
*vc_akcipher_req
,
221 struct akcipher_request
*req
, struct data_queue
*data_vq
)
223 struct virtio_crypto_akcipher_ctx
*ctx
= vc_akcipher_req
->akcipher_ctx
;
224 struct virtio_crypto_request
*vc_req
= &vc_akcipher_req
->base
;
225 struct virtio_crypto
*vcrypto
= ctx
->vcrypto
;
226 struct virtio_crypto_op_data_req
*req_data
= vc_req
->req_data
;
227 struct scatterlist
*sgs
[4], outhdr_sg
, inhdr_sg
, srcdata_sg
, dstdata_sg
;
228 void *src_buf
, *dst_buf
= NULL
;
229 unsigned int num_out
= 0, num_in
= 0;
230 int node
= dev_to_node(&vcrypto
->vdev
->dev
);
233 bool verify
= vc_akcipher_req
->opcode
== VIRTIO_CRYPTO_AKCIPHER_VERIFY
;
234 unsigned int src_len
= verify
? req
->src_len
+ req
->dst_len
: req
->src_len
;
237 sg_init_one(&outhdr_sg
, req_data
, sizeof(*req_data
));
238 sgs
[num_out
++] = &outhdr_sg
;
241 src_buf
= kcalloc_node(src_len
, 1, GFP_KERNEL
, node
);
246 /* for verify operation, both src and dst data work as OUT direction */
247 sg_copy_to_buffer(req
->src
, sg_nents(req
->src
), src_buf
, src_len
);
248 sg_init_one(&srcdata_sg
, src_buf
, src_len
);
249 sgs
[num_out
++] = &srcdata_sg
;
251 sg_copy_to_buffer(req
->src
, sg_nents(req
->src
), src_buf
, src_len
);
252 sg_init_one(&srcdata_sg
, src_buf
, src_len
);
253 sgs
[num_out
++] = &srcdata_sg
;
256 dst_buf
= kcalloc_node(req
->dst_len
, 1, GFP_KERNEL
, node
);
260 sg_init_one(&dstdata_sg
, dst_buf
, req
->dst_len
);
261 sgs
[num_out
+ num_in
++] = &dstdata_sg
;
264 vc_akcipher_req
->src_buf
= src_buf
;
265 vc_akcipher_req
->dst_buf
= dst_buf
;
268 sg_init_one(&inhdr_sg
, &vc_req
->status
, sizeof(vc_req
->status
));
269 sgs
[num_out
+ num_in
++] = &inhdr_sg
;
271 spin_lock_irqsave(&data_vq
->lock
, flags
);
272 ret
= virtqueue_add_sgs(data_vq
->vq
, sgs
, num_out
, num_in
, vc_req
, GFP_ATOMIC
);
273 virtqueue_kick(data_vq
->vq
);
274 spin_unlock_irqrestore(&data_vq
->lock
, flags
);
287 static int virtio_crypto_rsa_do_req(struct crypto_engine
*engine
, void *vreq
)
289 struct akcipher_request
*req
= container_of(vreq
, struct akcipher_request
, base
);
290 struct virtio_crypto_akcipher_request
*vc_akcipher_req
= akcipher_request_ctx(req
);
291 struct virtio_crypto_request
*vc_req
= &vc_akcipher_req
->base
;
292 struct virtio_crypto_akcipher_ctx
*ctx
= vc_akcipher_req
->akcipher_ctx
;
293 struct virtio_crypto
*vcrypto
= ctx
->vcrypto
;
294 struct data_queue
*data_vq
= vc_req
->dataq
;
295 struct virtio_crypto_op_header
*header
;
296 struct virtio_crypto_akcipher_data_req
*akcipher_req
;
300 vc_req
->req_data
= kzalloc_node(sizeof(*vc_req
->req_data
),
301 GFP_KERNEL
, dev_to_node(&vcrypto
->vdev
->dev
));
302 if (!vc_req
->req_data
)
305 /* build request header */
306 header
= &vc_req
->req_data
->header
;
307 header
->opcode
= cpu_to_le32(vc_akcipher_req
->opcode
);
308 header
->algo
= cpu_to_le32(VIRTIO_CRYPTO_AKCIPHER_RSA
);
309 header
->session_id
= cpu_to_le64(ctx
->session_id
);
311 /* build request akcipher data */
312 akcipher_req
= &vc_req
->req_data
->u
.akcipher_req
;
313 akcipher_req
->para
.src_data_len
= cpu_to_le32(req
->src_len
);
314 akcipher_req
->para
.dst_data_len
= cpu_to_le32(req
->dst_len
);
316 ret
= __virtio_crypto_akcipher_do_req(vc_akcipher_req
, req
, data_vq
);
318 kfree_sensitive(vc_req
->req_data
);
319 vc_req
->req_data
= NULL
;
326 static int virtio_crypto_rsa_req(struct akcipher_request
*req
, uint32_t opcode
)
328 struct crypto_akcipher
*atfm
= crypto_akcipher_reqtfm(req
);
329 struct virtio_crypto_akcipher_ctx
*ctx
= akcipher_tfm_ctx(atfm
);
330 struct virtio_crypto_akcipher_request
*vc_akcipher_req
= akcipher_request_ctx(req
);
331 struct virtio_crypto_request
*vc_req
= &vc_akcipher_req
->base
;
332 struct virtio_crypto
*vcrypto
= ctx
->vcrypto
;
333 /* Use the first data virtqueue as default */
334 struct data_queue
*data_vq
= &vcrypto
->data_vq
[0];
336 vc_req
->dataq
= data_vq
;
337 vc_req
->alg_cb
= virtio_crypto_dataq_akcipher_callback
;
338 vc_akcipher_req
->akcipher_ctx
= ctx
;
339 vc_akcipher_req
->akcipher_req
= req
;
340 vc_akcipher_req
->opcode
= opcode
;
342 return crypto_transfer_akcipher_request_to_engine(data_vq
->engine
, req
);
345 static int virtio_crypto_rsa_encrypt(struct akcipher_request
*req
)
347 return virtio_crypto_rsa_req(req
, VIRTIO_CRYPTO_AKCIPHER_ENCRYPT
);
350 static int virtio_crypto_rsa_decrypt(struct akcipher_request
*req
)
352 return virtio_crypto_rsa_req(req
, VIRTIO_CRYPTO_AKCIPHER_DECRYPT
);
355 static int virtio_crypto_rsa_sign(struct akcipher_request
*req
)
357 return virtio_crypto_rsa_req(req
, VIRTIO_CRYPTO_AKCIPHER_SIGN
);
360 static int virtio_crypto_rsa_verify(struct akcipher_request
*req
)
362 return virtio_crypto_rsa_req(req
, VIRTIO_CRYPTO_AKCIPHER_VERIFY
);
365 static int virtio_crypto_rsa_set_key(struct crypto_akcipher
*tfm
,
372 struct virtio_crypto_akcipher_ctx
*ctx
= akcipher_tfm_ctx(tfm
);
373 struct virtio_crypto_rsa_ctx
*rsa_ctx
= &ctx
->rsa_ctx
;
374 struct virtio_crypto
*vcrypto
;
375 struct virtio_crypto_ctrl_header header
;
376 struct virtio_crypto_akcipher_session_para para
;
377 struct rsa_key rsa_key
= {0};
378 int node
= virtio_crypto_get_current_node();
382 /* mpi_free will test n, just free it. */
383 mpi_free(rsa_ctx
->n
);
387 keytype
= VIRTIO_CRYPTO_AKCIPHER_KEY_TYPE_PRIVATE
;
388 ret
= rsa_parse_priv_key(&rsa_key
, key
, keylen
);
390 keytype
= VIRTIO_CRYPTO_AKCIPHER_KEY_TYPE_PUBLIC
;
391 ret
= rsa_parse_pub_key(&rsa_key
, key
, keylen
);
397 rsa_ctx
->n
= mpi_read_raw_data(rsa_key
.n
, rsa_key
.n_sz
);
402 vcrypto
= virtcrypto_get_dev_node(node
, VIRTIO_CRYPTO_SERVICE_AKCIPHER
,
403 VIRTIO_CRYPTO_AKCIPHER_RSA
);
405 pr_err("virtio_crypto: Could not find a virtio device in the system or unsupported algo\n");
409 ctx
->vcrypto
= vcrypto
;
411 virtio_crypto_alg_akcipher_close_session(ctx
);
414 /* set ctrl header */
415 header
.opcode
= cpu_to_le32(VIRTIO_CRYPTO_AKCIPHER_CREATE_SESSION
);
416 header
.algo
= cpu_to_le32(VIRTIO_CRYPTO_AKCIPHER_RSA
);
420 para
.algo
= cpu_to_le32(VIRTIO_CRYPTO_AKCIPHER_RSA
);
421 para
.keytype
= cpu_to_le32(keytype
);
422 para
.keylen
= cpu_to_le32(keylen
);
423 para
.u
.rsa
.padding_algo
= cpu_to_le32(padding_algo
);
424 para
.u
.rsa
.hash_algo
= cpu_to_le32(hash_algo
);
426 return virtio_crypto_alg_akcipher_init_session(ctx
, &header
, ¶
, key
, keylen
);
429 static int virtio_crypto_rsa_raw_set_priv_key(struct crypto_akcipher
*tfm
,
433 return virtio_crypto_rsa_set_key(tfm
, key
, keylen
, 1,
434 VIRTIO_CRYPTO_RSA_RAW_PADDING
,
435 VIRTIO_CRYPTO_RSA_NO_HASH
);
439 static int virtio_crypto_p1pad_rsa_sha1_set_priv_key(struct crypto_akcipher
*tfm
,
443 return virtio_crypto_rsa_set_key(tfm
, key
, keylen
, 1,
444 VIRTIO_CRYPTO_RSA_PKCS1_PADDING
,
445 VIRTIO_CRYPTO_RSA_SHA1
);
448 static int virtio_crypto_rsa_raw_set_pub_key(struct crypto_akcipher
*tfm
,
452 return virtio_crypto_rsa_set_key(tfm
, key
, keylen
, 0,
453 VIRTIO_CRYPTO_RSA_RAW_PADDING
,
454 VIRTIO_CRYPTO_RSA_NO_HASH
);
457 static int virtio_crypto_p1pad_rsa_sha1_set_pub_key(struct crypto_akcipher
*tfm
,
461 return virtio_crypto_rsa_set_key(tfm
, key
, keylen
, 0,
462 VIRTIO_CRYPTO_RSA_PKCS1_PADDING
,
463 VIRTIO_CRYPTO_RSA_SHA1
);
466 static unsigned int virtio_crypto_rsa_max_size(struct crypto_akcipher
*tfm
)
468 struct virtio_crypto_akcipher_ctx
*ctx
= akcipher_tfm_ctx(tfm
);
469 struct virtio_crypto_rsa_ctx
*rsa_ctx
= &ctx
->rsa_ctx
;
471 return mpi_get_size(rsa_ctx
->n
);
474 static int virtio_crypto_rsa_init_tfm(struct crypto_akcipher
*tfm
)
476 struct virtio_crypto_akcipher_ctx
*ctx
= akcipher_tfm_ctx(tfm
);
480 akcipher_set_reqsize(tfm
,
481 sizeof(struct virtio_crypto_akcipher_request
));
486 static void virtio_crypto_rsa_exit_tfm(struct crypto_akcipher
*tfm
)
488 struct virtio_crypto_akcipher_ctx
*ctx
= akcipher_tfm_ctx(tfm
);
489 struct virtio_crypto_rsa_ctx
*rsa_ctx
= &ctx
->rsa_ctx
;
491 virtio_crypto_alg_akcipher_close_session(ctx
);
492 virtcrypto_dev_put(ctx
->vcrypto
);
493 mpi_free(rsa_ctx
->n
);
497 static struct virtio_crypto_akcipher_algo virtio_crypto_akcipher_algs
[] = {
499 .algonum
= VIRTIO_CRYPTO_AKCIPHER_RSA
,
500 .service
= VIRTIO_CRYPTO_SERVICE_AKCIPHER
,
502 .encrypt
= virtio_crypto_rsa_encrypt
,
503 .decrypt
= virtio_crypto_rsa_decrypt
,
504 .set_pub_key
= virtio_crypto_rsa_raw_set_pub_key
,
505 .set_priv_key
= virtio_crypto_rsa_raw_set_priv_key
,
506 .max_size
= virtio_crypto_rsa_max_size
,
507 .init
= virtio_crypto_rsa_init_tfm
,
508 .exit
= virtio_crypto_rsa_exit_tfm
,
511 .cra_driver_name
= "virtio-crypto-rsa",
513 .cra_module
= THIS_MODULE
,
514 .cra_ctxsize
= sizeof(struct virtio_crypto_akcipher_ctx
),
518 .do_one_request
= virtio_crypto_rsa_do_req
,
522 .algonum
= VIRTIO_CRYPTO_AKCIPHER_RSA
,
523 .service
= VIRTIO_CRYPTO_SERVICE_AKCIPHER
,
525 .encrypt
= virtio_crypto_rsa_encrypt
,
526 .decrypt
= virtio_crypto_rsa_decrypt
,
527 .sign
= virtio_crypto_rsa_sign
,
528 .verify
= virtio_crypto_rsa_verify
,
529 .set_pub_key
= virtio_crypto_p1pad_rsa_sha1_set_pub_key
,
530 .set_priv_key
= virtio_crypto_p1pad_rsa_sha1_set_priv_key
,
531 .max_size
= virtio_crypto_rsa_max_size
,
532 .init
= virtio_crypto_rsa_init_tfm
,
533 .exit
= virtio_crypto_rsa_exit_tfm
,
535 .cra_name
= "pkcs1pad(rsa,sha1)",
536 .cra_driver_name
= "virtio-pkcs1-rsa-with-sha1",
538 .cra_module
= THIS_MODULE
,
539 .cra_ctxsize
= sizeof(struct virtio_crypto_akcipher_ctx
),
543 .do_one_request
= virtio_crypto_rsa_do_req
,
548 int virtio_crypto_akcipher_algs_register(struct virtio_crypto
*vcrypto
)
553 mutex_lock(&algs_lock
);
555 for (i
= 0; i
< ARRAY_SIZE(virtio_crypto_akcipher_algs
); i
++) {
556 uint32_t service
= virtio_crypto_akcipher_algs
[i
].service
;
557 uint32_t algonum
= virtio_crypto_akcipher_algs
[i
].algonum
;
559 if (!virtcrypto_algo_is_supported(vcrypto
, service
, algonum
))
562 if (virtio_crypto_akcipher_algs
[i
].active_devs
== 0) {
563 ret
= crypto_engine_register_akcipher(&virtio_crypto_akcipher_algs
[i
].algo
);
568 virtio_crypto_akcipher_algs
[i
].active_devs
++;
569 dev_info(&vcrypto
->vdev
->dev
, "Registered akcipher algo %s\n",
570 virtio_crypto_akcipher_algs
[i
].algo
.base
.base
.cra_name
);
574 mutex_unlock(&algs_lock
);
578 void virtio_crypto_akcipher_algs_unregister(struct virtio_crypto
*vcrypto
)
582 mutex_lock(&algs_lock
);
584 for (i
= 0; i
< ARRAY_SIZE(virtio_crypto_akcipher_algs
); i
++) {
585 uint32_t service
= virtio_crypto_akcipher_algs
[i
].service
;
586 uint32_t algonum
= virtio_crypto_akcipher_algs
[i
].algonum
;
588 if (virtio_crypto_akcipher_algs
[i
].active_devs
== 0 ||
589 !virtcrypto_algo_is_supported(vcrypto
, service
, algonum
))
592 if (virtio_crypto_akcipher_algs
[i
].active_devs
== 1)
593 crypto_engine_unregister_akcipher(&virtio_crypto_akcipher_algs
[i
].algo
);
595 virtio_crypto_akcipher_algs
[i
].active_devs
--;
598 mutex_unlock(&algs_lock
);