1 // SPDX-License-Identifier: GPL-2.0
3 * Amazon Nitro Secure Module driver.
5 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
7 * The Nitro Secure Module implements commands via CBOR over virtio.
8 * This driver exposes a raw message ioctls on /dev/nsm that user
9 * space can use to issue these commands.
12 #include <linux/file.h>
14 #include <linux/interrupt.h>
15 #include <linux/hw_random.h>
16 #include <linux/miscdevice.h>
17 #include <linux/module.h>
18 #include <linux/mutex.h>
19 #include <linux/slab.h>
20 #include <linux/string.h>
21 #include <linux/uaccess.h>
22 #include <linux/uio.h>
23 #include <linux/virtio_config.h>
24 #include <linux/virtio_ids.h>
25 #include <linux/virtio.h>
26 #include <linux/wait.h>
27 #include <uapi/linux/nsm.h>
29 /* Timeout for NSM virtqueue respose in milliseconds. */
30 #define NSM_DEFAULT_TIMEOUT_MSECS (120000) /* 2 minutes */
32 /* Maximum length input data */
35 u8 data
[NSM_REQUEST_MAX_SIZE
];
38 /* Maximum length output data */
39 struct nsm_data_resp
{
41 u8 data
[NSM_RESPONSE_MAX_SIZE
];
44 /* Full NSM request/response message */
46 struct nsm_data_req req
;
47 struct nsm_data_resp resp
;
51 struct virtio_device
*vdev
;
54 struct completion cmd_done
;
55 struct miscdevice misc
;
57 struct work_struct misc_init
;
62 static const struct virtio_device_id id_table
[] = {
63 { VIRTIO_ID_NITRO_SEC_MOD
, VIRTIO_DEV_ANY_ID
},
67 static struct nsm
*file_to_nsm(struct file
*file
)
69 return container_of(file
->private_data
, struct nsm
, misc
);
72 static struct nsm
*hwrng_to_nsm(struct hwrng
*rng
)
74 return container_of(rng
, struct nsm
, hwrng
);
77 #define CBOR_TYPE_MASK 0xE0
78 #define CBOR_TYPE_MAP 0xA0
79 #define CBOR_TYPE_TEXT 0x60
80 #define CBOR_TYPE_ARRAY 0x40
81 #define CBOR_HEADER_SIZE_SHORT 1
83 #define CBOR_SHORT_SIZE_MAX_VALUE 23
84 #define CBOR_LONG_SIZE_U8 24
85 #define CBOR_LONG_SIZE_U16 25
86 #define CBOR_LONG_SIZE_U32 26
87 #define CBOR_LONG_SIZE_U64 27
89 static bool cbor_object_is_array(const u8
*cbor_object
, size_t cbor_object_size
)
91 if (cbor_object_size
== 0 || cbor_object
== NULL
)
94 return (cbor_object
[0] & CBOR_TYPE_MASK
) == CBOR_TYPE_ARRAY
;
97 static int cbor_object_get_array(u8
*cbor_object
, size_t cbor_object_size
, u8
**cbor_array
)
104 if (!cbor_object_is_array(cbor_object
, cbor_object_size
))
107 cbor_short_size
= (cbor_object
[0] & 0x1F);
109 /* Decoding byte array length */
110 array_offset
= CBOR_HEADER_SIZE_SHORT
;
111 if (cbor_short_size
>= CBOR_LONG_SIZE_U8
)
112 array_offset
+= BIT(cbor_short_size
- CBOR_LONG_SIZE_U8
);
114 if (cbor_object_size
< array_offset
)
117 array_len_p
= &cbor_object
[1];
119 switch (cbor_short_size
) {
120 case CBOR_SHORT_SIZE_MAX_VALUE
: /* short encoding */
121 array_len
= cbor_short_size
;
123 case CBOR_LONG_SIZE_U8
:
124 array_len
= *(u8
*)array_len_p
;
126 case CBOR_LONG_SIZE_U16
:
127 array_len
= be16_to_cpup((__be16
*)array_len_p
);
129 case CBOR_LONG_SIZE_U32
:
130 array_len
= be32_to_cpup((__be32
*)array_len_p
);
132 case CBOR_LONG_SIZE_U64
:
133 array_len
= be64_to_cpup((__be64
*)array_len_p
);
137 if (cbor_object_size
< array_offset
)
140 if (cbor_object_size
- array_offset
< array_len
)
143 if (array_len
> INT_MAX
)
146 *cbor_array
= cbor_object
+ array_offset
;
150 /* Copy the request of a raw message to kernel space */
151 static int fill_req_raw(struct nsm
*nsm
, struct nsm_data_req
*req
,
154 /* Verify the user input size. */
155 if (raw
->request
.len
> sizeof(req
->data
))
158 /* Copy the request payload */
159 if (copy_from_user(req
->data
, u64_to_user_ptr(raw
->request
.addr
),
163 req
->len
= raw
->request
.len
;
168 /* Copy the response of a raw message back to user-space */
169 static int parse_resp_raw(struct nsm
*nsm
, struct nsm_data_resp
*resp
,
172 /* Truncate any message that does not fit. */
173 raw
->response
.len
= min_t(u64
, raw
->response
.len
, resp
->len
);
175 /* Copy the response content to user space */
176 if (copy_to_user(u64_to_user_ptr(raw
->response
.addr
),
177 resp
->data
, raw
->response
.len
))
183 /* Virtqueue interrupt handler */
184 static void nsm_vq_callback(struct virtqueue
*vq
)
186 struct nsm
*nsm
= vq
->vdev
->priv
;
188 complete(&nsm
->cmd_done
);
191 /* Forward a message to the NSM device and wait for the response from it */
192 static int nsm_sendrecv_msg_locked(struct nsm
*nsm
)
194 struct device
*dev
= &nsm
->vdev
->dev
;
195 struct scatterlist sg_in
, sg_out
;
196 struct nsm_msg
*msg
= &nsm
->msg
;
197 struct virtqueue
*vq
= nsm
->vq
;
203 /* Initialize scatter-gather lists with request and response buffers. */
204 sg_init_one(&sg_out
, msg
->req
.data
, msg
->req
.len
);
205 sg_init_one(&sg_in
, msg
->resp
.data
, sizeof(msg
->resp
.data
));
207 init_completion(&nsm
->cmd_done
);
208 /* Add the request buffer (read by the device). */
209 rc
= virtqueue_add_outbuf(vq
, &sg_out
, 1, msg
->req
.data
, GFP_KERNEL
);
213 /* Add the response buffer (written by the device). */
214 rc
= virtqueue_add_inbuf(vq
, &sg_in
, 1, msg
->resp
.data
, GFP_KERNEL
);
218 kicked
= virtqueue_kick(vq
);
220 /* Cannot kick the virtqueue. */
225 /* If the kick succeeded, wait for the device's response. */
226 if (!wait_for_completion_io_timeout(&nsm
->cmd_done
,
227 msecs_to_jiffies(NSM_DEFAULT_TIMEOUT_MSECS
))) {
232 queue_buf
= virtqueue_get_buf(vq
, &len
);
233 if (!queue_buf
|| (queue_buf
!= msg
->req
.data
)) {
234 dev_err(dev
, "wrong request buffer.");
239 queue_buf
= virtqueue_get_buf(vq
, &len
);
240 if (!queue_buf
|| (queue_buf
!= msg
->resp
.data
)) {
241 dev_err(dev
, "wrong response buffer.");
252 /* Clean the virtqueue. */
253 while (virtqueue_get_buf(vq
, &len
) != NULL
)
260 static int fill_req_get_random(struct nsm
*nsm
, struct nsm_data_req
*req
)
264 * 47657452616E646F6D # "GetRandom"
266 const u8 request
[] = { CBOR_TYPE_TEXT
+ strlen("GetRandom"),
267 'G', 'e', 't', 'R', 'a', 'n', 'd', 'o', 'm' };
269 memcpy(req
->data
, request
, sizeof(request
));
270 req
->len
= sizeof(request
);
275 static int parse_resp_get_random(struct nsm
*nsm
, struct nsm_data_resp
*resp
,
276 void *out
, size_t max
)
280 * 69 # text(9) - Name of field
281 * 47657452616E646F6D # "GetRandom"
282 * A1 # map(1) - The field itself
284 * 72616E646F6D # "random"
285 * # The rest of the response is random data
287 const u8 response
[] = { CBOR_TYPE_MAP
+ 1,
288 CBOR_TYPE_TEXT
+ strlen("GetRandom"),
289 'G', 'e', 't', 'R', 'a', 'n', 'd', 'o', 'm',
291 CBOR_TYPE_TEXT
+ strlen("random"),
292 'r', 'a', 'n', 'd', 'o', 'm' };
293 struct device
*dev
= &nsm
->vdev
->dev
;
294 u8
*rand_data
= NULL
;
295 u8
*resp_ptr
= resp
->data
;
296 u64 resp_len
= resp
->len
;
299 if ((resp
->len
< sizeof(response
) + 1) ||
300 (memcmp(resp_ptr
, response
, sizeof(response
)) != 0)) {
301 dev_err(dev
, "Invalid response for GetRandom");
305 resp_ptr
+= sizeof(response
);
306 resp_len
-= sizeof(response
);
308 rc
= cbor_object_get_array(resp_ptr
, resp_len
, &rand_data
);
310 dev_err(dev
, "GetRandom: Invalid CBOR encoding\n");
314 rc
= min_t(size_t, rc
, max
);
315 memcpy(out
, rand_data
, rc
);
321 * HwRNG implementation
323 static int nsm_rng_read(struct hwrng
*rng
, void *data
, size_t max
, bool wait
)
325 struct nsm
*nsm
= hwrng_to_nsm(rng
);
326 struct device
*dev
= &nsm
->vdev
->dev
;
329 /* NSM always needs to wait for a response */
333 mutex_lock(&nsm
->lock
);
335 rc
= fill_req_get_random(nsm
, &nsm
->msg
.req
);
339 rc
= nsm_sendrecv_msg_locked(nsm
);
343 rc
= parse_resp_get_random(nsm
, &nsm
->msg
.resp
, data
, max
);
347 dev_dbg(dev
, "RNG: returning rand bytes = %d", rc
);
349 mutex_unlock(&nsm
->lock
);
353 static long nsm_dev_ioctl(struct file
*file
, unsigned int cmd
,
356 void __user
*argp
= u64_to_user_ptr((u64
)arg
);
357 struct nsm
*nsm
= file_to_nsm(file
);
361 if (cmd
!= NSM_IOCTL_RAW
)
364 if (_IOC_SIZE(cmd
) != sizeof(raw
))
367 /* Copy user argument struct to kernel argument struct */
369 if (copy_from_user(&raw
, argp
, _IOC_SIZE(cmd
)))
372 mutex_lock(&nsm
->lock
);
374 /* Convert kernel argument struct to device request */
375 r
= fill_req_raw(nsm
, &nsm
->msg
.req
, &raw
);
379 /* Send message to NSM and read reply */
380 r
= nsm_sendrecv_msg_locked(nsm
);
384 /* Parse device response into kernel argument struct */
385 r
= parse_resp_raw(nsm
, &nsm
->msg
.resp
, &raw
);
389 /* Copy kernel argument struct back to user argument struct */
391 if (copy_to_user(argp
, &raw
, sizeof(raw
)))
397 mutex_unlock(&nsm
->lock
);
401 static int nsm_device_init_vq(struct virtio_device
*vdev
)
403 struct virtqueue
*vq
= virtio_find_single_vq(vdev
,
404 nsm_vq_callback
, "nsm.vq.0");
405 struct nsm
*nsm
= vdev
->priv
;
415 static const struct file_operations nsm_dev_fops
= {
416 .unlocked_ioctl
= nsm_dev_ioctl
,
417 .compat_ioctl
= compat_ptr_ioctl
,
420 /* Handler for probing the NSM device */
421 static int nsm_device_probe(struct virtio_device
*vdev
)
423 struct device
*dev
= &vdev
->dev
;
427 nsm
= devm_kzalloc(&vdev
->dev
, sizeof(*nsm
), GFP_KERNEL
);
434 rc
= nsm_device_init_vq(vdev
);
436 dev_err(dev
, "queue failed to initialize: %d.\n", rc
);
440 mutex_init(&nsm
->lock
);
442 /* Register as hwrng provider */
443 nsm
->hwrng
= (struct hwrng
) {
444 .read
= nsm_rng_read
,
449 rc
= hwrng_register(&nsm
->hwrng
);
451 dev_err(dev
, "RNG initialization error: %d.\n", rc
);
455 /* Register /dev/nsm device node */
456 nsm
->misc
= (struct miscdevice
) {
457 .minor
= MISC_DYNAMIC_MINOR
,
459 .fops
= &nsm_dev_fops
,
463 rc
= misc_register(&nsm
->misc
);
465 dev_err(dev
, "misc device registration error: %d.\n", rc
);
472 hwrng_unregister(&nsm
->hwrng
);
474 vdev
->config
->del_vqs(vdev
);
479 /* Handler for removing the NSM device */
480 static void nsm_device_remove(struct virtio_device
*vdev
)
482 struct nsm
*nsm
= vdev
->priv
;
484 hwrng_unregister(&nsm
->hwrng
);
486 vdev
->config
->del_vqs(vdev
);
487 misc_deregister(&nsm
->misc
);
490 /* NSM device configuration structure */
491 static struct virtio_driver virtio_nsm_driver
= {
493 .feature_table_size
= 0,
494 .feature_table_legacy
= 0,
495 .feature_table_size_legacy
= 0,
496 .driver
.name
= KBUILD_MODNAME
,
497 .id_table
= id_table
,
498 .probe
= nsm_device_probe
,
499 .remove
= nsm_device_remove
,
502 module_virtio_driver(virtio_nsm_driver
);
503 MODULE_DEVICE_TABLE(virtio
, id_table
);
504 MODULE_DESCRIPTION("Virtio NSM driver");
505 MODULE_LICENSE("GPL");