diff options
Diffstat (limited to 'dnsdist-tcp-downstream.cc')
-rw-r--r-- | dnsdist-tcp-downstream.cc | 778 |
1 files changed, 778 insertions, 0 deletions
diff --git a/dnsdist-tcp-downstream.cc b/dnsdist-tcp-downstream.cc new file mode 100644 index 0000000..9b83cad --- /dev/null +++ b/dnsdist-tcp-downstream.cc @@ -0,0 +1,778 @@ + +#include "dnsdist-session-cache.hh" +#include "dnsdist-tcp-downstream.hh" +#include "dnsdist-tcp-upstream.hh" + +#include "dnsparser.hh" + +thread_local DownstreamTCPConnectionsManager t_downstreamTCPConnectionsManager; + +ConnectionToBackend::~ConnectionToBackend() +{ + if (d_ds && d_handler) { + --d_ds->tcpCurrentConnections; + struct timeval now; + gettimeofday(&now, nullptr); + + if (d_handler->isTLS()) { + if (d_handler->hasTLSSessionBeenResumed()) { + ++d_ds->tlsResumptions; + } + try { + auto sessions = d_handler->getTLSSessions(); + if (!sessions.empty()) { + g_sessionCache.putSessions(d_ds->getID(), now.tv_sec, std::move(sessions)); + } + } + catch (const std::exception& e) { + vinfolog("Unable to get a TLS session: %s", e.what()); + } + } + auto diff = now - d_connectionStartTime; + // cerr<<"connection to backend terminated after "<<d_queries<<" queries, "<<diff.tv_sec<<" seconds"<<endl; + d_ds->updateTCPMetrics(d_queries, diff.tv_sec * 1000 + diff.tv_usec / 1000); + } +} + +bool ConnectionToBackend::reconnect() +{ + std::unique_ptr<TLSSession> tlsSession{nullptr}; + if (d_handler) { + DEBUGLOG("closing socket "<<d_handler->getDescriptor()); + if (d_handler->isTLS()) { + if (d_handler->hasTLSSessionBeenResumed()) { + ++d_ds->tlsResumptions; + } + try { + auto sessions = d_handler->getTLSSessions(); + if (!sessions.empty()) { + tlsSession = std::move(sessions.back()); + sessions.pop_back(); + if (!sessions.empty()) { + g_sessionCache.putSessions(d_ds->getID(), time(nullptr), std::move(sessions)); + } + } + } + catch (const std::exception& e) { + vinfolog("Unable to get a TLS session to resume: %s", e.what()); + } + } + d_handler->close(); + d_ioState.reset(); + d_handler.reset(); + --d_ds->tcpCurrentConnections; + } + + d_fresh = true; + d_highestStreamID = 0; + d_proxyProtocolPayloadSent = false; + + do { + vinfolog("TCP connecting to downstream %s (%d)", d_ds->getNameWithAddr(), d_downstreamFailures); + DEBUGLOG("Opening TCP connection to backend "<<d_ds->getNameWithAddr()); + ++d_ds->tcpNewConnections; + try { + auto socket = std::make_unique<Socket>(d_ds->remote.sin4.sin_family, SOCK_STREAM, 0); + DEBUGLOG("result of socket() is "<<socket->getHandle()); + + if (!IsAnyAddress(d_ds->sourceAddr)) { + SSetsockopt(socket->getHandle(), SOL_SOCKET, SO_REUSEADDR, 1); +#ifdef IP_BIND_ADDRESS_NO_PORT + if (d_ds->ipBindAddrNoPort) { + SSetsockopt(socket->getHandle(), SOL_IP, IP_BIND_ADDRESS_NO_PORT, 1); + } +#endif +#ifdef SO_BINDTODEVICE + if (!d_ds->sourceItfName.empty()) { + int res = setsockopt(socket->getHandle(), SOL_SOCKET, SO_BINDTODEVICE, d_ds->sourceItfName.c_str(), d_ds->sourceItfName.length()); + if (res != 0) { + vinfolog("Error setting up the interface on backend TCP socket '%s': %s", d_ds->getNameWithAddr(), stringerror()); + } + } +#endif + socket->bind(d_ds->sourceAddr, false); + } + socket->setNonBlocking(); + + gettimeofday(&d_connectionStartTime, nullptr); + auto handler = std::make_unique<TCPIOHandler>(d_ds->d_tlsSubjectName, socket->releaseHandle(), timeval{0,0}, d_ds->d_tlsCtx, d_connectionStartTime.tv_sec); + if (!tlsSession && d_ds->d_tlsCtx) { + tlsSession = g_sessionCache.getSession(d_ds->getID(), d_connectionStartTime.tv_sec); + } + if (tlsSession) { + handler->setTLSSession(tlsSession); + } + handler->tryConnect(d_ds->tcpFastOpen && isFastOpenEnabled(), d_ds->remote); + d_queries = 0; + + d_handler = std::move(handler); + d_ds->incCurrentConnectionsCount(); + return true; + } + catch (const std::runtime_error& e) { + vinfolog("Connection to downstream server %s failed: %s", d_ds->getName(), e.what()); + d_downstreamFailures++; + if (d_downstreamFailures >= d_ds->d_retries) { + throw; + } + } + } + while (d_downstreamFailures < d_ds->d_retries); + + return false; +} + +TCPConnectionToBackend::~TCPConnectionToBackend() +{ + if (d_ds && !d_pendingResponses.empty()) { + d_ds->outstanding -= d_pendingResponses.size(); + } +} + +void TCPConnectionToBackend::release() +{ + d_ds->outstanding -= d_pendingResponses.size(); + + d_pendingResponses.clear(); + d_pendingQueries.clear(); + + if (d_ioState) { + d_ioState.reset(); + } +} + +static void editPayloadID(PacketBuffer& payload, uint16_t newId, size_t proxyProtocolPayloadSize, bool sizePrepended) +{ + /* we cannot do a direct cast as the alignment might be off (the size of the payload might have been prepended, which is bad enough, + but we might also have a proxy protocol payload */ + size_t startOfHeaderOffset = (sizePrepended ? sizeof(uint16_t) : 0) + proxyProtocolPayloadSize; + if (payload.size() < startOfHeaderOffset + sizeof(dnsheader)) { + throw std::runtime_error("Invalid buffer for outgoing TCP query (size " + std::to_string(payload.size())); + } + uint16_t id = htons(newId); + memcpy(&payload.at(startOfHeaderOffset), &id, sizeof(id)); +} + +enum class QueryState : uint8_t { + hasSizePrepended, + noSize +}; + +enum class ConnectionState : uint8_t { + needProxy, + proxySent +}; + +static void prepareQueryForSending(TCPQuery& query, uint16_t id, QueryState queryState, ConnectionState connectionState) +{ + if (connectionState == ConnectionState::needProxy) { + if (query.d_proxyProtocolPayload.size() > 0 && !query.d_proxyProtocolPayloadAdded) { + query.d_buffer.insert(query.d_buffer.begin(), query.d_proxyProtocolPayload.begin(), query.d_proxyProtocolPayload.end()); + query.d_proxyProtocolPayloadAdded = true; + query.d_proxyProtocolPayloadAddedSize = query.d_proxyProtocolPayload.size(); + } + } + else if (connectionState == ConnectionState::proxySent) { + if (query.d_proxyProtocolPayloadAdded) { + if (query.d_buffer.size() < query.d_proxyProtocolPayloadAddedSize) { + throw std::runtime_error("Trying to remove a proxy protocol payload of size " + std::to_string(query.d_proxyProtocolPayload.size()) + " from a buffer of size " + std::to_string(query.d_buffer.size())); + } + query.d_buffer.erase(query.d_buffer.begin(), query.d_buffer.begin() + query.d_proxyProtocolPayloadAddedSize); + query.d_proxyProtocolPayloadAdded = false; + query.d_proxyProtocolPayloadAddedSize = 0; + } + } + editPayloadID(query.d_buffer, id, query.d_proxyProtocolPayloadAdded ? query.d_proxyProtocolPayloadAddedSize : 0, true); +} + +IOState TCPConnectionToBackend::queueNextQuery(std::shared_ptr<TCPConnectionToBackend>& conn) +{ + conn->d_currentQuery = std::move(conn->d_pendingQueries.front()); + + uint16_t id = conn->d_highestStreamID; + prepareQueryForSending(conn->d_currentQuery.d_query, id, QueryState::hasSizePrepended, conn->needProxyProtocolPayload() ? ConnectionState::needProxy : ConnectionState::proxySent); + + conn->d_pendingQueries.pop_front(); + conn->d_state = State::sendingQueryToBackend; + conn->d_currentPos = 0; + + return IOState::NeedWrite; +} + +IOState TCPConnectionToBackend::sendQuery(std::shared_ptr<TCPConnectionToBackend>& conn, const struct timeval& now) +{ + DEBUGLOG("sending query to backend "<<conn->getDS()->getName()<<" over FD "<<conn->d_handler->getDescriptor()); + + IOState state = conn->d_handler->tryWrite(conn->d_currentQuery.d_query.d_buffer, conn->d_currentPos, conn->d_currentQuery.d_query.d_buffer.size()); + + if (state != IOState::Done) { + return state; + } + + DEBUGLOG("query sent to backend"); + /* request sent ! */ + if (conn->d_currentQuery.d_query.d_proxyProtocolPayloadAdded) { + conn->d_proxyProtocolPayloadSent = true; + } + ++conn->d_queries; + conn->d_currentPos = 0; + + DEBUGLOG("adding a pending response for ID "<<conn->d_highestStreamID<<" and QNAME "<<conn->d_currentQuery.d_query.d_idstate.qname); + auto res = conn->d_pendingResponses.insert({conn->d_highestStreamID, std::move(conn->d_currentQuery)}); + /* if there was already a pending response with that ID, we messed up and we don't expect more + than one response */ + if (res.second) { + ++conn->d_ds->outstanding; + } + ++conn->d_highestStreamID; + conn->d_currentQuery.d_sender.reset(); + conn->d_currentQuery.d_query.d_buffer.clear(); + + return state; +} + +void TCPConnectionToBackend::handleIO(std::shared_ptr<TCPConnectionToBackend>& conn, const struct timeval& now) +{ + if (conn->d_handler == nullptr) { + throw std::runtime_error("No downstream socket in " + std::string(__PRETTY_FUNCTION__) + "!"); + } + + bool connectionDied = false; + IOState iostate = IOState::Done; + IOStateGuard ioGuard(conn->d_ioState); + bool reconnected = false; + + do { + reconnected = false; + + try { + if (conn->d_state == State::sendingQueryToBackend) { + iostate = sendQuery(conn, now); + + while (iostate == IOState::Done && !conn->d_pendingQueries.empty()) { + queueNextQuery(conn); + iostate = sendQuery(conn, now); + } + + if (iostate == IOState::Done && conn->d_pendingQueries.empty()) { + conn->d_state = State::waitingForResponseFromBackend; + conn->d_currentPos = 0; + conn->d_responseBuffer.resize(sizeof(uint16_t)); + iostate = IOState::NeedRead; + } + } + + if (conn->d_state == State::waitingForResponseFromBackend || + conn->d_state == State::readingResponseSizeFromBackend) { + DEBUGLOG("reading response size from backend"); + // then we need to allocate a new buffer (new because we might need to re-send the query if the + // backend dies on us) + // We also might need to read and send to the client more than one response in case of XFR (yeah!) + conn->d_responseBuffer.resize(sizeof(uint16_t)); + iostate = conn->d_handler->tryRead(conn->d_responseBuffer, conn->d_currentPos, sizeof(uint16_t)); + if (iostate == IOState::Done) { + DEBUGLOG("got response size from backend"); + conn->d_state = State::readingResponseFromBackend; + conn->d_responseSize = conn->d_responseBuffer.at(0) * 256 + conn->d_responseBuffer.at(1); + conn->d_responseBuffer.reserve(conn->d_responseSize + /* we will need to prepend the size later */ 2); + conn->d_responseBuffer.resize(conn->d_responseSize); + conn->d_currentPos = 0; + conn->d_lastDataReceivedTime = now; + } + else if (conn->d_state == State::waitingForResponseFromBackend && conn->d_currentPos > 0) { + conn->d_state = State::readingResponseSizeFromBackend; + } + } + + if (conn->d_state == State::readingResponseFromBackend) { + DEBUGLOG("reading response from backend"); + iostate = conn->d_handler->tryRead(conn->d_responseBuffer, conn->d_currentPos, conn->d_responseSize); + if (iostate == IOState::Done) { + DEBUGLOG("got response from backend"); + try { + conn->d_lastDataReceivedTime = now; + iostate = conn->handleResponse(conn, now); + } + catch (const std::exception& e) { + vinfolog("Got an exception while handling TCP response from %s (client is %s): %s", conn->d_ds ? conn->d_ds->getName() : "unknown", conn->d_currentQuery.d_query.d_idstate.origRemote.toStringWithPort(), e.what()); + ioGuard.release(); + conn->release(); + return; + } + } + } + + if (conn->d_state != State::idle && + conn->d_state != State::sendingQueryToBackend && + conn->d_state != State::waitingForResponseFromBackend && + conn->d_state != State::readingResponseSizeFromBackend && + conn->d_state != State::readingResponseFromBackend) { + vinfolog("Unexpected state %d in TCPConnectionToBackend::handleIO", static_cast<int>(conn->d_state)); + } + } + catch (const std::exception& e) { + /* most likely an EOF because the other end closed the connection, + but it might also be a real IO error or something else. + Let's just drop the connection + */ + vinfolog("Got an exception while handling (%s backend) TCP query from %s: %s", (conn->d_state == State::sendingQueryToBackend ? "writing to" : "reading from"), conn->d_currentQuery.d_query.d_idstate.origRemote.toStringWithPort(), e.what()); + + if (conn->d_state == State::sendingQueryToBackend) { + ++conn->d_ds->tcpDiedSendingQuery; + } + else if (conn->d_state != State::idle) { + ++conn->d_ds->tcpDiedReadingResponse; + } + + /* don't increase this counter when reusing connections */ + if (conn->d_fresh) { + ++conn->d_downstreamFailures; + } + + /* remove this FD from the IO multiplexer */ + iostate = IOState::Done; + connectionDied = true; + } + + if (connectionDied) { + + DEBUGLOG("connection died, number of failures is "<<conn->d_downstreamFailures<<", retries is "<<conn->d_ds->d_retries); + + if (conn->d_downstreamFailures < conn->d_ds->d_retries) { + + conn->d_ioState.reset(); + ioGuard.release(); + + try { + if (conn->reconnect()) { + conn->d_ioState = make_unique<IOStateHandler>(*conn->d_mplexer, conn->d_handler->getDescriptor()); + + /* we need to resend the queries that were in flight, if any */ + if (conn->d_state == State::sendingQueryToBackend) { + /* we need to edit this query so it has the correct ID */ + auto query = std::move(conn->d_currentQuery); + uint16_t id = conn->d_highestStreamID; + prepareQueryForSending(query.d_query, id, QueryState::hasSizePrepended, ConnectionState::needProxy); + conn->d_currentQuery = std::move(query); + } + + /* if we notify the sender it might terminate us so we need to move these first */ + auto pendingResponses = std::move(conn->d_pendingResponses); + conn->d_pendingResponses.clear(); + for (auto& pending : pendingResponses) { + --conn->d_ds->outstanding; + + if (pending.second.d_query.isXFR() && pending.second.d_query.d_xfrStarted) { + /* this one can't be restarted, sorry */ + DEBUGLOG("A XFR for which a response has already been sent cannot be restarted"); + try { + pending.second.d_sender->notifyIOError(std::move(pending.second.d_query.d_idstate), now); + } + catch (const std::exception& e) { + vinfolog("Got an exception while notifying: %s", e.what()); + } + catch (...) { + vinfolog("Got exception while notifying"); + } + } + else { + conn->d_pendingQueries.push_back(std::move(pending.second)); + } + } + conn->d_currentPos = 0; + + if (conn->d_state == State::sendingQueryToBackend) { + iostate = IOState::NeedWrite; + // resume sending query + } + else { + if (conn->d_pendingQueries.empty()) { + throw std::runtime_error("TCP connection to a backend in state " + std::to_string((int)conn->d_state) + " with no pending queries"); + } + + iostate = queueNextQuery(conn); + } + + reconnected = true; + connectionDied = false; + } + } + catch (const std::exception& e) { + // reconnect might throw on failure, let's ignore that, we just need to know + // it failed + } + } + + if (!reconnected) { + /* reconnect failed, we give up */ + DEBUGLOG("reconnect failed, we give up"); + ++conn->d_ds->tcpGaveUp; + conn->notifyAllQueriesFailed(now, FailureReason::gaveUp); + } + } + + if (conn->d_ioState) { + if (iostate == IOState::Done) { + conn->d_ioState->update(iostate, handleIOCallback, conn); + } + else { + boost::optional<struct timeval> ttd{boost::none}; + if (iostate == IOState::NeedRead) { + ttd = conn->getBackendReadTTD(now); + } + else if (conn->isFresh() && conn->d_queries == 0) { + /* first write just after the non-blocking connect */ + ttd = conn->getBackendConnectTTD(now); + } + else { + ttd = conn->getBackendWriteTTD(now); + } + + conn->d_ioState->update(iostate, handleIOCallback, conn, ttd); + } + } + } + while (reconnected); + + ioGuard.release(); +} + +void TCPConnectionToBackend::handleIOCallback(int fd, FDMultiplexer::funcparam_t& param) +{ + auto conn = boost::any_cast<std::shared_ptr<TCPConnectionToBackend>>(param); + if (fd != conn->getHandle()) { + throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__PRETTY_FUNCTION__) + ", expected " + std::to_string(conn->getHandle())); + } + + struct timeval now; + gettimeofday(&now, nullptr); + handleIO(conn, now); +} + +void TCPConnectionToBackend::queueQuery(std::shared_ptr<TCPQuerySender>& sender, TCPQuery&& query) +{ + if (!d_ioState) { + d_ioState = make_unique<IOStateHandler>(*d_mplexer, d_handler->getDescriptor()); + } + + // if we are not already sending a query or in the middle of reading a response (so idle), + // start sending the query + if (d_state == State::idle || d_state == State::waitingForResponseFromBackend) { + DEBUGLOG("Sending new query to backend right away, with ID "<<d_highestStreamID); + d_state = State::sendingQueryToBackend; + d_currentPos = 0; + + uint16_t id = d_highestStreamID; + + d_currentQuery = PendingRequest({sender, std::move(query)}); + prepareQueryForSending(d_currentQuery.d_query, id, QueryState::hasSizePrepended, needProxyProtocolPayload() ? ConnectionState::needProxy : ConnectionState::proxySent); + + struct timeval now; + gettimeofday(&now, 0); + + auto shared = std::dynamic_pointer_cast<TCPConnectionToBackend>(shared_from_this()); + handleIO(shared, now); + } + else { + DEBUGLOG("Adding new query to the queue because we are in state "<<(int)d_state); + // store query in the list of queries to send + d_pendingQueries.push_back(PendingRequest({sender, std::move(query)})); + } +} + +void TCPConnectionToBackend::handleTimeout(const struct timeval& now, bool write) +{ + /* in some cases we could retry, here, reconnecting and sending our pending responses again */ + if (write) { + if (isFresh() && d_queries == 0) { + ++d_ds->tcpConnectTimeouts; + vinfolog("Timeout while connecting to TCP backend %s", d_ds->getName()); + } + else { + ++d_ds->tcpWriteTimeouts; + vinfolog("Timeout while writing to TCP backend %s", d_ds->getName()); + } + } + else { + ++d_ds->tcpReadTimeouts; + vinfolog("Timeout while reading from TCP backend %s", d_ds->getName()); + } + + try { + notifyAllQueriesFailed(now, FailureReason::timeout); + } + catch (const std::exception& e) { + vinfolog("Got an exception while notifying a timeout: %s", e.what()); + } + catch (...) { + vinfolog("Got exception while notifying a timeout"); + } + + release(); +} + +void TCPConnectionToBackend::notifyAllQueriesFailed(const struct timeval& now, FailureReason reason) +{ + d_connectionDied = true; + + /* we might be terminated while notifying a query sender */ + d_ds->outstanding -= d_pendingResponses.size(); + auto pendingQueries = std::move(d_pendingQueries); + d_pendingQueries.clear(); + auto pendingResponses = std::move(d_pendingResponses); + d_pendingResponses.clear(); + + auto increaseCounters = [reason](std::shared_ptr<TCPQuerySender>& sender) { + if (reason == FailureReason::timeout) { + const ClientState* cs = sender->getClientState(); + if (cs) { + ++cs->tcpDownstreamTimeouts; + } + } + else if (reason == FailureReason::gaveUp) { + const ClientState* cs = sender->getClientState(); + if (cs) { + ++cs->tcpGaveUp; + } + } + }; + + try { + if (d_state == State::sendingQueryToBackend) { + auto sender = d_currentQuery.d_sender; + if (sender->active()) { + increaseCounters(sender); + sender->notifyIOError(std::move(d_currentQuery.d_query.d_idstate), now); + } + } + + for (auto& query : pendingQueries) { + auto sender = query.d_sender; + if (sender->active()) { + increaseCounters(sender); + sender->notifyIOError(std::move(query.d_query.d_idstate), now); + } + } + + for (auto& response : pendingResponses) { + auto sender = response.second.d_sender; + if (sender->active()) { + increaseCounters(sender); + sender->notifyIOError(std::move(response.second.d_query.d_idstate), now); + } + } + } + catch (const std::exception& e) { + vinfolog("Got an exception while notifying: %s", e.what()); + } + catch (...) { + vinfolog("Got exception while notifying"); + } + + release(); +} + +static uint32_t getSerialFromRawSOAContent(const std::vector<uint8_t>& raw) +{ + /* minimal size for a SOA record, as defined by rfc1035: + MNAME (root): 1 + RNAME (root): 1 + SERIAL: 4 + REFRESH: 4 + RETRY: 4 + EXPIRE: 4 + MINIMUM: 4 + = 22 bytes + */ + if (raw.size() < 22) { + throw std::runtime_error("Invalid content of size " + std::to_string(raw.size()) + " for a SOA record"); + } + /* As rfc1025 states that "all domain names in the RDATA section of these RRs may be compressed", + and we don't want to parse these names, start at the end */ + uint32_t serial = 0; + memcpy(&serial, &raw.at(raw.size() - 20), sizeof(serial)); + return ntohl(serial); +} + +IOState TCPConnectionToBackend::handleResponse(std::shared_ptr<TCPConnectionToBackend>& conn, const struct timeval& now) +{ + d_downstreamFailures = 0; + + uint16_t queryId = 0; + try { + queryId = getQueryIdFromResponse(); + } + catch (const std::exception& e) { + DEBUGLOG("Unable to get query ID"); + notifyAllQueriesFailed(now, FailureReason::unexpectedQueryID); + throw; + } + + auto it = d_pendingResponses.find(queryId); + if (it == d_pendingResponses.end()) { + DEBUGLOG("could not find any corresponding query for ID "<<queryId<<". This is likely a duplicated ID over the same TCP connection, giving up!"); + notifyAllQueriesFailed(now, FailureReason::unexpectedQueryID); + return IOState::Done; + } + + editPayloadID(d_responseBuffer, ntohs(it->second.d_query.d_idstate.origID), 0, false); + + auto sender = it->second.d_sender; + + if (sender->active() && it->second.d_query.isXFR()) { + DEBUGLOG("XFR!"); + bool done = false; + TCPResponse response; + response.d_buffer = std::move(d_responseBuffer); + response.d_connection = conn; + /* we don't move the whole IDS because we will need for the responses to come */ + response.d_idstate.qtype = it->second.d_query.d_idstate.qtype; + response.d_idstate.qname = it->second.d_query.d_idstate.qname; + DEBUGLOG("passing XFRresponse to client connection for "<<response.d_idstate.qname); + + it->second.d_query.d_xfrStarted = true; + done = isXFRFinished(response, it->second.d_query); + + if (done) { + d_pendingResponses.erase(it); + --conn->d_ds->outstanding; + /* marking as idle for now, so we can accept new queries if our queues are empty */ + if (d_pendingQueries.empty() && d_pendingResponses.empty()) { + t_downstreamTCPConnectionsManager.moveToIdle(conn); + d_state = State::idle; + } + } + + sender->handleXFRResponse(now, std::move(response)); + if (done) { + t_downstreamTCPConnectionsManager.moveToIdle(conn); + d_state = State::idle; + return IOState::Done; + } + + d_state = State::waitingForResponseFromBackend; + d_currentPos = 0; + d_responseBuffer.resize(sizeof(uint16_t)); + // get ready to read the next packet, if any + return IOState::NeedRead; + } + + --conn->d_ds->outstanding; + auto ids = std::move(it->second.d_query.d_idstate); + d_pendingResponses.erase(it); + /* marking as idle for now, so we can accept new queries if our queues are empty */ + if (d_pendingQueries.empty() && d_pendingResponses.empty()) { + t_downstreamTCPConnectionsManager.moveToIdle(conn); + d_state = State::idle; + } + + auto shared = conn; + if (sender->active()) { + DEBUGLOG("passing response to client connection for "<<ids.qname); + // make sure that we still exist after calling handleResponse() + sender->handleResponse(now, TCPResponse(std::move(d_responseBuffer), std::move(ids), conn)); + } + + if (!d_pendingQueries.empty()) { + DEBUGLOG("still have some queries to send"); + return queueNextQuery(shared); + } + else if (!d_pendingResponses.empty()) { + DEBUGLOG("still have some responses to read"); + d_state = State::waitingForResponseFromBackend; + d_currentPos = 0; + d_responseBuffer.resize(sizeof(uint16_t)); + return IOState::NeedRead; + } + else { + DEBUGLOG("nothing to do, waiting for a new query"); + t_downstreamTCPConnectionsManager.moveToIdle(conn); + d_state = State::idle; + return IOState::Done; + } +} + +uint16_t TCPConnectionToBackend::getQueryIdFromResponse() const +{ + if (d_responseBuffer.size() < sizeof(dnsheader)) { + throw std::runtime_error("Unable to get query ID in a too small (" + std::to_string(d_responseBuffer.size()) + ") response from " + d_ds->getNameWithAddr()); + } + + uint16_t id; + memcpy(&id, &d_responseBuffer.at(0), sizeof(id)); + return ntohs(id); +} + +void TCPConnectionToBackend::setProxyProtocolValuesSent(std::unique_ptr<std::vector<ProxyProtocolValue>>&& proxyProtocolValuesSent) +{ + /* if we already have some values, we have already verified they match */ + if (!d_proxyProtocolValuesSent) { + d_proxyProtocolValuesSent = std::move(proxyProtocolValuesSent); + } +} + +bool TCPConnectionToBackend::matchesTLVs(const std::unique_ptr<std::vector<ProxyProtocolValue>>& tlvs) const +{ + if (tlvs == nullptr) { + if (d_proxyProtocolValuesSent == nullptr) { + return true; + } + else { + return false; + } + } + + if (d_proxyProtocolValuesSent == nullptr) { + return false; + } + + return *tlvs == *d_proxyProtocolValuesSent; +} + +bool TCPConnectionToBackend::isXFRFinished(const TCPResponse& response, TCPQuery& query) +{ + bool done = false; + try { + MOADNSParser parser(true, reinterpret_cast<const char*>(response.d_buffer.data()), response.d_buffer.size()); + if (parser.d_header.rcode != 0U) { + done = true; + } + else { + for (const auto& record : parser.d_answers) { + if (record.first.d_class != QClass::IN || record.first.d_type != QType::SOA) { + continue; + } + + auto unknownContent = getRR<UnknownRecordContent>(record.first); + if (!unknownContent) { + continue; + } + auto raw = unknownContent->getRawContent(); + auto serial = getSerialFromRawSOAContent(raw); + ++query.d_xfrSerialCount; + if (query.d_xfrMasterSerial == 0) { + // store the first SOA in our client's connection metadata + ++query.d_xfrMasterSerialCount; + query.d_xfrMasterSerial = serial; + } + else if (query.d_xfrMasterSerial == serial) { + ++query.d_xfrMasterSerialCount; + // figure out if it's end when receiving master's SOA again + if (query.d_xfrSerialCount == 2) { + // if there are only two SOA records marks a finished AXFR + done = true; + } + if (query.d_xfrMasterSerialCount == 3) { + // receiving master's SOA 3 times marks a finished IXFR + done = true; + } + } + } + } + } + catch (const MOADNSException& e) { + DEBUGLOG("Exception when parsing TCPResponse to DNS: " << e.what()); + /* ponder what to do here, shall we close the connection? */ + } + return done; +} |