diff options
Diffstat (limited to '')
-rw-r--r-- | src/libknot/quic/quic_conn.c | 577 |
1 files changed, 577 insertions, 0 deletions
diff --git a/src/libknot/quic/quic_conn.c b/src/libknot/quic/quic_conn.c new file mode 100644 index 0000000..6616573 --- /dev/null +++ b/src/libknot/quic/quic_conn.c @@ -0,0 +1,577 @@ +/* Copyright (C) 2023 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz> + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see <https://www.gnu.org/licenses/>. + */ + +#include <assert.h> +#include <gnutls/gnutls.h> +#include <ngtcp2/ngtcp2.h> +#include <stdio.h> +#include <string.h> + +#include "libknot/quic/quic_conn.h" + +#include "contrib/macros.h" +#include "contrib/openbsd/siphash.h" +#include "contrib/ucw/heap.h" +#include "contrib/ucw/lists.h" +#include "libdnssec/random.h" +#include "libknot/attribute.h" +#include "libknot/error.h" +#include "libknot/quic/quic.h" +#include "libknot/xdp/tcp_iobuf.h" +#include "libknot/wire.h" + +#define STREAM_INCR 4 // DoQ only uses client-initiated bi-directional streams, so stream IDs increment by four +#define BUCKETS_PER_CONNS 8 // Each connecion has several dCIDs, and each CID takes one hash table bucket. + +static int cmp_expiry_heap_nodes(void *c1, void *c2) +{ + if (((knot_quic_conn_t *)c1)->next_expiry < ((knot_quic_conn_t *)c2)->next_expiry) return -1; + if (((knot_quic_conn_t *)c1)->next_expiry > ((knot_quic_conn_t *)c2)->next_expiry) return 1; + return 0; +} + +_public_ +knot_quic_table_t *knot_quic_table_new(size_t max_conns, size_t max_ibufs, size_t max_obufs, + size_t udp_payload, struct knot_quic_creds *creds) +{ + size_t table_size = max_conns * BUCKETS_PER_CONNS; + + knot_quic_table_t *res = calloc(1, sizeof(*res) + table_size * sizeof(res->conns[0])); + if (res == NULL || creds == NULL) { + free(res); + return NULL; + } + + res->size = table_size; + res->max_conns = max_conns; + res->ibufs_max = max_ibufs; + res->obufs_max = max_obufs; + res->udp_payload_limit = udp_payload; + + res->expiry_heap = malloc(sizeof(struct heap)); + if (res->expiry_heap == NULL || !heap_init(res->expiry_heap, cmp_expiry_heap_nodes, 0)) { + free(res->expiry_heap); + free(res); + return NULL; + } + + res->creds = creds; + + res->hash_secret[0] = dnssec_random_uint64_t(); + res->hash_secret[1] = dnssec_random_uint64_t(); + res->hash_secret[2] = dnssec_random_uint64_t(); + res->hash_secret[3] = dnssec_random_uint64_t(); + + return res; +} + +_public_ +void knot_quic_table_free(knot_quic_table_t *table) +{ + if (table != NULL) { + while (!EMPTY_HEAP(table->expiry_heap)) { + knot_quic_conn_t *c = *(knot_quic_conn_t **)HHEAD(table->expiry_heap); + knot_quic_table_rem(c, table); + knot_quic_cleanup(&c, 1); + } + assert(table->usage == 0); + assert(table->pointers == 0); + assert(table->ibufs_size == 0); + assert(table->obufs_size == 0); + + heap_deinit(table->expiry_heap); + free(table->expiry_heap); + free(table); + } +} + +static void send_excessive_load(knot_quic_conn_t *conn, struct knot_quic_reply *reply, + knot_quic_table_t *table) +{ + if (reply != NULL) { + reply->handle_ret = KNOT_QUIC_ERR_EXCESSIVE_LOAD; + (void)knot_quic_send(table, conn, reply, 0, 0); + } +} + +_public_ +void knot_quic_table_sweep(knot_quic_table_t *table, struct knot_quic_reply *sweep_reply, + struct knot_sweep_stats *stats) +{ + uint64_t now = 0; + if (table == NULL || stats == NULL) { + return; + } + + while (!EMPTY_HEAP(table->expiry_heap)) { + knot_quic_conn_t *c = *(knot_quic_conn_t **)HHEAD(table->expiry_heap); + if (table->usage > table->max_conns) { + knot_sweep_stats_incr(stats, KNOT_SWEEP_CTR_LIMIT_CONN); + send_excessive_load(c, sweep_reply, table); + knot_quic_table_rem(c, table); + } else if (table->obufs_size > table->obufs_max) { + knot_sweep_stats_incr(stats, KNOT_SWEEP_CTR_LIMIT_OBUF); + send_excessive_load(c, sweep_reply, table); + knot_quic_table_rem(c, table); + } else if (table->ibufs_size > table->ibufs_max) { + knot_sweep_stats_incr(stats, KNOT_SWEEP_CTR_LIMIT_IBUF); + send_excessive_load(c, sweep_reply, table); + knot_quic_table_rem(c, table); + } else if (quic_conn_timeout(c, &now)) { + int ret = ngtcp2_conn_handle_expiry(c->conn, now); + if (ret != NGTCP2_NO_ERROR) { // usually NGTCP2_ERR_IDLE_CLOSE or NGTCP2_ERR_HANDSHAKE_TIMEOUT + knot_sweep_stats_incr(stats, KNOT_SWEEP_CTR_TIMEOUT); + knot_quic_table_rem(c, table); + } else { + if (sweep_reply != NULL) { + sweep_reply->handle_ret = KNOT_EOK; + (void)knot_quic_send(table, c, sweep_reply, 0, 0); + } + quic_conn_mark_used(c, table); + } + } + knot_quic_cleanup(&c, 1); + + if (*(knot_quic_conn_t **)HHEAD(table->expiry_heap) == c) { // HHEAD already handled, NOOP, avoid infinite loop + break; + } + } +} + +static uint64_t cid2hash(const ngtcp2_cid *cid, knot_quic_table_t *table) +{ + SIPHASH_CTX ctx; + SipHash24_Init(&ctx, (const SIPHASH_KEY *)(table->hash_secret)); + SipHash24_Update(&ctx, cid->data, MIN(cid->datalen, 8)); + uint64_t ret = SipHash24_End(&ctx); + return ret; +} + +knot_quic_cid_t **quic_table_insert(knot_quic_conn_t *conn, const ngtcp2_cid *cid, + knot_quic_table_t *table) +{ + uint64_t hash = cid2hash(cid, table); + + knot_quic_cid_t *cidobj = malloc(sizeof(*cidobj)); + if (cidobj == NULL) { + return NULL; + } + _Static_assert(sizeof(*cid) <= sizeof(cidobj->cid_placeholder), "insufficient placeholder for CID struct"); + memcpy(cidobj->cid_placeholder, cid, sizeof(*cid)); + cidobj->conn = conn; + + knot_quic_cid_t **addto = table->conns + (hash % table->size); + + cidobj->next = *addto; + *addto = cidobj; + table->pointers++; + + return addto; +} + +knot_quic_conn_t *quic_table_add(ngtcp2_conn *ngconn, const ngtcp2_cid *cid, + knot_quic_table_t *table) +{ + knot_quic_conn_t *conn = calloc(1, sizeof(*conn)); + if (conn == NULL) { + return NULL; + } + + conn->conn = ngconn; + conn->quic_table = table; + conn->stream_inprocess = -1; + conn->qlog_fd = -1; + + conn->next_expiry = UINT64_MAX; + if (!heap_insert(table->expiry_heap, (heap_val_t *)conn)) { + free(conn); + return NULL; + } + + knot_quic_cid_t **addto = quic_table_insert(conn, cid, table); + if (addto == NULL) { + heap_delete(table->expiry_heap, heap_find(table->expiry_heap, (heap_val_t *)conn)); + free(conn); + return NULL; + } + table->usage++; + + return conn; +} + +knot_quic_cid_t **quic_table_lookup2(const ngtcp2_cid *cid, knot_quic_table_t *table) +{ + uint64_t hash = cid2hash(cid, table); + + knot_quic_cid_t **res = table->conns + (hash % table->size); + while (*res != NULL && !ngtcp2_cid_eq(cid, (const ngtcp2_cid *)(*res)->cid_placeholder)) { + res = &(*res)->next; + } + return res; +} + +knot_quic_conn_t *quic_table_lookup(const ngtcp2_cid *cid, knot_quic_table_t *table) +{ + knot_quic_cid_t **pcid = quic_table_lookup2(cid, table); + assert(pcid != NULL); + return *pcid == NULL ? NULL : (*pcid)->conn; +} + +static void conn_heap_reschedule(knot_quic_conn_t *conn, knot_quic_table_t *table) +{ + heap_replace(table->expiry_heap, heap_find(table->expiry_heap, (heap_val_t *)conn), (heap_val_t *)conn); +} + +void quic_conn_mark_used(knot_quic_conn_t *conn, knot_quic_table_t *table) +{ + conn->next_expiry = quic_conn_get_timeout(conn); + conn_heap_reschedule(conn, table); +} + +void quic_table_rem2(knot_quic_cid_t **pcid, knot_quic_table_t *table) +{ + knot_quic_cid_t *cid = *pcid; + *pcid = cid->next; + free(cid); + table->pointers--; +} + +_public_ +void knot_quic_conn_stream_free(knot_quic_conn_t *conn, int64_t stream_id) +{ + knot_quic_stream_t *s = knot_quic_conn_get_stream(conn, stream_id, false); + if (s != NULL && s->inbuf.iov_len > 0) { + free(s->inbuf.iov_base); + conn->ibufs_size -= buffer_alloc_size(s->inbuf.iov_len); + conn->quic_table->ibufs_size -= buffer_alloc_size(s->inbuf.iov_len); + memset(&s->inbuf, 0, sizeof(s->inbuf)); + } + while (s != NULL && s->inbufs != NULL) { + void *tofree = s->inbufs; + s->inbufs = s->inbufs->next; + free(tofree); + } + knot_quic_stream_ack_data(conn, stream_id, SIZE_MAX, false); +} + +_public_ +void knot_quic_table_rem(knot_quic_conn_t *conn, knot_quic_table_t *table) +{ + if (conn == NULL || conn->conn == NULL || table == NULL) { + return; + } + + if (conn->streams_count == -1) { // kxdpgun special + conn->streams_count = 1; + } + for (ssize_t i = conn->streams_count - 1; i >= 0; i--) { + knot_quic_conn_stream_free(conn, (i + conn->streams_first) * 4); + } + assert(conn->streams_count <= 0); + assert(conn->obufs_size == 0); + + size_t num_scid = ngtcp2_conn_get_scid(conn->conn, NULL); + ngtcp2_cid *scids = calloc(num_scid, sizeof(*scids)); + ngtcp2_conn_get_scid(conn->conn, scids); + + for (size_t i = 0; i < num_scid; i++) { + knot_quic_cid_t **pcid = quic_table_lookup2(&scids[i], table); + assert(pcid != NULL); + if (*pcid == NULL) { + continue; + } + assert((*pcid)->conn == conn); + quic_table_rem2(pcid, table); + } + + int pos = heap_find(table->expiry_heap, (heap_val_t *)conn); + heap_delete(table->expiry_heap, pos); + + free(scids); + + gnutls_deinit(conn->tls_session); + ngtcp2_conn_del(conn->conn); + conn->conn = NULL; + + table->usage--; +} + +_public_ +knot_quic_stream_t *knot_quic_conn_get_stream(knot_quic_conn_t *conn, + int64_t stream_id, bool create) +{ + if (stream_id % 4 != 0 || conn == NULL) { + return NULL; + } + stream_id /= 4; + + if (conn->streams_first > stream_id) { + return NULL; + } + if (conn->streams_count > stream_id - conn->streams_first) { + return &conn->streams[stream_id - conn->streams_first]; + } + + if (create) { + size_t new_streams_count; + knot_quic_stream_t *new_streams; + + if (conn->streams_count == 0) { + new_streams = malloc(sizeof(new_streams[0])); + if (new_streams == NULL) { + return NULL; + } + new_streams_count = 1; + conn->streams_first = stream_id; + } else { + new_streams_count = stream_id + 1 - conn->streams_first; + if (new_streams_count > MAX_STREAMS_PER_CONN) { + return NULL; + } + new_streams = realloc(conn->streams, new_streams_count * sizeof(*new_streams)); + if (new_streams == NULL) { + return NULL; + } + } + + for (knot_quic_stream_t *si = new_streams; + si < new_streams + conn->streams_count; si++) { + if (si->obufs_size == 0) { + init_list((list_t *)&si->outbufs); + } else { + fix_list((list_t *)&si->outbufs); + } + } + + for (knot_quic_stream_t *si = new_streams + conn->streams_count; + si < new_streams + new_streams_count; si++) { + memset(si, 0, sizeof(*si)); + init_list((list_t *)&si->outbufs); + } + conn->streams = new_streams; + conn->streams_count = new_streams_count; + + return &conn->streams[stream_id - conn->streams_first]; + } + return NULL; +} + +_public_ +knot_quic_stream_t *knot_quic_conn_new_stream(knot_quic_conn_t *conn) +{ + int64_t new_id = (conn->streams_first + conn->streams_count) * 4; + return knot_quic_conn_get_stream(conn, new_id, true); +} + +static void stream_inprocess(knot_quic_conn_t *conn, knot_quic_stream_t *stream) +{ + int16_t idx = stream - conn->streams; + assert(idx >= 0); + assert(idx < conn->streams_count); + if (conn->stream_inprocess < 0 || conn->stream_inprocess > idx) { + conn->stream_inprocess = idx; + } +} + +static void stream_outprocess(knot_quic_conn_t *conn, knot_quic_stream_t *stream) +{ + if (stream != &conn->streams[conn->stream_inprocess]) { + return; + } + + for (int16_t idx = conn->stream_inprocess + 1; idx < conn->streams_count; idx++) { + stream = &conn->streams[idx]; + if (stream->inbufs != NULL) { + conn->stream_inprocess = stream - conn->streams; + return; + } + } + conn->stream_inprocess = -1; +} + +int knot_quic_stream_recv_data(knot_quic_conn_t *conn, int64_t stream_id, + const uint8_t *data, size_t len, bool fin) +{ + if (len == 0 || conn == NULL || data == NULL) { + return KNOT_EINVAL; + } + + knot_quic_stream_t *stream = knot_quic_conn_get_stream(conn, stream_id, true); + if (stream == NULL) { + return KNOT_ENOENT; + } + + struct iovec in = { (void *)data, len }; + ssize_t prev_ibufs_size = conn->ibufs_size; + int ret = knot_tcp_inbufs_upd(&stream->inbuf, in, true, + &stream->inbufs, &conn->ibufs_size); + conn->quic_table->ibufs_size += (ssize_t)conn->ibufs_size - prev_ibufs_size; + if (ret != KNOT_EOK) { + return ret; + } + + if (fin && stream->inbufs == NULL) { + return KNOT_ESEMCHECK; + } + + if (stream->inbufs != NULL) { + stream_inprocess(conn, stream); + } + return KNOT_EOK; +} + +_public_ +knot_quic_stream_t *knot_quic_stream_get_process(knot_quic_conn_t *conn, + int64_t *stream_id) +{ + if (conn == NULL || conn->stream_inprocess < 0) { + return NULL; + } + + knot_quic_stream_t *stream = &conn->streams[conn->stream_inprocess]; + *stream_id = (conn->streams_first + conn->stream_inprocess) * 4; + stream_outprocess(conn, stream); + return stream; +} + +_public_ +uint8_t *knot_quic_stream_add_data(knot_quic_conn_t *conn, int64_t stream_id, + uint8_t *data, size_t len) +{ + knot_quic_stream_t *s = knot_quic_conn_get_stream(conn, stream_id, true); + if (s == NULL) { + return NULL; + } + + size_t prefix = sizeof(uint16_t); + + knot_quic_obuf_t *obuf = malloc(sizeof(*obuf) + prefix + len); + if (obuf == NULL) { + return NULL; + } + + obuf->len = len + prefix; + knot_wire_write_u16(obuf->buf, len); + if (data != NULL) { + memcpy(obuf->buf + prefix, data, len); + } + + list_t *list = (list_t *)&s->outbufs; + if (EMPTY_LIST(*list)) { + s->unsent_obuf = obuf; + } + add_tail((list_t *)&s->outbufs, (node_t *)obuf); + s->obufs_size += obuf->len; + conn->obufs_size += obuf->len; + conn->quic_table->obufs_size += obuf->len; + + return obuf->buf + prefix; +} + +void knot_quic_stream_ack_data(knot_quic_conn_t *conn, int64_t stream_id, + size_t end_acked, bool keep_stream) +{ + knot_quic_stream_t *s = knot_quic_conn_get_stream(conn, stream_id, false); + if (s == NULL) { + return; + } + + list_t *obs = (list_t *)&s->outbufs; + + knot_quic_obuf_t *first; + while (!EMPTY_LIST(*obs) && end_acked >= (first = HEAD(*obs))->len + s->first_offset) { + rem_node((node_t *)first); + assert(HEAD(*obs) != first); // help CLANG analyzer understand what rem_node did and that further usage of HEAD(*obs) is safe + s->obufs_size -= first->len; + conn->obufs_size -= first->len; + conn->quic_table->obufs_size -= first->len; + s->first_offset += first->len; + free(first); + if (s->unsent_obuf == first) { + s->unsent_obuf = EMPTY_LIST(*obs) ? NULL : HEAD(*obs); + s->unsent_offset = 0; + } + } + + if (EMPTY_LIST(*obs) && !keep_stream) { + stream_outprocess(conn, s); + memset(s, 0, sizeof(*s)); + init_list((list_t *)&s->outbufs); + while (s = &conn->streams[0], s->inbuf.iov_len == 0 && s->inbufs == NULL && s->obufs_size == 0) { + assert(conn->streams_count > 0); + conn->streams_count--; + + if (conn->streams_count == 0) { + free(conn->streams); + conn->streams = 0; + conn->streams_first = 0; + break; + } else { + conn->streams_first++; + conn->stream_inprocess--; + memmove(s, s + 1, sizeof(*s) * conn->streams_count); + // possible realloc to shrink allocated space, but probably useless + for (knot_quic_stream_t *si = s; si < s + conn->streams_count; si++) { + if (si->obufs_size == 0) { + init_list((list_t *)&si->outbufs); + } else { + fix_list((list_t *)&si->outbufs); + } + } + } + } + } +} + +void knot_quic_stream_mark_sent(knot_quic_conn_t *conn, int64_t stream_id, + size_t amount_sent) +{ + knot_quic_stream_t *s = knot_quic_conn_get_stream(conn, stream_id, false); + if (s == NULL) { + return; + } + + s->unsent_offset += amount_sent; + assert(s->unsent_offset <= s->unsent_obuf->len); + if (s->unsent_offset == s->unsent_obuf->len) { + s->unsent_offset = 0; + s->unsent_obuf = (knot_quic_obuf_t *)s->unsent_obuf->node.next; + if (s->unsent_obuf->node.next == NULL) { // already behind the tail of list + s->unsent_obuf = NULL; + } + } +} + +_public_ +void knot_quic_cleanup(knot_quic_conn_t *conns[], size_t n_conns) +{ + for (size_t i = 0; i < n_conns; i++) { + if (conns[i] != NULL && conns[i]->conn == NULL) { + free(conns[i]); + for (size_t j = i + 1; j < n_conns; j++) { + if (conns[j] == conns[i]) { + conns[j] = NULL; + } + } + } + } +} + +bool quic_require_retry(knot_quic_table_t *table) +{ + (void)table; + return false; +} |