/* * Copyright (c) 2019-2020, CZ.NIC, z.s.p.o. * All rights reserved. * * This file is part of dnsjit. * * dnsjit is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * dnsjit is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with dnsjit. If not, see . */ #include "config.h" #include "filter/ipsplit.h" #include "core/assert.h" #include "core/object/ip.h" #include "core/object/ip6.h" #include "lib/trie.h" #include typedef struct _filter_ipsplit { filter_ipsplit_t pub; trie_t* trie; uint32_t weight_total; } _filter_ipsplit_t; typedef struct _client { /* Receiver-specific client ID (1..N) in host byte order. */ /* Client ID starts at 1 to avoid issues with lua. */ uint8_t id[4]; filter_ipsplit_recv_t* recv; } _client_t; #define _self ((_filter_ipsplit_t*)self) static core_log_t _log = LOG_T_INIT("filter.ipsplit"); static filter_ipsplit_t _defaults = { LOG_T_INIT_OBJ("filter.ipsplit"), IPSPLIT_MODE_SEQUENTIAL, IPSPLIT_OVERWRITE_NONE, 0, NULL }; core_log_t* filter_ipsplit_log() { return &_log; } filter_ipsplit_t* filter_ipsplit_new() { filter_ipsplit_t* self; mlfatal_oom(self = malloc(sizeof(_filter_ipsplit_t))); *self = _defaults; lfatal_oom(_self->trie = trie_create(NULL)); _self->weight_total = 0; return self; } static int _free_trie_value(trie_val_t* val, void* ctx) { free(*val); return 0; } void filter_ipsplit_free(filter_ipsplit_t* self) { filter_ipsplit_recv_t* first; filter_ipsplit_recv_t* r; mlassert_self(); trie_apply(_self->trie, _free_trie_value, NULL); trie_free(_self->trie); if (self->recv) { first = self->recv; do { r = self->recv->next; free(self->recv); self->recv = r; } while (self->recv != first); } free(self); } void filter_ipsplit_add(filter_ipsplit_t* self, core_receiver_t recv, void* ctx, uint32_t weight) { filter_ipsplit_recv_t* r; mlassert_self(); lassert(recv, "recv is nil"); lassert(weight > 0, "weight must be positive integer"); _self->weight_total += weight; lfatal_oom(r = malloc(sizeof(filter_ipsplit_recv_t))); r->recv = recv; r->ctx = ctx; r->n_clients = 0; r->weight = weight; if (!self->recv) { r->next = r; self->recv = r; } else { r->next = self->recv->next; self->recv->next = r; } } /* * Use portable pseudo-random number generator. */ static uint32_t _rand_val = 1; static uint32_t _rand() { _rand_val = ((_rand_val * 1103515245) + 12345) & 0x7fffffff; return _rand_val; } void filter_ipsplit_srand(uint32_t seed) { _rand_val = seed; } static void _assign_client_to_receiver(filter_ipsplit_t* self, _client_t* client) { uint32_t id = 0; filter_ipsplit_recv_t* recv = 0; switch (self->mode) { case IPSPLIT_MODE_SEQUENTIAL: recv = self->recv; id = ++recv->n_clients; /* When *weight* clients are assigned, switch to next receiver. */ if (recv->n_clients % recv->weight == 0) self->recv = recv->next; break; case IPSPLIT_MODE_RANDOM: { /* Get random number from [1, weight_total], then iterate through * receivers until their weights add up to at least this value. */ int32_t random = (int32_t)(_rand() % _self->weight_total) + 1; while (random > 0) { random -= self->recv->weight; if (random > 0) self->recv = self->recv->next; } recv = self->recv; id = ++recv->n_clients; break; } default: lfatal("invalid ipsplit mode"); } client->recv = recv; memcpy(client->id, &id, sizeof(client->id)); } /* * Optionally, write client ID into byte 0-3 of src/dst IP address in the packet. * * Client ID is a 4-byte array in host byte order. */ static void _overwrite(filter_ipsplit_t* self, core_object_t* obj, _client_t* client) { mlassert_self(); lassert(obj, "invalid object"); lassert(client, "invalid client"); core_object_ip_t* ip; core_object_ip6_t* ip6; switch (self->overwrite) { case IPSPLIT_OVERWRITE_NONE: return; case IPSPLIT_OVERWRITE_SRC: if (obj->obj_type == CORE_OBJECT_IP) { ip = (core_object_ip_t*)obj; memcpy(&ip->src, client->id, sizeof(client->id)); } else if (obj->obj_type == CORE_OBJECT_IP6) { ip6 = (core_object_ip6_t*)obj; memcpy(&ip6->src, client->id, sizeof(client->id)); } break; case IPSPLIT_OVERWRITE_DST: if (obj->obj_type == CORE_OBJECT_IP) { ip = (core_object_ip_t*)obj; memcpy(&ip->dst, client->id, sizeof(client->id)); } else if (obj->obj_type == CORE_OBJECT_IP6) { ip6 = (core_object_ip6_t*)obj; memcpy(&ip6->dst, client->id, sizeof(client->id)); } break; default: lfatal("invalid overwrite mode"); } } static void _receive(filter_ipsplit_t* self, const core_object_t* obj) { mlassert_self(); /* Find ip/ip6 object in chain. */ core_object_t* pkt = (core_object_t*)obj; while (pkt != NULL) { if (pkt->obj_type == CORE_OBJECT_IP || pkt->obj_type == CORE_OBJECT_IP6) break; pkt = (core_object_t*)pkt->obj_prev; } if (pkt == NULL) { self->discarded++; lwarning("packet discarded (missing ip/ip6 object)"); return; } /* Lookup IPv4/IPv6 address in trie (prefix-tree). Inserts new node if not found. */ trie_val_t* node = 0; switch (pkt->obj_type) { case CORE_OBJECT_IP: { core_object_ip_t* ip = (core_object_ip_t*)pkt; node = trie_get_ins(_self->trie, ip->src, sizeof(ip->src)); break; } case CORE_OBJECT_IP6: { core_object_ip6_t* ip6 = (core_object_ip6_t*)pkt; node = trie_get_ins(_self->trie, ip6->src, sizeof(ip6->src)); break; } default: lfatal("unsupported object type"); } lassert(node, "trie failure"); _client_t* client; if (*node == NULL) { /* IP address not found in tree -> create new client. */ lfatal_oom(client = malloc(sizeof(_client_t))); *node = (void*)client; _assign_client_to_receiver(self, client); } client = (_client_t*)*node; _overwrite(self, pkt, client); client->recv->recv(client->recv->ctx, obj); } core_receiver_t filter_ipsplit_receiver(filter_ipsplit_t* self) { mlassert_self(); if (!self->recv) { lfatal("no receiver(s) set"); } return (core_receiver_t)_receive; }