summaryrefslogtreecommitdiffstats
path: root/net/tls/tls_sw.c
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--net/tls/tls_sw.c259
1 files changed, 168 insertions, 91 deletions
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 0323040d3..2bd27b777 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -62,6 +62,7 @@ struct tls_decrypt_ctx {
u8 iv[MAX_IV_SIZE];
u8 aad[TLS_MAX_AAD_SIZE];
u8 tail;
+ bool free_sgout;
struct scatterlist sg[];
};
@@ -186,7 +187,6 @@ static void tls_decrypt_done(crypto_completion_data_t *data, int err)
struct aead_request *aead_req = crypto_get_completion_data(data);
struct crypto_aead *aead = crypto_aead_reqtfm(aead_req);
struct scatterlist *sgout = aead_req->dst;
- struct scatterlist *sgin = aead_req->src;
struct tls_sw_context_rx *ctx;
struct tls_decrypt_ctx *dctx;
struct tls_context *tls_ctx;
@@ -212,7 +212,7 @@ static void tls_decrypt_done(crypto_completion_data_t *data, int err)
}
/* Free the destination pages if skb was not decrypted inplace */
- if (sgout != sgin) {
+ if (dctx->free_sgout) {
/* Skip the first S/G entry as it points to AAD */
for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
if (!sg)
@@ -223,10 +223,17 @@ static void tls_decrypt_done(crypto_completion_data_t *data, int err)
kfree(aead_req);
- spin_lock_bh(&ctx->decrypt_compl_lock);
- if (!atomic_dec_return(&ctx->decrypt_pending))
+ if (atomic_dec_and_test(&ctx->decrypt_pending))
complete(&ctx->async_wait.completion);
- spin_unlock_bh(&ctx->decrypt_compl_lock);
+}
+
+static int tls_decrypt_async_wait(struct tls_sw_context_rx *ctx)
+{
+ if (!atomic_dec_and_test(&ctx->decrypt_pending))
+ crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
+ atomic_inc(&ctx->decrypt_pending);
+
+ return ctx->async_wait.err;
}
static int tls_do_decryption(struct sock *sk,
@@ -252,6 +259,7 @@ static int tls_do_decryption(struct sock *sk,
aead_request_set_callback(aead_req,
CRYPTO_TFM_REQ_MAY_BACKLOG,
tls_decrypt_done, aead_req);
+ DEBUG_NET_WARN_ON_ONCE(atomic_read(&ctx->decrypt_pending) < 1);
atomic_inc(&ctx->decrypt_pending);
} else {
aead_request_set_callback(aead_req,
@@ -265,6 +273,8 @@ static int tls_do_decryption(struct sock *sk,
return 0;
ret = crypto_wait_req(ret, &ctx->async_wait);
+ } else if (darg->async) {
+ atomic_dec(&ctx->decrypt_pending);
}
darg->async = false;
@@ -441,7 +451,6 @@ static void tls_encrypt_done(crypto_completion_data_t *data, int err)
struct tls_rec *rec;
bool ready = false;
struct sock *sk;
- int pending;
rec = container_of(aead_req, struct tls_rec, aead_req);
msg_en = &rec->msg_encrypted;
@@ -481,12 +490,8 @@ static void tls_encrypt_done(crypto_completion_data_t *data, int err)
ready = true;
}
- spin_lock_bh(&ctx->encrypt_compl_lock);
- pending = atomic_dec_return(&ctx->encrypt_pending);
-
- if (!pending && ctx->async_notify)
+ if (atomic_dec_and_test(&ctx->encrypt_pending))
complete(&ctx->async_wait.completion);
- spin_unlock_bh(&ctx->encrypt_compl_lock);
if (!ready)
return;
@@ -496,6 +501,15 @@ static void tls_encrypt_done(crypto_completion_data_t *data, int err)
schedule_delayed_work(&ctx->tx_work.work, 1);
}
+static int tls_encrypt_async_wait(struct tls_sw_context_tx *ctx)
+{
+ if (!atomic_dec_and_test(&ctx->encrypt_pending))
+ crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
+ atomic_inc(&ctx->encrypt_pending);
+
+ return ctx->async_wait.err;
+}
+
static int tls_do_encryption(struct sock *sk,
struct tls_context *tls_ctx,
struct tls_sw_context_tx *ctx,
@@ -542,6 +556,7 @@ static int tls_do_encryption(struct sock *sk,
/* Add the record in tx_list */
list_add_tail((struct list_head *)&rec->list, &ctx->tx_list);
+ DEBUG_NET_WARN_ON_ONCE(atomic_read(&ctx->encrypt_pending) < 1);
atomic_inc(&ctx->encrypt_pending);
rc = crypto_aead_encrypt(aead_req);
@@ -953,7 +968,6 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
int num_zc = 0;
int orig_size;
int ret = 0;
- int pending;
if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
MSG_CMSG_COMPAT))
@@ -1122,24 +1136,12 @@ trim_sgl:
if (!num_async) {
goto send_end;
} else if (num_zc) {
- /* Wait for pending encryptions to get completed */
- spin_lock_bh(&ctx->encrypt_compl_lock);
- ctx->async_notify = true;
-
- pending = atomic_read(&ctx->encrypt_pending);
- spin_unlock_bh(&ctx->encrypt_compl_lock);
- if (pending)
- crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
- else
- reinit_completion(&ctx->async_wait.completion);
-
- /* There can be no concurrent accesses, since we have no
- * pending encrypt operations
- */
- WRITE_ONCE(ctx->async_notify, false);
+ int err;
- if (ctx->async_wait.err) {
- ret = ctx->async_wait.err;
+ /* Wait for pending encryptions to get completed */
+ err = tls_encrypt_async_wait(ctx);
+ if (err) {
+ ret = err;
copied = 0;
}
}
@@ -1158,6 +1160,67 @@ send_end:
return copied > 0 ? copied : ret;
}
+/*
+ * Handle unexpected EOF during splice without SPLICE_F_MORE set.
+ */
+void tls_sw_splice_eof(struct socket *sock)
+{
+ struct sock *sk = sock->sk;
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+ struct tls_rec *rec;
+ struct sk_msg *msg_pl;
+ ssize_t copied = 0;
+ bool retrying = false;
+ int ret = 0;
+
+ if (!ctx->open_rec)
+ return;
+
+ mutex_lock(&tls_ctx->tx_lock);
+ lock_sock(sk);
+
+retry:
+ /* same checks as in tls_sw_push_pending_record() */
+ rec = ctx->open_rec;
+ if (!rec)
+ goto unlock;
+
+ msg_pl = &rec->msg_plaintext;
+ if (msg_pl->sg.size == 0)
+ goto unlock;
+
+ /* Check the BPF advisor and perform transmission. */
+ ret = bpf_exec_tx_verdict(msg_pl, sk, false, TLS_RECORD_TYPE_DATA,
+ &copied, 0);
+ switch (ret) {
+ case 0:
+ case -EAGAIN:
+ if (retrying)
+ goto unlock;
+ retrying = true;
+ goto retry;
+ case -EINPROGRESS:
+ break;
+ default:
+ goto unlock;
+ }
+
+ /* Wait for pending encryptions to get completed */
+ if (tls_encrypt_async_wait(ctx))
+ goto unlock;
+
+ /* Transmit if any encryptions have completed */
+ if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
+ cancel_delayed_work(&ctx->tx_work.work);
+ tls_tx_records(sk, 0);
+ }
+
+unlock:
+ release_sock(sk);
+ mutex_unlock(&tls_ctx->tx_lock);
+}
+
static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
int offset, size_t size, int flags)
{
@@ -1595,6 +1658,7 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
} else if (out_sg) {
memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
}
+ dctx->free_sgout = !!pages;
/* Prepare and submit AEAD request */
err = tls_do_decryption(sk, sgin, sgout, dctx->iv,
@@ -1783,7 +1847,8 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
u8 *control,
size_t skip,
size_t len,
- bool is_peek)
+ bool is_peek,
+ bool *more)
{
struct sk_buff *skb = skb_peek(&ctx->rx_list);
struct tls_msg *tlm;
@@ -1796,7 +1861,7 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
err = tls_record_content_type(msg, tlm, control);
if (err <= 0)
- goto out;
+ goto more;
if (skip < rxm->full_len)
break;
@@ -1814,12 +1879,12 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
err = tls_record_content_type(msg, tlm, control);
if (err <= 0)
- goto out;
+ goto more;
err = skb_copy_datagram_msg(skb, rxm->offset + skip,
msg, chunk);
if (err < 0)
- goto out;
+ goto more;
len = len - chunk;
copied = copied + chunk;
@@ -1855,6 +1920,10 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
out:
return copied ? : err;
+more:
+ if (more)
+ *more = true;
+ goto out;
}
static bool
@@ -1954,10 +2023,12 @@ int tls_sw_recvmsg(struct sock *sk,
struct strp_msg *rxm;
struct tls_msg *tlm;
ssize_t copied = 0;
+ ssize_t peeked = 0;
bool async = false;
int target, err;
bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
bool is_peek = flags & MSG_PEEK;
+ bool rx_more = false;
bool released = true;
bool bpf_strp_enabled;
bool zc_capable;
@@ -1977,12 +2048,12 @@ int tls_sw_recvmsg(struct sock *sk,
goto end;
/* Process pending decrypted records. It must be non-zero-copy */
- err = process_rx_list(ctx, msg, &control, 0, len, is_peek);
+ err = process_rx_list(ctx, msg, &control, 0, len, is_peek, &rx_more);
if (err < 0)
goto end;
copied = err;
- if (len <= copied)
+ if (len <= copied || (copied && control != TLS_RECORD_TYPE_DATA) || rx_more)
goto end;
target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
@@ -2075,6 +2146,8 @@ put_on_rx_list:
decrypted += chunk;
len -= chunk;
__skb_queue_tail(&ctx->rx_list, skb);
+ if (unlikely(control != TLS_RECORD_TYPE_DATA))
+ break;
continue;
}
@@ -2098,8 +2171,10 @@ put_on_rx_list:
if (err < 0)
goto put_on_rx_list_err;
- if (is_peek)
+ if (is_peek) {
+ peeked += chunk;
goto put_on_rx_list;
+ }
if (partially_consumed) {
rxm->offset += chunk;
@@ -2123,16 +2198,10 @@ put_on_rx_list:
recv_end:
if (async) {
- int ret, pending;
+ int ret;
/* Wait for all previously submitted records to be decrypted */
- spin_lock_bh(&ctx->decrypt_compl_lock);
- reinit_completion(&ctx->async_wait.completion);
- pending = atomic_read(&ctx->decrypt_pending);
- spin_unlock_bh(&ctx->decrypt_compl_lock);
- ret = 0;
- if (pending)
- ret = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
+ ret = tls_decrypt_async_wait(ctx);
__skb_queue_purge(&ctx->async_hold);
if (ret) {
@@ -2144,12 +2213,11 @@ recv_end:
/* Drain records from the rx_list & copy if required */
if (is_peek || is_kvec)
- err = process_rx_list(ctx, msg, &control, copied,
- decrypted, is_peek);
+ err = process_rx_list(ctx, msg, &control, copied + peeked,
+ decrypted - peeked, is_peek, NULL);
else
err = process_rx_list(ctx, msg, &control, 0,
- async_copy_bytes, is_peek);
- decrypted += max(err, 0);
+ async_copy_bytes, is_peek, NULL);
}
copied += decrypted;
@@ -2351,16 +2419,9 @@ void tls_sw_release_resources_tx(struct sock *sk)
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
struct tls_rec *rec, *tmp;
- int pending;
/* Wait for any pending async encryptions to complete */
- spin_lock_bh(&ctx->encrypt_compl_lock);
- ctx->async_notify = true;
- pending = atomic_read(&ctx->encrypt_pending);
- spin_unlock_bh(&ctx->encrypt_compl_lock);
-
- if (pending)
- crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
+ tls_encrypt_async_wait(ctx);
tls_tx_records(sk, -1);
@@ -2513,6 +2574,48 @@ void tls_update_rx_zc_capable(struct tls_context *tls_ctx)
tls_ctx->prot_info.version != TLS_1_3_VERSION;
}
+static struct tls_sw_context_tx *init_ctx_tx(struct tls_context *ctx, struct sock *sk)
+{
+ struct tls_sw_context_tx *sw_ctx_tx;
+
+ if (!ctx->priv_ctx_tx) {
+ sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
+ if (!sw_ctx_tx)
+ return NULL;
+ } else {
+ sw_ctx_tx = ctx->priv_ctx_tx;
+ }
+
+ crypto_init_wait(&sw_ctx_tx->async_wait);
+ atomic_set(&sw_ctx_tx->encrypt_pending, 1);
+ INIT_LIST_HEAD(&sw_ctx_tx->tx_list);
+ INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler);
+ sw_ctx_tx->tx_work.sk = sk;
+
+ return sw_ctx_tx;
+}
+
+static struct tls_sw_context_rx *init_ctx_rx(struct tls_context *ctx)
+{
+ struct tls_sw_context_rx *sw_ctx_rx;
+
+ if (!ctx->priv_ctx_rx) {
+ sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
+ if (!sw_ctx_rx)
+ return NULL;
+ } else {
+ sw_ctx_rx = ctx->priv_ctx_rx;
+ }
+
+ crypto_init_wait(&sw_ctx_rx->async_wait);
+ atomic_set(&sw_ctx_rx->decrypt_pending, 1);
+ init_waitqueue_head(&sw_ctx_rx->wq);
+ skb_queue_head_init(&sw_ctx_rx->rx_list);
+ skb_queue_head_init(&sw_ctx_rx->async_hold);
+
+ return sw_ctx_rx;
+}
+
int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
@@ -2534,48 +2637,22 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
}
if (tx) {
- if (!ctx->priv_ctx_tx) {
- sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
- if (!sw_ctx_tx) {
- rc = -ENOMEM;
- goto out;
- }
- ctx->priv_ctx_tx = sw_ctx_tx;
- } else {
- sw_ctx_tx =
- (struct tls_sw_context_tx *)ctx->priv_ctx_tx;
- }
- } else {
- if (!ctx->priv_ctx_rx) {
- sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
- if (!sw_ctx_rx) {
- rc = -ENOMEM;
- goto out;
- }
- ctx->priv_ctx_rx = sw_ctx_rx;
- } else {
- sw_ctx_rx =
- (struct tls_sw_context_rx *)ctx->priv_ctx_rx;
- }
- }
+ ctx->priv_ctx_tx = init_ctx_tx(ctx, sk);
+ if (!ctx->priv_ctx_tx)
+ return -ENOMEM;
- if (tx) {
- crypto_init_wait(&sw_ctx_tx->async_wait);
- spin_lock_init(&sw_ctx_tx->encrypt_compl_lock);
+ sw_ctx_tx = ctx->priv_ctx_tx;
crypto_info = &ctx->crypto_send.info;
cctx = &ctx->tx;
aead = &sw_ctx_tx->aead_send;
- INIT_LIST_HEAD(&sw_ctx_tx->tx_list);
- INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler);
- sw_ctx_tx->tx_work.sk = sk;
} else {
- crypto_init_wait(&sw_ctx_rx->async_wait);
- spin_lock_init(&sw_ctx_rx->decrypt_compl_lock);
- init_waitqueue_head(&sw_ctx_rx->wq);
+ ctx->priv_ctx_rx = init_ctx_rx(ctx);
+ if (!ctx->priv_ctx_rx)
+ return -ENOMEM;
+
+ sw_ctx_rx = ctx->priv_ctx_rx;
crypto_info = &ctx->crypto_recv.info;
cctx = &ctx->rx;
- skb_queue_head_init(&sw_ctx_rx->rx_list);
- skb_queue_head_init(&sw_ctx_rx->async_hold);
aead = &sw_ctx_rx->aead_recv;
}