diff options
Diffstat (limited to '')
-rw-r--r-- | dnsdist-tcp.cc | 1541 |
1 files changed, 1541 insertions, 0 deletions
diff --git a/dnsdist-tcp.cc b/dnsdist-tcp.cc new file mode 100644 index 0000000..b927cbe --- /dev/null +++ b/dnsdist-tcp.cc @@ -0,0 +1,1541 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include <thread> +#include <netinet/tcp.h> +#include <queue> + +#include "dnsdist.hh" +#include "dnsdist-concurrent-connections.hh" +#include "dnsdist-ecs.hh" +#include "dnsdist-proxy-protocol.hh" +#include "dnsdist-rings.hh" +#include "dnsdist-tcp.hh" +#include "dnsdist-tcp-downstream.hh" +#include "dnsdist-downstream-connection.hh" +#include "dnsdist-tcp-upstream.hh" +#include "dnsdist-xpf.hh" +#include "dnsparser.hh" +#include "dolog.hh" +#include "gettime.hh" +#include "lock.hh" +#include "sstuff.hh" +#include "tcpiohandler.hh" +#include "tcpiohandler-mplexer.hh" +#include "threadname.hh" + +/* TCP: the grand design. + We forward 'messages' between clients and downstream servers. Messages are 65k bytes large, tops. + An answer might theoretically consist of multiple messages (for example, in the case of AXFR), initially + we will not go there. + + In a sense there is a strong symmetry between UDP and TCP, once a connection to a downstream has been setup. + This symmetry is broken because of head-of-line blocking within TCP though, necessitating additional connections + to guarantee performance. + + So the idea is to have a 'pool' of available downstream connections, and forward messages to/from them and never queue. + So whenever an answer comes in, we know where it needs to go. + + Let's start naively. +*/ + +size_t g_maxTCPQueriesPerConn{0}; +size_t g_maxTCPConnectionDuration{0}; + +#ifdef __linux__ +// On Linux this gives us 128k pending queries (default is 8192 queries), +// which should be enough to deal with huge spikes +size_t g_tcpInternalPipeBufferSize{1024*1024}; +uint64_t g_maxTCPQueuedConnections{10000}; +#else +size_t g_tcpInternalPipeBufferSize{0}; +uint64_t g_maxTCPQueuedConnections{1000}; +#endif + +int g_tcpRecvTimeout{2}; +int g_tcpSendTimeout{2}; +std::atomic<uint64_t> g_tcpStatesDumpRequested{0}; + +LockGuarded<std::map<ComboAddress, size_t, ComboAddress::addressOnlyLessThan>> dnsdist::IncomingConcurrentTCPConnectionsManager::s_tcpClientsConcurrentConnectionsCount; +size_t dnsdist::IncomingConcurrentTCPConnectionsManager::s_maxTCPConnectionsPerClient = 0; + +IncomingTCPConnectionState::~IncomingTCPConnectionState() +{ + dnsdist::IncomingConcurrentTCPConnectionsManager::accountClosedTCPConnection(d_ci.remote); + + if (d_ci.cs != nullptr) { + struct timeval now; + gettimeofday(&now, nullptr); + + auto diff = now - d_connectionStartTime; + d_ci.cs->updateTCPMetrics(d_queriesCount, diff.tv_sec * 1000.0 + diff.tv_usec / 1000.0); + } + + // would have been done when the object is destroyed anyway, + // but that way we make sure it's done before the ConnectionInfo is destroyed, + // closing the descriptor, instead of relying on the declaration order of the objects in the class + d_handler.close(); +} + +size_t IncomingTCPConnectionState::clearAllDownstreamConnections() +{ + return t_downstreamTCPConnectionsManager.clear(); +} + +std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getDownstreamConnection(std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs, const struct timeval& now) +{ + std::shared_ptr<TCPConnectionToBackend> downstream{nullptr}; + + downstream = getOwnedDownstreamConnection(ds, tlvs); + + if (!downstream) { + /* we don't have a connection to this backend owned yet, let's get one (it might not be a fresh one, though) */ + downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(d_threadData.mplexer, ds, now, std::string()); + if (ds->d_config.useProxyProtocol) { + registerOwnedDownstreamConnection(downstream); + } + } + + return downstream; +} + +static void tcpClientThread(int pipefd, int crossProtocolQueriesPipeFD, int crossProtocolResponsesListenPipeFD, int crossProtocolResponsesWritePipeFD, std::vector<ClientState*> tcpAcceptStates); + +TCPClientCollection::TCPClientCollection(size_t maxThreads, std::vector<ClientState*> tcpAcceptStates): d_tcpclientthreads(maxThreads), d_maxthreads(maxThreads) +{ + for (size_t idx = 0; idx < maxThreads; idx++) { + addTCPClientThread(tcpAcceptStates); + } +} + +void TCPClientCollection::addTCPClientThread(std::vector<ClientState*>& tcpAcceptStates) +{ + auto preparePipe = [](int fds[2], const std::string& type) -> bool { + if (pipe(fds) < 0) { + errlog("Error creating the TCP thread %s pipe: %s", type, stringerror()); + return false; + } + + if (!setNonBlocking(fds[0])) { + int err = errno; + close(fds[0]); + close(fds[1]); + errlog("Error setting the TCP thread %s pipe non-blocking: %s", type, stringerror(err)); + return false; + } + + if (!setNonBlocking(fds[1])) { + int err = errno; + close(fds[0]); + close(fds[1]); + errlog("Error setting the TCP thread %s pipe non-blocking: %s", type, stringerror(err)); + return false; + } + + if (g_tcpInternalPipeBufferSize > 0 && getPipeBufferSize(fds[0]) < g_tcpInternalPipeBufferSize) { + setPipeBufferSize(fds[0], g_tcpInternalPipeBufferSize); + } + + return true; + }; + + int pipefds[2] = { -1, -1}; + if (!preparePipe(pipefds, "communication")) { + return; + } + + int crossProtocolQueriesFDs[2] = { -1, -1}; + if (!preparePipe(crossProtocolQueriesFDs, "cross-protocol queries")) { + return; + } + + int crossProtocolResponsesFDs[2] = { -1, -1}; + if (!preparePipe(crossProtocolResponsesFDs, "cross-protocol responses")) { + return; + } + + vinfolog("Adding TCP Client thread"); + + { + if (d_numthreads >= d_tcpclientthreads.size()) { + vinfolog("Adding a new TCP client thread would exceed the vector size (%d/%d), skipping. Consider increasing the maximum amount of TCP client threads with setMaxTCPClientThreads() in the configuration.", d_numthreads.load(), d_tcpclientthreads.size()); + close(crossProtocolQueriesFDs[0]); + close(crossProtocolQueriesFDs[1]); + close(crossProtocolResponsesFDs[0]); + close(crossProtocolResponsesFDs[1]); + close(pipefds[0]); + close(pipefds[1]); + return; + } + + /* from now on this side of the pipe will be managed by that object, + no need to worry about it */ + TCPWorkerThread worker(pipefds[1], crossProtocolQueriesFDs[1], crossProtocolResponsesFDs[1]); + try { + std::thread t1(tcpClientThread, pipefds[0], crossProtocolQueriesFDs[0], crossProtocolResponsesFDs[0], crossProtocolResponsesFDs[1], tcpAcceptStates); + t1.detach(); + } + catch (const std::runtime_error& e) { + /* the thread creation failed, don't leak */ + errlog("Error creating a TCP thread: %s", e.what()); + close(pipefds[0]); + close(crossProtocolQueriesFDs[0]); + close(crossProtocolResponsesFDs[0]); + return; + } + + d_tcpclientthreads.at(d_numthreads) = std::move(worker); + ++d_numthreads; + } +} + +std::unique_ptr<TCPClientCollection> g_tcpclientthreads; + +static IOState sendQueuedResponses(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now) +{ + IOState result = IOState::Done; + + while (state->active() && !state->d_queuedResponses.empty()) { + DEBUGLOG("queue size is "<<state->d_queuedResponses.size()<<", sending the next one"); + TCPResponse resp = std::move(state->d_queuedResponses.front()); + state->d_queuedResponses.pop_front(); + state->d_state = IncomingTCPConnectionState::State::idle; + result = state->sendResponse(state, now, std::move(resp)); + if (result != IOState::Done) { + return result; + } + } + + state->d_state = IncomingTCPConnectionState::State::idle; + return IOState::Done; +} + +static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& state, TCPResponse& currentResponse) +{ + if (currentResponse.d_idstate.qtype == QType::AXFR || currentResponse.d_idstate.qtype == QType::IXFR) { + return; + } + + --state->d_currentQueriesCount; + + const auto& ds = currentResponse.d_connection ? currentResponse.d_connection->getDS() : currentResponse.d_ds; + if (currentResponse.d_idstate.selfGenerated == false && ds) { + const auto& ids = currentResponse.d_idstate; + double udiff = ids.queryRealTime.udiff(); + vinfolog("Got answer from %s, relayed to %s (%s, %d bytes), took %f usec", ds->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), currentResponse.d_buffer.size(), udiff); + + auto backendProtocol = ds->getProtocol(); + if (backendProtocol == dnsdist::Protocol::DoUDP) { + backendProtocol = dnsdist::Protocol::DoTCP; + } + ::handleResponseSent(ids, udiff, state->d_ci.remote, ds->d_config.remote, static_cast<unsigned int>(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, backendProtocol, true); + } else { + const auto& ids = currentResponse.d_idstate; + ::handleResponseSent(ids, 0., state->d_ci.remote, ComboAddress(), static_cast<unsigned int>(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, ids.protocol, false); + } + + currentResponse.d_buffer.clear(); + currentResponse.d_connection.reset(); +} + +static void prependSizeToTCPQuery(PacketBuffer& buffer, size_t proxyProtocolPayloadSize) +{ + if (buffer.size() <= proxyProtocolPayloadSize) { + throw std::runtime_error("The payload size is smaller or equal to the buffer size"); + } + + uint16_t queryLen = proxyProtocolPayloadSize > 0 ? (buffer.size() - proxyProtocolPayloadSize) : buffer.size(); + const uint8_t sizeBytes[] = { static_cast<uint8_t>(queryLen / 256), static_cast<uint8_t>(queryLen % 256) }; + /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes + that could occur if we had to deal with the size during the processing, + especially alignment issues */ + buffer.insert(buffer.begin() + proxyProtocolPayloadSize, sizeBytes, sizeBytes + 2); +} + +bool IncomingTCPConnectionState::canAcceptNewQueries(const struct timeval& now) +{ + if (d_hadErrors) { + DEBUGLOG("not accepting new queries because we encountered some error during the processing already"); + return false; + } + + if (d_currentQueriesCount >= d_ci.cs->d_maxInFlightQueriesPerConn) { + DEBUGLOG("not accepting new queries because we already have "<<d_currentQueriesCount<<" out of "<<d_ci.cs->d_maxInFlightQueriesPerConn); + return false; + } + + if (g_maxTCPQueriesPerConn && d_queriesCount > g_maxTCPQueriesPerConn) { + vinfolog("not accepting new queries from %s because it reached the maximum number of queries per conn (%d / %d)", d_ci.remote.toStringWithPort(), d_queriesCount, g_maxTCPQueriesPerConn); + return false; + } + + if (maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) { + vinfolog("not accepting new queries from %s because it reached the maximum TCP connection duration", d_ci.remote.toStringWithPort()); + return false; + } + + return true; +} + +void IncomingTCPConnectionState::resetForNewQuery() +{ + d_buffer.resize(sizeof(uint16_t)); + d_currentPos = 0; + d_querySize = 0; + d_state = State::waitingForQuery; +} + +std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getOwnedDownstreamConnection(const std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs) +{ + auto it = d_ownedConnectionsToBackend.find(ds); + if (it == d_ownedConnectionsToBackend.end()) { + DEBUGLOG("no owned connection found for "<<ds->getName()); + return nullptr; + } + + for (auto& conn : it->second) { + if (conn->canBeReused(true) && conn->matchesTLVs(tlvs)) { + DEBUGLOG("Got one owned connection accepting more for "<<ds->getName()); + conn->setReused(); + return conn; + } + DEBUGLOG("not accepting more for "<<ds->getName()); + } + + return nullptr; +} + +void IncomingTCPConnectionState::registerOwnedDownstreamConnection(std::shared_ptr<TCPConnectionToBackend>& conn) +{ + d_ownedConnectionsToBackend[conn->getDS()].push_front(conn); +} + +/* called when the buffer has been set and the rules have been processed, and only from handleIO (sometimes indirectly via handleQuery) */ +IOState IncomingTCPConnectionState::sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response) +{ + state->d_state = IncomingTCPConnectionState::State::sendingResponse; + + uint16_t responseSize = static_cast<uint16_t>(response.d_buffer.size()); + const uint8_t sizeBytes[] = { static_cast<uint8_t>(responseSize / 256), static_cast<uint8_t>(responseSize % 256) }; + /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes + that could occur if we had to deal with the size during the processing, + especially alignment issues */ + response.d_buffer.insert(response.d_buffer.begin(), sizeBytes, sizeBytes + 2); + state->d_currentPos = 0; + state->d_currentResponse = std::move(response); + + try { + auto iostate = state->d_handler.tryWrite(state->d_currentResponse.d_buffer, state->d_currentPos, state->d_currentResponse.d_buffer.size()); + if (iostate == IOState::Done) { + DEBUGLOG("response sent from "<<__PRETTY_FUNCTION__); + handleResponseSent(state, state->d_currentResponse); + return iostate; + } else { + state->d_lastIOBlocked = true; + DEBUGLOG("partial write"); + return iostate; + } + } + catch (const std::exception& e) { + vinfolog("Closing TCP client connection with %s: %s", state->d_ci.remote.toStringWithPort(), e.what()); + DEBUGLOG("Closing TCP client connection: "<<e.what()); + ++state->d_ci.cs->tcpDiedSendingResponse; + + state->terminateClientConnection(); + + return IOState::Done; + } +} + +void IncomingTCPConnectionState::terminateClientConnection() +{ + DEBUGLOG("terminating client connection"); + d_queuedResponses.clear(); + /* we have already released idle connections that could be reused, + we don't care about the ones still waiting for responses */ + for (auto& backend : d_ownedConnectionsToBackend) { + for (auto& conn : backend.second) { + conn->release(); + } + } + d_ownedConnectionsToBackend.clear(); + + /* meaning we will no longer be 'active' when the backend + response or timeout comes in */ + d_ioState.reset(); + + /* if we do have remaining async descriptors associated with this TLS + connection, we need to defer the destruction of the TLS object until + the engine has reported back, otherwise we have a use-after-free.. */ + auto afds = d_handler.getAsyncFDs(); + if (afds.empty()) { + d_handler.close(); + } + else { + /* we might already be waiting, but we might also not because sometimes we have already been + notified via the descriptor, not received Async again, but the async job still exists.. */ + auto state = shared_from_this(); + for (const auto fd : afds) { + try { + state->d_threadData.mplexer->addReadFD(fd, handleAsyncReady, state); + } + catch (...) { + } + } + + } +} + +void IncomingTCPConnectionState::queueResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response) +{ + // queue response + state->d_queuedResponses.push_back(std::move(response)); + DEBUGLOG("queueing response, state is "<<(int)state->d_state<<", queue size is now "<<state->d_queuedResponses.size()); + + // when the response comes from a backend, there is a real possibility that we are currently + // idle, and thus not trying to send the response right away would make our ref count go to 0. + // Even if we are waiting for a query, we will not wake up before the new query arrives or a + // timeout occurs + if (state->d_state == IncomingTCPConnectionState::State::idle || + state->d_state == IncomingTCPConnectionState::State::waitingForQuery) { + auto iostate = sendQueuedResponses(state, now); + + if (iostate == IOState::Done && state->active()) { + if (state->canAcceptNewQueries(now)) { + state->resetForNewQuery(); + state->d_state = IncomingTCPConnectionState::State::waitingForQuery; + iostate = IOState::NeedRead; + } + else { + state->d_state = IncomingTCPConnectionState::State::idle; + } + } + + // for the same reason we need to update the state right away, nobody will do that for us + if (state->active()) { + updateIO(state, iostate, now); + } + } +} + +void IncomingTCPConnectionState::handleAsyncReady(int fd, FDMultiplexer::funcparam_t& param) +{ + auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param); + + /* If we are here, the async jobs for this SSL* are finished + so we should be able to remove all FDs */ + auto afds = state->d_handler.getAsyncFDs(); + for (const auto afd : afds) { + try { + state->d_threadData.mplexer->removeReadFD(afd); + } + catch (...) { + } + } + + if (state->active()) { + /* and now we restart our own I/O state machine */ + struct timeval now; + gettimeofday(&now, nullptr); + handleIO(state, now); + } + else { + /* we were only waiting for the engine to come back, + to prevent a use-after-free */ + state->d_handler.close(); + } +} + +void IncomingTCPConnectionState::updateIO(std::shared_ptr<IncomingTCPConnectionState>& state, IOState newState, const struct timeval& now) +{ + if (newState == IOState::Async) { + auto fds = state->d_handler.getAsyncFDs(); + for (const auto fd : fds) { + state->d_threadData.mplexer->addReadFD(fd, handleAsyncReady, state); + } + state->d_ioState->update(IOState::Done, handleIOCallback, state); + } + else { + state->d_ioState->update(newState, handleIOCallback, state, newState == IOState::NeedWrite ? state->getClientWriteTTD(now) : state->getClientReadTTD(now)); + } +} + +/* called from the backend code when a new response has been received */ +void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPResponse&& response) +{ + if (std::this_thread::get_id() != d_creatorThreadID) { + handleCrossProtocolResponse(now, std::move(response)); + return; + } + + std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this(); + + if (!response.isAsync() && response.d_connection && response.d_connection->getDS() && response.d_connection->getDS()->d_config.useProxyProtocol) { + // if we have added a TCP Proxy Protocol payload to a connection, don't release it to the general pool as no one else will be able to use it anyway + if (!response.d_connection->willBeReusable(true)) { + // if it can't be reused even by us, well + const auto connIt = state->d_ownedConnectionsToBackend.find(response.d_connection->getDS()); + if (connIt != state->d_ownedConnectionsToBackend.end()) { + auto& list = connIt->second; + + for (auto it = list.begin(); it != list.end(); ++it) { + if (*it == response.d_connection) { + try { + response.d_connection->release(); + } + catch (const std::exception& e) { + vinfolog("Error releasing connection: %s", e.what()); + } + list.erase(it); + break; + } + } + } + } + } + + if (response.d_buffer.size() < sizeof(dnsheader)) { + state->terminateClientConnection(); + return; + } + + if (!response.isAsync()) { + try { + auto& ids = response.d_idstate; + unsigned int qnameWireLength; + if (!response.d_connection || !responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, response.d_connection->getDS(), qnameWireLength)) { + state->terminateClientConnection(); + return; + } + + if (response.d_connection->getDS()) { + ++response.d_connection->getDS()->responses; + } + + DNSResponse dr(ids, response.d_buffer, response.d_connection->getDS()); + dr.d_incomingTCPState = state; + + memcpy(&response.d_cleartextDH, dr.getHeader(), sizeof(response.d_cleartextDH)); + + if (!processResponse(response.d_buffer, *state->d_threadData.localRespRuleActions, *state->d_threadData.localCacheInsertedRespRuleActions, dr, false)) { + state->terminateClientConnection(); + return; + } + + if (dr.isAsynchronous()) { + /* we are done for now */ + return; + } + } + catch (const std::exception& e) { + vinfolog("Unexpected exception while handling response from backend: %s", e.what()); + state->terminateClientConnection(); + return; + } + } + + ++g_stats.responses; + ++state->d_ci.cs->responses; + + queueResponse(state, now, std::move(response)); +} + +struct TCPCrossProtocolResponse +{ + TCPCrossProtocolResponse(TCPResponse&& response, std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now): d_response(std::move(response)), d_state(state), d_now(now) + { + } + + TCPResponse d_response; + std::shared_ptr<IncomingTCPConnectionState> d_state; + struct timeval d_now; +}; + +class TCPCrossProtocolQuery : public CrossProtocolQuery +{ +public: + TCPCrossProtocolQuery(PacketBuffer&& buffer, InternalQueryState&& ids, std::shared_ptr<DownstreamState> ds, std::shared_ptr<IncomingTCPConnectionState> sender): CrossProtocolQuery(InternalQuery(std::move(buffer), std::move(ids)), ds), d_sender(std::move(sender)) + { + proxyProtocolPayloadSize = 0; + } + + ~TCPCrossProtocolQuery() + { + } + + std::shared_ptr<TCPQuerySender> getTCPQuerySender() override + { + return d_sender; + } + + DNSQuestion getDQ() override + { + auto& ids = query.d_idstate; + DNSQuestion dq(ids, query.d_buffer); + dq.d_incomingTCPState = d_sender; + return dq; + } + + DNSResponse getDR() override + { + auto& ids = query.d_idstate; + DNSResponse dr(ids, query.d_buffer, downstream); + dr.d_incomingTCPState = d_sender; + return dr; + } + +private: + std::shared_ptr<IncomingTCPConnectionState> d_sender; +}; + +std::unique_ptr<CrossProtocolQuery> getTCPCrossProtocolQueryFromDQ(DNSQuestion& dq) +{ + auto state = dq.getIncomingTCPState(); + if (!state) { + throw std::runtime_error("Trying to create a TCP cross protocol query without a valid TCP state"); + } + + dq.ids.origID = dq.getHeader()->id; + return std::make_unique<TCPCrossProtocolQuery>(std::move(dq.getMutableData()), std::move(dq.ids), nullptr, std::move(state)); +} + +void IncomingTCPConnectionState::handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response) +{ + if (d_threadData.crossProtocolResponsesPipe == -1) { + throw std::runtime_error("Invalid pipe descriptor in TCP Cross Protocol Query Sender"); + } + + std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this(); + auto ptr = new TCPCrossProtocolResponse(std::move(response), state, now); + static_assert(sizeof(ptr) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail"); + ssize_t sent = write(d_threadData.crossProtocolResponsesPipe, &ptr, sizeof(ptr)); + if (sent != sizeof(ptr)) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + ++g_stats.tcpCrossProtocolResponsePipeFull; + vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because the pipe is full"); + } + else { + vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because we couldn't write to the pipe: %s", stringerror()); + } + delete ptr; + } +} + +static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now) +{ + if (state->d_querySize < sizeof(dnsheader)) { + ++g_stats.nonCompliantQueries; + ++state->d_ci.cs->nonCompliantQueries; + state->terminateClientConnection(); + return; + } + + ++state->d_queriesCount; + ++state->d_ci.cs->queries; + ++g_stats.queries; + + if (state->d_handler.isTLS()) { + auto tlsVersion = state->d_handler.getTLSVersion(); + switch (tlsVersion) { + case LibsslTLSVersion::TLS10: + ++state->d_ci.cs->tls10queries; + break; + case LibsslTLSVersion::TLS11: + ++state->d_ci.cs->tls11queries; + break; + case LibsslTLSVersion::TLS12: + ++state->d_ci.cs->tls12queries; + break; + case LibsslTLSVersion::TLS13: + ++state->d_ci.cs->tls13queries; + break; + default: + ++state->d_ci.cs->tlsUnknownqueries; + } + } + + InternalQueryState ids; + ids.origDest = state->d_proxiedDestination; + ids.origRemote = state->d_proxiedRemote; + ids.cs = state->d_ci.cs; + ids.queryRealTime.start(); + + auto dnsCryptResponse = checkDNSCryptQuery(*state->d_ci.cs, state->d_buffer, ids.dnsCryptQuery, ids.queryRealTime.d_start.tv_sec, true); + if (dnsCryptResponse) { + TCPResponse response; + state->d_state = IncomingTCPConnectionState::State::idle; + ++state->d_currentQueriesCount; + state->queueResponse(state, now, std::move(response)); + return; + } + + { + /* this pointer will be invalidated the second the buffer is resized, don't hold onto it! */ + auto* dh = reinterpret_cast<dnsheader*>(state->d_buffer.data()); + if (!checkQueryHeaders(dh, *state->d_ci.cs)) { + state->terminateClientConnection(); + return; + } + + if (dh->qdcount == 0) { + TCPResponse response; + dh->rcode = RCode::NotImp; + dh->qr = true; + response.d_idstate.selfGenerated = true; + response.d_buffer = std::move(state->d_buffer); + state->d_state = IncomingTCPConnectionState::State::idle; + ++state->d_currentQueriesCount; + state->queueResponse(state, now, std::move(response)); + return; + } + } + + ids.qname = DNSName(reinterpret_cast<const char*>(state->d_buffer.data()), state->d_buffer.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass); + ids.protocol = dnsdist::Protocol::DoTCP; + if (ids.dnsCryptQuery) { + ids.protocol = dnsdist::Protocol::DNSCryptTCP; + } + else if (state->d_handler.isTLS()) { + ids.protocol = dnsdist::Protocol::DoT; + } + + DNSQuestion dq(ids, state->d_buffer); + const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader()); + ids.origFlags = *flags; + dq.d_incomingTCPState = state; + dq.sni = state->d_handler.getServerNameIndication(); + + if (state->d_proxyProtocolValues) { + /* we need to copy them, because the next queries received on that connection will + need to get the _unaltered_ values */ + dq.proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(*state->d_proxyProtocolValues); + } + + if (dq.ids.qtype == QType::AXFR || dq.ids.qtype == QType::IXFR) { + dq.ids.skipCache = true; + } + + std::shared_ptr<DownstreamState> ds; + auto result = processQuery(dq, state->d_threadData.holders, ds); + + if (result == ProcessQueryResult::Drop) { + state->terminateClientConnection(); + return; + } + else if (result == ProcessQueryResult::Asynchronous) { + /* we are done for now */ + ++state->d_currentQueriesCount; + return; + } + + // the buffer might have been invalidated by now + const dnsheader* dh = dq.getHeader(); + if (result == ProcessQueryResult::SendAnswer) { + TCPResponse response; + memcpy(&response.d_cleartextDH, dh, sizeof(response.d_cleartextDH)); + response.d_idstate = std::move(ids); + response.d_idstate.origID = dh->id; + response.d_idstate.selfGenerated = true; + response.d_idstate.cs = state->d_ci.cs; + response.d_buffer = std::move(state->d_buffer); + + state->d_state = IncomingTCPConnectionState::State::idle; + ++state->d_currentQueriesCount; + state->queueResponse(state, now, std::move(response)); + return; + } + + if (result != ProcessQueryResult::PassToBackend || ds == nullptr) { + state->terminateClientConnection(); + return; + } + + dq.ids.origID = dh->id; + + ++state->d_currentQueriesCount; + + std::string proxyProtocolPayload; + if (ds->isDoH()) { + vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", ids.qname.toLogString(), QType(ids.qtype).toString(), state->d_proxiedRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), state->d_buffer.size(), ds->getNameWithAddr()); + + /* we need to do this _before_ creating the cross protocol query because + after that the buffer will have been moved */ + if (ds->d_config.useProxyProtocol) { + proxyProtocolPayload = getProxyProtocolPayload(dq); + } + + auto cpq = std::make_unique<TCPCrossProtocolQuery>(std::move(state->d_buffer), std::move(ids), ds, state); + cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); + + ds->passCrossProtocolQuery(std::move(cpq)); + return; + } + + prependSizeToTCPQuery(state->d_buffer, 0); + + auto downstreamConnection = state->getDownstreamConnection(ds, dq.proxyProtocolValues, now); + + if (ds->d_config.useProxyProtocol) { + /* if we ever sent a TLV over a connection, we can never go back */ + if (!state->d_proxyProtocolPayloadHasTLV) { + state->d_proxyProtocolPayloadHasTLV = dq.proxyProtocolValues && !dq.proxyProtocolValues->empty(); + } + + proxyProtocolPayload = getProxyProtocolPayload(dq); + } + + if (dq.proxyProtocolValues) { + downstreamConnection->setProxyProtocolValuesSent(std::move(dq.proxyProtocolValues)); + } + + TCPQuery query(std::move(state->d_buffer), std::move(ids)); + query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); + + vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), state->d_proxiedRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), query.d_buffer.size(), ds->getNameWithAddr()); + std::shared_ptr<TCPQuerySender> incoming = state; + downstreamConnection->queueQuery(incoming, std::move(query)); +} + +void IncomingTCPConnectionState::handleIOCallback(int fd, FDMultiplexer::funcparam_t& param) +{ + auto conn = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param); + if (fd != conn->d_handler.getDescriptor()) { + throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->d_handler.getDescriptor())); + } + + struct timeval now; + gettimeofday(&now, nullptr); + handleIO(conn, now); +} + +void IncomingTCPConnectionState::handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now) +{ + // why do we loop? Because the TLS layer does buffering, and thus can have data ready to read + // even though the underlying socket is not ready, so we need to actually ask for the data first + IOState iostate = IOState::Done; + do { + iostate = IOState::Done; + IOStateGuard ioGuard(state->d_ioState); + + if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) { + vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort()); + // will be handled by the ioGuard + //handleNewIOState(state, IOState::Done, fd, handleIOCallback); + return; + } + + state->d_lastIOBlocked = false; + + try { + if (state->d_state == IncomingTCPConnectionState::State::doingHandshake) { + DEBUGLOG("doing handshake"); + iostate = state->d_handler.tryHandshake(); + if (iostate == IOState::Done) { + DEBUGLOG("handshake done"); + if (state->d_handler.isTLS()) { + if (!state->d_handler.hasTLSSessionBeenResumed()) { + ++state->d_ci.cs->tlsNewSessions; + } + else { + ++state->d_ci.cs->tlsResumptions; + } + if (state->d_handler.getResumedFromInactiveTicketKey()) { + ++state->d_ci.cs->tlsInactiveTicketKey; + } + if (state->d_handler.getUnknownTicketKey()) { + ++state->d_ci.cs->tlsUnknownTicketKey; + } + } + + state->d_handshakeDoneTime = now; + if (expectProxyProtocolFrom(state->d_ci.remote)) { + state->d_state = IncomingTCPConnectionState::State::readingProxyProtocolHeader; + state->d_buffer.resize(s_proxyProtocolMinimumHeaderSize); + state->d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize; + } + else { + state->d_state = IncomingTCPConnectionState::State::readingQuerySize; + } + } + else { + state->d_lastIOBlocked = true; + } + } + + if (!state->d_lastIOBlocked && state->d_state == IncomingTCPConnectionState::State::readingProxyProtocolHeader) { + do { + DEBUGLOG("reading proxy protocol header"); + iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_proxyProtocolNeed); + if (iostate == IOState::Done) { + state->d_buffer.resize(state->d_currentPos); + ssize_t remaining = isProxyHeaderComplete(state->d_buffer); + if (remaining == 0) { + vinfolog("Unable to consume proxy protocol header in packet from TCP client %s", state->d_ci.remote.toStringWithPort()); + ++g_stats.proxyProtocolInvalid; + break; + } + else if (remaining < 0) { + state->d_proxyProtocolNeed += -remaining; + state->d_buffer.resize(state->d_currentPos + state->d_proxyProtocolNeed); + /* we need to keep reading, since we might have buffered data */ + iostate = IOState::NeedRead; + } + else { + /* proxy header received */ + std::vector<ProxyProtocolValue> proxyProtocolValues; + if (!handleProxyProtocol(state->d_ci.remote, true, *state->d_threadData.holders.acl, state->d_buffer, state->d_proxiedRemote, state->d_proxiedDestination, proxyProtocolValues)) { + vinfolog("Error handling the Proxy Protocol received from TCP client %s", state->d_ci.remote.toStringWithPort()); + break; + } + + if (!proxyProtocolValues.empty()) { + state->d_proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(std::move(proxyProtocolValues)); + } + + state->d_state = IncomingTCPConnectionState::State::readingQuerySize; + state->d_buffer.resize(sizeof(uint16_t)); + state->d_currentPos = 0; + state->d_proxyProtocolNeed = 0; + break; + } + } + else { + state->d_lastIOBlocked = true; + } + } + while (state->active() && !state->d_lastIOBlocked); + } + + if (!state->d_lastIOBlocked && (state->d_state == IncomingTCPConnectionState::State::waitingForQuery || + state->d_state == IncomingTCPConnectionState::State::readingQuerySize)) { + DEBUGLOG("reading query size"); + iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, sizeof(uint16_t)); + if (state->d_currentPos > 0) { + /* if we got at least one byte, we can't go around sending responses */ + state->d_state = IncomingTCPConnectionState::State::readingQuerySize; + } + + if (iostate == IOState::Done) { + DEBUGLOG("query size received"); + state->d_state = IncomingTCPConnectionState::State::readingQuery; + state->d_querySizeReadTime = now; + if (state->d_queriesCount == 0) { + state->d_firstQuerySizeReadTime = now; + } + state->d_querySize = state->d_buffer.at(0) * 256 + state->d_buffer.at(1); + if (state->d_querySize < sizeof(dnsheader)) { + /* go away */ + state->terminateClientConnection(); + return; + } + + /* allocate a bit more memory to be able to spoof the content, get an answer from the cache + or to add ECS without allocating a new buffer */ + state->d_buffer.resize(std::max(state->d_querySize + static_cast<size_t>(512), s_maxPacketCacheEntrySize)); + state->d_currentPos = 0; + } + else { + state->d_lastIOBlocked = true; + } + } + + if (!state->d_lastIOBlocked && state->d_state == IncomingTCPConnectionState::State::readingQuery) { + DEBUGLOG("reading query"); + iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_querySize); + if (iostate == IOState::Done) { + DEBUGLOG("query received"); + state->d_buffer.resize(state->d_querySize); + + state->d_state = IncomingTCPConnectionState::State::idle; + handleQuery(state, now); + /* the state might have been updated in the meantime, we don't want to override it + in that case */ + if (state->active() && state->d_state != IncomingTCPConnectionState::State::idle) { + if (state->d_ioState->isWaitingForRead()) { + iostate = IOState::NeedRead; + } + else if (state->d_ioState->isWaitingForWrite()) { + iostate = IOState::NeedWrite; + } + else { + iostate = IOState::Done; + } + } + } + else { + state->d_lastIOBlocked = true; + } + } + + if (!state->d_lastIOBlocked && state->d_state == IncomingTCPConnectionState::State::sendingResponse) { + DEBUGLOG("sending response"); + iostate = state->d_handler.tryWrite(state->d_currentResponse.d_buffer, state->d_currentPos, state->d_currentResponse.d_buffer.size()); + if (iostate == IOState::Done) { + DEBUGLOG("response sent from "<<__PRETTY_FUNCTION__); + handleResponseSent(state, state->d_currentResponse); + state->d_state = IncomingTCPConnectionState::State::idle; + } + else { + state->d_lastIOBlocked = true; + } + } + + if (state->active() && + !state->d_lastIOBlocked && + iostate == IOState::Done && + (state->d_state == IncomingTCPConnectionState::State::idle || + state->d_state == IncomingTCPConnectionState::State::waitingForQuery)) + { + // try sending queued responses + DEBUGLOG("send responses, if any"); + iostate = sendQueuedResponses(state, now); + + if (!state->d_lastIOBlocked && state->active() && iostate == IOState::Done) { + // if the query has been passed to a backend, or dropped, and the responses have been sent, + // we can start reading again + if (state->canAcceptNewQueries(now)) { + state->resetForNewQuery(); + iostate = IOState::NeedRead; + } + else { + state->d_state = IncomingTCPConnectionState::State::idle; + iostate = IOState::Done; + } + } + } + + if (state->d_state != IncomingTCPConnectionState::State::idle && + state->d_state != IncomingTCPConnectionState::State::doingHandshake && + state->d_state != IncomingTCPConnectionState::State::readingProxyProtocolHeader && + state->d_state != IncomingTCPConnectionState::State::waitingForQuery && + state->d_state != IncomingTCPConnectionState::State::readingQuerySize && + state->d_state != IncomingTCPConnectionState::State::readingQuery && + state->d_state != IncomingTCPConnectionState::State::sendingResponse) { + vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(state->d_state)); + } + } + catch (const std::exception& e) { + /* most likely an EOF because the other end closed the connection, + but it might also be a real IO error or something else. + Let's just drop the connection + */ + if (state->d_state == IncomingTCPConnectionState::State::idle || + state->d_state == IncomingTCPConnectionState::State::waitingForQuery) { + /* no need to increase any counters in that case, the client is simply done with us */ + } + else if (state->d_state == IncomingTCPConnectionState::State::doingHandshake || + state->d_state != IncomingTCPConnectionState::State::readingProxyProtocolHeader || + state->d_state == IncomingTCPConnectionState::State::waitingForQuery || + state->d_state == IncomingTCPConnectionState::State::readingQuerySize || + state->d_state == IncomingTCPConnectionState::State::readingQuery) { + ++state->d_ci.cs->tcpDiedReadingQuery; + } + else if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) { + /* unlikely to happen here, the exception should be handled in sendResponse() */ + ++state->d_ci.cs->tcpDiedSendingResponse; + } + + if (state->d_ioState->isWaitingForWrite() || state->d_queriesCount == 0) { + DEBUGLOG("Got an exception while handling TCP query: "<<e.what()); + vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (state->d_ioState->isWaitingForRead() ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what()); + } + else { + vinfolog("Closing TCP client connection with %s: %s", state->d_ci.remote.toStringWithPort(), e.what()); + DEBUGLOG("Closing TCP client connection: "<<e.what()); + } + /* remove this FD from the IO multiplexer */ + state->terminateClientConnection(); + } + + if (!state->active()) { + DEBUGLOG("state is no longer active"); + return; + } + + if (iostate == IOState::Done) { + state->d_ioState->update(iostate, handleIOCallback, state); + } + else { + updateIO(state, iostate, now); + } + ioGuard.release(); + } + while ((iostate == IOState::NeedRead || iostate == IOState::NeedWrite) && !state->d_lastIOBlocked); +} + +void IncomingTCPConnectionState::notifyIOError(InternalQueryState&& query, const struct timeval& now) +{ + if (std::this_thread::get_id() != d_creatorThreadID) { + /* empty buffer will signal an IO error */ + TCPResponse response(PacketBuffer(), std::move(query), nullptr, nullptr); + handleCrossProtocolResponse(now, std::move(response)); + return; + } + + std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this(); + --state->d_currentQueriesCount; + state->d_hadErrors = true; + + if (state->d_state == State::sendingResponse) { + /* if we have responses to send, let's do that first */ + } + else if (!state->d_queuedResponses.empty()) { + /* stop reading and send what we have */ + try { + auto iostate = sendQueuedResponses(state, now); + + if (state->active() && iostate != IOState::Done) { + // we need to update the state right away, nobody will do that for us + updateIO(state, iostate, now); + } + } + catch (const std::exception& e) { + vinfolog("Exception in notifyIOError: %s", e.what()); + } + } + else { + // the backend code already tried to reconnect if it was possible + state->terminateClientConnection(); + } +} + +void IncomingTCPConnectionState::handleXFRResponse(const struct timeval& now, TCPResponse&& response) +{ + if (std::this_thread::get_id() != d_creatorThreadID) { + handleCrossProtocolResponse(now, std::move(response)); + return; + } + + std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this(); + queueResponse(state, now, std::move(response)); +} + +void IncomingTCPConnectionState::handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bool write) +{ + vinfolog("Timeout while %s TCP client %s", (write ? "writing to" : "reading from"), state->d_ci.remote.toStringWithPort()); + DEBUGLOG("client timeout"); + DEBUGLOG("Processed "<<state->d_queriesCount<<" queries, current count is "<<state->d_currentQueriesCount<<", "<<state->d_ownedConnectionsToBackend.size()<<" owned connections, "<<state->d_queuedResponses.size()<<" response queued"); + + if (write || state->d_currentQueriesCount == 0) { + ++state->d_ci.cs->tcpClientTimeouts; + state->d_ioState.reset(); + } + else { + DEBUGLOG("Going idle"); + /* we still have some queries in flight, let's just stop reading for now */ + state->d_state = IncomingTCPConnectionState::State::idle; + state->d_ioState->update(IOState::Done, handleIOCallback, state); + } +} + +static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param) +{ + auto threadData = boost::any_cast<TCPClientThreadData*>(param); + + ConnectionInfo* citmp{nullptr}; + + ssize_t got = read(pipefd, &citmp, sizeof(citmp)); + if (got == 0) { + throw std::runtime_error("EOF while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode"); + } + else if (got == -1) { + if (errno == EAGAIN || errno == EINTR) { + return; + } + throw std::runtime_error("Error while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + stringerror()); + } + else if (got != sizeof(citmp)) { + throw std::runtime_error("Partial read while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode"); + } + + try { + g_tcpclientthreads->decrementQueuedCount(); + + struct timeval now; + gettimeofday(&now, nullptr); + auto state = std::make_shared<IncomingTCPConnectionState>(std::move(*citmp), *threadData, now); + delete citmp; + citmp = nullptr; + + IncomingTCPConnectionState::handleIO(state, now); + } + catch (...) { + delete citmp; + citmp = nullptr; + throw; + } +} + +static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& param) +{ + auto threadData = boost::any_cast<TCPClientThreadData*>(param); + CrossProtocolQuery* tmp{nullptr}; + + ssize_t got = read(pipefd, &tmp, sizeof(tmp)); + if (got == 0) { + throw std::runtime_error("EOF while reading from the TCP cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode"); + } + else if (got == -1) { + if (errno == EAGAIN || errno == EINTR) { + return; + } + throw std::runtime_error("Error while reading from the TCP cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + stringerror()); + } + else if (got != sizeof(tmp)) { + throw std::runtime_error("Partial read while reading from the TCP cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode"); + } + + try { + struct timeval now; + gettimeofday(&now, nullptr); + + std::shared_ptr<TCPQuerySender> tqs = tmp->getTCPQuerySender(); + auto query = std::move(tmp->query); + auto downstreamServer = std::move(tmp->downstream); + auto proxyProtocolPayloadSize = tmp->proxyProtocolPayloadSize; + delete tmp; + tmp = nullptr; + + try { + auto downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(threadData->mplexer, downstreamServer, now, std::string()); + + prependSizeToTCPQuery(query.d_buffer, proxyProtocolPayloadSize); + query.d_proxyProtocolPayloadAddedSize = proxyProtocolPayloadSize; + + vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), query.d_idstate.origRemote.toStringWithPort(), query.d_idstate.protocol.toString(), query.d_buffer.size(), downstreamServer->getNameWithAddr()); + + downstream->queueQuery(tqs, std::move(query)); + } + catch (...) { + tqs->notifyIOError(std::move(query.d_idstate), now); + } + } + catch (...) { + delete tmp; + tmp = nullptr; + } +} + +static void handleCrossProtocolResponse(int pipefd, FDMultiplexer::funcparam_t& param) +{ + TCPCrossProtocolResponse* tmp{nullptr}; + + ssize_t got = read(pipefd, &tmp, sizeof(tmp)); + if (got == 0) { + throw std::runtime_error("EOF while reading from the TCP cross-protocol response pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode"); + } + else if (got == -1) { + if (errno == EAGAIN || errno == EINTR) { + return; + } + throw std::runtime_error("Error while reading from the TCP cross-protocol response pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + stringerror()); + } + else if (got != sizeof(tmp)) { + throw std::runtime_error("Partial read while reading from the TCP cross-protocol response pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode"); + } + + auto response = std::move(*tmp); + delete tmp; + tmp = nullptr; + + try { + if (response.d_response.d_buffer.empty()) { + response.d_state->notifyIOError(std::move(response.d_response.d_idstate), response.d_now); + } + else if (response.d_response.d_idstate.qtype == QType::AXFR || response.d_response.d_idstate.qtype == QType::IXFR) { + response.d_state->handleXFRResponse(response.d_now, std::move(response.d_response)); + } + else { + response.d_state->handleResponse(response.d_now, std::move(response.d_response)); + } + } + catch (...) { + /* no point bubbling up from there */ + } +} + +struct TCPAcceptorParam +{ + ClientState& cs; + ComboAddress local; + LocalStateHolder<NetmaskGroup>& acl; + int socket{-1}; +}; + +static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadData* threadData); + +static void tcpClientThread(int pipefd, int crossProtocolQueriesPipeFD, int crossProtocolResponsesListenPipeFD, int crossProtocolResponsesWritePipeFD, std::vector<ClientState*> tcpAcceptStates) +{ + /* we get launched with a pipe on which we receive file descriptors from clients that we own + from that point on */ + + setThreadName("dnsdist/tcpClie"); + + try { + TCPClientThreadData data; + /* this is the writing end! */ + data.crossProtocolResponsesPipe = crossProtocolResponsesWritePipeFD; + data.mplexer->addReadFD(pipefd, handleIncomingTCPQuery, &data); + data.mplexer->addReadFD(crossProtocolQueriesPipeFD, handleCrossProtocolQuery, &data); + data.mplexer->addReadFD(crossProtocolResponsesListenPipeFD, handleCrossProtocolResponse, &data); + + /* only used in single acceptor mode for now */ + auto acl = g_ACL.getLocal(); + std::vector<TCPAcceptorParam> acceptParams; + acceptParams.reserve(tcpAcceptStates.size()); + + for (auto& state : tcpAcceptStates) { + acceptParams.emplace_back(TCPAcceptorParam{*state, state->local, acl, state->tcpFD}); + for (const auto& [addr, socket] : state->d_additionalAddresses) { + acceptParams.emplace_back(TCPAcceptorParam{*state, addr, acl, socket}); + } + } + + auto acceptCallback = [&data](int socket, FDMultiplexer::funcparam_t& funcparam) { + auto acceptorParam = boost::any_cast<const TCPAcceptorParam*>(funcparam); + acceptNewConnection(*acceptorParam, &data); + }; + + for (size_t idx = 0; idx < acceptParams.size(); idx++) { + const auto& param = acceptParams.at(idx); + setNonBlocking(param.socket); + data.mplexer->addReadFD(param.socket, acceptCallback, ¶m); + } + + struct timeval now; + gettimeofday(&now, nullptr); + time_t lastTimeoutScan = now.tv_sec; + + for (;;) { + data.mplexer->run(&now); + + try { + t_downstreamTCPConnectionsManager.cleanupClosedConnections(now); + + if (now.tv_sec > lastTimeoutScan) { + lastTimeoutScan = now.tv_sec; + auto expiredReadConns = data.mplexer->getTimeouts(now, false); + for (const auto& cbData : expiredReadConns) { + if (cbData.second.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) { + auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(cbData.second); + if (cbData.first == state->d_handler.getDescriptor()) { + vinfolog("Timeout (read) from remote TCP client %s", state->d_ci.remote.toStringWithPort()); + state->handleTimeout(state, false); + } + } + else if (cbData.second.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) { + auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(cbData.second); + vinfolog("Timeout (read) from remote backend %s", conn->getBackendName()); + conn->handleTimeout(now, false); + } + } + + auto expiredWriteConns = data.mplexer->getTimeouts(now, true); + for (const auto& cbData : expiredWriteConns) { + if (cbData.second.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) { + auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(cbData.second); + if (cbData.first == state->d_handler.getDescriptor()) { + vinfolog("Timeout (write) from remote TCP client %s", state->d_ci.remote.toStringWithPort()); + state->handleTimeout(state, true); + } + } + else if (cbData.second.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) { + auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(cbData.second); + vinfolog("Timeout (write) from remote backend %s", conn->getBackendName()); + conn->handleTimeout(now, true); + } + } + + if (g_tcpStatesDumpRequested > 0) { + /* just to keep things clean in the output, debug only */ + static std::mutex s_lock; + std::lock_guard<decltype(s_lock)> lck(s_lock); + if (g_tcpStatesDumpRequested > 0) { + /* no race here, we took the lock so it can only be increased in the meantime */ + --g_tcpStatesDumpRequested; + errlog("Dumping the TCP states, as requested:"); + data.mplexer->runForAllWatchedFDs([](bool isRead, int fd, const FDMultiplexer::funcparam_t& param, struct timeval ttd) + { + struct timeval lnow; + gettimeofday(&lnow, nullptr); + if (ttd.tv_sec > 0) { + errlog("- Descriptor %d is in %s state, TTD in %d", fd, (isRead ? "read" : "write"), (ttd.tv_sec-lnow.tv_sec)); + } + else { + errlog("- Descriptor %d is in %s state, no TTD set", fd, (isRead ? "read" : "write")); + } + + if (param.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) { + auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param); + errlog(" - %s", state->toString()); + } + else if (param.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) { + auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(param); + errlog(" - %s", conn->toString()); + } + else if (param.type() == typeid(TCPClientThreadData*)) { + errlog(" - Worker thread pipe"); + } + }); + errlog("The TCP/DoT client cache has %d active and %d idle outgoing connections cached", t_downstreamTCPConnectionsManager.getActiveCount(), t_downstreamTCPConnectionsManager.getIdleCount()); + } + } + } + } + catch (const std::exception& e) { + errlog("Error in TCP worker thread: %s", e.what()); + } + } + } + catch (const std::exception& e) { + errlog("Fatal error in TCP worker thread: %s", e.what()); + } +} + +static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadData* threadData) +{ + auto& cs = param.cs; + auto& acl = param.acl; + int socket = param.socket; + bool tcpClientCountIncremented = false; + ComboAddress remote; + remote.sin4.sin_family = param.local.sin4.sin_family; + + tcpClientCountIncremented = false; + try { + socklen_t remlen = remote.getSocklen(); + ConnectionInfo ci(&cs); +#ifdef HAVE_ACCEPT4 + ci.fd = accept4(socket, reinterpret_cast<struct sockaddr*>(&remote), &remlen, SOCK_NONBLOCK); +#else + ci.fd = accept(socket, reinterpret_cast<struct sockaddr*>(&remote), &remlen); +#endif + // will be decremented when the ConnectionInfo object is destroyed, no matter the reason + auto concurrentConnections = ++cs.tcpCurrentConnections; + + if (ci.fd < 0) { + throw std::runtime_error((boost::format("accepting new connection on socket: %s") % stringerror()).str()); + } + + if (!acl->match(remote)) { + ++g_stats.aclDrops; + vinfolog("Dropped TCP connection from %s because of ACL", remote.toStringWithPort()); + return; + } + + if (cs.d_tcpConcurrentConnectionsLimit > 0 && concurrentConnections > cs.d_tcpConcurrentConnectionsLimit) { + vinfolog("Dropped TCP connection from %s because of concurrent connections limit", remote.toStringWithPort()); + return; + } + + if (concurrentConnections > cs.tcpMaxConcurrentConnections.load()) { + cs.tcpMaxConcurrentConnections.store(concurrentConnections); + } + +#ifndef HAVE_ACCEPT4 + if (!setNonBlocking(ci.fd)) { + return; + } +#endif + + setTCPNoDelay(ci.fd); // disable NAGLE + + if (g_maxTCPQueuedConnections > 0 && g_tcpclientthreads->getQueuedCount() >= g_maxTCPQueuedConnections) { + vinfolog("Dropping TCP connection from %s because we have too many queued already", remote.toStringWithPort()); + return; + } + + if (!dnsdist::IncomingConcurrentTCPConnectionsManager::accountNewTCPConnection(remote)) { + vinfolog("Dropping TCP connection from %s because we have too many from this client already", remote.toStringWithPort()); + return; + } + tcpClientCountIncremented = true; + + vinfolog("Got TCP connection from %s", remote.toStringWithPort()); + + ci.remote = remote; + if (threadData == nullptr) { + if (!g_tcpclientthreads->passConnectionToThread(std::make_unique<ConnectionInfo>(std::move(ci)))) { + if (tcpClientCountIncremented) { + dnsdist::IncomingConcurrentTCPConnectionsManager::accountClosedTCPConnection(remote); + } + } + } + else { + struct timeval now; + gettimeofday(&now, nullptr); + auto state = std::make_shared<IncomingTCPConnectionState>(std::move(ci), *threadData, now); + IncomingTCPConnectionState::handleIO(state, now); + } + } + catch (const std::exception& e) { + errlog("While reading a TCP question: %s", e.what()); + if (tcpClientCountIncremented) { + dnsdist::IncomingConcurrentTCPConnectionsManager::accountClosedTCPConnection(remote); + } + } + catch (...){} +} + +/* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and + they will hand off to worker threads & spawn more of them if required +*/ +#ifndef USE_SINGLE_ACCEPTOR_THREAD +void tcpAcceptorThread(std::vector<ClientState*> states) +{ + setThreadName("dnsdist/tcpAcce"); + + auto acl = g_ACL.getLocal(); + std::vector<TCPAcceptorParam> params; + params.reserve(states.size()); + + for (auto& state : states) { + params.emplace_back(TCPAcceptorParam{*state, state->local, acl, state->tcpFD}); + for (const auto& [addr, socket] : state->d_additionalAddresses) { + params.emplace_back(TCPAcceptorParam{*state, addr, acl, socket}); + } + } + + if (params.size() == 1) { + while (true) { + acceptNewConnection(params.at(0), nullptr); + } + } + else { + auto acceptCallback = [](int socket, FDMultiplexer::funcparam_t& funcparam) { + auto acceptorParam = boost::any_cast<const TCPAcceptorParam*>(funcparam); + acceptNewConnection(*acceptorParam, nullptr); + }; + + auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent(params.size())); + for (size_t idx = 0; idx < params.size(); idx++) { + const auto& param = params.at(idx); + mplexer->addReadFD(param.socket, acceptCallback, ¶m); + } + + struct timeval tv; + while (true) { + mplexer->run(&tv, -1); + } + } +} +#endif |