From 830407e88f9d40d954356c3754f2647f91d5c06a Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 7 Apr 2024 17:26:00 +0200 Subject: Adding upstream version 5.6.0. Signed-off-by: Daniel Baumann --- lib/selection.c | 795 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 795 insertions(+) create mode 100644 lib/selection.c (limited to 'lib/selection.c') diff --git a/lib/selection.c b/lib/selection.c new file mode 100644 index 0000000..5aa2992 --- /dev/null +++ b/lib/selection.c @@ -0,0 +1,795 @@ +/* Copyright (C) CZ.NIC, z.s.p.o. + * SPDX-License-Identifier: GPL-3.0-or-later + */ + +#include + +#include "lib/selection.h" +#include "lib/selection_forward.h" +#include "lib/selection_iter.h" +#include "lib/rplan.h" +#include "lib/cache/api.h" +#include "lib/resolve.h" + +#include "lib/utils.h" + +#define VERBOSE_MSG(qry, ...) kr_log_q((qry), SELECTION, __VA_ARGS__) + +#define DEFAULT_TIMEOUT 400 +#define MAX_TIMEOUT 10000 +#define EXPLORE_TIMEOUT_COEFFICIENT 2 +#define MAX_BACKOFF 8 +#define MINIMAL_TIMEOUT_ADDITION 20 + +/* After TCP_TIMEOUT_THRESHOLD timeouts one transport, we'll switch to TCP. */ +#define TCP_TIMEOUT_THRESHOLD 2 +/* If the expected RTT is over TCP_RTT_THRESHOLD we switch to TCP instead. */ +#define TCP_RTT_THRESHOLD 2000 + +/* Define ε for ε-greedy algorithm (see select_transport) + * as ε=EPSILON_NOMIN/EPSILON_DENOM */ +#define EPSILON_NOMIN 1 +#define EPSILON_DENOM 20 + +static const char *kr_selection_error_str(enum kr_selection_error err) { + switch (err) { + #define X(ENAME) case KR_SELECTION_ ## ENAME: return #ENAME + X(OK); + X(QUERY_TIMEOUT); + X(TLS_HANDSHAKE_FAILED); + X(TCP_CONNECT_FAILED); + X(TCP_CONNECT_TIMEOUT); + X(REFUSED); + X(SERVFAIL); + X(FORMERR); + X(FORMERR_EDNS); + X(NOTIMPL); + X(OTHER_RCODE); + X(MALFORMED); + X(MISMATCHED); + X(TRUNCATED); + X(DNSSEC_ERROR); + X(LAME_DELEGATION); + X(BAD_CNAME); + case KR_SELECTION_NUMBER_OF_ERRORS: break; // not a valid code + #undef X + } + kr_assert(false); // we want to define all; compiler helps by -Wswitch (no default:) + return NULL; +} + + +/* Simple detection of IPv6 being broken. + * + * We follow all IPv6 timeouts and successes. Consider it broken iff we've had + * timeouts on several different IPv6 prefixes since the last IPv6 success. + * Note: unlike the rtt_state, this happens only per-process (for simplicity). + * + * ## NO6_PREFIX_* choice + * For our practical use we choose primarily based on root and typical TLD servers. + * Looking at *.{root,gtld}-servers.net, we have 7/26 AAAAs in 2001:500:00**:: + * but adding one more byte makes these completely unique, so we choose /48. + * As distribution to ASs seems to be on shorter prefixes (RIPE: /32 -- /24?), + * we wait for several distinct prefixes. + */ + +#define NO6_PREFIX_COUNT 6 +#define NO6_PREFIX_BYTES (48/8) +static struct { + int len_used; + uint8_t addr_prefixes[NO6_PREFIX_COUNT][NO6_PREFIX_BYTES]; +} no6_est = { .len_used = 0 }; + +bool no6_is_bad(void) +{ + return no6_est.len_used == NO6_PREFIX_COUNT; +} + +static void no6_timed_out(const struct kr_query *qry, const uint8_t *addr) +{ + if (no6_is_bad()) { // we can't get worse + VERBOSE_MSG(qry, "NO6: timed out, but bad already\n"); + return; + } + // If we have the address already, do nothing. + for (int i = 0; i < no6_est.len_used; ++i) { + if (memcmp(addr, no6_est.addr_prefixes[i], NO6_PREFIX_BYTES) == 0) { + VERBOSE_MSG(qry, "NO6: timed out, repeated prefix, timeouts %d/%d\n", + no6_est.len_used, (int)NO6_PREFIX_COUNT); + return; + } + } + // Append! + memcpy(no6_est.addr_prefixes[no6_est.len_used++], addr, NO6_PREFIX_BYTES); + VERBOSE_MSG(qry, "NO6: timed out, appended, timeouts %d/%d\n", + no6_est.len_used, (int)NO6_PREFIX_COUNT); +} + +static inline void no6_success(const struct kr_query *qry) +{ + if (no6_est.len_used) { + VERBOSE_MSG(qry, "NO6: success, zeroing %d/%d\n", + no6_est.len_used, (int)NO6_PREFIX_COUNT); + } + no6_est.len_used = 0; +} + + +/* Simple cache interface follows */ + +static knot_db_val_t cache_key(const uint8_t *ip, size_t len) +{ + // CACHE_KEY_DEF: '\0' + 'S' + raw IP + const size_t key_len = len + 2; + uint8_t *key_data = malloc(key_len); + key_data[0] = '\0'; + key_data[1] = 'S'; + memcpy(key_data + 2, ip, len); + knot_db_val_t key = { + .len = key_len, + .data = key_data, + }; + return key; +} + +/* First value of timeout will be calculated as SRTT+4*VARIANCE + * by calc_timeout(), so it'll be equal to DEFAULT_TIMEOUT. */ +static const struct rtt_state default_rtt_state = { .srtt = 0, + .variance = DEFAULT_TIMEOUT / 4, + .consecutive_timeouts = 0, + .dead_since = 0 }; + +struct rtt_state get_rtt_state(const uint8_t *ip, size_t len, + struct kr_cache *cache) +{ + struct rtt_state state; + knot_db_val_t value; + knot_db_t *db = cache->db; + struct kr_cdb_stats *stats = &cache->stats; + + knot_db_val_t key = cache_key(ip, len); + + if (cache->api->read(db, stats, &key, &value, 1)) { + state = default_rtt_state; + } else if (kr_fails_assert(value.len == sizeof(struct rtt_state))) { + // shouldn't happen but let's be more robust + state = default_rtt_state; + } else { // memcpy is safe for unaligned case (on non-x86) + memcpy(&state, value.data, sizeof(state)); + } + + free(key.data); + return state; +} + +int put_rtt_state(const uint8_t *ip, size_t len, struct rtt_state state, + struct kr_cache *cache) +{ + knot_db_t *db = cache->db; + struct kr_cdb_stats *stats = &cache->stats; + + knot_db_val_t key = cache_key(ip, len); + knot_db_val_t value = { .len = sizeof(struct rtt_state), + .data = &state }; + + int ret = cache->api->write(db, stats, &key, &value, 1); + cache->api->commit(db, stats); + + free(key.data); + return ret; +} + +void bytes_to_ip(uint8_t *bytes, size_t len, uint16_t port, union kr_sockaddr *dst) +{ + switch (len) { + case sizeof(struct in_addr): + dst->ip4.sin_family = AF_INET; + memcpy(&dst->ip4.sin_addr, bytes, len); + dst->ip4.sin_port = htons(port); + break; + case sizeof(struct in6_addr): + memset(&dst->ip6, 0, sizeof(dst->ip6)); // avoid uninit surprises + dst->ip6.sin6_family = AF_INET6; + memcpy(&dst->ip6.sin6_addr, bytes, len); + dst->ip6.sin6_port = htons(port); + break; + default: + kr_assert(false); + } +} + +uint8_t *ip_to_bytes(const union kr_sockaddr *src, size_t len) +{ + switch (len) { + case sizeof(struct in_addr): + return (uint8_t *)&src->ip4.sin_addr; + case sizeof(struct in6_addr): + return (uint8_t *)&src->ip6.sin6_addr; + default: + kr_assert(false); + return NULL; + } +} + +static bool no_rtt_info(struct rtt_state s) +{ + return s.srtt == 0 && s.consecutive_timeouts == 0; +} + +static unsigned back_off_timeout(uint32_t to, int pow) +{ + pow = MIN(pow, MAX_BACKOFF); + to <<= pow; + return MIN(to, MAX_TIMEOUT); +} + +/* This is verbatim (minus the default timeout value and minimal variance) + * RFC6298, sec. 2. */ +static unsigned calc_timeout(struct rtt_state state) +{ + int32_t timeout = state.srtt + MAX(4 * state.variance, MINIMAL_TIMEOUT_ADDITION); + return back_off_timeout(timeout, state.consecutive_timeouts); +} + +/* This is verbatim RFC6298, sec. 2. */ +static struct rtt_state calc_rtt_state(struct rtt_state old, unsigned new_rtt) +{ + if (no_rtt_info(old)) { + return (struct rtt_state){ new_rtt, new_rtt / 2, 0 }; + } + + struct rtt_state ret = { 0 }; + ret.variance = (3 * old.variance + abs(old.srtt - (int32_t)new_rtt) + + 2/*rounding*/) / 4; + ret.srtt = (7 * old.srtt + new_rtt + 4/*rounding*/) / 8; + + return ret; +} + +/** + * @internal Invalidate addresses which should be considered dead + */ +static void invalidate_dead_upstream(struct address_state *state, + unsigned int retry_timeout) +{ + struct rtt_state *rs = &state->rtt_state; + if (rs->dead_since) { + uint64_t now = kr_now(); + if (now < rs->dead_since) { + // broken continuity of timestamp (reboot, different machine, etc.) + *rs = default_rtt_state; + } else if (now < rs->dead_since + retry_timeout) { + // period when we don't want to use the address + state->generation = -1; + } else { + kr_assert(now >= rs->dead_since + retry_timeout); + // we allow to retry the server now + // TODO: perhaps tweak *rs? + } + } +} + +/** + * @internal Check if IP address is TLS capable. + * + * @p req has to have the selection_context properly initialized. + */ +static void check_tls_capable(struct address_state *address_state, + struct kr_request *req, struct sockaddr *address) +{ + address_state->tls_capable = + req->selection_context.is_tls_capable ? + req->selection_context.is_tls_capable(address) : + false; +} + +#if 0 +/* TODO: uncomment these once we actually use the information it collects. */ +/** + * Check if there is a existing TCP connection to this address. + * + * @p req has to have the selection_context properly initialized. + */ +void check_tcp_connections(struct address_state *address_state, struct kr_request *req, struct sockaddr *address) { + address_state->tcp_connected = req->selection_context.is_tcp_connected ? req->selection_context.is_tcp_connected(address) : false; + address_state->tcp_waiting = req->selection_context.is_tcp_waiting ? req->selection_context.is_tcp_waiting(address) : false; +} +#endif + +/** + * @internal Invalidate address if the respective IP version is disabled. + */ +static void check_network_settings(struct address_state *address_state, + size_t address_len, bool no_ipv4, bool no_ipv6) +{ + if (no_ipv4 && address_len == sizeof(struct in_addr)) { + address_state->generation = -1; + } + if (no_ipv6 && address_len == sizeof(struct in6_addr)) { + address_state->generation = -1; + } +} + +void update_address_state(struct address_state *state, union kr_sockaddr *address, + size_t address_len, struct kr_query *qry) +{ + check_tls_capable(state, qry->request, &address->ip); + /* TODO: uncomment this once we actually use the information it collects + check_tcp_connections(address_state, qry->request, &address->ip); + */ + check_network_settings(state, address_len, qry->flags.NO_IPV4, + qry->flags.NO_IPV6); + state->rtt_state = + get_rtt_state(ip_to_bytes(address, address_len), + address_len, &qry->request->ctx->cache); + invalidate_dead_upstream( + state, qry->request->ctx->cache_rtt_tout_retry_interval); +#ifdef SELECTION_CHOICE_LOGGING + // This is sometimes useful for debugging, but usually too verbose + if (kr_log_is_debug_qry(SELECTION, qry)) { + const char *ns_str = kr_straddr(&address->ip); + VERBOSE_MSG(qry, "rtt of %s is %d, variance is %d\n", ns_str, + state->rtt_state.srtt, state->rtt_state.variance); + } +#endif +} + +static int cmp_choices(const struct choice *a_, const struct choice *b_) +{ + int diff; + /* Prefer IPv4 if IPv6 appears to be generally broken. */ + diff = (int)a_->address_len - (int)b_->address_len; + if (diff && no6_is_bad()) { + return diff; + } + /* Address with no RTT information is better than address + * with some information. */ + if ((diff = no_rtt_info(b_->address_state->rtt_state) - + no_rtt_info(a_->address_state->rtt_state))) { + return diff; + } + /* Address with less errors is better. */ + if ((diff = a_->address_state->error_count - + b_->address_state->error_count)) { + return diff; + } + /* Address with smaller expected timeout is better. */ + if ((diff = calc_timeout(a_->address_state->rtt_state) - + calc_timeout(b_->address_state->rtt_state))) { + return diff; + } + return 0; +} +/** Select the best entry from choices[] according to cmp_choices() comparator. + * + * Ties are decided in an (almost) uniformly random fashion. + */ +static const struct choice * select_best(const struct choice choices[], int choices_len) +{ + /* Deciding ties: it's as-if each index carries one byte of randomness. + * Ties get decided by comparing that byte, and the byte itself + * is computed lazily (negative until computed). + */ + int best_i = 0; + int best_rnd = -1; + for (int i = 1; i < choices_len; ++i) { + int diff = cmp_choices(&choices[i], &choices[best_i]); + if (diff > 0) + continue; + if (diff < 0) { + best_i = i; + best_rnd = -1; + continue; + } + if (best_rnd < 0) + best_rnd = kr_rand_bytes(1); + int new_rnd = kr_rand_bytes(1); + if (new_rnd < best_rnd) { + best_i = i; + best_rnd = new_rnd; + } + } + return &choices[best_i]; +} + +/* Adjust choice from `unresolved` in case of NO6 (broken IPv6). */ +static struct kr_transport unresolved_adjust(const struct to_resolve unresolved[], + int unresolved_len, int index) +{ + if (unresolved[index].type != KR_TRANSPORT_RESOLVE_AAAA || !no6_is_bad()) + goto finish; + /* AAAA is detected as bad; let's choose randomly from others, if there are any. */ + int aaaa_count = 0; + for (int i = 0; i < unresolved_len; ++i) + aaaa_count += (unresolved[i].type == KR_TRANSPORT_RESOLVE_AAAA); + if (aaaa_count == unresolved_len) + goto finish; + /* Chosen index within non-AAAA items. */ + int i_no6 = kr_rand_bytes(1) % (unresolved_len - aaaa_count); + for (int i = 0; i < unresolved_len; ++i) { + if (unresolved[i].type == KR_TRANSPORT_RESOLVE_AAAA) { + //continue + } else if (i_no6 == 0) { + index = i; + break; + } else { + --i_no6; + } + } +finish: + return (struct kr_transport){ + .protocol = unresolved[index].type, + .ns_name = unresolved[index].name + }; +} + +/* Performs the actual selection (currently variation on epsilon-greedy). */ +struct kr_transport *select_transport(const struct choice choices[], int choices_len, + const struct to_resolve unresolved[], + int unresolved_len, int timeouts, + struct knot_mm *mempool, bool tcp, + size_t *choice_index) +{ + if (!choices_len && !unresolved_len) { + /* There is nothing to choose from */ + return NULL; + } + + struct kr_transport *transport = mm_calloc(mempool, 1, sizeof(*transport)); + + /* If there are some addresses with no rtt_info we try them + * first (see cmp_choices). So unknown servers are chosen + * *before* the best know server. This ensures that every option + * is tried before going back to some that was tried before. */ + const struct choice *best = select_best(choices, choices_len); + const struct choice *chosen; + + const bool explore = choices_len == 0 || kr_rand_coin(EPSILON_NOMIN, EPSILON_DENOM) + /* We may need to explore to get at least one A record. */ + || (no6_is_bad() && best->address.ip.sa_family == AF_INET6); + if (explore) { + /* "EXPLORE": + * randomly choose some option + * (including resolution of some new name). */ + int index = kr_rand_bytes(1) % (choices_len + unresolved_len); + if (index < unresolved_len) { + // We will resolve a new NS name + *transport = unresolved_adjust(unresolved, unresolved_len, index); + return transport; + } else { + chosen = &choices[index - unresolved_len]; + } + } else { + /* "EXPLOIT": + * choose a resolved address which seems best right now. */ + chosen = best; + } + + /* Don't try the same server again when there are other choices to be explored */ + if (chosen->address_state->error_count && unresolved_len) { + int index = kr_rand_bytes(1) % unresolved_len; + *transport = unresolved_adjust(unresolved, unresolved_len, index); + return transport; + } + + unsigned timeout; + if (no_rtt_info(chosen->address_state->rtt_state)) { + /* Exponential back-off when retrying after timeout and choosing + * an unknown server. */ + timeout = back_off_timeout(DEFAULT_TIMEOUT, timeouts); + } else { + timeout = calc_timeout(chosen->address_state->rtt_state); + if (explore) { + /* When trying a random server, we cap the timeout to EXPLORE_TIMEOUT_COEFFICIENT + * times the timeout for the best server. This is done so we don't spend + * unreasonable amounts of time probing really bad servers while still + * checking once in a while for e.g. big network change etc. + * We also note this capping was done and don't punish the bad server + * further if it fails to answer in the capped timeout. */ + unsigned best_timeout = calc_timeout(best->address_state->rtt_state); + if (timeout > best_timeout * EXPLORE_TIMEOUT_COEFFICIENT) { + timeout = best_timeout * EXPLORE_TIMEOUT_COEFFICIENT; + transport->timeout_capped = true; + } + } + } + + enum kr_transport_protocol protocol; + if (chosen->address_state->tls_capable) { + protocol = KR_TRANSPORT_TLS; + } else if (tcp || + chosen->address_state->errors[KR_SELECTION_QUERY_TIMEOUT] >= TCP_TIMEOUT_THRESHOLD || + timeout > TCP_RTT_THRESHOLD) { + protocol = KR_TRANSPORT_TCP; + } else { + protocol = KR_TRANSPORT_UDP; + } + + *transport = (struct kr_transport){ + .ns_name = chosen->address_state->ns_name, + .protocol = protocol, + .timeout = timeout, + }; + + int port = chosen->port; + if (!port) { + switch (transport->protocol) { + case KR_TRANSPORT_TLS: + port = KR_DNS_TLS_PORT; + break; + case KR_TRANSPORT_UDP: + case KR_TRANSPORT_TCP: + port = KR_DNS_PORT; + break; + default: + kr_assert(false); + return NULL; + } + } + + switch (chosen->address_len) + { + case sizeof(struct in_addr): + transport->address.ip4 = chosen->address.ip4; + transport->address.ip4.sin_port = htons(port); + break; + case sizeof(struct in6_addr): + transport->address.ip6 = chosen->address.ip6; + transport->address.ip6.sin6_port = htons(port); + break; + default: + kr_assert(false); + return NULL; + } + + transport->address_len = chosen->address_len; + + if (choice_index) { + *choice_index = chosen->address_state->choice_array_index; + } + + return transport; +} + +void update_rtt(struct kr_query *qry, struct address_state *addr_state, + const struct kr_transport *transport, unsigned rtt) +{ + if (!transport || !addr_state) { + /* Answers from cache have NULL transport, ignore them. */ + return; + } + + struct kr_cache *cache = &qry->request->ctx->cache; + + uint8_t *address = ip_to_bytes(&transport->address, transport->address_len); + /* This construct is a bit racy since the global state may change + * between calls to `get_rtt_state` and `put_rtt_state` but we don't + * care that much since it is rare and we only risk slightly suboptimal + * transport choice. */ + struct rtt_state cur_rtt_state = + get_rtt_state(address, transport->address_len, cache); + struct rtt_state new_rtt_state = calc_rtt_state(cur_rtt_state, rtt); + put_rtt_state(address, transport->address_len, new_rtt_state, cache); + + if (transport->address_len == sizeof(struct in6_addr)) + no6_success(qry); + + if (kr_log_is_debug_qry(SELECTION, qry)) { + KR_DNAME_GET_STR(ns_name, transport->ns_name); + KR_DNAME_GET_STR(zonecut_str, qry->zone_cut.name); + const char *ns_str = kr_straddr(&transport->address.ip); + + VERBOSE_MSG( + qry, + "=> id: '%05u' updating: '%s'@'%s' zone cut: '%s'" + " with rtt %u to srtt: %d and variance: %d \n", + qry->id, ns_name, ns_str ? ns_str : "", zonecut_str, + rtt, new_rtt_state.srtt, new_rtt_state.variance); + } +} + +/// Update rtt_state (including caching) after a server timed out. +static void server_timeout(const struct kr_query *qry, const struct kr_transport *transport, + struct address_state *addr_state, struct kr_cache *cache) +{ + // Make sure that the timeout wasn't capped; see kr_transport::timeout_capped + if (transport->timeout_capped) + return; + + const uint8_t *address = ip_to_bytes(&transport->address, transport->address_len); + if (transport->address_len == sizeof(struct in6_addr)) + no6_timed_out(qry, address); + + struct rtt_state *state = &addr_state->rtt_state; + // While we were waiting for timeout, the stats might have changed considerably, + // so let's overwrite what we had by fresh cache contents. + // This is useful when the address is busy (we query it concurrently). + *state = get_rtt_state(address, transport->address_len, cache); + + ++state->consecutive_timeouts; + // Avoid overflow; we don't utilize very high values anyway (arbitrary limit). + state->consecutive_timeouts = MIN(64, state->consecutive_timeouts); + if (state->consecutive_timeouts >= KR_NS_TIMEOUT_ROW_DEAD) { + // Only mark as dead if we waited long enough, + // so that many (concurrent) short attempts can't cause the dead state. + if (transport->timeout >= KR_NS_TIMEOUT_MIN_DEAD_TIMEOUT) + state->dead_since = kr_now(); + } + + // If transport was chosen by a different query, that one will cache it. + if (!transport->deduplicated) { + put_rtt_state(address, transport->address_len, *state, cache); + } else { + kr_cache_commit(cache); // Avoid any risk of long transaction. + } +} +// Not everything can be checked in nice ways like static_assert() +static __attribute__((constructor)) void test_RTT_consts(void) +{ + // See KR_NS_TIMEOUT_MIN_DEAD_TIMEOUT above. + kr_require( + calc_timeout((struct rtt_state){ .consecutive_timeouts = MAX_BACKOFF, }) + >= KR_NS_TIMEOUT_MIN_DEAD_TIMEOUT + ); +} + +void error(struct kr_query *qry, struct address_state *addr_state, + const struct kr_transport *transport, + enum kr_selection_error sel_error) +{ + if (!transport || !addr_state) { + /* Answers from cache have NULL transport, ignore them. */ + return; + } + + switch (sel_error) { + case KR_SELECTION_OK: + return; + case KR_SELECTION_TCP_CONNECT_FAILED: + case KR_SELECTION_TCP_CONNECT_TIMEOUT: + qry->server_selection.local_state->force_udp = true; + qry->flags.NO_0X20 = false; + /* Connection and handshake failures have properties similar + * to UDP timeouts, so we handle them (almost) the same way. */ + /* fall-through */ + case KR_SELECTION_TLS_HANDSHAKE_FAILED: + case KR_SELECTION_QUERY_TIMEOUT: + qry->server_selection.local_state->timeouts++; + server_timeout(qry, transport, addr_state, &qry->request->ctx->cache); + break; + case KR_SELECTION_FORMERR: + if (qry->flags.NO_EDNS) { + addr_state->broken = true; + } else { + qry->flags.NO_EDNS = true; + } + break; + case KR_SELECTION_FORMERR_EDNS: + addr_state->broken = true; + break; + case KR_SELECTION_MISMATCHED: + if (qry->flags.NO_0X20 && qry->flags.TCP) { + addr_state->broken = true; + } else { + qry->flags.TCP = true; + qry->flags.NO_0X20 = true; + } + break; + case KR_SELECTION_TRUNCATED: + if (transport->protocol == KR_TRANSPORT_UDP) { + qry->server_selection.local_state->truncated = true; + /* TC=1 over UDP is not an error, so we compensate. */ + addr_state->error_count--; + } else { + addr_state->broken = true; + } + break; + case KR_SELECTION_REFUSED: + case KR_SELECTION_SERVFAIL: + if (qry->flags.NO_MINIMIZE && qry->flags.NO_0X20 && qry->flags.TCP) { + addr_state->broken = true; + } else if (qry->flags.NO_MINIMIZE) { + qry->flags.NO_0X20 = true; + qry->flags.TCP = true; + } else { + qry->flags.NO_MINIMIZE = true; + } + break; + case KR_SELECTION_LAME_DELEGATION: + if (qry->flags.NO_MINIMIZE) { + /* Lame delegations are weird, they breed more lame delegations on broken + * zones since trying another server from the same set usually doesn't help. + * We force resolution of another NS name in hope of getting somewhere. */ + qry->server_selection.local_state->force_resolve = true; + addr_state->broken = true; + } else { + qry->flags.NO_MINIMIZE = true; + } + break; + case KR_SELECTION_NOTIMPL: + case KR_SELECTION_OTHER_RCODE: + case KR_SELECTION_DNSSEC_ERROR: + case KR_SELECTION_BAD_CNAME: + case KR_SELECTION_MALFORMED: + /* These errors are fatal, no point in trying this server again. */ + addr_state->broken = true; + break; + default: + kr_assert(false); + return; + } + + addr_state->error_count++; + addr_state->errors[sel_error]++; + + if (kr_log_is_debug_qry(SELECTION, qry)) { + KR_DNAME_GET_STR(ns_name, transport->ns_name); + KR_DNAME_GET_STR(zonecut_str, qry->zone_cut.name); + const char *ns_str = kr_straddr(&transport->address.ip); + const char *err_str = kr_selection_error_str(sel_error); + + VERBOSE_MSG( + qry, + "=> id: '%05u' noting selection error: '%s'@'%s'" + " zone cut: '%s' error: %d %s\n", + qry->id, ns_name, ns_str ? ns_str : "", + zonecut_str, sel_error, err_str ? err_str : "??"); + } +} + +void kr_server_selection_init(struct kr_query *qry) +{ + struct knot_mm *mempool = &qry->request->pool; + struct local_state *local_state = mm_calloc(mempool, 1, sizeof(*local_state)); + + if (qry->flags.FORWARD || qry->flags.STUB) { + qry->server_selection = (struct kr_server_selection){ + .initialized = true, + .choose_transport = forward_choose_transport, + .update_rtt = forward_update_rtt, + .error = forward_error, + .local_state = local_state, + }; + forward_local_state_alloc( + mempool, &qry->server_selection.local_state->private, + qry->request); + } else { + qry->server_selection = (struct kr_server_selection){ + .initialized = true, + .choose_transport = iter_choose_transport, + .update_rtt = iter_update_rtt, + .error = iter_error, + .local_state = local_state, + }; + iter_local_state_alloc( + mempool, &qry->server_selection.local_state->private); + } +} + +int kr_forward_add_target(struct kr_request *req, const struct sockaddr *sock) +{ + if (!req->selection_context.forwarding_targets.at) { + return kr_error(EINVAL); + } + + union kr_sockaddr address; + + switch (sock->sa_family) { + case AF_INET: + if (req->options.NO_IPV4) + return kr_error(EINVAL); + address.ip4 = *(const struct sockaddr_in *)sock; + break; + case AF_INET6: + if (req->options.NO_IPV6) + return kr_error(EINVAL); + address.ip6 = *(const struct sockaddr_in6 *)sock; + break; + default: + return kr_error(EINVAL); + } + + array_push_mm(req->selection_context.forwarding_targets, address, + kr_memreserve, &req->pool); + return kr_ok(); +} -- cgit v1.2.3