diff options
Diffstat (limited to '')
-rw-r--r-- | dnsdist.cc | 1783 |
1 files changed, 1114 insertions, 669 deletions
@@ -29,6 +29,7 @@ #include <limits> #include <netinet/tcp.h> #include <pwd.h> +#include <set> #include <sys/resource.h> #include <unistd.h> @@ -52,11 +53,15 @@ #include "dnsdist-cache.hh" #include "dnsdist-carbon.hh" #include "dnsdist-console.hh" +#include "dnsdist-crypto.hh" #include "dnsdist-discovery.hh" +#include "dnsdist-dnsparser.hh" #include "dnsdist-dynblocks.hh" #include "dnsdist-ecs.hh" +#include "dnsdist-edns.hh" #include "dnsdist-healthchecks.hh" #include "dnsdist-lua.hh" +#include "dnsdist-lua-hooks.hh" #include "dnsdist-nghttp2.hh" #include "dnsdist-proxy-protocol.hh" #include "dnsdist-random.hh" @@ -65,10 +70,13 @@ #include "dnsdist-tcp.hh" #include "dnsdist-web.hh" #include "dnsdist-xpf.hh" +#include "dnsdist-xsk.hh" #include "base64.hh" #include "capabilities.hh" +#include "coverage.hh" #include "delaypipe.hh" +#include "doh.hh" #include "dolog.hh" #include "dnsname.hh" #include "dnsparser.hh" @@ -76,9 +84,9 @@ #include "gettime.hh" #include "lock.hh" #include "misc.hh" -#include "sodcrypto.hh" #include "sstuff.hh" #include "threadname.hh" +#include "xsk.hh" /* Known sins: @@ -94,14 +102,9 @@ using std::thread; bool g_verbose; -std::optional<std::ofstream> g_verboseStream{std::nullopt}; - -struct DNSDistStats g_stats; uint16_t g_maxOutstanding{std::numeric_limits<uint16_t>::max()}; uint32_t g_staleCacheEntriesTTL{0}; -bool g_syslog{true}; -bool g_logtimestamps{false}; bool g_allowEmptyResponse{false}; GlobalStateHolder<NetmaskGroup> g_ACL; @@ -109,6 +112,8 @@ string g_outputBuffer; std::vector<std::shared_ptr<TLSFrontend>> g_tlslocals; std::vector<std::shared_ptr<DOHFrontend>> g_dohlocals; +std::vector<std::shared_ptr<DOQFrontend>> g_doqlocals; +std::vector<std::shared_ptr<DOH3Frontend>> g_doh3locals; std::vector<std::shared_ptr<DNSCryptContext>> g_dnsCryptLocals; shared_ptr<BPFFilter> g_defaultBPFFilter{nullptr}; @@ -153,7 +158,9 @@ uint32_t g_socketUDPRecvBuffer{0}; std::set<std::string> g_capabilitiesToRetain; -static size_t const s_initialUDPPacketBufferSize = s_maxPacketCacheEntrySize + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE; +// we are not willing to receive a bigger UDP response than that, no matter what +static constexpr size_t s_maxUDPResponsePacketSize{4096U}; +static size_t const s_initialUDPPacketBufferSize = s_maxUDPResponsePacketSize + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE; static_assert(s_initialUDPPacketBufferSize <= UINT16_MAX, "Packet size should fit in a uint16_t"); static ssize_t sendfromto(int sock, const void* data, size_t len, int flags, const ComboAddress& from, const ComboAddress& to) @@ -196,8 +203,12 @@ static void truncateTC(PacketBuffer& packet, size_t maximumSize, unsigned int qn } packet.resize(static_cast<uint16_t>(sizeof(dnsheader)+qnameWireLength+DNS_TYPE_SIZE+DNS_CLASS_SIZE)); - struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(packet.data()); - dh->ancount = dh->arcount = dh->nscount = 0; + dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) { + header.ancount = 0; + header.arcount = 0; + header.nscount = 0; + return true; + }); if (hadEDNS) { addEDNS(packet, maximumSize, z & EDNS_HEADER_FLAG_DO, payloadSize, 0); @@ -205,7 +216,7 @@ static void truncateTC(PacketBuffer& packet, size_t maximumSize, unsigned int qn } catch(...) { - ++g_stats.truncFail; + ++dnsdist::metrics::g_stats.truncFail; } } @@ -226,13 +237,14 @@ struct DelayedPacket } }; -static DelayPipe<DelayedPacket>* g_delay = nullptr; +static std::unique_ptr<DelayPipe<DelayedPacket>> g_delay{nullptr}; #endif /* DISABLE_DELAY_PIPE */ std::string DNSQuestion::getTrailingData() const { - const char* message = reinterpret_cast<const char*>(this->getHeader()); - const uint16_t messageLen = getDNSPacketLength(message, this->data.size()); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + const auto* message = reinterpret_cast<const char*>(this->getData().data()); + const uint16_t messageLen = getDNSPacketLength(message, this->getData().size()); return std::string(message + messageLen, this->getData().size() - messageLen); } @@ -250,6 +262,14 @@ bool DNSQuestion::setTrailingData(const std::string& tail) return true; } +bool DNSQuestion::editHeader(const std::function<bool(dnsheader&)>& editFunction) +{ + if (data.size() < sizeof(dnsheader)) { + throw std::runtime_error("Trying to access the dnsheader of a too small (" + std::to_string(data.size()) + ") DNSQuestion buffer"); + } + return dnsdist::PacketMangling::editDNSHeaderFromPacket(data, editFunction); +} + static void doLatencyStats(dnsdist::Protocol protocol, double udiff) { constexpr auto doAvg = [](double& var, double n, double weight) { @@ -258,61 +278,73 @@ static void doLatencyStats(dnsdist::Protocol protocol, double udiff) if (protocol == dnsdist::Protocol::DoUDP || protocol == dnsdist::Protocol::DNSCryptUDP) { if (udiff < 1000) { - ++g_stats.latency0_1; + ++dnsdist::metrics::g_stats.latency0_1; } else if (udiff < 10000) { - ++g_stats.latency1_10; + ++dnsdist::metrics::g_stats.latency1_10; } else if (udiff < 50000) { - ++g_stats.latency10_50; + ++dnsdist::metrics::g_stats.latency10_50; } else if (udiff < 100000) { - ++g_stats.latency50_100; + ++dnsdist::metrics::g_stats.latency50_100; } else if (udiff < 1000000) { - ++g_stats.latency100_1000; + ++dnsdist::metrics::g_stats.latency100_1000; } else { - ++g_stats.latencySlow; + ++dnsdist::metrics::g_stats.latencySlow; } - g_stats.latencySum += udiff / 1000; - ++g_stats.latencyCount; + dnsdist::metrics::g_stats.latencySum += udiff / 1000; + ++dnsdist::metrics::g_stats.latencyCount; - doAvg(g_stats.latencyAvg100, udiff, 100); - doAvg(g_stats.latencyAvg1000, udiff, 1000); - doAvg(g_stats.latencyAvg10000, udiff, 10000); - doAvg(g_stats.latencyAvg1000000, udiff, 1000000); + doAvg(dnsdist::metrics::g_stats.latencyAvg100, udiff, 100); + doAvg(dnsdist::metrics::g_stats.latencyAvg1000, udiff, 1000); + doAvg(dnsdist::metrics::g_stats.latencyAvg10000, udiff, 10000); + doAvg(dnsdist::metrics::g_stats.latencyAvg1000000, udiff, 1000000); } else if (protocol == dnsdist::Protocol::DoTCP || protocol == dnsdist::Protocol::DNSCryptTCP) { - doAvg(g_stats.latencyTCPAvg100, udiff, 100); - doAvg(g_stats.latencyTCPAvg1000, udiff, 1000); - doAvg(g_stats.latencyTCPAvg10000, udiff, 10000); - doAvg(g_stats.latencyTCPAvg1000000, udiff, 1000000); + doAvg(dnsdist::metrics::g_stats.latencyTCPAvg100, udiff, 100); + doAvg(dnsdist::metrics::g_stats.latencyTCPAvg1000, udiff, 1000); + doAvg(dnsdist::metrics::g_stats.latencyTCPAvg10000, udiff, 10000); + doAvg(dnsdist::metrics::g_stats.latencyTCPAvg1000000, udiff, 1000000); } else if (protocol == dnsdist::Protocol::DoT) { - doAvg(g_stats.latencyDoTAvg100, udiff, 100); - doAvg(g_stats.latencyDoTAvg1000, udiff, 1000); - doAvg(g_stats.latencyDoTAvg10000, udiff, 10000); - doAvg(g_stats.latencyDoTAvg1000000, udiff, 1000000); + doAvg(dnsdist::metrics::g_stats.latencyDoTAvg100, udiff, 100); + doAvg(dnsdist::metrics::g_stats.latencyDoTAvg1000, udiff, 1000); + doAvg(dnsdist::metrics::g_stats.latencyDoTAvg10000, udiff, 10000); + doAvg(dnsdist::metrics::g_stats.latencyDoTAvg1000000, udiff, 1000000); } else if (protocol == dnsdist::Protocol::DoH) { - doAvg(g_stats.latencyDoHAvg100, udiff, 100); - doAvg(g_stats.latencyDoHAvg1000, udiff, 1000); - doAvg(g_stats.latencyDoHAvg10000, udiff, 10000); - doAvg(g_stats.latencyDoHAvg1000000, udiff, 1000000); + doAvg(dnsdist::metrics::g_stats.latencyDoHAvg100, udiff, 100); + doAvg(dnsdist::metrics::g_stats.latencyDoHAvg1000, udiff, 1000); + doAvg(dnsdist::metrics::g_stats.latencyDoHAvg10000, udiff, 10000); + doAvg(dnsdist::metrics::g_stats.latencyDoHAvg1000000, udiff, 1000000); + } + else if (protocol == dnsdist::Protocol::DoQ) { + doAvg(dnsdist::metrics::g_stats.latencyDoQAvg100, udiff, 100); + doAvg(dnsdist::metrics::g_stats.latencyDoQAvg1000, udiff, 1000); + doAvg(dnsdist::metrics::g_stats.latencyDoQAvg10000, udiff, 10000); + doAvg(dnsdist::metrics::g_stats.latencyDoQAvg1000000, udiff, 1000000); + } + else if (protocol == dnsdist::Protocol::DoH3) { + doAvg(dnsdist::metrics::g_stats.latencyDoH3Avg100, udiff, 100); + doAvg(dnsdist::metrics::g_stats.latencyDoH3Avg1000, udiff, 1000); + doAvg(dnsdist::metrics::g_stats.latencyDoH3Avg10000, udiff, 10000); + doAvg(dnsdist::metrics::g_stats.latencyDoH3Avg1000000, udiff, 1000000); } } -bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const std::shared_ptr<DownstreamState>& remote, unsigned int& qnameWireLength) +bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const std::shared_ptr<DownstreamState>& remote) { if (response.size() < sizeof(dnsheader)) { return false; } - const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(response.data()); + const dnsheader_aligned dh(response.data()); if (dh->qr == 0) { - ++g_stats.nonCompliantResponses; + ++dnsdist::metrics::g_stats.nonCompliantResponses; if (remote) { ++remote->nonCompliantResponses; } @@ -324,7 +356,7 @@ bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, return true; } else { - ++g_stats.nonCompliantResponses; + ++dnsdist::metrics::g_stats.nonCompliantResponses; if (remote) { ++remote->nonCompliantResponses; } @@ -335,13 +367,14 @@ bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, uint16_t rqtype, rqclass; DNSName rqname; try { - rqname = DNSName(reinterpret_cast<const char*>(response.data()), response.size(), sizeof(dnsheader), false, &rqtype, &rqclass, &qnameWireLength); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + rqname = DNSName(reinterpret_cast<const char*>(response.data()), response.size(), sizeof(dnsheader), false, &rqtype, &rqclass); } catch (const std::exception& e) { if (remote && response.size() > 0 && static_cast<size_t>(response.size()) > sizeof(dnsheader)) { infolog("Backend %s sent us a response with id %d that did not parse: %s", remote->d_config.remote.toStringWithPort(), ntohs(dh->id), e.what()); } - ++g_stats.nonCompliantResponses; + ++dnsdist::metrics::g_stats.nonCompliantResponses; if (remote) { ++remote->nonCompliantResponses; } @@ -369,11 +402,14 @@ static void restoreFlags(struct dnsheader* dh, uint16_t origFlags) *flags |= origFlags; } -static bool fixUpQueryTurnedResponse(DNSQuestion& dq, const uint16_t origFlags) +static bool fixUpQueryTurnedResponse(DNSQuestion& dnsQuestion, const uint16_t origFlags) { - restoreFlags(dq.getHeader(), origFlags); + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsQuestion.getMutableData(), [origFlags](dnsheader& header) { + restoreFlags(&header, origFlags); + return true; + }); - return addEDNSToQueryTurnedResponse(dq); + return addEDNSToQueryTurnedResponse(dnsQuestion); } static bool fixUpResponse(PacketBuffer& response, const DNSName& qname, uint16_t origFlags, bool ednsAdded, bool ecsAdded, bool* zeroScope) @@ -382,8 +418,10 @@ static bool fixUpResponse(PacketBuffer& response, const DNSName& qname, uint16_t return false; } - struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(response.data()); - restoreFlags(dh, origFlags); + dnsdist::PacketMangling::editDNSHeaderFromPacket(response, [origFlags](dnsheader& header) { + restoreFlags(&header, origFlags); + return true; + }); if (response.size() == sizeof(dnsheader)) { return true; @@ -421,10 +459,12 @@ static bool fixUpResponse(PacketBuffer& response, const DNSName& qname, uint16_t if (last) { /* simply remove the last AR */ response.resize(response.size() - optLen); - dh = reinterpret_cast<struct dnsheader*>(response.data()); - uint16_t arcount = ntohs(dh->arcount); - arcount--; - dh->arcount = htons(arcount); + dnsdist::PacketMangling::editDNSHeaderFromPacket(response, [](dnsheader& header) { + uint16_t arcount = ntohs(header.arcount); + arcount--; + header.arcount = htons(arcount); + return true; + }); } else { /* Removing an intermediary RR could lead to compression error */ @@ -498,9 +538,24 @@ static bool applyRulesToResponse(const std::vector<DNSDistResponseRuleAction>& r return true; break; case DNSResponseAction::Action::ServFail: - dr.getHeader()->rcode = RCode::ServFail; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dr.getMutableData(), [](dnsheader& header) { + header.rcode = RCode::ServFail; + return true; + }); return true; break; + case DNSResponseAction::Action::Truncate: + if (!dr.overTCP()) { + dnsdist::PacketMangling::editDNSHeaderFromPacket(dr.getMutableData(), [](dnsheader& header) { + header.tc = true; + header.qr = true; + return true; + }); + truncateTC(dr.getMutableData(), dr.getMaximumSize(), dr.ids.qname.wirelength()); + ++dnsdist::metrics::g_stats.ruleTruncated; + return true; + } + break; /* non-terminal actions follow */ case DNSResponseAction::Action::Delay: pdns::checked_stoi_into(dr.ids.delayMsec, ruleresult); // sorry @@ -521,7 +576,7 @@ bool processResponseAfterRules(PacketBuffer& response, const std::vector<DNSDist return false; } - if (dr.ids.packetCache && !dr.ids.selfGenerated && !dr.ids.skipCache && response.size() <= s_maxPacketCacheEntrySize) { + if (dr.ids.packetCache && !dr.ids.selfGenerated && !dr.ids.skipCache && (!dr.ids.forwardedOverUDP || response.size() <= s_maxUDPResponsePacketSize)) { if (!dr.ids.useZeroScope) { /* if the query was not suitable for zero-scope, for example because it had an existing ECS entry so the hash is @@ -555,6 +610,10 @@ bool processResponseAfterRules(PacketBuffer& response, const std::vector<DNSDist ac(&dr, &result); } + if (dr.ids.d_extendedError) { + dnsdist::edns::addExtendedDNSError(dr.getMutableData(), dr.getMaximumSize(), dr.ids.d_extendedError->infoCode, dr.ids.d_extendedError->extraText); + } + #ifdef HAVE_DNSCRYPT if (!muted) { if (!encryptResponse(response, dr.getMaximumSize(), dr.overTCP(), dr.ids.dnsCryptQuery)) { @@ -579,11 +638,11 @@ bool processResponse(PacketBuffer& response, const std::vector<DNSDistResponseRu return processResponseAfterRules(response, cacheInsertedRespRuleActions, dr, muted); } -static size_t getInitialUDPPacketBufferSize() +static size_t getInitialUDPPacketBufferSize(bool expectProxyProtocol) { static_assert(s_udpIncomingBufferSize <= s_initialUDPPacketBufferSize, "The incoming buffer size should not be larger than s_initialUDPPacketBufferSize"); - if (g_proxyProtocolACL.empty()) { + if (!expectProxyProtocol || g_proxyProtocolACL.empty()) { return s_initialUDPPacketBufferSize; } @@ -593,10 +652,10 @@ static size_t getInitialUDPPacketBufferSize() static size_t getMaximumIncomingPacketSize(const ClientState& cs) { if (cs.dnscryptCtx) { - return getInitialUDPPacketBufferSize(); + return getInitialUDPPacketBufferSize(cs.d_enableProxyProtocol); } - if (g_proxyProtocolACL.empty()) { + if (!cs.d_enableProxyProtocol || g_proxyProtocolACL.empty()) { return s_udpIncomingBufferSize; } @@ -606,7 +665,7 @@ static size_t getMaximumIncomingPacketSize(const ClientState& cs) bool sendUDPResponse(int origFD, const PacketBuffer& response, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote) { #ifndef DISABLE_DELAY_PIPE - if (delayMsec && g_delay) { + if (delayMsec > 0 && g_delay != nullptr) { DelayedPacket dp{origFD, response, origRemote, origDest}; g_delay->submit(dp, delayMsec); return true; @@ -636,16 +695,16 @@ void handleResponseSent(const DNSName& qname, const QType& qtype, double udiff, switch (cleartextDH.rcode) { case RCode::NXDomain: - ++g_stats.frontendNXDomain; + ++dnsdist::metrics::g_stats.frontendNXDomain; break; case RCode::ServFail: if (fromBackend) { - ++g_stats.servfailResponses; + ++dnsdist::metrics::g_stats.servfailResponses; } - ++g_stats.frontendServFail; + ++dnsdist::metrics::g_stats.frontendServFail; break; case RCode::NoError: - ++g_stats.frontendNoError; + ++dnsdist::metrics::g_stats.frontendNoError; break; } @@ -659,7 +718,10 @@ static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& re if (ids.udpPayloadSize > 0 && response.size() > ids.udpPayloadSize) { vinfolog("Got a response of size %d while the initial UDP payload size was %d, truncating", response.size(), ids.udpPayloadSize); truncateTC(dr.getMutableData(), dr.getMaximumSize(), dr.ids.qname.wirelength()); - dr.getHeader()->tc = true; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dr.getMutableData(), [](dnsheader& header) { + header.tc = true; + return true; + }); } else if (dr.getHeader()->tc && g_truncateTC) { truncateTC(response, dr.getMaximumSize(), dr.ids.qname.wirelength()); @@ -668,7 +730,7 @@ static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& re /* when the answer is encrypted in place, we need to get a copy of the original header before encryption to fill the ring buffer */ dnsheader cleartextDH; - memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH)); + memcpy(&cleartextDH, dr.getHeader().get(), sizeof(cleartextDH)); if (!isAsync) { if (!processResponse(response, respRuleActions, cacheInsertedRespRuleActions, dr, ids.cs && ids.cs->muted)) { @@ -680,15 +742,13 @@ static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& re } } - ++g_stats.responses; + ++dnsdist::metrics::g_stats.responses; if (ids.cs) { ++ids.cs->responses; } bool muted = true; - if (ids.cs && !ids.cs->muted) { - ComboAddress empty; - empty.sin4.sin_family = 0; + if (ids.cs != nullptr && !ids.cs->muted && !ids.isXSK()) { sendUDPResponse(ids.cs->udpFD, response, dr.ids.delayMsec, ids.hopLocal, ids.hopRemote); muted = false; } @@ -696,10 +756,15 @@ static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& re if (!selfGenerated) { double udiff = ids.queryRealTime.udiff(); if (!muted) { - vinfolog("Got answer from %s, relayed to %s (UDP), took %f usec", ds->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), udiff); + vinfolog("Got answer from %s, relayed to %s (UDP), took %f us", ds->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), udiff); } else { - vinfolog("Got answer from %s, NOT relayed to %s (UDP) since that frontend is muted, took %f usec", ds->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), udiff); + if (!ids.isXSK()) { + vinfolog("Got answer from %s, NOT relayed to %s (UDP) since that frontend is muted, took %f us", ds->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), udiff); + } + else { + vinfolog("Got answer from %s, relayed to %s (UDP via XSK), took %f us", ds->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), udiff); + } } handleResponseSent(ids, udiff, dr.ids.origRemote, ds->d_config.remote, response.size(), cleartextDH, ds->getProtocol(), true); @@ -709,6 +774,42 @@ static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& re } } +bool processResponderPacket(std::shared_ptr<DownstreamState>& dss, PacketBuffer& response, const std::vector<DNSDistResponseRuleAction>& localRespRuleActions, const std::vector<DNSDistResponseRuleAction>& cacheInsertedRespRuleActions, InternalQueryState&& ids) +{ + + const dnsheader_aligned dnsHeader(response.data()); + auto queryId = dnsHeader->id; + + if (!responseContentMatches(response, ids.qname, ids.qtype, ids.qclass, dss)) { + dss->restoreState(queryId, std::move(ids)); + return false; + } + + auto dohUnit = std::move(ids.du); + dnsdist::PacketMangling::editDNSHeaderFromPacket(response, [&ids](dnsheader& header) { + header.id = ids.origID; + return true; + }); + ++dss->responses; + + double udiff = ids.queryRealTime.udiff(); + // do that _before_ the processing, otherwise it's not fair to the backend + dss->latencyUsec = (127.0 * dss->latencyUsec / 128.0) + udiff / 128.0; + dss->reportResponse(dnsHeader->rcode); + + /* don't call processResponse for DOH */ + if (dohUnit) { +#ifdef HAVE_DNS_OVER_HTTPS + // DoH query, we cannot touch dohUnit after that + DOHUnitInterface::handleUDPResponse(std::move(dohUnit), std::move(response), std::move(ids), dss); +#endif + return false; + } + + handleResponseForUDPClient(ids, response, localRespRuleActions, cacheInsertedRespRuleActions, dss, false, false); + return true; +} + // listens on a dedicated socket, lobs answers from downstream servers to original requestors void responderThread(std::shared_ptr<DownstreamState> dss) { @@ -716,7 +817,7 @@ void responderThread(std::shared_ptr<DownstreamState> dss) setThreadName("dnsdist/respond"); auto localRespRuleActions = g_respruleactions.getLocal(); auto localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal(); - const size_t initialBufferSize = getInitialUDPPacketBufferSize(); + const size_t initialBufferSize = getInitialUDPPacketBufferSize(false); /* allocate one more byte so we can detect truncation */ PacketBuffer response(initialBufferSize + 1); uint16_t queryId = 0; @@ -748,7 +849,7 @@ void responderThread(std::shared_ptr<DownstreamState> dss) for (const auto& fd : sockets) { /* allocate one more byte so we can detect truncation */ - // NOLINTNEXTLINE(bugprone-use-after-move): resizing a vector has no preconditions so it is valid to do so after moving it + // NOLINTNEXTLINE(bugprone-use-after-move): resizing a vector has no preconditions so it is valid to do so after moving it response.resize(initialBufferSize + 1); ssize_t got = recv(fd, response.data(), response.size(), 0); @@ -761,40 +862,37 @@ void responderThread(std::shared_ptr<DownstreamState> dss) } response.resize(static_cast<size_t>(got)); - dnsheader* dh = reinterpret_cast<struct dnsheader*>(response.data()); - queryId = dh->id; + const dnsheader_aligned dnsHeader(response.data()); + queryId = dnsHeader->id; auto ids = dss->getState(queryId); if (!ids) { continue; } - unsigned int qnameWireLength = 0; - if (fd != ids->backendFD || !responseContentMatches(response, ids->qname, ids->qtype, ids->qclass, dss, qnameWireLength)) { + if (!ids->isXSK() && fd != ids->backendFD) { dss->restoreState(queryId, std::move(*ids)); continue; } - auto du = std::move(ids->du); - - dh->id = ids->origID; - ++dss->responses; - - double udiff = ids->queryRealTime.udiff(); - // do that _before_ the processing, otherwise it's not fair to the backend - dss->latencyUsec = (127.0 * dss->latencyUsec / 128.0) + udiff / 128.0; - dss->reportResponse(dh->rcode); - - /* don't call processResponse for DOH */ - if (du) { -#ifdef HAVE_DNS_OVER_HTTPS - // DoH query, we cannot touch du after that - handleUDPResponseForDoH(std::move(du), std::move(response), std::move(*ids)); -#endif - continue; + if (processResponderPacket(dss, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, std::move(*ids)) && ids->isXSK() && ids->cs->xskInfo) { +#ifdef HAVE_XSK + auto& xskInfo = ids->cs->xskInfo; + auto xskPacket = xskInfo->getEmptyFrame(); + if (!xskPacket) { + continue; + } + xskPacket->setHeader(ids->xskPacketHeader); + if (!xskPacket->setPayload(response)) { + } + if (ids->delayMsec > 0) { + xskPacket->addDelay(ids->delayMsec); + } + xskPacket->updatePacket(); + xskInfo->pushToSendQueue(*xskPacket); + xskInfo->notifyXskSocket(); +#endif /* HAVE_XSK */ } - - handleResponseForUDPClient(*ids, response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dss, false, false); } } catch (const std::exception& e) { @@ -824,8 +922,8 @@ static void spoofResponseFromString(DNSQuestion& dq, const string& spoofContent, if (raw) { std::vector<std::string> raws; stringtok(raws, spoofContent, ","); - SpoofAction sa(raws); - sa(&dq, &result); + SpoofAction tempSpoofAction(raws, std::nullopt); + tempSpoofAction(&dq, &result); } else { std::vector<std::string> addrs; @@ -834,13 +932,13 @@ static void spoofResponseFromString(DNSQuestion& dq, const string& spoofContent, if (addrs.size() == 1) { try { ComboAddress spoofAddr(spoofContent); - SpoofAction sa({spoofAddr}); - sa(&dq, &result); + SpoofAction tempSpoofAction({spoofAddr}); + tempSpoofAction(&dq, &result); } catch(const PDNSException &e) { DNSName cname(spoofContent); - SpoofAction sa(cname); // CNAME then - sa(&dq, &result); + SpoofAction tempSpoofAction(cname); // CNAME then + tempSpoofAction(&dq, &result); } } else { std::vector<ComboAddress> cas; @@ -851,8 +949,8 @@ static void spoofResponseFromString(DNSQuestion& dq, const string& spoofContent, catch (...) { } } - SpoofAction sa(cas); - sa(&dq, &result); + SpoofAction tempSpoofAction(cas); + tempSpoofAction(&dq, &result); } } } @@ -861,8 +959,8 @@ static void spoofPacketFromString(DNSQuestion& dq, const string& spoofContent) { string result; - SpoofAction sa(spoofContent.c_str(), spoofContent.size()); - sa(&dq, &result); + SpoofAction tempSpoofAction(spoofContent.c_str(), spoofContent.size()); + tempSpoofAction(&dq, &result); } bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::string& ruleresult, bool& drop) @@ -871,28 +969,33 @@ bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::s return false; } - switch(action) { + auto setRCode = [&dq](uint8_t rcode) { + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [rcode](dnsheader& header) { + header.rcode = rcode; + header.qr = true; + return true; + }); + }; + + switch (action) { case DNSAction::Action::Allow: return true; break; case DNSAction::Action::Drop: - ++g_stats.ruleDrop; + ++dnsdist::metrics::g_stats.ruleDrop; drop = true; return true; break; case DNSAction::Action::Nxdomain: - dq.getHeader()->rcode = RCode::NXDomain; - dq.getHeader()->qr = true; + setRCode(RCode::NXDomain); return true; break; case DNSAction::Action::Refused: - dq.getHeader()->rcode = RCode::Refused; - dq.getHeader()->qr = true; + setRCode(RCode::Refused); return true; break; case DNSAction::Action::ServFail: - dq.getHeader()->rcode = RCode::ServFail; - dq.getHeader()->qr = true; + setRCode(RCode::ServFail); return true; break; case DNSAction::Action::Spoof: @@ -909,12 +1012,15 @@ bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::s break; case DNSAction::Action::Truncate: if (!dq.overTCP()) { - dq.getHeader()->tc = true; - dq.getHeader()->qr = true; - dq.getHeader()->ra = dq.getHeader()->rd; - dq.getHeader()->aa = false; - dq.getHeader()->ad = false; - ++g_stats.ruleTruncated; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) { + header.tc = true; + header.qr = true; + header.ra = header.rd; + header.aa = false; + header.ad = false; + return true; + }); + ++dnsdist::metrics::g_stats.ruleTruncated; return true; } break; @@ -928,7 +1034,10 @@ bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::s return true; break; case DNSAction::Action::NoRecurse: - dq.getHeader()->rd = false; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) { + header.rd = false; + return true; + }); return true; break; /* non-terminal actions follow */ @@ -970,10 +1079,18 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru } #ifndef DISABLE_DYNBLOCKS + auto setRCode = [&dq](uint8_t rcode) { + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [rcode](dnsheader& header) { + header.rcode = rcode; + header.qr = true; + return true; + }); + }; + /* the Dynamic Block mechanism supports address and port ranges, so we need to pass the full address and port */ if (auto got = holders.dynNMGBlock->lookup(AddressAndPortRange(dq.ids.origRemote, dq.ids.origRemote.isIPv4() ? 32 : 128, 16))) { auto updateBlockStats = [&got]() { - ++g_stats.dynBlocked; + ++dnsdist::metrics::g_stats.dynBlocked; got->second.blocks++; }; @@ -982,6 +1099,7 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru if (action == DNSAction::Action::None) { action = g_dynBlockAction; } + switch (action) { case DNSAction::Action::NoOp: /* do nothing */ @@ -991,27 +1109,28 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru vinfolog("Query from %s turned into NXDomain because of dynamic block", dq.ids.origRemote.toStringWithPort()); updateBlockStats(); - dq.getHeader()->rcode = RCode::NXDomain; - dq.getHeader()->qr=true; + setRCode(RCode::NXDomain); return true; case DNSAction::Action::Refused: vinfolog("Query from %s refused because of dynamic block", dq.ids.origRemote.toStringWithPort()); updateBlockStats(); - dq.getHeader()->rcode = RCode::Refused; - dq.getHeader()->qr = true; + setRCode(RCode::Refused); return true; case DNSAction::Action::Truncate: if (!dq.overTCP()) { updateBlockStats(); vinfolog("Query from %s truncated because of dynamic block", dq.ids.origRemote.toStringWithPort()); - dq.getHeader()->tc = true; - dq.getHeader()->qr = true; - dq.getHeader()->ra = dq.getHeader()->rd; - dq.getHeader()->aa = false; - dq.getHeader()->ad = false; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) { + header.tc = true; + header.qr = true; + header.ra = header.rd; + header.aa = false; + header.ad = false; + return true; + }); return true; } else { @@ -1021,7 +1140,10 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru case DNSAction::Action::NoRecurse: updateBlockStats(); vinfolog("Query from %s setting rd=0 because of dynamic block", dq.ids.origRemote.toStringWithPort()); - dq.getHeader()->rd = false; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) { + header.rd = false; + return true; + }); return true; default: updateBlockStats(); @@ -1033,7 +1155,7 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru if (auto got = holders.dynSMTBlock->lookup(dq.ids.qname)) { auto updateBlockStats = [&got]() { - ++g_stats.dynBlocked; + ++dnsdist::metrics::g_stats.dynBlocked; got->blocks++; }; @@ -1050,26 +1172,27 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru vinfolog("Query from %s for %s turned into NXDomain because of dynamic block", dq.ids.origRemote.toStringWithPort(), dq.ids.qname.toLogString()); updateBlockStats(); - dq.getHeader()->rcode = RCode::NXDomain; - dq.getHeader()->qr = true; + setRCode(RCode::NXDomain); return true; case DNSAction::Action::Refused: vinfolog("Query from %s for %s refused because of dynamic block", dq.ids.origRemote.toStringWithPort(), dq.ids.qname.toLogString()); updateBlockStats(); - dq.getHeader()->rcode = RCode::Refused; - dq.getHeader()->qr = true; + setRCode(RCode::Refused); return true; case DNSAction::Action::Truncate: if (!dq.overTCP()) { updateBlockStats(); vinfolog("Query from %s for %s truncated because of dynamic block", dq.ids.origRemote.toStringWithPort(), dq.ids.qname.toLogString()); - dq.getHeader()->tc = true; - dq.getHeader()->qr = true; - dq.getHeader()->ra = dq.getHeader()->rd; - dq.getHeader()->aa = false; - dq.getHeader()->ad = false; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) { + header.tc = true; + header.qr = true; + header.ra = header.rd; + header.aa = false; + header.ad = false; + return true; + }); return true; } else { @@ -1079,7 +1202,10 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru case DNSAction::Action::NoRecurse: updateBlockStats(); vinfolog("Query from %s setting rd=0 because of dynamic block", dq.ids.origRemote.toStringWithPort()); - dq.getHeader()->rd = false; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) { + header.rd = false; + return true; + }); return true; default: updateBlockStats(); @@ -1110,34 +1236,37 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru return true; } -ssize_t udpClientSendRequestToBackend(const std::shared_ptr<DownstreamState>& ss, const int sd, const PacketBuffer& request, bool healthCheck) +ssize_t udpClientSendRequestToBackend(const std::shared_ptr<DownstreamState>& backend, const int socketDesc, const PacketBuffer& request, bool healthCheck) { ssize_t result; - if (ss->d_config.sourceItf == 0) { - result = send(sd, request.data(), request.size(), 0); + if (backend->d_config.sourceItf == 0) { + result = send(socketDesc, request.data(), request.size(), 0); } else { struct msghdr msgh; struct iovec iov; cmsgbuf_aligned cbuf; - ComboAddress remote(ss->d_config.remote); + ComboAddress remote(backend->d_config.remote); fillMSGHdr(&msgh, &iov, &cbuf, sizeof(cbuf), const_cast<char*>(reinterpret_cast<const char *>(request.data())), request.size(), &remote); - addCMsgSrcAddr(&msgh, &cbuf, &ss->d_config.sourceAddr, ss->d_config.sourceItf); - result = sendmsg(sd, &msgh, 0); + addCMsgSrcAddr(&msgh, &cbuf, &backend->d_config.sourceAddr, static_cast<int>(backend->d_config.sourceItf)); + result = sendmsg(socketDesc, &msgh, 0); } if (result == -1) { int savederrno = errno; - vinfolog("Error sending request to backend %s: %s", ss->d_config.remote.toStringWithPort(), stringerror(savederrno)); + vinfolog("Error sending request to backend %s: %s", backend->d_config.remote.toStringWithPort(), stringerror(savederrno)); /* This might sound silly, but on Linux send() might fail with EINVAL if the interface the socket was bound to doesn't exist anymore. We don't want to reconnect the real socket if the healthcheck failed, because it's not using the same socket. */ - if (!healthCheck && (savederrno == EINVAL || savederrno == ENODEV || savederrno == ENETUNREACH || savederrno == EBADF)) { - ss->reconnect(); + if (!healthCheck) { + if (savederrno == EINVAL || savederrno == ENODEV || savederrno == ENETUNREACH || savederrno == EHOSTUNREACH || savederrno == EBADF) { + backend->reconnect(); + } + backend->reportTimeoutOrError(); } } @@ -1150,14 +1279,14 @@ static bool isUDPQueryAcceptable(ClientState& cs, LocalHolders& holders, const s /* message was too large for our buffer */ vinfolog("Dropping message too large for our buffer"); ++cs.nonCompliantQueries; - ++g_stats.nonCompliantQueries; + ++dnsdist::metrics::g_stats.nonCompliantQueries; return false; } - expectProxyProtocol = expectProxyProtocolFrom(remote); + expectProxyProtocol = cs.d_enableProxyProtocol && expectProxyProtocolFrom(remote); if (!holders.acl->match(remote) && !expectProxyProtocol) { vinfolog("Query from %s dropped because of ACL", remote.toStringWithPort()); - ++g_stats.aclDrops; + ++dnsdist::metrics::g_stats.aclDrops; return false; } @@ -1187,7 +1316,7 @@ static bool isUDPQueryAcceptable(ClientState& cs, LocalHolders& holders, const s } ++cs.queries; - ++g_stats.queries; + ++dnsdist::metrics::g_stats.queries; return true; } @@ -1213,23 +1342,23 @@ bool checkDNSCryptQuery(const ClientState& cs, PacketBuffer& query, std::unique_ return false; } -bool checkQueryHeaders(const struct dnsheader* dh, ClientState& cs) +bool checkQueryHeaders(const struct dnsheader& dnsHeader, ClientState& clientState) { - if (dh->qr) { // don't respond to responses - ++g_stats.nonCompliantQueries; - ++cs.nonCompliantQueries; + if (dnsHeader.qr) { // don't respond to responses + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState.nonCompliantQueries; return false; } - if (dh->qdcount == 0) { - ++g_stats.emptyQueries; + if (dnsHeader.qdcount == 0) { + ++dnsdist::metrics::g_stats.emptyQueries; if (g_dropEmptyQueries) { return false; } } - if (dh->rd) { - ++g_stats.rdQueries; + if (dnsHeader.rd) { + ++dnsdist::metrics::g_stats.rdQueries; } return true; @@ -1270,8 +1399,12 @@ static bool prepareOutgoingResponse(LocalHolders& holders, const ClientState& cs ac(&dr, &result); } + if (dr.ids.d_extendedError) { + dnsdist::edns::addExtendedDNSError(dr.getMutableData(), dr.getMaximumSize(), dr.ids.d_extendedError->infoCode, dr.ids.d_extendedError->extraText); + } + if (cacheHit) { - ++g_stats.cacheHits; + ++dnsdist::metrics::g_stats.cacheHits; } if (dr.isAsynchronous()) { @@ -1303,20 +1436,19 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dq, LocalHolders& holders const auto rcode = dq.getHeader()->rcode; if (rcode == RCode::NXDomain) { - ++g_stats.ruleNXDomain; + ++dnsdist::metrics::g_stats.ruleNXDomain; } else if (rcode == RCode::Refused) { - ++g_stats.ruleRefused; + ++dnsdist::metrics::g_stats.ruleRefused; } else if (rcode == RCode::ServFail) { - ++g_stats.ruleServFail; + ++dnsdist::metrics::g_stats.ruleServFail; } - ++g_stats.selfAnswered; + ++dnsdist::metrics::g_stats.selfAnswered; ++dq.ids.cs->responses; return ProcessQueryResult::SendAnswer; } - std::shared_ptr<ServerPool> serverPool = getPool(*holders.pools, dq.ids.poolName); std::shared_ptr<ServerPolicy> poolPolicy = serverPool->policy; dq.ids.packetCache = serverPool->packetCache; @@ -1343,7 +1475,7 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dq, LocalHolders& holders return ProcessQueryResult::Drop; } - ++g_stats.responses; + ++dnsdist::metrics::g_stats.responses; ++dq.ids.cs->responses; return ProcessQueryResult::SendAnswer; } @@ -1370,7 +1502,10 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dq, LocalHolders& holders yet, as we will do a second-lookup */ if (dq.ids.packetCache->get(dq, dq.getHeader()->id, &dq.ids.cacheKey, dq.ids.subnet, dq.ids.dnssecOK, forwardedOverUDP, allowExpired, false, true, dq.ids.protocol != dnsdist::Protocol::DoH || forwardedOverUDP)) { - restoreFlags(dq.getHeader(), dq.ids.origFlags); + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [flags=dq.ids.origFlags](dnsheader& header) { + restoreFlags(&header, flags); + return true; + }); vinfolog("Packet cache hit for query for %s|%s from %s (%s, %d bytes)", dq.ids.qname.toLogString(), QType(dq.ids.qtype).toString(), dq.ids.origRemote.toStringWithPort(), dq.ids.protocol.toString(), dq.getData().size()); @@ -1378,7 +1513,7 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dq, LocalHolders& holders return ProcessQueryResult::Drop; } - ++g_stats.responses; + ++dnsdist::metrics::g_stats.responses; ++dq.ids.cs->responses; return ProcessQueryResult::SendAnswer; } @@ -1389,7 +1524,7 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dq, LocalHolders& holders return ProcessQueryResult::Drop; } - ++g_stats.responses; + ++dnsdist::metrics::g_stats.responses; ++dq.ids.cs->responses; return ProcessQueryResult::SendAnswer; } @@ -1397,23 +1532,26 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dq, LocalHolders& holders vinfolog("Packet cache miss for query for %s|%s from %s (%s, %d bytes)", dq.ids.qname.toLogString(), QType(dq.ids.qtype).toString(), dq.ids.origRemote.toStringWithPort(), dq.ids.protocol.toString(), dq.getData().size()); - ++g_stats.cacheMisses; + ++dnsdist::metrics::g_stats.cacheMisses; } if (!selectedBackend) { - ++g_stats.noPolicy; + ++dnsdist::metrics::g_stats.noPolicy; vinfolog("%s query for %s|%s from %s, no downstream server available", g_servFailOnNoPolicy ? "ServFailed" : "Dropped", dq.ids.qname.toLogString(), QType(dq.ids.qtype).toString(), dq.ids.origRemote.toStringWithPort()); if (g_servFailOnNoPolicy) { - dq.getHeader()->rcode = RCode::ServFail; - dq.getHeader()->qr = true; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) { + header.rcode = RCode::ServFail; + header.qr = true; + return true; + }); fixUpQueryTurnedResponse(dq, dq.ids.origFlags); if (!prepareOutgoingResponse(holders, *dq.ids.cs, dq, false)) { return ProcessQueryResult::Drop; } - ++g_stats.responses; + ++dnsdist::metrics::g_stats.responses; ++dq.ids.cs->responses; // no response-only statistics counter to update. return ProcessQueryResult::SendAnswer; @@ -1423,12 +1561,19 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dq, LocalHolders& holders } /* save the DNS flags as sent to the backend so we can cache the answer with the right flags later */ - dq.ids.cacheFlags = *getFlagsFromDNSHeader(dq.getHeader()); + dq.ids.cacheFlags = *getFlagsFromDNSHeader(dq.getHeader().get()); if (dq.addXPF && selectedBackend->d_config.xpfRRCode != 0) { addXPF(dq, selectedBackend->d_config.xpfRRCode); } + if (selectedBackend->d_config.useProxyProtocol && dq.getProtocol().isEncrypted() && selectedBackend->d_config.d_proxyProtocolAdvertiseTLS) { + if (!dq.proxyProtocolValues) { + dq.proxyProtocolValues = std::make_unique<std::vector<ProxyProtocolValue>>(); + } + dq.proxyProtocolValues->push_back(ProxyProtocolValue{"", static_cast<uint8_t>(ProxyProtocolValue::Types::PP_TLV_SSL)}); + } + selectedBackend->incQueriesCount(); return ProcessQueryResult::PassToBackend; } @@ -1473,7 +1618,7 @@ public: return handleResponse(now, std::move(response)); } - void notifyIOError(InternalQueryState&& query, const struct timeval& now) override + void notifyIOError(const struct timeval&, TCPResponse&&) override { // nothing to do } @@ -1544,41 +1689,48 @@ ProcessQueryResult processQuery(DNSQuestion& dq, LocalHolders& holders, std::sha return ProcessQueryResult::Drop; } -bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer& query, ComboAddress& dest) +bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& downstream, uint16_t queryID, DNSQuestion& dnsQuestion, PacketBuffer& query, bool actuallySend) { - bool doh = dq.ids.du != nullptr; + bool doh = dnsQuestion.ids.du != nullptr; bool failed = false; - size_t proxyPayloadSize = 0; - if (ds->d_config.useProxyProtocol) { + if (downstream->d_config.useProxyProtocol) { try { - if (addProxyProtocol(dq, &proxyPayloadSize)) { - if (dq.ids.du) { - dq.ids.du->proxyProtocolPayloadSize = proxyPayloadSize; - } - } + addProxyProtocol(dnsQuestion, &dnsQuestion.ids.d_proxyProtocolPayloadSize); } catch (const std::exception& e) { - vinfolog("Adding proxy protocol payload to %s query from %s failed: %s", (dq.ids.du ? "DoH" : ""), dq.ids.origDest.toStringWithPort(), e.what()); + vinfolog("Adding proxy protocol payload to %s query from %s failed: %s", (dnsQuestion.ids.du ? "DoH" : ""), dnsQuestion.ids.origDest.toStringWithPort(), e.what()); return false; } } + if (doh && !dnsQuestion.ids.d_packet) { + dnsQuestion.ids.d_packet = std::make_unique<PacketBuffer>(query); + } + try { - int fd = ds->pickSocketForSending(); - dq.ids.backendFD = fd; - dq.ids.origID = queryID; - dq.ids.forwardedOverUDP = true; + int descriptor = downstream->pickSocketForSending(); + if (actuallySend) { + dnsQuestion.ids.backendFD = descriptor; + } + dnsQuestion.ids.origID = queryID; + dnsQuestion.ids.forwardedOverUDP = true; - vinfolog("Got query for %s|%s from %s%s, relayed to %s", dq.ids.qname.toLogString(), QType(dq.ids.qtype).toString(), dq.ids.origRemote.toStringWithPort(), (doh ? " (https)" : ""), ds->getNameWithAddr()); + vinfolog("Got query for %s|%s from %s%s, relayed to %s%s", dnsQuestion.ids.qname.toLogString(), QType(dnsQuestion.ids.qtype).toString(), dnsQuestion.ids.origRemote.toStringWithPort(), (doh ? " (https)" : ""), downstream->getNameWithAddr(), actuallySend ? "" : " (xsk)"); - auto idOffset = ds->saveState(std::move(dq.ids)); + /* make a copy since we cannot touch dnsQuestion.ids after the move */ + auto proxyProtocolPayloadSize = dnsQuestion.ids.d_proxyProtocolPayloadSize; + auto idOffset = downstream->saveState(std::move(dnsQuestion.ids)); /* set the correct ID */ - memcpy(query.data() + proxyPayloadSize, &idOffset, sizeof(idOffset)); + memcpy(&query.at(proxyProtocolPayloadSize), &idOffset, sizeof(idOffset)); + + if (!actuallySend) { + return true; + } /* you can't touch ids or du after this line, unless the call returned a non-negative value, because it might already have been freed */ - ssize_t ret = udpClientSendRequestToBackend(ds, fd, query); + ssize_t ret = udpClientSendRequestToBackend(downstream, descriptor, query); if (ret < 0) { failed = true; @@ -1587,15 +1739,12 @@ bool assignOutgoingUDPQueryToBackend(std::shared_ptr<DownstreamState>& ds, uint1 if (failed) { /* clear up the state. In the very unlikely event it was reused in the meantime, so be it. */ - auto cleared = ds->getState(idOffset); + auto cleared = downstream->getState(idOffset); if (cleared) { - dq.ids.du = std::move(cleared->du); - if (dq.ids.du) { - dq.ids.du->status_code = 502; - } + dnsQuestion.ids.du = std::move(cleared->du); } - ++g_stats.downstreamSendErrors; - ++ds->sendErrors; + ++dnsdist::metrics::g_stats.downstreamSendErrors; + ++downstream->sendErrors; return false; } } @@ -1651,16 +1800,20 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct { /* this pointer will be invalidated the second the buffer is resized, don't hold onto it! */ - struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(query.data()); - queryId = ntohs(dh->id); + const dnsheader_aligned dnsHeader(query.data()); + queryId = ntohs(dnsHeader->id); - if (!checkQueryHeaders(dh, cs)) { + if (!checkQueryHeaders(*dnsHeader, cs)) { return; } - if (dh->qdcount == 0) { - dh->rcode = RCode::NotImp; - dh->qr = true; + if (dnsHeader->qdcount == 0) { + dnsdist::PacketMangling::editDNSHeaderFromPacket(query, [](dnsheader& header) { + header.rcode = RCode::NotImp; + header.qr = true; + return true; + }); + sendUDPResponse(cs.udpFD, query, 0, dest, remote); return; } @@ -1671,7 +1824,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct ids.protocol = dnsdist::Protocol::DNSCryptUDP; } DNSQuestion dq(ids, query); - const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader()); + const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader().get()); ids.origFlags = *flags; if (!proxyProtocolValues.empty()) { @@ -1686,7 +1839,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct } // the buffer might have been invalidated by now (resized) - struct dnsheader* dh = dq.getHeader(); + const auto dh = dq.getHeader(); if (result == ProcessQueryResult::SendAnswer) { #ifndef DISABLE_RECVMMSG #if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) @@ -1725,13 +1878,142 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct return; } - assignOutgoingUDPQueryToBackend(ss, dh->id, dq, query, dest); + assignOutgoingUDPQueryToBackend(ss, dh->id, dq, query); } catch(const std::exception& e){ vinfolog("Got an error in UDP question thread while parsing a query from %s, id %d: %s", ids.origRemote.toStringWithPort(), queryId, e.what()); } } +#ifdef HAVE_XSK +namespace dnsdist::xsk +{ +bool XskProcessQuery(ClientState& cs, LocalHolders& holders, XskPacket& packet) +{ + uint16_t queryId = 0; + const auto& remote = packet.getFromAddr(); + const auto& dest = packet.getToAddr(); + InternalQueryState ids; + ids.cs = &cs; + ids.origRemote = remote; + ids.hopRemote = remote; + ids.origDest = dest; + ids.hopLocal = dest; + ids.protocol = dnsdist::Protocol::DoUDP; + ids.xskPacketHeader = packet.cloneHeaderToPacketBuffer(); + + try { + bool expectProxyProtocol = false; + if (!XskIsQueryAcceptable(packet, cs, holders, expectProxyProtocol)) { + return false; + } + + auto query = packet.clonePacketBuffer(); + std::vector<ProxyProtocolValue> proxyProtocolValues; + if (expectProxyProtocol && !handleProxyProtocol(remote, false, *holders.acl, query, ids.origRemote, ids.origDest, proxyProtocolValues)) { + return false; + } + + ids.queryRealTime.start(); + + auto dnsCryptResponse = checkDNSCryptQuery(cs, query, ids.dnsCryptQuery, ids.queryRealTime.d_start.tv_sec, false); + if (dnsCryptResponse) { + packet.setPayload(query); + return true; + } + + { + /* this pointer will be invalidated the second the buffer is resized, don't hold onto it! */ + dnsheader_aligned dnsHeader(query.data()); + queryId = ntohs(dnsHeader->id); + + if (!checkQueryHeaders(*dnsHeader.get(), cs)) { + return false; + } + + if (dnsHeader->qdcount == 0) { + dnsdist::PacketMangling::editDNSHeaderFromPacket(query, [](dnsheader& header) { + header.rcode = RCode::NotImp; + header.qr = true; + return true; + }); + packet.setPayload(query); + return true; + } + } + + ids.qname = DNSName(reinterpret_cast<const char*>(query.data()), query.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass); + if (ids.origDest.sin4.sin_family == 0) { + ids.origDest = cs.local; + } + if (ids.dnsCryptQuery) { + ids.protocol = dnsdist::Protocol::DNSCryptUDP; + } + DNSQuestion dq(ids, query); + if (!proxyProtocolValues.empty()) { + dq.proxyProtocolValues = make_unique<std::vector<ProxyProtocolValue>>(std::move(proxyProtocolValues)); + } + std::shared_ptr<DownstreamState> ss{nullptr}; + auto result = processQuery(dq, holders, ss); + + if (result == ProcessQueryResult::Drop) { + return false; + } + + if (result == ProcessQueryResult::SendAnswer) { + packet.setPayload(query); + if (dq.ids.delayMsec > 0) { + packet.addDelay(dq.ids.delayMsec); + } + const auto dh = dq.getHeader(); + handleResponseSent(ids.qname, ids.qtype, 0., remote, ComboAddress(), query.size(), *dh, dnsdist::Protocol::DoUDP, dnsdist::Protocol::DoUDP, false); + return true; + } + + if (result != ProcessQueryResult::PassToBackend || ss == nullptr) { + return false; + } + + // the buffer might have been invalidated by now (resized) + const auto dh = dq.getHeader(); + if (ss->isTCPOnly()) { + std::string proxyProtocolPayload; + /* we need to do this _before_ creating the cross protocol query because + after that the buffer will have been moved */ + if (ss->d_config.useProxyProtocol) { + proxyProtocolPayload = getProxyProtocolPayload(dq); + } + + ids.origID = dh->id; + auto cpq = std::make_unique<UDPCrossProtocolQuery>(std::move(query), std::move(ids), ss); + cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); + + ss->passCrossProtocolQuery(std::move(cpq)); + return false; + } + + if (ss->d_xskInfos.empty()) { + assignOutgoingUDPQueryToBackend(ss, dh->id, dq, query, true); + return false; + } + else { + assignOutgoingUDPQueryToBackend(ss, dh->id, dq, query, false); + auto sourceAddr = ss->pickSourceAddressForSending(); + packet.setAddr(sourceAddr, ss->d_config.sourceMACAddr, ss->d_config.remote, ss->d_config.destMACAddr); + packet.setPayload(query); + packet.rewrite(); + return true; + } + } + catch (const std::exception& e) { + vinfolog("Got an error in UDP question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what()); + } + return false; +} + +} +#endif /* HAVE_XSK */ + #ifndef DISABLE_RECVMMSG #if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holders) @@ -1760,7 +2042,7 @@ static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holde - we use it for self-generated responses (from rule or cache) but we only accept incoming payloads up to that size */ - const size_t initialBufferSize = getInitialUDPPacketBufferSize(); + const size_t initialBufferSize = getInitialUDPPacketBufferSize(cs->d_enableProxyProtocol); const size_t maxIncomingPacketSize = getMaximumIncomingPacketSize(*cs); /* initialize the structures needed to receive our messages */ @@ -1799,7 +2081,7 @@ static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holde const ComboAddress& remote = recvData[msgIdx].remote; if (static_cast<size_t>(got) < sizeof(struct dnsheader)) { - ++g_stats.nonCompliantQueries; + ++dnsdist::metrics::g_stats.nonCompliantQueries; ++cs->nonCompliantQueries; continue; } @@ -1850,7 +2132,7 @@ static void udpClientThread(std::vector<ClientState*> states) size_t maxIncomingPacketSize{0}; int socket{-1}; }; - const size_t initialBufferSize = getInitialUDPPacketBufferSize(); + const size_t initialBufferSize = getInitialUDPPacketBufferSize(true); PacketBuffer packet(initialBufferSize); struct msghdr msgh; @@ -1866,7 +2148,7 @@ static void udpClientThread(std::vector<ClientState*> states) ssize_t got = recvmsg(param.socket, &msgh, 0); if (got < 0 || static_cast<size_t>(got) < sizeof(struct dnsheader)) { - ++g_stats.nonCompliantQueries; + ++dnsdist::metrics::g_stats.nonCompliantQueries; ++param.cs->nonCompliantQueries; return; } @@ -1883,7 +2165,7 @@ static void udpClientThread(std::vector<ClientState*> states) } if (params.size() == 1) { - auto param = params.at(0); + const auto& param = params.at(0); remote.sin4.sin_family = param.cs->local.sin4.sin_family; /* used by HarvestDestinationAddress */ cmsgbuf_aligned cbuf; @@ -1949,28 +2231,29 @@ pdns::stat16_t g_cacheCleaningPercentage{100}; static void maintThread() { setThreadName("dnsdist/main"); - int interval = 1; + constexpr int interval = 1; size_t counter = 0; int32_t secondsToWaitLog = 0; for (;;) { - sleep(interval); + std::this_thread::sleep_for(std::chrono::seconds(interval)); { auto lua = g_lua.lock(); - auto f = lua->readVariable<boost::optional<std::function<void()> > >("maintenance"); - if (f) { - try { - (*f)(); - secondsToWaitLog = 0; + try { + auto maintenanceCallback = lua->readVariable<boost::optional<std::function<void()> > >("maintenance"); + if (maintenanceCallback) { + (*maintenanceCallback)(); } - catch(const std::exception &e) { - if (secondsToWaitLog <= 0) { - infolog("Error during execution of maintenance function: %s", e.what()); - secondsToWaitLog = 61; - } - secondsToWaitLog -= interval; + dnsdist::lua::hooks::runMaintenanceHooks(*lua); + secondsToWaitLog = 0; + } + catch (const std::exception &e) { + if (secondsToWaitLog <= 0) { + warnlog("Error during execution of maintenance function(s): %s", e.what()); + secondsToWaitLog = 61; } + secondsToWaitLog -= interval; } } @@ -1984,7 +2267,7 @@ static void maintThread() if something prevents us from cleaning the expired entries */ auto localPools = g_pools.getLocal(); for (const auto& entry : *localPools) { - auto& pool = entry.second; + const auto& pool = entry.second; auto packetCache = pool->packetCache; if (!packetCache) { @@ -1996,7 +2279,7 @@ static void maintThread() /* if we need to keep stale data for this cache (ie, not clear expired entries when at least one pool using this cache has all its backends down) */ - if (packetCache->keepStaleData() && iter->second == false) { + if (packetCache->keepStaleData() && !iter->second) { /* so far all pools had at least one backend up */ if (pool->countServers(true) == 0) { iter->second = true; @@ -2007,10 +2290,10 @@ static void maintThread() const time_t now = time(nullptr); for (const auto& pair : caches) { /* shall we keep expired entries ? */ - if (pair.second == true) { + if (pair.second) { continue; } - auto& packetCache = pair.first; + const auto& packetCache = pair.first; size_t upTo = (packetCache->getMaxEntries()* (100 - g_cacheCleaningPercentage)) / 100; packetCache->purgeExpired(upTo, now); } @@ -2070,11 +2353,7 @@ static void healthChecksThread() std::unique_ptr<FDMultiplexer> mplexer{nullptr}; for (auto& dss : *states) { - auto delta = dss->sw.udiffAndSet()/1000000.0; - dss->queryLoad.store(1.0*(dss->queries.load() - dss->prev.queries.load())/delta); - dss->dropRate.store(1.0*(dss->reuseds.load() - dss->prev.reuseds.load())/delta); - dss->prev.queries.store(dss->queries.load()); - dss->prev.reuseds.store(dss->reuseds.load()); + dss->updateStatisticsInfo(); dss->handleUDPTimeouts(); @@ -2124,9 +2403,9 @@ static void bindAny(int af, int sock) static void dropGroupPrivs(gid_t gid) { - if (gid) { + if (gid != 0) { if (setgid(gid) == 0) { - if (setgroups(0, NULL) < 0) { + if (setgroups(0, nullptr) < 0) { warnlog("Warning: Unable to drop supplementary gids: %s", stringerror()); } } @@ -2138,8 +2417,8 @@ static void dropGroupPrivs(gid_t gid) static void dropUserPrivs(uid_t uid) { - if(uid) { - if(setuid(uid) < 0) { + if (uid != 0) { + if (setuid(uid) < 0) { warnlog("Warning: Unable to set user ID to %d: %s", uid, stringerror()); } } @@ -2197,145 +2476,185 @@ static void checkFileDescriptorsLimits(size_t udpBindsCount, size_t tcpBindsCoun static bool g_warned_ipv6_recvpktinfo = false; -static void setUpLocalBind(std::unique_ptr<ClientState>& cstate) +static void setupLocalSocket(ClientState& clientState, const ComboAddress& addr, int& socket, bool tcp, bool warn) { - auto setupSocket = [](ClientState& cs, const ComboAddress& addr, int& socket, bool tcp, bool warn) { - (void) warn; - socket = SSocket(addr.sin4.sin_family, tcp == false ? SOCK_DGRAM : SOCK_STREAM, 0); + (void) warn; + socket = SSocket(addr.sin4.sin_family, !tcp ? SOCK_DGRAM : SOCK_STREAM, 0); - if (tcp) { - SSetsockopt(socket, SOL_SOCKET, SO_REUSEADDR, 1); + if (tcp) { + SSetsockopt(socket, SOL_SOCKET, SO_REUSEADDR, 1); #ifdef TCP_DEFER_ACCEPT - SSetsockopt(socket, IPPROTO_TCP, TCP_DEFER_ACCEPT, 1); + SSetsockopt(socket, IPPROTO_TCP, TCP_DEFER_ACCEPT, 1); #endif - if (cs.fastOpenQueueSize > 0) { + if (clientState.fastOpenQueueSize > 0) { #ifdef TCP_FASTOPEN - SSetsockopt(socket, IPPROTO_TCP, TCP_FASTOPEN, cs.fastOpenQueueSize); + SSetsockopt(socket, IPPROTO_TCP, TCP_FASTOPEN, clientState.fastOpenQueueSize); #ifdef TCP_FASTOPEN_KEY - if (!g_TCPFastOpenKey.empty()) { - auto res = setsockopt(socket, IPPROTO_IP, TCP_FASTOPEN_KEY, g_TCPFastOpenKey.data(), g_TCPFastOpenKey.size() * sizeof(g_TCPFastOpenKey[0])); - if (res == -1) { - throw runtime_error("setsockopt for level IPPROTO_TCP and opname TCP_FASTOPEN_KEY failed: " + stringerror()); - } + if (!g_TCPFastOpenKey.empty()) { + auto res = setsockopt(socket, IPPROTO_IP, TCP_FASTOPEN_KEY, g_TCPFastOpenKey.data(), g_TCPFastOpenKey.size() * sizeof(g_TCPFastOpenKey[0])); + if (res == -1) { + throw runtime_error("setsockopt for level IPPROTO_TCP and opname TCP_FASTOPEN_KEY failed: " + stringerror()); } + } #endif /* TCP_FASTOPEN_KEY */ #else /* TCP_FASTOPEN */ - if (warn) { - warnlog("TCP Fast Open has been configured on local address '%s' but is not supported", addr.toStringWithPort()); - } -#endif /* TCP_FASTOPEN */ + if (warn) { + warnlog("TCP Fast Open has been configured on local address '%s' but is not supported", addr.toStringWithPort()); } +#endif /* TCP_FASTOPEN */ } + } - if (addr.sin4.sin_family == AF_INET6) { - SSetsockopt(socket, IPPROTO_IPV6, IPV6_V6ONLY, 1); - } + if (addr.sin4.sin_family == AF_INET6) { + SSetsockopt(socket, IPPROTO_IPV6, IPV6_V6ONLY, 1); + } - bindAny(addr.sin4.sin_family, socket); + bindAny(addr.sin4.sin_family, socket); - if (!tcp && IsAnyAddress(addr)) { - int one = 1; - (void) setsockopt(socket, IPPROTO_IP, GEN_IP_PKTINFO, &one, sizeof(one)); // linux supports this, so why not - might fail on other systems + if (!tcp && IsAnyAddress(addr)) { + int one = 1; + (void) setsockopt(socket, IPPROTO_IP, GEN_IP_PKTINFO, &one, sizeof(one)); // linux supports this, so why not - might fail on other systems #ifdef IPV6_RECVPKTINFO - if (addr.isIPv6() && setsockopt(socket, IPPROTO_IPV6, IPV6_RECVPKTINFO, &one, sizeof(one)) < 0 && - !g_warned_ipv6_recvpktinfo) { - warnlog("Warning: IPV6_RECVPKTINFO setsockopt failed: %s", stringerror()); - g_warned_ipv6_recvpktinfo = true; - } -#endif + if (addr.isIPv6() && setsockopt(socket, IPPROTO_IPV6, IPV6_RECVPKTINFO, &one, sizeof(one)) < 0 && + !g_warned_ipv6_recvpktinfo) { + warnlog("Warning: IPV6_RECVPKTINFO setsockopt failed: %s", stringerror()); + g_warned_ipv6_recvpktinfo = true; } +#endif + } - if (cs.reuseport) { - if (!setReusePort(socket)) { - if (warn) { - /* no need to warn again if configured but support is not available, we already did for UDP */ - warnlog("SO_REUSEPORT has been configured on local address '%s' but is not supported", addr.toStringWithPort()); - } + if (clientState.reuseport) { + if (!setReusePort(socket)) { + if (warn) { + /* no need to warn again if configured but support is not available, we already did for UDP */ + warnlog("SO_REUSEPORT has been configured on local address '%s' but is not supported", addr.toStringWithPort()); } } + } + const bool isQUIC = clientState.doqFrontend != nullptr || clientState.doh3Frontend != nullptr; + if (isQUIC) { + /* disable fragmentation and force PMTU discovery for QUIC-enabled sockets */ + try { + setSocketForcePMTU(socket, addr.sin4.sin_family); + } + catch (const std::exception& e) { + warnlog("Failed to set IP_MTU_DISCOVER on QUIC server socket for local address '%s': %s", addr.toStringWithPort(), e.what()); + } + } + else if (!tcp && !clientState.dnscryptCtx) { /* Only set this on IPv4 UDP sockets. Don't set it for DNSCrypt binds. DNSCrypt pads queries for privacy purposes, so we do receive large, sometimes fragmented datagrams. */ - if (!tcp && !cs.dnscryptCtx) { + try { + setSocketIgnorePMTU(socket, addr.sin4.sin_family); + } + catch (const std::exception& e) { + warnlog("Failed to set IP_MTU_DISCOVER on UDP server socket for local address '%s': %s", addr.toStringWithPort(), e.what()); + } + } + + if (!tcp) { + if (g_socketUDPSendBuffer > 0) { try { - setSocketIgnorePMTU(socket, addr.sin4.sin_family); + setSocketSendBuffer(socket, g_socketUDPSendBuffer); } catch (const std::exception& e) { - warnlog("Failed to set IP_MTU_DISCOVER on UDP server socket for local address '%s': %s", addr.toStringWithPort(), e.what()); + warnlog(e.what()); } - } - - if (!tcp) { - if (g_socketUDPSendBuffer > 0) { - try { - setSocketSendBuffer(socket, g_socketUDPSendBuffer); - } - catch (const std::exception& e) { - warnlog(e.what()); + } else { + try { + auto result = raiseSocketSendBufferToMax(socket); + if (result > 0) { + infolog("Raised send buffer to %u for local address '%s'", result, addr.toStringWithPort()); } + } catch (const std::exception& e) { + warnlog(e.what()); } + } - if (g_socketUDPRecvBuffer > 0) { - try { - setSocketReceiveBuffer(socket, g_socketUDPRecvBuffer); - } - catch (const std::exception& e) { - warnlog(e.what()); + if (g_socketUDPRecvBuffer > 0) { + try { + setSocketReceiveBuffer(socket, g_socketUDPRecvBuffer); + } + catch (const std::exception& e) { + warnlog(e.what()); + } + } else { + try { + auto result = raiseSocketReceiveBufferToMax(socket); + if (result > 0) { + infolog("Raised receive buffer to %u for local address '%s'", result, addr.toStringWithPort()); } + } catch (const std::exception& e) { + warnlog(e.what()); } } + } - const std::string& itf = cs.interface; - if (!itf.empty()) { + const std::string& itf = clientState.interface; + if (!itf.empty()) { #ifdef SO_BINDTODEVICE - int res = setsockopt(socket, SOL_SOCKET, SO_BINDTODEVICE, itf.c_str(), itf.length()); - if (res != 0) { - warnlog("Error setting up the interface on local address '%s': %s", addr.toStringWithPort(), stringerror()); - } + int res = setsockopt(socket, SOL_SOCKET, SO_BINDTODEVICE, itf.c_str(), itf.length()); + if (res != 0) { + warnlog("Error setting up the interface on local address '%s': %s", addr.toStringWithPort(), stringerror()); + } #else - if (warn) { - warnlog("An interface has been configured on local address '%s' but SO_BINDTODEVICE is not supported", addr.toStringWithPort()); - } -#endif + if (warn) { + warnlog("An interface has been configured on local address '%s' but SO_BINDTODEVICE is not supported", addr.toStringWithPort()); } +#endif + } #ifdef HAVE_EBPF - if (g_defaultBPFFilter && !g_defaultBPFFilter->isExternal()) { - cs.attachFilter(g_defaultBPFFilter, socket); - vinfolog("Attaching default BPF Filter to %s frontend %s", (!tcp ? "UDP" : "TCP"), addr.toStringWithPort()); - } + if (g_defaultBPFFilter && !g_defaultBPFFilter->isExternal()) { + clientState.attachFilter(g_defaultBPFFilter, socket); + vinfolog("Attaching default BPF Filter to %s frontend %s", (!tcp ? std::string("UDP") : std::string("TCP")), addr.toStringWithPort()); + } #endif /* HAVE_EBPF */ - SBind(socket, addr); + SBind(socket, addr); - if (tcp) { - SListen(socket, cs.tcpListenQueueSize); + if (tcp) { + SListen(socket, clientState.tcpListenQueueSize); - if (cs.tlsFrontend != nullptr) { - infolog("Listening on %s for TLS", addr.toStringWithPort()); - } - else if (cs.dohFrontend != nullptr) { - infolog("Listening on %s for DoH", addr.toStringWithPort()); - } - else if (cs.dnscryptCtx != nullptr) { - infolog("Listening on %s for DNSCrypt", addr.toStringWithPort()); - } - else { - infolog("Listening on %s", addr.toStringWithPort()); - } + if (clientState.tlsFrontend != nullptr) { + infolog("Listening on %s for TLS", addr.toStringWithPort()); } - }; + else if (clientState.dohFrontend != nullptr) { + infolog("Listening on %s for DoH", addr.toStringWithPort()); + } + else if (clientState.dnscryptCtx != nullptr) { + infolog("Listening on %s for DNSCrypt", addr.toStringWithPort()); + } + else { + infolog("Listening on %s", addr.toStringWithPort()); + } + } else { + if (clientState.doqFrontend != nullptr) { + infolog("Listening on %s for DoQ", addr.toStringWithPort()); + } else if (clientState.doh3Frontend != nullptr) { + infolog("Listening on %s for DoH3", addr.toStringWithPort()); + } +#ifdef HAVE_XSK + else if (clientState.xskInfo != nullptr) { + infolog("Listening on %s (XSK-enabled)", addr.toStringWithPort()); + } +#endif + } +} +static void setUpLocalBind(std::unique_ptr<ClientState>& cstate) +{ /* skip some warnings if there is an identical UDP context */ - bool warn = cstate->tcp == false || cstate->tlsFrontend != nullptr || cstate->dohFrontend != nullptr; - int& fd = cstate->tcp == false ? cstate->udpFD : cstate->tcpFD; + bool warn = !cstate->tcp || cstate->tlsFrontend != nullptr || cstate->dohFrontend != nullptr; + int& descriptor = !cstate->tcp ? cstate->udpFD : cstate->tcpFD; (void) warn; - setupSocket(*cstate, cstate->local, fd, cstate->tcp, warn); + setupLocalSocket(*cstate, cstate->local, descriptor, cstate->tcp, warn); for (auto& [addr, socket] : cstate->d_additionalAddresses) { - setupSocket(*cstate, addr, socket, true, false); + setupLocalSocket(*cstate, addr, socket, true, false); } if (cstate->tlsFrontend != nullptr) { @@ -2348,6 +2667,12 @@ static void setUpLocalBind(std::unique_ptr<ClientState>& cstate) if (cstate->dohFrontend != nullptr) { cstate->dohFrontend->setup(); } + if (cstate->doqFrontend != nullptr) { + cstate->doqFrontend->setup(); + } + if (cstate->doh3Frontend != nullptr) { + cstate->doh3Frontend->setup(); + } cstate->ready = true; } @@ -2379,7 +2704,7 @@ static void usage() cout<<"-c,--client Operate as a client, connect to dnsdist. This reads\n"; cout<<" controlSocket from your configuration file, but also\n"; cout<<" accepts an IP:PORT argument\n"; -#ifdef HAVE_LIBSODIUM +#if defined(HAVE_LIBSODIUM) || defined(HAVE_LIBCRYPTO) cout<<"-k,--setkey KEY Use KEY for encrypted communication to dnsdist. This\n"; cout<<" is similar to setting setKey in the configuration file.\n"; cout<<" NOTE: this will leak this key in your shell's history\n"; @@ -2403,28 +2728,25 @@ static void usage() } #ifdef COVERAGE -extern "C" -{ - void __gcov_dump(void); -} - static void cleanupLuaObjects() { - /* when our coverage mode is enabled, we need to make - that the Lua objects destroyed before the Lua contexts. */ + /* when our coverage mode is enabled, we need to make sure + that the Lua objects are destroyed before the Lua contexts. */ g_ruleactions.setState({}); g_respruleactions.setState({}); g_cachehitrespruleactions.setState({}); g_selfansweredrespruleactions.setState({}); g_dstates.setState({}); g_policy.setState(ServerPolicy()); + g_pools.setState({}); clearWebHandlers(); + dnsdist::lua::hooks::clearMaintenanceHooks(); } static void sigTermHandler(int) { cleanupLuaObjects(); - __gcov_dump(); + pdns::coverage::dumpCoverageData(); _exit(EXIT_SUCCESS); } #else /* COVERAGE */ @@ -2447,7 +2769,7 @@ static void sigTermHandler(int) we crash trying to exit, but let's try to avoid the warnings in our tests. */ - if (g_syslog) { + if (dnsdist::logging::LoggingConfiguration::getSyslog()) { syslog(LOG_INFO, "Exiting on user request"); } std::cout<<"Exiting on user request"<<std::endl; @@ -2457,211 +2779,463 @@ static void sigTermHandler(int) } #endif /* COVERAGE */ -int main(int argc, char** argv) +static void reportFeatures() { - try { - size_t udpBindsCount = 0; - size_t tcpBindsCount = 0; -#ifdef HAVE_LIBEDIT -#ifndef DISABLE_COMPLETION - rl_attempted_completion_function = my_completion; - rl_completion_append_character = 0; -#endif /* DISABLE_COMPLETION */ -#endif /* HAVE_LIBEDIT */ - - signal(SIGPIPE, SIG_IGN); - signal(SIGCHLD, SIG_IGN); - signal(SIGTERM, sigTermHandler); - - openlog("dnsdist", LOG_PID|LOG_NDELAY, LOG_DAEMON); - -#ifdef HAVE_LIBSODIUM - if (sodium_init() == -1) { - cerr<<"Unable to initialize crypto library"<<endl; - exit(EXIT_FAILURE); - } -#endif - dnsdist::initRandom(); - g_hashperturb = dnsdist::getRandomValue(0xffffffff); - - ComboAddress clientAddress = ComboAddress(); - g_cmdLine.config=SYSCONFDIR "/dnsdist.conf"; - struct option longopts[]={ - {"acl", required_argument, 0, 'a'}, - {"check-config", no_argument, 0, 1}, - {"client", no_argument, 0, 'c'}, - {"config", required_argument, 0, 'C'}, - {"disable-syslog", no_argument, 0, 2}, - {"execute", required_argument, 0, 'e'}, - {"gid", required_argument, 0, 'g'}, - {"help", no_argument, 0, 'h'}, - {"local", required_argument, 0, 'l'}, - {"log-timestamps", no_argument, 0, 4}, - {"setkey", required_argument, 0, 'k'}, - {"supervised", no_argument, 0, 3}, - {"uid", required_argument, 0, 'u'}, - {"verbose", no_argument, 0, 'v'}, - {"version", no_argument, 0, 'V'}, - {0,0,0,0} - }; - int longindex=0; - string optstring; - for(;;) { - int c=getopt_long(argc, argv, "a:cC:e:g:hk:l:u:vV", longopts, &longindex); - if(c==-1) - break; - switch(c) { - case 1: - g_cmdLine.checkConfig=true; - break; - case 2: - g_syslog=false; - break; - case 3: - g_cmdLine.beSupervised=true; - break; - case 4: - g_logtimestamps=true; - break; - case 'C': - g_cmdLine.config=optarg; - break; - case 'c': - g_cmdLine.beClient=true; - break; - case 'e': - g_cmdLine.command=optarg; - break; - case 'g': - g_cmdLine.gid=optarg; - break; - case 'h': - cout<<"dnsdist "<<VERSION<<endl; - usage(); - cout<<"\n"; - exit(EXIT_SUCCESS); - break; - case 'a': - optstring=optarg; - g_ACL.modify([optstring](NetmaskGroup& nmg) { nmg.addMask(optstring); }); - break; - case 'k': -#ifdef HAVE_LIBSODIUM - if (B64Decode(string(optarg), g_consoleKey) < 0) { - cerr<<"Unable to decode key '"<<optarg<<"'."<<endl; - exit(EXIT_FAILURE); - } -#else - cerr<<"dnsdist has been built without libsodium, -k/--setkey is unsupported."<<endl; - exit(EXIT_FAILURE); -#endif - break; - case 'l': - g_cmdLine.locals.push_back(boost::trim_copy(string(optarg))); - break; - case 'u': - g_cmdLine.uid=optarg; - break; - case 'v': - g_verbose=true; - break; - case 'V': #ifdef LUAJIT_VERSION - cout<<"dnsdist "<<VERSION<<" ("<<LUA_RELEASE<<" ["<<LUAJIT_VERSION<<"])"<<endl; + cout<<"dnsdist "<<VERSION<<" ("<<LUA_RELEASE<<" ["<<LUAJIT_VERSION<<"])"<<endl; #else - cout<<"dnsdist "<<VERSION<<" ("<<LUA_RELEASE<<")"<<endl; + cout<<"dnsdist "<<VERSION<<" ("<<LUA_RELEASE<<")"<<endl; +#endif + cout<<"Enabled features: "; +#ifdef HAVE_XSK + cout<<"AF_XDP "; #endif - cout<<"Enabled features: "; #ifdef HAVE_CDB - cout<<"cdb "; + cout<<"cdb "; +#endif +#ifdef HAVE_DNS_OVER_QUIC + cout<<"dns-over-quic "; +#endif +#ifdef HAVE_DNS_OVER_HTTP3 + cout<<"dns-over-http3 "; #endif #ifdef HAVE_DNS_OVER_TLS - cout<<"dns-over-tls("; + cout<<"dns-over-tls("; #ifdef HAVE_GNUTLS - cout<<"gnutls"; + cout<<"gnutls"; #ifdef HAVE_LIBSSL - cout<<" "; -#endif + cout<<" "; #endif +#endif /* HAVE_GNUTLS */ #ifdef HAVE_LIBSSL - cout<<"openssl"; -#endif - cout<<") "; + cout<<"openssl"; #endif + cout<<") "; +#endif /* HAVE_DNS_OVER_TLS */ #ifdef HAVE_DNS_OVER_HTTPS - cout<<"dns-over-https(DOH) "; -#endif + cout<<"dns-over-https("; +#ifdef HAVE_LIBH2OEVLOOP + cout<<"h2o"; +#endif /* HAVE_LIBH2OEVLOOP */ +#if defined(HAVE_LIBH2OEVLOOP) && defined(HAVE_NGHTTP2) + cout<<" "; +#endif /* defined(HAVE_LIBH2OEVLOOP) && defined(HAVE_NGHTTP2) */ +#ifdef HAVE_NGHTTP2 + cout<<"nghttp2"; +#endif /* HAVE_NGHTTP2 */ + cout<<") "; +#endif /* HAVE_DNS_OVER_HTTPS */ #ifdef HAVE_DNSCRYPT - cout<<"dnscrypt "; + cout<<"dnscrypt "; #endif #ifdef HAVE_EBPF - cout<<"ebpf "; + cout<<"ebpf "; #endif #ifdef HAVE_FSTRM - cout<<"fstrm "; + cout<<"fstrm "; #endif #ifdef HAVE_IPCIPHER - cout<<"ipcipher "; + cout<<"ipcipher "; #endif #ifdef HAVE_LIBEDIT - cout<<"libedit "; + cout<<"libedit "; #endif #ifdef HAVE_LIBSODIUM - cout<<"libsodium "; + cout<<"libsodium "; #endif #ifdef HAVE_LMDB - cout<<"lmdb "; -#endif -#ifdef HAVE_NGHTTP2 - cout<<"outgoing-dns-over-https(nghttp2) "; + cout<<"lmdb "; #endif #ifndef DISABLE_PROTOBUF - cout<<"protobuf "; + cout<<"protobuf "; #endif #ifdef HAVE_RE2 - cout<<"re2 "; + cout<<"re2 "; #endif #ifndef DISABLE_RECVMMSG #if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) - cout<<"recvmmsg/sendmmsg "; + cout<<"recvmmsg/sendmmsg "; #endif #endif /* DISABLE_RECVMMSG */ #ifdef HAVE_NET_SNMP - cout<<"snmp "; + cout<<"snmp "; #endif #ifdef HAVE_SYSTEMD - cout<<"systemd"; + cout<<"systemd"; #endif - cout<<endl; - exit(EXIT_SUCCESS); - break; - case '?': - //getopt_long printed an error message. - usage(); + cout<<endl; +} + +static void parseParameters(int argc, char** argv, ComboAddress& clientAddress) +{ + const std::array<struct option,16> longopts{{ + {"acl", required_argument, nullptr, 'a'}, + {"check-config", no_argument, nullptr, 1}, + {"client", no_argument, nullptr, 'c'}, + {"config", required_argument, nullptr, 'C'}, + {"disable-syslog", no_argument, nullptr, 2}, + {"execute", required_argument, nullptr, 'e'}, + {"gid", required_argument, nullptr, 'g'}, + {"help", no_argument, nullptr, 'h'}, + {"local", required_argument, nullptr, 'l'}, + {"log-timestamps", no_argument, nullptr, 4}, + {"setkey", required_argument, nullptr, 'k'}, + {"supervised", no_argument, nullptr, 3}, + {"uid", required_argument, nullptr, 'u'}, + {"verbose", no_argument, nullptr, 'v'}, + {"version", no_argument, nullptr, 'V'}, + {nullptr, 0, nullptr, 0} + }}; + int longindex = 0; + string optstring; + while (true) { + // NOLINTNEXTLINE(concurrency-mt-unsafe): only one thread at this point + int gotChar = getopt_long(argc, argv, "a:cC:e:g:hk:l:u:vV", longopts.data(), &longindex); + if (gotChar == -1) { + break; + } + switch (gotChar) { + case 1: + g_cmdLine.checkConfig = true; + break; + case 2: + dnsdist::logging::LoggingConfiguration::setSyslog(false); + break; + case 3: + g_cmdLine.beSupervised = true; + break; + case 4: + dnsdist::logging::LoggingConfiguration::setLogTimestamps(true); + break; + case 'C': + g_cmdLine.config = optarg; + break; + case 'c': + g_cmdLine.beClient = true; + break; + case 'e': + g_cmdLine.command = optarg; + break; + case 'g': + g_cmdLine.gid = optarg; + break; + case 'h': + cout<<"dnsdist "<<VERSION<<endl; + usage(); + cout<<"\n"; + // NOLINTNEXTLINE(concurrency-mt-unsafe): only one thread at this point + exit(EXIT_SUCCESS); + break; + case 'a': + optstring = optarg; + g_ACL.modify([optstring](NetmaskGroup& nmg) { nmg.addMask(optstring); }); + break; + case 'k': +#if defined HAVE_LIBSODIUM || defined(HAVE_LIBCRYPTO) + if (B64Decode(string(optarg), g_consoleKey) < 0) { + cerr<<"Unable to decode key '"<<optarg<<"'."<<endl; + // NOLINTNEXTLINE(concurrency-mt-unsafe): only one thread at this point exit(EXIT_FAILURE); - break; } +#else + cerr<<"dnsdist has been built without libsodium or libcrypto, -k/--setkey is unsupported."<<endl; + // NOLINTNEXTLINE(concurrency-mt-unsafe): only one thread at this point + exit(EXIT_FAILURE); +#endif + break; + case 'l': + g_cmdLine.locals.push_back(boost::trim_copy(string(optarg))); + break; + case 'u': + g_cmdLine.uid = optarg; + break; + case 'v': + g_verbose = true; + break; + case 'V': + reportFeatures(); + // NOLINTNEXTLINE(concurrency-mt-unsafe): only one thread at this point + exit(EXIT_SUCCESS); + break; + case '?': + //getopt_long printed an error message. + usage(); + // NOLINTNEXTLINE(concurrency-mt-unsafe): only one thread at this point + exit(EXIT_FAILURE); + break; } + } - argc -= optind; - argv += optind; - (void) argc; + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): argv + argv += optind; - for (auto p = argv; *p; ++p) { - if(g_cmdLine.beClient) { - clientAddress = ComboAddress(*p, 5199); - } else { - g_cmdLine.remotes.push_back(*p); + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): argv + for (const auto* ptr = argv; *ptr != nullptr; ++ptr) { + if (g_cmdLine.beClient) { + clientAddress = ComboAddress(*ptr, 5199); + } else { + g_cmdLine.remotes.emplace_back(*ptr); + } + } +} +static void setupPools() +{ + auto pools = g_pools.getCopy(); + { + bool precompute = false; + if (g_policy.getLocal()->getName() == "chashed") { + precompute = true; + } else { + for (const auto& entry: pools) { + if (entry.second->policy != nullptr && entry.second->policy->getName() == "chashed") { + precompute = true; + break ; + } } } + if (precompute) { + vinfolog("Pre-computing hashes for consistent hash load-balancing policy"); + // pre compute hashes + auto backends = g_dstates.getLocal(); + for (const auto& backend: *backends) { + if (backend->d_config.d_weight < 100) { + vinfolog("Warning, the backend '%s' has a very low weight (%d), which will not yield a good distribution of queries with the 'chashed' policy. Please consider raising it to at least '100'.", backend->getName(), backend->d_config.d_weight); + } + + backend->hash(); + } + } + } +} + +static void dropPrivileges() +{ + uid_t newgid = getegid(); + gid_t newuid = geteuid(); + + if (!g_cmdLine.gid.empty()) { + newgid = strToGID(g_cmdLine.gid); + } + + if (!g_cmdLine.uid.empty()) { + newuid = strToUID(g_cmdLine.uid); + } + + bool retainedCapabilities = true; + if (!g_capabilitiesToRetain.empty() && + (getegid() != newgid || geteuid() != newuid)) { + retainedCapabilities = keepCapabilitiesAfterSwitchingIDs(); + } + + if (getegid() != newgid) { + if (running_in_service_mgr()) { + errlog("--gid/-g set on command-line, but dnsdist was started as a systemd service. Use the 'Group' setting in the systemd unit file to set the group to run as"); + _exit(EXIT_FAILURE); + } + dropGroupPrivs(newgid); + } + + if (geteuid() != newuid) { + if (running_in_service_mgr()) { + errlog("--uid/-u set on command-line, but dnsdist was started as a systemd service. Use the 'User' setting in the systemd unit file to set the user to run as"); + _exit(EXIT_FAILURE); + } + dropUserPrivs(newuid); + } + + if (retainedCapabilities) { + dropCapabilitiesAfterSwitchingIDs(); + } + + try { + /* we might still have capabilities remaining, + for example if we have been started as root + without --uid or --gid (please don't do that) + or as an unprivileged user with ambient + capabilities like CAP_NET_BIND_SERVICE. + */ + dropCapabilities(g_capabilitiesToRetain); + } + catch (const std::exception& e) { + warnlog("%s", e.what()); + } +} + +static void initFrontends() +{ + if (!g_cmdLine.locals.empty()) { + for (auto it = g_frontends.begin(); it != g_frontends.end(); ) { + /* DoH, DoT and DNSCrypt frontends are separate */ + if ((*it)->dohFrontend == nullptr && (*it)->tlsFrontend == nullptr && (*it)->dnscryptCtx == nullptr && (*it)->doqFrontend == nullptr && (*it)->doh3Frontend == nullptr) { + it = g_frontends.erase(it); + } + else { + ++it; + } + } + + for (const auto& loc : g_cmdLine.locals) { + /* UDP */ + g_frontends.emplace_back(std::make_unique<ClientState>(ComboAddress(loc, 53), false, false, 0, "", std::set<int>{}, true)); + /* TCP */ + g_frontends.emplace_back(std::make_unique<ClientState>(ComboAddress(loc, 53), true, false, 0, "", std::set<int>{}, true)); + } + } + + if (g_frontends.empty()) { + /* UDP */ + g_frontends.emplace_back(std::make_unique<ClientState>(ComboAddress("127.0.0.1", 53), false, false, 0, "", std::set<int>{}, true)); + /* TCP */ + g_frontends.emplace_back(std::make_unique<ClientState>(ComboAddress("127.0.0.1", 53), true, false, 0, "", std::set<int>{}, true)); + } +} + +namespace dnsdist +{ +static void startFrontends() +{ +#ifdef HAVE_XSK + for (auto& xskContext : dnsdist::xsk::g_xsk) { + std::thread xskThread(dnsdist::xsk::XskRouter, std::move(xskContext)); + xskThread.detach(); + } +#endif /* HAVE_XSK */ + + std::vector<ClientState*> tcpStates; + std::vector<ClientState*> udpStates; + for (auto& clientState : g_frontends) { +#ifdef HAVE_XSK + if (clientState->xskInfo) { + dnsdist::xsk::addDestinationAddress(clientState->local); + + std::thread xskCT(dnsdist::xsk::XskClientThread, clientState.get()); + if (!clientState->cpus.empty()) { + mapThreadToCPUList(xskCT.native_handle(), clientState->cpus); + } + xskCT.detach(); + } +#endif /* HAVE_XSK */ + + if (clientState->dohFrontend != nullptr && clientState->dohFrontend->d_library == "h2o") { +#ifdef HAVE_DNS_OVER_HTTPS +#ifdef HAVE_LIBH2OEVLOOP + std::thread dotThreadHandle(dohThread, clientState.get()); + if (!clientState->cpus.empty()) { + mapThreadToCPUList(dotThreadHandle.native_handle(), clientState->cpus); + } + dotThreadHandle.detach(); +#endif /* HAVE_LIBH2OEVLOOP */ +#endif /* HAVE_DNS_OVER_HTTPS */ + continue; + } + if (clientState->doqFrontend != nullptr) { +#ifdef HAVE_DNS_OVER_QUIC + std::thread doqThreadHandle(doqThread, clientState.get()); + if (!clientState->cpus.empty()) { + mapThreadToCPUList(doqThreadHandle.native_handle(), clientState->cpus); + } + doqThreadHandle.detach(); +#endif /* HAVE_DNS_OVER_QUIC */ + continue; + } + if (clientState->doh3Frontend != nullptr) { +#ifdef HAVE_DNS_OVER_HTTP3 + std::thread doh3ThreadHandle(doh3Thread, clientState.get()); + if (!clientState->cpus.empty()) { + mapThreadToCPUList(doh3ThreadHandle.native_handle(), clientState->cpus); + } + doh3ThreadHandle.detach(); +#endif /* HAVE_DNS_OVER_HTTP3 */ + continue; + } + if (clientState->udpFD >= 0) { +#ifdef USE_SINGLE_ACCEPTOR_THREAD + udpStates.push_back(clientState.get()); +#else /* USE_SINGLE_ACCEPTOR_THREAD */ + std::thread udpClientThreadHandle(udpClientThread, std::vector<ClientState*>{ clientState.get() }); + if (!clientState->cpus.empty()) { + mapThreadToCPUList(udpClientThreadHandle.native_handle(), clientState->cpus); + } + udpClientThreadHandle.detach(); +#endif /* USE_SINGLE_ACCEPTOR_THREAD */ + } + else if (clientState->tcpFD >= 0) { +#ifdef USE_SINGLE_ACCEPTOR_THREAD + tcpStates.push_back(clientState.get()); +#else /* USE_SINGLE_ACCEPTOR_THREAD */ + std::thread tcpAcceptorThreadHandle(tcpAcceptorThread, std::vector<ClientState*>{clientState.get() }); + if (!clientState->cpus.empty()) { + mapThreadToCPUList(tcpAcceptorThreadHandle.native_handle(), clientState->cpus); + } + tcpAcceptorThreadHandle.detach(); +#endif /* USE_SINGLE_ACCEPTOR_THREAD */ + } + } +#ifdef USE_SINGLE_ACCEPTOR_THREAD + if (!udpStates.empty()) { + std::thread udpThreadHandle(udpClientThread, udpStates); + udpThreadHandle.detach(); + } + if (!tcpStates.empty()) { + g_tcpclientthreads = std::make_unique<TCPClientCollection>(1, tcpStates); + } +#endif /* USE_SINGLE_ACCEPTOR_THREAD */ +} +} + +int main(int argc, char** argv) +{ + try { + size_t udpBindsCount = 0; + size_t tcpBindsCount = 0; +#ifdef HAVE_LIBEDIT +#ifndef DISABLE_COMPLETION + rl_attempted_completion_function = my_completion; + rl_completion_append_character = 0; +#endif /* DISABLE_COMPLETION */ +#endif /* HAVE_LIBEDIT */ + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast): SIG_IGN macro + signal(SIGPIPE, SIG_IGN); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast): SIG_IGN macro + signal(SIGCHLD, SIG_IGN); + signal(SIGTERM, sigTermHandler); + + openlog("dnsdist", LOG_PID|LOG_NDELAY, LOG_DAEMON); + +#ifdef HAVE_LIBSODIUM + if (sodium_init() == -1) { + cerr<<"Unable to initialize crypto library"<<endl; + // NOLINTNEXTLINE(concurrency-mt-unsafe): only on thread at this point + exit(EXIT_FAILURE); + } +#endif + dnsdist::initRandom(); + g_hashperturb = dnsdist::getRandomValue(0xffffffff); + +#ifdef HAVE_XSK + try { + dnsdist::xsk::clearDestinationAddresses(); + } + catch (const std::exception& exp) { + /* silently handle failures: at this point we don't even know if XSK is enabled, + and we might not have the correct map (not the default one). */ + } +#endif /* HAVE_XSK */ + + ComboAddress clientAddress = ComboAddress(); + g_cmdLine.config=SYSCONFDIR "/dnsdist.conf"; + + parseParameters(argc, argv, clientAddress); ServerPolicy leastOutstandingPol{"leastOutstanding", leastOutstanding, false}; g_policy.setState(leastOutstandingPol); if (g_cmdLine.beClient || !g_cmdLine.command.empty()) { setupLua(*(g_lua.lock()), true, false, g_cmdLine.config); - if (clientAddress != ComboAddress()) + if (clientAddress != ComboAddress()) { g_serverControl = clientAddress; + } doClient(g_serverControl, g_cmdLine.command); #ifdef COVERAGE exit(EXIT_SUCCESS); @@ -2671,9 +3245,10 @@ int main(int argc, char** argv) } auto acl = g_ACL.getCopy(); - if(acl.empty()) { - for(auto& addr : {"127.0.0.0/8", "10.0.0.0/8", "100.64.0.0/10", "169.254.0.0/16", "192.168.0.0/16", "172.16.0.0/12", "::1/128", "fc00::/7", "fe80::/10"}) + if (acl.empty()) { + for (const auto& addr : {"127.0.0.0/8", "10.0.0.0/8", "100.64.0.0/10", "169.254.0.0/16", "192.168.0.0/16", "172.16.0.0/12", "::1/128", "fc00::/7", "fe80::/10"}) { acl.addMask(addr); + } g_ACL.setState(acl); } @@ -2702,67 +3277,18 @@ int main(int argc, char** argv) auto todo = setupLua(*(g_lua.lock()), false, false, g_cmdLine.config); - auto localPools = g_pools.getCopy(); - { - bool precompute = false; - if (g_policy.getLocal()->getName() == "chashed") { - precompute = true; - } else { - for (const auto& entry: localPools) { - if (entry.second->policy != nullptr && entry.second->policy->getName() == "chashed") { - precompute = true; - break ; - } - } - } - if (precompute) { - vinfolog("Pre-computing hashes for consistent hash load-balancing policy"); - // pre compute hashes - auto backends = g_dstates.getLocal(); - for (auto& backend: *backends) { - if (backend->d_config.d_weight < 100) { - vinfolog("Warning, the backend '%s' has a very low weight (%d), which will not yield a good distribution of queries with the 'chashed' policy. Please consider raising it to at least '100'.", backend->getName(), backend->d_config.d_weight); - } - - backend->hash(); - } - } - } - - if (!g_cmdLine.locals.empty()) { - for (auto it = g_frontends.begin(); it != g_frontends.end(); ) { - /* DoH, DoT and DNSCrypt frontends are separate */ - if ((*it)->dohFrontend == nullptr && (*it)->tlsFrontend == nullptr && (*it)->dnscryptCtx == nullptr) { - it = g_frontends.erase(it); - } - else { - ++it; - } - } + setupPools(); - for (const auto& loc : g_cmdLine.locals) { - /* UDP */ - g_frontends.push_back(std::unique_ptr<ClientState>(new ClientState(ComboAddress(loc, 53), false, false, 0, "", {}))); - /* TCP */ - g_frontends.push_back(std::unique_ptr<ClientState>(new ClientState(ComboAddress(loc, 53), true, false, 0, "", {}))); - } - } - - if (g_frontends.empty()) { - /* UDP */ - g_frontends.push_back(std::unique_ptr<ClientState>(new ClientState(ComboAddress("127.0.0.1", 53), false, false, 0, "", {}))); - /* TCP */ - g_frontends.push_back(std::unique_ptr<ClientState>(new ClientState(ComboAddress("127.0.0.1", 53), true, false, 0, "", {}))); - } + initFrontends(); g_configurationDone = true; g_rings.init(); - for(auto& frontend : g_frontends) { + for (auto& frontend : g_frontends) { setUpLocalBind(frontend); - if (frontend->tcp == false) { + if (!frontend->tcp) { ++udpBindsCount; } else { @@ -2770,88 +3296,43 @@ int main(int argc, char** argv) } } - vector<string> vec; - std::string acls; - g_ACL.getLocal()->toStringVector(&vec); - for(const auto& s : vec) { - if (!acls.empty()) - acls += ", "; - acls += s; + { + std::string acls; + auto aclEntries = g_ACL.getLocal()->toStringVector(); + for (const auto& aclEntry : aclEntries) { + if (!acls.empty()) { + acls += ", "; + } + acls += aclEntry; + } + infolog("ACL allowing queries from: %s", acls); } - infolog("ACL allowing queries from: %s", acls.c_str()); - vec.clear(); - acls.clear(); - g_consoleACL.getLocal()->toStringVector(&vec); - for (const auto& entry : vec) { - if (!acls.empty()) { - acls += ", "; + { + std::string acls; + auto aclEntries = g_consoleACL.getLocal()->toStringVector(); + for (const auto& entry : aclEntries) { + if (!acls.empty()) { + acls += ", "; + } + acls += entry; } - acls += entry; + infolog("Console ACL allowing connections from: %s", acls.c_str()); } - infolog("Console ACL allowing connections from: %s", acls.c_str()); -#ifdef HAVE_LIBSODIUM +#if defined(HAVE_LIBSODIUM) || defined(HAVE_LIBCRYPTO) if (g_consoleEnabled && g_consoleKey.empty()) { warnlog("Warning, the console has been enabled via 'controlSocket()' but no key has been set with 'setKey()' so all connections will fail until a key has been set"); } #endif - uid_t newgid=getegid(); - gid_t newuid=geteuid(); - - if (!g_cmdLine.gid.empty()) { - newgid = strToGID(g_cmdLine.gid); - } - - if (!g_cmdLine.uid.empty()) { - newuid = strToUID(g_cmdLine.uid); - } - - bool retainedCapabilities = true; - if (!g_capabilitiesToRetain.empty() && - (getegid() != newgid || geteuid() != newuid)) { - retainedCapabilities = keepCapabilitiesAfterSwitchingIDs(); - } - - if (getegid() != newgid) { - if (running_in_service_mgr()) { - errlog("--gid/-g set on command-line, but dnsdist was started as a systemd service. Use the 'Group' setting in the systemd unit file to set the group to run as"); - _exit(EXIT_FAILURE); - } - dropGroupPrivs(newgid); - } - - if (geteuid() != newuid) { - if (running_in_service_mgr()) { - errlog("--uid/-u set on command-line, but dnsdist was started as a systemd service. Use the 'User' setting in the systemd unit file to set the user to run as"); - _exit(EXIT_FAILURE); - } - dropUserPrivs(newuid); - } - - if (retainedCapabilities) { - dropCapabilitiesAfterSwitchingIDs(); - } - - try { - /* we might still have capabilities remaining, - for example if we have been started as root - without --uid or --gid (please don't do that) - or as an unprivileged user with ambient - capabilities like CAP_NET_BIND_SERVICE. - */ - dropCapabilities(g_capabilitiesToRetain); - } - catch (const std::exception& e) { - warnlog("%s", e.what()); - } + dropPrivileges(); /* this need to be done _after_ dropping privileges */ #ifndef DISABLE_DELAY_PIPE - g_delay = new DelayPipe<DelayedPacket>(); + g_delay = std::make_unique<DelayPipe<DelayedPacket>>(); #endif /* DISABLE_DELAY_PIPE */ - if (g_snmpAgent) { + if (g_snmpAgent != nullptr) { g_snmpAgent->run(); } @@ -2870,16 +3351,18 @@ int main(int argc, char** argv) g_tcpclientthreads = std::make_unique<TCPClientCollection>(*g_maxTCPClientThreads, std::vector<ClientState*>()); #endif +#if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2) initDoHWorkers(); +#endif - for (auto& t : todo) { - t(); + for (auto& todoItem : todo) { + todoItem(); } - localPools = g_pools.getCopy(); + auto localPools = g_pools.getCopy(); /* create the default pool no matter what */ createPoolIfNotExists(localPools, ""); - if (g_cmdLine.remotes.size()) { + if (!g_cmdLine.remotes.empty()) { for (const auto& address : g_cmdLine.remotes) { DownstreamState::Config config; config.remote = ComboAddress(address, 53); @@ -2902,12 +3385,14 @@ int main(int argc, char** argv) auto states = g_dstates.getCopy(); // it is a copy, but the internal shared_ptrs are the real deal auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent(states.size())); for (auto& dss : states) { + if (dss->d_config.availability == DownstreamState::Availability::Auto || dss->d_config.availability == DownstreamState::Availability::Lazy) { if (dss->d_config.availability == DownstreamState::Availability::Auto) { dss->d_nextCheck = dss->d_config.checkInterval; } if (!queueHealthCheck(mplexer, dss, true)) { + dss->submitHealthCheckResult(true, false); dss->setUpStatus(false); warnlog("Marking downstream %s as 'down'", dss->getNameWithAddr()); } @@ -2916,51 +3401,8 @@ int main(int argc, char** argv) handleQueuedHealthChecks(*mplexer, true); } - std::vector<ClientState*> tcpStates; - std::vector<ClientState*> udpStates; - for(auto& cs : g_frontends) { - if (cs->dohFrontend != nullptr) { -#ifdef HAVE_DNS_OVER_HTTPS - std::thread t1(dohThread, cs.get()); - if (!cs->cpus.empty()) { - mapThreadToCPUList(t1.native_handle(), cs->cpus); - } - t1.detach(); -#endif /* HAVE_DNS_OVER_HTTPS */ - continue; - } - if (cs->udpFD >= 0) { -#ifdef USE_SINGLE_ACCEPTOR_THREAD - udpStates.push_back(cs.get()); -#else /* USE_SINGLE_ACCEPTOR_THREAD */ - thread t1(udpClientThread, std::vector<ClientState*>{ cs.get() }); - if (!cs->cpus.empty()) { - mapThreadToCPUList(t1.native_handle(), cs->cpus); - } - t1.detach(); -#endif /* USE_SINGLE_ACCEPTOR_THREAD */ - } - else if (cs->tcpFD >= 0) { -#ifdef USE_SINGLE_ACCEPTOR_THREAD - tcpStates.push_back(cs.get()); -#else /* USE_SINGLE_ACCEPTOR_THREAD */ - thread t1(tcpAcceptorThread, std::vector<ClientState*>{cs.get() }); - if (!cs->cpus.empty()) { - mapThreadToCPUList(t1.native_handle(), cs->cpus); - } - t1.detach(); -#endif /* USE_SINGLE_ACCEPTOR_THREAD */ - } - } -#ifdef USE_SINGLE_ACCEPTOR_THREAD - if (!udpStates.empty()) { - thread udp(udpClientThread, udpStates); - udp.detach(); - } - if (!tcpStates.empty()) { - g_tcpclientthreads = std::make_unique<TCPClientCollection>(1, tcpStates); - } -#endif /* USE_SINGLE_ACCEPTOR_THREAD */ + dnsdist::startFrontends(); + dnsdist::ServiceDiscovery::run(); #ifndef DISABLE_CARBON @@ -3013,6 +3455,7 @@ int main(int argc, char** argv) errlog("Fatal pdns error: %s", ae.reason); } #ifdef COVERAGE + cleanupLuaObjects(); exit(EXIT_FAILURE); #else _exit(EXIT_FAILURE); @@ -3022,6 +3465,7 @@ int main(int argc, char** argv) { errlog("Fatal error: %s", e.what()); #ifdef COVERAGE + cleanupLuaObjects(); exit(EXIT_FAILURE); #else _exit(EXIT_FAILURE); @@ -3031,6 +3475,7 @@ int main(int argc, char** argv) { errlog("Fatal pdns error: %s", ae.reason); #ifdef COVERAGE + cleanupLuaObjects(); exit(EXIT_FAILURE); #else _exit(EXIT_FAILURE); |