summaryrefslogtreecommitdiffstats
path: root/dnsdist-tcp-upstream.hh
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--dnsdist-tcp-upstream.hh179
1 files changed, 179 insertions, 0 deletions
diff --git a/dnsdist-tcp-upstream.hh b/dnsdist-tcp-upstream.hh
new file mode 100644
index 0000000..59c4df4
--- /dev/null
+++ b/dnsdist-tcp-upstream.hh
@@ -0,0 +1,179 @@
+#pragma once
+
+#include "dolog.hh"
+#include "dnsdist-tcp.hh"
+
+class TCPClientThreadData
+{
+public:
+ TCPClientThreadData():
+ localRespRuleActions(g_respruleactions.getLocal()), localCacheInsertedRespRuleActions(g_cacheInsertedRespRuleActions.getLocal()), mplexer(std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent()))
+ {
+ }
+
+ LocalHolders holders;
+ LocalStateHolder<vector<DNSDistResponseRuleAction>> localRespRuleActions;
+ LocalStateHolder<vector<DNSDistResponseRuleAction>> localCacheInsertedRespRuleActions;
+ std::unique_ptr<FDMultiplexer> mplexer{nullptr};
+ int crossProtocolResponsesPipe{-1};
+};
+
+class IncomingTCPConnectionState : public TCPQuerySender, public std::enable_shared_from_this<IncomingTCPConnectionState>
+{
+public:
+ IncomingTCPConnectionState(ConnectionInfo&& ci, TCPClientThreadData& threadData, const struct timeval& now): d_buffer(s_maxPacketCacheEntrySize), d_ci(std::move(ci)), d_handler(d_ci.fd, timeval{g_tcpRecvTimeout,0}, d_ci.cs->tlsFrontend ? d_ci.cs->tlsFrontend->getContext() : nullptr, now.tv_sec), d_connectionStartTime(now), d_ioState(make_unique<IOStateHandler>(*threadData.mplexer, d_ci.fd)), d_threadData(threadData), d_creatorThreadID(std::this_thread::get_id())
+ {
+ d_origDest.reset();
+ d_origDest.sin4.sin_family = d_ci.remote.sin4.sin_family;
+ socklen_t socklen = d_origDest.getSocklen();
+ if (getsockname(d_ci.fd, reinterpret_cast<sockaddr*>(&d_origDest), &socklen)) {
+ d_origDest = d_ci.cs->local;
+ }
+ /* belongs to the handler now */
+ d_ci.fd = -1;
+ d_proxiedDestination = d_origDest;
+ d_proxiedRemote = d_ci.remote;
+
+ /* we manage the release of the downstream connection ourselves */
+ d_releaseConnection = false;
+ }
+
+ IncomingTCPConnectionState(const IncomingTCPConnectionState& rhs) = delete;
+ IncomingTCPConnectionState& operator=(const IncomingTCPConnectionState& rhs) = delete;
+
+ ~IncomingTCPConnectionState();
+
+ void resetForNewQuery();
+
+ boost::optional<struct timeval> getClientReadTTD(struct timeval now) const
+ {
+ if (g_maxTCPConnectionDuration == 0 && g_tcpRecvTimeout == 0) {
+ return boost::none;
+ }
+
+ if (g_maxTCPConnectionDuration > 0) {
+ auto elapsed = now.tv_sec - d_connectionStartTime.tv_sec;
+ if (elapsed < 0 || (static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration)) {
+ return now;
+ }
+ auto remaining = g_maxTCPConnectionDuration - elapsed;
+ if (g_tcpRecvTimeout == 0 || remaining <= static_cast<size_t>(g_tcpRecvTimeout)) {
+ now.tv_sec += remaining;
+ return now;
+ }
+ }
+
+ now.tv_sec += g_tcpRecvTimeout;
+ return now;
+ }
+
+ boost::optional<struct timeval> getClientWriteTTD(const struct timeval& now) const
+ {
+ if (g_maxTCPConnectionDuration == 0 && g_tcpSendTimeout == 0) {
+ return boost::none;
+ }
+
+ struct timeval res = now;
+
+ if (g_maxTCPConnectionDuration > 0) {
+ auto elapsed = res.tv_sec - d_connectionStartTime.tv_sec;
+ if (elapsed < 0 || static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration) {
+ return res;
+ }
+ auto remaining = g_maxTCPConnectionDuration - elapsed;
+ if (g_tcpSendTimeout == 0 || remaining <= static_cast<size_t>(g_tcpSendTimeout)) {
+ res.tv_sec += remaining;
+ return res;
+ }
+ }
+
+ res.tv_sec += g_tcpSendTimeout;
+ return res;
+ }
+
+ bool maxConnectionDurationReached(unsigned int maxConnectionDuration, const struct timeval& now)
+ {
+ if (maxConnectionDuration) {
+ time_t curtime = now.tv_sec;
+ unsigned int elapsed = 0;
+ if (curtime > d_connectionStartTime.tv_sec) { // To prevent issues when time goes backward
+ elapsed = curtime - d_connectionStartTime.tv_sec;
+ }
+ if (elapsed >= maxConnectionDuration) {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ std::shared_ptr<TCPConnectionToBackend> getOwnedDownstreamConnection(const std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs);
+ std::shared_ptr<TCPConnectionToBackend> getDownstreamConnection(std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs, const struct timeval& now);
+ void registerOwnedDownstreamConnection(std::shared_ptr<TCPConnectionToBackend>& conn);
+
+ static size_t clearAllDownstreamConnections();
+
+ static void handleIO(std::shared_ptr<IncomingTCPConnectionState>& conn, const struct timeval& now);
+ static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param);
+ static void handleAsyncReady(int fd, FDMultiplexer::funcparam_t& param);
+ static void updateIO(std::shared_ptr<IncomingTCPConnectionState>& state, IOState newState, const struct timeval& now);
+
+ static IOState sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response);
+ static void queueResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response);
+static void handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bool write);
+
+ /* we take a copy of a shared pointer, not a reference, because the initial shared pointer might be released during the handling of the response */
+ void handleResponse(const struct timeval& now, TCPResponse&& response) override;
+ void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override;
+ void notifyIOError(InternalQueryState&& query, const struct timeval& now) override;
+
+ void handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response);
+
+ void terminateClientConnection();
+ void queueQuery(TCPQuery&& query);
+
+ bool canAcceptNewQueries(const struct timeval& now);
+
+ bool active() const override
+ {
+ return d_ioState != nullptr;
+ }
+
+ std::string toString() const
+ {
+ ostringstream o;
+ o << "Incoming TCP connection from "<<d_ci.remote.toStringWithPort()<<" over FD "<<d_handler.getDescriptor()<<", state is "<<(int)d_state<<", io state is "<<(d_ioState ? d_ioState->getState() : "empty")<<", queries count is "<<d_queriesCount<<", current queries count is "<<d_currentQueriesCount<<", "<<d_queuedResponses.size()<<" queued responses, "<<d_ownedConnectionsToBackend.size()<<" owned connections to a backend";
+ return o.str();
+ }
+
+ enum class State : uint8_t { doingHandshake, readingProxyProtocolHeader, waitingForQuery, readingQuerySize, readingQuery, sendingResponse, idle /* in case of XFR, we stop processing queries */ };
+
+ TCPResponse d_currentResponse;
+ std::map<std::shared_ptr<DownstreamState>, std::deque<std::shared_ptr<TCPConnectionToBackend>>> d_ownedConnectionsToBackend;
+ std::deque<TCPResponse> d_queuedResponses;
+ PacketBuffer d_buffer;
+ ConnectionInfo d_ci;
+ ComboAddress d_origDest;
+ ComboAddress d_proxiedRemote;
+ ComboAddress d_proxiedDestination;
+ TCPIOHandler d_handler;
+ struct timeval d_connectionStartTime;
+ struct timeval d_handshakeDoneTime;
+ struct timeval d_firstQuerySizeReadTime;
+ struct timeval d_querySizeReadTime;
+ struct timeval d_queryReadTime;
+ std::unique_ptr<IOStateHandler> d_ioState{nullptr};
+ std::unique_ptr<std::vector<ProxyProtocolValue>> d_proxyProtocolValues{nullptr};
+ TCPClientThreadData& d_threadData;
+ size_t d_currentPos{0};
+ size_t d_proxyProtocolNeed{0};
+ size_t d_queriesCount{0};
+ size_t d_currentQueriesCount{0};
+ std::thread::id d_creatorThreadID;
+ uint16_t d_querySize{0};
+ State d_state{State::doingHandshake};
+ bool d_isXFR{false};
+ bool d_proxyProtocolPayloadHasTLV{false};
+ bool d_lastIOBlocked{false};
+ bool d_hadErrors{false};
+};