diff options
Diffstat (limited to '')
-rw-r--r-- | tests/pytests/proxy/tls-proxy.c | 1038 |
1 files changed, 1038 insertions, 0 deletions
diff --git a/tests/pytests/proxy/tls-proxy.c b/tests/pytests/proxy/tls-proxy.c new file mode 100644 index 0000000..986716c --- /dev/null +++ b/tests/pytests/proxy/tls-proxy.c @@ -0,0 +1,1038 @@ +/* SPDX-License-Identifier: GPL-3.0-or-later */ + +#include <assert.h> +#include <stdio.h> +#include <unistd.h> +#include <string.h> +#include <stdlib.h> +#include <stdbool.h> +#include <gnutls/gnutls.h> +#include <uv.h> +#include "lib/generic/array.h" +#include "tls-proxy.h" + +#define TLS_MAX_SEND_RETRIES 100 +#define CLIENT_ANSWER_CHUNK_SIZE 8 + +#define MAX_CLIENT_PENDING_SIZE 4096 + +struct buf { + size_t size; + char buf[]; +}; + +enum peer_state { + STATE_NOT_CONNECTED, + STATE_LISTENING, + STATE_CONNECTED, + STATE_CONNECT_IN_PROGRESS, + STATE_CLOSING_IN_PROGRESS +}; + +enum handshake_state { + TLS_HS_NOT_STARTED = 0, + TLS_HS_EXPECTED, + TLS_HS_REAUTH_EXPECTED, + TLS_HS_IN_PROGRESS, + TLS_HS_DONE, + TLS_HS_CLOSING, + TLS_HS_LAST +}; + +struct tls_ctx { + gnutls_session_t session; + enum handshake_state handshake_state; + /* for reading from the network */ + const uint8_t *buf; + ssize_t nread; + ssize_t consumed; + uint8_t recv_buf[4096]; +}; + +struct peer { + uv_tcp_t handle; + enum peer_state state; + struct sockaddr_storage addr; + array_t(struct buf *) pending_buf; + uint64_t connection_timestamp; + struct tls_ctx *tls; + struct peer *peer; + int active_requests; +}; + +struct tls_proxy_ctx { + const struct args *a; + uv_loop_t *loop; + gnutls_certificate_credentials_t tls_credentials; + gnutls_priority_t tls_priority_cache; + struct { + uv_tcp_t handle; + struct sockaddr_storage addr; + } server; + struct sockaddr_storage upstream_addr; + array_t(struct peer *) client_list; + char uv_wire_buf[65535 * 2]; + int conn_sequence; +}; + +static void read_from_upstream_cb(uv_stream_t *upstream, ssize_t nread, const uv_buf_t *buf); +static void read_from_client_cb(uv_stream_t *client, ssize_t nread, const uv_buf_t *buf); +static ssize_t proxy_gnutls_pull(gnutls_transport_ptr_t h, void *buf, size_t len); +static ssize_t proxy_gnutls_push(gnutls_transport_ptr_t h, const void *buf, size_t len); +static int tls_process_from_upstream(struct peer *upstream, const uint8_t *buf, ssize_t nread); +static int tls_process_from_client(struct peer *client, const uint8_t *buf, ssize_t nread); +static int write_to_upstream_pending(struct peer *peer); +static int write_to_client_pending(struct peer *peer); +static void on_client_close(uv_handle_t *handle); +static void on_upstream_close(uv_handle_t *handle); + +static int gnutls_references = 0; + +static const char * const tlsv12_priorities = + "NORMAL:" /* GnuTLS defaults */ + "-VERS-TLS1.0:-VERS-TLS1.1:+VERS-TLS1.2:-VERS-TLS1.3:" /* TLS 1.2 only */ + "-VERS-SSL3.0:-ARCFOUR-128:-COMP-ALL:+COMP-NULL"; + +static const char * const tlsv13_priorities = + "NORMAL:" /* GnuTLS defaults */ + "-VERS-TLS1.0:-VERS-TLS1.1:-VERS-TLS1.2:+VERS-TLS1.3:" /* TLS 1.3 only */ + "-VERS-SSL3.0:-ARCFOUR-128:-COMP-ALL:+COMP-NULL"; + +static struct tls_proxy_ctx *get_proxy(struct peer *peer) +{ + return (struct tls_proxy_ctx *)peer->handle.loop->data; +} + +const void *ip_addr(const struct sockaddr *addr) +{ + if (!addr) { + return NULL; + } + switch (addr->sa_family) { + case AF_INET: return (const void *)&(((const struct sockaddr_in *)addr)->sin_addr); + case AF_INET6: return (const void *)&(((const struct sockaddr_in6 *)addr)->sin6_addr); + default: return NULL; + } +} + +uint16_t ip_addr_port(const struct sockaddr *addr) +{ + if (!addr) { + return 0; + } + switch (addr->sa_family) { + case AF_INET: return ntohs(((const struct sockaddr_in *)addr)->sin_port); + case AF_INET6: return ntohs(((const struct sockaddr_in6 *)addr)->sin6_port); + default: return 0; + } +} + +static int ip_addr_str(const struct sockaddr *addr, char *buf, size_t *buflen) +{ + int ret = 0; + if (!addr || !buf || !buflen) { + return EINVAL; + } + + char str[INET6_ADDRSTRLEN + 6]; + if (!inet_ntop(addr->sa_family, ip_addr(addr), str, sizeof(str))) { + return errno; + } + int len = strlen(str); + str[len] = '#'; + snprintf(&str[len + 1], 6, "%hu", ip_addr_port(addr)); + len += 6; + str[len] = 0; + if (len >= *buflen) { + ret = ENOSPC; + } else { + memcpy(buf, str, len + 1); + } + *buflen = len; + return ret; +} + +static inline char *ip_straddr(const struct sockaddr_storage *saddr_storage) +{ + assert(saddr_storage != NULL); + const struct sockaddr *addr = (const struct sockaddr *)saddr_storage; + /* We are the single-threaded application */ + static char str[INET6_ADDRSTRLEN + 6]; + size_t len = sizeof(str); + int ret = ip_addr_str(addr, str, &len); + return ret != 0 || len == 0 ? NULL : str; +} + +static struct buf *alloc_io_buffer(size_t size) +{ + struct buf *buf = calloc(1, sizeof (struct buf) + size); + buf->size = size; + return buf; +} + +static void free_io_buffer(struct buf *buf) +{ + if (!buf) { + return; + } + free(buf); +} + +static struct buf *get_first_pending_buf(struct peer *peer) +{ + struct buf *buf = NULL; + if (peer->pending_buf.len > 0) { + buf = peer->pending_buf.at[0]; + } + return buf; +} + +static struct buf *remove_first_pending_buf(struct peer *peer) +{ + if (peer->pending_buf.len == 0) { + return NULL; + } + struct buf * buf = peer->pending_buf.at[0]; + for (int i = 1; i < peer->pending_buf.len; ++i) { + peer->pending_buf.at[i - 1] = peer->pending_buf.at[i]; + } + peer->pending_buf.len -= 1; + return buf; +} + +static void clear_pending_bufs(struct peer *peer) +{ + for (int i = 0; i < peer->pending_buf.len; ++i) { + struct buf *b = peer->pending_buf.at[i]; + free_io_buffer(b); + } + peer->pending_buf.len = 0; +} + +static void alloc_uv_buffer(uv_handle_t *handle, size_t suggested_size, uv_buf_t *buf) +{ + struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)handle->loop->data; + buf->base = proxy->uv_wire_buf; + buf->len = sizeof(proxy->uv_wire_buf); +} + +static void on_client_close(uv_handle_t *handle) +{ + struct peer *client = (struct peer *)handle->data; + struct peer *upstream = client->peer; + fprintf(stdout, "[client] connection with '%s' closed\n", ip_straddr(&client->addr)); + assert(client->tls); + gnutls_deinit(client->tls->session); + client->tls->handshake_state = TLS_HS_NOT_STARTED; + client->state = STATE_NOT_CONNECTED; + if (upstream->state != STATE_NOT_CONNECTED) { + if (upstream->state == STATE_CONNECTED) { + fprintf(stdout, "[client] closing connection with upstream for '%s'\n", + ip_straddr(&client->addr)); + upstream->state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&upstream->handle, on_upstream_close); + } + return; + } + struct tls_proxy_ctx *proxy = get_proxy(client); + for (size_t i = 0; i < proxy->client_list.len; ++i) { + struct peer *client_i = proxy->client_list.at[i]; + if (client_i == client) { + fprintf(stdout, "[client] connection structures deallocated for '%s'\n", + ip_straddr(&client->addr)); + array_del(proxy->client_list, i); + free(client->tls); + free(client); + break; + } + } +} + +static void on_upstream_close(uv_handle_t *handle) +{ + struct peer *upstream = (struct peer *)handle->data; + struct peer *client = upstream->peer; + assert(upstream->tls == NULL); + upstream->state = STATE_NOT_CONNECTED; + fprintf(stdout, "[upstream] connection with upstream closed for client '%s'\n", ip_straddr(&client->addr)); + if (client->state != STATE_NOT_CONNECTED) { + if (client->state == STATE_CONNECTED) { + fprintf(stdout, "[upstream] closing connection to client '%s'\n", + ip_straddr(&client->addr)); + client->state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&client->handle, on_client_close); + } + return; + } + struct tls_proxy_ctx *proxy = get_proxy(upstream); + for (size_t i = 0; i < proxy->client_list.len; ++i) { + struct peer *client_i = proxy->client_list.at[i]; + if (client_i == client) { + fprintf(stdout, "[upstream] connection structures deallocated for '%s'\n", + ip_straddr(&client->addr)); + array_del(proxy->client_list, i); + free(upstream); + free(client->tls); + free(client); + break; + } + } +} + +static void write_to_client_cb(uv_write_t *req, int status) +{ + struct peer *client = (struct peer *)req->handle->data; + free(req); + client->active_requests -= 1; + if (status) { + fprintf(stdout, "[client] error writing to client '%s': %s\n", + ip_straddr(&client->addr), uv_strerror(status)); + clear_pending_bufs(client); + if (client->state == STATE_CONNECTED) { + client->state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&client->handle, on_client_close); + return; + } + } + fprintf(stdout, "[client] successfully wrote to client '%s', pending len is %zd, active requests %i\n", + ip_straddr(&client->addr), client->pending_buf.len, client->active_requests); + if (client->state == STATE_CONNECTED && + client->tls->handshake_state == TLS_HS_DONE) { + struct tls_proxy_ctx *proxy = get_proxy(client); + uint64_t elapsed = uv_now(proxy->loop) - client->connection_timestamp; + if (!proxy->a->close_connection || elapsed < proxy->a->close_timeout) { + write_to_client_pending(client); + } else { + clear_pending_bufs(client); + client->state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&client->handle, on_client_close); + fprintf(stdout, "[client] closing connection to client '%s'\n", ip_straddr(&client->addr)); + } + } +} + +static void write_to_upstream_cb(uv_write_t *req, int status) +{ + struct peer *upstream = (struct peer *)req->handle->data; + void *data = req->data; + free(req); + if (status) { + fprintf(stdout, "[upstream] error writing to upstream: %s\n", uv_strerror(status)); + clear_pending_bufs(upstream); + upstream->state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&upstream->handle, on_upstream_close); + return; + } + if (data != NULL) { + assert(upstream->pending_buf.len > 0); + struct buf *buf = get_first_pending_buf(upstream); + assert(data == (void *)buf->buf); + fprintf(stdout, "[upstream] successfully wrote %zi bytes to upstream, pending len is %zd\n", + buf->size, upstream->pending_buf.len); + remove_first_pending_buf(upstream); + free_io_buffer(buf); + } else { + fprintf(stdout, "[upstream] successfully wrote to upstream, pending len is %zd\n", + upstream->pending_buf.len); + } + if (upstream->peer == NULL || upstream->peer->state != STATE_CONNECTED) { + clear_pending_bufs(upstream); + } else if (upstream->state == STATE_CONNECTED && upstream->pending_buf.len > 0) { + write_to_upstream_pending(upstream); + } +} + +static void accept_connection_from_client(uv_stream_t *server) +{ + struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)server->loop->data; + struct peer *client = calloc(1, sizeof(struct peer)); + uv_tcp_init(proxy->loop, &client->handle); + uv_tcp_nodelay((uv_tcp_t *)&client->handle, 1); + + int err = uv_accept(server, (uv_stream_t*)&client->handle); + if (err != 0) { + fprintf(stdout, "[client] incoming connection - uv_accept() failed: (%d) %s\n", + err, uv_strerror(err)); + proxy->conn_sequence = 0; + return; + } + + client->state = STATE_CONNECTED; + array_init(client->pending_buf); + client->handle.data = client; + + struct peer *upstream = calloc(1, sizeof(struct peer)); + uv_tcp_init(proxy->loop, &upstream->handle); + uv_tcp_nodelay((uv_tcp_t *)&upstream->handle, 1); + + client->peer = upstream; + + array_init(upstream->pending_buf); + upstream->state = STATE_NOT_CONNECTED; + upstream->peer = client; + upstream->handle.data = upstream; + + struct sockaddr *addr = (struct sockaddr *)&(client->addr); + int addr_len = sizeof(client->addr); + int ret = uv_tcp_getpeername(&client->handle, addr, &addr_len); + if (ret || addr->sa_family == AF_UNSPEC) { + client->state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&client->handle, on_client_close); + fprintf(stdout, "[client] incoming connection - uv_tcp_getpeername() failed: (%d) %s\n", + err, uv_strerror(err)); + proxy->conn_sequence = 0; + return; + } + memcpy(&upstream->addr, &proxy->upstream_addr, sizeof(struct sockaddr_storage)); + + struct tls_ctx *tls = calloc(1, sizeof(struct tls_ctx)); + tls->handshake_state = TLS_HS_NOT_STARTED; + + client->tls = tls; + const char *errpos = NULL; + unsigned int gnutls_flags = GNUTLS_SERVER | GNUTLS_NONBLOCK; +#if GNUTLS_VERSION_NUMBER >= 0x030604 + if (proxy->a->tls_13) { + gnutls_flags |= GNUTLS_POST_HANDSHAKE_AUTH; + } +#endif + err = gnutls_init(&tls->session, gnutls_flags); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stdout, "[client] gnutls_init() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + } + err = gnutls_priority_set(tls->session, proxy->tls_priority_cache); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stdout, "[client] gnutls_priority_set() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + } + + const char *direct_priorities = proxy->a->tls_13 ? tlsv13_priorities : tlsv12_priorities; + err = gnutls_priority_set_direct(tls->session, direct_priorities, &errpos); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stdout, "[client] setting priority '%s' failed at character %zd (...'%s') with %s (%d)\n", + direct_priorities, errpos - direct_priorities, errpos, + gnutls_strerror_name(err), err); + } + err = gnutls_credentials_set(tls->session, GNUTLS_CRD_CERTIFICATE, proxy->tls_credentials); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stdout, "[client] gnutls_credentials_set() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + } + if (proxy->a->tls_13) { + gnutls_certificate_server_set_request(tls->session, GNUTLS_CERT_REQUEST); + } else { + gnutls_certificate_server_set_request(tls->session, GNUTLS_CERT_IGNORE); + } + gnutls_handshake_set_timeout(tls->session, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT); + gnutls_transport_set_pull_function(tls->session, proxy_gnutls_pull); + gnutls_transport_set_push_function(tls->session, proxy_gnutls_push); + gnutls_transport_set_ptr(tls->session, client); + + tls->handshake_state = TLS_HS_IN_PROGRESS; + + client->connection_timestamp = uv_now(proxy->loop); + proxy->conn_sequence += 1; + array_push(proxy->client_list, client); + + fprintf(stdout, "[client] incoming connection from '%s'\n", ip_straddr(&client->addr)); + uv_read_start((uv_stream_t*)&client->handle, alloc_uv_buffer, read_from_client_cb); +} + +static void dynamic_handle_close_cb(uv_handle_t *handle) +{ + free(handle); +} + +static void delayed_accept_timer_cb(uv_timer_t *timer) +{ + uv_stream_t *server = (uv_stream_t *)timer->data; + fprintf(stdout, "[client] delayed connection processing\n"); + accept_connection_from_client(server); + uv_close((uv_handle_t *)timer, dynamic_handle_close_cb); +} + +static void on_client_connection(uv_stream_t *server, int status) +{ + if (status < 0) { + fprintf(stdout, "[client] incoming connection error: %s\n", uv_strerror(status)); + return; + } + struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)server->loop->data; + proxy->conn_sequence += 1; + if (proxy->a->max_conn_sequence > 0 && + proxy->conn_sequence > proxy->a->max_conn_sequence) { + fprintf(stdout, "[client] incoming connection, delaying\n"); + uv_timer_t *timer = (uv_timer_t*)malloc(sizeof *timer); + uv_timer_init(uv_default_loop(), timer); + timer->data = server; + uv_timer_start(timer, delayed_accept_timer_cb, 10000, 0); + proxy->conn_sequence = 0; + } else { + accept_connection_from_client(server); + } +} + +static void on_connect_to_upstream(uv_connect_t *req, int status) +{ + struct peer *upstream = (struct peer *)req->handle->data; + free(req); + if (status < 0) { + fprintf(stdout, "[upstream] error connecting to upstream (%s): %s\n", + ip_straddr(&upstream->addr), + uv_strerror(status)); + clear_pending_bufs(upstream); + upstream->state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&upstream->handle, on_upstream_close); + return; + } + fprintf(stdout, "[upstream] connected to %s\n", ip_straddr(&upstream->addr)); + + upstream->state = STATE_CONNECTED; + uv_read_start((uv_stream_t*)&upstream->handle, alloc_uv_buffer, read_from_upstream_cb); + if (upstream->pending_buf.len > 0) { + write_to_upstream_pending(upstream); + } +} + +static void read_from_client_cb(uv_stream_t *handle, ssize_t nread, const uv_buf_t *buf) +{ + if (nread == 0) { + fprintf(stdout, "[client] reading %zd bytes\n", nread); + return; + } + struct peer *client = (struct peer *)handle->data; + if (nread < 0) { + if (nread != UV_EOF) { + fprintf(stdout, "[client] error reading from '%s': %s\n", + ip_straddr(&client->addr), + uv_err_name(nread)); + } else { + fprintf(stdout, "[client] closing connection with '%s'\n", + ip_straddr(&client->addr)); + } + if (client->state == STATE_CONNECTED) { + client->state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)handle, on_client_close); + } + return; + } + + struct tls_proxy_ctx *proxy = get_proxy(client); + if (proxy->a->accept_only) { + fprintf(stdout, "[client] ignoring %zd bytes from '%s'\n", nread, ip_straddr(&client->addr)); + return; + } + fprintf(stdout, "[client] reading %zd bytes from '%s'\n", nread, ip_straddr(&client->addr)); + + int res = tls_process_from_client(client, (const uint8_t *)buf->base, nread); + if (res < 0) { + if (client->state == STATE_CONNECTED) { + client->state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&client->handle, on_client_close); + } + } +} + +static void read_from_upstream_cb(uv_stream_t *handle, ssize_t nread, const uv_buf_t *buf) +{ + fprintf(stdout, "[upstream] reading %zd bytes\n", nread); + if (nread == 0) { + return; + } + struct peer *upstream = (struct peer *)handle->data; + if (nread < 0) { + if (nread != UV_EOF) { + fprintf(stdout, "[upstream] error reading from upstream: %s\n", uv_err_name(nread)); + } else { + fprintf(stdout, "[upstream] closing connection\n"); + } + clear_pending_bufs(upstream); + if (upstream->state == STATE_CONNECTED) { + upstream->state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&upstream->handle, on_upstream_close); + } + return; + } + int res = tls_process_from_upstream(upstream, (const uint8_t *)buf->base, nread); + if (res < 0) { + fprintf(stdout, "[upstream] error processing tls data to client\n"); + if (upstream->peer->state == STATE_CONNECTED) { + upstream->peer->state = STATE_CLOSING_IN_PROGRESS; + uv_close((uv_handle_t*)&upstream->peer->handle, on_client_close); + } + } +} + +static void push_to_upstream_pending(struct peer *upstream, const char *buf, size_t size) +{ + struct buf *b = alloc_io_buffer(size); + memcpy(b->buf, buf, b->size); + array_push(upstream->pending_buf, b); +} + +static void push_to_client_pending(struct peer *client, const char *buf, size_t size) +{ + struct tls_proxy_ctx *proxy = get_proxy(client); + while (size > 0) { + int temp_size = size; + if (proxy->a->rehandshake && temp_size > CLIENT_ANSWER_CHUNK_SIZE) { + temp_size = CLIENT_ANSWER_CHUNK_SIZE; + } + struct buf *b = alloc_io_buffer(temp_size); + memcpy(b->buf, buf, b->size); + array_push(client->pending_buf, b); + size -= temp_size; + buf += temp_size; + } +} + +static int write_to_upstream_pending(struct peer *upstream) +{ + struct buf *buf = get_first_pending_buf(upstream); + uv_write_t *req = (uv_write_t *) malloc(sizeof(uv_write_t)); + uv_buf_t wrbuf = uv_buf_init(buf->buf, buf->size); + req->data = buf->buf; + fprintf(stdout, "[upstream] writing %zd bytes\n", buf->size); + return uv_write(req, (uv_stream_t *)&upstream->handle, &wrbuf, 1, write_to_upstream_cb); +} + +static ssize_t proxy_gnutls_pull(gnutls_transport_ptr_t h, void *buf, size_t len) +{ + struct peer *peer = (struct peer *)h; + struct tls_ctx *t = peer->tls; + + fprintf(stdout, "[gnutls_pull] pulling %zd bytes\n", len); + + if (t->nread <= t->consumed) { + errno = EAGAIN; + fprintf(stdout, "[gnutls_pull] return EAGAIN\n"); + return -1; + } + + ssize_t avail = t->nread - t->consumed; + ssize_t transfer = (avail <= len ? avail : len); + memcpy(buf, t->buf + t->consumed, transfer); + t->consumed += transfer; + return transfer; +} + +ssize_t proxy_gnutls_push(gnutls_transport_ptr_t h, const void *buf, size_t len) +{ + struct peer *client = (struct peer *)h; + fprintf(stdout, "[gnutls_push] writing %zd bytes\n", len); + + ssize_t ret = -1; + const size_t req_size_aligned = ((sizeof(uv_write_t) / 16) + 1) * 16; + char *common_buf = malloc(req_size_aligned + len); + uv_write_t *req = (uv_write_t *) common_buf; + char *data = common_buf + req_size_aligned; + const uv_buf_t uv_buf[1] = { + { data, len } + }; + memcpy(data, buf, len); + int res = uv_write(req, (uv_stream_t *)&client->handle, uv_buf, 1, write_to_client_cb); + if (res == 0) { + ret = len; + client->active_requests += 1; + } else { + free(common_buf); + errno = EIO; + } + return ret; +} + +static int write_to_client_pending(struct peer *client) +{ + if (client->pending_buf.len == 0) { + return 0; + } + + struct tls_proxy_ctx *proxy = get_proxy(client); + struct buf *buf = get_first_pending_buf(client); + fprintf(stdout, "[client] writing %zd bytes\n", buf->size); + + gnutls_session_t tls_session = client->tls->session; + assert(client->tls->handshake_state != TLS_HS_IN_PROGRESS); + + char *data = buf->buf; + size_t len = buf->size; + + ssize_t count = 0; + ssize_t submitted = len; + ssize_t retries = 0; + do { + count = gnutls_record_send(tls_session, data, len); + if (count < 0) { + if (gnutls_error_is_fatal(count)) { + fprintf(stdout, "[client] gnutls_record_send failed: %s (%zd)\n", + gnutls_strerror_name(count), count); + return -1; + } + if (++retries > TLS_MAX_SEND_RETRIES) { + fprintf(stdout, "[client] gnutls_record_send: too many sequential non-fatal errors (%zd), last error is: %s (%zd)\n", + retries, gnutls_strerror_name(count), count); + return -1; + } + } else if (count != 0) { + data += count; + len -= count; + retries = 0; + } else { + if (++retries < TLS_MAX_SEND_RETRIES) { + continue; + } + fprintf(stdout, "[client] gnutls_record_send: too many retries (%zd)\n", + retries); + fprintf(stdout, "[client] tls_push_to_client didn't send all data(%zd of %zd)\n", + len, submitted); + return -1; + } + } while (len > 0); + + remove_first_pending_buf(client); + free_io_buffer(buf); + + fprintf(stdout, "[client] submitted %zd bytes\n", submitted); + if (proxy->a->rehandshake) { + int err = GNUTLS_E_SUCCESS; +#if GNUTLS_VERSION_NUMBER >= 0x030604 + if (proxy->a->tls_13) { + int flags = gnutls_session_get_flags(tls_session); + if ((flags & GNUTLS_SFLAGS_POST_HANDSHAKE_AUTH) == 0) { + /* Client doesn't support post-handshake re-authentication, + * nothing to test here */ + fprintf(stdout, "[client] GNUTLS_SFLAGS_POST_HANDSHAKE_AUTH flag not detected\n"); + assert(false); + } + err = gnutls_reauth(tls_session, 0); + if (err != GNUTLS_E_INTERRUPTED && + err != GNUTLS_E_AGAIN && + err != GNUTLS_E_GOT_APPLICATION_DATA) { + fprintf(stdout, "[client] gnutls_reauth() failed: %s (%i)\n", + gnutls_strerror_name(err), err); + } else { + fprintf(stdout, "[client] post-handshake authentication initiated\n"); + } + client->tls->handshake_state = TLS_HS_REAUTH_EXPECTED; + } else { + assert (gnutls_safe_renegotiation_status(tls_session) != 0); + err = gnutls_rehandshake(tls_session); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stdout, "[client] gnutls_rehandshake() failed: %s (%i)\n", + gnutls_strerror_name(err), err); + assert(false); + } else { + fprintf(stdout, "[client] rehandshake started\n"); + } + client->tls->handshake_state = TLS_HS_EXPECTED; + } +#else + assert (gnutls_safe_renegotiation_status(tls_session) != 0); + err = gnutls_rehandshake(tls_session); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stdout, "[client] gnutls_rehandshake() failed: %s (%i)\n", + gnutls_strerror_name(err), err); + assert(false); + } else { + fprintf(stdout, "[client] rehandshake started\n"); + } + /* Prevent write-to-client callback from sending next pending chunk. + * At the same time tls_process_from_client() must not call gnutls_handshake() + * as there can be application data in this direction. */ + client->tls->handshake_state = TLS_HS_EXPECTED; +#endif + } + return submitted; +} + +static int tls_process_from_upstream(struct peer *upstream, const uint8_t *buf, ssize_t len) +{ + struct peer *client = upstream->peer; + + fprintf(stdout, "[upstream] pushing %zd bytes to client\n", len); + + ssize_t submitted = 0; + if (client->state != STATE_CONNECTED) { + return submitted; + } + + bool list_was_empty = (client->pending_buf.len == 0); + push_to_client_pending(client, (const char *)buf, len); + submitted = len; + if (client->tls->handshake_state == TLS_HS_DONE) { + if (list_was_empty && client->pending_buf.len > 0) { + int ret = write_to_client_pending(client); + if (ret < 0) { + submitted = -1; + } + } + } + + return submitted; +} + +int tls_process_handshake(struct peer *peer) +{ + struct tls_ctx *tls = peer->tls; + int ret = 1; + while (tls->handshake_state == TLS_HS_IN_PROGRESS) { + fprintf(stdout, "[tls] TLS handshake in progress...\n"); + int err = gnutls_handshake(tls->session); + if (err == GNUTLS_E_SUCCESS) { + tls->handshake_state = TLS_HS_DONE; + fprintf(stdout, "[tls] TLS handshake has completed\n"); + ret = 1; + if (peer->pending_buf.len != 0) { + write_to_client_pending(peer); + } + } else if (gnutls_error_is_fatal(err)) { + fprintf(stdout, "[tls] gnutls_handshake failed: %s (%d)\n", + gnutls_strerror_name(err), err); + ret = -1; + break; + } else { + fprintf(stdout, "[tls] gnutls_handshake nonfatal error: %s (%d)\n", + gnutls_strerror_name(err), err); + ret = 0; + break; + } + } + return ret; +} + +#if GNUTLS_VERSION_NUMBER >= 0x030604 +int tls_process_reauth(struct peer *peer) +{ + struct tls_ctx *tls = peer->tls; + int ret = 1; + while (tls->handshake_state == TLS_HS_REAUTH_EXPECTED) { + fprintf(stdout, "[tls] TLS re-authentication in progress...\n"); + int err = gnutls_reauth(tls->session, 0); + if (err == GNUTLS_E_SUCCESS) { + tls->handshake_state = TLS_HS_DONE; + fprintf(stdout, "[tls] TLS re-authentication has completed\n"); + ret = 1; + if (peer->pending_buf.len != 0) { + write_to_client_pending(peer); + } + } else if (err != GNUTLS_E_INTERRUPTED && + err != GNUTLS_E_AGAIN && + err != GNUTLS_E_GOT_APPLICATION_DATA) { + /* these are listed as nonfatal errors there + * https://www.gnutls.org/manual/gnutls.html#gnutls_005freauth */ + fprintf(stdout, "[tls] gnutls_reauth failed: %s (%d)\n", + gnutls_strerror_name(err), err); + ret = -1; + break; + } else { + fprintf(stdout, "[tls] gnutls_reauth nonfatal error: %s (%d)\n", + gnutls_strerror_name(err), err); + ret = 0; + break; + } + } + return ret; +} +#endif + +int tls_process_from_client(struct peer *client, const uint8_t *buf, ssize_t nread) +{ + struct tls_ctx *tls = client->tls; + + tls->buf = buf; + tls->nread = nread >= 0 ? nread : 0; + tls->consumed = 0; + + fprintf(stdout, "[client] tls_process: reading %zd bytes from client\n", nread); + + int ret = 0; + if (tls->handshake_state == TLS_HS_REAUTH_EXPECTED) { + ret = tls_process_reauth(client); + } else { + ret = tls_process_handshake(client); + } + if (ret <= 0) { + return ret; + } + + int submitted = 0; + while (true) { + ssize_t count = gnutls_record_recv(tls->session, tls->recv_buf, sizeof(tls->recv_buf)); + if (count == GNUTLS_E_AGAIN) { + break; /* No data available */ + } else if (count == GNUTLS_E_INTERRUPTED) { + continue; /* Try reading again */ + } else if (count == GNUTLS_E_REHANDSHAKE) { + tls->handshake_state = TLS_HS_IN_PROGRESS; + ret = tls_process_handshake(client); + if (ret < 0) { /* Critical error */ + return ret; + } + if (ret == 0) { /* Non fatal, most likely GNUTLS_E_AGAIN */ + break; + } + continue; + } +#if GNUTLS_VERSION_NUMBER >= 0x030604 + else if (count == GNUTLS_E_REAUTH_REQUEST) { + assert(false); + tls->handshake_state = TLS_HS_IN_PROGRESS; + ret = tls_process_reauth(client); + if (ret < 0) { /* Critical error */ + return ret; + } + if (ret == 0) { /* Non fatal, most likely GNUTLS_E_AGAIN */ + break; + } + continue; + } +#endif + else if (count < 0) { + fprintf(stdout, "[client] gnutls_record_recv failed: %s (%zd)\n", + gnutls_strerror_name(count), count); + assert(false); + return -1; + } else if (count == 0) { + break; + } + struct peer *upstream = client->peer; + if (upstream->state == STATE_CONNECTED) { + bool upstream_pending_is_empty = (upstream->pending_buf.len == 0); + push_to_upstream_pending(upstream, (const char *)tls->recv_buf, count); + if (upstream_pending_is_empty) { + write_to_upstream_pending(upstream); + } + } else if (upstream->state == STATE_NOT_CONNECTED) { + uv_connect_t *conn = (uv_connect_t *) malloc(sizeof(uv_connect_t)); + upstream->state = STATE_CONNECT_IN_PROGRESS; + fprintf(stdout, "[client] connecting to upstream '%s'\n", ip_straddr(&upstream->addr)); + uv_tcp_connect(conn, &upstream->handle, (struct sockaddr *)&upstream->addr, + on_connect_to_upstream); + push_to_upstream_pending(upstream, (const char *)tls->recv_buf, count); + } else if (upstream->state == STATE_CONNECT_IN_PROGRESS) { + push_to_upstream_pending(upstream, (const char *)tls->recv_buf, count); + } + submitted += count; + } + return submitted; +} + +struct tls_proxy_ctx *tls_proxy_allocate() +{ + return malloc(sizeof(struct tls_proxy_ctx)); +} + +int tls_proxy_init(struct tls_proxy_ctx *proxy, const struct args *a) +{ + const char *server_addr = a->local_addr; + int server_port = a->local_port; + const char *upstream_addr = a->upstream; + int upstream_port = a->upstream_port; + const char *cert_file = a->cert_file; + const char *key_file = a->key_file; + proxy->a = a; + proxy->loop = uv_default_loop(); + uv_tcp_init(proxy->loop, &proxy->server.handle); + int res = uv_ip4_addr(server_addr, server_port, (struct sockaddr_in *)&proxy->server.addr); + if (res != 0) { + res = uv_ip6_addr(server_addr, server_port, (struct sockaddr_in6 *)&proxy->server.addr); + if (res != 0) { + fprintf(stdout, "[proxy] tls_proxy_init: can't parse local address '%s'\n", server_addr); + return -1; + } + } + res = uv_ip4_addr(upstream_addr, upstream_port, (struct sockaddr_in *)&proxy->upstream_addr); + if (res != 0) { + res = uv_ip6_addr(upstream_addr, upstream_port, (struct sockaddr_in6 *)&proxy->upstream_addr); + if (res != 0) { + fprintf(stdout, "[proxy] tls_proxy_init: can't parse upstream address '%s'\n", upstream_addr); + return -1; + } + } + array_init(proxy->client_list); + proxy->conn_sequence = 0; + + proxy->loop->data = proxy; + + int err = 0; + if (gnutls_references == 0) { + err = gnutls_global_init(); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stdout, "[proxy] gnutls_global_init() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + return -1; + } + } + gnutls_references += 1; + + err = gnutls_certificate_allocate_credentials(&proxy->tls_credentials); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stdout, "[proxy] gnutls_certificate_allocate_credentials() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + return -1; + } + + err = gnutls_certificate_set_x509_system_trust(proxy->tls_credentials); + if (err <= 0) { + fprintf(stdout, "[proxy] gnutls_certificate_set_x509_system_trust() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + return -1; + } + + if (cert_file && key_file) { + err = gnutls_certificate_set_x509_key_file(proxy->tls_credentials, + cert_file, key_file, GNUTLS_X509_FMT_PEM); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stdout, "[proxy] gnutls_certificate_set_x509_key_file() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + return -1; + } + } + + err = gnutls_priority_init(&proxy->tls_priority_cache, NULL, NULL); + if (err != GNUTLS_E_SUCCESS) { + fprintf(stdout, "[proxy] gnutls_priority_init() failed: (%d) %s\n", + err, gnutls_strerror_name(err)); + return -1; + } + + return 0; +} + +void tls_proxy_free(struct tls_proxy_ctx *proxy) +{ + if (!proxy) { + return; + } + while (proxy->client_list.len > 0) { + size_t last_index = proxy->client_list.len - 1; + struct peer *client = proxy->client_list.at[last_index]; + clear_pending_bufs(client); + clear_pending_bufs(client->peer); + /* TODO correctly close all the uv_tcp_t */ + free(client->peer); + free(client); + array_del(proxy->client_list, last_index); + } + gnutls_certificate_free_credentials(proxy->tls_credentials); + gnutls_priority_deinit(proxy->tls_priority_cache); + free(proxy); + + gnutls_references -= 1; + if (gnutls_references == 0) { + gnutls_global_deinit(); + } +} + +int tls_proxy_start_listen(struct tls_proxy_ctx *proxy) +{ + uv_tcp_bind(&proxy->server.handle, (const struct sockaddr*)&proxy->server.addr, 0); + int ret = uv_listen((uv_stream_t*)&proxy->server.handle, 128, on_client_connection); + return ret; +} + +int tls_proxy_run(struct tls_proxy_ctx *proxy) +{ + return uv_run(proxy->loop, UV_RUN_DEFAULT); +} |