diff options
Diffstat (limited to 'src/lua/lua_trie.c')
-rw-r--r-- | src/lua/lua_trie.c | 500 |
1 files changed, 500 insertions, 0 deletions
diff --git a/src/lua/lua_trie.c b/src/lua/lua_trie.c new file mode 100644 index 0000000..3b0b55e --- /dev/null +++ b/src/lua/lua_trie.c @@ -0,0 +1,500 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" +#include "message.h" +#include "libutil/multipattern.h" + +/*** + * @module rspamd_trie + * Rspamd trie module provides the data structure suitable for searching of many + * patterns in arbitrary texts (or binary chunks). The algorithmic complexity of + * this algorithm is at most O(n + m + z), where `n` is the length of text, `m` is a length of pattern and `z` is a number of patterns in the text. + * + * Here is a typical example of trie usage: + * @example +local rspamd_trie = require "rspamd_trie" +local patterns = {'aab', 'ab', 'bcd\0ef'} + +local trie = rspamd_trie.create(patterns) + +local function trie_callback(number, pos) + print('Matched pattern number ' .. tostring(number) .. ' at pos: ' .. tostring(pos)) +end + +trie:match('some big text', trie_callback) + */ + +/* Suffix trie */ +LUA_FUNCTION_DEF(trie, create); +LUA_FUNCTION_DEF(trie, has_hyperscan); +LUA_FUNCTION_DEF(trie, match); +LUA_FUNCTION_DEF(trie, search_mime); +LUA_FUNCTION_DEF(trie, search_rawmsg); +LUA_FUNCTION_DEF(trie, search_rawbody); +LUA_FUNCTION_DEF(trie, destroy); + +static const struct luaL_reg trielib_m[] = { + LUA_INTERFACE_DEF(trie, match), + LUA_INTERFACE_DEF(trie, search_mime), + LUA_INTERFACE_DEF(trie, search_rawmsg), + LUA_INTERFACE_DEF(trie, search_rawbody), + {"__tostring", rspamd_lua_class_tostring}, + {"__gc", lua_trie_destroy}, + {NULL, NULL}}; +static const struct luaL_reg trielib_f[] = { + LUA_INTERFACE_DEF(trie, create), + LUA_INTERFACE_DEF(trie, has_hyperscan), + {NULL, NULL}}; + +static struct rspamd_multipattern * +lua_check_trie(lua_State *L, gint idx) +{ + void *ud = rspamd_lua_check_udata(L, 1, "rspamd{trie}"); + + luaL_argcheck(L, ud != NULL, 1, "'trie' expected"); + return ud ? *((struct rspamd_multipattern **) ud) : NULL; +} + +static gint +lua_trie_destroy(lua_State *L) +{ + struct rspamd_multipattern *trie = lua_check_trie(L, 1); + + if (trie) { + rspamd_multipattern_destroy(trie); + } + + return 0; +} + +/*** + * function trie.has_hyperscan() + * Checks for hyperscan support + * + * @return {bool} true if hyperscan is supported + */ +static gint +lua_trie_has_hyperscan(lua_State *L) +{ + lua_pushboolean(L, rspamd_multipattern_has_hyperscan()); + return 1; +} + +/*** + * function trie.create(patterns, [flags]) + * Creates new trie data structure + * @param {table} array of string patterns + * @return {trie} new trie object + */ +static gint +lua_trie_create(lua_State *L) +{ + struct rspamd_multipattern *trie, **ptrie; + gint npat = 0, flags = RSPAMD_MULTIPATTERN_ICASE | RSPAMD_MULTIPATTERN_GLOB; + GError *err = NULL; + + if (lua_isnumber(L, 2)) { + flags = lua_tointeger(L, 2); + } + + if (!lua_istable(L, 1)) { + return luaL_error(L, "lua trie expects array of patterns for now"); + } + else { + lua_pushvalue(L, 1); + lua_pushnil(L); + + while (lua_next(L, -2) != 0) { + if (lua_isstring(L, -1)) { + npat++; + } + + lua_pop(L, 1); + } + + trie = rspamd_multipattern_create_sized(npat, flags); + lua_pushnil(L); + + while (lua_next(L, -2) != 0) { + if (lua_isstring(L, -1)) { + const gchar *pat; + gsize patlen; + + pat = lua_tolstring(L, -1, &patlen); + rspamd_multipattern_add_pattern_len(trie, pat, patlen, flags); + } + + lua_pop(L, 1); + } + + lua_pop(L, 1); /* table */ + + if (!rspamd_multipattern_compile(trie, &err)) { + msg_err("cannot compile multipattern: %e", err); + g_error_free(err); + rspamd_multipattern_destroy(trie); + lua_pushnil(L); + } + else { + ptrie = lua_newuserdata(L, sizeof(void *)); + rspamd_lua_setclass(L, "rspamd{trie}", -1); + *ptrie = trie; + } + } + + return 1; +} + +#define PUSH_TRIE_MATCH(L, start, end, report_start) \ + do { \ + if (report_start) { \ + lua_createtable(L, 2, 0); \ + lua_pushinteger(L, (start)); \ + lua_rawseti(L, -2, 1); \ + lua_pushinteger(L, (end)); \ + lua_rawseti(L, -2, 2); \ + } \ + else { \ + lua_pushinteger(L, (end)); \ + } \ + } while (0) + +/* Normal callback type */ +static gint +lua_trie_lua_cb_callback(struct rspamd_multipattern *mp, + guint strnum, + gint match_start, + gint textpos, + const gchar *text, + gsize len, + void *context) +{ + lua_State *L = context; + gint ret; + + gboolean report_start = lua_toboolean(L, -1); + + /* Function */ + lua_pushvalue(L, 3); + lua_pushinteger(L, strnum + 1); + + PUSH_TRIE_MATCH(L, match_start, textpos, report_start); + + if (lua_pcall(L, 2, 1, 0) != 0) { + msg_info("call to trie callback has failed: %s", + lua_tostring(L, -1)); + lua_pop(L, 1); + + return 1; + } + + ret = lua_tonumber(L, -1); + lua_pop(L, 1); + + return ret; +} + +/* Table like callback, expect result table on top of the stack */ +static gint +lua_trie_table_callback(struct rspamd_multipattern *mp, + guint strnum, + gint match_start, + gint textpos, + const gchar *text, + gsize len, + void *context) +{ + lua_State *L = context; + + gint report_start = lua_toboolean(L, -2); + /* Set table, indexed by pattern number */ + lua_rawgeti(L, -1, strnum + 1); + + if (lua_istable(L, -1)) { + /* Already have table, add offset */ + gsize last = rspamd_lua_table_size(L, -1); + PUSH_TRIE_MATCH(L, match_start, textpos, report_start); + lua_rawseti(L, -2, last + 1); + /* Remove table from the stack */ + lua_pop(L, 1); + } + else { + /* Pop none */ + lua_pop(L, 1); + /* New table */ + lua_newtable(L); + PUSH_TRIE_MATCH(L, match_start, textpos, report_start); + lua_rawseti(L, -2, 1); + lua_rawseti(L, -2, strnum + 1); + } + + return 0; +} + +/* + * We assume that callback argument is at pos 3 and icase is in position 4 + */ +static gint +lua_trie_search_str(lua_State *L, struct rspamd_multipattern *trie, + const gchar *str, gsize len, rspamd_multipattern_cb_t cb) +{ + gint ret; + guint nfound = 0; + + if ((ret = rspamd_multipattern_lookup(trie, str, len, + cb, L, &nfound)) == 0) { + return nfound; + } + + return ret; +} + +/*** + * @method trie:match(input, [cb][, report_start]) + * Search for patterns in `input` invoking `cb` optionally ignoring case + * @param {table or string} input one or several (if `input` is an array) strings of input text + * @param {function} cb callback called on each pattern match in form `function (idx, pos)` where `idx` is a numeric index of pattern (starting from 1) and `pos` is a numeric offset where the pattern ends + * @param {boolean} report_start report both start and end offset when matching patterns + * @return {boolean} `true` if any pattern has been found (`cb` might be called multiple times however). If `cb` is not defined then it returns a table of match positions indexed by pattern number + */ +static gint +lua_trie_match(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_multipattern *trie = lua_check_trie(L, 1); + const gchar *text; + gsize len; + gboolean found = FALSE, report_start = FALSE; + struct rspamd_lua_text *t; + rspamd_multipattern_cb_t cb = lua_trie_lua_cb_callback; + + gint old_top = lua_gettop(L); + + if (trie) { + if (lua_type(L, 3) != LUA_TFUNCTION) { + if (lua_isboolean(L, 3)) { + report_start = lua_toboolean(L, 3); + } + + lua_pushboolean(L, report_start); + /* Table like match */ + lua_newtable(L); + cb = lua_trie_table_callback; + } + else { + if (lua_isboolean(L, 4)) { + report_start = lua_toboolean(L, 4); + } + lua_pushboolean(L, report_start); + } + + if (lua_type(L, 2) == LUA_TTABLE) { + lua_pushvalue(L, 2); + lua_pushnil(L); + + while (lua_next(L, -2) != 0) { + if (lua_isstring(L, -1)) { + text = lua_tolstring(L, -1, &len); + + if (lua_trie_search_str(L, trie, text, len, cb)) { + found = TRUE; + } + } + else if (lua_isuserdata(L, -1)) { + t = lua_check_text(L, -1); + + if (t) { + if (lua_trie_search_str(L, trie, t->start, t->len, cb)) { + found = TRUE; + } + } + } + lua_pop(L, 1); + } + } + else if (lua_type(L, 2) == LUA_TSTRING) { + text = lua_tolstring(L, 2, &len); + + if (lua_trie_search_str(L, trie, text, len, cb)) { + found = TRUE; + } + } + else if (lua_type(L, 2) == LUA_TUSERDATA) { + t = lua_check_text(L, 2); + + if (t && lua_trie_search_str(L, trie, t->start, t->len, cb)) { + found = TRUE; + } + } + } + + if (lua_type(L, 3) == LUA_TFUNCTION) { + lua_settop(L, old_top); + lua_pushboolean(L, found); + } + else { + lua_remove(L, -2); + } + + return 1; +} + +/*** + * @method trie:search_mime(task, cb) + * This is a helper mehthod to search pattern within text parts of a message in rspamd task + * @param {task} task object + * @param {function} cb callback called on each pattern match @see trie:match + * @param {boolean} caseless if `true` then match ignores symbols case (ASCII only) + * @return {boolean} `true` if any pattern has been found (`cb` might be called multiple times however) + */ +static gint +lua_trie_search_mime(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_multipattern *trie = lua_check_trie(L, 1); + struct rspamd_task *task = lua_check_task(L, 2); + struct rspamd_mime_text_part *part; + const gchar *text; + gsize len, i; + gboolean found = FALSE; + rspamd_multipattern_cb_t cb = lua_trie_lua_cb_callback; + + if (trie && task) { + PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, part) + { + if (!IS_TEXT_PART_EMPTY(part) && part->utf_content.len > 0) { + text = part->utf_content.begin; + len = part->utf_content.len; + + if (lua_trie_search_str(L, trie, text, len, cb) != 0) { + found = TRUE; + } + } + } + } + + lua_pushboolean(L, found); + return 1; +} + +/*** + * @method trie:search_rawmsg(task, cb[, caseless]) + * This is a helper mehthod to search pattern within the whole undecoded content of rspamd task + * @param {task} task object + * @param {function} cb callback called on each pattern match @see trie:match + * @param {boolean} caseless if `true` then match ignores symbols case (ASCII only) + * @return {boolean} `true` if any pattern has been found (`cb` might be called multiple times however) + */ +static gint +lua_trie_search_rawmsg(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_multipattern *trie = lua_check_trie(L, 1); + struct rspamd_task *task = lua_check_task(L, 2); + const gchar *text; + gsize len; + gboolean found = FALSE; + + if (trie && task) { + text = task->msg.begin; + len = task->msg.len; + + if (lua_trie_search_str(L, trie, text, len, lua_trie_lua_cb_callback) != 0) { + found = TRUE; + } + } + + lua_pushboolean(L, found); + return 1; +} + +/*** + * @method trie:search_rawbody(task, cb[, caseless]) + * This is a helper mehthod to search pattern within the whole undecoded content of task's body (not including headers) + * @param {task} task object + * @param {function} cb callback called on each pattern match @see trie:match + * @param {boolean} caseless if `true` then match ignores symbols case (ASCII only) + * @return {boolean} `true` if any pattern has been found (`cb` might be called multiple times however) + */ +static gint +lua_trie_search_rawbody(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_multipattern *trie = lua_check_trie(L, 1); + struct rspamd_task *task = lua_check_task(L, 2); + const gchar *text; + gsize len; + gboolean found = FALSE; + + if (trie && task) { + if (MESSAGE_FIELD(task, raw_headers_content).len > 0) { + text = task->msg.begin + MESSAGE_FIELD(task, raw_headers_content).len; + len = task->msg.len - MESSAGE_FIELD(task, raw_headers_content).len; + } + else { + /* Treat as raw message */ + text = task->msg.begin; + len = task->msg.len; + } + + if (lua_trie_search_str(L, trie, text, len, lua_trie_lua_cb_callback) != 0) { + found = TRUE; + } + } + + lua_pushboolean(L, found); + return 1; +} + +static gint +lua_load_trie(lua_State *L) +{ + lua_newtable(L); + + /* Flags */ + lua_pushstring(L, "flags"); + lua_newtable(L); + + lua_pushinteger(L, RSPAMD_MULTIPATTERN_GLOB); + lua_setfield(L, -2, "glob"); + lua_pushinteger(L, RSPAMD_MULTIPATTERN_RE); + lua_setfield(L, -2, "re"); + lua_pushinteger(L, RSPAMD_MULTIPATTERN_ICASE); + lua_setfield(L, -2, "icase"); + lua_pushinteger(L, RSPAMD_MULTIPATTERN_UTF8); + lua_setfield(L, -2, "utf8"); + lua_pushinteger(L, RSPAMD_MULTIPATTERN_TLD); + lua_setfield(L, -2, "tld"); + lua_pushinteger(L, RSPAMD_MULTIPATTERN_DOTALL); + lua_setfield(L, -2, "dot_all"); + lua_pushinteger(L, RSPAMD_MULTIPATTERN_SINGLEMATCH); + lua_setfield(L, -2, "single_match"); + lua_pushinteger(L, RSPAMD_MULTIPATTERN_NO_START); + lua_setfield(L, -2, "no_start"); + lua_settable(L, -3); + + /* Main content */ + luaL_register(L, NULL, trielib_f); + + return 1; +} + +void luaopen_trie(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{trie}", trielib_m); + lua_pop(L, 1); + rspamd_lua_add_preload(L, "rspamd_trie", lua_load_trie); +} |