drm/tests: hdmi: Fix memory leaks in drm_display_mode_from_cea_vic()
[drm/drm-misc.git] / drivers / misc / nsm.c
blobef7b32742340999add6533472f0fe00a1a812742
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3 * Amazon Nitro Secure Module driver.
5 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
7 * The Nitro Secure Module implements commands via CBOR over virtio.
8 * This driver exposes a raw message ioctls on /dev/nsm that user
9 * space can use to issue these commands.
12 #include <linux/file.h>
13 #include <linux/fs.h>
14 #include <linux/interrupt.h>
15 #include <linux/hw_random.h>
16 #include <linux/miscdevice.h>
17 #include <linux/module.h>
18 #include <linux/mutex.h>
19 #include <linux/slab.h>
20 #include <linux/string.h>
21 #include <linux/uaccess.h>
22 #include <linux/uio.h>
23 #include <linux/virtio_config.h>
24 #include <linux/virtio_ids.h>
25 #include <linux/virtio.h>
26 #include <linux/wait.h>
27 #include <uapi/linux/nsm.h>
29 /* Timeout for NSM virtqueue respose in milliseconds. */
30 #define NSM_DEFAULT_TIMEOUT_MSECS (120000) /* 2 minutes */
32 /* Maximum length input data */
33 struct nsm_data_req {
34 u32 len;
35 u8 data[NSM_REQUEST_MAX_SIZE];
38 /* Maximum length output data */
39 struct nsm_data_resp {
40 u32 len;
41 u8 data[NSM_RESPONSE_MAX_SIZE];
44 /* Full NSM request/response message */
45 struct nsm_msg {
46 struct nsm_data_req req;
47 struct nsm_data_resp resp;
50 struct nsm {
51 struct virtio_device *vdev;
52 struct virtqueue *vq;
53 struct mutex lock;
54 struct completion cmd_done;
55 struct miscdevice misc;
56 struct hwrng hwrng;
57 struct work_struct misc_init;
58 struct nsm_msg msg;
61 /* NSM device ID */
62 static const struct virtio_device_id id_table[] = {
63 { VIRTIO_ID_NITRO_SEC_MOD, VIRTIO_DEV_ANY_ID },
64 { 0 },
67 static struct nsm *file_to_nsm(struct file *file)
69 return container_of(file->private_data, struct nsm, misc);
72 static struct nsm *hwrng_to_nsm(struct hwrng *rng)
74 return container_of(rng, struct nsm, hwrng);
77 #define CBOR_TYPE_MASK 0xE0
78 #define CBOR_TYPE_MAP 0xA0
79 #define CBOR_TYPE_TEXT 0x60
80 #define CBOR_TYPE_ARRAY 0x40
81 #define CBOR_HEADER_SIZE_SHORT 1
83 #define CBOR_SHORT_SIZE_MAX_VALUE 23
84 #define CBOR_LONG_SIZE_U8 24
85 #define CBOR_LONG_SIZE_U16 25
86 #define CBOR_LONG_SIZE_U32 26
87 #define CBOR_LONG_SIZE_U64 27
89 static bool cbor_object_is_array(const u8 *cbor_object, size_t cbor_object_size)
91 if (cbor_object_size == 0 || cbor_object == NULL)
92 return false;
94 return (cbor_object[0] & CBOR_TYPE_MASK) == CBOR_TYPE_ARRAY;
97 static int cbor_object_get_array(u8 *cbor_object, size_t cbor_object_size, u8 **cbor_array)
99 u8 cbor_short_size;
100 void *array_len_p;
101 u64 array_len;
102 u64 array_offset;
104 if (!cbor_object_is_array(cbor_object, cbor_object_size))
105 return -EFAULT;
107 cbor_short_size = (cbor_object[0] & 0x1F);
109 /* Decoding byte array length */
110 array_offset = CBOR_HEADER_SIZE_SHORT;
111 if (cbor_short_size >= CBOR_LONG_SIZE_U8)
112 array_offset += BIT(cbor_short_size - CBOR_LONG_SIZE_U8);
114 if (cbor_object_size < array_offset)
115 return -EFAULT;
117 array_len_p = &cbor_object[1];
119 switch (cbor_short_size) {
120 case CBOR_SHORT_SIZE_MAX_VALUE: /* short encoding */
121 array_len = cbor_short_size;
122 break;
123 case CBOR_LONG_SIZE_U8:
124 array_len = *(u8 *)array_len_p;
125 break;
126 case CBOR_LONG_SIZE_U16:
127 array_len = be16_to_cpup((__be16 *)array_len_p);
128 break;
129 case CBOR_LONG_SIZE_U32:
130 array_len = be32_to_cpup((__be32 *)array_len_p);
131 break;
132 case CBOR_LONG_SIZE_U64:
133 array_len = be64_to_cpup((__be64 *)array_len_p);
134 break;
137 if (cbor_object_size < array_offset)
138 return -EFAULT;
140 if (cbor_object_size - array_offset < array_len)
141 return -EFAULT;
143 if (array_len > INT_MAX)
144 return -EFAULT;
146 *cbor_array = cbor_object + array_offset;
147 return array_len;
150 /* Copy the request of a raw message to kernel space */
151 static int fill_req_raw(struct nsm *nsm, struct nsm_data_req *req,
152 struct nsm_raw *raw)
154 /* Verify the user input size. */
155 if (raw->request.len > sizeof(req->data))
156 return -EMSGSIZE;
158 /* Copy the request payload */
159 if (copy_from_user(req->data, u64_to_user_ptr(raw->request.addr),
160 raw->request.len))
161 return -EFAULT;
163 req->len = raw->request.len;
165 return 0;
168 /* Copy the response of a raw message back to user-space */
169 static int parse_resp_raw(struct nsm *nsm, struct nsm_data_resp *resp,
170 struct nsm_raw *raw)
172 /* Truncate any message that does not fit. */
173 raw->response.len = min_t(u64, raw->response.len, resp->len);
175 /* Copy the response content to user space */
176 if (copy_to_user(u64_to_user_ptr(raw->response.addr),
177 resp->data, raw->response.len))
178 return -EFAULT;
180 return 0;
183 /* Virtqueue interrupt handler */
184 static void nsm_vq_callback(struct virtqueue *vq)
186 struct nsm *nsm = vq->vdev->priv;
188 complete(&nsm->cmd_done);
191 /* Forward a message to the NSM device and wait for the response from it */
192 static int nsm_sendrecv_msg_locked(struct nsm *nsm)
194 struct device *dev = &nsm->vdev->dev;
195 struct scatterlist sg_in, sg_out;
196 struct nsm_msg *msg = &nsm->msg;
197 struct virtqueue *vq = nsm->vq;
198 unsigned int len;
199 void *queue_buf;
200 bool kicked;
201 int rc;
203 /* Initialize scatter-gather lists with request and response buffers. */
204 sg_init_one(&sg_out, msg->req.data, msg->req.len);
205 sg_init_one(&sg_in, msg->resp.data, sizeof(msg->resp.data));
207 init_completion(&nsm->cmd_done);
208 /* Add the request buffer (read by the device). */
209 rc = virtqueue_add_outbuf(vq, &sg_out, 1, msg->req.data, GFP_KERNEL);
210 if (rc)
211 return rc;
213 /* Add the response buffer (written by the device). */
214 rc = virtqueue_add_inbuf(vq, &sg_in, 1, msg->resp.data, GFP_KERNEL);
215 if (rc)
216 goto cleanup;
218 kicked = virtqueue_kick(vq);
219 if (!kicked) {
220 /* Cannot kick the virtqueue. */
221 rc = -EIO;
222 goto cleanup;
225 /* If the kick succeeded, wait for the device's response. */
226 if (!wait_for_completion_io_timeout(&nsm->cmd_done,
227 msecs_to_jiffies(NSM_DEFAULT_TIMEOUT_MSECS))) {
228 rc = -ETIMEDOUT;
229 goto cleanup;
232 queue_buf = virtqueue_get_buf(vq, &len);
233 if (!queue_buf || (queue_buf != msg->req.data)) {
234 dev_err(dev, "wrong request buffer.");
235 rc = -ENODATA;
236 goto cleanup;
239 queue_buf = virtqueue_get_buf(vq, &len);
240 if (!queue_buf || (queue_buf != msg->resp.data)) {
241 dev_err(dev, "wrong response buffer.");
242 rc = -ENODATA;
243 goto cleanup;
246 msg->resp.len = len;
248 rc = 0;
250 cleanup:
251 if (rc) {
252 /* Clean the virtqueue. */
253 while (virtqueue_get_buf(vq, &len) != NULL)
257 return rc;
260 static int fill_req_get_random(struct nsm *nsm, struct nsm_data_req *req)
263 * 69 # text(9)
264 * 47657452616E646F6D # "GetRandom"
266 const u8 request[] = { CBOR_TYPE_TEXT + strlen("GetRandom"),
267 'G', 'e', 't', 'R', 'a', 'n', 'd', 'o', 'm' };
269 memcpy(req->data, request, sizeof(request));
270 req->len = sizeof(request);
272 return 0;
275 static int parse_resp_get_random(struct nsm *nsm, struct nsm_data_resp *resp,
276 void *out, size_t max)
279 * A1 # map(1)
280 * 69 # text(9) - Name of field
281 * 47657452616E646F6D # "GetRandom"
282 * A1 # map(1) - The field itself
283 * 66 # text(6)
284 * 72616E646F6D # "random"
285 * # The rest of the response is random data
287 const u8 response[] = { CBOR_TYPE_MAP + 1,
288 CBOR_TYPE_TEXT + strlen("GetRandom"),
289 'G', 'e', 't', 'R', 'a', 'n', 'd', 'o', 'm',
290 CBOR_TYPE_MAP + 1,
291 CBOR_TYPE_TEXT + strlen("random"),
292 'r', 'a', 'n', 'd', 'o', 'm' };
293 struct device *dev = &nsm->vdev->dev;
294 u8 *rand_data = NULL;
295 u8 *resp_ptr = resp->data;
296 u64 resp_len = resp->len;
297 int rc;
299 if ((resp->len < sizeof(response) + 1) ||
300 (memcmp(resp_ptr, response, sizeof(response)) != 0)) {
301 dev_err(dev, "Invalid response for GetRandom");
302 return -EFAULT;
305 resp_ptr += sizeof(response);
306 resp_len -= sizeof(response);
308 rc = cbor_object_get_array(resp_ptr, resp_len, &rand_data);
309 if (rc < 0) {
310 dev_err(dev, "GetRandom: Invalid CBOR encoding\n");
311 return rc;
314 rc = min_t(size_t, rc, max);
315 memcpy(out, rand_data, rc);
317 return rc;
321 * HwRNG implementation
323 static int nsm_rng_read(struct hwrng *rng, void *data, size_t max, bool wait)
325 struct nsm *nsm = hwrng_to_nsm(rng);
326 struct device *dev = &nsm->vdev->dev;
327 int rc = 0;
329 /* NSM always needs to wait for a response */
330 if (!wait)
331 return 0;
333 mutex_lock(&nsm->lock);
335 rc = fill_req_get_random(nsm, &nsm->msg.req);
336 if (rc != 0)
337 goto out;
339 rc = nsm_sendrecv_msg_locked(nsm);
340 if (rc != 0)
341 goto out;
343 rc = parse_resp_get_random(nsm, &nsm->msg.resp, data, max);
344 if (rc < 0)
345 goto out;
347 dev_dbg(dev, "RNG: returning rand bytes = %d", rc);
348 out:
349 mutex_unlock(&nsm->lock);
350 return rc;
353 static long nsm_dev_ioctl(struct file *file, unsigned int cmd,
354 unsigned long arg)
356 void __user *argp = u64_to_user_ptr((u64)arg);
357 struct nsm *nsm = file_to_nsm(file);
358 struct nsm_raw raw;
359 int r = 0;
361 if (cmd != NSM_IOCTL_RAW)
362 return -EINVAL;
364 if (_IOC_SIZE(cmd) != sizeof(raw))
365 return -EINVAL;
367 /* Copy user argument struct to kernel argument struct */
368 r = -EFAULT;
369 if (copy_from_user(&raw, argp, _IOC_SIZE(cmd)))
370 goto out;
372 mutex_lock(&nsm->lock);
374 /* Convert kernel argument struct to device request */
375 r = fill_req_raw(nsm, &nsm->msg.req, &raw);
376 if (r)
377 goto out;
379 /* Send message to NSM and read reply */
380 r = nsm_sendrecv_msg_locked(nsm);
381 if (r)
382 goto out;
384 /* Parse device response into kernel argument struct */
385 r = parse_resp_raw(nsm, &nsm->msg.resp, &raw);
386 if (r)
387 goto out;
389 /* Copy kernel argument struct back to user argument struct */
390 r = -EFAULT;
391 if (copy_to_user(argp, &raw, sizeof(raw)))
392 goto out;
394 r = 0;
396 out:
397 mutex_unlock(&nsm->lock);
398 return r;
401 static int nsm_device_init_vq(struct virtio_device *vdev)
403 struct virtqueue *vq = virtio_find_single_vq(vdev,
404 nsm_vq_callback, "nsm.vq.0");
405 struct nsm *nsm = vdev->priv;
407 if (IS_ERR(vq))
408 return PTR_ERR(vq);
410 nsm->vq = vq;
412 return 0;
415 static const struct file_operations nsm_dev_fops = {
416 .unlocked_ioctl = nsm_dev_ioctl,
417 .compat_ioctl = compat_ptr_ioctl,
420 /* Handler for probing the NSM device */
421 static int nsm_device_probe(struct virtio_device *vdev)
423 struct device *dev = &vdev->dev;
424 struct nsm *nsm;
425 int rc;
427 nsm = devm_kzalloc(&vdev->dev, sizeof(*nsm), GFP_KERNEL);
428 if (!nsm)
429 return -ENOMEM;
431 vdev->priv = nsm;
432 nsm->vdev = vdev;
434 rc = nsm_device_init_vq(vdev);
435 if (rc) {
436 dev_err(dev, "queue failed to initialize: %d.\n", rc);
437 goto err_init_vq;
440 mutex_init(&nsm->lock);
442 /* Register as hwrng provider */
443 nsm->hwrng = (struct hwrng) {
444 .read = nsm_rng_read,
445 .name = "nsm-hwrng",
446 .quality = 1000,
449 rc = hwrng_register(&nsm->hwrng);
450 if (rc) {
451 dev_err(dev, "RNG initialization error: %d.\n", rc);
452 goto err_hwrng;
455 /* Register /dev/nsm device node */
456 nsm->misc = (struct miscdevice) {
457 .minor = MISC_DYNAMIC_MINOR,
458 .name = "nsm",
459 .fops = &nsm_dev_fops,
460 .mode = 0666,
463 rc = misc_register(&nsm->misc);
464 if (rc) {
465 dev_err(dev, "misc device registration error: %d.\n", rc);
466 goto err_misc;
469 return 0;
471 err_misc:
472 hwrng_unregister(&nsm->hwrng);
473 err_hwrng:
474 vdev->config->del_vqs(vdev);
475 err_init_vq:
476 return rc;
479 /* Handler for removing the NSM device */
480 static void nsm_device_remove(struct virtio_device *vdev)
482 struct nsm *nsm = vdev->priv;
484 hwrng_unregister(&nsm->hwrng);
486 vdev->config->del_vqs(vdev);
487 misc_deregister(&nsm->misc);
490 /* NSM device configuration structure */
491 static struct virtio_driver virtio_nsm_driver = {
492 .feature_table = 0,
493 .feature_table_size = 0,
494 .feature_table_legacy = 0,
495 .feature_table_size_legacy = 0,
496 .driver.name = KBUILD_MODNAME,
497 .id_table = id_table,
498 .probe = nsm_device_probe,
499 .remove = nsm_device_remove,
502 module_virtio_driver(virtio_nsm_driver);
503 MODULE_DEVICE_TABLE(virtio, id_table);
504 MODULE_DESCRIPTION("Virtio NSM driver");
505 MODULE_LICENSE("GPL");