summaryrefslogtreecommitdiffstats
path: root/tcpiohandler.hh
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tcpiohandler.hh584
1 files changed, 584 insertions, 0 deletions
diff --git a/tcpiohandler.hh b/tcpiohandler.hh
new file mode 100644
index 0000000..88f0dc7
--- /dev/null
+++ b/tcpiohandler.hh
@@ -0,0 +1,584 @@
+
+#pragma once
+#include <memory>
+/* needed for proper TCP_FASTOPEN_CONNECT detection */
+#include <netinet/tcp.h>
+
+#include "iputils.hh"
+#include "libssl.hh"
+#include "misc.hh"
+#include "noinitvector.hh"
+
+/* Async is only returned for TLS connections, if OpenSSL's async mode has been enabled */
+enum class IOState : uint8_t { Done, NeedRead, NeedWrite, Async };
+
+class TLSSession
+{
+public:
+ virtual ~TLSSession()
+ {
+ }
+};
+
+class TLSConnection
+{
+public:
+ virtual ~TLSConnection() { }
+ virtual void doHandshake() = 0;
+ virtual IOState tryConnect(bool fastOpen, const ComboAddress& remote) = 0;
+ virtual void connect(bool fastOpen, const ComboAddress& remote, const struct timeval& timeout) = 0;
+ virtual IOState tryHandshake() = 0;
+ virtual size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout={0,0}, bool allowIncomplete=false) = 0;
+ virtual size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout) = 0;
+ virtual IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) = 0;
+ virtual IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false) = 0;
+ virtual bool hasBufferedData() const = 0;
+ virtual std::string getServerNameIndication() const = 0;
+ virtual std::vector<uint8_t> getNextProtocol() const = 0;
+ virtual LibsslTLSVersion getTLSVersion() const = 0;
+ virtual bool hasSessionBeenResumed() const = 0;
+ virtual std::vector<std::unique_ptr<TLSSession>> getSessions() = 0;
+ virtual void setSession(std::unique_ptr<TLSSession>& session) = 0;
+ virtual bool isUsable() const = 0;
+ virtual std::vector<int> getAsyncFDs() = 0;
+ virtual void close() = 0;
+
+ void setUnknownTicketKey()
+ {
+ d_unknownTicketKey = true;
+ }
+
+ bool getUnknownTicketKey() const
+ {
+ return d_unknownTicketKey;
+ }
+
+ void setResumedFromInactiveTicketKey()
+ {
+ d_resumedFromInactiveTicketKey = true;
+ }
+
+ bool getResumedFromInactiveTicketKey() const
+ {
+ return d_resumedFromInactiveTicketKey;
+ }
+
+protected:
+ int d_socket{-1};
+ bool d_unknownTicketKey{false};
+ bool d_resumedFromInactiveTicketKey{false};
+};
+
+class TLSCtx
+{
+public:
+ TLSCtx()
+ {
+ d_rotatingTicketsKey.clear();
+ }
+ virtual ~TLSCtx() {}
+ virtual std::unique_ptr<TLSConnection> getConnection(int socket, const struct timeval& timeout, time_t now) = 0;
+ virtual std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout) = 0;
+ virtual void rotateTicketsKey(time_t now) = 0;
+ virtual void loadTicketsKeys(const std::string& /* file */)
+ {
+ throw std::runtime_error("This TLS backend does not have the capability to load a tickets key from a file");
+ }
+
+ void handleTicketsKeyRotation(time_t now)
+ {
+ if (d_ticketsKeyRotationDelay != 0 && now > d_ticketsKeyNextRotation) {
+ if (d_rotatingTicketsKey.test_and_set()) {
+ /* someone is already rotating */
+ return;
+ }
+ try {
+ rotateTicketsKey(now);
+ d_rotatingTicketsKey.clear();
+ }
+ catch(const std::runtime_error& e) {
+ d_rotatingTicketsKey.clear();
+ throw std::runtime_error(std::string("Error generating a new tickets key for TLS context:") + e.what());
+ }
+ catch(...) {
+ d_rotatingTicketsKey.clear();
+ throw;
+ }
+ }
+ }
+
+ time_t getNextTicketsKeyRotation() const
+ {
+ return d_ticketsKeyNextRotation;
+ }
+
+ virtual size_t getTicketsKeysCount() = 0;
+ virtual std::string getName() const = 0;
+
+ /* set the advertised ALPN protocols, in client or server context */
+ virtual bool setALPNProtos(const std::vector<std::vector<uint8_t>>& /* protos */)
+ {
+ return false;
+ }
+
+ /* called in a client context, if the client advertised more than one ALPN values and the server returned more than one as well, to select the one to use. */
+ virtual bool setNextProtocolSelectCallback(bool(*)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen))
+ {
+ return false;
+ }
+
+protected:
+ std::atomic_flag d_rotatingTicketsKey;
+ std::atomic<time_t> d_ticketsKeyNextRotation{0};
+ time_t d_ticketsKeyRotationDelay{0};
+};
+
+class TLSFrontend
+{
+public:
+ TLSFrontend()
+ {
+ }
+
+ TLSFrontend(std::shared_ptr<TLSCtx> ctx): d_ctx(std::move(ctx))
+ {
+ }
+
+ bool setupTLS();
+
+ void rotateTicketsKey(time_t now)
+ {
+ if (d_ctx != nullptr) {
+ d_ctx->rotateTicketsKey(now);
+ }
+ }
+
+ void loadTicketsKeys(const std::string& file)
+ {
+ if (d_ctx != nullptr) {
+ d_ctx->loadTicketsKeys(file);
+ }
+ }
+
+ std::shared_ptr<TLSCtx> getContext()
+ {
+ return std::atomic_load_explicit(&d_ctx, std::memory_order_acquire);
+ }
+
+ void cleanup()
+ {
+ d_ctx.reset();
+ }
+
+ size_t getTicketsKeysCount()
+ {
+ if (d_ctx != nullptr) {
+ return d_ctx->getTicketsKeysCount();
+ }
+
+ return 0;
+ }
+
+ static std::string timeToString(time_t rotationTime)
+ {
+ char buf[20];
+ struct tm date_tm;
+
+ localtime_r(&rotationTime, &date_tm);
+ strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", &date_tm);
+
+ return std::string(buf);
+ }
+
+ time_t getTicketsKeyRotationDelay() const
+ {
+ return d_tlsConfig.d_ticketsKeyRotationDelay;
+ }
+
+ std::string getNextTicketsKeyRotation() const
+ {
+ std::string res;
+
+ if (d_ctx != nullptr) {
+ res = timeToString(d_ctx->getNextTicketsKeyRotation());
+ }
+
+ return res;
+ }
+
+ std::string getRequestedProvider() const
+ {
+ return d_provider;
+ }
+
+ std::string getEffectiveProvider() const
+ {
+ if (d_ctx) {
+ return d_ctx->getName();
+ }
+ return "";
+ }
+
+ TLSConfig d_tlsConfig;
+ TLSErrorCounters d_tlsCounters;
+ ComboAddress d_addr;
+ std::string d_provider;
+
+protected:
+ std::shared_ptr<TLSCtx> d_ctx{nullptr};
+};
+
+class TCPIOHandler
+{
+public:
+ enum class Type : uint8_t { Client, Server };
+
+ TCPIOHandler(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout, std::shared_ptr<TLSCtx> ctx): d_socket(socket)
+ {
+ if (ctx) {
+ d_conn = ctx->getClientConnection(host, hostIsAddr, d_socket, timeout);
+ }
+ }
+
+ TCPIOHandler(int socket, const struct timeval& timeout, std::shared_ptr<TLSCtx> ctx, time_t now): d_socket(socket)
+ {
+ if (ctx) {
+ d_conn = ctx->getConnection(d_socket, timeout, now);
+ }
+ }
+
+ ~TCPIOHandler()
+ {
+ close();
+ }
+
+ void close()
+ {
+ if (d_conn) {
+ d_conn->close();
+ d_conn.reset();
+ }
+
+ if (d_socket != -1) {
+ shutdown(d_socket, SHUT_RDWR);
+ ::close(d_socket);
+ d_socket = -1;
+ }
+ }
+
+ int getDescriptor() const
+ {
+ return d_socket;
+ }
+
+ IOState tryConnect(bool fastOpen, const ComboAddress& remote)
+ {
+ d_remote = remote;
+
+#ifdef TCP_FASTOPEN_CONNECT /* Linux >= 4.11 */
+ if (fastOpen) {
+ int value = 1;
+ int res = setsockopt(d_socket, IPPROTO_TCP, TCP_FASTOPEN_CONNECT, &value, sizeof(value));
+ if (res == 0) {
+ fastOpen = false;
+ }
+ }
+#endif /* TCP_FASTOPEN_CONNECT */
+
+#ifdef MSG_FASTOPEN
+ if (!d_conn && fastOpen) {
+ d_fastOpen = true;
+ }
+ else {
+ if (!s_disableConnectForUnitTests) {
+ SConnectWithTimeout(d_socket, remote, /* no timeout, we will handle it ourselves */ timeval{0,0});
+ }
+ }
+#else
+ if (!s_disableConnectForUnitTests) {
+ SConnectWithTimeout(d_socket, remote, /* no timeout, we will handle it ourselves */ timeval{0,0});
+ }
+#endif /* MSG_FASTOPEN */
+
+ if (d_conn) {
+ return d_conn->tryConnect(fastOpen, remote);
+ }
+
+ return IOState::Done;
+ }
+
+ void connect(bool fastOpen, const ComboAddress& remote, const struct timeval& timeout)
+ {
+ d_remote = remote;
+
+#ifdef TCP_FASTOPEN_CONNECT /* Linux >= 4.11 */
+ if (fastOpen) {
+ int value = 1;
+ int res = setsockopt(d_socket, IPPROTO_TCP, TCP_FASTOPEN_CONNECT, &value, sizeof(value));
+ if (res == 0) {
+ fastOpen = false;
+ }
+ }
+#endif /* TCP_FASTOPEN_CONNECT */
+
+#ifdef MSG_FASTOPEN
+ if (!d_conn && fastOpen) {
+ d_fastOpen = true;
+ }
+ else {
+ if (!s_disableConnectForUnitTests) {
+ SConnectWithTimeout(d_socket, remote, timeout);
+ }
+ }
+#else
+ if (!s_disableConnectForUnitTests) {
+ SConnectWithTimeout(d_socket, remote, timeout);
+ }
+#endif /* MSG_FASTOPEN */
+
+ if (d_conn) {
+ d_conn->connect(fastOpen, remote, timeout);
+ }
+ }
+
+ IOState tryHandshake()
+ {
+ if (d_conn) {
+ return d_conn->tryHandshake();
+ }
+ return IOState::Done;
+ }
+
+ size_t read(void* buffer, size_t bufferSize, const struct timeval& readTimeout, const struct timeval& totalTimeout = {0,0}, bool allowIncomplete=false)
+ {
+ if (d_conn) {
+ return d_conn->read(buffer, bufferSize, readTimeout, totalTimeout, allowIncomplete);
+ } else {
+ return readn2WithTimeout(d_socket, buffer, bufferSize, readTimeout, totalTimeout, allowIncomplete);
+ }
+ }
+
+ /* Tries to read exactly toRead - pos bytes into the buffer, starting at position pos.
+ Updates pos everytime a successful read occurs,
+ throws an std::runtime_error in case of IO error,
+ return Done when toRead bytes have been read, needRead or needWrite if the IO operation
+ would block.
+ */
+ IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false)
+ {
+ if (buffer.size() < toRead || pos >= toRead) {
+ throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead - pos) + " bytes starting at " + std::to_string(pos));
+ }
+
+ if (d_conn) {
+ return d_conn->tryRead(buffer, pos, toRead, allowIncomplete);
+ }
+
+ do {
+ ssize_t res = ::read(d_socket, reinterpret_cast<char*>(&buffer.at(pos)), toRead - pos);
+ if (res == 0) {
+ throw runtime_error("EOF while reading message");
+ }
+ if (res < 0) {
+ if (errno == EAGAIN || errno == EWOULDBLOCK || errno == ENOTCONN) {
+ return IOState::NeedRead;
+ }
+ else {
+ throw std::runtime_error("Error while reading message: " + stringerror());
+ }
+ }
+
+ pos += static_cast<size_t>(res);
+ if (allowIncomplete) {
+ break;
+ }
+ }
+ while (pos < toRead);
+
+ return IOState::Done;
+ }
+
+ /* Tries to write exactly toWrite - pos bytes from the buffer, starting at position pos.
+ Updates pos everytime a successful write occurs,
+ throws an std::runtime_error in case of IO error,
+ return Done when toWrite bytes have been written, needRead or needWrite if the IO operation
+ would block.
+ */
+ IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite)
+ {
+ if (buffer.size() < toWrite || pos >= toWrite) {
+ throw std::out_of_range("Calling tryWrite() with a too small buffer (" + std::to_string(buffer.size()) + ") for a write of " + std::to_string(toWrite - pos) + " bytes starting at " + std::to_string(pos));
+ }
+ if (d_conn) {
+ return d_conn->tryWrite(buffer, pos, toWrite);
+ }
+
+#ifdef MSG_FASTOPEN
+ if (d_fastOpen) {
+ int socketFlags = MSG_FASTOPEN;
+ size_t sent = sendMsgWithOptions(d_socket, reinterpret_cast<const char *>(&buffer.at(pos)), toWrite - pos, &d_remote, nullptr, 0, socketFlags);
+ if (sent > 0) {
+ d_fastOpen = false;
+ pos += sent;
+ }
+
+ if (pos < toWrite) {
+ return IOState::NeedWrite;
+ }
+
+ return IOState::Done;
+ }
+#endif /* MSG_FASTOPEN */
+
+ do {
+ ssize_t res = ::write(d_socket, reinterpret_cast<const char*>(&buffer.at(pos)), toWrite - pos);
+
+ if (res == 0) {
+ throw runtime_error("EOF while sending message");
+ }
+ if (res < 0) {
+ if (errno == EAGAIN || errno == EWOULDBLOCK || errno == ENOTCONN) {
+ return IOState::NeedWrite;
+ }
+ else {
+ throw std::runtime_error("Error while writing message: " + stringerror());
+ }
+ }
+
+ pos += static_cast<size_t>(res);
+ }
+ while (pos < toWrite);
+
+ return IOState::Done;
+ }
+
+ size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout)
+ {
+ if (d_conn) {
+ return d_conn->write(buffer, bufferSize, writeTimeout);
+ }
+
+#ifdef MSG_FASTOPEN
+ if (d_fastOpen) {
+ int socketFlags = MSG_FASTOPEN;
+ size_t sent = sendMsgWithOptions(d_socket, reinterpret_cast<const char *>(buffer), bufferSize, &d_remote, nullptr, 0, socketFlags);
+ if (sent > 0) {
+ d_fastOpen = false;
+ }
+
+ return sent;
+ }
+#endif /* MSG_FASTOPEN */
+
+ return writen2WithTimeout(d_socket, buffer, bufferSize, writeTimeout);
+ }
+
+ bool hasBufferedData() const
+ {
+ if (d_conn) {
+ return d_conn->hasBufferedData();
+ }
+ return false;
+ }
+
+ std::string getServerNameIndication() const
+ {
+ if (d_conn) {
+ return d_conn->getServerNameIndication();
+ }
+ return std::string();
+ }
+
+ std::vector<uint8_t> getNextProtocol() const
+ {
+ if (d_conn) {
+ return d_conn->getNextProtocol();
+ }
+ return std::vector<uint8_t>();
+ }
+
+ LibsslTLSVersion getTLSVersion() const
+ {
+ if (d_conn) {
+ return d_conn->getTLSVersion();
+ }
+ return LibsslTLSVersion::Unknown;
+ }
+
+ bool isTLS() const
+ {
+ return d_conn != nullptr;
+ }
+
+ bool hasTLSSessionBeenResumed() const
+ {
+ return d_conn && d_conn->hasSessionBeenResumed();
+ }
+
+ bool getResumedFromInactiveTicketKey() const
+ {
+ return d_conn && d_conn->getResumedFromInactiveTicketKey();
+ }
+
+ bool getUnknownTicketKey() const
+ {
+ return d_conn && d_conn->getUnknownTicketKey();
+ }
+
+ void setTLSSession(std::unique_ptr<TLSSession>& session)
+ {
+ if (d_conn != nullptr) {
+ d_conn->setSession(session);
+ }
+ }
+
+ std::vector<std::unique_ptr<TLSSession>> getTLSSessions()
+ {
+ if (!d_conn) {
+ throw std::runtime_error("Trying to get TLS sessions from a non-TLS handler");
+ }
+
+ return d_conn->getSessions();
+ }
+
+ bool isUsable() const
+ {
+ if (!d_conn) {
+ return isTCPSocketUsable(d_socket);
+ }
+ return d_conn->isUsable();
+ }
+
+ std::vector<int> getAsyncFDs()
+ {
+ if (!d_conn) {
+ return {};
+ }
+ return d_conn->getAsyncFDs();
+ }
+
+ const static bool s_disableConnectForUnitTests;
+
+private:
+ std::unique_ptr<TLSConnection> d_conn{nullptr};
+ ComboAddress d_remote;
+ int d_socket{-1};
+#ifdef MSG_FASTOPEN
+ bool d_fastOpen{false};
+#endif
+};
+
+struct TLSContextParameters
+{
+ std::string d_provider;
+ std::string d_ciphers;
+ std::string d_ciphers13;
+ std::string d_caStore;
+ bool d_validateCertificates{true};
+ bool d_releaseBuffers{true};
+ bool d_enableRenegotiation{false};
+ bool d_ktls{false};
+};
+
+std::shared_ptr<TLSCtx> getTLSContext(const TLSContextParameters& params);
+bool setupDoTProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx);