1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /* Asymmetric algorithms supported by virtio crypto device
4 * Authors: zhenwei pi <pizhenwei@bytedance.com>
5 * lei he <helei.sig11@bytedance.com>
7 * Copyright 2022 Bytedance CO., LTD.
10 #include <crypto/engine.h>
11 #include <crypto/internal/akcipher.h>
12 #include <crypto/internal/rsa.h>
13 #include <crypto/scatterwalk.h>
14 #include <linux/err.h>
15 #include <linux/kernel.h>
16 #include <linux/mpi.h>
17 #include <linux/scatterlist.h>
18 #include <linux/slab.h>
19 #include <linux/string.h>
20 #include <uapi/linux/virtio_crypto.h>
21 #include "virtio_crypto_common.h"
23 struct virtio_crypto_rsa_ctx
{
27 struct virtio_crypto_akcipher_ctx
{
28 struct virtio_crypto
*vcrypto
;
29 struct crypto_akcipher
*tfm
;
33 struct virtio_crypto_rsa_ctx rsa_ctx
;
37 struct virtio_crypto_akcipher_request
{
38 struct virtio_crypto_request base
;
39 struct virtio_crypto_akcipher_ctx
*akcipher_ctx
;
40 struct akcipher_request
*akcipher_req
;
46 struct virtio_crypto_akcipher_algo
{
49 unsigned int active_devs
;
50 struct akcipher_engine_alg algo
;
53 static DEFINE_MUTEX(algs_lock
);
55 static void virtio_crypto_akcipher_finalize_req(
56 struct virtio_crypto_akcipher_request
*vc_akcipher_req
,
57 struct akcipher_request
*req
, int err
)
59 kfree(vc_akcipher_req
->src_buf
);
60 kfree(vc_akcipher_req
->dst_buf
);
61 vc_akcipher_req
->src_buf
= NULL
;
62 vc_akcipher_req
->dst_buf
= NULL
;
63 virtcrypto_clear_request(&vc_akcipher_req
->base
);
65 crypto_finalize_akcipher_request(vc_akcipher_req
->base
.dataq
->engine
, req
, err
);
68 static void virtio_crypto_dataq_akcipher_callback(struct virtio_crypto_request
*vc_req
, int len
)
70 struct virtio_crypto_akcipher_request
*vc_akcipher_req
=
71 container_of(vc_req
, struct virtio_crypto_akcipher_request
, base
);
72 struct akcipher_request
*akcipher_req
;
75 switch (vc_req
->status
) {
76 case VIRTIO_CRYPTO_OK
:
79 case VIRTIO_CRYPTO_INVSESS
:
80 case VIRTIO_CRYPTO_ERR
:
83 case VIRTIO_CRYPTO_BADMSG
:
91 akcipher_req
= vc_akcipher_req
->akcipher_req
;
92 /* actual length maybe less than dst buffer */
93 akcipher_req
->dst_len
= len
- sizeof(vc_req
->status
);
94 sg_copy_from_buffer(akcipher_req
->dst
, sg_nents(akcipher_req
->dst
),
95 vc_akcipher_req
->dst_buf
, akcipher_req
->dst_len
);
96 virtio_crypto_akcipher_finalize_req(vc_akcipher_req
, akcipher_req
, error
);
99 static int virtio_crypto_alg_akcipher_init_session(struct virtio_crypto_akcipher_ctx
*ctx
,
100 struct virtio_crypto_ctrl_header
*header
,
101 struct virtio_crypto_akcipher_session_para
*para
,
102 const uint8_t *key
, unsigned int keylen
)
104 struct scatterlist outhdr_sg
, key_sg
, inhdr_sg
, *sgs
[3];
105 struct virtio_crypto
*vcrypto
= ctx
->vcrypto
;
108 unsigned int num_out
= 0, num_in
= 0;
109 struct virtio_crypto_op_ctrl_req
*ctrl
;
110 struct virtio_crypto_session_input
*input
;
111 struct virtio_crypto_ctrl_request
*vc_ctrl_req
;
113 pkey
= kmemdup(key
, keylen
, GFP_KERNEL
);
117 vc_ctrl_req
= kzalloc(sizeof(*vc_ctrl_req
), GFP_KERNEL
);
123 ctrl
= &vc_ctrl_req
->ctrl
;
124 memcpy(&ctrl
->header
, header
, sizeof(ctrl
->header
));
125 memcpy(&ctrl
->u
.akcipher_create_session
.para
, para
, sizeof(*para
));
126 input
= &vc_ctrl_req
->input
;
127 input
->status
= cpu_to_le32(VIRTIO_CRYPTO_ERR
);
129 sg_init_one(&outhdr_sg
, ctrl
, sizeof(*ctrl
));
130 sgs
[num_out
++] = &outhdr_sg
;
132 sg_init_one(&key_sg
, pkey
, keylen
);
133 sgs
[num_out
++] = &key_sg
;
135 sg_init_one(&inhdr_sg
, input
, sizeof(*input
));
136 sgs
[num_out
+ num_in
++] = &inhdr_sg
;
138 err
= virtio_crypto_ctrl_vq_request(vcrypto
, sgs
, num_out
, num_in
, vc_ctrl_req
);
142 if (le32_to_cpu(input
->status
) != VIRTIO_CRYPTO_OK
) {
143 pr_err("virtio_crypto: Create session failed status: %u\n",
144 le32_to_cpu(input
->status
));
149 ctx
->session_id
= le64_to_cpu(input
->session_id
);
150 ctx
->session_valid
= true;
155 kfree_sensitive(pkey
);
160 static int virtio_crypto_alg_akcipher_close_session(struct virtio_crypto_akcipher_ctx
*ctx
)
162 struct scatterlist outhdr_sg
, inhdr_sg
, *sgs
[2];
163 struct virtio_crypto_destroy_session_req
*destroy_session
;
164 struct virtio_crypto
*vcrypto
= ctx
->vcrypto
;
165 unsigned int num_out
= 0, num_in
= 0;
167 struct virtio_crypto_op_ctrl_req
*ctrl
;
168 struct virtio_crypto_inhdr
*ctrl_status
;
169 struct virtio_crypto_ctrl_request
*vc_ctrl_req
;
171 if (!ctx
->session_valid
)
174 vc_ctrl_req
= kzalloc(sizeof(*vc_ctrl_req
), GFP_KERNEL
);
178 ctrl_status
= &vc_ctrl_req
->ctrl_status
;
179 ctrl_status
->status
= VIRTIO_CRYPTO_ERR
;
180 ctrl
= &vc_ctrl_req
->ctrl
;
181 ctrl
->header
.opcode
= cpu_to_le32(VIRTIO_CRYPTO_AKCIPHER_DESTROY_SESSION
);
182 ctrl
->header
.queue_id
= 0;
184 destroy_session
= &ctrl
->u
.destroy_session
;
185 destroy_session
->session_id
= cpu_to_le64(ctx
->session_id
);
187 sg_init_one(&outhdr_sg
, ctrl
, sizeof(*ctrl
));
188 sgs
[num_out
++] = &outhdr_sg
;
190 sg_init_one(&inhdr_sg
, &ctrl_status
->status
, sizeof(ctrl_status
->status
));
191 sgs
[num_out
+ num_in
++] = &inhdr_sg
;
193 err
= virtio_crypto_ctrl_vq_request(vcrypto
, sgs
, num_out
, num_in
, vc_ctrl_req
);
197 if (ctrl_status
->status
!= VIRTIO_CRYPTO_OK
) {
198 pr_err("virtio_crypto: Close session failed status: %u, session_id: 0x%llx\n",
199 ctrl_status
->status
, destroy_session
->session_id
);
205 ctx
->session_valid
= false;
213 static int __virtio_crypto_akcipher_do_req(struct virtio_crypto_akcipher_request
*vc_akcipher_req
,
214 struct akcipher_request
*req
, struct data_queue
*data_vq
)
216 struct virtio_crypto_akcipher_ctx
*ctx
= vc_akcipher_req
->akcipher_ctx
;
217 struct virtio_crypto_request
*vc_req
= &vc_akcipher_req
->base
;
218 struct virtio_crypto
*vcrypto
= ctx
->vcrypto
;
219 struct virtio_crypto_op_data_req
*req_data
= vc_req
->req_data
;
220 struct scatterlist
*sgs
[4], outhdr_sg
, inhdr_sg
, srcdata_sg
, dstdata_sg
;
221 void *src_buf
, *dst_buf
= NULL
;
222 unsigned int num_out
= 0, num_in
= 0;
223 int node
= dev_to_node(&vcrypto
->vdev
->dev
);
228 sg_init_one(&outhdr_sg
, req_data
, sizeof(*req_data
));
229 sgs
[num_out
++] = &outhdr_sg
;
232 src_buf
= kcalloc_node(req
->src_len
, 1, GFP_KERNEL
, node
);
236 sg_copy_to_buffer(req
->src
, sg_nents(req
->src
), src_buf
, req
->src_len
);
237 sg_init_one(&srcdata_sg
, src_buf
, req
->src_len
);
238 sgs
[num_out
++] = &srcdata_sg
;
241 dst_buf
= kcalloc_node(req
->dst_len
, 1, GFP_KERNEL
, node
);
245 sg_init_one(&dstdata_sg
, dst_buf
, req
->dst_len
);
246 sgs
[num_out
+ num_in
++] = &dstdata_sg
;
248 vc_akcipher_req
->src_buf
= src_buf
;
249 vc_akcipher_req
->dst_buf
= dst_buf
;
252 sg_init_one(&inhdr_sg
, &vc_req
->status
, sizeof(vc_req
->status
));
253 sgs
[num_out
+ num_in
++] = &inhdr_sg
;
255 spin_lock_irqsave(&data_vq
->lock
, flags
);
256 ret
= virtqueue_add_sgs(data_vq
->vq
, sgs
, num_out
, num_in
, vc_req
, GFP_ATOMIC
);
257 virtqueue_kick(data_vq
->vq
);
258 spin_unlock_irqrestore(&data_vq
->lock
, flags
);
271 static int virtio_crypto_rsa_do_req(struct crypto_engine
*engine
, void *vreq
)
273 struct akcipher_request
*req
= container_of(vreq
, struct akcipher_request
, base
);
274 struct virtio_crypto_akcipher_request
*vc_akcipher_req
= akcipher_request_ctx(req
);
275 struct virtio_crypto_request
*vc_req
= &vc_akcipher_req
->base
;
276 struct virtio_crypto_akcipher_ctx
*ctx
= vc_akcipher_req
->akcipher_ctx
;
277 struct virtio_crypto
*vcrypto
= ctx
->vcrypto
;
278 struct data_queue
*data_vq
= vc_req
->dataq
;
279 struct virtio_crypto_op_header
*header
;
280 struct virtio_crypto_akcipher_data_req
*akcipher_req
;
284 vc_req
->req_data
= kzalloc_node(sizeof(*vc_req
->req_data
),
285 GFP_KERNEL
, dev_to_node(&vcrypto
->vdev
->dev
));
286 if (!vc_req
->req_data
)
289 /* build request header */
290 header
= &vc_req
->req_data
->header
;
291 header
->opcode
= cpu_to_le32(vc_akcipher_req
->opcode
);
292 header
->algo
= cpu_to_le32(VIRTIO_CRYPTO_AKCIPHER_RSA
);
293 header
->session_id
= cpu_to_le64(ctx
->session_id
);
295 /* build request akcipher data */
296 akcipher_req
= &vc_req
->req_data
->u
.akcipher_req
;
297 akcipher_req
->para
.src_data_len
= cpu_to_le32(req
->src_len
);
298 akcipher_req
->para
.dst_data_len
= cpu_to_le32(req
->dst_len
);
300 ret
= __virtio_crypto_akcipher_do_req(vc_akcipher_req
, req
, data_vq
);
302 kfree_sensitive(vc_req
->req_data
);
303 vc_req
->req_data
= NULL
;
310 static int virtio_crypto_rsa_req(struct akcipher_request
*req
, uint32_t opcode
)
312 struct crypto_akcipher
*atfm
= crypto_akcipher_reqtfm(req
);
313 struct virtio_crypto_akcipher_ctx
*ctx
= akcipher_tfm_ctx(atfm
);
314 struct virtio_crypto_akcipher_request
*vc_akcipher_req
= akcipher_request_ctx(req
);
315 struct virtio_crypto_request
*vc_req
= &vc_akcipher_req
->base
;
316 struct virtio_crypto
*vcrypto
= ctx
->vcrypto
;
317 /* Use the first data virtqueue as default */
318 struct data_queue
*data_vq
= &vcrypto
->data_vq
[0];
320 vc_req
->dataq
= data_vq
;
321 vc_req
->alg_cb
= virtio_crypto_dataq_akcipher_callback
;
322 vc_akcipher_req
->akcipher_ctx
= ctx
;
323 vc_akcipher_req
->akcipher_req
= req
;
324 vc_akcipher_req
->opcode
= opcode
;
326 return crypto_transfer_akcipher_request_to_engine(data_vq
->engine
, req
);
329 static int virtio_crypto_rsa_encrypt(struct akcipher_request
*req
)
331 return virtio_crypto_rsa_req(req
, VIRTIO_CRYPTO_AKCIPHER_ENCRYPT
);
334 static int virtio_crypto_rsa_decrypt(struct akcipher_request
*req
)
336 return virtio_crypto_rsa_req(req
, VIRTIO_CRYPTO_AKCIPHER_DECRYPT
);
339 static int virtio_crypto_rsa_set_key(struct crypto_akcipher
*tfm
,
346 struct virtio_crypto_akcipher_ctx
*ctx
= akcipher_tfm_ctx(tfm
);
347 struct virtio_crypto_rsa_ctx
*rsa_ctx
= &ctx
->rsa_ctx
;
348 struct virtio_crypto
*vcrypto
;
349 struct virtio_crypto_ctrl_header header
;
350 struct virtio_crypto_akcipher_session_para para
;
351 struct rsa_key rsa_key
= {0};
352 int node
= virtio_crypto_get_current_node();
356 /* mpi_free will test n, just free it. */
357 mpi_free(rsa_ctx
->n
);
361 keytype
= VIRTIO_CRYPTO_AKCIPHER_KEY_TYPE_PRIVATE
;
362 ret
= rsa_parse_priv_key(&rsa_key
, key
, keylen
);
364 keytype
= VIRTIO_CRYPTO_AKCIPHER_KEY_TYPE_PUBLIC
;
365 ret
= rsa_parse_pub_key(&rsa_key
, key
, keylen
);
371 rsa_ctx
->n
= mpi_read_raw_data(rsa_key
.n
, rsa_key
.n_sz
);
376 vcrypto
= virtcrypto_get_dev_node(node
, VIRTIO_CRYPTO_SERVICE_AKCIPHER
,
377 VIRTIO_CRYPTO_AKCIPHER_RSA
);
379 pr_err("virtio_crypto: Could not find a virtio device in the system or unsupported algo\n");
383 ctx
->vcrypto
= vcrypto
;
385 virtio_crypto_alg_akcipher_close_session(ctx
);
388 /* set ctrl header */
389 header
.opcode
= cpu_to_le32(VIRTIO_CRYPTO_AKCIPHER_CREATE_SESSION
);
390 header
.algo
= cpu_to_le32(VIRTIO_CRYPTO_AKCIPHER_RSA
);
394 para
.algo
= cpu_to_le32(VIRTIO_CRYPTO_AKCIPHER_RSA
);
395 para
.keytype
= cpu_to_le32(keytype
);
396 para
.keylen
= cpu_to_le32(keylen
);
397 para
.u
.rsa
.padding_algo
= cpu_to_le32(padding_algo
);
398 para
.u
.rsa
.hash_algo
= cpu_to_le32(hash_algo
);
400 return virtio_crypto_alg_akcipher_init_session(ctx
, &header
, ¶
, key
, keylen
);
403 static int virtio_crypto_rsa_raw_set_priv_key(struct crypto_akcipher
*tfm
,
407 return virtio_crypto_rsa_set_key(tfm
, key
, keylen
, 1,
408 VIRTIO_CRYPTO_RSA_RAW_PADDING
,
409 VIRTIO_CRYPTO_RSA_NO_HASH
);
413 static int virtio_crypto_p1pad_rsa_sha1_set_priv_key(struct crypto_akcipher
*tfm
,
417 return virtio_crypto_rsa_set_key(tfm
, key
, keylen
, 1,
418 VIRTIO_CRYPTO_RSA_PKCS1_PADDING
,
419 VIRTIO_CRYPTO_RSA_SHA1
);
422 static int virtio_crypto_rsa_raw_set_pub_key(struct crypto_akcipher
*tfm
,
426 return virtio_crypto_rsa_set_key(tfm
, key
, keylen
, 0,
427 VIRTIO_CRYPTO_RSA_RAW_PADDING
,
428 VIRTIO_CRYPTO_RSA_NO_HASH
);
431 static int virtio_crypto_p1pad_rsa_sha1_set_pub_key(struct crypto_akcipher
*tfm
,
435 return virtio_crypto_rsa_set_key(tfm
, key
, keylen
, 0,
436 VIRTIO_CRYPTO_RSA_PKCS1_PADDING
,
437 VIRTIO_CRYPTO_RSA_SHA1
);
440 static unsigned int virtio_crypto_rsa_max_size(struct crypto_akcipher
*tfm
)
442 struct virtio_crypto_akcipher_ctx
*ctx
= akcipher_tfm_ctx(tfm
);
443 struct virtio_crypto_rsa_ctx
*rsa_ctx
= &ctx
->rsa_ctx
;
445 return mpi_get_size(rsa_ctx
->n
);
448 static int virtio_crypto_rsa_init_tfm(struct crypto_akcipher
*tfm
)
450 struct virtio_crypto_akcipher_ctx
*ctx
= akcipher_tfm_ctx(tfm
);
454 akcipher_set_reqsize(tfm
,
455 sizeof(struct virtio_crypto_akcipher_request
));
460 static void virtio_crypto_rsa_exit_tfm(struct crypto_akcipher
*tfm
)
462 struct virtio_crypto_akcipher_ctx
*ctx
= akcipher_tfm_ctx(tfm
);
463 struct virtio_crypto_rsa_ctx
*rsa_ctx
= &ctx
->rsa_ctx
;
465 virtio_crypto_alg_akcipher_close_session(ctx
);
466 virtcrypto_dev_put(ctx
->vcrypto
);
467 mpi_free(rsa_ctx
->n
);
471 static struct virtio_crypto_akcipher_algo virtio_crypto_akcipher_algs
[] = {
473 .algonum
= VIRTIO_CRYPTO_AKCIPHER_RSA
,
474 .service
= VIRTIO_CRYPTO_SERVICE_AKCIPHER
,
476 .encrypt
= virtio_crypto_rsa_encrypt
,
477 .decrypt
= virtio_crypto_rsa_decrypt
,
478 .set_pub_key
= virtio_crypto_rsa_raw_set_pub_key
,
479 .set_priv_key
= virtio_crypto_rsa_raw_set_priv_key
,
480 .max_size
= virtio_crypto_rsa_max_size
,
481 .init
= virtio_crypto_rsa_init_tfm
,
482 .exit
= virtio_crypto_rsa_exit_tfm
,
485 .cra_driver_name
= "virtio-crypto-rsa",
487 .cra_module
= THIS_MODULE
,
488 .cra_ctxsize
= sizeof(struct virtio_crypto_akcipher_ctx
),
492 .do_one_request
= virtio_crypto_rsa_do_req
,
496 .algonum
= VIRTIO_CRYPTO_AKCIPHER_RSA
,
497 .service
= VIRTIO_CRYPTO_SERVICE_AKCIPHER
,
499 .encrypt
= virtio_crypto_rsa_encrypt
,
500 .decrypt
= virtio_crypto_rsa_decrypt
,
502 * Must specify an arbitrary hash algorithm upon
503 * set_{pub,priv}_key (even though it's not used
504 * by encrypt/decrypt) because qemu checks for it.
506 .set_pub_key
= virtio_crypto_p1pad_rsa_sha1_set_pub_key
,
507 .set_priv_key
= virtio_crypto_p1pad_rsa_sha1_set_priv_key
,
508 .max_size
= virtio_crypto_rsa_max_size
,
509 .init
= virtio_crypto_rsa_init_tfm
,
510 .exit
= virtio_crypto_rsa_exit_tfm
,
512 .cra_name
= "pkcs1pad(rsa)",
513 .cra_driver_name
= "virtio-pkcs1-rsa",
515 .cra_module
= THIS_MODULE
,
516 .cra_ctxsize
= sizeof(struct virtio_crypto_akcipher_ctx
),
520 .do_one_request
= virtio_crypto_rsa_do_req
,
525 int virtio_crypto_akcipher_algs_register(struct virtio_crypto
*vcrypto
)
530 mutex_lock(&algs_lock
);
532 for (i
= 0; i
< ARRAY_SIZE(virtio_crypto_akcipher_algs
); i
++) {
533 uint32_t service
= virtio_crypto_akcipher_algs
[i
].service
;
534 uint32_t algonum
= virtio_crypto_akcipher_algs
[i
].algonum
;
536 if (!virtcrypto_algo_is_supported(vcrypto
, service
, algonum
))
539 if (virtio_crypto_akcipher_algs
[i
].active_devs
== 0) {
540 ret
= crypto_engine_register_akcipher(&virtio_crypto_akcipher_algs
[i
].algo
);
545 virtio_crypto_akcipher_algs
[i
].active_devs
++;
546 dev_info(&vcrypto
->vdev
->dev
, "Registered akcipher algo %s\n",
547 virtio_crypto_akcipher_algs
[i
].algo
.base
.base
.cra_name
);
551 mutex_unlock(&algs_lock
);
555 void virtio_crypto_akcipher_algs_unregister(struct virtio_crypto
*vcrypto
)
559 mutex_lock(&algs_lock
);
561 for (i
= 0; i
< ARRAY_SIZE(virtio_crypto_akcipher_algs
); i
++) {
562 uint32_t service
= virtio_crypto_akcipher_algs
[i
].service
;
563 uint32_t algonum
= virtio_crypto_akcipher_algs
[i
].algonum
;
565 if (virtio_crypto_akcipher_algs
[i
].active_devs
== 0 ||
566 !virtcrypto_algo_is_supported(vcrypto
, service
, algonum
))
569 if (virtio_crypto_akcipher_algs
[i
].active_devs
== 1)
570 crypto_engine_unregister_akcipher(&virtio_crypto_akcipher_algs
[i
].algo
);
572 virtio_crypto_akcipher_algs
[i
].active_devs
--;
575 mutex_unlock(&algs_lock
);