diff options
Diffstat (limited to 'drivers/nvme/host/tcp.c')
-rw-r--r-- | drivers/nvme/host/tcp.c | 189 |
1 files changed, 173 insertions, 16 deletions
diff --git a/drivers/nvme/host/tcp.c b/drivers/nvme/host/tcp.c index f1d62d7442..08805f0278 100644 --- a/drivers/nvme/host/tcp.c +++ b/drivers/nvme/host/tcp.c @@ -8,9 +8,14 @@ #include <linux/init.h> #include <linux/slab.h> #include <linux/err.h> +#include <linux/key.h> #include <linux/nvme-tcp.h> +#include <linux/nvme-keyring.h> #include <net/sock.h> #include <net/tcp.h> +#include <net/tls.h> +#include <net/tls_prot.h> +#include <net/handshake.h> #include <linux/blk-mq.h> #include <crypto/hash.h> #include <net/busy_poll.h> @@ -31,6 +36,16 @@ static int so_priority; module_param(so_priority, int, 0644); MODULE_PARM_DESC(so_priority, "nvme tcp socket optimize priority"); +/* + * TLS handshake timeout + */ +static int tls_handshake_timeout = 10; +#ifdef CONFIG_NVME_TCP_TLS +module_param(tls_handshake_timeout, int, 0644); +MODULE_PARM_DESC(tls_handshake_timeout, + "nvme TLS handshake timeout in seconds (default 10)"); +#endif + #ifdef CONFIG_DEBUG_LOCK_ALLOC /* lockdep can detect a circular dependency of the form * sk_lock -> mmap_lock (page fault) -> fs locks -> sk_lock @@ -146,7 +161,8 @@ struct nvme_tcp_queue { struct ahash_request *snd_hash; __le32 exp_ddgst; __le32 recv_ddgst; - + struct completion tls_complete; + int tls_err; struct page_frag_cache pf_cache; void (*state_change)(struct sock *); @@ -189,6 +205,14 @@ static inline int nvme_tcp_queue_id(struct nvme_tcp_queue *queue) return queue - queue->ctrl->queues; } +static inline bool nvme_tcp_tls(struct nvme_ctrl *ctrl) +{ + if (!IS_ENABLED(CONFIG_NVME_TCP_TLS)) + return 0; + + return ctrl->opts->tls; +} + static inline struct blk_mq_tags *nvme_tcp_tagset(struct nvme_tcp_queue *queue) { u32 queue_idx = nvme_tcp_queue_id(queue); @@ -1338,7 +1362,9 @@ static void nvme_tcp_free_queue(struct nvme_ctrl *nctrl, int qid) } noreclaim_flag = memalloc_noreclaim_save(); - sock_release(queue->sock); + /* ->sock will be released by fput() */ + fput(queue->sock->file); + queue->sock = NULL; memalloc_noreclaim_restore(noreclaim_flag); kfree(queue->pdu); @@ -1350,6 +1376,8 @@ static int nvme_tcp_init_connection(struct nvme_tcp_queue *queue) { struct nvme_tcp_icreq_pdu *icreq; struct nvme_tcp_icresp_pdu *icresp; + char cbuf[CMSG_LEN(sizeof(char))] = {}; + u8 ctype; struct msghdr msg = {}; struct kvec iov; bool ctrl_hdgst, ctrl_ddgst; @@ -1381,17 +1409,36 @@ static int nvme_tcp_init_connection(struct nvme_tcp_queue *queue) iov.iov_base = icreq; iov.iov_len = sizeof(*icreq); ret = kernel_sendmsg(queue->sock, &msg, &iov, 1, iov.iov_len); - if (ret < 0) + if (ret < 0) { + pr_warn("queue %d: failed to send icreq, error %d\n", + nvme_tcp_queue_id(queue), ret); goto free_icresp; + } memset(&msg, 0, sizeof(msg)); iov.iov_base = icresp; iov.iov_len = sizeof(*icresp); + if (nvme_tcp_tls(&queue->ctrl->ctrl)) { + msg.msg_control = cbuf; + msg.msg_controllen = sizeof(cbuf); + } ret = kernel_recvmsg(queue->sock, &msg, &iov, 1, iov.iov_len, msg.msg_flags); - if (ret < 0) + if (ret < 0) { + pr_warn("queue %d: failed to receive icresp, error %d\n", + nvme_tcp_queue_id(queue), ret); goto free_icresp; - + } + ret = -ENOTCONN; + if (nvme_tcp_tls(&queue->ctrl->ctrl)) { + ctype = tls_get_record_type(queue->sock->sk, + (struct cmsghdr *)cbuf); + if (ctype != TLS_RECORD_TYPE_DATA) { + pr_err("queue %d: unhandled TLS record %d\n", + nvme_tcp_queue_id(queue), ctype); + goto free_icresp; + } + } ret = -EINVAL; if (icresp->hdr.type != nvme_tcp_icresp) { pr_err("queue %d: bad type returned %d\n", @@ -1507,11 +1554,90 @@ static void nvme_tcp_set_queue_io_cpu(struct nvme_tcp_queue *queue) queue->io_cpu = cpumask_next_wrap(n - 1, cpu_online_mask, -1, false); } -static int nvme_tcp_alloc_queue(struct nvme_ctrl *nctrl, int qid) +static void nvme_tcp_tls_done(void *data, int status, key_serial_t pskid) +{ + struct nvme_tcp_queue *queue = data; + struct nvme_tcp_ctrl *ctrl = queue->ctrl; + int qid = nvme_tcp_queue_id(queue); + struct key *tls_key; + + dev_dbg(ctrl->ctrl.device, "queue %d: TLS handshake done, key %x, status %d\n", + qid, pskid, status); + + if (status) { + queue->tls_err = -status; + goto out_complete; + } + + tls_key = key_lookup(pskid); + if (IS_ERR(tls_key)) { + dev_warn(ctrl->ctrl.device, "queue %d: Invalid key %x\n", + qid, pskid); + queue->tls_err = -ENOKEY; + } else { + ctrl->ctrl.tls_key = tls_key; + queue->tls_err = 0; + } + +out_complete: + complete(&queue->tls_complete); +} + +static int nvme_tcp_start_tls(struct nvme_ctrl *nctrl, + struct nvme_tcp_queue *queue, + key_serial_t pskid) +{ + int qid = nvme_tcp_queue_id(queue); + int ret; + struct tls_handshake_args args; + unsigned long tmo = tls_handshake_timeout * HZ; + key_serial_t keyring = nvme_keyring_id(); + + dev_dbg(nctrl->device, "queue %d: start TLS with key %x\n", + qid, pskid); + memset(&args, 0, sizeof(args)); + args.ta_sock = queue->sock; + args.ta_done = nvme_tcp_tls_done; + args.ta_data = queue; + args.ta_my_peerids[0] = pskid; + args.ta_num_peerids = 1; + if (nctrl->opts->keyring) + keyring = key_serial(nctrl->opts->keyring); + args.ta_keyring = keyring; + args.ta_timeout_ms = tls_handshake_timeout * 1000; + queue->tls_err = -EOPNOTSUPP; + init_completion(&queue->tls_complete); + ret = tls_client_hello_psk(&args, GFP_KERNEL); + if (ret) { + dev_err(nctrl->device, "queue %d: failed to start TLS: %d\n", + qid, ret); + return ret; + } + ret = wait_for_completion_interruptible_timeout(&queue->tls_complete, tmo); + if (ret <= 0) { + if (ret == 0) + ret = -ETIMEDOUT; + + dev_err(nctrl->device, + "queue %d: TLS handshake failed, error %d\n", + qid, ret); + tls_handshake_cancel(queue->sock->sk); + } else { + dev_dbg(nctrl->device, + "queue %d: TLS handshake complete, error %d\n", + qid, queue->tls_err); + ret = queue->tls_err; + } + return ret; +} + +static int nvme_tcp_alloc_queue(struct nvme_ctrl *nctrl, int qid, + key_serial_t pskid) { struct nvme_tcp_ctrl *ctrl = to_tcp_ctrl(nctrl); struct nvme_tcp_queue *queue = &ctrl->queues[qid]; int ret, rcv_pdu_size; + struct file *sock_file; mutex_init(&queue->queue_lock); queue->ctrl = ctrl; @@ -1534,6 +1660,11 @@ static int nvme_tcp_alloc_queue(struct nvme_ctrl *nctrl, int qid) goto err_destroy_mutex; } + sock_file = sock_alloc_file(queue->sock, O_CLOEXEC, NULL); + if (IS_ERR(sock_file)) { + ret = PTR_ERR(sock_file); + goto err_destroy_mutex; + } nvme_tcp_reclassify_socket(queue->sock); /* Single syn retry */ @@ -1624,6 +1755,13 @@ static int nvme_tcp_alloc_queue(struct nvme_ctrl *nctrl, int qid) goto err_rcv_pdu; } + /* If PSKs are configured try to start TLS */ + if (IS_ENABLED(CONFIG_NVME_TCP_TLS) && pskid) { + ret = nvme_tcp_start_tls(nctrl, queue, pskid); + if (ret) + goto err_init_connect; + } + ret = nvme_tcp_init_connection(queue); if (ret) goto err_init_connect; @@ -1640,7 +1778,8 @@ err_crypto: if (queue->hdr_digest || queue->data_digest) nvme_tcp_free_crypto(queue); err_sock: - sock_release(queue->sock); + /* ->sock will be released by fput() */ + fput(queue->sock->file); queue->sock = NULL; err_destroy_mutex: mutex_destroy(&queue->send_mutex); @@ -1772,10 +1911,25 @@ out_stop_queues: static int nvme_tcp_alloc_admin_queue(struct nvme_ctrl *ctrl) { int ret; + key_serial_t pskid = 0; + + if (nvme_tcp_tls(ctrl)) { + if (ctrl->opts->tls_key) + pskid = key_serial(ctrl->opts->tls_key); + else + pskid = nvme_tls_psk_default(ctrl->opts->keyring, + ctrl->opts->host->nqn, + ctrl->opts->subsysnqn); + if (!pskid) { + dev_err(ctrl->device, "no valid PSK found\n"); + ret = -ENOKEY; + goto out_free_queue; + } + } - ret = nvme_tcp_alloc_queue(ctrl, 0); + ret = nvme_tcp_alloc_queue(ctrl, 0, pskid); if (ret) - return ret; + goto out_free_queue; ret = nvme_tcp_alloc_async_req(to_tcp_ctrl(ctrl)); if (ret) @@ -1792,8 +1946,13 @@ static int __nvme_tcp_alloc_io_queues(struct nvme_ctrl *ctrl) { int i, ret; + if (nvme_tcp_tls(ctrl) && !ctrl->tls_key) { + dev_err(ctrl->device, "no PSK negotiated\n"); + return -ENOKEY; + } for (i = 1; i < ctrl->queue_count; i++) { - ret = nvme_tcp_alloc_queue(ctrl, i); + ret = nvme_tcp_alloc_queue(ctrl, i, + key_serial(ctrl->tls_key)); if (ret) goto out_free_queues; } @@ -2078,11 +2237,8 @@ destroy_io: nvme_tcp_destroy_io_queues(ctrl, new); } destroy_admin: - nvme_quiesce_admin_queue(ctrl); - blk_sync_queue(ctrl->admin_q); - nvme_tcp_stop_queue(ctrl, 0); - nvme_cancel_admin_tagset(ctrl); - nvme_tcp_destroy_admin_queue(ctrl, new); + nvme_stop_keep_alive(ctrl); + nvme_tcp_teardown_admin_queue(ctrl, false); return ret; } @@ -2628,7 +2784,8 @@ static struct nvmf_transport_ops nvme_tcp_transport = { NVMF_OPT_HOST_TRADDR | NVMF_OPT_CTRL_LOSS_TMO | NVMF_OPT_HDR_DIGEST | NVMF_OPT_DATA_DIGEST | NVMF_OPT_NR_WRITE_QUEUES | NVMF_OPT_NR_POLL_QUEUES | - NVMF_OPT_TOS | NVMF_OPT_HOST_IFACE, + NVMF_OPT_TOS | NVMF_OPT_HOST_IFACE | NVMF_OPT_TLS | + NVMF_OPT_KEYRING | NVMF_OPT_TLS_KEY, .create_ctrl = nvme_tcp_create_ctrl, }; |