summaryrefslogtreecommitdiffstats
path: root/lib/generic/lru.c
blob: 857b20b3514530837dcfe17d7dbb437a603d8ad5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
/*  Copyright (C) CZ.NIC, z.s.p.o. <knot-resolver@labs.nic.cz>
 *  SPDX-License-Identifier: GPL-3.0-or-later
 */

#include "lib/generic/lru.h"
#include "contrib/murmurhash3/murmurhash3.h"
#include "contrib/ucw/mempool.h"

typedef struct lru_group lru_group_t;

struct lru_item {
	uint16_t key_len, val_len; /**< Two bytes should be enough for our purposes. */
	char data[];
	/**< Place for both key and value.
	 *
	 * We use "char" to satisfy the C99+ aliasing rules.
	 * See C99 section 6.5 Expressions, paragraph 7.
	 * Any type can be accessed through char-pointer,
	 * so we can use a common struct definition
	 * for all types being held.
	 *
	 * The value address is restricted by val_alignment.
	 * Approach: require slightly larger sizes from the allocator
	 * and shift value on the closest address with val_alignment.
	 */
};

/** @brief Round the value up to a multiple of mul (a power of two). */
static inline size_t round_power(size_t size, size_t mult)
{
	kr_require(__builtin_popcount(mult) == 1);
	size_t res = ((size - 1) & ~(mult - 1)) + mult;
	kr_require(__builtin_ctz(res) >= __builtin_ctz(mult));
	kr_require(size <= res && res < size + mult);
	return res;
}

/** @internal Compute the allocation size for an lru_item. */
static uint item_size(const struct lru *lru, uint key_len, uint val_len)
{
	uint key_end = offsetof(struct lru_item, data) + key_len;
	return key_end + (lru->val_alignment - 1) + val_len;
	/*             ^^^ worst-case padding length
	 * Well, we might compute the bound a bit more precisely,
	 * as we know that lru_item will get alignment at least
	 * some sizeof(void*) and we know all the lengths,
	 * but let's not complicate it, as the gain would be small anyway. */
}

/** @internal Return pointer to value in an lru_item. */
static void * item_val(const struct lru *lru, struct lru_item *it)
{
	size_t key_end = it->data + it->key_len - (char *)NULL;
	size_t val_begin = round_power(key_end, lru->val_alignment);
	return (char *)NULL + val_begin;
}

/** @internal Free each item. */
KR_EXPORT void lru_free_items_impl(struct lru *lru)
{
	if (kr_fails_assert(lru))
		return;
	for (size_t i = 0; i < (1 << (size_t)lru->log_groups); ++i) {
		lru_group_t *g = &lru->groups[i];
		for (int j = 0; j < LRU_ASSOC; ++j)
			mm_free(lru->mm, g->items[j]);
	}
}

/** @internal See lru_apply. */
KR_EXPORT void lru_apply_impl(struct lru *lru, lru_apply_fun f, void *baton)
{
	if (kr_fails_assert(lru && f))
		return;
	for (size_t i = 0; i < (1 << (size_t)lru->log_groups); ++i) {
		lru_group_t *g = &lru->groups[i];
		for (uint j = 0; j < LRU_ASSOC; ++j) {
			struct lru_item *it = g->items[j];
			if (!it)
				continue;
			enum lru_apply_do ret =
				f(it->data, it->key_len, item_val(lru, it), baton);
			switch(ret) {
			case LRU_APPLY_DO_EVICT: // evict
				mm_free(lru->mm, it);
				g->items[j] = NULL;
				g->counts[j] = 0;
				g->hashes[j] = 0;
				break;
			default:
				kr_assert(ret == LRU_APPLY_DO_NOTHING);
			}
		}
	}
}

/** @internal See lru_create. */
KR_EXPORT struct lru * lru_create_impl(uint max_slots, uint val_alignment,
					knot_mm_t *mm_array, knot_mm_t *mm)
{
	if (kr_fails_assert(max_slots && __builtin_popcount(val_alignment) == 1))
		return NULL;
	// let lru->log_groups = ceil(log2(max_slots / (float) assoc))
	//   without trying for efficiency
	uint group_count = (max_slots - 1) / LRU_ASSOC + 1;
	uint log_groups = 0;
	for (uint s = group_count - 1; s; s /= 2)
		++log_groups;
	group_count = 1 << log_groups;
	if (kr_fails_assert(max_slots <= group_count * LRU_ASSOC && group_count * LRU_ASSOC < 2 * max_slots))
		return NULL;

	/* Get a sufficiently aligning mm_array if NULL is passed. */
	if (!mm_array) {
		static knot_mm_t mm_array_default = { 0 };
		if (!mm_array_default.ctx)
			mm_ctx_init_aligned(&mm_array_default, alignof(struct lru));
		mm_array = &mm_array_default;
	}
	if (kr_fails_assert(mm_array->alloc && mm_array->alloc != (knot_mm_alloc_t)mp_alloc))
		return NULL;

	size_t size = offsetof(struct lru, groups[group_count]);
	struct lru *lru = mm_alloc(mm_array, size);
	if (unlikely(lru == NULL))
		return NULL;
	*lru = (struct lru){
		.mm = mm,
		.mm_array = mm_array,
		.log_groups = log_groups,
		.val_alignment = val_alignment,
	};
	// zeros are a good init
	memset(lru->groups, 0, size - offsetof(struct lru, groups));
	return lru;
}

/** @internal Decrement all counters within a group. */
static void group_dec_counts(lru_group_t *g) {
	g->counts[LRU_TRACKED] = LRU_TRACKED;
	for (uint i = 0; i < LRU_TRACKED + 1; ++i)
		if (likely(g->counts[i]))
			--g->counts[i];
}

/** @internal Increment a counter within a group. */
static void group_inc_count(lru_group_t *g, int i) {
	if (likely(++(g->counts[i])))
       		return;
	g->counts[i] = -1;
	// We could've decreased or halved all of them, but let's keep the max.
}

/** @internal Implementation of both getting and insertion.
 * Note: val_len is only meaningful if do_insert.
 *       *is_new is only meaningful when return value isn't NULL, contains
 *	 true when returned lru entry has been allocated right now
 *	 if return value is NULL, *is_new remains untouched.
 */
KR_EXPORT void * lru_get_impl(struct lru *lru, const char *key, uint key_len,
			      uint val_len, bool do_insert, bool *is_new)
{
	bool ok = lru && (key || !key_len) && key_len <= UINT16_MAX
		   && (!do_insert || val_len <= UINT16_MAX);
	if (kr_fails_assert(ok))
		return NULL; // reasonable fallback when not debugging
	bool is_new_entry = false;
	// find the right group
	uint32_t khash = hash(key, key_len);
	uint16_t khash_top = khash >> 16;
	lru_group_t *g = &lru->groups[khash & ((1 << lru->log_groups) - 1)];
	struct lru_item *it = NULL;
	uint i;
	// scan the *stored* elements in the group
	for (i = 0; i < LRU_ASSOC; ++i) {
		if (g->hashes[i] == khash_top) {
			it = g->items[i];
			if (likely(it && it->key_len == key_len
					&& (key_len == 0 || memcmp(it->data, key, key_len) == 0))) {
				/* Found a key, but trying to insert a value larger than available
				 * space in the allocated slot, so the entry must be resized to fit. */
				if (unlikely(do_insert && val_len > it->val_len)) {
					goto insert;
				} else {
					goto found; // to reduce huge nesting depth
				}
			}
		}
	}
	// key not found; first try an empty/counted-out place to insert
	if (do_insert)
		for (i = 0; i < LRU_ASSOC; ++i)
			if (g->items[i] == NULL || g->counts[i] == 0)
				goto insert;
	// check if we track key's count at least
	for (i = LRU_ASSOC; i < LRU_TRACKED; ++i) {
		if (g->hashes[i] == khash_top) {
			group_inc_count(g, i);
			if (!do_insert)
				return NULL;
			// check if we trumped some stored key
			for (uint j = 0; j < LRU_ASSOC; ++j)
				if (unlikely(g->counts[i] > g->counts[j])) {
					// evict key j, i.e. swap with i
					--g->counts[i]; // we increment it below
					SWAP(g->counts[i], g->counts[j]);
					SWAP(g->hashes[i], g->hashes[j]);
					i = j;
					goto insert;
				}
			return NULL;
		}
	}
	// not found at all: decrement all counts but only on every LRU_TRACKED occasion
	if (g->counts[LRU_TRACKED])
		--g->counts[LRU_TRACKED];
	else
		group_dec_counts(g);
	return NULL;
insert: // insert into position i (incl. key)
	if (kr_fails_assert(i < LRU_ASSOC))
		return NULL;
	g->hashes[i] = khash_top;
	it = g->items[i];
	uint new_size = item_size(lru, key_len, val_len);
	if (it == NULL || new_size != item_size(lru, it->key_len, it->val_len)) {
		// (re)allocate
		mm_free(lru->mm, it);
		it = g->items[i] = mm_alloc(lru->mm, new_size);
		if (it == NULL)
			return NULL;
	}
	it->key_len = key_len;
	it->val_len = val_len;
	if (key_len > 0) {
		memcpy(it->data, key, key_len);
	}
	memset(item_val(lru, it), 0, val_len); // clear the value
	is_new_entry = true;
found: // key and hash OK on g->items[i]; now update stamps
	if (kr_fails_assert(i < LRU_ASSOC))
		return NULL;
	group_inc_count(g, i);
	if (is_new) {
		*is_new = is_new_entry;
	}
	return item_val(lru, g->items[i]);
}