summaryrefslogtreecommitdiffstats
path: root/src/utils/common/netio.c
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/utils/common/netio.c593
1 files changed, 593 insertions, 0 deletions
diff --git a/src/utils/common/netio.c b/src/utils/common/netio.c
new file mode 100644
index 0000000..9a27d33
--- /dev/null
+++ b/src/utils/common/netio.c
@@ -0,0 +1,593 @@
+/* Copyright (C) 2018 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+#include <arpa/inet.h>
+#include <fcntl.h>
+#include <netdb.h>
+#include <poll.h>
+#include <stdlib.h>
+#include <netinet/in.h>
+#include <sys/socket.h>
+
+#ifdef HAVE_SYS_UIO_H
+#include <sys/uio.h>
+#endif
+
+#include "utils/common/netio.h"
+#include "utils/common/msg.h"
+#include "utils/common/tls.h"
+#include "libknot/libknot.h"
+#include "contrib/sockaddr.h"
+
+srv_info_t *srv_info_create(const char *name, const char *service)
+{
+ if (name == NULL || service == NULL) {
+ DBG_NULL;
+ return NULL;
+ }
+
+ // Create output structure.
+ srv_info_t *server = calloc(1, sizeof(srv_info_t));
+
+ // Check output.
+ if (server == NULL) {
+ return NULL;
+ }
+
+ // Fill output.
+ server->name = strdup(name);
+ server->service = strdup(service);
+
+ if (server->name == NULL || server->service == NULL) {
+ srv_info_free(server);
+ return NULL;
+ }
+
+ // Return result.
+ return server;
+}
+
+void srv_info_free(srv_info_t *server)
+{
+ if (server == NULL) {
+ DBG_NULL;
+ return;
+ }
+
+ free(server->name);
+ free(server->service);
+ free(server);
+}
+
+int get_iptype(const ip_t ip)
+{
+ switch (ip) {
+ case IP_4:
+ return AF_INET;
+ case IP_6:
+ return AF_INET6;
+ default:
+ return AF_UNSPEC;
+ }
+}
+
+int get_socktype(const protocol_t proto, const uint16_t type)
+{
+ switch (proto) {
+ case PROTO_TCP:
+ return SOCK_STREAM;
+ case PROTO_UDP:
+ return SOCK_DGRAM;
+ default:
+ if (type == KNOT_RRTYPE_AXFR || type == KNOT_RRTYPE_IXFR) {
+ return SOCK_STREAM;
+ } else {
+ return SOCK_DGRAM;
+ }
+ }
+}
+
+const char *get_sockname(const int socktype)
+{
+ switch (socktype) {
+ case SOCK_STREAM:
+ return "TCP";
+ case SOCK_DGRAM:
+ return "UDP";
+ default:
+ return "UNKNOWN";
+ }
+}
+
+static int get_addr(const srv_info_t *server,
+ const int iptype,
+ const int socktype,
+ struct addrinfo **info)
+{
+ struct addrinfo hints;
+
+ // Set connection hints.
+ memset(&hints, 0, sizeof(hints));
+ hints.ai_family = iptype;
+ hints.ai_socktype = socktype;
+
+ // Get connection parameters.
+ if (getaddrinfo(server->name, server->service, &hints, info) != 0) {
+ ERR("can't resolve address %s@%s\n",
+ server->name, server->service);
+ return -1;
+ }
+
+ return 0;
+}
+
+void get_addr_str(const struct sockaddr_storage *ss,
+ const int socktype,
+ char **dst)
+{
+ char addr_str[SOCKADDR_STRLEN] = {0};
+
+ // Get network address string and port number.
+ sockaddr_tostr(addr_str, sizeof(addr_str), (struct sockaddr *)ss);
+
+ // Calculate needed buffer size
+ const char *sock_name = get_sockname(socktype);
+ size_t buflen = strlen(addr_str) + strlen(sock_name) + 3 /* () */;
+
+ // Free previous string if any and write result
+ free(*dst);
+ *dst = malloc(buflen);
+ if (*dst != NULL) {
+ int ret = snprintf(*dst, buflen, "%s(%s)", addr_str, sock_name);
+ if (ret <= 0 || ret >= buflen) {
+ **dst = '\0';
+ }
+ }
+}
+
+int net_init(const srv_info_t *local,
+ const srv_info_t *remote,
+ const int iptype,
+ const int socktype,
+ const int wait,
+ const net_flags_t flags,
+ const tls_params_t *tls_params,
+ net_t *net)
+{
+ if (remote == NULL || net == NULL) {
+ DBG_NULL;
+ return KNOT_EINVAL;
+ }
+
+ // Clean network structure.
+ memset(net, 0, sizeof(*net));
+
+ // Get remote address list.
+ if (get_addr(remote, iptype, socktype, &net->remote_info) != 0) {
+ net_clean(net);
+ return KNOT_NET_EADDR;
+ }
+
+ // Set current remote address.
+ net->srv = net->remote_info;
+
+ // Get local address if specified.
+ if (local != NULL) {
+ if (get_addr(local, iptype, socktype, &net->local_info) != 0) {
+ net_clean(net);
+ return KNOT_NET_EADDR;
+ }
+ }
+
+ // Store network parameters.
+ net->iptype = iptype;
+ net->socktype = socktype;
+ net->wait = wait;
+ net->local = local;
+ net->remote = remote;
+ net->flags = flags;
+
+ // Prepare for TLS.
+ if (tls_params != NULL && tls_params->enable) {
+ int ret = tls_ctx_init(&net->tls, tls_params, net->wait);
+ if (ret != KNOT_EOK) {
+ net_clean(net);
+ return ret;
+ }
+ }
+
+ return KNOT_EOK;
+}
+
+/**
+ * Connect with TCP Fast Open.
+ */
+static int fastopen_connect(int sockfd, const struct addrinfo *srv)
+{
+#if __APPLE__
+ // connection is performed lazily when first data are sent
+ struct sa_endpoints ep = {0};
+ ep.sae_dstaddr = srv->ai_addr;
+ ep.sae_dstaddrlen = srv->ai_addrlen;
+ int flags = CONNECT_DATA_IDEMPOTENT|CONNECT_RESUME_ON_READ_WRITE;
+
+ return connectx(sockfd, &ep, SAE_ASSOCID_ANY, flags, NULL, 0, NULL, NULL);
+#elif defined(MSG_FASTOPEN) // Linux with RFC 7413
+ // connect() will be called implicitly with sendto(), sendmsg()
+ return 0;
+#else
+ errno = ENOTSUP;
+ return -1;
+#endif
+}
+
+/**
+ * Sends data with TCP Fast Open.
+ */
+static int fastopen_send(int sockfd, const struct msghdr *msg, int timeout)
+{
+#if __APPLE__
+ return sendmsg(sockfd, msg, 0);
+#elif defined(MSG_FASTOPEN)
+ int ret = sendmsg(sockfd, msg, MSG_FASTOPEN);
+ if (ret == -1 && errno == EINPROGRESS) {
+ struct pollfd pfd = {
+ .fd = sockfd,
+ .events = POLLOUT,
+ .revents = 0,
+ };
+ if (poll(&pfd, 1, 1000 * timeout) != 1) {
+ errno = ETIMEDOUT;
+ return -1;
+ }
+ ret = sendmsg(sockfd, msg, 0);
+ }
+ return ret;
+#else
+ errno = ENOTSUP;
+ return -1;
+#endif
+}
+
+int net_connect(net_t *net)
+{
+ if (net == NULL || net->srv == NULL) {
+ DBG_NULL;
+ return KNOT_EINVAL;
+ }
+
+ // Set remote information string.
+ get_addr_str((struct sockaddr_storage *)net->srv->ai_addr,
+ net->socktype, &net->remote_str);
+
+ // Create socket.
+ int sockfd = socket(net->srv->ai_family, net->socktype, 0);
+ if (sockfd == -1) {
+ WARN("can't create socket for %s\n", net->remote_str);
+ return KNOT_NET_ESOCKET;
+ }
+
+ // Initialize poll descriptor structure.
+ struct pollfd pfd = {
+ .fd = sockfd,
+ .events = POLLOUT,
+ .revents = 0,
+ };
+
+ // Set non-blocking socket.
+ if (fcntl(sockfd, F_SETFL, O_NONBLOCK) == -1) {
+ WARN("can't set non-blocking socket for %s\n", net->remote_str);
+ return KNOT_NET_ESOCKET;
+ }
+
+ // Bind address to socket if specified.
+ if (net->local_info != NULL) {
+ if (bind(sockfd, net->local_info->ai_addr,
+ net->local_info->ai_addrlen) == -1) {
+ WARN("can't assign address %s\n", net->local->name);
+ return KNOT_NET_ESOCKET;
+ }
+ }
+
+ if (net->socktype == SOCK_STREAM) {
+ int cs, err, ret = 0;
+ socklen_t err_len = sizeof(err);
+ bool fastopen = net->flags & NET_FLAGS_FASTOPEN;
+
+ // Connect using socket.
+ if (fastopen) {
+ ret = fastopen_connect(sockfd, net->srv);
+ } else {
+ ret = connect(sockfd, net->srv->ai_addr, net->srv->ai_addrlen);
+ }
+ if (ret != 0 && errno != EINPROGRESS) {
+ WARN("can't connect to %s\n", net->remote_str);
+ close(sockfd);
+ return KNOT_NET_ECONNECT;
+ }
+
+ // Check for connection timeout.
+ if (!fastopen && poll(&pfd, 1, 1000 * net->wait) != 1) {
+ WARN("connection timeout for %s\n", net->remote_str);
+ close(sockfd);
+ return KNOT_NET_ECONNECT;
+ }
+
+ // Check if NB socket is writeable.
+ cs = getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &err, &err_len);
+ if (cs < 0 || err != 0) {
+ WARN("can't connect to %s\n", net->remote_str);
+ close(sockfd);
+ return KNOT_NET_ECONNECT;
+ }
+
+ // Establish TLS connection.
+ if (net->tls.params != NULL) {
+ int ret = tls_ctx_connect(&net->tls, sockfd, net->tls.params->sni);
+ if (ret != KNOT_EOK) {
+ close(sockfd);
+ return ret;
+ }
+ }
+ }
+
+ // Store socket descriptor.
+ net->sockfd = sockfd;
+
+ return KNOT_EOK;
+}
+
+int net_set_local_info(net_t *net)
+{
+ if (net == NULL) {
+ DBG_NULL;
+ return KNOT_EINVAL;
+ }
+
+ if (net->local_info != NULL) {
+ free(net->local_info->ai_addr);
+ freeaddrinfo(net->local_info);
+ }
+
+ socklen_t local_addr_len = sizeof(struct sockaddr_storage);
+ struct sockaddr_storage *local_addr = calloc(1, local_addr_len);
+
+ if (getsockname(net->sockfd, (struct sockaddr *)local_addr,
+ &local_addr_len) == -1) {
+ WARN("can't get local address\n");
+ free(local_addr);
+ return KNOT_NET_ESOCKET;
+ }
+
+ net->local_info = calloc(1, sizeof(struct addrinfo));
+ net->local_info->ai_family = net->srv->ai_family;
+ net->local_info->ai_socktype = net->srv->ai_socktype;
+ net->local_info->ai_protocol = net->srv->ai_protocol;
+ net->local_info->ai_addrlen = local_addr_len;
+ net->local_info->ai_addr = (struct sockaddr *)local_addr;
+
+ get_addr_str((struct sockaddr_storage *)net->local_info->ai_addr,
+ net->socktype, &net->local_str);
+
+ return KNOT_EOK;
+}
+
+int net_send(const net_t *net, const uint8_t *buf, const size_t buf_len)
+{
+ if (net == NULL || buf == NULL) {
+ DBG_NULL;
+ return KNOT_EINVAL;
+ }
+
+ // Send data over UDP.
+ if (net->socktype == SOCK_DGRAM) {
+ if (sendto(net->sockfd, buf, buf_len, 0, net->srv->ai_addr,
+ net->srv->ai_addrlen) != (ssize_t)buf_len) {
+ WARN("can't send query to %s\n", net->remote_str);
+ return KNOT_NET_ESEND;
+ }
+ // Send data over TLS.
+ } else if (net->tls.params != NULL) {
+ int ret = tls_ctx_send((tls_ctx_t *)&net->tls, buf, buf_len);
+ if (ret != KNOT_EOK) {
+ WARN("can't send query to %s\n", net->remote_str);
+ return KNOT_NET_ESEND;
+ }
+ // Send data over TCP.
+ } else {
+ bool fastopen = net->flags & NET_FLAGS_FASTOPEN;
+
+ // Leading packet length bytes.
+ uint16_t pktsize = htons(buf_len);
+
+ struct iovec iov[2];
+ iov[0].iov_base = &pktsize;
+ iov[0].iov_len = sizeof(pktsize);
+ iov[1].iov_base = (uint8_t *)buf;
+ iov[1].iov_len = buf_len;
+
+ // Compute packet total length.
+ ssize_t total = iov[0].iov_len + iov[1].iov_len;
+
+ struct msghdr msg = {0};
+ msg.msg_iov = iov;
+ msg.msg_iovlen = sizeof(iov) / sizeof(*iov);
+ msg.msg_name = net->srv->ai_addr;
+ msg.msg_namelen = net->srv->ai_addrlen;
+
+ int ret = 0;
+ if (fastopen) {
+ ret = fastopen_send(net->sockfd, &msg, net->wait);
+ } else {
+ ret = sendmsg(net->sockfd, &msg, 0);
+ }
+ if (ret != total) {
+ WARN("can't send query to %s\n", net->remote_str);
+ return KNOT_NET_ESEND;
+ }
+ }
+
+ return KNOT_EOK;
+}
+
+int net_receive(const net_t *net, uint8_t *buf, const size_t buf_len)
+{
+ if (net == NULL || buf == NULL) {
+ DBG_NULL;
+ return KNOT_EINVAL;
+ }
+
+ // Initialize poll descriptor structure.
+ struct pollfd pfd = {
+ .fd = net->sockfd,
+ .events = POLLIN,
+ .revents = 0,
+ };
+
+ // Receive data over UDP.
+ if (net->socktype == SOCK_DGRAM) {
+ struct sockaddr_storage from;
+ memset(&from, '\0', sizeof(from));
+
+ // Receive replies unless correct reply or timeout.
+ while (true) {
+ socklen_t from_len = sizeof(from);
+
+ // Wait for datagram data.
+ if (poll(&pfd, 1, 1000 * net->wait) != 1) {
+ WARN("response timeout for %s\n",
+ net->remote_str);
+ return KNOT_NET_ETIMEOUT;
+ }
+
+ // Receive whole UDP datagram.
+ ssize_t ret = recvfrom(net->sockfd, buf, buf_len, 0,
+ (struct sockaddr *)&from, &from_len);
+ if (ret <= 0) {
+ WARN("can't receive reply from %s\n",
+ net->remote_str);
+ return KNOT_NET_ERECV;
+ }
+
+ // Compare reply address with the remote one.
+ if (from_len > sizeof(from) ||
+ memcmp(&from, net->srv->ai_addr, from_len) != 0) {
+ char *src = NULL;
+ get_addr_str(&from, net->socktype, &src);
+ WARN("unexpected reply source %s\n", src);
+ free(src);
+ continue;
+ }
+
+ return ret;
+ }
+ // Receive data over TLS.
+ } else if (net->tls.params != NULL) {
+ int ret = tls_ctx_receive((tls_ctx_t *)&net->tls, buf, buf_len);
+ if (ret < 0) {
+ WARN("can't receive reply from %s\n", net->remote_str);
+ return KNOT_NET_ERECV;
+ }
+
+ return ret;
+ // Receive data over TCP.
+ } else {
+ uint32_t total = 0;
+
+ uint16_t msg_len = 0;
+ // Receive TCP message header.
+ while (total < sizeof(msg_len)) {
+ if (poll(&pfd, 1, 1000 * net->wait) != 1) {
+ WARN("response timeout for %s\n",
+ net->remote_str);
+ return KNOT_NET_ETIMEOUT;
+ }
+
+ // Receive piece of message.
+ ssize_t ret = recv(net->sockfd, (uint8_t *)&msg_len + total,
+ sizeof(msg_len) - total, 0);
+ if (ret <= 0) {
+ WARN("can't receive reply from %s\n",
+ net->remote_str);
+ return KNOT_NET_ERECV;
+ }
+ total += ret;
+ }
+
+ // Convert number to host format.
+ msg_len = ntohs(msg_len);
+ if (msg_len > buf_len) {
+ return KNOT_ESPACE;
+ }
+
+ total = 0;
+
+ // Receive whole answer message by parts.
+ while (total < msg_len) {
+ if (poll(&pfd, 1, 1000 * net->wait) != 1) {
+ WARN("response timeout for %s\n",
+ net->remote_str);
+ return KNOT_NET_ETIMEOUT;
+ }
+
+ // Receive piece of message.
+ ssize_t ret = recv(net->sockfd, buf + total, msg_len - total, 0);
+ if (ret <= 0) {
+ WARN("can't receive reply from %s\n",
+ net->remote_str);
+ return KNOT_NET_ERECV;
+ }
+ total += ret;
+ }
+
+ return total;
+ }
+
+ return KNOT_NET_ERECV;
+}
+
+void net_close(net_t *net)
+{
+ if (net == NULL) {
+ DBG_NULL;
+ return;
+ }
+
+ tls_ctx_close(&net->tls);
+ close(net->sockfd);
+ net->sockfd = -1;
+}
+
+void net_clean(net_t *net)
+{
+ if (net == NULL) {
+ DBG_NULL;
+ return;
+ }
+
+ free(net->local_str);
+ free(net->remote_str);
+
+ if (net->local_info != NULL) {
+ freeaddrinfo(net->local_info);
+ }
+
+ if (net->remote_info != NULL) {
+ freeaddrinfo(net->remote_info);
+ }
+
+ tls_ctx_deinit(&net->tls);
+}