summaryrefslogtreecommitdiffstats
path: root/lib/selection_forward.c
blob: 54f9a12262562c9f8f1efc064a665e4d4ab439e3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
/*  Copyright (C) CZ.NIC, z.s.p.o. <knot-resolver@labs.nic.cz>
 *  SPDX-License-Identifier: GPL-3.0-or-later
 */

#include "lib/selection_forward.h"
#include "lib/resolve.h"

#define VERBOSE_MSG(qry, ...) kr_log_q((qry), SELECTION, __VA_ARGS__)

#define FORWARDING_TIMEOUT 2000
/* TODO: well, this is a bit hard; maybe we'd better have a different approach
 * for estimating DEAD-ness for forwarding.
 * Even ACKs on connections might be useful here. */
static_assert(FORWARDING_TIMEOUT >= KR_NS_TIMEOUT_MIN_DEAD_TIMEOUT,
		"Bad combination of NS selection limits.");

struct forward_local_state {
	kr_sockaddr_array_t *targets;
	struct address_state *addr_states;
	/** Index of last choice in the targets array, used for error reporting. */
	size_t last_choice_index;
};

void forward_local_state_alloc(struct knot_mm *mm, void **local_state,
			       struct kr_request *req)
{
	kr_require(req->selection_context.forwarding_targets.at);
	*local_state = mm_calloc(mm, 1, sizeof(struct forward_local_state));

	struct forward_local_state *forward_state = *local_state;
	forward_state->targets = &req->selection_context.forwarding_targets;

	forward_state->addr_states = mm_calloc(mm, forward_state->targets->len,
						sizeof(struct address_state));
}

void forward_choose_transport(struct kr_query *qry,
			      struct kr_transport **transport)
{
	struct forward_local_state *local_state =
		qry->server_selection.local_state->private;
	struct choice choices[local_state->targets->len];
	int valid = 0;

	for (int i = 0; i < local_state->targets->len; i++) {
		union kr_sockaddr *address = &local_state->targets->at[i];
		size_t addr_len;
		uint16_t port;
		switch (address->ip.sa_family) {
		case AF_INET:
			port = ntohs(address->ip4.sin_port);
			addr_len = sizeof(struct in_addr);
			break;
		case AF_INET6:
			port = ntohs(address->ip6.sin6_port);
			addr_len = sizeof(struct in6_addr);
			break;
		default:
			kr_assert(false);
			*transport = NULL;
			goto cleanup;
		}

		struct address_state *addr_state = &local_state->addr_states[i];
		addr_state->ns_name = (knot_dname_t *)"";

		update_address_state(addr_state, address, addr_len, qry);

		if (addr_state->generation == -1) {
			continue;
		}
		addr_state->choice_array_index = i;

		choices[valid++] = (struct choice){
			.address = *address,
			.address_len = addr_len,
			.address_state = addr_state,
			.port = port,
		};
	}

	bool tcp = qry->flags.TCP || qry->server_selection.local_state->truncated;
	*transport =
		select_transport(choices, valid, NULL, 0,
				 qry->server_selection.local_state->timeouts,
				 &qry->request->pool, tcp,
				 &local_state->last_choice_index);
	if (*transport) {
		/* Set static timeout for forwarding; there is no point in this
		 * being dynamic since the RTT of a packet to forwarding target
		 * says nothing about the network RTT of said target, since
		 * it is doing resolution upstream. */
		(*transport)->timeout = FORWARDING_TIMEOUT;
		/* Try to avoid TCP in STUB case.  It seems better for common use cases. */
		if (qry->flags.STUB && !tcp && (*transport)->protocol == KR_TRANSPORT_TCP)
			(*transport)->protocol = KR_TRANSPORT_UDP;
		/* We need to propagate this to flags since it's used in other
		 * parts of the resolver (e.g. logging and stats). */
		qry->flags.TCP = (*transport)->protocol == KR_TRANSPORT_TCP
			      || (*transport)->protocol == KR_TRANSPORT_TLS;
	}
cleanup:
	kr_cache_commit(&qry->request->ctx->cache);
}

void forward_error(struct kr_query *qry, const struct kr_transport *transport,
		   enum kr_selection_error sel_error)
{
	if (!qry->server_selection.initialized) {
		return;
	}
	struct forward_local_state *local_state =
		qry->server_selection.local_state->private;
	struct address_state *addr_state =
		&local_state->addr_states[local_state->last_choice_index];
	error(qry, addr_state, transport, sel_error);
}

void forward_update_rtt(struct kr_query *qry,
			const struct kr_transport *transport, unsigned rtt)
{
	if (!qry->server_selection.initialized) {
		return;
	}

	if (!transport) {
		return;
	}

	struct forward_local_state *local_state =
		qry->server_selection.local_state->private;
	struct address_state *addr_state =
		&local_state->addr_states[local_state->last_choice_index];

	update_rtt(qry, addr_state, transport, rtt);
}