diff options
Diffstat (limited to 'tcpiohandler.hh')
-rw-r--r-- | tcpiohandler.hh | 38 |
1 files changed, 16 insertions, 22 deletions
diff --git a/tcpiohandler.hh b/tcpiohandler.hh index 88f0dc7..058d104 100644 --- a/tcpiohandler.hh +++ b/tcpiohandler.hh @@ -15,15 +15,13 @@ enum class IOState : uint8_t { Done, NeedRead, NeedWrite, Async }; class TLSSession { public: - virtual ~TLSSession() - { - } + virtual ~TLSSession() = default; }; class TLSConnection { public: - virtual ~TLSConnection() { } + virtual ~TLSConnection() = default; virtual void doHandshake() = 0; virtual IOState tryConnect(bool fastOpen, const ComboAddress& remote) = 0; virtual void connect(bool fastOpen, const ComboAddress& remote, const struct timeval& timeout) = 0; @@ -32,7 +30,6 @@ public: virtual size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout) = 0; virtual IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) = 0; virtual IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false) = 0; - virtual bool hasBufferedData() const = 0; virtual std::string getServerNameIndication() const = 0; virtual std::vector<uint8_t> getNextProtocol() const = 0; virtual LibsslTLSVersion getTLSVersion() const = 0; @@ -76,7 +73,7 @@ public: { d_rotatingTicketsKey.clear(); } - virtual ~TLSCtx() {} + virtual ~TLSCtx() = default; virtual std::unique_ptr<TLSConnection> getConnection(int socket, const struct timeval& timeout, time_t now) = 0; virtual std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout) = 0; virtual void rotateTicketsKey(time_t now) = 0; @@ -136,7 +133,9 @@ protected: class TLSFrontend { public: - TLSFrontend() + enum class ALPN : uint8_t { Unset, DoT, DoH }; + + TLSFrontend(ALPN alpn): d_alpn(alpn) { } @@ -223,7 +222,9 @@ public: TLSErrorCounters d_tlsCounters; ComboAddress d_addr; std::string d_provider; - + ALPN d_alpn{ALPN::Unset}; + /* whether the proxy protocol is inside or outside the TLS layer */ + bool d_proxyProtocolOutsideTLS{false}; protected: std::shared_ptr<TLSCtx> d_ctx{nullptr}; }; @@ -231,16 +232,16 @@ protected: class TCPIOHandler { public: - enum class Type : uint8_t { Client, Server }; - - TCPIOHandler(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout, std::shared_ptr<TLSCtx> ctx): d_socket(socket) + TCPIOHandler(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout, const std::shared_ptr<TLSCtx>& ctx) : + d_socket(socket) { if (ctx) { d_conn = ctx->getClientConnection(host, hostIsAddr, d_socket, timeout); } } - TCPIOHandler(int socket, const struct timeval& timeout, std::shared_ptr<TLSCtx> ctx, time_t now): d_socket(socket) + TCPIOHandler(int socket, const struct timeval& timeout, const std::shared_ptr<TLSCtx>& ctx, time_t now) : + d_socket(socket) { if (ctx) { d_conn = ctx->getConnection(d_socket, timeout, now); @@ -364,13 +365,13 @@ public: return Done when toRead bytes have been read, needRead or needWrite if the IO operation would block. */ - IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false) + IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false, bool bypassFilters=false) { if (buffer.size() < toRead || pos >= toRead) { throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead - pos) + " bytes starting at " + std::to_string(pos)); } - if (d_conn) { + if (!bypassFilters && d_conn) { return d_conn->tryRead(buffer, pos, toRead, allowIncomplete); } @@ -473,14 +474,6 @@ public: return writen2WithTimeout(d_socket, buffer, bufferSize, writeTimeout); } - bool hasBufferedData() const - { - if (d_conn) { - return d_conn->hasBufferedData(); - } - return false; - } - std::string getServerNameIndication() const { if (d_conn) { @@ -582,3 +575,4 @@ struct TLSContextParameters std::shared_ptr<TLSCtx> getTLSContext(const TLSContextParameters& params); bool setupDoTProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx); +bool setupDoHProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx); |