summaryrefslogtreecommitdiffstats
path: root/drivers/nvme/host/tcp.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/nvme/host/tcp.c')
-rw-r--r--drivers/nvme/host/tcp.c189
1 files changed, 173 insertions, 16 deletions
diff --git a/drivers/nvme/host/tcp.c b/drivers/nvme/host/tcp.c
index f1d62d7442..08805f0278 100644
--- a/drivers/nvme/host/tcp.c
+++ b/drivers/nvme/host/tcp.c
@@ -8,9 +8,14 @@
#include <linux/init.h>
#include <linux/slab.h>
#include <linux/err.h>
+#include <linux/key.h>
#include <linux/nvme-tcp.h>
+#include <linux/nvme-keyring.h>
#include <net/sock.h>
#include <net/tcp.h>
+#include <net/tls.h>
+#include <net/tls_prot.h>
+#include <net/handshake.h>
#include <linux/blk-mq.h>
#include <crypto/hash.h>
#include <net/busy_poll.h>
@@ -31,6 +36,16 @@ static int so_priority;
module_param(so_priority, int, 0644);
MODULE_PARM_DESC(so_priority, "nvme tcp socket optimize priority");
+/*
+ * TLS handshake timeout
+ */
+static int tls_handshake_timeout = 10;
+#ifdef CONFIG_NVME_TCP_TLS
+module_param(tls_handshake_timeout, int, 0644);
+MODULE_PARM_DESC(tls_handshake_timeout,
+ "nvme TLS handshake timeout in seconds (default 10)");
+#endif
+
#ifdef CONFIG_DEBUG_LOCK_ALLOC
/* lockdep can detect a circular dependency of the form
* sk_lock -> mmap_lock (page fault) -> fs locks -> sk_lock
@@ -146,7 +161,8 @@ struct nvme_tcp_queue {
struct ahash_request *snd_hash;
__le32 exp_ddgst;
__le32 recv_ddgst;
-
+ struct completion tls_complete;
+ int tls_err;
struct page_frag_cache pf_cache;
void (*state_change)(struct sock *);
@@ -189,6 +205,14 @@ static inline int nvme_tcp_queue_id(struct nvme_tcp_queue *queue)
return queue - queue->ctrl->queues;
}
+static inline bool nvme_tcp_tls(struct nvme_ctrl *ctrl)
+{
+ if (!IS_ENABLED(CONFIG_NVME_TCP_TLS))
+ return 0;
+
+ return ctrl->opts->tls;
+}
+
static inline struct blk_mq_tags *nvme_tcp_tagset(struct nvme_tcp_queue *queue)
{
u32 queue_idx = nvme_tcp_queue_id(queue);
@@ -1338,7 +1362,9 @@ static void nvme_tcp_free_queue(struct nvme_ctrl *nctrl, int qid)
}
noreclaim_flag = memalloc_noreclaim_save();
- sock_release(queue->sock);
+ /* ->sock will be released by fput() */
+ fput(queue->sock->file);
+ queue->sock = NULL;
memalloc_noreclaim_restore(noreclaim_flag);
kfree(queue->pdu);
@@ -1350,6 +1376,8 @@ static int nvme_tcp_init_connection(struct nvme_tcp_queue *queue)
{
struct nvme_tcp_icreq_pdu *icreq;
struct nvme_tcp_icresp_pdu *icresp;
+ char cbuf[CMSG_LEN(sizeof(char))] = {};
+ u8 ctype;
struct msghdr msg = {};
struct kvec iov;
bool ctrl_hdgst, ctrl_ddgst;
@@ -1381,17 +1409,36 @@ static int nvme_tcp_init_connection(struct nvme_tcp_queue *queue)
iov.iov_base = icreq;
iov.iov_len = sizeof(*icreq);
ret = kernel_sendmsg(queue->sock, &msg, &iov, 1, iov.iov_len);
- if (ret < 0)
+ if (ret < 0) {
+ pr_warn("queue %d: failed to send icreq, error %d\n",
+ nvme_tcp_queue_id(queue), ret);
goto free_icresp;
+ }
memset(&msg, 0, sizeof(msg));
iov.iov_base = icresp;
iov.iov_len = sizeof(*icresp);
+ if (nvme_tcp_tls(&queue->ctrl->ctrl)) {
+ msg.msg_control = cbuf;
+ msg.msg_controllen = sizeof(cbuf);
+ }
ret = kernel_recvmsg(queue->sock, &msg, &iov, 1,
iov.iov_len, msg.msg_flags);
- if (ret < 0)
+ if (ret < 0) {
+ pr_warn("queue %d: failed to receive icresp, error %d\n",
+ nvme_tcp_queue_id(queue), ret);
goto free_icresp;
-
+ }
+ ret = -ENOTCONN;
+ if (nvme_tcp_tls(&queue->ctrl->ctrl)) {
+ ctype = tls_get_record_type(queue->sock->sk,
+ (struct cmsghdr *)cbuf);
+ if (ctype != TLS_RECORD_TYPE_DATA) {
+ pr_err("queue %d: unhandled TLS record %d\n",
+ nvme_tcp_queue_id(queue), ctype);
+ goto free_icresp;
+ }
+ }
ret = -EINVAL;
if (icresp->hdr.type != nvme_tcp_icresp) {
pr_err("queue %d: bad type returned %d\n",
@@ -1507,11 +1554,90 @@ static void nvme_tcp_set_queue_io_cpu(struct nvme_tcp_queue *queue)
queue->io_cpu = cpumask_next_wrap(n - 1, cpu_online_mask, -1, false);
}
-static int nvme_tcp_alloc_queue(struct nvme_ctrl *nctrl, int qid)
+static void nvme_tcp_tls_done(void *data, int status, key_serial_t pskid)
+{
+ struct nvme_tcp_queue *queue = data;
+ struct nvme_tcp_ctrl *ctrl = queue->ctrl;
+ int qid = nvme_tcp_queue_id(queue);
+ struct key *tls_key;
+
+ dev_dbg(ctrl->ctrl.device, "queue %d: TLS handshake done, key %x, status %d\n",
+ qid, pskid, status);
+
+ if (status) {
+ queue->tls_err = -status;
+ goto out_complete;
+ }
+
+ tls_key = key_lookup(pskid);
+ if (IS_ERR(tls_key)) {
+ dev_warn(ctrl->ctrl.device, "queue %d: Invalid key %x\n",
+ qid, pskid);
+ queue->tls_err = -ENOKEY;
+ } else {
+ ctrl->ctrl.tls_key = tls_key;
+ queue->tls_err = 0;
+ }
+
+out_complete:
+ complete(&queue->tls_complete);
+}
+
+static int nvme_tcp_start_tls(struct nvme_ctrl *nctrl,
+ struct nvme_tcp_queue *queue,
+ key_serial_t pskid)
+{
+ int qid = nvme_tcp_queue_id(queue);
+ int ret;
+ struct tls_handshake_args args;
+ unsigned long tmo = tls_handshake_timeout * HZ;
+ key_serial_t keyring = nvme_keyring_id();
+
+ dev_dbg(nctrl->device, "queue %d: start TLS with key %x\n",
+ qid, pskid);
+ memset(&args, 0, sizeof(args));
+ args.ta_sock = queue->sock;
+ args.ta_done = nvme_tcp_tls_done;
+ args.ta_data = queue;
+ args.ta_my_peerids[0] = pskid;
+ args.ta_num_peerids = 1;
+ if (nctrl->opts->keyring)
+ keyring = key_serial(nctrl->opts->keyring);
+ args.ta_keyring = keyring;
+ args.ta_timeout_ms = tls_handshake_timeout * 1000;
+ queue->tls_err = -EOPNOTSUPP;
+ init_completion(&queue->tls_complete);
+ ret = tls_client_hello_psk(&args, GFP_KERNEL);
+ if (ret) {
+ dev_err(nctrl->device, "queue %d: failed to start TLS: %d\n",
+ qid, ret);
+ return ret;
+ }
+ ret = wait_for_completion_interruptible_timeout(&queue->tls_complete, tmo);
+ if (ret <= 0) {
+ if (ret == 0)
+ ret = -ETIMEDOUT;
+
+ dev_err(nctrl->device,
+ "queue %d: TLS handshake failed, error %d\n",
+ qid, ret);
+ tls_handshake_cancel(queue->sock->sk);
+ } else {
+ dev_dbg(nctrl->device,
+ "queue %d: TLS handshake complete, error %d\n",
+ qid, queue->tls_err);
+ ret = queue->tls_err;
+ }
+ return ret;
+}
+
+static int nvme_tcp_alloc_queue(struct nvme_ctrl *nctrl, int qid,
+ key_serial_t pskid)
{
struct nvme_tcp_ctrl *ctrl = to_tcp_ctrl(nctrl);
struct nvme_tcp_queue *queue = &ctrl->queues[qid];
int ret, rcv_pdu_size;
+ struct file *sock_file;
mutex_init(&queue->queue_lock);
queue->ctrl = ctrl;
@@ -1534,6 +1660,11 @@ static int nvme_tcp_alloc_queue(struct nvme_ctrl *nctrl, int qid)
goto err_destroy_mutex;
}
+ sock_file = sock_alloc_file(queue->sock, O_CLOEXEC, NULL);
+ if (IS_ERR(sock_file)) {
+ ret = PTR_ERR(sock_file);
+ goto err_destroy_mutex;
+ }
nvme_tcp_reclassify_socket(queue->sock);
/* Single syn retry */
@@ -1624,6 +1755,13 @@ static int nvme_tcp_alloc_queue(struct nvme_ctrl *nctrl, int qid)
goto err_rcv_pdu;
}
+ /* If PSKs are configured try to start TLS */
+ if (IS_ENABLED(CONFIG_NVME_TCP_TLS) && pskid) {
+ ret = nvme_tcp_start_tls(nctrl, queue, pskid);
+ if (ret)
+ goto err_init_connect;
+ }
+
ret = nvme_tcp_init_connection(queue);
if (ret)
goto err_init_connect;
@@ -1640,7 +1778,8 @@ err_crypto:
if (queue->hdr_digest || queue->data_digest)
nvme_tcp_free_crypto(queue);
err_sock:
- sock_release(queue->sock);
+ /* ->sock will be released by fput() */
+ fput(queue->sock->file);
queue->sock = NULL;
err_destroy_mutex:
mutex_destroy(&queue->send_mutex);
@@ -1772,10 +1911,25 @@ out_stop_queues:
static int nvme_tcp_alloc_admin_queue(struct nvme_ctrl *ctrl)
{
int ret;
+ key_serial_t pskid = 0;
+
+ if (nvme_tcp_tls(ctrl)) {
+ if (ctrl->opts->tls_key)
+ pskid = key_serial(ctrl->opts->tls_key);
+ else
+ pskid = nvme_tls_psk_default(ctrl->opts->keyring,
+ ctrl->opts->host->nqn,
+ ctrl->opts->subsysnqn);
+ if (!pskid) {
+ dev_err(ctrl->device, "no valid PSK found\n");
+ ret = -ENOKEY;
+ goto out_free_queue;
+ }
+ }
- ret = nvme_tcp_alloc_queue(ctrl, 0);
+ ret = nvme_tcp_alloc_queue(ctrl, 0, pskid);
if (ret)
- return ret;
+ goto out_free_queue;
ret = nvme_tcp_alloc_async_req(to_tcp_ctrl(ctrl));
if (ret)
@@ -1792,8 +1946,13 @@ static int __nvme_tcp_alloc_io_queues(struct nvme_ctrl *ctrl)
{
int i, ret;
+ if (nvme_tcp_tls(ctrl) && !ctrl->tls_key) {
+ dev_err(ctrl->device, "no PSK negotiated\n");
+ return -ENOKEY;
+ }
for (i = 1; i < ctrl->queue_count; i++) {
- ret = nvme_tcp_alloc_queue(ctrl, i);
+ ret = nvme_tcp_alloc_queue(ctrl, i,
+ key_serial(ctrl->tls_key));
if (ret)
goto out_free_queues;
}
@@ -2078,11 +2237,8 @@ destroy_io:
nvme_tcp_destroy_io_queues(ctrl, new);
}
destroy_admin:
- nvme_quiesce_admin_queue(ctrl);
- blk_sync_queue(ctrl->admin_q);
- nvme_tcp_stop_queue(ctrl, 0);
- nvme_cancel_admin_tagset(ctrl);
- nvme_tcp_destroy_admin_queue(ctrl, new);
+ nvme_stop_keep_alive(ctrl);
+ nvme_tcp_teardown_admin_queue(ctrl, false);
return ret;
}
@@ -2628,7 +2784,8 @@ static struct nvmf_transport_ops nvme_tcp_transport = {
NVMF_OPT_HOST_TRADDR | NVMF_OPT_CTRL_LOSS_TMO |
NVMF_OPT_HDR_DIGEST | NVMF_OPT_DATA_DIGEST |
NVMF_OPT_NR_WRITE_QUEUES | NVMF_OPT_NR_POLL_QUEUES |
- NVMF_OPT_TOS | NVMF_OPT_HOST_IFACE,
+ NVMF_OPT_TOS | NVMF_OPT_HOST_IFACE | NVMF_OPT_TLS |
+ NVMF_OPT_KEYRING | NVMF_OPT_TLS_KEY,
.create_ctrl = nvme_tcp_create_ctrl,
};