diff options
Diffstat (limited to 'drivers/virt/coco/sev-guest/sev-guest.c')
-rw-r--r-- | drivers/virt/coco/sev-guest/sev-guest.c | 179 |
1 files changed, 167 insertions, 12 deletions
diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c index 5bee58ef5f..bc564adcf4 100644 --- a/drivers/virt/coco/sev-guest/sev-guest.c +++ b/drivers/virt/coco/sev-guest/sev-guest.c @@ -16,9 +16,13 @@ #include <linux/miscdevice.h> #include <linux/set_memory.h> #include <linux/fs.h> +#include <linux/tsm.h> #include <crypto/aead.h> #include <linux/scatterlist.h> #include <linux/psp-sev.h> +#include <linux/sockptr.h> +#include <linux/cleanup.h> +#include <linux/uuid.h> #include <uapi/linux/sev-guest.h> #include <uapi/linux/psp-sev.h> @@ -475,6 +479,11 @@ static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code, return 0; } +struct snp_req_resp { + sockptr_t req_data; + sockptr_t resp_data; +}; + static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg) { struct snp_guest_crypto *crypto = snp_dev->crypto; @@ -555,22 +564,25 @@ static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque return rc; } -static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg) +static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg, + struct snp_req_resp *io) + { struct snp_ext_report_req *req = &snp_dev->req.ext_report; struct snp_guest_crypto *crypto = snp_dev->crypto; struct snp_report_resp *resp; int ret, npages = 0, resp_len; + sockptr_t certs_address; lockdep_assert_held(&snp_cmd_mutex); - if (!arg->req_data || !arg->resp_data) + if (sockptr_is_null(io->req_data) || sockptr_is_null(io->resp_data)) return -EINVAL; - if (copy_from_user(req, (void __user *)arg->req_data, sizeof(*req))) + if (copy_from_sockptr(req, io->req_data, sizeof(*req))) return -EFAULT; - /* userspace does not want certificate data */ + /* caller does not want certificate data */ if (!req->certs_len || !req->certs_address) goto cmd; @@ -578,8 +590,13 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques !IS_ALIGNED(req->certs_len, PAGE_SIZE)) return -EINVAL; - if (!access_ok((const void __user *)req->certs_address, req->certs_len)) - return -EFAULT; + if (sockptr_is_kernel(io->resp_data)) { + certs_address = KERNEL_SOCKPTR((void *)req->certs_address); + } else { + certs_address = USER_SOCKPTR((void __user *)req->certs_address); + if (!access_ok(certs_address.user, req->certs_len)) + return -EFAULT; + } /* * Initialize the intermediate buffer with all zeros. This buffer @@ -609,21 +626,19 @@ cmd: if (arg->vmm_error == SNP_GUEST_VMM_ERR_INVALID_LEN) { req->certs_len = snp_dev->input.data_npages << PAGE_SHIFT; - if (copy_to_user((void __user *)arg->req_data, req, sizeof(*req))) + if (copy_to_sockptr(io->req_data, req, sizeof(*req))) ret = -EFAULT; } if (ret) goto e_free; - if (npages && - copy_to_user((void __user *)req->certs_address, snp_dev->certs_data, - req->certs_len)) { + if (npages && copy_to_sockptr(certs_address, snp_dev->certs_data, req->certs_len)) { ret = -EFAULT; goto e_free; } - if (copy_to_user((void __user *)arg->resp_data, resp, sizeof(*resp))) + if (copy_to_sockptr(io->resp_data, resp, sizeof(*resp))) ret = -EFAULT; e_free: @@ -636,6 +651,7 @@ static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long struct snp_guest_dev *snp_dev = to_snp_dev(file); void __user *argp = (void __user *)arg; struct snp_guest_request_ioctl input; + struct snp_req_resp io; int ret = -ENOTTY; if (copy_from_user(&input, argp, sizeof(input))) @@ -664,7 +680,14 @@ static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long ret = get_derived_key(snp_dev, &input); break; case SNP_GET_EXT_REPORT: - ret = get_ext_report(snp_dev, &input); + /* + * As get_ext_report() may be called from the ioctl() path and a + * kernel internal path (configfs-tsm), decorate the passed + * buffers as user pointers. + */ + io.req_data = USER_SOCKPTR((void __user *)input.req_data); + io.resp_data = USER_SOCKPTR((void __user *)input.resp_data); + ret = get_ext_report(snp_dev, &input, &io); break; default: break; @@ -748,6 +771,130 @@ static u8 *get_vmpck(int id, struct snp_secrets_page_layout *layout, u32 **seqno return key; } +struct snp_msg_report_resp_hdr { + u32 status; + u32 report_size; + u8 rsvd[24]; +}; + +struct snp_msg_cert_entry { + guid_t guid; + u32 offset; + u32 length; +}; + +static int sev_report_new(struct tsm_report *report, void *data) +{ + struct snp_msg_cert_entry *cert_table; + struct tsm_desc *desc = &report->desc; + struct snp_guest_dev *snp_dev = data; + struct snp_msg_report_resp_hdr hdr; + const u32 report_size = SZ_4K; + const u32 ext_size = SEV_FW_BLOB_MAX_SIZE; + u32 certs_size, i, size = report_size + ext_size; + int ret; + + if (desc->inblob_len != SNP_REPORT_USER_DATA_SIZE) + return -EINVAL; + + void *buf __free(kvfree) = kvzalloc(size, GFP_KERNEL); + if (!buf) + return -ENOMEM; + + guard(mutex)(&snp_cmd_mutex); + + /* Check if the VMPCK is not empty */ + if (is_vmpck_empty(snp_dev)) { + dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n"); + return -ENOTTY; + } + + cert_table = buf + report_size; + struct snp_ext_report_req ext_req = { + .data = { .vmpl = desc->privlevel }, + .certs_address = (__u64)cert_table, + .certs_len = ext_size, + }; + memcpy(&ext_req.data.user_data, desc->inblob, desc->inblob_len); + + struct snp_guest_request_ioctl input = { + .msg_version = 1, + .req_data = (__u64)&ext_req, + .resp_data = (__u64)buf, + .exitinfo2 = 0xff, + }; + struct snp_req_resp io = { + .req_data = KERNEL_SOCKPTR(&ext_req), + .resp_data = KERNEL_SOCKPTR(buf), + }; + + ret = get_ext_report(snp_dev, &input, &io); + if (ret) + return ret; + + memcpy(&hdr, buf, sizeof(hdr)); + if (hdr.status == SEV_RET_INVALID_PARAM) + return -EINVAL; + if (hdr.status == SEV_RET_INVALID_KEY) + return -EINVAL; + if (hdr.status) + return -ENXIO; + if ((hdr.report_size + sizeof(hdr)) > report_size) + return -ENOMEM; + + void *rbuf __free(kvfree) = kvzalloc(hdr.report_size, GFP_KERNEL); + if (!rbuf) + return -ENOMEM; + + memcpy(rbuf, buf + sizeof(hdr), hdr.report_size); + report->outblob = no_free_ptr(rbuf); + report->outblob_len = hdr.report_size; + + certs_size = 0; + for (i = 0; i < ext_size / sizeof(struct snp_msg_cert_entry); i++) { + struct snp_msg_cert_entry *ent = &cert_table[i]; + + if (guid_is_null(&ent->guid) && !ent->offset && !ent->length) + break; + certs_size = max(certs_size, ent->offset + ent->length); + } + + /* Suspicious that the response populated entries without populating size */ + if (!certs_size && i) + dev_warn_ratelimited(snp_dev->dev, "certificate slots conveyed without size\n"); + + /* No certs to report */ + if (!certs_size) + return 0; + + /* Suspicious that the certificate blob size contract was violated + */ + if (certs_size > ext_size) { + dev_warn_ratelimited(snp_dev->dev, "certificate data truncated\n"); + certs_size = ext_size; + } + + void *cbuf __free(kvfree) = kvzalloc(certs_size, GFP_KERNEL); + if (!cbuf) + return -ENOMEM; + + memcpy(cbuf, cert_table, certs_size); + report->auxblob = no_free_ptr(cbuf); + report->auxblob_len = certs_size; + + return 0; +} + +static const struct tsm_ops sev_tsm_ops = { + .name = KBUILD_MODNAME, + .report_new = sev_report_new, +}; + +static void unregister_sev_tsm(void *data) +{ + tsm_unregister(&sev_tsm_ops); +} + static int __init sev_guest_probe(struct platform_device *pdev) { struct snp_secrets_page_layout *layout; @@ -821,6 +968,14 @@ static int __init sev_guest_probe(struct platform_device *pdev) snp_dev->input.resp_gpa = __pa(snp_dev->response); snp_dev->input.data_gpa = __pa(snp_dev->certs_data); + ret = tsm_register(&sev_tsm_ops, snp_dev, &tsm_report_extra_type); + if (ret) + goto e_free_cert_data; + + ret = devm_add_action_or_reset(&pdev->dev, unregister_sev_tsm, NULL); + if (ret) + goto e_free_cert_data; + ret = misc_register(misc); if (ret) goto e_free_cert_data; |