diff options
Diffstat (limited to '')
-rw-r--r-- | dnsdist-async.cc | 422 |
1 files changed, 422 insertions, 0 deletions
diff --git a/dnsdist-async.cc b/dnsdist-async.cc new file mode 100644 index 0000000..e1acef8 --- /dev/null +++ b/dnsdist-async.cc @@ -0,0 +1,422 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +#include "dnsdist-async.hh" +#include "dnsdist-internal-queries.hh" +#include "dolog.hh" +#include "threadname.hh" + +namespace dnsdist +{ + +AsynchronousHolder::AsynchronousHolder(bool failOpen) : + d_data(std::make_shared<Data>()) +{ + d_data->d_failOpen = failOpen; + + int fds[2] = {-1, -1}; + if (pipe(fds) < 0) { + throw std::runtime_error("Error creating the AsynchronousHolder pipe: " + stringerror()); + } + + for (size_t idx = 0; idx < (sizeof(fds) / sizeof(*fds)); idx++) { + if (!setNonBlocking(fds[idx])) { + int err = errno; + close(fds[0]); + close(fds[1]); + throw std::runtime_error("Error setting the AsynchronousHolder pipe non-blocking: " + stringerror(err)); + } + } + + d_data->d_notifyPipe = FDWrapper(fds[1]); + d_data->d_watchPipe = FDWrapper(fds[0]); + + std::thread main([data = this->d_data] { mainThread(data); }); + main.detach(); +} + +AsynchronousHolder::~AsynchronousHolder() +{ + try { + stop(); + } + catch (...) { + } +} + +bool AsynchronousHolder::notify() const +{ + const char data = 0; + bool failed = false; + do { + auto written = write(d_data->d_notifyPipe.getHandle(), &data, sizeof(data)); + if (written == 0) { + break; + } + if (written > 0 && static_cast<size_t>(written) == sizeof(data)) { + return true; + } + if (errno != EINTR) { + failed = true; + } + } while (!failed); + + return false; +} + +bool AsynchronousHolder::wait(const AsynchronousHolder::Data& data, FDMultiplexer& mplexer, std::vector<int>& readyFDs, int atMostMs) +{ + readyFDs.clear(); + mplexer.getAvailableFDs(readyFDs, atMostMs); + if (readyFDs.size() == 0) { + /* timeout */ + return true; + } + + while (true) { + /* we might have been notified several times, let's read + as much as possible before returning */ + char dummy = 0; + auto got = read(data.d_watchPipe.getHandle(), &dummy, sizeof(dummy)); + if (got == 0) { + break; + } + if (got > 0 && static_cast<size_t>(got) != sizeof(dummy)) { + continue; + } + if (got == -1 && (errno == EAGAIN || errno == EWOULDBLOCK)) { + break; + } + } + + return false; +} + +void AsynchronousHolder::stop() +{ + { + auto content = d_data->d_content.lock(); + d_data->d_done = true; + } + + notify(); +} + +void AsynchronousHolder::mainThread(std::shared_ptr<Data> data) +{ + setThreadName("dnsdist/async"); + struct timeval now; + std::list<std::pair<uint16_t, std::unique_ptr<CrossProtocolQuery>>> expiredEvents; + + auto mplexer = std::unique_ptr<FDMultiplexer>(FDMultiplexer::getMultiplexerSilent(1)); + mplexer->addReadFD(data->d_watchPipe.getHandle(), [](int, FDMultiplexer::funcparam_t&) {}); + std::vector<int> readyFDs; + + while (true) { + bool shouldWait = true; + int timeout = -1; + { + auto content = data->d_content.lock(); + if (data->d_done) { + return; + } + + if (!content->empty()) { + gettimeofday(&now, nullptr); + struct timeval next = getNextTTD(*content); + if (next <= now) { + pickupExpired(*content, now, expiredEvents); + shouldWait = false; + } + else { + auto remainingUsec = uSec(next - now); + timeout = std::round(remainingUsec / 1000.0); + if (timeout == 0 && remainingUsec > 0) { + /* if we have less than 1 ms, let's wait at least 1 ms */ + timeout = 1; + } + } + } + } + + if (shouldWait) { + auto timedOut = wait(*data, *mplexer, readyFDs, timeout); + if (timedOut) { + auto content = data->d_content.lock(); + gettimeofday(&now, nullptr); + pickupExpired(*content, now, expiredEvents); + } + } + + while (!expiredEvents.empty()) { + auto [queryID, query] = std::move(expiredEvents.front()); + expiredEvents.pop_front(); + if (!data->d_failOpen) { + vinfolog("Asynchronous query %d has expired at %d.%d, notifying the sender", queryID, now.tv_sec, now.tv_usec); + auto sender = query->getTCPQuerySender(); + if (sender) { + sender->notifyIOError(std::move(query->query.d_idstate), now); + } + } + else { + vinfolog("Asynchronous query %d has expired at %d.%d, resuming", queryID, now.tv_sec, now.tv_usec); + resumeQuery(std::move(query)); + } + } + } +} + +void AsynchronousHolder::push(uint16_t asyncID, uint16_t queryID, const struct timeval& ttd, std::unique_ptr<CrossProtocolQuery>&& query) +{ + bool needNotify = false; + { + auto content = d_data->d_content.lock(); + if (!content->empty()) { + /* the thread is already waiting on a TTD expiry in addition to notifications, + let's not wake it unless our TTD comes before the current one */ + const struct timeval next = getNextTTD(*content); + if (ttd < next) { + needNotify = true; + } + } + else { + /* the thread is currently only waiting for a notify */ + needNotify = true; + } + content->insert({std::move(query), ttd, asyncID, queryID}); + } + + if (needNotify) { + notify(); + } +} + +std::unique_ptr<CrossProtocolQuery> AsynchronousHolder::get(uint16_t asyncID, uint16_t queryID) +{ + /* no need to notify, worst case the thread wakes up for nothing because this was the next TTD */ + auto content = d_data->d_content.lock(); + auto it = content->find(std::tie(queryID, asyncID)); + if (it == content->end()) { + struct timeval now; + gettimeofday(&now, nullptr); + vinfolog("Asynchronous object %d not found at %d.%d", queryID, now.tv_sec, now.tv_usec); + return nullptr; + } + + auto result = std::move(it->d_query); + content->erase(it); + return result; +} + +void AsynchronousHolder::pickupExpired(content_t& content, const struct timeval& now, std::list<std::pair<uint16_t, std::unique_ptr<CrossProtocolQuery>>>& events) +{ + auto& idx = content.get<TTDTag>(); + for (auto it = idx.begin(); it != idx.end() && it->d_ttd < now;) { + events.emplace_back(it->d_queryID, std::move(it->d_query)); + it = idx.erase(it); + } +} + +struct timeval AsynchronousHolder::getNextTTD(const content_t& content) +{ + if (content.empty()) { + throw std::runtime_error("AsynchronousHolder::getNextTTD() called on an empty holder"); + } + + return content.get<TTDTag>().begin()->d_ttd; +} + +bool AsynchronousHolder::empty() +{ + return d_data->d_content.read_only_lock()->empty(); +} + +static bool resumeResponse(std::unique_ptr<CrossProtocolQuery>&& response) +{ + try { + auto& ids = response->query.d_idstate; + DNSResponse dr = response->getDR(); + + LocalHolders holders; + auto result = processResponseAfterRules(response->query.d_buffer, *holders.cacheInsertedRespRuleActions, dr, ids.cs->muted); + if (!result) { + /* easy */ + return true; + } + + auto sender = response->getTCPQuerySender(); + if (sender) { + struct timeval now; + gettimeofday(&now, nullptr); + + TCPResponse resp(std::move(response->query.d_buffer), std::move(response->query.d_idstate), nullptr, response->downstream); + resp.d_async = true; + sender->handleResponse(now, std::move(resp)); + } + } + catch (const std::exception& e) { + vinfolog("Got exception while resuming cross-protocol response: %s", e.what()); + return false; + } + + return true; +} + +static LockGuarded<std::deque<std::unique_ptr<CrossProtocolQuery>>> s_asynchronousEventsQueue; + +bool queueQueryResumptionEvent(std::unique_ptr<CrossProtocolQuery>&& query) +{ + s_asynchronousEventsQueue.lock()->push_back(std::move(query)); + return true; +} + +void handleQueuedAsynchronousEvents() +{ + while (true) { + std::unique_ptr<CrossProtocolQuery> query; + { + // we do not want to hold the lock while resuming + auto queue = s_asynchronousEventsQueue.lock(); + if (queue->empty()) { + return; + } + + query = std::move(queue->front()); + queue->pop_front(); + } + if (query && !resumeQuery(std::move(query))) { + vinfolog("Unable to resume asynchronous query event"); + } + } +} + +bool resumeQuery(std::unique_ptr<CrossProtocolQuery>&& query) +{ + if (query->d_isResponse) { + return resumeResponse(std::move(query)); + } + + auto& ids = query->query.d_idstate; + DNSQuestion dq = query->getDQ(); + LocalHolders holders; + + auto result = processQueryAfterRules(dq, holders, query->downstream); + if (result == ProcessQueryResult::Drop) { + /* easy */ + return true; + } + else if (result == ProcessQueryResult::PassToBackend) { + if (query->downstream == nullptr) { + return false; + } + +#ifdef HAVE_DNS_OVER_HTTPS + if (dq.ids.du != nullptr) { + dq.ids.du->downstream = query->downstream; + } +#endif + + if (query->downstream->isTCPOnly() || !(dq.getProtocol().isUDP() || dq.getProtocol() == dnsdist::Protocol::DoH)) { + query->downstream->passCrossProtocolQuery(std::move(query)); + return true; + } + + auto queryID = dq.getHeader()->id; + /* at this point 'du', if it is not nullptr, is owned by the DoHCrossProtocolQuery + which will stop existing when we return, so we need to increment the reference count + */ + return assignOutgoingUDPQueryToBackend(query->downstream, queryID, dq, query->query.d_buffer, ids.origDest); + } + else if (result == ProcessQueryResult::SendAnswer) { + auto sender = query->getTCPQuerySender(); + if (!sender) { + return false; + } + + struct timeval now; + gettimeofday(&now, nullptr); + + TCPResponse response(std::move(query->query.d_buffer), std::move(query->query.d_idstate), nullptr, query->downstream); + response.d_async = true; + response.d_idstate.selfGenerated = true; + + try { + sender->handleResponse(now, std::move(response)); + return true; + } + catch (const std::exception& e) { + vinfolog("Got exception while resuming cross-protocol self-answered query: %s", e.what()); + return false; + } + } + else if (result == ProcessQueryResult::Asynchronous) { + /* nope */ + errlog("processQueryAfterRules returned 'asynchronous' while trying to resume an already asynchronous query"); + return false; + } + + return false; +} + +bool suspendQuery(DNSQuestion& dq, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs) +{ + if (!g_asyncHolder) { + return false; + } + + struct timeval now; + gettimeofday(&now, nullptr); + struct timeval ttd = now; + ttd.tv_sec += timeoutMs / 1000; + ttd.tv_usec += (timeoutMs % 1000) * 1000; + normalizeTV(ttd); + + vinfolog("Suspending asynchronous query %d at %d.%d until %d.%d", queryID, now.tv_sec, now.tv_usec, ttd.tv_sec, ttd.tv_usec); + auto query = getInternalQueryFromDQ(dq, false); + + g_asyncHolder->push(asyncID, queryID, ttd, std::move(query)); + return true; +} + +bool suspendResponse(DNSResponse& dr, uint16_t asyncID, uint16_t queryID, uint32_t timeoutMs) +{ + if (!g_asyncHolder) { + return false; + } + + struct timeval now; + gettimeofday(&now, nullptr); + struct timeval ttd = now; + ttd.tv_sec += timeoutMs / 1000; + ttd.tv_usec += (timeoutMs % 1000) * 1000; + normalizeTV(ttd); + + vinfolog("Suspending asynchronous response %d at %d.%d until %d.%d", queryID, now.tv_sec, now.tv_usec, ttd.tv_sec, ttd.tv_usec); + auto query = getInternalQueryFromDQ(dr, true); + query->d_isResponse = true; + query->downstream = dr.d_downstream; + + g_asyncHolder->push(asyncID, queryID, ttd, std::move(query)); + return true; +} + +std::unique_ptr<AsynchronousHolder> g_asyncHolder; +} |