diff options
Diffstat (limited to '')
-rw-r--r-- | dnsdist-lbpolicies.cc | 357 |
1 files changed, 357 insertions, 0 deletions
diff --git a/dnsdist-lbpolicies.cc b/dnsdist-lbpolicies.cc new file mode 100644 index 0000000..d747ca7 --- /dev/null +++ b/dnsdist-lbpolicies.cc @@ -0,0 +1,357 @@ +/* + * This file is part of PowerDNS or dnsdist. + * Copyright -- PowerDNS.COM B.V. and its contributors + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of version 2 of the GNU General Public License as + * published by the Free Software Foundation. + * + * In addition, for the avoidance of any doubt, permission is granted to + * link this program with OpenSSL and to (re)distribute the binaries + * produced as the result of such linking. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include "dnsdist.hh" +#include "dnsdist-lbpolicies.hh" +#include "dnsdist-lua.hh" +#include "dnsdist-lua-ffi.hh" +#include "dolog.hh" + +GlobalStateHolder<ServerPolicy> g_policy; +bool g_roundrobinFailOnNoServer{false}; + +// get server with least outstanding queries, and within those, with the lowest order, and within those: the fastest +shared_ptr<DownstreamState> leastOutstanding(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq) +{ + if (servers.size() == 1 && servers[0].second->isUp()) { + return servers[0].second; + } + + vector<pair<tuple<int,int,double>, size_t>> poss; + /* so you might wonder, why do we go through this trouble? The data on which we sort could change during the sort, + which would suck royally and could even lead to crashes. So first we snapshot on what we sort, and then we sort */ + poss.reserve(servers.size()); + size_t position = 0; + for(const auto& d : servers) { + if(d.second->isUp()) { + poss.emplace_back(make_tuple(d.second->outstanding.load(), d.second->order, d.second->latencyUsec), position); + } + ++position; + } + + if (poss.empty()) { + return shared_ptr<DownstreamState>(); + } + + nth_element(poss.begin(), poss.begin(), poss.end(), [](const decltype(poss)::value_type& a, const decltype(poss)::value_type& b) { return a.first < b.first; }); + return servers.at(poss.begin()->second).second; +} + +shared_ptr<DownstreamState> firstAvailable(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq) +{ + for(auto& d : servers) { + if (d.second->isUp() && d.second->qps.checkOnly()) { + return d.second; + } + } + return leastOutstanding(servers, dq); +} + +double g_weightedBalancingFactor = 0; + +static shared_ptr<DownstreamState> valrandom(unsigned int val, const ServerPolicy::NumberedServerVector& servers) +{ + vector<pair<int, size_t>> poss; + poss.reserve(servers.size()); + int sum = 0; + int max = std::numeric_limits<int>::max(); + double targetLoad = std::numeric_limits<double>::max(); + + if (g_weightedBalancingFactor > 0) { + /* we start with one, representing the query we are currently handling */ + double currentLoad = 1; + size_t totalWeight = 0; + for (const auto& pair : servers) { + if (pair.second->isUp()) { + currentLoad += pair.second->outstanding; + totalWeight += pair.second->weight; + } + } + + if (totalWeight > 0) { + targetLoad = (currentLoad / totalWeight) * g_weightedBalancingFactor; + } + } + + for (const auto& d : servers) { // w=1, w=10 -> 1, 11 + if (d.second->isUp() && (g_weightedBalancingFactor == 0 || (d.second->outstanding <= (targetLoad * d.second->weight)))) { + // Don't overflow sum when adding high weights + if (d.second->weight > max - sum) { + sum = max; + } else { + sum += d.second->weight; + } + + poss.emplace_back(sum, d.first); + } + } + + // Catch poss & sum are empty to avoid SIGFPE + if (poss.empty() || sum == 0) { + return shared_ptr<DownstreamState>(); + } + + int r = val % sum; + auto p = upper_bound(poss.begin(), poss.end(),r, [](int r_, const decltype(poss)::value_type& a) { return r_ < a.first;}); + if (p == poss.end()) { + return shared_ptr<DownstreamState>(); + } + + return servers.at(p->second - 1).second; +} + +shared_ptr<DownstreamState> wrandom(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq) +{ + return valrandom(random(), servers); +} + +uint32_t g_hashperturb; +double g_consistentHashBalancingFactor = 0; + +shared_ptr<DownstreamState> whashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t hash) +{ + return valrandom(hash, servers); +} + +shared_ptr<DownstreamState> whashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq) +{ + return whashedFromHash(servers, dq->qname->hash(g_hashperturb)); +} + +shared_ptr<DownstreamState> chashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t qhash) +{ + unsigned int sel = std::numeric_limits<unsigned int>::max(); + unsigned int min = std::numeric_limits<unsigned int>::max(); + shared_ptr<DownstreamState> ret = nullptr, first = nullptr; + + double targetLoad = std::numeric_limits<double>::max(); + if (g_consistentHashBalancingFactor > 0) { + /* we start with one, representing the query we are currently handling */ + double currentLoad = 1; + size_t totalWeight = 0; + for (const auto& pair : servers) { + if (pair.second->isUp()) { + currentLoad += pair.second->outstanding; + totalWeight += pair.second->weight; + } + } + + if (totalWeight > 0) { + targetLoad = (currentLoad / totalWeight) * g_consistentHashBalancingFactor; + } + } + + for (const auto& d: servers) { + if (d.second->isUp() && (g_consistentHashBalancingFactor == 0 || d.second->outstanding <= (targetLoad * d.second->weight))) { + // make sure hashes have been computed + if (!d.second->hashesComputed) { + d.second->hash(); + } + { + const auto& server = d.second; + auto hashes = server->hashes.read_lock(); + // we want to keep track of the last hash + if (min > *(hashes->begin())) { + min = *(hashes->begin()); + first = server; + } + + auto hash_it = std::lower_bound(hashes->begin(), hashes->end(), qhash); + if (hash_it != hashes->end()) { + if (*hash_it < sel) { + sel = *hash_it; + ret = server; + } + } + } + } + } + if (ret != nullptr) { + return ret; + } + if (first != nullptr) { + return first; + } + return shared_ptr<DownstreamState>(); +} + +shared_ptr<DownstreamState> chashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq) +{ + return chashedFromHash(servers, dq->qname->hash(g_hashperturb)); +} + +shared_ptr<DownstreamState> roundrobin(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq) +{ + if (servers.empty()) { + return shared_ptr<DownstreamState>(); + } + + vector<size_t> candidates; + candidates.reserve(servers.size()); + + for (auto& d : servers) { + if (d.second->isUp()) { + candidates.push_back(d.first); + } + } + + if (candidates.empty()) { + if (g_roundrobinFailOnNoServer) { + return shared_ptr<DownstreamState>(); + } + for (auto& d : servers) { + candidates.push_back(d.first); + } + } + + static unsigned int counter; + return servers.at(candidates.at((counter++) % candidates.size()) - 1).second; +} + +const std::shared_ptr<ServerPolicy::NumberedServerVector> getDownstreamCandidates(const pools_t& pools, const std::string& poolName) +{ + std::shared_ptr<ServerPool> pool = getPool(pools, poolName); + return pool->getServers(); +} + +std::shared_ptr<ServerPool> createPoolIfNotExists(pools_t& pools, const string& poolName) +{ + std::shared_ptr<ServerPool> pool; + pools_t::iterator it = pools.find(poolName); + if (it != pools.end()) { + pool = it->second; + } + else { + if (!poolName.empty()) + vinfolog("Creating pool %s", poolName); + pool = std::make_shared<ServerPool>(); + pools.insert(std::pair<std::string,std::shared_ptr<ServerPool> >(poolName, pool)); + } + return pool; +} + +void setPoolPolicy(pools_t& pools, const string& poolName, std::shared_ptr<ServerPolicy> policy) +{ + std::shared_ptr<ServerPool> pool = createPoolIfNotExists(pools, poolName); + if (!poolName.empty()) { + vinfolog("Setting pool %s server selection policy to %s", poolName, policy->getName()); + } else { + vinfolog("Setting default pool server selection policy to %s", policy->getName()); + } + pool->policy = policy; +} + +void addServerToPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server) +{ + std::shared_ptr<ServerPool> pool = createPoolIfNotExists(pools, poolName); + if (!poolName.empty()) { + vinfolog("Adding server to pool %s", poolName); + } else { + vinfolog("Adding server to default pool"); + } + pool->addServer(server); +} + +void removeServerFromPool(pools_t& pools, const string& poolName, std::shared_ptr<DownstreamState> server) +{ + std::shared_ptr<ServerPool> pool = getPool(pools, poolName); + + if (!poolName.empty()) { + vinfolog("Removing server from pool %s", poolName); + } + else { + vinfolog("Removing server from default pool"); + } + + pool->removeServer(server); +} + +std::shared_ptr<ServerPool> getPool(const pools_t& pools, const std::string& poolName) +{ + pools_t::const_iterator it = pools.find(poolName); + + if (it == pools.end()) { + throw std::out_of_range("No pool named " + poolName); + } + + return it->second; +} + +ServerPolicy::ServerPolicy(const std::string& name_, const std::string& code): d_name(name_), d_perThreadPolicyCode(code), d_isLua(true), d_isFFI(true), d_isPerThread(true) +{ + LuaContext tmpContext; + setupLuaLoadBalancingContext(tmpContext); + auto ret = tmpContext.executeCode<ServerPolicy::ffipolicyfunc_t>(code); +} + +thread_local ServerPolicy::PerThreadState ServerPolicy::t_perThreadState; + +const ServerPolicy::ffipolicyfunc_t& ServerPolicy::getPerThreadPolicy() const +{ + auto& state = t_perThreadState; + if (!state.d_initialized) { + setupLuaLoadBalancingContext(state.d_luaContext); + state.d_initialized = true; + } + + const auto& it = state.d_policies.find(d_name); + if (it != state.d_policies.end()) { + return it->second; + } + + auto newPolicy = state.d_luaContext.executeCode<ServerPolicy::ffipolicyfunc_t>(d_perThreadPolicyCode); + state.d_policies[d_name] = std::move(newPolicy); + return state.d_policies.at(d_name); +} + +std::shared_ptr<DownstreamState> ServerPolicy::getSelectedBackend(const ServerPolicy::NumberedServerVector& servers, DNSQuestion& dq) const +{ + std::shared_ptr<DownstreamState> selectedBackend{nullptr}; + + if (d_isLua) { + if (!d_isFFI) { + auto lock = g_lua.lock(); + selectedBackend = d_policy(servers, &dq); + } + else { + dnsdist_ffi_dnsquestion_t dnsq(&dq); + dnsdist_ffi_servers_list_t serversList(servers); + unsigned int selected = 0; + + if (!d_isPerThread) { + auto lock = g_lua.lock(); + selected = d_ffipolicy(&serversList, &dnsq); + } + else { + const auto& policy = getPerThreadPolicy(); + selected = policy(&serversList, &dnsq); + } + + selectedBackend = servers.at(selected).second; + } + } + else { + selectedBackend = d_policy(servers, &dq); + } + + return selectedBackend; +} |