diff options
Diffstat (limited to 'examples/shared.cc')
-rw-r--r-- | examples/shared.cc | 385 |
1 files changed, 385 insertions, 0 deletions
diff --git a/examples/shared.cc b/examples/shared.cc new file mode 100644 index 0000000..d65819a --- /dev/null +++ b/examples/shared.cc @@ -0,0 +1,385 @@ +/* + * ngtcp2 + * + * Copyright (c) 2019 ngtcp2 contributors + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE + * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +#include "shared.h" + +#include <nghttp3/nghttp3.h> + +#include <cstring> +#include <cassert> +#include <iostream> + +#include <unistd.h> +#ifdef HAVE_NETINET_IN_H +# include <netinet/in.h> +#endif // HAVE_NETINET_IN_H +#ifdef HAVE_ASM_TYPES_H +# include <asm/types.h> +#endif // HAVE_ASM_TYPES_H +#ifdef HAVE_LINUX_NETLINK_H +# include <linux/netlink.h> +#endif // HAVE_LINUX_NETLINK_H +#ifdef HAVE_LINUX_RTNETLINK_H +# include <linux/rtnetlink.h> +#endif // HAVE_LINUX_RTNETLINK_H + +#include "template.h" + +namespace ngtcp2 { + +unsigned int msghdr_get_ecn(msghdr *msg, int family) { + switch (family) { + case AF_INET: + for (auto cmsg = CMSG_FIRSTHDR(msg); cmsg; cmsg = CMSG_NXTHDR(msg, cmsg)) { + if (cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_TOS && + cmsg->cmsg_len) { + return *reinterpret_cast<uint8_t *>(CMSG_DATA(cmsg)); + } + } + break; + case AF_INET6: + for (auto cmsg = CMSG_FIRSTHDR(msg); cmsg; cmsg = CMSG_NXTHDR(msg, cmsg)) { + if (cmsg->cmsg_level == IPPROTO_IPV6 && cmsg->cmsg_type == IPV6_TCLASS && + cmsg->cmsg_len) { + return *reinterpret_cast<uint8_t *>(CMSG_DATA(cmsg)); + } + } + break; + } + + return 0; +} + +void fd_set_ecn(int fd, int family, unsigned int ecn) { + switch (family) { + case AF_INET: + if (setsockopt(fd, IPPROTO_IP, IP_TOS, &ecn, + static_cast<socklen_t>(sizeof(ecn))) == -1) { + std::cerr << "setsockopt: " << strerror(errno) << std::endl; + } + break; + case AF_INET6: + if (setsockopt(fd, IPPROTO_IPV6, IPV6_TCLASS, &ecn, + static_cast<socklen_t>(sizeof(ecn))) == -1) { + std::cerr << "setsockopt: " << strerror(errno) << std::endl; + } + break; + } +} + +void fd_set_recv_ecn(int fd, int family) { + unsigned int tos = 1; + switch (family) { + case AF_INET: + if (setsockopt(fd, IPPROTO_IP, IP_RECVTOS, &tos, + static_cast<socklen_t>(sizeof(tos))) == -1) { + std::cerr << "setsockopt: " << strerror(errno) << std::endl; + } + break; + case AF_INET6: + if (setsockopt(fd, IPPROTO_IPV6, IPV6_RECVTCLASS, &tos, + static_cast<socklen_t>(sizeof(tos))) == -1) { + std::cerr << "setsockopt: " << strerror(errno) << std::endl; + } + break; + } +} + +void fd_set_ip_mtu_discover(int fd, int family) { +#if defined(IP_MTU_DISCOVER) && defined(IPV6_MTU_DISCOVER) + int val; + + switch (family) { + case AF_INET: + val = IP_PMTUDISC_DO; + if (setsockopt(fd, IPPROTO_IP, IP_MTU_DISCOVER, &val, + static_cast<socklen_t>(sizeof(val))) == -1) { + std::cerr << "setsockopt: IP_MTU_DISCOVER: " << strerror(errno) + << std::endl; + } + break; + case AF_INET6: + val = IPV6_PMTUDISC_DO; + if (setsockopt(fd, IPPROTO_IPV6, IPV6_MTU_DISCOVER, &val, + static_cast<socklen_t>(sizeof(val))) == -1) { + std::cerr << "setsockopt: IPV6_MTU_DISCOVER: " << strerror(errno) + << std::endl; + } + break; + } +#endif // defined(IP_MTU_DISCOVER) && defined(IPV6_MTU_DISCOVER) +} + +void fd_set_ip_dontfrag(int fd, int family) { +#if defined(IP_DONTFRAG) && defined(IPV6_DONTFRAG) + int val = 1; + + switch (family) { + case AF_INET: + if (setsockopt(fd, IPPROTO_IP, IP_DONTFRAG, &val, + static_cast<socklen_t>(sizeof(val))) == -1) { + std::cerr << "setsockopt: IP_DONTFRAG: " << strerror(errno) << std::endl; + } + break; + case AF_INET6: + if (setsockopt(fd, IPPROTO_IPV6, IPV6_DONTFRAG, &val, + static_cast<socklen_t>(sizeof(val))) == -1) { + std::cerr << "setsockopt: IPV6_DONTFRAG: " << strerror(errno) + << std::endl; + } + break; + } +#endif // defined(IP_DONTFRAG) && defined(IPV6_DONTFRAG) +} + +std::optional<Address> msghdr_get_local_addr(msghdr *msg, int family) { + switch (family) { + case AF_INET: + for (auto cmsg = CMSG_FIRSTHDR(msg); cmsg; cmsg = CMSG_NXTHDR(msg, cmsg)) { + if (cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_PKTINFO) { + auto pktinfo = reinterpret_cast<in_pktinfo *>(CMSG_DATA(cmsg)); + Address res{}; + res.ifindex = pktinfo->ipi_ifindex; + res.len = sizeof(res.su.in); + auto &sa = res.su.in; + sa.sin_family = AF_INET; + sa.sin_addr = pktinfo->ipi_addr; + return res; + } + } + return {}; + case AF_INET6: + for (auto cmsg = CMSG_FIRSTHDR(msg); cmsg; cmsg = CMSG_NXTHDR(msg, cmsg)) { + if (cmsg->cmsg_level == IPPROTO_IPV6 && cmsg->cmsg_type == IPV6_PKTINFO) { + auto pktinfo = reinterpret_cast<in6_pktinfo *>(CMSG_DATA(cmsg)); + Address res{}; + res.ifindex = pktinfo->ipi6_ifindex; + res.len = sizeof(res.su.in6); + auto &sa = res.su.in6; + sa.sin6_family = AF_INET6; + sa.sin6_addr = pktinfo->ipi6_addr; + return res; + } + } + return {}; + } + return {}; +} + +void set_port(Address &dst, Address &src) { + switch (dst.su.storage.ss_family) { + case AF_INET: + assert(AF_INET == src.su.storage.ss_family); + dst.su.in.sin_port = src.su.in.sin_port; + return; + case AF_INET6: + assert(AF_INET6 == src.su.storage.ss_family); + dst.su.in6.sin6_port = src.su.in6.sin6_port; + return; + default: + assert(0); + } +} + +#ifdef HAVE_LINUX_RTNETLINK_H + +struct nlmsg { + nlmsghdr hdr; + rtmsg msg; + rtattr dst; + in_addr_union dst_addr; +}; + +namespace { +int send_netlink_msg(int fd, const Address &remote_addr) { + nlmsg nlmsg{}; + nlmsg.hdr.nlmsg_type = RTM_GETROUTE; + nlmsg.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + + nlmsg.msg.rtm_family = remote_addr.su.sa.sa_family; + + nlmsg.dst.rta_type = RTA_DST; + + switch (remote_addr.su.sa.sa_family) { + case AF_INET: + nlmsg.dst.rta_len = RTA_LENGTH(sizeof(remote_addr.su.in.sin_addr)); + memcpy(RTA_DATA(&nlmsg.dst), &remote_addr.su.in.sin_addr, + sizeof(remote_addr.su.in.sin_addr)); + break; + case AF_INET6: + nlmsg.dst.rta_len = RTA_LENGTH(sizeof(remote_addr.su.in6.sin6_addr)); + memcpy(RTA_DATA(&nlmsg.dst), &remote_addr.su.in6.sin6_addr, + sizeof(remote_addr.su.in6.sin6_addr)); + break; + default: + assert(0); + } + + nlmsg.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(nlmsg.msg) + nlmsg.dst.rta_len); + + sockaddr_nl sa{}; + sa.nl_family = AF_NETLINK; + + iovec iov{&nlmsg, nlmsg.hdr.nlmsg_len}; + msghdr msg{}; + msg.msg_name = &sa; + msg.msg_namelen = sizeof(sa); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ssize_t nwrite; + + do { + nwrite = sendmsg(fd, &msg, 0); + } while (nwrite == -1 && errno == EINTR); + + if (nwrite == -1) { + std::cerr << "sendmsg: Could not write netlink message: " << strerror(errno) + << std::endl; + return -1; + } + + return 0; +} +} // namespace + +namespace { +int recv_netlink_msg(in_addr_union &iau, int fd) { + std::array<uint8_t, 8192> buf; + iovec iov = {buf.data(), buf.size()}; + sockaddr_nl sa{}; + msghdr msg{}; + + msg.msg_name = &sa; + msg.msg_namelen = sizeof(sa); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ssize_t nread; + + do { + nread = recvmsg(fd, &msg, 0); + } while (nread == -1 && errno == EINTR); + + if (nread == -1) { + std::cerr << "recvmsg: Could not receive netlink message: " + << strerror(errno) << std::endl; + return -1; + } + + for (auto hdr = reinterpret_cast<nlmsghdr *>(buf.data()); + NLMSG_OK(hdr, nread); hdr = NLMSG_NEXT(hdr, nread)) { + switch (hdr->nlmsg_type) { + case NLMSG_DONE: + std::cerr << "netlink: no info returned from kernel" << std::endl; + return -1; + case NLMSG_NOOP: + continue; + case NLMSG_ERROR: + std::cerr << "netlink: " + << strerror(-static_cast<nlmsgerr *>(NLMSG_DATA(hdr))->error) + << std::endl; + return -1; + } + + auto attrlen = hdr->nlmsg_len - NLMSG_SPACE(sizeof(rtmsg)); + + for (auto rta = reinterpret_cast<rtattr *>( + static_cast<uint8_t *>(NLMSG_DATA(hdr)) + sizeof(rtmsg)); + RTA_OK(rta, attrlen); rta = RTA_NEXT(rta, attrlen)) { + if (rta->rta_type != RTA_PREFSRC) { + continue; + } + + size_t in_addrlen; + + switch (static_cast<rtmsg *>(NLMSG_DATA(hdr))->rtm_family) { + case AF_INET: + in_addrlen = sizeof(in_addr); + break; + case AF_INET6: + in_addrlen = sizeof(in6_addr); + break; + default: + assert(0); + abort(); + } + + if (RTA_LENGTH(in_addrlen) != rta->rta_len) { + return -1; + } + + memcpy(&iau, RTA_DATA(rta), in_addrlen); + + return 0; + } + } + + return -1; +} +} // namespace + +int get_local_addr(in_addr_union &iau, const Address &remote_addr) { + sockaddr_nl sa{}; + sa.nl_family = AF_NETLINK; + + auto fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE); + if (fd == -1) { + std::cerr << "socket: Could not create netlink socket: " << strerror(errno) + << std::endl; + return -1; + } + + auto fd_d = defer(close, fd); + + if (bind(fd, reinterpret_cast<sockaddr *>(&sa), sizeof(sa)) == -1) { + std::cerr << "bind: Could not bind netlink socket: " << strerror(errno) + << std::endl; + return -1; + } + + if (send_netlink_msg(fd, remote_addr) != 0) { + return -1; + } + + return recv_netlink_msg(iau, fd); +} + +#endif // HAVE_LINUX_NETLINK_H + +bool addreq(const sockaddr *sa, const in_addr_union &iau) { + switch (sa->sa_family) { + case AF_INET: + return memcmp(&reinterpret_cast<const sockaddr_in *>(sa)->sin_addr, &iau.in, + sizeof(iau.in)) == 0; + case AF_INET6: + return memcmp(&reinterpret_cast<const sockaddr_in6 *>(sa)->sin6_addr, + &iau.in6, sizeof(iau.in6)) == 0; + default: + assert(0); + abort(); + } +} + +} // namespace ngtcp2 |