summaryrefslogtreecommitdiffstats
path: root/tests/contrib/test_net.c
diff options
context:
space:
mode:
Diffstat (limited to 'tests/contrib/test_net.c')
-rw-r--r--tests/contrib/test_net.c718
1 files changed, 718 insertions, 0 deletions
diff --git a/tests/contrib/test_net.c b/tests/contrib/test_net.c
new file mode 100644
index 0000000..c0061cd
--- /dev/null
+++ b/tests/contrib/test_net.c
@@ -0,0 +1,718 @@
+/* Copyright (C) 2023 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+#include <tap/basic.h>
+
+#include <assert.h>
+#include <fcntl.h>
+#include <poll.h>
+#include <pthread.h>
+#include <signal.h>
+#include <stdbool.h>
+#include <string.h>
+#include <unistd.h>
+
+#include "libknot/errcode.h"
+#include "contrib/net.h"
+#include "contrib/sockaddr.h"
+
+const int TIMEOUT = 5000;
+const int TIMEOUT_SHORT = 500;
+
+/*!
+ * \brief Get loopback socket address with unset port.
+ */
+static struct sockaddr_storage addr_local(void)
+{
+ struct sockaddr_storage addr = { 0 };
+
+ struct sockaddr_in *addr4 = (struct sockaddr_in *)&addr;
+ addr4->sin_family = AF_INET;
+ addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+
+ return addr;
+}
+
+/*!
+ * \brief Get address of a socket.
+ */
+static struct sockaddr_storage addr_from_socket(int sock)
+{
+ struct sockaddr_storage addr = { 0 };
+ socklen_t len = sizeof(addr);
+ int ret = getsockname(sock, (struct sockaddr *)&addr, &len);
+ is_int(0, ret, "check getsockname return");
+
+ return addr;
+}
+
+static const char *socktype_name(int type)
+{
+ switch (type) {
+ case SOCK_STREAM: return "TCP";
+ case SOCK_DGRAM: return "UDP";
+ default: return "unknown";
+ }
+}
+
+static bool socktype_is_stream(int type)
+{
+ return type == SOCK_STREAM;
+}
+
+/* -- mock server ---------------------------------------------------------- */
+
+#define LISTEN_BACKLOG 5
+
+struct server_ctx;
+typedef struct server_ctx server_ctx_t;
+
+typedef void (*server_cb)(int sock, void *data);
+
+/*!
+ * \brief Server context.
+ */
+struct server_ctx {
+ int sock;
+ int type;
+ bool terminate;
+ server_cb handler;
+ void *handler_data;
+
+ pthread_t thr;
+ pthread_mutex_t mx;
+};
+
+static int poll_read(int sock)
+{
+ struct pollfd pfd = { .fd = sock, .events = POLLIN };
+ return poll(&pfd, 1, TIMEOUT);
+}
+
+static void server_handle(server_ctx_t *ctx)
+{
+ int remote = ctx->sock;
+
+ assert(ctx->type == SOCK_STREAM || ctx->type == SOCK_DGRAM);
+
+ if (socktype_is_stream(ctx->type)) {
+ remote = accept(ctx->sock, 0, 0);
+ if (remote < 0) {
+ return;
+ }
+ }
+
+ pthread_mutex_lock(&ctx->mx);
+ server_cb handler = ctx->handler;
+ pthread_mutex_unlock(&ctx->mx);
+ handler(remote, ctx->handler_data);
+
+ if (socktype_is_stream(ctx->type)) {
+ close(remote);
+ }
+}
+
+/*!
+ * \brief Simple server.
+ *
+ * Terminated when a one-byte message is delivered.
+ */
+static void *server_main(void *_ctx)
+{
+ server_ctx_t *ctx = _ctx;
+
+ for (;;) {
+ pthread_mutex_lock(&ctx->mx);
+ bool terminate = ctx->terminate;
+ pthread_mutex_unlock(&ctx->mx);
+ if (terminate) {
+ break;
+ }
+
+ int r = poll_read(ctx->sock);
+ if (r == -1) {
+ if (errno == EINTR) {
+ continue;
+ } else {
+ break;
+ }
+ } else if (r == 0) {
+ continue;
+ }
+
+ assert(r == 1);
+ server_handle(ctx);
+ }
+
+ return NULL;
+}
+
+static bool server_start(server_ctx_t *ctx, int sock, int type,
+ server_cb handler, void *handler_data)
+{
+ memset(ctx, 0, sizeof(*ctx));
+
+ ctx->sock = sock;
+ ctx->type = type;
+ ctx->handler = handler;
+ ctx->handler_data = handler_data;
+
+ ctx->terminate = false;
+
+ pthread_mutex_init(&ctx->mx, NULL);
+ return (pthread_create(&ctx->thr, NULL, server_main, ctx) == 0);
+}
+
+static void server_stop(server_ctx_t *ctx)
+{
+ pthread_mutex_lock(&ctx->mx);
+ ctx->terminate = true;
+ pthread_mutex_unlock(&ctx->mx);
+
+ pthread_kill(ctx->thr, SIGUSR1);
+ pthread_join(ctx->thr, NULL);
+}
+
+/* -- tests ---------------------------------------------------------------- */
+
+static void handler_echo(int sock, void *_server)
+{
+ server_ctx_t *server = _server;
+ uint8_t buffer[16] = { 0 };
+
+ struct sockaddr_storage remote = { 0 };
+ struct sockaddr_storage *addr = NULL;
+ if (!socktype_is_stream(server->type)) {
+ addr = &remote;
+ }
+
+ int in = net_base_recv(sock, buffer, sizeof(buffer), addr, TIMEOUT);
+ if (in <= 0) {
+ return;
+ }
+
+ net_base_send(sock, buffer, in, addr, TIMEOUT);
+}
+
+static void test_connected_one(const struct sockaddr_storage *server_addr,
+ const struct sockaddr_storage *source_addr,
+ int type, const char *name, const char *addr_name)
+{
+ int r;
+
+ int client = net_connected_socket(type, server_addr, source_addr, false);
+ ok(client >= 0, "%s, %s: client, create connected socket", name, addr_name);
+
+ const uint8_t out[] = "test message";
+ const size_t out_len = sizeof(out);
+ if (socktype_is_stream(type)) {
+ r = net_stream_send(client, out, out_len, TIMEOUT);
+ } else {
+ r = net_dgram_send(client, out, out_len, NULL);
+ }
+ is_int(out_len, r, "%s, %s: client, send message", name, addr_name);
+
+ r = net_is_connected(client);
+ ok(r, "%s, %s: client, is connected", name, addr_name);
+
+ uint8_t in[128] = { 0 };
+ if (socktype_is_stream(type)) {
+ r = net_stream_recv(client, in, sizeof(in), TIMEOUT);
+ } else {
+ r = net_dgram_recv(client, in, sizeof(in), TIMEOUT);
+ }
+ is_int(out_len, r, "%s, %s: client, receive message length", name, addr_name);
+ ok(memcmp(out, in, out_len) == 0, "%s, %s: client, receive message", name, addr_name);
+
+ close(client);
+}
+
+static void test_connected(int type)
+{
+ const char *name = socktype_name(type);
+ const struct sockaddr_storage empty_addr = { 0 };
+ const struct sockaddr_storage local_addr = addr_local();
+
+ int r;
+
+ // setup server
+
+ int server = net_bound_socket(type, &local_addr, 0, 0);
+ ok(server >= 0, "%s: server, create bound socket", name);
+
+ if (socktype_is_stream(type)) {
+ r = listen(server, LISTEN_BACKLOG);
+ is_int(0, r, "%s: server, start listening", name);
+ }
+
+ server_ctx_t server_ctx = { 0 };
+ r = server_start(&server_ctx, server, type, handler_echo, &server_ctx);
+ ok(r, "%s: server, start", name);
+
+ const struct sockaddr_storage server_addr = addr_from_socket(server);
+
+ // connected socket, send and receive
+
+ test_connected_one(&server_addr, NULL, type, name, "NULL source");
+ test_connected_one(&server_addr, &empty_addr, type, name, "zero source");
+ test_connected_one(&server_addr, &local_addr, type, name, "valid source");
+
+ // cleanup
+
+ server_stop(&server_ctx);
+ close(server);
+}
+
+static void handler_noop(int sock, void *data)
+{
+}
+
+static void test_unconnected(void)
+{
+ int r = 0;
+ int sock = -1;
+ const struct sockaddr_storage local = addr_local();
+
+ uint8_t buffer[] = { 'k', 'n', 'o', 't' };
+ ssize_t buffer_len = sizeof(buffer);
+
+ // server
+
+ int server = net_bound_socket(SOCK_DGRAM, &local, 0, 0);
+ ok(server >= 0, "UDP, create server socket");
+
+ server_ctx_t server_ctx = { 0 };
+ r = server_start(&server_ctx, server, SOCK_DGRAM, handler_noop, NULL);
+ ok(r, "UDP, start server");
+
+ // UDP
+
+ sock = net_unbound_socket(SOCK_DGRAM, &local);
+ ok(sock >= 0, "UDP, create unbound socket");
+
+ ok(!net_is_connected(sock), "UDP, is not connected");
+
+ r = net_dgram_send(sock, buffer, buffer_len, NULL);
+ is_int(KNOT_ECONN, r, "UDP, send failure on unconnected socket");
+
+ r = net_dgram_recv(sock, buffer, buffer_len, TIMEOUT_SHORT);
+ is_int(KNOT_ETIMEOUT, r, "UDP, receive timeout on unconnected socket");
+
+ struct sockaddr_storage server_addr = addr_from_socket(server);
+ r = net_dgram_send(sock, buffer, buffer_len, &server_addr);
+ is_int(buffer_len, r, "UDP, send on defined address");
+
+ close(sock);
+
+ // TCP
+
+ sock = net_unbound_socket(SOCK_STREAM, &local);
+ ok(sock >= 0, "TCP, create unbound socket");
+
+ ok(!net_is_connected(sock), "TCP, is not connected");
+
+#ifdef __linux__
+ const int expected = KNOT_ECONN;
+ const char *expected_msg = "failure";
+ const int expected_timeout = TIMEOUT;
+#else
+ const int expected = KNOT_ETIMEOUT;
+ const char *expected_msg = "timeout";
+ const int expected_timeout = TIMEOUT_SHORT;
+#endif
+
+ r = net_stream_send(sock, buffer, buffer_len, expected_timeout);
+ is_int(expected, r, "TCP, send %s on unconnected socket", expected_msg);
+
+ r = net_stream_recv(sock, buffer, sizeof(buffer), expected_timeout);
+ is_int(expected, r, "TCP, receive %s on unconnected socket", expected_msg);
+
+ close(sock);
+
+ // server termination
+
+ server_stop(&server_ctx);
+ close(server);
+}
+
+static void test_refused(void)
+{
+ int r = -1;
+
+ struct sockaddr_storage addr = { 0 };
+ uint8_t buffer[1] = { 0 };
+ int server, client;
+
+ // listening, not accepting
+
+ addr = addr_local();
+ server = net_bound_socket(SOCK_STREAM, &addr, 0, 0);
+ ok(server >= 0, "server, create server");
+ addr = addr_from_socket(server);
+
+ r = listen(server, LISTEN_BACKLOG);
+ is_int(0, r, "server, start listening");
+
+ client = net_connected_socket(SOCK_STREAM, &addr, NULL, false);
+ ok(client >= 0, "client, connect");
+
+ r = net_stream_send(client, (uint8_t *)"", 1, TIMEOUT);
+ is_int(1, r, "client, successful write");
+
+ r = net_stream_recv(client, buffer, sizeof(buffer), TIMEOUT_SHORT);
+ is_int(KNOT_ETIMEOUT, r, "client, timeout on read");
+
+ close(client);
+
+ // listening, closed immediately
+
+ client = net_connected_socket(SOCK_STREAM, &addr, NULL, false);
+ ok(client >= 0, "client, connect");
+
+ r = close(server);
+ is_int(0, r, "server, close socket");
+ usleep(50000);
+
+ r = net_stream_send(client, (uint8_t *)"", 1, TIMEOUT);
+ is_int(KNOT_ECONN, r, "client, refused on write");
+
+ close(client);
+}
+
+struct dns_handler_ctx {
+ const uint8_t *expected;
+ int len;
+ bool raw;
+ bool success;
+};
+
+static bool _sync(int remote, int send)
+{
+ uint8_t buf[1] = { 0 };
+ int r;
+ if (send) {
+ r = net_stream_send(remote, buf, sizeof(buf), TIMEOUT);
+ } else {
+ r = net_stream_recv(remote, buf, sizeof(buf), TIMEOUT);
+
+ }
+ return r == sizeof(buf);
+}
+
+static bool sync_signal(int remote)
+{
+ return _sync(remote, true);
+}
+
+static bool sync_wait(int remote)
+{
+ return _sync(remote, false);
+}
+
+static void handler_dns(int sock, void *_ctx)
+{
+ struct dns_handler_ctx *ctx = _ctx;
+
+ uint8_t in[16] = { 0 };
+ int in_len = 0;
+
+ sync_signal(sock);
+
+ if (ctx->raw) {
+ in_len = net_stream_recv(sock, in, sizeof(in), TIMEOUT);
+ } else {
+ in_len = net_dns_tcp_recv(sock, in, sizeof(in), TIMEOUT);
+ }
+
+ ctx->success = in_len == ctx->len &&
+ (ctx->len < 0 || memcmp(in, ctx->expected, in_len) == 0);
+}
+
+static void dns_send_hello(int sock)
+{
+ net_dns_tcp_send(sock, (uint8_t *)"wimbgunts", 9, TIMEOUT, NULL);
+}
+
+static void dns_send_fragmented(int sock)
+{
+ struct fragment { const uint8_t *data; size_t len; };
+
+ const struct fragment fragments[] = {
+ { (uint8_t *)"\x00", 1 },
+ { (uint8_t *)"\x08""qu", 3 },
+ { (uint8_t *)"oopisk", 6 },
+ { NULL }
+ };
+
+ for (const struct fragment *f = fragments; f->len > 0; f++) {
+ net_stream_send(sock, f->data, f->len, TIMEOUT);
+ }
+}
+
+static void dns_send_incomplete(int sock)
+{
+ net_stream_send(sock, (uint8_t *)"\x00\x08""korm", 6, TIMEOUT);
+}
+
+static void dns_send_trailing(int sock)
+{
+ net_stream_send(sock, (uint8_t *)"\x00\x05""bloitxx", 9, TIMEOUT);
+}
+
+static void test_dns_tcp(void)
+{
+ struct testcase {
+ const char *name;
+ const uint8_t *expected;
+ size_t expected_len;
+ bool expected_raw;
+ void (*send_callback)(int sock);
+ };
+
+ const struct testcase testcases[] = {
+ { "single DNS", (uint8_t *)"wimbgunts", 9, false, dns_send_hello },
+ { "single RAW", (uint8_t *)"\x00\x09""wimbgunts", 11, true, dns_send_hello },
+ { "fragmented", (uint8_t *)"quoopisk", 8, false, dns_send_fragmented },
+ { "incomplete", NULL, KNOT_ECONN, false, dns_send_incomplete },
+ { "trailing garbage", (uint8_t *)"bloit", 5, false, dns_send_trailing },
+ { NULL }
+ };
+
+ for (const struct testcase *t = testcases; t->name != NULL; t++) {
+ struct dns_handler_ctx handler_ctx = {
+ .expected = t->expected,
+ .len = t->expected_len,
+ .raw = t->expected_raw,
+ .success = false
+ };
+
+ struct sockaddr_storage addr = addr_local();
+ int server = net_bound_socket(SOCK_STREAM, &addr, 0, 0);
+ ok(server >= 0, "%s, server, create socket", t->name);
+
+ int r = listen(server, LISTEN_BACKLOG);
+ is_int(0, r, "%s, server, start listening", t->name);
+
+ server_ctx_t server_ctx = { 0 };
+ r = server_start(&server_ctx, server, SOCK_STREAM, handler_dns, &handler_ctx);
+ ok(r, "%s, server, start handler", t->name);
+
+ addr = addr_from_socket(server);
+ int client = net_connected_socket(SOCK_STREAM, &addr, NULL, false);
+ ok(client >= 0, "%s, client, create connected socket", t->name);
+
+ r = sync_wait(client);
+ ok(r, "%s, client, wait for stream read", t->name);
+ t->send_callback(client);
+
+ close(client);
+ server_stop(&server_ctx);
+ close(server);
+
+ ok(handler_ctx.success, "%s, expected result", t->name);
+ }
+}
+
+static bool socket_is_blocking(int sock)
+{
+ return fcntl(sock, F_GETFL, O_NONBLOCK) == 0;
+}
+
+static void test_nonblocking_mode(int type)
+{
+ const char *name = socktype_name(type);
+ const struct sockaddr_storage addr = addr_local();
+
+ int client = net_unbound_socket(type, &addr);
+ ok(client >= 0, "%s: unbound, create", name);
+ ok(!socket_is_blocking(client), "%s: unbound, nonblocking mode", name);
+ close(client);
+
+ int server = net_bound_socket(type, &addr, 0, 0);
+ ok(server >= 0, "%s: bound, create", name);
+ ok(!socket_is_blocking(server), "%s: bound, nonblocking mode", name);
+
+ if (socktype_is_stream(type)) {
+ int r = listen(server, LISTEN_BACKLOG);
+ is_int(0, r, "%s: bound, start listening", name);
+ }
+
+ struct sockaddr_storage server_addr = addr_from_socket(server);
+ client = net_connected_socket(type, &server_addr, NULL, false);
+ ok(client >= 0, "%s: connected, create", name);
+ ok(!socket_is_blocking(client), "%s: connected, nonblocking mode", name);
+
+ close(client);
+ close(server);
+}
+
+static void test_nonblocking_accept(void)
+{
+ int r;
+
+ // create server
+
+ struct sockaddr_storage addr_server = addr_local();
+
+ int server = net_bound_socket(SOCK_STREAM, &addr_server, 0, 0);
+ ok(server >= 0, "server, create socket");
+
+ r = listen(server, LISTEN_BACKLOG);
+ is_int(0, r, "server, start listening");
+
+ addr_server = addr_from_socket(server);
+
+ // create client
+
+ int client = net_connected_socket(SOCK_STREAM, &addr_server, NULL, false);
+ ok(client >= 0, "client, create connected socket");
+
+ struct sockaddr_storage addr_client = addr_from_socket(client);
+
+ // accept connection
+
+ r = poll_read(server);
+ is_int(1, r, "server, pending connection");
+
+ struct sockaddr_storage addr_accepted = { 0 };
+ int accepted = net_accept(server, &addr_accepted);
+ ok(accepted >= 0, "server, accept connection");
+
+ ok(!socket_is_blocking(accepted), "accepted, nonblocking mode");
+
+ ok(sockaddr_cmp(&addr_client, &addr_accepted, false) == 0,
+ "accepted, correct address");
+
+ close(client);
+
+ // client reconnect
+
+ close(client);
+ client = net_connected_socket(SOCK_STREAM, &addr_server, NULL, false);
+ ok(client >= 0, "client, reconnect");
+
+ r = poll_read(server);
+ is_int(1, r, "server, pending connection");
+
+ accepted = net_accept(server, NULL);
+ ok(accepted >= 0, "server, accept connection (no remote address)");
+
+ ok(!socket_is_blocking(accepted), "accepted, nonblocking mode");
+
+ // cleanup
+
+ close(client);
+ close(server);
+}
+
+static void test_socket_types(void)
+{
+ struct sockaddr_storage addr = addr_local();
+
+ struct testcase {
+ const char *name;
+ int type;
+ bool is_stream;
+ };
+
+ const struct testcase testcases[] = {
+ { "UDP", SOCK_DGRAM, false },
+ { "TCP", SOCK_STREAM, true },
+ { NULL }
+ };
+
+ for (const struct testcase *t = testcases; t->name != NULL; t++) {
+ int sock = net_unbound_socket(t->type, &addr);
+ ok(sock >= 0, "%s, create socket", t->name);
+
+ is_int(t->type, net_socktype(sock), "%s, socket type", t->name);
+
+ ok(net_is_stream(sock) == t->is_stream, "%s, is stream", t->name);
+
+ close(sock);
+ }
+
+ is_int(AF_UNSPEC, net_socktype(-1), "invalid, socket type");
+ ok(!net_is_stream(-1), "invalid, is stream");
+}
+
+static void test_bind_multiple(void)
+{
+ const struct sockaddr_storage addr = addr_local();
+
+ // bind first socket
+
+ int sock_one = net_bound_socket(SOCK_DGRAM, &addr, NET_BIND_MULTIPLE, 0);
+ if (sock_one == KNOT_ENOTSUP) {
+ skip("not supported on this system");
+ return;
+ }
+ ok(sock_one >= 0, "bind first socket");
+
+ // bind second socket to the same address
+
+ const struct sockaddr_storage addr_one = addr_from_socket(sock_one);
+ int sock_two = net_bound_socket(SOCK_DGRAM, &addr_one, NET_BIND_MULTIPLE, 0);
+ ok(sock_two >= 0, "bind second socket");
+
+ // compare sockets
+
+ ok(sock_one != sock_two, "descriptors are different");
+
+ const struct sockaddr_storage addr_two = addr_from_socket(sock_two);
+ ok(sockaddr_cmp(&addr_one, &addr_two, false) == 0,
+ "addresses are the same");
+
+ close(sock_one);
+ close(sock_two);
+}
+
+static void signal_noop(int sig)
+{
+}
+
+int main(int argc, char *argv[])
+{
+ plan_lazy();
+
+ signal(SIGUSR1, signal_noop);
+
+ diag("nonblocking mode");
+ test_nonblocking_mode(SOCK_DGRAM);
+ test_nonblocking_mode(SOCK_STREAM);
+ test_nonblocking_accept();
+
+ diag("socket types");
+ test_socket_types();
+
+ diag("connected sockets");
+ test_connected(SOCK_DGRAM);
+ test_connected(SOCK_STREAM);
+
+ diag("unconnected sockets");
+ test_unconnected();
+
+ diag("refused connections");
+ test_refused();
+
+ diag("DNS messages over TCP");
+ test_dns_tcp();
+
+ diag("flag NET_BIND_MULTIPLE");
+ test_bind_multiple();
+
+ return 0;
+}