diff options
Diffstat (limited to 'dnsdist-tcp-upstream.hh')
-rw-r--r-- | dnsdist-tcp-upstream.hh | 74 |
1 files changed, 57 insertions, 17 deletions
diff --git a/dnsdist-tcp-upstream.hh b/dnsdist-tcp-upstream.hh index 59c4df4..c6410df 100644 --- a/dnsdist-tcp-upstream.hh +++ b/dnsdist-tcp-upstream.hh @@ -2,6 +2,9 @@ #include "dolog.hh" #include "dnsdist-tcp.hh" +#include "dnsdist-tcp-downstream.hh" + +struct TCPCrossProtocolResponse; class TCPClientThreadData { @@ -15,13 +18,19 @@ public: LocalStateHolder<vector<DNSDistResponseRuleAction>> localRespRuleActions; LocalStateHolder<vector<DNSDistResponseRuleAction>> localCacheInsertedRespRuleActions; std::unique_ptr<FDMultiplexer> mplexer{nullptr}; - int crossProtocolResponsesPipe{-1}; + pdns::channel::Receiver<ConnectionInfo> queryReceiver; + pdns::channel::Receiver<CrossProtocolQuery> crossProtocolQueryReceiver; + pdns::channel::Receiver<TCPCrossProtocolResponse> crossProtocolResponseReceiver; + pdns::channel::Sender<TCPCrossProtocolResponse> crossProtocolResponseSender; }; class IncomingTCPConnectionState : public TCPQuerySender, public std::enable_shared_from_this<IncomingTCPConnectionState> { public: - IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{g_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_connectionStartTime(now), d_ioState(make_unique<IOStateHandler>(*threadData.mplexer, d_ci.fd)), d_threadData(threadData), d_creatorThreadID(std::this_thread::get_id()) + enum class QueryProcessingResult : uint8_t { Forwarded, TooSmall, InvalidHeaders, Dropped, SelfAnswered, NoBackend, Asynchronous }; + enum class ProxyProtocolResult : uint8_t { Reading, Done, Error }; + + IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(sizeof(uint16_t)), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{g_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : (d_ci.cs->dohFrontend ? d_ci.cs->dohFrontend->d_tlsContext.getContext() : nullptr), now.tv_sec), d_connectionStartTime(now), d_ioState(make_unique<IOStateHandler>(*threadData.mplexer, d_ci.fd)), d_threadData(threadData), d_creatorThreadID(std::this_thread::get_id()) { d_origDest.reset(); d_origDest.sin4.sin_family = d_ci.remote.sin4.sin_family; @@ -41,7 +50,7 @@ public: IncomingTCPConnectionState(const IncomingTCPConnectionState& rhs) = delete; IncomingTCPConnectionState& operator=(const IncomingTCPConnectionState& rhs) = delete; - ~IncomingTCPConnectionState(); + virtual ~IncomingTCPConnectionState(); void resetForNewQuery(); @@ -107,30 +116,34 @@ public: return false; } - std::shared_ptr<TCPConnectionToBackend> getOwnedDownstreamConnection(const std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs); - std::shared_ptr<TCPConnectionToBackend> getDownstreamConnection(std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs, const struct timeval& now); + std::shared_ptr<TCPConnectionToBackend> getOwnedDownstreamConnection(const std::shared_ptr<DownstreamState>& backend, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs); + std::shared_ptr<TCPConnectionToBackend> getDownstreamConnection(std::shared_ptr<DownstreamState>& backend, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs, const struct timeval& now); void registerOwnedDownstreamConnection(std::shared_ptr<TCPConnectionToBackend>& conn); static size_t clearAllDownstreamConnections(); - static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& conn, const struct timeval& now); - static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param); - static void handleAsyncReady(int fd, FDMultiplexer::funcparam_t& param); + static void handleIOCallback(int desc, FDMultiplexer::funcparam_t& param); + static void handleAsyncReady(int desc, FDMultiplexer::funcparam_t& param); static void updateIO(std::shared_ptr<IncomingTCPConnectionState>& state, IOState newState, const struct timeval& now); - static IOState sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response); - static void queueResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response); -static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bool write); + static void queueResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response, bool fromBackend); + static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bool write); - /* we take a copy of a shared pointer, not a reference, because the initial shared pointer might be released during the handling of the response */ - void handleResponse(const struct timeval& now, TCPResponse&& response) override; + virtual void handleIO(); + + QueryProcessingResult handleQuery(PacketBuffer&& query, const struct timeval& now, std::optional<int32_t> streamID); + virtual void handleResponse(const struct timeval& now, TCPResponse&& response) override; + virtual void notifyIOError(const struct timeval& now, TCPResponse&& response) override; void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override; - void notifyIOError(InternalQueryState&& query, const struct timeval& now) override; + virtual IOState sendResponse(const struct timeval& now, TCPResponse&& response); + void handleResponseSent(TCPResponse& currentResponse); + virtual IOState handleHandshake(const struct timeval& now); + void handleHandshakeDone(const struct timeval& now); + ProxyProtocolResult handleProxyProtocolPayload(); void handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response); void terminateClientConnection(); - void queueQuery(TCPQuery&& query); bool canAcceptNewQueries(const struct timeval& now); @@ -138,6 +151,28 @@ static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bo { return d_ioState != nullptr; } + bool isProxyPayloadOutsideTLS() const + { + if (!d_ci.cs->hasTLS()) { + return false; + } + return d_ci.cs->getTLSFrontend().d_proxyProtocolOutsideTLS; + } + + virtual bool forwardViaUDPFirst() const + { + return false; + } + virtual std::unique_ptr<DOHUnitInterface> getDOHUnit(uint32_t streamID) + { + throw std::runtime_error("Getting a DOHUnit state from a generic TCP/DoT connection is not supported"); + } + virtual void restoreDOHUnit(std::unique_ptr<DOHUnitInterface>&&) + { + throw std::runtime_error("Restoring a DOHUnit state to a generic TCP/DoT connection is not supported"); + } + + std::unique_ptr<CrossProtocolQuery> getCrossProtocolQuery(PacketBuffer&& query, InternalQueryState&& state, const std::shared_ptr<DownstreamState>& backend); std::string toString() const { @@ -146,7 +181,12 @@ static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bo return o.str(); } - enum class State : uint8_t { doingHandshake, readingProxyProtocolHeader, waitingForQuery, readingQuerySize, readingQuery, sendingResponse, idle /* in case of XFR, we stop processing queries */ }; + dnsdist::Protocol getProtocol() const; + IOState handleIncomingQueryReceived(const struct timeval& now); + void handleExceptionDuringIO(const std::exception& exp); + bool readIncomingQuery(const timeval& now, IOState& iostate); + + enum class State : uint8_t { starting, doingHandshake, readingProxyProtocolHeader, waitingForQuery, readingQuerySize, readingQuery, sendingResponse, idle /* in case of XFR, we stop processing queries */ }; TCPResponse d_currentResponse; std::map<std::shared_ptr<DownstreamState>, std::deque<std::shared_ptr<TCPConnectionToBackend>>> d_ownedConnectionsToBackend; @@ -171,7 +211,7 @@ static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bo size_t d_currentQueriesCount{0}; std::thread::id d_creatorThreadID; uint16_t d_querySize{0}; - State d_state{State::doingHandshake}; + State d_state{State::starting}; bool d_isXFR{false}; bool d_proxyProtocolPayloadHasTLV{false}; bool d_lastIOBlocked{false}; |