summaryrefslogtreecommitdiffstats
path: root/doh3.cc
diff options
context:
space:
mode:
Diffstat (limited to 'doh3.cc')
-rw-r--r--doh3.cc1031
1 files changed, 1031 insertions, 0 deletions
diff --git a/doh3.cc b/doh3.cc
new file mode 100644
index 0000000..d5216aa
--- /dev/null
+++ b/doh3.cc
@@ -0,0 +1,1031 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "doh3.hh"
+
+#ifdef HAVE_DNS_OVER_HTTP3
+#include <quiche.h>
+
+#include "dolog.hh"
+#include "iputils.hh"
+#include "misc.hh"
+#include "sstuff.hh"
+#include "threadname.hh"
+#include "base64.hh"
+
+#include "dnsdist-dnsparser.hh"
+#include "dnsdist-ecs.hh"
+#include "dnsdist-proxy-protocol.hh"
+#include "dnsdist-tcp.hh"
+#include "dnsdist-random.hh"
+
+#include "doq-common.hh"
+
+#if 0
+#define DEBUGLOG_ENABLED
+#define DEBUGLOG(x) std::cerr << x << std::endl;
+#else
+#define DEBUGLOG(x)
+#endif
+
+using namespace dnsdist::doq;
+
+using h3_headers_t = std::map<std::string, std::string>;
+
+class H3Connection
+{
+public:
+ H3Connection(const ComboAddress& peer, QuicheConfig config, QuicheConnection&& conn) :
+ d_peer(peer), d_conn(std::move(conn)), d_config(std::move(config))
+ {
+ }
+ H3Connection(const H3Connection&) = delete;
+ H3Connection(H3Connection&&) = default;
+ H3Connection& operator=(const H3Connection&) = delete;
+ H3Connection& operator=(H3Connection&&) = default;
+ ~H3Connection() = default;
+
+ ComboAddress d_peer;
+ QuicheConnection d_conn;
+ QuicheConfig d_config;
+ QuicheHTTP3Connection d_http3{nullptr, quiche_h3_conn_free};
+ // buffer request headers by streamID
+ std::unordered_map<uint64_t, h3_headers_t> d_headersBuffers;
+ std::unordered_map<uint64_t, PacketBuffer> d_streamBuffers;
+ std::unordered_map<uint64_t, PacketBuffer> d_streamOutBuffers;
+};
+
+static void sendBackDOH3Unit(DOH3UnitUniquePtr&& unit, const char* description);
+
+struct DOH3ServerConfig
+{
+ DOH3ServerConfig(QuicheConfig&& config_, QuicheHTTP3Config&& http3config_, uint32_t internalPipeBufferSize) :
+ config(std::move(config_)), http3config(std::move(http3config_))
+ {
+ {
+ auto [sender, receiver] = pdns::channel::createObjectQueue<DOH3Unit>(pdns::channel::SenderBlockingMode::SenderNonBlocking, pdns::channel::ReceiverBlockingMode::ReceiverNonBlocking, internalPipeBufferSize);
+ d_responseSender = std::move(sender);
+ d_responseReceiver = std::move(receiver);
+ }
+ }
+ DOH3ServerConfig(const DOH3ServerConfig&) = delete;
+ DOH3ServerConfig(DOH3ServerConfig&&) = default;
+ DOH3ServerConfig& operator=(const DOH3ServerConfig&) = delete;
+ DOH3ServerConfig& operator=(DOH3ServerConfig&&) = default;
+ ~DOH3ServerConfig() = default;
+
+ using ConnectionsMap = std::map<PacketBuffer, H3Connection>;
+
+ LocalHolders holders;
+ ConnectionsMap d_connections;
+ QuicheConfig config;
+ QuicheHTTP3Config http3config;
+ ClientState* clientState{nullptr};
+ std::shared_ptr<DOH3Frontend> df{nullptr};
+ pdns::channel::Sender<DOH3Unit> d_responseSender;
+ pdns::channel::Receiver<DOH3Unit> d_responseReceiver;
+};
+
+/* these might seem useless, but they are needed because
+ they need to be declared _after_ the definition of DOH3ServerConfig
+ so that we can use a unique_ptr in DOH3Frontend */
+DOH3Frontend::DOH3Frontend() = default;
+DOH3Frontend::~DOH3Frontend() = default;
+
+class DOH3TCPCrossQuerySender final : public TCPQuerySender
+{
+public:
+ DOH3TCPCrossQuerySender() = default;
+
+ [[nodiscard]] bool active() const override
+ {
+ return true;
+ }
+
+ void handleResponse([[maybe_unused]] const struct timeval& now, TCPResponse&& response) override
+ {
+ if (!response.d_idstate.doh3u) {
+ return;
+ }
+
+ auto unit = std::move(response.d_idstate.doh3u);
+ if (unit->dsc == nullptr) {
+ return;
+ }
+
+ unit->response = std::move(response.d_buffer);
+ unit->ids = std::move(response.d_idstate);
+ DNSResponse dnsResponse(unit->ids, unit->response, unit->downstream);
+
+ dnsheader cleartextDH{};
+ memcpy(&cleartextDH, dnsResponse.getHeader().get(), sizeof(cleartextDH));
+
+ if (!response.isAsync()) {
+
+ static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localRespRuleActions = g_respruleactions.getLocal();
+ static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal();
+
+ dnsResponse.ids.doh3u = std::move(unit);
+
+ if (!processResponse(dnsResponse.ids.doh3u->response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dnsResponse, false)) {
+ if (dnsResponse.ids.doh3u) {
+
+ sendBackDOH3Unit(std::move(dnsResponse.ids.doh3u), "Response dropped by rules");
+ }
+ return;
+ }
+
+ if (dnsResponse.isAsynchronous()) {
+ return;
+ }
+
+ unit = std::move(dnsResponse.ids.doh3u);
+ }
+
+ if (!unit->ids.selfGenerated) {
+ double udiff = unit->ids.queryRealTime.udiff();
+ vinfolog("Got answer from %s, relayed to %s (DoH3, %d bytes), took %f us", unit->downstream->d_config.remote.toStringWithPort(), unit->ids.origRemote.toStringWithPort(), unit->response.size(), udiff);
+
+ auto backendProtocol = unit->downstream->getProtocol();
+ if (backendProtocol == dnsdist::Protocol::DoUDP && unit->tcp) {
+ backendProtocol = dnsdist::Protocol::DoTCP;
+ }
+ handleResponseSent(unit->ids, udiff, unit->ids.origRemote, unit->downstream->d_config.remote, unit->response.size(), cleartextDH, backendProtocol, true);
+ }
+
+ ++dnsdist::metrics::g_stats.responses;
+ if (unit->ids.cs != nullptr) {
+ ++unit->ids.cs->responses;
+ }
+
+ sendBackDOH3Unit(std::move(unit), "Cross-protocol response");
+ }
+
+ void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override
+ {
+ return handleResponse(now, std::move(response));
+ }
+
+ void notifyIOError([[maybe_unused]] const struct timeval& now, TCPResponse&& response) override
+ {
+ if (!response.d_idstate.doh3u) {
+ return;
+ }
+
+ auto unit = std::move(response.d_idstate.doh3u);
+ if (unit->dsc == nullptr) {
+ return;
+ }
+
+ /* this will signal an error */
+ unit->response.clear();
+ unit->ids = std::move(response.d_idstate);
+ sendBackDOH3Unit(std::move(unit), "Cross-protocol error");
+ }
+};
+
+class DOH3CrossProtocolQuery : public CrossProtocolQuery
+{
+public:
+ DOH3CrossProtocolQuery(DOH3UnitUniquePtr&& unit, bool isResponse)
+ {
+ if (isResponse) {
+ /* happens when a response becomes async */
+ query = InternalQuery(std::move(unit->response), std::move(unit->ids));
+ }
+ else {
+ /* we need to duplicate the query here because we might need
+ the existing query later if we get a truncated answer */
+ query = InternalQuery(PacketBuffer(unit->query), std::move(unit->ids));
+ }
+
+ /* it might have been moved when we moved unit->ids */
+ if (unit) {
+ query.d_idstate.doh3u = std::move(unit);
+ }
+
+ /* we _could_ remove it from the query buffer and put in query's d_proxyProtocolPayload,
+ clearing query.d_proxyProtocolPayloadAdded and unit->proxyProtocolPayloadSize.
+ Leave it for now because we know that the onky case where the payload has been
+ added is when we tried over UDP, got a TC=1 answer and retried over TCP/DoT,
+ and we know the TCP/DoT code can handle it. */
+ query.d_proxyProtocolPayloadAdded = query.d_idstate.doh3u->proxyProtocolPayloadSize > 0;
+ downstream = query.d_idstate.doh3u->downstream;
+ }
+
+ void handleInternalError()
+ {
+ sendBackDOH3Unit(std::move(query.d_idstate.doh3u), "DOH3 internal error");
+ }
+
+ std::shared_ptr<TCPQuerySender> getTCPQuerySender() override
+ {
+ query.d_idstate.doh3u->downstream = downstream;
+ return s_sender;
+ }
+
+ DNSQuestion getDQ() override
+ {
+ auto& ids = query.d_idstate;
+ DNSQuestion dnsQuestion(ids, query.d_buffer);
+ return dnsQuestion;
+ }
+
+ DNSResponse getDR() override
+ {
+ auto& ids = query.d_idstate;
+ DNSResponse dnsResponse(ids, query.d_buffer, downstream);
+ return dnsResponse;
+ }
+
+ DOH3UnitUniquePtr&& releaseDU()
+ {
+ return std::move(query.d_idstate.doh3u);
+ }
+
+private:
+ static std::shared_ptr<DOH3TCPCrossQuerySender> s_sender;
+};
+
+std::shared_ptr<DOH3TCPCrossQuerySender> DOH3CrossProtocolQuery::s_sender = std::make_shared<DOH3TCPCrossQuerySender>();
+
+static bool tryWriteResponse(H3Connection& conn, const uint64_t streamID, PacketBuffer& response)
+{
+ size_t pos = 0;
+ while (pos < response.size()) {
+ // send_body takes care of setting fin to false if it cannot send the entire content so we can try again.
+ auto res = quiche_h3_send_body(conn.d_http3.get(), conn.d_conn.get(),
+ streamID, &response.at(pos), response.size() - pos, true);
+ if (res == QUICHE_H3_ERR_DONE || res == QUICHE_H3_TRANSPORT_ERR_DONE) {
+ response.erase(response.begin(), response.begin() + static_cast<ssize_t>(pos));
+ return false;
+ }
+ if (res < 0) {
+ // Shutdown with internal error code
+ quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(dnsdist::doq::DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
+ return true;
+ }
+ pos += res;
+ }
+
+ return true;
+}
+
+static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const uint8_t* body, size_t len)
+{
+ std::string status = std::to_string(statusCode);
+ std::string lenStr = std::to_string(len);
+ std::array<quiche_h3_header, 3> headers{
+ (quiche_h3_header){
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
+ .name = reinterpret_cast<const uint8_t*>(":status"),
+ .name_len = sizeof(":status") - 1,
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
+ .value = reinterpret_cast<const uint8_t*>(status.data()),
+ .value_len = status.size(),
+ },
+ (quiche_h3_header){
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
+ .name = reinterpret_cast<const uint8_t*>("content-length"),
+ .name_len = sizeof("content-length") - 1,
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
+ .value = reinterpret_cast<const uint8_t*>(lenStr.data()),
+ .value_len = lenStr.size(),
+ },
+ (quiche_h3_header){
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
+ .name = reinterpret_cast<const uint8_t*>("content-type"),
+ .name_len = sizeof("content-type") - 1,
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
+ .value = reinterpret_cast<const uint8_t*>("application/dns-message"),
+ .value_len = sizeof("application/dns-message") - 1,
+ },
+ };
+ auto returnValue = quiche_h3_send_response(conn.d_http3.get(), conn.d_conn.get(),
+ streamID, headers.data(),
+ // do not include content-type header info if there is no content
+ (len > 0 && statusCode == 200U ? headers.size() : headers.size() - 1),
+ len == 0);
+ if (returnValue != 0) {
+ /* in theory it could be QUICHE_H3_ERR_STREAM_BLOCKED if the stream is not writable / congested, but we are not going to handle this case */
+ quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(dnsdist::doq::DOQ_Error_Codes::DOQ_INTERNAL_ERROR));
+ return;
+ }
+
+ if (len == 0) {
+ return;
+ }
+
+ size_t pos = 0;
+ while (pos < len) {
+ // send_body takes care of setting fin to false if it cannot send the entire content so we can try again.
+ auto res = quiche_h3_send_body(conn.d_http3.get(), conn.d_conn.get(),
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic): Quiche API
+ streamID, const_cast<uint8_t*>(body) + pos, len - pos, true);
+ if (res == QUICHE_H3_ERR_DONE || res == QUICHE_H3_TRANSPORT_ERR_DONE) {
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast,cppcoreguidelines-pro-bounds-pointer-arithmetic): Quiche API
+ conn.d_streamOutBuffers[streamID] = PacketBuffer(body + pos, body + len);
+ return;
+ }
+ if (res < 0) {
+ // Shutdown with internal error code
+ quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(1));
+ return;
+ }
+ pos += res;
+ }
+}
+
+static void h3_send_response(H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const std::string& content)
+{
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
+ h3_send_response(conn, streamID, statusCode, reinterpret_cast<const uint8_t*>(content.data()), content.size());
+}
+
+static void handleResponse(DOH3Frontend& frontend, H3Connection& conn, const uint64_t streamID, uint16_t statusCode, const PacketBuffer& response)
+{
+ if (statusCode == 200) {
+ ++frontend.d_validResponses;
+ }
+ else {
+ ++frontend.d_errorResponses;
+ }
+ if (response.empty()) {
+ quiche_conn_stream_shutdown(conn.d_conn.get(), streamID, QUICHE_SHUTDOWN_WRITE, static_cast<uint64_t>(DOQ_Error_Codes::DOQ_UNSPECIFIED_ERROR));
+ }
+ else {
+ h3_send_response(conn, streamID, statusCode, &response.at(0), response.size());
+ }
+}
+
+void DOH3Frontend::setup()
+{
+ auto config = QuicheConfig(quiche_config_new(QUICHE_PROTOCOL_VERSION), quiche_config_free);
+ d_quicheParams.d_alpn = std::string(DOH3_ALPN.begin(), DOH3_ALPN.end());
+ configureQuiche(config, d_quicheParams, true);
+
+ auto http3config = QuicheHTTP3Config(quiche_h3_config_new(), quiche_h3_config_free);
+
+ d_server_config = std::make_unique<DOH3ServerConfig>(std::move(config), std::move(http3config), d_internalPipeBufferSize);
+}
+
+void DOH3Frontend::reloadCertificates()
+{
+ auto config = QuicheConfig(quiche_config_new(QUICHE_PROTOCOL_VERSION), quiche_config_free);
+ d_quicheParams.d_alpn = std::string(DOH3_ALPN.begin(), DOH3_ALPN.end());
+ configureQuiche(config, d_quicheParams, true);
+ std::atomic_store_explicit(&d_server_config->config, std::move(config), std::memory_order_release);
+}
+
+static std::optional<std::reference_wrapper<H3Connection>> getConnection(DOH3ServerConfig::ConnectionsMap& connMap, const PacketBuffer& connID)
+{
+ auto iter = connMap.find(connID);
+ if (iter == connMap.end()) {
+ return std::nullopt;
+ }
+ return iter->second;
+}
+
+static void sendBackDOH3Unit(DOH3UnitUniquePtr&& unit, const char* description)
+{
+ if (unit->dsc == nullptr) {
+ return;
+ }
+ try {
+ if (!unit->dsc->d_responseSender.send(std::move(unit))) {
+ ++dnsdist::metrics::g_stats.doh3ResponsePipeFull;
+ vinfolog("Unable to pass a %s to the DoH3 worker thread because the pipe is full", description);
+ }
+ }
+ catch (const std::exception& e) {
+ vinfolog("Unable to pass a %s to the DoH3 worker thread because we couldn't write to the pipe: %s", description, e.what());
+ }
+}
+
+static std::optional<std::reference_wrapper<H3Connection>> createConnection(DOH3ServerConfig& config, const PacketBuffer& serverSideID, const PacketBuffer& originalDestinationID, const ComboAddress& local, const ComboAddress& peer)
+{
+ auto quicheConfig = std::atomic_load_explicit(&config.config, std::memory_order_acquire);
+ auto quicheConn = QuicheConnection(quiche_accept(serverSideID.data(), serverSideID.size(),
+ originalDestinationID.data(), originalDestinationID.size(),
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+ reinterpret_cast<const struct sockaddr*>(&local),
+ local.getSocklen(),
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+ reinterpret_cast<const struct sockaddr*>(&peer),
+ peer.getSocklen(),
+ quicheConfig.get()),
+ quiche_conn_free);
+
+ if (config.df && !config.df->d_quicheParams.d_keyLogFile.empty()) {
+ quiche_conn_set_keylog_path(quicheConn.get(), config.df->d_quicheParams.d_keyLogFile.c_str());
+ }
+
+ auto conn = H3Connection(peer, std::move(quicheConfig), std::move(quicheConn));
+ auto pair = config.d_connections.emplace(serverSideID, std::move(conn));
+ return pair.first->second;
+}
+
+std::unique_ptr<CrossProtocolQuery> getDOH3CrossProtocolQueryFromDQ(DNSQuestion& dnsQuestion, bool isResponse)
+{
+ if (!dnsQuestion.ids.doh3u) {
+ throw std::runtime_error("Trying to create a DoH3 cross protocol query without a valid DoH3 unit");
+ }
+
+ auto unit = std::move(dnsQuestion.ids.doh3u);
+ if (&dnsQuestion.ids != &unit->ids) {
+ unit->ids = std::move(dnsQuestion.ids);
+ }
+
+ unit->ids.origID = dnsQuestion.getHeader()->id;
+
+ if (!isResponse) {
+ if (unit->query.data() != dnsQuestion.getMutableData().data()) {
+ unit->query = std::move(dnsQuestion.getMutableData());
+ }
+ }
+ else {
+ if (unit->response.data() != dnsQuestion.getMutableData().data()) {
+ unit->response = std::move(dnsQuestion.getMutableData());
+ }
+ }
+
+ return std::make_unique<DOH3CrossProtocolQuery>(std::move(unit), isResponse);
+}
+
+static void processDOH3Query(DOH3UnitUniquePtr&& doh3Unit)
+{
+ const auto handleImmediateResponse = [](DOH3UnitUniquePtr&& unit, [[maybe_unused]] const char* reason) {
+ DEBUGLOG("handleImmediateResponse() reason=" << reason);
+ auto conn = getConnection(unit->dsc->df->d_server_config->d_connections, unit->serverConnID);
+ handleResponse(*unit->dsc->df, *conn, unit->streamID, unit->status_code, unit->response);
+ unit->ids.doh3u.reset();
+ };
+
+ auto& ids = doh3Unit->ids;
+ ids.doh3u = std::move(doh3Unit);
+ auto& unit = ids.doh3u;
+ uint16_t queryId = 0;
+ ComboAddress remote;
+
+ try {
+
+ remote = unit->ids.origRemote;
+ DOH3ServerConfig* dsc = unit->dsc;
+ auto& holders = dsc->holders;
+ ClientState& clientState = *dsc->clientState;
+
+ if (!holders.acl->match(remote)) {
+ vinfolog("Query from %s (DoH3) dropped because of ACL", remote.toStringWithPort());
+ ++dnsdist::metrics::g_stats.aclDrops;
+ unit->response.clear();
+
+ unit->status_code = 403;
+ handleImmediateResponse(std::move(unit), "DoH3 query dropped because of ACL");
+ return;
+ }
+
+ if (unit->query.size() < sizeof(dnsheader)) {
+ ++dnsdist::metrics::g_stats.nonCompliantQueries;
+ ++clientState.nonCompliantQueries;
+ unit->response.clear();
+
+ unit->status_code = 400;
+ handleImmediateResponse(std::move(unit), "DoH3 non-compliant query");
+ return;
+ }
+
+ ++clientState.queries;
+ ++dnsdist::metrics::g_stats.queries;
+ unit->ids.queryRealTime.start();
+
+ {
+ /* don't keep that pointer around, it will be invalidated if the buffer is ever resized */
+ dnsheader_aligned dnsHeader(unit->query.data());
+
+ if (!checkQueryHeaders(*dnsHeader, clientState)) {
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(unit->query, [](dnsheader& header) {
+ header.rcode = RCode::ServFail;
+ header.qr = true;
+ return true;
+ });
+ unit->response = std::move(unit->query);
+
+ unit->status_code = 400;
+ handleImmediateResponse(std::move(unit), "DoH3 invalid headers");
+ return;
+ }
+
+ if (dnsHeader->qdcount == 0) {
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(unit->query, [](dnsheader& header) {
+ header.rcode = RCode::NotImp;
+ header.qr = true;
+ return true;
+ });
+ unit->response = std::move(unit->query);
+
+ unit->status_code = 400;
+ handleImmediateResponse(std::move(unit), "DoH3 empty query");
+ return;
+ }
+
+ queryId = ntohs(dnsHeader->id);
+ }
+
+ auto downstream = unit->downstream;
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+ unit->ids.qname = DNSName(reinterpret_cast<const char*>(unit->query.data()), static_cast<int>(unit->query.size()), sizeof(dnsheader), false, &unit->ids.qtype, &unit->ids.qclass);
+ DNSQuestion dnsQuestion(unit->ids, unit->query);
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsQuestion.getMutableData(), [&ids](dnsheader& header) {
+ const uint16_t* flags = getFlagsFromDNSHeader(&header);
+ ids.origFlags = *flags;
+ return true;
+ });
+ unit->ids.cs = &clientState;
+
+ auto result = processQuery(dnsQuestion, holders, downstream);
+ if (result == ProcessQueryResult::Drop) {
+ unit->status_code = 403;
+ handleImmediateResponse(std::move(unit), "DoH3 dropped query");
+ return;
+ }
+ if (result == ProcessQueryResult::Asynchronous) {
+ return;
+ }
+ if (result == ProcessQueryResult::SendAnswer) {
+ if (unit->response.empty()) {
+ unit->response = std::move(unit->query);
+ }
+ if (unit->response.size() >= sizeof(dnsheader)) {
+ const dnsheader_aligned dnsHeader(unit->response.data());
+
+ handleResponseSent(unit->ids.qname, QType(unit->ids.qtype), 0., unit->ids.origDest, ComboAddress(), unit->response.size(), *dnsHeader, dnsdist::Protocol::DoH3, dnsdist::Protocol::DoH3, false);
+ }
+ handleImmediateResponse(std::move(unit), "DoH3 self-answered response");
+ return;
+ }
+
+ ++dnsdist::metrics::g_stats.responses;
+ if (unit->ids.cs != nullptr) {
+ ++unit->ids.cs->responses;
+ }
+
+ if (result != ProcessQueryResult::PassToBackend) {
+ unit->status_code = 500;
+ handleImmediateResponse(std::move(unit), "DoH3 no backend available");
+ return;
+ }
+
+ if (downstream == nullptr) {
+ unit->status_code = 502;
+ handleImmediateResponse(std::move(unit), "DoH3 no backend available");
+ return;
+ }
+
+ unit->downstream = downstream;
+
+ std::string proxyProtocolPayload;
+ /* we need to do this _before_ creating the cross protocol query because
+ after that the buffer will have been moved */
+ if (downstream->d_config.useProxyProtocol) {
+ proxyProtocolPayload = getProxyProtocolPayload(dnsQuestion);
+ }
+
+ unit->ids.origID = htons(queryId);
+ unit->tcp = true;
+
+ /* this moves unit->ids, careful! */
+ auto cpq = std::make_unique<DOH3CrossProtocolQuery>(std::move(unit), false);
+ cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload);
+
+ if (downstream->passCrossProtocolQuery(std::move(cpq))) {
+ return;
+ }
+ // NOLINTNEXTLINE(bugprone-use-after-move): it was only moved if the call succeeded
+ unit = cpq->releaseDU();
+ unit->status_code = 500;
+ handleImmediateResponse(std::move(unit), "DoH3 internal error");
+ return;
+ }
+ catch (const std::exception& e) {
+ vinfolog("Got an error in DOH3 question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what());
+ unit->status_code = 500;
+ handleImmediateResponse(std::move(unit), "DoH3 internal error");
+ return;
+ }
+}
+
+static void doh3_dispatch_query(DOH3ServerConfig& dsc, PacketBuffer&& query, const ComboAddress& local, const ComboAddress& remote, const PacketBuffer& serverConnID, const uint64_t streamID)
+{
+ try {
+ auto unit = std::make_unique<DOH3Unit>(std::move(query));
+ unit->dsc = &dsc;
+ unit->ids.origDest = local;
+ unit->ids.origRemote = remote;
+ unit->ids.protocol = dnsdist::Protocol::DoH3;
+ unit->serverConnID = serverConnID;
+ unit->streamID = streamID;
+
+ processDOH3Query(std::move(unit));
+ }
+ catch (const std::exception& exp) {
+ vinfolog("Had error handling DoH3 DNS packet from %s: %s", remote.toStringWithPort(), exp.what());
+ }
+}
+
+static void flushResponses(pdns::channel::Receiver<DOH3Unit>& receiver)
+{
+ for (;;) {
+ try {
+ auto tmp = receiver.receive();
+ if (!tmp) {
+ return;
+ }
+
+ auto unit = std::move(*tmp);
+ auto conn = getConnection(unit->dsc->df->d_server_config->d_connections, unit->serverConnID);
+ if (conn) {
+ handleResponse(*unit->dsc->df, *conn, unit->streamID, unit->status_code, unit->response);
+ }
+ }
+ catch (const std::exception& e) {
+ errlog("Error while processing response received over DoH3: %s", e.what());
+ }
+ catch (...) {
+ errlog("Unspecified error while processing response received over DoH3");
+ }
+ }
+}
+
+static void flushStalledResponses(H3Connection& conn)
+{
+ for (auto streamIt = conn.d_streamOutBuffers.begin(); streamIt != conn.d_streamOutBuffers.end();) {
+ const auto streamID = streamIt->first;
+ auto& response = streamIt->second;
+ if (quiche_conn_stream_writable(conn.d_conn.get(), streamID, response.size()) == 1) {
+ if (tryWriteResponse(conn, streamID, response)) {
+ streamIt = conn.d_streamOutBuffers.erase(streamIt);
+ continue;
+ }
+ }
+ ++streamIt;
+ }
+}
+
+static void processH3HeaderEvent(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, const PacketBuffer& serverConnID, const uint64_t streamID, quiche_h3_event* event)
+{
+ auto handleImmediateError = [&clientState, &frontend, &conn, streamID](const char* msg) {
+ DEBUGLOG(msg);
+ ++dnsdist::metrics::g_stats.nonCompliantQueries;
+ ++clientState.nonCompliantQueries;
+ ++frontend.d_errorResponses;
+ h3_send_response(conn, streamID, 400, msg);
+ };
+
+ auto& headers = conn.d_headersBuffers.at(streamID);
+ // Callback result. Any value other than 0 will interrupt further header processing.
+ int cbresult = quiche_h3_event_for_each_header(
+ event,
+ [](uint8_t* name, size_t name_len, uint8_t* value, size_t value_len, void* argp) -> int {
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
+ std::string_view key(reinterpret_cast<char*>(name), name_len);
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
+ std::string_view content(reinterpret_cast<char*>(value), value_len);
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): Quiche API
+ auto* headersptr = reinterpret_cast<h3_headers_t*>(argp);
+ headersptr->emplace(key, content);
+ return 0;
+ },
+ &headers);
+
+#ifdef DEBUGLOG_ENABLED
+ DEBUGLOG("Processed headers of stream " << streamID);
+ for (const auto& [key, value] : headers) {
+ DEBUGLOG(" " << key << ": " << value);
+ }
+#endif
+ if (cbresult != 0 || headers.count(":method") == 0) {
+ handleImmediateError("Unable to process query headers");
+ return;
+ }
+
+ if (headers.at(":method") == "GET") {
+ if (headers.count(":path") == 0 || headers.at(":path").empty()) {
+ handleImmediateError("Path not found");
+ return;
+ }
+ const auto& path = headers.at(":path");
+ auto payload = dnsdist::doh::getPayloadFromPath(path);
+ if (!payload) {
+ handleImmediateError("Unable to find the DNS parameter");
+ return;
+ }
+ if (payload->size() < sizeof(dnsheader)) {
+ handleImmediateError("DoH3 non-compliant query");
+ return;
+ }
+ DEBUGLOG("Dispatching GET query");
+ doh3_dispatch_query(*(frontend.d_server_config), std::move(*payload), clientState.local, client, serverConnID, streamID);
+ conn.d_streamBuffers.erase(streamID);
+ conn.d_headersBuffers.erase(streamID);
+ return;
+ }
+
+ if (headers.at(":method") == "POST") {
+ if (!quiche_h3_event_headers_has_body(event)) {
+ handleImmediateError("Empty POST query");
+ }
+ return;
+ }
+
+ handleImmediateError("Unsupported HTTP method");
+}
+
+static void processH3DataEvent(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, const PacketBuffer& serverConnID, const uint64_t streamID, quiche_h3_event* event, PacketBuffer& buffer)
+{
+ auto handleImmediateError = [&clientState, &frontend, &conn, streamID](const char* msg) {
+ DEBUGLOG(msg);
+ ++dnsdist::metrics::g_stats.nonCompliantQueries;
+ ++clientState.nonCompliantQueries;
+ ++frontend.d_errorResponses;
+ h3_send_response(conn, streamID, 400, msg);
+ };
+ auto& headers = conn.d_headersBuffers.at(streamID);
+
+ if (headers.at(":method") != "POST") {
+ handleImmediateError("DATA frame for non-POST method");
+ return;
+ }
+
+ if (headers.count("content-type") == 0 || headers.at("content-type") != "application/dns-message") {
+ handleImmediateError("Unsupported content-type");
+ return;
+ }
+
+ buffer.resize(std::numeric_limits<uint16_t>::max());
+ auto& streamBuffer = conn.d_streamBuffers[streamID];
+
+ while (true) {
+ buffer.resize(std::numeric_limits<uint16_t>::max());
+ ssize_t len = quiche_h3_recv_body(conn.d_http3.get(),
+ conn.d_conn.get(), streamID,
+ buffer.data(), buffer.size());
+
+ if (len <= 0) {
+ break;
+ }
+
+ buffer.resize(static_cast<size_t>(len));
+ streamBuffer.insert(streamBuffer.end(), buffer.begin(), buffer.end());
+ }
+
+ if (!quiche_conn_stream_finished(conn.d_conn.get(), streamID)) {
+ return;
+ }
+
+ if (streamBuffer.size() < sizeof(dnsheader)) {
+ conn.d_streamBuffers.erase(streamID);
+ handleImmediateError("DoH3 non-compliant query");
+ return;
+ }
+
+ DEBUGLOG("Dispatching POST query");
+ doh3_dispatch_query(*(frontend.d_server_config), std::move(streamBuffer), clientState.local, client, serverConnID, streamID);
+ conn.d_headersBuffers.erase(streamID);
+ conn.d_streamBuffers.erase(streamID);
+}
+
+static void processH3Events(ClientState& clientState, DOH3Frontend& frontend, H3Connection& conn, const ComboAddress& client, const PacketBuffer& serverConnID, PacketBuffer& buffer)
+{
+ while (true) {
+ quiche_h3_event* event{nullptr};
+ // Processes HTTP/3 data received from the peer
+ const int64_t streamID = quiche_h3_conn_poll(conn.d_http3.get(),
+ conn.d_conn.get(),
+ &event);
+
+ if (streamID < 0) {
+ break;
+ }
+ conn.d_headersBuffers.try_emplace(streamID, h3_headers_t{});
+
+ switch (quiche_h3_event_type(event)) {
+ case QUICHE_H3_EVENT_HEADERS: {
+ processH3HeaderEvent(clientState, frontend, conn, client, serverConnID, streamID, event);
+ break;
+ }
+ case QUICHE_H3_EVENT_DATA: {
+ processH3DataEvent(clientState, frontend, conn, client, serverConnID, streamID, event, buffer);
+ break;
+ }
+ case QUICHE_H3_EVENT_FINISHED:
+ case QUICHE_H3_EVENT_RESET:
+ case QUICHE_H3_EVENT_PRIORITY_UPDATE:
+ case QUICHE_H3_EVENT_GOAWAY:
+ break;
+ }
+
+ quiche_h3_event_free(event);
+ }
+}
+
+static void handleSocketReadable(DOH3Frontend& frontend, ClientState& clientState, Socket& sock, PacketBuffer& buffer)
+{
+ // destination connection ID, will have to be sent as original destination connection ID
+ PacketBuffer serverConnID;
+ // source connection ID, will have to be sent as destination connection ID
+ PacketBuffer clientConnID;
+ PacketBuffer tokenBuf;
+ while (true) {
+ ComboAddress client;
+ buffer.resize(4096);
+ if (!sock.recvFromAsync(buffer, client) || buffer.empty()) {
+ return;
+ }
+ DEBUGLOG("Received DoH3 datagram of size " << buffer.size() << " from " << client.toStringWithPort());
+
+ uint32_t version{0};
+ uint8_t type{0};
+ std::array<uint8_t, QUICHE_MAX_CONN_ID_LEN> scid{};
+ size_t scid_len = scid.size();
+ std::array<uint8_t, QUICHE_MAX_CONN_ID_LEN> dcid{};
+ size_t dcid_len = dcid.size();
+ std::array<uint8_t, MAX_TOKEN_LEN> token{};
+ size_t token_len = token.size();
+
+ auto res = quiche_header_info(buffer.data(), buffer.size(), LOCAL_CONN_ID_LEN,
+ &version, &type,
+ scid.data(), &scid_len,
+ dcid.data(), &dcid_len,
+ token.data(), &token_len);
+ if (res != 0) {
+ DEBUGLOG("Error in quiche_header_info: " << res);
+ continue;
+ }
+
+ serverConnID.assign(dcid.begin(), dcid.begin() + dcid_len);
+ // source connection ID, will have to be sent as destination connection ID
+ clientConnID.assign(scid.begin(), scid.begin() + scid_len);
+ auto conn = getConnection(frontend.d_server_config->d_connections, serverConnID);
+
+ if (!conn) {
+ DEBUGLOG("Connection not found");
+ if (type != static_cast<uint8_t>(DOQ_Packet_Types::QUIC_PACKET_TYPE_INITIAL)) {
+ DEBUGLOG("Packet is not initial");
+ continue;
+ }
+
+ if (!quiche_version_is_supported(version)) {
+ DEBUGLOG("Unsupported version");
+ ++frontend.d_doh3UnsupportedVersionErrors;
+ handleVersionNegociation(sock, clientConnID, serverConnID, client, buffer);
+ continue;
+ }
+
+ if (token_len == 0) {
+ /* stateless retry */
+ DEBUGLOG("No token received");
+ handleStatelessRetry(sock, clientConnID, serverConnID, client, version, buffer);
+ continue;
+ }
+
+ tokenBuf.assign(token.begin(), token.begin() + token_len);
+ auto originalDestinationID = validateToken(tokenBuf, client);
+ if (!originalDestinationID) {
+ ++frontend.d_doh3InvalidTokensReceived;
+ DEBUGLOG("Discarding invalid token");
+ continue;
+ }
+
+ DEBUGLOG("Creating a new connection");
+ conn = createConnection(*frontend.d_server_config, serverConnID, *originalDestinationID, clientState.local, client);
+ if (!conn) {
+ continue;
+ }
+ }
+ DEBUGLOG("Connection found");
+ quiche_recv_info recv_info = {
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+ reinterpret_cast<struct sockaddr*>(&client),
+ client.getSocklen(),
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+ reinterpret_cast<struct sockaddr*>(&clientState.local),
+ clientState.local.getSocklen(),
+ };
+
+ auto done = quiche_conn_recv(conn->get().d_conn.get(), buffer.data(), buffer.size(), &recv_info);
+ if (done < 0) {
+ continue;
+ }
+
+ if (quiche_conn_is_established(conn->get().d_conn.get()) || quiche_conn_is_in_early_data(conn->get().d_conn.get())) {
+ DEBUGLOG("Connection is established");
+
+ if (!conn->get().d_http3) {
+ conn->get().d_http3 = QuicheHTTP3Connection(quiche_h3_conn_new_with_transport(conn->get().d_conn.get(), frontend.d_server_config->http3config.get()),
+ quiche_h3_conn_free);
+ if (!conn->get().d_http3) {
+ continue;
+ }
+ DEBUGLOG("Successfully created HTTP/3 connection");
+ }
+
+ processH3Events(clientState, frontend, conn->get(), client, serverConnID, buffer);
+
+ flushEgress(sock, conn->get().d_conn, client, buffer);
+ }
+ else {
+ DEBUGLOG("Connection not established");
+ }
+ }
+}
+
+// this is the entrypoint from dnsdist.cc
+void doh3Thread(ClientState* clientState)
+{
+ try {
+ std::shared_ptr<DOH3Frontend>& frontend = clientState->doh3Frontend;
+
+ frontend->d_server_config->clientState = clientState;
+ frontend->d_server_config->df = clientState->doh3Frontend;
+
+ setThreadName("dnsdist/doh3");
+
+ Socket sock(clientState->udpFD);
+ sock.setNonBlocking();
+
+ auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent());
+
+ auto responseReceiverFD = frontend->d_server_config->d_responseReceiver.getDescriptor();
+ mplexer->addReadFD(sock.getHandle(), [](int, FDMultiplexer::funcparam_t&) {});
+ mplexer->addReadFD(responseReceiverFD, [](int, FDMultiplexer::funcparam_t&) {});
+ std::vector<int> readyFDs;
+ PacketBuffer buffer(4096);
+ while (true) {
+ readyFDs.clear();
+ mplexer->getAvailableFDs(readyFDs, 500);
+
+ try {
+ if (std::find(readyFDs.begin(), readyFDs.end(), sock.getHandle()) != readyFDs.end()) {
+ handleSocketReadable(*frontend, *clientState, sock, buffer);
+ }
+
+ if (std::find(readyFDs.begin(), readyFDs.end(), responseReceiverFD) != readyFDs.end()) {
+ flushResponses(frontend->d_server_config->d_responseReceiver);
+ }
+
+ for (auto conn = frontend->d_server_config->d_connections.begin(); conn != frontend->d_server_config->d_connections.end();) {
+ quiche_conn_on_timeout(conn->second.d_conn.get());
+
+ flushEgress(sock, conn->second.d_conn, conn->second.d_peer, buffer);
+
+ if (quiche_conn_is_closed(conn->second.d_conn.get())) {
+#ifdef DEBUGLOG_ENABLED
+ quiche_stats stats;
+ quiche_path_stats path_stats;
+
+ quiche_conn_stats(conn->second.d_conn.get(), &stats);
+ quiche_conn_path_stats(conn->second.d_conn.get(), 0, &path_stats);
+
+ DEBUGLOG("Connection (DoH3) closed, recv=" << stats.recv << " sent=" << stats.sent << " lost=" << stats.lost << " rtt=" << path_stats.rtt << "ns cwnd=" << path_stats.cwnd);
+#endif
+ conn = frontend->d_server_config->d_connections.erase(conn);
+ }
+ else {
+ flushStalledResponses(conn->second);
+ ++conn;
+ }
+ }
+ }
+ catch (const std::exception& exp) {
+ vinfolog("Caught exception in the main DoH3 thread: %s", exp.what());
+ }
+ catch (...) {
+ vinfolog("Unknown exception in the main DoH3 thread");
+ }
+ }
+ }
+ catch (const std::exception& e) {
+ DEBUGLOG("Caught fatal error in the main DoH3 thread: " << e.what());
+ }
+}
+
+#endif /* HAVE_DNS_OVER_HTTP3 */