summaryrefslogtreecommitdiffstats
path: root/src/lib/util/io/socketsession.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/util/io/socketsession.cc')
-rw-r--r--src/lib/util/io/socketsession.cc438
1 files changed, 438 insertions, 0 deletions
diff --git a/src/lib/util/io/socketsession.cc b/src/lib/util/io/socketsession.cc
new file mode 100644
index 0000000..88f2b1e
--- /dev/null
+++ b/src/lib/util/io/socketsession.cc
@@ -0,0 +1,438 @@
+// Copyright (C) 2011-2021 Internet Systems Consortium, Inc. ("ISC")
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#include <config.h>
+
+#include <unistd.h>
+
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <sys/uio.h>
+#include <sys/un.h>
+
+#include <netinet/in.h>
+
+#include <fcntl.h>
+#include <stdint.h>
+
+#include <cerrno>
+#include <csignal>
+#include <cstddef>
+#include <cstring>
+
+#include <string>
+#include <vector>
+
+#include <boost/noncopyable.hpp>
+
+#include <exceptions/exceptions.h>
+#include <exceptions/isc_assert.h>
+
+#include <util/buffer.h>
+
+#include <util/io/fd_share.h>
+#include <util/io/socketsession.h>
+#include <util/io/sockaddr_util.h>
+
+using namespace std;
+
+namespace isc {
+namespace util {
+namespace io {
+
+using namespace internal;
+
+// The expected max size of the session header: 2-byte header length,
+// 6 32-bit fields, and 2 sockaddr structure. (see the SocketSessionUtility
+// overview description in the header file). sizeof sockaddr_storage
+// should be the possible max of any sockaddr structure
+const size_t DEFAULT_HEADER_BUFLEN = sizeof(uint16_t) + sizeof(uint32_t) * 6 +
+ sizeof(struct sockaddr_storage) * 2;
+
+// The allowable maximum size of data passed with the socket FD. For now
+// we use a fixed value of 65535, the largest possible size of valid DNS
+// messages. We may enlarge it or make it configurable as we see the need
+// for more flexibility.
+const int MAX_DATASIZE = 65535;
+
+// The initial buffer size for receiving socket session data in the receiver.
+// This value is the maximum message size of DNS messages carried over UDP
+// (without EDNS). In our expected usage (at the moment) this should be
+// sufficiently large (the expected data is AXFR/IXFR query or an UPDATE
+// requests. The former should be generally quite small. While the latter
+// could be large, it would often be small enough for a single UDP message).
+// If it turns out that there are many exceptions, we may want to extend
+// the class so that this value can be customized. Note that the buffer
+// will be automatically extended for longer data and this is only about
+// efficiency.
+const size_t INITIAL_BUFSIZE = 512;
+
+// The (default) socket buffer size for the forwarder and receiver. This is
+// chosen to be sufficiently large to store two full-size DNS messages. We
+// may want to customize this value in future.
+const int SOCKSESSION_BUFSIZE = (DEFAULT_HEADER_BUFLEN + MAX_DATASIZE) * 2;
+
+struct SocketSessionForwarder::ForwarderImpl {
+ ForwarderImpl() : sock_un_len_(0), fd_(-1), buf_(DEFAULT_HEADER_BUFLEN) {
+ memset(&sock_un_, 0, sizeof(sock_un_));
+ }
+
+ struct sockaddr_un sock_un_;
+ socklen_t sock_un_len_;
+ int fd_;
+ OutputBuffer buf_;
+};
+
+SocketSessionForwarder::SocketSessionForwarder(const std::string& unix_file) :
+ impl_(NULL)
+{
+ // We need to filter SIGPIPE for subsequent push(). See the class
+ // description.
+ if (signal(SIGPIPE, SIG_IGN) == SIG_ERR) {
+ isc_throw(Unexpected, "Failed to filter SIGPIPE: " << strerror(errno));
+ }
+
+ ForwarderImpl impl;
+ if (sizeof(impl.sock_un_.sun_path) - 1 < unix_file.length()) {
+ isc_throw(SocketSessionError,
+ "File name for a UNIX domain socket is too long: " <<
+ unix_file);
+ }
+ impl.sock_un_.sun_family = AF_UNIX;
+ // the copy should be safe due to the above check, but we'd be rather
+ // paranoid about making it 100% sure even if the check has a bug (with
+ // triggering the assertion in the worse case)
+ memset(&impl.sock_un_.sun_path, 0, sizeof(impl.sock_un_.sun_path));
+ strncpy(impl.sock_un_.sun_path, unix_file.c_str(),
+ sizeof(impl.sock_un_.sun_path) - 1);
+ isc_throw_assert(impl.sock_un_.sun_path[sizeof(impl.sock_un_.sun_path) - 1] == '\0');
+ impl.sock_un_len_ = offsetof(struct sockaddr_un, sun_path) +
+ unix_file.length();
+#ifdef HAVE_SA_LEN
+ impl.sock_un_.sun_len = impl.sock_un_len_;
+#endif
+ impl.fd_ = -1;
+
+ impl_ = new ForwarderImpl;
+ *impl_ = impl;
+}
+
+SocketSessionForwarder::~SocketSessionForwarder() {
+ if (impl_->fd_ != -1) {
+ close();
+ }
+ delete impl_;
+}
+
+void
+SocketSessionForwarder::connectToReceiver() {
+ if (impl_->fd_ != -1) {
+ isc_throw(BadValue, "Duplicate connect to UNIX domain "
+ "endpoint " << impl_->sock_un_.sun_path);
+ }
+
+ impl_->fd_ = socket(AF_UNIX, SOCK_STREAM, 0);
+ if (impl_->fd_ == -1) {
+ isc_throw(SocketSessionError, "Failed to create a UNIX domain socket: "
+ << strerror(errno));
+ }
+ // Make the socket non blocking
+ int fcntl_flags = fcntl(impl_->fd_, F_GETFL, 0);
+ if (fcntl_flags != -1) {
+ fcntl_flags |= O_NONBLOCK;
+ fcntl_flags = fcntl(impl_->fd_, F_SETFL, fcntl_flags);
+ }
+ if (fcntl_flags == -1) {
+ close(); // note: this is the internal method, not ::close()
+ isc_throw(SocketSessionError,
+ "Failed to make UNIX domain socket non blocking: " <<
+ strerror(errno));
+ }
+ // Ensure the socket send buffer is large enough. If we can't get the
+ // current size, simply set the sufficient size.
+ int sndbuf_size;
+ socklen_t sndbuf_size_len = sizeof(sndbuf_size);
+ if (getsockopt(impl_->fd_, SOL_SOCKET, SO_SNDBUF, &sndbuf_size,
+ &sndbuf_size_len) == -1 ||
+ sndbuf_size < SOCKSESSION_BUFSIZE) {
+ if (setsockopt(impl_->fd_, SOL_SOCKET, SO_SNDBUF, &SOCKSESSION_BUFSIZE,
+ sizeof(SOCKSESSION_BUFSIZE)) == -1) {
+ close();
+ isc_throw(SocketSessionError,
+ "Failed to set send buffer size to " <<
+ SOCKSESSION_BUFSIZE);
+ }
+ }
+ if (connect(impl_->fd_, convertSockAddr(&impl_->sock_un_),
+ impl_->sock_un_len_) == -1) {
+ close();
+ isc_throw(SocketSessionError, "Failed to connect to UNIX domain "
+ "endpoint " << impl_->sock_un_.sun_path << ": " <<
+ strerror(errno));
+ }
+}
+
+void
+SocketSessionForwarder::close() {
+ if (impl_->fd_ == -1) {
+ isc_throw(BadValue, "Attempt of close before connect");
+ }
+ ::close(impl_->fd_);
+ impl_->fd_ = -1;
+}
+
+void
+SocketSessionForwarder::push(int sock, int family, int type, int protocol,
+ const struct sockaddr& local_end,
+ const struct sockaddr& remote_end,
+ const void* data, size_t data_len)
+{
+ if (impl_->fd_ == -1) {
+ isc_throw(BadValue, "Attempt of push before connect");
+ }
+ if ((local_end.sa_family != AF_INET && local_end.sa_family != AF_INET6) ||
+ (remote_end.sa_family != AF_INET && remote_end.sa_family != AF_INET6))
+ {
+ isc_throw(BadValue, "Invalid address family: must be "
+ "AF_INET or AF_INET6; " <<
+ static_cast<int>(local_end.sa_family) << ", " <<
+ static_cast<int>(remote_end.sa_family) << " given");
+ }
+ if (family != local_end.sa_family || family != remote_end.sa_family) {
+ isc_throw(BadValue, "Inconsistent address family: must be "
+ << static_cast<int>(family) << "; "
+ << static_cast<int>(local_end.sa_family) << ", "
+ << static_cast<int>(remote_end.sa_family) << " given");
+ }
+ if (data_len == 0 || data == NULL) {
+ isc_throw(BadValue, "Data for a socket session must not be empty");
+ }
+ if (data_len > MAX_DATASIZE) {
+ isc_throw(BadValue, "Invalid socket session data size: " <<
+ data_len << ", must not exceed " << MAX_DATASIZE);
+ }
+
+ if (send_fd(impl_->fd_, sock) != 0) {
+ isc_throw(SocketSessionError, "FD passing failed: " <<
+ strerror(errno));
+ }
+
+ impl_->buf_.clear();
+ // Leave the space for the header length
+ impl_->buf_.skip(sizeof(uint16_t));
+ // Socket properties: family, type, protocol
+ impl_->buf_.writeUint32(static_cast<uint32_t>(family));
+ impl_->buf_.writeUint32(static_cast<uint32_t>(type));
+ impl_->buf_.writeUint32(static_cast<uint32_t>(protocol));
+ // Local endpoint
+ impl_->buf_.writeUint32(static_cast<uint32_t>(getSALength(local_end)));
+ impl_->buf_.writeData(&local_end, getSALength(local_end));
+ // Remote endpoint
+ impl_->buf_.writeUint32(static_cast<uint32_t>(getSALength(remote_end)));
+ impl_->buf_.writeData(&remote_end, getSALength(remote_end));
+ // Data length. Must be fit uint32 due to the range check above.
+ const uint32_t data_len32 = static_cast<uint32_t>(data_len);
+ isc_throw_assert(data_len == data_len32); // shouldn't cause overflow.
+ impl_->buf_.writeUint32(data_len32);
+ // Write the resulting header length at the beginning of the buffer
+ impl_->buf_.writeUint16At(impl_->buf_.getLength() - sizeof(uint16_t), 0);
+
+ const struct iovec iov[2] = {
+ { const_cast<void*>(impl_->buf_.getData()), impl_->buf_.getLength() },
+ { const_cast<void*>(data), data_len }
+ };
+ const int cc = writev(impl_->fd_, iov, 2);
+ if (cc != impl_->buf_.getLength() + data_len) {
+ if (cc < 0) {
+ isc_throw(SocketSessionError,
+ "Write failed in forwarding a socket session: " <<
+ strerror(errno));
+ }
+ isc_throw(SocketSessionError,
+ "Incomplete write in forwarding a socket session: " << cc <<
+ "/" << (impl_->buf_.getLength() + data_len));
+ }
+}
+
+SocketSession::SocketSession(int sock, int family, int type, int protocol,
+ const sockaddr* local_end,
+ const sockaddr* remote_end,
+ const void* data, size_t data_len) :
+ sock_(sock), family_(family), type_(type), protocol_(protocol),
+ local_end_(local_end), remote_end_(remote_end),
+ data_(data), data_len_(data_len)
+{
+ if (local_end == NULL || remote_end == NULL) {
+ isc_throw(BadValue, "sockaddr must be non NULL for SocketSession");
+ }
+ if (data_len == 0) {
+ isc_throw(BadValue, "data_len must be non 0 for SocketSession");
+ }
+ if (data == NULL) {
+ isc_throw(BadValue, "data must be non NULL for SocketSession");
+ }
+}
+
+struct SocketSessionReceiver::ReceiverImpl {
+ ReceiverImpl(int fd) : fd_(fd),
+ sa_local_(convertSockAddr(&ss_local_)),
+ sa_remote_(convertSockAddr(&ss_remote_)),
+ header_buf_(DEFAULT_HEADER_BUFLEN),
+ data_buf_(INITIAL_BUFSIZE)
+ {
+ memset(&ss_local_, 0, sizeof(ss_local_));
+ memset(&ss_remote_, 0, sizeof(ss_remote_));
+
+ if (setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &SOCKSESSION_BUFSIZE,
+ sizeof(SOCKSESSION_BUFSIZE)) == -1) {
+ isc_throw(SocketSessionError,
+ "Failed to set receive buffer size to " <<
+ SOCKSESSION_BUFSIZE);
+ }
+ }
+
+ const int fd_;
+ struct sockaddr_storage ss_local_; // placeholder for local endpoint
+ struct sockaddr* const sa_local_;
+ struct sockaddr_storage ss_remote_; // placeholder for remote endpoint
+ struct sockaddr* const sa_remote_;
+
+ // placeholder for session header and data
+ vector<uint8_t> header_buf_;
+ vector<uint8_t> data_buf_;
+};
+
+SocketSessionReceiver::SocketSessionReceiver(int fd) :
+ impl_(new ReceiverImpl(fd))
+{
+}
+
+SocketSessionReceiver::~SocketSessionReceiver() {
+ delete impl_;
+}
+
+namespace {
+// A shortcut to throw common exception on failure of recv(2)
+void
+readFail(int actual_len, int expected_len) {
+ if (expected_len < 0) {
+ isc_throw(SocketSessionError, "Failed to receive data from "
+ "SocketSessionForwarder: " << strerror(errno));
+ }
+ isc_throw(SocketSessionError, "Incomplete data from "
+ "SocketSessionForwarder: " << actual_len << "/" <<
+ expected_len);
+}
+
+// A helper container for a (socket) file descriptor used in
+// SocketSessionReceiver::pop that ensures the socket is closed unless it
+// can be safely passed to the caller via release().
+struct ScopedSocket : boost::noncopyable {
+ ScopedSocket(int fd) : fd_(fd) {}
+ ~ScopedSocket() {
+ if (fd_ >= 0) {
+ close(fd_);
+ }
+ }
+ int release() {
+ const int fd = fd_;
+ fd_ = -1;
+ return (fd);
+ }
+ int fd_;
+};
+}
+
+SocketSession
+SocketSessionReceiver::pop() {
+ ScopedSocket passed_sock(recv_fd(impl_->fd_));
+ if (passed_sock.fd_ == FD_SYSTEM_ERROR) {
+ isc_throw(SocketSessionError, "Receiving a forwarded FD failed: " <<
+ strerror(errno));
+ } else if (passed_sock.fd_ < 0) {
+ isc_throw(SocketSessionError, "No FD forwarded");
+ }
+
+ uint16_t header_len;
+ const int cc_hlen = recv(impl_->fd_, &header_len, sizeof(header_len),
+ MSG_WAITALL);
+ if (cc_hlen < sizeof(header_len)) {
+ readFail(cc_hlen, sizeof(header_len));
+ }
+ header_len = InputBuffer(&header_len, sizeof(header_len)).readUint16();
+ if (header_len > DEFAULT_HEADER_BUFLEN) {
+ isc_throw(SocketSessionError, "Too large header length: " <<
+ header_len);
+ }
+ impl_->header_buf_.clear();
+ impl_->header_buf_.resize(header_len);
+ const int cc_hdr = recv(impl_->fd_, &impl_->header_buf_[0], header_len,
+ MSG_WAITALL);
+ if (cc_hdr < header_len) {
+ readFail(cc_hdr, header_len);
+ }
+
+ InputBuffer ibuffer(&impl_->header_buf_[0], header_len);
+ try {
+ const int family = static_cast<int>(ibuffer.readUint32());
+ if (family != AF_INET && family != AF_INET6) {
+ isc_throw(SocketSessionError,
+ "Unsupported address family is passed: " << family);
+ }
+ const int type = static_cast<int>(ibuffer.readUint32());
+ const int protocol = static_cast<int>(ibuffer.readUint32());
+ const socklen_t local_end_len = ibuffer.readUint32();
+ const socklen_t endpoint_minlen = (family == AF_INET) ?
+ sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6);
+ if (local_end_len < endpoint_minlen ||
+ local_end_len > sizeof(impl_->ss_local_)) {
+ isc_throw(SocketSessionError, "Invalid local SA length: " <<
+ local_end_len);
+ }
+ ibuffer.readData(&impl_->ss_local_, local_end_len);
+ const socklen_t remote_end_len = ibuffer.readUint32();
+ if (remote_end_len < endpoint_minlen ||
+ remote_end_len > sizeof(impl_->ss_remote_)) {
+ isc_throw(SocketSessionError, "Invalid remote SA length: " <<
+ remote_end_len);
+ }
+ ibuffer.readData(&impl_->ss_remote_, remote_end_len);
+ if (family != impl_->sa_local_->sa_family ||
+ family != impl_->sa_remote_->sa_family) {
+ isc_throw(SocketSessionError, "SA family inconsistent: " <<
+ static_cast<int>(impl_->sa_local_->sa_family) << ", " <<
+ static_cast<int>(impl_->sa_remote_->sa_family) <<
+ " given, must be " << family);
+ }
+ const size_t data_len = ibuffer.readUint32();
+ if (data_len == 0 || data_len > MAX_DATASIZE) {
+ isc_throw(SocketSessionError,
+ "Invalid socket session data size: " << data_len <<
+ ", must be > 0 and <= " << MAX_DATASIZE);
+ }
+
+ impl_->data_buf_.clear();
+ impl_->data_buf_.resize(data_len);
+ const int cc_data = recv(impl_->fd_, &impl_->data_buf_[0], data_len,
+ MSG_WAITALL);
+ if (cc_data < data_len) {
+ readFail(cc_data, data_len);
+ }
+
+ return (SocketSession(passed_sock.release(), family, type, protocol,
+ impl_->sa_local_, impl_->sa_remote_,
+ &impl_->data_buf_[0], data_len));
+ } catch (const InvalidBufferPosition& ex) {
+ // We catch the case where the given header is too short and convert
+ // the exception to SocketSessionError.
+ isc_throw(SocketSessionError, "bogus socket session header: " <<
+ ex.what());
+ }
+}
+
+}
+}
+}