summaryrefslogtreecommitdiffstats
path: root/dnsdist-cache.cc
diff options
context:
space:
mode:
Diffstat (limited to 'dnsdist-cache.cc')
-rw-r--r--dnsdist-cache.cc622
1 files changed, 622 insertions, 0 deletions
diff --git a/dnsdist-cache.cc b/dnsdist-cache.cc
new file mode 100644
index 0000000..7ca9be2
--- /dev/null
+++ b/dnsdist-cache.cc
@@ -0,0 +1,622 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#include <cinttypes>
+
+#include "dnsdist.hh"
+#include "dolog.hh"
+#include "dnsparser.hh"
+#include "dnsdist-cache.hh"
+#include "dnsdist-ecs.hh"
+#include "ednssubnet.hh"
+#include "packetcache.hh"
+
+DNSDistPacketCache::DNSDistPacketCache(size_t maxEntries, uint32_t maxTTL, uint32_t minTTL, uint32_t tempFailureTTL, uint32_t maxNegativeTTL, uint32_t staleTTL, bool dontAge, uint32_t shards, bool deferrableInsertLock, bool parseECS): d_maxEntries(maxEntries), d_shardCount(shards), d_maxTTL(maxTTL), d_tempFailureTTL(tempFailureTTL), d_maxNegativeTTL(maxNegativeTTL), d_minTTL(minTTL), d_staleTTL(staleTTL), d_dontAge(dontAge), d_deferrableInsertLock(deferrableInsertLock), d_parseECS(parseECS)
+{
+ if (d_maxEntries == 0) {
+ throw std::runtime_error("Trying to create a 0-sized packet-cache");
+ }
+
+ d_shards.resize(d_shardCount);
+
+ /* we reserve maxEntries + 1 to avoid rehashing from occurring
+ when we get to maxEntries, as it means a load factor of 1 */
+ for (auto& shard : d_shards) {
+ shard.setSize((maxEntries / d_shardCount) + 1);
+ }
+}
+
+bool DNSDistPacketCache::getClientSubnet(const PacketBuffer& packet, size_t qnameWireLength, boost::optional<Netmask>& subnet)
+{
+ uint16_t optRDPosition;
+ size_t remaining = 0;
+
+ int res = getEDNSOptionsStart(packet, qnameWireLength, &optRDPosition, &remaining);
+
+ if (res == 0) {
+ size_t ecsOptionStartPosition = 0;
+ size_t ecsOptionSize = 0;
+
+ res = getEDNSOption(reinterpret_cast<const char*>(&packet.at(optRDPosition)), remaining, EDNSOptionCode::ECS, &ecsOptionStartPosition, &ecsOptionSize);
+
+ if (res == 0 && ecsOptionSize > (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE)) {
+
+ EDNSSubnetOpts eso;
+ if (getEDNSSubnetOptsFromString(reinterpret_cast<const char*>(&packet.at(optRDPosition + ecsOptionStartPosition + (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE))), ecsOptionSize - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), &eso) == true) {
+ subnet = eso.source;
+ return true;
+ }
+ }
+ }
+
+ return false;
+}
+
+bool DNSDistPacketCache::cachedValueMatches(const CacheValue& cachedValue, uint16_t queryFlags, const DNSName& qname, uint16_t qtype, uint16_t qclass, bool receivedOverUDP, bool dnssecOK, const boost::optional<Netmask>& subnet) const
+{
+ if (cachedValue.queryFlags != queryFlags || cachedValue.dnssecOK != dnssecOK || cachedValue.receivedOverUDP != receivedOverUDP || cachedValue.qtype != qtype || cachedValue.qclass != qclass || cachedValue.qname != qname) {
+ return false;
+ }
+
+ if (d_parseECS && cachedValue.subnet != subnet) {
+ return false;
+ }
+
+ return true;
+}
+
+void DNSDistPacketCache::insertLocked(CacheShard& shard, std::unordered_map<uint32_t,CacheValue>& map, uint32_t key, CacheValue& newValue)
+{
+ /* check again now that we hold the lock to prevent a race */
+ if (map.size() >= (d_maxEntries / d_shardCount)) {
+ return;
+ }
+
+ std::unordered_map<uint32_t,CacheValue>::iterator it;
+ bool result;
+ std::tie(it, result) = map.insert({key, newValue});
+
+ if (result) {
+ ++shard.d_entriesCount;
+ return;
+ }
+
+ /* in case of collision, don't override the existing entry
+ except if it has expired */
+ CacheValue& value = it->second;
+ bool wasExpired = value.validity <= newValue.added;
+
+ if (!wasExpired && !cachedValueMatches(value, newValue.queryFlags, newValue.qname, newValue.qtype, newValue.qclass, newValue.receivedOverUDP, newValue.dnssecOK, newValue.subnet)) {
+ ++d_insertCollisions;
+ return;
+ }
+
+ /* if the existing entry had a longer TTD, keep it */
+ if (newValue.validity <= value.validity) {
+ return;
+ }
+
+ value = newValue;
+}
+
+void DNSDistPacketCache::insert(uint32_t key, const boost::optional<Netmask>& subnet, uint16_t queryFlags, bool dnssecOK, const DNSName& qname, uint16_t qtype, uint16_t qclass, const PacketBuffer& response, bool receivedOverUDP, uint8_t rcode, boost::optional<uint32_t> tempFailureTTL)
+{
+ if (response.size() < sizeof(dnsheader)) {
+ return;
+ }
+ if (qtype == QType::AXFR || qtype == QType::IXFR) {
+ return;
+ }
+
+ uint32_t minTTL;
+
+ if (rcode == RCode::ServFail || rcode == RCode::Refused) {
+ minTTL = tempFailureTTL == boost::none ? d_tempFailureTTL : *tempFailureTTL;
+ if (minTTL == 0) {
+ return;
+ }
+ }
+ else {
+ bool seenAuthSOA = false;
+ minTTL = getMinTTL(reinterpret_cast<const char*>(response.data()), response.size(), &seenAuthSOA);
+
+ /* no TTL found, we don't want to cache this */
+ if (minTTL == std::numeric_limits<uint32_t>::max()) {
+ return;
+ }
+
+ if (rcode == RCode::NXDomain || (rcode == RCode::NoError && seenAuthSOA)) {
+ minTTL = std::min(minTTL, d_maxNegativeTTL);
+ }
+ else if (minTTL > d_maxTTL) {
+ minTTL = d_maxTTL;
+ }
+
+ if (minTTL < d_minTTL) {
+ ++d_ttlTooShorts;
+ return;
+ }
+ }
+
+ uint32_t shardIndex = getShardIndex(key);
+
+ if (d_shards.at(shardIndex).d_entriesCount >= (d_maxEntries / d_shardCount)) {
+ return;
+ }
+
+ const time_t now = time(nullptr);
+ time_t newValidity = now + minTTL;
+ CacheValue newValue;
+ newValue.qname = qname;
+ newValue.qtype = qtype;
+ newValue.qclass = qclass;
+ newValue.queryFlags = queryFlags;
+ newValue.len = response.size();
+ newValue.validity = newValidity;
+ newValue.added = now;
+ newValue.receivedOverUDP = receivedOverUDP;
+ newValue.dnssecOK = dnssecOK;
+ newValue.value = std::string(response.begin(), response.end());
+ newValue.subnet = subnet;
+
+ auto& shard = d_shards.at(shardIndex);
+
+ if (d_deferrableInsertLock) {
+ auto w = shard.d_map.try_write_lock();
+
+ if (!w.owns_lock()) {
+ ++d_deferredInserts;
+ return;
+ }
+ insertLocked(shard, *w, key, newValue);
+ }
+ else {
+ auto w = shard.d_map.write_lock();
+
+ insertLocked(shard, *w, key, newValue);
+ }
+}
+
+bool DNSDistPacketCache::get(DNSQuestion& dq, uint16_t queryId, uint32_t* keyOut, boost::optional<Netmask>& subnet, bool dnssecOK, bool receivedOverUDP, uint32_t allowExpired, bool skipAging, bool truncatedOK, bool recordMiss)
+{
+ if (dq.ids.qtype == QType::AXFR || dq.ids.qtype == QType::IXFR) {
+ ++d_misses;
+ return false;
+ }
+
+ const auto& dnsQName = dq.ids.qname.getStorage();
+ uint32_t key = getKey(dnsQName, dq.ids.qname.wirelength(), dq.getData(), receivedOverUDP);
+
+ if (keyOut) {
+ *keyOut = key;
+ }
+
+ if (d_parseECS) {
+ getClientSubnet(dq.getData(), dq.ids.qname.wirelength(), subnet);
+ }
+
+ uint32_t shardIndex = getShardIndex(key);
+ time_t now = time(nullptr);
+ time_t age;
+ bool stale = false;
+ auto& response = dq.getMutableData();
+ auto& shard = d_shards.at(shardIndex);
+ {
+ auto map = shard.d_map.try_read_lock();
+ if (!map.owns_lock()) {
+ ++d_deferredLookups;
+ return false;
+ }
+
+ std::unordered_map<uint32_t,CacheValue>::const_iterator it = map->find(key);
+ if (it == map->end()) {
+ if (recordMiss) {
+ ++d_misses;
+ }
+ return false;
+ }
+
+ const CacheValue& value = it->second;
+ if (value.validity <= now) {
+ if ((now - value.validity) >= static_cast<time_t>(allowExpired)) {
+ if (recordMiss) {
+ ++d_misses;
+ }
+ return false;
+ }
+ else {
+ stale = true;
+ }
+ }
+
+ if (value.len < sizeof(dnsheader)) {
+ return false;
+ }
+
+ /* check for collision */
+ if (!cachedValueMatches(value, *(getFlagsFromDNSHeader(dq.getHeader())), dq.ids.qname, dq.ids.qtype, dq.ids.qclass, receivedOverUDP, dnssecOK, subnet)) {
+ ++d_lookupCollisions;
+ return false;
+ }
+
+ if (!truncatedOK) {
+ dnsheader dh;
+ memcpy(&dh, value.value.data(), sizeof(dh));
+ if (dh.tc != 0) {
+ return false;
+ }
+ }
+
+ response.resize(value.len);
+ memcpy(&response.at(0), &queryId, sizeof(queryId));
+ memcpy(&response.at(sizeof(queryId)), &value.value.at(sizeof(queryId)), sizeof(dnsheader) - sizeof(queryId));
+
+ if (value.len == sizeof(dnsheader)) {
+ /* DNS header only, our work here is done */
+ ++d_hits;
+ return true;
+ }
+
+ const size_t dnsQNameLen = dnsQName.length();
+ if (value.len < (sizeof(dnsheader) + dnsQNameLen)) {
+ return false;
+ }
+
+ memcpy(&response.at(sizeof(dnsheader)), dnsQName.c_str(), dnsQNameLen);
+ if (value.len > (sizeof(dnsheader) + dnsQNameLen)) {
+ memcpy(&response.at(sizeof(dnsheader) + dnsQNameLen), &value.value.at(sizeof(dnsheader) + dnsQNameLen), value.len - (sizeof(dnsheader) + dnsQNameLen));
+ }
+
+ if (!stale) {
+ age = now - value.added;
+ }
+ else {
+ age = (value.validity - value.added) - d_staleTTL;
+ }
+ }
+
+ if (!d_dontAge && !skipAging) {
+ if (!stale) {
+ // coverity[store_truncates_time_t]
+ dnsheader_aligned dh_aligned(response.data());
+ ageDNSPacket(reinterpret_cast<char *>(&response[0]), response.size(), age, dh_aligned);
+ }
+ else {
+ editDNSPacketTTL(reinterpret_cast<char*>(&response[0]), response.size(),
+ [staleTTL = d_staleTTL](uint8_t /* section */, uint16_t /* class_ */, uint16_t /* type */, uint32_t /* ttl */) { return staleTTL; });
+ }
+ }
+
+ ++d_hits;
+ return true;
+}
+
+/* Remove expired entries, until the cache has at most
+ upTo entries in it.
+ If the cache has more than one shard, we will try hard
+ to make sure that every shard has free space remaining.
+*/
+size_t DNSDistPacketCache::purgeExpired(size_t upTo, const time_t now)
+{
+ const size_t maxPerShard = upTo / d_shardCount;
+
+ size_t removed = 0;
+
+ ++d_cleanupCount;
+ for (auto& shard : d_shards) {
+ auto map = shard.d_map.write_lock();
+ if (map->size() <= maxPerShard) {
+ continue;
+ }
+
+ size_t toRemove = map->size() - maxPerShard;
+
+ for (auto it = map->begin(); toRemove > 0 && it != map->end(); ) {
+ const CacheValue& value = it->second;
+
+ if (value.validity <= now) {
+ it = map->erase(it);
+ --toRemove;
+ --shard.d_entriesCount;
+ ++removed;
+ } else {
+ ++it;
+ }
+ }
+ }
+
+ return removed;
+}
+
+/* Remove all entries, keeping only upTo
+ entries in the cache.
+ If the cache has more than one shard, we will try hard
+ to make sure that every shard has free space remaining.
+*/
+size_t DNSDistPacketCache::expunge(size_t upTo)
+{
+ const size_t maxPerShard = upTo / d_shardCount;
+
+ size_t removed = 0;
+
+ for (auto& shard : d_shards) {
+ auto map = shard.d_map.write_lock();
+
+ if (map->size() <= maxPerShard) {
+ continue;
+ }
+
+ size_t toRemove = map->size() - maxPerShard;
+
+ auto beginIt = map->begin();
+ auto endIt = beginIt;
+
+ if (map->size() >= toRemove) {
+ std::advance(endIt, toRemove);
+ map->erase(beginIt, endIt);
+ shard.d_entriesCount -= toRemove;
+ removed += toRemove;
+ }
+ else {
+ removed += map->size();
+ map->clear();
+ shard.d_entriesCount = 0;
+ }
+ }
+
+ return removed;
+}
+
+size_t DNSDistPacketCache::expungeByName(const DNSName& name, uint16_t qtype, bool suffixMatch)
+{
+ size_t removed = 0;
+
+ for (auto& shard : d_shards) {
+ auto map = shard.d_map.write_lock();
+
+ for(auto it = map->begin(); it != map->end(); ) {
+ const CacheValue& value = it->second;
+
+ if ((value.qname == name || (suffixMatch && value.qname.isPartOf(name))) && (qtype == QType::ANY || qtype == value.qtype)) {
+ it = map->erase(it);
+ --shard.d_entriesCount;
+ ++removed;
+ } else {
+ ++it;
+ }
+ }
+ }
+
+ return removed;
+}
+
+bool DNSDistPacketCache::isFull()
+{
+ return (getSize() >= d_maxEntries);
+}
+
+uint64_t DNSDistPacketCache::getSize()
+{
+ uint64_t count = 0;
+
+ for (auto& shard : d_shards) {
+ count += shard.d_entriesCount;
+ }
+
+ return count;
+}
+
+uint32_t DNSDistPacketCache::getMinTTL(const char* packet, uint16_t length, bool* seenNoDataSOA)
+{
+ return getDNSPacketMinTTL(packet, length, seenNoDataSOA);
+}
+
+uint32_t DNSDistPacketCache::getKey(const DNSName::string_t& qname, size_t qnameWireLength, const PacketBuffer& packet, bool receivedOverUDP)
+{
+ uint32_t result = 0;
+ /* skip the query ID */
+ if (packet.size() < sizeof(dnsheader)) {
+ throw std::range_error("Computing packet cache key for an invalid packet size (" + std::to_string(packet.size()) +")");
+ }
+
+ result = burtle(&packet.at(2), sizeof(dnsheader) - 2, result);
+ result = burtleCI((const unsigned char*) qname.c_str(), qname.length(), result);
+ if (packet.size() < sizeof(dnsheader) + qnameWireLength) {
+ throw std::range_error("Computing packet cache key for an invalid packet (" + std::to_string(packet.size()) + " < " + std::to_string(sizeof(dnsheader) + qnameWireLength) + ")");
+ }
+ if (packet.size() > ((sizeof(dnsheader) + qnameWireLength))) {
+ if (!d_optionsToSkip.empty()) {
+ /* skip EDNS options if any */
+ result = PacketCache::hashAfterQname(std::string_view(reinterpret_cast<const char*>(packet.data()), packet.size()), result, sizeof(dnsheader) + qnameWireLength, d_optionsToSkip);
+ }
+ else {
+ result = burtle(&packet.at(sizeof(dnsheader) + qnameWireLength), packet.size() - (sizeof(dnsheader) + qnameWireLength), result);
+ }
+ }
+ result = burtle((const unsigned char*) &receivedOverUDP, sizeof(receivedOverUDP), result);
+ return result;
+}
+
+uint32_t DNSDistPacketCache::getShardIndex(uint32_t key) const
+{
+ return key % d_shardCount;
+}
+
+string DNSDistPacketCache::toString()
+{
+ return std::to_string(getSize()) + "/" + std::to_string(d_maxEntries);
+}
+
+uint64_t DNSDistPacketCache::getEntriesCount()
+{
+ return getSize();
+}
+
+uint64_t DNSDistPacketCache::dump(int fd)
+{
+ auto fp = std::unique_ptr<FILE, int(*)(FILE*)>(fdopen(dup(fd), "w"), fclose);
+ if (fp == nullptr) {
+ return 0;
+ }
+
+ fprintf(fp.get(), "; dnsdist's packet cache dump follows\n;\n");
+
+ uint64_t count = 0;
+ time_t now = time(nullptr);
+ for (auto& shard : d_shards) {
+ auto map = shard.d_map.read_lock();
+
+ for (const auto& entry : *map) {
+ const CacheValue& value = entry.second;
+ count++;
+
+ try {
+ uint8_t rcode = 0;
+ if (value.len >= sizeof(dnsheader)) {
+ dnsheader dh;
+ memcpy(&dh, value.value.data(), sizeof(dnsheader));
+ rcode = dh.rcode;
+ }
+
+ fprintf(fp.get(), "%s %" PRId64 " %s ; rcode %" PRIu8 ", key %" PRIu32 ", length %" PRIu16 ", received over UDP %d, added %" PRId64 "\n", value.qname.toString().c_str(), static_cast<int64_t>(value.validity - now), QType(value.qtype).toString().c_str(), rcode, entry.first, value.len, value.receivedOverUDP, static_cast<int64_t>(value.added));
+ }
+ catch(...) {
+ fprintf(fp.get(), "; error printing '%s'\n", value.qname.empty() ? "EMPTY" : value.qname.toString().c_str());
+ }
+ }
+ }
+
+ return count;
+}
+
+void DNSDistPacketCache::setSkippedOptions(const std::unordered_set<uint16_t>& optionsToSkip)
+{
+ d_optionsToSkip = optionsToSkip;
+}
+
+std::set<DNSName> DNSDistPacketCache::getDomainsContainingRecords(const ComboAddress& addr)
+{
+ std::set<DNSName> domains;
+
+ for (auto& shard : d_shards) {
+ auto map = shard.d_map.read_lock();
+
+ for (const auto& entry : *map) {
+ const CacheValue& value = entry.second;
+
+ try {
+ dnsheader dh;
+ if (value.len < sizeof(dnsheader)) {
+ continue;
+ }
+
+ memcpy(&dh, value.value.data(), sizeof(dnsheader));
+ if (dh.rcode != RCode::NoError || (dh.ancount == 0 && dh.nscount == 0 && dh.arcount == 0)) {
+ continue;
+ }
+
+ bool found = false;
+ bool valid = visitDNSPacket(value.value, [addr, &found](uint8_t /* section */, uint16_t qclass, uint16_t qtype, uint32_t /* ttl */, uint16_t rdatalength, const char* rdata) {
+ if (qtype == QType::A && qclass == QClass::IN && addr.isIPv4() && rdatalength == 4 && rdata != nullptr) {
+ ComboAddress parsed;
+ parsed.sin4.sin_family = AF_INET;
+ memcpy(&parsed.sin4.sin_addr.s_addr, rdata, rdatalength);
+ if (parsed == addr) {
+ found = true;
+ return true;
+ }
+ }
+ else if (qtype == QType::AAAA && qclass == QClass::IN && addr.isIPv6() && rdatalength == 16 && rdata != nullptr) {
+ ComboAddress parsed;
+ parsed.sin6.sin6_family = AF_INET6;
+ memcpy(&parsed.sin6.sin6_addr.s6_addr, rdata, rdatalength);
+ if (parsed == addr) {
+ found = true;
+ return true;
+ }
+ }
+
+ return false;
+ });
+
+ if (valid && found) {
+ domains.insert(value.qname);
+ }
+ }
+ catch (...) {
+ continue;
+ }
+ }
+ }
+
+ return domains;
+}
+
+std::set<ComboAddress> DNSDistPacketCache::getRecordsForDomain(const DNSName& domain)
+{
+ std::set<ComboAddress> addresses;
+
+ for (auto& shard : d_shards) {
+ auto map = shard.d_map.read_lock();
+
+ for (const auto& entry : *map) {
+ const CacheValue& value = entry.second;
+
+ try {
+ if (value.qname != domain) {
+ continue;
+ }
+
+ dnsheader dh;
+ if (value.len < sizeof(dnsheader)) {
+ continue;
+ }
+
+ memcpy(&dh, value.value.data(), sizeof(dnsheader));
+ if (dh.rcode != RCode::NoError || (dh.ancount == 0 && dh.nscount == 0 && dh.arcount == 0)) {
+ continue;
+ }
+
+ visitDNSPacket(value.value, [&addresses](uint8_t /* section */, uint16_t qclass, uint16_t qtype, uint32_t /* ttl */, uint16_t rdatalength, const char* rdata) {
+ if (qtype == QType::A && qclass == QClass::IN && rdatalength == 4 && rdata != nullptr) {
+ ComboAddress parsed;
+ parsed.sin4.sin_family = AF_INET;
+ memcpy(&parsed.sin4.sin_addr.s_addr, rdata, rdatalength);
+ addresses.insert(parsed);
+ }
+ else if (qtype == QType::AAAA && qclass == QClass::IN && rdatalength == 16 && rdata != nullptr) {
+ ComboAddress parsed;
+ parsed.sin6.sin6_family = AF_INET6;
+ memcpy(&parsed.sin6.sin6_addr.s6_addr, rdata, rdatalength);
+ addresses.insert(parsed);
+ }
+
+ return false;
+ });
+ }
+ catch (...) {
+ continue;
+ }
+ }
+ }
+
+ return addresses;
+}