/* Copyright (C) 2023 CZ.NIC, z.s.p.o. This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . */ #include #include #include #include #include #include #include #include #include #include "libknot/errcode.h" #include "contrib/net.h" #include "contrib/sockaddr.h" const int TIMEOUT = 5000; const int TIMEOUT_SHORT = 500; /*! * \brief Get loopback socket address with unset port. */ static struct sockaddr_storage addr_local(void) { struct sockaddr_storage addr = { 0 }; struct sockaddr_in *addr4 = (struct sockaddr_in *)&addr; addr4->sin_family = AF_INET; addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK); return addr; } /*! * \brief Get address of a socket. */ static struct sockaddr_storage addr_from_socket(int sock) { struct sockaddr_storage addr = { 0 }; socklen_t len = sizeof(addr); int ret = getsockname(sock, (struct sockaddr *)&addr, &len); is_int(0, ret, "check getsockname return"); return addr; } static const char *socktype_name(int type) { switch (type) { case SOCK_STREAM: return "TCP"; case SOCK_DGRAM: return "UDP"; default: return "unknown"; } } static bool socktype_is_stream(int type) { return type == SOCK_STREAM; } /* -- mock server ---------------------------------------------------------- */ #define LISTEN_BACKLOG 5 struct server_ctx; typedef struct server_ctx server_ctx_t; typedef void (*server_cb)(int sock, void *data); /*! * \brief Server context. */ struct server_ctx { int sock; int type; bool terminate; server_cb handler; void *handler_data; pthread_t thr; pthread_mutex_t mx; }; static int poll_read(int sock) { struct pollfd pfd = { .fd = sock, .events = POLLIN }; return poll(&pfd, 1, TIMEOUT); } static void server_handle(server_ctx_t *ctx) { int remote = ctx->sock; assert(ctx->type == SOCK_STREAM || ctx->type == SOCK_DGRAM); if (socktype_is_stream(ctx->type)) { remote = accept(ctx->sock, 0, 0); if (remote < 0) { return; } } pthread_mutex_lock(&ctx->mx); server_cb handler = ctx->handler; pthread_mutex_unlock(&ctx->mx); handler(remote, ctx->handler_data); if (socktype_is_stream(ctx->type)) { close(remote); } } /*! * \brief Simple server. * * Terminated when a one-byte message is delivered. */ static void *server_main(void *_ctx) { server_ctx_t *ctx = _ctx; for (;;) { pthread_mutex_lock(&ctx->mx); bool terminate = ctx->terminate; pthread_mutex_unlock(&ctx->mx); if (terminate) { break; } int r = poll_read(ctx->sock); if (r == -1) { if (errno == EINTR) { continue; } else { break; } } else if (r == 0) { continue; } assert(r == 1); server_handle(ctx); } return NULL; } static bool server_start(server_ctx_t *ctx, int sock, int type, server_cb handler, void *handler_data) { memset(ctx, 0, sizeof(*ctx)); ctx->sock = sock; ctx->type = type; ctx->handler = handler; ctx->handler_data = handler_data; ctx->terminate = false; pthread_mutex_init(&ctx->mx, NULL); return (pthread_create(&ctx->thr, NULL, server_main, ctx) == 0); } static void server_stop(server_ctx_t *ctx) { pthread_mutex_lock(&ctx->mx); ctx->terminate = true; pthread_mutex_unlock(&ctx->mx); pthread_kill(ctx->thr, SIGUSR1); pthread_join(ctx->thr, NULL); } /* -- tests ---------------------------------------------------------------- */ static void handler_echo(int sock, void *_server) { server_ctx_t *server = _server; uint8_t buffer[16] = { 0 }; struct sockaddr_storage remote = { 0 }; struct sockaddr_storage *addr = NULL; if (!socktype_is_stream(server->type)) { addr = &remote; } int in = net_base_recv(sock, buffer, sizeof(buffer), addr, TIMEOUT); if (in <= 0) { return; } net_base_send(sock, buffer, in, addr, TIMEOUT); } static void test_connected_one(const struct sockaddr_storage *server_addr, const struct sockaddr_storage *source_addr, int type, const char *name, const char *addr_name) { int r; int client = net_connected_socket(type, server_addr, source_addr, false); ok(client >= 0, "%s, %s: client, create connected socket", name, addr_name); const uint8_t out[] = "test message"; const size_t out_len = sizeof(out); if (socktype_is_stream(type)) { r = net_stream_send(client, out, out_len, TIMEOUT); } else { r = net_dgram_send(client, out, out_len, NULL); } is_int(out_len, r, "%s, %s: client, send message", name, addr_name); r = net_is_connected(client); ok(r, "%s, %s: client, is connected", name, addr_name); uint8_t in[128] = { 0 }; if (socktype_is_stream(type)) { r = net_stream_recv(client, in, sizeof(in), TIMEOUT); } else { r = net_dgram_recv(client, in, sizeof(in), TIMEOUT); } is_int(out_len, r, "%s, %s: client, receive message length", name, addr_name); ok(memcmp(out, in, out_len) == 0, "%s, %s: client, receive message", name, addr_name); close(client); } static void test_connected(int type) { const char *name = socktype_name(type); const struct sockaddr_storage empty_addr = { 0 }; const struct sockaddr_storage local_addr = addr_local(); int r; // setup server int server = net_bound_socket(type, &local_addr, 0, 0); ok(server >= 0, "%s: server, create bound socket", name); if (socktype_is_stream(type)) { r = listen(server, LISTEN_BACKLOG); is_int(0, r, "%s: server, start listening", name); } server_ctx_t server_ctx = { 0 }; r = server_start(&server_ctx, server, type, handler_echo, &server_ctx); ok(r, "%s: server, start", name); const struct sockaddr_storage server_addr = addr_from_socket(server); // connected socket, send and receive test_connected_one(&server_addr, NULL, type, name, "NULL source"); test_connected_one(&server_addr, &empty_addr, type, name, "zero source"); test_connected_one(&server_addr, &local_addr, type, name, "valid source"); // cleanup server_stop(&server_ctx); close(server); } static void handler_noop(int sock, void *data) { } static void test_unconnected(void) { int r = 0; int sock = -1; const struct sockaddr_storage local = addr_local(); uint8_t buffer[] = { 'k', 'n', 'o', 't' }; ssize_t buffer_len = sizeof(buffer); // server int server = net_bound_socket(SOCK_DGRAM, &local, 0, 0); ok(server >= 0, "UDP, create server socket"); server_ctx_t server_ctx = { 0 }; r = server_start(&server_ctx, server, SOCK_DGRAM, handler_noop, NULL); ok(r, "UDP, start server"); // UDP sock = net_unbound_socket(SOCK_DGRAM, &local); ok(sock >= 0, "UDP, create unbound socket"); ok(!net_is_connected(sock), "UDP, is not connected"); r = net_dgram_send(sock, buffer, buffer_len, NULL); is_int(KNOT_ECONN, r, "UDP, send failure on unconnected socket"); r = net_dgram_recv(sock, buffer, buffer_len, TIMEOUT_SHORT); is_int(KNOT_ETIMEOUT, r, "UDP, receive timeout on unconnected socket"); struct sockaddr_storage server_addr = addr_from_socket(server); r = net_dgram_send(sock, buffer, buffer_len, &server_addr); is_int(buffer_len, r, "UDP, send on defined address"); close(sock); // TCP sock = net_unbound_socket(SOCK_STREAM, &local); ok(sock >= 0, "TCP, create unbound socket"); ok(!net_is_connected(sock), "TCP, is not connected"); #ifdef __linux__ const int expected = KNOT_ECONN; const char *expected_msg = "failure"; const int expected_timeout = TIMEOUT; #else const int expected = KNOT_ETIMEOUT; const char *expected_msg = "timeout"; const int expected_timeout = TIMEOUT_SHORT; #endif r = net_stream_send(sock, buffer, buffer_len, expected_timeout); is_int(expected, r, "TCP, send %s on unconnected socket", expected_msg); r = net_stream_recv(sock, buffer, sizeof(buffer), expected_timeout); is_int(expected, r, "TCP, receive %s on unconnected socket", expected_msg); close(sock); // server termination server_stop(&server_ctx); close(server); } static void test_refused(void) { int r = -1; struct sockaddr_storage addr = { 0 }; uint8_t buffer[1] = { 0 }; int server, client; // listening, not accepting addr = addr_local(); server = net_bound_socket(SOCK_STREAM, &addr, 0, 0); ok(server >= 0, "server, create server"); addr = addr_from_socket(server); r = listen(server, LISTEN_BACKLOG); is_int(0, r, "server, start listening"); client = net_connected_socket(SOCK_STREAM, &addr, NULL, false); ok(client >= 0, "client, connect"); r = net_stream_send(client, (uint8_t *)"", 1, TIMEOUT); is_int(1, r, "client, successful write"); r = net_stream_recv(client, buffer, sizeof(buffer), TIMEOUT_SHORT); is_int(KNOT_ETIMEOUT, r, "client, timeout on read"); close(client); // listening, closed immediately client = net_connected_socket(SOCK_STREAM, &addr, NULL, false); ok(client >= 0, "client, connect"); r = close(server); is_int(0, r, "server, close socket"); usleep(50000); r = net_stream_send(client, (uint8_t *)"", 1, TIMEOUT); is_int(KNOT_ECONN, r, "client, refused on write"); close(client); } struct dns_handler_ctx { const uint8_t *expected; int len; bool raw; bool success; }; static bool _sync(int remote, int send) { uint8_t buf[1] = { 0 }; int r; if (send) { r = net_stream_send(remote, buf, sizeof(buf), TIMEOUT); } else { r = net_stream_recv(remote, buf, sizeof(buf), TIMEOUT); } return r == sizeof(buf); } static bool sync_signal(int remote) { return _sync(remote, true); } static bool sync_wait(int remote) { return _sync(remote, false); } static void handler_dns(int sock, void *_ctx) { struct dns_handler_ctx *ctx = _ctx; uint8_t in[16] = { 0 }; int in_len = 0; sync_signal(sock); if (ctx->raw) { in_len = net_stream_recv(sock, in, sizeof(in), TIMEOUT); } else { in_len = net_dns_tcp_recv(sock, in, sizeof(in), TIMEOUT); } ctx->success = in_len == ctx->len && (ctx->len < 0 || memcmp(in, ctx->expected, in_len) == 0); } static void dns_send_hello(int sock) { net_dns_tcp_send(sock, (uint8_t *)"wimbgunts", 9, TIMEOUT, NULL); } static void dns_send_fragmented(int sock) { struct fragment { const uint8_t *data; size_t len; }; const struct fragment fragments[] = { { (uint8_t *)"\x00", 1 }, { (uint8_t *)"\x08""qu", 3 }, { (uint8_t *)"oopisk", 6 }, { NULL } }; for (const struct fragment *f = fragments; f->len > 0; f++) { net_stream_send(sock, f->data, f->len, TIMEOUT); } } static void dns_send_incomplete(int sock) { net_stream_send(sock, (uint8_t *)"\x00\x08""korm", 6, TIMEOUT); } static void dns_send_trailing(int sock) { net_stream_send(sock, (uint8_t *)"\x00\x05""bloitxx", 9, TIMEOUT); } static void test_dns_tcp(void) { struct testcase { const char *name; const uint8_t *expected; size_t expected_len; bool expected_raw; void (*send_callback)(int sock); }; const struct testcase testcases[] = { { "single DNS", (uint8_t *)"wimbgunts", 9, false, dns_send_hello }, { "single RAW", (uint8_t *)"\x00\x09""wimbgunts", 11, true, dns_send_hello }, { "fragmented", (uint8_t *)"quoopisk", 8, false, dns_send_fragmented }, { "incomplete", NULL, KNOT_ECONN, false, dns_send_incomplete }, { "trailing garbage", (uint8_t *)"bloit", 5, false, dns_send_trailing }, { NULL } }; for (const struct testcase *t = testcases; t->name != NULL; t++) { struct dns_handler_ctx handler_ctx = { .expected = t->expected, .len = t->expected_len, .raw = t->expected_raw, .success = false }; struct sockaddr_storage addr = addr_local(); int server = net_bound_socket(SOCK_STREAM, &addr, 0, 0); ok(server >= 0, "%s, server, create socket", t->name); int r = listen(server, LISTEN_BACKLOG); is_int(0, r, "%s, server, start listening", t->name); server_ctx_t server_ctx = { 0 }; r = server_start(&server_ctx, server, SOCK_STREAM, handler_dns, &handler_ctx); ok(r, "%s, server, start handler", t->name); addr = addr_from_socket(server); int client = net_connected_socket(SOCK_STREAM, &addr, NULL, false); ok(client >= 0, "%s, client, create connected socket", t->name); r = sync_wait(client); ok(r, "%s, client, wait for stream read", t->name); t->send_callback(client); close(client); server_stop(&server_ctx); close(server); ok(handler_ctx.success, "%s, expected result", t->name); } } static bool socket_is_blocking(int sock) { return fcntl(sock, F_GETFL, O_NONBLOCK) == 0; } static void test_nonblocking_mode(int type) { const char *name = socktype_name(type); const struct sockaddr_storage addr = addr_local(); int client = net_unbound_socket(type, &addr); ok(client >= 0, "%s: unbound, create", name); ok(!socket_is_blocking(client), "%s: unbound, nonblocking mode", name); close(client); int server = net_bound_socket(type, &addr, 0, 0); ok(server >= 0, "%s: bound, create", name); ok(!socket_is_blocking(server), "%s: bound, nonblocking mode", name); if (socktype_is_stream(type)) { int r = listen(server, LISTEN_BACKLOG); is_int(0, r, "%s: bound, start listening", name); } struct sockaddr_storage server_addr = addr_from_socket(server); client = net_connected_socket(type, &server_addr, NULL, false); ok(client >= 0, "%s: connected, create", name); ok(!socket_is_blocking(client), "%s: connected, nonblocking mode", name); close(client); close(server); } static void test_nonblocking_accept(void) { int r; // create server struct sockaddr_storage addr_server = addr_local(); int server = net_bound_socket(SOCK_STREAM, &addr_server, 0, 0); ok(server >= 0, "server, create socket"); r = listen(server, LISTEN_BACKLOG); is_int(0, r, "server, start listening"); addr_server = addr_from_socket(server); // create client int client = net_connected_socket(SOCK_STREAM, &addr_server, NULL, false); ok(client >= 0, "client, create connected socket"); struct sockaddr_storage addr_client = addr_from_socket(client); // accept connection r = poll_read(server); is_int(1, r, "server, pending connection"); struct sockaddr_storage addr_accepted = { 0 }; int accepted = net_accept(server, &addr_accepted); ok(accepted >= 0, "server, accept connection"); ok(!socket_is_blocking(accepted), "accepted, nonblocking mode"); ok(sockaddr_cmp(&addr_client, &addr_accepted, false) == 0, "accepted, correct address"); close(client); // client reconnect close(client); client = net_connected_socket(SOCK_STREAM, &addr_server, NULL, false); ok(client >= 0, "client, reconnect"); r = poll_read(server); is_int(1, r, "server, pending connection"); accepted = net_accept(server, NULL); ok(accepted >= 0, "server, accept connection (no remote address)"); ok(!socket_is_blocking(accepted), "accepted, nonblocking mode"); // cleanup close(client); close(server); } static void test_socket_types(void) { struct sockaddr_storage addr = addr_local(); struct testcase { const char *name; int type; bool is_stream; }; const struct testcase testcases[] = { { "UDP", SOCK_DGRAM, false }, { "TCP", SOCK_STREAM, true }, { NULL } }; for (const struct testcase *t = testcases; t->name != NULL; t++) { int sock = net_unbound_socket(t->type, &addr); ok(sock >= 0, "%s, create socket", t->name); is_int(t->type, net_socktype(sock), "%s, socket type", t->name); ok(net_is_stream(sock) == t->is_stream, "%s, is stream", t->name); close(sock); } is_int(AF_UNSPEC, net_socktype(-1), "invalid, socket type"); ok(!net_is_stream(-1), "invalid, is stream"); } static void test_bind_multiple(void) { const struct sockaddr_storage addr = addr_local(); // bind first socket int sock_one = net_bound_socket(SOCK_DGRAM, &addr, NET_BIND_MULTIPLE, 0); if (sock_one == KNOT_ENOTSUP) { skip("not supported on this system"); return; } ok(sock_one >= 0, "bind first socket"); // bind second socket to the same address const struct sockaddr_storage addr_one = addr_from_socket(sock_one); int sock_two = net_bound_socket(SOCK_DGRAM, &addr_one, NET_BIND_MULTIPLE, 0); ok(sock_two >= 0, "bind second socket"); // compare sockets ok(sock_one != sock_two, "descriptors are different"); const struct sockaddr_storage addr_two = addr_from_socket(sock_two); ok(sockaddr_cmp(&addr_one, &addr_two, false) == 0, "addresses are the same"); close(sock_one); close(sock_two); } static void signal_noop(int sig) { } int main(int argc, char *argv[]) { plan_lazy(); signal(SIGUSR1, signal_noop); diag("nonblocking mode"); test_nonblocking_mode(SOCK_DGRAM); test_nonblocking_mode(SOCK_STREAM); test_nonblocking_accept(); diag("socket types"); test_socket_types(); diag("connected sockets"); test_connected(SOCK_DGRAM); test_connected(SOCK_STREAM); diag("unconnected sockets"); test_unconnected(); diag("refused connections"); test_refused(); diag("DNS messages over TCP"); test_dns_tcp(); diag("flag NET_BIND_MULTIPLE"); test_bind_multiple(); return 0; }