summaryrefslogtreecommitdiffstats
path: root/dnsdist-tcp-upstream.hh
diff options
context:
space:
mode:
Diffstat (limited to 'dnsdist-tcp-upstream.hh')
-rw-r--r--dnsdist-tcp-upstream.hh74
1 files changed, 57 insertions, 17 deletions
diff --git a/dnsdist-tcp-upstream.hh b/dnsdist-tcp-upstream.hh
index 59c4df4..c6410df 100644
--- a/dnsdist-tcp-upstream.hh
+++ b/dnsdist-tcp-upstream.hh
@@ -2,6 +2,9 @@
#include "dolog.hh"
#include "dnsdist-tcp.hh"
+#include "dnsdist-tcp-downstream.hh"
+
+struct TCPCrossProtocolResponse;
class TCPClientThreadData
{
@@ -15,13 +18,19 @@ public:
LocalStateHolder<vector<DNSDistResponseRuleAction>> localRespRuleActions;
LocalStateHolder<vector<DNSDistResponseRuleAction>> localCacheInsertedRespRuleActions;
std::unique_ptr<FDMultiplexer> mplexer{nullptr};
- int crossProtocolResponsesPipe{-1};
+ pdns::channel::Receiver<ConnectionInfo> queryReceiver;
+ pdns::channel::Receiver<CrossProtocolQuery> crossProtocolQueryReceiver;
+ pdns::channel::Receiver<TCPCrossProtocolResponse> crossProtocolResponseReceiver;
+ pdns::channel::Sender<TCPCrossProtocolResponse> crossProtocolResponseSender;
};
class IncomingTCPConnectionState : public TCPQuerySender, public std::enable_shared_from_this<IncomingTCPConnectionState>
{
public:
- IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{g_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_connectionStartTime(now), d_ioState(make_unique<IOStateHandler>(*threadData.mplexer, d_ci.fd)), d_threadData(threadData), d_creatorThreadID(std::this_thread::get_id())
+ enum class QueryProcessingResult : uint8_t { Forwarded, TooSmall, InvalidHeaders, Dropped, SelfAnswered, NoBackend, Asynchronous };
+ enum class ProxyProtocolResult : uint8_t { Reading, Done, Error };
+
+ IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(sizeof(uint16_t)), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{g_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : (d_ci.cs->dohFrontend ? d_ci.cs->dohFrontend->d_tlsContext.getContext() : nullptr), now.tv_sec), d_connectionStartTime(now), d_ioState(make_unique<IOStateHandler>(*threadData.mplexer, d_ci.fd)), d_threadData(threadData), d_creatorThreadID(std::this_thread::get_id())
{
d_origDest.reset();
d_origDest.sin4.sin_family = d_ci.remote.sin4.sin_family;
@@ -41,7 +50,7 @@ public:
IncomingTCPConnectionState(const IncomingTCPConnectionState& rhs) = delete;
IncomingTCPConnectionState& operator=(const IncomingTCPConnectionState& rhs) = delete;
- ~IncomingTCPConnectionState();
+ virtual ~IncomingTCPConnectionState();
void resetForNewQuery();
@@ -107,30 +116,34 @@ public:
return false;
}
- std::shared_ptr<TCPConnectionToBackend> getOwnedDownstreamConnection(const std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs);
- std::shared_ptr<TCPConnectionToBackend> getDownstreamConnection(std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs, const struct timeval& now);
+ std::shared_ptr<TCPConnectionToBackend> getOwnedDownstreamConnection(const std::shared_ptr<DownstreamState>& backend, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs);
+ std::shared_ptr<TCPConnectionToBackend> getDownstreamConnection(std::shared_ptr<DownstreamState>& backend, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs, const struct timeval& now);
void registerOwnedDownstreamConnection(std::shared_ptr<TCPConnectionToBackend>& conn);
static size_t clearAllDownstreamConnections();
- static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& conn, const struct timeval& now);
- static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param);
- static void handleAsyncReady(int fd, FDMultiplexer::funcparam_t& param);
+ static void handleIOCallback(int desc, FDMultiplexer::funcparam_t& param);
+ static void handleAsyncReady(int desc, FDMultiplexer::funcparam_t& param);
static void updateIO(std::shared_ptr<IncomingTCPConnectionState>& state, IOState newState, const struct timeval& now);
- static IOState sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response);
- static void queueResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response);
-static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bool write);
+ static void queueResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response, bool fromBackend);
+ static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bool write);
- /* we take a copy of a shared pointer, not a reference, because the initial shared pointer might be released during the handling of the response */
- void handleResponse(const struct timeval& now, TCPResponse&& response) override;
+ virtual void handleIO();
+
+ QueryProcessingResult handleQuery(PacketBuffer&& query, const struct timeval& now, std::optional<int32_t> streamID);
+ virtual void handleResponse(const struct timeval& now, TCPResponse&& response) override;
+ virtual void notifyIOError(const struct timeval& now, TCPResponse&& response) override;
void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override;
- void notifyIOError(InternalQueryState&& query, const struct timeval& now) override;
+ virtual IOState sendResponse(const struct timeval& now, TCPResponse&& response);
+ void handleResponseSent(TCPResponse& currentResponse);
+ virtual IOState handleHandshake(const struct timeval& now);
+ void handleHandshakeDone(const struct timeval& now);
+ ProxyProtocolResult handleProxyProtocolPayload();
void handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response);
void terminateClientConnection();
- void queueQuery(TCPQuery&& query);
bool canAcceptNewQueries(const struct timeval& now);
@@ -138,6 +151,28 @@ static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bo
{
return d_ioState != nullptr;
}
+ bool isProxyPayloadOutsideTLS() const
+ {
+ if (!d_ci.cs->hasTLS()) {
+ return false;
+ }
+ return d_ci.cs->getTLSFrontend().d_proxyProtocolOutsideTLS;
+ }
+
+ virtual bool forwardViaUDPFirst() const
+ {
+ return false;
+ }
+ virtual std::unique_ptr<DOHUnitInterface> getDOHUnit(uint32_t streamID)
+ {
+ throw std::runtime_error("Getting a DOHUnit state from a generic TCP/DoT connection is not supported");
+ }
+ virtual void restoreDOHUnit(std::unique_ptr<DOHUnitInterface>&&)
+ {
+ throw std::runtime_error("Restoring a DOHUnit state to a generic TCP/DoT connection is not supported");
+ }
+
+ std::unique_ptr<CrossProtocolQuery> getCrossProtocolQuery(PacketBuffer&& query, InternalQueryState&& state, const std::shared_ptr<DownstreamState>& backend);
std::string toString() const
{
@@ -146,7 +181,12 @@ static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bo
return o.str();
}
- enum class State : uint8_t { doingHandshake, readingProxyProtocolHeader, waitingForQuery, readingQuerySize, readingQuery, sendingResponse, idle /* in case of XFR, we stop processing queries */ };
+ dnsdist::Protocol getProtocol() const;
+ IOState handleIncomingQueryReceived(const struct timeval& now);
+ void handleExceptionDuringIO(const std::exception& exp);
+ bool readIncomingQuery(const timeval& now, IOState& iostate);
+
+ enum class State : uint8_t { starting, doingHandshake, readingProxyProtocolHeader, waitingForQuery, readingQuerySize, readingQuery, sendingResponse, idle /* in case of XFR, we stop processing queries */ };
TCPResponse d_currentResponse;
std::map<std::shared_ptr<DownstreamState>, std::deque<std::shared_ptr<TCPConnectionToBackend>>> d_ownedConnectionsToBackend;
@@ -171,7 +211,7 @@ static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bo
size_t d_currentQueriesCount{0};
std::thread::id d_creatorThreadID;
uint16_t d_querySize{0};
- State d_state{State::doingHandshake};
+ State d_state{State::starting};
bool d_isXFR{false};
bool d_proxyProtocolPayloadHasTLV{false};
bool d_lastIOBlocked{false};