summaryrefslogtreecommitdiffstats
path: root/dnsdist-async.cc
diff options
context:
space:
mode:
Diffstat (limited to 'dnsdist-async.cc')
-rw-r--r--dnsdist-async.cc154
1 files changed, 64 insertions, 90 deletions
diff --git a/dnsdist-async.cc b/dnsdist-async.cc
index e1acef8..9cb96d8 100644
--- a/dnsdist-async.cc
+++ b/dnsdist-async.cc
@@ -27,28 +27,19 @@
namespace dnsdist
{
-AsynchronousHolder::AsynchronousHolder(bool failOpen) :
- d_data(std::make_shared<Data>())
+AsynchronousHolder::Data::Data(bool failOpen) :
+ d_failOpen(failOpen)
{
- d_data->d_failOpen = failOpen;
-
- int fds[2] = {-1, -1};
- if (pipe(fds) < 0) {
- throw std::runtime_error("Error creating the AsynchronousHolder pipe: " + stringerror());
- }
-
- for (size_t idx = 0; idx < (sizeof(fds) / sizeof(*fds)); idx++) {
- if (!setNonBlocking(fds[idx])) {
- int err = errno;
- close(fds[0]);
- close(fds[1]);
- throw std::runtime_error("Error setting the AsynchronousHolder pipe non-blocking: " + stringerror(err));
- }
- }
-
- d_data->d_notifyPipe = FDWrapper(fds[1]);
- d_data->d_watchPipe = FDWrapper(fds[0]);
+ auto [notifier, waiter] = pdns::channel::createNotificationQueue(true);
+ // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer): how I am supposed to do that?
+ d_waiter = std::move(waiter);
+ // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer): how I am supposed to do that?
+ d_notifier = std::move(notifier);
+}
+AsynchronousHolder::AsynchronousHolder(bool failOpen) :
+ d_data(std::make_shared<Data>(failOpen))
+{
std::thread main([data = this->d_data] { mainThread(data); });
main.detach();
}
@@ -64,49 +55,19 @@ AsynchronousHolder::~AsynchronousHolder()
bool AsynchronousHolder::notify() const
{
- const char data = 0;
- bool failed = false;
- do {
- auto written = write(d_data->d_notifyPipe.getHandle(), &data, sizeof(data));
- if (written == 0) {
- break;
- }
- if (written > 0 && static_cast<size_t>(written) == sizeof(data)) {
- return true;
- }
- if (errno != EINTR) {
- failed = true;
- }
- } while (!failed);
-
- return false;
+ return d_data->d_notifier.notify();
}
-bool AsynchronousHolder::wait(const AsynchronousHolder::Data& data, FDMultiplexer& mplexer, std::vector<int>& readyFDs, int atMostMs)
+bool AsynchronousHolder::wait(AsynchronousHolder::Data& data, FDMultiplexer& mplexer, std::vector<int>& readyFDs, int atMostMs)
{
readyFDs.clear();
mplexer.getAvailableFDs(readyFDs, atMostMs);
- if (readyFDs.size() == 0) {
+ if (readyFDs.empty()) {
/* timeout */
return true;
}
- while (true) {
- /* we might have been notified several times, let's read
- as much as possible before returning */
- char dummy = 0;
- auto got = read(data.d_watchPipe.getHandle(), &dummy, sizeof(dummy));
- if (got == 0) {
- break;
- }
- if (got > 0 && static_cast<size_t>(got) != sizeof(dummy)) {
- continue;
- }
- if (got == -1 && (errno == EAGAIN || errno == EWOULDBLOCK)) {
- break;
- }
- }
-
+ data.d_waiter.clear();
return false;
}
@@ -120,14 +81,17 @@ void AsynchronousHolder::stop()
notify();
}
+// NOLINTNEXTLINE(performance-unnecessary-value-param): this is a long-lived thread, and we want to make sure the reference count of the shared pointer has been increased
void AsynchronousHolder::mainThread(std::shared_ptr<Data> data)
{
setThreadName("dnsdist/async");
- struct timeval now;
+ struct timeval now
+ {
+ };
std::list<std::pair<uint16_t, std::unique_ptr<CrossProtocolQuery>>> expiredEvents;
auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent(1));
- mplexer->addReadFD(data->d_watchPipe.getHandle(), [](int, FDMultiplexer::funcparam_t&) {});
+ mplexer->addReadFD(data->d_waiter.getDescriptor(), [](int, FDMultiplexer::funcparam_t&) {});
std::vector<int> readyFDs;
while (true) {
@@ -148,7 +112,7 @@ void AsynchronousHolder::mainThread(std::shared_ptr<Data> data)
}
else {
auto remainingUsec = uSec(next - now);
- timeout = std::round(remainingUsec / 1000.0);
+ timeout = static_cast<int>(std::round(static_cast<double>(remainingUsec) / 1000.0));
if (timeout == 0 && remainingUsec > 0) {
/* if we have less than 1 ms, let's wait at least 1 ms */
timeout = 1;
@@ -173,7 +137,8 @@ void AsynchronousHolder::mainThread(std::shared_ptr<Data> data)
vinfolog("Asynchronous query %d has expired at %d.%d, notifying the sender", queryID, now.tv_sec, now.tv_usec);
auto sender = query->getTCPQuerySender();
if (sender) {
- sender->notifyIOError(std::move(query->query.d_idstate), now);
+ TCPResponse tresponse(std::move(query->query));
+ sender->notifyIOError(now, std::move(tresponse));
}
}
else {
@@ -213,25 +178,27 @@ std::unique_ptr<CrossProtocolQuery> AsynchronousHolder::get(uint16_t asyncID, ui
{
/* no need to notify, worst case the thread wakes up for nothing because this was the next TTD */
auto content = d_data->d_content.lock();
- auto it = content->find(std::tie(queryID, asyncID));
- if (it == content->end()) {
- struct timeval now;
+ auto contentIt = content->find(std::tie(queryID, asyncID));
+ if (contentIt == content->end()) {
+ struct timeval now
+ {
+ };
gettimeofday(&now, nullptr);
vinfolog("Asynchronous object %d not found at %d.%d", queryID, now.tv_sec, now.tv_usec);
return nullptr;
}
- auto result = std::move(it->d_query);
- content->erase(it);
+ auto result = std::move(contentIt->d_query);
+ content->erase(contentIt);
return result;
}
void AsynchronousHolder::pickupExpired(content_t& content, const struct timeval& now, std::list<std::pair<uint16_t, std::unique_ptr<CrossProtocolQuery>>>& events)
{
auto& idx = content.get<TTDTag>();
- for (auto it = idx.begin(); it != idx.end() && it->d_ttd < now;) {
- events.emplace_back(it->d_queryID, std::move(it->d_query));
- it = idx.erase(it);
+ for (auto contentIt = idx.begin(); contentIt != idx.end() && contentIt->d_ttd < now;) {
+ events.emplace_back(contentIt->d_queryID, std::move(contentIt->d_query));
+ contentIt = idx.erase(contentIt);
}
}
@@ -253,10 +220,10 @@ static bool resumeResponse(std::unique_ptr<CrossProtocolQuery>&& response)
{
try {
auto& ids = response->query.d_idstate;
- DNSResponse dr = response->getDR();
+ DNSResponse dnsResponse = response->getDR();
LocalHolders holders;
- auto result = processResponseAfterRules(response->query.d_buffer, *holders.cacheInsertedRespRuleActions, dr, ids.cs->muted);
+ auto result = processResponseAfterRules(response->query.d_buffer, *holders.cacheInsertedRespRuleActions, dnsResponse, ids.cs->muted);
if (!result) {
/* easy */
return true;
@@ -264,7 +231,9 @@ static bool resumeResponse(std::unique_ptr<CrossProtocolQuery>&& response)
auto sender = response->getTCPQuerySender();
if (sender) {
- struct timeval now;
+ struct timeval now
+ {
+ };
gettimeofday(&now, nullptr);
TCPResponse resp(std::move(response->query.d_buffer), std::move(response->query.d_idstate), nullptr, response->downstream);
@@ -314,44 +283,45 @@ bool resumeQuery(std::unique_ptr<CrossProtocolQuery>&& query)
return resumeResponse(std::move(query));
}
- auto& ids = query->query.d_idstate;
- DNSQuestion dq = query->getDQ();
+ DNSQuestion dnsQuestion = query->getDQ();
LocalHolders holders;
- auto result = processQueryAfterRules(dq, holders, query->downstream);
+ auto result = processQueryAfterRules(dnsQuestion, holders, query->downstream);
if (result == ProcessQueryResult::Drop) {
/* easy */
return true;
}
- else if (result == ProcessQueryResult::PassToBackend) {
+ if (result == ProcessQueryResult::PassToBackend) {
if (query->downstream == nullptr) {
return false;
}
#ifdef HAVE_DNS_OVER_HTTPS
- if (dq.ids.du != nullptr) {
- dq.ids.du->downstream = query->downstream;
+ if (dnsQuestion.ids.du != nullptr) {
+ dnsQuestion.ids.du->downstream = query->downstream;
}
#endif
- if (query->downstream->isTCPOnly() || !(dq.getProtocol().isUDP() || dq.getProtocol() == dnsdist::Protocol::DoH)) {
+ if (query->downstream->isTCPOnly() || !(dnsQuestion.getProtocol().isUDP() || dnsQuestion.getProtocol() == dnsdist::Protocol::DoH)) {
query->downstream->passCrossProtocolQuery(std::move(query));
return true;
}
- auto queryID = dq.getHeader()->id;
+ auto queryID = dnsQuestion.getHeader()->id;
/* at this point 'du', if it is not nullptr, is owned by the DoHCrossProtocolQuery
which will stop existing when we return, so we need to increment the reference count
*/
- return assignOutgoingUDPQueryToBackend(query->downstream, queryID, dq, query->query.d_buffer, ids.origDest);
+ return assignOutgoingUDPQueryToBackend(query->downstream, queryID, dnsQuestion, query->query.d_buffer);
}
- else if (result == ProcessQueryResult::SendAnswer) {
+ if (result == ProcessQueryResult::SendAnswer) {
auto sender = query->getTCPQuerySender();
if (!sender) {
return false;
}
- struct timeval now;
+ struct timeval now
+ {
+ };
gettimeofday(&now, nullptr);
TCPResponse response(std::move(query->query.d_buffer), std::move(query->query.d_idstate), nullptr, query->downstream);
@@ -367,7 +337,7 @@ bool resumeQuery(std::unique_ptr<CrossProtocolQuery>&& query)
return false;
}
}
- else if (result == ProcessQueryResult::Asynchronous) {
+ if (result == ProcessQueryResult::Asynchronous) {
/* nope */
errlog("processQueryAfterRules returned 'asynchronous' while trying to resume an already asynchronous query");
return false;
@@ -376,43 +346,47 @@ bool resumeQuery(std::unique_ptr<CrossProtocolQuery>&& query)
return false;
}
-bool suspendQuery(DNSQuestion& dq, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs)
+bool suspendQuery(DNSQuestion& dnsQuestion, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs)
{
if (!g_asyncHolder) {
return false;
}
- struct timeval now;
+ struct timeval now
+ {
+ };
gettimeofday(&now, nullptr);
struct timeval ttd = now;
ttd.tv_sec += timeoutMs / 1000;
- ttd.tv_usec += (timeoutMs % 1000) * 1000;
+ ttd.tv_usec += static_cast<decltype(ttd.tv_usec)>((timeoutMs % 1000) * 1000);
normalizeTV(ttd);
vinfolog("Suspending asynchronous query %d at %d.%d until %d.%d", queryID, now.tv_sec, now.tv_usec, ttd.tv_sec, ttd.tv_usec);
- auto query = getInternalQueryFromDQ(dq, false);
+ auto query = getInternalQueryFromDQ(dnsQuestion, false);
g_asyncHolder->push(asyncID, queryID, ttd, std::move(query));
return true;
}
-bool suspendResponse(DNSResponse& dr, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs)
+bool suspendResponse(DNSResponse& dnsResponse, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs)
{
if (!g_asyncHolder) {
return false;
}
- struct timeval now;
+ struct timeval now
+ {
+ };
gettimeofday(&now, nullptr);
struct timeval ttd = now;
ttd.tv_sec += timeoutMs / 1000;
- ttd.tv_usec += (timeoutMs % 1000) * 1000;
+ ttd.tv_usec += static_cast<decltype(ttd.tv_usec)>((timeoutMs % 1000) * 1000);
normalizeTV(ttd);
vinfolog("Suspending asynchronous response %d at %d.%d until %d.%d", queryID, now.tv_sec, now.tv_usec, ttd.tv_sec, ttd.tv_usec);
- auto query = getInternalQueryFromDQ(dr, true);
+ auto query = getInternalQueryFromDQ(dnsResponse, true);
query->d_isResponse = true;
- query->downstream = dr.d_downstream;
+ query->downstream = dnsResponse.d_downstream;
g_asyncHolder->push(asyncID, queryID, ttd, std::move(query));
return true;