summaryrefslogtreecommitdiffstats
path: root/dnsdist-nghttp2-in.cc
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--dnsdist-nghttp2-in.cc1187
1 files changed, 1187 insertions, 0 deletions
diff --git a/dnsdist-nghttp2-in.cc b/dnsdist-nghttp2-in.cc
new file mode 100644
index 0000000..32dc254
--- /dev/null
+++ b/dnsdist-nghttp2-in.cc
@@ -0,0 +1,1187 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#include "dnsdist-dnsparser.hh"
+#include "dnsdist-doh-common.hh"
+#include "dnsdist-nghttp2-in.hh"
+#include "dnsdist-proxy-protocol.hh"
+#include "dnsparser.hh"
+
+#if defined(HAVE_DNS_OVER_HTTPS) && defined(HAVE_NGHTTP2)
+
+#if 0
+class IncomingDoHCrossProtocolContext : public CrossProtocolContext
+{
+public:
+ IncomingDoHCrossProtocolContext(IncomingHTTP2Connection::PendingQuery&& query, std::shared_ptr<IncomingHTTP2Connection> connection, IncomingHTTP2Connection::StreamID streamID): CrossProtocolContext(std::move(query.d_buffer)), d_connection(connection), d_query(std::move(query))
+ {
+ }
+
+ std::optional<std::string> getHTTPPath() const override
+ {
+ return d_query.d_path;
+ }
+
+ std::optional<std::string> getHTTPScheme() const override
+ {
+ return d_query.d_scheme;
+ }
+
+ std::optional<std::string> getHTTPHost() const override
+ {
+ return d_query.d_host;
+ }
+
+ std::optional<std::string> getHTTPQueryString() const override
+ {
+ return d_query.d_queryString;
+ }
+
+ std::optional<HeadersMap> getHTTPHeaders() const override
+ {
+ if (!d_query.d_headers) {
+ return std::nullopt;
+ }
+ return *d_query.d_headers;
+ }
+
+ void handleResponse(PacketBuffer&& response, InternalQueryState&& state) override
+ {
+ auto conn = d_connection.lock();
+ if (!conn) {
+ /* the connection has been closed in the meantime */
+ return;
+ }
+ }
+
+ void handleTimeout() override
+ {
+ auto conn = d_connection.lock();
+ if (!conn) {
+ /* the connection has been closed in the meantime */
+ return;
+ }
+ }
+
+ ~IncomingDoHCrossProtocolContext() override
+ {
+ }
+
+private:
+ std::weak_ptr<IncomingHTTP2Connection> d_connection;
+ IncomingHTTP2Connection::PendingQuery d_query;
+ IncomingHTTP2Connection::StreamID d_streamID{-1};
+};
+#endif
+
+class IncomingDoHCrossProtocolContext : public DOHUnitInterface
+{
+public:
+ IncomingDoHCrossProtocolContext(IncomingHTTP2Connection::PendingQuery&& query, const std::shared_ptr<IncomingHTTP2Connection>& connection, IncomingHTTP2Connection::StreamID streamID) :
+ d_connection(connection), d_query(std::move(query)), d_streamID(streamID)
+ {
+ }
+ IncomingDoHCrossProtocolContext(const IncomingDoHCrossProtocolContext&) = delete;
+ IncomingDoHCrossProtocolContext(IncomingDoHCrossProtocolContext&&) = delete;
+ IncomingDoHCrossProtocolContext& operator=(const IncomingDoHCrossProtocolContext&) = delete;
+ IncomingDoHCrossProtocolContext& operator=(IncomingDoHCrossProtocolContext&&) = delete;
+
+ ~IncomingDoHCrossProtocolContext() override = default;
+
+ [[nodiscard]] std::string getHTTPPath() const override
+ {
+ return d_query.d_path;
+ }
+
+ [[nodiscard]] const std::string& getHTTPScheme() const override
+ {
+ return d_query.d_scheme;
+ }
+
+ [[nodiscard]] const std::string& getHTTPHost() const override
+ {
+ return d_query.d_host;
+ }
+
+ [[nodiscard]] std::string getHTTPQueryString() const override
+ {
+ return d_query.d_queryString;
+ }
+
+ [[nodiscard]] const HeadersMap& getHTTPHeaders() const override
+ {
+ if (!d_query.d_headers) {
+ static const HeadersMap empty{};
+ return empty;
+ }
+ return *d_query.d_headers;
+ }
+
+ void setHTTPResponse(uint16_t statusCode, PacketBuffer&& body, const std::string& contentType = "") override
+ {
+ d_query.d_statusCode = statusCode;
+ d_query.d_response = std::move(body);
+ d_query.d_contentTypeOut = contentType;
+ }
+
+ void handleUDPResponse(PacketBuffer&& response, InternalQueryState&& state, const std::shared_ptr<DownstreamState>& downstream_) override
+ {
+ std::unique_ptr<DOHUnitInterface> unit(this);
+ auto conn = d_connection.lock();
+ if (!conn) {
+ /* the connection has been closed in the meantime */
+ return;
+ }
+
+ state.du = std::move(unit);
+ TCPResponse resp(std::move(response), std::move(state), nullptr, nullptr);
+ resp.d_ds = downstream_;
+ struct timeval now
+ {
+ };
+ gettimeofday(&now, nullptr);
+ conn->handleResponse(now, std::move(resp));
+ }
+
+ void handleTimeout() override
+ {
+ std::unique_ptr<DOHUnitInterface> unit(this);
+ auto conn = d_connection.lock();
+ if (!conn) {
+ /* the connection has been closed in the meantime */
+ return;
+ }
+ struct timeval now
+ {
+ };
+ gettimeofday(&now, nullptr);
+ TCPResponse resp;
+ resp.d_idstate.d_streamID = d_streamID;
+ conn->notifyIOError(now, std::move(resp));
+ }
+
+ std::weak_ptr<IncomingHTTP2Connection> d_connection;
+ IncomingHTTP2Connection::PendingQuery d_query;
+ IncomingHTTP2Connection::StreamID d_streamID{-1};
+};
+
+void IncomingHTTP2Connection::handleResponse(const struct timeval& now, TCPResponse&& response)
+{
+ if (std::this_thread::get_id() != d_creatorThreadID) {
+ handleCrossProtocolResponse(now, std::move(response));
+ return;
+ }
+
+ auto& state = response.d_idstate;
+ if (state.forwardedOverUDP) {
+ dnsheader_aligned responseDH(response.d_buffer.data());
+
+ if (responseDH.get()->tc && state.d_packet && state.d_packet->size() > state.d_proxyProtocolPayloadSize && state.d_packet->size() - state.d_proxyProtocolPayloadSize > sizeof(dnsheader)) {
+ vinfolog("Response received from backend %s via UDP, for query %d received from %s via DoH, is truncated, retrying over TCP", response.d_ds->getNameWithAddr(), state.d_streamID, state.origRemote.toStringWithPort());
+ auto& query = *state.d_packet;
+ dnsdist::PacketMangling::editDNSHeaderFromRawPacket(&query.at(state.d_proxyProtocolPayloadSize), [origID = state.origID](dnsheader& header) {
+ /* restoring the original ID */
+ header.id = origID;
+ return true;
+ });
+
+ state.forwardedOverUDP = false;
+ bool proxyProtocolPayloadAdded = state.d_proxyProtocolPayloadSize > 0;
+ auto cpq = getCrossProtocolQuery(std::move(query), std::move(state), response.d_ds);
+ cpq->query.d_proxyProtocolPayloadAdded = proxyProtocolPayloadAdded;
+ if (g_tcpclientthreads && g_tcpclientthreads->passCrossProtocolQueryToThread(std::move(cpq))) {
+ return;
+ }
+ vinfolog("Unable to pass DoH query to a TCP worker thread after getting a TC response over UDP");
+ notifyIOError(now, std::move(response));
+ return;
+ }
+ }
+
+ IncomingTCPConnectionState::handleResponse(now, std::move(response));
+}
+
+std::unique_ptr<DOHUnitInterface> IncomingHTTP2Connection::getDOHUnit(uint32_t streamID)
+{
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-array-to-pointer-decay): clang-tidy is getting confused by assert()
+ assert(streamID <= std::numeric_limits<IncomingHTTP2Connection::StreamID>::max());
+ // NOLINTNEXTLINE(*-narrowing-conversions): generic interface between DNS and DoH with different types
+ auto query = std::move(d_currentStreams.at(static_cast<IncomingHTTP2Connection::StreamID>(streamID)));
+ return std::make_unique<IncomingDoHCrossProtocolContext>(std::move(query), std::dynamic_pointer_cast<IncomingHTTP2Connection>(shared_from_this()), streamID);
+}
+
+void IncomingHTTP2Connection::restoreDOHUnit(std::unique_ptr<DOHUnitInterface>&& unit)
+{
+ auto context = std::unique_ptr<IncomingDoHCrossProtocolContext>(dynamic_cast<IncomingDoHCrossProtocolContext*>(unit.release()));
+ if (context) {
+ d_currentStreams.at(context->d_streamID) = std::move(context->d_query);
+ }
+}
+
+IncomingHTTP2Connection::IncomingHTTP2Connection(ConnectionInfo&& connectionInfo, TCPClientThreadData& threadData, const struct timeval& now) :
+ IncomingTCPConnectionState(std::move(connectionInfo), threadData, now)
+{
+ nghttp2_session_callbacks* cbs = nullptr;
+ if (nghttp2_session_callbacks_new(&cbs) != 0) {
+ throw std::runtime_error("Unable to create a callback object for a new incoming HTTP/2 session");
+ }
+ std::unique_ptr<nghttp2_session_callbacks, void (*)(nghttp2_session_callbacks*)> callbacks(cbs, nghttp2_session_callbacks_del);
+ cbs = nullptr;
+
+ nghttp2_session_callbacks_set_send_callback(callbacks.get(), send_callback);
+ nghttp2_session_callbacks_set_on_frame_recv_callback(callbacks.get(), on_frame_recv_callback);
+ nghttp2_session_callbacks_set_on_stream_close_callback(callbacks.get(), on_stream_close_callback);
+ nghttp2_session_callbacks_set_on_begin_headers_callback(callbacks.get(), on_begin_headers_callback);
+ nghttp2_session_callbacks_set_on_header_callback(callbacks.get(), on_header_callback);
+ nghttp2_session_callbacks_set_on_data_chunk_recv_callback(callbacks.get(), on_data_chunk_recv_callback);
+ nghttp2_session_callbacks_set_error_callback2(callbacks.get(), on_error_callback);
+
+ nghttp2_session* sess = nullptr;
+ if (nghttp2_session_server_new(&sess, callbacks.get(), this) != 0) {
+ throw std::runtime_error("Coult not allocate a new incoming HTTP/2 session");
+ }
+
+ d_session = std::unique_ptr<nghttp2_session, decltype(&nghttp2_session_del)>(sess, nghttp2_session_del);
+ sess = nullptr;
+}
+
+bool IncomingHTTP2Connection::checkALPN()
+{
+ constexpr std::array<uint8_t, 2> h2ALPN{'h', '2'};
+ const auto protocols = d_handler.getNextProtocol();
+ if (protocols.size() == h2ALPN.size() && memcmp(protocols.data(), h2ALPN.data(), h2ALPN.size()) == 0) {
+ return true;
+ }
+
+ constexpr std::array<uint8_t, 8> http11ALPN{'h', 't', 't', 'p', '/', '1', '.', '1'};
+ if (protocols.size() == http11ALPN.size() && memcmp(protocols.data(), http11ALPN.data(), http11ALPN.size()) == 0) {
+ ++d_ci.cs->dohFrontend->d_http1Stats.d_nbQueries;
+ }
+
+ const std::string data("HTTP/1.1 400 Bad Request\r\nConnection: Close\r\n\r\n<html><body>This server implements RFC 8484 - DNS Queries over HTTP, and requires HTTP/2 in accordance with section 5.2 of the RFC.</body></html>\r\n");
+ d_out.insert(d_out.end(), data.begin(), data.end());
+ writeToSocket(false);
+
+ vinfolog("DoH connection from %s expected ALPN value 'h2', got '%s'", d_ci.remote.toStringWithPort(), std::string(protocols.begin(), protocols.end()));
+ return false;
+}
+
+void IncomingHTTP2Connection::handleConnectionReady()
+{
+ constexpr std::array<nghttp2_settings_entry, 1> settings{{{NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, 100U}}};
+ auto ret = nghttp2_submit_settings(d_session.get(), NGHTTP2_FLAG_NONE, settings.data(), settings.size());
+ if (ret != 0) {
+ throw std::runtime_error("Fatal error: " + std::string(nghttp2_strerror(ret)));
+ }
+ d_needFlush = true;
+ ret = nghttp2_session_send(d_session.get());
+ if (ret != 0) {
+ throw std::runtime_error("Fatal error: " + std::string(nghttp2_strerror(ret)));
+ }
+}
+
+bool IncomingHTTP2Connection::hasPendingWrite() const
+{
+ return d_pendingWrite;
+}
+
+IOState IncomingHTTP2Connection::handleHandshake(const struct timeval& now)
+{
+ auto iostate = d_handler.tryHandshake();
+ if (iostate == IOState::Done) {
+ handleHandshakeDone(now);
+ if (d_handler.isTLS()) {
+ if (!checkALPN()) {
+ d_connectionDied = true;
+ stopIO();
+ return iostate;
+ }
+ }
+
+ if (d_ci.cs != nullptr && d_ci.cs->d_enableProxyProtocol && !isProxyPayloadOutsideTLS() && expectProxyProtocolFrom(d_ci.remote)) {
+ d_state = State::readingProxyProtocolHeader;
+ d_buffer.resize(s_proxyProtocolMinimumHeaderSize);
+ d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize;
+ }
+ else {
+ d_state = State::waitingForQuery;
+ handleConnectionReady();
+ }
+ }
+ return iostate;
+}
+
+void IncomingHTTP2Connection::handleIO()
+{
+ IOState iostate = IOState::Done;
+ struct timeval now
+ {
+ };
+ gettimeofday(&now, nullptr);
+
+ try {
+ if (maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) {
+ vinfolog("Terminating DoH connection from %s because it reached the maximum TCP connection duration", d_ci.remote.toStringWithPort());
+ stopIO();
+ d_connectionClosing = true;
+ return;
+ }
+
+ if (d_state == State::starting) {
+ if (d_ci.cs != nullptr && d_ci.cs->dohFrontend != nullptr) {
+ ++d_ci.cs->dohFrontend->d_httpconnects;
+ }
+ if (d_ci.cs != nullptr && d_ci.cs->d_enableProxyProtocol && isProxyPayloadOutsideTLS() && expectProxyProtocolFrom(d_ci.remote)) {
+ d_state = State::readingProxyProtocolHeader;
+ d_buffer.resize(s_proxyProtocolMinimumHeaderSize);
+ d_proxyProtocolNeed = s_proxyProtocolMinimumHeaderSize;
+ }
+ else {
+ d_state = State::doingHandshake;
+ }
+ }
+
+ if (d_state == State::doingHandshake) {
+ iostate = handleHandshake(now);
+ if (d_connectionDied) {
+ return;
+ }
+ }
+
+ if (d_state == State::readingProxyProtocolHeader) {
+ auto status = handleProxyProtocolPayload();
+ if (status == ProxyProtocolResult::Done) {
+ if (isProxyPayloadOutsideTLS()) {
+ d_state = State::doingHandshake;
+ iostate = handleHandshake(now);
+ if (d_connectionDied) {
+ return;
+ }
+ }
+ else {
+ d_currentPos = 0;
+ d_proxyProtocolNeed = 0;
+ d_buffer.clear();
+ d_state = State::waitingForQuery;
+ handleConnectionReady();
+ }
+ }
+ else if (status == ProxyProtocolResult::Error) {
+ d_connectionDied = true;
+ stopIO();
+ return;
+ }
+ }
+
+ if (active() && !d_connectionClosing && (d_state == State::waitingForQuery || d_state == State::idle)) {
+ do {
+ iostate = readHTTPData();
+ } while (active() && !d_connectionClosing && iostate == IOState::Done);
+ }
+
+ if (!active()) {
+ stopIO();
+ return;
+ }
+ /*
+ So:
+ - if we have a pending write, we need to wait until the socket becomes writable
+ and then call handleWritableCallback
+ - if we have NeedWrite but no pending write, we need to wait until the socket
+ becomes writable but for handleReadableIOCallback
+ - if we have NeedRead, or nghttp2_session_want_read, wait until the socket
+ becomes readable and call handleReadableIOCallback
+ */
+ if (hasPendingWrite()) {
+ updateIO(IOState::NeedWrite, handleWritableIOCallback);
+ }
+ else if (iostate == IOState::NeedWrite) {
+ updateIO(IOState::NeedWrite, handleReadableIOCallback);
+ }
+ else if (!d_connectionClosing) {
+ if (nghttp2_session_want_read(d_session.get()) != 0) {
+ updateIO(IOState::NeedRead, handleReadableIOCallback);
+ }
+ }
+ }
+ catch (const std::exception& e) {
+ vinfolog("Exception when processing IO for incoming DoH connection from %s: %s", d_ci.remote.toStringWithPort(), e.what());
+ d_connectionDied = true;
+ stopIO();
+ }
+}
+
+void IncomingHTTP2Connection::writeToSocket(bool socketReady)
+{
+ try {
+ d_needFlush = false;
+ IOState newState = d_handler.tryWrite(d_out, d_outPos, d_out.size());
+
+ if (newState == IOState::Done) {
+ d_pendingWrite = false;
+ d_out.clear();
+ d_outPos = 0;
+ if (active() && !d_connectionClosing) {
+ updateIO(IOState::NeedRead, handleReadableIOCallback);
+ }
+ else {
+ stopIO();
+ }
+ }
+ else {
+ updateIO(newState, handleWritableIOCallback);
+ d_pendingWrite = true;
+ }
+ }
+ catch (const std::exception& e) {
+ vinfolog("Exception while trying to write (%s) to HTTP client connection to %s: %s", (socketReady ? "ready" : "send"), d_ci.remote.toStringWithPort(), e.what());
+ handleIOError();
+ }
+}
+
+ssize_t IncomingHTTP2Connection::send_callback(nghttp2_session* session, const uint8_t* data, size_t length, int flags, void* user_data)
+{
+ auto* conn = static_cast<IncomingHTTP2Connection*>(user_data);
+ if (conn->d_connectionDied) {
+ return static_cast<ssize_t>(length);
+ }
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): nghttp2 API
+ conn->d_out.insert(conn->d_out.end(), data, data + length);
+
+ if (conn->d_connectionClosing || conn->d_needFlush) {
+ conn->writeToSocket(false);
+ }
+
+ return static_cast<ssize_t>(length);
+}
+
+static const std::array<const std::string, static_cast<size_t>(NGHTTP2Headers::HeaderConstantIndexes::COUNT)> s_headerConstants{
+ "200",
+ ":method",
+ "POST",
+ ":scheme",
+ "https",
+ ":authority",
+ "x-forwarded-for",
+ ":path",
+ "content-length",
+ ":status",
+ "location",
+ "accept",
+ "application/dns-message",
+ "cache-control",
+ "content-type",
+ "application/dns-message",
+ "user-agent",
+ "nghttp2-" NGHTTP2_VERSION "/dnsdist",
+ "x-forwarded-port",
+ "x-forwarded-proto",
+ "dns-over-udp",
+ "dns-over-tcp",
+ "dns-over-tls",
+ "dns-over-http",
+ "dns-over-https"};
+
+static const std::string s_authorityHeaderName(":authority");
+static const std::string s_pathHeaderName(":path");
+static const std::string s_methodHeaderName(":method");
+static const std::string s_schemeHeaderName(":scheme");
+static const std::string s_xForwardedForHeaderName("x-forwarded-for");
+
+void NGHTTP2Headers::addStaticHeader(std::vector<nghttp2_nv>& headers, NGHTTP2Headers::HeaderConstantIndexes nameKey, NGHTTP2Headers::HeaderConstantIndexes valueKey)
+{
+ const auto& name = s_headerConstants.at(static_cast<size_t>(nameKey));
+ const auto& value = s_headerConstants.at(static_cast<size_t>(valueKey));
+
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast,cppcoreguidelines-pro-type-reinterpret-cast): nghttp2 API
+ headers.push_back({const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(name.c_str())), const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(value.c_str())), name.size(), value.size(), NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE});
+}
+
+void NGHTTP2Headers::addCustomDynamicHeader(std::vector<nghttp2_nv>& headers, const std::string& name, const std::string_view& value)
+{
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast,cppcoreguidelines-pro-type-reinterpret-cast): nghttp2 API
+ headers.push_back({const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(name.data())), const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(value.data())), name.size(), value.size(), NGHTTP2_NV_FLAG_NO_COPY_NAME | NGHTTP2_NV_FLAG_NO_COPY_VALUE});
+}
+
+void NGHTTP2Headers::addDynamicHeader(std::vector<nghttp2_nv>& headers, NGHTTP2Headers::HeaderConstantIndexes nameKey, const std::string_view& value)
+{
+ const auto& name = s_headerConstants.at(static_cast<size_t>(nameKey));
+ NGHTTP2Headers::addCustomDynamicHeader(headers, name, value);
+}
+
+IOState IncomingHTTP2Connection::sendResponse(const struct timeval& now, TCPResponse&& response)
+{
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-array-to-pointer-decay): clang-tidy is getting confused by assert()
+ assert(response.d_idstate.d_streamID != -1);
+ auto& context = d_currentStreams.at(response.d_idstate.d_streamID);
+
+ uint32_t statusCode = 200U;
+ std::string contentType;
+ bool sendContentType = true;
+ auto& responseBuffer = context.d_buffer;
+ if (context.d_statusCode != 0) {
+ responseBuffer = std::move(context.d_response);
+ statusCode = context.d_statusCode;
+ contentType = std::move(context.d_contentTypeOut);
+ }
+ else {
+ responseBuffer = std::move(response.d_buffer);
+ }
+
+ sendResponse(response.d_idstate.d_streamID, context, statusCode, d_ci.cs->dohFrontend->d_customResponseHeaders, contentType, sendContentType);
+ handleResponseSent(response);
+
+ return hasPendingWrite() ? IOState::NeedWrite : IOState::Done;
+}
+
+void IncomingHTTP2Connection::notifyIOError(const struct timeval& now, TCPResponse&& response)
+{
+ if (std::this_thread::get_id() != d_creatorThreadID) {
+ /* empty buffer will signal an IO error */
+ response.d_buffer.clear();
+ handleCrossProtocolResponse(now, std::move(response));
+ return;
+ }
+
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-array-to-pointer-decay): clang-tidy is getting confused by assert()
+ assert(response.d_idstate.d_streamID != -1);
+ auto& context = d_currentStreams.at(response.d_idstate.d_streamID);
+ context.d_buffer = std::move(response.d_buffer);
+ sendResponse(response.d_idstate.d_streamID, context, 502, d_ci.cs->dohFrontend->d_customResponseHeaders);
+}
+
+bool IncomingHTTP2Connection::sendResponse(IncomingHTTP2Connection::StreamID streamID, IncomingHTTP2Connection::PendingQuery& context, uint16_t responseCode, const HeadersMap& customResponseHeaders, const std::string& contentType, bool addContentType)
+{
+ /* if data_prd is not NULL, it provides data which will be sent in subsequent DATA frames. In this case, a method that allows request message bodies (https://tools.ietf.org/html/rfc7231#section-4) must be specified with :method key (e.g. POST). This function does not take ownership of the data_prd. The function copies the members of the data_prd. If data_prd is NULL, HEADERS have END_STREAM set.
+ */
+ nghttp2_data_provider data_provider;
+
+ data_provider.source.ptr = this;
+ data_provider.read_callback = [](nghttp2_session*, IncomingHTTP2Connection::StreamID stream_id, uint8_t* buf, size_t length, uint32_t* data_flags, nghttp2_data_source* source, void* cb_data) -> ssize_t {
+ auto* connection = static_cast<IncomingHTTP2Connection*>(cb_data);
+ auto& obj = connection->d_currentStreams.at(stream_id);
+ size_t toCopy = 0;
+ if (obj.d_queryPos < obj.d_buffer.size()) {
+ size_t remaining = obj.d_buffer.size() - obj.d_queryPos;
+ toCopy = length > remaining ? remaining : length;
+ memcpy(buf, &obj.d_buffer.at(obj.d_queryPos), toCopy);
+ obj.d_queryPos += toCopy;
+ }
+
+ if (obj.d_queryPos >= obj.d_buffer.size()) {
+ *data_flags |= NGHTTP2_DATA_FLAG_EOF;
+ obj.d_buffer.clear();
+ connection->d_needFlush = true;
+ }
+ return static_cast<ssize_t>(toCopy);
+ };
+
+ const auto& dohFrontend = d_ci.cs->dohFrontend;
+ auto& responseBody = context.d_buffer;
+
+ std::vector<nghttp2_nv> headers;
+ std::string responseCodeStr;
+ std::string cacheControlValue;
+ std::string location;
+ /* remember that dynamic header values should be kept alive
+ until we have called nghttp2_submit_response(), at least */
+ /* status, content-type, cache-control, content-length */
+ headers.reserve(4);
+
+ if (responseCode == 200) {
+ NGHTTP2Headers::addStaticHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::STATUS_NAME, NGHTTP2Headers::HeaderConstantIndexes::OK_200_VALUE);
+ ++dohFrontend->d_validresponses;
+ ++dohFrontend->d_http2Stats.d_nb200Responses;
+
+ if (addContentType) {
+ if (contentType.empty()) {
+ NGHTTP2Headers::addStaticHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::CONTENT_TYPE_NAME, NGHTTP2Headers::HeaderConstantIndexes::CONTENT_TYPE_VALUE);
+ }
+ else {
+ NGHTTP2Headers::addDynamicHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::CONTENT_TYPE_NAME, contentType);
+ }
+ }
+
+ if (dohFrontend->d_sendCacheControlHeaders && responseBody.size() > sizeof(dnsheader)) {
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): API
+ uint32_t minTTL = getDNSPacketMinTTL(reinterpret_cast<const char*>(responseBody.data()), responseBody.size());
+ if (minTTL != std::numeric_limits<uint32_t>::max()) {
+ cacheControlValue = "max-age=" + std::to_string(minTTL);
+ NGHTTP2Headers::addDynamicHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::CACHE_CONTROL_NAME, cacheControlValue);
+ }
+ }
+ }
+ else {
+ responseCodeStr = std::to_string(responseCode);
+ NGHTTP2Headers::addDynamicHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::STATUS_NAME, responseCodeStr);
+
+ if (responseCode >= 300 && responseCode < 400) {
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+ location = std::string(reinterpret_cast<const char*>(responseBody.data()), responseBody.size());
+ NGHTTP2Headers::addDynamicHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::CONTENT_TYPE_NAME, "text/html; charset=utf-8");
+ NGHTTP2Headers::addDynamicHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::LOCATION_NAME, location);
+ static const std::string s_redirectStart{"<!DOCTYPE html><TITLE>Moved</TITLE><P>The document has moved <A HREF=\""};
+ static const std::string s_redirectEnd{"\">here</A>"};
+ responseBody.reserve(s_redirectStart.size() + responseBody.size() + s_redirectEnd.size());
+ responseBody.insert(responseBody.begin(), s_redirectStart.begin(), s_redirectStart.end());
+ responseBody.insert(responseBody.end(), s_redirectEnd.begin(), s_redirectEnd.end());
+ ++dohFrontend->d_redirectresponses;
+ }
+ else {
+ ++dohFrontend->d_errorresponses;
+ switch (responseCode) {
+ case 400:
+ ++dohFrontend->d_http2Stats.d_nb400Responses;
+ break;
+ case 403:
+ ++dohFrontend->d_http2Stats.d_nb403Responses;
+ break;
+ case 500:
+ ++dohFrontend->d_http2Stats.d_nb500Responses;
+ break;
+ case 502:
+ ++dohFrontend->d_http2Stats.d_nb502Responses;
+ break;
+ default:
+ ++dohFrontend->d_http2Stats.d_nbOtherResponses;
+ break;
+ }
+
+ if (!responseBody.empty()) {
+ NGHTTP2Headers::addDynamicHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::CONTENT_TYPE_NAME, "text/plain; charset=utf-8");
+ }
+ else {
+ static const std::string invalid{"invalid DNS query"};
+ static const std::string notAllowed{"dns query not allowed"};
+ static const std::string noDownstream{"no downstream server available"};
+ static const std::string internalServerError{"Internal Server Error"};
+
+ switch (responseCode) {
+ case 400:
+ responseBody.insert(responseBody.begin(), invalid.begin(), invalid.end());
+ break;
+ case 403:
+ responseBody.insert(responseBody.begin(), notAllowed.begin(), notAllowed.end());
+ break;
+ case 502:
+ responseBody.insert(responseBody.begin(), noDownstream.begin(), noDownstream.end());
+ break;
+ case 500:
+ /* fall-through */
+ default:
+ responseBody.insert(responseBody.begin(), internalServerError.begin(), internalServerError.end());
+ break;
+ }
+ }
+ }
+ }
+
+ const std::string contentLength = std::to_string(responseBody.size());
+ NGHTTP2Headers::addDynamicHeader(headers, NGHTTP2Headers::HeaderConstantIndexes::CONTENT_LENGTH_NAME, contentLength);
+
+ for (const auto& [key, value] : customResponseHeaders) {
+ NGHTTP2Headers::addCustomDynamicHeader(headers, key, value);
+ }
+
+ auto ret = nghttp2_submit_response(d_session.get(), streamID, headers.data(), headers.size(), &data_provider);
+ if (ret != 0) {
+ d_currentStreams.erase(streamID);
+ vinfolog("Error submitting HTTP response for stream %d: %s", streamID, nghttp2_strerror(ret));
+ return false;
+ }
+
+ ret = nghttp2_session_send(d_session.get());
+ if (ret != 0) {
+ d_currentStreams.erase(streamID);
+ vinfolog("Error flushing HTTP response for stream %d: %s", streamID, nghttp2_strerror(ret));
+ return false;
+ }
+
+ return true;
+}
+
+static void processForwardedForHeader(const std::unique_ptr<HeadersMap>& headers, ComboAddress& remote)
+{
+ if (!headers) {
+ return;
+ }
+
+ auto headerIt = headers->find(s_xForwardedForHeaderName);
+ if (headerIt == headers->end()) {
+ return;
+ }
+
+ std::string_view value = headerIt->second;
+ try {
+ auto pos = value.rfind(',');
+ if (pos != std::string_view::npos) {
+ ++pos;
+ for (; pos < value.size() && value[pos] == ' '; ++pos) {
+ }
+
+ if (pos < value.size()) {
+ value = value.substr(pos);
+ }
+ }
+ auto newRemote = ComboAddress(std::string(value));
+ remote = newRemote;
+ }
+ catch (const std::exception& e) {
+ vinfolog("Invalid X-Forwarded-For header ('%s') received from %s : %s", std::string(value), remote.toStringWithPort(), e.what());
+ }
+ catch (const PDNSException& e) {
+ vinfolog("Invalid X-Forwarded-For header ('%s') received from %s : %s", std::string(value), remote.toStringWithPort(), e.reason);
+ }
+}
+
+void IncomingHTTP2Connection::handleIncomingQuery(IncomingHTTP2Connection::PendingQuery&& query, IncomingHTTP2Connection::StreamID streamID)
+{
+ const auto handleImmediateResponse = [this, &query, streamID](uint16_t code, const std::string& reason, PacketBuffer&& response = PacketBuffer()) {
+ if (response.empty()) {
+ query.d_buffer.clear();
+ query.d_buffer.insert(query.d_buffer.begin(), reason.begin(), reason.end());
+ }
+ else {
+ query.d_buffer = std::move(response);
+ }
+ vinfolog("Sending an immediate %d response to incoming DoH query: %s", code, reason);
+ sendResponse(streamID, query, code, d_ci.cs->dohFrontend->d_customResponseHeaders);
+ };
+
+ if (query.d_method == PendingQuery::Method::Unknown || query.d_method == PendingQuery::Method::Unsupported) {
+ handleImmediateResponse(400, "DoH query not allowed because of unsupported HTTP method");
+ return;
+ }
+
+ ++d_ci.cs->dohFrontend->d_http2Stats.d_nbQueries;
+
+ if (d_ci.cs->dohFrontend->d_trustForwardedForHeader) {
+ processForwardedForHeader(query.d_headers, d_proxiedRemote);
+
+ /* second ACL lookup based on the updated address */
+ auto& holders = d_threadData.holders;
+ if (!holders.acl->match(d_proxiedRemote)) {
+ ++dnsdist::metrics::g_stats.aclDrops;
+ vinfolog("Query from %s (%s) (DoH) dropped because of ACL", d_ci.remote.toStringWithPort(), d_proxiedRemote.toStringWithPort());
+ handleImmediateResponse(403, "DoH query not allowed because of ACL");
+ return;
+ }
+
+ if (!d_ci.cs->dohFrontend->d_keepIncomingHeaders) {
+ query.d_headers.reset();
+ }
+ }
+
+ if (d_ci.cs->dohFrontend->d_exactPathMatching) {
+ if (d_ci.cs->dohFrontend->d_urls.count(query.d_path) == 0) {
+ handleImmediateResponse(404, "there is no endpoint configured for this path");
+ return;
+ }
+ }
+ else {
+ bool found = false;
+ for (const auto& path : d_ci.cs->dohFrontend->d_urls) {
+ if (boost::starts_with(query.d_path, path)) {
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ handleImmediateResponse(404, "there is no endpoint configured for this path");
+ return;
+ }
+ }
+
+ /* the responses map can be updated at runtime, so we need to take a copy of
+ the shared pointer, increasing the reference counter */
+ auto responsesMap = d_ci.cs->dohFrontend->d_responsesMap;
+ if (responsesMap) {
+ for (const auto& entry : *responsesMap) {
+ if (entry->matches(query.d_path)) {
+ const auto& customHeaders = entry->getHeaders();
+ query.d_buffer = entry->getContent();
+ if (entry->getStatusCode() >= 400 && !query.d_buffer.empty()) {
+ // legacy trailing 0 from the h2o era
+ query.d_buffer.pop_back();
+ }
+
+ sendResponse(streamID, query, entry->getStatusCode(), customHeaders ? *customHeaders : d_ci.cs->dohFrontend->d_customResponseHeaders, std::string(), false);
+ return;
+ }
+ }
+ }
+
+ if (query.d_buffer.empty() && query.d_method == PendingQuery::Method::Get && !query.d_queryString.empty()) {
+ auto payload = dnsdist::doh::getPayloadFromPath(query.d_queryString);
+ if (payload) {
+ query.d_buffer = std::move(*payload);
+ }
+ else {
+ ++d_ci.cs->dohFrontend->d_badrequests;
+ handleImmediateResponse(400, "DoH unable to decode BASE64-URL");
+ return;
+ }
+ }
+
+ if (query.d_method == PendingQuery::Method::Get) {
+ ++d_ci.cs->dohFrontend->d_getqueries;
+ }
+ else if (query.d_method == PendingQuery::Method::Post) {
+ ++d_ci.cs->dohFrontend->d_postqueries;
+ }
+
+ try {
+ struct timeval now
+ {
+ };
+ gettimeofday(&now, nullptr);
+ auto processingResult = handleQuery(std::move(query.d_buffer), now, streamID);
+
+ switch (processingResult) {
+ case QueryProcessingResult::TooSmall:
+ handleImmediateResponse(400, "DoH non-compliant query");
+ break;
+ case QueryProcessingResult::InvalidHeaders:
+ handleImmediateResponse(400, "DoH invalid headers");
+ break;
+ case QueryProcessingResult::Dropped:
+ handleImmediateResponse(403, "DoH dropped query");
+ break;
+ case QueryProcessingResult::NoBackend:
+ handleImmediateResponse(502, "DoH no backend available");
+ return;
+ case QueryProcessingResult::Forwarded:
+ case QueryProcessingResult::Asynchronous:
+ case QueryProcessingResult::SelfAnswered:
+ break;
+ }
+ }
+ catch (const std::exception& e) {
+ vinfolog("Exception while processing DoH query: %s", e.what());
+ handleImmediateResponse(400, "DoH non-compliant query");
+ return;
+ }
+}
+
+int IncomingHTTP2Connection::on_frame_recv_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data)
+{
+ auto* conn = static_cast<IncomingHTTP2Connection*>(user_data);
+ /* is this the last frame for this stream? */
+ if ((frame->hd.type == NGHTTP2_HEADERS || frame->hd.type == NGHTTP2_DATA) && (frame->hd.flags & NGHTTP2_FLAG_END_STREAM) != 0) {
+ auto streamID = frame->hd.stream_id;
+ auto stream = conn->d_currentStreams.find(streamID);
+ if (stream != conn->d_currentStreams.end()) {
+ conn->handleIncomingQuery(std::move(stream->second), streamID);
+ }
+ else {
+ vinfolog("Stream %d NOT FOUND", streamID);
+ return NGHTTP2_ERR_CALLBACK_FAILURE;
+ }
+ }
+
+ return 0;
+}
+
+int IncomingHTTP2Connection::on_stream_close_callback(nghttp2_session* session, IncomingHTTP2Connection::StreamID stream_id, uint32_t error_code, void* user_data)
+{
+ auto* conn = static_cast<IncomingHTTP2Connection*>(user_data);
+
+ conn->d_currentStreams.erase(stream_id);
+ return 0;
+}
+
+int IncomingHTTP2Connection::on_begin_headers_callback(nghttp2_session* session, const nghttp2_frame* frame, void* user_data)
+{
+ if (frame->hd.type != NGHTTP2_HEADERS || frame->headers.cat != NGHTTP2_HCAT_REQUEST) {
+ return 0;
+ }
+
+ auto* conn = static_cast<IncomingHTTP2Connection*>(user_data);
+ auto insertPair = conn->d_currentStreams.emplace(frame->hd.stream_id, PendingQuery());
+ if (!insertPair.second) {
+ /* there is a stream ID collision, something is very wrong! */
+ vinfolog("Stream ID collision (%d) on connection from %d", frame->hd.stream_id, conn->d_ci.remote.toStringWithPort());
+ conn->d_connectionClosing = true;
+ conn->d_needFlush = true;
+ nghttp2_session_terminate_session(conn->d_session.get(), NGHTTP2_NO_ERROR);
+ auto ret = nghttp2_session_send(conn->d_session.get());
+ if (ret != 0) {
+ vinfolog("Error flushing HTTP response for stream %d from %s: %s", frame->hd.stream_id, conn->d_ci.remote.toStringWithPort(), nghttp2_strerror(ret));
+ return NGHTTP2_ERR_CALLBACK_FAILURE;
+ }
+
+ return 0;
+ }
+
+ return 0;
+}
+
+static std::string::size_type getLengthOfPathWithoutParameters(const std::string_view& path)
+{
+ auto pos = path.find('?');
+ if (pos == string::npos) {
+ return path.size();
+ }
+
+ return pos;
+}
+
+int IncomingHTTP2Connection::on_header_callback(nghttp2_session* session, const nghttp2_frame* frame, const uint8_t* name, size_t nameLen, const uint8_t* value, size_t valuelen, uint8_t flags, void* user_data)
+{
+ auto* conn = static_cast<IncomingHTTP2Connection*>(user_data);
+
+ if (frame->hd.type == NGHTTP2_HEADERS && frame->headers.cat == NGHTTP2_HCAT_REQUEST) {
+ if (nghttp2_check_header_name(name, nameLen) == 0) {
+ vinfolog("Invalid header name");
+ return NGHTTP2_ERR_CALLBACK_FAILURE;
+ }
+
+#if HAVE_NGHTTP2_CHECK_HEADER_VALUE_RFC9113
+ if (nghttp2_check_header_value_rfc9113(value, valuelen) == 0) {
+ vinfolog("Invalid header value");
+ return NGHTTP2_ERR_CALLBACK_FAILURE;
+ }
+#endif /* HAVE_NGHTTP2_CHECK_HEADER_VALUE_RFC9113 */
+
+ auto headerMatches = [name, nameLen](const std::string& expected) -> bool {
+ return nameLen == expected.size() && memcmp(name, expected.data(), expected.size()) == 0;
+ };
+
+ auto stream = conn->d_currentStreams.find(frame->hd.stream_id);
+ if (stream == conn->d_currentStreams.end()) {
+ vinfolog("Unable to match the stream ID %d to a known one!", frame->hd.stream_id);
+ return NGHTTP2_ERR_CALLBACK_FAILURE;
+ }
+ auto& query = stream->second;
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): nghttp2 API
+ auto valueView = std::string_view(reinterpret_cast<const char*>(value), valuelen);
+ if (headerMatches(s_pathHeaderName)) {
+#if HAVE_NGHTTP2_CHECK_PATH
+ if (nghttp2_check_path(value, valuelen) == 0) {
+ vinfolog("Invalid path value");
+ return NGHTTP2_ERR_CALLBACK_FAILURE;
+ }
+#endif /* HAVE_NGHTTP2_CHECK_PATH */
+
+ auto pathLen = getLengthOfPathWithoutParameters(valueView);
+ query.d_path = valueView.substr(0, pathLen);
+ if (pathLen < valueView.size()) {
+ query.d_queryString = valueView.substr(pathLen);
+ }
+ }
+ else if (headerMatches(s_authorityHeaderName)) {
+ query.d_host = valueView;
+ }
+ else if (headerMatches(s_schemeHeaderName)) {
+ query.d_scheme = valueView;
+ }
+ else if (headerMatches(s_methodHeaderName)) {
+#if HAVE_NGHTTP2_CHECK_METHOD
+ if (nghttp2_check_method(value, valuelen) == 0) {
+ vinfolog("Invalid method value");
+ return NGHTTP2_ERR_CALLBACK_FAILURE;
+ }
+#endif /* HAVE_NGHTTP2_CHECK_METHOD */
+
+ if (valueView == "GET") {
+ query.d_method = PendingQuery::Method::Get;
+ }
+ else if (valueView == "POST") {
+ query.d_method = PendingQuery::Method::Post;
+ }
+ else {
+ query.d_method = PendingQuery::Method::Unsupported;
+ vinfolog("Unsupported method value");
+ return 0;
+ }
+ }
+
+ if (conn->d_ci.cs->dohFrontend->d_keepIncomingHeaders || (conn->d_ci.cs->dohFrontend->d_trustForwardedForHeader && headerMatches(s_xForwardedForHeaderName))) {
+ if (!query.d_headers) {
+ query.d_headers = std::make_unique<HeadersMap>();
+ }
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast): nghttp2 API
+ query.d_headers->insert({std::string(reinterpret_cast<const char*>(name), nameLen), std::string(valueView)});
+ }
+ }
+ return 0;
+}
+
+int IncomingHTTP2Connection::on_data_chunk_recv_callback(nghttp2_session* session, uint8_t flags, IncomingHTTP2Connection::StreamID stream_id, const uint8_t* data, size_t len, void* user_data)
+{
+ auto* conn = static_cast<IncomingHTTP2Connection*>(user_data);
+ auto stream = conn->d_currentStreams.find(stream_id);
+ if (stream == conn->d_currentStreams.end()) {
+ vinfolog("Unable to match the stream ID %d to a known one!", stream_id);
+ return NGHTTP2_ERR_CALLBACK_FAILURE;
+ }
+ if (len > std::numeric_limits<uint16_t>::max() || (std::numeric_limits<uint16_t>::max() - stream->second.d_buffer.size()) < len) {
+ vinfolog("Data frame of size %d is too large for a DNS query (we already have %d)", len, stream->second.d_buffer.size());
+ return NGHTTP2_ERR_CALLBACK_FAILURE;
+ }
+
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic): nghttp2 API
+ stream->second.d_buffer.insert(stream->second.d_buffer.end(), data, data + len);
+
+ return 0;
+}
+
+int IncomingHTTP2Connection::on_error_callback(nghttp2_session* session, int lib_error_code, const char* msg, size_t len, void* user_data)
+{
+ auto* conn = static_cast<IncomingHTTP2Connection*>(user_data);
+
+ vinfolog("Error in HTTP/2 connection from %d: %s", conn->d_ci.remote.toStringWithPort(), std::string(msg, len));
+ conn->d_connectionClosing = true;
+ conn->d_needFlush = true;
+ nghttp2_session_terminate_session(conn->d_session.get(), NGHTTP2_NO_ERROR);
+ auto ret = nghttp2_session_send(conn->d_session.get());
+ if (ret != 0) {
+ vinfolog("Error flushing HTTP response on connection from %s: %s", conn->d_ci.remote.toStringWithPort(), nghttp2_strerror(ret));
+ return NGHTTP2_ERR_CALLBACK_FAILURE;
+ }
+
+ return 0;
+}
+
+IOState IncomingHTTP2Connection::readHTTPData()
+{
+ IOState newState = IOState::Done;
+ size_t got = 0;
+ if (d_in.size() < s_initialReceiveBufferSize) {
+ d_in.resize(std::max(s_initialReceiveBufferSize, d_in.capacity()));
+ }
+ try {
+ newState = d_handler.tryRead(d_in, got, d_in.size(), true);
+ d_in.resize(got);
+
+ if (got > 0) {
+ /* we got something */
+ auto readlen = nghttp2_session_mem_recv(d_session.get(), d_in.data(), d_in.size());
+ /* as long as we don't require a pause by returning nghttp2_error.NGHTTP2_ERR_PAUSE from a CB,
+ all data should be consumed before returning */
+ if (readlen < 0 || static_cast<size_t>(readlen) < d_in.size()) {
+ throw std::runtime_error("Fatal error while passing received data to nghttp2: " + std::string(nghttp2_strerror((int)readlen)));
+ }
+
+ nghttp2_session_send(d_session.get());
+ }
+ }
+ catch (const std::exception& e) {
+ vinfolog("Exception while trying to read from HTTP client connection to %s: %s", d_ci.remote.toStringWithPort(), e.what());
+ handleIOError();
+ return IOState::Done;
+ }
+ return newState;
+}
+
+void IncomingHTTP2Connection::handleReadableIOCallback([[maybe_unused]] int descriptor, FDMultiplexer::funcparam_t& param)
+{
+ auto conn = boost::any_cast<std::shared_ptr<IncomingHTTP2Connection>>(param);
+ conn->handleIO();
+}
+
+void IncomingHTTP2Connection::handleWritableIOCallback([[maybe_unused]] int descriptor, FDMultiplexer::funcparam_t& param)
+{
+ auto conn = boost::any_cast<std::shared_ptr<IncomingHTTP2Connection>>(param);
+ conn->writeToSocket(true);
+}
+
+void IncomingHTTP2Connection::stopIO()
+{
+ if (d_ioState) {
+ d_ioState->reset();
+ }
+}
+
+uint32_t IncomingHTTP2Connection::getConcurrentStreamsCount() const
+{
+ return d_currentStreams.size();
+}
+
+boost::optional<struct timeval> IncomingHTTP2Connection::getIdleClientReadTTD(struct timeval now) const
+{
+ auto idleTimeout = d_ci.cs->dohFrontend->d_idleTimeout;
+ if (g_maxTCPConnectionDuration == 0 && idleTimeout == 0) {
+ return boost::none;
+ }
+
+ if (g_maxTCPConnectionDuration > 0) {
+ auto elapsed = now.tv_sec - d_connectionStartTime.tv_sec;
+ if (elapsed < 0 || (static_cast<size_t>(elapsed) >= g_maxTCPConnectionDuration)) {
+ return now;
+ }
+ auto remaining = g_maxTCPConnectionDuration - elapsed;
+ if (idleTimeout == 0 || remaining <= static_cast<size_t>(idleTimeout)) {
+ now.tv_sec += static_cast<time_t>(remaining);
+ return now;
+ }
+ }
+
+ now.tv_sec += idleTimeout;
+ return now;
+}
+
+void IncomingHTTP2Connection::updateIO(IOState newState, const FDMultiplexer::callbackfunc_t& callback)
+{
+ boost::optional<struct timeval> ttd{boost::none};
+
+ auto shared = std::dynamic_pointer_cast<IncomingHTTP2Connection>(shared_from_this());
+ if (!shared || !d_ioState) {
+ return;
+ }
+
+ timeval now{};
+ gettimeofday(&now, nullptr);
+
+ if (newState == IOState::NeedRead) {
+ /* use the idle TTL if the handshake has been completed (and proxy protocol payload received, if any),
+ and we have processed at least one query, otherwise we use the shorter read TTL */
+ if ((d_state == State::waitingForQuery || d_state == State::idle) && (d_queriesCount > 0 || d_currentQueriesCount > 0)) {
+ ttd = getIdleClientReadTTD(now);
+ }
+ else {
+ ttd = getClientReadTTD(now);
+ }
+ d_ioState->update(newState, callback, shared, ttd);
+ }
+ else if (newState == IOState::NeedWrite) {
+ ttd = getClientWriteTTD(now);
+ d_ioState->update(newState, callback, shared, ttd);
+ }
+}
+
+void IncomingHTTP2Connection::handleIOError()
+{
+ d_connectionDied = true;
+ d_out.clear();
+ d_outPos = 0;
+ nghttp2_session_terminate_session(d_session.get(), NGHTTP2_PROTOCOL_ERROR);
+ d_currentStreams.clear();
+ stopIO();
+}
+
+bool IncomingHTTP2Connection::active() const
+{
+ return !d_connectionDied && d_ioState != nullptr;
+}
+
+#endif /* HAVE_DNS_OVER_HTTPS && HAVE_NGHTTP2 */