summaryrefslogtreecommitdiffstats
path: root/tests/pytests/proxy
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 15:26:00 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 15:26:00 +0000
commit830407e88f9d40d954356c3754f2647f91d5c06a (patch)
treed6a0ece6feea91f3c656166dbaa884ef8a29740e /tests/pytests/proxy
parentInitial commit. (diff)
downloadknot-resolver-upstream/5.6.0.tar.xz
knot-resolver-upstream/5.6.0.zip
Adding upstream version 5.6.0.upstream/5.6.0upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
-rw-r--r--tests/pytests/proxy.py161
-rw-r--r--tests/pytests/proxy/tls-proxy.c1038
-rw-r--r--tests/pytests/proxy/tls-proxy.h34
-rw-r--r--tests/pytests/proxy/tlsproxy.c198
4 files changed, 1431 insertions, 0 deletions
diff --git a/tests/pytests/proxy.py b/tests/pytests/proxy.py
new file mode 100644
index 0000000..b8a53cd
--- /dev/null
+++ b/tests/pytests/proxy.py
@@ -0,0 +1,161 @@
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from contextlib import contextmanager, ContextDecorator
+import os
+import subprocess
+from typing import Any, Dict, Optional
+
+import dns
+import dns.rcode
+import pytest
+
+from kresd import CERTS_DIR, Forward, Kresd, make_kresd, make_port
+import utils
+
+
+HINTS = {
+ '0.foo.': '127.0.0.1',
+ '1.foo.': '127.0.0.1',
+ '2.foo.': '127.0.0.1',
+ '3.foo.': '127.0.0.1',
+}
+
+
+def resolve_hint(sock, qname):
+ buff, msgid = utils.get_msgbuff(qname)
+ sock.sendall(buff)
+ answer = utils.receive_parse_answer(sock)
+ assert answer.id == msgid
+ assert answer.rcode() == dns.rcode.NOERROR
+ assert answer.answer[0][0].address == HINTS[qname]
+
+
+class Proxy(ContextDecorator):
+ EXECUTABLE = ''
+
+ def __init__(
+ self,
+ local_ip: str = '127.0.0.1',
+ local_port: Optional[int] = None,
+ upstream_ip: str = '127.0.0.1',
+ upstream_port: Optional[int] = None
+ ) -> None:
+ self.local_ip = local_ip
+ self.local_port = local_port
+ self.upstream_ip = upstream_ip
+ self.upstream_port = upstream_port
+ self.proxy = None
+
+ def get_args(self):
+ args = []
+ args.append('--local')
+ args.append(self.local_ip)
+ if self.local_port is not None:
+ args.append('--lport')
+ args.append(str(self.local_port))
+ args.append('--upstream')
+ args.append(self.upstream_ip)
+ if self.upstream_port is not None:
+ args.append('--uport')
+ args.append(str(self.upstream_port))
+ return args
+
+ def __enter__(self):
+ args = [self.EXECUTABLE] + self.get_args()
+ print(' '.join(args))
+
+ try:
+ self.proxy = subprocess.Popen(
+ args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
+ except subprocess.CalledProcessError:
+ pytest.skip("proxy '{}' failed to run (did you compile it?)"
+ .format(self.EXECUTABLE))
+
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ if self.proxy is not None:
+ self.proxy.terminate()
+ self.proxy = None
+
+
+class TLSProxy(Proxy):
+ EXECUTABLE = 'tlsproxy'
+
+ def __init__(
+ self,
+ local_ip: str = '127.0.0.1',
+ local_port: Optional[int] = None,
+ upstream_ip: str = '127.0.0.1',
+ upstream_port: Optional[int] = None,
+ certname: Optional[str] = 'tt',
+ close: Optional[int] = None,
+ rehandshake: bool = False,
+ force_tls13: bool = False
+ ) -> None:
+ super().__init__(local_ip, local_port, upstream_ip, upstream_port)
+ if certname is not None:
+ self.cert_path = os.path.join(CERTS_DIR, certname + '.cert.pem')
+ self.key_path = os.path.join(CERTS_DIR, certname + '.key.pem')
+ else:
+ self.cert_path = None
+ self.key_path = None
+ self.close = close
+ self.rehandshake = rehandshake
+ self.force_tls13 = force_tls13
+
+ def get_args(self):
+ args = super().get_args()
+ if self.cert_path is not None:
+ args.append('--cert')
+ args.append(self.cert_path)
+ if self.key_path is not None:
+ args.append('--key')
+ args.append(self.key_path)
+ if self.close is not None:
+ args.append('--close')
+ args.append(str(self.close))
+ if self.rehandshake:
+ args.append('--rehandshake')
+ if self.force_tls13:
+ args.append('--tls13')
+ return args
+
+
+@contextmanager
+def kresd_tls_client(
+ workdir: str,
+ proxy: TLSProxy,
+ kresd_tls_client_kwargs: Optional[Dict[Any, Any]] = None,
+ kresd_fwd_target_kwargs: Optional[Dict[Any, Any]] = None
+ ) -> Kresd:
+ """kresd_tls_client --(tls)--> tlsproxy --(tcp)--> kresd_fwd_target"""
+ ALLOWED_IPS = {'127.0.0.1', '::1'}
+ assert proxy.local_ip in ALLOWED_IPS, "only localhost IPs supported for proxy"
+ assert proxy.upstream_ip in ALLOWED_IPS, "only localhost IPs are supported for proxy"
+
+ if kresd_tls_client_kwargs is None:
+ kresd_tls_client_kwargs = {}
+ if kresd_fwd_target_kwargs is None:
+ kresd_fwd_target_kwargs = {}
+
+ # run forward target instance
+ dir1 = os.path.join(workdir, 'kresd_fwd_target')
+ os.makedirs(dir1)
+
+ with make_kresd(dir1, hints=HINTS, **kresd_fwd_target_kwargs) as kresd_fwd_target:
+ sock = kresd_fwd_target.ip_tcp_socket()
+ resolve_hint(sock, list(HINTS.keys())[0])
+
+ proxy.local_port = make_port('127.0.0.1', '::1')
+ proxy.upstream_port = kresd_fwd_target.port
+
+ with proxy:
+ # run test kresd instance
+ dir2 = os.path.join(workdir, 'kresd_tls_client')
+ os.makedirs(dir2)
+ forward = Forward(
+ proto='tls', ip=proxy.local_ip, port=proxy.local_port,
+ hostname='transport-test-server.com', ca_file=proxy.cert_path)
+ with make_kresd(dir2, forward=forward, **kresd_tls_client_kwargs) as kresd:
+ yield kresd
diff --git a/tests/pytests/proxy/tls-proxy.c b/tests/pytests/proxy/tls-proxy.c
new file mode 100644
index 0000000..986716c
--- /dev/null
+++ b/tests/pytests/proxy/tls-proxy.c
@@ -0,0 +1,1038 @@
+/* SPDX-License-Identifier: GPL-3.0-or-later */
+
+#include <assert.h>
+#include <stdio.h>
+#include <unistd.h>
+#include <string.h>
+#include <stdlib.h>
+#include <stdbool.h>
+#include <gnutls/gnutls.h>
+#include <uv.h>
+#include "lib/generic/array.h"
+#include "tls-proxy.h"
+
+#define TLS_MAX_SEND_RETRIES 100
+#define CLIENT_ANSWER_CHUNK_SIZE 8
+
+#define MAX_CLIENT_PENDING_SIZE 4096
+
+struct buf {
+ size_t size;
+ char buf[];
+};
+
+enum peer_state {
+ STATE_NOT_CONNECTED,
+ STATE_LISTENING,
+ STATE_CONNECTED,
+ STATE_CONNECT_IN_PROGRESS,
+ STATE_CLOSING_IN_PROGRESS
+};
+
+enum handshake_state {
+ TLS_HS_NOT_STARTED = 0,
+ TLS_HS_EXPECTED,
+ TLS_HS_REAUTH_EXPECTED,
+ TLS_HS_IN_PROGRESS,
+ TLS_HS_DONE,
+ TLS_HS_CLOSING,
+ TLS_HS_LAST
+};
+
+struct tls_ctx {
+ gnutls_session_t session;
+ enum handshake_state handshake_state;
+ /* for reading from the network */
+ const uint8_t *buf;
+ ssize_t nread;
+ ssize_t consumed;
+ uint8_t recv_buf[4096];
+};
+
+struct peer {
+ uv_tcp_t handle;
+ enum peer_state state;
+ struct sockaddr_storage addr;
+ array_t(struct buf *) pending_buf;
+ uint64_t connection_timestamp;
+ struct tls_ctx *tls;
+ struct peer *peer;
+ int active_requests;
+};
+
+struct tls_proxy_ctx {
+ const struct args *a;
+ uv_loop_t *loop;
+ gnutls_certificate_credentials_t tls_credentials;
+ gnutls_priority_t tls_priority_cache;
+ struct {
+ uv_tcp_t handle;
+ struct sockaddr_storage addr;
+ } server;
+ struct sockaddr_storage upstream_addr;
+ array_t(struct peer *) client_list;
+ char uv_wire_buf[65535 * 2];
+ int conn_sequence;
+};
+
+static void read_from_upstream_cb(uv_stream_t *upstream, ssize_t nread, const uv_buf_t *buf);
+static void read_from_client_cb(uv_stream_t *client, ssize_t nread, const uv_buf_t *buf);
+static ssize_t proxy_gnutls_pull(gnutls_transport_ptr_t h, void *buf, size_t len);
+static ssize_t proxy_gnutls_push(gnutls_transport_ptr_t h, const void *buf, size_t len);
+static int tls_process_from_upstream(struct peer *upstream, const uint8_t *buf, ssize_t nread);
+static int tls_process_from_client(struct peer *client, const uint8_t *buf, ssize_t nread);
+static int write_to_upstream_pending(struct peer *peer);
+static int write_to_client_pending(struct peer *peer);
+static void on_client_close(uv_handle_t *handle);
+static void on_upstream_close(uv_handle_t *handle);
+
+static int gnutls_references = 0;
+
+static const char * const tlsv12_priorities =
+ "NORMAL:" /* GnuTLS defaults */
+ "-VERS-TLS1.0:-VERS-TLS1.1:+VERS-TLS1.2:-VERS-TLS1.3:" /* TLS 1.2 only */
+ "-VERS-SSL3.0:-ARCFOUR-128:-COMP-ALL:+COMP-NULL";
+
+static const char * const tlsv13_priorities =
+ "NORMAL:" /* GnuTLS defaults */
+ "-VERS-TLS1.0:-VERS-TLS1.1:-VERS-TLS1.2:+VERS-TLS1.3:" /* TLS 1.3 only */
+ "-VERS-SSL3.0:-ARCFOUR-128:-COMP-ALL:+COMP-NULL";
+
+static struct tls_proxy_ctx *get_proxy(struct peer *peer)
+{
+ return (struct tls_proxy_ctx *)peer->handle.loop->data;
+}
+
+const void *ip_addr(const struct sockaddr *addr)
+{
+ if (!addr) {
+ return NULL;
+ }
+ switch (addr->sa_family) {
+ case AF_INET: return (const void *)&(((const struct sockaddr_in *)addr)->sin_addr);
+ case AF_INET6: return (const void *)&(((const struct sockaddr_in6 *)addr)->sin6_addr);
+ default: return NULL;
+ }
+}
+
+uint16_t ip_addr_port(const struct sockaddr *addr)
+{
+ if (!addr) {
+ return 0;
+ }
+ switch (addr->sa_family) {
+ case AF_INET: return ntohs(((const struct sockaddr_in *)addr)->sin_port);
+ case AF_INET6: return ntohs(((const struct sockaddr_in6 *)addr)->sin6_port);
+ default: return 0;
+ }
+}
+
+static int ip_addr_str(const struct sockaddr *addr, char *buf, size_t *buflen)
+{
+ int ret = 0;
+ if (!addr || !buf || !buflen) {
+ return EINVAL;
+ }
+
+ char str[INET6_ADDRSTRLEN + 6];
+ if (!inet_ntop(addr->sa_family, ip_addr(addr), str, sizeof(str))) {
+ return errno;
+ }
+ int len = strlen(str);
+ str[len] = '#';
+ snprintf(&str[len + 1], 6, "%hu", ip_addr_port(addr));
+ len += 6;
+ str[len] = 0;
+ if (len >= *buflen) {
+ ret = ENOSPC;
+ } else {
+ memcpy(buf, str, len + 1);
+ }
+ *buflen = len;
+ return ret;
+}
+
+static inline char *ip_straddr(const struct sockaddr_storage *saddr_storage)
+{
+ assert(saddr_storage != NULL);
+ const struct sockaddr *addr = (const struct sockaddr *)saddr_storage;
+ /* We are the single-threaded application */
+ static char str[INET6_ADDRSTRLEN + 6];
+ size_t len = sizeof(str);
+ int ret = ip_addr_str(addr, str, &len);
+ return ret != 0 || len == 0 ? NULL : str;
+}
+
+static struct buf *alloc_io_buffer(size_t size)
+{
+ struct buf *buf = calloc(1, sizeof (struct buf) + size);
+ buf->size = size;
+ return buf;
+}
+
+static void free_io_buffer(struct buf *buf)
+{
+ if (!buf) {
+ return;
+ }
+ free(buf);
+}
+
+static struct buf *get_first_pending_buf(struct peer *peer)
+{
+ struct buf *buf = NULL;
+ if (peer->pending_buf.len > 0) {
+ buf = peer->pending_buf.at[0];
+ }
+ return buf;
+}
+
+static struct buf *remove_first_pending_buf(struct peer *peer)
+{
+ if (peer->pending_buf.len == 0) {
+ return NULL;
+ }
+ struct buf * buf = peer->pending_buf.at[0];
+ for (int i = 1; i < peer->pending_buf.len; ++i) {
+ peer->pending_buf.at[i - 1] = peer->pending_buf.at[i];
+ }
+ peer->pending_buf.len -= 1;
+ return buf;
+}
+
+static void clear_pending_bufs(struct peer *peer)
+{
+ for (int i = 0; i < peer->pending_buf.len; ++i) {
+ struct buf *b = peer->pending_buf.at[i];
+ free_io_buffer(b);
+ }
+ peer->pending_buf.len = 0;
+}
+
+static void alloc_uv_buffer(uv_handle_t *handle, size_t suggested_size, uv_buf_t *buf)
+{
+ struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)handle->loop->data;
+ buf->base = proxy->uv_wire_buf;
+ buf->len = sizeof(proxy->uv_wire_buf);
+}
+
+static void on_client_close(uv_handle_t *handle)
+{
+ struct peer *client = (struct peer *)handle->data;
+ struct peer *upstream = client->peer;
+ fprintf(stdout, "[client] connection with '%s' closed\n", ip_straddr(&client->addr));
+ assert(client->tls);
+ gnutls_deinit(client->tls->session);
+ client->tls->handshake_state = TLS_HS_NOT_STARTED;
+ client->state = STATE_NOT_CONNECTED;
+ if (upstream->state != STATE_NOT_CONNECTED) {
+ if (upstream->state == STATE_CONNECTED) {
+ fprintf(stdout, "[client] closing connection with upstream for '%s'\n",
+ ip_straddr(&client->addr));
+ upstream->state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&upstream->handle, on_upstream_close);
+ }
+ return;
+ }
+ struct tls_proxy_ctx *proxy = get_proxy(client);
+ for (size_t i = 0; i < proxy->client_list.len; ++i) {
+ struct peer *client_i = proxy->client_list.at[i];
+ if (client_i == client) {
+ fprintf(stdout, "[client] connection structures deallocated for '%s'\n",
+ ip_straddr(&client->addr));
+ array_del(proxy->client_list, i);
+ free(client->tls);
+ free(client);
+ break;
+ }
+ }
+}
+
+static void on_upstream_close(uv_handle_t *handle)
+{
+ struct peer *upstream = (struct peer *)handle->data;
+ struct peer *client = upstream->peer;
+ assert(upstream->tls == NULL);
+ upstream->state = STATE_NOT_CONNECTED;
+ fprintf(stdout, "[upstream] connection with upstream closed for client '%s'\n", ip_straddr(&client->addr));
+ if (client->state != STATE_NOT_CONNECTED) {
+ if (client->state == STATE_CONNECTED) {
+ fprintf(stdout, "[upstream] closing connection to client '%s'\n",
+ ip_straddr(&client->addr));
+ client->state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&client->handle, on_client_close);
+ }
+ return;
+ }
+ struct tls_proxy_ctx *proxy = get_proxy(upstream);
+ for (size_t i = 0; i < proxy->client_list.len; ++i) {
+ struct peer *client_i = proxy->client_list.at[i];
+ if (client_i == client) {
+ fprintf(stdout, "[upstream] connection structures deallocated for '%s'\n",
+ ip_straddr(&client->addr));
+ array_del(proxy->client_list, i);
+ free(upstream);
+ free(client->tls);
+ free(client);
+ break;
+ }
+ }
+}
+
+static void write_to_client_cb(uv_write_t *req, int status)
+{
+ struct peer *client = (struct peer *)req->handle->data;
+ free(req);
+ client->active_requests -= 1;
+ if (status) {
+ fprintf(stdout, "[client] error writing to client '%s': %s\n",
+ ip_straddr(&client->addr), uv_strerror(status));
+ clear_pending_bufs(client);
+ if (client->state == STATE_CONNECTED) {
+ client->state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&client->handle, on_client_close);
+ return;
+ }
+ }
+ fprintf(stdout, "[client] successfully wrote to client '%s', pending len is %zd, active requests %i\n",
+ ip_straddr(&client->addr), client->pending_buf.len, client->active_requests);
+ if (client->state == STATE_CONNECTED &&
+ client->tls->handshake_state == TLS_HS_DONE) {
+ struct tls_proxy_ctx *proxy = get_proxy(client);
+ uint64_t elapsed = uv_now(proxy->loop) - client->connection_timestamp;
+ if (!proxy->a->close_connection || elapsed < proxy->a->close_timeout) {
+ write_to_client_pending(client);
+ } else {
+ clear_pending_bufs(client);
+ client->state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&client->handle, on_client_close);
+ fprintf(stdout, "[client] closing connection to client '%s'\n", ip_straddr(&client->addr));
+ }
+ }
+}
+
+static void write_to_upstream_cb(uv_write_t *req, int status)
+{
+ struct peer *upstream = (struct peer *)req->handle->data;
+ void *data = req->data;
+ free(req);
+ if (status) {
+ fprintf(stdout, "[upstream] error writing to upstream: %s\n", uv_strerror(status));
+ clear_pending_bufs(upstream);
+ upstream->state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&upstream->handle, on_upstream_close);
+ return;
+ }
+ if (data != NULL) {
+ assert(upstream->pending_buf.len > 0);
+ struct buf *buf = get_first_pending_buf(upstream);
+ assert(data == (void *)buf->buf);
+ fprintf(stdout, "[upstream] successfully wrote %zi bytes to upstream, pending len is %zd\n",
+ buf->size, upstream->pending_buf.len);
+ remove_first_pending_buf(upstream);
+ free_io_buffer(buf);
+ } else {
+ fprintf(stdout, "[upstream] successfully wrote to upstream, pending len is %zd\n",
+ upstream->pending_buf.len);
+ }
+ if (upstream->peer == NULL || upstream->peer->state != STATE_CONNECTED) {
+ clear_pending_bufs(upstream);
+ } else if (upstream->state == STATE_CONNECTED && upstream->pending_buf.len > 0) {
+ write_to_upstream_pending(upstream);
+ }
+}
+
+static void accept_connection_from_client(uv_stream_t *server)
+{
+ struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)server->loop->data;
+ struct peer *client = calloc(1, sizeof(struct peer));
+ uv_tcp_init(proxy->loop, &client->handle);
+ uv_tcp_nodelay((uv_tcp_t *)&client->handle, 1);
+
+ int err = uv_accept(server, (uv_stream_t*)&client->handle);
+ if (err != 0) {
+ fprintf(stdout, "[client] incoming connection - uv_accept() failed: (%d) %s\n",
+ err, uv_strerror(err));
+ proxy->conn_sequence = 0;
+ return;
+ }
+
+ client->state = STATE_CONNECTED;
+ array_init(client->pending_buf);
+ client->handle.data = client;
+
+ struct peer *upstream = calloc(1, sizeof(struct peer));
+ uv_tcp_init(proxy->loop, &upstream->handle);
+ uv_tcp_nodelay((uv_tcp_t *)&upstream->handle, 1);
+
+ client->peer = upstream;
+
+ array_init(upstream->pending_buf);
+ upstream->state = STATE_NOT_CONNECTED;
+ upstream->peer = client;
+ upstream->handle.data = upstream;
+
+ struct sockaddr *addr = (struct sockaddr *)&(client->addr);
+ int addr_len = sizeof(client->addr);
+ int ret = uv_tcp_getpeername(&client->handle, addr, &addr_len);
+ if (ret || addr->sa_family == AF_UNSPEC) {
+ client->state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&client->handle, on_client_close);
+ fprintf(stdout, "[client] incoming connection - uv_tcp_getpeername() failed: (%d) %s\n",
+ err, uv_strerror(err));
+ proxy->conn_sequence = 0;
+ return;
+ }
+ memcpy(&upstream->addr, &proxy->upstream_addr, sizeof(struct sockaddr_storage));
+
+ struct tls_ctx *tls = calloc(1, sizeof(struct tls_ctx));
+ tls->handshake_state = TLS_HS_NOT_STARTED;
+
+ client->tls = tls;
+ const char *errpos = NULL;
+ unsigned int gnutls_flags = GNUTLS_SERVER | GNUTLS_NONBLOCK;
+#if GNUTLS_VERSION_NUMBER >= 0x030604
+ if (proxy->a->tls_13) {
+ gnutls_flags |= GNUTLS_POST_HANDSHAKE_AUTH;
+ }
+#endif
+ err = gnutls_init(&tls->session, gnutls_flags);
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stdout, "[client] gnutls_init() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ }
+ err = gnutls_priority_set(tls->session, proxy->tls_priority_cache);
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stdout, "[client] gnutls_priority_set() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ }
+
+ const char *direct_priorities = proxy->a->tls_13 ? tlsv13_priorities : tlsv12_priorities;
+ err = gnutls_priority_set_direct(tls->session, direct_priorities, &errpos);
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stdout, "[client] setting priority '%s' failed at character %zd (...'%s') with %s (%d)\n",
+ direct_priorities, errpos - direct_priorities, errpos,
+ gnutls_strerror_name(err), err);
+ }
+ err = gnutls_credentials_set(tls->session, GNUTLS_CRD_CERTIFICATE, proxy->tls_credentials);
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stdout, "[client] gnutls_credentials_set() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ }
+ if (proxy->a->tls_13) {
+ gnutls_certificate_server_set_request(tls->session, GNUTLS_CERT_REQUEST);
+ } else {
+ gnutls_certificate_server_set_request(tls->session, GNUTLS_CERT_IGNORE);
+ }
+ gnutls_handshake_set_timeout(tls->session, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT);
+ gnutls_transport_set_pull_function(tls->session, proxy_gnutls_pull);
+ gnutls_transport_set_push_function(tls->session, proxy_gnutls_push);
+ gnutls_transport_set_ptr(tls->session, client);
+
+ tls->handshake_state = TLS_HS_IN_PROGRESS;
+
+ client->connection_timestamp = uv_now(proxy->loop);
+ proxy->conn_sequence += 1;
+ array_push(proxy->client_list, client);
+
+ fprintf(stdout, "[client] incoming connection from '%s'\n", ip_straddr(&client->addr));
+ uv_read_start((uv_stream_t*)&client->handle, alloc_uv_buffer, read_from_client_cb);
+}
+
+static void dynamic_handle_close_cb(uv_handle_t *handle)
+{
+ free(handle);
+}
+
+static void delayed_accept_timer_cb(uv_timer_t *timer)
+{
+ uv_stream_t *server = (uv_stream_t *)timer->data;
+ fprintf(stdout, "[client] delayed connection processing\n");
+ accept_connection_from_client(server);
+ uv_close((uv_handle_t *)timer, dynamic_handle_close_cb);
+}
+
+static void on_client_connection(uv_stream_t *server, int status)
+{
+ if (status < 0) {
+ fprintf(stdout, "[client] incoming connection error: %s\n", uv_strerror(status));
+ return;
+ }
+ struct tls_proxy_ctx *proxy = (struct tls_proxy_ctx *)server->loop->data;
+ proxy->conn_sequence += 1;
+ if (proxy->a->max_conn_sequence > 0 &&
+ proxy->conn_sequence > proxy->a->max_conn_sequence) {
+ fprintf(stdout, "[client] incoming connection, delaying\n");
+ uv_timer_t *timer = (uv_timer_t*)malloc(sizeof *timer);
+ uv_timer_init(uv_default_loop(), timer);
+ timer->data = server;
+ uv_timer_start(timer, delayed_accept_timer_cb, 10000, 0);
+ proxy->conn_sequence = 0;
+ } else {
+ accept_connection_from_client(server);
+ }
+}
+
+static void on_connect_to_upstream(uv_connect_t *req, int status)
+{
+ struct peer *upstream = (struct peer *)req->handle->data;
+ free(req);
+ if (status < 0) {
+ fprintf(stdout, "[upstream] error connecting to upstream (%s): %s\n",
+ ip_straddr(&upstream->addr),
+ uv_strerror(status));
+ clear_pending_bufs(upstream);
+ upstream->state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&upstream->handle, on_upstream_close);
+ return;
+ }
+ fprintf(stdout, "[upstream] connected to %s\n", ip_straddr(&upstream->addr));
+
+ upstream->state = STATE_CONNECTED;
+ uv_read_start((uv_stream_t*)&upstream->handle, alloc_uv_buffer, read_from_upstream_cb);
+ if (upstream->pending_buf.len > 0) {
+ write_to_upstream_pending(upstream);
+ }
+}
+
+static void read_from_client_cb(uv_stream_t *handle, ssize_t nread, const uv_buf_t *buf)
+{
+ if (nread == 0) {
+ fprintf(stdout, "[client] reading %zd bytes\n", nread);
+ return;
+ }
+ struct peer *client = (struct peer *)handle->data;
+ if (nread < 0) {
+ if (nread != UV_EOF) {
+ fprintf(stdout, "[client] error reading from '%s': %s\n",
+ ip_straddr(&client->addr),
+ uv_err_name(nread));
+ } else {
+ fprintf(stdout, "[client] closing connection with '%s'\n",
+ ip_straddr(&client->addr));
+ }
+ if (client->state == STATE_CONNECTED) {
+ client->state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)handle, on_client_close);
+ }
+ return;
+ }
+
+ struct tls_proxy_ctx *proxy = get_proxy(client);
+ if (proxy->a->accept_only) {
+ fprintf(stdout, "[client] ignoring %zd bytes from '%s'\n", nread, ip_straddr(&client->addr));
+ return;
+ }
+ fprintf(stdout, "[client] reading %zd bytes from '%s'\n", nread, ip_straddr(&client->addr));
+
+ int res = tls_process_from_client(client, (const uint8_t *)buf->base, nread);
+ if (res < 0) {
+ if (client->state == STATE_CONNECTED) {
+ client->state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&client->handle, on_client_close);
+ }
+ }
+}
+
+static void read_from_upstream_cb(uv_stream_t *handle, ssize_t nread, const uv_buf_t *buf)
+{
+ fprintf(stdout, "[upstream] reading %zd bytes\n", nread);
+ if (nread == 0) {
+ return;
+ }
+ struct peer *upstream = (struct peer *)handle->data;
+ if (nread < 0) {
+ if (nread != UV_EOF) {
+ fprintf(stdout, "[upstream] error reading from upstream: %s\n", uv_err_name(nread));
+ } else {
+ fprintf(stdout, "[upstream] closing connection\n");
+ }
+ clear_pending_bufs(upstream);
+ if (upstream->state == STATE_CONNECTED) {
+ upstream->state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&upstream->handle, on_upstream_close);
+ }
+ return;
+ }
+ int res = tls_process_from_upstream(upstream, (const uint8_t *)buf->base, nread);
+ if (res < 0) {
+ fprintf(stdout, "[upstream] error processing tls data to client\n");
+ if (upstream->peer->state == STATE_CONNECTED) {
+ upstream->peer->state = STATE_CLOSING_IN_PROGRESS;
+ uv_close((uv_handle_t*)&upstream->peer->handle, on_client_close);
+ }
+ }
+}
+
+static void push_to_upstream_pending(struct peer *upstream, const char *buf, size_t size)
+{
+ struct buf *b = alloc_io_buffer(size);
+ memcpy(b->buf, buf, b->size);
+ array_push(upstream->pending_buf, b);
+}
+
+static void push_to_client_pending(struct peer *client, const char *buf, size_t size)
+{
+ struct tls_proxy_ctx *proxy = get_proxy(client);
+ while (size > 0) {
+ int temp_size = size;
+ if (proxy->a->rehandshake && temp_size > CLIENT_ANSWER_CHUNK_SIZE) {
+ temp_size = CLIENT_ANSWER_CHUNK_SIZE;
+ }
+ struct buf *b = alloc_io_buffer(temp_size);
+ memcpy(b->buf, buf, b->size);
+ array_push(client->pending_buf, b);
+ size -= temp_size;
+ buf += temp_size;
+ }
+}
+
+static int write_to_upstream_pending(struct peer *upstream)
+{
+ struct buf *buf = get_first_pending_buf(upstream);
+ uv_write_t *req = (uv_write_t *) malloc(sizeof(uv_write_t));
+ uv_buf_t wrbuf = uv_buf_init(buf->buf, buf->size);
+ req->data = buf->buf;
+ fprintf(stdout, "[upstream] writing %zd bytes\n", buf->size);
+ return uv_write(req, (uv_stream_t *)&upstream->handle, &wrbuf, 1, write_to_upstream_cb);
+}
+
+static ssize_t proxy_gnutls_pull(gnutls_transport_ptr_t h, void *buf, size_t len)
+{
+ struct peer *peer = (struct peer *)h;
+ struct tls_ctx *t = peer->tls;
+
+ fprintf(stdout, "[gnutls_pull] pulling %zd bytes\n", len);
+
+ if (t->nread <= t->consumed) {
+ errno = EAGAIN;
+ fprintf(stdout, "[gnutls_pull] return EAGAIN\n");
+ return -1;
+ }
+
+ ssize_t avail = t->nread - t->consumed;
+ ssize_t transfer = (avail <= len ? avail : len);
+ memcpy(buf, t->buf + t->consumed, transfer);
+ t->consumed += transfer;
+ return transfer;
+}
+
+ssize_t proxy_gnutls_push(gnutls_transport_ptr_t h, const void *buf, size_t len)
+{
+ struct peer *client = (struct peer *)h;
+ fprintf(stdout, "[gnutls_push] writing %zd bytes\n", len);
+
+ ssize_t ret = -1;
+ const size_t req_size_aligned = ((sizeof(uv_write_t) / 16) + 1) * 16;
+ char *common_buf = malloc(req_size_aligned + len);
+ uv_write_t *req = (uv_write_t *) common_buf;
+ char *data = common_buf + req_size_aligned;
+ const uv_buf_t uv_buf[1] = {
+ { data, len }
+ };
+ memcpy(data, buf, len);
+ int res = uv_write(req, (uv_stream_t *)&client->handle, uv_buf, 1, write_to_client_cb);
+ if (res == 0) {
+ ret = len;
+ client->active_requests += 1;
+ } else {
+ free(common_buf);
+ errno = EIO;
+ }
+ return ret;
+}
+
+static int write_to_client_pending(struct peer *client)
+{
+ if (client->pending_buf.len == 0) {
+ return 0;
+ }
+
+ struct tls_proxy_ctx *proxy = get_proxy(client);
+ struct buf *buf = get_first_pending_buf(client);
+ fprintf(stdout, "[client] writing %zd bytes\n", buf->size);
+
+ gnutls_session_t tls_session = client->tls->session;
+ assert(client->tls->handshake_state != TLS_HS_IN_PROGRESS);
+
+ char *data = buf->buf;
+ size_t len = buf->size;
+
+ ssize_t count = 0;
+ ssize_t submitted = len;
+ ssize_t retries = 0;
+ do {
+ count = gnutls_record_send(tls_session, data, len);
+ if (count < 0) {
+ if (gnutls_error_is_fatal(count)) {
+ fprintf(stdout, "[client] gnutls_record_send failed: %s (%zd)\n",
+ gnutls_strerror_name(count), count);
+ return -1;
+ }
+ if (++retries > TLS_MAX_SEND_RETRIES) {
+ fprintf(stdout, "[client] gnutls_record_send: too many sequential non-fatal errors (%zd), last error is: %s (%zd)\n",
+ retries, gnutls_strerror_name(count), count);
+ return -1;
+ }
+ } else if (count != 0) {
+ data += count;
+ len -= count;
+ retries = 0;
+ } else {
+ if (++retries < TLS_MAX_SEND_RETRIES) {
+ continue;
+ }
+ fprintf(stdout, "[client] gnutls_record_send: too many retries (%zd)\n",
+ retries);
+ fprintf(stdout, "[client] tls_push_to_client didn't send all data(%zd of %zd)\n",
+ len, submitted);
+ return -1;
+ }
+ } while (len > 0);
+
+ remove_first_pending_buf(client);
+ free_io_buffer(buf);
+
+ fprintf(stdout, "[client] submitted %zd bytes\n", submitted);
+ if (proxy->a->rehandshake) {
+ int err = GNUTLS_E_SUCCESS;
+#if GNUTLS_VERSION_NUMBER >= 0x030604
+ if (proxy->a->tls_13) {
+ int flags = gnutls_session_get_flags(tls_session);
+ if ((flags & GNUTLS_SFLAGS_POST_HANDSHAKE_AUTH) == 0) {
+ /* Client doesn't support post-handshake re-authentication,
+ * nothing to test here */
+ fprintf(stdout, "[client] GNUTLS_SFLAGS_POST_HANDSHAKE_AUTH flag not detected\n");
+ assert(false);
+ }
+ err = gnutls_reauth(tls_session, 0);
+ if (err != GNUTLS_E_INTERRUPTED &&
+ err != GNUTLS_E_AGAIN &&
+ err != GNUTLS_E_GOT_APPLICATION_DATA) {
+ fprintf(stdout, "[client] gnutls_reauth() failed: %s (%i)\n",
+ gnutls_strerror_name(err), err);
+ } else {
+ fprintf(stdout, "[client] post-handshake authentication initiated\n");
+ }
+ client->tls->handshake_state = TLS_HS_REAUTH_EXPECTED;
+ } else {
+ assert (gnutls_safe_renegotiation_status(tls_session) != 0);
+ err = gnutls_rehandshake(tls_session);
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stdout, "[client] gnutls_rehandshake() failed: %s (%i)\n",
+ gnutls_strerror_name(err), err);
+ assert(false);
+ } else {
+ fprintf(stdout, "[client] rehandshake started\n");
+ }
+ client->tls->handshake_state = TLS_HS_EXPECTED;
+ }
+#else
+ assert (gnutls_safe_renegotiation_status(tls_session) != 0);
+ err = gnutls_rehandshake(tls_session);
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stdout, "[client] gnutls_rehandshake() failed: %s (%i)\n",
+ gnutls_strerror_name(err), err);
+ assert(false);
+ } else {
+ fprintf(stdout, "[client] rehandshake started\n");
+ }
+ /* Prevent write-to-client callback from sending next pending chunk.
+ * At the same time tls_process_from_client() must not call gnutls_handshake()
+ * as there can be application data in this direction. */
+ client->tls->handshake_state = TLS_HS_EXPECTED;
+#endif
+ }
+ return submitted;
+}
+
+static int tls_process_from_upstream(struct peer *upstream, const uint8_t *buf, ssize_t len)
+{
+ struct peer *client = upstream->peer;
+
+ fprintf(stdout, "[upstream] pushing %zd bytes to client\n", len);
+
+ ssize_t submitted = 0;
+ if (client->state != STATE_CONNECTED) {
+ return submitted;
+ }
+
+ bool list_was_empty = (client->pending_buf.len == 0);
+ push_to_client_pending(client, (const char *)buf, len);
+ submitted = len;
+ if (client->tls->handshake_state == TLS_HS_DONE) {
+ if (list_was_empty && client->pending_buf.len > 0) {
+ int ret = write_to_client_pending(client);
+ if (ret < 0) {
+ submitted = -1;
+ }
+ }
+ }
+
+ return submitted;
+}
+
+int tls_process_handshake(struct peer *peer)
+{
+ struct tls_ctx *tls = peer->tls;
+ int ret = 1;
+ while (tls->handshake_state == TLS_HS_IN_PROGRESS) {
+ fprintf(stdout, "[tls] TLS handshake in progress...\n");
+ int err = gnutls_handshake(tls->session);
+ if (err == GNUTLS_E_SUCCESS) {
+ tls->handshake_state = TLS_HS_DONE;
+ fprintf(stdout, "[tls] TLS handshake has completed\n");
+ ret = 1;
+ if (peer->pending_buf.len != 0) {
+ write_to_client_pending(peer);
+ }
+ } else if (gnutls_error_is_fatal(err)) {
+ fprintf(stdout, "[tls] gnutls_handshake failed: %s (%d)\n",
+ gnutls_strerror_name(err), err);
+ ret = -1;
+ break;
+ } else {
+ fprintf(stdout, "[tls] gnutls_handshake nonfatal error: %s (%d)\n",
+ gnutls_strerror_name(err), err);
+ ret = 0;
+ break;
+ }
+ }
+ return ret;
+}
+
+#if GNUTLS_VERSION_NUMBER >= 0x030604
+int tls_process_reauth(struct peer *peer)
+{
+ struct tls_ctx *tls = peer->tls;
+ int ret = 1;
+ while (tls->handshake_state == TLS_HS_REAUTH_EXPECTED) {
+ fprintf(stdout, "[tls] TLS re-authentication in progress...\n");
+ int err = gnutls_reauth(tls->session, 0);
+ if (err == GNUTLS_E_SUCCESS) {
+ tls->handshake_state = TLS_HS_DONE;
+ fprintf(stdout, "[tls] TLS re-authentication has completed\n");
+ ret = 1;
+ if (peer->pending_buf.len != 0) {
+ write_to_client_pending(peer);
+ }
+ } else if (err != GNUTLS_E_INTERRUPTED &&
+ err != GNUTLS_E_AGAIN &&
+ err != GNUTLS_E_GOT_APPLICATION_DATA) {
+ /* these are listed as nonfatal errors there
+ * https://www.gnutls.org/manual/gnutls.html#gnutls_005freauth */
+ fprintf(stdout, "[tls] gnutls_reauth failed: %s (%d)\n",
+ gnutls_strerror_name(err), err);
+ ret = -1;
+ break;
+ } else {
+ fprintf(stdout, "[tls] gnutls_reauth nonfatal error: %s (%d)\n",
+ gnutls_strerror_name(err), err);
+ ret = 0;
+ break;
+ }
+ }
+ return ret;
+}
+#endif
+
+int tls_process_from_client(struct peer *client, const uint8_t *buf, ssize_t nread)
+{
+ struct tls_ctx *tls = client->tls;
+
+ tls->buf = buf;
+ tls->nread = nread >= 0 ? nread : 0;
+ tls->consumed = 0;
+
+ fprintf(stdout, "[client] tls_process: reading %zd bytes from client\n", nread);
+
+ int ret = 0;
+ if (tls->handshake_state == TLS_HS_REAUTH_EXPECTED) {
+ ret = tls_process_reauth(client);
+ } else {
+ ret = tls_process_handshake(client);
+ }
+ if (ret <= 0) {
+ return ret;
+ }
+
+ int submitted = 0;
+ while (true) {
+ ssize_t count = gnutls_record_recv(tls->session, tls->recv_buf, sizeof(tls->recv_buf));
+ if (count == GNUTLS_E_AGAIN) {
+ break; /* No data available */
+ } else if (count == GNUTLS_E_INTERRUPTED) {
+ continue; /* Try reading again */
+ } else if (count == GNUTLS_E_REHANDSHAKE) {
+ tls->handshake_state = TLS_HS_IN_PROGRESS;
+ ret = tls_process_handshake(client);
+ if (ret < 0) { /* Critical error */
+ return ret;
+ }
+ if (ret == 0) { /* Non fatal, most likely GNUTLS_E_AGAIN */
+ break;
+ }
+ continue;
+ }
+#if GNUTLS_VERSION_NUMBER >= 0x030604
+ else if (count == GNUTLS_E_REAUTH_REQUEST) {
+ assert(false);
+ tls->handshake_state = TLS_HS_IN_PROGRESS;
+ ret = tls_process_reauth(client);
+ if (ret < 0) { /* Critical error */
+ return ret;
+ }
+ if (ret == 0) { /* Non fatal, most likely GNUTLS_E_AGAIN */
+ break;
+ }
+ continue;
+ }
+#endif
+ else if (count < 0) {
+ fprintf(stdout, "[client] gnutls_record_recv failed: %s (%zd)\n",
+ gnutls_strerror_name(count), count);
+ assert(false);
+ return -1;
+ } else if (count == 0) {
+ break;
+ }
+ struct peer *upstream = client->peer;
+ if (upstream->state == STATE_CONNECTED) {
+ bool upstream_pending_is_empty = (upstream->pending_buf.len == 0);
+ push_to_upstream_pending(upstream, (const char *)tls->recv_buf, count);
+ if (upstream_pending_is_empty) {
+ write_to_upstream_pending(upstream);
+ }
+ } else if (upstream->state == STATE_NOT_CONNECTED) {
+ uv_connect_t *conn = (uv_connect_t *) malloc(sizeof(uv_connect_t));
+ upstream->state = STATE_CONNECT_IN_PROGRESS;
+ fprintf(stdout, "[client] connecting to upstream '%s'\n", ip_straddr(&upstream->addr));
+ uv_tcp_connect(conn, &upstream->handle, (struct sockaddr *)&upstream->addr,
+ on_connect_to_upstream);
+ push_to_upstream_pending(upstream, (const char *)tls->recv_buf, count);
+ } else if (upstream->state == STATE_CONNECT_IN_PROGRESS) {
+ push_to_upstream_pending(upstream, (const char *)tls->recv_buf, count);
+ }
+ submitted += count;
+ }
+ return submitted;
+}
+
+struct tls_proxy_ctx *tls_proxy_allocate()
+{
+ return malloc(sizeof(struct tls_proxy_ctx));
+}
+
+int tls_proxy_init(struct tls_proxy_ctx *proxy, const struct args *a)
+{
+ const char *server_addr = a->local_addr;
+ int server_port = a->local_port;
+ const char *upstream_addr = a->upstream;
+ int upstream_port = a->upstream_port;
+ const char *cert_file = a->cert_file;
+ const char *key_file = a->key_file;
+ proxy->a = a;
+ proxy->loop = uv_default_loop();
+ uv_tcp_init(proxy->loop, &proxy->server.handle);
+ int res = uv_ip4_addr(server_addr, server_port, (struct sockaddr_in *)&proxy->server.addr);
+ if (res != 0) {
+ res = uv_ip6_addr(server_addr, server_port, (struct sockaddr_in6 *)&proxy->server.addr);
+ if (res != 0) {
+ fprintf(stdout, "[proxy] tls_proxy_init: can't parse local address '%s'\n", server_addr);
+ return -1;
+ }
+ }
+ res = uv_ip4_addr(upstream_addr, upstream_port, (struct sockaddr_in *)&proxy->upstream_addr);
+ if (res != 0) {
+ res = uv_ip6_addr(upstream_addr, upstream_port, (struct sockaddr_in6 *)&proxy->upstream_addr);
+ if (res != 0) {
+ fprintf(stdout, "[proxy] tls_proxy_init: can't parse upstream address '%s'\n", upstream_addr);
+ return -1;
+ }
+ }
+ array_init(proxy->client_list);
+ proxy->conn_sequence = 0;
+
+ proxy->loop->data = proxy;
+
+ int err = 0;
+ if (gnutls_references == 0) {
+ err = gnutls_global_init();
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stdout, "[proxy] gnutls_global_init() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ return -1;
+ }
+ }
+ gnutls_references += 1;
+
+ err = gnutls_certificate_allocate_credentials(&proxy->tls_credentials);
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stdout, "[proxy] gnutls_certificate_allocate_credentials() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ return -1;
+ }
+
+ err = gnutls_certificate_set_x509_system_trust(proxy->tls_credentials);
+ if (err <= 0) {
+ fprintf(stdout, "[proxy] gnutls_certificate_set_x509_system_trust() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ return -1;
+ }
+
+ if (cert_file && key_file) {
+ err = gnutls_certificate_set_x509_key_file(proxy->tls_credentials,
+ cert_file, key_file, GNUTLS_X509_FMT_PEM);
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stdout, "[proxy] gnutls_certificate_set_x509_key_file() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ return -1;
+ }
+ }
+
+ err = gnutls_priority_init(&proxy->tls_priority_cache, NULL, NULL);
+ if (err != GNUTLS_E_SUCCESS) {
+ fprintf(stdout, "[proxy] gnutls_priority_init() failed: (%d) %s\n",
+ err, gnutls_strerror_name(err));
+ return -1;
+ }
+
+ return 0;
+}
+
+void tls_proxy_free(struct tls_proxy_ctx *proxy)
+{
+ if (!proxy) {
+ return;
+ }
+ while (proxy->client_list.len > 0) {
+ size_t last_index = proxy->client_list.len - 1;
+ struct peer *client = proxy->client_list.at[last_index];
+ clear_pending_bufs(client);
+ clear_pending_bufs(client->peer);
+ /* TODO correctly close all the uv_tcp_t */
+ free(client->peer);
+ free(client);
+ array_del(proxy->client_list, last_index);
+ }
+ gnutls_certificate_free_credentials(proxy->tls_credentials);
+ gnutls_priority_deinit(proxy->tls_priority_cache);
+ free(proxy);
+
+ gnutls_references -= 1;
+ if (gnutls_references == 0) {
+ gnutls_global_deinit();
+ }
+}
+
+int tls_proxy_start_listen(struct tls_proxy_ctx *proxy)
+{
+ uv_tcp_bind(&proxy->server.handle, (const struct sockaddr*)&proxy->server.addr, 0);
+ int ret = uv_listen((uv_stream_t*)&proxy->server.handle, 128, on_client_connection);
+ return ret;
+}
+
+int tls_proxy_run(struct tls_proxy_ctx *proxy)
+{
+ return uv_run(proxy->loop, UV_RUN_DEFAULT);
+}
diff --git a/tests/pytests/proxy/tls-proxy.h b/tests/pytests/proxy/tls-proxy.h
new file mode 100644
index 0000000..2dbd9fd
--- /dev/null
+++ b/tests/pytests/proxy/tls-proxy.h
@@ -0,0 +1,34 @@
+#pragma once
+/* SPDX-License-Identifier: GPL-3.0-or-later */
+
+#define __STDC_FORMAT_MACROS
+#include <inttypes.h>
+#include <stdint.h>
+#include <stdbool.h>
+#include <netinet/in.h>
+
+struct args {
+ const char *local_addr;
+ uint16_t local_port;
+ const char *upstream;
+ uint16_t upstream_port;
+
+ bool rehandshake;
+ bool close_connection;
+ bool accept_only;
+ bool tls_13;
+
+ uint64_t close_timeout;
+ uint32_t max_conn_sequence;
+
+ const char *cert_file;
+ const char *key_file;
+};
+
+struct tls_proxy_ctx;
+
+struct tls_proxy_ctx *tls_proxy_allocate();
+void tls_proxy_free(struct tls_proxy_ctx *proxy);
+int tls_proxy_init(struct tls_proxy_ctx *proxy, const struct args *a);
+int tls_proxy_start_listen(struct tls_proxy_ctx *proxy);
+int tls_proxy_run(struct tls_proxy_ctx *proxy);
diff --git a/tests/pytests/proxy/tlsproxy.c b/tests/pytests/proxy/tlsproxy.c
new file mode 100644
index 0000000..bcdffb0
--- /dev/null
+++ b/tests/pytests/proxy/tlsproxy.c
@@ -0,0 +1,198 @@
+/* SPDX-License-Identifier: GPL-3.0-or-later */
+
+#include <stdio.h>
+#include <getopt.h>
+#include <stdlib.h>
+#include <signal.h>
+#include <errno.h>
+#include <string.h>
+#include <gnutls/gnutls.h>
+#include "tls-proxy.h"
+
+static char default_local_addr[] = "127.0.0.1";
+static char default_upstream_addr[] = "127.0.0.1";
+static char default_cert_path[] = "../certs/tt.cert.pem";
+static char default_key_path[] = "../certs/tt.key.pem";
+
+void help(char *argv[], struct args *a)
+{
+ printf("Usage: %s [parameters] [rundir]\n", argv[0]);
+ printf("\nParameters:\n"
+ " -l, --local=[addr] Server address to bind to (default: %s).\n"
+ " -p, --lport=[port] Server port to bind to (default: %u).\n"
+ " -u, --upstream=[addr] Upstream address (default: %s).\n"
+ " -d, --uport=[port] Upstream port (default: %u).\n"
+ " -t, --cert=[path] Path to certificate file (default: %s).\n"
+ " -k, --key=[path] Path to key file (default: %s).\n"
+ " -c, --close=[N] Close connection to client after\n"
+ " every N ms (default: %" PRIu64 ").\n"
+ " -f, --fail=[N] Delay every Nth incoming connection by 10 sec,\n"
+ " 0 disables delaying (default: 0).\n"
+ " -r, --rehandshake Do TLS rehandshake after every 8 bytes\n"
+ " sent to the client (default: no).\n"
+ " -a, --acceptonly Accept incoming connections, but don't\n"
+ " connect to upstream (default: no).\n"
+ " -v, --tls13 Force use of TLSv1.3. If not turned on,\n"
+ " TLSv1.2 will be used (default: no).\n"
+ ,
+ a->local_addr, a->local_port,
+ a->upstream, a->upstream_port,
+ a->cert_file, a->key_file,
+ a->close_timeout);
+}
+
+void init_args(struct args *a)
+{
+ a->local_addr = default_local_addr;
+ a->local_port = 54000;
+ a->upstream = default_upstream_addr;
+ a->upstream_port = 53000;
+ a->cert_file = default_cert_path;
+ a->key_file = default_key_path;
+ a->rehandshake = false;
+ a->accept_only = false;
+ a->tls_13 = false;
+ a->close_connection = false;
+ a->close_timeout = 1000;
+ a->max_conn_sequence = 0; /* disabled */
+}
+
+int main(int argc, char **argv)
+{
+ long int li_value = 0;
+ int c = 0, li = 0;
+ struct option opts[] = {
+ {"local", required_argument, 0, 'l'},
+ {"lport", required_argument, 0, 'p'},
+ {"upstream", required_argument, 0, 'u'},
+ {"uport", required_argument, 0, 'd'},
+ {"cert", required_argument, 0, 't'},
+ {"key", required_argument, 0, 'k'},
+ {"close", required_argument, 0, 'c'},
+ {"fail", required_argument, 0, 'f'},
+ {"rehandshake", no_argument, 0, 'r'},
+ {"acceptonly", no_argument, 0, 'a'},
+#if GNUTLS_VERSION_NUMBER >= 0x030604
+ {"tls13", no_argument, 0, 'v'},
+#endif
+ {0, 0, 0, 0}
+ };
+ struct args args;
+ init_args(&args);
+ while ((c = getopt_long(argc, argv, "l:p:u:d:t:k:c:f:rav", opts, &li)) != -1) {
+ switch (c)
+ {
+ case 'l':
+ args.local_addr = optarg;
+ break;
+ case 'u':
+ args.upstream = optarg;
+ break;
+ case 't':
+ args.cert_file = optarg;
+ break;
+ case 'k':
+ args.key_file = optarg;
+ break;
+ case 'p':
+ li_value = strtol(optarg, NULL, 10);
+ if (li_value <= 0 || li_value > UINT16_MAX) {
+ printf("error: '-p' requires a positive"
+ " number less or equal to 65535, not '%s'\n", optarg);
+ return -1;
+ }
+ args.local_port = (uint16_t)li_value;
+ break;
+ case 'd':
+ li_value = strtol(optarg, NULL, 10);
+ if (li_value <= 0 || li_value > UINT16_MAX) {
+ printf("error: '-d' requires a positive"
+ " number less or equal to 65535, not '%s'\n", optarg);
+ return -1;
+ }
+ args.upstream_port = (uint16_t)li_value;
+ break;
+ case 'c':
+ li_value = strtol(optarg, NULL, 10);
+ if (li_value <= 0) {
+ printf("[system] error '-c' requires a positive"
+ " number, not '%s'\n", optarg);
+ return -1;
+ }
+ args.close_connection = true;
+ args.close_timeout = li_value;
+ break;
+ case 'f':
+ li_value = strtol(optarg, NULL, 10);
+ if (li_value <= 0 || li_value > UINT32_MAX) {
+ printf("error: '-f' requires a positive"
+ " number less or equal to %i, not '%s'\n",
+ UINT32_MAX, optarg);
+ return -1;
+ }
+ args.max_conn_sequence = (uint32_t)li_value;
+ break;
+ case 'r':
+ args.rehandshake = true;
+ break;
+ case 'a':
+ args.accept_only = true;
+ break;
+ case 'v':
+#if GNUTLS_VERSION_NUMBER >= 0x030604
+ args.tls_13 = true;
+#endif
+ break;
+ default:
+ init_args(&args);
+ help(argv, &args);
+ return -1;
+ }
+ }
+ if (signal(SIGPIPE, SIG_IGN) == SIG_ERR) {
+ fprintf(stderr, "failed to set up SIGPIPE handler to ignore(%s)\n",
+ strerror(errno));
+ }
+ struct tls_proxy_ctx *proxy = tls_proxy_allocate();
+ if (!proxy) {
+ fprintf(stderr, "can't allocate tls_proxy structure\n");
+ return 1;
+ }
+ int res = tls_proxy_init(proxy, &args);
+ if (res) {
+ fprintf(stderr, "can't initialize tls_proxy structure\n");
+ return res;
+ }
+ res = tls_proxy_start_listen(proxy);
+ if (res) {
+ fprintf(stderr, "error starting listen, error code: %i\n", res);
+ return res;
+ }
+ fprintf(stdout, "Listen on %s#%u\n"
+ "Upstream is expected on %s#%u\n"
+ "Certificate file %s\n"
+ "Key file %s\n"
+ "Rehandshake %s\n"
+ "Close %s\n"
+ "Refuse incoming connections every %ith%s\n"
+ "Only accept, don't forward %s\n"
+ "Force TLSv1.3 %s\n"
+ ,
+ args.local_addr, args.local_port,
+ args.upstream, args.upstream_port,
+ args.cert_file, args.key_file,
+ args.rehandshake ? "yes" : "no",
+ args.close_connection ? "yes" : "no",
+ args.max_conn_sequence, args.max_conn_sequence ? "" : " (disabled)",
+ args.accept_only ? "yes" : "no",
+#if GNUTLS_VERSION_NUMBER >= 0x030604
+ args.tls_13 ? "yes" : "no"
+#else
+ "Not supported"
+#endif
+ );
+ res = tls_proxy_run(proxy);
+ tls_proxy_free(proxy);
+ return res;
+}
+