summaryrefslogtreecommitdiffstats
path: root/src/lib-dict/dict-memcached.c
blob: c5b7ce63faf001fe8c1a978a4288b15d88f2ea25 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
/* Copyright (c) 2013-2018 Dovecot authors, see the included COPYING memcached */

#include "lib.h"
#include "array.h"
#include "str.h"
#include "istream.h"
#include "ostream.h"
#include "connection.h"
#include "dict-private.h"

#define MEMCACHED_DEFAULT_PORT 11211
#define MEMCACHED_DEFAULT_LOOKUP_TIMEOUT_MSECS (1000*30)

/* we need only very limited memcached functionality, so just define the binary
   protocol ourself instead requiring protocol_binary.h */
#define MEMCACHED_REQUEST_HDR_MAGIC 0x80
#define MEMCACHED_REPLY_HDR_MAGIC 0x81

#define MEMCACHED_REQUEST_HDR_LENGTH 24
#define MEMCACHED_REPLY_HDR_LENGTH 24

#define MEMCACHED_CMD_GET 0x00

#define MEMCACHED_DATA_TYPE_RAW 0x00

enum memcached_response {
	MEMCACHED_RESPONSE_OK		= 0x0000,
	MEMCACHED_RESPONSE_NOTFOUND	= 0x0001,
	MEMCACHED_RESPONSE_INTERNALERROR= 0x0084,
	MEMCACHED_RESPONSE_BUSY		= 0x0085,
	MEMCACHED_RESPONSE_TEMPFAILURE	= 0x0086,
};

struct memcached_connection {
	struct connection conn;
	struct memcached_dict *dict;

	buffer_t *cmd;
	struct {
		const unsigned char *value;
		size_t value_len;
		uint16_t status; /* enum memcached_response */
		bool reply_received;
	} reply;
};

struct memcached_dict {
	struct dict dict;
	struct ip_addr ip;
	char *key_prefix;
	in_port_t port;
	unsigned int timeout_msecs;

	struct memcached_connection conn;

	bool connected;
};

static struct connection_list *memcached_connections;

static void memcached_conn_destroy(struct connection *_conn)
{
	struct memcached_connection *conn = (struct memcached_connection *)_conn;

	conn->dict->connected = FALSE;
	connection_disconnect(_conn);

	if (conn->dict->dict.ioloop != NULL)
		io_loop_stop(conn->dict->dict.ioloop);
}

static int memcached_input_get(struct memcached_connection *conn)
{
	const unsigned char *data;
	size_t size;
	uint32_t body_len, value_pos;
	uint16_t key_len, key_pos, status;
	uint8_t extras_len, data_type;

	data = i_stream_get_data(conn->conn.input, &size);
	if (size < MEMCACHED_REPLY_HDR_LENGTH)
		return 0;

	if (data[0] != MEMCACHED_REPLY_HDR_MAGIC) {
		e_error(conn->conn.event, "Invalid reply magic: %u != %u",
			data[0], MEMCACHED_REPLY_HDR_MAGIC);
		return -1;
	}
	memcpy(&body_len, data+8, 4); body_len = ntohl(body_len);
	body_len += MEMCACHED_REPLY_HDR_LENGTH;
	if (size < body_len) {
		/* we haven't read the whole response yet */
		return 0;
	}

	memcpy(&key_len, data+2, 2); key_len = ntohs(key_len);
	extras_len = data[4];
	data_type = data[5];
	memcpy(&status, data+6, 2); status = ntohs(status);
	if (data_type != MEMCACHED_DATA_TYPE_RAW) {
		e_error(conn->conn.event, "Unsupported data type: %u != %u",
			data[0], MEMCACHED_DATA_TYPE_RAW);
		return -1;
	}

	key_pos = MEMCACHED_REPLY_HDR_LENGTH + extras_len;
	value_pos = key_pos + key_len;
	if (value_pos > body_len) {
		e_error(conn->conn.event, "Invalid key/extras lengths");
		return -1;
	}
	conn->reply.value = data + value_pos;
	conn->reply.value_len = body_len - value_pos;
	conn->reply.status = status;

	i_stream_skip(conn->conn.input, body_len);
	conn->reply.reply_received = TRUE;

	if (conn->dict->dict.ioloop != NULL)
		io_loop_stop(conn->dict->dict.ioloop);
	return 1;
}

static void memcached_conn_input(struct connection *_conn)
{
	struct memcached_connection *conn = (struct memcached_connection *)_conn;

	switch (i_stream_read(_conn->input)) {
	case 0:
		return;
	case -1:
		memcached_conn_destroy(_conn);
		return;
	default:
		break;
	}

	if (memcached_input_get(conn) < 0)
		memcached_conn_destroy(_conn);
}

static void memcached_conn_connected(struct connection *_conn, bool success)
{
	struct memcached_connection *conn =
		(struct memcached_connection *)_conn;

	if (!success) {
		e_error(conn->conn.event, "connect() failed: %m");
	} else {
		conn->dict->connected = TRUE;
	}
	if (conn->dict->dict.ioloop != NULL)
		io_loop_stop(conn->dict->dict.ioloop);
}

static const struct connection_settings memcached_conn_set = {
	.input_max_size = SIZE_MAX,
	.output_max_size = SIZE_MAX,
	.client = TRUE
};

static const struct connection_vfuncs memcached_conn_vfuncs = {
	.destroy = memcached_conn_destroy,
	.input = memcached_conn_input,
	.client_connected = memcached_conn_connected
};

static int
memcached_dict_init(struct dict *driver, const char *uri,
		    const struct dict_settings *set,
		    struct dict **dict_r, const char **error_r)
{
	struct memcached_dict *dict;
	const char *const *args;
	int ret = 0;

	if (memcached_connections == NULL) {
		memcached_connections =
			connection_list_init(&memcached_conn_set,
					     &memcached_conn_vfuncs);
	}

	dict = i_new(struct memcached_dict, 1);
	if (net_addr2ip("127.0.0.1", &dict->ip) < 0)
		i_unreached();
	dict->port = MEMCACHED_DEFAULT_PORT;
	dict->timeout_msecs = MEMCACHED_DEFAULT_LOOKUP_TIMEOUT_MSECS;
	dict->key_prefix = i_strdup("");

	args = t_strsplit(uri, ":");
	for (; *args != NULL; args++) {
		if (str_begins(*args, "host=")) {
			if (net_addr2ip(*args+5, &dict->ip) < 0) {
				*error_r = t_strdup_printf("Invalid IP: %s",
							   *args+5);
				ret = -1;
			}
		} else if (str_begins(*args, "port=")) {
			if (net_str2port(*args+5, &dict->port) < 0) {
				*error_r = t_strdup_printf("Invalid port: %s",
							   *args+5);
				ret = -1;
			}
		} else if (str_begins(*args, "prefix=")) {
			i_free(dict->key_prefix);
			dict->key_prefix = i_strdup(*args + 7);
		} else if (str_begins(*args, "timeout_msecs=")) {
			if (str_to_uint(*args+14, &dict->timeout_msecs) < 0) {
				*error_r = t_strdup_printf(
					"Invalid timeout_msecs: %s", *args+14);
				ret = -1;
			}
		} else {
			*error_r = t_strdup_printf("Unknown parameter: %s",
						   *args);
			ret = -1;
		}
	}
	if (ret < 0) {
		i_free(dict->key_prefix);
		i_free(dict);
		return -1;
	}

	dict->conn.conn.event_parent = set->event_parent;

	connection_init_client_ip(memcached_connections, &dict->conn.conn,
				  NULL, &dict->ip, dict->port);
	event_set_append_log_prefix(dict->conn.conn.event, "memcached: ");
	dict->dict = *driver;
	dict->conn.cmd = buffer_create_dynamic(default_pool, 256);
	dict->conn.dict = dict;
	*dict_r = &dict->dict;
	return 0;
}

static void memcached_dict_deinit(struct dict *_dict)
{
	struct memcached_dict *dict = (struct memcached_dict *)_dict;

	connection_deinit(&dict->conn.conn);
	buffer_free(&dict->conn.cmd);
	i_free(dict->key_prefix);
	i_free(dict);

	if (memcached_connections->connections == NULL)
		connection_list_deinit(&memcached_connections);
}

static void memcached_dict_lookup_timeout(struct memcached_dict *dict)
{
	e_error(dict->dict.event, "Lookup timed out in %u.%03u secs",
		dict->timeout_msecs/1000, dict->timeout_msecs%1000);
	io_loop_stop(dict->dict.ioloop);
}

static void memcached_add_header(buffer_t *buf, unsigned int key_len)
{
	uint32_t body_len = htonl(key_len);

	i_assert(key_len <= 0xffff);

	buffer_append_c(buf, MEMCACHED_REQUEST_HDR_MAGIC);
	buffer_append_c(buf, MEMCACHED_CMD_GET);
	buffer_append_c(buf, (key_len >> 8) & 0xff);
	buffer_append_c(buf, key_len & 0xff);
	buffer_append_c(buf, 0); /* extras length */
	buffer_append_c(buf, MEMCACHED_DATA_TYPE_RAW);
	buffer_append_zero(buf, 2); /* vbucket id - we probably don't care? */
	buffer_append(buf, &body_len, sizeof(body_len));
	buffer_append_zero(buf, 4+8); /* opaque + cas */
	i_assert(buf->used == MEMCACHED_REQUEST_HDR_LENGTH);
}

static int
memcached_dict_lookup(struct dict *_dict, const struct dict_op_settings *set ATTR_UNUSED,
		      pool_t pool, const char *key, const char **value_r,
		      const char **error_r)
{
	struct memcached_dict *dict = (struct memcached_dict *)_dict;
	struct ioloop *prev_ioloop = current_ioloop;
	struct timeout *to;
	size_t key_len;

	if (str_begins(key, DICT_PATH_SHARED))
		key += strlen(DICT_PATH_SHARED);
	else {
		*error_r = t_strdup_printf("memcached: Only shared keys supported currently");
		return -1;
	}
	if (*dict->key_prefix != '\0')
		key = t_strconcat(dict->key_prefix, key, NULL);
	key_len = strlen(key);
	if (key_len > 0xffff) {
		*error_r = t_strdup_printf(
			"memcached: Key is too long (%zu bytes): %s", key_len, key);
		return -1;
	}

	i_assert(dict->dict.ioloop == NULL);

	dict->dict.ioloop = io_loop_create();
	connection_switch_ioloop(&dict->conn.conn);

	if (dict->conn.conn.fd_in == -1 &&
	    connection_client_connect(&dict->conn.conn) < 0) {
		e_error(dict->conn.conn.event, "Couldn't connect");
	} else {
		to = timeout_add(dict->timeout_msecs,
				 memcached_dict_lookup_timeout, dict);
		if (!dict->connected) {
			/* wait for connection */
			io_loop_run(dict->dict.ioloop);
		}

		if (dict->connected) {
			buffer_set_used_size(dict->conn.cmd, 0);
			memcached_add_header(dict->conn.cmd, key_len);
			buffer_append(dict->conn.cmd, key, key_len);

			o_stream_nsend(dict->conn.conn.output,
				       dict->conn.cmd->data,
				       dict->conn.cmd->used);

			i_zero(&dict->conn.reply);
			io_loop_run(dict->dict.ioloop);
		}
		timeout_remove(&to);
	}

	io_loop_set_current(prev_ioloop);
	connection_switch_ioloop(&dict->conn.conn);
	io_loop_set_current(dict->dict.ioloop);
	io_loop_destroy(&dict->dict.ioloop);

	if (!dict->conn.reply.reply_received) {
		/* we failed in some way. make sure we disconnect since the
		   connection state isn't known anymore */
		memcached_conn_destroy(&dict->conn.conn);
		*error_r = "Communication failure";
		return -1;
	}
	switch (dict->conn.reply.status) {
	case MEMCACHED_RESPONSE_OK:
		*value_r = p_strndup(pool, dict->conn.reply.value,
				     dict->conn.reply.value_len);
		return 1;
	case MEMCACHED_RESPONSE_NOTFOUND:
		return 0;
	case MEMCACHED_RESPONSE_INTERNALERROR:
		*error_r = "Lookup failed: Internal error";
		return -1;
	case MEMCACHED_RESPONSE_BUSY:
		*error_r = "Lookup failed: Busy";
		return -1;
	case MEMCACHED_RESPONSE_TEMPFAILURE:
		*error_r = "Lookup failed: Temporary failure";
		return -1;
	}

	*error_r = t_strdup_printf("Lookup failed: Error code=%u",
				   dict->conn.reply.status);
	return -1;
}

struct dict dict_driver_memcached = {
	.name = "memcached",
	{
		.init = memcached_dict_init,
		.deinit = memcached_dict_deinit,
		.lookup = memcached_dict_lookup,
	}
};