diff options
Diffstat (limited to 'src/basic/socket-util.c')
-rw-r--r-- | src/basic/socket-util.c | 138 |
1 files changed, 106 insertions, 32 deletions
diff --git a/src/basic/socket-util.c b/src/basic/socket-util.c index beb64d8..6e304e8 100644 --- a/src/basic/socket-util.c +++ b/src/basic/socket-util.c @@ -1,9 +1,10 @@ /* SPDX-License-Identifier: LGPL-2.1-or-later */ +/* Make sure the net/if.h header is included before any linux/ one */ +#include <net/if.h> #include <arpa/inet.h> #include <errno.h> #include <limits.h> -#include <net/if.h> #include <netdb.h> #include <netinet/ip.h> #include <poll.h> @@ -453,6 +454,7 @@ int sockaddr_pretty( assert(sa); assert(salen >= sizeof(sa->sa.sa_family)); + assert(ret); switch (sa->sa.sa_family) { @@ -547,7 +549,7 @@ int sockaddr_pretty( } else { if (path[path_len - 1] == '\0') /* We expect a terminating NUL and don't print it */ - path_len --; + path_len--; p = cescape_length(path, path_len); } @@ -628,29 +630,27 @@ int getsockname_pretty(int fd, char **ret) { return sockaddr_pretty(&sa.sa, salen, false, true, ret); } -int socknameinfo_pretty(union sockaddr_union *sa, socklen_t salen, char **_ret) { +int socknameinfo_pretty(const struct sockaddr *sa, socklen_t salen, char **ret) { + char host[NI_MAXHOST]; int r; - char host[NI_MAXHOST], *ret; - assert(_ret); + assert(sa); + assert(salen >= sizeof(sa_family_t)); + assert(ret); - r = getnameinfo(&sa->sa, salen, host, sizeof(host), NULL, 0, IDN_FLAGS); + r = getnameinfo(sa, salen, host, sizeof(host), /* service= */ NULL, /* service_len= */ 0, IDN_FLAGS); if (r != 0) { - int saved_errno = errno; - - r = sockaddr_pretty(&sa->sa, salen, true, true, &ret); - if (r < 0) - return r; + if (r == EAI_MEMORY) + return log_oom_debug(); + if (r == EAI_SYSTEM) + log_debug_errno(errno, "getnameinfo() failed, ignoring: %m"); + else + log_debug("getnameinfo() failed, ignoring: %s", gai_strerror(r)); - log_debug_errno(saved_errno, "getnameinfo(%s) failed: %m", ret); - } else { - ret = strdup(host); - if (!ret) - return -ENOMEM; + return sockaddr_pretty(sa, salen, /* translate_ipv6= */ true, /* include_port= */ true, ret); } - *_ret = ret; - return 0; + return strdup_to(ret, host); } static const char* const netlink_family_table[] = { @@ -872,13 +872,11 @@ bool address_label_valid(const char *p) { int getpeercred(int fd, struct ucred *ucred) { socklen_t n = sizeof(struct ucred); struct ucred u; - int r; assert(fd >= 0); assert(ucred); - r = getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &u, &n); - if (r < 0) + if (getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &u, &n) < 0) return -errno; if (n != sizeof(struct ucred)) @@ -907,8 +905,10 @@ int getpeersec(int fd, char **ret) { if (!s) return -ENOMEM; - if (getsockopt(fd, SOL_SOCKET, SO_PEERSEC, s, &n) >= 0) + if (getsockopt(fd, SOL_SOCKET, SO_PEERSEC, s, &n) >= 0) { + s[n] = 0; break; + } if (errno != ERANGE) return -errno; @@ -925,12 +925,16 @@ int getpeersec(int fd, char **ret) { } int getpeergroups(int fd, gid_t **ret) { - socklen_t n = sizeof(gid_t) * 64; + socklen_t n = sizeof(gid_t) * 64U; _cleanup_free_ gid_t *d = NULL; assert(fd >= 0); assert(ret); + long ngroups_max = sysconf(_SC_NGROUPS_MAX); + if (ngroups_max > 0) + n = MAX(n, sizeof(gid_t) * (socklen_t) ngroups_max); + for (;;) { d = malloc(n); if (!d) @@ -948,7 +952,7 @@ int getpeergroups(int fd, gid_t **ret) { assert_se(n % sizeof(gid_t) == 0); n /= sizeof(gid_t); - if ((socklen_t) (int) n != n) + if (n > INT_MAX) return -E2BIG; *ret = TAKE_PTR(d); @@ -956,6 +960,21 @@ int getpeergroups(int fd, gid_t **ret) { return (int) n; } +int getpeerpidfd(int fd) { + socklen_t n = sizeof(int); + int pidfd = -EBADF; + + assert(fd >= 0); + + if (getsockopt(fd, SOL_SOCKET, SO_PEERPIDFD, &pidfd, &n) < 0) + return -errno; + + if (n != sizeof(int)) + return -EIO; + + return pidfd; +} + ssize_t send_many_fds_iov_sa( int transport_fd, int *fds_array, size_t n_fds_array, @@ -1093,14 +1112,10 @@ ssize_t receive_many_fds_iov( if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) { size_t n = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int); - fds_array = GREEDY_REALLOC(fds_array, n_fds_array + n); - if (!fds_array) { + if (!GREEDY_REALLOC_APPEND(fds_array, n_fds_array, CMSG_TYPED_DATA(cmsg, int), n)) { cmsg_close_all(&mh); return -ENOMEM; } - - memcpy(fds_array + n_fds_array, CMSG_TYPED_DATA(cmsg, int), sizeof(int) * n); - n_fds_array += n; } if (n_fds_array == 0) { @@ -1641,6 +1656,50 @@ int socket_address_parse_unix(SocketAddress *ret_address, const char *s) { return 0; } +int vsock_parse_port(const char *s, unsigned *ret) { + int r; + + assert(ret); + + if (!s) + return -EINVAL; + + unsigned u; + r = safe_atou(s, &u); + if (r < 0) + return r; + + /* Port 0 is apparently valid and not special in AF_VSOCK (unlike on IP). But VMADDR_PORT_ANY + * (UINT32_MAX) is. Hence refuse that. */ + + if (u == VMADDR_PORT_ANY) + return -EINVAL; + + *ret = u; + return 0; +} + +int vsock_parse_cid(const char *s, unsigned *ret) { + assert(ret); + + if (!s) + return -EINVAL; + + /* Parsed an AF_VSOCK "CID". This is a 32bit entity, and the usual type is "unsigned". We recognize + * the three special CIDs as strings, and otherwise parse the numeric CIDs. */ + + if (streq(s, "hypervisor")) + *ret = VMADDR_CID_HYPERVISOR; + else if (streq(s, "local")) + *ret = VMADDR_CID_LOCAL; + else if (streq(s, "host")) + *ret = VMADDR_CID_HOST; + else + return safe_atou(s, ret); + + return 0; +} + int socket_address_parse_vsock(SocketAddress *ret_address, const char *s) { /* AF_VSOCK socket in vsock:cid:port notation */ _cleanup_free_ char *n = NULL; @@ -1666,7 +1725,7 @@ int socket_address_parse_vsock(SocketAddress *ret_address, const char *s) { if (!e) return -EINVAL; - r = safe_atou(e+1, &port); + r = vsock_parse_port(e+1, &port); if (r < 0) return r; @@ -1677,15 +1736,15 @@ int socket_address_parse_vsock(SocketAddress *ret_address, const char *s) { if (isempty(n)) cid = VMADDR_CID_ANY; else { - r = safe_atou(n, &cid); + r = vsock_parse_cid(n, &cid); if (r < 0) return r; } *ret_address = (SocketAddress) { .sockaddr.vm = { - .svm_cid = cid, .svm_family = AF_VSOCK, + .svm_cid = cid, .svm_port = port, }, .type = type, @@ -1694,3 +1753,18 @@ int socket_address_parse_vsock(SocketAddress *ret_address, const char *s) { return 0; } + +int vsock_get_local_cid(unsigned *ret) { + _cleanup_close_ int vsock_fd = -EBADF; + + assert(ret); + + vsock_fd = open("/dev/vsock", O_RDONLY|O_CLOEXEC); + if (vsock_fd < 0) + return log_debug_errno(errno, "Failed to open /dev/vsock: %m"); + + if (ioctl(vsock_fd, IOCTL_VM_SOCKETS_GET_LOCAL_CID, ret) < 0) + return log_debug_errno(errno, "Failed to query local AF_VSOCK CID: %m"); + + return 0; +} |