diff options
Diffstat (limited to 'dnsdist-async.cc')
-rw-r--r-- | dnsdist-async.cc | 154 |
1 files changed, 64 insertions, 90 deletions
diff --git a/dnsdist-async.cc b/dnsdist-async.cc index e1acef8..9cb96d8 100644 --- a/dnsdist-async.cc +++ b/dnsdist-async.cc @@ -27,28 +27,19 @@ namespace dnsdist { -AsynchronousHolder::AsynchronousHolder(bool failOpen) : - d_data(std::make_shared<Data>()) +AsynchronousHolder::Data::Data(bool failOpen) : + d_failOpen(failOpen) { - d_data->d_failOpen = failOpen; - - int fds[2] = {-1, -1}; - if (pipe(fds) < 0) { - throw std::runtime_error("Error creating the AsynchronousHolder pipe: " + stringerror()); - } - - for (size_t idx = 0; idx < (sizeof(fds) / sizeof(*fds)); idx++) { - if (!setNonBlocking(fds[idx])) { - int err = errno; - close(fds[0]); - close(fds[1]); - throw std::runtime_error("Error setting the AsynchronousHolder pipe non-blocking: " + stringerror(err)); - } - } - - d_data->d_notifyPipe = FDWrapper(fds[1]); - d_data->d_watchPipe = FDWrapper(fds[0]); + auto [notifier, waiter] = pdns::channel::createNotificationQueue(true); + // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer): how I am supposed to do that? + d_waiter = std::move(waiter); + // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer): how I am supposed to do that? + d_notifier = std::move(notifier); +} +AsynchronousHolder::AsynchronousHolder(bool failOpen) : + d_data(std::make_shared<Data>(failOpen)) +{ std::thread main([data = this->d_data] { mainThread(data); }); main.detach(); } @@ -64,49 +55,19 @@ AsynchronousHolder::~AsynchronousHolder() bool AsynchronousHolder::notify() const { - const char data = 0; - bool failed = false; - do { - auto written = write(d_data->d_notifyPipe.getHandle(), &data, sizeof(data)); - if (written == 0) { - break; - } - if (written > 0 && static_cast<size_t>(written) == sizeof(data)) { - return true; - } - if (errno != EINTR) { - failed = true; - } - } while (!failed); - - return false; + return d_data->d_notifier.notify(); } -bool AsynchronousHolder::wait(const AsynchronousHolder::Data& data, FDMultiplexer& mplexer, std::vector<int>& readyFDs, int atMostMs) +bool AsynchronousHolder::wait(AsynchronousHolder::Data& data, FDMultiplexer& mplexer, std::vector<int>& readyFDs, int atMostMs) { readyFDs.clear(); mplexer.getAvailableFDs(readyFDs, atMostMs); - if (readyFDs.size() == 0) { + if (readyFDs.empty()) { /* timeout */ return true; } - while (true) { - /* we might have been notified several times, let's read - as much as possible before returning */ - char dummy = 0; - auto got = read(data.d_watchPipe.getHandle(), &dummy, sizeof(dummy)); - if (got == 0) { - break; - } - if (got > 0 && static_cast<size_t>(got) != sizeof(dummy)) { - continue; - } - if (got == -1 && (errno == EAGAIN || errno == EWOULDBLOCK)) { - break; - } - } - + data.d_waiter.clear(); return false; } @@ -120,14 +81,17 @@ void AsynchronousHolder::stop() notify(); } +// NOLINTNEXTLINE(performance-unnecessary-value-param): this is a long-lived thread, and we want to make sure the reference count of the shared pointer has been increased void AsynchronousHolder::mainThread(std::shared_ptr<Data> data) { setThreadName("dnsdist/async"); - struct timeval now; + struct timeval now + { + }; std::list<std::pair<uint16_t, std::unique_ptr<CrossProtocolQuery>>> expiredEvents; auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent(1)); - mplexer->addReadFD(data->d_watchPipe.getHandle(), [](int, FDMultiplexer::funcparam_t&) {}); + mplexer->addReadFD(data->d_waiter.getDescriptor(), [](int, FDMultiplexer::funcparam_t&) {}); std::vector<int> readyFDs; while (true) { @@ -148,7 +112,7 @@ void AsynchronousHolder::mainThread(std::shared_ptr<Data> data) } else { auto remainingUsec = uSec(next - now); - timeout = std::round(remainingUsec / 1000.0); + timeout = static_cast<int>(std::round(static_cast<double>(remainingUsec) / 1000.0)); if (timeout == 0 && remainingUsec > 0) { /* if we have less than 1 ms, let's wait at least 1 ms */ timeout = 1; @@ -173,7 +137,8 @@ void AsynchronousHolder::mainThread(std::shared_ptr<Data> data) vinfolog("Asynchronous query %d has expired at %d.%d, notifying the sender", queryID, now.tv_sec, now.tv_usec); auto sender = query->getTCPQuerySender(); if (sender) { - sender->notifyIOError(std::move(query->query.d_idstate), now); + TCPResponse tresponse(std::move(query->query)); + sender->notifyIOError(now, std::move(tresponse)); } } else { @@ -213,25 +178,27 @@ std::unique_ptr<CrossProtocolQuery> AsynchronousHolder::get(uint16_t asyncID, ui { /* no need to notify, worst case the thread wakes up for nothing because this was the next TTD */ auto content = d_data->d_content.lock(); - auto it = content->find(std::tie(queryID, asyncID)); - if (it == content->end()) { - struct timeval now; + auto contentIt = content->find(std::tie(queryID, asyncID)); + if (contentIt == content->end()) { + struct timeval now + { + }; gettimeofday(&now, nullptr); vinfolog("Asynchronous object %d not found at %d.%d", queryID, now.tv_sec, now.tv_usec); return nullptr; } - auto result = std::move(it->d_query); - content->erase(it); + auto result = std::move(contentIt->d_query); + content->erase(contentIt); return result; } void AsynchronousHolder::pickupExpired(content_t& content, const struct timeval& now, std::list<std::pair<uint16_t, std::unique_ptr<CrossProtocolQuery>>>& events) { auto& idx = content.get<TTDTag>(); - for (auto it = idx.begin(); it != idx.end() && it->d_ttd < now;) { - events.emplace_back(it->d_queryID, std::move(it->d_query)); - it = idx.erase(it); + for (auto contentIt = idx.begin(); contentIt != idx.end() && contentIt->d_ttd < now;) { + events.emplace_back(contentIt->d_queryID, std::move(contentIt->d_query)); + contentIt = idx.erase(contentIt); } } @@ -253,10 +220,10 @@ static bool resumeResponse(std::unique_ptr<CrossProtocolQuery>&& response) { try { auto& ids = response->query.d_idstate; - DNSResponse dr = response->getDR(); + DNSResponse dnsResponse = response->getDR(); LocalHolders holders; - auto result = processResponseAfterRules(response->query.d_buffer, *holders.cacheInsertedRespRuleActions, dr, ids.cs->muted); + auto result = processResponseAfterRules(response->query.d_buffer, *holders.cacheInsertedRespRuleActions, dnsResponse, ids.cs->muted); if (!result) { /* easy */ return true; @@ -264,7 +231,9 @@ static bool resumeResponse(std::unique_ptr<CrossProtocolQuery>&& response) auto sender = response->getTCPQuerySender(); if (sender) { - struct timeval now; + struct timeval now + { + }; gettimeofday(&now, nullptr); TCPResponse resp(std::move(response->query.d_buffer), std::move(response->query.d_idstate), nullptr, response->downstream); @@ -314,44 +283,45 @@ bool resumeQuery(std::unique_ptr<CrossProtocolQuery>&& query) return resumeResponse(std::move(query)); } - auto& ids = query->query.d_idstate; - DNSQuestion dq = query->getDQ(); + DNSQuestion dnsQuestion = query->getDQ(); LocalHolders holders; - auto result = processQueryAfterRules(dq, holders, query->downstream); + auto result = processQueryAfterRules(dnsQuestion, holders, query->downstream); if (result == ProcessQueryResult::Drop) { /* easy */ return true; } - else if (result == ProcessQueryResult::PassToBackend) { + if (result == ProcessQueryResult::PassToBackend) { if (query->downstream == nullptr) { return false; } #ifdef HAVE_DNS_OVER_HTTPS - if (dq.ids.du != nullptr) { - dq.ids.du->downstream = query->downstream; + if (dnsQuestion.ids.du != nullptr) { + dnsQuestion.ids.du->downstream = query->downstream; } #endif - if (query->downstream->isTCPOnly() || !(dq.getProtocol().isUDP() || dq.getProtocol() == dnsdist::Protocol::DoH)) { + if (query->downstream->isTCPOnly() || !(dnsQuestion.getProtocol().isUDP() || dnsQuestion.getProtocol() == dnsdist::Protocol::DoH)) { query->downstream->passCrossProtocolQuery(std::move(query)); return true; } - auto queryID = dq.getHeader()->id; + auto queryID = dnsQuestion.getHeader()->id; /* at this point 'du', if it is not nullptr, is owned by the DoHCrossProtocolQuery which will stop existing when we return, so we need to increment the reference count */ - return assignOutgoingUDPQueryToBackend(query->downstream, queryID, dq, query->query.d_buffer, ids.origDest); + return assignOutgoingUDPQueryToBackend(query->downstream, queryID, dnsQuestion, query->query.d_buffer); } - else if (result == ProcessQueryResult::SendAnswer) { + if (result == ProcessQueryResult::SendAnswer) { auto sender = query->getTCPQuerySender(); if (!sender) { return false; } - struct timeval now; + struct timeval now + { + }; gettimeofday(&now, nullptr); TCPResponse response(std::move(query->query.d_buffer), std::move(query->query.d_idstate), nullptr, query->downstream); @@ -367,7 +337,7 @@ bool resumeQuery(std::unique_ptr<CrossProtocolQuery>&& query) return false; } } - else if (result == ProcessQueryResult::Asynchronous) { + if (result == ProcessQueryResult::Asynchronous) { /* nope */ errlog("processQueryAfterRules returned 'asynchronous' while trying to resume an already asynchronous query"); return false; @@ -376,43 +346,47 @@ bool resumeQuery(std::unique_ptr<CrossProtocolQuery>&& query) return false; } -bool suspendQuery(DNSQuestion& dq, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs) +bool suspendQuery(DNSQuestion& dnsQuestion, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs) { if (!g_asyncHolder) { return false; } - struct timeval now; + struct timeval now + { + }; gettimeofday(&now, nullptr); struct timeval ttd = now; ttd.tv_sec += timeoutMs / 1000; - ttd.tv_usec += (timeoutMs % 1000) * 1000; + ttd.tv_usec += static_cast<decltype(ttd.tv_usec)>((timeoutMs % 1000) * 1000); normalizeTV(ttd); vinfolog("Suspending asynchronous query %d at %d.%d until %d.%d", queryID, now.tv_sec, now.tv_usec, ttd.tv_sec, ttd.tv_usec); - auto query = getInternalQueryFromDQ(dq, false); + auto query = getInternalQueryFromDQ(dnsQuestion, false); g_asyncHolder->push(asyncID, queryID, ttd, std::move(query)); return true; } -bool suspendResponse(DNSResponse& dr, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs) +bool suspendResponse(DNSResponse& dnsResponse, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs) { if (!g_asyncHolder) { return false; } - struct timeval now; + struct timeval now + { + }; gettimeofday(&now, nullptr); struct timeval ttd = now; ttd.tv_sec += timeoutMs / 1000; - ttd.tv_usec += (timeoutMs % 1000) * 1000; + ttd.tv_usec += static_cast<decltype(ttd.tv_usec)>((timeoutMs % 1000) * 1000); normalizeTV(ttd); vinfolog("Suspending asynchronous response %d at %d.%d until %d.%d", queryID, now.tv_sec, now.tv_usec, ttd.tv_sec, ttd.tv_usec); - auto query = getInternalQueryFromDQ(dr, true); + auto query = getInternalQueryFromDQ(dnsResponse, true); query->d_isResponse = true; - query->downstream = dr.d_downstream; + query->downstream = dnsResponse.d_downstream; g_asyncHolder->push(asyncID, queryID, ttd, std::move(query)); return true; |