diff options
Diffstat (limited to 'doh.cc')
-rw-r--r-- | doh.cc | 1174 |
1 files changed, 581 insertions, 593 deletions
@@ -2,6 +2,7 @@ #include "doh.hh" #ifdef HAVE_DNS_OVER_HTTPS +#ifdef HAVE_LIBH2OEVLOOP #define H2O_USE_EPOLL 1 #include <cerrno> @@ -10,7 +11,6 @@ #include <boost/algorithm/string.hpp> #include <h2o.h> -//#include <h2o/http1.h> #include <h2o/http2.h> #include <openssl/err.h> @@ -25,7 +25,9 @@ #include "dns.hh" #include "dolog.hh" #include "dnsdist-concurrent-connections.hh" +#include "dnsdist-dnsparser.hh" #include "dnsdist-ecs.hh" +#include "dnsdist-metrics.hh" #include "dnsdist-proxy-protocol.hh" #include "dnsdist-rules.hh" #include "dnsdist-xpf.hh" @@ -55,7 +57,7 @@ */ /* 'Intermediate' compatibility from https://wiki.mozilla.org/Security/Server_Side_TLS#Intermediate_compatibility_.28default.29 */ -#define DOH_DEFAULT_CIPHERS "ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256:ECDHE-ECDSA-AES128-SHA:ECDHE-RSA-AES256-SHA384:ECDHE-RSA-AES128-SHA:ECDHE-ECDSA-AES256-SHA384:ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA:DHE-RSA-AES128-SHA256:DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA256:DHE-RSA-AES256-SHA:ECDHE-ECDSA-DES-CBC3-SHA:ECDHE-RSA-DES-CBC3-SHA:EDH-RSA-DES-CBC3-SHA:AES128-GCM-SHA256:AES256-GCM-SHA384:AES128-SHA256:AES256-SHA256:AES128-SHA:AES256-SHA:DES-CBC3-SHA:!DSS" +static constexpr string_view DOH_DEFAULT_CIPHERS = "ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256:ECDHE-ECDSA-AES128-SHA:ECDHE-RSA-AES256-SHA384:ECDHE-RSA-AES128-SHA:ECDHE-ECDSA-AES256-SHA384:ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA:DHE-RSA-AES128-SHA256:DHE-RSA-AES128-SHA:DHE-RSA-AES256-SHA256:DHE-RSA-AES256-SHA:ECDHE-ECDSA-DES-CBC3-SHA:ECDHE-RSA-DES-CBC3-SHA:EDH-RSA-DES-CBC3-SHA:AES128-GCM-SHA256:AES256-GCM-SHA384:AES128-SHA256:AES256-SHA256:AES128-SHA:AES256-SHA:DES-CBC3-SHA:!DSS"; class DOHAcceptContext { @@ -66,7 +68,9 @@ public: d_rotatingTicketsKey.clear(); } DOHAcceptContext(const DOHAcceptContext&) = delete; + DOHAcceptContext(DOHAcceptContext&&) = delete; DOHAcceptContext& operator=(const DOHAcceptContext&) = delete; + DOHAcceptContext& operator=(DOHAcceptContext&&) = delete; h2o_accept_ctx_t* get() { @@ -79,19 +83,19 @@ public: d_h2o_accept_ctx.ssl_ctx = nullptr; } - void decrementConcurrentConnections() + void decrementConcurrentConnections() const { if (d_cs != nullptr) { --d_cs->tcpCurrentConnections; } } - time_t getNextTicketsKeyRotation() const + [[nodiscard]] time_t getNextTicketsKeyRotation() const { return d_ticketsKeyNextRotation; } - size_t getTicketsKeysCount() const + [[nodiscard]] size_t getTicketsKeysCount() const { size_t res = 0; if (d_ticketKeys) { @@ -155,57 +159,41 @@ public: std::map<int, std::string> d_ocspResponses; std::unique_ptr<OpenSSLTLSTicketKeysRing> d_ticketKeys{nullptr}; - std::unique_ptr<FILE, int(*)(FILE*)> d_keyLogFile{nullptr, fclose}; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + pdns::UniqueFilePtr d_keyLogFile{nullptr}; ClientState* d_cs{nullptr}; time_t d_ticketsKeyRotationDelay{0}; private: - h2o_accept_ctx_t d_h2o_accept_ctx; - std::atomic<uint64_t> d_refcnt{1}; + h2o_accept_ctx_t d_h2o_accept_ctx{}; time_t d_ticketsKeyNextRotation{0}; std::atomic_flag d_rotatingTicketsKey; }; +struct DOHUnit; + // we create one of these per thread, and pass around a pointer to it // through the bowels of h2o struct DOHServerConfig { DOHServerConfig(uint32_t idleTimeout, uint32_t internalPipeBufferSize): accept_ctx(std::make_shared<DOHAcceptContext>()) { - int fd[2]; #ifndef USE_SINGLE_ACCEPTOR_THREAD - if (pipe(fd) < 0) { - unixDie("Creating a pipe for DNS over HTTPS"); - } - dohquerypair[0] = fd[1]; - dohquerypair[1] = fd[0]; - - setNonBlocking(dohquerypair[0]); - if (internalPipeBufferSize > 0) { - setPipeBufferSize(dohquerypair[0], internalPipeBufferSize); + { + auto [sender, receiver] = pdns::channel::createObjectQueue<DOHUnit>(pdns::channel::SenderBlockingMode::SenderNonBlocking, pdns::channel::ReceiverBlockingMode::ReceiverBlocking, internalPipeBufferSize); + d_querySender = std::move(sender); + d_queryReceiver = std::move(receiver); } #endif /* USE_SINGLE_ACCEPTOR_THREAD */ - if (pipe(fd) < 0) { -#ifndef USE_SINGLE_ACCEPTOR_THREAD - close(dohquerypair[0]); - close(dohquerypair[1]); -#endif /* USE_SINGLE_ACCEPTOR_THREAD */ - unixDie("Creating a pipe for DNS over HTTPS"); - } - - dohresponsepair[0] = fd[1]; - dohresponsepair[1] = fd[0]; - - setNonBlocking(dohresponsepair[0]); - if (internalPipeBufferSize > 0) { - setPipeBufferSize(dohresponsepair[0], internalPipeBufferSize); + { + auto [sender, receiver] = pdns::channel::createObjectQueue<DOHUnit>(pdns::channel::SenderBlockingMode::SenderNonBlocking, pdns::channel::ReceiverBlockingMode::ReceiverNonBlocking, internalPipeBufferSize); + d_responseSender = std::move(sender); + d_responseReceiver = std::move(receiver); } - setNonBlocking(dohresponsepair[1]); - h2o_config_init(&h2o_config); - h2o_config.http2.idle_timeout = idleTimeout * 1000; + h2o_config.http2.idle_timeout = static_cast<uint64_t>(idleTimeout) * 1000; /* if you came here for a way to make the number of concurrent streams (concurrent requests per connection) configurable, or even just bigger, I have bad news for you. h2o_config.http2.max_concurrent_requests_per_connection (default of 100) is capped by @@ -215,67 +203,106 @@ struct DOHServerConfig */ } DOHServerConfig(const DOHServerConfig&) = delete; + DOHServerConfig(DOHServerConfig&&) = delete; DOHServerConfig& operator=(const DOHServerConfig&) = delete; + DOHServerConfig& operator=(DOHServerConfig&&) = delete; + ~DOHServerConfig() = default; LocalHolders holders; std::set<std::string, std::less<>> paths; - h2o_globalconf_t h2o_config; - h2o_context_t h2o_ctx; + h2o_globalconf_t h2o_config{}; + h2o_context_t h2o_ctx{}; std::shared_ptr<DOHAcceptContext> accept_ctx{nullptr}; - ClientState* cs{nullptr}; - std::shared_ptr<DOHFrontend> df{nullptr}; + ClientState* clientState{nullptr}; + std::shared_ptr<DOHFrontend> dohFrontend{nullptr}; #ifndef USE_SINGLE_ACCEPTOR_THREAD - int dohquerypair[2]{-1,-1}; + pdns::channel::Sender<DOHUnit> d_querySender; + pdns::channel::Receiver<DOHUnit> d_queryReceiver; #endif /* USE_SINGLE_ACCEPTOR_THREAD */ - int dohresponsepair[2]{-1,-1}; + pdns::channel::Sender<DOHUnit> d_responseSender; + pdns::channel::Receiver<DOHUnit> d_responseReceiver; }; -/* This internal function sends back the object to the main thread to send a reply. - The caller should NOT release or touch the unit after calling this function */ -static void sendDoHUnitToTheMainThread(DOHUnitUniquePtr&& du, const char* description) +struct DOHUnit : public DOHUnitInterface { - /* taking a naked pointer since we are about to send that pointer over a pipe */ - auto ptr = du.release(); - /* increasing the reference counter. This should not be strictly needed because - we already hold a reference and will only release it if we failed to send the - pointer over the pipe, but TSAN seems confused when the responder thread gets - a reply from a backend before the send() syscall sending the corresponding query - to that backend has returned in the initial thread. - The memory barrier needed to increase that counter seems to work around that. + DOHUnit(PacketBuffer&& query_, std::string&& path_, std::string&& host_): path(std::move(path_)), host(std::move(host_)), query(std::move(query_)) + { + ids.ednsAdded = false; + } + ~DOHUnit() override + { + if (self != nullptr) { + *self = nullptr; + } + } + + DOHUnit(const DOHUnit&) = delete; + DOHUnit(DOHUnit&&) = delete; + DOHUnit& operator=(const DOHUnit&) = delete; + DOHUnit& operator=(DOHUnit&&) = delete; + + InternalQueryState ids; + std::string sni; + std::string path; + std::string scheme; + std::string host; + std::string contentType; + PacketBuffer query; + PacketBuffer response; + std::unique_ptr<std::unordered_map<std::string, std::string>> headers; + st_h2o_req_t* req{nullptr}; + DOHUnit** self{nullptr}; + DOHServerConfig* dsc{nullptr}; + pdns::channel::Sender<DOHUnit>* responseSender{nullptr}; + size_t query_at{0}; + int rsock{-1}; + /* the status_code is set from + processDOHQuery() (which is executed in + the DOH client thread) so that the correct + response can be sent in on_dnsdist(), + after the DOHUnit has been passed back to + the main DoH thread. */ - ptr->get(); - static_assert(sizeof(ptr) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail"); + uint16_t status_code{200}; + /* whether the query was re-sent to the backend over + TCP after receiving a truncated answer over UDP */ + bool tcp{false}; + bool truncated{false}; + + [[nodiscard]] std::string getHTTPPath() const override; + [[nodiscard]] std::string getHTTPQueryString() const override; + [[nodiscard]] const std::string& getHTTPHost() const override; + [[nodiscard]] const std::string& getHTTPScheme() const override; + [[nodiscard]] const std::unordered_map<std::string, std::string>& getHTTPHeaders() const override; + void setHTTPResponse(uint16_t statusCode, PacketBuffer&& body, const std::string& contentType="") override; + void handleTimeout() override; + void handleUDPResponse(PacketBuffer&& response, InternalQueryState&& state, [[maybe_unused]] const std::shared_ptr<DownstreamState>& downstream) override; +}; +using DOHUnitUniquePtr = std::unique_ptr<DOHUnit>; - ssize_t sent = write(ptr->rsock, &ptr, sizeof(ptr)); - if (sent != sizeof(ptr)) { - if (errno == EAGAIN || errno == EWOULDBLOCK) { - ++g_stats.dohResponsePipeFull; +/* This internal function sends back the object to the main thread to send a reply. + The caller should NOT release or touch the unit after calling this function */ +static void sendDoHUnitToTheMainThread(DOHUnitUniquePtr&& dohUnit, const char* description) +{ + if (dohUnit->responseSender == nullptr) { + return; + } + try { + if (!dohUnit->responseSender->send(std::move(dohUnit))) { + ++dnsdist::metrics::g_stats.dohResponsePipeFull; vinfolog("Unable to pass a %s to the DoH worker thread because the pipe is full", description); } - else { - vinfolog("Unable to pass a %s to the DoH worker thread because we couldn't write to the pipe: %s", description, stringerror()); - } - - /* we fail to write over the pipe so we do not need to hold to that ref anymore */ - ptr->release(); + } catch (const std::exception& e) { + vinfolog("Unable to pass a %s to the DoH worker thread because we couldn't write to the pipe: %s", description, e.what()); } - /* we decrement the counter incremented above at the beginning of that function */ - ptr->release(); } /* This function is called from other threads than the main DoH one, - instructing it to send a 502 error to the client. - It takes ownership of the unit. */ -void handleDOHTimeout(DOHUnitUniquePtr&& oldDU) + instructing it to send a 502 error to the client. */ +void DOHUnit::handleTimeout() { - if (oldDU == nullptr) { - return; - } - - /* we are about to erase an existing DU */ - oldDU->status_code = 502; - - sendDoHUnitToTheMainThread(std::move(oldDU), "DoH timeout"); + status_code = 502; + sendDoHUnitToTheMainThread(std::unique_ptr<DOHUnit>(this), "DoH timeout"); } struct DOHConnection @@ -292,10 +319,10 @@ static thread_local std::unordered_map<int, DOHConnection> t_conns; static void on_socketclose(void *data) { - auto conn = reinterpret_cast<DOHConnection*>(data); + auto* conn = static_cast<DOHConnection*>(data); if (conn != nullptr) { if (conn->d_acceptCtx) { - struct timeval now; + struct timeval now{}; gettimeofday(&now, nullptr); auto diff = now - conn->d_connectionStartTime; @@ -352,17 +379,15 @@ static const std::string& getReasonFromStatusCode(uint16_t statusCode) }; static const std::string unknown = "Unknown"; - const auto it = reasons.find(statusCode); - if (it == reasons.end()) { + const auto reasonIt = reasons.find(statusCode); + if (reasonIt == reasons.end()) { return unknown; } - else { - return it->second; - } + return reasonIt->second; } /* Always called from the main DoH thread */ -static void handleResponse(DOHFrontend& df, st_h2o_req_t* req, uint16_t statusCode, const PacketBuffer& response, const std::unordered_map<std::string, std::string>& customResponseHeaders, const std::string& contentType, bool addContentType) +static void handleResponse(DOHFrontend& dohFrontend, st_h2o_req_t* req, uint16_t statusCode, const PacketBuffer& response, const std::unordered_map<std::string, std::string>& customResponseHeaders, const std::string& contentType, bool addContentType) { constexpr int overwrite_if_exists = 1; constexpr int maybe_token = 1; @@ -371,7 +396,7 @@ static void handleResponse(DOHFrontend& df, st_h2o_req_t* req, uint16_t statusCo } if (statusCode == 200) { - ++df.d_validresponses; + ++dohFrontend.d_validresponses; req->res.status = 200; if (addContentType) { @@ -380,12 +405,14 @@ static void handleResponse(DOHFrontend& df, st_h2o_req_t* req, uint16_t statusCo } else { /* we need to duplicate the header content because h2o keeps a pointer and we will be deleted before the response has been sent */ - h2o_iovec_t ct = h2o_strdup(&req->pool, contentType.c_str(), contentType.size()); - h2o_add_header(&req->pool, &req->res.headers, H2O_TOKEN_CONTENT_TYPE, nullptr, ct.base, ct.len); + h2o_iovec_t contentTypeVect = h2o_strdup(&req->pool, contentType.c_str(), contentType.size()); + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-array-to-pointer-decay,cppcoreguidelines-pro-bounds-pointer-arithmetic): h2o API + h2o_add_header(&req->pool, &req->res.headers, H2O_TOKEN_CONTENT_TYPE, nullptr, contentTypeVect.base, contentTypeVect.len); } } - if (df.d_sendCacheControlHeaders && response.size() > sizeof(dnsheader)) { + if (dohFrontend.d_sendCacheControlHeaders && response.size() > sizeof(dnsheader)) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) uint32_t minTTL = getDNSPacketMinTTL(reinterpret_cast<const char*>(response.data()), response.size()); if (minTTL != std::numeric_limits<uint32_t>::max()) { std::string cacheControlValue = "max-age=" + std::to_string(minTTL); @@ -396,18 +423,21 @@ static void handleResponse(DOHFrontend& df, st_h2o_req_t* req, uint16_t statusCo } req->res.content_length = response.size(); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): h2o API h2o_send_inline(req, reinterpret_cast<const char*>(response.data()), response.size()); } else if (statusCode >= 300 && statusCode < 400) { /* in that case the response is actually a URL */ /* we need to duplicate the URL because h2o uses it for the location header, keeping a pointer, and we will be deleted before the response has been sent */ + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): h2o API h2o_iovec_t url = h2o_strdup(&req->pool, reinterpret_cast<const char*>(response.data()), response.size()); h2o_send_redirect(req, statusCode, getReasonFromStatusCode(statusCode).c_str(), url.base, url.len); - ++df.d_redirectresponses; + ++dohFrontend.d_redirectresponses; } else { // we need to make sure it's null-terminated */ if (!response.empty() && response.at(response.size() - 1) == 0) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): h2o API h2o_send_error_generic(req, statusCode, getReasonFromStatusCode(statusCode).c_str(), reinterpret_cast<const char*>(response.data()), H2O_SEND_ERROR_KEEP_HEADERS); } else { @@ -416,7 +446,7 @@ static void handleResponse(DOHFrontend& df, st_h2o_req_t* req, uint16_t statusCo h2o_send_error_400(req, getReasonFromStatusCode(statusCode).c_str(), "invalid DNS query" , 0); break; case 403: - h2o_send_error_403(req, getReasonFromStatusCode(statusCode).c_str(), "dns query not allowed", 0); + h2o_send_error_403(req, getReasonFromStatusCode(statusCode).c_str(), "DoH query not allowed", 0); break; case 502: h2o_send_error_502(req, getReasonFromStatusCode(statusCode).c_str(), "no downstream server available", 0); @@ -429,18 +459,27 @@ static void handleResponse(DOHFrontend& df, st_h2o_req_t* req, uint16_t statusCo } } - ++df.d_errorresponses; + ++dohFrontend.d_errorresponses; } } -class DoHTCPCrossQuerySender : public TCPQuerySender +static std::unique_ptr<DOHUnit> getDUFromIDS(InternalQueryState& ids) { -public: - DoHTCPCrossQuerySender() - { - } + auto dohUnit = std::unique_ptr<DOHUnit>(dynamic_cast<DOHUnit*>(ids.du.release())); + return dohUnit; +} - bool active() const override +class DoHTCPCrossQuerySender final : public TCPQuerySender +{ +public: + DoHTCPCrossQuerySender() = default; + DoHTCPCrossQuerySender(const DoHTCPCrossQuerySender&) = delete; + DoHTCPCrossQuerySender(DoHTCPCrossQuerySender&&) = delete; + DoHTCPCrossQuerySender& operator=(const DoHTCPCrossQuerySender&) = delete; + DoHTCPCrossQuerySender& operator=(DoHTCPCrossQuerySender&&) = delete; + ~DoHTCPCrossQuerySender() final = default; + + [[nodiscard]] bool active() const override { return true; } @@ -451,28 +490,29 @@ public: return; } - auto du = std::move(response.d_idstate.du); - if (du->rsock == -1) { + auto dohUnit = getDUFromIDS(response.d_idstate); + if (dohUnit->responseSender == nullptr) { return; } - du->response = std::move(response.d_buffer); - du->ids = std::move(response.d_idstate); - DNSResponse dr(du->ids, du->response, du->downstream); + dohUnit->response = std::move(response.d_buffer); + dohUnit->ids = std::move(response.d_idstate); + DNSResponse dr(dohUnit->ids, dohUnit->response, dohUnit->downstream); - dnsheader cleartextDH; - memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH)); + dnsheader cleartextDH{}; + memcpy(&cleartextDH, dr.getHeader().get(), sizeof(cleartextDH)); if (!response.isAsync()) { static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localRespRuleActions = g_respruleactions.getLocal(); static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal(); - dr.ids.du = std::move(du); + dr.ids.du = std::move(dohUnit); - if (!processResponse(dr.ids.du->response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dr, false)) { + if (!processResponse(dynamic_cast<DOHUnit*>(dr.ids.du.get())->response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dr, false)) { if (dr.ids.du) { - dr.ids.du->status_code = 503; - sendDoHUnitToTheMainThread(std::move(dr.ids.du), "Response dropped by rules"); + dohUnit = getDUFromIDS(dr.ids); + dohUnit->status_code = 503; + sendDoHUnitToTheMainThread(std::move(dohUnit), "Response dropped by rules"); } return; } @@ -481,26 +521,26 @@ public: return; } - du = std::move(dr.ids.du); + dohUnit = getDUFromIDS(dr.ids); } - if (!du->ids.selfGenerated) { - double udiff = du->ids.queryRealTime.udiff(); - vinfolog("Got answer from %s, relayed to %s (https), took %f usec", du->downstream->d_config.remote.toStringWithPort(), du->ids.origRemote.toStringWithPort(), udiff); + if (!dohUnit->ids.selfGenerated) { + double udiff = dohUnit->ids.queryRealTime.udiff(); + vinfolog("Got answer from %s, relayed to %s (https), took %f us", dohUnit->downstream->d_config.remote.toStringWithPort(), dohUnit->ids.origRemote.toStringWithPort(), udiff); - auto backendProtocol = du->downstream->getProtocol(); - if (backendProtocol == dnsdist::Protocol::DoUDP && du->tcp) { + auto backendProtocol = dohUnit->downstream->getProtocol(); + if (backendProtocol == dnsdist::Protocol::DoUDP && dohUnit->tcp) { backendProtocol = dnsdist::Protocol::DoTCP; } - handleResponseSent(du->ids, udiff, du->ids.origRemote, du->downstream->d_config.remote, du->response.size(), cleartextDH, backendProtocol, true); + handleResponseSent(dohUnit->ids, udiff, dohUnit->ids.origRemote, dohUnit->downstream->d_config.remote, dohUnit->response.size(), cleartextDH, backendProtocol, true); } - ++g_stats.responses; - if (du->ids.cs) { - ++du->ids.cs->responses; + ++dnsdist::metrics::g_stats.responses; + if (dohUnit->ids.cs != nullptr) { + ++dohUnit->ids.cs->responses; } - sendDoHUnitToTheMainThread(std::move(du), "cross-protocol response"); + sendDoHUnitToTheMainThread(std::move(dohUnit), "cross-protocol response"); } void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override @@ -508,62 +548,69 @@ public: return handleResponse(now, std::move(response)); } - void notifyIOError(InternalQueryState&& query, const struct timeval& now) override + void notifyIOError(const struct timeval& now, TCPResponse&& response) override { + auto& query = response.d_idstate; if (!query.du) { return; } - if (query.du->rsock == -1) { + auto dohUnit = getDUFromIDS(query); + if (dohUnit->responseSender == nullptr) { return; } - auto du = std::move(query.du); - du->ids = std::move(query); - du->status_code = 502; - sendDoHUnitToTheMainThread(std::move(du), "cross-protocol error response"); + dohUnit->ids = std::move(query); + dohUnit->status_code = 502; + sendDoHUnitToTheMainThread(std::move(dohUnit), "cross-protocol error response"); } }; class DoHCrossProtocolQuery : public CrossProtocolQuery { public: - DoHCrossProtocolQuery(DOHUnitUniquePtr&& du, bool isResponse) + DoHCrossProtocolQuery(DOHUnitUniquePtr&& dohUnit, bool isResponse) { if (isResponse) { /* happens when a response becomes async */ - query = InternalQuery(std::move(du->response), std::move(du->ids)); + query = InternalQuery(std::move(dohUnit->response), std::move(dohUnit->ids)); } else { /* we need to duplicate the query here because we might need the existing query later if we get a truncated answer */ - query = InternalQuery(PacketBuffer(du->query), std::move(du->ids)); + query = InternalQuery(PacketBuffer(dohUnit->query), std::move(dohUnit->ids)); } - /* it might have been moved when we moved du->ids */ - if (du) { - query.d_idstate.du = std::move(du); + /* it might have been moved when we moved dohUnit->ids */ + if (dohUnit) { + query.d_idstate.du = std::move(dohUnit); } /* we _could_ remove it from the query buffer and put in query's d_proxyProtocolPayload, - clearing query.d_proxyProtocolPayloadAdded and du->proxyProtocolPayloadSize. + clearing query.d_proxyProtocolPayloadAdded and dohUnit->proxyProtocolPayloadSize. Leave it for now because we know that the onky case where the payload has been added is when we tried over UDP, got a TC=1 answer and retried over TCP/DoT, and we know the TCP/DoT code can handle it. */ - query.d_proxyProtocolPayloadAdded = query.d_idstate.du->proxyProtocolPayloadSize > 0; + query.d_proxyProtocolPayloadAdded = query.d_idstate.d_proxyProtocolPayloadSize > 0; downstream = query.d_idstate.du->downstream; - proxyProtocolPayloadSize = query.d_idstate.du->proxyProtocolPayloadSize; } void handleInternalError() { - query.d_idstate.du->status_code = 502; - sendDoHUnitToTheMainThread(std::move(query.d_idstate.du), "DoH internal error"); + auto dohUnit = getDUFromIDS(query.d_idstate); + if (dohUnit == nullptr) { + return; + } + dohUnit->status_code = 502; + sendDoHUnitToTheMainThread(std::move(dohUnit), "DoH internal error"); } std::shared_ptr<TCPQuerySender> getTCPQuerySender() override { - query.d_idstate.du->downstream = downstream; + auto* unit = dynamic_cast<DOHUnit*>(query.d_idstate.du.get()); + if (unit != nullptr) { + unit->downstream = downstream; + } return s_sender; } @@ -581,9 +628,9 @@ public: return dr; } - DOHUnitUniquePtr&& releaseDU() + DOHUnitUniquePtr releaseDU() { - return std::move(query.d_idstate.du); + return getDUFromIDS(query.d_idstate); } private: @@ -598,25 +645,25 @@ std::unique_ptr<CrossProtocolQuery> getDoHCrossProtocolQueryFromDQ(DNSQuestion& throw std::runtime_error("Trying to create a DoH cross protocol query without a valid DoH unit"); } - auto du = std::move(dq.ids.du); - if (&dq.ids != &du->ids) { - du->ids = std::move(dq.ids); + auto dohUnit = getDUFromIDS(dq.ids); + if (&dq.ids != &dohUnit->ids) { + dohUnit->ids = std::move(dq.ids); } - du->ids.origID = dq.getHeader()->id; + dohUnit->ids.origID = dq.getHeader()->id; if (!isResponse) { - if (du->query.data() != dq.getMutableData().data()) { - du->query = std::move(dq.getMutableData()); + if (dohUnit->query.data() != dq.getMutableData().data()) { + dohUnit->query = std::move(dq.getMutableData()); } } else { - if (du->response.data() != dq.getMutableData().data()) { - du->response = std::move(dq.getMutableData()); + if (dohUnit->response.data() != dq.getMutableData().data()) { + dohUnit->response = std::move(dq.getMutableData()); } } - return std::make_unique<DoHCrossProtocolQuery>(std::move(du), isResponse); + return std::make_unique<DoHCrossProtocolQuery>(std::move(dohUnit), isResponse); } /* @@ -624,181 +671,191 @@ std::unique_ptr<CrossProtocolQuery> getDoHCrossProtocolQueryFromDQ(DNSQuestion& */ static void processDOHQuery(DOHUnitUniquePtr&& unit, bool inMainThread = false) { - const auto handleImmediateResponse = [inMainThread](DOHUnitUniquePtr&& du, const char* reason) { + const auto handleImmediateResponse = [inMainThread](DOHUnitUniquePtr&& dohUnit, const char* reason) { if (inMainThread) { - handleResponse(*du->dsc->df, du->req, du->status_code, du->response, du->dsc->df->d_customResponseHeaders, du->contentType, true); + handleResponse(*dohUnit->dsc->dohFrontend, dohUnit->req, dohUnit->status_code, dohUnit->response, dohUnit->dsc->dohFrontend->d_customResponseHeaders, dohUnit->contentType, true); /* so the unique pointer is stored in the InternalState which itself is stored in the unique pointer itself. We likely need a better design, but for now let's just reset the internal one since we know it is no longer needed. */ - du->ids.du.reset(); + dohUnit->ids.du.reset(); } else { - sendDoHUnitToTheMainThread(std::move(du), reason); + sendDoHUnitToTheMainThread(std::move(dohUnit), reason); } }; auto& ids = unit->ids; - ids.du = std::move(unit); - auto& du = ids.du; uint16_t queryId = 0; ComboAddress remote; try { - if (!du->req) { + if (unit->req == nullptr) { // we got closed meanwhile. XXX small race condition here - // but we should be fine as long as we don't touch du->req + // but we should be fine as long as we don't touch dohUnit->req // outside of the main DoH thread - du->status_code = 500; - handleImmediateResponse(std::move(du), "DoH killed in flight"); + unit->status_code = 500; + handleImmediateResponse(std::move(unit), "DoH killed in flight"); return; } - { - // if there was no EDNS, we add it with a large buffer size - // so we can use UDP to talk to the backend. - auto dh = const_cast<struct dnsheader*>(reinterpret_cast<const struct dnsheader*>(du->query.data())); - - if (!dh->arcount) { - if (generateOptRR(std::string(), du->query, 4096, 4096, 0, false)) { - dh = const_cast<struct dnsheader*>(reinterpret_cast<const struct dnsheader*>(du->query.data())); // may have reallocated - dh->arcount = htons(1); - du->ids.ednsAdded = true; - } - } - else { - // we leave existing EDNS in place - } - } - - remote = du->ids.origRemote; - DOHServerConfig* dsc = du->dsc; + remote = ids.origRemote; + DOHServerConfig* dsc = unit->dsc; auto& holders = dsc->holders; - ClientState& cs = *dsc->cs; + ClientState& clientState = *dsc->clientState; - if (du->query.size() < sizeof(dnsheader)) { - ++g_stats.nonCompliantQueries; - ++cs.nonCompliantQueries; - du->status_code = 400; - handleImmediateResponse(std::move(du), "DoH non-compliant query"); + if (unit->query.size() < sizeof(dnsheader) || unit->query.size() > std::numeric_limits<uint16_t>::max()) { + ++dnsdist::metrics::g_stats.nonCompliantQueries; + ++clientState.nonCompliantQueries; + unit->status_code = 400; + handleImmediateResponse(std::move(unit), "DoH non-compliant query"); return; } - ++cs.queries; - ++g_stats.queries; - du->ids.queryRealTime.start(); + ++clientState.queries; + ++dnsdist::metrics::g_stats.queries; + ids.queryRealTime.start(); { /* don't keep that pointer around, it will be invalidated if the buffer is ever resized */ - struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(du->query.data()); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + const dnsheader_aligned dnsHeader(unit->query.data()); - if (!checkQueryHeaders(dh, cs)) { - du->status_code = 400; - handleImmediateResponse(std::move(du), "DoH invalid headers"); + if (!checkQueryHeaders(*dnsHeader, clientState)) { + unit->status_code = 400; + handleImmediateResponse(std::move(unit), "DoH invalid headers"); return; } - if (dh->qdcount == 0) { - dh->rcode = RCode::NotImp; - dh->qr = true; - du->response = std::move(du->query); + if (dnsHeader->qdcount == 0U) { + dnsdist::PacketMangling::editDNSHeaderFromPacket(unit->query, [](dnsheader& header) { + header.rcode = RCode::NotImp; + header.qr = true; + return true; + }); + unit->response = std::move(unit->query); - handleImmediateResponse(std::move(du), "DoH empty query"); + handleImmediateResponse(std::move(unit), "DoH empty query"); return; } - queryId = ntohs(dh->id); + queryId = ntohs(dnsHeader->id); } - auto downstream = du->downstream; - du->ids.qname = DNSName(reinterpret_cast<const char*>(du->query.data()), du->query.size(), sizeof(dnsheader), false, &du->ids.qtype, &du->ids.qclass); - DNSQuestion dq(du->ids, du->query); - const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader()); - ids.origFlags = *flags; - du->ids.cs = &cs; - dq.sni = std::move(du->sni); + { + // if there was no EDNS, we add it with a large buffer size + // so we can use UDP to talk to the backend. + dnsheader_aligned dnsHeader(unit->query.data()); + if (dnsHeader.get()->arcount == 0U) { + if (addEDNS(unit->query, 4096, false, 4096, 0)) { + ids.ednsAdded = true; + } + } + } - auto result = processQuery(dq, holders, downstream); + auto downstream = unit->downstream; + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + ids.qname = DNSName(reinterpret_cast<const char*>(unit->query.data()), static_cast<int>(unit->query.size()), static_cast<int>(sizeof(dnsheader)), false, &ids.qtype, &ids.qclass); + DNSQuestion dnsQuestion(ids, unit->query); + const uint16_t* flags = getFlagsFromDNSHeader(dnsQuestion.getHeader().get()); + ids.origFlags = *flags; + ids.cs = &clientState; + dnsQuestion.sni = std::move(unit->sni); + ids.du = std::move(unit); + auto result = processQuery(dnsQuestion, holders, downstream); if (result == ProcessQueryResult::Drop) { - du->status_code = 403; - handleImmediateResponse(std::move(du), "DoH dropped query"); + unit = getDUFromIDS(ids); + unit->status_code = 403; + handleImmediateResponse(std::move(unit), "DoH dropped query"); return; } - else if (result == ProcessQueryResult::Asynchronous) { + if (result == ProcessQueryResult::Asynchronous) { return; } - else if (result == ProcessQueryResult::SendAnswer) { - if (du->response.empty()) { - du->response = std::move(du->query); + if (result == ProcessQueryResult::SendAnswer) { + unit = getDUFromIDS(ids); + if (unit->response.empty()) { + unit->response = std::move(unit->query); } - if (du->response.size() >= sizeof(dnsheader) && du->contentType.empty()) { - auto dh = reinterpret_cast<const struct dnsheader*>(du->response.data()); - - handleResponseSent(du->ids.qname, QType(du->ids.qtype), 0., du->ids.origDest, ComboAddress(), du->response.size(), *dh, dnsdist::Protocol::DoH, dnsdist::Protocol::DoH, false); + if (unit->response.size() >= sizeof(dnsheader) && unit->contentType.empty()) { + dnsheader_aligned dnsHeader(unit->response.data()); + handleResponseSent(unit->ids.qname, QType(unit->ids.qtype), 0., unit->ids.origDest, ComboAddress(), unit->response.size(), *(dnsHeader.get()), dnsdist::Protocol::DoH, dnsdist::Protocol::DoH, false); } - handleImmediateResponse(std::move(du), "DoH self-answered response"); + handleImmediateResponse(std::move(unit), "DoH self-answered response"); return; } + unit = getDUFromIDS(ids); if (result != ProcessQueryResult::PassToBackend) { - du->status_code = 500; - handleImmediateResponse(std::move(du), "DoH no backend available"); + unit->status_code = 500; + handleImmediateResponse(std::move(unit), "DoH no backend available"); return; } if (downstream == nullptr) { - du->status_code = 502; - handleImmediateResponse(std::move(du), "DoH no backend available"); + unit->status_code = 502; + handleImmediateResponse(std::move(unit), "DoH no backend available"); return; } - du->downstream = downstream; + unit->downstream = downstream; if (downstream->isTCPOnly()) { std::string proxyProtocolPayload; /* we need to do this _before_ creating the cross protocol query because after that the buffer will have been moved */ if (downstream->d_config.useProxyProtocol) { - proxyProtocolPayload = getProxyProtocolPayload(dq); + proxyProtocolPayload = getProxyProtocolPayload(dnsQuestion); } - du->ids.origID = htons(queryId); - du->tcp = true; + unit->ids.origID = htons(queryId); + unit->tcp = true; /* this moves du->ids, careful! */ - auto cpq = std::make_unique<DoHCrossProtocolQuery>(std::move(du), false); + auto cpq = std::make_unique<DoHCrossProtocolQuery>(std::move(unit), false); + if (!cpq) { + // make linters happy + return; + } cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); if (downstream->passCrossProtocolQuery(std::move(cpq))) { return; } - else { - if (inMainThread) { - du = cpq->releaseDU(); - du->status_code = 502; - handleImmediateResponse(std::move(du), "DoH internal error"); + + if (inMainThread) { + // cpq is not altered if the call fails but linters are not smart enough to notice that + if (cpq) { + // NOLINTNEXTLINE(bugprone-use-after-move): cpq is not altered if the call fails + unit = cpq->releaseDU(); } - else { + unit->status_code = 502; + handleImmediateResponse(std::move(unit), "DoH internal error"); + } + else { + // cpq is not altered if the call fails but linters are not smart enough to notice that + if (cpq) { + // NOLINTNEXTLINE(bugprone-use-after-move): cpq is not altered if the call fails cpq->handleInternalError(); } - return; } + return; } - ComboAddress dest = dq.ids.origDest; - if (!assignOutgoingUDPQueryToBackend(downstream, htons(queryId), dq, du->query, dest)) { - du->status_code = 502; - handleImmediateResponse(std::move(du), "DoH internal error"); + auto& query = unit->query; + ids.du = std::move(unit); + if (!assignOutgoingUDPQueryToBackend(downstream, htons(queryId), dnsQuestion, query)) { + unit = getDUFromIDS(ids); + unit->status_code = 502; + handleImmediateResponse(std::move(unit), "DoH internal error"); return; } } catch (const std::exception& e) { vinfolog("Got an error in DOH question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what()); - du->status_code = 500; - handleImmediateResponse(std::move(du), "DoH internal error"); + unit->status_code = 500; + handleImmediateResponse(std::move(unit), "DoH internal error"); return; } - - return; } /* called when a HTTP response is about to be sent, from the main DoH thread */ @@ -808,16 +865,17 @@ static void on_response_ready_cb(struct st_h2o_filter_t *self, h2o_req_t *req, h return; } - DOHServerConfig* dsc = reinterpret_cast<DOHServerConfig*>(req->conn->ctx->storage.entries[0].data); + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): h2o API + auto* dsc = static_cast<DOHServerConfig*>(req->conn->ctx->storage.entries[0].data); DOHFrontend::HTTPVersionStats* stats = nullptr; if (req->version < 0x200) { /* HTTP 1.x */ - stats = &dsc->df->d_http1Stats; + stats = &dsc->dohFrontend->d_http1Stats; } else { /* HTTP 2.0 */ - stats = &dsc->df->d_http2Stats; + stats = &dsc->dohFrontend->d_http2Stats; } switch (req->res.status) { @@ -848,10 +906,10 @@ static void on_response_ready_cb(struct st_h2o_filter_t *self, h2o_req_t *req, h We use this to signal to the 'du' that this req is no longer alive */ static void on_generator_dispose(void *_self) { - DOHUnit** du = reinterpret_cast<DOHUnit**>(_self); - if (*du) { // if 0, on_dnsdist cleaned up du already - (*du)->self = nullptr; - (*du)->req = nullptr; + auto* dohUnit = static_cast<DOHUnit**>(_self); + if (*dohUnit != nullptr) { // if nullptr, on_dnsdist cleaned up dohUnit already + (*dohUnit)->self = nullptr; + (*dohUnit)->req = nullptr; } } @@ -862,6 +920,7 @@ static void doh_dispatch_query(DOHServerConfig* dsc, h2o_handler_t* self, h2o_re { try { /* we only parse it there as a sanity check, we will parse it again later */ + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) DNSPacketMangler mangler(reinterpret_cast<char*>(query.data()), query.size()); mangler.skipDomainName(); mangler.skipBytes(4); @@ -869,23 +928,24 @@ static void doh_dispatch_query(DOHServerConfig* dsc, h2o_handler_t* self, h2o_re /* we are doing quite some copies here, sorry about that, but we can't keep accessing the req object once we are in a different thread because the request might get killed by h2o at pretty much any time */ - auto du = std::make_unique<DOHUnit>(std::move(query), std::move(path), std::string(req->authority.base, req->authority.len)); - du->dsc = dsc; - du->req = req; - du->ids.origDest = local; - du->ids.origRemote = remote; - du->ids.protocol = dnsdist::Protocol::DoH; - du->rsock = dsc->dohresponsepair[0]; + auto dohUnit = std::make_unique<DOHUnit>(std::move(query), std::move(path), std::string(req->authority.base, req->authority.len)); + dohUnit->dsc = dsc; + dohUnit->req = req; + dohUnit->ids.origDest = local; + dohUnit->ids.origRemote = remote; + dohUnit->ids.protocol = dnsdist::Protocol::DoH; + dohUnit->responseSender = &dsc->d_responseSender; if (req->scheme != nullptr) { - du->scheme = std::string(req->scheme->name.base, req->scheme->name.len); + dohUnit->scheme = std::string(req->scheme->name.base, req->scheme->name.len); } - du->query_at = req->query_at; + dohUnit->query_at = req->query_at; - if (dsc->df->d_keepIncomingHeaders) { - du->headers = std::make_unique<std::unordered_map<std::string, std::string>>(); - du->headers->reserve(req->headers.size); + if (dsc->dohFrontend->d_keepIncomingHeaders) { + dohUnit->headers = std::make_unique<std::unordered_map<std::string, std::string>>(); + dohUnit->headers->reserve(req->headers.size); for (size_t i = 0; i < req->headers.size; ++i) { - (*du->headers)[std::string(req->headers.entries[i].name->base, req->headers.entries[i].name->len)] = std::string(req->headers.entries[i].value.base, req->headers.entries[i].value.len); + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): h2o API + (*dohUnit->headers)[std::string(req->headers.entries[i].name->base, req->headers.entries[i].name->len)] = std::string(req->headers.entries[i].value.base, req->headers.entries[i].value.len); } } @@ -893,36 +953,25 @@ static void doh_dispatch_query(DOHServerConfig* dsc, h2o_handler_t* self, h2o_re h2o_socket_t* sock = req->conn->callbacks->get_socket(req->conn); const char * sni = h2o_socket_get_ssl_server_name(sock); if (sni != nullptr) { - du->sni = sni; + dohUnit->sni = sni; } #endif /* HAVE_H2O_SOCKET_GET_SSL_SERVER_NAME */ - du->self = reinterpret_cast<DOHUnit**>(h2o_mem_alloc_shared(&req->pool, sizeof(*self), on_generator_dispose)); - auto ptr = du.release(); - *(ptr->self) = ptr; + dohUnit->self = static_cast<DOHUnit**>(h2o_mem_alloc_shared(&req->pool, sizeof(*self), on_generator_dispose)); + *(dohUnit->self) = dohUnit.get(); #ifdef USE_SINGLE_ACCEPTOR_THREAD - processDOHQuery(DOHUnitUniquePtr(ptr, DOHUnit::release), true); + processDOHQuery(std::move(dohUnit), true); #else /* USE_SINGLE_ACCEPTOR_THREAD */ - try { - static_assert(sizeof(ptr) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail"); - ssize_t sent = write(dsc->dohquerypair[0], &ptr, sizeof(ptr)); - if (sent != sizeof(ptr)) { - if (errno == EAGAIN || errno == EWOULDBLOCK) { - ++g_stats.dohQueryPipeFull; - vinfolog("Unable to pass a DoH query to the DoH worker thread because the pipe is full"); - } - else { - vinfolog("Unable to pass a DoH query to the DoH worker thread because we couldn't write to the pipe: %s", stringerror()); - } - ptr->release(); - ptr = nullptr; + try { + if (!dsc->d_querySender.send(std::move(dohUnit))) { + ++dnsdist::metrics::g_stats.dohQueryPipeFull; + vinfolog("Unable to pass a DoH query to the DoH worker thread because the pipe is full"); h2o_send_error_500(req, "Internal Server Error", "Internal Server Error", 0); } } catch (...) { - if (ptr != nullptr) { - ptr->release(); - } + vinfolog("Unable to pass a DoH query to the DoH worker thread because we couldn't write to the pipe: %s", stringerror()); + h2o_send_error_500(req, "Internal Server Error", "Internal Server Error", 0); } #endif /* USE_SINGLE_ACCEPTOR_THREAD */ } @@ -940,7 +989,9 @@ static bool getHTTPHeaderValue(const h2o_req_t* req, const std::string& headerNa std::string_view headerNameView(headerName); for (size_t i = 0; i < req->headers.size; ++i) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): h2o API if (std::string_view(req->headers.entries[i].name->base, req->headers.entries[i].name->len) == headerNameView) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): h2o API value = std::string_view(req->headers.entries[i].value.base, req->headers.entries[i].value.len); /* don't stop there, we might have more than one header with the same name, and we want the last one */ found = true; @@ -990,10 +1041,11 @@ static std::optional<ComboAddress> processForwardedForHeader(const h2o_req_t* re static int doh_handler(h2o_handler_t *self, h2o_req_t *req) { try { - if (!req->conn->ctx->storage.size) { + if (req->conn->ctx->storage.size == 0) { return 0; // although we might was well crash on this } - DOHServerConfig* dsc = reinterpret_cast<DOHServerConfig*>(req->conn->ctx->storage.entries[0].data); + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): h2o API + auto* dsc = static_cast<DOHServerConfig*>(req->conn->ctx->storage.entries[0].data); h2o_socket_t* sock = req->conn->callbacks->get_socket(req->conn); const int descriptor = h2o_socket_get_fd(sock); @@ -1005,45 +1057,51 @@ static int doh_handler(h2o_handler_t *self, h2o_req_t *req) ++conn.d_nbQueries; if (conn.d_nbQueries == 1) { if (h2o_socket_get_ssl_session_reused(sock) == 0) { - ++dsc->cs->tlsNewSessions; + ++dsc->clientState->tlsNewSessions; } else { - ++dsc->cs->tlsResumptions; + ++dsc->clientState->tlsResumptions; } + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): h2o API h2o_socket_getsockname(sock, reinterpret_cast<struct sockaddr*>(&conn.d_local)); } auto remote = conn.d_remote; - if (dsc->df->d_trustForwardedForHeader) { + if (dsc->dohFrontend->d_trustForwardedForHeader) { auto newRemote = processForwardedForHeader(req, remote); if (newRemote) { - remote = std::move(*newRemote); + remote = *newRemote; } } auto& holders = dsc->holders; if (!holders.acl->match(remote)) { - ++g_stats.aclDrops; + ++dnsdist::metrics::g_stats.aclDrops; vinfolog("Query from %s (DoH) dropped because of ACL", remote.toStringWithPort()); - h2o_send_error_403(req, "Forbidden", "dns query not allowed because of ACL", 0); + h2o_send_error_403(req, "Forbidden", "DoH query not allowed because of ACL", 0); return 0; } - if (auto tlsversion = h2o_socket_get_ssl_protocol_version(sock)) { - if(!strcmp(tlsversion, "TLSv1.0")) - ++dsc->cs->tls10queries; - else if(!strcmp(tlsversion, "TLSv1.1")) - ++dsc->cs->tls11queries; - else if(!strcmp(tlsversion, "TLSv1.2")) - ++dsc->cs->tls12queries; - else if(!strcmp(tlsversion, "TLSv1.3")) - ++dsc->cs->tls13queries; - else - ++dsc->cs->tlsUnknownqueries; + if (const auto* tlsversion = h2o_socket_get_ssl_protocol_version(sock)) { + if (strcmp(tlsversion, "TLSv1.0") == 0) { + ++dsc->clientState->tls10queries; + } + else if (strcmp(tlsversion, "TLSv1.1") == 0) { + ++dsc->clientState->tls11queries; + } + else if (strcmp(tlsversion, "TLSv1.2") == 0) { + ++dsc->clientState->tls12queries; + } + else if (strcmp(tlsversion, "TLSv1.3") == 0) { + ++dsc->clientState->tls13queries; + } + else { + ++dsc->clientState->tlsUnknownqueries; + } } - if (dsc->df->d_exactPathMatching) { + if (dsc->dohFrontend->d_exactPathMatching) { const std::string_view pathOnly(req->path_normalized.base, req->path_normalized.len); if (dsc->paths.count(pathOnly) == 0) { h2o_send_error_404(req, "Not Found", "there is no endpoint configured for this path", 0); @@ -1056,25 +1114,27 @@ static int doh_handler(h2o_handler_t *self, h2o_req_t *req) string path(req->path.base, req->path.len); /* the responses map can be updated at runtime, so we need to take a copy of the shared pointer, increasing the reference counter */ - auto responsesMap = dsc->df->d_responsesMap; + auto responsesMap = dsc->dohFrontend->d_responsesMap; /* 1 byte for the root label, 2 type, 2 class, 4 TTL (fake), 2 record length, 2 option length, 2 option code, 2 family, 1 source, 1 scope, 16 max for a full v6 */ const size_t maxAdditionalSizeForEDNS = 35U; if (responsesMap) { for (const auto& entry : *responsesMap) { if (entry->matches(path)) { const auto& customHeaders = entry->getHeaders(); - handleResponse(*dsc->df, req, entry->getStatusCode(), entry->getContent(), customHeaders ? *customHeaders : dsc->df->d_customResponseHeaders, std::string(), false); + handleResponse(*dsc->dohFrontend, req, entry->getStatusCode(), entry->getContent(), customHeaders ? *customHeaders : dsc->dohFrontend->d_customResponseHeaders, std::string(), false); return 0; } } } - if (h2o_memis(req->method.base, req->method.len, H2O_STRLIT("POST"))) { - ++dsc->df->d_postqueries; - if(req->version >= 0x0200) - ++dsc->df->d_http2Stats.d_nbQueries; - else - ++dsc->df->d_http1Stats.d_nbQueries; + if (h2o_memis(req->method.base, req->method.len, H2O_STRLIT("POST")) != 0) { + ++dsc->dohFrontend->d_postqueries; + if (req->version >= 0x0200) { + ++dsc->dohFrontend->d_http2Stats.d_nbQueries; + } + else { + ++dsc->dohFrontend->d_http1Stats.d_nbQueries; + } PacketBuffer query; /* We reserve a few additional bytes to be able to add EDNS later */ @@ -1085,9 +1145,10 @@ static int doh_handler(h2o_handler_t *self, h2o_req_t *req) } else if(req->query_at != SIZE_MAX && (req->path.len - req->query_at > 5)) { auto pos = path.find("?dns="); - if(pos == string::npos) + if (pos == string::npos) { pos = path.find("&dns="); - if(pos != string::npos) { + } + if (pos != string::npos) { // need to base64url decode this string sdns(path.substr(pos+5)); boost::replace_all(sdns,"-", "+"); @@ -1110,119 +1171,47 @@ static int doh_handler(h2o_handler_t *self, h2o_req_t *req) decoded.reserve(estimate + maxAdditionalSizeForEDNS); if(B64Decode(sdns, decoded) < 0) { h2o_send_error_400(req, "Bad Request", "Unable to decode BASE64-URL", 0); - ++dsc->df->d_badrequests; + ++dsc->dohFrontend->d_badrequests; return 0; } - else { - ++dsc->df->d_getqueries; - if(req->version >= 0x0200) - ++dsc->df->d_http2Stats.d_nbQueries; - else - ++dsc->df->d_http1Stats.d_nbQueries; - doh_dispatch_query(dsc, self, req, std::move(decoded), conn.d_local, remote, std::move(path)); + ++dsc->dohFrontend->d_getqueries; + if (req->version >= 0x0200) { + ++dsc->dohFrontend->d_http2Stats.d_nbQueries; } + else { + ++dsc->dohFrontend->d_http1Stats.d_nbQueries; + } + + doh_dispatch_query(dsc, self, req, std::move(decoded), conn.d_local, remote, std::move(path)); } else { vinfolog("HTTP request without DNS parameter: %s", req->path.base); h2o_send_error_400(req, "Bad Request", "Unable to find the DNS parameter", 0); - ++dsc->df->d_badrequests; + ++dsc->dohFrontend->d_badrequests; return 0; } } else { h2o_send_error_400(req, "Bad Request", "Unable to parse the request", 0); - ++dsc->df->d_badrequests; + ++dsc->dohFrontend->d_badrequests; } return 0; } - catch(const std::exception& e) - { - errlog("DOH Handler function failed with error %s", e.what()); + catch (const std::exception& e) { + vinfolog("DOH Handler function failed with error: '%s'", e.what()); return 0; } } -HTTPHeaderRule::HTTPHeaderRule(const std::string& header, const std::string& regex) - : d_header(toLower(header)), d_regex(regex), d_visual("http[" + header+ "] ~ " + regex) -{ -} - -bool HTTPHeaderRule::matches(const DNSQuestion* dq) const -{ - if (!dq->ids.du || !dq->ids.du->headers) { - return false; - } - - for (const auto& header : *dq->ids.du->headers) { - if (header.first == d_header) { - return d_regex.match(header.second); - } - } - return false; -} - -string HTTPHeaderRule::toString() const -{ - return d_visual; -} - -HTTPPathRule::HTTPPathRule(const std::string& path) - : d_path(path) -{ - -} - -bool HTTPPathRule::matches(const DNSQuestion* dq) const -{ - if (!dq->ids.du) { - return false; - } - - if (dq->ids.du->query_at == SIZE_MAX) { - return dq->ids.du->path == d_path; - } - else { - return d_path.compare(0, d_path.size(), dq->ids.du->path, 0, dq->ids.du->query_at) == 0; - } -} - -string HTTPPathRule::toString() const -{ - return "url path == " + d_path; -} - -HTTPPathRegexRule::HTTPPathRegexRule(const std::string& regex): d_regex(regex), d_visual("http path ~ " + regex) -{ -} - -bool HTTPPathRegexRule::matches(const DNSQuestion* dq) const +const std::unordered_map<std::string, std::string>& DOHUnit::getHTTPHeaders() const { - if (!dq->ids.du) { - return false; + if (!headers) { + static const HeadersMap empty{}; + return empty; } - - return d_regex.match(dq->ids.du->getHTTPPath()); -} - -string HTTPPathRegexRule::toString() const -{ - return d_visual; -} - -std::unordered_map<std::string, std::string> DOHUnit::getHTTPHeaders() const -{ - std::unordered_map<std::string, std::string> results; - if (headers) { - results.reserve(headers->size()); - - for (const auto& header : *headers) { - results.insert(header); - } - } - - return results; + return *headers; } std::string DOHUnit::getHTTPPath() const @@ -1230,17 +1219,15 @@ std::string DOHUnit::getHTTPPath() const if (query_at == SIZE_MAX) { return path; } - else { - return std::string(path, 0, query_at); - } + return {path, 0, query_at}; } -std::string DOHUnit::getHTTPHost() const +const std::string& DOHUnit::getHTTPHost() const { return host; } -std::string DOHUnit::getHTTPScheme() const +const std::string& DOHUnit::getHTTPScheme() const { return scheme; } @@ -1248,11 +1235,9 @@ std::string DOHUnit::getHTTPScheme() const std::string DOHUnit::getHTTPQueryString() const { if (query_at == SIZE_MAX) { - return std::string(); - } - else { - return path.substr(query_at); + return {}; } + return path.substr(query_at); } void DOHUnit::setHTTPResponse(uint16_t statusCode, PacketBuffer&& body_, const std::string& contentType_) @@ -1273,47 +1258,41 @@ void DOHUnit::setHTTPResponse(uint16_t statusCode, PacketBuffer&& body_, const s /* query has been parsed by h2o, which called doh_handler() in the main DoH thread. In order not to block for long, doh_handler() called doh_dispatch_query() which allocated a DOHUnit object and passed it to us */ -static void dnsdistclient(int qsock) +static void dnsdistclient(pdns::channel::Receiver<DOHUnit>&& receiver) { setThreadName("dnsdist/doh-cli"); for(;;) { try { - DOHUnit* ptr = nullptr; - ssize_t got = read(qsock, &ptr, sizeof(ptr)); - if (got < 0) { - warnlog("Error receiving internal DoH query: %s", strerror(errno)); - continue; - } - else if (static_cast<size_t>(got) < sizeof(ptr)) { + auto tmp = receiver.receive(); + if (!tmp) { continue; } - - DOHUnitUniquePtr du(ptr, DOHUnit::release); + auto dohUnit = std::move(*tmp); /* we are not in the main DoH thread anymore, so there is a real risk of a race condition where h2o kills the query while we are processing it, - so we can't touch the content of du->req until we are back into the + so we can't touch the content of dohUnit->req until we are back into the main DoH thread */ - if (!du->req) { + if (dohUnit->req == nullptr) { // it got killed in flight already - du->self = nullptr; + dohUnit->self = nullptr; continue; } - processDOHQuery(std::move(du), false); + processDOHQuery(std::move(dohUnit), false); } catch (const std::exception& e) { - errlog("Error while processing query received over DoH: %s", e.what()); + vinfolog("Error while processing query received over DoH: %s", e.what()); } catch (...) { - errlog("Unspecified error while processing query received over DoH"); + vinfolog("Unspecified error while processing query received over DoH"); } } } #endif /* USE_SINGLE_ACCEPTOR_THREAD */ /* Called in the main DoH thread if h2o finds that dnsdist gave us an answer by writing into - the dohresponsepair[0] side of the pipe so from: + the response channel so from: - handleDOHTimeout() when we did not get a response fast enough (called either from the health check thread (active) or from the frontend ones (reused)) - dnsdistclient (error 500 because processDOHQuery() returned a negative value) @@ -1326,73 +1305,71 @@ static void on_dnsdist(h2o_socket_t *listener, const char *err) for the CPU, the first thing we need to do is to send responses to free slots anyway, otherwise queries and responses are piling up in our pipes, consuming memory and likely coming up too late after the client has gone away */ + auto* dsc = static_cast<DOHServerConfig*>(listener->data); while (true) { - DOHUnit *ptr = nullptr; - DOHServerConfig* dsc = reinterpret_cast<DOHServerConfig*>(listener->data); - ssize_t got = read(dsc->dohresponsepair[1], &ptr, sizeof(ptr)); - - if (got < 0) { - if (errno != EWOULDBLOCK && errno != EAGAIN) { - errlog("Error reading a DOH internal response: %s", strerror(errno)); + DOHUnitUniquePtr dohUnit{nullptr}; + try { + auto tmp = dsc->d_responseReceiver.receive(); + if (!tmp) { + return; } - return; + dohUnit = std::move(*tmp); } - else if (static_cast<size_t>(got) != sizeof(ptr)) { - errlog("Error reading a DoH internal response, got %d bytes instead of the expected %d", got, sizeof(ptr)); + catch (const std::exception& e) { + warnlog("Error reading a DOH internal response: %s", e.what()); return; } - DOHUnitUniquePtr du(ptr, DOHUnit::release); - if (!du->req) { // it got killed in flight - du->self = nullptr; + if (dohUnit->req == nullptr) { // it got killed in flight + dohUnit->self = nullptr; continue; } - if (!du->tcp && - du->truncated && - du->query.size() > du->proxyProtocolPayloadSize && - (du->query.size() - du->proxyProtocolPayloadSize) > sizeof(dnsheader)) { + if (!dohUnit->tcp && + dohUnit->truncated && + dohUnit->query.size() > dohUnit->ids.d_proxyProtocolPayloadSize && + (dohUnit->query.size() - dohUnit->ids.d_proxyProtocolPayloadSize) > sizeof(dnsheader)) { /* restoring the original ID */ - dnsheader* queryDH = reinterpret_cast<struct dnsheader*>(du->query.data() + du->proxyProtocolPayloadSize); - queryDH->id = du->ids.origID; - du->ids.forwardedOverUDP = false; - du->tcp = true; - du->truncated = false; - du->response.clear(); + dnsdist::PacketMangling::editDNSHeaderFromRawPacket(&dohUnit->query.at(dohUnit->ids.d_proxyProtocolPayloadSize), [oldID=dohUnit->ids.origID](dnsheader& header) { + header.id = oldID; + return true; + }); + dohUnit->ids.forwardedOverUDP = false; + dohUnit->tcp = true; + dohUnit->truncated = false; + dohUnit->response.clear(); - auto cpq = std::make_unique<DoHCrossProtocolQuery>(std::move(du), false); + auto cpq = std::make_unique<DoHCrossProtocolQuery>(std::move(dohUnit), false); if (g_tcpclientthreads && g_tcpclientthreads->passCrossProtocolQueryToThread(std::move(cpq))) { continue; } - else { - vinfolog("Unable to pass DoH query to a TCP worker thread after getting a TC response over UDP"); - continue; - } + vinfolog("Unable to pass DoH query to a TCP worker thread after getting a TC response over UDP"); + continue; } - if (du->self) { + if (dohUnit->self != nullptr) { // we are back in the h2o main thread now, so we don't risk - // a race (h2o killing the query) when accessing du->req anymore - *du->self = nullptr; // so we don't clean up again in on_generator_dispose - du->self = nullptr; + // a race (h2o killing the query) when accessing dohUnit->req anymore + *dohUnit->self = nullptr; // so we don't clean up again in on_generator_dispose + dohUnit->self = nullptr; } - handleResponse(*dsc->df, du->req, du->status_code, du->response, dsc->df->d_customResponseHeaders, du->contentType, true); + handleResponse(*dsc->dohFrontend, dohUnit->req, dohUnit->status_code, dohUnit->response, dsc->dohFrontend->d_customResponseHeaders, dohUnit->contentType, true); } } /* called when a TCP connection has been accepted, the TLS session has not been established */ static void on_accept(h2o_socket_t *listener, const char *err) { - DOHServerConfig* dsc = reinterpret_cast<DOHServerConfig*>(listener->data); - h2o_socket_t *sock = nullptr; + auto* dsc = static_cast<DOHServerConfig*>(listener->data); if (err != nullptr) { return; } - if ((sock = h2o_evloop_socket_accept(listener)) == nullptr) { + h2o_socket_t* sock = h2o_evloop_socket_accept(listener); + if (sock == nullptr) { return; } @@ -1403,27 +1380,35 @@ static void on_accept(h2o_socket_t *listener, const char *err) } ComboAddress remote; + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): h2o API if (h2o_socket_getpeername(sock, reinterpret_cast<struct sockaddr*>(&remote)) == 0) { vinfolog("Dropping DoH connection because we could not retrieve the remote host"); h2o_socket_close(sock); return; } + if (dsc->dohFrontend->d_earlyACLDrop && !dsc->dohFrontend->d_trustForwardedForHeader && !dsc->holders.acl->match(remote)) { + ++dnsdist::metrics::g_stats.aclDrops; + vinfolog("Dropping DoH connection from %s because of ACL", remote.toStringWithPort()); + h2o_socket_close(sock); + return; + } + if (!dnsdist::IncomingConcurrentTCPConnectionsManager::accountNewTCPConnection(remote)) { vinfolog("Dropping DoH connection from %s because we have too many from this client already", remote.toStringWithPort()); h2o_socket_close(sock); return; } - auto concurrentConnections = ++dsc->cs->tcpCurrentConnections; - if (dsc->cs->d_tcpConcurrentConnectionsLimit > 0 && concurrentConnections > dsc->cs->d_tcpConcurrentConnectionsLimit) { - --dsc->cs->tcpCurrentConnections; + auto concurrentConnections = ++dsc->clientState->tcpCurrentConnections; + if (dsc->clientState->d_tcpConcurrentConnectionsLimit > 0 && concurrentConnections > dsc->clientState->d_tcpConcurrentConnectionsLimit) { + --dsc->clientState->tcpCurrentConnections; h2o_socket_close(sock); return; } - if (concurrentConnections > dsc->cs->tcpMaxConcurrentConnections.load()) { - dsc->cs->tcpMaxConcurrentConnections.store(concurrentConnections); + if (concurrentConnections > dsc->clientState->tcpMaxConcurrentConnections.load()) { + dsc->clientState->tcpMaxConcurrentConnections.store(concurrentConnections); } auto& conn = t_conns[descriptor]; @@ -1438,14 +1423,14 @@ static void on_accept(h2o_socket_t *listener, const char *err) sock->on_close.data = &conn; sock->data = dsc; - ++dsc->df->d_httpconnects; + ++dsc->dohFrontend->d_httpconnects; h2o_accept(conn.d_acceptCtx->get(), sock); } -static int create_listener(std::shared_ptr<DOHServerConfig>& dsc, int fd) +static int create_listener(std::shared_ptr<DOHServerConfig>& dsc, int descriptor) { - auto sock = h2o_evloop_socket_create(dsc->h2o_ctx.loop, fd, H2O_SOCKET_FLAG_DONT_READ); + auto* sock = h2o_evloop_socket_create(dsc->h2o_ctx.loop, descriptor, H2O_SOCKET_FLAG_DONT_READ); sock->data = dsc.get(); h2o_socket_read_start(sock, on_accept); @@ -1458,25 +1443,27 @@ static int ocsp_stapling_callback(SSL* ssl, void* arg) if (ssl == nullptr || arg == nullptr) { return SSL_TLSEXT_ERR_NOACK; } - const auto ocspMap = reinterpret_cast<std::map<int, std::string>*>(arg); + const auto* ocspMap = static_cast<std::map<int, std::string>*>(arg); return libssl_ocsp_stapling_callback(ssl, *ocspMap); } #endif /* DISABLE_OCSP_STAPLING */ #if OPENSSL_VERSION_MAJOR >= 3 -static int ticket_key_callback(SSL *s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, EVP_MAC_CTX *hctx, int enc) +// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays): OpenSSL API +static int ticket_key_callback(SSL* sslContext, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char* ivector, EVP_CIPHER_CTX* ectx, EVP_MAC_CTX* hctx, int enc) #else -static int ticket_key_callback(SSL *s, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char *iv, EVP_CIPHER_CTX *ectx, HMAC_CTX *hctx, int enc) +// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays): OpenSSL API +static int ticket_key_callback(SSL *sslContext, unsigned char keyName[TLS_TICKETS_KEY_NAME_SIZE], unsigned char* ivector, EVP_CIPHER_CTX* ectx, HMAC_CTX* hctx, int enc) #endif { - DOHAcceptContext* ctx = reinterpret_cast<DOHAcceptContext*>(libssl_get_ticket_key_callback_data(s)); + auto* ctx = static_cast<DOHAcceptContext*>(libssl_get_ticket_key_callback_data(sslContext)); if (ctx == nullptr || !ctx->d_ticketKeys) { return -1; } ctx->handleTicketsKeyRotation(); - auto ret = libssl_ticket_key_callback(s, *ctx->d_ticketKeys, keyName, iv, ectx, hctx, enc); + auto ret = libssl_ticket_key_callback(sslContext, *ctx->d_ticketKeys, keyName, ivector, ectx, hctx, enc); if (enc == 0) { if (ret == 0) { ++ctx->d_cs->tlsUnknownTicketKey; @@ -1494,7 +1481,7 @@ static void setupTLSContext(DOHAcceptContext& acceptCtx, TLSErrorCounters& counters) { if (tlsConfig.d_ciphers.empty()) { - tlsConfig.d_ciphers = DOH_DEFAULT_CIPHERS; + tlsConfig.d_ciphers = DOH_DEFAULT_CIPHERS.data(); } auto [ctx, warnings] = libssl_init_server_context(tlsConfig, acceptCtx.d_ocspResponses); @@ -1535,91 +1522,29 @@ static void setupTLSContext(DOHAcceptContext& acceptCtx, acceptCtx.loadTicketsKeys(tlsConfig.d_ticketKeyFile); } - auto nativeCtx = acceptCtx.get(); + auto* nativeCtx = acceptCtx.get(); nativeCtx->ssl_ctx = ctx.release(); } static void setupAcceptContext(DOHAcceptContext& ctx, DOHServerConfig& dsc, bool setupTLS) { - auto nativeCtx = ctx.get(); + auto* nativeCtx = ctx.get(); nativeCtx->ctx = &dsc.h2o_ctx; nativeCtx->hosts = dsc.h2o_config.hosts; - ctx.d_ticketsKeyRotationDelay = dsc.df->d_tlsConfig.d_ticketsKeyRotationDelay; + auto dohFrontend = std::atomic_load_explicit(&dsc.dohFrontend, std::memory_order_acquire); + ctx.d_ticketsKeyRotationDelay = dohFrontend->d_tlsContext.d_tlsConfig.d_ticketsKeyRotationDelay; - if (setupTLS && dsc.df->isHTTPS()) { + if (setupTLS && dohFrontend->isHTTPS()) { try { setupTLSContext(ctx, - dsc.df->d_tlsConfig, - dsc.df->d_tlsCounters); + dohFrontend->d_tlsContext.d_tlsConfig, + dohFrontend->d_tlsContext.d_tlsCounters); } catch (const std::runtime_error& e) { - throw std::runtime_error("Error setting up TLS context for DoH listener on '" + dsc.df->d_local.toStringWithPort() + "': " + e.what()); - } - } - ctx.d_cs = dsc.cs; -} - -void DOHFrontend::rotateTicketsKey(time_t now) -{ - if (d_dsc && d_dsc->accept_ctx) { - d_dsc->accept_ctx->rotateTicketsKey(now); - } -} - -void DOHFrontend::loadTicketsKeys(const std::string& keyFile) -{ - if (d_dsc && d_dsc->accept_ctx) { - d_dsc->accept_ctx->loadTicketsKeys(keyFile); - } -} - -void DOHFrontend::handleTicketsKeyRotation() -{ - if (d_dsc && d_dsc->accept_ctx) { - d_dsc->accept_ctx->handleTicketsKeyRotation(); - } -} - -time_t DOHFrontend::getNextTicketsKeyRotation() const -{ - if (d_dsc && d_dsc->accept_ctx) { - return d_dsc->accept_ctx->getNextTicketsKeyRotation(); - } - return 0; -} - -size_t DOHFrontend::getTicketsKeysCount() const -{ - size_t res = 0; - if (d_dsc && d_dsc->accept_ctx) { - res = d_dsc->accept_ctx->getTicketsKeysCount(); - } - return res; -} - -void DOHFrontend::reloadCertificates() -{ - auto newAcceptContext = std::make_shared<DOHAcceptContext>(); - setupAcceptContext(*newAcceptContext, *d_dsc, true); - std::atomic_store_explicit(&d_dsc->accept_ctx, newAcceptContext, std::memory_order_release); -} - -void DOHFrontend::setup() -{ - registerOpenSSLUser(); - - d_dsc = std::make_shared<DOHServerConfig>(d_idleTimeout, d_internalPipeBufferSize); - - if (isHTTPS()) { - try { - setupTLSContext(*d_dsc->accept_ctx, - d_tlsConfig, - d_tlsCounters); - } - catch (const std::runtime_error& e) { - throw std::runtime_error("Error setting up TLS context for DoH listener on '" + d_local.toStringWithPort() + "': " + e.what()); + throw std::runtime_error("Error setting up TLS context for DoH listener on '" + dohFrontend->d_tlsContext.d_addr.toStringWithPort() + "': " + e.what()); } } + ctx.d_cs = dsc.clientState; } static h2o_pathconf_t *register_handler(h2o_hostconf_t *hostconf, const char *path, int (*on_req)(h2o_handler_t *, h2o_req_t *)) @@ -1629,7 +1554,7 @@ static h2o_pathconf_t *register_handler(h2o_hostconf_t *hostconf, const char *pa return pathconf; } h2o_filter_t *filter = h2o_create_filter(pathconf, sizeof(*filter)); - if (filter) { + if (filter != nullptr) { filter->on_setup_ostream = on_response_ready_cb; } @@ -1642,38 +1567,39 @@ static h2o_pathconf_t *register_handler(h2o_hostconf_t *hostconf, const char *pa } // this is the entrypoint from dnsdist.cc -void dohThread(ClientState* cs) +void dohThread(ClientState* clientState) { try { - std::shared_ptr<DOHFrontend>& df = cs->dohFrontend; - auto& dsc = df->d_dsc; - dsc->cs = cs; - dsc->df = cs->dohFrontend; - dsc->h2o_config.server_name = h2o_iovec_init(df->d_serverTokens.c_str(), df->d_serverTokens.size()); + std::shared_ptr<DOHFrontend>& dohFrontend = clientState->dohFrontend; + auto& dsc = dohFrontend->d_dsc; + dsc->clientState = clientState; + std::atomic_store_explicit(&dsc->dohFrontend, clientState->dohFrontend, std::memory_order_release); + dsc->h2o_config.server_name = h2o_iovec_init(dohFrontend->d_serverTokens.c_str(), dohFrontend->d_serverTokens.size()); #ifndef USE_SINGLE_ACCEPTOR_THREAD - std::thread dnsdistThread(dnsdistclient, dsc->dohquerypair[1]); + std::thread dnsdistThread(dnsdistclient, std::move(dsc->d_queryReceiver)); dnsdistThread.detach(); // gets us better error reporting #endif setThreadName("dnsdist/doh"); // I wonder if this registers an IP address.. I think it does // this may mean we need to actually register a site "name" here and not the IP address - h2o_hostconf_t *hostconf = h2o_config_register_host(&dsc->h2o_config, h2o_iovec_init(df->d_local.toString().c_str(), df->d_local.toString().size()), 65535); + h2o_hostconf_t *hostconf = h2o_config_register_host(&dsc->h2o_config, h2o_iovec_init(dohFrontend->d_tlsContext.d_addr.toString().c_str(), dohFrontend->d_tlsContext.d_addr.toString().size()), 65535); - for(const auto& url : df->d_urls) { + dsc->paths = dohFrontend->d_urls; + for (const auto& url : dsc->paths) { register_handler(hostconf, url.c_str(), doh_handler); - dsc->paths.insert(url); } h2o_context_init(&dsc->h2o_ctx, h2o_evloop_create(), &dsc->h2o_config); // in this complicated way we insert the DOHServerConfig pointer in there h2o_vector_reserve(nullptr, &dsc->h2o_ctx.storage, 1); + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): h2o API dsc->h2o_ctx.storage.entries[0].data = dsc.get(); ++dsc->h2o_ctx.storage.size; - auto sock = h2o_evloop_socket_create(dsc->h2o_ctx.loop, dsc->dohresponsepair[1], H2O_SOCKET_FLAG_DONT_READ); + auto* sock = h2o_evloop_socket_create(dsc->h2o_ctx.loop, dsc->d_responseReceiver.getDescriptor(), H2O_SOCKET_FLAG_DONT_READ); sock->data = dsc.get(); // this listens to responses from dnsdist to turn into http responses @@ -1681,12 +1607,12 @@ void dohThread(ClientState* cs) setupAcceptContext(*dsc->accept_ctx, *dsc, false); - if (create_listener(dsc, cs->tcpFD) != 0) { - throw std::runtime_error("DOH server failed to listen on " + df->d_local.toStringWithPort() + ": " + strerror(errno)); + if (create_listener(dsc, clientState->tcpFD) != 0) { + throw std::runtime_error("DOH server failed to listen on " + dohFrontend->d_tlsContext.d_addr.toStringWithPort() + ": " + stringerror(errno)); } - for (const auto& [addr, fd] : cs->d_additionalAddresses) { - if (create_listener(dsc, fd) != 0) { - throw std::runtime_error("DOH server failed to listen on additional address " + addr.toStringWithPort() + " for DOH local" + df->d_local.toStringWithPort() + ": " + strerror(errno)); + for (const auto& [addr, descriptor] : clientState->d_additionalAddresses) { + if (create_listener(dsc, descriptor) != 0) { + throw std::runtime_error("DOH server failed to listen on additional address " + addr.toStringWithPort() + " for DOH local" + dohFrontend->d_tlsContext.d_addr.toStringWithPort() + ": " + stringerror(errno)); } } @@ -1695,12 +1621,12 @@ void dohThread(ClientState* cs) int result = h2o_evloop_run(dsc->h2o_ctx.loop, INT32_MAX); if (result == -1) { if (errno != EINTR) { - errlog("Error in the DoH event loop: %s", strerror(errno)); + errlog("Error in the DoH event loop: %s", stringerror(errno)); stop = true; } } } - while (stop == false); + while (!stop); } catch (const std::exception& e) { @@ -1711,55 +1637,117 @@ void dohThread(ClientState* cs) } } -void handleUDPResponseForDoH(DOHUnitUniquePtr&& du, PacketBuffer&& udpResponse, InternalQueryState&& state) +void DOHUnit::handleUDPResponse(PacketBuffer&& udpResponse, InternalQueryState&& state, [[maybe_unused]] const std::shared_ptr<DownstreamState>& downstream_) { - du->response = std::move(udpResponse); - du->ids = std::move(state); + auto dohUnit = std::unique_ptr<DOHUnit>(this); + dohUnit->ids = std::move(state); - const dnsheader* dh = reinterpret_cast<const struct dnsheader*>(du->response.data()); - if (!dh->tc) { + { + dnsheader_aligned dnsHeader(udpResponse.data()); + if (dnsHeader.get()->tc) { + dohUnit->truncated = true; + } + } + if (!dohUnit->truncated) { static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localRespRuleActions = g_respruleactions.getLocal(); static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal(); - DNSResponse dr(du->ids, du->response, du->downstream); - dnsheader cleartextDH; - memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH)); + DNSResponse dnsResponse(dohUnit->ids, udpResponse, dohUnit->downstream); + dnsheader cleartextDH{}; + memcpy(&cleartextDH, dnsResponse.getHeader().get(), sizeof(cleartextDH)); - dr.ids.du = std::move(du); - if (!processResponse(dr.ids.du->response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dr, false)) { - if (dr.ids.du) { - dr.ids.du->status_code = 503; - sendDoHUnitToTheMainThread(std::move(dr.ids.du), "Response dropped by rules"); + dnsResponse.ids.du = std::move(dohUnit); + if (!processResponse(udpResponse, *localRespRuleActions, *localCacheInsertedRespRuleActions, dnsResponse, false)) { + if (dnsResponse.ids.du) { + dohUnit = getDUFromIDS(dnsResponse.ids); + dohUnit->status_code = 503; + sendDoHUnitToTheMainThread(std::move(dohUnit), "Response dropped by rules"); } return; } - if (dr.isAsynchronous()) { + if (dnsResponse.isAsynchronous()) { return; } - du = std::move(dr.ids.du); - double udiff = du->ids.queryRealTime.udiff(); - vinfolog("Got answer from %s, relayed to %s (https), took %f usec", du->downstream->d_config.remote.toStringWithPort(), du->ids.origRemote.toStringWithPort(), udiff); + dohUnit = getDUFromIDS(dnsResponse.ids); + dohUnit->response = std::move(udpResponse); + double udiff = dohUnit->ids.queryRealTime.udiff(); + vinfolog("Got answer from %s, relayed to %s (https), took %f us", dohUnit->downstream->d_config.remote.toStringWithPort(), dohUnit->ids.origRemote.toStringWithPort(), udiff); - handleResponseSent(du->ids, udiff, dr.ids.origRemote, du->downstream->d_config.remote, du->response.size(), cleartextDH, du->downstream->getProtocol(), true); + handleResponseSent(dohUnit->ids, udiff, dnsResponse.ids.origRemote, dohUnit->downstream->d_config.remote, dohUnit->response.size(), cleartextDH, dohUnit->downstream->getProtocol(), true); - ++g_stats.responses; - if (du->ids.cs) { - ++du->ids.cs->responses; + ++dnsdist::metrics::g_stats.responses; + if (dohUnit->ids.cs != nullptr) { + ++dohUnit->ids.cs->responses; } } - else { - du->truncated = true; + + sendDoHUnitToTheMainThread(std::move(dohUnit), "DoH response"); +} + +void H2ODOHFrontend::rotateTicketsKey(time_t now) +{ + if (d_dsc && d_dsc->accept_ctx) { + d_dsc->accept_ctx->rotateTicketsKey(now); } +} - sendDoHUnitToTheMainThread(std::move(du), "DoH response"); +void H2ODOHFrontend::loadTicketsKeys(const std::string& keyFile) +{ + if (d_dsc && d_dsc->accept_ctx) { + d_dsc->accept_ctx->loadTicketsKeys(keyFile); + } } -#else /* HAVE_DNS_OVER_HTTPS */ +void H2ODOHFrontend::handleTicketsKeyRotation() +{ + if (d_dsc && d_dsc->accept_ctx) { + d_dsc->accept_ctx->handleTicketsKeyRotation(); + } +} + +std::string H2ODOHFrontend::getNextTicketsKeyRotation() const +{ + if (d_dsc && d_dsc->accept_ctx) { + return std::to_string(d_dsc->accept_ctx->getNextTicketsKeyRotation()); + } + return {}; +} + +size_t H2ODOHFrontend::getTicketsKeysCount() +{ + size_t res = 0; + if (d_dsc && d_dsc->accept_ctx) { + res = d_dsc->accept_ctx->getTicketsKeysCount(); + } + return res; +} -void handleDOHTimeout(DOHUnitUniquePtr&& oldDU) +void H2ODOHFrontend::reloadCertificates() { + auto newAcceptContext = std::make_shared<DOHAcceptContext>(); + setupAcceptContext(*newAcceptContext, *d_dsc, true); + std::atomic_store_explicit(&d_dsc->accept_ctx, std::move(newAcceptContext), std::memory_order_release); +} + +void H2ODOHFrontend::setup() +{ + registerOpenSSLUser(); + + d_dsc = std::make_shared<DOHServerConfig>(d_idleTimeout, d_internalPipeBufferSize); + + if (isHTTPS()) { + try { + setupTLSContext(*d_dsc->accept_ctx, + d_tlsContext.d_tlsConfig, + d_tlsContext.d_tlsCounters); + } + catch (const std::runtime_error& e) { + throw std::runtime_error("Error setting up TLS context for DoH listener on '" + d_tlsContext.d_addr.toStringWithPort() + "': " + e.what()); + } + } } +#endif /* HAVE_LIBH2OEVLOOP */ #endif /* HAVE_DNS_OVER_HTTPS */ |