summaryrefslogtreecommitdiffstats
path: root/include/dnsjit/filter/ipsplit.c
diff options
context:
space:
mode:
Diffstat (limited to 'include/dnsjit/filter/ipsplit.c')
-rw-r--r--include/dnsjit/filter/ipsplit.c270
1 files changed, 270 insertions, 0 deletions
diff --git a/include/dnsjit/filter/ipsplit.c b/include/dnsjit/filter/ipsplit.c
new file mode 100644
index 0000000..8632e2e
--- /dev/null
+++ b/include/dnsjit/filter/ipsplit.c
@@ -0,0 +1,270 @@
+/*
+ * Copyright (c) 2019-2020, CZ.NIC, z.s.p.o.
+ * All rights reserved.
+ *
+ * This file is part of dnsjit.
+ *
+ * dnsjit is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * dnsjit is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with dnsjit. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+#include "config.h"
+
+#include "filter/ipsplit.h"
+#include "core/assert.h"
+#include "core/object/ip.h"
+#include "core/object/ip6.h"
+#include "lib/trie.h"
+
+#include <string.h>
+
+typedef struct _filter_ipsplit {
+ filter_ipsplit_t pub;
+
+ trie_t* trie;
+ uint32_t weight_total;
+} _filter_ipsplit_t;
+
+typedef struct _client {
+ /* Receiver-specific client ID (1..N) in host byte order. */
+ /* Client ID starts at 1 to avoid issues with lua. */
+ uint8_t id[4];
+
+ filter_ipsplit_recv_t* recv;
+} _client_t;
+
+#define _self ((_filter_ipsplit_t*)self)
+
+static core_log_t _log = LOG_T_INIT("filter.ipsplit");
+static filter_ipsplit_t _defaults = {
+ LOG_T_INIT_OBJ("filter.ipsplit"),
+ IPSPLIT_MODE_SEQUENTIAL, IPSPLIT_OVERWRITE_NONE,
+ 0,
+ NULL
+};
+
+core_log_t* filter_ipsplit_log()
+{
+ return &_log;
+}
+
+filter_ipsplit_t* filter_ipsplit_new()
+{
+ filter_ipsplit_t* self;
+
+ mlfatal_oom(self = malloc(sizeof(_filter_ipsplit_t)));
+ *self = _defaults;
+ lfatal_oom(_self->trie = trie_create(NULL));
+ _self->weight_total = 0;
+
+ return self;
+}
+
+static int _free_trie_value(trie_val_t* val, void* ctx)
+{
+ free(*val);
+ return 0;
+}
+
+void filter_ipsplit_free(filter_ipsplit_t* self)
+{
+ filter_ipsplit_recv_t* first;
+ filter_ipsplit_recv_t* r;
+ mlassert_self();
+
+ trie_apply(_self->trie, _free_trie_value, NULL);
+ trie_free(_self->trie);
+
+ if (self->recv) {
+ first = self->recv;
+ do {
+ r = self->recv->next;
+ free(self->recv);
+ self->recv = r;
+ } while (self->recv != first);
+ }
+
+ free(self);
+}
+
+void filter_ipsplit_add(filter_ipsplit_t* self, core_receiver_t recv, void* ctx, uint32_t weight)
+{
+ filter_ipsplit_recv_t* r;
+ mlassert_self();
+ lassert(recv, "recv is nil");
+ lassert(weight > 0, "weight must be positive integer");
+
+ _self->weight_total += weight;
+
+ lfatal_oom(r = malloc(sizeof(filter_ipsplit_recv_t)));
+ r->recv = recv;
+ r->ctx = ctx;
+ r->n_clients = 0;
+ r->weight = weight;
+
+ if (!self->recv) {
+ r->next = r;
+ self->recv = r;
+ } else {
+ r->next = self->recv->next;
+ self->recv->next = r;
+ }
+}
+
+/*
+ * Use portable pseudo-random number generator.
+ */
+static uint32_t _rand_val = 1;
+
+static uint32_t _rand()
+{
+ _rand_val = ((_rand_val * 1103515245) + 12345) & 0x7fffffff;
+ return _rand_val;
+}
+
+void filter_ipsplit_srand(uint32_t seed)
+{
+ _rand_val = seed;
+}
+
+static void _assign_client_to_receiver(filter_ipsplit_t* self, _client_t* client)
+{
+ uint32_t id = 0;
+ filter_ipsplit_recv_t* recv = 0;
+
+ switch (self->mode) {
+ case IPSPLIT_MODE_SEQUENTIAL:
+ recv = self->recv;
+ id = ++recv->n_clients;
+ /* When *weight* clients are assigned, switch to next receiver. */
+ if (recv->n_clients % recv->weight == 0)
+ self->recv = recv->next;
+ break;
+ case IPSPLIT_MODE_RANDOM: {
+ /* Get random number from [1, weight_total], then iterate through
+ * receivers until their weights add up to at least this value. */
+ int32_t random = (int32_t)(_rand() % _self->weight_total) + 1;
+ while (random > 0) {
+ random -= self->recv->weight;
+ if (random > 0)
+ self->recv = self->recv->next;
+ }
+ recv = self->recv;
+ id = ++recv->n_clients;
+ break;
+ }
+ default:
+ lfatal("invalid ipsplit mode");
+ }
+
+ client->recv = recv;
+ memcpy(client->id, &id, sizeof(client->id));
+}
+
+/*
+ * Optionally, write client ID into byte 0-3 of src/dst IP address in the packet.
+ *
+ * Client ID is a 4-byte array in host byte order.
+ */
+static void _overwrite(filter_ipsplit_t* self, core_object_t* obj, _client_t* client)
+{
+ mlassert_self();
+ lassert(obj, "invalid object");
+ lassert(client, "invalid client");
+
+ core_object_ip_t* ip;
+ core_object_ip6_t* ip6;
+
+ switch (self->overwrite) {
+ case IPSPLIT_OVERWRITE_NONE:
+ return;
+ case IPSPLIT_OVERWRITE_SRC:
+ if (obj->obj_type == CORE_OBJECT_IP) {
+ ip = (core_object_ip_t*)obj;
+ memcpy(&ip->src, client->id, sizeof(client->id));
+ } else if (obj->obj_type == CORE_OBJECT_IP6) {
+ ip6 = (core_object_ip6_t*)obj;
+ memcpy(&ip6->src, client->id, sizeof(client->id));
+ }
+ break;
+ case IPSPLIT_OVERWRITE_DST:
+ if (obj->obj_type == CORE_OBJECT_IP) {
+ ip = (core_object_ip_t*)obj;
+ memcpy(&ip->dst, client->id, sizeof(client->id));
+ } else if (obj->obj_type == CORE_OBJECT_IP6) {
+ ip6 = (core_object_ip6_t*)obj;
+ memcpy(&ip6->dst, client->id, sizeof(client->id));
+ }
+ break;
+ default:
+ lfatal("invalid overwrite mode");
+ }
+}
+
+static void _receive(filter_ipsplit_t* self, const core_object_t* obj)
+{
+ mlassert_self();
+
+ /* Find ip/ip6 object in chain. */
+ core_object_t* pkt = (core_object_t*)obj;
+ while (pkt != NULL) {
+ if (pkt->obj_type == CORE_OBJECT_IP || pkt->obj_type == CORE_OBJECT_IP6)
+ break;
+ pkt = (core_object_t*)pkt->obj_prev;
+ }
+ if (pkt == NULL) {
+ self->discarded++;
+ lwarning("packet discarded (missing ip/ip6 object)");
+ return;
+ }
+
+ /* Lookup IPv4/IPv6 address in trie (prefix-tree). Inserts new node if not found. */
+ trie_val_t* node = 0;
+ switch (pkt->obj_type) {
+ case CORE_OBJECT_IP: {
+ core_object_ip_t* ip = (core_object_ip_t*)pkt;
+ node = trie_get_ins(_self->trie, ip->src, sizeof(ip->src));
+ break;
+ }
+ case CORE_OBJECT_IP6: {
+ core_object_ip6_t* ip6 = (core_object_ip6_t*)pkt;
+ node = trie_get_ins(_self->trie, ip6->src, sizeof(ip6->src));
+ break;
+ }
+ default:
+ lfatal("unsupported object type");
+ }
+ lassert(node, "trie failure");
+
+ _client_t* client;
+ if (*node == NULL) { /* IP address not found in tree -> create new client. */
+ lfatal_oom(client = malloc(sizeof(_client_t)));
+ *node = (void*)client;
+ _assign_client_to_receiver(self, client);
+ }
+
+ client = (_client_t*)*node;
+ _overwrite(self, pkt, client);
+ client->recv->recv(client->recv->ctx, obj);
+}
+
+core_receiver_t filter_ipsplit_receiver(filter_ipsplit_t* self)
+{
+ mlassert_self();
+
+ if (!self->recv) {
+ lfatal("no receiver(s) set");
+ }
+
+ return (core_receiver_t)_receive;
+}