diff options
Diffstat (limited to 'dnsdist-tcp.cc')
-rw-r--r-- | dnsdist-tcp.cc | 1383 |
1 files changed, 754 insertions, 629 deletions
diff --git a/dnsdist-tcp.cc b/dnsdist-tcp.cc index b927cbe..e3eb68e 100644 --- a/dnsdist-tcp.cc +++ b/dnsdist-tcp.cc @@ -26,7 +26,9 @@ #include "dnsdist.hh" #include "dnsdist-concurrent-connections.hh" +#include "dnsdist-dnsparser.hh" #include "dnsdist-ecs.hh" +#include "dnsdist-nghttp2-in.hh" #include "dnsdist-proxy-protocol.hh" #include "dnsdist-rings.hh" #include "dnsdist-tcp.hh" @@ -64,7 +66,7 @@ size_t g_maxTCPConnectionDuration{0}; #ifdef __linux__ // On Linux this gives us 128k pending queries (default is 8192 queries), // which should be enough to deal with huge spikes -size_t g_tcpInternalPipeBufferSize{1024*1024}; +size_t g_tcpInternalPipeBufferSize{1048576U}; uint64_t g_maxTCPQueuedConnections{10000}; #else size_t g_tcpInternalPipeBufferSize{0}; @@ -83,11 +85,11 @@ IncomingTCPConnectionState::~IncomingTCPConnectionState() dnsdist::IncomingConcurrentTCPConnectionsManager::accountClosedTCPConnection(d_ci.remote); if (d_ci.cs != nullptr) { - struct timeval now; + timeval now{}; gettimeofday(&now, nullptr); auto diff = now - d_connectionStartTime; - d_ci.cs->updateTCPMetrics(d_queriesCount, diff.tv_sec * 1000.0 + diff.tv_usec / 1000.0); + d_ci.cs->updateTCPMetrics(d_queriesCount, diff.tv_sec * 1000 + diff.tv_usec / 1000); } // would have been done when the object is destroyed anyway, @@ -96,21 +98,30 @@ IncomingTCPConnectionState::~IncomingTCPConnectionState() d_handler.close(); } +dnsdist::Protocol IncomingTCPConnectionState::getProtocol() const +{ + if (d_ci.cs->dohFrontend) { + return dnsdist::Protocol::DoH; + } + if (d_handler.isTLS()) { + return dnsdist::Protocol::DoT; + } + return dnsdist::Protocol::DoTCP; +} + size_t IncomingTCPConnectionState::clearAllDownstreamConnections() { return t_downstreamTCPConnectionsManager.clear(); } -std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getDownstreamConnection(std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs, const struct timeval& now) +std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getDownstreamConnection(std::shared_ptr<DownstreamState>& backend, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs, const struct timeval& now) { - std::shared_ptr<TCPConnectionToBackend> downstream{nullptr}; - - downstream = getOwnedDownstreamConnection(ds, tlvs); + auto downstream = getOwnedDownstreamConnection(backend, tlvs); if (!downstream) { /* we don't have a connection to this backend owned yet, let's get one (it might not be a fresh one, though) */ - downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(d_threadData.mplexer, ds, now, std::string()); - if (ds->d_config.useProxyProtocol) { + downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(d_threadData.mplexer, backend, now, std::string()); + if (backend->d_config.useProxyProtocol) { registerOwnedDownstreamConnection(downstream); } } @@ -118,9 +129,10 @@ std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getDownstrea return downstream; } -static void tcpClientThread(int pipefd, int crossProtocolQueriesPipeFD, int crossProtocolResponsesListenPipeFD, int crossProtocolResponsesWritePipeFD, std::vector<ClientState*> tcpAcceptStates); +static void tcpClientThread(pdns::channel::Receiver<ConnectionInfo>&& queryReceiver, pdns::channel::Receiver<CrossProtocolQuery>&& crossProtocolQueryReceiver, pdns::channel::Receiver<TCPCrossProtocolResponse>&& crossProtocolResponseReceiver, pdns::channel::Sender<TCPCrossProtocolResponse>&& crossProtocolResponseSender, std::vector<ClientState*> tcpAcceptStates); -TCPClientCollection::TCPClientCollection(size_t maxThreads, std::vector<ClientState*> tcpAcceptStates): d_tcpclientthreads(maxThreads), d_maxthreads(maxThreads) +TCPClientCollection::TCPClientCollection(size_t maxThreads, std::vector<ClientState*> tcpAcceptStates) : + d_tcpclientthreads(maxThreads), d_maxthreads(maxThreads) { for (size_t idx = 0; idx < maxThreads; idx++) { addTCPClientThread(tcpAcceptStates); @@ -129,83 +141,37 @@ TCPClientCollection::TCPClientCollection(size_t maxThreads, std::vector<ClientSt void TCPClientCollection::addTCPClientThread(std::vector<ClientState*>& tcpAcceptStates) { - auto preparePipe = [](int fds[2], const std::string& type) -> bool { - if (pipe(fds) < 0) { - errlog("Error creating the TCP thread %s pipe: %s", type, stringerror()); - return false; - } - - if (!setNonBlocking(fds[0])) { - int err = errno; - close(fds[0]); - close(fds[1]); - errlog("Error setting the TCP thread %s pipe non-blocking: %s", type, stringerror(err)); - return false; - } - - if (!setNonBlocking(fds[1])) { - int err = errno; - close(fds[0]); - close(fds[1]); - errlog("Error setting the TCP thread %s pipe non-blocking: %s", type, stringerror(err)); - return false; - } - - if (g_tcpInternalPipeBufferSize > 0 && getPipeBufferSize(fds[0]) < g_tcpInternalPipeBufferSize) { - setPipeBufferSize(fds[0], g_tcpInternalPipeBufferSize); - } - - return true; - }; - - int pipefds[2] = { -1, -1}; - if (!preparePipe(pipefds, "communication")) { - return; - } + try { + auto [queryChannelSender, queryChannelReceiver] = pdns::channel::createObjectQueue<ConnectionInfo>(pdns::channel::SenderBlockingMode::SenderNonBlocking, pdns::channel::ReceiverBlockingMode::ReceiverNonBlocking, g_tcpInternalPipeBufferSize); - int crossProtocolQueriesFDs[2] = { -1, -1}; - if (!preparePipe(crossProtocolQueriesFDs, "cross-protocol queries")) { - return; - } + auto [crossProtocolQueryChannelSender, crossProtocolQueryChannelReceiver] = pdns::channel::createObjectQueue<CrossProtocolQuery>(pdns::channel::SenderBlockingMode::SenderNonBlocking, pdns::channel::ReceiverBlockingMode::ReceiverNonBlocking, g_tcpInternalPipeBufferSize); - int crossProtocolResponsesFDs[2] = { -1, -1}; - if (!preparePipe(crossProtocolResponsesFDs, "cross-protocol responses")) { - return; - } + auto [crossProtocolResponseChannelSender, crossProtocolResponseChannelReceiver] = pdns::channel::createObjectQueue<TCPCrossProtocolResponse>(pdns::channel::SenderBlockingMode::SenderNonBlocking, pdns::channel::ReceiverBlockingMode::ReceiverNonBlocking, g_tcpInternalPipeBufferSize); - vinfolog("Adding TCP Client thread"); + vinfolog("Adding TCP Client thread"); - { if (d_numthreads >= d_tcpclientthreads.size()) { vinfolog("Adding a new TCP client thread would exceed the vector size (%d/%d), skipping. Consider increasing the maximum amount of TCP client threads with setMaxTCPClientThreads() in the configuration.", d_numthreads.load(), d_tcpclientthreads.size()); - close(crossProtocolQueriesFDs[0]); - close(crossProtocolQueriesFDs[1]); - close(crossProtocolResponsesFDs[0]); - close(crossProtocolResponsesFDs[1]); - close(pipefds[0]); - close(pipefds[1]); return; } - /* from now on this side of the pipe will be managed by that object, - no need to worry about it */ - TCPWorkerThread worker(pipefds[1], crossProtocolQueriesFDs[1], crossProtocolResponsesFDs[1]); + TCPWorkerThread worker(std::move(queryChannelSender), std::move(crossProtocolQueryChannelSender)); + try { - std::thread t1(tcpClientThread, pipefds[0], crossProtocolQueriesFDs[0], crossProtocolResponsesFDs[0], crossProtocolResponsesFDs[1], tcpAcceptStates); - t1.detach(); + std::thread clientThread(tcpClientThread, std::move(queryChannelReceiver), std::move(crossProtocolQueryChannelReceiver), std::move(crossProtocolResponseChannelReceiver), std::move(crossProtocolResponseChannelSender), tcpAcceptStates); + clientThread.detach(); } catch (const std::runtime_error& e) { - /* the thread creation failed, don't leak */ errlog("Error creating a TCP thread: %s", e.what()); - close(pipefds[0]); - close(crossProtocolQueriesFDs[0]); - close(crossProtocolResponsesFDs[0]); return; } d_tcpclientthreads.at(d_numthreads) = std::move(worker); ++d_numthreads; } + catch (const std::exception& e) { + errlog("Error creating TCP worker: %", e.what()); + } } std::unique_ptr<TCPClientCollection> g_tcpclientthreads; @@ -215,11 +181,11 @@ static IOState sendQueuedResponses(std::shared_ptr<IncomingTCPConnectionState>& IOState result = IOState::Done; while (state->active() && !state->d_queuedResponses.empty()) { - DEBUGLOG("queue size is "<<state->d_queuedResponses.size()<<", sending the next one"); + DEBUGLOG("queue size is " << state->d_queuedResponses.size() << ", sending the next one"); TCPResponse resp = std::move(state->d_queuedResponses.front()); state->d_queuedResponses.pop_front(); state->d_state = IncomingTCPConnectionState::State::idle; - result = state->sendResponse(state, now, std::move(resp)); + result = state->sendResponse(now, std::move(resp)); if (result != IOState::Done) { return result; } @@ -229,28 +195,29 @@ static IOState sendQueuedResponses(std::shared_ptr<IncomingTCPConnectionState>& return IOState::Done; } -static void handleResponseSent(std::shared_ptr<IncomingTCPConnectionState>& state, TCPResponse& currentResponse) +void IncomingTCPConnectionState::handleResponseSent(TCPResponse& currentResponse) { if (currentResponse.d_idstate.qtype == QType::AXFR || currentResponse.d_idstate.qtype == QType::IXFR) { return; } - --state->d_currentQueriesCount; + --d_currentQueriesCount; - const auto& ds = currentResponse.d_connection ? currentResponse.d_connection->getDS() : currentResponse.d_ds; - if (currentResponse.d_idstate.selfGenerated == false && ds) { + const auto& backend = currentResponse.d_connection ? currentResponse.d_connection->getDS() : currentResponse.d_ds; + if (!currentResponse.d_idstate.selfGenerated && backend) { const auto& ids = currentResponse.d_idstate; double udiff = ids.queryRealTime.udiff(); - vinfolog("Got answer from %s, relayed to %s (%s, %d bytes), took %f usec", ds->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), currentResponse.d_buffer.size(), udiff); + vinfolog("Got answer from %s, relayed to %s (%s, %d bytes), took %f us", backend->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), getProtocol().toString(), currentResponse.d_buffer.size(), udiff); - auto backendProtocol = ds->getProtocol(); - if (backendProtocol == dnsdist::Protocol::DoUDP) { + auto backendProtocol = backend->getProtocol(); + if (backendProtocol == dnsdist::Protocol::DoUDP && !currentResponse.d_idstate.forwardedOverUDP) { backendProtocol = dnsdist::Protocol::DoTCP; } - ::handleResponseSent(ids, udiff, state->d_ci.remote, ds->d_config.remote, static_cast<unsigned int>(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, backendProtocol, true); - } else { + ::handleResponseSent(ids, udiff, d_ci.remote, backend->d_config.remote, static_cast<unsigned int>(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, backendProtocol, true); + } + else { const auto& ids = currentResponse.d_idstate; - ::handleResponseSent(ids, 0., state->d_ci.remote, ComboAddress(), static_cast<unsigned int>(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, ids.protocol, false); + ::handleResponseSent(ids, 0., d_ci.remote, ComboAddress(), static_cast<unsigned int>(currentResponse.d_buffer.size()), currentResponse.d_cleartextDH, ids.protocol, false); } currentResponse.d_buffer.clear(); @@ -264,11 +231,11 @@ static void prependSizeToTCPQuery(PacketBuffer& buffer, size_t proxyProtocolPayl } uint16_t queryLen = proxyProtocolPayloadSize > 0 ? (buffer.size() - proxyProtocolPayloadSize) : buffer.size(); - const uint8_t sizeBytes[] = { static_cast<uint8_t>(queryLen / 256), static_cast<uint8_t>(queryLen % 256) }; + const std::array<uint8_t, 2> sizeBytes{static_cast<uint8_t>(queryLen / 256), static_cast<uint8_t>(queryLen % 256)}; /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes that could occur if we had to deal with the size during the processing, especially alignment issues */ - buffer.insert(buffer.begin() + proxyProtocolPayloadSize, sizeBytes, sizeBytes + 2); + buffer.insert(buffer.begin() + static_cast<PacketBuffer::iterator::difference_type>(proxyProtocolPayloadSize), sizeBytes.begin(), sizeBytes.end()); } bool IncomingTCPConnectionState::canAcceptNewQueries(const struct timeval& now) @@ -278,12 +245,13 @@ bool IncomingTCPConnectionState::canAcceptNewQueries(const struct timeval& now) return false; } - if (d_currentQueriesCount >= d_ci.cs->d_maxInFlightQueriesPerConn) { - DEBUGLOG("not accepting new queries because we already have "<<d_currentQueriesCount<<" out of "<<d_ci.cs->d_maxInFlightQueriesPerConn); + // for DoH, this is already handled by the underlying library + if (!d_ci.cs->dohFrontend && d_currentQueriesCount >= d_ci.cs->d_maxInFlightQueriesPerConn) { + DEBUGLOG("not accepting new queries because we already have " << d_currentQueriesCount << " out of " << d_ci.cs->d_maxInFlightQueriesPerConn); return false; } - if (g_maxTCPQueriesPerConn && d_queriesCount > g_maxTCPQueriesPerConn) { + if (g_maxTCPQueriesPerConn != 0 && d_queriesCount > g_maxTCPQueriesPerConn) { vinfolog("not accepting new queries from %s because it reached the maximum number of queries per conn (%d / %d)", d_ci.remote.toStringWithPort(), d_queriesCount, g_maxTCPQueriesPerConn); return false; } @@ -298,27 +266,27 @@ bool IncomingTCPConnectionState::canAcceptNewQueries(const struct timeval& now) void IncomingTCPConnectionState::resetForNewQuery() { - d_buffer.resize(sizeof(uint16_t)); + d_buffer.clear(); d_currentPos = 0; d_querySize = 0; d_state = State::waitingForQuery; } -std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getOwnedDownstreamConnection(const std::shared_ptr<DownstreamState>& ds, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs) +std::shared_ptr<TCPConnectionToBackend> IncomingTCPConnectionState::getOwnedDownstreamConnection(const std::shared_ptr<DownstreamState>& backend, const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs) { - auto it = d_ownedConnectionsToBackend.find(ds); - if (it == d_ownedConnectionsToBackend.end()) { - DEBUGLOG("no owned connection found for "<<ds->getName()); + auto connIt = d_ownedConnectionsToBackend.find(backend); + if (connIt == d_ownedConnectionsToBackend.end()) { + DEBUGLOG("no owned connection found for " << backend->getName()); return nullptr; } - for (auto& conn : it->second) { + for (auto& conn : connIt->second) { if (conn->canBeReused(true) && conn->matchesTLVs(tlvs)) { - DEBUGLOG("Got one owned connection accepting more for "<<ds->getName()); + DEBUGLOG("Got one owned connection accepting more for " << backend->getName()); conn->setReused(); return conn; } - DEBUGLOG("not accepting more for "<<ds->getName()); + DEBUGLOG("not accepting more for " << backend->getName()); } return nullptr; @@ -330,37 +298,36 @@ void IncomingTCPConnectionState::registerOwnedDownstreamConnection(std::shared_p } /* called when the buffer has been set and the rules have been processed, and only from handleIO (sometimes indirectly via handleQuery) */ -IOState IncomingTCPConnectionState::sendResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response) +IOState IncomingTCPConnectionState::sendResponse(const struct timeval& now, TCPResponse&& response) { - state->d_state = IncomingTCPConnectionState::State::sendingResponse; + d_state = State::sendingResponse; - uint16_t responseSize = static_cast<uint16_t>(response.d_buffer.size()); - const uint8_t sizeBytes[] = { static_cast<uint8_t>(responseSize / 256), static_cast<uint8_t>(responseSize % 256) }; + const auto responseSize = static_cast<uint16_t>(response.d_buffer.size()); + const std::array<uint8_t, 2> sizeBytes{static_cast<uint8_t>(responseSize / 256), static_cast<uint8_t>(responseSize % 256)}; /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes that could occur if we had to deal with the size during the processing, especially alignment issues */ - response.d_buffer.insert(response.d_buffer.begin(), sizeBytes, sizeBytes + 2); - state->d_currentPos = 0; - state->d_currentResponse = std::move(response); + response.d_buffer.insert(response.d_buffer.begin(), sizeBytes.begin(), sizeBytes.end()); + d_currentPos = 0; + d_currentResponse = std::move(response); try { - auto iostate = state->d_handler.tryWrite(state->d_currentResponse.d_buffer, state->d_currentPos, state->d_currentResponse.d_buffer.size()); + auto iostate = d_handler.tryWrite(d_currentResponse.d_buffer, d_currentPos, d_currentResponse.d_buffer.size()); if (iostate == IOState::Done) { - DEBUGLOG("response sent from "<<__PRETTY_FUNCTION__); - handleResponseSent(state, state->d_currentResponse); - return iostate; - } else { - state->d_lastIOBlocked = true; - DEBUGLOG("partial write"); + DEBUGLOG("response sent from " << __PRETTY_FUNCTION__); + handleResponseSent(d_currentResponse); return iostate; } + d_lastIOBlocked = true; + DEBUGLOG("partial write"); + return iostate; } catch (const std::exception& e) { - vinfolog("Closing TCP client connection with %s: %s", state->d_ci.remote.toStringWithPort(), e.what()); - DEBUGLOG("Closing TCP client connection: "<<e.what()); - ++state->d_ci.cs->tcpDiedSendingResponse; + vinfolog("Closing TCP client connection with %s: %s", d_ci.remote.toStringWithPort(), e.what()); + DEBUGLOG("Closing TCP client connection: " << e.what()); + ++d_ci.cs->tcpDiedSendingResponse; - state->terminateClientConnection(); + terminateClientConnection(); return IOState::Done; } @@ -394,50 +361,59 @@ void IncomingTCPConnectionState::terminateClientConnection() /* we might already be waiting, but we might also not because sometimes we have already been notified via the descriptor, not received Async again, but the async job still exists.. */ auto state = shared_from_this(); - for (const auto fd : afds) { + for (const auto desc : afds) { try { - state->d_threadData.mplexer->addReadFD(fd, handleAsyncReady, state); + state->d_threadData.mplexer->addReadFD(desc, handleAsyncReady, state); } catch (...) { } } - } } -void IncomingTCPConnectionState::queueResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response) +void IncomingTCPConnectionState::queueResponse(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now, TCPResponse&& response, bool fromBackend) { // queue response - state->d_queuedResponses.push_back(std::move(response)); - DEBUGLOG("queueing response, state is "<<(int)state->d_state<<", queue size is now "<<state->d_queuedResponses.size()); + state->d_queuedResponses.emplace_back(std::move(response)); + DEBUGLOG("queueing response, state is " << (int)state->d_state << ", queue size is now " << state->d_queuedResponses.size()); // when the response comes from a backend, there is a real possibility that we are currently // idle, and thus not trying to send the response right away would make our ref count go to 0. // Even if we are waiting for a query, we will not wake up before the new query arrives or a // timeout occurs - if (state->d_state == IncomingTCPConnectionState::State::idle || - state->d_state == IncomingTCPConnectionState::State::waitingForQuery) { + if (state->d_state == State::idle || state->d_state == State::waitingForQuery) { auto iostate = sendQueuedResponses(state, now); if (iostate == IOState::Done && state->active()) { if (state->canAcceptNewQueries(now)) { state->resetForNewQuery(); - state->d_state = IncomingTCPConnectionState::State::waitingForQuery; + state->d_state = State::waitingForQuery; iostate = IOState::NeedRead; } else { - state->d_state = IncomingTCPConnectionState::State::idle; + state->d_state = State::idle; } } // for the same reason we need to update the state right away, nobody will do that for us if (state->active()) { updateIO(state, iostate, now); + // if we have not finished reading every available byte, we _need_ to do an actual read + // attempt before waiting for the socket to become readable again, because if there is + // buffered data available the socket might never become readable again. + // This is true as soon as we deal with TLS because TLS records are processed one by + // one and might not match what we see at the application layer, so data might already + // be available in the TLS library's buffers. This is especially true when OpenSSL's + // read-ahead mode is enabled because then it buffers even more than one SSL record + // for performance reasons. + if (fromBackend && !state->d_lastIOBlocked) { + state->handleIO(); + } } } } -void IncomingTCPConnectionState::handleAsyncReady(int fd, FDMultiplexer::funcparam_t& param) +void IncomingTCPConnectionState::handleAsyncReady([[maybe_unused]] int desc, FDMultiplexer::funcparam_t& param) { auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param); @@ -454,9 +430,7 @@ void IncomingTCPConnectionState::handleAsyncReady(int fd, FDMultiplexer::funcpar if (state->active()) { /* and now we restart our own I/O state machine */ - struct timeval now; - gettimeofday(&now, nullptr); - handleIO(state, now); + state->handleIO(); } else { /* we were only waiting for the engine to come back, @@ -469,8 +443,8 @@ void IncomingTCPConnectionState::updateIO(std::shared_ptr<IncomingTCPConnectionS { if (newState == IOState::Async) { auto fds = state->d_handler.getAsyncFDs(); - for (const auto fd : fds) { - state->d_threadData.mplexer->addReadFD(fd, handleAsyncReady, state); + for (const auto desc : fds) { + state->d_threadData.mplexer->addReadFD(desc, handleAsyncReady, state); } state->d_ioState->update(IOState::Done, handleIOCallback, state); } @@ -521,27 +495,27 @@ void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPRe if (!response.isAsync()) { try { auto& ids = response.d_idstate; - unsigned int qnameWireLength; - if (!response.d_connection || !responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, response.d_connection->getDS(), qnameWireLength)) { + std::shared_ptr<DownstreamState> backend = response.d_ds ? response.d_ds : (response.d_connection ? response.d_connection->getDS() : nullptr); + if (backend == nullptr || !responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, backend)) { state->terminateClientConnection(); return; } - if (response.d_connection->getDS()) { - ++response.d_connection->getDS()->responses; + if (backend != nullptr) { + ++backend->responses; } - DNSResponse dr(ids, response.d_buffer, response.d_connection->getDS()); - dr.d_incomingTCPState = state; + DNSResponse dnsResponse(ids, response.d_buffer, backend); + dnsResponse.d_incomingTCPState = state; - memcpy(&response.d_cleartextDH, dr.getHeader(), sizeof(response.d_cleartextDH)); + memcpy(&response.d_cleartextDH, dnsResponse.getHeader().get(), sizeof(response.d_cleartextDH)); - if (!processResponse(response.d_buffer, *state->d_threadData.localRespRuleActions, *state->d_threadData.localCacheInsertedRespRuleActions, dr, false)) { + if (!processResponse(response.d_buffer, *state->d_threadData.localRespRuleActions, *state->d_threadData.localCacheInsertedRespRuleActions, dnsResponse, false)) { state->terminateClientConnection(); return; } - if (dr.isAsynchronous()) { + if (dnsResponse.isAsynchronous()) { /* we are done for now */ return; } @@ -553,17 +527,23 @@ void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPRe } } - ++g_stats.responses; + ++dnsdist::metrics::g_stats.responses; ++state->d_ci.cs->responses; - queueResponse(state, now, std::move(response)); + queueResponse(state, now, std::move(response), true); } struct TCPCrossProtocolResponse { - TCPCrossProtocolResponse(TCPResponse&& response, std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now): d_response(std::move(response)), d_state(state), d_now(now) + TCPCrossProtocolResponse(TCPResponse&& response, std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now) : + d_response(std::move(response)), d_state(state), d_now(now) { } + TCPCrossProtocolResponse(const TCPCrossProtocolResponse&) = delete; + TCPCrossProtocolResponse& operator=(const TCPCrossProtocolResponse&) = delete; + TCPCrossProtocolResponse(TCPCrossProtocolResponse&&) = delete; + TCPCrossProtocolResponse& operator=(TCPCrossProtocolResponse&&) = delete; + ~TCPCrossProtocolResponse() = default; TCPResponse d_response; std::shared_ptr<IncomingTCPConnectionState> d_state; @@ -573,14 +553,15 @@ struct TCPCrossProtocolResponse class TCPCrossProtocolQuery : public CrossProtocolQuery { public: - TCPCrossProtocolQuery(PacketBuffer&& buffer, InternalQueryState&& ids, std::shared_ptr<DownstreamState> ds, std::shared_ptr<IncomingTCPConnectionState> sender): CrossProtocolQuery(InternalQuery(std::move(buffer), std::move(ids)), ds), d_sender(std::move(sender)) - { - proxyProtocolPayloadSize = 0; - } - - ~TCPCrossProtocolQuery() + TCPCrossProtocolQuery(PacketBuffer&& buffer, InternalQueryState&& ids, std::shared_ptr<DownstreamState> backend, std::shared_ptr<IncomingTCPConnectionState> sender) : + CrossProtocolQuery(InternalQuery(std::move(buffer), std::move(ids)), backend), d_sender(std::move(sender)) { } + TCPCrossProtocolQuery(const TCPCrossProtocolQuery&) = delete; + TCPCrossProtocolQuery& operator=(const TCPCrossProtocolQuery&) = delete; + TCPCrossProtocolQuery(TCPCrossProtocolQuery&&) = delete; + TCPCrossProtocolQuery& operator=(TCPCrossProtocolQuery&&) = delete; + ~TCPCrossProtocolQuery() override = default; std::shared_ptr<TCPQuerySender> getTCPQuerySender() override { @@ -590,503 +571,616 @@ public: DNSQuestion getDQ() override { auto& ids = query.d_idstate; - DNSQuestion dq(ids, query.d_buffer); - dq.d_incomingTCPState = d_sender; - return dq; + DNSQuestion dnsQuestion(ids, query.d_buffer); + dnsQuestion.d_incomingTCPState = d_sender; + return dnsQuestion; } DNSResponse getDR() override { auto& ids = query.d_idstate; - DNSResponse dr(ids, query.d_buffer, downstream); - dr.d_incomingTCPState = d_sender; - return dr; + DNSResponse dnsResponse(ids, query.d_buffer, downstream); + dnsResponse.d_incomingTCPState = d_sender; + return dnsResponse; } private: std::shared_ptr<IncomingTCPConnectionState> d_sender; }; -std::unique_ptr<CrossProtocolQuery> getTCPCrossProtocolQueryFromDQ(DNSQuestion& dq) +std::unique_ptr<CrossProtocolQuery> IncomingTCPConnectionState::getCrossProtocolQuery(PacketBuffer&& query, InternalQueryState&& state, const std::shared_ptr<DownstreamState>& backend) { - auto state = dq.getIncomingTCPState(); + return std::make_unique<TCPCrossProtocolQuery>(std::move(query), std::move(state), backend, shared_from_this()); +} + +std::unique_ptr<CrossProtocolQuery> getTCPCrossProtocolQueryFromDQ(DNSQuestion& dnsQuestion) +{ + auto state = dnsQuestion.getIncomingTCPState(); if (!state) { throw std::runtime_error("Trying to create a TCP cross protocol query without a valid TCP state"); } - dq.ids.origID = dq.getHeader()->id; - return std::make_unique<TCPCrossProtocolQuery>(std::move(dq.getMutableData()), std::move(dq.ids), nullptr, std::move(state)); + dnsQuestion.ids.origID = dnsQuestion.getHeader()->id; + return std::make_unique<TCPCrossProtocolQuery>(std::move(dnsQuestion.getMutableData()), std::move(dnsQuestion.ids), nullptr, std::move(state)); } void IncomingTCPConnectionState::handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response) { - if (d_threadData.crossProtocolResponsesPipe == -1) { - throw std::runtime_error("Invalid pipe descriptor in TCP Cross Protocol Query Sender"); - } - std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this(); - auto ptr = new TCPCrossProtocolResponse(std::move(response), state, now); - static_assert(sizeof(ptr) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail"); - ssize_t sent = write(d_threadData.crossProtocolResponsesPipe, &ptr, sizeof(ptr)); - if (sent != sizeof(ptr)) { - if (errno == EAGAIN || errno == EWOULDBLOCK) { - ++g_stats.tcpCrossProtocolResponsePipeFull; + try { + auto ptr = std::make_unique<TCPCrossProtocolResponse>(std::move(response), state, now); + if (!state->d_threadData.crossProtocolResponseSender.send(std::move(ptr))) { + ++dnsdist::metrics::g_stats.tcpCrossProtocolResponsePipeFull; vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because the pipe is full"); } - else { - vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because we couldn't write to the pipe: %s", stringerror()); - } - delete ptr; + } + catch (const std::exception& e) { + vinfolog("Unable to pass a cross-protocol response to the TCP worker thread because we couldn't write to the pipe: %s", stringerror()); } } -static void handleQuery(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now) +IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::handleQuery(PacketBuffer&& queryIn, const struct timeval& now, std::optional<int32_t> streamID) { - if (state->d_querySize < sizeof(dnsheader)) { - ++g_stats.nonCompliantQueries; - ++state->d_ci.cs->nonCompliantQueries; - state->terminateClientConnection(); - return; + auto query = std::move(queryIn); + if (query.size() < sizeof(dnsheader)) { + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++d_ci.cs->nonCompliantQueries; + return QueryProcessingResult::TooSmall; } - ++state->d_queriesCount; - ++state->d_ci.cs->queries; - ++g_stats.queries; + ++d_queriesCount; + ++d_ci.cs->queries; + ++dnsdist::metrics::g_stats.queries; - if (state->d_handler.isTLS()) { - auto tlsVersion = state->d_handler.getTLSVersion(); + if (d_handler.isTLS()) { + auto tlsVersion = d_handler.getTLSVersion(); switch (tlsVersion) { case LibsslTLSVersion::TLS10: - ++state->d_ci.cs->tls10queries; + ++d_ci.cs->tls10queries; break; case LibsslTLSVersion::TLS11: - ++state->d_ci.cs->tls11queries; + ++d_ci.cs->tls11queries; break; case LibsslTLSVersion::TLS12: - ++state->d_ci.cs->tls12queries; + ++d_ci.cs->tls12queries; break; case LibsslTLSVersion::TLS13: - ++state->d_ci.cs->tls13queries; + ++d_ci.cs->tls13queries; break; default: - ++state->d_ci.cs->tlsUnknownqueries; + ++d_ci.cs->tlsUnknownqueries; } } + auto state = shared_from_this(); InternalQueryState ids; - ids.origDest = state->d_proxiedDestination; - ids.origRemote = state->d_proxiedRemote; - ids.cs = state->d_ci.cs; + ids.origDest = d_proxiedDestination; + ids.origRemote = d_proxiedRemote; + ids.cs = d_ci.cs; ids.queryRealTime.start(); + if (streamID) { + ids.d_streamID = *streamID; + } - auto dnsCryptResponse = checkDNSCryptQuery(*state->d_ci.cs, state->d_buffer, ids.dnsCryptQuery, ids.queryRealTime.d_start.tv_sec, true); + auto dnsCryptResponse = checkDNSCryptQuery(*d_ci.cs, query, ids.dnsCryptQuery, ids.queryRealTime.d_start.tv_sec, true); if (dnsCryptResponse) { TCPResponse response; - state->d_state = IncomingTCPConnectionState::State::idle; - ++state->d_currentQueriesCount; - state->queueResponse(state, now, std::move(response)); - return; + d_state = State::idle; + ++d_currentQueriesCount; + queueResponse(state, now, std::move(response), false); + return QueryProcessingResult::SelfAnswered; } { /* this pointer will be invalidated the second the buffer is resized, don't hold onto it! */ - auto* dh = reinterpret_cast<dnsheader*>(state->d_buffer.data()); - if (!checkQueryHeaders(dh, *state->d_ci.cs)) { - state->terminateClientConnection(); - return; + const dnsheader_aligned dnsHeader(query.data()); + if (!checkQueryHeaders(*dnsHeader, *d_ci.cs)) { + return QueryProcessingResult::InvalidHeaders; } - if (dh->qdcount == 0) { + if (dnsHeader->qdcount == 0) { TCPResponse response; - dh->rcode = RCode::NotImp; - dh->qr = true; + auto queryID = dnsHeader->id; + dnsdist::PacketMangling::editDNSHeaderFromPacket(query, [](dnsheader& header) { + header.rcode = RCode::NotImp; + header.qr = true; + return true; + }); + response.d_idstate = std::move(ids); + response.d_idstate.origID = queryID; response.d_idstate.selfGenerated = true; - response.d_buffer = std::move(state->d_buffer); - state->d_state = IncomingTCPConnectionState::State::idle; - ++state->d_currentQueriesCount; - state->queueResponse(state, now, std::move(response)); - return; + response.d_buffer = std::move(query); + d_state = State::idle; + ++d_currentQueriesCount; + queueResponse(state, now, std::move(response), false); + return QueryProcessingResult::SelfAnswered; } } - ids.qname = DNSName(reinterpret_cast<const char*>(state->d_buffer.data()), state->d_buffer.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass); - ids.protocol = dnsdist::Protocol::DoTCP; + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast + ids.qname = DNSName(reinterpret_cast<const char*>(query.data()), static_cast<int>(query.size()), sizeof(dnsheader), false, &ids.qtype, &ids.qclass); + ids.protocol = getProtocol(); if (ids.dnsCryptQuery) { ids.protocol = dnsdist::Protocol::DNSCryptTCP; } - else if (state->d_handler.isTLS()) { - ids.protocol = dnsdist::Protocol::DoT; - } - DNSQuestion dq(ids, state->d_buffer); - const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader()); - ids.origFlags = *flags; - dq.d_incomingTCPState = state; - dq.sni = state->d_handler.getServerNameIndication(); + DNSQuestion dnsQuestion(ids, query); + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsQuestion.getMutableData(), [&ids](dnsheader& header) { + const uint16_t* flags = getFlagsFromDNSHeader(&header); + ids.origFlags = *flags; + return true; + }); + dnsQuestion.d_incomingTCPState = state; + dnsQuestion.sni = d_handler.getServerNameIndication(); - if (state->d_proxyProtocolValues) { + if (d_proxyProtocolValues) { /* we need to copy them, because the next queries received on that connection will need to get the _unaltered_ values */ - dq.proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(*state->d_proxyProtocolValues); + dnsQuestion.proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(*d_proxyProtocolValues); } - if (dq.ids.qtype == QType::AXFR || dq.ids.qtype == QType::IXFR) { - dq.ids.skipCache = true; + if (dnsQuestion.ids.qtype == QType::AXFR || dnsQuestion.ids.qtype == QType::IXFR) { + dnsQuestion.ids.skipCache = true; } - std::shared_ptr<DownstreamState> ds; - auto result = processQuery(dq, state->d_threadData.holders, ds); + if (forwardViaUDPFirst()) { + // if there was no EDNS, we add it with a large buffer size + // so we can use UDP to talk to the backend. + const dnsheader_aligned dnsHeader(query.data()); + if (dnsHeader->arcount == 0U) { + if (addEDNS(query, 4096, false, 4096, 0)) { + dnsQuestion.ids.ednsAdded = true; + } + } + } - if (result == ProcessQueryResult::Drop) { - state->terminateClientConnection(); - return; + if (streamID) { + auto unit = getDOHUnit(*streamID); + if (unit) { + dnsQuestion.ids.du = std::move(unit); + } } - else if (result == ProcessQueryResult::Asynchronous) { + + std::shared_ptr<DownstreamState> backend; + auto result = processQuery(dnsQuestion, d_threadData.holders, backend); + + if (result == ProcessQueryResult::Asynchronous) { /* we are done for now */ - ++state->d_currentQueriesCount; - return; + ++d_currentQueriesCount; + return QueryProcessingResult::Asynchronous; + } + + if (streamID) { + restoreDOHUnit(std::move(dnsQuestion.ids.du)); + } + + if (result == ProcessQueryResult::Drop) { + return QueryProcessingResult::Dropped; } // the buffer might have been invalidated by now - const dnsheader* dh = dq.getHeader(); + uint16_t queryID{0}; + { + const auto dnsHeader = dnsQuestion.getHeader(); + queryID = dnsHeader->id; + } + if (result == ProcessQueryResult::SendAnswer) { TCPResponse response; - memcpy(&response.d_cleartextDH, dh, sizeof(response.d_cleartextDH)); + { + const auto dnsHeader = dnsQuestion.getHeader(); + memcpy(&response.d_cleartextDH, dnsHeader.get(), sizeof(response.d_cleartextDH)); + } response.d_idstate = std::move(ids); - response.d_idstate.origID = dh->id; + response.d_idstate.origID = queryID; response.d_idstate.selfGenerated = true; - response.d_idstate.cs = state->d_ci.cs; - response.d_buffer = std::move(state->d_buffer); + response.d_idstate.cs = d_ci.cs; + response.d_buffer = std::move(query); - state->d_state = IncomingTCPConnectionState::State::idle; - ++state->d_currentQueriesCount; - state->queueResponse(state, now, std::move(response)); - return; + d_state = State::idle; + ++d_currentQueriesCount; + queueResponse(state, now, std::move(response), false); + return QueryProcessingResult::SelfAnswered; } - if (result != ProcessQueryResult::PassToBackend || ds == nullptr) { - state->terminateClientConnection(); - return; + if (result != ProcessQueryResult::PassToBackend || backend == nullptr) { + return QueryProcessingResult::NoBackend; } - dq.ids.origID = dh->id; + dnsQuestion.ids.origID = queryID; - ++state->d_currentQueriesCount; + ++d_currentQueriesCount; std::string proxyProtocolPayload; - if (ds->isDoH()) { - vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", ids.qname.toLogString(), QType(ids.qtype).toString(), state->d_proxiedRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), state->d_buffer.size(), ds->getNameWithAddr()); + if (backend->isDoH()) { + vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", ids.qname.toLogString(), QType(ids.qtype).toString(), d_proxiedRemote.toStringWithPort(), getProtocol().toString(), query.size(), backend->getNameWithAddr()); /* we need to do this _before_ creating the cross protocol query because after that the buffer will have been moved */ - if (ds->d_config.useProxyProtocol) { - proxyProtocolPayload = getProxyProtocolPayload(dq); + if (backend->d_config.useProxyProtocol) { + proxyProtocolPayload = getProxyProtocolPayload(dnsQuestion); } - auto cpq = std::make_unique<TCPCrossProtocolQuery>(std::move(state->d_buffer), std::move(ids), ds, state); + auto cpq = std::make_unique<TCPCrossProtocolQuery>(std::move(query), std::move(ids), backend, state); cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); - ds->passCrossProtocolQuery(std::move(cpq)); - return; + backend->passCrossProtocolQuery(std::move(cpq)); + return QueryProcessingResult::Forwarded; + } + if (!backend->isTCPOnly() && forwardViaUDPFirst()) { + if (streamID) { + auto unit = getDOHUnit(*streamID); + if (unit) { + dnsQuestion.ids.du = std::move(unit); + } + } + if (assignOutgoingUDPQueryToBackend(backend, queryID, dnsQuestion, query)) { + return QueryProcessingResult::Forwarded; + } + restoreDOHUnit(std::move(dnsQuestion.ids.du)); + // fallback to the normal flow } - prependSizeToTCPQuery(state->d_buffer, 0); + prependSizeToTCPQuery(query, 0); - auto downstreamConnection = state->getDownstreamConnection(ds, dq.proxyProtocolValues, now); + auto downstreamConnection = getDownstreamConnection(backend, dnsQuestion.proxyProtocolValues, now); - if (ds->d_config.useProxyProtocol) { + if (backend->d_config.useProxyProtocol) { /* if we ever sent a TLV over a connection, we can never go back */ - if (!state->d_proxyProtocolPayloadHasTLV) { - state->d_proxyProtocolPayloadHasTLV = dq.proxyProtocolValues && !dq.proxyProtocolValues->empty(); + if (!d_proxyProtocolPayloadHasTLV) { + d_proxyProtocolPayloadHasTLV = dnsQuestion.proxyProtocolValues && !dnsQuestion.proxyProtocolValues->empty(); } - proxyProtocolPayload = getProxyProtocolPayload(dq); + proxyProtocolPayload = getProxyProtocolPayload(dnsQuestion); } - if (dq.proxyProtocolValues) { - downstreamConnection->setProxyProtocolValuesSent(std::move(dq.proxyProtocolValues)); + if (dnsQuestion.proxyProtocolValues) { + downstreamConnection->setProxyProtocolValuesSent(std::move(dnsQuestion.proxyProtocolValues)); } - TCPQuery query(std::move(state->d_buffer), std::move(ids)); - query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); + TCPQuery tcpquery(std::move(query), std::move(ids)); + tcpquery.d_proxyProtocolPayload = std::move(proxyProtocolPayload); - vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), state->d_proxiedRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), query.d_buffer.size(), ds->getNameWithAddr()); + vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", tcpquery.d_idstate.qname.toLogString(), QType(tcpquery.d_idstate.qtype).toString(), d_proxiedRemote.toStringWithPort(), getProtocol().toString(), tcpquery.d_buffer.size(), backend->getNameWithAddr()); std::shared_ptr<TCPQuerySender> incoming = state; - downstreamConnection->queueQuery(incoming, std::move(query)); + downstreamConnection->queueQuery(incoming, std::move(tcpquery)); + return QueryProcessingResult::Forwarded; } -void IncomingTCPConnectionState::handleIOCallback(int fd, FDMultiplexer::funcparam_t& param) +void IncomingTCPConnectionState::handleIOCallback(int desc, FDMultiplexer::funcparam_t& param) { auto conn = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param); - if (fd != conn->d_handler.getDescriptor()) { - throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->d_handler.getDescriptor())); + if (desc != conn->d_handler.getDescriptor()) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-array-to-pointer-decay): __PRETTY_FUNCTION__ is fine + throw std::runtime_error("Unexpected socket descriptor " + std::to_string(desc) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->d_handler.getDescriptor())); } - struct timeval now; - gettimeofday(&now, nullptr); - handleIO(conn, now); + conn->handleIO(); +} + +void IncomingTCPConnectionState::handleHandshakeDone(const struct timeval& now) +{ + if (d_handler.isTLS()) { + if (!d_handler.hasTLSSessionBeenResumed()) { + ++d_ci.cs->tlsNewSessions; + } + else { + ++d_ci.cs->tlsResumptions; + } + if (d_handler.getResumedFromInactiveTicketKey()) { + ++d_ci.cs->tlsInactiveTicketKey; + } + if (d_handler.getUnknownTicketKey()) { + ++d_ci.cs->tlsUnknownTicketKey; + } + } + + d_handshakeDoneTime = now; +} + +IncomingTCPConnectionState::ProxyProtocolResult IncomingTCPConnectionState::handleProxyProtocolPayload() +{ + do { + DEBUGLOG("reading proxy protocol header"); + auto iostate = d_handler.tryRead(d_buffer, d_currentPos, d_proxyProtocolNeed, false, isProxyPayloadOutsideTLS()); + if (iostate == IOState::Done) { + d_buffer.resize(d_currentPos); + ssize_t remaining = isProxyHeaderComplete(d_buffer); + if (remaining == 0) { + vinfolog("Unable to consume proxy protocol header in packet from TCP client %s", d_ci.remote.toStringWithPort()); + ++dnsdist::metrics::g_stats.proxyProtocolInvalid; + return ProxyProtocolResult::Error; + } + if (remaining < 0) { + d_proxyProtocolNeed += -remaining; + d_buffer.resize(d_currentPos + d_proxyProtocolNeed); + /* we need to keep reading, since we might have buffered data */ + } + else { + /* proxy header received */ + std::vector<ProxyProtocolValue> proxyProtocolValues; + if (!handleProxyProtocol(d_ci.remote, true, *d_threadData.holders.acl, d_buffer, d_proxiedRemote, d_proxiedDestination, proxyProtocolValues)) { + vinfolog("Error handling the Proxy Protocol received from TCP client %s", d_ci.remote.toStringWithPort()); + return ProxyProtocolResult::Error; + } + + if (!proxyProtocolValues.empty()) { + d_proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(std::move(proxyProtocolValues)); + } + + return ProxyProtocolResult::Done; + } + } + else { + d_lastIOBlocked = true; + } + } while (active() && !d_lastIOBlocked); + + return ProxyProtocolResult::Reading; +} + +IOState IncomingTCPConnectionState::handleHandshake(const struct timeval& now) +{ + DEBUGLOG("doing handshake"); + auto iostate = d_handler.tryHandshake(); + if (iostate == IOState::Done) { + DEBUGLOG("handshake done"); + handleHandshakeDone(now); + + if (d_ci.cs != nullptr && d_ci.cs->d_enableProxyProtocol && !isProxyPayloadOutsideTLS() && expectProxyProtocolFrom(d_ci.remote)) { + d_state = State::readingProxyProtocolHeader; + d_buffer.resize(s_proxyProtocolMinimumHeaderSize); + d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize; + } + else { + d_state = State::readingQuerySize; + } + } + else { + d_lastIOBlocked = true; + } + + return iostate; +} + +IOState IncomingTCPConnectionState::handleIncomingQueryReceived(const struct timeval& now) +{ + DEBUGLOG("query received"); + d_buffer.resize(d_querySize); + + d_state = State::idle; + auto processingResult = handleQuery(std::move(d_buffer), now, std::nullopt); + switch (processingResult) { + case QueryProcessingResult::TooSmall: + /* fall-through */ + case QueryProcessingResult::InvalidHeaders: + /* fall-through */ + case QueryProcessingResult::Dropped: + /* fall-through */ + case QueryProcessingResult::NoBackend: + terminateClientConnection(); + ; + default: + break; + } + + /* the state might have been updated in the meantime, we don't want to override it + in that case */ + if (active() && d_state != State::idle) { + if (d_ioState->isWaitingForRead()) { + return IOState::NeedRead; + } + if (d_ioState->isWaitingForWrite()) { + return IOState::NeedWrite; + } + return IOState::Done; + } + return IOState::Done; +}; + +void IncomingTCPConnectionState::handleExceptionDuringIO(const std::exception& exp) +{ + if (d_state == State::idle || d_state == State::waitingForQuery) { + /* no need to increase any counters in that case, the client is simply done with us */ + } + else if (d_state == State::doingHandshake || d_state == State::readingProxyProtocolHeader || d_state == State::waitingForQuery || d_state == State::readingQuerySize || d_state == State::readingQuery) { + ++d_ci.cs->tcpDiedReadingQuery; + } + else if (d_state == State::sendingResponse) { + /* unlikely to happen here, the exception should be handled in sendResponse() */ + ++d_ci.cs->tcpDiedSendingResponse; + } + + if (d_ioState->isWaitingForWrite() || d_queriesCount == 0) { + DEBUGLOG("Got an exception while handling TCP query: " << exp.what()); + vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (d_ioState->isWaitingForRead() ? "reading" : "writing"), d_ci.remote.toStringWithPort(), exp.what()); + } + else { + vinfolog("Closing TCP client connection with %s: %s", d_ci.remote.toStringWithPort(), exp.what()); + DEBUGLOG("Closing TCP client connection: " << exp.what()); + } + /* remove this FD from the IO multiplexer */ + terminateClientConnection(); } -void IncomingTCPConnectionState::handleIO(std::shared_ptr<IncomingTCPConnectionState>& state, const struct timeval& now) +bool IncomingTCPConnectionState::readIncomingQuery(const timeval& now, IOState& iostate) +{ + if (!d_lastIOBlocked && (d_state == State::waitingForQuery || d_state == State::readingQuerySize)) { + DEBUGLOG("reading query size"); + d_buffer.resize(sizeof(uint16_t)); + iostate = d_handler.tryRead(d_buffer, d_currentPos, sizeof(uint16_t)); + if (d_currentPos > 0) { + /* if we got at least one byte, we can't go around sending responses */ + d_state = State::readingQuerySize; + } + + if (iostate == IOState::Done) { + DEBUGLOG("query size received"); + d_state = State::readingQuery; + d_querySizeReadTime = now; + if (d_queriesCount == 0) { + d_firstQuerySizeReadTime = now; + } + d_querySize = d_buffer.at(0) * 256 + d_buffer.at(1); + if (d_querySize < sizeof(dnsheader)) { + /* go away */ + terminateClientConnection(); + return true; + } + + d_buffer.resize(d_querySize); + d_currentPos = 0; + } + else { + d_lastIOBlocked = true; + } + } + + if (!d_lastIOBlocked && d_state == State::readingQuery) { + DEBUGLOG("reading query"); + iostate = d_handler.tryRead(d_buffer, d_currentPos, d_querySize); + if (iostate == IOState::Done) { + iostate = handleIncomingQueryReceived(now); + } + else { + d_lastIOBlocked = true; + } + } + + return false; +} + +void IncomingTCPConnectionState::handleIO() { // why do we loop? Because the TLS layer does buffering, and thus can have data ready to read // even though the underlying socket is not ready, so we need to actually ask for the data first IOState iostate = IOState::Done; + timeval now{}; + gettimeofday(&now, nullptr); + do { iostate = IOState::Done; - IOStateGuard ioGuard(state->d_ioState); + IOStateGuard ioGuard(d_ioState); - if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) { - vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort()); + if (maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) { + vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", d_ci.remote.toStringWithPort()); // will be handled by the ioGuard - //handleNewIOState(state, IOState::Done, fd, handleIOCallback); + // handleNewIOState(state, IOState::Done, fd, handleIOCallback); return; } - state->d_lastIOBlocked = false; + d_lastIOBlocked = false; try { - if (state->d_state == IncomingTCPConnectionState::State::doingHandshake) { - DEBUGLOG("doing handshake"); - iostate = state->d_handler.tryHandshake(); - if (iostate == IOState::Done) { - DEBUGLOG("handshake done"); - if (state->d_handler.isTLS()) { - if (!state->d_handler.hasTLSSessionBeenResumed()) { - ++state->d_ci.cs->tlsNewSessions; - } - else { - ++state->d_ci.cs->tlsResumptions; - } - if (state->d_handler.getResumedFromInactiveTicketKey()) { - ++state->d_ci.cs->tlsInactiveTicketKey; - } - if (state->d_handler.getUnknownTicketKey()) { - ++state->d_ci.cs->tlsUnknownTicketKey; - } - } - - state->d_handshakeDoneTime = now; - if (expectProxyProtocolFrom(state->d_ci.remote)) { - state->d_state = IncomingTCPConnectionState::State::readingProxyProtocolHeader; - state->d_buffer.resize(s_proxyProtocolMinimumHeaderSize); - state->d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize; - } - else { - state->d_state = IncomingTCPConnectionState::State::readingQuerySize; - } + if (d_state == State::starting) { + if (d_ci.cs != nullptr && d_ci.cs->d_enableProxyProtocol && isProxyPayloadOutsideTLS() && expectProxyProtocolFrom(d_ci.remote)) { + d_state = State::readingProxyProtocolHeader; + d_buffer.resize(s_proxyProtocolMinimumHeaderSize); + d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize; } else { - state->d_lastIOBlocked = true; + d_state = State::doingHandshake; } } - if (!state->d_lastIOBlocked && state->d_state == IncomingTCPConnectionState::State::readingProxyProtocolHeader) { - do { - DEBUGLOG("reading proxy protocol header"); - iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_proxyProtocolNeed); - if (iostate == IOState::Done) { - state->d_buffer.resize(state->d_currentPos); - ssize_t remaining = isProxyHeaderComplete(state->d_buffer); - if (remaining == 0) { - vinfolog("Unable to consume proxy protocol header in packet from TCP client %s", state->d_ci.remote.toStringWithPort()); - ++g_stats.proxyProtocolInvalid; - break; - } - else if (remaining < 0) { - state->d_proxyProtocolNeed += -remaining; - state->d_buffer.resize(state->d_currentPos + state->d_proxyProtocolNeed); - /* we need to keep reading, since we might have buffered data */ - iostate = IOState::NeedRead; - } - else { - /* proxy header received */ - std::vector<ProxyProtocolValue> proxyProtocolValues; - if (!handleProxyProtocol(state->d_ci.remote, true, *state->d_threadData.holders.acl, state->d_buffer, state->d_proxiedRemote, state->d_proxiedDestination, proxyProtocolValues)) { - vinfolog("Error handling the Proxy Protocol received from TCP client %s", state->d_ci.remote.toStringWithPort()); - break; - } - - if (!proxyProtocolValues.empty()) { - state->d_proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(std::move(proxyProtocolValues)); - } - - state->d_state = IncomingTCPConnectionState::State::readingQuerySize; - state->d_buffer.resize(sizeof(uint16_t)); - state->d_currentPos = 0; - state->d_proxyProtocolNeed = 0; - break; - } - } - else { - state->d_lastIOBlocked = true; - } - } - while (state->active() && !state->d_lastIOBlocked); + if (d_state == State::doingHandshake) { + iostate = handleHandshake(now); } - if (!state->d_lastIOBlocked && (state->d_state == IncomingTCPConnectionState::State::waitingForQuery || - state->d_state == IncomingTCPConnectionState::State::readingQuerySize)) { - DEBUGLOG("reading query size"); - iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, sizeof(uint16_t)); - if (state->d_currentPos > 0) { - /* if we got at least one byte, we can't go around sending responses */ - state->d_state = IncomingTCPConnectionState::State::readingQuerySize; - } - - if (iostate == IOState::Done) { - DEBUGLOG("query size received"); - state->d_state = IncomingTCPConnectionState::State::readingQuery; - state->d_querySizeReadTime = now; - if (state->d_queriesCount == 0) { - state->d_firstQuerySizeReadTime = now; + if (!d_lastIOBlocked && d_state == State::readingProxyProtocolHeader) { + auto status = handleProxyProtocolPayload(); + if (status == ProxyProtocolResult::Done) { + if (isProxyPayloadOutsideTLS()) { + d_state = State::doingHandshake; + iostate = handleHandshake(now); } - state->d_querySize = state->d_buffer.at(0) * 256 + state->d_buffer.at(1); - if (state->d_querySize < sizeof(dnsheader)) { - /* go away */ - state->terminateClientConnection(); - return; + else { + d_state = State::readingQuerySize; + d_buffer.resize(sizeof(uint16_t)); + d_currentPos = 0; + d_proxyProtocolNeed = 0; } - - /* allocate a bit more memory to be able to spoof the content, get an answer from the cache - or to add ECS without allocating a new buffer */ - state->d_buffer.resize(std::max(state->d_querySize + static_cast<size_t>(512), s_maxPacketCacheEntrySize)); - state->d_currentPos = 0; + } + else if (status == ProxyProtocolResult::Error) { + iostate = IOState::Done; } else { - state->d_lastIOBlocked = true; + iostate = IOState::NeedRead; } } - if (!state->d_lastIOBlocked && state->d_state == IncomingTCPConnectionState::State::readingQuery) { - DEBUGLOG("reading query"); - iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_querySize); - if (iostate == IOState::Done) { - DEBUGLOG("query received"); - state->d_buffer.resize(state->d_querySize); - - state->d_state = IncomingTCPConnectionState::State::idle; - handleQuery(state, now); - /* the state might have been updated in the meantime, we don't want to override it - in that case */ - if (state->active() && state->d_state != IncomingTCPConnectionState::State::idle) { - if (state->d_ioState->isWaitingForRead()) { - iostate = IOState::NeedRead; - } - else if (state->d_ioState->isWaitingForWrite()) { - iostate = IOState::NeedWrite; - } - else { - iostate = IOState::Done; - } - } - } - else { - state->d_lastIOBlocked = true; + if (!d_lastIOBlocked && (d_state == State::waitingForQuery || d_state == State::readingQuerySize || d_state == State::readingQuery)) { + if (readIncomingQuery(now, iostate)) { + return; } } - if (!state->d_lastIOBlocked && state->d_state == IncomingTCPConnectionState::State::sendingResponse) { + if (!d_lastIOBlocked && d_state == State::sendingResponse) { DEBUGLOG("sending response"); - iostate = state->d_handler.tryWrite(state->d_currentResponse.d_buffer, state->d_currentPos, state->d_currentResponse.d_buffer.size()); + iostate = d_handler.tryWrite(d_currentResponse.d_buffer, d_currentPos, d_currentResponse.d_buffer.size()); if (iostate == IOState::Done) { - DEBUGLOG("response sent from "<<__PRETTY_FUNCTION__); - handleResponseSent(state, state->d_currentResponse); - state->d_state = IncomingTCPConnectionState::State::idle; + DEBUGLOG("response sent from " << __PRETTY_FUNCTION__); + handleResponseSent(d_currentResponse); + d_state = State::idle; } else { - state->d_lastIOBlocked = true; + d_lastIOBlocked = true; } } - if (state->active() && - !state->d_lastIOBlocked && - iostate == IOState::Done && - (state->d_state == IncomingTCPConnectionState::State::idle || - state->d_state == IncomingTCPConnectionState::State::waitingForQuery)) - { + if (active() && !d_lastIOBlocked && iostate == IOState::Done && (d_state == State::idle || d_state == State::waitingForQuery)) { // try sending queued responses DEBUGLOG("send responses, if any"); + auto state = shared_from_this(); iostate = sendQueuedResponses(state, now); - if (!state->d_lastIOBlocked && state->active() && iostate == IOState::Done) { + if (!d_lastIOBlocked && active() && iostate == IOState::Done) { // if the query has been passed to a backend, or dropped, and the responses have been sent, // we can start reading again - if (state->canAcceptNewQueries(now)) { - state->resetForNewQuery(); + if (canAcceptNewQueries(now)) { + resetForNewQuery(); iostate = IOState::NeedRead; } else { - state->d_state = IncomingTCPConnectionState::State::idle; + d_state = State::idle; iostate = IOState::Done; } } } - if (state->d_state != IncomingTCPConnectionState::State::idle && - state->d_state != IncomingTCPConnectionState::State::doingHandshake && - state->d_state != IncomingTCPConnectionState::State::readingProxyProtocolHeader && - state->d_state != IncomingTCPConnectionState::State::waitingForQuery && - state->d_state != IncomingTCPConnectionState::State::readingQuerySize && - state->d_state != IncomingTCPConnectionState::State::readingQuery && - state->d_state != IncomingTCPConnectionState::State::sendingResponse) { - vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(state->d_state)); + if (d_state != State::idle && d_state != State::doingHandshake && d_state != State::readingProxyProtocolHeader && d_state != State::waitingForQuery && d_state != State::readingQuerySize && d_state != State::readingQuery && d_state != State::sendingResponse) { + vinfolog("Unexpected state %d in handleIOCallback", static_cast<int>(d_state)); } } - catch (const std::exception& e) { + catch (const std::exception& exp) { /* most likely an EOF because the other end closed the connection, but it might also be a real IO error or something else. Let's just drop the connection */ - if (state->d_state == IncomingTCPConnectionState::State::idle || - state->d_state == IncomingTCPConnectionState::State::waitingForQuery) { - /* no need to increase any counters in that case, the client is simply done with us */ - } - else if (state->d_state == IncomingTCPConnectionState::State::doingHandshake || - state->d_state != IncomingTCPConnectionState::State::readingProxyProtocolHeader || - state->d_state == IncomingTCPConnectionState::State::waitingForQuery || - state->d_state == IncomingTCPConnectionState::State::readingQuerySize || - state->d_state == IncomingTCPConnectionState::State::readingQuery) { - ++state->d_ci.cs->tcpDiedReadingQuery; - } - else if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) { - /* unlikely to happen here, the exception should be handled in sendResponse() */ - ++state->d_ci.cs->tcpDiedSendingResponse; - } - - if (state->d_ioState->isWaitingForWrite() || state->d_queriesCount == 0) { - DEBUGLOG("Got an exception while handling TCP query: "<<e.what()); - vinfolog("Got an exception while handling (%s) TCP query from %s: %s", (state->d_ioState->isWaitingForRead() ? "reading" : "writing"), state->d_ci.remote.toStringWithPort(), e.what()); - } - else { - vinfolog("Closing TCP client connection with %s: %s", state->d_ci.remote.toStringWithPort(), e.what()); - DEBUGLOG("Closing TCP client connection: "<<e.what()); - } - /* remove this FD from the IO multiplexer */ - state->terminateClientConnection(); + handleExceptionDuringIO(exp); } - if (!state->active()) { + if (!active()) { DEBUGLOG("state is no longer active"); return; } + auto state = shared_from_this(); if (iostate == IOState::Done) { - state->d_ioState->update(iostate, handleIOCallback, state); + d_ioState->update(iostate, handleIOCallback, state); } else { updateIO(state, iostate, now); } ioGuard.release(); - } - while ((iostate == IOState::NeedRead || iostate == IOState::NeedWrite) && !state->d_lastIOBlocked); + } while ((iostate == IOState::NeedRead || iostate == IOState::NeedWrite) && !d_lastIOBlocked); } -void IncomingTCPConnectionState::notifyIOError(InternalQueryState&& query, const struct timeval& now) +void IncomingTCPConnectionState::notifyIOError(const struct timeval& now, TCPResponse&& response) { if (std::this_thread::get_id() != d_creatorThreadID) { /* empty buffer will signal an IO error */ - TCPResponse response(PacketBuffer(), std::move(query), nullptr, nullptr); + response.d_buffer.clear(); handleCrossProtocolResponse(now, std::move(response)); return; } @@ -1105,7 +1199,7 @@ void IncomingTCPConnectionState::notifyIOError(InternalQueryState&& query, const if (state->active() && iostate != IOState::Done) { // we need to update the state right away, nobody will do that for us - updateIO(state, iostate, now); + updateIO(state, iostate, now); } } catch (const std::exception& e) { @@ -1126,14 +1220,14 @@ void IncomingTCPConnectionState::handleXFRResponse(const struct timeval& now, TC } std::shared_ptr<IncomingTCPConnectionState> state = shared_from_this(); - queueResponse(state, now, std::move(response)); + queueResponse(state, now, std::move(response), true); } void IncomingTCPConnectionState::handleTimeout(std::shared_ptr<IncomingTCPConnectionState>& state, bool write) { vinfolog("Timeout while %s TCP client %s", (write ? "writing to" : "reading from"), state->d_ci.remote.toStringWithPort()); DEBUGLOG("client timeout"); - DEBUGLOG("Processed "<<state->d_queriesCount<<" queries, current count is "<<state->d_currentQueriesCount<<", "<<state->d_ownedConnectionsToBackend.size()<<" owned connections, "<<state->d_queuedResponses.size()<<" response queued"); + DEBUGLOG("Processed " << state->d_queriesCount << " queries, current count is " << state->d_currentQueriesCount << ", " << state->d_ownedConnectionsToBackend.size() << " owned connections, " << state->d_queuedResponses.size() << " response queued"); if (write || state->d_currentQueriesCount == 0) { ++state->d_ci.cs->tcpClientTimeouts; @@ -1142,124 +1236,102 @@ void IncomingTCPConnectionState::handleTimeout(std::shared_ptr<IncomingTCPConnec else { DEBUGLOG("Going idle"); /* we still have some queries in flight, let's just stop reading for now */ - state->d_state = IncomingTCPConnectionState::State::idle; + state->d_state = State::idle; state->d_ioState->update(IOState::Done, handleIOCallback, state); } } static void handleIncomingTCPQuery(int pipefd, FDMultiplexer::funcparam_t& param) { - auto threadData = boost::any_cast<TCPClientThreadData*>(param); + auto* threadData = boost::any_cast<TCPClientThreadData*>(param); - ConnectionInfo* citmp{nullptr}; - - ssize_t got = read(pipefd, &citmp, sizeof(citmp)); - if (got == 0) { - throw std::runtime_error("EOF while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode"); - } - else if (got == -1) { - if (errno == EAGAIN || errno == EINTR) { + std::unique_ptr<ConnectionInfo> citmp{nullptr}; + try { + auto tmp = threadData->queryReceiver.receive(); + if (!tmp) { return; } - throw std::runtime_error("Error while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + stringerror()); + citmp = std::move(*tmp); } - else if (got != sizeof(citmp)) { - throw std::runtime_error("Partial read while reading from the TCP acceptor pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode"); + catch (const std::exception& e) { + throw std::runtime_error("Error while reading from the TCP query channel: " + std::string(e.what())); } - try { - g_tcpclientthreads->decrementQueuedCount(); + g_tcpclientthreads->decrementQueuedCount(); - struct timeval now; - gettimeofday(&now, nullptr); - auto state = std::make_shared<IncomingTCPConnectionState>(std::move(*citmp), *threadData, now); - delete citmp; - citmp = nullptr; + timeval now{}; + gettimeofday(&now, nullptr); - IncomingTCPConnectionState::handleIO(state, now); + if (citmp->cs->dohFrontend) { +#if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2) + auto state = std::make_shared<IncomingHTTP2Connection>(std::move(*citmp), *threadData, now); + state->handleIO(); +#endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */ } - catch (...) { - delete citmp; - citmp = nullptr; - throw; + else { + auto state = std::make_shared<IncomingTCPConnectionState>(std::move(*citmp), *threadData, now); + state->handleIO(); } } static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& param) { - auto threadData = boost::any_cast<TCPClientThreadData*>(param); - CrossProtocolQuery* tmp{nullptr}; + auto* threadData = boost::any_cast<TCPClientThreadData*>(param); - ssize_t got = read(pipefd, &tmp, sizeof(tmp)); - if (got == 0) { - throw std::runtime_error("EOF while reading from the TCP cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode"); - } - else if (got == -1) { - if (errno == EAGAIN || errno == EINTR) { + std::unique_ptr<CrossProtocolQuery> cpq{nullptr}; + try { + auto tmp = threadData->crossProtocolQueryReceiver.receive(); + if (!tmp) { return; } - throw std::runtime_error("Error while reading from the TCP cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + stringerror()); + cpq = std::move(*tmp); } - else if (got != sizeof(tmp)) { - throw std::runtime_error("Partial read while reading from the TCP cross-protocol pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode"); + catch (const std::exception& e) { + throw std::runtime_error("Error while reading from the TCP cross-protocol channel: " + std::string(e.what())); } - try { - struct timeval now; - gettimeofday(&now, nullptr); + timeval now{}; + gettimeofday(&now, nullptr); - std::shared_ptr<TCPQuerySender> tqs = tmp->getTCPQuerySender(); - auto query = std::move(tmp->query); - auto downstreamServer = std::move(tmp->downstream); - auto proxyProtocolPayloadSize = tmp->proxyProtocolPayloadSize; - delete tmp; - tmp = nullptr; + std::shared_ptr<TCPQuerySender> tqs = cpq->getTCPQuerySender(); + auto query = std::move(cpq->query); + auto downstreamServer = std::move(cpq->downstream); - try { - auto downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(threadData->mplexer, downstreamServer, now, std::string()); + try { + auto downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(threadData->mplexer, downstreamServer, now, std::string()); - prependSizeToTCPQuery(query.d_buffer, proxyProtocolPayloadSize); - query.d_proxyProtocolPayloadAddedSize = proxyProtocolPayloadSize; + prependSizeToTCPQuery(query.d_buffer, query.d_idstate.d_proxyProtocolPayloadSize); - vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), query.d_idstate.origRemote.toStringWithPort(), query.d_idstate.protocol.toString(), query.d_buffer.size(), downstreamServer->getNameWithAddr()); + vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), query.d_idstate.origRemote.toStringWithPort(), query.d_idstate.protocol.toString(), query.d_buffer.size(), downstreamServer->getNameWithAddr()); - downstream->queueQuery(tqs, std::move(query)); - } - catch (...) { - tqs->notifyIOError(std::move(query.d_idstate), now); - } + downstream->queueQuery(tqs, std::move(query)); } catch (...) { - delete tmp; - tmp = nullptr; + tqs->notifyIOError(now, std::move(query)); } } static void handleCrossProtocolResponse(int pipefd, FDMultiplexer::funcparam_t& param) { - TCPCrossProtocolResponse* tmp{nullptr}; + auto* threadData = boost::any_cast<TCPClientThreadData*>(param); - ssize_t got = read(pipefd, &tmp, sizeof(tmp)); - if (got == 0) { - throw std::runtime_error("EOF while reading from the TCP cross-protocol response pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode"); - } - else if (got == -1) { - if (errno == EAGAIN || errno == EINTR) { + std::unique_ptr<TCPCrossProtocolResponse> cpr{nullptr}; + try { + auto tmp = threadData->crossProtocolResponseReceiver.receive(); + if (!tmp) { return; } - throw std::runtime_error("Error while reading from the TCP cross-protocol response pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode:" + stringerror()); + cpr = std::move(*tmp); } - else if (got != sizeof(tmp)) { - throw std::runtime_error("Partial read while reading from the TCP cross-protocol response pipe (" + std::to_string(pipefd) + ") in " + std::string(isNonBlocking(pipefd) ? "non-blocking" : "blocking") + " mode"); + catch (const std::exception& e) { + throw std::runtime_error("Error while reading from the TCP cross-protocol response: " + std::string(e.what())); } - auto response = std::move(*tmp); - delete tmp; - tmp = nullptr; + auto& response = *cpr; try { if (response.d_response.d_buffer.empty()) { - response.d_state->notifyIOError(std::move(response.d_response.d_idstate), response.d_now); + response.d_state->notifyIOError(response.d_now, std::move(response.d_response)); } else if (response.d_response.d_idstate.qtype == QType::AXFR || response.d_response.d_idstate.qtype == QType::IXFR) { response.d_state->handleXFRResponse(response.d_now, std::move(response.d_response)); @@ -1275,15 +1347,114 @@ static void handleCrossProtocolResponse(int pipefd, FDMultiplexer::funcparam_t& struct TCPAcceptorParam { - ClientState& cs; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + ClientState& clientState; ComboAddress local; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) LocalStateHolder<NetmaskGroup>& acl; int socket{-1}; }; static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadData* threadData); -static void tcpClientThread(int pipefd, int crossProtocolQueriesPipeFD, int crossProtocolResponsesListenPipeFD, int crossProtocolResponsesWritePipeFD, std::vector<ClientState*> tcpAcceptStates) +static void scanForTimeouts(const TCPClientThreadData& data, const timeval& now) +{ + auto expiredReadConns = data.mplexer->getTimeouts(now, false); + for (const auto& cbData : expiredReadConns) { + if (cbData.second.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) { + auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(cbData.second); + if (cbData.first == state->d_handler.getDescriptor()) { + vinfolog("Timeout (read) from remote TCP client %s", state->d_ci.remote.toStringWithPort()); + state->handleTimeout(state, false); + } + } +#if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2) + else if (cbData.second.type() == typeid(std::shared_ptr<IncomingHTTP2Connection>)) { + auto state = boost::any_cast<std::shared_ptr<IncomingHTTP2Connection>>(cbData.second); + if (cbData.first == state->d_handler.getDescriptor()) { + vinfolog("Timeout (read) from remote H2 client %s", state->d_ci.remote.toStringWithPort()); + std::shared_ptr<IncomingTCPConnectionState> parentState = state; + state->handleTimeout(parentState, false); + } + } +#endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */ + else if (cbData.second.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) { + auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(cbData.second); + vinfolog("Timeout (read) from remote backend %s", conn->getBackendName()); + conn->handleTimeout(now, false); + } + } + + auto expiredWriteConns = data.mplexer->getTimeouts(now, true); + for (const auto& cbData : expiredWriteConns) { + if (cbData.second.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) { + auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(cbData.second); + if (cbData.first == state->d_handler.getDescriptor()) { + vinfolog("Timeout (write) from remote TCP client %s", state->d_ci.remote.toStringWithPort()); + state->handleTimeout(state, true); + } + } +#if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2) + else if (cbData.second.type() == typeid(std::shared_ptr<IncomingHTTP2Connection>)) { + auto state = boost::any_cast<std::shared_ptr<IncomingHTTP2Connection>>(cbData.second); + if (cbData.first == state->d_handler.getDescriptor()) { + vinfolog("Timeout (write) from remote H2 client %s", state->d_ci.remote.toStringWithPort()); + std::shared_ptr<IncomingTCPConnectionState> parentState = state; + state->handleTimeout(parentState, true); + } + } +#endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */ + else if (cbData.second.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) { + auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(cbData.second); + vinfolog("Timeout (write) from remote backend %s", conn->getBackendName()); + conn->handleTimeout(now, true); + } + } +} + +static void dumpTCPStates(const TCPClientThreadData& data) +{ + /* just to keep things clean in the output, debug only */ + static std::mutex s_lock; + std::lock_guard<decltype(s_lock)> lck(s_lock); + if (g_tcpStatesDumpRequested > 0) { + /* no race here, we took the lock so it can only be increased in the meantime */ + --g_tcpStatesDumpRequested; + infolog("Dumping the TCP states, as requested:"); + data.mplexer->runForAllWatchedFDs([](bool isRead, int desc, const FDMultiplexer::funcparam_t& param, struct timeval ttd) { + timeval lnow{}; + gettimeofday(&lnow, nullptr); + if (ttd.tv_sec > 0) { + infolog("- Descriptor %d is in %s state, TTD in %d", desc, (isRead ? "read" : "write"), (ttd.tv_sec - lnow.tv_sec)); + } + else { + infolog("- Descriptor %d is in %s state, no TTD set", desc, (isRead ? "read" : "write")); + } + + if (param.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) { + auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param); + infolog(" - %s", state->toString()); + } +#if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2) + else if (param.type() == typeid(std::shared_ptr<IncomingHTTP2Connection>)) { + auto state = boost::any_cast<std::shared_ptr<IncomingHTTP2Connection>>(param); + infolog(" - %s", state->toString()); + } +#endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */ + else if (param.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) { + auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(param); + infolog(" - %s", conn->toString()); + } + else if (param.type() == typeid(TCPClientThreadData*)) { + infolog(" - Worker thread pipe"); + } + }); + infolog("The TCP/DoT client cache has %d active and %d idle outgoing connections cached", t_downstreamTCPConnectionsManager.getActiveCount(), t_downstreamTCPConnectionsManager.getIdleCount()); + } +} + +// NOLINTNEXTLINE(performance-unnecessary-value-param): you are wrong, clang-tidy, go home +static void tcpClientThread(pdns::channel::Receiver<ConnectionInfo>&& queryReceiver, pdns::channel::Receiver<CrossProtocolQuery>&& crossProtocolQueryReceiver, pdns::channel::Receiver<TCPCrossProtocolResponse>&& crossProtocolResponseReceiver, pdns::channel::Sender<TCPCrossProtocolResponse>&& crossProtocolResponseSender, std::vector<ClientState*> tcpAcceptStates) { /* we get launched with a pipe on which we receive file descriptors from clients that we own from that point on */ @@ -1292,11 +1463,14 @@ static void tcpClientThread(int pipefd, int crossProtocolQueriesPipeFD, int cros try { TCPClientThreadData data; - /* this is the writing end! */ - data.crossProtocolResponsesPipe = crossProtocolResponsesWritePipeFD; - data.mplexer->addReadFD(pipefd, handleIncomingTCPQuery, &data); - data.mplexer->addReadFD(crossProtocolQueriesPipeFD, handleCrossProtocolQuery, &data); - data.mplexer->addReadFD(crossProtocolResponsesListenPipeFD, handleCrossProtocolResponse, &data); + data.crossProtocolResponseSender = std::move(crossProtocolResponseSender); + data.queryReceiver = std::move(queryReceiver); + data.crossProtocolQueryReceiver = std::move(crossProtocolQueryReceiver); + data.crossProtocolResponseReceiver = std::move(crossProtocolResponseReceiver); + + data.mplexer->addReadFD(data.queryReceiver.getDescriptor(), handleIncomingTCPQuery, &data); + data.mplexer->addReadFD(data.crossProtocolQueryReceiver.getDescriptor(), handleCrossProtocolQuery, &data); + data.mplexer->addReadFD(data.crossProtocolResponseReceiver.getDescriptor(), handleCrossProtocolResponse, &data); /* only used in single acceptor mode for now */ auto acl = g_ACL.getLocal(); @@ -1311,17 +1485,16 @@ static void tcpClientThread(int pipefd, int crossProtocolQueriesPipeFD, int cros } auto acceptCallback = [&data](int socket, FDMultiplexer::funcparam_t& funcparam) { - auto acceptorParam = boost::any_cast<const TCPAcceptorParam*>(funcparam); + const auto* acceptorParam = boost::any_cast<const TCPAcceptorParam*>(funcparam); acceptNewConnection(*acceptorParam, &data); }; - for (size_t idx = 0; idx < acceptParams.size(); idx++) { - const auto& param = acceptParams.at(idx); + for (const auto& param : acceptParams) { setNonBlocking(param.socket); data.mplexer->addReadFD(param.socket, acceptCallback, ¶m); } - struct timeval now; + timeval now{}; gettimeofday(&now, nullptr); time_t lastTimeoutScan = now.tv_sec; @@ -1333,76 +1506,15 @@ static void tcpClientThread(int pipefd, int crossProtocolQueriesPipeFD, int cros if (now.tv_sec > lastTimeoutScan) { lastTimeoutScan = now.tv_sec; - auto expiredReadConns = data.mplexer->getTimeouts(now, false); - for (const auto& cbData : expiredReadConns) { - if (cbData.second.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) { - auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(cbData.second); - if (cbData.first == state->d_handler.getDescriptor()) { - vinfolog("Timeout (read) from remote TCP client %s", state->d_ci.remote.toStringWithPort()); - state->handleTimeout(state, false); - } - } - else if (cbData.second.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) { - auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(cbData.second); - vinfolog("Timeout (read) from remote backend %s", conn->getBackendName()); - conn->handleTimeout(now, false); - } - } - - auto expiredWriteConns = data.mplexer->getTimeouts(now, true); - for (const auto& cbData : expiredWriteConns) { - if (cbData.second.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) { - auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(cbData.second); - if (cbData.first == state->d_handler.getDescriptor()) { - vinfolog("Timeout (write) from remote TCP client %s", state->d_ci.remote.toStringWithPort()); - state->handleTimeout(state, true); - } - } - else if (cbData.second.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) { - auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(cbData.second); - vinfolog("Timeout (write) from remote backend %s", conn->getBackendName()); - conn->handleTimeout(now, true); - } - } + scanForTimeouts(data, now); if (g_tcpStatesDumpRequested > 0) { - /* just to keep things clean in the output, debug only */ - static std::mutex s_lock; - std::lock_guard<decltype(s_lock)> lck(s_lock); - if (g_tcpStatesDumpRequested > 0) { - /* no race here, we took the lock so it can only be increased in the meantime */ - --g_tcpStatesDumpRequested; - errlog("Dumping the TCP states, as requested:"); - data.mplexer->runForAllWatchedFDs([](bool isRead, int fd, const FDMultiplexer::funcparam_t& param, struct timeval ttd) - { - struct timeval lnow; - gettimeofday(&lnow, nullptr); - if (ttd.tv_sec > 0) { - errlog("- Descriptor %d is in %s state, TTD in %d", fd, (isRead ? "read" : "write"), (ttd.tv_sec-lnow.tv_sec)); - } - else { - errlog("- Descriptor %d is in %s state, no TTD set", fd, (isRead ? "read" : "write")); - } - - if (param.type() == typeid(std::shared_ptr<IncomingTCPConnectionState>)) { - auto state = boost::any_cast<std::shared_ptr<IncomingTCPConnectionState>>(param); - errlog(" - %s", state->toString()); - } - else if (param.type() == typeid(std::shared_ptr<TCPConnectionToBackend>)) { - auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(param); - errlog(" - %s", conn->toString()); - } - else if (param.type() == typeid(TCPClientThreadData*)) { - errlog(" - Worker thread pipe"); - } - }); - errlog("The TCP/DoT client cache has %d active and %d idle outgoing connections cached", t_downstreamTCPConnectionsManager.getActiveCount(), t_downstreamTCPConnectionsManager.getIdleCount()); - } + dumpTCPStates(data); } } } catch (const std::exception& e) { - errlog("Error in TCP worker thread: %s", e.what()); + warnlog("Error in TCP worker thread: %s", e.what()); } } } @@ -1413,9 +1525,10 @@ static void tcpClientThread(int pipefd, int crossProtocolQueriesPipeFD, int cros static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadData* threadData) { - auto& cs = param.cs; + auto& clientState = param.clientState; auto& acl = param.acl; - int socket = param.socket; + const bool checkACL = clientState.dohFrontend == nullptr || (!clientState.dohFrontend->d_trustForwardedForHeader && clientState.dohFrontend->d_earlyACLDrop); + const int socket = param.socket; bool tcpClientCountIncremented = false; ComboAddress remote; remote.sin4.sin_family = param.local.sin4.sin_family; @@ -1423,41 +1536,43 @@ static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadDa tcpClientCountIncremented = false; try { socklen_t remlen = remote.getSocklen(); - ConnectionInfo ci(&cs); + ConnectionInfo connInfo(&clientState); #ifdef HAVE_ACCEPT4 - ci.fd = accept4(socket, reinterpret_cast<struct sockaddr*>(&remote), &remlen, SOCK_NONBLOCK); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + connInfo.fd = accept4(socket, reinterpret_cast<struct sockaddr*>(&remote), &remlen, SOCK_NONBLOCK); #else - ci.fd = accept(socket, reinterpret_cast<struct sockaddr*>(&remote), &remlen); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + connInfo.fd = accept(socket, reinterpret_cast<struct sockaddr*>(&remote), &remlen); #endif // will be decremented when the ConnectionInfo object is destroyed, no matter the reason - auto concurrentConnections = ++cs.tcpCurrentConnections; + auto concurrentConnections = ++clientState.tcpCurrentConnections; - if (ci.fd < 0) { + if (connInfo.fd < 0) { throw std::runtime_error((boost::format("accepting new connection on socket: %s") % stringerror()).str()); } - if (!acl->match(remote)) { - ++g_stats.aclDrops; + if (checkACL && !acl->match(remote)) { + ++dnsdist::metrics::g_stats.aclDrops; vinfolog("Dropped TCP connection from %s because of ACL", remote.toStringWithPort()); return; } - if (cs.d_tcpConcurrentConnectionsLimit > 0 && concurrentConnections > cs.d_tcpConcurrentConnectionsLimit) { + if (clientState.d_tcpConcurrentConnectionsLimit > 0 && concurrentConnections > clientState.d_tcpConcurrentConnectionsLimit) { vinfolog("Dropped TCP connection from %s because of concurrent connections limit", remote.toStringWithPort()); return; } - if (concurrentConnections > cs.tcpMaxConcurrentConnections.load()) { - cs.tcpMaxConcurrentConnections.store(concurrentConnections); + if (concurrentConnections > clientState.tcpMaxConcurrentConnections.load()) { + clientState.tcpMaxConcurrentConnections.store(concurrentConnections); } #ifndef HAVE_ACCEPT4 - if (!setNonBlocking(ci.fd)) { + if (!setNonBlocking(connInfo.fd)) { return; } #endif - setTCPNoDelay(ci.fd); // disable NAGLE + setTCPNoDelay(connInfo.fd); // disable NAGLE if (g_maxTCPQueuedConnections > 0 && g_tcpclientthreads->getQueuedCount() >= g_maxTCPQueuedConnections) { vinfolog("Dropping TCP connection from %s because we have too many queued already", remote.toStringWithPort()); @@ -1472,19 +1587,29 @@ static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadDa vinfolog("Got TCP connection from %s", remote.toStringWithPort()); - ci.remote = remote; + connInfo.remote = remote; + if (threadData == nullptr) { - if (!g_tcpclientthreads->passConnectionToThread(std::make_unique<ConnectionInfo>(std::move(ci)))) { + if (!g_tcpclientthreads->passConnectionToThread(std::make_unique<ConnectionInfo>(std::move(connInfo)))) { if (tcpClientCountIncremented) { dnsdist::IncomingConcurrentTCPConnectionsManager::accountClosedTCPConnection(remote); } } } else { - struct timeval now; + timeval now{}; gettimeofday(&now, nullptr); - auto state = std::make_shared<IncomingTCPConnectionState>(std::move(ci), *threadData, now); - IncomingTCPConnectionState::handleIO(state, now); + + if (connInfo.cs->dohFrontend) { +#if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2) + auto state = std::make_shared<IncomingHTTP2Connection>(std::move(connInfo), *threadData, now); + state->handleIO(); +#endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */ + } + else { + auto state = std::make_shared<IncomingTCPConnectionState>(std::move(connInfo), *threadData, now); + state->handleIO(); + } } } catch (const std::exception& e) { @@ -1493,14 +1618,15 @@ static void acceptNewConnection(const TCPAcceptorParam& param, TCPClientThreadDa dnsdist::IncomingConcurrentTCPConnectionsManager::accountClosedTCPConnection(remote); } } - catch (...){} + catch (...) { + } } /* spawn as many of these as required, they call Accept on a socket on which they will accept queries, and they will hand off to worker threads & spawn more of them if required */ #ifndef USE_SINGLE_ACCEPTOR_THREAD -void tcpAcceptorThread(std::vector<ClientState*> states) +void tcpAcceptorThread(const std::vector<ClientState*>& states) { setThreadName("dnsdist/tcpAcce"); @@ -1508,7 +1634,7 @@ void tcpAcceptorThread(std::vector<ClientState*> states) std::vector<TCPAcceptorParam> params; params.reserve(states.size()); - for (auto& state : states) { + for (const auto& state : states) { params.emplace_back(TCPAcceptorParam{*state, state->local, acl, state->tcpFD}); for (const auto& [addr, socket] : state->d_additionalAddresses) { params.emplace_back(TCPAcceptorParam{*state, addr, acl, socket}); @@ -1522,19 +1648,18 @@ void tcpAcceptorThread(std::vector<ClientState*> states) } else { auto acceptCallback = [](int socket, FDMultiplexer::funcparam_t& funcparam) { - auto acceptorParam = boost::any_cast<const TCPAcceptorParam*>(funcparam); + const auto* acceptorParam = boost::any_cast<const TCPAcceptorParam*>(funcparam); acceptNewConnection(*acceptorParam, nullptr); }; auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent(params.size())); - for (size_t idx = 0; idx < params.size(); idx++) { - const auto& param = params.at(idx); + for (const auto& param : params) { mplexer->addReadFD(param.socket, acceptCallback, ¶m); } - struct timeval tv; + timeval now{}; while (true) { - mplexer->run(&tv, -1); + mplexer->run(&now, -1); } } } |