summaryrefslogtreecommitdiffstats
path: root/tcpiohandler.hh
diff options
context:
space:
mode:
Diffstat (limited to 'tcpiohandler.hh')
-rw-r--r--tcpiohandler.hh38
1 files changed, 16 insertions, 22 deletions
diff --git a/tcpiohandler.hh b/tcpiohandler.hh
index 88f0dc7..058d104 100644
--- a/tcpiohandler.hh
+++ b/tcpiohandler.hh
@@ -15,15 +15,13 @@ enum class IOState : uint8_t { Done, NeedRead, NeedWrite, Async };
class TLSSession
{
public:
- virtual ~TLSSession()
- {
- }
+ virtual ~TLSSession() = default;
};
class TLSConnection
{
public:
- virtual ~TLSConnection() { }
+ virtual ~TLSConnection() = default;
virtual void doHandshake() = 0;
virtual IOState tryConnect(bool fastOpen, const ComboAddress& remote) = 0;
virtual void connect(bool fastOpen, const ComboAddress& remote, const struct timeval& timeout) = 0;
@@ -32,7 +30,6 @@ public:
virtual size_t write(const void* buffer, size_t bufferSize, const struct timeval& writeTimeout) = 0;
virtual IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) = 0;
virtual IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false) = 0;
- virtual bool hasBufferedData() const = 0;
virtual std::string getServerNameIndication() const = 0;
virtual std::vector<uint8_t> getNextProtocol() const = 0;
virtual LibsslTLSVersion getTLSVersion() const = 0;
@@ -76,7 +73,7 @@ public:
{
d_rotatingTicketsKey.clear();
}
- virtual ~TLSCtx() {}
+ virtual ~TLSCtx() = default;
virtual std::unique_ptr<TLSConnection> getConnection(int socket, const struct timeval& timeout, time_t now) = 0;
virtual std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout) = 0;
virtual void rotateTicketsKey(time_t now) = 0;
@@ -136,7 +133,9 @@ protected:
class TLSFrontend
{
public:
- TLSFrontend()
+ enum class ALPN : uint8_t { Unset, DoT, DoH };
+
+ TLSFrontend(ALPN alpn): d_alpn(alpn)
{
}
@@ -223,7 +222,9 @@ public:
TLSErrorCounters d_tlsCounters;
ComboAddress d_addr;
std::string d_provider;
-
+ ALPN d_alpn{ALPN::Unset};
+ /* whether the proxy protocol is inside or outside the TLS layer */
+ bool d_proxyProtocolOutsideTLS{false};
protected:
std::shared_ptr<TLSCtx> d_ctx{nullptr};
};
@@ -231,16 +232,16 @@ protected:
class TCPIOHandler
{
public:
- enum class Type : uint8_t { Client, Server };
-
- TCPIOHandler(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout, std::shared_ptr<TLSCtx> ctx): d_socket(socket)
+ TCPIOHandler(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout, const std::shared_ptr<TLSCtx>& ctx) :
+ d_socket(socket)
{
if (ctx) {
d_conn = ctx->getClientConnection(host, hostIsAddr, d_socket, timeout);
}
}
- TCPIOHandler(int socket, const struct timeval& timeout, std::shared_ptr<TLSCtx> ctx, time_t now): d_socket(socket)
+ TCPIOHandler(int socket, const struct timeval& timeout, const std::shared_ptr<TLSCtx>& ctx, time_t now) :
+ d_socket(socket)
{
if (ctx) {
d_conn = ctx->getConnection(d_socket, timeout, now);
@@ -364,13 +365,13 @@ public:
return Done when toRead bytes have been read, needRead or needWrite if the IO operation
would block.
*/
- IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false)
+ IOState tryRead(PacketBuffer& buffer, size_t& pos, size_t toRead, bool allowIncomplete=false, bool bypassFilters=false)
{
if (buffer.size() < toRead || pos >= toRead) {
throw std::out_of_range("Calling tryRead() with a too small buffer (" + std::to_string(buffer.size()) + ") for a read of " + std::to_string(toRead - pos) + " bytes starting at " + std::to_string(pos));
}
- if (d_conn) {
+ if (!bypassFilters && d_conn) {
return d_conn->tryRead(buffer, pos, toRead, allowIncomplete);
}
@@ -473,14 +474,6 @@ public:
return writen2WithTimeout(d_socket, buffer, bufferSize, writeTimeout);
}
- bool hasBufferedData() const
- {
- if (d_conn) {
- return d_conn->hasBufferedData();
- }
- return false;
- }
-
std::string getServerNameIndication() const
{
if (d_conn) {
@@ -582,3 +575,4 @@ struct TLSContextParameters
std::shared_ptr<TLSCtx> getTLSContext(const TLSContextParameters& params);
bool setupDoTProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx);
+bool setupDoHProtocolNegotiation(std::shared_ptr<TLSCtx>& ctx);