diff options
Diffstat (limited to 'src/lib')
-rw-r--r-- | src/lib/base64url.c | 480 | ||||
-rw-r--r-- | src/lib/base64url.h | 28 | ||||
-rw-r--r-- | src/lib/base64url.hh | 99 | ||||
-rw-r--r-- | src/lib/base64url.lua | 98 | ||||
-rw-r--r-- | src/lib/clock.c | 75 | ||||
-rw-r--r-- | src/lib/clock.h | 28 | ||||
-rw-r--r-- | src/lib/clock.hh | 29 | ||||
-rw-r--r-- | src/lib/clock.lua | 47 | ||||
-rw-r--r-- | src/lib/getopt.lua | 365 | ||||
-rw-r--r-- | src/lib/ip.lua | 125 | ||||
-rw-r--r-- | src/lib/parseconf.lua | 181 | ||||
-rw-r--r-- | src/lib/trie.c | 923 | ||||
-rw-r--r-- | src/lib/trie.h | 39 | ||||
-rw-r--r-- | src/lib/trie.hh | 118 | ||||
-rw-r--r-- | src/lib/trie.lua | 172 | ||||
-rw-r--r-- | src/lib/trie/iter.lua | 93 | ||||
-rw-r--r-- | src/lib/trie/node.lua | 84 |
17 files changed, 2984 insertions, 0 deletions
diff --git a/src/lib/base64url.c b/src/lib/base64url.c new file mode 100644 index 0000000..aacd490 --- /dev/null +++ b/src/lib/base64url.c @@ -0,0 +1,480 @@ +/* Copyright (C) 2020 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz> + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see <https://www.gnu.org/licenses/>. + */ + +#include "lib/base64url.h" + +#include <errno.h> +#include <stdlib.h> +#include <stdint.h> + +/*! \brief Maximal length of binary input to Base64url encoding. */ +#define MAX_BIN_DATA_LEN ((INT32_MAX / 4) * 3) + +/*! \brief Base64url padding character. */ +static const uint8_t base64url_pad = '\0'; +/*! \brief Base64 alphabet. */ +static const uint8_t base64url_enc[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; + +/*! \brief Indicates bad Base64 character. */ +#define KO 255 +/*! \brief Indicates Base64 padding character. */ +#define PD 64 + +/*! \brief Transformation and validation table for decoding Base64. */ +static const uint8_t base64url_dec[256] = { + [0] = PD, + [43] = KO, + ['V'] = 21, + [129] = KO, + [172] = KO, + [215] = KO, + [1] = KO, + [44] = KO, + ['W'] = 22, + [130] = KO, + [173] = KO, + [216] = KO, + [2] = KO, + ['-'] = 62, + ['X'] = 23, + [131] = KO, + [174] = KO, + [217] = KO, + [3] = KO, + [46] = KO, + ['Y'] = 24, + [132] = KO, + [175] = KO, + [218] = KO, + [4] = KO, + [47] = KO, + ['Z'] = 25, + [133] = KO, + [176] = KO, + [219] = KO, + [5] = KO, + ['0'] = 52, + [91] = KO, + [134] = KO, + [177] = KO, + [220] = KO, + [6] = KO, + ['1'] = 53, + [92] = KO, + [135] = KO, + [178] = KO, + [221] = KO, + [7] = KO, + ['2'] = 54, + [93] = KO, + [136] = KO, + [179] = KO, + [222] = KO, + [8] = KO, + ['3'] = 55, + [94] = KO, + [137] = KO, + [180] = KO, + [223] = KO, + [9] = KO, + ['4'] = 56, + ['_'] = 63, + [138] = KO, + [181] = KO, + [224] = KO, + [10] = KO, + ['5'] = 57, + [96] = KO, + [139] = KO, + [182] = KO, + [225] = KO, + [11] = KO, + ['6'] = 58, + ['a'] = 26, + [140] = KO, + [183] = KO, + [226] = KO, + [12] = KO, + ['7'] = 59, + ['b'] = 27, + [141] = KO, + [184] = KO, + [227] = KO, + [13] = KO, + ['8'] = 60, + ['c'] = 28, + [142] = KO, + [185] = KO, + [228] = KO, + [14] = KO, + ['9'] = 61, + ['d'] = 29, + [143] = KO, + [186] = KO, + [229] = KO, + [15] = KO, + [58] = KO, + ['e'] = 30, + [144] = KO, + [187] = KO, + [230] = KO, + [16] = KO, + [59] = KO, + ['f'] = 31, + [145] = KO, + [188] = KO, + [231] = KO, + [17] = KO, + [60] = KO, + ['g'] = 32, + [146] = KO, + [189] = KO, + [232] = KO, + [18] = KO, + [61] = KO, + ['h'] = 33, + [147] = KO, + [190] = KO, + [233] = KO, + [19] = KO, + [62] = KO, + ['i'] = 34, + [148] = KO, + [191] = KO, + [234] = KO, + [20] = KO, + [63] = KO, + ['j'] = 35, + [149] = KO, + [192] = KO, + [235] = KO, + [21] = KO, + [64] = KO, + ['k'] = 36, + [150] = KO, + [193] = KO, + [236] = KO, + [22] = KO, + ['A'] = 0, + ['l'] = 37, + [151] = KO, + [194] = KO, + [237] = KO, + [23] = KO, + ['B'] = 1, + ['m'] = 38, + [152] = KO, + [195] = KO, + [238] = KO, + [24] = KO, + ['C'] = 2, + ['n'] = 39, + [153] = KO, + [196] = KO, + [239] = KO, + [25] = KO, + ['D'] = 3, + ['o'] = 40, + [154] = KO, + [197] = KO, + [240] = KO, + [26] = KO, + ['E'] = 4, + ['p'] = 41, + [155] = KO, + [198] = KO, + [241] = KO, + [27] = KO, + ['F'] = 5, + ['q'] = 42, + [156] = KO, + [199] = KO, + [242] = KO, + [28] = KO, + ['G'] = 6, + ['r'] = 43, + [157] = KO, + [200] = KO, + [243] = KO, + [29] = KO, + ['H'] = 7, + ['s'] = 44, + [158] = KO, + [201] = KO, + [244] = KO, + [30] = KO, + ['I'] = 8, + ['t'] = 45, + [159] = KO, + [202] = KO, + [245] = KO, + [31] = KO, + ['J'] = 9, + ['u'] = 46, + [160] = KO, + [203] = KO, + [246] = KO, + [32] = KO, + ['K'] = 10, + ['v'] = 47, + [161] = KO, + [204] = KO, + [247] = KO, + [33] = KO, + ['L'] = 11, + ['w'] = 48, + [162] = KO, + [205] = KO, + [248] = KO, + [34] = KO, + ['M'] = 12, + ['x'] = 49, + [163] = KO, + [206] = KO, + [249] = KO, + [35] = KO, + ['N'] = 13, + ['y'] = 50, + [164] = KO, + [207] = KO, + [250] = KO, + [36] = KO, + ['O'] = 14, + ['z'] = 51, + [165] = KO, + [208] = KO, + [251] = KO, + [37] = KO, + ['P'] = 15, + [123] = KO, + [166] = KO, + [209] = KO, + [252] = KO, + [38] = KO, + ['Q'] = 16, + [124] = KO, + [167] = KO, + [210] = KO, + [253] = KO, + [39] = KO, + ['R'] = 17, + [125] = KO, + [168] = KO, + [211] = KO, + [254] = KO, + [40] = KO, + ['S'] = 18, + [126] = KO, + [169] = KO, + [212] = KO, + [255] = KO, + [41] = KO, + ['T'] = 19, + [127] = KO, + [170] = KO, + [213] = KO, + [42] = KO, + ['U'] = 20, + [128] = KO, + [171] = KO, + [214] = KO, +}; + +int32_t base64url_encode(const uint8_t* in, + const uint32_t in_len, + uint8_t* out, + const uint32_t out_len) +{ + // Checking inputs. + if (in == NULL || out == NULL) { + return -EINVAL; + } + if (in_len > MAX_BIN_DATA_LEN || out_len < ((in_len + 2) / 3) * 4) { + return -ERANGE; + } + + uint8_t rest_len = in_len % 3; + const uint8_t* stop = in + in_len - rest_len; + uint8_t* text = out; + + // Encoding loop takes 3 bytes and creates 4 characters. + while (in < stop) { + text[0] = base64url_enc[in[0] >> 2]; + text[1] = base64url_enc[(in[0] & 0x03) << 4 | in[1] >> 4]; + text[2] = base64url_enc[(in[1] & 0x0F) << 2 | in[2] >> 6]; + text[3] = base64url_enc[in[2] & 0x3F]; + text += 4; + in += 3; + } + + // Processing of padding, if any. + switch (rest_len) { + case 2: + text[0] = base64url_enc[in[0] >> 2]; + text[1] = base64url_enc[(in[0] & 0x03) << 4 | in[1] >> 4]; + text[2] = base64url_enc[(in[1] & 0x0F) << 2]; + text[3] = base64url_pad; + text += 3; + break; + case 1: + text[0] = base64url_enc[in[0] >> 2]; + text[1] = base64url_enc[(in[0] & 0x03) << 4]; + text[2] = base64url_pad; + text[3] = base64url_pad; + text += 2; + break; + } + return (text - out); +} + +int32_t base64url_encode_alloc(const uint8_t* in, + const uint32_t in_len, + uint8_t** out) +{ + // Checking inputs. + if (out == NULL) { + return -EINVAL; + } + if (in_len > MAX_BIN_DATA_LEN) { + return -ERANGE; + } + + // Compute output buffer length. + uint32_t out_len = ((in_len + 2) / 3) * 4; + + // Allocate output buffer. + *out = malloc(out_len); + if (*out == NULL) { + return -ENOMEM; + } + + // Encode data. + int32_t ret = base64url_encode(in, in_len, *out, out_len); + if (ret < 0) { + free(*out); + *out = NULL; + } + + return ret; +} + +int32_t base64url_decode(const uint8_t* in, + const uint32_t in_len, + uint8_t* out, + const uint32_t out_len) +{ + // Checking inputs. + if (in == NULL || out == NULL) { + return -EINVAL; + } + if (in_len > INT32_MAX || out_len < ((in_len + 3) / 4) * 3) { + return -ERANGE; + } + + const uint8_t* stop = in + in_len; + uint8_t* bin = out; + uint8_t pad_len = 0; + uint8_t c1, c2, c3, c4; + + // Decoding loop takes 4 characters and creates 3 bytes. + while (in < stop) { + // Filling and transforming 4 Base64 chars. + c1 = base64url_dec[in[0]]; + c2 = (in + 1 < stop) ? base64url_dec[in[1]] : PD; + c3 = (in + 2 < stop) ? base64url_dec[in[2]] : PD; + c4 = (in + 3 < stop) ? base64url_dec[in[3]] : PD; + + // Check 4. char if is bad or padding. + if (c4 >= PD) { + if (c4 == PD && pad_len == 0) { + pad_len = 1; + } else { + return -1; + } + } + + // Check 3. char if is bad or padding. + if (c3 >= PD) { + if (c3 == PD && pad_len == 1) { + pad_len = 2; + } else { + return -1; + } + } + + // Check 1. and 2. chars if are not padding. + if (c2 >= PD || c1 >= PD) { + return -1; + } + + // Computing of output data based on padding length. + switch (pad_len) { + case 0: + bin[2] = (c3 << 6) + c4; + // FALLTHROUGH + case 1: + bin[1] = (c2 << 4) + (c3 >> 2); + // FALLTHROUGH + case 2: + bin[0] = (c1 << 2) + (c2 >> 4); + } + + // Update output end. + switch (pad_len) { + case 0: + bin += 3; + break; + case 1: + bin += 2; + break; + case 2: + bin += 1; + break; + } + + in += 4; + } + + return (bin - out); +} + +int32_t base64url_decode_alloc(const uint8_t* in, + const uint32_t in_len, + uint8_t** out) +{ + // Checking inputs. + if (out == NULL) { + return -EINVAL; + } + + // Compute output buffer length. + uint32_t out_len = ((in_len + 3) / 4) * 3; + + // Allocate output buffer. + *out = malloc(out_len); + if (*out == NULL) { + return -ENOMEM; + } + + // Decode data. + int32_t ret = base64url_decode(in, in_len, *out, out_len); + if (ret < 0) { + free(*out); + *out = NULL; + } + + return ret; +} diff --git a/src/lib/base64url.h b/src/lib/base64url.h new file mode 100644 index 0000000..d355598 --- /dev/null +++ b/src/lib/base64url.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020, CZ.NIC, z.s.p.o. + * All rights reserved. + * + * This file is part of dnsjit. + * + * dnsjit is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * dnsjit is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with dnsjit. If not, see <http://www.gnu.org/licenses/>. + */ + +#include <stdint.h> + +#ifndef __dnsjit_lib_base64url_h +#define __dnsjit_lib_base64url_h + +#include "lib/base64url.hh" + +#endif diff --git a/src/lib/base64url.hh b/src/lib/base64url.hh new file mode 100644 index 0000000..5f19a10 --- /dev/null +++ b/src/lib/base64url.hh @@ -0,0 +1,99 @@ +/* Copyright (C) 2020 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz> + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see <https://www.gnu.org/licenses/>. + */ + +/*! + * \brief Base64url implementation (RFC 4648). + */ + +/*! + * \brief Encodes binary data using Base64. + * + * \note Output data buffer contains Base64 text string which isn't + * terminated with '\0'! + * + * \param in Input binary data. + * \param in_len Length of input data. + * \param out Output data buffer. + * \param out_len Size of output buffer. + * + * \retval >=0 length of output string. + * \retval KNOT_E* if error. + */ +int32_t base64url_encode(const uint8_t* in, + const uint32_t in_len, + uint8_t* out, + const uint32_t out_len); + +/*! + * \brief Encodes binary data using Base64 and output stores to own buffer. + * + * \note Output data buffer contains Base64 text string which isn't + * terminated with '\0'! + * + * \note Output buffer should be deallocated after use. + * + * \param in Input binary data. + * \param in_len Length of input data. + * \param out Output data buffer. + * + * \retval >=0 length of output string. + * \retval KNOT_E* if error. + */ +int32_t base64url_encode_alloc(const uint8_t* in, + const uint32_t in_len, + uint8_t** out); + +/*! + * \brief Decodes text data using Base64. + * + * \note Input data needn't be terminated with '\0'. + * + * \note Input data must be continuous Base64 string! + * + * \param in Input text data. + * \param in_len Length of input string. + * \param out Output data buffer. + * \param out_len Size of output buffer. + * + * \retval >=0 length of output data. + * \retval KNOT_E* if error. + */ +int32_t base64url_decode(const uint8_t* in, + const uint32_t in_len, + uint8_t* out, + const uint32_t out_len); + +/*! + * \brief Decodes text data using Base64 and output stores to own buffer. + * + * \note Input data needn't be terminated with '\0'. + * + * \note Input data must be continuous Base64 string! + * + * \note Output buffer should be deallocated after use. + * + * \param in Input text data. + * \param in_len Length of input string. + * \param out Output data buffer. + * + * \retval >=0 length of output data. + * \retval KNOT_E* if error. + */ +int32_t base64url_decode_alloc(const uint8_t* in, + const uint32_t in_len, + uint8_t** out); + +/*! @} */ diff --git a/src/lib/base64url.lua b/src/lib/base64url.lua new file mode 100644 index 0000000..e9966d5 --- /dev/null +++ b/src/lib/base64url.lua @@ -0,0 +1,98 @@ +-- Copyright (c) 2020, CZ.NIC, z.s.p.o. +-- All rights reserved. +-- +-- This file is part of dnsjit. +-- +-- dnsjit is free software: you can redistribute it and/or modify +-- it under the terms of the GNU General Public License as published by +-- the Free Software Foundation, either version 3 of the License, or +-- (at your option) any later version. +-- +-- dnsjit is distributed in the hope that it will be useful, +-- but WITHOUT ANY WARRANTY; without even the implied warranty of +-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +-- GNU General Public License for more details. +-- +-- You should have received a copy of the GNU General Public License +-- along with dnsjit. If not, see <http://www.gnu.org/licenses/>. + +-- dnsjit.lib.base64url +-- Utility library to convert data to base64url format +-- local base64url = require("dnsjit.lib.base64url") +-- .SS Encoding and decoding lua strings +-- local encoded = base64url.encode("abcd") +-- local decoded = base64url.decode(encoded) +-- .SS Encoding C byte arrays +-- local pl -- pl is core.object.payload +-- local encoded = base64url.encode(pl.payload, pl.len) +-- +-- Encode and decode data to/from base64url format. +module(...,package.seeall) + +require("dnsjit.lib.base64url_h") +local ffi = require("ffi") +local C = ffi.C +local log = require("dnsjit.core.log") +local module_log = log.new("lib.base64url") + +Base64Url = {} + +-- Encode lua string or C byte array to base64url representation. +-- The input string may contain non-printable characters. +-- +-- .B data_len +-- is length of the input data (optional for lua strings, required for +-- C byte arrays). +function Base64Url.encode(data, data_len) + data_len = tonumber(data_len) -- in case of cdata length + if type(data) == "cdata" then + if type(data_len) ~= "number" then + module_log:fatal("encode: data_len must be specified for cdata") + return + end + elseif type(data) ~= "string" then + module_log:fatal("encode: input must be string") + return + end + + if data_len ~= nil and data_len < 0 then + module_log:fatal("encode: data_len must be greater than 0") + return + end + + local in_len = data_len or string.len(data) + local buf_len = math.ceil(4 * in_len / 3) + 2 + local buf = ffi.new("uint8_t[?]", buf_len) + local out_len = ffi.C.base64url_encode(data, in_len, buf, buf_len) + if out_len < 0 then + module_log:critical("encode: error ("..log.errstr(-out_len)..")") + return + end + return ffi.string(buf, out_len) +end + +-- Decode a base64url encoded lua string. +-- The output string may contain non-printable characters. +function Base64Url.decode(data) + if type(data) ~= "string" then + module_log:fatal("decode: input must be string") + return + end + + local in_len = string.len(data) + local buf_len = math.ceil(3 * in_len / 4) + 1 + local buf = ffi.new("uint8_t[?]", buf_len) + local out_len = ffi.C.base64url_decode(data, in_len, buf, buf_len) + if out_len == -34 then -- ERANGE + module_log:critical("decode: error "..log.errstr(-out_len).." - invalid character(s) in input string?") + return + elseif out_len < 0 then + module_log:critical("decode: error "..log.errstr(-out_len)) + return + end + return ffi.string(buf, out_len) +end + +-- dnsjit.core.object.payload(3) +-- dnsjit.output.dnssim (3) +return Base64Url diff --git a/src/lib/clock.c b/src/lib/clock.c new file mode 100644 index 0000000..0676648 --- /dev/null +++ b/src/lib/clock.c @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2018-2021, OARC, Inc. + * All rights reserved. + * + * This file is part of dnsjit. + * + * dnsjit is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * dnsjit is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with dnsjit. If not, see <http://www.gnu.org/licenses/>. + */ + +#include "config.h" + +#include "lib/clock.h" + +#include <time.h> + +core_timespec_t lib_clock_getres(lib_clock_clkid_t clkid) +{ + struct timespec ts; + core_timespec_t ret = { 0, 0 }; + clockid_t clk_id; + + switch (clkid) { + case LIB_CLOCK_REALTIME: + clk_id = CLOCK_REALTIME; + break; + case LIB_CLOCK_MONOTONIC: + clk_id = CLOCK_MONOTONIC; + break; + default: + return ret; + } + + if (!clock_getres(clk_id, &ts)) { + ret.sec = ts.tv_sec; + ret.nsec = ts.tv_nsec; + } + + return ret; +} + +core_timespec_t lib_clock_gettime(lib_clock_clkid_t clkid) +{ + struct timespec ts; + core_timespec_t ret = { 0, 0 }; + clockid_t clk_id; + + switch (clkid) { + case LIB_CLOCK_REALTIME: + clk_id = CLOCK_REALTIME; + break; + case LIB_CLOCK_MONOTONIC: + clk_id = CLOCK_MONOTONIC; + break; + default: + return ret; + } + + if (!clock_gettime(clk_id, &ts)) { + ret.sec = ts.tv_sec; + ret.nsec = ts.tv_nsec; + } + + return ret; +} diff --git a/src/lib/clock.h b/src/lib/clock.h new file mode 100644 index 0000000..0dc0faa --- /dev/null +++ b/src/lib/clock.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2018-2021, OARC, Inc. + * All rights reserved. + * + * This file is part of dnsjit. + * + * dnsjit is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * dnsjit is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with dnsjit. If not, see <http://www.gnu.org/licenses/>. + */ + +#include "core/timespec.h" + +#ifndef __dnsjit_lib_clock_h +#define __dnsjit_lib_clock_h + +#include "lib/clock.hh" + +#endif diff --git a/src/lib/clock.hh b/src/lib/clock.hh new file mode 100644 index 0000000..9463815 --- /dev/null +++ b/src/lib/clock.hh @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2018-2021, OARC, Inc. + * All rights reserved. + * + * This file is part of dnsjit. + * + * dnsjit is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * dnsjit is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with dnsjit. If not, see <http://www.gnu.org/licenses/>. + */ + +//lua:require("dnsjit.core.timespec_h") + +typedef enum lib_clock_clkid { + LIB_CLOCK_REALTIME, + LIB_CLOCK_MONOTONIC +} lib_clock_clkid_t; + +core_timespec_t lib_clock_getres(lib_clock_clkid_t clkid); +core_timespec_t lib_clock_gettime(lib_clock_clkid_t clkid); diff --git a/src/lib/clock.lua b/src/lib/clock.lua new file mode 100644 index 0000000..60ee9f6 --- /dev/null +++ b/src/lib/clock.lua @@ -0,0 +1,47 @@ +-- Copyright (c) 2018-2021, OARC, Inc. +-- All rights reserved. +-- +-- This file is part of dnsjit. +-- +-- dnsjit is free software: you can redistribute it and/or modify +-- it under the terms of the GNU General Public License as published by +-- the Free Software Foundation, either version 3 of the License, or +-- (at your option) any later version. +-- +-- dnsjit is distributed in the hope that it will be useful, +-- but WITHOUT ANY WARRANTY; without even the implied warranty of +-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +-- GNU General Public License for more details. +-- +-- You should have received a copy of the GNU General Public License +-- along with dnsjit. If not, see <http://www.gnu.org/licenses/>. + +-- dnsjit.lib.clock +-- Clock and time functions +-- local clock = require("dnsjit.lib.clock") +-- local sec, nsec = clock.monotonic() +-- +-- Functions to get the time from system-wide clocks. +module(...,package.seeall) + +require("dnsjit.lib.clock_h") +local C = require("ffi").C + +Clock = {} + +-- Return the current seconds and nanoseconds (as a list) from the realtime +-- clock. +function Clock.realtime() + local ts = C.lib_clock_gettime("LIB_CLOCK_REALTIME") + return tonumber(ts.sec), tonumber(ts.nsec) +end + +-- Return the current seconds and nanoseconds (as a list) from the monotonic +-- clock. +function Clock.monotonic() + local ts = C.lib_clock_gettime("LIB_CLOCK_MONOTONIC") + return tonumber(ts.sec), tonumber(ts.nsec) +end + +-- clock_gettime (2) +return Clock diff --git a/src/lib/getopt.lua b/src/lib/getopt.lua new file mode 100644 index 0000000..91ce6cd --- /dev/null +++ b/src/lib/getopt.lua @@ -0,0 +1,365 @@ +-- Copyright (c) 2018-2021, OARC, Inc. +-- All rights reserved. +-- +-- This file is part of dnsjit. +-- +-- dnsjit is free software: you can redistribute it and/or modify +-- it under the terms of the GNU General Public License as published by +-- the Free Software Foundation, either version 3 of the License, or +-- (at your option) any later version. +-- +-- dnsjit is distributed in the hope that it will be useful, +-- but WITHOUT ANY WARRANTY; without even the implied warranty of +-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +-- GNU General Public License for more details. +-- +-- You should have received a copy of the GNU General Public License +-- along with dnsjit. If not, see <http://www.gnu.org/licenses/>. + +-- dnsjit.lib.getopt +-- Parse and handle arguments +-- local getopt = require("dnsjit.lib.getopt").new({ +-- { "v", "verbose", 0, "Enable verbosity", "?+" }, +-- { nil, "host", "localhost", "Set host", "?" }, +-- { "p", nil, 53, "Set port", "?" }, +-- }) +-- . +-- local left = getopt:parse() +-- . +-- print("host", getopt:val("host")) +-- print("port", getopt:val("p")) +-- +-- A "getopt long" implementation to easily handle command line arguments +-- and display usage. +-- An option is the short name (one character), long name, +-- default value (which also defines the type), help text and extensions. +-- Options are by default required, see extensions to change this. +-- .LP +-- The Lua types allowed are +-- .BR boolean , +-- .BR string , +-- .BR number . +-- .LP +-- The extensions available are: +-- .TP +-- .B ? +-- Make the option optional. +-- .TP +-- .B * +-- For string and number options this make it possible to specified it +-- multiple times and all values will be returned in a table. +-- .TP +-- .B + +-- For number options this will act as an counter increaser, the value will +-- be the default value + 1 for each time the option is given. +-- .LP +-- Option +-- .I -h +-- and +-- .I --help +-- are automatically added if the option +-- .I --help +-- is not already defined. +-- .SS Attributes +-- .TP +-- left +-- A table that contains the arguments left after parsing, same as returned by +-- .IR parse() . +-- .TP +-- usage_desc +-- A string that describes the usage of the program, if not set then the +-- default will be " +-- .I "program [options...]" +-- ". +module(...,package.seeall) + +local log = require("dnsjit.core.log") + +local module_log = log.new("lib.getopt") +Getopt = {} + +-- Create a new Getopt object. +-- .I args +-- is a table with tables that specifies the options available. +-- Each entry is unpacked and sent to +-- .BR Getopt:add() . +function Getopt.new(args) + local self = setmetatable({ + left = {}, + usage_desc = nil, + _opt = {}, + _s2l = {}, + _log = log.new("lib.getopt", module_log), + }, { __index = Getopt }) + + self._log:debug("new()") + + for k, v in pairs(args) do + local short, long, default, help, extensions = unpack(v) + self:add(short, long, default, help, extensions) + end + + return self +end + +-- Return the Log object to control logging of this instance or module. +function Getopt:log() + if self == nil then + return module_log + end + return self._log +end + +-- Add an option. +function Getopt:add(short, long, default, help, extensions) + local optional = false + local multiple = false + local counter = false + local name = long or short + + if type(name) ~= "string" then + error("long|short) need to be a string") + elseif name == "" then + error("name (long|short) needs to be set") + end + + if self._opt[name] then + error("option "..name.." alredy exists") + elseif short and self._s2l[short] then + error("option "..short.." alredy exists") + end + + local t = type(default) + if t ~= "string" and t ~= "number" and t ~= "boolean" then + error("option "..name..": invalid type "..t) + end + + if type(extensions) == "string" then + local n + for n = 1, extensions:len() do + local extension = extensions:sub(n, n) + if extension == "?" then + optional = true + elseif extension == "*" then + multiple = true + elseif extension == "+" then + counter = true + else + error("option "..name..": invalid extension "..extension) + end + end + end + + self._opt[name] = { + value = nil, + short = short, + long = long, + type = t, + default = default, + help = help, + optional = optional, + multiple = multiple, + counter = counter, + } + if long and short then + self._s2l[short] = long + elseif short and not long then + self._s2l[short] = short + end + + if not self._opt["help"] then + self._opt["help"] = { + short = nil, + long = "help", + type = "boolean", + default = false, + help = "Display this help text", + optional = true, + } + if not self._s2l["h"] then + self._opt["help"].short = "h" + self._s2l["h"] = "help" + end + end +end + +-- Print the usage. +function Getopt:usage() + if self.usage_desc then + print("usage: " .. self.usage_desc) + else + print("usage: program [options...]") + end + + local opts = {} + for k, _ in pairs(self._opt) do + if k ~= "help" then + table.insert(opts, k) + end + end + table.sort(opts) + table.insert(opts, "help") + + for _, k in pairs(opts) do + local v = self._opt[k] + local arg + if v.type == "string" then + arg = " \""..v.default.."\"" + elseif v.type == "number" and v.counter == false then + arg = " "..v.default + else + arg = "" + end + if v.long then + print("", (v.short and "-"..v.short or " ").." --"..v.long..arg, v.help) + else + print("", "-"..v.short..arg, v.help) + end + end +end + +-- Parse the options. +-- If +-- .I args +-- is not specified or nil then the global +-- .B arg +-- is used. +-- If +-- .I startn +-- is given, it will start parsing arguments in the table from that position. +-- The default position to start at is 2 for +-- .IR dnsjit , +-- see +-- .BR dnsjit.core (3). +function Getopt:parse(args, startn) + if not args then + args = arg + end + + local n + local opt = nil + local left = {} + local need_arg = false + local stop = false + local name + for n = startn or 2, table.maxn(args) do + if need_arg then + if opt.multiple then + if opt.value == nil then + opt.value = {} + end + if opt.type == "number" then + table.insert(opt.value, tonumber(args[n])) + else + table.insert(opt.value, args[n]) + end + else + if opt.type == "number" then + opt.value = tonumber(args[n]) + else + opt.value = args[n] + end + end + need_arg = false + elseif stop or args[n] == "-" then + table.insert(left, args[n]) + elseif args[n] == "--" then + stop = true + elseif args[n]:sub(1, 1) == "-" then + if args[n]:sub(1, 2) == "--" then + name = args[n]:sub(3) + else + name = args[n]:sub(2) + if name:len() > 1 then + local n2, name2 + for n2 = 1, name:len() - 1 do + name2 = name:sub(n2, n2) + opt = self._opt[self._s2l[name2]] + if not opt then + error("unknown option "..name2) + end + if opt.type == "number" and opt.counter then + if opt.value == nil then + opt.value = opt.default + end + opt.value = opt.value + 1 + elseif opt.type == "boolean" then + if opt.value == nil then + opt.value = opt.default + end + if opt.value then + opt.value = false + else + opt.value = true + end + else + error("invalid short option '"..name2.."' in multioption statement") + end + end + name = name:sub(-1) + end + end + if self._s2l[name] then + name = self._s2l[name] + end + if not self._opt[name] then + error("unknown option "..name) + end + opt = self._opt[name] + if opt.type == "string" then + need_arg = true + elseif opt.type == "number" then + if opt.counter then + if opt.value == nil then + opt.value = opt.default + end + opt.value = opt.value + 1 + else + need_arg = true + end + elseif opt.type == "boolean" then + if opt.value == nil then + opt.value = opt.default + end + if opt.value then + opt.value = false + else + opt.value = true + end + else + error("internal error, invalid option type "..opt.type) + end + else + table.insert(left, args[n]) + end + end + + if need_arg then + error("option "..name.." needs argument") + end + + for k, v in pairs(self._opt) do + if v.optional == false and v.value == nil then + error("missing required option "..k.."") + end + end + + self.left = left + return left +end + +-- Return the value of an option. +function Getopt:val(name) + local opt = self._opt[name] or self._opt[self._s2l[name]] + if not opt then + return + end + if opt.value == nil then + return opt.default + else + return opt.value + end +end + +-- dnsjit.core (3) +return Getopt diff --git a/src/lib/ip.lua b/src/lib/ip.lua new file mode 100644 index 0000000..74dc85a --- /dev/null +++ b/src/lib/ip.lua @@ -0,0 +1,125 @@ +-- Copyright (c) 2018-2021, OARC, Inc. +-- All rights reserved. +-- +-- This file is part of dnsjit. +-- +-- dnsjit is free software: you can redistribute it and/or modify +-- it under the terms of the GNU General Public License as published by +-- the Free Software Foundation, either version 3 of the License, or +-- (at your option) any later version. +-- +-- dnsjit is distributed in the hope that it will be useful, +-- but WITHOUT ANY WARRANTY; without even the implied warranty of +-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +-- GNU General Public License for more details. +-- +-- You should have received a copy of the GNU General Public License +-- along with dnsjit. If not, see <http://www.gnu.org/licenses/>. + +-- dnsjit.lib.ip +-- IP address utility library +-- local ip = require("dnsjit.lib.ip") +-- print(ip.ipstring(ipv4_cdata)) +-- print(ip.ip6string(ipv6_cdata), true) +-- +-- A library to help with various IP address related tasks, such as +-- printing them. +module(...,package.seeall) + +local ffi = require("ffi") + +Ip = {} + +-- Return an IPv4 or IPv6 address as a string. +-- If it's an IPv6 address the optional argument +-- .I pretty +-- is true then return an easier to read IPv6 address. +-- Return an empty string on invalid input. +function Ip.tostring(ip, pretty) + if type(ip) == "cdata" then + if ffi.sizeof(ip) == 4 then + return Ip.ipstring(ip) + elseif ffi.sizeof(ip) == 16 then + return Ip.ip6string(ip, pretty) + end + end + return "" +end + +-- Return a IPv4 address as a string. +-- The input is a 4-byte cdata array. +function Ip.ipstring(ip) + return ip[0] ..".".. ip[1] ..".".. ip[2] ..".".. ip[3] +end + +local function _pretty(ip) + local src = {} + + local n, nn + nn = 1 + for n = 0, 15, 2 do + if ip[n] ~= 0 then + src[nn] = string.format("%x%02x", ip[n], ip[n + 1]) + elseif ip[n + 1] ~= 0 then + src[nn] = string.format("%x", ip[n + 1]) + else + src[nn] = "0" + end + nn = nn + 1 + end + + local best_n, best_at, at = 0, 0, 0 + n = 0 + for nn = 1, 8 do + if src[nn] == "0" then + if n == 0 then + at = nn + end + n = n + 1 + else + if n > 0 then + if n > best_n then + best_n = n + best_at = at + end + n = 0 + end + end + end + if n > 0 then + if n > best_n then + best_n = n + best_at = at + end + end + if best_n > 1 then + for n = 2, best_n do + table.remove(src, best_at) + end + if best_at == 1 or best_at + best_n > 8 then + src[best_at] = ":" + else + src[best_at] = "" + end + end + + return table.concat(src,":") +end + +-- Return the IPv6 address as a string. +-- The input is a 16-byte cdata array. +-- If +-- .I pretty +-- is true then return an easier to read IPv6 address. +function Ip.ip6string(ip6, pretty) + if pretty == true then + return _pretty(ip6) + end + return string.format("%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x:%02x%02x", + ip6[0], ip6[1], ip6[2], ip6[3], ip6[4], ip6[5], ip6[6], ip6[7], + ip6[8], ip6[9], ip6[10], ip6[11], ip6[12], ip6[13], ip6[14], ip6[15]) +end + +-- dnsjit.core.object.ip (3), +-- dnsjit.core.object.ip6 (3) +return Ip diff --git a/src/lib/parseconf.lua b/src/lib/parseconf.lua new file mode 100644 index 0000000..638763b --- /dev/null +++ b/src/lib/parseconf.lua @@ -0,0 +1,181 @@ +-- Copyright (c) 2018-2021, OARC, Inc. +-- All rights reserved. +-- +-- This file is part of dnsjit. +-- +-- dnsjit is free software: you can redistribute it and/or modify +-- it under the terms of the GNU General Public License as published by +-- the Free Software Foundation, either version 3 of the License, or +-- (at your option) any later version. +-- +-- dnsjit is distributed in the hope that it will be useful, +-- but WITHOUT ANY WARRANTY; without even the implied warranty of +-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +-- GNU General Public License for more details. +-- +-- You should have received a copy of the GNU General Public License +-- along with dnsjit. If not, see <http://www.gnu.org/licenses/>. + +-- dnsjit.lib.parseconf +-- Parse simple config files +-- local conf = require("dnsjit.lib.parseconf").new() +-- . +-- conf:func("config_name", function(k,...) +-- print(k,...) +-- end) +-- . +-- conf:file(file) +-- . +-- print(conf:val("another_config_name")) +-- +-- This module parses simple config files that are based on the config +-- syntax of DSC, drool and parseconf helper library. +-- Each config begins with a +-- .B name +-- followed by +-- .B options +-- and ends with a +-- .BR ; . +-- Multiple configs can be given on the same line. +-- Valid option types are +-- .IR number , +-- .IR float , +-- .IR string , +-- .IR "quoted string" . +-- Comments can be added by prefixing the comment with +-- .BR # . +-- .SS Example +-- # Comment +-- number 12345; +-- float 123.456; +-- string string string; +-- quoted_string "string string string"; +-- multi config; on one line; +module(...,package.seeall) + +local log = require("dnsjit.core.log") + +local module_log = log.new("lib.parseconf") +Parseconf = {} + +-- Create a new Parseconf object. +function Parseconf.new() + local self = setmetatable({ + conf = {}, + cf = {}, + _log = log.new("lib.parseconf", module_log), + }, { __index = Parseconf }) + + self._log:debug("new()") + + return self +end + +-- Return the Log object to control logging of this instance or module. +function Parseconf:log() + if self == nil then + return module_log + end + return self._log +end + +-- Set a function to call when config +-- .I name +-- is found. +function Parseconf:func(name, func) + self.cf[name] = func +end + +function Parseconf:part(l, n) + local p + p = l:match("^(%d+)[%s;]", n) + if p then + return p, tonumber(p) + end + p = l:match("^(%d+%.%d+)[%s;]", n) + if p then + return p, tonumber(p) + end + p = l:match("^(\"[^\"]+\")[%s;]", n) + if p then + return p, p:sub(2, -2) + end + p = l:match("^([^%s;]+)[%s;]", n) + if p then + return p, p + end +end + +function Parseconf:next(l, n) + local eol = l:match("^%s*;%s*", n) + if eol then + return true, eol + end + local ws = l:match("^%s+", n) + if ws then + return ws + end + return false +end + +-- Parse the given file. +function Parseconf:file(fn) + local ln, l + ln = 1 + for l in io.lines(fn) do + local c = l:find("#") + if c then + l = l:sub(1, c - 1) + end + local e, m = pcall(self.line, self, l) + if e == false then + error("parse error in "..fn.."["..ln.."]: "..m) + end + ln = ln + 1 + end +end + +-- Parse the given line. +function Parseconf:line(l) + local n + n = 1 + while n <= l:len() do + local c = nil + local va = {} + while true do + local p, v = self:part(l, n) + if not p then + error("invalid config at character "..n..": "..l:sub(n)) + end + if not c then + c = p + else + table.insert(va, v) + end + n = n + p:len() + local ws, eol = self:next(l, n) + if ws == true then + if eol then + n = n + eol:len() + end + break + elseif ws == false then + error("invalid config at character "..n..": "..l:sub(n)) + end + n = n + ws:len() + end + if self.cf[c] then + self.cf[c](c, unpack(va)) + else + self.conf[c] = va + end + end +end + +-- Get the value of a config +-- .IR name . +function Parseconf:val(name) + return self.conf[name] +end + +return Parseconf diff --git a/src/lib/trie.c b/src/lib/trie.c new file mode 100644 index 0000000..158c5ca --- /dev/null +++ b/src/lib/trie.c @@ -0,0 +1,923 @@ +/* + * Copyright (C) 2016-2019 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz> + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * The code originated from https://github.com/fanf2/qp/blob/master/qp.c + * at revision 5f6d93753. + */ + +#include <assert.h> +#include <stdlib.h> +#include <string.h> + +#include "lib/trie.h" + +/*! \brief Error codes used in the library. */ +enum knot_error { + KNOT_EOK = 0, + + /* Directly mapped error codes. */ + KNOT_ENOMEM = -ENOMEM, + KNOT_EINVAL = -EINVAL, + KNOT_ENOENT = -ENOENT, +}; + +#if defined(__i386) || defined(__x86_64) || defined(_M_IX86) \ + || (defined(__BYTE_ORDER__) && defined(__ORDER_LITTLE_ENDIAN) \ + && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__) + +/*! + * \brief Use a pointer alignment hack to save memory. + * + * When on, isbranch() relies on the fact that in leaf_t the first pointer + * is aligned on multiple of 4 bytes and that the flags bitfield is + * overlaid over the lowest two bits of that pointer. + * Neither is really guaranteed by the C standards; the second part should + * be OK with x86_64 ABI and most likely any other little-endian platform. + * It would be possible to manipulate the right bits portably, but it would + * complicate the code nontrivially. C++ doesn't even guarantee type-punning. + * In debug mode we check this works OK when creating a new trie instance. + */ +#define FLAGS_HACK 1 +#else +#define FLAGS_HACK 0 +#endif + +typedef unsigned char byte; +#ifndef uint +typedef unsigned int uint; +#define uint uint +#endif +typedef uint bitmap_t; /*! Bit-maps, using the range of 1<<0 to 1<<16 (inclusive). */ + +typedef struct { + uint32_t len; // 32 bits are enough for key lengths; probably even 16 bits would be. + uint8_t chars[]; +} tkey_t; + +/*! \brief Leaf of trie. */ +typedef struct { +#if !FLAGS_HACK + byte flags; +#endif + tkey_t* key; /*!< The pointer must be aligned to 4-byte multiples! */ + trie_val_t val; +} leaf_t; + +/*! \brief A trie node is either leaf_t or branch_t. */ +typedef union node node_t; + +/*! + * \brief Branch node of trie. + * + * - The flags distinguish whether the node is a leaf_t (0), or a branch + * testing the more-important nibble (1) or the less-important one (2). + * - It stores the index of the byte that the node tests. The combined + * value (index*4 + flags) increases in branch nodes as you go deeper + * into the trie. All the keys below a branch are identical up to the + * nibble identified by the branch. Indices have to be stored because + * we skip any branch nodes that would have a single child. + * (Consequently, the skipped parts of key have to be validated in a leaf.) + * - The bitmap indicates which subtries are present. The present child nodes + * are stored in the twigs array (with no holes between them). + * - To simplify storing keys that are prefixes of each other, the end-of-string + * position is treated as another nibble value, ordered before all others. + * That affects the bitmap and twigs fields. + * + * \note The branch nodes are never allocated individually, but they are + * always part of either the root node or the twigs array of the parent. + */ +typedef struct { +#if FLAGS_HACK + uint32_t flags : 2, + bitmap : 17; /*!< The first bitmap bit is for end-of-string child. */ +#else + byte flags; + uint32_t bitmap; +#endif + uint32_t index; + node_t* twigs; +} branch_t; + +union node { + leaf_t leaf; + branch_t branch; +}; + +struct trie { + node_t root; // undefined when weight == 0, see empty_root() + size_t weight; + knot_mm_t mm; +}; + +/* Included from other files */ + +/** Readability: avoid const-casts in code. */ +static inline void free_const(const void* what) +{ + free((void*)what); +} + +static inline void* mm_alloc(knot_mm_t* mm, size_t size) +{ + if (mm) + return mm->alloc(mm->ctx, size); + else + return malloc(size); +} + +static inline void mm_free(knot_mm_t* mm, const void* what) +{ + if (mm) { + if (mm->free) + mm->free((void*)what); + } else + free_const(what); +} + +static void* mm_malloc(void* ctx, size_t n) +{ + (void)ctx; + return malloc(n); +} + +static void* mm_realloc(knot_mm_t* mm, void* what, size_t size, size_t prev_size) +{ + if (mm) { + void* p = mm->alloc(mm->ctx, size); + if (p == NULL) { + return NULL; + } else { + if (what) { + memcpy(p, what, + prev_size < size ? prev_size : size); + } + mm_free(mm, what); + return p; + } + } else { + return realloc(what, size); + } +} + +static inline void mm_ctx_init(knot_mm_t* mm) +{ + mm->ctx = NULL; + mm->alloc = mm_malloc; + mm->free = free; +} + +/*! \brief Make the root node empty (debug-only). */ +static inline void empty_root(node_t* root) +{ +#ifndef NDEBUG + *root = (node_t) { .branch = { + .flags = 3, // invalid value that fits + .bitmap = 0, + .index = -1, + .twigs = NULL } }; +#endif +} + +/*! \brief Check that unportable code works OK (debug-only). */ +static void assert_portability(void) +{ +#if FLAGS_HACK + assert(((union node) { .leaf = { + .key = (tkey_t*)(((uint8_t*)NULL) + 1), + .val = NULL } }) + .branch.flags + == 1); +#endif +} + +/*! \brief Propagate error codes. */ +#define ERR_RETURN(x) \ + do { \ + int err_code_ = x; \ + if (unlikely(err_code_ != KNOT_EOK)) \ + return err_code_; \ + } while (false) + +/*! + * \brief Count the number of set bits. + * + * \TODO This implementation may be relatively slow on some HW. + */ +static uint bitmap_weight(bitmap_t w) +{ + assert((w & ~((1 << 17) - 1)) == 0); // using the least-important 17 bits + return __builtin_popcount(w); +} + +/*! \brief Only keep the lowest bit in the bitmap (least significant -> twigs[0]). */ +static bitmap_t bitmap_lowest_bit(bitmap_t w) +{ + assert((w & ~((1 << 17) - 1)) == 0); // using the least-important 17 bits + return 1 << __builtin_ctz(w); +} + +/*! \brief Test flags to determine type of this node. */ +static bool isbranch(const node_t* t) +{ + uint f = t->branch.flags; + assert(f <= 2); + return f != 0; +} + +/*! \brief Make a bitmask for testing a branch bitmap. */ +static bitmap_t nibbit(byte k, uint flags) +{ + uint shift = (2 - flags) << 2; + uint nibble = (k >> shift) & 0xf; + return 1 << (nibble + 1 /*because of prefix keys*/); +} + +/*! \brief Extract a nibble from a key and turn it into a bitmask. */ +static bitmap_t twigbit(const node_t* t, const uint8_t* key, uint32_t len) +{ + assert(isbranch(t)); + uint i = t->branch.index; + + if (i >= len) + return 1 << 0; // leaf position + + return nibbit((byte)key[i], t->branch.flags); +} + +/*! \brief Test if a branch node has a child indicated by a bitmask. */ +static bool hastwig(const node_t* t, bitmap_t bit) +{ + assert(isbranch(t)); + return t->branch.bitmap & bit; +} + +/*! \brief Compute offset of an existing child in a branch node. */ +static uint twigoff(const node_t* t, bitmap_t b) +{ + assert(isbranch(t)); + return bitmap_weight(t->branch.bitmap & (b - 1)); +} + +/*! \brief Get pointer to a particular child of a branch node. */ +static node_t* twig(node_t* t, uint i) +{ + assert(isbranch(t)); + return &t->branch.twigs[i]; +} + +/*! + * \brief For a branch nod, compute offset of a child and child count. + * + * Having this separate might be meaningful for performance optimization. + */ +#define TWIGOFFMAX(off, max, t, b) \ + do { \ + (off) = twigoff((t), (b)); \ + (max) = bitmap_weight((t)->branch.bitmap); \ + } while (0) + +/*! \brief Simple string comparator. */ +static int key_cmp(const uint8_t* k1, uint32_t k1_len, const uint8_t* k2, uint32_t k2_len) +{ + int ret = memcmp(k1, k2, MIN(k1_len, k2_len)); + if (ret != 0) { + return ret; + } + + /* Key string is equal, compare lengths. */ + if (k1_len == k2_len) { + return 0; + } else if (k1_len < k2_len) { + return -1; + } else { + return 1; + } +} + +trie_t* trie_create(knot_mm_t* mm) +{ + assert_portability(); + trie_t* trie = mm_alloc(mm, sizeof(trie_t)); + if (trie != NULL) { + empty_root(&trie->root); + trie->weight = 0; + if (mm != NULL) + trie->mm = *mm; + else + mm_ctx_init(&trie->mm); + } + return trie; +} + +/*! \brief Free anything under the trie node, except for the passed pointer itself. */ +static void clear_trie(node_t* trie, knot_mm_t* mm) +{ + if (!isbranch(trie)) { + mm_free(mm, trie->leaf.key); + } else { + branch_t* b = &trie->branch; + int len = bitmap_weight(b->bitmap); + int i; + for (i = 0; i < len; ++i) + clear_trie(b->twigs + i, mm); + mm_free(mm, b->twigs); + } +} + +void trie_free(trie_t* tbl) +{ + if (tbl == NULL) + return; + if (tbl->weight) + clear_trie(&tbl->root, &tbl->mm); + mm_free(&tbl->mm, tbl); +} + +void trie_clear(trie_t* tbl) +{ + assert(tbl); + if (!tbl->weight) + return; + clear_trie(&tbl->root, &tbl->mm); + empty_root(&tbl->root); + tbl->weight = 0; +} + +size_t trie_weight(const trie_t* tbl) +{ + assert(tbl); + return tbl->weight; +} + +struct found { + leaf_t* l; /**< the found leaf (NULL if not found) */ + branch_t* p; /**< the leaf's parent (if exists) */ + bitmap_t b; /**< bit-mask with a single bit marking l under p */ +}; +/** Search trie for an item with the given key (equality only). */ +static struct found find_equal(trie_t* tbl, const uint8_t* key, uint32_t len) +{ + assert(tbl); + struct found ret0; + memset(&ret0, 0, sizeof(ret0)); + if (!tbl->weight) + return ret0; + /* Current node and parent while descending (returned values basically). */ + node_t* t = &tbl->root; + branch_t* p = NULL; + bitmap_t b = 0; + while (isbranch(t)) { + __builtin_prefetch(t->branch.twigs); + b = twigbit(t, key, len); + if (!hastwig(t, b)) + return ret0; + p = &t->branch; + t = twig(t, twigoff(t, b)); + } + if (key_cmp(key, len, t->leaf.key->chars, t->leaf.key->len) != 0) + return ret0; + return (struct found) { + .l = &t->leaf, + .p = p, + .b = b, + }; +} +/** Find item with the first key (lexicographical order). */ +static struct found find_first(trie_t* tbl) +{ + assert(tbl); + if (!tbl->weight) { + struct found ret0; + memset(&ret0, 0, sizeof(ret0)); + return ret0; + } + /* Current node and parent while descending (returned values basically). */ + node_t* t = &tbl->root; + branch_t* p = NULL; + while (isbranch(t)) { + p = &t->branch; + t = &p->twigs[0]; + } + return (struct found) { + .l = &t->leaf, + .p = p, + .b = p ? bitmap_lowest_bit(p->bitmap) : 0, + }; +} + +trie_val_t* trie_get_try(trie_t* tbl, const uint8_t* key, uint32_t len) +{ + struct found found = find_equal(tbl, key, len); + return found.l ? &found.l->val : NULL; +} + +trie_val_t* trie_get_first(trie_t* tbl, uint8_t** key, uint32_t* len) +{ + struct found found = find_first(tbl); + if (!found.l) + return NULL; + if (key) + *key = found.l->key->chars; + if (len) + *len = found.l->key->len; + return &found.l->val; +} + +/*! + * \brief Stack of nodes, storing a path down a trie. + * + * The structure also serves directly as the public trie_it_t type, + * in which case it always points to the current leaf, unless we've finished + * (i.e. it->len == 0). + */ +typedef struct trie_it { + node_t** stack; /*!< The stack; malloc is used directly instead of mm. */ + uint32_t len; /*!< Current length of the stack. */ + uint32_t alen; /*!< Allocated/available length of the stack. */ + /*! \brief Initial storage for \a stack; it should fit in many use cases. */ + node_t* stack_init[60]; +} nstack_t; + +/*! \brief Create a node stack containing just the root (or empty). */ +static void ns_init(nstack_t* ns, trie_t* tbl) +{ + assert(tbl); + ns->stack = ns->stack_init; + ns->alen = sizeof(ns->stack_init) / sizeof(ns->stack_init[0]); + if (tbl->weight) { + ns->len = 1; + ns->stack[0] = &tbl->root; + } else { + ns->len = 0; + } +} + +/*! \brief Free inside of the stack, i.e. not the passed pointer itself. */ +static void ns_cleanup(nstack_t* ns) +{ + assert(ns && ns->stack); + if (likely(ns->stack == ns->stack_init)) + return; + free(ns->stack); +#ifndef NDEBUG + ns->stack = NULL; + ns->alen = 0; +#endif +} + +/*! \brief Allocate more space for the stack. */ +static int ns_longer_alloc(nstack_t* ns) +{ + ns->alen *= 2; + size_t new_size = sizeof(nstack_t) + ns->alen * sizeof(node_t*); + node_t** st; + if (ns->stack == ns->stack_init) { + st = malloc(new_size); + if (st != NULL) + memcpy(st, ns->stack, ns->len * sizeof(node_t*)); + } else { + st = realloc(ns->stack, new_size); + } + if (st == NULL) + return KNOT_ENOMEM; + ns->stack = st; + return KNOT_EOK; +} + +/*! \brief Ensure the node stack can be extended by one. */ +static inline int ns_longer(nstack_t* ns) +{ + // get a longer stack if needed + if (likely(ns->len < ns->alen)) + return KNOT_EOK; + return ns_longer_alloc(ns); // hand-split the part suitable for inlining +} + +/*! + * \brief Find the "branching point" as if searching for a key. + * + * The whole path to the point is kept on the passed stack; + * always at least the root will remain on the top of it. + * Beware: the precise semantics of this function is rather tricky. + * The top of the stack will contain: the corresponding leaf if exact match is found; + * or the immediate node below a branching-point-on-edge or the branching-point itself. + * + * \param info Set position of the point of first mismatch (in index and flags). + * \param first Set the value of the first non-matching character (from trie), + * optionally; end-of-string character has value -256 (that's why it's int). + * Note: the character is converted to *unsigned* char (i.e. 0..255), + * as that's the ordering used in the trie. + * + * \return KNOT_EOK or KNOT_ENOMEM. + */ +static int ns_find_branch(nstack_t* ns, const uint8_t* key, uint32_t len, + branch_t* info, int* first) +{ + assert(ns && ns->len && info); + // First find some leaf with longest matching prefix. + while (isbranch(ns->stack[ns->len - 1])) { + ERR_RETURN(ns_longer(ns)); + node_t* t = ns->stack[ns->len - 1]; + __builtin_prefetch(t->branch.twigs); + bitmap_t b = twigbit(t, key, len); + // Even if our key is missing from this branch we need to + // keep iterating down to a leaf. It doesn't matter which + // twig we choose since the keys are all the same up to this + // index. Note that blindly using twigoff(t, b) can cause + // an out-of-bounds index if it equals twigmax(t). + uint i = hastwig(t, b) ? twigoff(t, b) : 0; + ns->stack[ns->len++] = twig(t, i); + } + tkey_t* lkey = ns->stack[ns->len - 1]->leaf.key; + // Find index of the first char that differs. + uint32_t index = 0; + while (index < MIN(len, lkey->len)) { + if (key[index] != lkey->chars[index]) + break; + else + ++index; + } + info->index = index; + if (first) + *first = lkey->len > index ? (unsigned char)lkey->chars[index] : -256; + // Find flags: which half-byte has matched. + uint flags; + if (index == len && len == lkey->len) { // found equivalent key + info->flags = flags = 0; + goto success; + } + if (likely(index < MIN(len, lkey->len))) { + byte k2 = (byte)lkey->chars[index]; + byte k1 = (byte)key[index]; + flags = ((k1 ^ k2) & 0xf0) ? 1 : 2; + } else { // one is prefix of another + flags = 1; + } + info->flags = flags; + // now go up the trie from the current leaf + branch_t* t; + do { + if (unlikely(ns->len == 1)) + goto success; // only the root stays on the stack + t = (branch_t*)ns->stack[ns->len - 2]; + if (t->index < index || (t->index == index && t->flags < flags)) + goto success; + --ns->len; + } while (true); +success: +#ifndef NDEBUG // invariants on successful return + assert(ns->len); + if (isbranch(ns->stack[ns->len - 1])) { + t = &ns->stack[ns->len - 1]->branch; + assert(t->index > index || (t->index == index && t->flags >= flags)); + } + if (ns->len > 1) { + t = &ns->stack[ns->len - 2]->branch; + assert(t->index < index || (t->index == index && (t->flags < flags || (t->flags == 1 && flags == 0)))); + } +#endif + return KNOT_EOK; +} + +/*! + * \brief Advance the node stack to the last leaf in the subtree. + * + * \return KNOT_EOK or KNOT_ENOMEM. + */ +static int ns_last_leaf(nstack_t* ns) +{ + assert(ns); + do { + ERR_RETURN(ns_longer(ns)); + node_t* t = ns->stack[ns->len - 1]; + if (!isbranch(t)) + return KNOT_EOK; + int lasti = bitmap_weight(t->branch.bitmap) - 1; + assert(lasti >= 0); + ns->stack[ns->len++] = twig(t, lasti); + } while (true); +} + +/*! + * \brief Advance the node stack to the first leaf in the subtree. + * + * \return KNOT_EOK or KNOT_ENOMEM. + */ +static int ns_first_leaf(nstack_t* ns) +{ + assert(ns && ns->len); + do { + ERR_RETURN(ns_longer(ns)); + node_t* t = ns->stack[ns->len - 1]; + if (!isbranch(t)) + return KNOT_EOK; + ns->stack[ns->len++] = twig(t, 0); + } while (true); +} + +/*! + * \brief Advance the node stack to the leaf that is previous to the current node. + * + * \note Prefix leaf under the current node DOES count (if present; perhaps questionable). + * \return KNOT_EOK on success, KNOT_ENOENT on not-found, or possibly KNOT_ENOMEM. + */ +static int ns_prev_leaf(nstack_t* ns) +{ + assert(ns && ns->len > 0); + + node_t* t = ns->stack[ns->len - 1]; + if (hastwig(t, 1 << 0)) { // the prefix leaf + t = twig(t, 0); + ERR_RETURN(ns_longer(ns)); + ns->stack[ns->len++] = t; + return KNOT_EOK; + } + + do { + if (ns->len < 2) + return KNOT_ENOENT; // root without empty key has no previous leaf + t = ns->stack[ns->len - 1]; + node_t* p = ns->stack[ns->len - 2]; + int pindex = t - p->branch.twigs; // index in parent via pointer arithmetic + assert(pindex >= 0 && pindex <= 16); + if (pindex > 0) { // t isn't the first child -> go down the previous one + ns->stack[ns->len - 1] = twig(p, pindex - 1); + return ns_last_leaf(ns); + } + // we've got to go up again + --ns->len; + } while (true); +} + +/*! + * \brief Advance the node stack to the leaf that is successor to the current node. + * + * \note Prefix leaf or anything else under the current node DOES count. + * \return KNOT_EOK on success, KNOT_ENOENT on not-found, or possibly KNOT_ENOMEM. + */ +static int ns_next_leaf(nstack_t* ns) +{ + assert(ns && ns->len > 0); + + node_t* t = ns->stack[ns->len - 1]; + if (isbranch(t)) + return ns_first_leaf(ns); + do { + if (ns->len < 2) + return KNOT_ENOENT; // not found, as no more parent is available + t = ns->stack[ns->len - 1]; + node_t* p = ns->stack[ns->len - 2]; + int pindex = t - p->branch.twigs; // index in parent via pointer arithmetic + assert(pindex >= 0 && pindex <= 16); + int pcount = bitmap_weight(p->branch.bitmap); + if (pindex + 1 < pcount) { // t isn't the last child -> go down the next one + ns->stack[ns->len - 1] = twig(p, pindex + 1); + return ns_first_leaf(ns); + } + // we've got to go up again + --ns->len; + } while (true); +} + +int trie_get_leq(trie_t* tbl, const uint8_t* key, uint32_t len, trie_val_t** val) +{ + assert(tbl && val); + *val = NULL; // so on failure we can just return; + if (tbl->weight == 0) + return KNOT_ENOENT; + { // Intentionally un-indented; until end of function, to bound cleanup attr. + // First find a key with longest-matching prefix + __attribute__((cleanup(ns_cleanup))) + nstack_t ns_local; + ns_init(&ns_local, tbl); + nstack_t* ns = &ns_local; + branch_t bp; + int un_leaf; // first unmatched character in the leaf + ERR_RETURN(ns_find_branch(ns, key, len, &bp, &un_leaf)); + int un_key = bp.index < len ? (unsigned char)key[bp.index] : -256; + node_t* t = ns->stack[ns->len - 1]; + if (bp.flags == 0) { // found exact match + *val = &t->leaf.val; + return KNOT_EOK; + } + // Get t: the last node on matching path + if (isbranch(t) && t->branch.index == bp.index && t->branch.flags == bp.flags) { + // t is OK + } else { + // the top of the stack was the first unmatched node -> step up + if (ns->len == 1) { + // root was unmatched already + if (un_key < un_leaf) + return KNOT_ENOENT; + ERR_RETURN(ns_last_leaf(ns)); + goto success; + } + --ns->len; + t = ns->stack[ns->len - 1]; + } + // Now we re-do the first "non-matching" step in the trie + // but try the previous child if key was less (it may not exist) + bitmap_t b = twigbit(t, key, len); + int i = hastwig(t, b) + ? twigoff(t, b) - (un_key < un_leaf) + : twigoff(t, b) - 1 /*twigoff returns successor when !hastwig*/; + if (i >= 0) { + ERR_RETURN(ns_longer(ns)); + ns->stack[ns->len++] = twig(t, i); + ERR_RETURN(ns_last_leaf(ns)); + } else { + ERR_RETURN(ns_prev_leaf(ns)); + } + success: + assert(!isbranch(ns->stack[ns->len - 1])); + *val = &ns->stack[ns->len - 1]->leaf.val; + return 1; + } +} + +/*! \brief Initialize a new leaf, copying the key, and returning failure code. */ +static int mk_leaf(node_t* leaf, const uint8_t* key, uint32_t len, knot_mm_t* mm) +{ + tkey_t* k = mm_alloc(mm, sizeof(tkey_t) + len); +#if FLAGS_HACK + assert(((uintptr_t)k) % 4 == 0); // we need an aligned pointer +#endif + if (unlikely(!k)) + return KNOT_ENOMEM; + k->len = len; + memcpy(k->chars, key, len); + leaf->leaf = (leaf_t) + { +#if !FLAGS_HACK + .flags = 0, +#endif + .val = NULL, + .key = k + }; + return KNOT_EOK; +} + +trie_val_t* trie_get_ins(trie_t* tbl, const uint8_t* key, uint32_t len) +{ + assert(tbl); + // First leaf in an empty tbl? + if (unlikely(!tbl->weight)) { + if (unlikely(mk_leaf(&tbl->root, key, len, &tbl->mm))) + return NULL; + ++tbl->weight; + return &tbl->root.leaf.val; + } + { // Intentionally un-indented; until end of function, to bound cleanup attr. + // Find the branching-point + __attribute__((cleanup(ns_cleanup))) + nstack_t ns_local; + ns_init(&ns_local, tbl); + nstack_t* ns = &ns_local; + branch_t bp; // branch-point: index and flags signifying the longest common prefix + int k2; // the first unmatched character in the leaf + if (unlikely(ns_find_branch(ns, key, len, &bp, &k2))) + return NULL; + node_t* t = ns->stack[ns->len - 1]; + if (bp.flags == 0) // the same key was already present + return &t->leaf.val; + node_t leaf; + if (unlikely(mk_leaf(&leaf, key, len, &tbl->mm))) + return NULL; + + if (isbranch(t) && bp.index == t->branch.index && bp.flags == t->branch.flags) { + // The node t needs a new leaf child. + bitmap_t b1 = twigbit(t, key, len); + assert(!hastwig(t, b1)); + uint s, m; + TWIGOFFMAX(s, m, t, b1); // new child position and original child count + node_t* twigs = mm_realloc(&tbl->mm, t->branch.twigs, + sizeof(node_t) * (m + 1), sizeof(node_t) * m); + if (unlikely(!twigs)) + goto err_leaf; + memmove(twigs + s + 1, twigs + s, sizeof(node_t) * (m - s)); + twigs[s] = leaf; + t->branch.twigs = twigs; + t->branch.bitmap |= b1; + ++tbl->weight; + return &twigs[s].leaf.val; + } else { +// We need to insert a new binary branch with leaf at *t. +// Note: it works the same for the case where we insert above root t. +#ifndef NDEBUG + if (ns->len > 1) { + node_t* pt = ns->stack[ns->len - 2]; + assert(hastwig(pt, twigbit(pt, key, len))); + } +#endif + node_t* twigs = mm_alloc(&tbl->mm, sizeof(node_t) * 2); + if (unlikely(!twigs)) + goto err_leaf; + node_t t2 = *t; // Save before overwriting t. + t->branch.flags = bp.flags; + t->branch.index = bp.index; + t->branch.twigs = twigs; + bitmap_t b1 = twigbit(t, key, len); + bitmap_t b2 = unlikely(k2 == -256) ? (1 << 0) : nibbit(k2, bp.flags); + t->branch.bitmap = b1 | b2; + *twig(t, twigoff(t, b1)) = leaf; + *twig(t, twigoff(t, b2)) = t2; + ++tbl->weight; + return &twig(t, twigoff(t, b1))->leaf.val; + }; + err_leaf: + mm_free(&tbl->mm, leaf.leaf.key); + return NULL; + } +} + +/*! \brief Apply a function to every trie_val_t*, in order; a recursive solution. */ +static int apply_trie(node_t* t, int (*f)(trie_val_t*, void*), void* d) +{ + assert(t); + if (!isbranch(t)) + return f(&t->leaf.val, d); + int child_count = bitmap_weight(t->branch.bitmap); + int i; + for (i = 0; i < child_count; ++i) + ERR_RETURN(apply_trie(twig(t, i), f, d)); + return KNOT_EOK; +} + +int trie_apply(trie_t* tbl, int (*f)(trie_val_t*, void*), void* d) +{ + assert(tbl && f); + if (!tbl->weight) + return KNOT_EOK; + return apply_trie(&tbl->root, f, d); +} + +/* These are all thin wrappers around static Tns* functions. */ +trie_it_t* trie_it_begin(trie_t* tbl) +{ + assert(tbl); + trie_it_t* it = malloc(sizeof(nstack_t)); + if (!it) + return NULL; + ns_init(it, tbl); + if (it->len == 0) // empty tbl + return it; + if (ns_first_leaf(it)) { + ns_cleanup(it); + free(it); + return NULL; + } + return it; +} + +void trie_it_next(trie_it_t* it) +{ + assert(it && it->len); + if (ns_next_leaf(it) != KNOT_EOK) + it->len = 0; +} + +bool trie_it_finished(trie_it_t* it) +{ + assert(it); + return it->len == 0; +} + +void trie_it_free(trie_it_t* it) +{ + if (!it) + return; + ns_cleanup(it); + free(it); +} + +const uint8_t* trie_it_key(trie_it_t* it, size_t* len) +{ + assert(it && it->len); + node_t* t = it->stack[it->len - 1]; + assert(!isbranch(t)); + tkey_t* key = t->leaf.key; + if (len) + *len = key->len; + return key->chars; +} + +trie_val_t* trie_it_val(trie_it_t* it) +{ + assert(it && it->len); + node_t* t = it->stack[it->len - 1]; + assert(!isbranch(t)); + return &t->leaf.val; +} diff --git a/src/lib/trie.h b/src/lib/trie.h new file mode 100644 index 0000000..3ce881c --- /dev/null +++ b/src/lib/trie.h @@ -0,0 +1,39 @@ +/* + * Copyright (C) 2017-2019 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz> + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#include <errno.h> +#include <stdbool.h> +#include <stdint.h> + +#ifndef __dnsjit_contrib_trie_h +#define __dnsjit_contrib_trie_h + +#include "lib/trie.hh" + +#ifndef likely +/*! \brief Optimize for x to be true value. */ +#define likely(x) __builtin_expect((x), 1) +#endif + +#ifndef unlikely +/*! \brief Optimize for x to be false value. */ +#define unlikely(x) __builtin_expect((x), 0) +#endif + +#define MIN(a, b) (((a) < (b)) ? (a) : (b)) /** Minimum of two numbers **/ + +#endif diff --git a/src/lib/trie.hh b/src/lib/trie.hh new file mode 100644 index 0000000..60c8f8a --- /dev/null +++ b/src/lib/trie.hh @@ -0,0 +1,118 @@ +/* + * Copyright (C) 2017-2019 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz> + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +/* Memory allocation function prototypes. */ +typedef void* (*knot_mm_alloc_t)(void* ctx, size_t len); +typedef void (*knot_mm_free_t)(void* p); + +/*! \brief Memory allocation context. */ +typedef struct knot_mm { + void* ctx; /* \note Must be first */ + knot_mm_alloc_t alloc; + knot_mm_free_t free; +} knot_mm_t; + +/*! + * \brief Native API of QP-tries: + * + * - keys are uint8_t strings, not necessarily zero-terminated, + * the structure copies the contents of the passed keys + * - values are void* pointers, typically you get an ephemeral pointer to it + * - key lengths are limited by 2^32-1 ATM + */ + +/*! \brief Element value. */ +typedef void* trie_val_t; + +/*! \brief Opaque structure holding a QP-trie. */ +typedef struct trie trie_t; + +/*! \brief Opaque type for holding a QP-trie iterator. */ +typedef struct trie_it trie_it_t; + +/*! \brief Create a trie instance. Pass NULL to use malloc+free. */ +trie_t* trie_create(knot_mm_t* mm); + +/*! \brief Free a trie instance. */ +void trie_free(trie_t* tbl); + +/*! \brief Clear a trie instance (make it empty). */ +void trie_clear(trie_t* tbl); + +/*! \brief Return the number of keys in the trie. */ +size_t trie_weight(const trie_t* tbl); + +/*! \brief Search the trie, returning NULL on failure. */ +trie_val_t* trie_get_try(trie_t* tbl, const uint8_t* key, uint32_t len); + +/*! + * \brief Return pointer to the minimum. Optionally with key and its length. */ +trie_val_t* trie_get_first(trie_t* tbl, uint8_t** key, uint32_t* len); + +/*! \brief Search the trie, inserting NULL trie_val_t on failure. */ +trie_val_t* trie_get_ins(trie_t* tbl, const uint8_t* key, uint32_t len); + +/*! + * \brief Search for less-or-equal element. + * + * \param tbl Trie. + * \param key Searched key. + * \param len Key length. + * \param val Must be valid; it will be set to NULL if not found or errored. + * \return KNOT_EOK for exact match, 1 for previous, KNOT_ENOENT for not-found, + * or KNOT_E*. + */ +int trie_get_leq(trie_t* tbl, const uint8_t* key, uint32_t len, trie_val_t** val); + +/*! + * \brief Apply a function to every trie_val_t, in order. + * + * \param d Parameter passed as the second argument to f(). + * \return First nonzero from f() or zero (i.e. KNOT_EOK). + */ +int trie_apply(trie_t* tbl, int (*f)(trie_val_t*, void*), void* d); + +/*! \brief Create a new iterator pointing to the first element (if any). */ +trie_it_t* trie_it_begin(trie_t* tbl); + +/*! + * \brief Advance the iterator to the next element. + * + * Iteration is in ascending lexicographical order. + * In particular, the empty string would be considered as the very first. + * + * \note You may not use this function if the trie's key-set has been modified + * during the lifetime of the iterator (modifying values only is OK). + */ +void trie_it_next(trie_it_t* it); + +/*! \brief Test if the iterator has gone past the last element. */ +bool trie_it_finished(trie_it_t* it); + +/*! \brief Free any resources of the iterator. It's OK to call it on NULL. */ +void trie_it_free(trie_it_t* it); + +/*! + * \brief Return pointer to the key of the current element. + * + * \note The optional len is uint32_t internally but size_t is better for our usage, + * as it is without an additional type conversion. + */ +const uint8_t* trie_it_key(trie_it_t* it, size_t* len); + +/*! \brief Return pointer to the value of the current element (writable). */ +trie_val_t* trie_it_val(trie_it_t* it); diff --git a/src/lib/trie.lua b/src/lib/trie.lua new file mode 100644 index 0000000..f7fee03 --- /dev/null +++ b/src/lib/trie.lua @@ -0,0 +1,172 @@ +-- Copyright (c) 2020, CZ.NIC, z.s.p.o. +-- All rights reserved. +-- +-- This file is part of dnsjit. +-- +-- dnsjit is free software: you can redistribute it and/or modify +-- it under the terms of the GNU General Public License as published by +-- the Free Software Foundation, either version 3 of the License, or +-- (at your option) any later version. +-- +-- dnsjit is distributed in the hope that it will be useful, +-- but WITHOUT ANY WARRANTY; without even the implied warranty of +-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +-- GNU General Public License for more details. +-- +-- You should have received a copy of the GNU General Public License +-- along with dnsjit. If not, see <http://www.gnu.org/licenses/>. + +-- dnsjit.lib.trie +-- Prefix-tree data structure which addresses values by strings or byte arrays +-- .SS Binary-key trie with integer values +-- local trie = require("dnsjit.lib.trie").new("uint64_t", true, 4) +-- -- assume we have a bunch of dnsjit.core.object.ip packets to process +-- for _, pkt in pairs(pkts) do +-- local node = trie:get_ins(pkt.src) +-- local value = node:get() -- new nodes' values are initialized to 0 +-- node:set(value + 1) +-- end +-- -- iterate over unique IPs and print number of packets received from each +-- local iter = trie:iter() +-- local node = iter:node() +-- local p = require("dnsjit.lib.ip") +-- while node ~= nil do +-- local ip_bytes = node:key() +-- local npkts = tonumber(node:get()) +-- print(ip.tostring(ip_bytes).." sent "..npkts.." packets") +-- iter:next() +-- node = iter:node() +-- end +-- .SS String-key trie with cdata values +-- local trie = require("dnsjit.lib.trie").new("core_object_t*") +-- local obj1 -- assume this contains cdata of type core_object_t* +-- local node = trie:get_ins("obj1") +-- node:set(obj1) +-- node = trie:get_try("obj1") +-- assert(node:get() == obj1) +-- +-- Prefix-tree data structure that stores values indexed by strings or byte +-- arrays, such as IP addresses. +-- Values of size up to sizeof(size_t) can be stored directly, otherwise +-- a pointer must be used. +module(...,package.seeall) + +require("dnsjit.lib.trie_h") +local ffi = require("ffi") +local C = ffi.C +local log = require("dnsjit.core.log") +local module_log = log.new("lib.trie") +local TrieNode = require("dnsjit.lib.trie.node") +local TrieIter = require("dnsjit.lib.trie.iter") + +Trie = {} + +-- Create a new Trie that stores +-- .I ctype +-- values as data. +-- By default, keys are handled as strings. +-- To use trie with byte arrays, set +-- .I binary +-- to true. +-- Optionally, +-- .I keylen +-- may be specified as a default keylen for binary keys. +-- For string keys, their string length is used by default. +function Trie.new(ctype, binary, keylen) + if ctype == nil then + module_log:fatal("missing value ctype") + end + if ffi.sizeof(ctype) > ffi.sizeof("void *") then + module_log:fatal("data type exceeds max size, use a pointer instead") + end + if keylen ~= nil and not binary then + module_log:warning("setting keylen has no effect for string-key tries") + end + + local self = setmetatable({ + obj = C.trie_create(nil), + _binary = binary, + _keylen = keylen, + _ctype = ctype, + _log = log.new("lib.trie", module_log), + }, { __index = Trie }) + + ffi.gc(self.obj, C.trie_free) + + return self +end + +function Trie:_get_keylen(key, keylen) + if keylen ~= nil then + if type(keylen) == "number" then + return keylen + else + self._log:fatal("keylen must be numeric") + end + end + if not self._binary then + if type(key) == "string" then + return string.len(key) + else + self._log:fatal("key must be string when using trie with non-binary keys") + end + end + if not self._keylen or type(self._keylen) ~= "number" then + self._log:fatal("default keylen not set or invalid") + end + return self._keylen +end + +-- Return the Log object to control logging of this instance or module. +function Trie:log() + if self == nil then + return module_log + end + return self._log +end + +-- Clear the trie instance (make it empty). +function Trie:clear() + C.trie_clear(self.obj) +end + +-- Return the number of keys in the trie. +function Trie:weight() + return tonumber(C.trie_weight(self.obj)) +end + +-- Search the trie and return nil of failure. +function Trie:get_try(key, keylen) + keylen = self:_get_keylen(key, keylen) + local val = C.trie_get_try(self.obj, key, keylen) + if val == nil then return nil end + val = ffi.cast("trie_val_t *", val) + return TrieNode.new(self, val, key, keylen) +end + +-- Search the trie and insert an empty node (with value set to 0) on failure. +function Trie:get_ins(key, keylen) + keylen = self:_get_keylen(key, keylen) + local val = C.trie_get_ins(self.obj, key, keylen) + val = ffi.cast("trie_val_t *", val) + return TrieNode.new(self, val, key, keylen) +end + +-- Return the first node (minimum). +function Trie:get_first() + local key_ptr = ffi.new("uint8_t *[1]") + local keylen_ptr = ffi.new("uint32_t[1]") + local val = C.trie_get_first(self.obj, key_ptr, keylen_ptr) + local keylen = tonumber(keylen_ptr[0]) + key = key_ptr[0] + return TrieNode.new(self, val, key, keylen) +end + +-- Return a trie iterator. +-- It is only valid as long as the key-set remains unchanged. +function Trie:iter() + return TrieIter.new(self) +end + +-- dnsjit.lib.trie.node (3), dnsjit.lib.trie.iter (3) +return Trie diff --git a/src/lib/trie/iter.lua b/src/lib/trie/iter.lua new file mode 100644 index 0000000..520cc7b --- /dev/null +++ b/src/lib/trie/iter.lua @@ -0,0 +1,93 @@ +-- Copyright (c) 2020, CZ.NIC, z.s.p.o. +-- All rights reserved. +-- +-- This file is part of dnsjit. +-- +-- dnsjit is free software: you can redistribute it and/or modify +-- it under the terms of the GNU General Public License as published by +-- the Free Software Foundation, either version 3 of the License, or +-- (at your option) any later version. +-- +-- dnsjit is distributed in the hope that it will be useful, +-- but WITHOUT ANY WARRANTY; without even the implied warranty of +-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +-- GNU General Public License for more details. +-- +-- You should have received a copy of the GNU General Public License +-- along with dnsjit. If not, see <http://www.gnu.org/licenses/>. + +-- dnsjit.lib.trie.iter +-- Iterator of the trie. +-- .SS Iterate over all trie's key-value pairs +-- local trie = require("dnsjit.lib.trie").new("uint64_t") +-- local iter = trie:iter() +-- local node = iter:node() +-- while node ~= nil do +-- local key = node:key() +-- local value = tonumber(node:get()) +-- print(key..": "..value) +-- iter:next() +-- node = iter:node() +-- end +-- +-- Beware that iterator is only valid as long as the trie's key-set +-- remains unchanged. +module(...,package.seeall) + +require("dnsjit.lib.trie_h") +local ffi = require("ffi") +local C = ffi.C +local log = require("dnsjit.core.log") +local module_log = log.new("lib.trie.iter") +local TrieNode = require("dnsjit.lib.trie.node") + +TrieIter = {} + +-- Create a new iterator pointing to the first element (if any). +function TrieIter.new(trie) + local self = setmetatable({ + obj = C.trie_it_begin(trie.obj), + _trie = trie, + _log = log.new("lib.trie.iter", module_log), + }, { __index = TrieIter }) + + ffi.gc(self.obj, C.trie_it_free) + + return self +end + +-- Return the Log object to control logging of this instance or module. +function TrieIter:log() + if self == nil then + return module_log + end + return self._log +end + +-- Return the node pointer to by the iterator. +-- Returns nil when iterator has gone past the last element. +function TrieIter:node() + if C.trie_it_finished(self.obj) then + return nil + end + + local keylen_ptr = ffi.new("size_t[1]") + local key = C.trie_it_key(self.obj, keylen_ptr) + local keylen = tonumber(keylen_ptr[0]) + + local val = C.trie_it_val(self.obj) + return TrieNode.new(self._trie, val, key, keylen) +end + +-- Advance the iterator to the next element. +-- +-- Iteration is in ascending lexicographical order. +-- Empty string would be considered as the very first. +-- +-- You may not use this function if the trie's key-set has been modified during the lifetime of the iterator (modifying only values is OK). +function TrieIter:next() + C.trie_it_next(self.obj) +end + +-- dnsjit.lib.trie (3), dnsjit.lib.trie.node (3) +return TrieIter diff --git a/src/lib/trie/node.lua b/src/lib/trie/node.lua new file mode 100644 index 0000000..7fdc39d --- /dev/null +++ b/src/lib/trie/node.lua @@ -0,0 +1,84 @@ +-- Copyright (c) 2020, CZ.NIC, z.s.p.o. +-- All rights reserved. +-- +-- This file is part of dnsjit. +-- +-- dnsjit is free software: you can redistribute it and/or modify +-- it under the terms of the GNU General Public License as published by +-- the Free Software Foundation, either version 3 of the License, or +-- (at your option) any later version. +-- +-- dnsjit is distributed in the hope that it will be useful, +-- but WITHOUT ANY WARRANTY; without even the implied warranty of +-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +-- GNU General Public License for more details. +-- +-- You should have received a copy of the GNU General Public License +-- along with dnsjit. If not, see <http://www.gnu.org/licenses/>. + +-- dnsjit.lib.trie.node +-- Node of trie, which contains the value and key. +-- .SS Set a node's value. +-- node:set(42) +-- .SS Get a node's key and value. +-- local key = node:key() +-- local val = node:get() +module(...,package.seeall) + +require("dnsjit.lib.trie_h") +local ffi = require("ffi") +local C = ffi.C +local log = require("dnsjit.core.log") +local module_log = log.new("lib.trie.node") + +TrieNode = {} + +-- Create a new node object. +function TrieNode.new(trie, val, key, keylen) + local self = setmetatable({ + _key = key, + _keylen = keylen, + _val = val, + _trie = trie, + _log = log.new("lib.trie.node", module_log), + }, { __index = TrieNode }) + + return self +end + +-- Return key and keylen of this node. +-- Key is string or byte array if the trie's +-- .I +-- binary +-- setting is set to true. +function TrieNode:key() + if self._trie._binary then + local key = ffi.new("uint8_t[?]", self._keylen) + ffi.copy(key, self._key, self._keylen) + return key, self._keylen + else + return ffi.string(self._key, self._keylen), self._keylen + end +end + +-- Return the Log object to control logging of this instance or module. +function TrieNode:log() + if self == nil then + return module_log + end + return self._log +end + +-- Get the value of this node. +function TrieNode:get() + return ffi.cast(self._trie._ctype, self._val[0]) +end + +-- Set the value of this node. +function TrieNode:set(value) + value = ffi.cast('void *', value) + self._val[0] = value +end + +-- dnsjit.lib.trie (3) +return TrieNode |