diff options
Diffstat (limited to '')
-rw-r--r-- | src/socket.c | 693 |
1 files changed, 693 insertions, 0 deletions
diff --git a/src/socket.c b/src/socket.c new file mode 100644 index 0000000..36ac292 --- /dev/null +++ b/src/socket.c @@ -0,0 +1,693 @@ +/* + * Copyright (C) 2000-2016 Free Software Foundation, Inc. + * Copyright (C) 2015-2016 Red Hat, Inc. + * + * This file is part of GnuTLS. + * + * GnuTLS is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * GnuTLS is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see <https://www.gnu.org/licenses/>. + */ + +#include <config.h> + +#if HAVE_SYS_SOCKET_H +#include <sys/socket.h> +#elif HAVE_WS2TCPIP_H +#include <ws2tcpip.h> +#endif +#include <netdb.h> +#include <string.h> +#include <errno.h> +#include <sys/select.h> +#include <sys/types.h> +#include <stdio.h> +#include <stdlib.h> +#include <unistd.h> +#include <arpa/inet.h> +#include <socket.h> +#include <c-ctype.h> +#include "sockets.h" +#include "common.h" + +#ifdef _WIN32 +# undef endservent +# define endservent() +#endif + +#define MAX_BUF 4096 + +/* Functions to manipulate sockets + */ + +ssize_t +socket_recv(const socket_st * socket, void *buffer, int buffer_size) +{ + int ret; + + if (socket->secure) { + do { + ret = + gnutls_record_recv(socket->session, buffer, + buffer_size); + if (ret == GNUTLS_E_HEARTBEAT_PING_RECEIVED) + gnutls_heartbeat_pong(socket->session, 0); + } + while (ret == GNUTLS_E_INTERRUPTED + || ret == GNUTLS_E_HEARTBEAT_PING_RECEIVED); + + } else + do { + ret = recv(socket->fd, buffer, buffer_size, 0); + } + while (ret == -1 && errno == EINTR); + + return ret; +} + +ssize_t +socket_recv_timeout(const socket_st * socket, void *buffer, int buffer_size, unsigned ms) +{ + int ret; + + if (socket->secure) + gnutls_record_set_timeout(socket->session, ms); + ret = socket_recv(socket, buffer, buffer_size); + + if (socket->secure) + gnutls_record_set_timeout(socket->session, 0); + + return ret; +} + +ssize_t +socket_send(const socket_st * socket, const void *buffer, int buffer_size) +{ + return socket_send_range(socket, buffer, buffer_size, NULL); +} + + +ssize_t +socket_send_range(const socket_st * socket, const void *buffer, + int buffer_size, gnutls_range_st * range) +{ + int ret; + + if (socket->secure) + do { + if (range == NULL) + ret = + gnutls_record_send(socket->session, + buffer, + buffer_size); + else + ret = + gnutls_record_send_range(socket-> + session, + buffer, + buffer_size, + range); + } + while (ret == GNUTLS_E_AGAIN + || ret == GNUTLS_E_INTERRUPTED); + else + do { + ret = send(socket->fd, buffer, buffer_size, 0); + } + while (ret == -1 && errno == EINTR); + + if (ret > 0 && ret != buffer_size && socket->verbose) + fprintf(stderr, + "*** Only sent %d bytes instead of %d.\n", ret, + buffer_size); + + return ret; +} + +static +ssize_t send_line(socket_st * socket, const char *txt) +{ + int len = strlen(txt); + int ret; + + if (socket->verbose) + fprintf(stderr, "starttls: sending: %s\n", txt); + + ret = send(socket->fd, txt, len, 0); + + if (ret == -1) { + fprintf(stderr, "error sending \"%s\"\n", txt); + exit(2); + } + + return ret; +} + +static +ssize_t wait_for_text(socket_st * socket, const char *txt, unsigned txt_size) +{ + char buf[1024]; + char *pbuf, *p; + int ret; + fd_set read_fds; + struct timeval tv; + size_t left, got; + + if (txt_size > sizeof(buf)) + abort(); + + if (socket->verbose && txt != NULL) + fprintf(stderr, "starttls: waiting for: \"%.*s\"\n", txt_size, txt); + + pbuf = buf; + left = sizeof(buf)-1; + got = 0; + + do { + FD_ZERO(&read_fds); + FD_SET(socket->fd, &read_fds); + tv.tv_sec = 10; + tv.tv_usec = 0; + ret = select(socket->fd + 1, &read_fds, NULL, NULL, &tv); + if (ret > 0) + ret = recv(socket->fd, pbuf, left, 0); + if (ret == -1) { + fprintf(stderr, "error receiving '%s': %s\n", txt, strerror(errno)); + exit(2); + } else if (ret == 0) { + fprintf(stderr, "error receiving '%s': Timeout\n", txt); + exit(2); + } + pbuf[ret] = 0; + + if (txt == NULL) + break; + + if (socket->verbose) + fprintf(stderr, "starttls: received: %s\n", pbuf); + + pbuf += ret; + left -= ret; + got += ret; + + + /* check for text after a newline in buffer */ + if (got > txt_size) { + p = memmem(buf, got, txt, txt_size); + if (p != NULL && p != buf) { + p--; + if (*p == '\n' || *p == '\r' || (*txt == '<' && *p == '>')) // XMPP is not line oriented, uses XML format + break; + } + } + } while(got < txt_size || strncmp(buf, txt, txt_size) != 0); + + return got; +} + +static void +socket_starttls(socket_st * socket) +{ + char buf[512]; + + if (socket->secure) + return; + + if (socket->app_proto == NULL || strcasecmp(socket->app_proto, "https") == 0) + return; + + if (strcasecmp(socket->app_proto, "smtp") == 0 || strcasecmp(socket->app_proto, "submission") == 0) { + if (socket->verbose) + log_msg(stdout, "Negotiating SMTP STARTTLS\n"); + + wait_for_text(socket, "220 ", 4); + snprintf(buf, sizeof(buf), "EHLO %s\r\n", socket->hostname); + send_line(socket, buf); + wait_for_text(socket, "250 ", 4); + send_line(socket, "STARTTLS\r\n"); + wait_for_text(socket, "220 ", 4); + } else if (strcasecmp(socket->app_proto, "imap") == 0 || strcasecmp(socket->app_proto, "imap2") == 0) { + if (socket->verbose) + log_msg(stdout, "Negotiating IMAP STARTTLS\n"); + + send_line(socket, "a CAPABILITY\r\n"); + wait_for_text(socket, "a OK", 4); + send_line(socket, "a STARTTLS\r\n"); + wait_for_text(socket, "a OK", 4); + } else if (strcasecmp(socket->app_proto, "xmpp") == 0) { + if (socket->verbose) + log_msg(stdout, "Negotiating XMPP STARTTLS\n"); + + snprintf(buf, sizeof(buf), "<stream:stream xmlns:stream='http://etherx.jabber.org/streams' xmlns='jabber:client' to='%s' version='1.0'>\n", socket->hostname); + send_line(socket, buf); + wait_for_text(socket, "<?", 2); + send_line(socket, "<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>"); + wait_for_text(socket, "<proceed", 8); + } else if (strcasecmp(socket->app_proto, "ldap") == 0) { + if (socket->verbose) + log_msg(stdout, "Negotiating LDAP STARTTLS\n"); +#define LDAP_STR "\x30\x1d\x02\x01\x01\x77\x18\x80\x16\x31\x2e\x33\x2e\x36\x2e\x31\x2e\x34\x2e\x31\x2e\x31\x34\x36\x36\x2e\x32\x30\x30\x33\x37" + send(socket->fd, LDAP_STR, sizeof(LDAP_STR)-1, 0); + wait_for_text(socket, NULL, 0); + } else if (strcasecmp(socket->app_proto, "ftp") == 0 || strcasecmp(socket->app_proto, "ftps") == 0) { + if (socket->verbose) + log_msg(stdout, "Negotiating FTP STARTTLS\n"); + + send_line(socket, "FEAT\r\n"); + wait_for_text(socket, "211 ", 4); + send_line(socket, "AUTH TLS\r\n"); + wait_for_text(socket, "234", 3); + } else if (strcasecmp(socket->app_proto, "lmtp") == 0) { + if (socket->verbose) + log_msg(stdout, "Negotiating LMTP STARTTLS\n"); + + wait_for_text(socket, "220 ", 4); + snprintf(buf, sizeof(buf), "LHLO %s\r\n", socket->hostname); + send_line(socket, buf); + wait_for_text(socket, "250 ", 4); + send_line(socket, "STARTTLS\r\n"); + wait_for_text(socket, "220 ", 4); + } else if (strcasecmp(socket->app_proto, "pop3") == 0) { + if (socket->verbose) + log_msg(stdout, "Negotiating POP3 STARTTLS\n"); + + wait_for_text(socket, "+OK", 3); + send_line(socket, "STLS\r\n"); + wait_for_text(socket, "+OK", 3); + } else if (strcasecmp(socket->app_proto, "nntp") == 0) { + if (socket->verbose) + log_msg(stdout, "Negotiating NNTP STARTTLS\n"); + + wait_for_text(socket, "200 ", 4); + send_line(socket, "STARTTLS\r\n"); + wait_for_text(socket, "382 ", 4); + } else if (strcasecmp(socket->app_proto, "sieve") == 0) { + if (socket->verbose) + log_msg(stdout, "Negotiating Sieve STARTTLS\n"); + + wait_for_text(socket, "OK ", 3); + send_line(socket, "STARTTLS\r\n"); + wait_for_text(socket, "OK ", 3); + } else if (strcasecmp(socket->app_proto, "postgres") == 0 || strcasecmp(socket->app_proto, "postgresql") == 0) { + if (socket->verbose) + log_msg(stdout, "Negotiating PostgreSQL STARTTLS\n"); + +#define POSTGRES_STR "\x00\x00\x00\x08\x04\xD2\x16\x2F" + send(socket->fd, POSTGRES_STR, sizeof(POSTGRES_STR)-1, 0); + wait_for_text(socket, NULL, 0); + } else { + if (!c_isdigit(socket->app_proto[0])) { + static int warned = 0; + if (warned == 0) { + fprintf(stderr, "unknown protocol '%s'\n", socket->app_proto); + warned = 1; + } + } + } + + return; +} + +#define CANON_SERVICE(app_proto) \ + if (strcasecmp(app_proto, "xmpp") == 0) \ + app_proto = "xmpp-server"; \ + +int +starttls_proto_to_port(const char *app_proto) +{ + struct servent *s; + + CANON_SERVICE(app_proto); + + s = getservbyname(app_proto, NULL); + if (s != NULL) { + return ntohs(s->s_port); + } + + endservent(); + + return 443; +} + +const char *starttls_proto_to_service(const char *app_proto) +{ + struct servent *s; + + CANON_SERVICE(app_proto); + + s = getservbyname(app_proto, NULL); + if (s != NULL) { + return s->s_name; + } + endservent(); + + return "443"; +} + +void socket_bye(socket_st * socket, unsigned polite) +{ + int ret; + + if (socket->secure && socket->session) { + if (polite) { + do + ret = gnutls_bye(socket->session, GNUTLS_SHUT_WR); + while (ret == GNUTLS_E_INTERRUPTED + || ret == GNUTLS_E_AGAIN); + if (socket->verbose && ret < 0) + fprintf(stderr, "*** gnutls_bye() error: %s\n", + gnutls_strerror(ret)); + } + } + + if (socket->session) { + gnutls_deinit(socket->session); + socket->session = NULL; + } + + freeaddrinfo(socket->addr_info); + socket->addr_info = socket->ptr = NULL; + socket->connect_addrlen = 0; + + free(socket->ip); + free(socket->hostname); + free(socket->service); + + shutdown(socket->fd, SHUT_RDWR); /* no more receptions */ + close(socket->fd); + + gnutls_free(socket->rdata.data); + socket->rdata.data = NULL; + + if (socket->server_trace) + fclose(socket->server_trace); + if (socket->client_trace) + fclose(socket->client_trace); + + socket->fd = -1; + socket->secure = 0; +} + +/* Handle host:port format. + */ +void canonicalize_host(char *hostname, char *service, unsigned service_size) +{ + char *p; + + if ((p = strchr(hostname, ':'))) { + unsigned char buf[64]; + + if (inet_pton(AF_INET6, hostname, buf) == 1) + return; + + *p = 0; + + if (service && service_size) + snprintf(service, service_size, "%s", p+1); + } else + p = hostname + strlen(hostname); + + if (p > hostname && p[-1] == '.') + p[-1] = 0; // remove trailing dot on FQDN +} + +static ssize_t +wrap_pull(gnutls_transport_ptr_t ptr, void *data, size_t len) +{ + socket_st *hd = ptr; + ssize_t r; + + r = recv(hd->fd, data, len, 0); + if (r > 0 && hd->server_trace) { + fwrite(data, 1, r, hd->server_trace); + } + return r; +} + +static ssize_t +wrap_push(gnutls_transport_ptr_t ptr, const void *data, size_t len) +{ + socket_st *hd = ptr; + + if (hd->client_trace) { + fwrite(data, 1, len, hd->client_trace); + } + + return send(hd->fd, data, len, 0); +} + +/* inline is used to avoid a gcc warning if used in mini-eagain */ +inline static int wrap_pull_timeout_func(gnutls_transport_ptr_t ptr, + unsigned int ms) +{ + socket_st *hd = ptr; + + return gnutls_system_recv_timeout((gnutls_transport_ptr_t)(long)hd->fd, ms); +} + + +void +socket_open2(socket_st * hd, const char *hostname, const char *service, + const char *app_proto, int flags, const char *msg, gnutls_datum_t *rdata, gnutls_datum_t *edata, + FILE *server_trace, FILE *client_trace) +{ + struct addrinfo hints, *res, *ptr; + int sd, err = 0; + int udp = flags & SOCKET_FLAG_UDP; + int ret; + int fastopen = flags & SOCKET_FLAG_FASTOPEN; + char buffer[MAX_BUF + 1]; + char portname[16] = { 0 }; + gnutls_datum_t idna; + char *a_hostname; + + memset(hd, 0, sizeof(*hd)); + + if (flags & SOCKET_FLAG_VERBOSE) + hd->verbose = 1; + + if (rdata) { + hd->rdata.data = rdata->data; + hd->rdata.size = rdata->size; + } + + if (edata) { + hd->edata.data = edata->data; + hd->edata.size = edata->size; + } + + ret = gnutls_idna_map(hostname, strlen(hostname), &idna, 0); + if (ret < 0) { + fprintf(stderr, "Cannot convert %s to IDNA: %s\n", hostname, gnutls_strerror(ret)); + exit(1); + } + + hd->hostname = strdup(hostname); + a_hostname = (char*)idna.data; + + if (msg != NULL) + log_msg(stdout, "Resolving '%s:%s'...\n", a_hostname, service); + + /* get server name */ + memset(&hints, 0, sizeof(hints)); + hints.ai_socktype = udp ? SOCK_DGRAM : SOCK_STREAM; + if ((err = getaddrinfo(a_hostname, service, &hints, &res))) { + fprintf(stderr, "Cannot resolve %s:%s: %s\n", hostname, + service, gai_strerror(err)); + exit(1); + } + + sd = -1; + for (ptr = res; ptr != NULL; ptr = ptr->ai_next) { + sd = socket(ptr->ai_family, ptr->ai_socktype, + ptr->ai_protocol); + if (sd == -1) + continue; + + if ((err = + getnameinfo(ptr->ai_addr, ptr->ai_addrlen, buffer, + MAX_BUF, portname, sizeof(portname), + NI_NUMERICHOST | NI_NUMERICSERV)) != 0) { + fprintf(stderr, "getnameinfo(): %s\n", + gai_strerror(err)); + close(sd); + continue; + } + + if (hints.ai_socktype == SOCK_DGRAM) { +#if defined(IP_DONTFRAG) + int yes = 1; + if (setsockopt(sd, IPPROTO_IP, IP_DONTFRAG, + (const void *) &yes, + sizeof(yes)) < 0) + perror("setsockopt(IP_DF) failed"); +#elif defined(IP_MTU_DISCOVER) + int yes = IP_PMTUDISC_DO; + if (setsockopt(sd, IPPROTO_IP, IP_MTU_DISCOVER, + (const void *) &yes, + sizeof(yes)) < 0) + perror("setsockopt(IP_DF) failed"); +#endif + } + + if (fastopen && ptr->ai_socktype == SOCK_STREAM + && (ptr->ai_family == AF_INET || ptr->ai_family == AF_INET6)) { + memcpy(&hd->connect_addr, ptr->ai_addr, ptr->ai_addrlen); + hd->connect_addrlen = ptr->ai_addrlen; + + if (msg) + log_msg(stdout, "%s '%s:%s' (TFO)...\n", msg, buffer, portname); + + } else { + if (msg) + log_msg(stdout, "%s '%s:%s'...\n", msg, buffer, portname); + + if ((err = connect(sd, ptr->ai_addr, ptr->ai_addrlen)) < 0) { + close(sd); + continue; + } + } + + hd->fd = sd; + if (flags & SOCKET_FLAG_STARTTLS) { + hd->app_proto = app_proto; + socket_starttls(hd); + hd->app_proto = NULL; + } + + if (!(flags & SOCKET_FLAG_SKIP_INIT)) { + hd->session = init_tls_session(hostname); + if (hd->session == NULL) { + fprintf(stderr, "error initializing session\n"); + close(sd); + exit(1); + } + } + + if (hd->session) { + if (hd->edata.data) { + ret = gnutls_record_send_early_data(hd->session, hd->edata.data, hd->edata.size); + if (ret < 0) { + fprintf(stderr, "error sending early data\n"); + close(sd); + exit(1); + } + } + if (hd->rdata.data) { + gnutls_session_set_data(hd->session, hd->rdata.data, hd->rdata.size); + } + + if (client_trace || server_trace) { + hd->server_trace = server_trace; + hd->client_trace = client_trace; + gnutls_transport_set_push_function(hd->session, wrap_push); + gnutls_transport_set_pull_function(hd->session, wrap_pull); + gnutls_transport_set_pull_timeout_function(hd->session, wrap_pull_timeout_func); + gnutls_transport_set_ptr(hd->session, hd); + } else { + gnutls_transport_set_int(hd->session, hd->fd); + } + } + + if (!(flags & SOCKET_FLAG_RAW) && !(flags & SOCKET_FLAG_SKIP_INIT)) { + err = do_handshake(hd); + if (err == GNUTLS_E_PUSH_ERROR) { /* failed connecting */ + gnutls_deinit(hd->session); + hd->session = NULL; + close(sd); + continue; + } + else if (err < 0) { + if (!(flags & SOCKET_FLAG_DONT_PRINT_ERRORS)) + fprintf(stderr, "*** handshake has failed: %s\n", gnutls_strerror(err)); + close(sd); + exit(1); + } + } + + break; + } + + if (err != 0) { + int e = errno; + fprintf(stderr, "Could not connect to %s:%s: %s\n", + buffer, portname, strerror(e)); + exit(1); + } + + if (sd == -1) { + fprintf(stderr, "Could not find a supported socket\n"); + exit(1); + } + + if ((flags & SOCKET_FLAG_RAW) || (flags & SOCKET_FLAG_SKIP_INIT)) + hd->secure = 0; + else + hd->secure = 1; + + hd->fd = sd; + hd->ip = strdup(buffer); + hd->service = strdup(portname); + hd->ptr = ptr; + hd->addr_info = res; + gnutls_free(hd->rdata.data); + hd->rdata.data = NULL; + gnutls_free(hd->edata.data); + hd->edata.data = NULL; + gnutls_free(idna.data); + return; +} + +/* converts a textual service or port to + * a service. + */ +const char *port_to_service(const char *sport, const char *proto) +{ + unsigned int port; + struct servent *sr; + + if (!c_isdigit(sport[0])) + return sport; + + port = atoi(sport); + if (port == 0) + return sport; + + port = htons(port); + + sr = getservbyport(port, proto); + if (sr == NULL) { + fprintf(stderr, + "Warning: getservbyport(%s) failed. Using port number as service.\n", sport); + return sport; + } + + return sr->s_name; +} + +int service_to_port(const char *service, const char *proto) +{ + unsigned int port; + struct servent *sr; + + port = atoi(service); + if (port != 0) + return port; + + sr = getservbyname(service, proto); + if (sr == NULL) { + fprintf(stderr, "Warning: getservbyname() failed for '%s/%s'.\n", service, proto); + exit(1); + } + + return ntohs(sr->s_port); +} |