diff options
Diffstat (limited to '')
46 files changed, 50307 insertions, 0 deletions
diff --git a/src/lua/CMakeLists.txt b/src/lua/CMakeLists.txt new file mode 100644 index 0000000..a504f99 --- /dev/null +++ b/src/lua/CMakeLists.txt @@ -0,0 +1,39 @@ +# Lua support makefile +SET(LUASRC ${CMAKE_CURRENT_SOURCE_DIR}/lua_common.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_logger.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_task.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_config.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_classifier.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_cfg_file.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_regexp.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_cdb.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_xmlrpc.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_http.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_redis.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_upstream.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_mempool.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_dns_resolver.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_rsa.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_ip.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_expression.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_trie.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_mimepart.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_url.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_util.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_tcp.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_html.cxx + ${CMAKE_CURRENT_SOURCE_DIR}/lua_sqlite3.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_cryptobox.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_map.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_thread_pool.cxx + ${CMAKE_CURRENT_SOURCE_DIR}/lua_dns.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_udp.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_text.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_worker.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_kann.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_spf.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_tensor.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_parsers.c + ${CMAKE_CURRENT_SOURCE_DIR}/lua_compress.c) + +SET(RSPAMD_LUA ${LUASRC} PARENT_SCOPE)
\ No newline at end of file diff --git a/src/lua/lua_cdb.c b/src/lua/lua_cdb.c new file mode 100644 index 0000000..76a5795 --- /dev/null +++ b/src/lua/lua_cdb.c @@ -0,0 +1,391 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" +#include "cdb.h" + +#define CDB_REFRESH_TIME 60 + +/*** + * @module rspamd_cdb + * Rspamd CDB module is used to read and write key/value pairs to the CDB file + * + * @example +local rspamd_cdb = require "rspamd_cdb" +rspamd_cdb.build('/tmp/test.cdb'):add('test', 'value'):finalize() +local c = rspamd_cdb.open('/tmp/test.cdb') +c:find('test') +-- will return 'value' + */ + +/*** + * @function rspamd_cdb.open(filename, [ev_base]) + * Opens an existing CDB for reading. If `ev_base` is specified, then cdb file is added + * for monitoring, that will get updates on disk file changes. + * @param {string} filename path to file + * @param {ev_base} event loop object + * @return {rspamd_cdb} cdb object + */ +LUA_FUNCTION_DEF(cdb, create); +/*** + * @method rspamd_cdb:find(key) + * Finds a specific key in cdb and returns a string or nil if a key has not been found + * @param {string} key key to find + * @return {string/nil} value for the specific key + */ +LUA_FUNCTION_DEF(cdb, lookup); +/*** + * @method rspamd_cdb:get_name() + * Returns filename for the specific cdb + * @return {string} filename for cdb + */ +LUA_FUNCTION_DEF(cdb, get_name); +LUA_FUNCTION_DEF(cdb, destroy); + +/*** + * @function rspamd_cdb.build(filename, [mode]) + * Creates a new cdb in a file (existing one will be overwritten!). The object + * returned can be used merely for adding data. Upon finalizing, the data is written to + * disk and cdb can no longer be changed. + * @param {string} filename path to file + * @param {int} mode numeric mode to create a file + * @return {rspamd_cdb_builder} cdb builder object (or nil + error message) + */ +LUA_FUNCTION_DEF(cdb, build); +/*** + * @method rspamd_cdb_builder:add(key, value) + * Adds new value to cdb in the builder mode + * @param {string} key key to add + * @param {string} value value to associate with the key + * @return {rspamd_cdb_builder} the same object to allow chaining calls + */ +LUA_FUNCTION_DEF(cdb_builder, add); +/*** + * @method rspamd_cdb_builder:finalize() + * Finalizes the CDB and writes it to disk. This method also closes FD associated with + * CDB builder. No further additions are allowed after this point + */ +LUA_FUNCTION_DEF(cdb_builder, finalize); +LUA_FUNCTION_DEF(cdb_builder, dtor); + +static const struct luaL_reg cdblib_m[] = { + LUA_INTERFACE_DEF(cdb, lookup), + {"find", lua_cdb_lookup}, + LUA_INTERFACE_DEF(cdb, get_name), + {"__tostring", rspamd_lua_class_tostring}, + {"__gc", lua_cdb_destroy}, + {NULL, NULL}}; + +static const struct luaL_reg cdbbuilderlib_m[] = { + LUA_INTERFACE_DEF(cdb_builder, add), + LUA_INTERFACE_DEF(cdb_builder, finalize), + {"__tostring", rspamd_lua_class_tostring}, + {"__gc", lua_cdb_builder_dtor}, + {NULL, NULL}}; + +static const struct luaL_reg cdblib_f[] = { + LUA_INTERFACE_DEF(cdb, create), + {"open", lua_cdb_create}, + {"build", lua_cdb_build}, + {NULL, NULL}}; + +static struct cdb * +lua_check_cdb(lua_State *L, int pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{cdb}"); + + luaL_argcheck(L, ud != NULL, pos, "'cdb' expected"); + return ud ? *((struct cdb **) ud) : NULL; +} + +static struct cdb_make * +lua_check_cdb_builder(lua_State *L, int pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{cdb_builder}"); + + luaL_argcheck(L, ud != NULL, pos, "'cdb_builder' expected"); + return ud ? ((struct cdb_make *) ud) : NULL; +} + +static const char * +lua_cdb_get_input(lua_State *L, int pos, gsize *olen) +{ + int t = lua_type(L, pos); + + switch (t) { + case LUA_TSTRING: + return lua_tolstring(L, pos, olen); + case LUA_TNUMBER: { + static char numbuf[sizeof(lua_Number)]; + lua_Number n = lua_tonumber(L, pos); + memcpy(numbuf, &n, sizeof(numbuf)); + *olen = sizeof(n); + return numbuf; + } + case LUA_TUSERDATA: { + void *p = rspamd_lua_check_udata_maybe(L, pos, "rspamd{text}"); + if (p) { + struct rspamd_lua_text *t = (struct rspamd_lua_text *) p; + *olen = t->len; + return t->start; + } + + p = rspamd_lua_check_udata_maybe(L, pos, "rspamd{int64}"); + if (p) { + static char numbuf[sizeof(gint64)]; + + memcpy(numbuf, p, sizeof(numbuf)); + *olen = sizeof(numbuf); + return numbuf; + } + } + default: + break; + } + + return NULL; +} + +static gint +lua_cdb_create(lua_State *L) +{ + struct cdb *cdb, **pcdb; + const gchar *filename; + gint fd; + + struct ev_loop *ev_base = NULL; + + if (lua_type(L, 2) == LUA_TUSERDATA) { + ev_base = lua_check_ev_base(L, 2); + } + + filename = luaL_checkstring(L, 1); + /* If file begins with cdb://, just skip it */ + if (g_ascii_strncasecmp(filename, "cdb://", sizeof("cdb://") - 1) == 0) { + filename += sizeof("cdb://") - 1; + } + + if ((fd = open(filename, O_RDONLY)) == -1) { + msg_warn("cannot open cdb: %s, %s", filename, strerror(errno)); + lua_pushnil(L); + } + else { + cdb = g_malloc0(sizeof(struct cdb)); + cdb->filename = g_strdup(filename); + if (cdb_init(cdb, fd) == -1) { + g_free(cdb->filename); + g_free(cdb); + msg_warn("cannot open cdb: %s, %s", filename, strerror(errno)); + lua_pushnil(L); + } + else { +#ifdef HAVE_READAHEAD + struct stat st; + /* + * Do not readahead more than 100mb, + * which is enough for the vast majority of the use cases + */ + static const size_t max_readahead = 100 * 0x100000; + + if (fstat(cdb_fileno(cdb), &st) != 1) { + /* Must always be true because cdb_init calls it as well */ + if (readahead(cdb_fileno(cdb), 0, MIN(max_readahead, st.st_size)) == -1) { + msg_warn("cannot readahead cdb: %s, %s", filename, strerror(errno)); + } + } +#endif + if (ev_base) { + cdb_add_timer(cdb, ev_base, CDB_REFRESH_TIME); + } + pcdb = lua_newuserdata(L, sizeof(struct cdb *)); + rspamd_lua_setclass(L, "rspamd{cdb}", -1); + *pcdb = cdb; + } + } + + return 1; +} + +static gint +lua_cdb_get_name(lua_State *L) +{ + struct cdb *cdb = lua_check_cdb(L, 1); + + if (!cdb) { + lua_error(L); + return 1; + } + lua_pushstring(L, cdb->filename); + return 1; +} + +static gint +lua_cdb_lookup(lua_State *L) +{ + struct cdb *cdb = lua_check_cdb(L, 1); + gsize klen; + const gchar *what = lua_cdb_get_input(L, 2, &klen); + + if (!cdb || what == NULL) { + return lua_error(L); + } + + if (cdb_find(cdb, what, klen) > 0) { + /* Extract and push value to lua as string */ + lua_pushlstring(L, cdb_getdata(cdb), cdb_datalen(cdb)); + } + else { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_cdb_destroy(lua_State *L) +{ + struct cdb *cdb = lua_check_cdb(L, 1); + + if (cdb) { + cdb_free(cdb); + if (cdb->cdb_fd != -1) { + (void) close(cdb->cdb_fd); + } + g_free(cdb->filename); + g_free(cdb); + } + + return 0; +} + +static gint +lua_cdb_build(lua_State *L) +{ + const char *filename = luaL_checkstring(L, 1); + int fd, mode = 00755; + + if (filename == NULL) { + return luaL_error(L, "invalid arguments, filename expected"); + } + + /* If file begins with cdb://, just skip it */ + if (g_ascii_strncasecmp(filename, "cdb://", sizeof("cdb://") - 1) == 0) { + filename += sizeof("cdb://") - 1; + } + + if (lua_isnumber(L, 2)) { + mode = lua_tointeger(L, 2); + } + + fd = rspamd_file_xopen(filename, O_RDWR | O_CREAT | O_TRUNC, mode, 0); + + if (fd == -1) { + lua_pushnil(L); + lua_pushfstring(L, "cannot open cdb: %s, %s", filename, strerror(errno)); + + return 2; + } + + struct cdb_make *cdbm = lua_newuserdata(L, sizeof(struct cdb_make)); + + g_assert(cdb_make_start(cdbm, fd) == 0); + rspamd_lua_setclass(L, "rspamd{cdb_builder}", -1); + + return 1; +} + +static gint +lua_cdb_builder_add(lua_State *L) +{ + struct cdb_make *cdbm = lua_check_cdb_builder(L, 1); + gsize data_sz, key_sz; + const char *key = lua_cdb_get_input(L, 2, &key_sz); + const char *data = lua_cdb_get_input(L, 3, &data_sz); + + if (cdbm == NULL || key == NULL || data == NULL || cdbm->cdb_fd == -1) { + return luaL_error(L, "invalid arguments"); + } + + if (cdb_make_add(cdbm, key, key_sz, data, data_sz) == -1) { + lua_pushvalue(L, 1); + lua_pushfstring(L, "cannot push value to cdb: %s", strerror(errno)); + + return 2; + } + + /* Allow chaining */ + lua_pushvalue(L, 1); + return 1; +} + +static gint +lua_cdb_builder_finalize(lua_State *L) +{ + struct cdb_make *cdbm = lua_check_cdb_builder(L, 1); + + if (cdbm == NULL || cdbm->cdb_fd == -1) { + return luaL_error(L, "invalid arguments"); + } + + if (cdb_make_finish(cdbm) == -1) { + lua_pushvalue(L, 1); + lua_pushfstring(L, "cannot finish value to cdb: %s", strerror(errno)); + + return 2; + } + + close(cdbm->cdb_fd); + cdbm->cdb_fd = -1; /* To distinguish finalized object */ + + /* Allow chaining */ + lua_pushvalue(L, 1); + return 1; +} + +static gint +lua_cdb_builder_dtor(lua_State *L) +{ + struct cdb_make *cdbm = lua_check_cdb_builder(L, 1); + + if (cdbm == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (cdbm->cdb_fd != -1) { + cdb_make_finish(cdbm); + close(cdbm->cdb_fd); + cdbm->cdb_fd = -1; /* Finalized object */ + } + + return 0; +} + +static gint +lua_load_cdb(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, cdblib_f); + + return 1; +} + +void luaopen_cdb(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{cdb}", cdblib_m); + lua_pop(L, 1); + rspamd_lua_new_class(L, "rspamd{cdb_builder}", cdbbuilderlib_m); + lua_pop(L, 1); + rspamd_lua_add_preload(L, "rspamd_cdb", lua_load_cdb); +} diff --git a/src/lua/lua_cfg_file.c b/src/lua/lua_cfg_file.c new file mode 100644 index 0000000..75bc380 --- /dev/null +++ b/src/lua/lua_cfg_file.c @@ -0,0 +1,157 @@ +/* + * Copyright 2023 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" +#include "expression.h" +#include "src/libserver/composites/composites.h" + +/* + * This is implementation of lua routines to handle config file params + */ + +/* Process a single item in 'metrics' table */ +static void +lua_process_metric(lua_State *L, const gchar *name, struct rspamd_config *cfg) +{ + gchar *symbol; + const gchar *desc = NULL; + gdouble *score; + struct rspamd_symbol *s; + + /* Now iterate through module table */ + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + /* key - -2, value - -1 */ + symbol = rspamd_mempool_strdup(cfg->cfg_pool, luaL_checkstring(L, -2)); + if (symbol != NULL) { + if (lua_istable(L, -1)) { + /* We got a table, so extract individual attributes */ + lua_pushstring(L, "weight"); + lua_gettable(L, -2); + if (lua_isnumber(L, -1)) { + score = rspamd_mempool_alloc(cfg->cfg_pool, sizeof(double)); + *score = lua_tonumber(L, -1); + } + else { + msg_warn_config("cannot get weight of symbol: %s", symbol); + continue; + } + lua_pop(L, 1); + lua_pushstring(L, "description"); + lua_gettable(L, -2); + if (lua_isstring(L, -1)) { + desc = lua_tostring(L, -1); + } + lua_pop(L, 1); + } + else if (lua_isnumber(L, -1)) { + /* Just got weight */ + score = rspamd_mempool_alloc(cfg->cfg_pool, sizeof(double)); + *score = lua_tonumber(L, -1); + } + else { + msg_warn_config("cannot get weight of symbol: %s", symbol); + continue; + } + /* Insert symbol */ + if ((s = + g_hash_table_lookup(cfg->symbols, symbol)) != NULL) { + msg_info_config("replacing weight for symbol %s: %.2f -> %.2f", + symbol, + *s->weight_ptr, + *score); + s->weight_ptr = score; + } + else { + s = rspamd_mempool_alloc0(cfg->cfg_pool, sizeof(*s)); + s->name = symbol; + s->weight_ptr = score; + g_hash_table_insert(cfg->symbols, symbol, s); + } + + if (desc) { + s->description = rspamd_mempool_strdup(cfg->cfg_pool, desc); + } + } + } +} + +/* Do post load initialization based on lua */ +void rspamd_lua_post_load_config(struct rspamd_config *cfg) +{ + lua_State *L = cfg->lua_state; + const gchar *name; + ucl_object_t *obj; + gsize keylen, i; + + /* First check all module options that may be overridden in 'config' global */ + lua_getglobal(L, "config"); + + if (lua_istable(L, -1)) { + /* Iterate to get all keys */ + GPtrArray *names = g_ptr_array_new_full(rspamd_lua_table_size(L, -1), + g_free); + + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 2)) { + gchar *tmp; + lua_pushvalue(L, -2); + name = luaL_checklstring(L, -1, &keylen); + + if (name && lua_istable(L, -2)) { + tmp = g_malloc(keylen + 1); + rspamd_strlcpy(tmp, name, keylen + 1); + g_ptr_array_add(names, tmp); + } + } + + PTR_ARRAY_FOREACH(names, i, name) + { + lua_getfield(L, -1, name); + + if (lua_istable(L, -1)) { + obj = ucl_object_lua_import(L, lua_gettop(L)); + + if (obj != NULL) { + ucl_object_sort_keys(obj, UCL_SORT_KEYS_DEFAULT); + ucl_object_insert_key_merged(cfg->cfg_ucl_obj, + obj, + name, + strlen(name), + true); + } + } + } + + g_ptr_array_free(names, TRUE); + } + + /* Check metrics settings */ + lua_getglobal(L, "metrics"); + + if (lua_istable(L, -1)) { + /* Iterate */ + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + /* 'key' is at index -2 and 'value' is at index -1 */ + /* Key must be a string and value must be a table */ + name = luaL_checkstring(L, -2); + if (name != NULL && lua_istable(L, -1)) { + lua_process_metric(L, name, cfg); + } + } + } + + lua_settop(L, 0); + + rspamd_lua_start_gc(cfg); +} diff --git a/src/lua/lua_classifier.c b/src/lua/lua_classifier.c new file mode 100644 index 0000000..39580a6 --- /dev/null +++ b/src/lua/lua_classifier.c @@ -0,0 +1,230 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" + +/* Classifier methods */ +LUA_FUNCTION_DEF(classifier, get_statfiles); +LUA_FUNCTION_DEF(classifier, get_statfile_by_label); +LUA_FUNCTION_DEF(classifier, get_param); + +static const struct luaL_reg classifierlib_m[] = { + LUA_INTERFACE_DEF(classifier, get_statfiles), + LUA_INTERFACE_DEF(classifier, get_param), + LUA_INTERFACE_DEF(classifier, get_statfile_by_label), + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + +LUA_FUNCTION_DEF(statfile, get_symbol); +LUA_FUNCTION_DEF(statfile, get_label); +LUA_FUNCTION_DEF(statfile, is_spam); +LUA_FUNCTION_DEF(statfile, get_param); + +static const struct luaL_reg statfilelib_m[] = { + LUA_INTERFACE_DEF(statfile, get_symbol), + LUA_INTERFACE_DEF(statfile, get_label), + LUA_INTERFACE_DEF(statfile, is_spam), + LUA_INTERFACE_DEF(statfile, get_param), + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + +static struct rspamd_statfile_config *lua_check_statfile(lua_State *L); + +/* Classifier implementation */ + + +static struct rspamd_classifier_config * +lua_check_classifier(lua_State *L) +{ + void *ud = rspamd_lua_check_udata(L, 1, "rspamd{classifier}"); + luaL_argcheck(L, ud != NULL, 1, "'classifier' expected"); + return ud ? *((struct rspamd_classifier_config **) ud) : NULL; +} + +/* Return table of statfiles indexed by name */ +static gint +lua_classifier_get_statfiles(lua_State *L) +{ + struct rspamd_classifier_config *ccf = lua_check_classifier(L); + GList *cur; + struct rspamd_statfile_config *st, **pst; + gint i; + + if (ccf) { + lua_newtable(L); + cur = g_list_first(ccf->statfiles); + i = 1; + while (cur) { + st = cur->data; + pst = lua_newuserdata(L, sizeof(struct rspamd_statfile_config *)); + rspamd_lua_setclass(L, "rspamd{statfile}", -1); + *pst = st; + lua_rawseti(L, -2, i++); + + cur = g_list_next(cur); + } + } + else { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_classifier_get_param(lua_State *L) +{ + struct rspamd_classifier_config *ccf = lua_check_classifier(L); + const gchar *param; + const ucl_object_t *value; + + param = luaL_checkstring(L, 2); + + if (ccf != NULL && param != NULL) { + value = ucl_object_lookup(ccf->opts, param); + + if (value != NULL) { + ucl_object_push_lua(L, value, true); + return 1; + } + } + + lua_pushnil(L); + + return 1; +} + +/* Get statfile with specified label */ +static gint +lua_classifier_get_statfile_by_label(lua_State *L) +{ + struct rspamd_classifier_config *ccf = lua_check_classifier(L); + struct rspamd_statfile_config *st, **pst; + const gchar *label; + GList *cur; + gint i; + + label = luaL_checkstring(L, 2); + if (ccf && label) { + cur = g_hash_table_lookup(ccf->labels, label); + if (cur) { + lua_newtable(L); + i = 1; + while (cur) { + st = cur->data; + pst = + lua_newuserdata(L, + sizeof(struct rspamd_statfile_config *)); + rspamd_lua_setclass(L, "rspamd{statfile}", -1); + *pst = st; + lua_rawseti(L, -2, i++); + cur = g_list_next(cur); + } + return 1; + } + } + lua_pushnil(L); + return 1; +} + +/* Statfile functions */ +static gint +lua_statfile_get_symbol(lua_State *L) +{ + struct rspamd_statfile_config *st = lua_check_statfile(L); + + if (st != NULL) { + lua_pushstring(L, st->symbol); + } + else { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_statfile_get_label(lua_State *L) +{ + struct rspamd_statfile_config *st = lua_check_statfile(L); + + if (st != NULL && st->label != NULL) { + lua_pushstring(L, st->label); + } + else { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_statfile_is_spam(lua_State *L) +{ + struct rspamd_statfile_config *st = lua_check_statfile(L); + + if (st != NULL) { + lua_pushboolean(L, st->is_spam); + } + else { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_statfile_get_param(lua_State *L) +{ + struct rspamd_statfile_config *st = lua_check_statfile(L); + const gchar *param; + const ucl_object_t *value; + + param = luaL_checkstring(L, 2); + + if (st != NULL && param != NULL) { + value = ucl_object_lookup(st->opts, param); + if (value != NULL) { + lua_pushstring(L, ucl_object_tostring_forced(value)); + return 1; + } + } + lua_pushnil(L); + + return 1; +} + +static struct rspamd_statfile_config * +lua_check_statfile(lua_State *L) +{ + void *ud = rspamd_lua_check_udata(L, 1, "rspamd{statfile}"); + luaL_argcheck(L, ud != NULL, 1, "'statfile' expected"); + return ud ? *((struct rspamd_statfile_config **) ud) : NULL; +} + + +/* Open functions */ + +void luaopen_classifier(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{classifier}", classifierlib_m); + lua_pop(L, 1); /* remove metatable from stack */ +} + +void luaopen_statfile(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{statfile}", statfilelib_m); + lua_pop(L, 1); /* remove metatable from stack */ +} diff --git a/src/lua/lua_common.c b/src/lua/lua_common.c new file mode 100644 index 0000000..9bf9514 --- /dev/null +++ b/src/lua/lua_common.c @@ -0,0 +1,2659 @@ +/* + * Copyright 2023 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" +#include "lua_compress.h" +#include "lptree.h" +#include "utlist.h" +#include "unix-std.h" +#include "ottery.h" +#include "lua_thread_pool.h" +#include "libstat/stat_api.h" +#include "libserver/rspamd_control.h" + +#include <math.h> + + +/* Lua module init function */ +#define MODULE_INIT_FUNC "module_init" + +#ifdef WITH_LUA_TRACE +ucl_object_t *lua_traces; +#endif + +const luaL_reg null_reg[] = { + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + +static const char rspamd_modules_state_global[] = "rspamd_plugins_state"; + +static GQuark +lua_error_quark(void) +{ + return g_quark_from_static_string("lua-routines"); +} + +/* + * Used to map string to a pointer + */ +KHASH_INIT(lua_class_set, const char *, int, 1, rspamd_str_hash, rspamd_str_equal); +struct rspamd_lua_context { + lua_State *L; + khash_t(lua_class_set) * classes; + struct rspamd_lua_context *prev, *next; /* Expensive but we usually have exactly one lua state */ +}; +struct rspamd_lua_context *rspamd_lua_global_ctx = NULL; +#define RSPAMD_LUA_NCLASSES 64 +static inline struct rspamd_lua_context * +rspamd_lua_ctx_by_state(lua_State *L) +{ + struct rspamd_lua_context *cur; + + DL_FOREACH(rspamd_lua_global_ctx, cur) + { + if (cur->L == L) { + return cur; + } + } + + /* When we are using thread pool, this is the case... */ + return rspamd_lua_global_ctx; +} + +/* Util functions */ +/** + * Create new class and store metatable on top of the stack (must be popped if not needed) + * @param L + * @param classname name of class + * @param func table of class methods + */ +void rspamd_lua_new_class(lua_State *L, + const gchar *classname, + const struct luaL_reg *methods) +{ + khiter_t k; + gint r, nmethods = 0; + gboolean seen_index = false; + struct rspamd_lua_context *ctx = rspamd_lua_ctx_by_state(L); + + if (methods) { + for (;;) { + if (methods[nmethods].name != NULL) { + if (strcmp(methods[nmethods].name, "__index") == 0) { + seen_index = true; + } + nmethods++; + } + else { + break; + } + } + } + + lua_createtable(L, 0, 3 + nmethods); + + if (!seen_index) { + lua_pushstring(L, "__index"); + lua_pushvalue(L, -2); /* pushes the metatable */ + lua_settable(L, -3); /* metatable.__index = metatable */ + } + + lua_pushstring(L, "class"); + lua_pushstring(L, classname); + lua_rawset(L, -3); + + if (methods) { + luaL_register(L, NULL, methods); /* pushes all methods as MT fields */ + } + + lua_pushvalue(L, -1); /* Preserves metatable */ + int offset = luaL_ref(L, LUA_REGISTRYINDEX); + k = kh_put(lua_class_set, ctx->classes, classname, &r); + kh_value(ctx->classes, k) = offset; + /* MT is left on stack ! */ +} + +static const gchar * +rspamd_lua_class_tostring_buf(lua_State *L, gboolean print_pointer, gint pos) +{ + static gchar buf[64]; + const gchar *ret = NULL; + gint pop = 0; + + if (!lua_getmetatable(L, pos)) { + goto err; + } + + pop++; + lua_pushstring(L, "class"); + lua_gettable(L, -2); + pop++; + + if (!lua_isstring(L, -1)) { + goto err; + } + + if (print_pointer) { + rspamd_snprintf(buf, sizeof(buf), "%s(%p)", lua_tostring(L, -1), + lua_touserdata(L, 1)); + } + else { + rspamd_snprintf(buf, sizeof(buf), "%s", lua_tostring(L, -1)); + } + + ret = buf; + +err: + lua_pop(L, pop); + + return ret; +} + +gint rspamd_lua_class_tostring(lua_State *L) +{ + const gchar *p; + + p = rspamd_lua_class_tostring_buf(L, TRUE, 1); + + if (!p) { + lua_pushstring(L, "invalid object passed to 'lua_common.c:__tostring'"); + return lua_error(L); + } + + lua_pushstring(L, p); + + return 1; +} + + +void rspamd_lua_setclass(lua_State *L, const gchar *classname, gint objidx) +{ + khiter_t k; + struct rspamd_lua_context *ctx = rspamd_lua_ctx_by_state(L); + + k = kh_get(lua_class_set, ctx->classes, classname); + + g_assert(k != kh_end(ctx->classes)); + lua_rawgeti(L, LUA_REGISTRYINDEX, kh_value(ctx->classes, k)); + + if (objidx < 0) { + objidx--; + } + lua_setmetatable(L, objidx); +} + +void rspamd_lua_class_metatable(lua_State *L, const gchar *classname) +{ + khiter_t k; + struct rspamd_lua_context *ctx = rspamd_lua_ctx_by_state(L); + + k = kh_get(lua_class_set, ctx->classes, classname); + + g_assert(k != kh_end(ctx->classes)); + lua_rawgeti(L, LUA_REGISTRYINDEX, kh_value(ctx->classes, k)); +} + +void rspamd_lua_add_metamethod(lua_State *L, const gchar *classname, + luaL_Reg *meth) +{ + khiter_t k; + struct rspamd_lua_context *ctx = rspamd_lua_ctx_by_state(L); + + k = kh_get(lua_class_set, ctx->classes, classname); + + g_assert(k != kh_end(ctx->classes)); + lua_rawgeti(L, LUA_REGISTRYINDEX, kh_value(ctx->classes, k)); + + lua_pushcfunction(L, meth->func); + lua_setfield(L, -2, meth->name); + lua_pop(L, 1); /* remove metatable */ +} + +/* assume that table is at the top */ +void rspamd_lua_table_set(lua_State *L, const gchar *index, const gchar *value) +{ + lua_pushstring(L, index); + if (value) { + lua_pushstring(L, value); + } + else { + lua_pushnil(L); + } + lua_settable(L, -3); +} + +const gchar * +rspamd_lua_table_get(lua_State *L, const gchar *index) +{ + const gchar *result; + + lua_pushstring(L, index); + lua_gettable(L, -2); + if (!lua_isstring(L, -1)) { + return NULL; + } + result = lua_tostring(L, -1); + lua_pop(L, 1); + return result; +} + +static void +lua_add_actions_global(lua_State *L) +{ + gint i; + + lua_newtable(L); + + for (i = METRIC_ACTION_REJECT; i <= METRIC_ACTION_NOACTION; i++) { + lua_pushstring(L, rspamd_action_to_str(i)); + lua_pushinteger(L, i); + lua_settable(L, -3); + } + /* Set global table */ + lua_setglobal(L, "rspamd_actions"); +} + +#ifndef __APPLE__ +#define OS_SO_SUFFIX ".so" +#else +#define OS_SO_SUFFIX ".dylib" +#endif + +void rspamd_lua_set_path(lua_State *L, const ucl_object_t *cfg_obj, GHashTable *vars) +{ + const gchar *old_path, *additional_path = NULL; + const ucl_object_t *opts = NULL; + const gchar *rulesdir = RSPAMD_RULESDIR, + *lualibdir = RSPAMD_LUALIBDIR, + *libdir = RSPAMD_LIBDIR; + const gchar *t; + + gchar path_buf[PATH_MAX]; + + lua_getglobal(L, "package"); + lua_getfield(L, -1, "path"); + old_path = luaL_checkstring(L, -1); + + if (strstr(old_path, RSPAMD_LUALIBDIR) != NULL) { + /* Path has been already set, do not touch it */ + lua_pop(L, 2); + return; + } + + if (cfg_obj) { + opts = ucl_object_lookup(cfg_obj, "options"); + if (opts != NULL) { + opts = ucl_object_lookup(opts, "lua_path"); + if (opts != NULL && ucl_object_type(opts) == UCL_STRING) { + additional_path = ucl_object_tostring(opts); + } + } + } + + if (additional_path) { + rspamd_snprintf(path_buf, sizeof(path_buf), + "%s;" + "%s", + additional_path, old_path); + } + else { + /* Try environment */ + t = getenv("RULESDIR"); + if (t) { + rulesdir = t; + } + + t = getenv("LUALIBDIR"); + if (t) { + lualibdir = t; + } + + t = getenv("LIBDIR"); + if (t) { + libdir = t; + } + + t = getenv("RSPAMD_LIBDIR"); + if (t) { + libdir = t; + } + + if (vars) { + t = g_hash_table_lookup(vars, "RULESDIR"); + if (t) { + rulesdir = t; + } + + t = g_hash_table_lookup(vars, "LUALIBDIR"); + if (t) { + lualibdir = t; + } + + t = g_hash_table_lookup(vars, "LIBDIR"); + if (t) { + libdir = t; + } + + t = g_hash_table_lookup(vars, "RSPAMD_LIBDIR"); + if (t) { + libdir = t; + } + } + + rspamd_snprintf(path_buf, sizeof(path_buf), + "%s/lua/?.lua;" + "%s/?.lua;" + "%s/?.lua;" + "%s/?/init.lua;" + "%s", + RSPAMD_CONFDIR, + rulesdir, + lualibdir, lualibdir, + old_path); + } + + lua_pop(L, 1); + lua_pushstring(L, path_buf); + lua_setfield(L, -2, "path"); + + lua_getglobal(L, "package"); + lua_getfield(L, -1, "cpath"); + old_path = luaL_checkstring(L, -1); + + additional_path = NULL; + + if (opts != NULL) { + opts = ucl_object_lookup(opts, "lua_cpath"); + if (opts != NULL && ucl_object_type(opts) == UCL_STRING) { + additional_path = ucl_object_tostring(opts); + } + } + + if (additional_path) { + rspamd_snprintf(path_buf, sizeof(path_buf), + "%s/?%s;" + "%s", + additional_path, + OS_SO_SUFFIX, + old_path); + } + else { + rspamd_snprintf(path_buf, sizeof(path_buf), + "%s/?%s;" + "%s", + libdir, + OS_SO_SUFFIX, + old_path); + } + + lua_pop(L, 1); + lua_pushstring(L, path_buf); + lua_setfield(L, -2, "cpath"); + + lua_pop(L, 1); +} + +static gint +rspamd_lua_cmp_version_components(const gchar *comp1, const gchar *comp2) +{ + guint v1, v2; + + v1 = strtoul(comp1, NULL, 10); + v2 = strtoul(comp2, NULL, 10); + + return v1 - v2; +} + +static int +rspamd_lua_rspamd_version_cmp(lua_State *L) +{ + const gchar *ver; + gchar **components; + gint ret = 0; + + if (lua_type(L, 2) == LUA_TSTRING) { + ver = lua_tostring(L, 2); + + components = g_strsplit_set(ver, ".-_", -1); + + if (!components) { + return luaL_error(L, "invalid arguments to 'cmp': %s", ver); + } + + if (components[0]) { + ret = rspamd_lua_cmp_version_components(components[0], + RSPAMD_VERSION_MAJOR); + } + + if (ret) { + goto set; + } + + if (components[1]) { + ret = rspamd_lua_cmp_version_components(components[1], + RSPAMD_VERSION_MINOR); + } + + if (ret) { + goto set; + } + + /* + * XXX: we don't compare git releases assuming that it is meaningless + */ + } + else { + return luaL_error(L, "invalid arguments to 'cmp'"); + } + +set: + g_strfreev(components); + lua_pushinteger(L, ret); + + return 1; +} + +static int +rspamd_lua_rspamd_version_numeric(lua_State *L) +{ + static gint64 version_num = RSPAMD_VERSION_NUM; + const gchar *type; + + if (lua_gettop(L) >= 2 && lua_type(L, 1) == LUA_TSTRING) { + type = lua_tostring(L, 1); + if (g_ascii_strcasecmp(type, "short") == 0) { + version_num = RSPAMD_VERSION_MAJOR_NUM * 1000 + + RSPAMD_VERSION_MINOR_NUM * 100 + + RSPAMD_VERSION_PATCH_NUM * 10; + } + else if (g_ascii_strcasecmp(type, "main") == 0) { + version_num = RSPAMD_VERSION_MAJOR_NUM * 1000 + + RSPAMD_VERSION_MINOR_NUM * 100 + + RSPAMD_VERSION_PATCH_NUM * 10; + } + else if (g_ascii_strcasecmp(type, "major") == 0) { + version_num = RSPAMD_VERSION_MAJOR_NUM; + } + else if (g_ascii_strcasecmp(type, "patch") == 0) { + version_num = RSPAMD_VERSION_PATCH_NUM; + } + else if (g_ascii_strcasecmp(type, "minor") == 0) { + version_num = RSPAMD_VERSION_MINOR_NUM; + } + } + + lua_pushinteger(L, version_num); + + return 1; +} + +static int +rspamd_lua_rspamd_version(lua_State *L) +{ + const gchar *result = NULL, *type; + + if (lua_gettop(L) == 0) { + result = RVERSION; + } + else if (lua_gettop(L) >= 1 && lua_type(L, 1) == LUA_TSTRING) { + /* We got something like string */ + type = lua_tostring(L, 1); + + if (g_ascii_strcasecmp(type, "short") == 0) { + result = RSPAMD_VERSION_MAJOR + "." RSPAMD_VERSION_MINOR; + } + else if (g_ascii_strcasecmp(type, "main") == 0) { + result = RSPAMD_VERSION_MAJOR "." RSPAMD_VERSION_MINOR "." RSPAMD_VERSION_PATCH; + } + else if (g_ascii_strcasecmp(type, "major") == 0) { + result = RSPAMD_VERSION_MAJOR; + } + else if (g_ascii_strcasecmp(type, "minor") == 0) { + result = RSPAMD_VERSION_MINOR; + } + else if (g_ascii_strcasecmp(type, "patch") == 0) { + result = RSPAMD_VERSION_PATCH; + } + else if (g_ascii_strcasecmp(type, "id") == 0) { + result = RID; + } + else if (g_ascii_strcasecmp(type, "num") == 0) { + return rspamd_lua_rspamd_version_numeric(L); + } + else if (g_ascii_strcasecmp(type, "cmp") == 0) { + return rspamd_lua_rspamd_version_cmp(L); + } + } + + lua_pushstring(L, result); + + return 1; +} + +static gboolean +rspamd_lua_load_env(lua_State *L, const char *fname, gint tbl_pos, GError **err) +{ + gint orig_top = lua_gettop(L), err_idx; + gboolean ret = TRUE; + + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + + if (luaL_loadfile(L, fname) != 0) { + g_set_error(err, g_quark_from_static_string("lua_env"), errno, + "cannot load lua file %s: %s", + fname, + lua_tostring(L, -1)); + ret = FALSE; + } + + if (ret && lua_pcall(L, 0, 1, err_idx) != 0) { + g_set_error(err, g_quark_from_static_string("lua_env"), errno, + "cannot init lua file %s: %s", + fname, + lua_tostring(L, -1)); + ret = FALSE; + } + + if (ret && lua_type(L, -1) == LUA_TTABLE) { + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + lua_pushvalue(L, -2); /* Store key */ + lua_pushvalue(L, -2); /* Store value */ + lua_settable(L, tbl_pos); + } + } + else if (ret) { + g_set_error(err, g_quark_from_static_string("lua_env"), errno, + "invalid return type when loading env from %s: %s", + fname, + lua_typename(L, lua_type(L, -1))); + ret = FALSE; + } + + lua_settop(L, orig_top); + + return ret; +} + +gboolean +rspamd_lua_set_env(lua_State *L, GHashTable *vars, char **lua_env, GError **err) +{ + gint orig_top = lua_gettop(L); + gchar **env = g_get_environ(); + + /* Set known paths as rspamd_paths global */ + lua_getglobal(L, "rspamd_paths"); + if (lua_isnil(L, -1)) { + const gchar *confdir = RSPAMD_CONFDIR, + *local_confdir = RSPAMD_LOCAL_CONFDIR, + *rundir = RSPAMD_RUNDIR, + *dbdir = RSPAMD_DBDIR, + *logdir = RSPAMD_LOGDIR, + *wwwdir = RSPAMD_WWWDIR, + *pluginsdir = RSPAMD_PLUGINSDIR, + *rulesdir = RSPAMD_RULESDIR, + *lualibdir = RSPAMD_LUALIBDIR, + *prefix = RSPAMD_PREFIX, + *sharedir = RSPAMD_SHAREDIR; + const gchar *t; + + /* Try environment */ + t = g_environ_getenv(env, "SHAREDIR"); + if (t) { + sharedir = t; + } + + t = g_environ_getenv(env, "PLUGINSDIR"); + if (t) { + pluginsdir = t; + } + + t = g_environ_getenv(env, "RULESDIR"); + if (t) { + rulesdir = t; + } + + t = g_environ_getenv(env, "DBDIR"); + if (t) { + dbdir = t; + } + + t = g_environ_getenv(env, "RUNDIR"); + if (t) { + rundir = t; + } + + t = g_environ_getenv(env, "LUALIBDIR"); + if (t) { + lualibdir = t; + } + + t = g_environ_getenv(env, "LOGDIR"); + if (t) { + logdir = t; + } + + t = g_environ_getenv(env, "WWWDIR"); + if (t) { + wwwdir = t; + } + + t = g_environ_getenv(env, "CONFDIR"); + if (t) { + confdir = t; + } + + t = g_environ_getenv(env, "LOCAL_CONFDIR"); + if (t) { + local_confdir = t; + } + + + if (vars) { + t = g_hash_table_lookup(vars, "SHAREDIR"); + if (t) { + sharedir = t; + } + + t = g_hash_table_lookup(vars, "PLUGINSDIR"); + if (t) { + pluginsdir = t; + } + + t = g_hash_table_lookup(vars, "RULESDIR"); + if (t) { + rulesdir = t; + } + + t = g_hash_table_lookup(vars, "LUALIBDIR"); + if (t) { + lualibdir = t; + } + + t = g_hash_table_lookup(vars, "RUNDIR"); + if (t) { + rundir = t; + } + + t = g_hash_table_lookup(vars, "WWWDIR"); + if (t) { + wwwdir = t; + } + + t = g_hash_table_lookup(vars, "CONFDIR"); + if (t) { + confdir = t; + } + + t = g_hash_table_lookup(vars, "LOCAL_CONFDIR"); + if (t) { + local_confdir = t; + } + + t = g_hash_table_lookup(vars, "DBDIR"); + if (t) { + dbdir = t; + } + + t = g_hash_table_lookup(vars, "LOGDIR"); + if (t) { + logdir = t; + } + } + + lua_createtable(L, 0, 9); + + rspamd_lua_table_set(L, RSPAMD_SHAREDIR_INDEX, sharedir); + rspamd_lua_table_set(L, RSPAMD_CONFDIR_INDEX, confdir); + rspamd_lua_table_set(L, RSPAMD_LOCAL_CONFDIR_INDEX, local_confdir); + rspamd_lua_table_set(L, RSPAMD_RUNDIR_INDEX, rundir); + rspamd_lua_table_set(L, RSPAMD_DBDIR_INDEX, dbdir); + rspamd_lua_table_set(L, RSPAMD_LOGDIR_INDEX, logdir); + rspamd_lua_table_set(L, RSPAMD_WWWDIR_INDEX, wwwdir); + rspamd_lua_table_set(L, RSPAMD_PLUGINSDIR_INDEX, pluginsdir); + rspamd_lua_table_set(L, RSPAMD_RULESDIR_INDEX, rulesdir); + rspamd_lua_table_set(L, RSPAMD_LUALIBDIR_INDEX, lualibdir); + rspamd_lua_table_set(L, RSPAMD_PREFIX_INDEX, prefix); + + lua_setglobal(L, "rspamd_paths"); + } + + lua_getglobal(L, "rspamd_env"); + if (lua_isnil(L, -1)) { + lua_newtable(L); + + if (vars != NULL) { + GHashTableIter it; + gpointer k, v; + + g_hash_table_iter_init(&it, vars); + + while (g_hash_table_iter_next(&it, &k, &v)) { + rspamd_lua_table_set(L, k, v); + } + } + + gint hostlen = sysconf(_SC_HOST_NAME_MAX); + + if (hostlen <= 0) { + hostlen = 256; + } + else { + hostlen++; + } + + gchar *hostbuf = g_alloca(hostlen); + memset(hostbuf, 0, hostlen); + gethostname(hostbuf, hostlen - 1); + + rspamd_lua_table_set(L, "hostname", hostbuf); + + rspamd_lua_table_set(L, "version", RVERSION); + rspamd_lua_table_set(L, "ver_major", RSPAMD_VERSION_MAJOR); + rspamd_lua_table_set(L, "ver_minor", RSPAMD_VERSION_MINOR); + rspamd_lua_table_set(L, "ver_id", RID); + lua_pushstring(L, "ver_num"); + lua_pushinteger(L, RSPAMD_VERSION_NUM); + lua_settable(L, -3); + + if (env) { + gint lim = g_strv_length(env); + + for (gint i = 0; i < lim; i++) { + if (RSPAMD_LEN_CHECK_STARTS_WITH(env[i], strlen(env[i]), "RSPAMD_")) { + const char *var = env[i] + sizeof("RSPAMD_") - 1, *value; + gint varlen; + + varlen = strcspn(var, "="); + value = var + varlen; + + if (*value == '=') { + value++; + + lua_pushlstring(L, var, varlen); + lua_pushstring(L, value); + lua_settable(L, -3); + } + } + } + } + + if (lua_env) { + gint lim = g_strv_length(lua_env); + + for (gint i = 0; i < lim; i++) { + if (!rspamd_lua_load_env(L, lua_env[i], lua_gettop(L), err)) { + return FALSE; + } + } + } + + lua_setglobal(L, "rspamd_env"); + } + + lua_settop(L, orig_top); + g_strfreev(env); + + return TRUE; +} + +void rspamd_lua_set_globals(struct rspamd_config *cfg, lua_State *L) +{ + struct rspamd_config **pcfg; + gint orig_top = lua_gettop(L); + + /* First check for global variable 'config' */ + lua_getglobal(L, "config"); + if (lua_isnil(L, -1)) { + /* Assign global table to set up attributes */ + lua_newtable(L); + lua_setglobal(L, "config"); + } + + lua_getglobal(L, "metrics"); + if (lua_isnil(L, -1)) { + lua_newtable(L); + lua_setglobal(L, "metrics"); + } + + lua_getglobal(L, "composites"); + if (lua_isnil(L, -1)) { + lua_newtable(L); + lua_setglobal(L, "composites"); + } + + lua_getglobal(L, "rspamd_classifiers"); + if (lua_isnil(L, -1)) { + lua_newtable(L); + lua_setglobal(L, "rspamd_classifiers"); + } + + lua_getglobal(L, "classifiers"); + if (lua_isnil(L, -1)) { + lua_newtable(L); + lua_setglobal(L, "classifiers"); + } + + lua_getglobal(L, "rspamd_version"); + if (lua_isnil(L, -1)) { + lua_pushcfunction(L, rspamd_lua_rspamd_version); + lua_setglobal(L, "rspamd_version"); + } + + if (cfg != NULL) { + pcfg = lua_newuserdata(L, sizeof(struct rspamd_config *)); + rspamd_lua_setclass(L, "rspamd{config}", -1); + *pcfg = cfg; + lua_setglobal(L, "rspamd_config"); + } + + lua_settop(L, orig_top); +} + +#ifdef WITH_LUA_TRACE +static gint +lua_push_trace_data(lua_State *L) +{ + if (lua_traces) { + ucl_object_push_lua(L, lua_traces, true); + } + else { + lua_pushnil(L); + } + + return 1; +} +#endif + + +static void * +rspamd_lua_wipe_realloc(void *ud, + void *ptr, + size_t osize, + size_t nsize) RSPAMD_ATTR_ALLOC_SIZE(4); +static void * +rspamd_lua_wipe_realloc(void *ud, + void *ptr, + size_t osize, + size_t nsize) +{ + if (nsize == 0) { + if (ptr) { + rspamd_explicit_memzero(ptr, osize); + } + + free(ptr); + } + else if (ptr == NULL) { + return malloc(nsize); + } + else { + if (nsize < osize) { + /* Wipe on shrinking (actually never used) */ + rspamd_explicit_memzero(((unsigned char *) ptr) + nsize, osize - nsize); + } + + return realloc(ptr, nsize); + } + + return NULL; +} + +#ifndef WITH_LUAJIT +extern int luaopen_bit(lua_State *L); +#endif + +static unsigned int lua_initialized = 0; + +lua_State * +rspamd_lua_init(bool wipe_mem) +{ + lua_State *L; + + if (wipe_mem) { +#ifdef WITH_LUAJIT + /* TODO: broken on luajit without GC64 */ + L = luaL_newstate(); +#else + L = lua_newstate(rspamd_lua_wipe_realloc, NULL); +#endif + } + else { + L = luaL_newstate(); + } + + struct rspamd_lua_context *ctx; + + ctx = (struct rspamd_lua_context *) g_malloc0(sizeof(*ctx)); + ctx->L = L; + ctx->classes = kh_init(lua_class_set); + kh_resize(lua_class_set, ctx->classes, RSPAMD_LUA_NCLASSES); + DL_APPEND(rspamd_lua_global_ctx, ctx); + + lua_gc(L, LUA_GCSTOP, 0); + luaL_openlibs(L); + luaopen_logger(L); + luaopen_mempool(L); + luaopen_config(L); + luaopen_map(L); + luaopen_trie(L); + luaopen_task(L); + luaopen_textpart(L); + luaopen_mimepart(L); + luaopen_image(L); + luaopen_url(L); + luaopen_classifier(L); + luaopen_statfile(L); + luaopen_regexp(L); + luaopen_cdb(L); + luaopen_xmlrpc(L); + luaopen_http(L); + luaopen_redis(L); + luaopen_upstream(L); + lua_add_actions_global(L); + luaopen_dns_resolver(L); + luaopen_rsa(L); + luaopen_ip(L); + luaopen_expression(L); + luaopen_text(L); + luaopen_util(L); + luaopen_tcp(L); + luaopen_html(L); + luaopen_sqlite3(L); + luaopen_cryptobox(L); + luaopen_dns(L); + luaopen_udp(L); + luaopen_worker(L); + luaopen_kann(L); + luaopen_spf(L); + luaopen_tensor(L); + luaopen_parsers(L); + luaopen_compress(L); +#ifndef WITH_LUAJIT + rspamd_lua_add_preload(L, "bit", luaopen_bit); + lua_settop(L, 0); +#endif + + rspamd_lua_new_class(L, "rspamd{session}", NULL); + lua_pop(L, 1); + + rspamd_lua_add_preload(L, "lpeg", luaopen_lpeg); + luaopen_ucl(L); + rspamd_lua_add_preload(L, "ucl", luaopen_ucl); + + /* Add plugins global */ + lua_newtable(L); + lua_setglobal(L, "rspamd_plugins"); + + /* Set PRNG */ + lua_getglobal(L, "math"); + lua_pushstring(L, "randomseed"); /* Push math.randomseed function on top of the stack */ + lua_gettable(L, -2); + lua_pushinteger(L, ottery_rand_uint64()); + g_assert(lua_pcall(L, 1, 0, 0) == 0); + lua_pop(L, 1); /* math table */ + + /* Modules state */ + lua_newtable(L); + /* + * rspamd_plugins_state = { + * enabled = {}, + * disabled_unconfigured = {}, + * disabled_redis = {}, + * disabled_explicitly = {}, + * disabled_failed = {}, + * disabled_experimental = {}, + * disabled_unknown = {}, + * } + */ +#define ADD_TABLE(name) \ + do { \ + lua_pushstring(L, #name); \ + lua_newtable(L); \ + lua_settable(L, -3); \ + } while (0) + + ADD_TABLE(enabled); + ADD_TABLE(disabled_unconfigured); + ADD_TABLE(disabled_redis); + ADD_TABLE(disabled_explicitly); + ADD_TABLE(disabled_failed); + ADD_TABLE(disabled_experimental); + ADD_TABLE(disabled_unknown); + +#undef ADD_TABLE + lua_setglobal(L, rspamd_modules_state_global); + +#ifdef WITH_LUA_TRACE + lua_pushcfunction(L, lua_push_trace_data); + lua_setglobal(L, "get_traces"); +#endif + + lua_initialized++; + + return L; +} + +void rspamd_lua_close(lua_State *L) +{ + struct rspamd_lua_context *ctx = rspamd_lua_ctx_by_state(L); + + /* TODO: we will leak this memory, but I don't know how to resolve + * the chicked-egg problem when lua_close calls GC for many + * userdata that requires classes metatables to be represented + * For now, it is safe to leave it as is, I'm afraid + */ +#if 0 + int ref; + kh_foreach_value(ctx->classes, ref, { + luaL_unref(L, LUA_REGISTRYINDEX, ref); + }); +#endif + + lua_close(L); + DL_DELETE(rspamd_lua_global_ctx, ctx); + kh_destroy(lua_class_set, ctx->classes); + g_free(ctx); + + lua_initialized--; +} + +bool rspamd_lua_is_initialised(void) +{ + return lua_initialized != 0; +} + +void rspamd_lua_start_gc(struct rspamd_config *cfg) +{ + lua_State *L = (lua_State *) cfg->lua_state; + + lua_settop(L, 0); + /* Set up GC */ + lua_gc(L, LUA_GCCOLLECT, 0); + lua_gc(L, LUA_GCSETSTEPMUL, cfg->lua_gc_step); + lua_gc(L, LUA_GCSETPAUSE, cfg->lua_gc_pause); + lua_gc(L, LUA_GCRESTART, 0); +} + + +void rspamd_plugins_table_push_elt(lua_State *L, const gchar *field_name, + const gchar *new_elt) +{ + lua_getglobal(L, rspamd_modules_state_global); + + if (lua_istable(L, -1)) { + lua_pushstring(L, field_name); + lua_gettable(L, -2); + + if (lua_istable(L, -1)) { + lua_pushstring(L, new_elt); + lua_newtable(L); + lua_settable(L, -3); + lua_pop(L, 2); /* Global + element */ + } + else { + lua_pop(L, 2); /* Global + element */ + } + } + else { + lua_pop(L, 1); + } +} + +gboolean +rspamd_init_lua_filters(struct rspamd_config *cfg, bool force_load, bool strict) +{ + struct rspamd_config **pcfg; + struct script_module *module; + lua_State *L = cfg->lua_state; + gint err_idx, i; + + pcfg = lua_newuserdata(L, sizeof(struct rspamd_config *)); + rspamd_lua_setclass(L, "rspamd{config}", -1); + *pcfg = cfg; + lua_setglobal(L, "rspamd_config"); + + PTR_ARRAY_FOREACH(cfg->script_modules, i, module) + { + if (module->path) { + if (!force_load) { + if (!rspamd_config_is_module_enabled(cfg, module->name)) { + continue; + } + } + + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + + gsize fsize; + guint8 *data = rspamd_file_xmap(module->path, + PROT_READ, &fsize, TRUE); + guchar digest[rspamd_cryptobox_HASHBYTES]; + gchar *lua_fname; + + if (data == NULL) { + msg_err_config("cannot mmap %s failed: %s", module->path, + strerror(errno)); + + lua_settop(L, err_idx - 1); /* Error function */ + + rspamd_plugins_table_push_elt(L, "disabled_failed", + module->name); + + if (strict) { + return FALSE; + } + + continue; + } + + module->digest = rspamd_mempool_alloc(cfg->cfg_pool, + rspamd_cryptobox_HASHBYTES * 2 + 1); + rspamd_cryptobox_hash(digest, data, fsize, NULL, 0); + rspamd_encode_hex_buf(digest, sizeof(digest), + module->digest, rspamd_cryptobox_HASHBYTES * 2 + 1); + module->digest[rspamd_cryptobox_HASHBYTES * 2] = '\0'; + lua_fname = g_malloc(strlen(module->path) + 2); + rspamd_snprintf(lua_fname, strlen(module->path) + 2, "@%s", + module->path); + + if (luaL_loadbuffer(L, data, fsize, lua_fname) != 0) { + msg_err_config("load of %s failed: %s", module->path, + lua_tostring(L, -1)); + lua_settop(L, err_idx - 1); /* Error function */ + + rspamd_plugins_table_push_elt(L, "disabled_failed", + module->name); + munmap(data, fsize); + g_free(lua_fname); + + if (strict) { + return FALSE; + } + + continue; + } + + munmap(data, fsize); + g_free(lua_fname); + + if (lua_pcall(L, 0, 0, err_idx) != 0) { + msg_err_config("init of %s failed: %s", + module->path, + lua_tostring(L, -1)); + + lua_settop(L, err_idx - 1); + rspamd_plugins_table_push_elt(L, "disabled_failed", + module->name); + + if (strict) { + return FALSE; + } + + continue; + } + + if (!force_load) { + msg_info_config("init lua module %s from %s; digest: %*s", + module->name, + module->path, + 10, module->digest); + } + + lua_pop(L, 1); /* Error function */ + } + } + + return TRUE; +} + +void rspamd_lua_dumpstack(lua_State *L) +{ + gint i, t, r = 0; + gint top = lua_gettop(L); + gchar buf[BUFSIZ]; + + r += rspamd_snprintf(buf + r, sizeof(buf) - r, "lua stack: "); + for (i = 1; i <= top; i++) { /* repeat for each level */ + t = lua_type(L, i); + switch (t) { + case LUA_TSTRING: /* strings */ + r += rspamd_snprintf(buf + r, + sizeof(buf) - r, + "str: %s", + lua_tostring(L, i)); + break; + + case LUA_TBOOLEAN: /* booleans */ + r += rspamd_snprintf(buf + r, sizeof(buf) - r, + lua_toboolean(L, i) ? "bool: true" : "bool: false"); + break; + + case LUA_TNUMBER: /* numbers */ + r += rspamd_snprintf(buf + r, + sizeof(buf) - r, + "number: %.2f", + lua_tonumber(L, i)); + break; + + default: /* other values */ + r += rspamd_snprintf(buf + r, + sizeof(buf) - r, + "type: %s", + lua_typename(L, t)); + break; + } + if (i < top) { + r += rspamd_snprintf(buf + r, sizeof(buf) - r, + " -> "); /* put a separator */ + } + } + + msg_info("%*s", r, buf); +} + +gpointer +rspamd_lua_check_class(lua_State *L, gint index, const gchar *name) +{ + gpointer p; + khiter_t k; + + if (lua_type(L, index) == LUA_TUSERDATA) { + p = lua_touserdata(L, index); + if (p) { + if (lua_getmetatable(L, index)) { + struct rspamd_lua_context *ctx = rspamd_lua_ctx_by_state(L); + + k = kh_get(lua_class_set, ctx->classes, name); + + if (k == kh_end(ctx->classes)) { + lua_pop(L, 1); + + return NULL; + } + + lua_rawgeti(L, LUA_REGISTRYINDEX, kh_value(ctx->classes, k)); + + if (lua_rawequal(L, -1, -2)) { /* does it have the correct mt? */ + lua_pop(L, 2); /* remove both metatables */ + return p; + } + lua_pop(L, 2); + } + } + } + return NULL; +} + +int rspamd_lua_typerror(lua_State *L, int narg, const char *tname) +{ + const char *msg = lua_pushfstring(L, "%s expected, got %s", tname, + luaL_typename(L, narg)); + return luaL_argerror(L, narg, msg); +} + + +void rspamd_lua_add_preload(lua_State *L, const gchar *name, lua_CFunction func) +{ + lua_getglobal(L, "package"); + lua_pushstring(L, "preload"); + lua_gettable(L, -2); + lua_pushcfunction(L, func); + lua_setfield(L, -2, name); + lua_pop(L, 2); /* preload key + global package */ +} + + +gboolean +rspamd_lua_parse_table_arguments(lua_State *L, gint pos, + GError **err, + enum rspamd_lua_parse_arguments_flags how, + const gchar *extraction_pattern, ...) +{ + const gchar *p, *key = NULL, *end, *cls; + va_list ap; + gboolean required = FALSE, failed = FALSE, is_table; + gchar classbuf[128]; + enum { + read_key = 0, + read_arg, + read_class_start, + read_class, + read_semicolon + } state = read_key; + gsize keylen = 0, *valuelen, clslen; + gint idx = 0, t, direct_userdata = 0; + + g_assert(extraction_pattern != NULL); + + if (pos < 0) { + /* Get absolute pos */ + pos = lua_gettop(L) + pos + 1; + } + + if (lua_type(L, pos) == LUA_TTABLE) { + is_table = TRUE; + } + else { + is_table = FALSE; + idx = pos; + } + + p = extraction_pattern; + end = p + strlen(extraction_pattern); + + va_start(ap, extraction_pattern); + + while (p <= end) { + switch (state) { + case read_key: + if (*p == '=') { + if (key == NULL) { + g_set_error(err, lua_error_quark(), 1, "cannot read key"); + va_end(ap); + + return FALSE; + } + + state = read_arg; + keylen = p - key; + } + else if (*p == '*' && key == NULL) { + required = TRUE; + } + else if (key == NULL) { + key = p; + } + p++; + break; + case read_arg: + g_assert(keylen != 0); + + if (is_table) { + lua_pushlstring(L, key, keylen); + lua_gettable(L, pos); + idx = -1; + } + + t = lua_type(L, idx); + + switch (*p) { + case 'S': + if (t == LUA_TSTRING) { + *(va_arg(ap, const gchar **)) = lua_tostring(L, idx); + } + else if (t == LUA_TNIL || t == LUA_TNONE) { + failed = TRUE; + + if (how != RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING) { + *(va_arg(ap, const gchar **)) = NULL; + } + else { + (void) va_arg(ap, gchar **); + } + } + else { + g_set_error(err, + lua_error_quark(), + 1, + "bad type for key:" + " %.*s: '%s', '%s' is expected", + (gint) keylen, + key, + lua_typename(L, lua_type(L, idx)), "string"); + va_end(ap); + + return FALSE; + } + + if (is_table) { + lua_pop(L, 1); + } + break; + + case 'I': + if (t == LUA_TNUMBER) { + *(va_arg(ap, gint64 *)) = lua_tointeger(L, idx); + } + else if (t == LUA_TNIL || t == LUA_TNONE) { + failed = TRUE; + if (how != RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING) { + *(va_arg(ap, gint64 *)) = 0; + } + else { + (void) va_arg(ap, gint64 *); + } + } + else { + g_set_error(err, + lua_error_quark(), + 1, + "bad type for key:" + " %.*s: '%s', '%s' is expected", + (gint) keylen, + key, + lua_typename(L, lua_type(L, idx)), + "int64"); + va_end(ap); + + return FALSE; + } + if (is_table) { + lua_pop(L, 1); + } + break; + + case 'i': + if (t == LUA_TNUMBER) { + *(va_arg(ap, gint32 *)) = lua_tointeger(L, idx); + } + else if (t == LUA_TNIL || t == LUA_TNONE) { + failed = TRUE; + if (how != RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING) { + *(va_arg(ap, gint32 *)) = 0; + } + else { + (void) va_arg(ap, gint32 *); + } + } + else { + g_set_error(err, + lua_error_quark(), + 1, + "bad type for key:" + " %.*s: '%s', '%s' is expected", + (gint) keylen, + key, + lua_typename(L, lua_type(L, idx)), + "int64"); + va_end(ap); + + return FALSE; + } + if (is_table) { + lua_pop(L, 1); + } + break; + + case 'F': + if (t == LUA_TFUNCTION) { + if (!is_table) { + lua_pushvalue(L, idx); + } + + *(va_arg(ap, gint *)) = luaL_ref(L, LUA_REGISTRYINDEX); + } + else if (t == LUA_TNIL || t == LUA_TNONE) { + failed = TRUE; + + if (how != RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING) { + *(va_arg(ap, gint *)) = -1; + } + else { + (void) va_arg(ap, gint *); + } + + if (is_table) { + lua_pop(L, 1); + } + } + else { + g_set_error(err, + lua_error_quark(), + 1, + "bad type for key:" + " %.*s: '%s', '%s' is expected", + (gint) keylen, + key, + lua_typename(L, lua_type(L, idx)), + "function"); + va_end(ap); + if (is_table) { + lua_pop(L, 1); + } + + return FALSE; + } + + /* luaL_ref pops argument from the stack */ + break; + + case 'B': + if (t == LUA_TBOOLEAN) { + *(va_arg(ap, gboolean *)) = lua_toboolean(L, idx); + } + else if (t == LUA_TNIL || t == LUA_TNONE) { + failed = TRUE; + + if (how != RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING) { + *(va_arg(ap, gboolean *)) = 0; + } + } + else { + g_set_error(err, + lua_error_quark(), + 1, + "bad type for key:" + " %.*s: '%s', '%s' is expected", + (gint) keylen, + key, + lua_typename(L, lua_type(L, idx)), + "bool"); + va_end(ap); + + return FALSE; + } + + if (is_table) { + lua_pop(L, 1); + } + break; + + case 'N': + if (t == LUA_TNUMBER) { + *(va_arg(ap, gdouble *)) = lua_tonumber(L, idx); + } + else if (t == LUA_TNIL || t == LUA_TNONE) { + failed = TRUE; + + if (how != RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING) { + *(va_arg(ap, gdouble *)) = 0; + } + else { + (void) va_arg(ap, gdouble *); + } + } + else { + g_set_error(err, + lua_error_quark(), + 1, + "bad type for key:" + " %.*s: '%s', '%s' is expected", + (gint) keylen, + key, + lua_typename(L, lua_type(L, idx)), + "double"); + va_end(ap); + + return FALSE; + } + + if (is_table) { + lua_pop(L, 1); + } + break; + + case 'D': + if (t == LUA_TNUMBER) { + *(va_arg(ap, gdouble *)) = lua_tonumber(L, idx); + } + else if (t == LUA_TNIL || t == LUA_TNONE) { + failed = TRUE; + + if (how != RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING) { + *(va_arg(ap, gdouble *)) = NAN; + } + else { + (void) va_arg(ap, gdouble *); + } + } + else { + g_set_error(err, + lua_error_quark(), + 1, + "bad type for key:" + " %.*s: '%s', '%s' is expected", + (gint) keylen, + key, + lua_typename(L, lua_type(L, idx)), + "double"); + va_end(ap); + + return FALSE; + } + + if (is_table) { + lua_pop(L, 1); + } + break; + + case 'V': + valuelen = va_arg(ap, gsize *); + + if (t == LUA_TSTRING) { + *(va_arg(ap, const gchar **)) = lua_tolstring(L, idx, + valuelen); + } + else if (t == LUA_TNIL || t == LUA_TNONE) { + failed = TRUE; + + if (how != RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING) { + *(va_arg(ap, const char **)) = NULL; + *valuelen = 0; + } + else { + (void) va_arg(ap, const char **); + } + } + else { + g_set_error(err, + lua_error_quark(), + 1, + "bad type for key:" + " %.*s: '%s', '%s' is expected", + (gint) keylen, + key, + lua_typename(L, lua_type(L, idx)), + "string"); + va_end(ap); + + return FALSE; + } + + if (is_table) { + lua_pop(L, 1); + } + break; + case 'O': + if (t != LUA_TNONE) { + *(va_arg(ap, ucl_object_t **)) = ucl_object_lua_import(L, + idx); + } + else { + failed = TRUE; + + if (how != RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING) { + *(va_arg(ap, ucl_object_t **)) = NULL; + } + else { + (void) va_arg(ap, ucl_object_t **); + } + } + + if (is_table) { + lua_pop(L, 1); + } + break; + case 'U': + if (t == LUA_TNIL || t == LUA_TNONE) { + failed = TRUE; + + if (how != RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING) { + *(va_arg(ap, void **)) = NULL; + } + else { + (void) va_arg(ap, void **); + } + } + else if (t != LUA_TUSERDATA) { + g_set_error(err, + lua_error_quark(), + 1, + "bad type for key:" + " %.*s: '%s', '%s' is expected", + (gint) keylen, + key, + lua_typename(L, lua_type(L, idx)), + "int64"); + va_end(ap); + + return FALSE; + } + + state = read_class_start; + clslen = 0; + direct_userdata = 0; + cls = NULL; + p++; + continue; + case 'u': + if (t == LUA_TNIL || t == LUA_TNONE) { + failed = TRUE; + + if (how != RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING) { + *(va_arg(ap, void **)) = NULL; + } + else { + (void) va_arg(ap, void **); + } + } + else if (t != LUA_TUSERDATA) { + g_set_error(err, + lua_error_quark(), + 1, + "bad type for key:" + " %.*s: '%s', '%s' is expected", + (gint) keylen, + key, + lua_typename(L, lua_type(L, idx)), + "int64"); + va_end(ap); + + return FALSE; + } + + state = read_class_start; + clslen = 0; + direct_userdata = 1; + cls = NULL; + p++; + continue; + default: + g_assert(0); + break; + } + + if (failed && required) { + g_set_error(err, lua_error_quark(), 2, "required parameter " + "%.*s is missing", + (gint) keylen, key); + va_end(ap); + + return FALSE; + } + + if (!is_table) { + idx++; + } + + /* Reset read params */ + state = read_semicolon; + failed = FALSE; + required = FALSE; + keylen = 0; + key = NULL; + p++; + break; + + case read_class_start: + if (*p == '{') { + cls = p + 1; + state = read_class; + } + else { + if (is_table) { + lua_pop(L, 1); + } + + g_set_error(err, lua_error_quark(), 2, "missing classname for " + "%.*s", + (gint) keylen, key); + va_end(ap); + + return FALSE; + } + p++; + break; + + case read_class: + if (*p == '}') { + clslen = p - cls; + if (clslen == 0) { + if (is_table) { + lua_pop(L, 1); + } + + g_set_error(err, + lua_error_quark(), + 2, + "empty classname for " + "%*.s", + (gint) keylen, + key); + va_end(ap); + + return FALSE; + } + + rspamd_snprintf(classbuf, sizeof(classbuf), "rspamd{%*s}", + (gint) clslen, cls); + + + /* + * We skip class check here for speed in non-table mode + */ + if (!failed && (!is_table || + rspamd_lua_check_class(L, idx, classbuf))) { + if (direct_userdata) { + void **arg_p = (va_arg(ap, void **)); + *arg_p = lua_touserdata(L, idx); + } + else { + *(va_arg(ap, + void **)) = *(void **) lua_touserdata(L, idx); + } + } + else { + if (!failed) { + g_set_error(err, + lua_error_quark(), + 2, + "invalid class for key %.*s, expected %s, got %s", + (gint) keylen, + key, + classbuf, + rspamd_lua_class_tostring_buf(L, FALSE, idx)); + va_end(ap); + + return FALSE; + } + } + + if (is_table) { + lua_pop(L, 1); + } + else { + idx++; + } + + if (failed && required) { + g_set_error(err, + lua_error_quark(), + 2, + "required parameter " + "%.*s is missing", + (gint) keylen, + key); + va_end(ap); + + return FALSE; + } + + /* Reset read params */ + state = read_semicolon; + failed = FALSE; + required = FALSE; + keylen = 0; + key = NULL; + } + p++; + break; + + case read_semicolon: + if (*p == ';' || p == end) { + state = read_key; + key = NULL; + keylen = 0; + failed = FALSE; + } + else { + g_set_error(err, lua_error_quark(), 2, "bad format string: %s," + " at char %c, position %d", + extraction_pattern, *p, (int) (p - extraction_pattern)); + va_end(ap); + + return FALSE; + } + + p++; + break; + } + } + + va_end(ap); + + return TRUE; +} + +static void +rspamd_lua_traceback_string(lua_State *L, luaL_Buffer *buf) +{ + gint i = 1, r; + lua_Debug d; + gchar tmp[256]; + + while (lua_getstack(L, i++, &d)) { + lua_getinfo(L, "nSl", &d); + r = rspamd_snprintf(tmp, sizeof(tmp), " [%d]:{%s:%d - %s [%s]};", + i - 1, d.short_src, d.currentline, + (d.name ? d.name : "<unknown>"), d.what); + luaL_addlstring(buf, tmp, r); + } +} + +gint rspamd_lua_traceback(lua_State *L) +{ + luaL_Buffer b; + + luaL_buffinit(L, &b); + rspamd_lua_get_traceback_string(L, &b); + luaL_pushresult(&b); + + return 1; +} + +void rspamd_lua_get_traceback_string(lua_State *L, luaL_Buffer *buf) +{ + const gchar *msg = lua_tostring(L, -1); + + if (msg) { + luaL_addstring(buf, msg); + lua_pop(L, 1); /* Error string */ + } + else { + luaL_addstring(buf, "unknown error"); + } + + luaL_addstring(buf, "; trace:"); + rspamd_lua_traceback_string(L, buf); +} + +guint rspamd_lua_table_size(lua_State *L, gint tbl_pos) +{ + guint tbl_size = 0; + + if (!lua_istable(L, tbl_pos)) { + return 0; + } + +#if LUA_VERSION_NUM >= 502 + tbl_size = lua_rawlen(L, tbl_pos); +#else + tbl_size = lua_objlen(L, tbl_pos); +#endif + + return tbl_size; +} + +static void * +rspamd_lua_check_udata_common(lua_State *L, gint pos, const gchar *classname, + gboolean fatal) +{ + void *p = lua_touserdata(L, pos); + guint i, top = lua_gettop(L); + khiter_t k; + + if (p == NULL) { + goto err; + } + else { + /* Match class */ + if (lua_getmetatable(L, pos)) { + struct rspamd_lua_context *ctx = rspamd_lua_ctx_by_state(L); + + k = kh_get(lua_class_set, ctx->classes, classname); + + if (k == kh_end(ctx->classes)) { + goto err; + } + + lua_rawgeti(L, LUA_REGISTRYINDEX, kh_value(ctx->classes, k)); + + if (!lua_rawequal(L, -1, -2)) { + goto err; + } + } + else { + goto err; + } + } + + lua_settop(L, top); + + return p; + +err: + if (fatal) { + const gchar *actual_classname = NULL; + + if (lua_type(L, pos) == LUA_TUSERDATA && lua_getmetatable(L, pos)) { + lua_pushstring(L, "__index"); + lua_gettable(L, -2); + lua_pushstring(L, "class"); + lua_gettable(L, -2); + actual_classname = lua_tostring(L, -1); + } + else { + actual_classname = lua_typename(L, lua_type(L, pos)); + } + + luaL_Buffer buf; + gchar tmp[512]; + gint r; + + luaL_buffinit(L, &buf); + r = rspamd_snprintf(tmp, sizeof(tmp), + "expected %s at position %d, but userdata has " + "%s metatable; trace: ", + classname, pos, actual_classname); + luaL_addlstring(&buf, tmp, r); + rspamd_lua_traceback_string(L, &buf); + r = rspamd_snprintf(tmp, sizeof(tmp), " stack(%d): ", top); + luaL_addlstring(&buf, tmp, r); + + for (i = 1; i <= MIN(top, 10); i++) { + if (lua_type(L, i) == LUA_TUSERDATA) { + const char *clsname; + + if (lua_getmetatable(L, i)) { + lua_pushstring(L, "__index"); + lua_gettable(L, -2); + lua_pushstring(L, "class"); + lua_gettable(L, -2); + clsname = lua_tostring(L, -1); + } + else { + clsname = lua_typename(L, lua_type(L, i)); + } + + r = rspamd_snprintf(tmp, sizeof(tmp), "[%d: ud=%s] ", i, + clsname); + luaL_addlstring(&buf, tmp, r); + } + else { + r = rspamd_snprintf(tmp, sizeof(tmp), "[%d: %s] ", i, + lua_typename(L, lua_type(L, i))); + luaL_addlstring(&buf, tmp, r); + } + } + + luaL_pushresult(&buf); + msg_err("lua type error: %s", lua_tostring(L, -1)); + } + + lua_settop(L, top); + + return NULL; +} + +void * +rspamd_lua_check_udata(lua_State *L, gint pos, const gchar *classname) +{ + return rspamd_lua_check_udata_common(L, pos, classname, TRUE); +} + +void * +rspamd_lua_check_udata_maybe(lua_State *L, gint pos, const gchar *classname) +{ + return rspamd_lua_check_udata_common(L, pos, classname, FALSE); +} + +struct rspamd_async_session * +lua_check_session(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{session}"); + luaL_argcheck(L, ud != NULL, pos, "'session' expected"); + return ud ? *((struct rspamd_async_session **) ud) : NULL; +} + +struct ev_loop * +lua_check_ev_base(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{ev_base}"); + luaL_argcheck(L, ud != NULL, pos, "'event_base' expected"); + return ud ? *((struct ev_loop **) ud) : NULL; +} + +static void rspamd_lua_run_postloads_error(struct thread_entry *thread, int ret, const char *msg); + +void rspamd_lua_run_postloads(lua_State *L, struct rspamd_config *cfg, + struct ev_loop *ev_base, struct rspamd_worker *w) +{ + struct rspamd_config_cfg_lua_script *sc; + struct rspamd_config **pcfg; + struct ev_loop **pev_base; + struct rspamd_worker **pw; + + /* Execute post load scripts */ + LL_FOREACH(cfg->on_load_scripts, sc) + { + struct thread_entry *thread = lua_thread_pool_get_for_config(cfg); + thread->error_callback = rspamd_lua_run_postloads_error; + L = thread->lua_state; + + lua_rawgeti(L, LUA_REGISTRYINDEX, sc->cbref); + pcfg = lua_newuserdata(L, sizeof(*pcfg)); + *pcfg = cfg; + rspamd_lua_setclass(L, "rspamd{config}", -1); + + pev_base = lua_newuserdata(L, sizeof(*pev_base)); + *pev_base = ev_base; + rspamd_lua_setclass(L, "rspamd{ev_base}", -1); + + pw = lua_newuserdata(L, sizeof(*pw)); + *pw = w; + rspamd_lua_setclass(L, "rspamd{worker}", -1); + + lua_thread_call(thread, 3); + } +} + + +void rspamd_lua_run_config_post_init(lua_State *L, struct rspamd_config *cfg) +{ + struct rspamd_config_cfg_lua_script *sc; + struct rspamd_config **pcfg; + + LL_FOREACH(cfg->post_init_scripts, sc) + { + lua_pushcfunction(L, &rspamd_lua_traceback); + gint err_idx = lua_gettop(L); + + lua_rawgeti(L, LUA_REGISTRYINDEX, sc->cbref); + pcfg = lua_newuserdata(L, sizeof(*pcfg)); + *pcfg = cfg; + rspamd_lua_setclass(L, "rspamd{config}", -1); + + if (lua_pcall(L, 1, 0, err_idx) != 0) { + msg_err_config("cannot run config post init script: %s; priority = %d", + lua_tostring(L, -1), sc->priority); + } + + lua_settop(L, err_idx - 1); + } +} + + +void rspamd_lua_run_config_unload(lua_State *L, struct rspamd_config *cfg) +{ + struct rspamd_config_cfg_lua_script *sc; + struct rspamd_config **pcfg; + + LL_FOREACH(cfg->config_unload_scripts, sc) + { + lua_pushcfunction(L, &rspamd_lua_traceback); + gint err_idx = lua_gettop(L); + + lua_rawgeti(L, LUA_REGISTRYINDEX, sc->cbref); + pcfg = lua_newuserdata(L, sizeof(*pcfg)); + *pcfg = cfg; + rspamd_lua_setclass(L, "rspamd{config}", -1); + + if (lua_pcall(L, 1, 0, err_idx) != 0) { + msg_err_config("cannot run config post init script: %s", + lua_tostring(L, -1)); + } + + lua_settop(L, err_idx - 1); + } +} + +static void +rspamd_lua_run_postloads_error(struct thread_entry *thread, int ret, const char *msg) +{ + struct rspamd_config *cfg = thread->cfg; + + msg_err_config("error executing post load code: %s", msg); +} + + +struct rspamd_lua_ref_cbdata { + lua_State *L; + gint cbref; +}; + +static void +rspamd_lua_ref_dtor(gpointer p) +{ + struct rspamd_lua_ref_cbdata *cbdata = p; + + luaL_unref(cbdata->L, LUA_REGISTRYINDEX, cbdata->cbref); +} + +void rspamd_lua_add_ref_dtor(lua_State *L, rspamd_mempool_t *pool, + gint ref) +{ + struct rspamd_lua_ref_cbdata *cbdata; + + if (ref != -1) { + cbdata = rspamd_mempool_alloc(pool, sizeof(*cbdata)); + cbdata->cbref = ref; + cbdata->L = L; + + rspamd_mempool_add_destructor(pool, rspamd_lua_ref_dtor, cbdata); + } +} + +gboolean +rspamd_lua_require_function(lua_State *L, const gchar *modname, + const gchar *funcname) +{ + gint table_pos, err_pos; + + lua_pushcfunction(L, &rspamd_lua_traceback); + err_pos = lua_gettop(L); + lua_getglobal(L, "require"); + + if (lua_isnil(L, -1)) { + lua_remove(L, err_pos); + lua_pop(L, 1); + + return FALSE; + } + + lua_pushstring(L, modname); + + /* Now try to call */ + if (lua_pcall(L, 1, 1, 0) != 0) { + lua_remove(L, err_pos); + msg_warn("require of %s.%s failed: %s", modname, + funcname, lua_tostring(L, -1)); + lua_pop(L, 1); + + return FALSE; + } + + lua_remove(L, err_pos); + + /* Now we should have a table with results */ + if (funcname) { + if (!lua_istable(L, -1)) { + msg_warn("require of %s.%s failed: not a table but %s", modname, + funcname, lua_typename(L, lua_type(L, -1))); + + lua_pop(L, 1); + + return FALSE; + } + + table_pos = lua_gettop(L); + lua_pushstring(L, funcname); + lua_gettable(L, -2); + + if (lua_type(L, -1) == LUA_TFUNCTION) { + /* Remove table, preserve just a function */ + lua_remove(L, table_pos); + + return TRUE; + } + else { + msg_warn("require of %s.%s failed: not a function but %s", modname, + funcname, lua_typename(L, lua_type(L, -1))); + } + + lua_pop(L, 2); + + return FALSE; + } + else if (lua_isfunction(L, -1)) { + return TRUE; + } + else { + msg_warn("require of %s failed: not a function but %s", modname, + lua_typename(L, lua_type(L, -1))); + lua_pop(L, 1); + + return FALSE; + } +} + +gint rspamd_lua_function_ref_from_str(lua_State *L, const gchar *str, gsize slen, + const gchar *modname, GError **err) +{ + gint err_idx, ref_idx; + + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + + /* Load file */ + if (luaL_loadbuffer(L, str, slen, modname) != 0) { + g_set_error(err, + lua_error_quark(), + EINVAL, + "%s: cannot load lua script: %s", + modname, + lua_tostring(L, -1)); + lua_settop(L, err_idx - 1); /* Error function */ + + return LUA_NOREF; + } + + /* Now call it */ + if (lua_pcall(L, 0, 1, err_idx) != 0) { + g_set_error(err, + lua_error_quark(), + EINVAL, + "%s: cannot init lua script: %s", + modname, + lua_tostring(L, -1)); + lua_settop(L, err_idx - 1); + + return LUA_NOREF; + } + + if (!lua_isfunction(L, -1)) { + g_set_error(err, + lua_error_quark(), + EINVAL, + "%s: cannot init lua script: " + "must return function not %s", + modname, + lua_typename(L, lua_type(L, -1))); + lua_settop(L, err_idx - 1); + + return LUA_NOREF; + } + + ref_idx = luaL_ref(L, LUA_REGISTRYINDEX); + lua_settop(L, err_idx - 1); + + return ref_idx; +} + + +gboolean +rspamd_lua_try_load_redis(lua_State *L, const ucl_object_t *obj, + struct rspamd_config *cfg, gint *ref_id) +{ + gint err_idx; + struct rspamd_config **pcfg; + + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + + /* Obtain function */ + if (!rspamd_lua_require_function(L, "lua_redis", "try_load_redis_servers")) { + msg_err_config("cannot require lua_redis"); + lua_pop(L, 2); + + return FALSE; + } + + /* Function arguments */ + ucl_object_push_lua(L, obj, false); + pcfg = lua_newuserdata(L, sizeof(*pcfg)); + rspamd_lua_setclass(L, "rspamd{config}", -1); + *pcfg = cfg; + lua_pushboolean(L, false); /* no_fallback */ + + if (lua_pcall(L, 3, 1, err_idx) != 0) { + msg_err_config("cannot call lua try_load_redis_servers script: %s", + lua_tostring(L, -1)); + lua_settop(L, 0); + + return FALSE; + } + + if (lua_istable(L, -1)) { + if (ref_id) { + /* Ref table */ + lua_pushvalue(L, -1); + *ref_id = luaL_ref(L, LUA_REGISTRYINDEX); + lua_settop(L, 0); + } + else { + /* Leave it on the stack */ + lua_insert(L, err_idx); + lua_settop(L, err_idx); + } + + return TRUE; + } + else { + lua_settop(L, 0); + } + + return FALSE; +} + +void rspamd_lua_push_full_word(lua_State *L, rspamd_stat_token_t *w) +{ + gint fl_cnt; + + lua_createtable(L, 4, 0); + + if (w->stemmed.len > 0) { + lua_pushlstring(L, w->stemmed.begin, w->stemmed.len); + lua_rawseti(L, -2, 1); + } + else { + lua_pushstring(L, ""); + lua_rawseti(L, -2, 1); + } + + if (w->normalized.len > 0) { + lua_pushlstring(L, w->normalized.begin, w->normalized.len); + lua_rawseti(L, -2, 2); + } + else { + lua_pushstring(L, ""); + lua_rawseti(L, -2, 2); + } + + if (w->original.len > 0) { + lua_pushlstring(L, w->original.begin, w->original.len); + lua_rawseti(L, -2, 3); + } + else { + lua_pushstring(L, ""); + lua_rawseti(L, -2, 3); + } + + /* Flags part */ + fl_cnt = 1; + lua_createtable(L, 4, 0); + + if (w->flags & RSPAMD_STAT_TOKEN_FLAG_NORMALISED) { + lua_pushstring(L, "normalised"); + lua_rawseti(L, -2, fl_cnt++); + } + if (w->flags & RSPAMD_STAT_TOKEN_FLAG_BROKEN_UNICODE) { + lua_pushstring(L, "broken_unicode"); + lua_rawseti(L, -2, fl_cnt++); + } + if (w->flags & RSPAMD_STAT_TOKEN_FLAG_UTF) { + lua_pushstring(L, "utf"); + lua_rawseti(L, -2, fl_cnt++); + } + if (w->flags & RSPAMD_STAT_TOKEN_FLAG_TEXT) { + lua_pushstring(L, "text"); + lua_rawseti(L, -2, fl_cnt++); + } + if (w->flags & RSPAMD_STAT_TOKEN_FLAG_HEADER) { + lua_pushstring(L, "header"); + lua_rawseti(L, -2, fl_cnt++); + } + if (w->flags & (RSPAMD_STAT_TOKEN_FLAG_META | RSPAMD_STAT_TOKEN_FLAG_LUA_META)) { + lua_pushstring(L, "meta"); + lua_rawseti(L, -2, fl_cnt++); + } + if (w->flags & RSPAMD_STAT_TOKEN_FLAG_STOP_WORD) { + lua_pushstring(L, "stop_word"); + lua_rawseti(L, -2, fl_cnt++); + } + if (w->flags & RSPAMD_STAT_TOKEN_FLAG_INVISIBLE_SPACES) { + lua_pushstring(L, "invisible_spaces"); + lua_rawseti(L, -2, fl_cnt++); + } + if (w->flags & RSPAMD_STAT_TOKEN_FLAG_STEMMED) { + lua_pushstring(L, "stemmed"); + lua_rawseti(L, -2, fl_cnt++); + } + + lua_rawseti(L, -2, 4); +} + +gint rspamd_lua_push_words(lua_State *L, GArray *words, + enum rspamd_lua_words_type how) +{ + rspamd_stat_token_t *w; + guint i, cnt; + + lua_createtable(L, words->len, 0); + + for (i = 0, cnt = 1; i < words->len; i++) { + w = &g_array_index(words, rspamd_stat_token_t, i); + + switch (how) { + case RSPAMD_LUA_WORDS_STEM: + if (w->stemmed.len > 0) { + lua_pushlstring(L, w->stemmed.begin, w->stemmed.len); + lua_rawseti(L, -2, cnt++); + } + break; + case RSPAMD_LUA_WORDS_NORM: + if (w->normalized.len > 0) { + lua_pushlstring(L, w->normalized.begin, w->normalized.len); + lua_rawseti(L, -2, cnt++); + } + break; + case RSPAMD_LUA_WORDS_RAW: + if (w->original.len > 0) { + lua_pushlstring(L, w->original.begin, w->original.len); + lua_rawseti(L, -2, cnt++); + } + break; + case RSPAMD_LUA_WORDS_FULL: + rspamd_lua_push_full_word(L, w); + /* Push to the resulting vector */ + lua_rawseti(L, -2, cnt++); + break; + default: + break; + } + } + + return 1; +} + +gchar * +rspamd_lua_get_module_name(lua_State *L) +{ + lua_Debug d; + gchar *p; + gchar func_buf[128]; + + if (lua_getstack(L, 1, &d) == 1) { + (void) lua_getinfo(L, "Sl", &d); + if ((p = strrchr(d.short_src, '/')) == NULL) { + p = d.short_src; + } + else { + p++; + } + + if (strlen(p) > 20) { + rspamd_snprintf(func_buf, sizeof(func_buf), "%10s...]:%d", p, + d.currentline); + } + else { + rspamd_snprintf(func_buf, sizeof(func_buf), "%s:%d", p, + d.currentline); + } + + return g_strdup(func_buf); + } + + return NULL; +} + +bool rspamd_lua_universal_pcall(lua_State *L, gint cbref, const gchar *strloc, + gint nret, const gchar *args, GError **err, ...) +{ + va_list ap; + const gchar *argp = args, *classname; + gint err_idx, nargs = 0; + gpointer *cls_ptr; + gsize sz; + + /* Error function */ + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + + va_start(ap, err); + /* Called function */ + if (cbref > 0) { + lua_rawgeti(L, LUA_REGISTRYINDEX, cbref); + } + else { + /* Assume that function was on top of the stack */ + lua_pushvalue(L, err_idx - 1); + } + /* + * Possible arguments + * - i - lua_integer, argument - gint64 + * - n - lua_number, argument - gdouble + * - s - lua_string, argument - const gchar * (zero terminated) + * - l - lua_lstring, argument - (size_t + const gchar *) pair + * - u - lua_userdata, argument - (const char * + void *) - classname + pointer + * - b - lua_boolean, argument - gboolean (not bool due to varargs promotion) + * - f - lua_function, argument - int - position of the function on stack (not lua_registry) + * - t - lua_text, argument - int - position of the lua_text on stack (not lua_registry) + */ + while (*argp) { + switch (*argp) { + case 'i': + lua_pushinteger(L, va_arg(ap, gint64)); + nargs++; + break; + case 'n': + lua_pushnumber(L, va_arg(ap, gdouble)); + nargs++; + break; + case 's': + lua_pushstring(L, va_arg(ap, const gchar *)); + nargs++; + break; + case 'l': + sz = va_arg(ap, gsize); + lua_pushlstring(L, va_arg(ap, const gchar *), sz); + nargs++; + break; + case 'b': + lua_pushboolean(L, va_arg(ap, gboolean)); + nargs++; + break; + case 'u': + classname = va_arg(ap, const gchar *); + cls_ptr = (gpointer *) lua_newuserdata(L, sizeof(gpointer)); + *cls_ptr = va_arg(ap, gpointer); + rspamd_lua_setclass(L, classname, -1); + nargs++; + break; + case 'f': + case 't': + lua_pushvalue(L, va_arg(ap, gint)); + nargs++; + break; + default: + lua_settop(L, err_idx - 1); + g_set_error(err, lua_error_quark(), EINVAL, + "invalid argument character: %c at %s", + *argp, argp); + va_end(ap); + + return false; + } + + argp++; + } + + if (lua_pcall(L, nargs, nret, err_idx) != 0) { + g_set_error(err, lua_error_quark(), EBADF, + "error when calling lua function from %s: %s", + strloc, lua_tostring(L, -1)); + lua_settop(L, err_idx - 1); + va_end(ap); + + return false; + } + + lua_remove(L, err_idx); + va_end(ap); + + return true; +} + +#if defined(LUA_VERSION_NUM) && LUA_VERSION_NUM <= 502 +gint rspamd_lua_geti(lua_State *L, int pos, int i) +{ + pos = lua_absindex(L, pos); + lua_pushinteger(L, i); + lua_gettable(L, pos); + + return lua_type(L, -1); +} +#endif
\ No newline at end of file diff --git a/src/lua/lua_common.h b/src/lua/lua_common.h new file mode 100644 index 0000000..cc2b943 --- /dev/null +++ b/src/lua/lua_common.h @@ -0,0 +1,729 @@ +/* + * Copyright 2023 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef RSPAMD_LUA_H +#define RSPAMD_LUA_H + +#include "config.h" + +/* Lua headers do not have __cplusplus guards... */ +#ifdef __cplusplus +extern "C" { +#endif + +#include <lua.h> +#include <lauxlib.h> +#include <lualib.h> +#ifdef WITH_LUAJIT +#include <luajit.h> +#endif + +#ifdef __cplusplus +} +#endif +#include <stdbool.h> + + +#include "rspamd.h" +#include "ucl.h" +#include "lua_ucl.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef lua_open +#define lua_open() luaL_newstate() +#endif + +#ifndef luaL_reg +#define luaL_reg luaL_Reg +#endif + +#define LUA_ENUM(L, name, val) \ + lua_pushlstring(L, #name, sizeof(#name) - 1); \ + lua_pushinteger(L, val); \ + lua_settable(L, -3); + +#if LUA_VERSION_NUM > 501 && !defined LUA_COMPAT_MODULE +static inline void +luaL_register(lua_State *L, const gchar *name, const struct luaL_reg *methods) +{ + if (name != NULL) { + lua_newtable(L); + } + luaL_setfuncs(L, methods, 0); + if (name != NULL) { + lua_pushvalue(L, -1); + lua_setglobal(L, name); + } +} +#endif + +#if defined(LUA_VERSION_NUM) && LUA_VERSION_NUM == 501 + +/* Special hack to work with moonjit of specific version */ +#if !defined(MOONJIT_VERSION) && (!defined(LUAJIT_VERSION_NUM) || LUAJIT_VERSION_NUM != 20200) +static inline int lua_absindex(lua_State *L, int i) +{ + if (i < 0 && i > LUA_REGISTRYINDEX) + i += lua_gettop(L) + 1; + return i; +} +#endif + +#endif + +/* Interface definitions */ +#define LUA_FUNCTION_DEF(class, name) static int lua_##class##_##name(lua_State *L) +#define LUA_PUBLIC_FUNCTION_DEF(class, name) int lua_##class##_##name(lua_State *L) +#define LUA_INTERFACE_DEF(class, name) \ + { \ + #name, lua_##class##_##name \ + } + +extern const luaL_reg null_reg[]; + +#define RSPAMD_LUA_CFG_STATE(cfg) ((lua_State *) ((cfg)->lua_state)) +/** +* Lua IP address structure +*/ +struct rspamd_lua_ip { + rspamd_inet_addr_t *addr; +}; + +#define RSPAMD_TEXT_FLAG_OWN (1u << 0u) +#define RSPAMD_TEXT_FLAG_MMAPED (1u << 1u) +#define RSPAMD_TEXT_FLAG_WIPE (1u << 2u) +#define RSPAMD_TEXT_FLAG_SYSMALLOC (1u << 3u) +#define RSPAMD_TEXT_FLAG_FAKE (1u << 4u) +#define RSPAMD_TEXT_FLAG_BINARY (1u << 5u) +struct rspamd_lua_text { + const gchar *start; + guint len; + guint flags; +}; + +struct rspamd_lua_url { + struct rspamd_url *url; +}; + +struct rspamd_lua_regexp { + rspamd_regexp_t *re; + gchar *module; + gchar *re_pattern; + gint re_flags; +}; + +struct rspamd_map; +struct lua_map_callback_data; +struct radix_tree_compressed; +struct rspamd_mime_header; + +enum rspamd_lua_map_type { + RSPAMD_LUA_MAP_RADIX = 0, + RSPAMD_LUA_MAP_SET, + RSPAMD_LUA_MAP_HASH, + RSPAMD_LUA_MAP_REGEXP, + RSPAMD_LUA_MAP_REGEXP_MULTIPLE, + RSPAMD_LUA_MAP_CALLBACK, + RSPAMD_LUA_MAP_CDB, + RSPAMD_LUA_MAP_UNKNOWN, +}; + +struct rspamd_lua_map { + struct rspamd_map *map; + enum rspamd_lua_map_type type; + guint flags; + + union { + struct rspamd_radix_map_helper *radix; + struct rspamd_hash_map_helper *hash; + struct rspamd_regexp_map_helper *re_map; + struct rspamd_cdb_map_helper *cdb_map; + struct lua_map_callback_data *cbdata; + } data; +}; + +struct rspamd_lua_upstream { + struct upstream *up; + gint upref; +}; + +/* Common utility functions */ + +/** +* Create and register new class +*/ +void rspamd_lua_new_class(lua_State *L, + const gchar *classname, + const struct luaL_reg *methods); + +/** +* Set class name for object at @param objidx position +*/ +void rspamd_lua_setclass(lua_State *L, const gchar *classname, gint objidx); + +/** +* Pushes the metatable for specific class on top of the stack +* @param L +* @param classname +*/ +void rspamd_lua_class_metatable(lua_State *L, const gchar *classname); + +/** +* Adds a new field to the class (metatable) identified by `classname` +* @param L +* @param classname +* @param meth +*/ +void rspamd_lua_add_metamethod(lua_State *L, const gchar *classname, + luaL_Reg *meth); + +/** +* Set index of table to value (like t['index'] = value) +*/ +void rspamd_lua_table_set(lua_State *L, const gchar *index, const gchar *value); + +/** +* Get string value of index in a table (return t['index']) +*/ +const gchar *rspamd_lua_table_get(lua_State *L, const gchar *index); + +/** +* Convert classname to string +*/ +gint rspamd_lua_class_tostring(lua_State *L); + +/** +* Check whether the argument at specified index is of the specified class +*/ +gpointer rspamd_lua_check_class(lua_State *L, gint index, const gchar *name); + +/** +* Initialize lua and bindings +*/ +lua_State *rspamd_lua_init(bool wipe_mem); + +/** + * Close lua_state and free remainders + * @param L + */ +void rspamd_lua_close(lua_State *L); + +void rspamd_lua_start_gc(struct rspamd_config *cfg); + +/** +* Sets field in a global variable +* @param L +* @param global_name +* @param field_name +* @param new_elt +*/ +void rspamd_plugins_table_push_elt(lua_State *L, const gchar *field_name, + const gchar *new_elt); + +/** +* Load and initialize lua plugins +*/ +gboolean +rspamd_init_lua_filters(struct rspamd_config *cfg, bool force_load, bool strict); + + +/** +* Push lua ip address +*/ +void rspamd_lua_ip_push(lua_State *L, rspamd_inet_addr_t *addr); + +/** +* Push rspamd task structure to lua +*/ +void rspamd_lua_task_push(lua_State *L, struct rspamd_task *task); + +/** +* Return lua ip structure at the specified address +*/ +struct rspamd_lua_ip *lua_check_ip(lua_State *L, gint pos); + +struct rspamd_lua_text *lua_check_text(lua_State *L, gint pos); +/** +* Checks for a text or a string. In case of string a pointer to static structure is returned. +* So it should not be reused or placed to Lua stack anyhow! +* However, you can use this function up to 4 times and have distinct static structures +* @param L +* @param pos +* @return +*/ +struct rspamd_lua_text *lua_check_text_or_string(lua_State *L, gint pos); +/** + * Create new text object + * @param L + * @param start + * @param len + * @param own + * @return + */ +struct rspamd_lua_text *lua_new_text(lua_State *L, const gchar *start, + gsize len, gboolean own); +/** + * Create new text object from task pool if allocation is needed + * @param task + * @param L + * @param start + * @param len + * @param own + * @return + */ +struct rspamd_lua_text *lua_new_text_task(lua_State *L, struct rspamd_task *task, + const gchar *start, gsize len, gboolean own); +/** + * Checks if a text has binary characters (non ascii and non-utf8 characters) + * @param t + * @return + */ +bool lua_is_text_binary(struct rspamd_lua_text *t); + +struct rspamd_lua_regexp *lua_check_regexp(lua_State *L, gint pos); + +struct rspamd_lua_upstream *lua_check_upstream(lua_State *L, int pos); + +enum rspamd_lua_task_header_type { + RSPAMD_TASK_HEADER_PUSH_SIMPLE = 0, + RSPAMD_TASK_HEADER_PUSH_RAW, + RSPAMD_TASK_HEADER_PUSH_FULL, + RSPAMD_TASK_HEADER_PUSH_COUNT, + RSPAMD_TASK_HEADER_PUSH_HAS, +}; + +gint rspamd_lua_push_header(lua_State *L, + struct rspamd_mime_header *h, + enum rspamd_lua_task_header_type how); + +/** +* Push specific header to lua +*/ +gint rspamd_lua_push_header_array(lua_State *L, + const gchar *name, + struct rspamd_mime_header *rh, + enum rspamd_lua_task_header_type how, + gboolean strong); + +/** +* Check for task at the specified position +*/ +struct rspamd_task *lua_check_task(lua_State *L, gint pos); + +struct rspamd_task *lua_check_task_maybe(lua_State *L, gint pos); + +struct rspamd_lua_map *lua_check_map(lua_State *L, gint pos); + +/** +* Push ip address from a string (nil is pushed if a string cannot be converted) +*/ +void rspamd_lua_ip_push_fromstring(lua_State *L, const gchar *ip_str); + +/** +* Create type error +*/ +int rspamd_lua_typerror(lua_State *L, int narg, const char *tname); +/** +* Open libraries functions +*/ + +/** +* Add preload function +*/ +void rspamd_lua_add_preload(lua_State *L, const gchar *name, lua_CFunction func); + +void luaopen_task(lua_State *L); + +void luaopen_config(lua_State *L); + +void luaopen_map(lua_State *L); + +void luaopen_trie(lua_State *L); + +void luaopen_textpart(lua_State *L); + +void luaopen_mimepart(lua_State *L); + +void luaopen_image(lua_State *L); + +void luaopen_url(lua_State *L); + +void luaopen_classifier(lua_State *L); + +void luaopen_statfile(lua_State *L); + +void luaopen_regexp(lua_State *L); + +void luaopen_cdb(lua_State *L); + +void luaopen_xmlrpc(lua_State *L); + +void luaopen_http(lua_State *L); + +void luaopen_redis(lua_State *L); + +void luaopen_upstream(lua_State *L); + +void luaopen_mempool(lua_State *L); + +void luaopen_dns_resolver(lua_State *L); + +void luaopen_rsa(lua_State *L); + +void luaopen_ip(lua_State *L); + +void luaopen_expression(lua_State *L); + +void luaopen_logger(lua_State *L); + +void luaopen_text(lua_State *L); + +void luaopen_util(lua_State *L); + +void luaopen_tcp(lua_State *L); + +void luaopen_html(lua_State *L); + +void luaopen_sqlite3(lua_State *L); + +void luaopen_cryptobox(lua_State *L); + +void luaopen_dns(lua_State *L); + +void luaopen_udp(lua_State *L); + +void luaopen_worker(lua_State *L); + +void luaopen_kann(lua_State *L); + +void luaopen_spf(lua_State *L); + +void luaopen_tensor(lua_State *L); + +void luaopen_parsers(lua_State *L); + +void rspamd_lua_dostring(const gchar *line); + +double rspamd_lua_normalize(struct rspamd_config *cfg, + long double score, + void *params); + +/* Config file functions */ +void rspamd_lua_post_load_config(struct rspamd_config *cfg); + +void rspamd_lua_dumpstack(lua_State *L); + +/* Set lua path according to the configuration */ +void rspamd_lua_set_path(lua_State *L, const ucl_object_t *cfg_obj, + GHashTable *vars); + +/* Set some lua globals */ +gboolean rspamd_lua_set_env(lua_State *L, GHashTable *vars, char **lua_env, + GError **err); + +void rspamd_lua_set_globals(struct rspamd_config *cfg, lua_State *L); + +struct memory_pool_s *rspamd_lua_check_mempool(lua_State *L, gint pos); + +struct rspamd_config *lua_check_config(lua_State *L, gint pos); + +struct rspamd_async_session *lua_check_session(lua_State *L, gint pos); + +struct ev_loop *lua_check_ev_base(lua_State *L, gint pos); + +struct rspamd_dns_resolver *lua_check_dns_resolver(lua_State *L, gint pos); + +struct rspamd_lua_url *lua_check_url(lua_State *L, gint pos); + +enum rspamd_lua_parse_arguments_flags { + RSPAMD_LUA_PARSE_ARGUMENTS_DEFAULT = 0, + RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING, +}; + +/** +* Extract an arguments from lua table according to format string. Supported arguments are: +* [*]key=S|I|N|B|V|U{a-z};[key=...] +* - S - const char * +* - I - gint64_t +* - i - int32_t +* - N - double +* - B - gboolean +* - V - size_t + const char * +* - U{classname} - userdata of the following class (stored in gpointer) +* - F - function +* - O - ucl_object_t * +* - D - same as N but argument is set to NAN not to 0.0 +* - u{classname} - userdata of the following class (stored directly) +* +* If any of keys is prefixed with `*` then it is treated as required argument +* @param L lua state +* @param pos at which pos start extraction +* @param err error pointer +* @param how extraction type (IGNORE_MISSING means that default values will not be set) +* @param extraction_pattern static pattern +* @return TRUE if a table has been parsed +*/ +gboolean rspamd_lua_parse_table_arguments(lua_State *L, gint pos, + GError **err, + enum rspamd_lua_parse_arguments_flags how, + const gchar *extraction_pattern, ...); + + +gint rspamd_lua_traceback(lua_State *L); + +/** +* Returns stack trace as a string. Caller should clear memory. +* @param L +* @return +*/ +void rspamd_lua_get_traceback_string(lua_State *L, luaL_Buffer *buf); + +/** +* Returns size of table at position `tbl_pos` +*/ +guint rspamd_lua_table_size(lua_State *L, gint tbl_pos); + +void lua_push_emails_address_list(lua_State *L, GPtrArray *addrs, int flags); + + +#define TRACE_POINTS 6 + +struct lua_logger_trace { + gint cur_level; + gconstpointer traces[TRACE_POINTS]; +}; + +enum lua_logger_escape_type { + LUA_ESCAPE_NONE = (0u), + LUA_ESCAPE_UNPRINTABLE = (1u << 0u), + LUA_ESCAPE_NEWLINES = (1u << 1u), + LUA_ESCAPE_8BIT = (1u << 2u), +}; + +#define LUA_ESCAPE_LOG (LUA_ESCAPE_UNPRINTABLE | LUA_ESCAPE_NEWLINES) +#define LUA_ESCAPE_ALL (LUA_ESCAPE_UNPRINTABLE | LUA_ESCAPE_NEWLINES | LUA_ESCAPE_8BIT) + +/** +* Log lua object to string +* @param L +* @param pos +* @param outbuf +* @param len +* @return +*/ +gsize lua_logger_out_type(lua_State *L, gint pos, gchar *outbuf, + gsize len, struct lua_logger_trace *trace, + enum lua_logger_escape_type esc_type); + +/** +* Safely checks userdata to match specified class +* @param L +* @param pos +* @param classname +*/ +void *rspamd_lua_check_udata(lua_State *L, gint pos, const gchar *classname); + +#define RSPAMD_LUA_CHECK_UDATA_PTR_OR_RETURN(L, pos, classname, type, dest) \ + do { \ + type **_maybe_ptr = (type **) rspamd_lua_check_udata((L), (pos), (classname)); \ + if (_maybe_ptr == NULL) { \ + return luaL_error(L, "%s: invalid arguments; pos = %d; expected = %s", G_STRFUNC, (pos), (classname)); \ + } \ + (dest) = *(_maybe_ptr); \ + } while (0) + +/** +* Safely checks userdata to match specified class +* @param L +* @param pos +* @param classname +*/ +void *rspamd_lua_check_udata_maybe(lua_State *L, gint pos, const gchar *classname); + +/** +* Call finishing script with the specified task +* @param sc +* @param task +*/ +void lua_call_finish_script(struct rspamd_config_cfg_lua_script *sc, + struct rspamd_task *task); + +/** +* Run post-load operations +* @param L +* @param cfg +* @param ev_base +*/ +void rspamd_lua_run_postloads(lua_State *L, struct rspamd_config *cfg, + struct ev_loop *ev_base, struct rspamd_worker *w); + +void rspamd_lua_run_config_post_init(lua_State *L, struct rspamd_config *cfg); + +void rspamd_lua_run_config_unload(lua_State *L, struct rspamd_config *cfg); + +/** +* Adds new destructor for a local function for specific pool +* @param L +* @param pool +* @param ref +*/ +void rspamd_lua_add_ref_dtor(lua_State *L, rspamd_mempool_t *pool, + gint ref); + +/** + * Returns a lua reference from a function like string, e.g. `return function(...) end` + * @param L + * @param str + * @return + */ +gint rspamd_lua_function_ref_from_str(lua_State *L, const gchar *str, gsize slen, + const gchar *modname, GError **err); + +/** +* Tries to load some module using `require` and get some method from it +* @param L +* @param modname +* @param funcname +* @return TRUE if function exists in that module, the function is pushed in stack, otherwise stack is unchanged and FALSE is returned +*/ +gboolean rspamd_lua_require_function(lua_State *L, const gchar *modname, + const gchar *funcname); + +/** +* Tries to load redis server definition from ucl object specified +* @param L +* @param obj +* @param cfg +* @return +*/ +gboolean rspamd_lua_try_load_redis(lua_State *L, const ucl_object_t *obj, + struct rspamd_config *cfg, gint *ref_id); + +struct rspamd_stat_token_s; + +/** +* Pushes a single word into Lua +* @param L +* @param word +*/ +void rspamd_lua_push_full_word(lua_State *L, struct rspamd_stat_token_s *word); + +enum rspamd_lua_words_type { + RSPAMD_LUA_WORDS_STEM = 0, + RSPAMD_LUA_WORDS_NORM, + RSPAMD_LUA_WORDS_RAW, + RSPAMD_LUA_WORDS_FULL, + RSPAMD_LUA_WORDS_MAX +}; + +/** +* Pushes words (rspamd_stat_token_t) to Lua +* @param L +* @param words +* @param how +*/ +gint rspamd_lua_push_words(lua_State *L, GArray *words, + enum rspamd_lua_words_type how); + +/** +* Returns newly allocated name for caller module name +* @param L +* @return +*/ +gchar *rspamd_lua_get_module_name(lua_State *L); + +/** +* Call Lua function in a universal way. Arguments string: +* - i - lua_integer, argument - gint64 +* - n - lua_number, argument - gdouble +* - s - lua_string, argument - const gchar * (zero terminated) +* - l - lua_lstring, argument - (size_t + const gchar *) pair +* - u - lua_userdata, argument - (const char * + void *) - classname + pointer +* - b - lua_boolean, argument - gboolean (not bool due to varargs promotion) +* - f - lua_function, argument - int - position of the function on stack (not lua_registry) +* - t - lua_text, argument - int - position of the lua_text on stack (not lua_registry) +* @param L lua state +* @param cbref LUA_REGISTRY reference (if it is -1 then a function on top of the stack is called - it must be removed by caller manually) +* @param strloc where this function is called from +* @param nret number of results (or LUA_MULTRET) +* @param args arguments format string +* @param err error to promote +* @param ... arguments +* @return true of pcall returned 0, false + err otherwise +*/ +bool rspamd_lua_universal_pcall(lua_State *L, gint cbref, const gchar *strloc, + gint nret, const gchar *args, GError **err, ...); + +/** + * Returns true if lua is initialised + * @return + */ +bool rspamd_lua_is_initialised(void); + +/** +* Wrapper for lua_geti from lua 5.3 +* @param L +* @param index +* @param i +* @return +*/ +#if defined(LUA_VERSION_NUM) && LUA_VERSION_NUM <= 502 +gint rspamd_lua_geti(lua_State *L, int index, int i); +#else +#define rspamd_lua_geti lua_geti +#endif + +/* Paths defs */ +#define RSPAMD_CONFDIR_INDEX "CONFDIR" +#define RSPAMD_LOCAL_CONFDIR_INDEX "LOCAL_CONFDIR" +#define RSPAMD_RUNDIR_INDEX "RUNDIR" +#define RSPAMD_DBDIR_INDEX "DBDIR" +#define RSPAMD_LOGDIR_INDEX "LOGDIR" +#define RSPAMD_PLUGINSDIR_INDEX "PLUGINSDIR" +#define RSPAMD_SHAREDIR_INDEX "SHAREDIR" +#define RSPAMD_RULESDIR_INDEX "RULESDIR" +#define RSPAMD_LUALIBDIR_INDEX "LUALIBDIR" +#define RSPAMD_WWWDIR_INDEX "WWWDIR" +#define RSPAMD_PREFIX_INDEX "PREFIX" +#define RSPAMD_VERSION_INDEX "VERSION" + +#ifdef WITH_LUA_TRACE +extern ucl_object_t *lua_traces; +#define LUA_TRACE_POINT \ + do { \ + ucl_object_t *func_obj; \ + if (lua_traces == NULL) { lua_traces = ucl_object_typed_new(UCL_OBJECT); } \ + func_obj = (ucl_object_t *) ucl_object_lookup(lua_traces, G_STRFUNC); \ + if (func_obj == NULL) { \ + func_obj = ucl_object_typed_new(UCL_INT); \ + ucl_object_insert_key(lua_traces, func_obj, G_STRFUNC, 0, false); \ + } \ + func_obj->value.iv++; \ + } while (0) +#else +#define LUA_TRACE_POINT \ + do { \ + } while (0) +#endif + +#ifdef __cplusplus +} +#endif + +#endif /* RSPAMD_LUA_H */ diff --git a/src/lua/lua_compress.c b/src/lua/lua_compress.c new file mode 100644 index 0000000..77c82c5 --- /dev/null +++ b/src/lua/lua_compress.c @@ -0,0 +1,622 @@ +/*- + * Copyright 2021 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "lua_common.h" +#include "unix-std.h" +#include <zlib.h> + +#ifdef SYS_ZSTD +#include "zstd.h" +#include "zstd_errors.h" +#else +#include "contrib/zstd/zstd.h" +#include "contrib/zstd/error_public.h" +#endif + +/*** + * @module rspamd_compress + * This module contains compression/decompression routines (zstd and zlib currently) + */ + +/*** + * @function zstd.compress_ctx() + * Creates new compression ctx + * @return {compress_ctx} new compress ctx + */ +LUA_FUNCTION_DEF(zstd, compress_ctx); + +/*** + * @function zstd.compress_ctx() + * Creates new compression ctx + * @return {compress_ctx} new compress ctx + */ +LUA_FUNCTION_DEF(zstd, decompress_ctx); + +LUA_FUNCTION_DEF(zstd_compress, stream); +LUA_FUNCTION_DEF(zstd_compress, dtor); + +LUA_FUNCTION_DEF(zstd_decompress, stream); +LUA_FUNCTION_DEF(zstd_decompress, dtor); + +static const struct luaL_reg zstd_compress_lib_f[] = { + LUA_INTERFACE_DEF(zstd, compress_ctx), + LUA_INTERFACE_DEF(zstd, decompress_ctx), + {NULL, NULL}}; + +static const struct luaL_reg zstd_compress_lib_m[] = { + LUA_INTERFACE_DEF(zstd_compress, stream), + {"__gc", lua_zstd_compress_dtor}, + {NULL, NULL}}; + +static const struct luaL_reg zstd_decompress_lib_m[] = { + LUA_INTERFACE_DEF(zstd_decompress, stream), + {"__gc", lua_zstd_decompress_dtor}, + {NULL, NULL}}; + +static ZSTD_CStream * +lua_check_zstd_compress_ctx(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{zstd_compress}"); + luaL_argcheck(L, ud != NULL, pos, "'zstd_compress' expected"); + return ud ? *(ZSTD_CStream **) ud : NULL; +} + +static ZSTD_DStream * +lua_check_zstd_decompress_ctx(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{zstd_decompress}"); + luaL_argcheck(L, ud != NULL, pos, "'zstd_decompress' expected"); + return ud ? *(ZSTD_DStream **) ud : NULL; +} + +int lua_zstd_push_error(lua_State *L, int err) +{ + lua_pushnil(L); + lua_pushfstring(L, "zstd error %d (%s)", err, ZSTD_getErrorString(err)); + + return 2; +} + +gint lua_compress_zstd_compress(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = NULL, *res; + gsize sz, r; + gint comp_level = 1; + + t = lua_check_text_or_string(L, 1); + + if (t == NULL || t->start == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (lua_type(L, 2) == LUA_TNUMBER) { + comp_level = lua_tointeger(L, 2); + } + + sz = ZSTD_compressBound(t->len); + + if (ZSTD_isError(sz)) { + msg_err("cannot compress data: %s", ZSTD_getErrorName(sz)); + lua_pushnil(L); + + return 1; + } + + res = lua_newuserdata(L, sizeof(*res)); + res->start = g_malloc(sz); + res->flags = RSPAMD_TEXT_FLAG_OWN; + rspamd_lua_setclass(L, "rspamd{text}", -1); + r = ZSTD_compress((void *) res->start, sz, t->start, t->len, comp_level); + + if (ZSTD_isError(r)) { + msg_err("cannot compress data: %s", ZSTD_getErrorName(r)); + lua_pop(L, 1); /* Text will be freed here */ + lua_pushnil(L); + + return 1; + } + + res->len = r; + + return 1; +} + +gint lua_compress_zstd_decompress(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = NULL, *res; + gsize outlen, r; + ZSTD_DStream *zstream; + ZSTD_inBuffer zin; + ZSTD_outBuffer zout; + gchar *out; + + t = lua_check_text_or_string(L, 1); + + if (t == NULL || t->start == NULL) { + return luaL_error(L, "invalid arguments"); + } + + zstream = ZSTD_createDStream(); + ZSTD_initDStream(zstream); + + zin.pos = 0; + zin.src = t->start; + zin.size = t->len; + + if ((outlen = ZSTD_getDecompressedSize(zin.src, zin.size)) == 0) { + outlen = ZSTD_DStreamOutSize(); + } + + out = g_malloc(outlen); + + zout.dst = out; + zout.pos = 0; + zout.size = outlen; + + while (zin.pos < zin.size) { + r = ZSTD_decompressStream(zstream, &zout, &zin); + + if (ZSTD_isError(r)) { + msg_err("cannot decompress data: %s", ZSTD_getErrorName(r)); + ZSTD_freeDStream(zstream); + g_free(out); + lua_pushstring(L, ZSTD_getErrorName(r)); + lua_pushnil(L); + + return 2; + } + + if (zin.pos < zin.size && zout.pos == zout.size) { + /* We need to extend output buffer */ + zout.size = zout.size * 2; + out = g_realloc(zout.dst, zout.size); + zout.dst = out; + } + } + + ZSTD_freeDStream(zstream); + lua_pushnil(L); /* Error */ + res = lua_newuserdata(L, sizeof(*res)); + res->start = out; + res->flags = RSPAMD_TEXT_FLAG_OWN; + rspamd_lua_setclass(L, "rspamd{text}", -1); + res->len = zout.pos; + + return 2; +} + +gint lua_compress_zlib_decompress(lua_State *L, bool is_gzip) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = NULL, *res; + gsize sz; + z_stream strm; + gint rc; + guchar *p; + gsize remain; + gssize size_limit = -1; + + int windowBits = is_gzip ? (MAX_WBITS + 16) : (MAX_WBITS); + + t = lua_check_text_or_string(L, 1); + + if (t == NULL || t->start == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (lua_type(L, 2) == LUA_TNUMBER) { + size_limit = lua_tointeger(L, 2); + if (size_limit <= 0) { + return luaL_error(L, "invalid arguments (size_limit)"); + } + + sz = MIN(t->len * 2, size_limit); + } + else { + sz = t->len * 2; + } + + memset(&strm, 0, sizeof(strm)); + /* windowBits +16 to decode gzip, zlib 1.2.0.4+ */ + + /* Here are dragons to distinguish between raw deflate and zlib */ + if (windowBits == MAX_WBITS && t->len > 0) { + if ((int) (unsigned char) ((t->start[0] << 4)) != 0x80) { + /* Assume raw deflate */ + windowBits = -windowBits; + } + } + + rc = inflateInit2(&strm, windowBits); + + if (rc != Z_OK) { + return luaL_error(L, "cannot init zlib"); + } + + strm.avail_in = t->len; + strm.next_in = (guchar *) t->start; + + res = lua_newuserdata(L, sizeof(*res)); + res->start = g_malloc(sz); + res->flags = RSPAMD_TEXT_FLAG_OWN; + rspamd_lua_setclass(L, "rspamd{text}", -1); + + p = (guchar *) res->start; + remain = sz; + + while (strm.avail_in != 0) { + strm.avail_out = remain; + strm.next_out = p; + + rc = inflate(&strm, Z_NO_FLUSH); + + if (rc != Z_OK && rc != Z_BUF_ERROR) { + if (rc == Z_STREAM_END) { + break; + } + else { + msg_err("cannot decompress data: %s (last error: %s)", + zError(rc), strm.msg); + lua_pop(L, 1); /* Text will be freed here */ + lua_pushnil(L); + inflateEnd(&strm); + + return 1; + } + } + + res->len = strm.total_out; + + if (strm.avail_out == 0 && strm.avail_in != 0) { + + if (size_limit > 0 || res->len >= G_MAXUINT32 / 2) { + if (res->len > size_limit || res->len >= G_MAXUINT32 / 2) { + lua_pop(L, 1); /* Text will be freed here */ + lua_pushnil(L); + inflateEnd(&strm); + + return 1; + } + } + + /* Need to allocate more */ + remain = res->len; + res->start = g_realloc((gpointer) res->start, res->len * 2); + sz = res->len * 2; + p = (guchar *) res->start + remain; + remain = sz - remain; + } + } + + inflateEnd(&strm); + res->len = strm.total_out; + + return 1; +} + +gint lua_compress_zlib_compress(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = NULL, *res; + gsize sz; + z_stream strm; + gint rc, comp_level = Z_DEFAULT_COMPRESSION; + guchar *p; + gsize remain; + + t = lua_check_text_or_string(L, 1); + + if (t == NULL || t->start == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (lua_isnumber(L, 2)) { + comp_level = lua_tointeger(L, 2); + + if (comp_level > Z_BEST_COMPRESSION || comp_level < Z_BEST_SPEED) { + return luaL_error(L, "invalid arguments: compression level must be between %d and %d", + Z_BEST_SPEED, Z_BEST_COMPRESSION); + } + } + + + memset(&strm, 0, sizeof(strm)); + rc = deflateInit2(&strm, comp_level, Z_DEFLATED, + MAX_WBITS + 16, MAX_MEM_LEVEL - 1, Z_DEFAULT_STRATEGY); + + if (rc != Z_OK) { + return luaL_error(L, "cannot init zlib: %s", zError(rc)); + } + + sz = deflateBound(&strm, t->len); + + strm.avail_in = t->len; + strm.next_in = (guchar *) t->start; + + res = lua_newuserdata(L, sizeof(*res)); + res->start = g_malloc(sz); + res->flags = RSPAMD_TEXT_FLAG_OWN; + rspamd_lua_setclass(L, "rspamd{text}", -1); + + p = (guchar *) res->start; + remain = sz; + + while (strm.avail_in != 0) { + strm.avail_out = remain; + strm.next_out = p; + + rc = deflate(&strm, Z_FINISH); + + if (rc != Z_OK && rc != Z_BUF_ERROR) { + if (rc == Z_STREAM_END) { + break; + } + else { + msg_err("cannot compress data: %s (last error: %s)", + zError(rc), strm.msg); + lua_pop(L, 1); /* Text will be freed here */ + lua_pushnil(L); + deflateEnd(&strm); + + return 1; + } + } + + res->len = strm.total_out; + + if (strm.avail_out == 0 && strm.avail_in != 0) { + /* Need to allocate more */ + remain = res->len; + res->start = g_realloc((gpointer) res->start, strm.avail_in + sz); + sz = strm.avail_in + sz; + p = (guchar *) res->start + remain; + remain = sz - remain; + } + } + + deflateEnd(&strm); + res->len = strm.total_out; + + return 1; +} + +/* Stream API interface for Zstd: both compression and decompression */ + +/* Operations allowed by zstd stream methods */ +static const char *const zstd_stream_op[] = { + "continue", + "flush", + "end", + NULL}; + +static gint +lua_zstd_compress_ctx(lua_State *L) +{ + ZSTD_CCtx *ctx, **pctx; + + pctx = lua_newuserdata(L, sizeof(*pctx)); + ctx = ZSTD_createCCtx(); + + if (!ctx) { + return luaL_error(L, "context create failed"); + } + + *pctx = ctx; + rspamd_lua_setclass(L, "rspamd{zstd_compress}", -1); + return 1; +} + +static gint +lua_zstd_compress_dtor(lua_State *L) +{ + ZSTD_CCtx *ctx = lua_check_zstd_compress_ctx(L, 1); + + if (ctx) { + ZSTD_freeCCtx(ctx); + } + + return 0; +} + +static gint +lua_zstd_compress_reset(lua_State *L) +{ + ZSTD_CCtx *ctx = lua_check_zstd_compress_ctx(L, 1); + + if (ctx) { + ZSTD_CCtx_reset(ctx, ZSTD_reset_session_and_parameters); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static gint +lua_zstd_compress_stream(lua_State *L) +{ + ZSTD_CStream *ctx = lua_check_zstd_compress_ctx(L, 1); + struct rspamd_lua_text *t = lua_check_text_or_string(L, 2); + int op = luaL_checkoption(L, 3, zstd_stream_op[0], zstd_stream_op); + int err = 0; + ZSTD_inBuffer inb; + ZSTD_outBuffer onb; + + if (ctx && t) { + gsize dlen = 0; + + inb.size = t->len; + inb.pos = 0; + inb.src = (const void *) t->start; + + onb.pos = 0; + onb.size = ZSTD_CStreamInSize(); /* Initial guess */ + onb.dst = NULL; + + for (;;) { + if ((onb.dst = g_realloc(onb.dst, onb.size)) == NULL) { + return lua_zstd_push_error(L, ZSTD_error_memory_allocation); + } + + dlen = onb.size; + + int res = ZSTD_compressStream2(ctx, &onb, &inb, op); + + if (res == 0) { + /* All done */ + break; + } + + if ((err = ZSTD_getErrorCode(res))) { + break; + } + + onb.size *= 2; + res += dlen; /* Hint returned by compression routine */ + + /* Either double the buffer, or use the hint provided */ + if (onb.size < res) { + onb.size = res; + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + if (err) { + return lua_zstd_push_error(L, err); + } + + lua_new_text(L, onb.dst, onb.pos, TRUE); + + return 1; +} + +static gint +lua_zstd_decompress_dtor(lua_State *L) +{ + ZSTD_DStream *ctx = lua_check_zstd_decompress_ctx(L, 1); + + if (ctx) { + ZSTD_freeDStream(ctx); + } + + return 0; +} + + +static gint +lua_zstd_decompress_ctx(lua_State *L) +{ + ZSTD_DStream *ctx, **pctx; + + pctx = lua_newuserdata(L, sizeof(*pctx)); + ctx = ZSTD_createDStream(); + + if (!ctx) { + return luaL_error(L, "context create failed"); + } + + *pctx = ctx; + rspamd_lua_setclass(L, "rspamd{zstd_decompress}", -1); + return 1; +} + +static gint +lua_zstd_decompress_stream(lua_State *L) +{ + ZSTD_DStream *ctx = lua_check_zstd_decompress_ctx(L, 1); + struct rspamd_lua_text *t = lua_check_text_or_string(L, 2); + int err = 0; + ZSTD_inBuffer inb; + ZSTD_outBuffer onb; + + if (ctx && t) { + gsize dlen = 0; + + if (t->len == 0) { + return lua_zstd_push_error(L, ZSTD_error_init_missing); + } + + inb.size = t->len; + inb.pos = 0; + inb.src = (const void *) t->start; + + onb.pos = 0; + onb.size = ZSTD_DStreamInSize(); /* Initial guess */ + onb.dst = NULL; + + for (;;) { + if ((onb.dst = g_realloc(onb.dst, onb.size)) == NULL) { + return lua_zstd_push_error(L, ZSTD_error_memory_allocation); + } + + dlen = onb.size; + + int res = ZSTD_decompressStream(ctx, &onb, &inb); + + if (res == 0) { + /* All done */ + break; + } + + if ((err = ZSTD_getErrorCode(res))) { + break; + } + + onb.size *= 2; + res += dlen; /* Hint returned by compression routine */ + + /* Either double the buffer, or use the hint provided */ + if (onb.size < res) { + onb.size = res; + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + if (err) { + return lua_zstd_push_error(L, err); + } + + lua_new_text(L, onb.dst, onb.pos, TRUE); + + return 1; +} + +static gint +lua_load_zstd(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, zstd_compress_lib_f); + + return 1; +} + +void luaopen_compress(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{zstd_compress}", zstd_compress_lib_m); + rspamd_lua_new_class(L, "rspamd{zstd_decompress}", zstd_decompress_lib_m); + lua_pop(L, 2); + + rspamd_lua_add_preload(L, "rspamd_zstd", lua_load_zstd); +} diff --git a/src/lua/lua_compress.h b/src/lua/lua_compress.h new file mode 100644 index 0000000..34234de --- /dev/null +++ b/src/lua/lua_compress.h @@ -0,0 +1,37 @@ +/*- + * Copyright 2021 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef RSPAMD_LUA_COMPRESS_H +#define RSPAMD_LUA_COMPRESS_H + +#include "lua_common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +gint lua_compress_zstd_compress(lua_State *L); +gint lua_compress_zstd_decompress(lua_State *L); +gint lua_compress_zlib_compress(lua_State *L); +gint lua_compress_zlib_decompress(lua_State *L, bool is_gzip); + +void luaopen_compress(lua_State *L); + +#ifdef __cplusplus +} +#endif + +#endif//RSPAMD_LUA_COMPRESS_H diff --git a/src/lua/lua_config.c b/src/lua/lua_config.c new file mode 100644 index 0000000..a044827 --- /dev/null +++ b/src/lua/lua_config.c @@ -0,0 +1,4780 @@ +/* + * Copyright 2023 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" +#include "libmime/message.h" +#include "libutil/expression.h" +#include "src/libserver/composites/composites.h" +#include "libserver/cfg_file_private.h" +#include "libmime/lang_detection.h" +#include "lua/lua_map.h" +#include "lua/lua_thread_pool.h" +#include "utlist.h" +#include <math.h> + +/*** + * This module is used to configure rspamd and is normally available as global + * variable named `rspamd_config`. Unlike other modules, it is not necessary to + * require it before usage. + * @module rspamd_config + * @example +-- Register some callback symbol +local function foo(task) + -- do something +end +rspamd_config:register_symbol('SYMBOL', 1.0, foo) + +-- Get configuration +local tab = rspamd_config:get_all_opt('module') -- get table for module's options +local opts = rspamd_config:get_key('options') -- get content of the specified key in rspamd configuration + */ + +/* Config file methods */ +/*** + * @method rspamd_config:get_module_opt(mname, optname) + * Returns value of specified option `optname` for a module `mname`, + * @param {string} mname name of module + * @param {string} optname option to get + * @return {string or table} value of the option or `nil` if option is not found + */ +LUA_FUNCTION_DEF(config, get_module_opt); +/*** + * @method rspamd_config:get_all_opt(mname) + * Returns value of all options for a module `mname`, flattening values into a single table consisting + * of all sections with such a name. + * @param {string} mname name of module + * @return {table} table of all options for `mname` or `nil` if a module's configuration is not found + */ +LUA_FUNCTION_DEF(config, get_all_opt); + +/*** + * @method rspamd_config:get_ucl() + * Returns full configuration as a native Lua object (ucl to lua conversion). + * This method uses caching if possible. + * @return {table} table of all options in the configuration + */ +LUA_FUNCTION_DEF(config, get_ucl); +/*** + * @method rspamd_config:get_mempool() + * Returns static configuration memory pool. + * @return {mempool} [memory pool](mempool.md) object + */ +LUA_FUNCTION_DEF(config, get_mempool); +/*** + * @method rspamd_config:get_resolver() + * Returns DNS resolver. + * @return {dns_resolver} opaque DNS resolver pointer if any + */ +LUA_FUNCTION_DEF(config, get_resolver); +/*** + * @method rspamd_config:add_radix_map(mapline[, description]) + * Creates new dynamic map of IP/mask addresses. + * @param {string} mapline URL for a map + * @param {string} description optional map description + * @return {map} radix tree object + * @example +local ip_map = rspamd_config:add_radix_map ('file:///path/to/file', 'my radix map') +... +local function foo(task) + local ip = task:get_from_ip() + if ip_map:get_key(ip) then + return true + end + return false +end + */ + +/*** + * @method rspamd_config:radix_from_config(mname, optname) + * Creates new embedded map of IP/mask addresses from config. + * @param {string} mname name of module + * @param {string} optname option to get + * @return {map} radix tree object + * @example +local ip_map = rspamd_config:radix_from_config ('mymodule', 'ips') +... +local function foo(task) + local ip = task:get_from_ip() + if ip_map:get_key(ip) then + return true + end + return false +end + */ +/*** +* @method rspamd_config:radix_from_ucl(obj) +* Creates new embedded map of IP/mask addresses from object. +* @param {ucl} obj object +* @return {map} radix tree object +*/ +/*** + * @method rspamd_config:add_hash_map(mapline[, description]) + * Creates new dynamic map string objects. + * @param {string} mapline URL for a map + * @param {string} description optional map description + * @return {map} hash set object + * @example +local hash_map = rspamd_config:add_hash_map ('file:///path/to/file', 'my hash map') +... +local function foo(task) + local from = task:get_from() + if hash_map:get_key(from['user']) then + return true + end + return false +end + */ +/*** + * @method rspamd_config:add_kv_map(mapline[, description]) + * Creates new dynamic map of key/values associations. + * @param {string} mapline URL for a map + * @param {string} description optional map description + * @return {map} hash table object + * @example +local kv_map = rspamd_config:add_kv_map ('file:///path/to/file', 'my kv map') +... +local function foo(task) + local from = task:get_from() + if from then + local value = kv_map:get_key(from['user']) + if value then + return true,value + end + end + return false +end + */ +/*** + * @method rspamd_config:add_map({args}) + * Creates new dynamic map according to the attributes passed. + * + * - `type`: type of map to be created, can be one of the following set: + * + `set`: set of strings + * + `radix`: map of IP addresses to strings + * + `map`: map of strings to strings + * + `regexp`: map of regexps to strings + * + `callback`: map processed by lua callback + * - `url`: url to load map from + * - `description`: map's description + * - `callback`: lua callback for the map + * + * @return {map} `true` if map has been added + * @example + +local str = '' +local function process_map(in) + str = in +end + +rspamd_config:add_map('http://example.com/map', "settings map", process_map) + */ +/*** +* @method rspamd_config:get_maps() +* Get all maps defined as an array of rspamd{map} objects + * +* @return {table|rspamd{map}} +*/ +/*** + * @method rspamd_config:get_classifier(name) + * Returns classifier config. + * @param {string} name name of classifier (e.g. `bayes`) + * @return {classifier} classifier object or `nil` + */ +LUA_FUNCTION_DEF(config, get_classifier); +/*** + * @method rspamd_config:register_symbol(table) + * Register symbol of a specified type in rspamd. This function accepts table of arguments: + * + * - `name`: name of symbol (can be missing for callback symbols) + * - `callback`: function to be called for symbol's check (can be absent for virtual symbols) + * - `weight`: weight of symbol (should normally be 1 or missing) + * - `priority`: priority of symbol (normally 0 or missing) + * - `type`: type of symbol: + * + `normal`: executed after prefilters, according to dependency graph or in undefined order + * + `callback`: a check that merely inserts virtual symbols + * + `connfilter`: executed early; before message body is available + * + `idempotent`: cannot change result in any way; executed last + * + `postfilter`: executed after most other checks + * + `prefilter`: executed before most other checks + * + `virtual`: a symbol inserted by its parent check + * - `flags`: various flags split by commas or spaces: + * + `nice` if symbol can produce negative score; + * + `empty` if symbol can be called for empty messages + * + `skip` if symbol should be skipped now + * + `nostat` if symbol should be excluded from stat tokens + * + `trivial` symbol is trivial (e.g. no network requests) + * + `explicit_disable` requires explicit disabling (e.g. via settings) + * + `ignore_passthrough` executed even if passthrough result has been set + * - `parent`: id of parent symbol (useful for virtual symbols) + * + * @return {number} id of symbol registered + */ +LUA_FUNCTION_DEF(config, register_symbol); +/*** + * @method rspamd_config:register_symbols(callback, [weight], callback_name, [, symbol, ...]) + * Register callback function to be called for a set of symbols with initial weight. + * @param {function} callback callback function to be called for a specified symbol + * @param {number} weight initial weight of symbol (can be less than zero to specify non-spam symbols) + * @param {string} callback_name symbolic name of callback + * @param {list of strings} symbol list of symbols registered by this function + */ +LUA_FUNCTION_DEF(config, register_symbols); +/*** + * @method rspamd_config:register_virtual_symbol(name, weight,) + * Register virtual symbol that is not associated with any callback. + * + * **This method is deprecated and should not be used in newly written code ** + * @param {string} virtual name symbol's name + * @param {number} weight initial weight of symbol (can be less than zero to specify non-spam symbols) + */ +LUA_FUNCTION_DEF(config, register_virtual_symbol); +/*** + * @method rspamd_config:register_callback_symbol(name, weight, callback) + * Register callback function to be called for a specified symbol with initial weight. Symbol itself is + * not registered in the metric and is not intended to be visible by a user. + * + * **This method is deprecated and should not be used in newly written code ** + * @param {string} name symbol's name (just for unique id purposes) + * @param {number} weight initial weight of symbol (can be less than zero to specify non-spam symbols) + * @param {function} callback callback function to be called for a specified symbol + */ +LUA_FUNCTION_DEF(config, register_callback_symbol); +LUA_FUNCTION_DEF(config, register_callback_symbol_priority); + +/*** + * @method rspamd_config:register_dependency(id|name, depname) + * Create a dependency on symbol identified by name for symbol identified by ID or name. + * This affects order of checks only (a symbol is still checked if its dependencies are disabled). + * @param {number|string} id id or name of source (numeric id is returned by all register_*_symbol) + * @param {string} depname dependency name + * @example +local function cb(task) +... +end + +local id = rspamd_config:register_symbol('SYM', 1.0, cb) +rspamd_config:register_dependency(id, 'OTHER_SYM') +-- Alternative form +-- Symbol MY_RULE needs result from SPF_CHECK +rspamd_config:register_dependency('MY_RULE', 'SPF_CHECK') + */ +LUA_FUNCTION_DEF(config, register_dependency); + +/*** + * @method rspamd_config:get_symbol_flags(name) + * Returns symbol flags + * @param {string} name symbols's name + * @return {table|string} list of flags for symbol or nil + */ +LUA_FUNCTION_DEF(config, get_symbol_flags); + +/*** + * @method rspamd_config:add_symbol_flags(name, flags) + * Adds flags to a symbol + * @param {string} name symbols's name + * @param {table|string} flags flags to add + * @return {table|string} new set of flags + */ +LUA_FUNCTION_DEF(config, add_symbol_flags); + +/** + * @method rspamd_config:register_re_selector(name, selector_str, [delimiter, [flatten]]) + * Registers selector with the specific name to use in regular expressions in form + * name=/re/$ or name=/re/{selector} + * @param {string} name name of the selector + * @param {string} selector_str selector definition + * @param {string} delimiter delimiter to use when joining strings if flatten is false + * @param {bool} flatten if true then selector will return a table of captures instead of a single string + * @return true if selector has been registered + */ +LUA_FUNCTION_DEF(config, register_re_selector); + +/** + * @method rspamd_config:set_symbol({table}) + * Sets the value of a specified symbol in a metric. This function accepts table with the following elements: + * + * - `name`: name of symbol (string) + * - `score`: score for symbol (number) + * - `metric`: name of metric (string, optional) + * - `description`: description of symbol (string, optional) + * - `group`: name of group for symbol (string, optional) + * - `one_shot`: turn off multiple hits for a symbol (boolean, optional) + * - `one_param`: turn off multiple options for a symbol (boolean, optional) + * - `flags`: comma separated string of flags: + * + `ignore`: do not strictly check validity of symbol and corresponding rule + * + `one_shot`: turn off multiple hits for a symbol + * + `one_param`: allow only one parameter for a symbol + * - `priority`: priority of symbol's definition + */ +LUA_FUNCTION_DEF(config, set_metric_symbol); + +/** + * @method rspamd_config:set_action({table}) + * Sets the score of a specified action in a metric. This function accepts table with the following elements: + * + * - `action`: name of action (string) + * - `score`: score for action (number) + * - `metric`: name of metric (string, optional) + * - `priority`: priority of action's definition + */ +LUA_FUNCTION_DEF(config, set_metric_action); + +/** + * @method rspamd_config:get_action(name) + * Gets data for a specific action in config. This function returns number representing action's score + * + * @param {string} name name of action + * @return {number} action's score or nil in case of undefined score or action + */ +LUA_FUNCTION_DEF(config, get_metric_action); + +/** + * @method rspamd_config:get_all_actions() + * Gets data for all action in config + * @return {table|str->num} action's score or nil in case of undefined score or action + */ +LUA_FUNCTION_DEF(config, get_all_actions); + +/** + * @method rspamd_config:add_composite(name, expression) + * @param {string} name name of composite symbol + * @param {string} expression symbolic expression of the composite rule + * @return {bool} true if a composite has been added successfully + */ +LUA_FUNCTION_DEF(config, add_composite); +/*** + * @method rspamd_config:register_pre_filter(callback[, order]) + * Register function to be called prior to symbols processing. + * @param {function} callback callback function + * @param {number} order filters are called from lower orders to higher orders, order is equal to 0 by default + * @example +local function check_function(task) + -- It is possible to manipulate the task object here: set settings, set pre-action and so on + ... +end + +rspamd_config:register_pre_filter(check_function) + */ +LUA_FUNCTION_DEF(config, register_pre_filter); +/*** + * @method rspamd_config:register_post_filter(callback[, order]) + * Register function to be called after symbols are processed. + * + * @param {function} callback callback function + * @param {number} order filters are called from lower orders to higher orders, order is equal to 0 by default + */ +LUA_FUNCTION_DEF(config, register_post_filter); +/* XXX: obsoleted */ +LUA_FUNCTION_DEF(config, register_module_option); +/* XXX: not needed now */ +LUA_FUNCTION_DEF(config, get_api_version); +/*** + * @method rspamd_config:get_key(name) + * Returns configuration section with the specified `name`. + * @param {string} name name of config section + * @return {variant} specific value of section + * @example + +local set_section = rspamd_config:get_key("settings") +if type(set_section) == "string" then + -- Just a map of ucl + if rspamd_config:add_map(set_section, "settings map", process_settings_map) then + rspamd_config:register_pre_filter(check_settings) + end +elseif type(set_section) == "table" then + if process_settings_table(set_section) then + rspamd_config:register_pre_filter(check_settings) + end +end + */ +LUA_FUNCTION_DEF(config, get_key); + +/*** + * @method rspamd_config:add_condition(symbol, condition) + * Adds condition callback for specified symbol + * @param {string} symbol symbol's name + * @param {function} condition condition callback + * @return {boolean} true if condition has been added + * @example + +local condition_map = rspamd_config:add_map{ + type = "hash", + urls = ['file:///path/to/file'], + description = 'SMTP from map that allows FUZZY_DENIED skip for the listed addresses' +} +rspamd_config:add_condition('FUZZY_DENIED', function(task) + local E = {} + -- Check for the smtp from address adding fail safe checks + if condition_map:find_key(((task:get_from('smtp') or E)[1] or E).addr) then + return false + end + -- Allow execution otherwise + return true +end) + */ +LUA_FUNCTION_DEF(config, add_condition); + +/*** + * @method rspamd_config:enable_symbol(symbol) + * Enables execution for the specified symbol + * @param {string} symbol symbol's name + */ +LUA_FUNCTION_DEF(config, enable_symbol); + +/*** + * @method rspamd_config:disable_symbol(symbol, [disable_parent=true]) + * Disables execution for the specified symbol + * @param {string} symbol symbol's name + * @param {boolean} disable_parent if true then disable parent execution in case of a virtual symbol + */ +LUA_FUNCTION_DEF(config, disable_symbol); + +/*** + * @method rspamd_config:get_symbol_parent(symbol) + * Returns a parent symbol for specific symbol (or symbol itself if top level) + * @param {string} symbol symbol's name + */ +LUA_FUNCTION_DEF(config, get_symbol_parent); + +/*** + * @method rspamd_config:get_group_symbols(group) + * Returns list of symbols for a specific group + * @param {string} group group's name + * @available 2.0+ + * @return {list|string} list of all symbols in a specific group + */ +LUA_FUNCTION_DEF(config, get_group_symbols); + +/*** + * @method rspamd_config:get_groups([need_private]) + * Returns list of all groups defined + * @param {boolean} need_private optional flag to include private groups + * @available 2.3+ + * @return {list|table} list of all groups + */ +LUA_FUNCTION_DEF(config, get_groups); + +/*** + * @method rspamd_config:register_settings_id(name, symbols_enabled, symbols_disabled) + * Register new static settings id in config + * @param {string} name id name (not numeric!) + * @param {map|string->string} symbols_enabled map from symbol's name to boolean (currently) + * @param {map|string->string} symbols_disabled map from symbol's name to boolean (currently) + * @available 2.0+ + */ +LUA_FUNCTION_DEF(config, register_settings_id); + +/*** + * @method rspamd_config:__newindex(name, callback) + * This metamethod is called if new indices are added to the `rspamd_config` object. + * Technically, it is the equivalent of @see rspamd_config:register_symbol where `weight` is 1.0. + * There is also table form invocation that allows to control more things: + * + * - `callback`: has the same meaning and acts as function of task + * - `score`: default score for a symbol + * - `group`: default group for a symbol + * - `description`: default symbol's description + * - `priority`: additional priority value + * - `one_shot`: default value for one shot attribute + * - `condition`: function of task that can enable or disable this specific rule's execution + * @param {string} name index name + * @param {function/table} callback callback to be called + * @return {number} id of the new symbol added + * @example +rspamd_config.R_EMPTY_IMAGE = function (task) + parts = task:get_text_parts() + if parts then + for _,part in ipairs(parts) do + if part:is_empty() then + images = task:get_images() + if images then + -- Symbol `R_EMPTY_IMAGE` is inserted + return true + end + return false + end + end + end + return false +end + +rspamd_config.SYMBOL = { + callback = function(task) + ... + end, + score = 5.1, + description = 'sample symbol', + group = 'sample symbols', + condition = function(task) + if task:get_from()[1]['addr'] == 'user@example.com' then + return false + end + return true + end +} + */ +LUA_FUNCTION_DEF(config, newindex); + +/*** + * @method rspamd_config:register_regexp(params) + * Registers new re for further cached usage + * Params is the table with the following fields (mandatory fields are marked with `*`): + * - `re`* : regular expression object + * - `type`*: type of regular expression: + * + `mime`: mime regexp + * + `rawmime`: raw mime regexp + * + `header`: header regexp + * + `rawheader`: raw header expression + * + `body`: raw body regexp + * + `url`: url regexp + * - `header`: for header and rawheader regexp means the name of header + * - `pcre_only`: flag regexp as pcre only regexp + */ +LUA_FUNCTION_DEF(config, register_regexp); + +/*** + * @method rspamd_config:replace_regexp(params) + * Replaces regexp with a new one + * Params is the table with the following fields (mandatory fields are marked with `*`): + * - `old_re`* : old regular expression object (must be in the cache) + * - `new_re`* : old regular expression object (must not be in the cache) + */ +LUA_FUNCTION_DEF(config, replace_regexp); + +/*** + * @method rspamd_config:register_worker_script(worker_type, script) + * Registers the following script for workers of a specified type. The exact type + * of script function depends on worker type + * @param {string} worker_type worker type (e.g. "normal") + * @param {function} script script for a worker + * @return {boolean} `true` if a script has been registered + */ +LUA_FUNCTION_DEF(config, register_worker_script); + +/*** + * @method rspamd_config:add_on_load(function(cfg, ev_base, worker) ... end) + * Registers the following script to be executed when configuration is completely loaded + * and the worker is already started (forked) + * @param {function} script function to be executed + * @example +rspamd_config:add_on_load(function(cfg, ev_base, worker) + rspamd_config:add_periodic(ev_base, 1.0, function(cfg, ev_base) + local logger = require "rspamd_logger" + logger.infox(cfg, "periodic function in worker %s", worker:get_name()) + return true + end) +end) + */ +LUA_FUNCTION_DEF(config, add_on_load); + +/*** + * @method rspamd_config:add_periodic(event_base, timeout, function(cfg, ev_base) ... end, [jitter = false]) + * Registers function to be periodically executed by Rspamd + * @param {ev_base} event_base event base that is needed for async events + * @param {number} timeout time in seconds (could be fractional) + * @param {function} script function to be executed + * @param {boolean} jitter `true` if timeout jittering is needed + * @example +rspamd_config:add_on_load(function(cfg, ev_base) + rspamd_config:add_periodic(ev_base, 1.0, function(cfg, ev_base) + local logger = require "rspamd_logger" + logger.infox(cfg, "periodic function") + return true -- if return numeric, a new interval is set. if return false, then the periodic event is removed + end) +end) + */ +LUA_FUNCTION_DEF(config, add_periodic); + +/*** + * @method rspamd_config:add_post_init(function(cfg) ... end) + * Registers the following script to be executed when configuration is completely loaded + * @available 2.0+ + * @param {function} script function to be executed + */ +LUA_FUNCTION_DEF(config, add_post_init); + +/*** + * @method rspamd_config:add_config_unload(function(cfg) ... end) + * Registers the following script to be executed when configuration is unloaded + * @available 2.0+ + * @param {function} script function to be executed + */ +LUA_FUNCTION_DEF(config, add_config_unload); + +/*** + * @method rspamd_config:get_symbols_count() + * Returns number of symbols registered in rspamd configuration + * @return {number} number of symbols registered in the configuration + */ +LUA_FUNCTION_DEF(config, get_symbols_count); + +/*** + * @method rspamd_config:get_symbols_cksum() + * Returns checksum for all symbols in the cache + * @return {int64} boxed value of the 64 bit checksum + */ +LUA_FUNCTION_DEF(config, get_symbols_cksum); + +/*** + * @method rspamd_config:get_symbols_counters() + * Returns table of all counters in the cache (weights, frequencies etc) + * @return {table|tables} all symbols indexed by name + */ +LUA_FUNCTION_DEF(config, get_symbols_counters); + +/*** + * @method rspamd_config:get_symbols() + * Returns table of all scores defined in config. From version 2.0 returns table: + * - name + * - score + * - flags (e.g. `ignore` or `oneparam`) + * - nshots (== maxhits) + * - group - main group + * - groups - array of all groups + * @available 2.0+ + * @return {table|tables} all symbols indexed by name + */ +LUA_FUNCTION_DEF(config, get_symbols); + +/*** + * @method rspamd_config:get_symbol(sym_name) + * Returns table for a specific symbol getting data from the static config: + * - name + * - score + * - flags (e.g. `ignore` or `oneparam`) + * - nshots (== maxhits) + * - group - main group + * - groups - array of all groups + * @available 3.3+ + * @return {table} symbol data (or nil) + */ +LUA_FUNCTION_DEF(config, get_symbol); + +/*** + * @method rspamd_config:get_symbol_callback(name) + * Returns callback function for the specified symbol if it is a lua registered callback + * @return {function} callback function or nil + */ +LUA_FUNCTION_DEF(config, get_symbol_callback); + +/*** + * @method rspamd_config:get_symbol_stat(name) + * Returns table with statistics for a specific symbol: + * - `frequency`: frequency for symbol's hits + * - `stddev`: standard deviation of `frequency` + * - `time`: average time in seconds (floating point) + * - `count`: total number of hits + * @return {table} symbol stats + */ +LUA_FUNCTION_DEF(config, get_symbol_stat); + +/*** + * @method rspamd_config:set_symbol_callback(name, callback) + * Sets callback for the specified symbol + * @return {boolean} true if function has been replaced + */ +LUA_FUNCTION_DEF(config, set_symbol_callback); + +/*** + * @method rspamd_config:register_finish_script(callback) + * Adds new callback that is called on worker process termination when all + * tasks pending are processed + * + * @param callback {function} a function with one argument (rspamd_task) + */ +LUA_FUNCTION_DEF(config, register_finish_script); + +/*** + * @method rspamd_config:register_monitored(url, type, [{params}]) + * Registers monitored resource to watch its availability. Supported types: + * + * - `dns`: DNS monitored object + * + * Params are optional table specific for each type. For DNS it supports the + * following options: + * + * - `prefix`: prefix to add before making request + * - `type`: type of request (e.g. 'a' or 'txt') + * - `ipnet`: array of ip/networks to expect on reply + * - `rcode`: expected return code (e.g. `nxdomain`) + * + * Returned object has the following methods: + * + * - `alive`: returns `true` if monitored resource is alive + * - `offline`: returns number of seconds of the current offline period (or 0 if alive) + * - `total_offline`: returns number of seconds of the overall offline + * - `latency`: returns the current average latency in seconds (or 0 if offline) + * + * @param {string} url resource to monitor + * @param {string} type type of monitoring + * @param {table} opts optional parameters + * @return {rspamd_monitored} rspamd monitored object + */ +LUA_FUNCTION_DEF(config, register_monitored); + +/*** + * @method rspamd_config:add_doc(path, option, doc_string, [{params}]) + * Adds new documentation string for an option `option` at path `path` + * Options defines optional params, such as: + * + * - `default`: default option value + * - `type`: type of an option (`string`, `number`, `object`, `array` etc) + * - `required`: if an option is required + * + * @param {string} path documentation path (e.g. module name) + * @param {string} option name of the option + * @param {string} doc_string documentation string + * @param {table} params optional parameters + */ +LUA_FUNCTION_DEF(config, add_doc); + +/*** + * @method rspamd_config:add_example(path, option, doc_string, example) + * Adds new documentation + * + * @param {string} path documentation path (e.g. module name or nil for top) + * @param {string} option name of the option + * @param {string} doc_string documentation string + * @param {string} example example in ucl format, comments are also parsed + */ +LUA_FUNCTION_DEF(config, add_example); + +/*** + * @method rspamd_config:set_peak_cb(function) + * Sets a function that will be called when frequency of some symbol goes out of + * stddev * 2 over the last period of refreshment. + * + * @example +rspamd_config:set_peak_cb(function(ev_base, sym, mean, stddev, value, error) + -- ev_base: event base for async events (e.g. redis) + -- sym: symbol's name + -- mean: mean frequency value + -- stddev: standard deviation of frequency + -- value: current frequency value + -- error: squared error + local logger = require "rspamd_logger" + logger.infox(rspamd_config, "symbol %s has changed frequency significantly: %s(%s) over %s(%s)", + sym, value, error, mean, stddev) +end) + */ +LUA_FUNCTION_DEF(config, set_peak_cb); +/*** + * @method rspamd_config:get_cpu_flags() + * Returns architecture dependent flags supported by the CPU + * Currently, only x86 flags are supported: + * - 'ssse3' + * - 'sse42' + * - 'avx' + * - 'avx2' + * @return {table} flag -> true table + */ +LUA_FUNCTION_DEF(config, get_cpu_flags); + +/*** + * @method rspamd_config:has_torch() + * Returns true if Rspamd is compiled with torch support and the runtime CPU + * supports sse4.2 required for torch. + * @return {boolean} true if torch is compiled and supported + */ +LUA_FUNCTION_DEF(config, has_torch); + +/*** + * @method rspamd_config:experimental_enabled() + * Returns true if experimental plugins are enabled + * @return {boolean} true if experimental plugins are enabled + */ +LUA_FUNCTION_DEF(config, experimental_enabled); + +/*** + * @method rspamd_config:load_ucl(filename[, include_trace]) + * Loads config from the UCL file (but does not perform parsing using rcl) + * @param {string} filename file to load + * @return true or false + error message + */ +LUA_FUNCTION_DEF(config, load_ucl); + +/*** + * @method rspamd_config:parse_rcl([skip_sections]) + * Parses RCL using loaded ucl file + * @param {table|string} sections to skip + * @return true or false + error message + */ +LUA_FUNCTION_DEF(config, parse_rcl); + +/*** + * @method rspamd_config:init_modules() + * Initialize lua and internal modules + * @return true or false + */ +LUA_FUNCTION_DEF(config, init_modules); + +/*** + * @method rspamd_config:init_subsystem(str) + * Initialize config subsystem from a comma separated list: + * - `modules` - init modules + * - `langdet` - language detector + * - `dns` - DNS resolver + * - TODO: add more + */ +LUA_FUNCTION_DEF(config, init_subsystem); + +/*** + * @method rspamd_config:get_tld_path() + * Returns path to TLD file + * @return {string} path to tld file + */ +LUA_FUNCTION_DEF(config, get_tld_path); + +/*** + * @method rspamd_config:get_dns_max_requests() + * Returns limit of DNS requests per task + * @return {number} number of dns requests allowed + */ +LUA_FUNCTION_DEF(config, get_dns_max_requests); + +/*** + * @method rspamd_config:get_dns_timeout() + * Returns timeout for a DNS request + * @return {number} DNS timeout in second or 0 if not defined + */ +LUA_FUNCTION_DEF(config, get_dns_timeout); + +static const struct luaL_reg configlib_m[] = { + LUA_INTERFACE_DEF(config, get_module_opt), + LUA_INTERFACE_DEF(config, get_mempool), + LUA_INTERFACE_DEF(config, get_resolver), + LUA_INTERFACE_DEF(config, get_all_opt), + LUA_INTERFACE_DEF(config, get_ucl), + LUA_INTERFACE_DEF(config, add_radix_map), + LUA_INTERFACE_DEF(config, radix_from_config), + LUA_INTERFACE_DEF(config, radix_from_ucl), + LUA_INTERFACE_DEF(config, add_hash_map), + LUA_INTERFACE_DEF(config, add_kv_map), + LUA_INTERFACE_DEF(config, add_map), + LUA_INTERFACE_DEF(config, get_maps), + LUA_INTERFACE_DEF(config, get_classifier), + LUA_INTERFACE_DEF(config, register_symbol), + LUA_INTERFACE_DEF(config, register_symbols), + LUA_INTERFACE_DEF(config, register_virtual_symbol), + LUA_INTERFACE_DEF(config, register_callback_symbol), + LUA_INTERFACE_DEF(config, register_callback_symbol_priority), + LUA_INTERFACE_DEF(config, register_dependency), + LUA_INTERFACE_DEF(config, register_settings_id), + LUA_INTERFACE_DEF(config, get_symbol_flags), + LUA_INTERFACE_DEF(config, set_metric_symbol), + {"set_symbol", lua_config_set_metric_symbol}, + LUA_INTERFACE_DEF(config, set_metric_action), + {"set_action", lua_config_set_metric_action}, + {"get_metric_symbol", lua_config_get_symbol}, + LUA_INTERFACE_DEF(config, get_metric_action), + {"get_action", lua_config_get_metric_action}, + LUA_INTERFACE_DEF(config, get_all_actions), + LUA_INTERFACE_DEF(config, add_composite), + LUA_INTERFACE_DEF(config, register_module_option), + LUA_INTERFACE_DEF(config, register_pre_filter), + LUA_INTERFACE_DEF(config, register_post_filter), + LUA_INTERFACE_DEF(config, get_api_version), + LUA_INTERFACE_DEF(config, get_key), + LUA_INTERFACE_DEF(config, add_condition), + LUA_INTERFACE_DEF(config, enable_symbol), + LUA_INTERFACE_DEF(config, disable_symbol), + LUA_INTERFACE_DEF(config, register_regexp), + LUA_INTERFACE_DEF(config, replace_regexp), + LUA_INTERFACE_DEF(config, register_worker_script), + LUA_INTERFACE_DEF(config, register_re_selector), + LUA_INTERFACE_DEF(config, add_on_load), + LUA_INTERFACE_DEF(config, add_periodic), + LUA_INTERFACE_DEF(config, add_post_init), + LUA_INTERFACE_DEF(config, add_config_unload), + LUA_INTERFACE_DEF(config, get_symbols_count), + LUA_INTERFACE_DEF(config, get_symbols_cksum), + LUA_INTERFACE_DEF(config, get_symbols_counters), + {"get_symbols_scores", lua_config_get_symbols}, + LUA_INTERFACE_DEF(config, get_symbols), + LUA_INTERFACE_DEF(config, get_symbol), + LUA_INTERFACE_DEF(config, get_groups), + LUA_INTERFACE_DEF(config, get_symbol_callback), + LUA_INTERFACE_DEF(config, set_symbol_callback), + LUA_INTERFACE_DEF(config, get_symbol_stat), + LUA_INTERFACE_DEF(config, get_symbol_parent), + LUA_INTERFACE_DEF(config, get_group_symbols), + LUA_INTERFACE_DEF(config, register_finish_script), + LUA_INTERFACE_DEF(config, register_monitored), + LUA_INTERFACE_DEF(config, add_doc), + LUA_INTERFACE_DEF(config, add_example), + LUA_INTERFACE_DEF(config, set_peak_cb), + LUA_INTERFACE_DEF(config, get_cpu_flags), + LUA_INTERFACE_DEF(config, has_torch), + LUA_INTERFACE_DEF(config, experimental_enabled), + LUA_INTERFACE_DEF(config, load_ucl), + LUA_INTERFACE_DEF(config, parse_rcl), + LUA_INTERFACE_DEF(config, init_modules), + LUA_INTERFACE_DEF(config, init_subsystem), + LUA_INTERFACE_DEF(config, get_tld_path), + LUA_INTERFACE_DEF(config, get_dns_max_requests), + LUA_INTERFACE_DEF(config, get_dns_timeout), + {"__tostring", rspamd_lua_class_tostring}, + {"__newindex", lua_config_newindex}, + {NULL, NULL}}; + +LUA_FUNCTION_DEF(monitored, alive); +LUA_FUNCTION_DEF(monitored, latency); +LUA_FUNCTION_DEF(monitored, offline); +LUA_FUNCTION_DEF(monitored, total_offline); + +static const struct luaL_reg monitoredlib_m[] = { + LUA_INTERFACE_DEF(monitored, alive), + LUA_INTERFACE_DEF(monitored, latency), + LUA_INTERFACE_DEF(monitored, offline), + LUA_INTERFACE_DEF(monitored, total_offline), + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + +static const guint64 rspamd_lua_callback_magic = 0x32c118af1e3263c7ULL; + +struct rspamd_config * +lua_check_config(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{config}"); + luaL_argcheck(L, ud != NULL, pos, "'config' expected"); + return ud ? *((struct rspamd_config **) ud) : NULL; +} + +static struct rspamd_monitored * +lua_check_monitored(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{monitored}"); + luaL_argcheck(L, ud != NULL, pos, "'monitored' expected"); + return ud ? *((struct rspamd_monitored **) ud) : NULL; +} + +/*** Config functions ***/ +static gint +lua_config_get_api_version(lua_State *L) +{ + msg_warn("get_api_version is deprecated, do not use it"); + lua_pushnumber(L, 100); + + return 1; +} + +static gint +lua_config_get_module_opt(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *mname, *optname; + const ucl_object_t *obj; + + if (cfg) { + mname = luaL_checkstring(L, 2); + optname = luaL_checkstring(L, 3); + + if (mname && optname) { + obj = rspamd_config_get_module_opt(cfg, mname, optname); + if (obj) { + return ucl_object_push_lua(L, obj, TRUE); + } + } + } + lua_pushnil(L); + return 1; +} + +static int +lua_config_get_mempool(lua_State *L) +{ + LUA_TRACE_POINT; + rspamd_mempool_t **ppool; + struct rspamd_config *cfg = lua_check_config(L, 1); + + if (cfg != NULL) { + ppool = lua_newuserdata(L, sizeof(rspamd_mempool_t *)); + rspamd_lua_setclass(L, "rspamd{mempool}", -1); + *ppool = cfg->cfg_pool; + } + else { + lua_pushnil(L); + } + return 1; +} + +static int +lua_config_get_resolver(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_dns_resolver **pres; + struct rspamd_config *cfg = lua_check_config(L, 1); + + if (cfg != NULL && cfg->dns_resolver) { + pres = lua_newuserdata(L, sizeof(*pres)); + rspamd_lua_setclass(L, "rspamd{resolver}", -1); + *pres = cfg->dns_resolver; + } + else { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_config_get_all_opt(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *mname; + const ucl_object_t *obj, *cur, *cur_elt; + ucl_object_iter_t it = NULL; + gint i; + + if (cfg) { + mname = luaL_checkstring(L, 2); + + if (mname) { + obj = ucl_obj_get_key(cfg->cfg_ucl_obj, mname); + /* Flatten object */ + if (obj != NULL && (ucl_object_type(obj) == UCL_OBJECT || + ucl_object_type(obj) == UCL_ARRAY)) { + + lua_newtable(L); + it = ucl_object_iterate_new(obj); + + LL_FOREACH(obj, cur) + { + it = ucl_object_iterate_reset(it, cur); + + while ((cur_elt = ucl_object_iterate_safe(it, true))) { + lua_pushstring(L, ucl_object_key(cur_elt)); + ucl_object_push_lua(L, cur_elt, true); + lua_settable(L, -3); + } + } + + ucl_object_iterate_free(it); + + return 1; + } + else if (obj != NULL) { + lua_newtable(L); + i = 1; + + LL_FOREACH(obj, cur) + { + lua_pushinteger(L, i++); + ucl_object_push_lua(L, cur, true); + lua_settable(L, -3); + } + + return 1; + } + } + } + lua_pushnil(L); + + return 1; +} + +struct rspamd_lua_cached_config { + lua_State *L; + gint ref; +}; + +static void +lua_config_ucl_dtor(gpointer p) +{ + struct rspamd_lua_cached_config *cached = p; + + luaL_unref(cached->L, LUA_REGISTRYINDEX, cached->ref); +} + +static gint +lua_config_get_ucl(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + struct rspamd_lua_cached_config *cached; + + if (cfg) { + cached = rspamd_mempool_get_variable(cfg->cfg_pool, "ucl_cached"); + + if (cached) { + lua_rawgeti(L, LUA_REGISTRYINDEX, cached->ref); + } + else { + if (cfg->cfg_ucl_obj) { + ucl_object_push_lua(L, cfg->cfg_ucl_obj, true); + lua_pushvalue(L, -1); + cached = rspamd_mempool_alloc(cfg->cfg_pool, sizeof(*cached)); + cached->L = L; + cached->ref = luaL_ref(L, LUA_REGISTRYINDEX); + rspamd_mempool_set_variable(cfg->cfg_pool, "ucl_cached", + cached, lua_config_ucl_dtor); + } + else { + lua_pushnil(L); + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + + +static gint +lua_config_get_classifier(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + struct rspamd_classifier_config *clc = NULL, **pclc = NULL; + const gchar *name; + GList *cur; + + if (cfg) { + name = luaL_checkstring(L, 2); + + cur = g_list_first(cfg->classifiers); + while (cur) { + clc = cur->data; + if (g_ascii_strcasecmp(clc->name, name) == 0) { + pclc = &clc; + break; + } + cur = g_list_next(cur); + } + if (pclc) { + pclc = lua_newuserdata(L, + sizeof(struct rspamd_classifier_config *)); + rspamd_lua_setclass(L, "rspamd{classifier}", -1); + *pclc = clc; + return 1; + } + } + + lua_pushnil(L); + return 1; +} + +struct lua_callback_data { + guint64 magic; + lua_State *L; + gchar *symbol; + + union { + gchar *name; + gint ref; + } callback; + gboolean cb_is_ref; + + /* Dynamic data */ + gint stack_level; + gint order; + struct rspamd_symcache_dynamic_item *item; +}; + +/* + * Unref symbol if it is local reference + */ +static void +lua_destroy_cfg_symbol(gpointer ud) +{ + struct lua_callback_data *cd = ud; + + /* Unref callback */ + if (cd->cb_is_ref) { + luaL_unref(cd->L, LUA_REGISTRYINDEX, cd->callback.ref); + } +} + +static gint +lua_config_register_module_option(lua_State *L) +{ + return 0; +} + +static gint +rspamd_compare_order_func(gconstpointer a, gconstpointer b) +{ + const struct lua_callback_data *cb1 = a, *cb2 = b; + + /* order of call goes from lower to higher */ + return cb2->order - cb1->order; +} + +static void +lua_metric_symbol_callback(struct rspamd_task *task, + struct rspamd_symcache_dynamic_item *item, + gpointer ud) +{ + struct lua_callback_data *cd = ud; + struct rspamd_task **ptask; + gint level = lua_gettop(cd->L), nresults, err_idx, ret; + lua_State *L = cd->L; + struct rspamd_symbol_result *s; + + cd->item = item; + rspamd_symcache_item_async_inc(task, item, "lua symbol"); + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + + level++; + + if (cd->cb_is_ref) { + lua_rawgeti(L, LUA_REGISTRYINDEX, cd->callback.ref); + } + else { + lua_getglobal(L, cd->callback.name); + } + + ptask = lua_newuserdata(L, sizeof(struct rspamd_task *)); + rspamd_lua_setclass(L, "rspamd{task}", -1); + *ptask = task; + + if ((ret = lua_pcall(L, 1, LUA_MULTRET, err_idx)) != 0) { + msg_err_task("call to (%s) failed (%d): %s", cd->symbol, ret, + lua_tostring(L, -1)); + lua_settop(L, err_idx); /* Not -1 here, as err_func is popped below */ + } + else { + nresults = lua_gettop(L) - level; + + if (nresults >= 1) { + /* Function returned boolean, so maybe we need to insert result? */ + gint res = 0; + gint i; + gdouble flag = 1.0; + gint type; + + type = lua_type(cd->L, level + 1); + + if (type == LUA_TBOOLEAN) { + res = lua_toboolean(L, level + 1); + } + else if (type == LUA_TNUMBER) { + res = lua_tonumber(L, level + 1); + } + else if (type == LUA_TNIL) { + /* Can happen sometimes... */ + res = FALSE; + } + else { + /* Something bogus has been returned, so we should log it */ + msg_err_task("invalid return value for %s: %s", + cd->symbol, lua_typename(L, type)); + res = FALSE; + } + + if (res) { + gint first_opt = 2; + + if (lua_type(L, level + 2) == LUA_TNUMBER) { + flag = lua_tonumber(L, level + 2); + /* Shift opt index */ + first_opt = 3; + } + else { + flag = res; + } + + s = rspamd_task_insert_result(task, cd->symbol, flag, NULL); + + if (s) { + guint last_pos = lua_gettop(L); + + for (i = level + first_opt; i <= last_pos; i++) { + if (lua_type(L, i) == LUA_TSTRING) { + gsize optlen; + const char *opt = lua_tolstring(L, i, &optlen); + + rspamd_task_add_result_option(task, s, opt, optlen); + } + else if (lua_type(L, i) == LUA_TUSERDATA) { + struct rspamd_lua_text *t = lua_check_text(L, i); + + if (t) { + rspamd_task_add_result_option(task, s, t->start, + t->len); + } + } + else if (lua_type(L, i) == LUA_TTABLE) { + gsize objlen = rspamd_lua_table_size(L, i); + + for (guint j = 1; j <= objlen; j++) { + lua_rawgeti(L, i, j); + + if (lua_type(L, -1) == LUA_TSTRING) { + gsize optlen; + const char *opt = lua_tolstring(L, -1, &optlen); + + rspamd_task_add_result_option(task, s, opt, optlen); + } + else if (lua_type(L, -1) == LUA_TUSERDATA) { + struct rspamd_lua_text *t = lua_check_text(L, -1); + + if (t) { + rspamd_task_add_result_option(task, s, t->start, + t->len); + } + } + + lua_pop(L, 1); + } + } + } + } + } + + lua_pop(L, nresults); + } + } + + lua_pop(L, 1); /* Error function */ + rspamd_symcache_item_async_dec_check(task, cd->item, "lua symbol"); + g_assert(lua_gettop(L) == level - 1); +} + +static void lua_metric_symbol_callback_return(struct thread_entry *thread_entry, + int ret); + +static void lua_metric_symbol_callback_error(struct thread_entry *thread_entry, + int ret, + const char *msg); + +static void +lua_metric_symbol_callback_coro(struct rspamd_task *task, + struct rspamd_symcache_dynamic_item *item, + gpointer ud) +{ + struct lua_callback_data *cd = ud; + struct rspamd_task **ptask; + struct thread_entry *thread_entry; + + cd->item = item; + rspamd_symcache_item_async_inc(task, item, "lua coro symbol"); + thread_entry = lua_thread_pool_get_for_task(task); + + g_assert(thread_entry->cd == NULL); + thread_entry->cd = cd; + + lua_State *thread = thread_entry->lua_state; + cd->stack_level = lua_gettop(thread); + + if (cd->cb_is_ref) { + lua_rawgeti(thread, LUA_REGISTRYINDEX, cd->callback.ref); + } + else { + lua_getglobal(thread, cd->callback.name); + } + + ptask = lua_newuserdata(thread, sizeof(struct rspamd_task *)); + rspamd_lua_setclass(thread, "rspamd{task}", -1); + *ptask = task; + + thread_entry->finish_callback = lua_metric_symbol_callback_return; + thread_entry->error_callback = lua_metric_symbol_callback_error; + + lua_thread_call(thread_entry, 1); +} + +static void +lua_metric_symbol_callback_error(struct thread_entry *thread_entry, + int ret, + const char *msg) +{ + struct lua_callback_data *cd = thread_entry->cd; + struct rspamd_task *task = thread_entry->task; + msg_err_task("call to coroutine (%s) failed (%d): %s", cd->symbol, ret, msg); + + rspamd_symcache_item_async_dec_check(task, cd->item, "lua coro symbol"); +} + +static void +lua_metric_symbol_callback_return(struct thread_entry *thread_entry, int ret) +{ + struct lua_callback_data *cd = thread_entry->cd; + struct rspamd_task *task = thread_entry->task; + int nresults; + struct rspamd_symbol_result *s; + + (void) ret; + + lua_State *L = thread_entry->lua_state; + + nresults = lua_gettop(L) - cd->stack_level; + + if (nresults >= 1) { + /* Function returned boolean, so maybe we need to insert result? */ + gint res = 0; + gint i; + gdouble flag = 1.0; + gint type; + + type = lua_type(L, cd->stack_level + 1); + + if (type == LUA_TBOOLEAN) { + res = lua_toboolean(L, cd->stack_level + 1); + } + else if (type == LUA_TFUNCTION) { + g_assert_not_reached(); + } + else { + res = lua_tonumber(L, cd->stack_level + 1); + } + + if (res) { + gint first_opt = 2; + + if (lua_type(L, cd->stack_level + 2) == LUA_TNUMBER) { + flag = lua_tonumber(L, cd->stack_level + 2); + /* Shift opt index */ + first_opt = 3; + } + else { + flag = res; + } + + s = rspamd_task_insert_result(task, cd->symbol, flag, NULL); + + if (s) { + guint last_pos = lua_gettop(L); + + for (i = cd->stack_level + first_opt; i <= last_pos; i++) { + if (lua_type(L, i) == LUA_TSTRING) { + gsize optlen; + const char *opt = lua_tolstring(L, i, &optlen); + + rspamd_task_add_result_option(task, s, opt, optlen); + } + else if (lua_type(L, i) == LUA_TUSERDATA) { + struct rspamd_lua_text *t = lua_check_text(L, i); + + if (t) { + rspamd_task_add_result_option(task, s, t->start, + t->len); + } + } + else if (lua_type(L, i) == LUA_TTABLE) { + gsize objlen = rspamd_lua_table_size(L, i); + + for (guint j = 1; j <= objlen; j++) { + lua_rawgeti(L, i, j); + + if (lua_type(L, -1) == LUA_TSTRING) { + gsize optlen; + const char *opt = lua_tolstring(L, -1, &optlen); + + rspamd_task_add_result_option(task, s, opt, optlen); + } + else if (lua_type(L, -1) == LUA_TUSERDATA) { + struct rspamd_lua_text *t = lua_check_text(L, -1); + + if (t) { + rspamd_task_add_result_option(task, s, t->start, + t->len); + } + } + + lua_pop(L, 1); + } + } + } + } + } + + lua_pop(L, nresults); + } + + g_assert(lua_gettop(L) == cd->stack_level); /* we properly cleaned up the stack */ + + cd->stack_level = 0; + rspamd_symcache_item_async_dec_check(task, cd->item, "lua coro symbol"); +} + +static GArray * +rspamd_process_id_list(const gchar *entries) +{ + gchar **sym_elts; + GArray *ret; + + sym_elts = g_strsplit_set(entries, ",;", -1); + guint nids = g_strv_length(sym_elts); + ret = g_array_sized_new(FALSE, FALSE, sizeof(guint32), nids); + + for (guint i = 0; i < nids; i++) { + guint32 v = rspamd_config_name_to_id(sym_elts[i], strlen(sym_elts[i])); + g_array_append_val(ret, v); + } + + g_strfreev(sym_elts); + + return ret; +} + +static gint +rspamd_register_symbol_fromlua(lua_State *L, + struct rspamd_config *cfg, + const gchar *name, + gint ref, + gdouble weight, + gint priority, + enum rspamd_symbol_type type, + gint parent, + GArray *allowed_ids, + GArray *forbidden_ids, + gboolean optional) +{ + struct lua_callback_data *cd; + gint ret = -1; + + if (priority == 0 && weight < 0) { + priority = 1; + } + + if ((ret = rspamd_symcache_find_symbol(cfg->cache, name)) != -1) { + if (optional) { + msg_debug_config("duplicate symbol: %s, skip registering", name); + + return ret; + } + else { + msg_err_config("duplicate symbol: %s, skip registering", name); + + return -1; + } + } + + if (allowed_ids && !(type & SYMBOL_TYPE_EXPLICIT_DISABLE)) { + /* Mark symbol as explicit allow */ + msg_info_config("mark symbol %s as explicit enable as its execution is" + "allowed merely on specific settings ids", + name); + type |= SYMBOL_TYPE_EXPLICIT_ENABLE; + } + + if (ref != -1) { + cd = rspamd_mempool_alloc0(cfg->cfg_pool, + sizeof(struct lua_callback_data)); + cd->magic = rspamd_lua_callback_magic; + cd->cb_is_ref = TRUE; + cd->callback.ref = ref; + cd->L = L; + + if (name) { + cd->symbol = rspamd_mempool_strdup(cfg->cfg_pool, name); + } + + if (type & SYMBOL_TYPE_USE_CORO) { + ret = rspamd_symcache_add_symbol(cfg->cache, + name, + priority, + lua_metric_symbol_callback_coro, + cd, + type, + parent); + } + else { + ret = rspamd_symcache_add_symbol(cfg->cache, + name, + priority, + lua_metric_symbol_callback, + cd, + type, + parent); + } + rspamd_mempool_add_destructor(cfg->cfg_pool, + (rspamd_mempool_destruct_t) lua_destroy_cfg_symbol, + cd); + } + else { + /* No callback */ + ret = rspamd_symcache_add_symbol(cfg->cache, + name, + priority, + NULL, + NULL, + type, + parent); + } + + if (allowed_ids) { + rspamd_symcache_set_allowed_settings_ids(cfg->cache, name, + &g_array_index(allowed_ids, guint32, 0), allowed_ids->len); + } + + if (forbidden_ids) { + rspamd_symcache_set_forbidden_settings_ids(cfg->cache, name, + &g_array_index(forbidden_ids, guint32, 0), forbidden_ids->len); + } + + return ret; +} + +static gint +lua_config_register_post_filter(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + gint order = 0, cbref, ret; + + if (cfg) { + if (lua_type(L, 3) == LUA_TNUMBER) { + order = lua_tonumber(L, 3); + } + + if (lua_type(L, 2) == LUA_TFUNCTION) { + lua_pushvalue(L, 2); + /* Get a reference */ + cbref = luaL_ref(L, LUA_REGISTRYINDEX); + } + else { + return luaL_error(L, "invalid type for callback: %s", + lua_typename(L, lua_type(L, 2))); + } + + msg_warn_config("register_post_filter function is deprecated, " + "use register_symbol instead"); + + ret = rspamd_register_symbol_fromlua(L, + cfg, + NULL, + cbref, + 1.0, + order, + SYMBOL_TYPE_POSTFILTER | SYMBOL_TYPE_CALLBACK, + -1, + NULL, NULL, + FALSE); + + lua_pushboolean(L, ret); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_config_register_pre_filter(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + gint order = 0, cbref, ret; + + if (cfg) { + if (lua_type(L, 3) == LUA_TNUMBER) { + order = lua_tonumber(L, 3); + } + + if (lua_type(L, 2) == LUA_TFUNCTION) { + lua_pushvalue(L, 2); + /* Get a reference */ + cbref = luaL_ref(L, LUA_REGISTRYINDEX); + } + else { + return luaL_error(L, "invalid type for callback: %s", + lua_typename(L, lua_type(L, 2))); + } + + msg_warn_config("register_pre_filter function is deprecated, " + "use register_symbol instead"); + + ret = rspamd_register_symbol_fromlua(L, + cfg, + NULL, + cbref, + 1.0, + order, + SYMBOL_TYPE_PREFILTER | SYMBOL_TYPE_CALLBACK, + -1, + NULL, NULL, + FALSE); + + lua_pushboolean(L, ret); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_config_get_key(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *name; + size_t namelen; + const ucl_object_t *val; + + name = luaL_checklstring(L, 2, &namelen); + if (name && cfg) { + val = ucl_object_lookup_len(cfg->cfg_ucl_obj, name, namelen); + if (val != NULL) { + ucl_object_push_lua(L, val, val->type != UCL_ARRAY); + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static guint +lua_parse_symbol_flags(const gchar *str) +{ + guint ret = 0; + + if (str) { + if (strstr(str, "fine") != NULL) { + ret |= SYMBOL_TYPE_FINE; + } + if (strstr(str, "nice") != NULL) { + ret |= SYMBOL_TYPE_FINE; + } + if (strstr(str, "empty") != NULL) { + ret |= SYMBOL_TYPE_EMPTY; + } + if (strstr(str, "skip") != NULL) { + ret |= SYMBOL_TYPE_SKIPPED; + } + if (strstr(str, "nostat") != NULL) { + ret |= SYMBOL_TYPE_NOSTAT; + } + if (strstr(str, "idempotent") != NULL) { + ret |= SYMBOL_TYPE_IDEMPOTENT; + } + if (strstr(str, "trivial") != NULL) { + ret |= SYMBOL_TYPE_TRIVIAL; + } + if (strstr(str, "ghost") != NULL) { + ret |= SYMBOL_TYPE_GHOST; + } + if (strstr(str, "mime") != NULL) { + ret |= SYMBOL_TYPE_MIME_ONLY; + } + if (strstr(str, "ignore_passthrough") != NULL) { + ret |= SYMBOL_TYPE_IGNORE_PASSTHROUGH; + } + if (strstr(str, "explicit_disable") != NULL) { + ret |= SYMBOL_TYPE_EXPLICIT_DISABLE; + } + if (strstr(str, "explicit_enable") != NULL) { + ret |= SYMBOL_TYPE_EXPLICIT_ENABLE; + } + if (strstr(str, "coro") != NULL) { + ret |= SYMBOL_TYPE_USE_CORO; + } + } + + return ret; +} + +static guint +lua_parse_symbol_type(const gchar *str) +{ + guint ret = SYMBOL_TYPE_NORMAL; + gchar **vec; + guint i, l; + + if (str) { + vec = g_strsplit_set(str, ",;", -1); + + if (vec) { + l = g_strv_length(vec); + + for (i = 0; i < l; i++) { + str = vec[i]; + + /* TODO: total shit, rework some day */ + if (g_ascii_strcasecmp(str, "virtual") == 0) { + ret |= SYMBOL_TYPE_VIRTUAL; + ret &= ~SYMBOL_TYPE_NORMAL; + ret &= ~SYMBOL_TYPE_CALLBACK; + } + else if (g_ascii_strcasecmp(str, "callback") == 0) { + ret |= SYMBOL_TYPE_CALLBACK; + ret &= ~SYMBOL_TYPE_NORMAL; + ret &= ~SYMBOL_TYPE_VIRTUAL; + } + else if (g_ascii_strcasecmp(str, "normal") == 0) { + ret |= SYMBOL_TYPE_NORMAL; + ret &= ~SYMBOL_TYPE_CALLBACK; + ret &= ~SYMBOL_TYPE_VIRTUAL; + } + else if (g_ascii_strcasecmp(str, "prefilter") == 0) { + ret |= SYMBOL_TYPE_PREFILTER | SYMBOL_TYPE_GHOST; + } + else if (g_ascii_strcasecmp(str, "postfilter") == 0) { + ret |= SYMBOL_TYPE_POSTFILTER | SYMBOL_TYPE_GHOST; + } + else if (g_ascii_strcasecmp(str, "connfilter") == 0 || + g_ascii_strcasecmp(str, "conn_filter") == 0) { + ret |= SYMBOL_TYPE_CONNFILTER | SYMBOL_TYPE_GHOST; + } + else if (g_ascii_strcasecmp(str, "idempotent") == 0) { + ret |= SYMBOL_TYPE_GHOST | + SYMBOL_TYPE_IDEMPOTENT | SYMBOL_TYPE_CALLBACK; + } + else { + gint fl = 0; + + fl = lua_parse_symbol_flags(str); + + if (fl == 0) { + msg_warn("bad type: %s", str); + } + else { + ret |= fl; + } + } + } + + g_strfreev(vec); + } + } + + return ret; +} + +enum lua_push_symbol_flags_opts { + LUA_SYMOPT_FLAG_CREATE_ARRAY = 1u << 0u, + LUA_SYMOPT_FLAG_CREATE_MAP = 1u << 1u, + LUA_SYMOPT_FLAG_USE_MAP = 1u << 2u, + LUA_SYMOPT_FLAG_USE_ARRAY = 1u << 3u, +}; + +#define LUA_SYMOPT_IS_ARRAY(f) ((f) & (LUA_SYMOPT_FLAG_CREATE_ARRAY | LUA_SYMOPT_FLAG_USE_ARRAY)) +#define LUA_SYMOPT_IS_CREATE(f) ((f) & (LUA_SYMOPT_FLAG_CREATE_ARRAY | LUA_SYMOPT_FLAG_CREATE_MAP)) +#define LUA_OPTION_PUSH(nm) \ + do { \ + if (LUA_SYMOPT_IS_ARRAY(fl)) { \ + lua_pushstring(L, #nm); \ + lua_rawseti(L, -2, i++); \ + } \ + else { \ + lua_pushboolean(L, true); \ + lua_setfield(L, -2, #nm); \ + } \ + } while (0) + +static void +lua_push_symbol_flags(lua_State *L, guint flags, enum lua_push_symbol_flags_opts fl) +{ + guint i = 1; + + if (LUA_SYMOPT_IS_CREATE(fl)) { + lua_newtable(L); + } + + if (flags & SYMBOL_TYPE_FINE) { + LUA_OPTION_PUSH(fine); + } + + if (flags & SYMBOL_TYPE_EMPTY) { + LUA_OPTION_PUSH(empty); + } + + if (flags & SYMBOL_TYPE_EXPLICIT_DISABLE) { + LUA_OPTION_PUSH(explicit_disable); + } + + if (flags & SYMBOL_TYPE_EXPLICIT_ENABLE) { + LUA_OPTION_PUSH(explicit_enable); + } + + if (flags & SYMBOL_TYPE_IGNORE_PASSTHROUGH) { + LUA_OPTION_PUSH(ignore_passthrough); + } + + if (flags & SYMBOL_TYPE_NOSTAT) { + LUA_OPTION_PUSH(nostat); + } + + if (flags & SYMBOL_TYPE_IDEMPOTENT) { + LUA_OPTION_PUSH(idempotent); + } + + if (flags & SYMBOL_TYPE_MIME_ONLY) { + LUA_OPTION_PUSH(mime); + } + + if (flags & SYMBOL_TYPE_TRIVIAL) { + LUA_OPTION_PUSH(trivial); + } + + if (flags & SYMBOL_TYPE_SKIPPED) { + LUA_OPTION_PUSH(skip); + } + + if (flags & SYMBOL_TYPE_COMPOSITE) { + LUA_OPTION_PUSH(composite); + } +} + +static gint +lua_config_get_symbol_flags(lua_State *L) +{ + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *name = luaL_checkstring(L, 2); + guint flags; + + if (cfg && name) { + flags = rspamd_symcache_get_symbol_flags(cfg->cache, + name); + + if (flags != 0) { + lua_push_symbol_flags(L, flags, LUA_SYMOPT_FLAG_CREATE_ARRAY); + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_config_register_symbol(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *name = NULL, *type_str = NULL, + *description = NULL, *group = NULL; + double weight = 0, score = NAN, parent_float = NAN; + gboolean one_shot = FALSE; + gint ret = -1, cbref = -1; + guint type = 0, flags = 0; + gint64 parent = 0, priority = 0, nshots = 0; + GArray *allowed_ids = NULL, *forbidden_ids = NULL; + GError *err = NULL; + int prev_top = lua_gettop(L); + + if (cfg) { + if (!rspamd_lua_parse_table_arguments(L, 2, &err, + RSPAMD_LUA_PARSE_ARGUMENTS_DEFAULT, + "name=S;weight=N;callback=F;type=S;priority=I;parent=D;" + "score=D;description=S;group=S;one_shot=B;nshots=I", + &name, &weight, &cbref, &type_str, + &priority, &parent_float, + &score, &description, &group, &one_shot, &nshots)) { + msg_err_config("bad arguments: %e", err); + g_error_free(err); + lua_settop(L, prev_top); + + return luaL_error(L, "invalid arguments"); + } + + /* Deal with flags and ids */ + lua_pushstring(L, "flags"); + lua_gettable(L, 2); + if (lua_type(L, -1) == LUA_TSTRING) { + flags = lua_parse_symbol_flags(lua_tostring(L, -1)); + } + else if (lua_type(L, -1) == LUA_TTABLE) { + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + flags |= lua_parse_symbol_flags(lua_tostring(L, -1)); + } + } + lua_pop(L, 1); /* Clean flags */ + + lua_pushstring(L, "allowed_ids"); + lua_gettable(L, 2); + if (lua_type(L, -1) == LUA_TSTRING) { + allowed_ids = rspamd_process_id_list(lua_tostring(L, -1)); + } + else if (lua_type(L, -1) == LUA_TTABLE) { + allowed_ids = g_array_sized_new(FALSE, FALSE, sizeof(guint32), + rspamd_lua_table_size(L, -1)); + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + guint32 v = lua_tointeger(L, -1); + g_array_append_val(allowed_ids, v); + } + } + lua_pop(L, 1); + + lua_pushstring(L, "forbidden_ids"); + lua_gettable(L, 2); + if (lua_type(L, -1) == LUA_TSTRING) { + forbidden_ids = rspamd_process_id_list(lua_tostring(L, -1)); + } + else if (lua_type(L, -1) == LUA_TTABLE) { + forbidden_ids = g_array_sized_new(FALSE, FALSE, sizeof(guint32), + rspamd_lua_table_size(L, -1)); + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + guint32 v = lua_tointeger(L, -1); + g_array_append_val(forbidden_ids, v); + } + } + lua_pop(L, 1); + + if (nshots == 0) { + nshots = cfg->default_max_shots; + } + + type = lua_parse_symbol_type(type_str); + + if (!name && !(type & SYMBOL_TYPE_CALLBACK)) { + lua_settop(L, prev_top); + return luaL_error(L, "no symbol name but type is not callback"); + } + else if (!(type & SYMBOL_TYPE_VIRTUAL) && cbref == -1) { + lua_settop(L, prev_top); + return luaL_error(L, "no callback for symbol %s", name); + } + + if (isnan(parent_float)) { + parent = -1; + } + else { + parent = parent_float; + } + + ret = rspamd_register_symbol_fromlua(L, + cfg, + name, + cbref, + weight == 0 ? 1.0 : weight, + priority, + type | flags, + parent, + allowed_ids, forbidden_ids, + FALSE); + + if (allowed_ids) { + g_array_free(allowed_ids, TRUE); + } + + if (forbidden_ids) { + g_array_free(forbidden_ids, TRUE); + } + + if (ret != -1) { + if (!isnan(score) || group) { + if (one_shot) { + nshots = 1; + } + + rspamd_config_add_symbol(cfg, name, + score, description, group, flags, + 0, nshots); + + lua_pushstring(L, "groups"); + lua_gettable(L, 2); + + if (lua_istable(L, -1)) { + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + if (lua_isstring(L, -1)) { + rspamd_config_add_symbol_group(cfg, name, + lua_tostring(L, -1)); + } + else { + lua_settop(L, prev_top); + return luaL_error(L, "invalid groups element"); + } + } + } + + lua_pop(L, 1); + } + + lua_pushstring(L, "augmentations"); + lua_gettable(L, 2); + + if (lua_type(L, -1) == LUA_TTABLE) { + int tbl_idx = lua_gettop(L); + for (lua_pushnil(L); lua_next(L, tbl_idx); lua_pop(L, 1)) { + size_t len; + const char *augmentation = lua_tolstring(L, -1, &len), *eqsign_pos; + + /* Find `=` symbol and use it as a separator */ + eqsign_pos = memchr(augmentation, '=', len); + if (eqsign_pos != NULL && eqsign_pos + 1 < augmentation + len) { + rspamd_ftok_t tok; + + tok.begin = augmentation; + tok.len = eqsign_pos - augmentation; + char *augentation_name = rspamd_ftokdup(&tok); + + tok.begin = eqsign_pos + 1; + tok.len = (augmentation + len) - tok.begin; + + char *augmentation_value = rspamd_ftokdup(&tok); + + if (!rspamd_symcache_add_symbol_augmentation(cfg->cache, ret, + augentation_name, augmentation_value)) { + lua_settop(L, prev_top); + g_free(augmentation_value); + g_free(augentation_name); + + return luaL_error(L, "unknown or invalid augmentation %s in symbol %s", + augmentation, name); + } + + g_free(augmentation_value); + g_free(augentation_name); + } + else { + /* Just a value */ + if (!rspamd_symcache_add_symbol_augmentation(cfg->cache, ret, + augmentation, NULL)) { + lua_settop(L, prev_top); + + return luaL_error(L, "unknown augmentation %s in symbol %s", + augmentation, name); + } + } + } + } + } + } + else { + lua_settop(L, prev_top); + + return luaL_error(L, "invalid arguments"); + } + + lua_settop(L, prev_top); + lua_pushinteger(L, ret); + + return 1; +} + +static gint +lua_config_register_symbols(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + gint i, top, idx, ret = -1; + const gchar *sym; + gdouble weight = 1.0; + + if (lua_gettop(L) < 3) { + if (cfg) { + msg_err_config("not enough arguments to register a function"); + } + + lua_error(L); + + return 0; + } + if (cfg) { + if (lua_type(L, 2) == LUA_TSTRING) { + lua_getglobal(L, luaL_checkstring(L, 2)); + } + else { + lua_pushvalue(L, 2); + } + idx = luaL_ref(L, LUA_REGISTRYINDEX); + + if (lua_type(L, 3) == LUA_TNUMBER) { + weight = lua_tonumber(L, 3); + top = 4; + } + else { + top = 3; + } + sym = luaL_checkstring(L, top++); + ret = rspamd_register_symbol_fromlua(L, + cfg, + sym, + idx, + weight, + 0, + SYMBOL_TYPE_CALLBACK, + -1, + NULL, NULL, + FALSE); + + for (i = top; i <= lua_gettop(L); i++) { + if (lua_type(L, i) == LUA_TTABLE) { + lua_pushvalue(L, i); + lua_pushnil(L); + while (lua_next(L, -2)) { + lua_pushvalue(L, -2); + sym = luaL_checkstring(L, -2); + rspamd_symcache_add_symbol(cfg->cache, sym, + 0, NULL, NULL, + SYMBOL_TYPE_VIRTUAL, ret); + lua_pop(L, 2); + } + lua_pop(L, 1); + } + else if (lua_type(L, i) == LUA_TSTRING) { + sym = luaL_checkstring(L, i); + rspamd_symcache_add_symbol(cfg->cache, sym, + 0, NULL, NULL, + SYMBOL_TYPE_VIRTUAL, ret); + } + } + } + + lua_pushinteger(L, ret); + + return 1; +} + +static gint +lua_config_register_virtual_symbol(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *name; + double weight; + gint ret = -1, parent = -1; + + if (cfg) { + name = luaL_checkstring(L, 2); + weight = luaL_checknumber(L, 3); + + if (lua_gettop(L) > 3) { + parent = lua_tonumber(L, 4); + } + + if (name) { + ret = rspamd_symcache_add_symbol(cfg->cache, name, + weight > 0 ? 0 : -1, NULL, NULL, + SYMBOL_TYPE_VIRTUAL, parent); + } + } + + lua_pushinteger(L, ret); + + return 1; +} + +static gint +lua_config_register_callback_symbol(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *name = NULL; + double weight; + gint ret = -1, top = 2; + + if (cfg) { + if (lua_type(L, 2) == LUA_TSTRING) { + /* Legacy syntax */ + name = luaL_checkstring(L, 2); + top++; + } + + weight = luaL_checknumber(L, top); + + if (lua_type(L, top + 1) == LUA_TSTRING) { + lua_getglobal(L, luaL_checkstring(L, top + 1)); + } + else { + lua_pushvalue(L, top + 1); + } + ret = rspamd_register_symbol_fromlua(L, + cfg, + name, + luaL_ref(L, LUA_REGISTRYINDEX), + weight, + 0, + SYMBOL_TYPE_CALLBACK, + -1, + NULL, NULL, + FALSE); + } + + lua_pushinteger(L, ret); + + return 1; +} + +static gint +lua_config_register_callback_symbol_priority(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *name = NULL; + double weight; + gint priority, ret = -1, top = 2; + + if (cfg) { + if (lua_type(L, 2) == LUA_TSTRING) { + /* Legacy syntax */ + name = luaL_checkstring(L, 2); + top++; + } + + weight = luaL_checknumber(L, top); + priority = luaL_checknumber(L, top + 1); + + if (lua_type(L, top + 2) == LUA_TSTRING) { + lua_getglobal(L, luaL_checkstring(L, top + 2)); + } + else { + lua_pushvalue(L, top + 2); + } + + ret = rspamd_register_symbol_fromlua(L, + cfg, + name, + luaL_ref(L, LUA_REGISTRYINDEX), + weight, + priority, + SYMBOL_TYPE_CALLBACK, + -1, + NULL, NULL, + FALSE); + } + + lua_pushinteger(L, ret); + + return 1; +} + + +static gint +lua_config_register_dependency(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *parent = NULL, *child = NULL; + gint child_id; + + if (cfg == NULL) { + lua_error(L); + return 0; + } + + if (lua_type(L, 2) == LUA_TNUMBER) { + child_id = luaL_checknumber(L, 2); + parent = luaL_checkstring(L, 3); + + return luaL_error(L, "calling for obsolete method to register deps for symbol %d->%s", + child_id, parent); + } + else { + child = luaL_checkstring(L, 2); + parent = luaL_checkstring(L, 3); + + if (child != NULL && parent != NULL) { + rspamd_symcache_add_delayed_dependency(cfg->cache, child, + parent); + } + } + + return 0; +} + +static gint +lua_config_set_metric_symbol(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *description = NULL, + *group = NULL, *name = NULL, *flags_str = NULL; + double score; + gboolean one_shot = FALSE, one_param = FALSE; + GError *err = NULL; + gdouble priority = 0.0; + guint flags = 0; + gint64 nshots = 0; + + if (cfg) { + + if (lua_type(L, 2) == LUA_TTABLE) { + if (!rspamd_lua_parse_table_arguments(L, 2, &err, + RSPAMD_LUA_PARSE_ARGUMENTS_DEFAULT, + "*name=S;score=N;description=S;" + "group=S;one_shot=B;one_param=B;priority=N;flags=S;" + "nshots=I", + &name, &score, &description, + &group, &one_shot, &one_param, + &priority, &flags_str, &nshots)) { + msg_err_config("bad arguments: %e", err); + g_error_free(err); + + return 0; + } + } + else { + name = luaL_checkstring(L, 2); + score = luaL_checknumber(L, 3); + + if (lua_gettop(L) > 3 && lua_type(L, 4) == LUA_TSTRING) { + description = luaL_checkstring(L, 4); + } + if (lua_gettop(L) > 4 && lua_type(L, 5) == LUA_TSTRING) { + /* XXX: metrics */ + } + if (lua_gettop(L) > 5 && lua_type(L, 6) == LUA_TSTRING) { + group = luaL_checkstring(L, 6); + } + if (lua_gettop(L) > 6 && lua_type(L, 7) == LUA_TBOOLEAN) { + one_shot = lua_toboolean(L, 7); + } + } + + if (nshots == 0) { + nshots = cfg->default_max_shots; + } + + if (one_shot) { + nshots = 1; + } + if (one_param) { + flags |= RSPAMD_SYMBOL_FLAG_ONEPARAM; + } + + if (flags_str) { + if (strstr(flags_str, "one_shot") != NULL) { + nshots = 1; + } + if (strstr(flags_str, "ignore") != NULL) { + flags |= RSPAMD_SYMBOL_FLAG_IGNORE_METRIC; + } + if (strstr(flags_str, "one_param") != NULL) { + flags |= RSPAMD_SYMBOL_FLAG_ONEPARAM; + } + } + + rspamd_config_add_symbol(cfg, name, + score, description, group, flags, (guint) priority, nshots); + + + if (lua_type(L, 2) == LUA_TTABLE) { + lua_pushstring(L, "groups"); + lua_gettable(L, 2); + + if (lua_istable(L, -1)) { + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + if (lua_isstring(L, -1)) { + rspamd_config_add_symbol_group(cfg, name, + lua_tostring(L, -1)); + } + else { + return luaL_error(L, "invalid groups element"); + } + } + } + + lua_pop(L, 1); + } + } + else { + return luaL_error(L, "invalid arguments, rspamd_config expected"); + } + + return 0; +} + +static gint +lua_config_set_metric_action(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *name = NULL; + double threshold = NAN; + GError *err = NULL; + gdouble priority = 0.0; + ucl_object_t *obj_tbl = NULL; + + if (cfg) { + + if (lua_type(L, 2) == LUA_TTABLE) { + if (!rspamd_lua_parse_table_arguments(L, 2, &err, + RSPAMD_LUA_PARSE_ARGUMENTS_DEFAULT, + "*action=S;score=N;" + "priority=N", + &name, &threshold, + &priority)) { + msg_err_config("bad arguments: %e", err); + g_error_free(err); + + return 0; + } + } + else if (lua_type(L, 2) == LUA_TSTRING && lua_type(L, 3) == LUA_TTABLE) { + name = lua_tostring(L, 2); + obj_tbl = ucl_object_lua_import(L, 3); + + if (obj_tbl) { + if (name) { + rspamd_config_set_action_score(cfg, name, obj_tbl); + ucl_object_unref(obj_tbl); + } + else { + ucl_object_unref(obj_tbl); + return luaL_error(L, "invalid first argument, action name expected"); + } + } + else { + return luaL_error(L, "invalid second argument, table expected"); + } + } + else { + return luaL_error(L, "invalid arguments, table expected"); + } + + if (name != NULL && !isnan(threshold) && threshold != 0) { + obj_tbl = ucl_object_typed_new(UCL_OBJECT); + ucl_object_insert_key(obj_tbl, ucl_object_fromdouble(threshold), + "score", 0, false); + ucl_object_insert_key(obj_tbl, ucl_object_fromdouble(priority), + "priority", 0, false); + rspamd_config_set_action_score(cfg, name, obj_tbl); + ucl_object_unref(obj_tbl); + } + } + else { + return luaL_error(L, "invalid arguments, rspamd_config expected"); + } + + return 0; +} + +static gint +lua_config_get_metric_action(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *act_name = luaL_checkstring(L, 2); + struct rspamd_action *act; + + if (cfg && act_name) { + act = rspamd_config_get_action(cfg, act_name); + + if (act) { + if (!isnan(act->threshold)) { + lua_pushnumber(L, act->threshold); + } + else { + lua_pushnil(L); + } + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments, rspamd_config expected"); + } + + return 1; +} + +static void +lua_config_actions_cb(struct rspamd_action *act, void *cbd) +{ + lua_State *L = (lua_State *) cbd; + + if (!isnan(act->threshold)) { + lua_pushstring(L, act->name); + lua_pushnumber(L, act->threshold); + lua_settable(L, -3); + } +} + +static gint +lua_config_get_all_actions(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + + if (cfg) { + lua_createtable(L, 0, rspamd_config_actions_size(cfg)); + rspamd_config_actions_foreach(cfg, lua_config_actions_cb, L); + } + else { + return luaL_error(L, "invalid arguments, rspamd_config expected"); + } + + return 1; +} + +static gint +lua_config_add_composite(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + gchar *name; + const gchar *expr_str; + struct rspamd_composite *composite; + gboolean ret = FALSE; + + if (cfg) { + name = rspamd_mempool_strdup(cfg->cfg_pool, luaL_checkstring(L, 2)); + expr_str = luaL_checkstring(L, 3); + + if (name && expr_str) { + composite = rspamd_composites_manager_add_from_string(cfg->composites_manager, + name, expr_str); + + if (composite) { + rspamd_symcache_add_symbol(cfg->cache, name, + 0, NULL, composite, SYMBOL_TYPE_COMPOSITE, -1); + ret = TRUE; + } + } + } + + lua_pushboolean(L, ret); + + return 1; +} + +static gint +lua_config_newindex(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *name; + GArray *allowed_ids = NULL, *forbidden_ids = NULL; + gint id, nshots; + guint flags = 0; + gboolean optional = FALSE; + + name = luaL_checkstring(L, 2); + + if (cfg != NULL && name != NULL && lua_gettop(L) == 3) { + + if (lua_type(L, 3) == LUA_TFUNCTION) { + /* Normal symbol from just a function */ + lua_pushvalue(L, 3); + rspamd_register_symbol_fromlua(L, + cfg, + name, + luaL_ref(L, LUA_REGISTRYINDEX), + 1.0, + 0, + SYMBOL_TYPE_NORMAL, + -1, + NULL, NULL, + FALSE); + } + else if (lua_type(L, 3) == LUA_TTABLE) { + guint type = SYMBOL_TYPE_NORMAL, priority = 0; + gint idx; + gdouble weight = 1.0, score = NAN; + const char *type_str, *group = NULL, *description = NULL; + + /* + * Table can have the following attributes: + * "callback" - should be a callback function + * "weight" - optional weight + * "priority" - optional priority + * "type" - optional type (normal, virtual, callback) + * "flags" - optional flags + * -- Metric options + * "score" - optional default score (overridden by metric) + * "group" - optional default group + * "one_shot" - optional one shot mode + * "description" - optional description + */ + lua_pushvalue(L, 3); + lua_pushstring(L, "callback"); + lua_gettable(L, -2); + + if (lua_type(L, -1) != LUA_TFUNCTION) { + lua_pop(L, 2); + msg_info_config("cannot find callback definition for %s", + name); + return 0; + } + idx = luaL_ref(L, LUA_REGISTRYINDEX); + + /* Optional fields */ + lua_pushstring(L, "weight"); + lua_gettable(L, -2); + + if (lua_type(L, -1) == LUA_TNUMBER) { + weight = lua_tonumber(L, -1); + } + lua_pop(L, 1); + + lua_pushstring(L, "priority"); + lua_gettable(L, -2); + + if (lua_type(L, -1) == LUA_TNUMBER) { + priority = lua_tointeger(L, -1); + } + lua_pop(L, 1); + + lua_pushstring(L, "optional"); + lua_gettable(L, -2); + + if (lua_type(L, -1) == LUA_TBOOLEAN) { + optional = lua_toboolean(L, -1); + } + lua_pop(L, 1); + + lua_pushstring(L, "type"); + lua_gettable(L, -2); + + if (lua_type(L, -1) == LUA_TSTRING) { + type_str = lua_tostring(L, -1); + type = lua_parse_symbol_type(type_str); + } + lua_pop(L, 1); + + /* Deal with flags and ids */ + lua_pushstring(L, "flags"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TSTRING) { + flags = lua_parse_symbol_flags(lua_tostring(L, -1)); + } + else if (lua_type(L, -1) == LUA_TTABLE) { + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + flags |= lua_parse_symbol_flags(lua_tostring(L, -1)); + } + } + lua_pop(L, 1); /* Clean flags */ + + lua_pushstring(L, "allowed_ids"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TSTRING) { + allowed_ids = rspamd_process_id_list(lua_tostring(L, -1)); + } + else if (lua_type(L, -1) == LUA_TTABLE) { + allowed_ids = g_array_sized_new(FALSE, FALSE, sizeof(guint32), + rspamd_lua_table_size(L, -1)); + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + guint32 v = lua_tointeger(L, -1); + g_array_append_val(allowed_ids, v); + } + } + lua_pop(L, 1); + + lua_pushstring(L, "forbidden_ids"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TSTRING) { + forbidden_ids = rspamd_process_id_list(lua_tostring(L, -1)); + } + else if (lua_type(L, -1) == LUA_TTABLE) { + forbidden_ids = g_array_sized_new(FALSE, FALSE, sizeof(guint32), + rspamd_lua_table_size(L, -1)); + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + guint32 v = lua_tointeger(L, -1); + g_array_append_val(forbidden_ids, v); + } + } + lua_pop(L, 1); + + id = rspamd_register_symbol_fromlua(L, + cfg, + name, + idx, + weight, + priority, + type | flags, + -1, + allowed_ids, forbidden_ids, + optional); + + if (allowed_ids) { + g_array_free(allowed_ids, TRUE); + } + + if (forbidden_ids) { + g_array_free(forbidden_ids, TRUE); + } + + if (id != -1) { + /* Check for condition */ + lua_pushstring(L, "condition"); + lua_gettable(L, -2); + + if (lua_type(L, -1) == LUA_TFUNCTION) { + gint condref; + + /* Here we pop function from the stack, so no lua_pop is required */ + condref = luaL_ref(L, LUA_REGISTRYINDEX); + g_assert(name != NULL); + rspamd_symcache_add_condition_delayed(cfg->cache, + name, L, condref); + } + else { + lua_pop(L, 1); + } + + /* Check for augmentations */ + lua_pushstring(L, "augmentations"); + lua_gettable(L, -2); + + if (lua_type(L, -1) == LUA_TTABLE) { + + int tbl_idx = lua_gettop(L); + for (lua_pushnil(L); lua_next(L, tbl_idx); lua_pop(L, 1)) { + rspamd_symcache_add_symbol_augmentation(cfg->cache, id, + lua_tostring(L, -1), NULL); + } + } + + lua_pop(L, 1); + } + + /* + * Now check if a symbol has not been registered in any metric and + * insert default value if applicable + */ + struct rspamd_symbol *sym = g_hash_table_lookup(cfg->symbols, name); + if (sym == NULL || (sym->flags & RSPAMD_SYMBOL_FLAG_UNSCORED)) { + nshots = cfg->default_max_shots; + + lua_pushstring(L, "score"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TNUMBER) { + score = lua_tonumber(L, -1); + + if (sym) { + /* Reset unscored flag */ + sym->flags &= ~RSPAMD_SYMBOL_FLAG_UNSCORED; + } + } + lua_pop(L, 1); + + lua_pushstring(L, "group"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TSTRING) { + group = lua_tostring(L, -1); + } + lua_pop(L, 1); + + if (!isnan(score) || group != NULL) { + lua_pushstring(L, "description"); + lua_gettable(L, -2); + + if (lua_type(L, -1) == LUA_TSTRING) { + description = lua_tostring(L, -1); + } + lua_pop(L, 1); + + lua_pushstring(L, "one_shot"); + lua_gettable(L, -2); + + if (lua_type(L, -1) == LUA_TBOOLEAN) { + if (lua_toboolean(L, -1)) { + nshots = 1; + } + } + lua_pop(L, 1); + + lua_pushstring(L, "one_param"); + lua_gettable(L, -2); + + if (lua_type(L, -1) == LUA_TBOOLEAN) { + if (lua_toboolean(L, -1)) { + flags |= RSPAMD_SYMBOL_FLAG_ONEPARAM; + } + } + lua_pop(L, 1); + + /* + * Do not override the existing symbols (using zero priority), + * since we are defining default values here + */ + if (!isnan(score)) { + rspamd_config_add_symbol(cfg, name, score, + description, group, flags, 0, nshots); + } + else if (group) { + /* Add with zero score */ + rspamd_config_add_symbol(cfg, name, NAN, + description, group, flags, 0, nshots); + } + + lua_pushstring(L, "groups"); + lua_gettable(L, -2); + + if (lua_istable(L, -1)) { + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + if (lua_isstring(L, -1)) { + rspamd_config_add_symbol_group(cfg, name, + lua_tostring(L, -1)); + } + else { + return luaL_error(L, "invalid groups element"); + } + } + } + + lua_pop(L, 1); + } + } + else { + /* Fill in missing fields from lua definition if they are not set */ + if (sym->description == NULL) { + lua_pushstring(L, "description"); + lua_gettable(L, -2); + + if (lua_type(L, -1) == LUA_TSTRING) { + description = lua_tostring(L, -1); + } + lua_pop(L, 1); + + if (description) { + sym->description = rspamd_mempool_strdup(cfg->cfg_pool, description); + } + } + + /* If ungrouped and there is a group defined in lua, change the primary group + * Otherwise, add to the list of groups for this symbol. */ + lua_pushstring(L, "group"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TSTRING) { + group = lua_tostring(L, -1); + } + lua_pop(L, 1); + if (group) { + if (sym->flags & RSPAMD_SYMBOL_FLAG_UNGROUPED) { + /* Unset the "ungrouped" group */ + sym->gr = NULL; + } + /* Add the group. If the symbol was ungrouped, this will + * clear RSPAMD_SYMBOL_FLAG_UNGROUPED from the flags. */ + rspamd_config_add_symbol_group(cfg, name, group); + } + } + + /* Remove table from stack */ + lua_pop(L, 1); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static gint +lua_config_add_condition(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *sym = luaL_checkstring(L, 2); + gboolean ret = FALSE; + gint condref; + + if (cfg && sym && lua_type(L, 3) == LUA_TFUNCTION) { + lua_pushvalue(L, 3); + condref = luaL_ref(L, LUA_REGISTRYINDEX); + + ret = rspamd_symcache_add_condition_delayed(cfg->cache, sym, L, + condref); + + if (!ret) { + luaL_unref(L, LUA_REGISTRYINDEX, condref); + } + } + + lua_pushboolean(L, ret); + return 1; +} + +static gint +lua_config_set_peak_cb(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + gint condref; + + if (cfg && lua_type(L, 2) == LUA_TFUNCTION) { + lua_pushvalue(L, 2); + condref = luaL_ref(L, LUA_REGISTRYINDEX); + rspamd_symcache_set_peak_callback(cfg->cache, + condref); + } + + return 0; +} + +static gint +lua_config_enable_symbol(lua_State *L) +{ + struct rspamd_config *cfg = lua_check_config(L, 1); + const char *sym = luaL_checkstring(L, 2); + + if (!sym || !cfg) { + return luaL_error(L, "invalid arguments"); + } + + rspamd_symcache_enable_symbol_static(cfg->cache, sym); + + return 0; +} + +static gint +lua_config_disable_symbol(lua_State *L) +{ + struct rspamd_config *cfg = lua_check_config(L, 1); + const char *sym = luaL_checkstring(L, 2); + + if (!sym || !cfg) { + return luaL_error(L, "invalid arguments"); + } + + rspamd_symcache_disable_symbol_static(cfg->cache, sym); + + return 0; +} + +static gint +lua_config_register_regexp(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + struct rspamd_lua_regexp *re = NULL; + rspamd_regexp_t *cache_re; + const gchar *type_str = NULL, *header_str = NULL; + gsize header_len = 0; + GError *err = NULL; + enum rspamd_re_type type = RSPAMD_RE_BODY; + gboolean pcre_only = FALSE; + + /* + * - `re`* : regular expression object + * - `type`*: type of regular expression: + * + `mime`: mime regexp + * + `rawmime`: raw mime regexp + * + `header`: header regexp + * + `rawheader`: raw header expression + * + `body`: raw body regexp + * + `url`: url regexp + * - `header`: for header and rawheader regexp means the name of header + * - `pcre_only`: allow merely pcre for this regexp + */ + if (cfg != NULL) { + if (!rspamd_lua_parse_table_arguments(L, 2, &err, + RSPAMD_LUA_PARSE_ARGUMENTS_DEFAULT, + "*re=U{regexp};*type=S;header=S;pcre_only=B", + &re, &type_str, &header_str, &pcre_only)) { + msg_err_config("cannot get parameters list: %e", err); + + if (err) { + g_error_free(err); + } + } + else { + type = rspamd_re_cache_type_from_string(type_str); + + if ((type == RSPAMD_RE_HEADER || + type == RSPAMD_RE_RAWHEADER || + type == RSPAMD_RE_MIMEHEADER) && + header_str == NULL) { + msg_err_config( + "header argument is mandatory for header/rawheader regexps"); + } + else { + if (pcre_only) { + rspamd_regexp_set_flags(re->re, + rspamd_regexp_get_flags(re->re) | RSPAMD_REGEXP_FLAG_PCRE_ONLY); + } + + if (header_str != NULL) { + /* Include the last \0 */ + header_len = strlen(header_str) + 1; + } + + cache_re = rspamd_re_cache_add(cfg->re_cache, re->re, type, + (gpointer) header_str, header_len, -1); + + /* + * XXX: here are dragons! + * Actually, lua regexp contains internal rspamd_regexp_t + * and it owns it. + * However, after this operation we have some OTHER regexp, + * which we really would like to use. + * So we do the following: + * 1) Remove old re and unref it + * 2) Replace the internal re with cached one + * 3) Increase its refcount to share ownership between cache and + * lua object + */ + if (cache_re != re->re) { + rspamd_regexp_unref(re->re); + re->re = rspamd_regexp_ref(cache_re); + + if (pcre_only) { + rspamd_regexp_set_flags(re->re, + rspamd_regexp_get_flags(re->re) | RSPAMD_REGEXP_FLAG_PCRE_ONLY); + } + } + } + } + } + + return 0; +} + +static gint +lua_config_replace_regexp(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + struct rspamd_lua_regexp *old_re = NULL, *new_re = NULL; + gboolean pcre_only = FALSE; + GError *err = NULL; + + if (cfg != NULL) { + if (!rspamd_lua_parse_table_arguments(L, 2, &err, + RSPAMD_LUA_PARSE_ARGUMENTS_DEFAULT, + "*old_re=U{regexp};*new_re=U{regexp};pcre_only=B", + &old_re, &new_re, &pcre_only)) { + gint ret = luaL_error(L, "cannot get parameters list: %s", + err ? err->message : "invalid arguments"); + + if (err) { + g_error_free(err); + } + + return ret; + } + else { + + if (pcre_only) { + rspamd_regexp_set_flags(new_re->re, + rspamd_regexp_get_flags(new_re->re) | RSPAMD_REGEXP_FLAG_PCRE_ONLY); + } + + rspamd_re_cache_replace(cfg->re_cache, old_re->re, new_re->re); + } + } + + return 0; +} + +static gint +lua_config_register_worker_script(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *worker_type = luaL_checkstring(L, 2), *wtype; + struct rspamd_worker_conf *cf; + GList *cur; + struct rspamd_worker_lua_script *sc; + gboolean found = FALSE; + + if (cfg == NULL || worker_type == NULL || lua_type(L, 3) != LUA_TFUNCTION) { + return luaL_error(L, "invalid arguments"); + } + + for (cur = g_list_first(cfg->workers); cur != NULL; cur = g_list_next(cur)) { + cf = cur->data; + wtype = g_quark_to_string(cf->type); + + if (g_ascii_strcasecmp(wtype, worker_type) == 0) { + sc = rspamd_mempool_alloc0(cfg->cfg_pool, sizeof(*sc)); + lua_pushvalue(L, 3); + sc->cbref = luaL_ref(L, LUA_REGISTRYINDEX); + DL_APPEND(cf->scripts, sc); + found = TRUE; + } + } + + lua_pushboolean(L, found); + + return 1; +} + +static gint +lua_config_add_on_load(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + struct rspamd_config_cfg_lua_script *sc; + + if (cfg == NULL || lua_type(L, 2) != LUA_TFUNCTION) { + return luaL_error(L, "invalid arguments"); + } + + sc = rspamd_mempool_alloc0(cfg->cfg_pool, sizeof(*sc)); + lua_pushvalue(L, 2); + sc->cbref = luaL_ref(L, LUA_REGISTRYINDEX); + DL_APPEND(cfg->on_load_scripts, sc); + + return 0; +} + +static inline int +rspamd_post_init_sc_sort(const struct rspamd_config_cfg_lua_script *pra, + const struct rspamd_config_cfg_lua_script *prb) +{ + /* Inverse sort */ + return prb->priority - pra->priority; +} + +static gint +lua_config_add_post_init(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + struct rspamd_config_cfg_lua_script *sc; + guint priority = 0; + lua_Debug d; + gchar tmp[256], *p; + + if (cfg == NULL || lua_type(L, 2) != LUA_TFUNCTION) { + return luaL_error(L, "invalid arguments"); + } + + if (lua_type(L, 3) == LUA_TNUMBER) { + priority = lua_tointeger(L, 3); + } + + if (lua_getstack(L, 1, &d) == 1) { + (void) lua_getinfo(L, "Sl", &d); + if ((p = strrchr(d.short_src, '/')) == NULL) { + p = d.short_src; + } + else { + p++; + } + + if (strlen(p) > 200) { + rspamd_snprintf(tmp, sizeof(tmp), "%10s...]:%d", p, + d.currentline); + } + else { + rspamd_snprintf(tmp, sizeof(tmp), "%s:%d", p, + d.currentline); + } + } + + sc = rspamd_mempool_alloc0(cfg->cfg_pool, sizeof(*sc)); + lua_pushvalue(L, 2); + sc->cbref = luaL_ref(L, LUA_REGISTRYINDEX); + sc->priority = priority; + sc->lua_src_pos = rspamd_mempool_strdup(cfg->cfg_pool, tmp); + DL_APPEND(cfg->post_init_scripts, sc); + DL_SORT(cfg->post_init_scripts, rspamd_post_init_sc_sort); + + return 0; +} + +static gint +lua_config_add_config_unload(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + struct rspamd_config_cfg_lua_script *sc; + lua_Debug d; + gchar tmp[256], *p; + + if (cfg == NULL || lua_type(L, 2) != LUA_TFUNCTION) { + return luaL_error(L, "invalid arguments"); + } + + if (lua_getstack(L, 1, &d) == 1) { + (void) lua_getinfo(L, "Sl", &d); + if ((p = strrchr(d.short_src, '/')) == NULL) { + p = d.short_src; + } + else { + p++; + } + + if (strlen(p) > 20) { + rspamd_snprintf(tmp, sizeof(tmp), "%10s...]:%d", p, + d.currentline); + } + else { + rspamd_snprintf(tmp, sizeof(tmp), "%s:%d", p, + d.currentline); + } + } + + sc = rspamd_mempool_alloc0(cfg->cfg_pool, sizeof(*sc)); + lua_pushvalue(L, 2); + sc->cbref = luaL_ref(L, LUA_REGISTRYINDEX); + sc->lua_src_pos = rspamd_mempool_strdup(cfg->cfg_pool, tmp); + DL_APPEND(cfg->config_unload_scripts, sc); + + return 0; +} + + +static void lua_periodic_callback_finish(struct thread_entry *thread, int ret); +static void lua_periodic_callback_error(struct thread_entry *thread, int ret, const char *msg); + +struct rspamd_lua_periodic { + struct ev_loop *event_loop; + struct rspamd_config *cfg; + gchar *lua_src_pos; + lua_State *L; + gdouble timeout; + ev_timer ev; + gint cbref; + gboolean need_jitter; + ref_entry_t ref; +}; + +static void +lua_periodic_dtor(struct rspamd_lua_periodic *periodic) +{ + luaL_unref(periodic->L, LUA_REGISTRYINDEX, periodic->cbref); + ev_timer_stop(periodic->event_loop, &periodic->ev); +} + +static void +lua_periodic_fin(gpointer p) +{ + struct rspamd_lua_periodic *periodic = (struct rspamd_lua_periodic *) p; + + REF_RELEASE(periodic); +} + +static void +lua_periodic_callback(struct ev_loop *loop, ev_timer *w, int revents) +{ + struct rspamd_lua_periodic *periodic = (struct rspamd_lua_periodic *) w->data; + struct rspamd_config **pcfg, *cfg; + struct ev_loop **pev_base; + struct thread_entry *thread; + lua_State *L; + + REF_RETAIN(periodic); + thread = lua_thread_pool_get_for_config(periodic->cfg); + thread->cd = periodic; + thread->finish_callback = lua_periodic_callback_finish; + thread->error_callback = lua_periodic_callback_error; + + L = thread->lua_state; + + lua_rawgeti(L, LUA_REGISTRYINDEX, periodic->cbref); + pcfg = lua_newuserdata(L, sizeof(*pcfg)); + rspamd_lua_setclass(L, "rspamd{config}", -1); + cfg = periodic->cfg; + *pcfg = cfg; + pev_base = lua_newuserdata(L, sizeof(*pev_base)); + rspamd_lua_setclass(L, "rspamd{ev_base}", -1); + *pev_base = periodic->event_loop; + lua_pushnumber(L, ev_now(periodic->event_loop)); + + lua_thread_call(thread, 3); +} + +static void +lua_periodic_callback_finish(struct thread_entry *thread, int ret) +{ + lua_State *L; + struct rspamd_lua_periodic *periodic = thread->cd; + gboolean plan_more = FALSE; + gdouble timeout = 0.0; + + L = thread->lua_state; + + ev_now_update(periodic->event_loop); + + if (ret == 0) { + if (lua_type(L, -1) == LUA_TBOOLEAN) { + plan_more = lua_toboolean(L, -1); + timeout = periodic->timeout; + } + else if (lua_type(L, -1) == LUA_TNUMBER) { + timeout = lua_tonumber(L, -1); + plan_more = timeout > 0 ? TRUE : FALSE; + } + + lua_pop(L, 1); /* Return value */ + } + + if (periodic->cfg->cur_worker) { + if (periodic->cfg->cur_worker->state != rspamd_worker_state_running) { + /* We are terminating, no more periodics */ + plan_more = FALSE; + } + } + + if (plan_more) { + if (periodic->need_jitter) { + timeout = rspamd_time_jitter(timeout, 0.0); + } + + periodic->ev.repeat = timeout; + ev_timer_again(periodic->event_loop, &periodic->ev); + } + else { + ev_timer_stop(periodic->event_loop, &periodic->ev); + } + + REF_RELEASE(periodic); +} + +static void +lua_periodic_callback_error(struct thread_entry *thread, int ret, const char *msg) +{ + struct rspamd_config *cfg; + struct rspamd_lua_periodic *periodic = thread->cd; + cfg = periodic->cfg; + + msg_err_config("call to periodic script (registered at %s) failed: %s", + periodic->lua_src_pos, msg); + + lua_periodic_callback_finish(thread, ret); +} + + +static gint +lua_config_add_periodic(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + struct ev_loop *ev_base = lua_check_ev_base(L, 2); + gdouble timeout = lua_tonumber(L, 3); + struct rspamd_lua_periodic *periodic; + gboolean need_jitter = FALSE; + lua_Debug d; + gchar tmp[256], *p; + + if (cfg == NULL || timeout < 0 || lua_type(L, 4) != LUA_TFUNCTION) { + return luaL_error(L, "invalid arguments"); + } + + if (lua_type(L, 5) == LUA_TBOOLEAN) { + need_jitter = lua_toboolean(L, 5); + } + + if (lua_getstack(L, 1, &d) == 1) { + (void) lua_getinfo(L, "Sl", &d); + if ((p = strrchr(d.short_src, '/')) == NULL) { + p = d.short_src; + } + else { + p++; + } + + if (strlen(p) > 20) { + rspamd_snprintf(tmp, sizeof(tmp), "%10s...]:%d", p, + d.currentline); + } + else { + rspamd_snprintf(tmp, sizeof(tmp), "%s:%d", p, + d.currentline); + } + } + + periodic = rspamd_mempool_alloc0(cfg->cfg_pool, sizeof(*periodic)); + periodic->timeout = timeout; + periodic->L = L; + periodic->cfg = cfg; + periodic->event_loop = ev_base; + periodic->need_jitter = need_jitter; + periodic->lua_src_pos = rspamd_mempool_strdup(cfg->cfg_pool, tmp); + lua_pushvalue(L, 4); + periodic->cbref = luaL_ref(L, LUA_REGISTRYINDEX); + + if (need_jitter) { + timeout = rspamd_time_jitter(timeout, 0.0); + } + + ev_timer_init(&periodic->ev, lua_periodic_callback, timeout, 0.0); + periodic->ev.data = periodic; + ev_timer_start(ev_base, &periodic->ev); + REF_INIT_RETAIN(periodic, lua_periodic_dtor); + + rspamd_mempool_add_destructor(cfg->cfg_pool, lua_periodic_fin, + periodic); + + return 0; +} + +static gint +lua_config_get_symbols_count(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + guint res = 0; + + if (cfg != NULL) { + res = rspamd_symcache_stats_symbols_count(cfg->cache); + } + else { + return luaL_error(L, "invalid arguments"); + } + + lua_pushinteger(L, res); + + return 1; +} + +static gint +lua_config_get_symbols_cksum(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + guint64 res = 0, *pres; + + if (cfg != NULL) { + res = rspamd_symcache_get_cksum(cfg->cache); + } + else { + return luaL_error(L, "invalid arguments"); + } + + pres = lua_newuserdata(L, sizeof(res)); + *pres = res; + rspamd_lua_setclass(L, "rspamd{int64}", -1); + + return 1; +} + +static gint +lua_config_get_symbols_counters(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + ucl_object_t *counters; + + if (cfg != NULL) { + counters = rspamd_symcache_counters(cfg->cache); + ucl_object_push_lua(L, counters, true); + ucl_object_unref(counters); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +struct lua_metric_symbols_cbdata { + lua_State *L; + struct rspamd_config *cfg; + bool is_table; +}; + +static void +lua_metric_symbol_inserter(gpointer k, gpointer v, gpointer ud) +{ + struct lua_metric_symbols_cbdata *cbd = (struct lua_metric_symbols_cbdata *) ud; + lua_State *L; + const gchar *sym = k; + struct rspamd_symbol *s = (struct rspamd_symbol *) v; + struct rspamd_symbols_group *gr; + gint i; + + L = cbd->L; + + if (cbd->is_table) { + lua_pushstring(L, sym); /* Symbol name */ + } + + lua_createtable(L, 0, 6); + lua_pushstring(L, "score"); + lua_pushnumber(L, s->score); + lua_settable(L, -3); + lua_pushstring(L, "description"); + lua_pushstring(L, s->description); + lua_settable(L, -3); + + lua_pushstring(L, "flags"); + lua_createtable(L, 0, 3); + + if (s->flags & RSPAMD_SYMBOL_FLAG_IGNORE_METRIC) { + lua_pushstring(L, "ignore"); + lua_pushboolean(L, true); + lua_settable(L, -3); + } + if (s->flags & RSPAMD_SYMBOL_FLAG_ONEPARAM) { + lua_pushstring(L, "oneparam"); + lua_pushboolean(L, true); + lua_settable(L, -3); + } + if (s->flags & RSPAMD_SYMBOL_FLAG_UNGROUPED) { + lua_pushstring(L, "ungrouped"); + lua_pushboolean(L, true); + lua_settable(L, -3); + } + if (s->flags & RSPAMD_SYMBOL_FLAG_DISABLED) { + lua_pushstring(L, "disabled"); + lua_pushboolean(L, true); + lua_settable(L, -3); + } + + if (s->cache_item) { + guint sflags = rspamd_symcache_get_symbol_flags(cbd->cfg->cache, sym); + + lua_push_symbol_flags(L, sflags, LUA_SYMOPT_FLAG_USE_MAP); + + guint nids; + const guint *allowed_ids = rspamd_symcache_get_allowed_settings_ids(cbd->cfg->cache, + sym, &nids); + + if (allowed_ids && nids > 0) { + lua_createtable(L, nids, 0); + + for (i = 0; i < nids; i++) { + lua_pushinteger(L, allowed_ids[i]); + lua_rawseti(L, -2, i + 1); + } + + lua_setfield(L, -2, "allowed_ids"); + } + + const guint *forbidden_ids = rspamd_symcache_get_forbidden_settings_ids( + cbd->cfg->cache, + sym, &nids); + + if (forbidden_ids && nids > 0) { + lua_createtable(L, nids, 0); + + for (i = 0; i < nids; i++) { + lua_pushinteger(L, forbidden_ids[i]); + lua_rawseti(L, -2, i + 1); + } + + lua_setfield(L, -2, "forbidden_ids"); + } + } + + lua_settable(L, -3); /* Flags -> flags_table */ + + lua_pushstring(L, "nshots"); + lua_pushinteger(L, s->nshots); + lua_settable(L, -3); + + if (s->gr) { + lua_pushstring(L, "group"); + lua_pushstring(L, s->gr->name); + lua_settable(L, -3); + } + + if (s->groups && s->groups->len > 0) { + lua_pushstring(L, "groups"); + lua_createtable(L, s->groups->len, 0); + + PTR_ARRAY_FOREACH(s->groups, i, gr) + { + lua_pushstring(L, gr->name); + lua_rawseti(L, -2, i + 1); /* Groups[i + 1] = group_name */ + } + + lua_settable(L, -3); /* Groups -> groups_table */ + } + else { + lua_createtable(L, 0, 0); + lua_setfield(L, -2, "groups"); + } + + if (cbd->is_table) { + lua_settable(L, -3); /* Symname -> table */ + } +} + +static gint +lua_config_get_symbols(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + + if (cfg != NULL) { + struct lua_metric_symbols_cbdata cbd; + + cbd.L = L; + cbd.cfg = cfg; + cbd.is_table = true; + + lua_createtable(L, 0, g_hash_table_size(cfg->symbols)); + g_hash_table_foreach(cfg->symbols, + lua_metric_symbol_inserter, + &cbd); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_config_get_symbol(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *sym_name = luaL_checkstring(L, 2); + + if (cfg != NULL && sym_name != NULL) { + struct lua_metric_symbols_cbdata cbd; + struct rspamd_symbol *s = g_hash_table_lookup(cfg->symbols, sym_name); + + if (s) { + cbd.L = L; + cbd.cfg = cfg; + cbd.is_table = false; + lua_metric_symbol_inserter((void *) sym_name, s, &cbd); + } + else { + /* No config for a symbol */ + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_config_get_symbol_callback(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *sym = luaL_checkstring(L, 2); + struct rspamd_abstract_callback_data *abs_cbdata; + struct lua_callback_data *cbd; + + if (cfg != NULL && sym != NULL) { + abs_cbdata = rspamd_symcache_get_cbdata(cfg->cache, sym); + + if (abs_cbdata == NULL || abs_cbdata->magic != rspamd_lua_callback_magic) { + lua_pushnil(L); + } + else { + cbd = (struct lua_callback_data *) abs_cbdata; + + if (cbd->cb_is_ref) { + lua_rawgeti(L, LUA_REGISTRYINDEX, cbd->callback.ref); + } + else { + lua_getglobal(L, cbd->callback.name); + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_config_set_symbol_callback(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *sym = luaL_checkstring(L, 2); + struct rspamd_abstract_callback_data *abs_cbdata; + struct lua_callback_data *cbd; + + if (cfg != NULL && sym != NULL && lua_type(L, 3) == LUA_TFUNCTION) { + abs_cbdata = rspamd_symcache_get_cbdata(cfg->cache, sym); + + if (abs_cbdata == NULL || abs_cbdata->magic != rspamd_lua_callback_magic) { + lua_pushboolean(L, FALSE); + } + else { + cbd = (struct lua_callback_data *) abs_cbdata; + + if (cbd->cb_is_ref) { + luaL_unref(L, LUA_REGISTRYINDEX, cbd->callback.ref); + } + else { + cbd->cb_is_ref = TRUE; + } + + lua_pushvalue(L, 3); + cbd->callback.ref = luaL_ref(L, LUA_REGISTRYINDEX); + lua_pushboolean(L, TRUE); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_config_get_symbol_stat(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *sym = luaL_checkstring(L, 2); + gdouble freq, stddev, tm; + guint hits; + + if (cfg != NULL && sym != NULL) { + if (!rspamd_symcache_stat_symbol(cfg->cache, sym, &freq, + &stddev, &tm, &hits)) { + lua_pushnil(L); + } + else { + lua_createtable(L, 0, 4); + lua_pushstring(L, "frequency"); + lua_pushnumber(L, freq); + lua_settable(L, -3); + lua_pushstring(L, "sttdev"); + lua_pushnumber(L, stddev); + lua_settable(L, -3); + lua_pushstring(L, "time"); + lua_pushnumber(L, tm); + lua_settable(L, -3); + lua_pushstring(L, "hits"); + lua_pushinteger(L, hits); + lua_settable(L, -3); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_config_get_symbol_parent(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *sym = luaL_checkstring(L, 2), *parent; + + if (cfg != NULL && sym != NULL) { + parent = rspamd_symcache_get_parent(cfg->cache, sym); + + if (parent) { + lua_pushstring(L, parent); + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_config_get_group_symbols(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *gr_name = luaL_checkstring(L, 2); + + if (cfg != NULL && gr_name != NULL) { + struct rspamd_symbols_group *group; + + group = g_hash_table_lookup(cfg->groups, gr_name); + + if (group == NULL) { + lua_pushnil(L); + } + else { + guint i = 1; + gpointer k, v; + GHashTableIter it; + + lua_createtable(L, g_hash_table_size(group->symbols), 0); + g_hash_table_iter_init(&it, group->symbols); + + while (g_hash_table_iter_next(&it, &k, &v)) { + lua_pushstring(L, k); + lua_rawseti(L, -2, i); + i++; + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_config_get_groups(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + gboolean need_private; + struct rspamd_symbols_group *gr; + GHashTableIter it; + gpointer k, v; + + if (cfg) { + if (lua_isboolean(L, 2)) { + need_private = lua_toboolean(L, 2); + } + else { + need_private = !(cfg->public_groups_only); + } + + lua_createtable(L, 0, g_hash_table_size(cfg->groups)); + g_hash_table_iter_init(&it, cfg->groups); + + while (g_hash_table_iter_next(&it, &k, &v)) { + gr = (struct rspamd_symbols_group *) v; + + if (need_private || (gr->flags & RSPAMD_SYMBOL_GROUP_PUBLIC)) { + lua_createtable(L, 0, 4); + + lua_pushstring(L, gr->description); + lua_setfield(L, -2, "description"); + lua_pushnumber(L, gr->max_score); + lua_setfield(L, -2, "max_score"); + lua_pushboolean(L, (gr->flags & RSPAMD_SYMBOL_GROUP_PUBLIC) != 0); + lua_setfield(L, -2, "is_public"); + /* TODO: maybe push symbols as well */ + + /* Parent table indexed by group name */ + lua_setfield(L, -2, gr->name); + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_config_register_finish_script(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + struct rspamd_config_cfg_lua_script *sc; + + if (cfg != NULL && lua_type(L, 2) == LUA_TFUNCTION) { + sc = rspamd_mempool_alloc0(cfg->cfg_pool, sizeof(*sc)); + lua_pushvalue(L, 2); + sc->cbref = luaL_ref(L, LUA_REGISTRYINDEX); + DL_APPEND(cfg->on_term_scripts, sc); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static inline bool +rspamd_lua_config_check_settings_symbols_object(const ucl_object_t *obj) +{ + if (obj == NULL) { + /* Semantically valid */ + return true; + } + + if (ucl_object_type(obj) == UCL_OBJECT) { + /* Key-value mapping - should be okay */ + return true; + } + + if (ucl_object_type(obj) == UCL_ARRAY) { + /* Okay if empty */ + if (obj->len == 0) { + return true; + } + } + + /* Everything else not okay */ + return false; +} + +static gint +lua_config_register_settings_id(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *settings_name = luaL_checkstring(L, 2); + + if (cfg != NULL && settings_name) { + ucl_object_t *sym_enabled, *sym_disabled; + enum rspamd_config_settings_policy policy = RSPAMD_SETTINGS_POLICY_DEFAULT; + + sym_enabled = ucl_object_lua_import(L, 3); + + if (!rspamd_lua_config_check_settings_symbols_object(sym_enabled)) { + ucl_object_unref(sym_enabled); + + return luaL_error(L, "invalid symbols enabled"); + } + + sym_disabled = ucl_object_lua_import(L, 4); + + if (!rspamd_lua_config_check_settings_symbols_object(sym_disabled)) { + ucl_object_unref(sym_enabled); + ucl_object_unref(sym_disabled); + + return luaL_error(L, "invalid symbols enabled"); + } + + /* Check policy */ + if (lua_isstring(L, 5)) { + const gchar *policy_str = lua_tostring(L, 5); + + if (strcmp(policy_str, "default") == 0) { + policy = RSPAMD_SETTINGS_POLICY_DEFAULT; + } + else if (strcmp(policy_str, "implicit_allow") == 0) { + policy = RSPAMD_SETTINGS_POLICY_IMPLICIT_ALLOW; + } + else if (strcmp(policy_str, "implicit_deny") == 0) { + policy = RSPAMD_SETTINGS_POLICY_IMPLICIT_DENY; + } + else { + return luaL_error(L, "invalid settings policy: %s", policy_str); + } + } + else { + /* Apply heuristic */ + if (!sym_enabled) { + policy = RSPAMD_SETTINGS_POLICY_IMPLICIT_ALLOW; + } + } + + rspamd_config_register_settings_id(cfg, settings_name, sym_enabled, + sym_disabled, policy); + + if (sym_enabled) { + ucl_object_unref(sym_enabled); + } + + if (sym_disabled) { + ucl_object_unref(sym_disabled); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static gint +lua_config_register_monitored(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + struct rspamd_monitored *m, **pm; + const gchar *url, *type; + ucl_object_t *params = NULL; + + url = lua_tostring(L, 2); + type = lua_tostring(L, 3); + + if (cfg != NULL && url != NULL && type != NULL) { + if (g_ascii_strcasecmp(type, "dns") == 0) { + lua_Debug ar; + + if (lua_type(L, 4) == LUA_TTABLE) { + params = ucl_object_lua_import(L, 4); + } + + /* Get lua line and source */ + lua_getstack(L, 1, &ar); + lua_getinfo(L, "nSl", &ar); + + m = rspamd_monitored_create_(cfg->monitored_ctx, url, + RSPAMD_MONITORED_DNS, RSPAMD_MONITORED_DEFAULT, + params, ar.short_src); + + if (m) { + pm = lua_newuserdata(L, sizeof(*pm)); + *pm = m; + rspamd_lua_setclass(L, "rspamd{monitored}", -1); + } + else { + lua_pushnil(L); + } + + if (params) { + ucl_object_unref(params); + } + } + else { + return luaL_error(L, "invalid monitored type: %s", type); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_config_add_doc(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg; + const gchar *path = NULL, *option, *doc_string; + const gchar *type_str = NULL, *default_value = NULL; + ucl_type_t type = UCL_NULL; + gboolean required = FALSE; + GError *err = NULL; + + cfg = lua_check_config(L, 1); + + if (lua_type(L, 2) == LUA_TSTRING) { + path = luaL_checkstring(L, 2); + } + + option = luaL_checkstring(L, 3); + doc_string = luaL_checkstring(L, 4); + + if (cfg && option && doc_string) { + if (lua_type(L, 5) == LUA_TTABLE) { + if (!rspamd_lua_parse_table_arguments(L, 5, &err, + RSPAMD_LUA_PARSE_ARGUMENTS_DEFAULT, + "type=S;default=S;required=B", + &type_str, &default_value, &required)) { + msg_err_config("cannot get parameters list: %e", err); + + if (err) { + g_error_free(err); + } + + if (type_str) { + if (!ucl_object_string_to_type(type_str, &type)) { + msg_err_config("invalid type: %s", type_str); + } + } + } + } + + rspamd_rcl_add_doc_by_path(cfg, path, doc_string, option, + type, NULL, 0, default_value, required); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static gint +lua_config_add_example(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg; + const gchar *path = NULL, *option, *doc_string, *example; + gsize example_len; + + cfg = lua_check_config(L, 1); + + if (lua_type(L, 2) == LUA_TSTRING) { + path = luaL_checkstring(L, 2); + } + + option = luaL_checkstring(L, 3); + doc_string = luaL_checkstring(L, 4); + example = luaL_checklstring(L, 5, &example_len); + + if (cfg && option && doc_string && example) { + + rspamd_rcl_add_doc_by_example(cfg, path, doc_string, option, + example, example_len); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static gint +lua_config_get_cpu_flags(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + struct rspamd_cryptobox_library_ctx *crypto_ctx; + + if (cfg != NULL) { + crypto_ctx = cfg->libs_ctx->crypto_ctx; + lua_newtable(L); + + if (crypto_ctx->cpu_config & CPUID_SSSE3) { + lua_pushstring(L, "ssse3"); + lua_pushboolean(L, true); + lua_settable(L, -3); + } + if (crypto_ctx->cpu_config & CPUID_SSE41) { + lua_pushstring(L, "sse41"); + lua_pushboolean(L, true); + lua_settable(L, -3); + } + if (crypto_ctx->cpu_config & CPUID_SSE42) { + lua_pushstring(L, "sse42"); + lua_pushboolean(L, true); + lua_settable(L, -3); + } + if (crypto_ctx->cpu_config & CPUID_SSE2) { + lua_pushstring(L, "sse2"); + lua_pushboolean(L, true); + lua_settable(L, -3); + } + if (crypto_ctx->cpu_config & CPUID_SSE3) { + lua_pushstring(L, "sse3"); + lua_pushboolean(L, true); + lua_settable(L, -3); + } + if (crypto_ctx->cpu_config & CPUID_AVX) { + lua_pushstring(L, "avx"); + lua_pushboolean(L, true); + lua_settable(L, -3); + } + if (crypto_ctx->cpu_config & CPUID_AVX2) { + lua_pushstring(L, "avx2"); + lua_pushboolean(L, true); + lua_settable(L, -3); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_config_has_torch(lua_State *L) +{ + msg_warn("use of the obsoleted `has_torch` function"); + lua_pushboolean(L, false); + + return 1; +} + +static gint +lua_config_experimental_enabled(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + + if (cfg != NULL) { + lua_pushboolean(L, cfg->enable_experimental); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +struct rspamd_lua_include_trace_cbdata { + lua_State *L; + gint cbref; +}; + +static void +lua_include_trace_cb(struct ucl_parser *parser, + const ucl_object_t *parent, + const ucl_object_t *args, + const char *path, + size_t pathlen, + void *user_data) +{ + struct rspamd_lua_include_trace_cbdata *cbdata = + (struct rspamd_lua_include_trace_cbdata *) user_data; + gint err_idx; + lua_State *L; + + L = cbdata->L; + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + + lua_rawgeti(L, LUA_REGISTRYINDEX, cbdata->cbref); + /* Current filename */ + lua_pushstring(L, ucl_parser_get_cur_file(parser)); + /* Included filename */ + lua_pushlstring(L, path, pathlen); + /* Params */ + if (args) { + ucl_object_push_lua(L, args, true); + } + else { + lua_newtable(L); + } + /* Parent */ + if (parent) { + lua_pushstring(L, ucl_object_key(parent)); + } + else { + lua_pushnil(L); + } + + if (lua_pcall(L, 4, 0, err_idx) != 0) { + msg_err("lua call to local include trace failed: %s", lua_tostring(L, -1)); + } + + lua_settop(L, err_idx - 1); +} + +#define LUA_TABLE_TO_HASH(htb, idx) \ + do { \ + lua_pushstring(L, (idx)); \ + lua_gettable(L, -2); \ + if (lua_isstring(L, -1)) { \ + g_hash_table_insert((htb), (idx), g_strdup(lua_tostring(L, -1))); \ + } \ + lua_pop(L, 1); \ + } while (0) + +static gint +lua_config_load_ucl(lua_State *L) +{ + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *filename; + GHashTable *paths = g_hash_table_new_full(rspamd_str_hash, rspamd_str_equal, + NULL, g_free); + GError *err = NULL; + + if (cfg) { + if (lua_isstring(L, 2)) { + filename = lua_tostring(L, 2); + } + else { + filename = RSPAMD_CONFDIR "/rspamd.conf"; + } + + /* Convert rspamd_paths */ + lua_getglobal(L, "rspamd_paths"); + + if (lua_istable(L, -1)) { + LUA_TABLE_TO_HASH(paths, RSPAMD_CONFDIR_INDEX); + LUA_TABLE_TO_HASH(paths, RSPAMD_LOCAL_CONFDIR_INDEX); + LUA_TABLE_TO_HASH(paths, RSPAMD_RUNDIR_INDEX); + LUA_TABLE_TO_HASH(paths, RSPAMD_DBDIR_INDEX); + LUA_TABLE_TO_HASH(paths, RSPAMD_LOGDIR_INDEX); + LUA_TABLE_TO_HASH(paths, RSPAMD_WWWDIR_INDEX); + LUA_TABLE_TO_HASH(paths, RSPAMD_PLUGINSDIR_INDEX); + LUA_TABLE_TO_HASH(paths, RSPAMD_RULESDIR_INDEX); + LUA_TABLE_TO_HASH(paths, RSPAMD_LUALIBDIR_INDEX); + LUA_TABLE_TO_HASH(paths, RSPAMD_PREFIX_INDEX); + } + + lua_pop(L, 1); + + if (lua_isfunction(L, 3)) { + struct rspamd_lua_include_trace_cbdata cbd; + + lua_pushvalue(L, 3); + cbd.cbref = luaL_ref(L, LUA_REGISTRYINDEX); + cbd.L = L; + + if (!rspamd_config_parse_ucl(cfg, filename, paths, + lua_include_trace_cb, &cbd, lua_toboolean(L, 4), &err)) { + luaL_unref(L, LUA_REGISTRYINDEX, cbd.cbref); + lua_pushboolean(L, false); + lua_pushfstring(L, "failed to load config: %s", err->message); + g_error_free(err); + g_hash_table_unref(paths); + + return 2; + } + + luaL_unref(L, LUA_REGISTRYINDEX, cbd.cbref); + } + else { + if (!rspamd_config_parse_ucl(cfg, filename, paths, NULL, NULL, + lua_toboolean(L, 3), &err)) { + lua_pushboolean(L, false); + lua_pushfstring(L, "failed to load config: %s", err->message); + g_error_free(err); + g_hash_table_unref(paths); + + return 2; + } + } + + rspamd_rcl_maybe_apply_lua_transform(cfg); + rspamd_config_calculate_cksum(cfg); + } + else { + return luaL_error(L, "invalid arguments"); + } + + g_hash_table_unref(paths); + lua_pushboolean(L, true); + + return 1; +} + +#undef IDX_TO_HASH + +static gint +lua_config_parse_rcl(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + GHashTable *excluded = g_hash_table_new_full(rspamd_str_hash, rspamd_str_equal, + g_free, NULL); + GError *err = NULL; + struct rspamd_rcl_sections_map *top; + + if (cfg) { + if (lua_istable(L, 2)) { + lua_pushvalue(L, 2); + + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + g_hash_table_insert(excluded, g_strdup(lua_tostring(L, -1)), + GINT_TO_POINTER(-1)); + } + + lua_pop(L, 1); + } + + top = rspamd_rcl_config_init(cfg, excluded); + + if (!rspamd_rcl_parse(top, cfg, cfg, cfg->cfg_pool, cfg->cfg_ucl_obj, &err)) { + lua_pushboolean(L, false); + lua_pushfstring(L, "failed to load config: %s", err->message); + g_error_free(err); + g_hash_table_unref(excluded); + rspamd_rcl_sections_free(top); + + return 2; + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + g_hash_table_unref(excluded); + rspamd_rcl_sections_free(top); + lua_pushboolean(L, true); + + return 1; +} + +static gint +lua_config_init_modules(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + + if (cfg != NULL) { + rspamd_lua_post_load_config(cfg); + lua_pushboolean(L, rspamd_init_filters(cfg, false, false)); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_config_init_subsystem(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *subsystem = luaL_checkstring(L, 2); + gchar **parts; + guint nparts, i; + + if (cfg != NULL && subsystem != NULL) { + parts = g_strsplit_set(subsystem, ";,", -1); + nparts = g_strv_length(parts); + + for (i = 0; i < nparts; i++) { + if (strcmp(parts[i], "filters") == 0) { + rspamd_lua_post_load_config(cfg); + rspamd_init_filters(cfg, false, false); + } + else if (strcmp(parts[i], "langdet") == 0) { + if (!cfg->lang_det) { + cfg->lang_det = rspamd_language_detector_init(cfg); + rspamd_mempool_add_destructor(cfg->cfg_pool, + (rspamd_mempool_destruct_t) rspamd_language_detector_unref, + cfg->lang_det); + } + } + else if (strcmp(parts[i], "stat") == 0) { + rspamd_stat_init(cfg, NULL); + } + else if (strcmp(parts[i], "dns") == 0) { + struct ev_loop *ev_base = lua_check_ev_base(L, 3); + + if (ev_base) { + cfg->dns_resolver = rspamd_dns_resolver_init(rspamd_log_default_logger(), + ev_base, + cfg); + } + else { + g_strfreev(parts); + + return luaL_error(L, "no event base specified"); + } + } + else if (strcmp(parts[i], "symcache") == 0) { + rspamd_symcache_init(cfg->cache); + } + else { + int ret = luaL_error(L, "invalid param: %s", parts[i]); + g_strfreev(parts); + + return ret; + } + } + + g_strfreev(parts); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static gint +lua_config_register_re_selector(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *name = luaL_checkstring(L, 2); + const gchar *selector_str = luaL_checkstring(L, 3); + const gchar *delimiter = ""; + bool flatten = false; + gint top = lua_gettop(L); + bool res = false; + + if (cfg && name && selector_str) { + if (lua_gettop(L) >= 4) { + delimiter = luaL_checkstring(L, 4); + + if (lua_isboolean(L, 5)) { + flatten = lua_toboolean(L, 5); + } + } + + if (luaL_dostring(L, "return require \"lua_selectors\"") != 0) { + msg_warn_config("cannot require lua_selectors: %s", + lua_tostring(L, -1)); + } + else { + if (lua_type(L, -1) != LUA_TTABLE) { + msg_warn_config("lua selectors must return " + "table and not %s", + lua_typename(L, lua_type(L, -1))); + } + else { + lua_pushstring(L, "create_selector_closure"); + lua_gettable(L, -2); + + if (lua_type(L, -1) != LUA_TFUNCTION) { + msg_warn_config("create_selector_closure must return " + "function and not %s", + lua_typename(L, lua_type(L, -1))); + } + else { + gint err_idx, ret; + struct rspamd_config **pcfg; + + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + + /* Push function */ + lua_pushvalue(L, -2); + + pcfg = lua_newuserdata(L, sizeof(*pcfg)); + rspamd_lua_setclass(L, "rspamd{config}", -1); + *pcfg = cfg; + lua_pushstring(L, selector_str); + lua_pushstring(L, delimiter); + lua_pushboolean(L, flatten); + + if ((ret = lua_pcall(L, 4, 1, err_idx)) != 0) { + msg_err_config("call to create_selector_closure lua " + "script failed (%d): %s", + ret, + lua_tostring(L, -1)); + } + else { + if (lua_type(L, -1) != LUA_TFUNCTION) { + msg_warn_config("create_selector_closure " + "invocation must return " + "function and not %s", + lua_typename(L, lua_type(L, -1))); + } + else { + ret = luaL_ref(L, LUA_REGISTRYINDEX); + rspamd_re_cache_add_selector(cfg->re_cache, + name, ret); + res = true; + } + } + } + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + lua_settop(L, top); + lua_pushboolean(L, res); + + if (res) { + msg_info_config("registered regexp selector %s", name); + } + + return 1; +} + +static gint +lua_config_get_tld_path(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + + if (cfg != NULL) { + lua_pushstring(L, cfg->tld_file); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_config_get_dns_max_requests(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + + if (cfg != NULL) { + lua_pushinteger(L, cfg->dns_max_requests); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_config_get_dns_timeout(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + + if (cfg != NULL) { + lua_pushnumber(L, cfg->dns_timeout); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_monitored_alive(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_monitored *m = lua_check_monitored(L, 1); + + if (m) { + lua_pushboolean(L, rspamd_monitored_alive(m)); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_monitored_offline(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_monitored *m = lua_check_monitored(L, 1); + + if (m) { + lua_pushnumber(L, rspamd_monitored_offline_time(m)); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_monitored_total_offline(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_monitored *m = lua_check_monitored(L, 1); + + if (m) { + lua_pushnumber(L, rspamd_monitored_total_offline_time(m)); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_monitored_latency(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_monitored *m = lua_check_monitored(L, 1); + + if (m) { + lua_pushnumber(L, rspamd_monitored_latency(m)); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +void luaopen_config(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{config}", configlib_m); + + lua_pop(L, 1); + + rspamd_lua_new_class(L, "rspamd{monitored}", monitoredlib_m); + + lua_pop(L, 1); +} + +void lua_call_finish_script(struct rspamd_config_cfg_lua_script *sc, + struct rspamd_task *task) +{ + + struct rspamd_task **ptask; + struct thread_entry *thread; + + thread = lua_thread_pool_get_for_task(task); + thread->task = task; + + lua_State *L = thread->lua_state; + + lua_rawgeti(L, LUA_REGISTRYINDEX, sc->cbref); + + ptask = lua_newuserdata(L, sizeof(struct rspamd_task *)); + rspamd_lua_setclass(L, "rspamd{task}", -1); + *ptask = task; + + lua_thread_call(thread, 1); +} diff --git a/src/lua/lua_cryptobox.c b/src/lua/lua_cryptobox.c new file mode 100644 index 0000000..70c6f0a --- /dev/null +++ b/src/lua/lua_cryptobox.c @@ -0,0 +1,3065 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** + * @module rspamd_cryptobox + * Rspamd cryptobox is a module that operates with digital signatures and + * hashes. + * @example + * local hash = require "rspamd_cryptobox_hash" + * + * local h = hash.create() + * h:update('hello world') + * print(h:hex()) + */ + + +#include "lua_common.h" +#include "libcryptobox/cryptobox.h" +#include "libcryptobox/keypair.h" +#include "libcryptobox/keypair_private.h" +#include "unix-std.h" +#include "contrib/libottery/ottery.h" +#include "libutil/ref.h" + +#include <stdalign.h> +#include <openssl/hmac.h> + + +enum lua_cryptobox_hash_type { + LUA_CRYPTOBOX_HASH_BLAKE2 = 0, + LUA_CRYPTOBOX_HASH_SSL, + LUA_CRYPTOBOX_HASH_HMAC, + LUA_CRYPTOBOX_HASH_XXHASH64, + LUA_CRYPTOBOX_HASH_XXHASH32, + LUA_CRYPTOBOX_HASH_XXHASH3, + LUA_CRYPTOBOX_HASH_MUM, + LUA_CRYPTOBOX_HASH_T1HA, +}; + +struct rspamd_lua_cryptobox_hash { + union { + rspamd_cryptobox_hash_state_t *h; + EVP_MD_CTX *c; + HMAC_CTX *hmac_c; + rspamd_cryptobox_fast_hash_state_t *fh; + } content; + + unsigned char out[rspamd_cryptobox_HASHBYTES]; + + uint8_t type; + uint8_t out_len; + uint8_t is_finished; + + ref_entry_t ref; +}; + +LUA_FUNCTION_DEF(cryptobox_pubkey, load); +LUA_FUNCTION_DEF(cryptobox_pubkey, create); +LUA_FUNCTION_DEF(cryptobox_pubkey, gc); +LUA_FUNCTION_DEF(cryptobox_keypair, load); +LUA_FUNCTION_DEF(cryptobox_keypair, create); +LUA_FUNCTION_DEF(cryptobox_keypair, gc); +LUA_FUNCTION_DEF(cryptobox_keypair, totable); +LUA_FUNCTION_DEF(cryptobox_keypair, get_type); +LUA_FUNCTION_DEF(cryptobox_keypair, get_alg); +LUA_FUNCTION_DEF(cryptobox_keypair, get_pk); +LUA_FUNCTION_DEF(cryptobox_signature, create); +LUA_FUNCTION_DEF(cryptobox_signature, load); +LUA_FUNCTION_DEF(cryptobox_signature, save); +LUA_FUNCTION_DEF(cryptobox_signature, gc); +LUA_FUNCTION_DEF(cryptobox_signature, hex); +LUA_FUNCTION_DEF(cryptobox_signature, base32); +LUA_FUNCTION_DEF(cryptobox_signature, base64); +LUA_FUNCTION_DEF(cryptobox_signature, bin); +LUA_FUNCTION_DEF(cryptobox_hash, create); +LUA_FUNCTION_DEF(cryptobox_hash, create_specific); +LUA_FUNCTION_DEF(cryptobox_hash, create_specific_keyed); +LUA_FUNCTION_DEF(cryptobox_hash, create_keyed); +LUA_FUNCTION_DEF(cryptobox_hash, update); +LUA_FUNCTION_DEF(cryptobox_hash, reset); +LUA_FUNCTION_DEF(cryptobox_hash, hex); +LUA_FUNCTION_DEF(cryptobox_hash, base32); +LUA_FUNCTION_DEF(cryptobox_hash, base64); +LUA_FUNCTION_DEF(cryptobox_hash, bin); +LUA_FUNCTION_DEF(cryptobox_hash, gc); +LUA_FUNCTION_DEF(cryptobox, verify_memory); +LUA_FUNCTION_DEF(cryptobox, verify_file); +LUA_FUNCTION_DEF(cryptobox, sign_file); +LUA_FUNCTION_DEF(cryptobox, sign_memory); +LUA_FUNCTION_DEF(cryptobox, encrypt_memory); +LUA_FUNCTION_DEF(cryptobox, encrypt_file); +LUA_FUNCTION_DEF(cryptobox, decrypt_memory); +LUA_FUNCTION_DEF(cryptobox, decrypt_file); +LUA_FUNCTION_DEF(cryptobox, encrypt_cookie); +LUA_FUNCTION_DEF(cryptobox, decrypt_cookie); +LUA_FUNCTION_DEF(cryptobox, pbkdf); +LUA_FUNCTION_DEF(cryptobox, gen_dkim_keypair); + +/* Secretbox API: uses libsodium secretbox and blake2b for key derivation */ +LUA_FUNCTION_DEF(cryptobox_secretbox, create); +LUA_FUNCTION_DEF(cryptobox_secretbox, encrypt); +LUA_FUNCTION_DEF(cryptobox_secretbox, decrypt); +LUA_FUNCTION_DEF(cryptobox_secretbox, gc); + +static const struct luaL_reg cryptoboxlib_f[] = { + LUA_INTERFACE_DEF(cryptobox, verify_memory), + LUA_INTERFACE_DEF(cryptobox, verify_file), + LUA_INTERFACE_DEF(cryptobox, sign_memory), + LUA_INTERFACE_DEF(cryptobox, sign_file), + LUA_INTERFACE_DEF(cryptobox, encrypt_memory), + LUA_INTERFACE_DEF(cryptobox, encrypt_file), + LUA_INTERFACE_DEF(cryptobox, decrypt_memory), + LUA_INTERFACE_DEF(cryptobox, decrypt_file), + LUA_INTERFACE_DEF(cryptobox, encrypt_cookie), + LUA_INTERFACE_DEF(cryptobox, decrypt_cookie), + LUA_INTERFACE_DEF(cryptobox, pbkdf), + LUA_INTERFACE_DEF(cryptobox, gen_dkim_keypair), + {NULL, NULL}}; + +static const struct luaL_reg cryptoboxpubkeylib_f[] = { + LUA_INTERFACE_DEF(cryptobox_pubkey, load), + LUA_INTERFACE_DEF(cryptobox_pubkey, create), + {NULL, NULL}}; + +static const struct luaL_reg cryptoboxpubkeylib_m[] = { + {"__tostring", rspamd_lua_class_tostring}, + {"__gc", lua_cryptobox_pubkey_gc}, + {NULL, NULL}}; + +static const struct luaL_reg cryptoboxkeypairlib_f[] = { + LUA_INTERFACE_DEF(cryptobox_keypair, load), + LUA_INTERFACE_DEF(cryptobox_keypair, create), + {NULL, NULL}}; + +static const struct luaL_reg cryptoboxkeypairlib_m[] = { + {"__tostring", rspamd_lua_class_tostring}, + {"totable", lua_cryptobox_keypair_totable}, + {"get_type", lua_cryptobox_keypair_get_type}, + {"get_alg", lua_cryptobox_keypair_get_alg}, + {"type", lua_cryptobox_keypair_get_type}, + {"alg", lua_cryptobox_keypair_get_alg}, + {"pk", lua_cryptobox_keypair_get_pk}, + {"pubkey", lua_cryptobox_keypair_get_pk}, + {"__gc", lua_cryptobox_keypair_gc}, + {NULL, NULL}}; + +static const struct luaL_reg cryptoboxsignlib_f[] = { + LUA_INTERFACE_DEF(cryptobox_signature, load), + LUA_INTERFACE_DEF(cryptobox_signature, create), + {NULL, NULL}}; + +static const struct luaL_reg cryptoboxsignlib_m[] = { + LUA_INTERFACE_DEF(cryptobox_signature, save), + LUA_INTERFACE_DEF(cryptobox_signature, hex), + LUA_INTERFACE_DEF(cryptobox_signature, base32), + LUA_INTERFACE_DEF(cryptobox_signature, base64), + LUA_INTERFACE_DEF(cryptobox_signature, bin), + {"__tostring", rspamd_lua_class_tostring}, + {"__gc", lua_cryptobox_signature_gc}, + {NULL, NULL}}; + +static const struct luaL_reg cryptoboxhashlib_f[] = { + LUA_INTERFACE_DEF(cryptobox_hash, create), + LUA_INTERFACE_DEF(cryptobox_hash, create_keyed), + LUA_INTERFACE_DEF(cryptobox_hash, create_specific), + LUA_INTERFACE_DEF(cryptobox_hash, create_specific_keyed), + {NULL, NULL}}; + +static const struct luaL_reg cryptoboxhashlib_m[] = { + LUA_INTERFACE_DEF(cryptobox_hash, update), + LUA_INTERFACE_DEF(cryptobox_hash, reset), + LUA_INTERFACE_DEF(cryptobox_hash, hex), + LUA_INTERFACE_DEF(cryptobox_hash, base32), + LUA_INTERFACE_DEF(cryptobox_hash, base64), + LUA_INTERFACE_DEF(cryptobox_hash, bin), + {"__tostring", rspamd_lua_class_tostring}, + {"__gc", lua_cryptobox_hash_gc}, + {NULL, NULL}}; + + +static const struct luaL_reg cryptoboxsecretboxlib_f[] = { + LUA_INTERFACE_DEF(cryptobox_secretbox, create), + {NULL, NULL}, +}; + +static const struct luaL_reg cryptoboxsecretboxlib_m[] = { + LUA_INTERFACE_DEF(cryptobox_secretbox, encrypt), + LUA_INTERFACE_DEF(cryptobox_secretbox, decrypt), + {"__gc", lua_cryptobox_secretbox_gc}, + {NULL, NULL}, +}; + +struct rspamd_lua_cryptobox_secretbox { + guchar sk[crypto_secretbox_KEYBYTES]; +}; + +static struct rspamd_cryptobox_pubkey * +lua_check_cryptobox_pubkey(lua_State *L, int pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{cryptobox_pubkey}"); + + luaL_argcheck(L, ud != NULL, 1, "'cryptobox_pubkey' expected"); + return ud ? *((struct rspamd_cryptobox_pubkey **) ud) : NULL; +} + +static struct rspamd_cryptobox_keypair * +lua_check_cryptobox_keypair(lua_State *L, int pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{cryptobox_keypair}"); + + luaL_argcheck(L, ud != NULL, 1, "'cryptobox_keypair' expected"); + return ud ? *((struct rspamd_cryptobox_keypair **) ud) : NULL; +} + +static rspamd_fstring_t * +lua_check_cryptobox_sign(lua_State *L, int pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{cryptobox_signature}"); + + luaL_argcheck(L, ud != NULL, 1, "'cryptobox_signature' expected"); + return ud ? *((rspamd_fstring_t **) ud) : NULL; +} + +struct rspamd_lua_cryptobox_hash * +lua_check_cryptobox_hash(lua_State *L, int pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{cryptobox_hash}"); + + luaL_argcheck(L, ud != NULL, 1, "'cryptobox_hash' expected"); + return ud ? *((struct rspamd_lua_cryptobox_hash **) ud) : NULL; +} + +static struct rspamd_lua_cryptobox_secretbox * +lua_check_cryptobox_secretbox(lua_State *L, int pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{cryptobox_secretbox}"); + + luaL_argcheck(L, ud != NULL, 1, "'cryptobox_secretbox' expected"); + return ud ? *((struct rspamd_lua_cryptobox_secretbox **) ud) : NULL; +} + +/*** + * @function rspamd_cryptobox_pubkey.load(file[, type[, alg]]) + * Loads public key from base32 encoded file + * @param {string} file filename to load + * @param {string} type optional 'sign' or 'kex' for signing and encryption + * @param {string} alg optional 'default' or 'nist' for curve25519/nistp256 keys + * @return {cryptobox_pubkey} new public key + */ +static gint +lua_cryptobox_pubkey_load(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_cryptobox_pubkey *pkey = NULL, **ppkey; + const gchar *filename, *arg; + gint type = RSPAMD_KEYPAIR_SIGN; + gint alg = RSPAMD_CRYPTOBOX_MODE_25519; + guchar *map; + gsize len; + + filename = luaL_checkstring(L, 1); + if (filename != NULL) { + map = rspamd_file_xmap(filename, PROT_READ, &len, TRUE); + + if (map == NULL) { + msg_err("cannot open pubkey from file: %s, %s", + filename, + strerror(errno)); + lua_pushnil(L); + } + else { + if (lua_type(L, 2) == LUA_TSTRING) { + /* keypair type */ + arg = lua_tostring(L, 2); + + if (strcmp(arg, "sign") == 0) { + type = RSPAMD_KEYPAIR_SIGN; + } + else if (strcmp(arg, "kex") == 0) { + type = RSPAMD_KEYPAIR_KEX; + } + } + if (lua_type(L, 3) == LUA_TSTRING) { + /* algorithm */ + arg = lua_tostring(L, 3); + + if (strcmp(arg, "default") == 0 || strcmp(arg, "curve25519") == 0) { + type = RSPAMD_CRYPTOBOX_MODE_25519; + } + else if (strcmp(arg, "nist") == 0) { + type = RSPAMD_CRYPTOBOX_MODE_NIST; + } + } + + pkey = rspamd_pubkey_from_base32(map, len, type, alg); + + if (pkey == NULL) { + msg_err("cannot open pubkey from file: %s", filename); + munmap(map, len); + lua_pushnil(L); + } + else { + munmap(map, len); + ppkey = lua_newuserdata(L, sizeof(void *)); + rspamd_lua_setclass(L, "rspamd{cryptobox_pubkey}", -1); + *ppkey = pkey; + } + } + } + else { + return luaL_error(L, "bad input arguments"); + } + + return 1; +} + + +/*** + * @function rspamd_cryptobox_pubkey.create(data[, type[, alg]]) + * Loads public key from base32 encoded string + * @param {base32 string} base32 string with the key + * @param {string} type optional 'sign' or 'kex' for signing and encryption + * @param {string} alg optional 'default' or 'nist' for curve25519/nistp256 keys + * @return {cryptobox_pubkey} new public key + */ +static gint +lua_cryptobox_pubkey_create(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_cryptobox_pubkey *pkey = NULL, **ppkey; + const gchar *buf, *arg; + gsize len; + gint type = RSPAMD_KEYPAIR_SIGN; + gint alg = RSPAMD_CRYPTOBOX_MODE_25519; + + buf = luaL_checklstring(L, 1, &len); + if (buf != NULL) { + if (lua_type(L, 2) == LUA_TSTRING) { + /* keypair type */ + arg = lua_tostring(L, 2); + + if (strcmp(arg, "sign") == 0) { + type = RSPAMD_KEYPAIR_SIGN; + } + else if (strcmp(arg, "kex") == 0) { + type = RSPAMD_KEYPAIR_KEX; + } + } + if (lua_type(L, 3) == LUA_TSTRING) { + /* algorithm */ + arg = lua_tostring(L, 3); + + if (strcmp(arg, "default") == 0 || strcmp(arg, "curve25519") == 0) { + type = RSPAMD_CRYPTOBOX_MODE_25519; + } + else if (strcmp(arg, "nist") == 0) { + type = RSPAMD_CRYPTOBOX_MODE_NIST; + } + } + + pkey = rspamd_pubkey_from_base32(buf, len, type, alg); + + if (pkey == NULL) { + msg_err("cannot load pubkey from string"); + lua_pushnil(L); + } + else { + ppkey = lua_newuserdata(L, sizeof(void *)); + rspamd_lua_setclass(L, "rspamd{cryptobox_pubkey}", -1); + *ppkey = pkey; + } + } + else { + return luaL_error(L, "bad input arguments"); + } + + return 1; +} + +static gint +lua_cryptobox_pubkey_gc(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_cryptobox_pubkey *pkey = lua_check_cryptobox_pubkey(L, 1); + + if (pkey != NULL) { + rspamd_pubkey_unref(pkey); + } + + return 0; +} + +/*** + * @function rspamd_cryptobox_keypair.load(file|table) + * Loads public key from UCL file or directly from Lua + * @param {string} file filename to load + * @return {cryptobox_keypair} new keypair + */ +static gint +lua_cryptobox_keypair_load(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_cryptobox_keypair *kp, **pkp; + const gchar *buf; + gsize len; + struct ucl_parser *parser; + ucl_object_t *obj; + + if (lua_type(L, 1) == LUA_TSTRING) { + buf = luaL_checklstring(L, 1, &len); + if (buf != NULL) { + parser = ucl_parser_new(0); + + if (!ucl_parser_add_chunk(parser, buf, len)) { + msg_err("cannot open keypair from data: %s", + ucl_parser_get_error(parser)); + ucl_parser_free(parser); + lua_pushnil(L); + } + else { + obj = ucl_parser_get_object(parser); + kp = rspamd_keypair_from_ucl(obj); + ucl_parser_free(parser); + + if (kp == NULL) { + msg_err("cannot load keypair from data"); + ucl_object_unref(obj); + lua_pushnil(L); + } + else { + pkp = lua_newuserdata(L, sizeof(gpointer)); + *pkp = kp; + rspamd_lua_setclass(L, "rspamd{cryptobox_keypair}", -1); + ucl_object_unref(obj); + } + } + } + else { + luaL_error(L, "bad input arguments"); + } + } + else { + /* Directly import from lua */ + obj = ucl_object_lua_import(L, 1); + kp = rspamd_keypair_from_ucl(obj); + + if (kp == NULL) { + msg_err("cannot load keypair from data"); + ucl_object_unref(obj); + lua_pushnil(L); + } + else { + pkp = lua_newuserdata(L, sizeof(gpointer)); + *pkp = kp; + rspamd_lua_setclass(L, "rspamd{cryptobox_keypair}", -1); + ucl_object_unref(obj); + } + } + + return 1; +} + +/*** + * @function rspamd_cryptobox_keypair.create([type='encryption'[, alg='curve25519']]) + * Generates new keypair + * @param {string} type type of keypair: 'encryption' (default) or 'sign' + * @param {string} alg algorithm of keypair: 'curve25519' (default) or 'nist' + * @return {cryptobox_keypair} new keypair + */ +static gint +lua_cryptobox_keypair_create(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_cryptobox_keypair *kp, **pkp; + enum rspamd_cryptobox_keypair_type type = RSPAMD_KEYPAIR_KEX; + enum rspamd_cryptobox_mode alg = RSPAMD_CRYPTOBOX_MODE_25519; + + if (lua_isstring(L, 1)) { + const gchar *str = lua_tostring(L, 1); + + if (strcmp(str, "sign") == 0) { + type = RSPAMD_KEYPAIR_SIGN; + } + else if (strcmp(str, "encryption") == 0) { + type = RSPAMD_KEYPAIR_KEX; + } + else { + return luaL_error(L, "invalid keypair type: %s", str); + } + } + + if (lua_isstring(L, 2)) { + const gchar *str = lua_tostring(L, 2); + + if (strcmp(str, "nist") == 0 || strcmp(str, "openssl") == 0) { + alg = RSPAMD_CRYPTOBOX_MODE_NIST; + } + else if (strcmp(str, "curve25519") == 0 || strcmp(str, "default") == 0) { + alg = RSPAMD_CRYPTOBOX_MODE_25519; + } + else { + return luaL_error(L, "invalid keypair algorithm: %s", str); + } + } + + kp = rspamd_keypair_new(type, alg); + + pkp = lua_newuserdata(L, sizeof(gpointer)); + *pkp = kp; + rspamd_lua_setclass(L, "rspamd{cryptobox_keypair}", -1); + + return 1; +} + +static gint +lua_cryptobox_keypair_gc(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_cryptobox_keypair *kp = lua_check_cryptobox_keypair(L, 1); + + if (kp != NULL) { + rspamd_keypair_unref(kp); + } + + return 0; +} + +/*** + * @method keypair:totable([hex=false]]) + * Converts keypair to table (not very safe due to memory leftovers) + */ +static gint +lua_cryptobox_keypair_totable(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_cryptobox_keypair *kp = lua_check_cryptobox_keypair(L, 1); + ucl_object_t *obj; + gboolean hex = FALSE; + gint ret = 1; + + if (kp != NULL) { + + if (lua_isboolean(L, 2)) { + hex = lua_toboolean(L, 2); + } + + obj = rspamd_keypair_to_ucl(kp, hex ? RSPAMD_KEYPAIR_DUMP_HEX : RSPAMD_KEYPAIR_DUMP_DEFAULT); + + ret = ucl_object_push_lua(L, obj, true); + ucl_object_unref(obj); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return ret; +} +/*** + * @method keypair:type() + * Returns type of keypair as a string: 'encryption' or 'sign' + * @return {string} type of keypair as a string + */ +static gint +lua_cryptobox_keypair_get_type(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_cryptobox_keypair *kp = lua_check_cryptobox_keypair(L, 1); + + if (kp) { + if (kp->type == RSPAMD_KEYPAIR_KEX) { + lua_pushstring(L, "encryption"); + } + else { + lua_pushstring(L, "sign"); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @method keypair:alg() + * Returns algorithm of keypair as a string: 'encryption' or 'sign' + * @return {string} type of keypair as a string + */ +static gint +lua_cryptobox_keypair_get_alg(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_cryptobox_keypair *kp = lua_check_cryptobox_keypair(L, 1); + + if (kp) { + if (kp->alg == RSPAMD_CRYPTOBOX_MODE_25519) { + lua_pushstring(L, "curve25519"); + } + else { + lua_pushstring(L, "nist"); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @method keypair:pk() + * Returns pubkey for a specific keypair + * @return {rspamd_pubkey} pubkey for a keypair + */ +static gint +lua_cryptobox_keypair_get_pk(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_cryptobox_keypair *kp = lua_check_cryptobox_keypair(L, 1); + struct rspamd_cryptobox_pubkey *pk, **ppk; + const guchar *data; + guint dlen; + + if (kp) { + data = rspamd_keypair_component(kp, RSPAMD_KEYPAIR_COMPONENT_PK, &dlen); + pk = rspamd_pubkey_from_bin(data, dlen, kp->type, kp->alg); + + if (pk == NULL) { + return luaL_error(L, "invalid keypair"); + } + + ppk = lua_newuserdata(L, sizeof(*ppk)); + *ppk = pk; + rspamd_lua_setclass(L, "rspamd{cryptobox_pubkey}", -1); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @function rspamd_cryptobox_signature.load(file, [alg = 'curve25519']) + * Loads signature from raw file + * @param {string} file filename to load + * @return {cryptobox_signature} new signature + */ +static gint +lua_cryptobox_signature_load(lua_State *L) +{ + LUA_TRACE_POINT; + rspamd_fstring_t *sig, **psig; + const gchar *filename; + gpointer data; + int fd; + struct stat st; + enum rspamd_cryptobox_mode alg = RSPAMD_CRYPTOBOX_MODE_25519; + + filename = luaL_checkstring(L, 1); + if (filename != NULL) { + fd = open(filename, O_RDONLY); + if (fd == -1) { + msg_err("cannot open signature file: %s, %s", filename, + strerror(errno)); + lua_pushnil(L); + } + else { + if (fstat(fd, &st) == -1 || + (data = + mmap(NULL, st.st_size, PROT_READ, MAP_SHARED, fd, 0)) == MAP_FAILED) { + msg_err("cannot mmap file %s: %s", filename, strerror(errno)); + lua_pushnil(L); + } + else { + if (lua_isstring(L, 2)) { + const gchar *str = lua_tostring(L, 2); + + if (strcmp(str, "nist") == 0 || strcmp(str, "openssl") == 0) { + alg = RSPAMD_CRYPTOBOX_MODE_NIST; + } + else if (strcmp(str, "curve25519") == 0 || strcmp(str, "default") == 0) { + alg = RSPAMD_CRYPTOBOX_MODE_25519; + } + else { + munmap(data, st.st_size); + close(fd); + + return luaL_error(L, "invalid keypair algorithm: %s", str); + } + } + if (st.st_size > 0) { + sig = rspamd_fstring_new_init(data, st.st_size); + psig = lua_newuserdata(L, sizeof(rspamd_fstring_t *)); + rspamd_lua_setclass(L, "rspamd{cryptobox_signature}", -1); + *psig = sig; + } + else { + msg_err("size of %s mismatches: %d while %d is expected", + filename, (int) st.st_size, + rspamd_cryptobox_signature_bytes(alg)); + lua_pushnil(L); + } + + munmap(data, st.st_size); + } + close(fd); + } + } + else { + luaL_error(L, "bad input arguments"); + } + + return 1; +} + +/*** + * @method rspamd_cryptobox_signature:save(file) + * Stores signature in raw file + * @param {string} file filename to use + * @return {boolean} true if signature has been saved + */ +static gint +lua_cryptobox_signature_save(lua_State *L) +{ + LUA_TRACE_POINT; + rspamd_fstring_t *sig; + gint fd, flags; + const gchar *filename; + gboolean forced = FALSE, res = TRUE; + + sig = lua_check_cryptobox_sign(L, 1); + filename = luaL_checkstring(L, 2); + + if (!sig || !filename) { + luaL_error(L, "bad input arguments"); + return 1; + } + + if (lua_gettop(L) > 2) { + forced = lua_toboolean(L, 3); + } + + if (sig != NULL && filename != NULL) { + flags = O_WRONLY | O_CREAT; + if (forced) { + flags |= O_TRUNC; + } + else { + flags |= O_EXCL; + } + fd = open(filename, flags, 00644); + if (fd == -1) { + msg_err("cannot create a signature file: %s, %s", + filename, + strerror(errno)); + lua_pushboolean(L, FALSE); + } + else { + while (write(fd, sig->str, sig->len) == -1) { + if (errno == EINTR) { + continue; + } + msg_err("cannot write to a signature file: %s, %s", + filename, + strerror(errno)); + res = FALSE; + break; + } + lua_pushboolean(L, res); + close(fd); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @function rspamd_cryptobox_signature.create(data) + * Creates signature object from raw data + * @param {data} raw signature data + * @return {cryptobox_signature} signature object + */ +static gint +lua_cryptobox_signature_create(lua_State *L) +{ + LUA_TRACE_POINT; + rspamd_fstring_t *sig, **psig; + struct rspamd_lua_text *t; + const gchar *data; + gsize dlen; + + if (lua_isuserdata(L, 1)) { + t = lua_check_text(L, 1); + + if (!t) { + return luaL_error(L, "invalid arguments"); + } + + data = t->start; + dlen = t->len; + } + else { + data = luaL_checklstring(L, 1, &dlen); + } + + if (data != NULL) { + if (dlen == rspamd_cryptobox_signature_bytes(RSPAMD_CRYPTOBOX_MODE_25519)) { + sig = rspamd_fstring_new_init(data, dlen); + psig = lua_newuserdata(L, sizeof(rspamd_fstring_t *)); + rspamd_lua_setclass(L, "rspamd{cryptobox_signature}", -1); + *psig = sig; + } + } + else { + return luaL_error(L, "bad input arguments"); + } + + return 1; +} + +/*** + * @method cryptobox_signature:hex() + * Return hex encoded signature string + * @return {string} raw value of signature + */ +static gint +lua_cryptobox_signature_hex(lua_State *L) +{ + LUA_TRACE_POINT; + rspamd_fstring_t *sig = lua_check_cryptobox_sign(L, 1); + gchar *encoded; + + if (sig) { + encoded = rspamd_encode_hex(sig->str, sig->len); + lua_pushstring(L, encoded); + g_free(encoded); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @method cryptobox_signature:base32([b32type='default']) + * Return base32 encoded signature string + * @param {string} b32type base32 type (default, bleach, rfc) + * @return {string} raw value of signature + */ +static gint +lua_cryptobox_signature_base32(lua_State *L) +{ + LUA_TRACE_POINT; + rspamd_fstring_t *sig = lua_check_cryptobox_sign(L, 1); + gchar *encoded; + enum rspamd_base32_type btype = RSPAMD_BASE32_DEFAULT; + + if (lua_type(L, 2) == LUA_TSTRING) { + btype = rspamd_base32_decode_type_from_str(lua_tostring(L, 2)); + + if (btype == RSPAMD_BASE32_INVALID) { + return luaL_error(L, "invalid b32 type: %s", lua_tostring(L, 2)); + } + } + + if (sig) { + encoded = rspamd_encode_base32(sig->str, sig->len, btype); + lua_pushstring(L, encoded); + g_free(encoded); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @method cryptobox_signature:base64() + * Return base64 encoded signature string + * @return {string} raw value of signature + */ +static gint +lua_cryptobox_signature_base64(lua_State *L) +{ + LUA_TRACE_POINT; + rspamd_fstring_t *sig = lua_check_cryptobox_sign(L, 1); + gsize dlen; + gchar *encoded; + + if (sig) { + encoded = rspamd_encode_base64(sig->str, sig->len, 0, &dlen); + lua_pushlstring(L, encoded, dlen); + g_free(encoded); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @method cryptobox_signature:bin() + * Return raw signature string + * @return {string} raw value of signature + */ +static gint +lua_cryptobox_signature_bin(lua_State *L) +{ + LUA_TRACE_POINT; + rspamd_fstring_t *sig = lua_check_cryptobox_sign(L, 1); + + if (sig) { + lua_pushlstring(L, sig->str, sig->len); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_cryptobox_signature_gc(lua_State *L) +{ + LUA_TRACE_POINT; + rspamd_fstring_t *sig = lua_check_cryptobox_sign(L, 1); + + rspamd_fstring_free(sig); + + return 0; +} + +static void +rspamd_lua_hash_update(struct rspamd_lua_cryptobox_hash *h, + const void *p, gsize len) +{ + if (h) { + switch (h->type) { + case LUA_CRYPTOBOX_HASH_BLAKE2: + rspamd_cryptobox_hash_update(h->content.h, p, len); + break; + case LUA_CRYPTOBOX_HASH_SSL: + EVP_DigestUpdate(h->content.c, p, len); + break; + case LUA_CRYPTOBOX_HASH_HMAC: + HMAC_Update(h->content.hmac_c, p, len); + break; + case LUA_CRYPTOBOX_HASH_XXHASH64: + case LUA_CRYPTOBOX_HASH_XXHASH32: + case LUA_CRYPTOBOX_HASH_XXHASH3: + case LUA_CRYPTOBOX_HASH_MUM: + case LUA_CRYPTOBOX_HASH_T1HA: + rspamd_cryptobox_fast_hash_update(h->content.fh, p, len); + break; + default: + g_assert_not_reached(); + } + } +} + +static void +lua_cryptobox_hash_dtor(struct rspamd_lua_cryptobox_hash *h) +{ + if (h->type == LUA_CRYPTOBOX_HASH_SSL) { +#if OPENSSL_VERSION_NUMBER < 0x10100000L || defined(LIBRESSL_VERSION_NUMBER) + EVP_MD_CTX_cleanup(h->content.c); +#else + EVP_MD_CTX_reset(h->content.c); +#endif + EVP_MD_CTX_destroy(h->content.c); + } + else if (h->type == LUA_CRYPTOBOX_HASH_HMAC) { +#if OPENSSL_VERSION_NUMBER < 0x10100000L || \ + (defined(LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x30500000) + HMAC_CTX_cleanup(h->content.hmac_c); + g_free(h->content.hmac_c); +#else + HMAC_CTX_free(h->content.hmac_c); +#endif + } + else if (h->type == LUA_CRYPTOBOX_HASH_BLAKE2) { + rspamd_explicit_memzero(h->content.h, sizeof(*h->content.h)); + free(h->content.h); /* Allocated by posix_memalign */ + } + else { + rspamd_cryptobox_fast_hash_free(h->content.fh); + } + + g_free(h); +} + +static inline void +rspamd_lua_hash_init_default(struct rspamd_lua_cryptobox_hash *h, + const gchar *key, gsize keylen) +{ + h->type = LUA_CRYPTOBOX_HASH_BLAKE2; + if (posix_memalign((void **) &h->content.h, + RSPAMD_ALIGNOF(rspamd_cryptobox_hash_state_t), + sizeof(*h->content.h)) != 0) { + g_assert_not_reached(); + } + + rspamd_cryptobox_hash_init(h->content.h, key, keylen); + h->out_len = rspamd_cryptobox_HASHBYTES; +} + +static void +rspamd_lua_ssl_hash_create(struct rspamd_lua_cryptobox_hash *h, const EVP_MD *htype, + bool insecure) +{ + h->type = LUA_CRYPTOBOX_HASH_SSL; + h->content.c = EVP_MD_CTX_create(); + h->out_len = EVP_MD_size(htype); + + if (insecure) { + /* Should never ever be used for crypto/security purposes! */ +#ifdef EVP_MD_CTX_FLAG_NON_FIPS_ALLOW + EVP_MD_CTX_set_flags(h->content.c, EVP_MD_CTX_FLAG_NON_FIPS_ALLOW); +#endif + } + + EVP_DigestInit_ex(h->content.c, htype, NULL); +} + +static void +rspamd_lua_ssl_hmac_create(struct rspamd_lua_cryptobox_hash *h, const EVP_MD *htype, + const gchar *key, gsize keylen, + bool insecure) +{ + h->type = LUA_CRYPTOBOX_HASH_HMAC; + +#if OPENSSL_VERSION_NUMBER < 0x10100000L || \ + (defined(LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x30500000) + h->content.hmac_c = g_malloc0(sizeof(*h->content.hmac_c)); +#else + h->content.hmac_c = HMAC_CTX_new(); +#endif + h->out_len = EVP_MD_size(htype); + +#if OPENSSL_VERSION_NUMBER > 0x10100000L + if (insecure) { + /* Should never ever be used for crypto/security purposes! */ +#ifdef EVP_MD_CTX_FLAG_NON_FIPS_ALLOW + HMAC_CTX_set_flags(h->content.hmac_c, EVP_MD_CTX_FLAG_NON_FIPS_ALLOW); +#endif + } +#endif + + HMAC_Init_ex(h->content.hmac_c, key, keylen, htype, NULL); +} + +static struct rspamd_lua_cryptobox_hash * +rspamd_lua_hash_create(const gchar *type, const gchar *key, gsize keylen) +{ + struct rspamd_lua_cryptobox_hash *h; + + h = g_malloc0(sizeof(*h)); + REF_INIT_RETAIN(h, lua_cryptobox_hash_dtor); + + if (type) { + if (g_ascii_strcasecmp(type, "md5") == 0) { + if (keylen > 0) { + rspamd_lua_ssl_hmac_create(h, EVP_md5(), key, keylen, true); + } + else { + rspamd_lua_ssl_hash_create(h, EVP_md5(), true); + } + } + else if (g_ascii_strcasecmp(type, "sha1") == 0 || + g_ascii_strcasecmp(type, "sha") == 0) { + if (keylen > 0) { + rspamd_lua_ssl_hmac_create(h, EVP_sha1(), key, keylen, true); + } + else { + rspamd_lua_ssl_hash_create(h, EVP_sha1(), true); + } + } + else if (g_ascii_strcasecmp(type, "sha256") == 0) { + if (keylen > 0) { + rspamd_lua_ssl_hmac_create(h, EVP_sha256(), key, keylen, true); + } + else { + rspamd_lua_ssl_hash_create(h, EVP_sha256(), true); + } + } + else if (g_ascii_strcasecmp(type, "sha512") == 0) { + if (keylen > 0) { + rspamd_lua_ssl_hmac_create(h, EVP_sha512(), key, keylen, true); + } + else { + rspamd_lua_ssl_hash_create(h, EVP_sha512(), true); + } + } + else if (g_ascii_strcasecmp(type, "sha384") == 0) { + if (keylen > 0) { + rspamd_lua_ssl_hmac_create(h, EVP_sha384(), key, keylen, true); + } + else { + rspamd_lua_ssl_hash_create(h, EVP_sha384(), true); + } + } + else if (g_ascii_strcasecmp(type, "xxh64") == 0) { + h->type = LUA_CRYPTOBOX_HASH_XXHASH64; + h->content.fh = rspamd_cryptobox_fast_hash_new(); + rspamd_cryptobox_fast_hash_init_specific(h->content.fh, + RSPAMD_CRYPTOBOX_XXHASH64, 0); + h->out_len = sizeof(guint64); + } + else if (g_ascii_strcasecmp(type, "xxh32") == 0) { + h->type = LUA_CRYPTOBOX_HASH_XXHASH32; + h->content.fh = rspamd_cryptobox_fast_hash_new(); + rspamd_cryptobox_fast_hash_init_specific(h->content.fh, + RSPAMD_CRYPTOBOX_XXHASH32, 0); + h->out_len = sizeof(guint32); + } + else if (g_ascii_strcasecmp(type, "xxh3") == 0) { + h->type = LUA_CRYPTOBOX_HASH_XXHASH3; + h->content.fh = rspamd_cryptobox_fast_hash_new(); + rspamd_cryptobox_fast_hash_init_specific(h->content.fh, + RSPAMD_CRYPTOBOX_XXHASH3, 0); + h->out_len = sizeof(guint64); + } + else if (g_ascii_strcasecmp(type, "mum") == 0) { + h->type = LUA_CRYPTOBOX_HASH_MUM; + h->content.fh = rspamd_cryptobox_fast_hash_new(); + rspamd_cryptobox_fast_hash_init_specific(h->content.fh, + RSPAMD_CRYPTOBOX_MUMHASH, 0); + h->out_len = sizeof(guint64); + } + else if (g_ascii_strcasecmp(type, "t1ha") == 0) { + h->type = LUA_CRYPTOBOX_HASH_T1HA; + h->content.fh = rspamd_cryptobox_fast_hash_new(); + rspamd_cryptobox_fast_hash_init_specific(h->content.fh, + RSPAMD_CRYPTOBOX_T1HA, 0); + h->out_len = sizeof(guint64); + } + else if (g_ascii_strcasecmp(type, "blake2") == 0) { + rspamd_lua_hash_init_default(h, key, keylen); + } + else { + g_free(h); + + return NULL; + } + } + else { + /* Default hash type */ + rspamd_lua_hash_init_default(h, key, keylen); + } + + return h; +} + +/*** + * @function rspamd_cryptobox_hash.create([string]) + * Creates new hash context + * @param {string} data optional string to hash + * @return {cryptobox_hash} hash object + */ +static gint +lua_cryptobox_hash_create(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_cryptobox_hash *h, **ph; + const gchar *s = NULL; + struct rspamd_lua_text *t; + gsize len = 0; + + h = rspamd_lua_hash_create(NULL, NULL, 0); + + if (lua_type(L, 1) == LUA_TSTRING) { + s = lua_tolstring(L, 1, &len); + } + else if (lua_type(L, 1) == LUA_TUSERDATA) { + t = lua_check_text(L, 1); + + if (!t) { + REF_RELEASE(h); + return luaL_error(L, "invalid arguments"); + } + + s = t->start; + len = t->len; + } + + if (s) { + rspamd_lua_hash_update(h, s, len); + } + + ph = lua_newuserdata(L, sizeof(void *)); + *ph = h; + rspamd_lua_setclass(L, "rspamd{cryptobox_hash}", -1); + + return 1; +} + +/*** + * @function rspamd_cryptobox_hash.create_specific(type, [string]) + * Creates new hash context + * @param {string} type type of hash (blake2, sha256, md5, sha512, mum, xxh64, xxh32, t1ha) + * @param {string} string initial data + * @return {cryptobox_hash} hash object + */ +static gint +lua_cryptobox_hash_create_specific(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_cryptobox_hash *h, **ph; + const gchar *s = NULL, *type = luaL_checkstring(L, 1); + gsize len = 0; + struct rspamd_lua_text *t; + + if (!type) { + return luaL_error(L, "invalid arguments"); + } + + h = rspamd_lua_hash_create(type, NULL, 0); + + if (h == NULL) { + return luaL_error(L, "invalid hash type: %s", type); + } + + if (lua_type(L, 2) == LUA_TSTRING) { + s = lua_tolstring(L, 2, &len); + } + else if (lua_type(L, 2) == LUA_TUSERDATA) { + t = lua_check_text(L, 2); + + if (!t) { + REF_RELEASE(h); + return luaL_error(L, "invalid arguments"); + } + + s = t->start; + len = t->len; + } + + if (s) { + rspamd_lua_hash_update(h, s, len); + } + + ph = lua_newuserdata(L, sizeof(void *)); + *ph = h; + rspamd_lua_setclass(L, "rspamd{cryptobox_hash}", -1); + + return 1; +} + +/*** + * @function rspamd_cryptobox_hash.create_keyed(key, [string]) + * Creates new hash context with specified key + * @param {string} key key + * @return {cryptobox_hash} hash object + */ +static gint +lua_cryptobox_hash_create_keyed(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_cryptobox_hash *h, **ph; + const gchar *key, *s = NULL; + struct rspamd_lua_text *t; + gsize len = 0; + gsize keylen; + + key = luaL_checklstring(L, 1, &keylen); + + if (key != NULL) { + h = rspamd_lua_hash_create(NULL, key, keylen); + + if (lua_type(L, 2) == LUA_TSTRING) { + s = lua_tolstring(L, 2, &len); + } + else if (lua_type(L, 2) == LUA_TUSERDATA) { + t = lua_check_text(L, 2); + + if (!t) { + REF_RELEASE(h); + return luaL_error(L, "invalid arguments"); + } + + s = t->start; + len = t->len; + } + + if (s) { + rspamd_lua_hash_update(h, s, len); + } + + ph = lua_newuserdata(L, sizeof(void *)); + *ph = h; + rspamd_lua_setclass(L, "rspamd{cryptobox_hash}", -1); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @function rspamd_cryptobox_hash.create_specific_keyed(key, type, [string]) + * Creates new hash context with specified key + * @param {string} key key + * @return {cryptobox_hash} hash object + */ +static gint +lua_cryptobox_hash_create_specific_keyed(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_cryptobox_hash *h, **ph; + const gchar *key, *s = NULL, *type = luaL_checkstring(L, 2); + struct rspamd_lua_text *t; + gsize len = 0; + gsize keylen; + + key = luaL_checklstring(L, 1, &keylen); + + if (key != NULL && type != NULL) { + h = rspamd_lua_hash_create(type, key, keylen); + + if (h == NULL) { + return luaL_error(L, "invalid hash type: %s", type); + } + + if (lua_type(L, 3) == LUA_TSTRING) { + s = lua_tolstring(L, 3, &len); + } + else if (lua_type(L, 3) == LUA_TUSERDATA) { + t = lua_check_text(L, 3); + + if (!t) { + REF_RELEASE(h); + + return luaL_error(L, "invalid arguments"); + } + + s = t->start; + len = t->len; + } + + if (s) { + rspamd_lua_hash_update(h, s, len); + } + + ph = lua_newuserdata(L, sizeof(void *)); + *ph = h; + rspamd_lua_setclass(L, "rspamd{cryptobox_hash}", -1); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @method cryptobox_hash:update(data) + * Updates hash with the specified data (hash should not be finalized using `hex` or `bin` methods) + * @param {string} data data to hash + */ +static gint +lua_cryptobox_hash_update(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_cryptobox_hash *h = lua_check_cryptobox_hash(L, 1), **ph; + const gchar *data; + struct rspamd_lua_text *t; + gsize len; + + if (lua_isuserdata(L, 2)) { + t = lua_check_text(L, 2); + + if (!t) { + return luaL_error(L, "invalid arguments"); + } + + data = t->start; + len = t->len; + } + else { + data = luaL_checklstring(L, 2, &len); + } + + if (lua_isnumber(L, 3)) { + gsize nlen = lua_tonumber(L, 3); + + if (nlen > len) { + return luaL_error(L, "invalid length: %d while %d is available", + (int) nlen, (int) len); + } + + len = nlen; + } + + if (h && data) { + if (!h->is_finished) { + rspamd_lua_hash_update(h, data, len); + } + else { + return luaL_error(L, "hash is already finalized"); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + ph = lua_newuserdata(L, sizeof(void *)); + *ph = h; + REF_RETAIN(h); + rspamd_lua_setclass(L, "rspamd{cryptobox_hash}", -1); + + return 1; +} + +/*** + * @method cryptobox_hash:reset() + * Resets hash to the initial state + */ +static gint +lua_cryptobox_hash_reset(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_cryptobox_hash *h = lua_check_cryptobox_hash(L, 1), **ph; + + if (h) { + switch (h->type) { + case LUA_CRYPTOBOX_HASH_BLAKE2: + memset(h->content.h, 0, sizeof(*h->content.h)); + rspamd_cryptobox_hash_init(h->content.h, NULL, 0); + break; + case LUA_CRYPTOBOX_HASH_SSL: + EVP_DigestInit(h->content.c, EVP_MD_CTX_md(h->content.c)); + break; + case LUA_CRYPTOBOX_HASH_HMAC: +#if OPENSSL_VERSION_NUMBER < 0x10100000L || \ + (defined(LIBRESSL_VERSION_NUMBER) && LIBRESSL_VERSION_NUMBER < 0x30500000) + /* Old openssl is awesome... */ + HMAC_Init_ex(h->content.hmac_c, NULL, 0, h->content.hmac_c->md, NULL); +#else + HMAC_CTX_reset(h->content.hmac_c); +#endif + break; + case LUA_CRYPTOBOX_HASH_XXHASH64: + rspamd_cryptobox_fast_hash_init_specific(h->content.fh, + RSPAMD_CRYPTOBOX_XXHASH64, 0); + break; + case LUA_CRYPTOBOX_HASH_XXHASH32: + rspamd_cryptobox_fast_hash_init_specific(h->content.fh, + RSPAMD_CRYPTOBOX_XXHASH32, 0); + break; + case LUA_CRYPTOBOX_HASH_XXHASH3: + rspamd_cryptobox_fast_hash_init_specific(h->content.fh, + RSPAMD_CRYPTOBOX_XXHASH3, 0); + break; + case LUA_CRYPTOBOX_HASH_MUM: + rspamd_cryptobox_fast_hash_init_specific(h->content.fh, + RSPAMD_CRYPTOBOX_MUMHASH, 0); + break; + case LUA_CRYPTOBOX_HASH_T1HA: + rspamd_cryptobox_fast_hash_init_specific(h->content.fh, + RSPAMD_CRYPTOBOX_T1HA, 0); + break; + default: + g_assert_not_reached(); + } + h->is_finished = FALSE; + } + else { + return luaL_error(L, "invalid arguments"); + } + + ph = lua_newuserdata(L, sizeof(void *)); + *ph = h; + REF_RETAIN(h); + rspamd_lua_setclass(L, "rspamd{cryptobox_hash}", -1); + + return 1; +} + +static void +lua_cryptobox_hash_finish(struct rspamd_lua_cryptobox_hash *h) +{ + guint64 ll; + guchar out[rspamd_cryptobox_HASHBYTES]; + guint ssl_outlen = sizeof(out); + + switch (h->type) { + case LUA_CRYPTOBOX_HASH_BLAKE2: + rspamd_cryptobox_hash_final(h->content.h, out); + memcpy(h->out, out, sizeof(out)); + break; + case LUA_CRYPTOBOX_HASH_SSL: + EVP_DigestFinal_ex(h->content.c, out, &ssl_outlen); + h->out_len = ssl_outlen; + g_assert(ssl_outlen <= sizeof(h->out)); + memcpy(h->out, out, ssl_outlen); + break; + case LUA_CRYPTOBOX_HASH_HMAC: + HMAC_Final(h->content.hmac_c, out, &ssl_outlen); + h->out_len = ssl_outlen; + g_assert(ssl_outlen <= sizeof(h->out)); + memcpy(h->out, out, ssl_outlen); + break; + case LUA_CRYPTOBOX_HASH_XXHASH64: + case LUA_CRYPTOBOX_HASH_XXHASH32: + case LUA_CRYPTOBOX_HASH_XXHASH3: + case LUA_CRYPTOBOX_HASH_MUM: + case LUA_CRYPTOBOX_HASH_T1HA: + ll = rspamd_cryptobox_fast_hash_final(h->content.fh); + memcpy(h->out, &ll, sizeof(ll)); + break; + default: + g_assert_not_reached(); + } + + h->is_finished = TRUE; +} + +/*** + * @method cryptobox_hash:hex() + * Finalizes hash and return it as hex string + * @return {string} hex value of hash + */ +static gint +lua_cryptobox_hash_hex(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_cryptobox_hash *h = lua_check_cryptobox_hash(L, 1); + guchar out_hex[rspamd_cryptobox_HASHBYTES * 2 + 1], *r; + guint dlen; + + if (h) { + if (!h->is_finished) { + lua_cryptobox_hash_finish(h); + } + + memset(out_hex, 0, sizeof(out_hex)); + r = h->out; + dlen = h->out_len; + + if (lua_isnumber(L, 2)) { + guint lim = lua_tonumber(L, 2); + + if (lim < dlen) { + r += dlen - lim; + dlen = lim; + } + } + + rspamd_encode_hex_buf(r, dlen, out_hex, sizeof(out_hex)); + lua_pushstring(L, out_hex); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @method cryptobox_hash:base32([b32type]) + * Finalizes hash and return it as zbase32 (by default) string + * @param {string} b32type base32 type (default, bleach, rfc) + * @return {string} base32 value of hash + */ +static gint +lua_cryptobox_hash_base32(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_cryptobox_hash *h = lua_check_cryptobox_hash(L, 1); + guchar out_b32[rspamd_cryptobox_HASHBYTES * 2], *r; + guint dlen; + + if (h) { + enum rspamd_base32_type btype = RSPAMD_BASE32_DEFAULT; + + if (lua_type(L, 2) == LUA_TSTRING) { + btype = rspamd_base32_decode_type_from_str(lua_tostring(L, 2)); + + if (btype == RSPAMD_BASE32_INVALID) { + return luaL_error(L, "invalid b32 type: %s", lua_tostring(L, 2)); + } + } + + if (!h->is_finished) { + lua_cryptobox_hash_finish(h); + } + + memset(out_b32, 0, sizeof(out_b32)); + r = h->out; + dlen = h->out_len; + + if (lua_isnumber(L, 2)) { + guint lim = lua_tonumber(L, 2); + + if (lim < dlen) { + r += dlen - lim; + dlen = lim; + } + } + + rspamd_encode_base32_buf(r, dlen, out_b32, sizeof(out_b32), btype); + lua_pushstring(L, out_b32); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @method cryptobox_hash:base64() + * Finalizes hash and return it as base64 string + * @return {string} base64 value of hash + */ +static gint +lua_cryptobox_hash_base64(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_cryptobox_hash *h = lua_check_cryptobox_hash(L, 1); + guchar *b64, *r; + gsize len; + guint dlen; + + if (h) { + if (!h->is_finished) { + lua_cryptobox_hash_finish(h); + } + + r = h->out; + dlen = h->out_len; + + if (lua_isnumber(L, 2)) { + guint lim = lua_tonumber(L, 2); + + if (lim < dlen) { + r += dlen - lim; + dlen = lim; + } + } + + b64 = rspamd_encode_base64(r, dlen, 0, &len); + lua_pushlstring(L, b64, len); + g_free(b64); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @method cryptobox_hash:bin() + * Finalizes hash and return it as raw string + * @return {string} raw value of hash + */ +static gint +lua_cryptobox_hash_bin(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_cryptobox_hash *h = lua_check_cryptobox_hash(L, 1); + guchar *r; + guint dlen; + + if (h) { + if (!h->is_finished) { + lua_cryptobox_hash_finish(h); + } + + r = h->out; + dlen = h->out_len; + + if (lua_isnumber(L, 2)) { + guint lim = lua_tonumber(L, 2); + + if (lim < dlen) { + r += dlen - lim; + dlen = lim; + } + } + + lua_pushlstring(L, r, dlen); + h->is_finished = TRUE; + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_cryptobox_hash_gc(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_cryptobox_hash *h = lua_check_cryptobox_hash(L, 1); + + REF_RELEASE(h); + + return 0; +} + +/*** + * @function rspamd_cryptobox.verify_memory(pk, sig, data, [alg = 'curve25519']) + * Check memory using specified cryptobox key and signature + * @param {pubkey} pk public key to verify + * @param {sig} signature to check + * @param {string} data data to check signature against + * @return {boolean} `true` - if string matches cryptobox signature + */ +static gint +lua_cryptobox_verify_memory(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_cryptobox_pubkey *pk; + rspamd_fstring_t *signature; + struct rspamd_lua_text *t; + const gchar *data; + enum rspamd_cryptobox_mode alg = RSPAMD_CRYPTOBOX_MODE_25519; + gsize len; + gint ret; + + pk = lua_check_cryptobox_pubkey(L, 1); + signature = lua_check_cryptobox_sign(L, 2); + + if (lua_isuserdata(L, 3)) { + t = lua_check_text(L, 3); + + if (!t) { + return luaL_error(L, "invalid arguments"); + } + + data = t->start; + len = t->len; + } + else { + data = luaL_checklstring(L, 3, &len); + } + + if (lua_isstring(L, 4)) { + const gchar *str = lua_tostring(L, 4); + + if (strcmp(str, "nist") == 0 || strcmp(str, "openssl") == 0) { + alg = RSPAMD_CRYPTOBOX_MODE_NIST; + } + else if (strcmp(str, "curve25519") == 0 || strcmp(str, "default") == 0) { + alg = RSPAMD_CRYPTOBOX_MODE_25519; + } + else { + return luaL_error(L, "invalid algorithm: %s", str); + } + } + + if (pk != NULL && signature != NULL && data != NULL) { + ret = rspamd_cryptobox_verify(signature->str, signature->len, data, len, + rspamd_pubkey_get_pk(pk, NULL), alg); + + if (ret) { + lua_pushboolean(L, 1); + } + else { + lua_pushboolean(L, 0); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @function rspamd_cryptobox.verify_file(pk, sig, file, [alg = 'curve25519']) + * Check file using specified cryptobox key and signature + * @param {pubkey} pk public key to verify + * @param {sig} signature to check + * @param {string} file to load data from + * @return {boolean} `true` - if string matches cryptobox signature + */ +static gint +lua_cryptobox_verify_file(lua_State *L) +{ + LUA_TRACE_POINT; + const gchar *fname; + struct rspamd_cryptobox_pubkey *pk; + rspamd_fstring_t *signature; + guchar *map = NULL; + enum rspamd_cryptobox_mode alg = RSPAMD_CRYPTOBOX_MODE_25519; + gsize len; + gint ret; + + pk = lua_check_cryptobox_pubkey(L, 1); + signature = lua_check_cryptobox_sign(L, 2); + fname = luaL_checkstring(L, 3); + + if (lua_isstring(L, 4)) { + const gchar *str = lua_tostring(L, 4); + + if (strcmp(str, "nist") == 0 || strcmp(str, "openssl") == 0) { + alg = RSPAMD_CRYPTOBOX_MODE_NIST; + } + else if (strcmp(str, "curve25519") == 0 || strcmp(str, "default") == 0) { + alg = RSPAMD_CRYPTOBOX_MODE_25519; + } + else { + return luaL_error(L, "invalid algorithm: %s", str); + } + } + + map = rspamd_file_xmap(fname, PROT_READ, &len, TRUE); + + if (map != NULL && pk != NULL && signature != NULL) { + ret = rspamd_cryptobox_verify(signature->str, signature->len, + map, len, + rspamd_pubkey_get_pk(pk, NULL), alg); + + if (ret) { + lua_pushboolean(L, 1); + } + else { + lua_pushboolean(L, 0); + } + } + else { + if (map != NULL) { + munmap(map, len); + } + + return luaL_error(L, "invalid arguments"); + } + + if (map != NULL) { + munmap(map, len); + } + + return 1; +} + +/*** + * @function rspamd_cryptobox.sign_memory(kp, data) + * Sign data using specified keypair + * @param {keypair} kp keypair to sign + * @param {string} data + * @return {cryptobox_signature} signature object + */ +static gint +lua_cryptobox_sign_memory(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_cryptobox_keypair *kp; + const gchar *data; + struct rspamd_lua_text *t; + gsize len = 0; + rspamd_fstring_t *sig, **psig; + + kp = lua_check_cryptobox_keypair(L, 1); + + if (lua_isuserdata(L, 2)) { + t = lua_check_text(L, 2); + + if (!t) { + return luaL_error(L, "invalid arguments"); + } + + data = t->start; + len = t->len; + } + else { + data = luaL_checklstring(L, 2, &len); + } + + + if (!kp || !data || kp->type == RSPAMD_KEYPAIR_KEX) { + return luaL_error(L, "invalid arguments"); + } + + sig = rspamd_fstring_sized_new(rspamd_cryptobox_signature_bytes( + rspamd_keypair_alg(kp))); + + unsigned long long siglen = sig->len; + rspamd_cryptobox_sign(sig->str, &siglen, data, + len, rspamd_keypair_component(kp, RSPAMD_KEYPAIR_COMPONENT_SK, NULL), rspamd_keypair_alg(kp)); + + sig->len = siglen; + psig = lua_newuserdata(L, sizeof(void *)); + *psig = sig; + rspamd_lua_setclass(L, "rspamd{cryptobox_signature}", -1); + + return 1; +} + +/*** + * @function rspamd_cryptobox.sign_file(kp, file) + * Sign file using specified keypair + * @param {keypair} kp keypair to sign + * @param {string} filename + * @return {cryptobox_signature} signature object + */ +static gint +lua_cryptobox_sign_file(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_cryptobox_keypair *kp; + const gchar *filename; + gchar *data; + gsize len = 0; + rspamd_fstring_t *sig, **psig; + + kp = lua_check_cryptobox_keypair(L, 1); + filename = luaL_checkstring(L, 2); + + if (!kp || !filename) { + return luaL_error(L, "invalid arguments"); + } + + data = rspamd_file_xmap(filename, PROT_READ, &len, TRUE); + + if (data == NULL) { + msg_err("cannot mmap file %s: %s", filename, strerror(errno)); + lua_pushnil(L); + } + else { + sig = rspamd_fstring_sized_new(rspamd_cryptobox_signature_bytes( + rspamd_keypair_alg(kp))); + + unsigned long long siglen = sig->len; + + rspamd_cryptobox_sign(sig->str, &siglen, data, + len, rspamd_keypair_component(kp, RSPAMD_KEYPAIR_COMPONENT_SK, NULL), rspamd_keypair_alg(kp)); + + sig->len = siglen; + psig = lua_newuserdata(L, sizeof(void *)); + *psig = sig; + rspamd_lua_setclass(L, "rspamd{cryptobox_signature}", -1); + munmap(data, len); + } + + return 1; +} + +/*** + * @function rspamd_cryptobox.encrypt_memory(kp, data[, nist=false]) + * Encrypt data using specified keypair/pubkey + * @param {keypair|string} kp keypair or pubkey in base32 to use + * @param {string|text} data + * @return {rspamd_text} encrypted text + */ +static gint +lua_cryptobox_encrypt_memory(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_cryptobox_keypair *kp = NULL; + struct rspamd_cryptobox_pubkey *pk = NULL; + const gchar *data; + guchar *out = NULL; + struct rspamd_lua_text *t, *res; + gsize len = 0, outlen = 0; + GError *err = NULL; + bool owned_pk = false; + + if (lua_type(L, 1) == LUA_TUSERDATA) { + if (rspamd_lua_check_udata_maybe(L, 1, "rspamd{cryptobox_keypair}")) { + kp = lua_check_cryptobox_keypair(L, 1); + } + else if (rspamd_lua_check_udata_maybe(L, 1, "rspamd{cryptobox_pubkey}")) { + pk = lua_check_cryptobox_pubkey(L, 1); + } + } + else if (lua_type(L, 1) == LUA_TSTRING) { + const gchar *b32; + gsize blen; + + b32 = lua_tolstring(L, 1, &blen); + pk = rspamd_pubkey_from_base32(b32, blen, RSPAMD_KEYPAIR_KEX, + lua_toboolean(L, 3) ? RSPAMD_CRYPTOBOX_MODE_NIST : RSPAMD_CRYPTOBOX_MODE_25519); + owned_pk = true; + } + + if (lua_isuserdata(L, 2)) { + t = lua_check_text(L, 2); + + if (!t) { + goto err; + } + + data = t->start; + len = t->len; + } + else { + data = luaL_checklstring(L, 2, &len); + } + + + if (!(kp || pk) || !data) { + goto err; + } + + if (kp) { + if (!rspamd_keypair_encrypt(kp, data, len, &out, &outlen, &err)) { + gint ret = luaL_error(L, "cannot encrypt data: %s", err->message); + g_error_free(err); + + if (owned_pk) { + rspamd_pubkey_unref(pk); + } + + return ret; + } + } + else { + if (!rspamd_pubkey_encrypt(pk, data, len, &out, &outlen, &err)) { + gint ret = luaL_error(L, "cannot encrypt data: %s", err->message); + g_error_free(err); + + if (owned_pk) { + rspamd_pubkey_unref(pk); + } + + return ret; + } + } + + res = lua_newuserdata(L, sizeof(*res)); + res->flags = RSPAMD_TEXT_FLAG_OWN; + res->start = out; + res->len = outlen; + rspamd_lua_setclass(L, "rspamd{text}", -1); + + if (owned_pk) { + rspamd_pubkey_unref(pk); + } + + return 1; +err: + + if (owned_pk) { + rspamd_pubkey_unref(pk); + } + + return luaL_error(L, "invalid arguments"); +} + +/*** + * @function rspamd_cryptobox.encrypt_file(kp|pk_string, filename[, nist=false]) + * Encrypt data using specified keypair/pubkey + * @param {keypair|string} kp keypair or pubkey in base32 to use + * @param {string} filename + * @return {rspamd_text} encrypted text + */ +static gint +lua_cryptobox_encrypt_file(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_cryptobox_keypair *kp = NULL; + struct rspamd_cryptobox_pubkey *pk = NULL; + const gchar *filename; + gchar *data = NULL; + guchar *out = NULL; + struct rspamd_lua_text *res; + gsize len = 0, outlen = 0; + GError *err = NULL; + bool own_pk = false; + + if (lua_type(L, 1) == LUA_TUSERDATA) { + if (rspamd_lua_check_udata_maybe(L, 1, "rspamd{cryptobox_keypair}")) { + kp = lua_check_cryptobox_keypair(L, 1); + } + else if (rspamd_lua_check_udata_maybe(L, 1, "rspamd{cryptobox_pubkey}")) { + pk = lua_check_cryptobox_pubkey(L, 1); + } + } + else if (lua_type(L, 1) == LUA_TSTRING) { + const gchar *b32; + gsize blen; + + b32 = lua_tolstring(L, 1, &blen); + pk = rspamd_pubkey_from_base32(b32, blen, RSPAMD_KEYPAIR_KEX, + lua_toboolean(L, 3) ? RSPAMD_CRYPTOBOX_MODE_NIST : RSPAMD_CRYPTOBOX_MODE_25519); + own_pk = true; + } + + filename = luaL_checkstring(L, 2); + data = rspamd_file_xmap(filename, PROT_READ, &len, TRUE); + + if (!(kp || pk) || !data) { + goto err; + } + + if (kp) { + if (!rspamd_keypair_encrypt(kp, data, len, &out, &outlen, &err)) { + gint ret = luaL_error(L, "cannot encrypt file %s: %s", filename, + err->message); + g_error_free(err); + munmap(data, len); + if (own_pk) { + rspamd_pubkey_unref(pk); + } + + return ret; + } + } + else if (pk) { + if (!rspamd_pubkey_encrypt(pk, data, len, &out, &outlen, &err)) { + gint ret = luaL_error(L, "cannot encrypt file %s: %s", filename, + err->message); + g_error_free(err); + munmap(data, len); + + if (own_pk) { + rspamd_pubkey_unref(pk); + } + + return ret; + } + } + + res = lua_newuserdata(L, sizeof(*res)); + res->flags = RSPAMD_TEXT_FLAG_OWN; + res->start = out; + res->len = outlen; + rspamd_lua_setclass(L, "rspamd{text}", -1); + munmap(data, len); + if (own_pk) { + rspamd_pubkey_unref(pk); + } + + return 1; + +err: + if (data) { + munmap(data, len); + } + if (own_pk) { + rspamd_pubkey_unref(pk); + } + return luaL_error(L, "invalid arguments"); +} + +/*** + * @function rspamd_cryptobox.decrypt_memory(kp, data[, nist = false]) + * Encrypt data using specified keypair + * @param {keypair} kp keypair to use + * @param {string} data + * @return status,{rspamd_text}|error status is boolean variable followed by either unencrypted data or an error message + */ +static gint +lua_cryptobox_decrypt_memory(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_cryptobox_keypair *kp; + const gchar *data; + guchar *out; + struct rspamd_lua_text *t, *res; + gsize len = 0, outlen; + GError *err = NULL; + + kp = lua_check_cryptobox_keypair(L, 1); + + if (lua_isuserdata(L, 2)) { + t = lua_check_text(L, 2); + + if (!t) { + return luaL_error(L, "invalid arguments"); + } + + data = t->start; + len = t->len; + } + else { + data = luaL_checklstring(L, 2, &len); + } + + + if (!kp || !data) { + return luaL_error(L, "invalid arguments"); + } + + if (!rspamd_keypair_decrypt(kp, data, len, &out, &outlen, &err)) { + lua_pushboolean(L, false); + lua_pushstring(L, err->message); + g_error_free(err); + } + else { + lua_pushboolean(L, true); + res = lua_newuserdata(L, sizeof(*res)); + res->flags = RSPAMD_TEXT_FLAG_OWN; + res->start = out; + res->len = outlen; + rspamd_lua_setclass(L, "rspamd{text}", -1); + } + + return 2; +} + +/*** + * @function rspamd_cryptobox.decrypt_file(kp, filename) + * Encrypt data using specified keypair + * @param {keypair} kp keypair to use + * @param {string} filename + * @return status,{rspamd_text}|error status is boolean variable followed by either unencrypted data or an error message + */ +static gint +lua_cryptobox_decrypt_file(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_cryptobox_keypair *kp; + const gchar *filename; + gchar *data; + guchar *out; + struct rspamd_lua_text *res; + gsize len = 0, outlen; + GError *err = NULL; + + kp = lua_check_cryptobox_keypair(L, 1); + if (!kp) { + return luaL_error(L, "invalid arguments; keypair is expected"); + } + + filename = luaL_checkstring(L, 2); + data = rspamd_file_xmap(filename, PROT_READ, &len, TRUE); + if (!data) { + return luaL_error(L, "invalid arguments; cannot mmap %s: %s", + filename, strerror(errno)); + } + + if (!rspamd_keypair_decrypt(kp, data, len, &out, &outlen, &err)) { + lua_pushboolean(L, false); + lua_pushstring(L, err->message); + g_error_free(err); + } + else { + lua_pushboolean(L, true); + res = lua_newuserdata(L, sizeof(*res)); + res->flags = RSPAMD_TEXT_FLAG_OWN; + res->start = out; + res->len = outlen; + rspamd_lua_setclass(L, "rspamd{text}", -1); + } + + munmap(data, len); + + return 2; +} + +#define RSPAMD_CRYPTOBOX_AES_BLOCKSIZE 16 +#define RSPAMD_CRYPTOBOX_AES_KEYSIZE 16 + +/*** + * @function rspamd_cryptobox.encrypt_cookie(secret_key, secret_cookie) + * Specialised function that performs AES-CTR encryption of the provided cookie + * ``` + * e := base64(nonce||aesencrypt(nonce, secret_cookie)) + * nonce := uint32_le(unix_timestamp)||random_64bit + * aesencrypt := aes_ctr(nonce, secret_key) ^ pad(secret_cookie) + * pad := secret_cookie || 0^(32-len(secret_cookie)) + * ``` + * @param {string} secret_key secret key as a hex string (must be 16 bytes in raw or 32 in hex) + * @param {string} secret_cookie secret cookie as a string for up to 31 character + * @return {string} e function value for this sk and cookie + */ +static gint +lua_cryptobox_encrypt_cookie(lua_State *L) +{ + guchar aes_block[RSPAMD_CRYPTOBOX_AES_BLOCKSIZE], *blk; + guchar padded_cookie[RSPAMD_CRYPTOBOX_AES_BLOCKSIZE]; + guchar nonce[RSPAMD_CRYPTOBOX_AES_BLOCKSIZE]; + guchar aes_key[RSPAMD_CRYPTOBOX_AES_KEYSIZE]; + guchar result[RSPAMD_CRYPTOBOX_AES_BLOCKSIZE * 2]; + guint32 ts; + + const gchar *sk, *cookie; + gsize sklen, cookie_len; + gint bklen; + + sk = lua_tolstring(L, 1, &sklen); + cookie = lua_tolstring(L, 2, &cookie_len); + + if (sk && cookie) { + if (sklen == 32) { + /* Hex */ + rspamd_decode_hex_buf(sk, sklen, aes_key, sizeof(aes_key)); + } + else if (sklen == RSPAMD_CRYPTOBOX_AES_KEYSIZE) { + /* Raw */ + memcpy(aes_key, sk, sizeof(aes_key)); + } + else { + return luaL_error(L, "invalid keysize %d", (gint) sklen); + } + + if (cookie_len > sizeof(padded_cookie) - 1) { + return luaL_error(L, "cookie is too long %d", (gint) cookie_len); + } + + /* Fill nonce */ + ottery_rand_bytes(nonce, sizeof(guint64) + sizeof(guint32)); + ts = (guint32) rspamd_get_calendar_ticks(); + ts = GUINT32_TO_LE(ts); + memcpy(nonce + sizeof(guint64) + sizeof(guint32), &ts, sizeof(ts)); + + /* Prepare padded cookie */ + memset(padded_cookie, 0, sizeof(padded_cookie)); + memcpy(padded_cookie, cookie, cookie_len); + + /* Perform AES CTR via AES ECB on nonce */ + EVP_CIPHER_CTX *ctx; + ctx = EVP_CIPHER_CTX_new(); + EVP_EncryptInit_ex(ctx, EVP_aes_128_ecb(), NULL, aes_key, NULL); + EVP_CIPHER_CTX_set_padding(ctx, 0); + + bklen = sizeof(aes_block); + blk = aes_block; + g_assert(EVP_EncryptUpdate(ctx, blk, &bklen, nonce, sizeof(nonce))); + blk += bklen; + g_assert(EVP_EncryptFinal_ex(ctx, blk, &bklen)); + EVP_CIPHER_CTX_free(ctx); + + /* Encode result */ + memcpy(result, nonce, sizeof(nonce)); + for (guint i = 0; i < sizeof(aes_block); i++) { + result[i + sizeof(nonce)] = padded_cookie[i] ^ aes_block[i]; + } + + gsize rlen; + gchar *res = rspamd_encode_base64(result, sizeof(result), + 0, &rlen); + + lua_pushlstring(L, res, rlen); + g_free(res); + rspamd_explicit_memzero(aes_key, sizeof(aes_key)); + rspamd_explicit_memzero(aes_block, sizeof(aes_block)); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @function rspamd_cryptobox.decrypt_cookie(secret_key, encrypted_cookie) + * Specialised function that performs AES-CTR decryption of the provided cookie in form + * ``` + * e := base64(nonce||aesencrypt(nonce, secret_cookie)) + * nonce := int32_le(unix_timestamp)||random_96bit + * aesencrypt := aes_ctr(nonce, secret_key) ^ pad(secret_cookie) + * pad := secret_cookie || 0^(32-len(secret_cookie)) + * ``` + * @param {string} secret_key secret key as a hex string (must be 16 bytes in raw or 32 in hex) + * @param {string} encrypted_cookie encrypted cookie as a base64 encoded string + * @return {string+number} decrypted value of the cookie and the cookie timestamp + */ +static gint +lua_cryptobox_decrypt_cookie(lua_State *L) +{ + guchar *blk; + guchar nonce[RSPAMD_CRYPTOBOX_AES_BLOCKSIZE]; + guchar aes_key[RSPAMD_CRYPTOBOX_AES_KEYSIZE]; + guchar *src; + guint32 ts; + + const gchar *sk, *cookie; + gsize sklen, cookie_len; + gint bklen; + + sk = lua_tolstring(L, 1, &sklen); + cookie = lua_tolstring(L, 2, &cookie_len); + + if (sk && cookie) { + if (sklen == 32) { + /* Hex */ + rspamd_decode_hex_buf(sk, sklen, aes_key, sizeof(aes_key)); + } + else if (sklen == RSPAMD_CRYPTOBOX_AES_KEYSIZE) { + /* Raw */ + memcpy(aes_key, sk, sizeof(aes_key)); + } + else { + return luaL_error(L, "invalid keysize %d", (gint) sklen); + } + + src = g_malloc(cookie_len); + + rspamd_cryptobox_base64_decode(cookie, cookie_len, src, &cookie_len); + + if (cookie_len != RSPAMD_CRYPTOBOX_AES_BLOCKSIZE * 2) { + g_free(src); + lua_pushnil(L); + + return 1; + } + + /* Perform AES CTR via AES ECB on nonce */ + EVP_CIPHER_CTX *ctx; + ctx = EVP_CIPHER_CTX_new(); + /* As per CTR definition, we use encrypt for both encrypt and decrypt */ + EVP_EncryptInit_ex(ctx, EVP_aes_128_ecb(), NULL, aes_key, NULL); + EVP_CIPHER_CTX_set_padding(ctx, 0); + + /* Copy time */ + memcpy(&ts, src + sizeof(guint64) + sizeof(guint32), sizeof(ts)); + ts = GUINT32_FROM_LE(ts); + bklen = sizeof(nonce); + blk = nonce; + g_assert(EVP_EncryptUpdate(ctx, blk, &bklen, src, + RSPAMD_CRYPTOBOX_AES_BLOCKSIZE)); + blk += bklen; + g_assert(EVP_EncryptFinal_ex(ctx, blk, &bklen)); + EVP_CIPHER_CTX_free(ctx); + + /* Decode result */ + for (guint i = 0; i < RSPAMD_CRYPTOBOX_AES_BLOCKSIZE; i++) { + src[i + sizeof(nonce)] ^= nonce[i]; + } + + if (src[RSPAMD_CRYPTOBOX_AES_BLOCKSIZE * 2 - 1] != '\0') { + /* Bad cookie */ + lua_pushnil(L); + lua_pushnil(L); + } + else { + lua_pushstring(L, src + sizeof(nonce)); + lua_pushnumber(L, ts); + } + + rspamd_explicit_memzero(src, RSPAMD_CRYPTOBOX_AES_BLOCKSIZE * 2); + g_free(src); + rspamd_explicit_memzero(aes_key, sizeof(aes_key)); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 2; +} + +/*** + * @function rspamd_cryptobox.pbkdf([password, [kdf_alg]]) + * Function that encrypts password using PBKDF function. + * This function either reads password from STDIN or accepts prepared password as + * an argument + * @param {string} password optional password string + * @param {string} kdf_alg algorithm to use (catena or pbkdf2) + * @return {string} encrypted password or nil if error occurs + */ +static gint +lua_cryptobox_pbkdf(lua_State *L) +{ + const struct rspamd_controller_pbkdf *pbkdf = NULL; + const gchar *pbkdf_str = "catena"; + gchar *password; + gsize pwlen; + + if (lua_type(L, 2) == LUA_TSTRING) { + pbkdf_str = lua_tostring(L, 2); + } + + for (guint i = 0; i < RSPAMD_PBKDF_ID_MAX - 1; i++) { + pbkdf = &pbkdf_list[i]; + + if (g_ascii_strcasecmp(pbkdf_str, pbkdf->alias) == 0) { + break; + } + if (g_ascii_strcasecmp(pbkdf_str, pbkdf->name) == 0) { + break; + } + + pbkdf = NULL; + } + + if (pbkdf == NULL) { + return luaL_error(L, "invalid pbkdf algorithm: %s", pbkdf_str); + } + + if (lua_type(L, 1) == LUA_TSTRING) { + password = g_strdup(lua_tolstring(L, 1, &pwlen)); + } + else { + pwlen = 8192; + password = g_malloc0(pwlen); + pwlen = rspamd_read_passphrase(password, pwlen, 0, NULL); + } + + if (pwlen == 0) { + lua_pushnil(L); + g_free(password); + + return 1; + } + + guchar *salt, *key; + gchar *encoded_salt, *encoded_key; + GString *result; + + salt = g_alloca(pbkdf->salt_len); + key = g_alloca(pbkdf->key_len); + ottery_rand_bytes(salt, pbkdf->salt_len); + /* Derive key */ + rspamd_cryptobox_pbkdf(password, pwlen, + salt, pbkdf->salt_len, key, pbkdf->key_len, pbkdf->complexity, + pbkdf->type); + + encoded_salt = rspamd_encode_base32(salt, pbkdf->salt_len, RSPAMD_BASE32_DEFAULT); + encoded_key = rspamd_encode_base32(key, pbkdf->key_len, RSPAMD_BASE32_DEFAULT); + + result = g_string_new(""); + rspamd_printf_gstring(result, "$%d$%s$%s", pbkdf->id, encoded_salt, + encoded_key); + + g_free(encoded_salt); + g_free(encoded_key); + rspamd_explicit_memzero(password, pwlen); + g_free(password); + lua_pushlstring(L, result->str, result->len); + g_string_free(result, TRUE); + + return 1; +} + +/*** + * @function rspamd_cryptobox.gen_dkim_keypair([alg, [nbits]]) + * Generates DKIM keypair. Returns 2 base64 strings as rspamd_text: privkey and pubkey + * @param {string} alg optional algorithm (rsa default, can be ed25519) + * @param {number} nbits optional number of bits for rsa (default 1024) + * @return {rspamd_text,rspamd_text} private key and public key as base64 encoded strings + */ +static gint +lua_cryptobox_gen_dkim_keypair(lua_State *L) +{ + const gchar *alg_str = "rsa"; + guint nbits = 1024; + struct rspamd_lua_text *priv_out, *pub_out; + + if (lua_type(L, 1) == LUA_TSTRING) { + alg_str = lua_tostring(L, 1); + } + + if (lua_type(L, 2) == LUA_TNUMBER) { + nbits = lua_tointeger(L, 2); + } + + if (strcmp(alg_str, "rsa") == 0) { + BIGNUM *e; + RSA *r; + EVP_PKEY *pk; + + e = BN_new(); + r = RSA_new(); + pk = EVP_PKEY_new(); + + if (BN_set_word(e, RSA_F4) != 1) { + BN_free(e); + RSA_free(r); + EVP_PKEY_free(pk); + + return luaL_error(L, "BN_set_word failed"); + } + + if (RSA_generate_key_ex(r, nbits, e, NULL) != 1) { + BN_free(e); + RSA_free(r); + EVP_PKEY_free(pk); + + return luaL_error(L, "RSA_generate_key_ex failed"); + } + + if (EVP_PKEY_set1_RSA(pk, r) != 1) { + BN_free(e); + RSA_free(r); + EVP_PKEY_free(pk); + + return luaL_error(L, "EVP_PKEY_set1_RSA failed"); + } + + BIO *mbio; + gint rc, len; + guchar *data; + gchar *b64_data; + gsize b64_len; + + mbio = BIO_new(BIO_s_mem()); + + /* Process private key */ + rc = i2d_RSAPrivateKey_bio(mbio, r); + + if (rc == 0) { + BIO_free(mbio); + BN_free(e); + RSA_free(r); + EVP_PKEY_free(pk); + + return luaL_error(L, "i2d_RSAPrivateKey_bio failed"); + } + + len = BIO_get_mem_data(mbio, &data); + + b64_data = rspamd_encode_base64(data, len, -1, &b64_len); + + priv_out = lua_newuserdata(L, sizeof(*priv_out)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + priv_out->start = b64_data; + priv_out->len = b64_len; + priv_out->flags = RSPAMD_TEXT_FLAG_OWN | RSPAMD_TEXT_FLAG_WIPE; + + /* Process public key */ + BIO_reset(mbio); + rc = i2d_RSA_PUBKEY_bio(mbio, r); + + if (rc == 0) { + BIO_free(mbio); + BN_free(e); + RSA_free(r); + EVP_PKEY_free(pk); + + return luaL_error(L, "i2d_RSA_PUBKEY_bio failed"); + } + + len = BIO_get_mem_data(mbio, &data); + + b64_data = rspamd_encode_base64(data, len, -1, &b64_len); + + pub_out = lua_newuserdata(L, sizeof(*pub_out)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + pub_out->start = b64_data; + pub_out->len = b64_len; + pub_out->flags = RSPAMD_TEXT_FLAG_OWN; + + BN_free(e); + RSA_free(r); + EVP_PKEY_free(pk); + BIO_free(mbio); + } + else if (strcmp(alg_str, "ed25519") == 0) { + rspamd_sig_pk_t pk; + rspamd_sig_sk_t sk; + gchar *b64_data; + gsize b64_len; + + rspamd_cryptobox_keypair_sig(pk, sk, RSPAMD_CRYPTOBOX_MODE_25519); + + /* Process private key */ + b64_data = rspamd_encode_base64(sk, + rspamd_cryptobox_sk_sig_bytes(RSPAMD_CRYPTOBOX_MODE_25519), + -1, &b64_len); + + priv_out = lua_newuserdata(L, sizeof(*priv_out)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + priv_out->start = b64_data; + priv_out->len = b64_len; + priv_out->flags = RSPAMD_TEXT_FLAG_OWN | RSPAMD_TEXT_FLAG_WIPE; + + /* Process public key */ + b64_data = rspamd_encode_base64(pk, + rspamd_cryptobox_pk_sig_bytes(RSPAMD_CRYPTOBOX_MODE_25519), + -1, &b64_len); + + pub_out = lua_newuserdata(L, sizeof(*pub_out)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + pub_out->start = b64_data; + pub_out->len = b64_len; + pub_out->flags = RSPAMD_TEXT_FLAG_OWN; + + rspamd_explicit_memzero(pk, sizeof(pk)); + rspamd_explicit_memzero(sk, sizeof(sk)); + } + else if (strcmp(alg_str, "ed25519-seed") == 0) { + rspamd_sig_pk_t pk; + rspamd_sig_sk_t sk; + gchar *b64_data; + gsize b64_len; + + rspamd_cryptobox_keypair_sig(pk, sk, RSPAMD_CRYPTOBOX_MODE_25519); + + /* Process private key */ + b64_data = rspamd_encode_base64(sk, + 32, + -1, &b64_len); + + priv_out = lua_newuserdata(L, sizeof(*priv_out)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + priv_out->start = b64_data; + priv_out->len = b64_len; + priv_out->flags = RSPAMD_TEXT_FLAG_OWN | RSPAMD_TEXT_FLAG_WIPE; + + /* Process public key */ + b64_data = rspamd_encode_base64(pk, + rspamd_cryptobox_pk_sig_bytes(RSPAMD_CRYPTOBOX_MODE_25519), + -1, &b64_len); + + pub_out = lua_newuserdata(L, sizeof(*pub_out)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + pub_out->start = b64_data; + pub_out->len = b64_len; + pub_out->flags = RSPAMD_TEXT_FLAG_OWN; + + rspamd_explicit_memzero(pk, sizeof(pk)); + rspamd_explicit_memzero(sk, sizeof(sk)); + } + else { + return luaL_error(L, "invalid algorithm %s", alg_str); + } + + return 2; +} + +/* + * Secretbox API + */ +/* Ensure that KDF output is suitable for crypto_secretbox_KEYBYTES */ +#ifdef crypto_generichash_BYTES_MIN +G_STATIC_ASSERT(crypto_secretbox_KEYBYTES >= crypto_generichash_BYTES_MIN); +#endif + +/*** + * @function rspamd_cryptobox_secretbox.create(secret_string, [params]) + * Generates a secretbox state by expanding secret string + * @param {string/text} secret_string secret string (should have high enough entropy) + * @param {table} params optional parameters - NYI + * @return {rspamd_cryptobox_secretbox} opaque object with the key expanded + */ +static gint +lua_cryptobox_secretbox_create(lua_State *L) +{ + const gchar *in; + gsize inlen; + + + if (lua_isstring(L, 1)) { + in = lua_tolstring(L, 1, &inlen); + } + else if (lua_isuserdata(L, 1)) { + struct rspamd_lua_text *t = lua_check_text(L, 1); + + if (!t) { + return luaL_error(L, "invalid arguments; userdata is not text"); + } + + in = t->start; + inlen = t->len; + } + else { + return luaL_error(L, "invalid arguments; userdata or string are expected"); + } + + if (in == NULL || inlen == 0) { + return luaL_error(L, "invalid arguments; non empty secret expected"); + } + + struct rspamd_lua_cryptobox_secretbox *sbox, **psbox; + + sbox = g_malloc0(sizeof(*sbox)); + crypto_generichash(sbox->sk, sizeof(sbox->sk), in, inlen, NULL, 0); + psbox = lua_newuserdata(L, sizeof(*psbox)); + *psbox = sbox; + rspamd_lua_setclass(L, "rspamd{cryptobox_secretbox}", -1); + + return 1; +} + + +static gint +lua_cryptobox_secretbox_gc(lua_State *L) +{ + struct rspamd_lua_cryptobox_secretbox *sbox = + lua_check_cryptobox_secretbox(L, 1); + + if (sbox != NULL) { + sodium_memzero(sbox->sk, sizeof(sbox->sk)); + g_free(sbox); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +/*** + * @method rspamd_cryptobox_secretbox:encrypt(input, [nonce]) + * Encrypts data using secretbox. MAC is prepended to the message + * @param {string/text} input input to encrypt + * @param {string/text} nonce optional nonce (must be 1 - 192 bits length) + * @param {table} params optional parameters - NYI + * @return {rspamd_text},{rspamd_text} output with mac + nonce or just output if nonce is there + */ +static gint +lua_cryptobox_secretbox_encrypt(lua_State *L) +{ + const gchar *in, *nonce; + gsize inlen, nlen; + struct rspamd_lua_cryptobox_secretbox *sbox = + lua_check_cryptobox_secretbox(L, 1); + struct rspamd_lua_text *out; + + if (sbox == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (lua_isstring(L, 2)) { + in = lua_tolstring(L, 2, &inlen); + } + else if (lua_isuserdata(L, 2)) { + struct rspamd_lua_text *t = lua_check_text(L, 2); + + if (!t) { + return luaL_error(L, "invalid arguments; userdata is not text"); + } + + in = t->start; + inlen = t->len; + } + else { + return luaL_error(L, "invalid arguments; userdata or string are expected"); + } + + /* Nonce part */ + if (!lua_isnoneornil(L, 3)) { + if (lua_isstring(L, 3)) { + nonce = lua_tolstring(L, 3, &nlen); + } + else if (lua_isuserdata(L, 3)) { + struct rspamd_lua_text *t = lua_check_text(L, 3); + + if (!t) { + return luaL_error(L, "invalid arguments; userdata is not text"); + } + + nonce = t->start; + nlen = t->len; + } + else { + return luaL_error(L, "invalid arguments; userdata or string are expected"); + } + + if (nlen < 1 || nlen > crypto_secretbox_NONCEBYTES) { + return luaL_error(L, "bad nonce"); + } + + guchar real_nonce[crypto_secretbox_NONCEBYTES]; + + memset(real_nonce, 0, sizeof(real_nonce)); + memcpy(real_nonce, nonce, nlen); + + out = lua_new_text(L, NULL, inlen + crypto_secretbox_MACBYTES, + TRUE); + crypto_secretbox_easy((guchar *) out->start, in, inlen, + nonce, sbox->sk); + + return 1; + } + else { + /* Random nonce */ + struct rspamd_lua_text *random_nonce; + + out = lua_new_text(L, NULL, inlen + crypto_secretbox_MACBYTES, + TRUE); + random_nonce = lua_new_text(L, NULL, crypto_secretbox_NONCEBYTES, TRUE); + + randombytes_buf((guchar *) random_nonce->start, random_nonce->len); + crypto_secretbox_easy((guchar *) out->start, in, inlen, + random_nonce->start, sbox->sk); + + return 2; /* output + random nonce */ + } +} + +/*** + * @method rspamd_cryptobox_secretbox:decrypt(input, nonce) + * Decrypts data using secretbox + * @param {string/text} nonce nonce used to encrypt + * @param {string/text} input input to decrypt + * @param {table} params optional parameters - NYI + * @return {boolean},{rspamd_text} decryption result + decrypted text + */ +static gint +lua_cryptobox_secretbox_decrypt(lua_State *L) +{ + const gchar *in, *nonce; + gsize inlen, nlen; + struct rspamd_lua_cryptobox_secretbox *sbox = + lua_check_cryptobox_secretbox(L, 1); + struct rspamd_lua_text *out; + + if (sbox == NULL) { + return luaL_error(L, "invalid arguments"); + } + + /* Input argument */ + if (lua_isstring(L, 2)) { + in = lua_tolstring(L, 2, &inlen); + } + else if (lua_isuserdata(L, 2)) { + struct rspamd_lua_text *t = lua_check_text(L, 2); + + if (!t) { + return luaL_error(L, "invalid arguments; userdata is not text"); + } + + in = t->start; + inlen = t->len; + } + else { + return luaL_error(L, "invalid arguments; userdata or string are expected"); + } + + /* Nonce argument */ + if (lua_isstring(L, 3)) { + nonce = lua_tolstring(L, 3, &nlen); + } + else if (lua_isuserdata(L, 3)) { + struct rspamd_lua_text *t = lua_check_text(L, 3); + + if (!t) { + return luaL_error(L, "invalid arguments; userdata is not text"); + } + + nonce = t->start; + nlen = t->len; + } + else { + return luaL_error(L, "invalid arguments; userdata or string are expected"); + } + + + if (nlen < 1 || nlen > crypto_secretbox_NONCEBYTES) { + lua_pushboolean(L, false); + lua_pushstring(L, "invalid nonce"); + return 2; + } + + if (inlen < crypto_secretbox_MACBYTES) { + lua_pushboolean(L, false); + lua_pushstring(L, "too short"); + return 2; + } + + guchar real_nonce[crypto_secretbox_NONCEBYTES]; + + memset(real_nonce, 0, sizeof(real_nonce)); + memcpy(real_nonce, nonce, nlen); + + out = lua_new_text(L, NULL, inlen - crypto_secretbox_MACBYTES, + TRUE); + gint text_pos = lua_gettop(L); + + if (crypto_secretbox_open_easy((guchar *) out->start, in, inlen, + nonce, sbox->sk) == 0) { + lua_pushboolean(L, true); + lua_pushvalue(L, text_pos); /* Prevent gc by copying in stack */ + } + else { + lua_pushboolean(L, false); + lua_pushstring(L, "authentication error"); + } + + /* This causes gc method if decryption has failed */ + lua_remove(L, text_pos); + + return 2; +} + +static gint +lua_load_pubkey(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, cryptoboxpubkeylib_f); + + return 1; +} + +static gint +lua_load_keypair(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, cryptoboxkeypairlib_f); + + return 1; +} + +static gint +lua_load_signature(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, cryptoboxsignlib_f); + + return 1; +} + +static gint +lua_load_hash(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, cryptoboxhashlib_f); + + return 1; +} + +static gint +lua_load_cryptobox_secretbox(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, cryptoboxsecretboxlib_f); + + return 1; +} + +static gint +lua_load_cryptobox(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, cryptoboxlib_f); + + return 1; +} + +void luaopen_cryptobox(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{cryptobox_pubkey}", cryptoboxpubkeylib_m); + lua_pop(L, 1); + rspamd_lua_add_preload(L, "rspamd_cryptobox_pubkey", lua_load_pubkey); + + rspamd_lua_new_class(L, "rspamd{cryptobox_keypair}", cryptoboxkeypairlib_m); + lua_pop(L, 1); + rspamd_lua_add_preload(L, "rspamd_cryptobox_keypair", lua_load_keypair); + + rspamd_lua_new_class(L, "rspamd{cryptobox_signature}", cryptoboxsignlib_m); + lua_pop(L, 1); + rspamd_lua_add_preload(L, "rspamd_cryptobox_signature", lua_load_signature); + + rspamd_lua_new_class(L, "rspamd{cryptobox_hash}", cryptoboxhashlib_m); + lua_pop(L, 1); + rspamd_lua_add_preload(L, "rspamd_cryptobox_hash", lua_load_hash); + + rspamd_lua_new_class(L, "rspamd{cryptobox_secretbox}", + cryptoboxsecretboxlib_m); + lua_pop(L, 1); + rspamd_lua_add_preload(L, "rspamd_cryptobox_secretbox", + lua_load_cryptobox_secretbox); + + rspamd_lua_add_preload(L, "rspamd_cryptobox", lua_load_cryptobox); + + lua_settop(L, 0); +} diff --git a/src/lua/lua_dns.c b/src/lua/lua_dns.c new file mode 100644 index 0000000..cffa312 --- /dev/null +++ b/src/lua/lua_dns.c @@ -0,0 +1,198 @@ +/*- + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" +#include "lua_dns_resolver.h" +#include "lua_thread_pool.h" + +LUA_FUNCTION_DEF(dns, request); + +static const struct luaL_reg dns_f[] = { + LUA_INTERFACE_DEF(dns, request), + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + +static const gchar *M = "rspamd lua dns"; + +void lua_dns_callback(struct rdns_reply *reply, void *arg); + +struct lua_rspamd_dns_cbdata { + struct thread_entry *thread; + struct rspamd_task *task; + struct rspamd_dns_resolver *resolver; + struct rspamd_symcache_dynamic_item *item; + struct rspamd_async_session *s; +}; + +static gint +lua_dns_request(lua_State *L) +{ + GError *err = NULL; + struct rspamd_async_session *session = NULL; + struct rspamd_config *cfg = NULL; + struct lua_rspamd_dns_cbdata *cbdata = NULL; + const gchar *to_resolve = NULL; + const gchar *type_str = NULL; + struct rspamd_task *task = NULL; + rspamd_mempool_t *pool = NULL; + gint ret = 0; + gboolean forced = FALSE; + + /* Check arguments */ + if (!rspamd_lua_parse_table_arguments(L, 1, &err, + RSPAMD_LUA_PARSE_ARGUMENTS_DEFAULT, + "*name=S;task=U{task};*type=S;forced=B;session=U{session};config=U{config}", + &to_resolve, + &task, + &type_str, + &forced, + &session, + &cfg)) { + + if (err) { + ret = luaL_error(L, "invalid arguments: %s", err->message); + g_error_free(err); + + return ret; + } + + return luaL_error(L, "invalid arguments"); + } + + if (task) { + session = task->s; + pool = task->task_pool; + cfg = task->cfg; + } + else if (session && cfg) { + pool = cfg->cfg_pool; + } + else { + return luaL_error(L, "invalid arguments: either task or session/config should be set"); + } + + enum rdns_request_type type = rdns_type_fromstr(type_str); + + if (type == RDNS_REQUEST_INVALID) { + return luaL_error(L, "invalid arguments: this record type is not supported"); + } + + cbdata = rspamd_mempool_alloc0(pool, sizeof(*cbdata)); + + cbdata->task = task; + + if (type == RDNS_REQUEST_PTR) { + char *ptr_str; + + ptr_str = rdns_generate_ptr_from_str(to_resolve); + + if (ptr_str == NULL) { + msg_err_task_check("wrong resolve string to PTR request: %s", + to_resolve); + lua_pushnil(L); + + return 1; + } + + to_resolve = rspamd_mempool_strdup(pool, ptr_str); + free(ptr_str); + } + + + if (task == NULL) { + ret = (rspamd_dns_resolver_request(cfg->dns_resolver, + session, + pool, + lua_dns_callback, + cbdata, + type, + to_resolve) != NULL); + } + else { + if (forced) { + ret = rspamd_dns_resolver_request_task_forced(task, + lua_dns_callback, + cbdata, + type, + to_resolve); + } + else { + ret = rspamd_dns_resolver_request_task(task, + lua_dns_callback, + cbdata, + type, + to_resolve); + } + } + + if (ret) { + cbdata->thread = lua_thread_pool_get_running_entry(cfg->lua_thread_pool); + cbdata->s = session; + + if (task) { + cbdata->item = rspamd_symcache_get_cur_item(task); + rspamd_symcache_item_async_inc(task, cbdata->item, M); + } + + return lua_thread_yield(cbdata->thread, 0); + } + else { + lua_pushnil(L); + return 1; + } +} + +void lua_dns_callback(struct rdns_reply *reply, void *arg) +{ + struct lua_rspamd_dns_cbdata *cbdata = arg; + lua_State *L = cbdata->thread->lua_state; + + if (reply->code != RDNS_RC_NOERROR) { + lua_pushboolean(L, false); + lua_pushstring(L, rdns_strerror(reply->code)); + } + else { + lua_push_dns_reply(L, reply); + + lua_pushboolean(L, reply->flags & RDNS_AUTH); + lua_setfield(L, -3, "authenticated"); + + lua_pushboolean(L, reply->flags & RDNS_TRUNCATED); + lua_setfield(L, -3, "truncated"); + + /* result 1 - not and error */ + lua_pushboolean(L, true); + /* push table into stack, result 2 - results itself */ + lua_pushvalue(L, -3); + } + + lua_thread_resume(cbdata->thread, 2); + + if (cbdata->item) { + rspamd_symcache_item_async_dec_check(cbdata->task, cbdata->item, M); + } +} + +static gint +lua_load_dns(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, dns_f); + + return 1; +} + +void luaopen_dns(lua_State *L) +{ + rspamd_lua_add_preload(L, "rspamd_dns", lua_load_dns); +} diff --git a/src/lua/lua_dns_resolver.c b/src/lua/lua_dns_resolver.c new file mode 100644 index 0000000..b022e13 --- /dev/null +++ b/src/lua/lua_dns_resolver.c @@ -0,0 +1,754 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" +#include "lua_thread_pool.h" +#include "utlist.h" + + +/*** + * @module rspamd_resolver + * This module allows to resolve DNS names from LUA code. All resolving is executed + * asynchronously. Here is an example of name resolution: + * @example +local function symbol_callback(task) + local host = 'example.com' + + local function dns_cb(resolver, to_resolve, results, err, _, authenticated) + if not results then + rspamd_logger.infox('DNS resolving of %1 failed: %2', host, err) + return + end + for _,r in ipairs(results) do + -- r is of type rspamd{ip} here, but it can be converted to string + rspamd_logger.infox('Resolved %1 to %2', host, tostring(r)) + end + end + + task:get_resolver():resolve_a({task = task, name = host, callback = dns_cb}) +end + */ + +static const gchar *M = "rspamd lua dns resolver"; + +/* Lua bindings */ +LUA_FUNCTION_DEF(dns_resolver, init); +LUA_FUNCTION_DEF(dns_resolver, resolve_a); +LUA_FUNCTION_DEF(dns_resolver, resolve_ptr); +LUA_FUNCTION_DEF(dns_resolver, resolve_txt); +LUA_FUNCTION_DEF(dns_resolver, resolve_mx); +LUA_FUNCTION_DEF(dns_resolver, resolve_ns); +LUA_FUNCTION_DEF(dns_resolver, resolve); +LUA_FUNCTION_DEF(dns_resolver, idna_convert_utf8); + +void lua_push_dns_reply(lua_State *L, const struct rdns_reply *reply); + +static const struct luaL_reg dns_resolverlib_f[] = { + LUA_INTERFACE_DEF(dns_resolver, init), + {NULL, NULL}}; + +static const struct luaL_reg dns_resolverlib_m[] = { + LUA_INTERFACE_DEF(dns_resolver, resolve_a), + LUA_INTERFACE_DEF(dns_resolver, resolve_ptr), + LUA_INTERFACE_DEF(dns_resolver, resolve_txt), + LUA_INTERFACE_DEF(dns_resolver, resolve_mx), + LUA_INTERFACE_DEF(dns_resolver, resolve_ns), + LUA_INTERFACE_DEF(dns_resolver, resolve), + LUA_INTERFACE_DEF(dns_resolver, idna_convert_utf8), + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + +struct rspamd_dns_resolver * +lua_check_dns_resolver(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{resolver}"); + luaL_argcheck(L, ud != NULL, pos, "'resolver' expected"); + return ud ? *((struct rspamd_dns_resolver **) ud) : NULL; +} + +struct lua_dns_cbdata { + struct rspamd_task *task; + rspamd_mempool_t *pool; + struct rspamd_dns_resolver *resolver; + gint cbref; + gchar *to_resolve; + gchar *user_str; + struct rspamd_symcache_dynamic_item *item; + struct rspamd_async_session *s; +}; + +static int +lua_dns_get_type(lua_State *L, int argno) +{ + int type = RDNS_REQUEST_A; + const gchar *strtype; + + if (lua_type(L, argno) != LUA_TSTRING) { + lua_pushvalue(L, argno); + lua_gettable(L, lua_upvalueindex(1)); + + type = lua_tonumber(L, -1); + lua_pop(L, 1); + if (type == 0) { + rspamd_lua_typerror(L, argno, "dns_request_type"); + } + } + else { + strtype = lua_tostring(L, argno); + type = rdns_type_fromstr(strtype); + } + + return type; +} + +static void +lua_dns_resolver_callback(struct rdns_reply *reply, gpointer arg) +{ + struct lua_dns_cbdata *cd = arg; + struct rspamd_dns_resolver **presolver; + lua_State *L; + struct lua_callback_state cbs; + rspamd_mempool_t *pool; + gint err_idx; + + pool = cd->pool; + lua_thread_pool_prepare_callback(cd->resolver->cfg->lua_thread_pool, &cbs); + L = cbs.L; + + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + + lua_rawgeti(L, LUA_REGISTRYINDEX, cd->cbref); + + presolver = lua_newuserdata(L, sizeof(gpointer)); + rspamd_lua_setclass(L, "rspamd{resolver}", -1); + + *presolver = cd->resolver; + lua_pushstring(L, cd->to_resolve); + + lua_push_dns_reply(L, reply); + + /* + * 1 - resolver + * 2 - to_resolve + * 3 - entries | nil + * 4 - error | nil + * 5 - user_str + * 6 - reply->flags & RDNS_AUTH + * 7 - server + */ + if (reply->code != RDNS_RC_NOERROR) { + lua_pushnil(L); + lua_pushstring(L, rdns_strerror(reply->code)); + } + if (cd->user_str != NULL) { + lua_pushstring(L, cd->user_str); + } + else { + lua_pushnil(L); + } + + lua_pushboolean(L, reply->flags & RDNS_AUTH); + + const gchar *servname = rdns_request_get_server(reply->request); + + if (servname) { + lua_pushstring(L, servname); + } + else { + lua_pushnil(L); + } + + if (cd->item) { + /* We also need to restore the item in case there are some chains */ + rspamd_symcache_set_cur_item(cd->task, cd->item); + } + + if (lua_pcall(L, 7, 0, err_idx) != 0) { + msg_err_pool_check("call to dns callback failed: %s", + lua_tostring(L, -1)); + } + + lua_settop(L, err_idx - 1); + + /* Unref function */ + luaL_unref(L, LUA_REGISTRYINDEX, cd->cbref); + lua_thread_pool_restore_callback(&cbs); + + if (cd->item) { + rspamd_symcache_item_async_dec_check(cd->task, cd->item, M); + } + + if (!cd->pool) { + g_free(cd->to_resolve); + g_free(cd->user_str); + g_free(cd); + } +} + +void lua_push_dns_reply(lua_State *L, const struct rdns_reply *reply) +{ + gint i = 0, naddrs = 0; + struct rdns_reply_entry *elt; + rspamd_inet_addr_t *addr; + + if (reply->code == RDNS_RC_NOERROR) { + LL_FOREACH(reply->entries, elt) + { + naddrs++; + } + + lua_createtable(L, naddrs, 0); + + LL_FOREACH(reply->entries, elt) + { + if (!rdns_request_has_type(reply->request, elt->type)) { + /* Unrequested type has been returned, ignore it */ + continue; + } + + switch (elt->type) { + case RDNS_REQUEST_A: + addr = rspamd_inet_address_new(AF_INET, &elt->content.a.addr); + rspamd_lua_ip_push(L, addr); + rspamd_inet_address_free(addr); + lua_rawseti(L, -2, ++i); + break; + case RDNS_REQUEST_AAAA: + addr = rspamd_inet_address_new(AF_INET6, &elt->content.aaa.addr); + rspamd_lua_ip_push(L, addr); + rspamd_inet_address_free(addr); + lua_rawseti(L, -2, ++i); + break; + case RDNS_REQUEST_NS: + lua_pushstring(L, elt->content.ns.name); + lua_rawseti(L, -2, ++i); + break; + case RDNS_REQUEST_PTR: + lua_pushstring(L, elt->content.ptr.name); + lua_rawseti(L, -2, ++i); + break; + case RDNS_REQUEST_TXT: + case RDNS_REQUEST_SPF: + lua_pushstring(L, elt->content.txt.data); + lua_rawseti(L, -2, ++i); + break; + case RDNS_REQUEST_MX: + /* mx['name'], mx['priority'] */ + lua_createtable(L, 0, 2); + rspamd_lua_table_set(L, "name", elt->content.mx.name); + lua_pushstring(L, "priority"); + lua_pushinteger(L, elt->content.mx.priority); + lua_settable(L, -3); + + lua_rawseti(L, -2, ++i); + break; + case RDNS_REQUEST_SOA: + lua_createtable(L, 0, 7); + rspamd_lua_table_set(L, "ns", elt->content.soa.mname); + rspamd_lua_table_set(L, "contact", elt->content.soa.admin); + lua_pushstring(L, "serial"); + lua_pushinteger(L, elt->content.soa.serial); + lua_settable(L, -3); + lua_pushstring(L, "refresh"); + lua_pushinteger(L, elt->content.soa.refresh); + lua_settable(L, -3); + lua_pushstring(L, "retry"); + lua_pushinteger(L, elt->content.soa.retry); + lua_settable(L, -3); + lua_pushstring(L, "expiry"); + lua_pushinteger(L, elt->content.soa.expire); + lua_settable(L, -3); + /* Negative TTL */ + lua_pushstring(L, "nx"); + lua_pushinteger(L, elt->content.soa.minimum); + lua_settable(L, -3); + + lua_rawseti(L, -2, ++i); + break; + case RDNS_REQUEST_CNAME: + lua_pushstring(L, elt->content.cname.name); + lua_rawseti(L, -2, ++i); + break; + default: + continue; + } + } + lua_pushnil(L); + } +} + +/*** + * @function rspamd_resolver.init(ev_base, config) + * @param {event_base} ev_base event base used for asynchronous events + * @param {rspamd_config} config rspamd configuration parameters + * @return {rspamd_resolver} new resolver object associated with the specified base + */ +static int +lua_dns_resolver_init(lua_State *L) +{ + struct rspamd_dns_resolver *resolver, **presolver; + struct rspamd_config *cfg, **pcfg; + struct ev_loop *base, **pbase; + + /* Check args */ + pbase = rspamd_lua_check_udata(L, 1, "rspamd{ev_base}"); + luaL_argcheck(L, pbase != NULL, 1, "'ev_base' expected"); + base = pbase ? *(pbase) : NULL; + pcfg = rspamd_lua_check_udata(L, 2, "rspamd{config}"); + luaL_argcheck(L, pcfg != NULL, 2, "'config' expected"); + cfg = pcfg ? *(pcfg) : NULL; + + if (base != NULL && cfg != NULL) { + resolver = rspamd_dns_resolver_init(NULL, base, cfg); + if (resolver) { + presolver = lua_newuserdata(L, sizeof(gpointer)); + rspamd_lua_setclass(L, "rspamd{resolver}", -1); + *presolver = resolver; + } + else { + lua_pushnil(L); + } + } + else { + lua_pushnil(L); + } + + return 1; +} + +static int +lua_dns_resolver_resolve_common(lua_State *L, + struct rspamd_dns_resolver *resolver, + enum rdns_request_type type, + int first) +{ + LUA_TRACE_POINT; + struct rspamd_async_session *session = NULL; + rspamd_mempool_t *pool = NULL; + const gchar *to_resolve = NULL, *user_str = NULL; + struct lua_dns_cbdata *cbdata; + gint cbref = -1, ret; + struct rspamd_task *task = NULL; + GError *err = NULL; + gboolean forced = FALSE; + struct rspamd_symcache_dynamic_item *item = NULL; + + /* Check arguments */ + if (!rspamd_lua_parse_table_arguments(L, first, &err, + RSPAMD_LUA_PARSE_ARGUMENTS_DEFAULT, + "session=U{session};mempool=U{mempool};*name=S;*callback=F;" + "option=S;task=U{task};forced=B", + &session, &pool, &to_resolve, &cbref, &user_str, &task, &forced)) { + + if (err) { + ret = luaL_error(L, "invalid arguments: %s", err->message); + g_error_free(err); + + return ret; + } + + return luaL_error(L, "invalid arguments"); + } + + if (task) { + pool = task->task_pool; + session = task->s; + item = rspamd_symcache_get_cur_item(task); + } + + if (to_resolve != NULL) { + if (pool != NULL) { + cbdata = rspamd_mempool_alloc0(pool, sizeof(struct lua_dns_cbdata)); + cbdata->user_str = rspamd_mempool_strdup(pool, user_str); + + if (type != RDNS_REQUEST_PTR) { + cbdata->to_resolve = rspamd_mempool_strdup(pool, to_resolve); + } + else { + char *ptr_str; + + ptr_str = rdns_generate_ptr_from_str(to_resolve); + + if (ptr_str == NULL) { + msg_err_task_check("wrong resolve string to PTR request: %s", + to_resolve); + goto err; + } + + cbdata->to_resolve = rspamd_mempool_strdup(pool, ptr_str); + to_resolve = cbdata->to_resolve; + free(ptr_str); + } + } + else { + cbdata = g_malloc0(sizeof(struct lua_dns_cbdata)); + cbdata->user_str = user_str ? g_strdup(user_str) : NULL; + + if (type != RDNS_REQUEST_PTR) { + cbdata->to_resolve = g_strdup(to_resolve); + } + else { + char *ptr_str; + + ptr_str = rdns_generate_ptr_from_str(to_resolve); + + if (ptr_str == NULL) { + msg_err_task_check("wrong resolve string to PTR request: %s", + to_resolve); + goto err; + } + + cbdata->to_resolve = g_strdup(ptr_str); + free(ptr_str); + } + } + + cbdata->resolver = resolver; + cbdata->cbref = cbref; + cbdata->task = task; + cbdata->pool = pool; + + if (task == NULL) { + if (rspamd_dns_resolver_request(resolver, + session, + pool, + lua_dns_resolver_callback, + cbdata, + type, + to_resolve)) { + + lua_pushboolean(L, TRUE); + + if (session) { + cbdata->s = session; + } + } + else { + goto err; + } + } + else { + /* Fail-safety as this function can, in theory, call + * lua_dns_resolver_callback without switching to the event loop + */ + if (item) { + rspamd_symcache_item_async_inc(task, item, M); + } + + if (forced) { + ret = rspamd_dns_resolver_request_task_forced(task, + lua_dns_resolver_callback, + cbdata, + type, + to_resolve); + } + else { + ret = rspamd_dns_resolver_request_task(task, + lua_dns_resolver_callback, + cbdata, + type, + to_resolve); + } + + if (ret) { + cbdata->s = session; + + if (item) { + cbdata->item = item; + rspamd_symcache_item_async_inc(task, item, M); + } + /* callback was set up */ + lua_pushboolean(L, TRUE); + } + else { + if (item) { + rspamd_symcache_item_async_dec_check(task, item, M); + } + + goto err; + } + + if (item) { + rspamd_symcache_item_async_dec_check(task, item, M); + } + } + } + else { + return luaL_error(L, "invalid arguments to lua_resolve"); + } + + return 1; + +err: + /* Callback is not called in this case */ + if (cbdata->cbref != -1) { + luaL_unref(L, LUA_REGISTRYINDEX, cbdata->cbref); + } + + if (!pool) { + /* Free resources */ + g_free(cbdata->to_resolve); + g_free(cbdata->user_str); + g_free(cbdata); + } + + lua_pushnil(L); + + return 1; +} + +/*** + * @method resolver:resolve_a(table) + * Resolve A record for a specified host. + * Table elements: + * * `task` - task element (preferred, required to track dependencies) -or- + * * `session` - asynchronous session normally associated with rspamd task (`task:get_session()`) + * * `mempool` - pool memory pool for storing intermediate data + * * `name` - host name to resolve + * * `callback` - callback callback function to be called upon name resolution is finished; must be of type `function (resolver, to_resolve, results, err)` + * * `forced` - true if needed to override normal limit for DNS requests + * @return {boolean} `true` if DNS request has been scheduled + */ +static int +lua_dns_resolver_resolve_a(lua_State *L) +{ + struct rspamd_dns_resolver *dns_resolver = lua_check_dns_resolver(L, 1); + + if (dns_resolver) { + return lua_dns_resolver_resolve_common(L, + dns_resolver, + RDNS_REQUEST_A, + 2); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** + * @method resolver:resolve_ptr(table) + * Resolve PTR record for a specified host. + * Table elements: + * * `task` - task element (preferred, required to track dependencies) -or- + * * `session` - asynchronous session normally associated with rspamd task (`task:get_session()`) + * * `mempool` - pool memory pool for storing intermediate data + * * `name` - host name to resolve + * * `callback` - callback callback function to be called upon name resolution is finished; must be of type `function (resolver, to_resolve, results, err)` + * * `forced` - true if needed to override normal limit for DNS requests + * @return {boolean} `true` if DNS request has been scheduled + */ +static int +lua_dns_resolver_resolve_ptr(lua_State *L) +{ + struct rspamd_dns_resolver *dns_resolver = lua_check_dns_resolver(L, 1); + + if (dns_resolver) { + return lua_dns_resolver_resolve_common(L, + dns_resolver, + RDNS_REQUEST_PTR, + 2); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** + * @method resolver:resolve_txt(table) + * Resolve TXT record for a specified host. + * Table elements: + * * `task` - task element (preferred, required to track dependencies) -or- + * * `session` - asynchronous session normally associated with rspamd task (`task:get_session()`) + * * `mempool` - pool memory pool for storing intermediate data + * * `name` - host name to resolve + * * `callback` - callback callback function to be called upon name resolution is finished; must be of type `function (resolver, to_resolve, results, err)` + * * `forced` - true if needed to override normal limit for DNS requests + * @return {boolean} `true` if DNS request has been scheduled + */ +static int +lua_dns_resolver_resolve_txt(lua_State *L) +{ + struct rspamd_dns_resolver *dns_resolver = lua_check_dns_resolver(L, 1); + + if (dns_resolver) { + return lua_dns_resolver_resolve_common(L, + dns_resolver, + RDNS_REQUEST_TXT, + 2); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** + * @method resolver:resolve_mx(table) + * Resolve MX record for a specified host. + * Table elements: + * * `task` - task element (preferred, required to track dependencies) -or- + * * `session` - asynchronous session normally associated with rspamd task (`task:get_session()`) + * * `mempool` - pool memory pool for storing intermediate data + * * `name` - host name to resolve + * * `callback` - callback callback function to be called upon name resolution is finished; must be of type `function (resolver, to_resolve, results, err)` + * * `forced` - true if needed to override normal limit for DNS requests + * @return {boolean} `true` if DNS request has been scheduled + */ +static int +lua_dns_resolver_resolve_mx(lua_State *L) +{ + struct rspamd_dns_resolver *dns_resolver = lua_check_dns_resolver(L, 1); + + if (dns_resolver) { + return lua_dns_resolver_resolve_common(L, + dns_resolver, + RDNS_REQUEST_MX, + 2); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** + * @method resolver:resolve_ns(table) + * Resolve NS records for a specified host. + * Table elements: + * * `task` - task element (preferred, required to track dependencies) -or- + * * `session` - asynchronous session normally associated with rspamd task (`task:get_session()`) + * * `mempool` - pool memory pool for storing intermediate data + * * `name` - host name to resolve + * * `callback` - callback callback function to be called upon name resolution is finished; must be of type `function (resolver, to_resolve, results, err)` + * * `forced` - true if needed to override normal limit for DNS requests + * @return {boolean} `true` if DNS request has been scheduled + */ +static int +lua_dns_resolver_resolve_ns(lua_State *L) +{ + struct rspamd_dns_resolver *dns_resolver = lua_check_dns_resolver(L, 1); + + if (dns_resolver) { + return lua_dns_resolver_resolve_common(L, + dns_resolver, + RDNS_REQUEST_NS, + 2); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/* XXX: broken currently */ +static int +lua_dns_resolver_resolve(lua_State *L) +{ + struct rspamd_dns_resolver *dns_resolver = lua_check_dns_resolver(L, 1); + int type; + + type = lua_dns_get_type(L, 2); + + if (dns_resolver && type != 0) { + return lua_dns_resolver_resolve_common(L, dns_resolver, type, 3); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** + * @method resolver:idna_convert_utf8(hostname[, pool]) + * Converts domain name from IDN (in utf8 format) to punycode + * @return {string} new name converted + */ +static int +lua_dns_resolver_idna_convert_utf8(lua_State *L) +{ + struct rspamd_dns_resolver *dns_resolver = lua_check_dns_resolver(L, 1); + gsize hlen; + guint conv_len = 0; + const gchar *hname = luaL_checklstring(L, 2, &hlen); + gchar *converted; + rspamd_mempool_t *pool = rspamd_lua_check_udata_maybe(L, 3, "rspamd{mempool}"); + + + if (dns_resolver && hname) { + if (!rspamd_str_has_8bit(hname, hlen)) { + /* No 8 bit, no reasons to call idna */ + lua_pushlstring(L, hname, hlen); + } + else { + converted = rspamd_dns_resolver_idna_convert_utf8(dns_resolver, pool, + hname, hlen, &conv_len); + + if (converted == NULL) { + lua_pushnil(L); + } + else { + lua_pushlstring(L, converted, conv_len); + + if (pool == NULL) { + g_free(converted); + } + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_load_dns_resolver(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, dns_resolverlib_f); + + return 1; +} + +void luaopen_dns_resolver(lua_State *L) +{ + + rspamd_lua_new_class(L, "rspamd{resolver}", dns_resolverlib_m); + { + LUA_ENUM(L, DNS_A, RDNS_REQUEST_A); + LUA_ENUM(L, DNS_PTR, RDNS_REQUEST_PTR); + LUA_ENUM(L, DNS_MX, RDNS_REQUEST_MX); + LUA_ENUM(L, DNS_TXT, RDNS_REQUEST_TXT); + LUA_ENUM(L, DNS_SRV, RDNS_REQUEST_SRV); + LUA_ENUM(L, DNS_SPF, RDNS_REQUEST_SPF); + LUA_ENUM(L, DNS_AAAA, RDNS_REQUEST_AAAA); + LUA_ENUM(L, DNS_SOA, RDNS_REQUEST_SOA); + LUA_ENUM(L, DNS_CNAME, RDNS_REQUEST_CNAME); + } + + lua_pop(L, 1); + + rspamd_lua_add_preload(L, "rspamd_resolver", lua_load_dns_resolver); +} diff --git a/src/lua/lua_dns_resolver.h b/src/lua/lua_dns_resolver.h new file mode 100644 index 0000000..515e9ac --- /dev/null +++ b/src/lua/lua_dns_resolver.h @@ -0,0 +1,15 @@ +#ifndef RSPAMD_LUA_DNS_H +#define RSPAMD_LUA_DNS_H + +struct lua_State; +struct rdns_reply; + +/** + * Pushes dns reply onto Lua stack + * + * @param L + * @param reply + */ +void lua_push_dns_reply(struct lua_State *L, const struct rdns_reply *reply); + +#endif diff --git a/src/lua/lua_expression.c b/src/lua/lua_expression.c new file mode 100644 index 0000000..1ac6f86 --- /dev/null +++ b/src/lua/lua_expression.c @@ -0,0 +1,512 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" +#include "expression.h" + +/*** + * @module rspamd_expression + * This module can be used to implement different logic expressions in lua using + * rspamd AST optimizer. There are some examples in individual methods definitions to help understanding of this module. + */ + +/*** + * @function rspamd_expression.create(line, {parse_func, [process_func]}, pool) + * Create expression from the line using atom parsing routines and the specified memory pool + * @param {string} line expression line + * @param {table} atom_functions parse_atom function and optional process_atom function + * @param {rspamd_mempool} memory pool to use for this function + * @return {expr, err} expression object and error message of `expr` is nil + * @example +require "fun" () +local rspamd_expression = require "rspamd_expression" +local rspamd_mempool = require "rspamd_mempool" + +local function parse_func(str) + -- extract token till the first space character + local token = table.concat(totable(take_while(function(s) return s ~= ' ' end, iter(str)))) + -- Return token name + return token +end + +local function process_func(token) + -- Do something using token and task +end + +local pool = rspamd_mempool.create() +local expr,err = rspamd_expression.create('A & B | !C', {parse_func, process_func}, pool) +-- Expression is destroyed when the corresponding pool is destroyed +pool:destroy() + */ +LUA_FUNCTION_DEF(expr, create); + +/*** + * @method rspamd_expression:to_string() + * Converts rspamd expression to string + * @return {string} string representation of rspamd expression + */ +LUA_FUNCTION_DEF(expr, to_string); + +/*** + * @method rspamd_expression:process([callback, ]input[, flags]) + * Executes the expression and pass input to process atom callbacks + * @param {function} callback if not specified on process, then callback must be here + * @param {any} input input data for processing callbacks + * @return {number} result of the expression evaluation + */ +LUA_FUNCTION_DEF(expr, process); + +/*** + * @method rspamd_expression:process_traced([callback, ]input[, flags]) + * Executes the expression and pass input to process atom callbacks. This function also saves the full trace + * @param {function} callback if not specified on process, then callback must be here + * @param {any} input input data for processing callbacks + * @return {number, table of matched atoms} result of the expression evaluation + */ +LUA_FUNCTION_DEF(expr, process_traced); + +/*** + * @method rspamd_expression:atoms() + * Extract all atoms from the expression as table of strings + * @return {table/strings} list of all atoms in the expression + */ +LUA_FUNCTION_DEF(expr, atoms); + +static const struct luaL_reg exprlib_m[] = { + LUA_INTERFACE_DEF(expr, to_string), + LUA_INTERFACE_DEF(expr, atoms), + LUA_INTERFACE_DEF(expr, process), + LUA_INTERFACE_DEF(expr, process_traced), + {"__tostring", lua_expr_to_string}, + {NULL, NULL}}; + +static const struct luaL_reg exprlib_f[] = { + LUA_INTERFACE_DEF(expr, create), + {NULL, NULL}}; + +static rspamd_expression_atom_t *lua_atom_parse(const gchar *line, gsize len, + rspamd_mempool_t *pool, gpointer ud, GError **err); +static gdouble lua_atom_process(gpointer runtime_ud, rspamd_expression_atom_t *atom); + +static const struct rspamd_atom_subr lua_atom_subr = { + .parse = lua_atom_parse, + .process = lua_atom_process, + .priority = NULL, + .destroy = NULL}; + +struct lua_expression { + struct rspamd_expression *expr; + gint parse_idx; + gint process_idx; + lua_State *L; + rspamd_mempool_t *pool; +}; + +static GQuark +lua_expr_quark(void) +{ + return g_quark_from_static_string("lua-expression"); +} + +struct lua_expression * +rspamd_lua_expression(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{expr}"); + luaL_argcheck(L, ud != NULL, pos, "'expr' expected"); + return ud ? *((struct lua_expression **) ud) : NULL; +} + +static void +lua_expr_dtor(gpointer p) +{ + struct lua_expression *e = (struct lua_expression *) p; + + if (e->parse_idx != -1) { + luaL_unref(e->L, LUA_REGISTRYINDEX, e->parse_idx); + } + + if (e->process_idx != -1) { + luaL_unref(e->L, LUA_REGISTRYINDEX, e->process_idx); + } +} + +static rspamd_expression_atom_t * +lua_atom_parse(const gchar *line, gsize len, + rspamd_mempool_t *pool, gpointer ud, GError **err) +{ + struct lua_expression *e = (struct lua_expression *) ud; + rspamd_expression_atom_t *atom; + gsize rlen; + const gchar *tok; + + lua_rawgeti(e->L, LUA_REGISTRYINDEX, e->parse_idx); + lua_pushlstring(e->L, line, len); + + if (lua_pcall(e->L, 1, 1, 0) != 0) { + msg_info("callback call failed: %s", lua_tostring(e->L, -1)); + lua_pop(e->L, 1); + return NULL; + } + + if (lua_type(e->L, -1) != LUA_TSTRING) { + g_set_error(err, lua_expr_quark(), 500, "cannot parse lua atom"); + lua_pop(e->L, 1); + return NULL; + } + + tok = lua_tolstring(e->L, -1, &rlen); + atom = rspamd_mempool_alloc0(e->pool, sizeof(*atom)); + atom->str = rspamd_mempool_strdup(e->pool, tok); + atom->len = rlen; + atom->data = ud; + + lua_pop(e->L, 1); + + return atom; +} + +struct lua_atom_process_data { + lua_State *L; + struct lua_expression *e; + gint process_cb_pos; + gint stack_item; +}; + +static gdouble +lua_atom_process(gpointer runtime_ud, rspamd_expression_atom_t *atom) +{ + struct lua_atom_process_data *pd = (struct lua_atom_process_data *) runtime_ud; + gdouble ret = 0; + guint nargs; + gint err_idx; + + if (pd->stack_item != -1) { + nargs = 2; + } + else { + nargs = 1; + } + + lua_pushcfunction(pd->L, &rspamd_lua_traceback); + err_idx = lua_gettop(pd->L); + + /* Function position */ + lua_pushvalue(pd->L, pd->process_cb_pos); + /* Atom name */ + lua_pushlstring(pd->L, atom->str, atom->len); + + /* If we have data passed */ + if (pd->stack_item != -1) { + lua_pushvalue(pd->L, pd->stack_item); + } + + if (lua_pcall(pd->L, nargs, 1, err_idx) != 0) { + msg_info("expression process callback failed: %s", lua_tostring(pd->L, -1)); + } + else { + ret = lua_tonumber(pd->L, -1); + } + + lua_settop(pd->L, err_idx - 1); + + return ret; +} + +static gint +lua_expr_process(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_expression *e = rspamd_lua_expression(L, 1); + struct lua_atom_process_data pd; + gdouble res; + gint flags = 0, old_top; + + pd.L = L; + pd.e = e; + old_top = lua_gettop(L); + + if (e->process_idx == -1) { + if (!lua_isfunction(L, 2)) { + return luaL_error(L, "expression process is called with no callback function"); + } + + pd.process_cb_pos = 2; + + if (lua_type(L, 3) != LUA_TNONE && lua_type(L, 3) != LUA_TNIL) { + pd.stack_item = 3; + } + else { + pd.stack_item = -1; + } + + if (lua_isnumber(L, 4)) { + flags = lua_tointeger(L, 4); + } + } + else { + lua_rawgeti(L, LUA_REGISTRYINDEX, e->process_idx); + pd.process_cb_pos = lua_gettop(L); + + if (lua_type(L, 2) != LUA_TNONE && lua_type(L, 2) != LUA_TNIL) { + pd.stack_item = 2; + } + else { + pd.stack_item = -1; + } + + if (lua_isnumber(L, 3)) { + flags = lua_tointeger(L, 3); + } + } + + res = rspamd_process_expression(e->expr, flags, &pd); + + lua_settop(L, old_top); + lua_pushnumber(L, res); + + return 1; +} + +static gint +lua_expr_process_traced(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_expression *e = rspamd_lua_expression(L, 1); + struct lua_atom_process_data pd; + gdouble res; + gint flags = 0, old_top; + GPtrArray *trace; + + pd.L = L; + pd.e = e; + old_top = lua_gettop(L); + + if (e->process_idx == -1) { + if (!lua_isfunction(L, 2)) { + return luaL_error(L, "expression process is called with no callback function"); + } + + pd.process_cb_pos = 2; + pd.stack_item = 3; + + if (lua_isnumber(L, 4)) { + flags = lua_tointeger(L, 4); + } + } + else { + lua_rawgeti(L, LUA_REGISTRYINDEX, e->process_idx); + pd.process_cb_pos = lua_gettop(L); + pd.stack_item = 2; + + if (lua_isnumber(L, 3)) { + flags = lua_tointeger(L, 3); + } + } + + res = rspamd_process_expression_track(e->expr, flags, &pd, &trace); + + lua_settop(L, old_top); + lua_pushnumber(L, res); + + lua_createtable(L, trace->len, 0); + + for (guint i = 0; i < trace->len; i++) { + struct rspamd_expression_atom_s *atom = g_ptr_array_index(trace, i); + + lua_pushlstring(L, atom->str, atom->len); + lua_rawseti(L, -2, i + 1); + } + + g_ptr_array_free(trace, TRUE); + + return 2; +} + +static gint +lua_expr_create(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_expression *e, **pe; + const char *line; + gsize len; + gboolean no_process = FALSE; + GError *err = NULL; + rspamd_mempool_t *pool; + + /* Check sanity of the arguments */ + if (lua_type(L, 1) != LUA_TSTRING || + (lua_type(L, 2) != LUA_TTABLE && lua_type(L, 2) != LUA_TFUNCTION) || + rspamd_lua_check_mempool(L, 3) == NULL) { + lua_pushnil(L); + lua_pushstring(L, "bad arguments"); + } + else { + line = lua_tolstring(L, 1, &len); + pool = rspamd_lua_check_mempool(L, 3); + + e = rspamd_mempool_alloc(pool, sizeof(*e)); + e->L = L; + e->pool = pool; + + /* Check callbacks */ + if (lua_istable(L, 2)) { + lua_pushvalue(L, 2); + lua_pushnumber(L, 1); + lua_gettable(L, -2); + + if (lua_type(L, -1) != LUA_TFUNCTION) { + lua_pop(L, 1); + lua_pushnil(L); + lua_pushstring(L, "bad parse callback"); + + return 2; + } + + lua_pop(L, 1); + + lua_pushnumber(L, 2); + lua_gettable(L, -2); + + if (lua_type(L, -1) != LUA_TFUNCTION) { + if (lua_type(L, -1) != LUA_TNIL && lua_type(L, -1) != LUA_TNONE) { + lua_pop(L, 1); + lua_pushnil(L); + lua_pushstring(L, "bad process callback"); + + return 2; + } + else { + no_process = TRUE; + } + } + + lua_pop(L, 1); + /* Table is still on the top of stack */ + + lua_pushnumber(L, 1); + lua_gettable(L, -2); + e->parse_idx = luaL_ref(L, LUA_REGISTRYINDEX); + + if (!no_process) { + lua_pushnumber(L, 2); + lua_gettable(L, -2); + e->process_idx = luaL_ref(L, LUA_REGISTRYINDEX); + } + else { + e->process_idx = -1; + } + + lua_pop(L, 1); /* Table */ + } + else { + /* Process function is just a function, not a table */ + lua_pushvalue(L, 2); + e->parse_idx = luaL_ref(L, LUA_REGISTRYINDEX); + e->process_idx = -1; + } + + if (!rspamd_parse_expression(line, len, &lua_atom_subr, e, pool, &err, + &e->expr)) { + lua_pushnil(L); + lua_pushstring(L, err->message); + g_error_free(err); + lua_expr_dtor(e); + + return 2; + } + + rspamd_mempool_add_destructor(pool, lua_expr_dtor, e); + pe = lua_newuserdata(L, sizeof(struct lua_expression *)); + rspamd_lua_setclass(L, "rspamd{expr}", -1); + *pe = e; + lua_pushnil(L); + } + + return 2; +} + +static gint +lua_expr_to_string(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_expression *e = rspamd_lua_expression(L, 1); + GString *str; + + if (e != NULL && e->expr != NULL) { + str = rspamd_expression_tostring(e->expr); + if (str) { + lua_pushlstring(L, str->str, str->len); + g_string_free(str, TRUE); + } + else { + lua_pushnil(L); + } + } + else { + lua_pushnil(L); + } + + return 1; +} + +struct lua_expr_atoms_cbdata { + lua_State *L; + gint idx; +}; + +static void +lua_exr_atom_cb(const rspamd_ftok_t *tok, gpointer ud) +{ + struct lua_expr_atoms_cbdata *cbdata = ud; + + lua_pushlstring(cbdata->L, tok->begin, tok->len); + lua_rawseti(cbdata->L, -2, cbdata->idx++); +} + +static gint +lua_expr_atoms(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_expression *e = rspamd_lua_expression(L, 1); + struct lua_expr_atoms_cbdata cbdata; + + if (e != NULL && e->expr != NULL) { + lua_newtable(L); + cbdata.L = L; + cbdata.idx = 1; + rspamd_expression_atom_foreach(e->expr, lua_exr_atom_cb, &cbdata); + } + else { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_load_expression(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, exprlib_f); + + return 1; +} + +void luaopen_expression(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{expr}", exprlib_m); + lua_pop(L, 1); + rspamd_lua_add_preload(L, "rspamd_expression", lua_load_expression); +} diff --git a/src/lua/lua_html.cxx b/src/lua/lua_html.cxx new file mode 100644 index 0000000..6613337 --- /dev/null +++ b/src/lua/lua_html.cxx @@ -0,0 +1,738 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" +#include "message.h" +#include "libserver/html/html.h" +#include "libserver/html/html.hxx" +#include "libserver/html/html_tag.hxx" +#include "libserver/html/html_block.hxx" +#include "images.h" + +#include "contrib/ankerl/unordered_dense.h" +#include <frozen/string.h> +#include <frozen/unordered_map.h> + +/*** + * @module rspamd_html + * This module provides different methods to access HTML tags. To get HTML context + * from an HTML part you could use method `part:get_html()` + * @example +rspamd_config.R_EMPTY_IMAGE = function(task) + local tp = task:get_text_parts() -- get text parts in a message + + for _,p in ipairs(tp) do -- iterate over text parts array using `ipairs` + if p:is_html() then -- if the current part is html part + local hc = p:get_html() -- we get HTML context + local len = p:get_length() -- and part's length + + if len < 50 then -- if we have a part that has less than 50 bytes of text + local images = hc:get_images() -- then we check for HTML images + + if images then -- if there are images + for _,i in ipairs(images) do -- then iterate over images in the part + if i['height'] + i['width'] >= 400 then -- if we have a large image + return true -- add symbol + end + end + end + end + end + end +end + */ + +/*** + * @method html:has_tag(name) + * Checks if a specified tag `name` is presented in a part + * @param {string} name name of tag to check + * @return {boolean} `true` if the tag exists in HTML tree + */ +LUA_FUNCTION_DEF(html, has_tag); + +/*** + * @method html:check_property(name) + * Checks if the HTML has a specific property. Here is the list of available properties: + * + * - `no_html` - no html tag presented + * - `bad_element` - part has some broken elements + * - `xml` - part is xhtml + * - `unknown_element` - part has some unknown elements + * - `duplicate_element` - part has some duplicate elements that should be unique (namely, `title` tag) + * - `unbalanced` - part has unbalanced tags + * @param {string} name name of property + * @return {boolean} true if the part has the specified property + */ +LUA_FUNCTION_DEF(html, has_property); + +/*** + * @method html:get_images() + * Returns a table of images found in html. Each image is, in turn, a table with the following fields: + * + * - `src` - link to the source + * - `height` - height in pixels + * - `width` - width in pixels + * - `embedded` - `true` if an image is embedded in a message + * @return {table} table of images in html part + */ +LUA_FUNCTION_DEF(html, get_images); + +/*** + * @method html:foreach_tag(tagname, callback) + * Processes HTML tree calling the specified callback for each tag of the specified + * type. + * + * Callback is called with the following attributes: + * + * - `tag`: html tag structure + * - `content_length`: length of content within a tag + * + * Callback function should return `true` to **stop** processing and `false` to continue + * @return nothing + */ +LUA_FUNCTION_DEF(html, foreach_tag); + +/*** + * @method html:get_invisible() + * Returns invisible content of the HTML data + * @return + */ +LUA_FUNCTION_DEF(html, get_invisible); + +static const struct luaL_reg htmllib_m[] = { + LUA_INTERFACE_DEF(html, has_tag), + LUA_INTERFACE_DEF(html, has_property), + LUA_INTERFACE_DEF(html, get_images), + LUA_INTERFACE_DEF(html, foreach_tag), + LUA_INTERFACE_DEF(html, get_invisible), + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + +/*** + * @method html_tag:get_type() + * Returns string representation of HTML type for a tag + * @return {string} type of tag + */ +LUA_FUNCTION_DEF(html_tag, get_type); +/*** + * @method html_tag:get_extra() + * Returns extra data associated with the tag + * @return {url|image|nil} extra data associated with the tag + */ +LUA_FUNCTION_DEF(html_tag, get_extra); +/*** + * @method html_tag:get_parent() + * Returns parent node for a specified tag + * @return {html_tag} parent object for a specified tag + */ +LUA_FUNCTION_DEF(html_tag, get_parent); + +/*** + * @method html_tag:get_flags() + * Returns flags a specified tag: + * + * - `closed`: tag is properly closed + * - `closing`: tag is a closing tag + * - `broken`: tag is somehow broken + * - `unbalanced`: tag is unbalanced + * - `xml`: tag is xml tag + * @return {table} table of flags + */ +LUA_FUNCTION_DEF(html_tag, get_flags); +/*** + * @method html_tag:get_content() + * Returns content of tag (approximate for some cases) + * @return {rspamd_text} rspamd text with tag's content + */ +LUA_FUNCTION_DEF(html_tag, get_content); +/*** + * @method html_tag:get_content_length() + * Returns length of a tag's content + * @return {number} size of content enclosed within a tag + */ +LUA_FUNCTION_DEF(html_tag, get_content_length); + +/*** + * @method html_tag:get_style() + * Returns style calculated for the element + * @return {table} table associated with the style + */ +LUA_FUNCTION_DEF(html_tag, get_style); + +/*** + * @method html_tag:get_attribute(name) + * Returns value of attribute for the element + * Refer to `html_components_map` in `src/libserver/html/html.cxx` for recognised names + * @return {string|nil} value of the attribute + */ +LUA_FUNCTION_DEF(html_tag, get_attribute); + +static const struct luaL_reg taglib_m[] = { + LUA_INTERFACE_DEF(html_tag, get_type), + LUA_INTERFACE_DEF(html_tag, get_extra), + LUA_INTERFACE_DEF(html_tag, get_parent), + LUA_INTERFACE_DEF(html_tag, get_flags), + LUA_INTERFACE_DEF(html_tag, get_content), + LUA_INTERFACE_DEF(html_tag, get_content_length), + LUA_INTERFACE_DEF(html_tag, get_style), + LUA_INTERFACE_DEF(html_tag, get_attribute), + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + +static struct rspamd::html::html_content * +lua_check_html(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{html}"); + luaL_argcheck(L, ud != NULL, pos, "'html' expected"); + return ud ? *((struct rspamd::html::html_content **) ud) : NULL; +} + +struct lua_html_tag { + rspamd::html::html_content *html; + const rspamd::html::html_tag *tag; +}; + +static struct lua_html_tag * +lua_check_html_tag(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{html_tag}"); + luaL_argcheck(L, ud != NULL, pos, "'html_tag' expected"); + return ud ? ((struct lua_html_tag *) ud) : NULL; +} + +static gint +lua_html_has_tag(lua_State *L) +{ + LUA_TRACE_POINT; + auto *hc = lua_check_html(L, 1); + const gchar *tagname = luaL_checkstring(L, 2); + gboolean ret = FALSE; + + if (hc && tagname) { + if (rspamd_html_tag_seen(hc, tagname)) { + ret = TRUE; + } + } + + lua_pushboolean(L, ret); + + return 1; +} + +constexpr const auto prop_map = frozen::make_unordered_map<frozen::string, int>({ + {"no_html", RSPAMD_HTML_FLAG_BAD_START}, + {"bad_start", RSPAMD_HTML_FLAG_BAD_START}, + {"bad_element", RSPAMD_HTML_FLAG_BAD_ELEMENTS}, + {"bad_elements", RSPAMD_HTML_FLAG_BAD_ELEMENTS}, + {"xml", RSPAMD_HTML_FLAG_XML}, + {"unknown_element", RSPAMD_HTML_FLAG_UNKNOWN_ELEMENTS}, + {"unknown_elements", RSPAMD_HTML_FLAG_UNKNOWN_ELEMENTS}, + {"duplicate_element", RSPAMD_HTML_FLAG_DUPLICATE_ELEMENTS}, + {"duplicate_elements", RSPAMD_HTML_FLAG_DUPLICATE_ELEMENTS}, + {"unbalanced", RSPAMD_HTML_FLAG_UNBALANCED}, + {"data_urls", RSPAMD_HTML_FLAG_HAS_DATA_URLS}, +}); + +static gint +lua_html_has_property(lua_State *L) +{ + LUA_TRACE_POINT; + auto *hc = lua_check_html(L, 1); + const gchar *propname = luaL_checkstring(L, 2); + gboolean ret = FALSE; + + if (hc && propname) { + auto found_prop = prop_map.find(frozen::string(propname)); + + if (found_prop != prop_map.end()) { + ret = hc->flags & found_prop->second; + } + } + + lua_pushboolean(L, ret); + + return 1; +} + +static void +lua_html_push_image(lua_State *L, const struct html_image *img) +{ + LUA_TRACE_POINT; + struct lua_html_tag *ltag; + struct rspamd_url **purl; + + lua_createtable(L, 0, 7); + + if (img->src) { + lua_pushstring(L, "src"); + + if (img->flags & RSPAMD_HTML_FLAG_IMAGE_DATA) { + struct rspamd_lua_text *t; + + t = static_cast<rspamd_lua_text *>(lua_newuserdata(L, sizeof(*t))); + t->start = img->src; + t->len = strlen(img->src); + t->flags = 0; + + rspamd_lua_setclass(L, "rspamd{text}", -1); + } + else { + lua_pushstring(L, img->src); + } + + lua_settable(L, -3); + } + + if (img->url) { + lua_pushstring(L, "url"); + purl = static_cast<rspamd_url **>(lua_newuserdata(L, sizeof(gpointer))); + *purl = img->url; + rspamd_lua_setclass(L, "rspamd{url}", -1); + lua_settable(L, -3); + } + + if (img->tag) { + lua_pushstring(L, "tag"); + ltag = static_cast<lua_html_tag *>(lua_newuserdata(L, sizeof(struct lua_html_tag))); + ltag->tag = static_cast<rspamd::html::html_tag *>(img->tag); + ltag->html = NULL; + rspamd_lua_setclass(L, "rspamd{html_tag}", -1); + lua_settable(L, -3); + } + + lua_pushstring(L, "height"); + lua_pushinteger(L, img->height); + lua_settable(L, -3); + lua_pushstring(L, "width"); + lua_pushinteger(L, img->width); + lua_settable(L, -3); + lua_pushstring(L, "embedded"); + lua_pushboolean(L, img->flags & RSPAMD_HTML_FLAG_IMAGE_EMBEDDED); + lua_settable(L, -3); + lua_pushstring(L, "data"); + lua_pushboolean(L, img->flags & RSPAMD_HTML_FLAG_IMAGE_DATA); + lua_settable(L, -3); +} + +static gint +lua_html_get_images(lua_State *L) +{ + LUA_TRACE_POINT; + auto *hc = lua_check_html(L, 1); + guint i = 1; + + if (hc != NULL) { + lua_createtable(L, hc->images.size(), 0); + + for (const auto *img: hc->images) { + lua_html_push_image(L, img); + lua_rawseti(L, -2, i++); + } + } + else { + lua_newtable(L); + } + + return 1; +} + +static void +lua_html_push_block(lua_State *L, const struct rspamd::html::html_block *bl) +{ + LUA_TRACE_POINT; + + lua_createtable(L, 0, 6); + + if (bl->fg_color_mask) { + lua_pushstring(L, "color"); + lua_createtable(L, 4, 0); + lua_pushinteger(L, bl->fg_color.r); + lua_rawseti(L, -2, 1); + lua_pushinteger(L, bl->fg_color.g); + lua_rawseti(L, -2, 2); + lua_pushinteger(L, bl->fg_color.b); + lua_rawseti(L, -2, 3); + lua_pushinteger(L, bl->fg_color.alpha); + lua_rawseti(L, -2, 4); + lua_settable(L, -3); + } + if (bl->bg_color_mask) { + lua_pushstring(L, "bgcolor"); + lua_createtable(L, 4, 0); + lua_pushinteger(L, bl->bg_color.r); + lua_rawseti(L, -2, 1); + lua_pushinteger(L, bl->bg_color.g); + lua_rawseti(L, -2, 2); + lua_pushinteger(L, bl->bg_color.b); + lua_rawseti(L, -2, 3); + lua_pushinteger(L, bl->bg_color.alpha); + lua_rawseti(L, -2, 4); + lua_settable(L, -3); + } + + if (bl->font_mask) { + lua_pushstring(L, "font_size"); + lua_pushinteger(L, bl->font_size); + lua_settable(L, -3); + } + + lua_pushstring(L, "visible"); + lua_pushboolean(L, bl->is_visible()); + lua_settable(L, -3); + + lua_pushstring(L, "transparent"); + lua_pushboolean(L, bl->is_transparent()); + lua_settable(L, -3); +} + +static gint +lua_html_foreach_tag(lua_State *L) +{ + LUA_TRACE_POINT; + auto *hc = lua_check_html(L, 1); + const gchar *tagname; + gint id; + auto any = false; + ankerl::unordered_dense::set<int> tags; + + + if (lua_type(L, 2) == LUA_TSTRING) { + tagname = luaL_checkstring(L, 2); + if (strcmp(tagname, "any") == 0) { + any = true; + } + else { + id = rspamd_html_tag_by_name(tagname); + + if (id == -1) { + return luaL_error(L, "invalid tagname: %s", tagname); + } + + + tags.insert(id); + } + } + else if (lua_type(L, 2) == LUA_TTABLE) { + lua_pushvalue(L, 2); + + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + tagname = luaL_checkstring(L, -1); + if (strcmp(tagname, "any") == 0) { + any = TRUE; + } + else { + id = rspamd_html_tag_by_name(tagname); + + if (id == -1) { + return luaL_error(L, "invalid tagname: %s", tagname); + } + tags.insert(id); + } + } + + lua_pop(L, 1); + } + + if (hc && (any || !tags.empty()) && lua_isfunction(L, 3)) { + hc->traverse_all_tags([&](const rspamd::html::html_tag *tag) -> bool { + if (tag && (any || tags.contains(tag->id))) { + lua_pushcfunction(L, &rspamd_lua_traceback); + auto err_idx = lua_gettop(L); + lua_pushvalue(L, 3); + + auto *ltag = static_cast<lua_html_tag *>(lua_newuserdata(L, sizeof(lua_html_tag))); + ltag->tag = tag; + ltag->html = hc; + auto ct = ltag->tag->get_content(hc); + rspamd_lua_setclass(L, "rspamd{html_tag}", -1); + lua_pushinteger(L, ct.size()); + + /* Leaf flag */ + if (tag->children.empty()) { + lua_pushboolean(L, true); + } + else { + lua_pushboolean(L, false); + } + + if (lua_pcall(L, 3, 1, err_idx) != 0) { + msg_err("error in foreach_tag callback: %s", lua_tostring(L, -1)); + lua_settop(L, err_idx - 1); + return false; + } + + if (lua_toboolean(L, -1)) { + lua_settop(L, err_idx - 1); + return false; + } + + lua_settop(L, err_idx - 1); + } + + return true; + }); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static gint +lua_html_get_invisible(lua_State *L) +{ + LUA_TRACE_POINT; + auto *hc = lua_check_html(L, 1); + + if (hc != NULL) { + lua_new_text(L, hc->invisible.c_str(), hc->invisible.size(), false); + } + else { + lua_newtable(L); + } + + return 1; +} + +static gint +lua_html_tag_get_type(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_html_tag *ltag = lua_check_html_tag(L, 1); + const gchar *tagname; + + if (ltag != NULL) { + tagname = rspamd_html_tag_by_id(ltag->tag->id); + + if (tagname) { + lua_pushstring(L, tagname); + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_html_tag_get_parent(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_html_tag *ltag = lua_check_html_tag(L, 1), *ptag; + + if (ltag != NULL) { + auto *parent = ltag->tag->parent; + + if (parent) { + ptag = static_cast<lua_html_tag *>(lua_newuserdata(L, sizeof(*ptag))); + ptag->tag = static_cast<rspamd::html::html_tag *>(parent); + ptag->html = ltag->html; + rspamd_lua_setclass(L, "rspamd{html_tag}", -1); + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_html_tag_get_flags(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_html_tag *ltag = lua_check_html_tag(L, 1); + gint i = 1; + + if (ltag && ltag->tag) { + /* Push flags */ + lua_createtable(L, 4, 0); + if (ltag->tag->flags & FL_HREF) { + lua_pushstring(L, "href"); + lua_rawseti(L, -2, i++); + } + if (ltag->tag->flags & FL_CLOSED) { + lua_pushstring(L, "closed"); + lua_rawseti(L, -2, i++); + } + if (ltag->tag->flags & FL_BROKEN) { + lua_pushstring(L, "broken"); + lua_rawseti(L, -2, i++); + } + if (ltag->tag->flags & FL_XML) { + lua_pushstring(L, "xml"); + lua_rawseti(L, -2, i++); + } + if (ltag->tag->flags & RSPAMD_HTML_FLAG_UNBALANCED) { + lua_pushstring(L, "unbalanced"); + lua_rawseti(L, -2, i++); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_html_tag_get_content(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_html_tag *ltag = lua_check_html_tag(L, 1); + struct rspamd_lua_text *t; + + if (ltag) { + + if (ltag->html) { + auto ct = ltag->tag->get_content(ltag->html); + if (ct.size() > 0) { + t = static_cast<rspamd_lua_text *>(lua_newuserdata(L, sizeof(*t))); + rspamd_lua_setclass(L, "rspamd{text}", -1); + t->start = ct.data(); + t->len = ct.size(); + t->flags = 0; + } + else { + lua_pushnil(L); + } + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_html_tag_get_content_length(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_html_tag *ltag = lua_check_html_tag(L, 1); + + if (ltag) { + if (ltag->html) { + auto ct = ltag->tag->get_content(ltag->html); + lua_pushinteger(L, ct.size()); + } + else { + lua_pushinteger(L, ltag->tag->get_content_length()); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_html_tag_get_extra(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_html_tag *ltag = lua_check_html_tag(L, 1); + struct html_image *img; + + if (ltag) { + if (!std::holds_alternative<std::monostate>(ltag->tag->extra)) { + if (std::holds_alternative<struct html_image *>(ltag->tag->extra)) { + img = std::get<struct html_image *>(ltag->tag->extra); + lua_html_push_image(L, img); + } + else if (std::holds_alternative<struct rspamd_url *>(ltag->tag->extra)) { + /* For A that's URL */ + auto *lua_url = static_cast<rspamd_lua_url *>(lua_newuserdata(L, sizeof(rspamd_lua_url))); + lua_url->url = std::get<struct rspamd_url *>(ltag->tag->extra); + rspamd_lua_setclass(L, "rspamd{url}", -1); + } + else { + /* Unknown extra ? */ + lua_pushnil(L); + } + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_html_tag_get_style(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_html_tag *ltag = lua_check_html_tag(L, 1); + + if (ltag) { + if (ltag->tag->block) { + lua_html_push_block(L, ltag->tag->block); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_html_tag_get_attribute(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_html_tag *ltag = lua_check_html_tag(L, 1); + gsize slen; + const gchar *attr_name = luaL_checklstring(L, 2, &slen); + + if (ltag && attr_name) { + auto maybe_attr = ltag->tag->find_component( + rspamd::html::html_component_from_string({attr_name, slen})); + + if (maybe_attr) { + lua_pushlstring(L, maybe_attr->data(), maybe_attr->size()); + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +void luaopen_html(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{html}", htmllib_m); + lua_pop(L, 1); + rspamd_lua_new_class(L, "rspamd{html_tag}", taglib_m); + lua_pop(L, 1); +} diff --git a/src/lua/lua_http.c b/src/lua/lua_http.c new file mode 100644 index 0000000..713082a --- /dev/null +++ b/src/lua/lua_http.c @@ -0,0 +1,1270 @@ +/* + * Copyright 2023 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" +#include "lua_thread_pool.h" +#include "libserver/http/http_private.h" +#include "libutil/upstream.h" +#include "ref.h" +#include "unix-std.h" +#include "zlib.h" +#include "utlist.h" + +/*** + * @module rspamd_http + * Rspamd HTTP module represents HTTP asynchronous client available from LUA code. + * This module hides all complexity: DNS resolving, sessions management, zero-copy + * text transfers and so on under the hood. + * @example +local rspamd_http = require "rspamd_http" + +local function symbol_callback(task) + local function http_callback(err_message, code, body, headers) + task:insert_result('SYMBOL', 1) -- task is available via closure + end + + rspamd_http.request({ + task=task, + url='http://example.com/data', + body=task:get_content(), + callback=http_callback, + headers={Header='Value', OtherHeader='Value'}, + mime_type='text/plain', + }) + end + */ + +#define MAX_HEADERS_SIZE 8192 + +static const gchar *M = "rspamd lua http"; + +LUA_FUNCTION_DEF(http, request); + +static const struct luaL_reg httplib_m[] = { + LUA_INTERFACE_DEF(http, request), + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + +#define RSPAMD_LUA_HTTP_FLAG_TEXT (1 << 0) +#define RSPAMD_LUA_HTTP_FLAG_NOVERIFY (1 << 1) +#define RSPAMD_LUA_HTTP_FLAG_RESOLVED (1 << 2) +#define RSPAMD_LUA_HTTP_FLAG_KEEP_ALIVE (1 << 3) +#define RSPAMD_LUA_HTTP_FLAG_YIELDED (1 << 4) + +struct lua_http_cbdata { + struct rspamd_http_connection *conn; + struct rspamd_async_session *session; + struct rspamd_symcache_dynamic_item *item; + struct rspamd_http_message *msg; + struct ev_loop *event_loop; + struct rspamd_config *cfg; + struct rspamd_task *task; + ev_tstamp timeout; + struct rspamd_cryptobox_keypair *local_kp; + struct rspamd_cryptobox_pubkey *peer_pk; + rspamd_inet_addr_t *addr; + gchar *mime_type; + gchar *host; + gchar *auth; + struct upstream *up; + const gchar *url; + gsize max_size; + gint flags; + gint fd; + gint cbref; + struct thread_entry *thread; + ref_entry_t ref; +}; + +static const gdouble default_http_timeout = 5.0; + +static struct rspamd_dns_resolver * +lua_http_global_resolver(struct ev_loop *ev_base) +{ + static struct rspamd_dns_resolver *global_resolver; + + if (global_resolver == NULL) { + global_resolver = rspamd_dns_resolver_init(NULL, ev_base, NULL); + } + + return global_resolver; +} + +static void +lua_http_fin(gpointer arg) +{ + struct lua_http_cbdata *cbd = (struct lua_http_cbdata *) arg; + + if (cbd->cbref != -1) { + luaL_unref(cbd->cfg->lua_state, LUA_REGISTRYINDEX, cbd->cbref); + } + + if (cbd->conn) { + /* Here we already have a connection, so we need to unref it */ + rspamd_http_connection_unref(cbd->conn); + } + else if (cbd->msg != NULL) { + /* We need to free message */ + rspamd_http_message_unref(cbd->msg); + } + + if (cbd->fd != -1) { + close(cbd->fd); + } + + if (cbd->addr) { + rspamd_inet_address_free(cbd->addr); + } + + if (cbd->up) { + rspamd_upstream_unref(cbd->up); + } + + if (cbd->mime_type) { + g_free(cbd->mime_type); + } + + if (cbd->auth) { + g_free(cbd->auth); + } + + if (cbd->host) { + g_free(cbd->host); + } + + if (cbd->local_kp) { + rspamd_keypair_unref(cbd->local_kp); + } + + if (cbd->peer_pk) { + rspamd_pubkey_unref(cbd->peer_pk); + } + + g_free(cbd); +} + +static void +lua_http_cbd_dtor(struct lua_http_cbdata *cbd) +{ + if (cbd->session) { + + if (cbd->flags & RSPAMD_LUA_HTTP_FLAG_RESOLVED) { + /* Event is added merely for resolved events */ + if (cbd->item) { + rspamd_symcache_item_async_dec_check(cbd->task, cbd->item, M); + } + + rspamd_session_remove_event(cbd->session, lua_http_fin, cbd); + } + } + else { + lua_http_fin(cbd); + } +} + +static void +lua_http_push_error(struct lua_http_cbdata *cbd, const char *err) +{ + struct lua_callback_state lcbd; + lua_State *L; + + lua_thread_pool_prepare_callback(cbd->cfg->lua_thread_pool, &lcbd); + + L = lcbd.L; + + lua_rawgeti(L, LUA_REGISTRYINDEX, cbd->cbref); + lua_pushstring(L, err); + + + if (cbd->item) { + rspamd_symcache_set_cur_item(cbd->task, cbd->item); + } + + if (lua_pcall(L, 1, 0, 0) != 0) { + msg_info("callback call failed: %s", lua_tostring(L, -1)); + lua_pop(L, 1); + } + + lua_thread_pool_restore_callback(&lcbd); +} + +static void lua_http_resume_handler(struct rspamd_http_connection *conn, + struct rspamd_http_message *msg, const char *err); + +static void +lua_http_error_handler(struct rspamd_http_connection *conn, GError *err) +{ + struct lua_http_cbdata *cbd = (struct lua_http_cbdata *) conn->ud; + + if (cbd->up) { + rspamd_upstream_fail(cbd->up, false, err ? err->message : "unknown error"); + } + + if (cbd->cbref == -1) { + if (cbd->flags & RSPAMD_LUA_HTTP_FLAG_YIELDED) { + cbd->flags &= ~RSPAMD_LUA_HTTP_FLAG_YIELDED; + lua_http_resume_handler(conn, NULL, err->message); + } + else { + /* TODO: kill me please */ + msg_info("lost HTTP error from %s in coroutines mess: %s", + rspamd_inet_address_to_string_pretty(cbd->addr), + err->message); + } + } + else { + lua_http_push_error(cbd, err->message); + } + + REF_RELEASE(cbd); +} + +static int +lua_http_finish_handler(struct rspamd_http_connection *conn, + struct rspamd_http_message *msg) +{ + struct lua_http_cbdata *cbd = (struct lua_http_cbdata *) conn->ud; + struct rspamd_http_header *h; + const gchar *body; + gsize body_len; + + struct lua_callback_state lcbd; + lua_State *L; + + if (cbd->cbref == -1) { + if (cbd->flags & RSPAMD_LUA_HTTP_FLAG_YIELDED) { + cbd->flags &= ~RSPAMD_LUA_HTTP_FLAG_YIELDED; + lua_http_resume_handler(conn, msg, NULL); + } + else { + /* TODO: kill me please */ + msg_err("lost HTTP data from %s in coroutines mess", + rspamd_inet_address_to_string_pretty(cbd->addr)); + } + + REF_RELEASE(cbd); + + return 0; + } + lua_thread_pool_prepare_callback(cbd->cfg->lua_thread_pool, &lcbd); + + if (cbd->up) { + rspamd_upstream_ok(cbd->up); + } + + L = lcbd.L; + + lua_rawgeti(L, LUA_REGISTRYINDEX, cbd->cbref); + /* Error */ + lua_pushnil(L); + /* Reply code */ + lua_pushinteger(L, msg->code); + /* Body */ + body = rspamd_http_message_get_body(msg, &body_len); + + if (cbd->flags & RSPAMD_LUA_HTTP_FLAG_TEXT) { + struct rspamd_lua_text *t; + + t = lua_newuserdata(L, sizeof(*t)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + t->start = body; + t->len = body_len; + t->flags = 0; + } + else { + if (body_len > 0) { + lua_pushlstring(L, body, body_len); + } + else { + lua_pushnil(L); + } + } + /* Headers */ + lua_newtable(L); + + kh_foreach_value(msg->headers, h, { + /* + * Lowercase header name, as Lua cannot search in caseless matter + */ + rspamd_str_lc(h->combined->str, h->name.len); + lua_pushlstring(L, h->name.begin, h->name.len); + lua_pushlstring(L, h->value.begin, h->value.len); + lua_settable(L, -3); + }); + + if (cbd->item) { + /* Replace watcher to deal with nested calls */ + rspamd_symcache_set_cur_item(cbd->task, cbd->item); + } + + if (lua_pcall(L, 4, 0, 0) != 0) { + msg_info("callback call failed: %s", lua_tostring(L, -1)); + lua_pop(L, 1); + } + + REF_RELEASE(cbd); + + lua_thread_pool_restore_callback(&lcbd); + + return 0; +} + +/* + * resumes yielded thread + */ +static void +lua_http_resume_handler(struct rspamd_http_connection *conn, + struct rspamd_http_message *msg, const char *err) +{ + struct lua_http_cbdata *cbd = (struct lua_http_cbdata *) conn->ud; + lua_State *L = cbd->thread->lua_state; + const gchar *body; + gsize body_len; + struct rspamd_http_header *h; + + if (err) { + lua_pushstring(L, err); + lua_pushnil(L); + } + else { + /* + * 1 - nil (error) + * 2 - table: + * code (int) + * content (string) + * headers (table: header -> value) + */ + lua_pushnil(L);// error code + + lua_createtable(L, 0, 3); + + /* code */ + lua_pushliteral(L, "code"); + lua_pushinteger(L, msg->code); + lua_settable(L, -3); + + /* content */ + lua_pushliteral(L, "content"); + + body = rspamd_http_message_get_body(msg, &body_len); + if (cbd->flags & RSPAMD_LUA_HTTP_FLAG_TEXT) { + struct rspamd_lua_text *t; + + t = lua_newuserdata(L, sizeof(*t)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + t->start = body; + t->len = body_len; + t->flags = 0; + } + else { + if (body_len > 0) { + lua_pushlstring(L, body, body_len); + } + else { + lua_pushnil(L); + } + } + lua_settable(L, -3); + + /* headers */ + lua_pushliteral(L, "headers"); + lua_newtable(L); + + kh_foreach_value(msg->headers, h, { + /* + * Lowercase header name, as Lua cannot search in caseless matter + */ + rspamd_str_lc(h->combined->str, h->name.len); + lua_pushlstring(L, h->name.begin, h->name.len); + lua_pushlstring(L, h->value.begin, h->value.len); + lua_settable(L, -3); + }); + + lua_settable(L, -3); + } + + if (cbd->item) { + /* Replace watcher to deal with nested calls */ + rspamd_symcache_set_cur_item(cbd->task, cbd->item); + } + + lua_thread_resume(cbd->thread, 2); +} + +static gboolean +lua_http_make_connection(struct lua_http_cbdata *cbd) +{ + rspamd_inet_address_set_port(cbd->addr, cbd->msg->port); + unsigned http_opts = RSPAMD_HTTP_CLIENT_SIMPLE; + + if (cbd->msg->flags & RSPAMD_HTTP_FLAG_WANT_SSL) { + http_opts |= RSPAMD_HTTP_CLIENT_SSL; + } + + if (cbd->flags & RSPAMD_LUA_HTTP_FLAG_KEEP_ALIVE) { + cbd->fd = -1; /* FD is owned by keepalive connection */ + cbd->conn = rspamd_http_connection_new_client_keepalive( + NULL, /* Default context */ + NULL, + lua_http_error_handler, + lua_http_finish_handler, + http_opts, + cbd->addr, + cbd->host); + } + else { + cbd->fd = -1; + cbd->conn = rspamd_http_connection_new_client( + NULL, /* Default context */ + NULL, + lua_http_error_handler, + lua_http_finish_handler, + http_opts, + cbd->addr); + } + + if (cbd->conn) { + if (cbd->local_kp) { + rspamd_http_connection_set_key(cbd->conn, cbd->local_kp); + } + + if (cbd->peer_pk) { + rspamd_http_message_set_peer_key(cbd->msg, cbd->peer_pk); + } + + if (cbd->flags & RSPAMD_LUA_HTTP_FLAG_NOVERIFY) { + cbd->msg->flags |= RSPAMD_HTTP_FLAG_SSL_NOVERIFY; + } + + if (cbd->max_size) { + rspamd_http_connection_set_max_size(cbd->conn, cbd->max_size); + } + + if (cbd->auth) { + rspamd_http_message_add_header(cbd->msg, "Authorization", + cbd->auth); + } + + if (cbd->session) { + if (cbd->item) { + rspamd_session_add_event_full(cbd->session, + (event_finalizer_t) lua_http_fin, cbd, + M, + rspamd_symcache_dyn_item_name(cbd->task, cbd->item)); + } + else { + rspamd_session_add_event(cbd->session, + (event_finalizer_t) lua_http_fin, cbd, + M); + } + cbd->flags |= RSPAMD_LUA_HTTP_FLAG_RESOLVED; + } + + if (cbd->task) { + cbd->conn->log_tag = cbd->task->task_pool->tag.uid; + + if (cbd->item) { + rspamd_symcache_item_async_inc(cbd->task, cbd->item, M); + } + } + else if (cbd->cfg) { + cbd->conn->log_tag = cbd->cfg->cfg_pool->tag.uid; + } + + struct rspamd_http_message *msg = cbd->msg; + + /* Message is now owned by a connection object */ + cbd->msg = NULL; + + return rspamd_http_connection_write_message(cbd->conn, msg, + cbd->host, cbd->mime_type, cbd, + cbd->timeout); + } + + return FALSE; +} + +static void +lua_http_dns_handler(struct rdns_reply *reply, gpointer ud) +{ + struct lua_http_cbdata *cbd = (struct lua_http_cbdata *) ud; + struct rspamd_symcache_dynamic_item *item; + struct rspamd_task *task; + + task = cbd->task; + item = cbd->item; + + if (reply->code != RDNS_RC_NOERROR) { + lua_http_push_error(cbd, "unable to resolve host"); + REF_RELEASE(cbd); + } + else { + struct rdns_reply_entry *entry; + + DL_FOREACH(reply->entries, entry) + { + if (entry->type == RDNS_REQUEST_A) { + cbd->addr = rspamd_inet_address_new(AF_INET, + &entry->content.a.addr); + break; + } + else if (entry->type == RDNS_REQUEST_AAAA) { + cbd->addr = rspamd_inet_address_new(AF_INET6, + &entry->content.aaa.addr); + break; + } + } + + if (cbd->addr == NULL) { + lua_http_push_error(cbd, "unable to resolve host: no records with such name"); + REF_RELEASE(cbd); + } + else { + REF_RETAIN(cbd); + if (!lua_http_make_connection(cbd)) { + lua_http_push_error(cbd, "unable to make connection to the host"); + + if (cbd->ref.refcount > 1) { + REF_RELEASE(cbd); + } + + REF_RELEASE(cbd); + + return; + } + REF_RELEASE(cbd); + } + } + + if (item) { + rspamd_symcache_item_async_dec_check(task, item, M); + } +} + +static void +lua_http_push_headers(lua_State *L, struct rspamd_http_message *msg) +{ + const char *name, *value; + gint i, sz; + + lua_pushnil(L); + while (lua_next(L, -2) != 0) { + + lua_pushvalue(L, -2); + name = lua_tostring(L, -1); + sz = rspamd_lua_table_size(L, -2); + if (sz != 0 && name != NULL) { + for (i = 1; i <= sz; i++) { + lua_rawgeti(L, -2, i); + value = lua_tostring(L, -1); + if (value != NULL) { + rspamd_http_message_add_header(msg, name, value); + } + lua_pop(L, 1); + } + } + else { + value = lua_tostring(L, -2); + if (name != NULL && value != NULL) { + rspamd_http_message_add_header(msg, name, value); + } + } + lua_pop(L, 2); + } +} + +/*** + * @function rspamd_http.request({params...}) + * This function creates HTTP request and accepts several parameters as a table using key=value syntax. + * Required params are: + * + * - `url` + * - `task` + * + * In taskless mode, instead of `task` required are: + * + * - `ev_base` + * - `config` + * + * @param {string} url specifies URL for a request in the standard URI form (e.g. 'http://example.com/path') + * @param {function} callback specifies callback function in format `function (err_message, code, body, headers)` that is called on HTTP request completion. if this parameter is missing, the function performs "pseudo-synchronous" call (see [Synchronous and Asynchronous API overview](/doc/lua/sync_async.html#API-example-http-module) + * @param {task} task if called from symbol handler it is generally a good idea to use the common task objects: event base, DNS resolver and events session + * @param {table} headers optional headers in form `[name='value', name='value']` + * @param {string} mime_type MIME type of the HTTP content (for example, `text/html`) + * @param {string/text} body full body content, can be opaque `rspamd{text}` to avoid data copying + * @param {number} timeout floating point request timeout value in seconds (default is 5.0 seconds) + * @param {resolver} resolver to perform DNS-requests. Usually got from either `task` or `config` + * @param {boolean} gzip if true, body of the requests will be compressed + * @param {boolean} no_ssl_verify disable SSL peer checks + * @param {boolean} keepalive enable keep-alive pool + * @param {string} user for HTTP authentication + * @param {string} password for HTTP authentication, only if "user" present + * @return {boolean} `true`, in **async** mode, if a request has been successfully scheduled. If this value is `false` then some error occurred, the callback thus will not be called. + * @return In **sync** mode `string|nil, nil|table` In sync mode error message if any and response as table: `int` _code_, `string` _content_ and `table` _headers_ (header -> value) + */ +static gint +lua_http_request(lua_State *L) +{ + LUA_TRACE_POINT; + struct ev_loop *ev_base; + struct rspamd_http_message *msg; + struct lua_http_cbdata *cbd; + struct rspamd_dns_resolver *resolver; + struct rspamd_async_session *session = NULL; + struct rspamd_lua_text *t; + struct rspamd_task *task = NULL; + struct rspamd_config *cfg = NULL; + struct rspamd_cryptobox_pubkey *peer_key = NULL; + struct rspamd_cryptobox_keypair *local_kp = NULL; + struct upstream *up = NULL; + const gchar *url, *lua_body; + rspamd_fstring_t *body = NULL; + gint cbref = -1; + gsize bodylen; + gdouble timeout = default_http_timeout; + gint flags = 0; + gchar *mime_type = NULL; + gchar *auth = NULL; + gsize max_size = 0; + gboolean gzip = FALSE; + + if (lua_gettop(L) >= 2) { + /* url, callback and event_base format */ + url = luaL_checkstring(L, 1); + + if (url == NULL || lua_type(L, 2) != LUA_TFUNCTION) { + msg_err("http request has bad params"); + lua_pushboolean(L, FALSE); + return 1; + } + + lua_pushvalue(L, 2); + cbref = luaL_ref(L, LUA_REGISTRYINDEX); + + if (lua_gettop(L) >= 3 && rspamd_lua_check_udata_maybe(L, 3, "rspamd{ev_base}")) { + ev_base = *(struct ev_loop **) lua_touserdata(L, 3); + } + else { + ev_base = NULL; + } + + if (lua_gettop(L) >= 4 && rspamd_lua_check_udata_maybe(L, 4, "rspamd{resolver}")) { + resolver = *(struct rspamd_dns_resolver **) lua_touserdata(L, 4); + } + else { + resolver = lua_http_global_resolver(ev_base); + } + + if (lua_gettop(L) >= 5 && rspamd_lua_check_udata_maybe(L, 5, "rspamd{session}")) { + session = *(struct rspamd_async_session **) lua_touserdata(L, 5); + } + else { + session = NULL; + } + + msg = rspamd_http_message_from_url(url); + + if (msg == NULL) { + luaL_unref(L, LUA_REGISTRYINDEX, cbref); + msg_err("cannot create HTTP message from url %s", url); + lua_pushboolean(L, FALSE); + return 1; + } + } + else if (lua_type(L, 1) == LUA_TTABLE) { + lua_pushstring(L, "url"); + lua_gettable(L, 1); + url = luaL_checkstring(L, -1); + lua_pop(L, 1); + + if (url == NULL) { + msg_err("cannot create HTTP message without url"); + lua_pushboolean(L, FALSE); + return 1; + } + + lua_pushstring(L, "callback"); + lua_gettable(L, 1); + if (url == NULL || lua_type(L, -1) != LUA_TFUNCTION) { + lua_pop(L, 1); + } + else { + cbref = luaL_ref(L, LUA_REGISTRYINDEX); + } + + lua_pushstring(L, "task"); + lua_gettable(L, 1); + + if (lua_type(L, -1) == LUA_TUSERDATA) { + task = lua_check_task(L, -1); + + if (task) { + ev_base = task->event_loop; + resolver = task->resolver; + session = task->s; + cfg = task->cfg; + } + } + lua_pop(L, 1); + + if (task == NULL) { + lua_pushstring(L, "ev_base"); + lua_gettable(L, 1); + if (rspamd_lua_check_udata_maybe(L, -1, "rspamd{ev_base}")) { + ev_base = *(struct ev_loop **) lua_touserdata(L, -1); + } + else { + ev_base = NULL; + } + lua_pop(L, 1); + + + lua_pushstring(L, "session"); + lua_gettable(L, 1); + if (rspamd_lua_check_udata_maybe(L, -1, "rspamd{session}")) { + session = *(struct rspamd_async_session **) lua_touserdata(L, -1); + } + else { + session = NULL; + } + lua_pop(L, 1); + + lua_pushstring(L, "config"); + lua_gettable(L, 1); + if (rspamd_lua_check_udata_maybe(L, -1, "rspamd{config}")) { + cfg = *(struct rspamd_config **) lua_touserdata(L, -1); + } + else { + cfg = NULL; + } + + lua_pop(L, 1); + + lua_pushstring(L, "resolver"); + lua_gettable(L, 1); + + if (rspamd_lua_check_udata_maybe(L, -1, "rspamd{resolver}")) { + resolver = *(struct rspamd_dns_resolver **) lua_touserdata(L, -1); + } + else { + if (cfg && cfg->dns_resolver) { + resolver = cfg->dns_resolver; + } + else { + resolver = lua_http_global_resolver(ev_base); + } + } + lua_pop(L, 1); + } + + msg = rspamd_http_message_from_url(url); + if (msg == NULL) { + msg_err_task_check("cannot create HTTP message from url %s", url); + lua_pushboolean(L, FALSE); + return 1; + } + + lua_pushstring(L, "headers"); + lua_gettable(L, 1); + if (lua_type(L, -1) == LUA_TTABLE) { + lua_http_push_headers(L, msg); + } + lua_pop(L, 1); + + lua_pushstring(L, "timeout"); + lua_gettable(L, 1); + if (lua_type(L, -1) == LUA_TNUMBER) { + timeout = lua_tonumber(L, -1); + } + lua_pop(L, 1); + + lua_pushstring(L, "mime_type"); + lua_gettable(L, 1); + if (lua_type(L, -1) == LUA_TSTRING) { + mime_type = g_strdup(lua_tostring(L, -1)); + } + lua_pop(L, 1); + + lua_pushstring(L, "body"); + lua_gettable(L, 1); + if (lua_type(L, -1) == LUA_TSTRING) { + lua_body = lua_tolstring(L, -1, &bodylen); + body = rspamd_fstring_new_init(lua_body, bodylen); + } + else if (lua_type(L, -1) == LUA_TUSERDATA) { + t = lua_check_text(L, -1); + /* TODO: think about zero-copy possibilities */ + if (t) { + body = rspamd_fstring_new_init(t->start, t->len); + } + else { + rspamd_http_message_unref(msg); + g_free(mime_type); + + return luaL_error(L, "invalid body argument type: %s", + lua_typename(L, lua_type(L, -1))); + } + } + else if (lua_type(L, -1) == LUA_TTABLE) { + gsize total_len = 0, nelts = rspamd_lua_table_size(L, -1); + + /* Calculate length and check types */ + for (gsize i = 0; i < nelts; i++) { + lua_rawgeti(L, -1, i + 1); + + if (lua_type(L, -1) == LUA_TSTRING) { +#if LUA_VERSION_NUM >= 502 + total_len += lua_rawlen(L, -1); +#else + total_len += lua_objlen(L, -1); +#endif + } + else if (lua_type(L, -1) == LUA_TUSERDATA) { + t = lua_check_text(L, -1); + + if (t) { + total_len += t->len; + } + else { + rspamd_http_message_unref(msg); + if (mime_type) { + g_free(mime_type); + } + + return luaL_error(L, "invalid body argument: %s", + lua_typename(L, lua_type(L, -1))); + } + } + else { + rspamd_http_message_unref(msg); + if (mime_type) { + g_free(mime_type); + } + + return luaL_error(L, "invalid body argument type: %s", + lua_typename(L, lua_type(L, -1))); + } + + lua_pop(L, 1); + } + + /* Preallocate body */ + if (total_len > 0) { + body = rspamd_fstring_sized_new(total_len); + } + else { + rspamd_http_message_unref(msg); + if (mime_type) { + g_free(mime_type); + } + + return luaL_error(L, "empty body specified"); + } + + /* Fill elements */ + for (gsize i = 0; i < nelts; i++) { + lua_rawgeti(L, -1, i + 1); + + if (lua_type(L, -1) == LUA_TSTRING) { + lua_body = lua_tolstring(L, -1, &bodylen); + body = rspamd_fstring_append(body, lua_body, bodylen); + } + else { + t = lua_check_text(L, -1); + + if (t) { + body = rspamd_fstring_append(body, t->start, t->len); + } + } + + lua_pop(L, 1); + } + } + else if (lua_type(L, -1) != LUA_TNONE && lua_type(L, -1) != LUA_TNIL) { + rspamd_http_message_unref(msg); + return luaL_error(L, "invalid body argument type: %s", + lua_typename(L, lua_type(L, -1))); + } + lua_pop(L, 1); + + lua_pushstring(L, "peer_key"); + lua_gettable(L, 1); + + if (lua_type(L, -1) == LUA_TSTRING) { + const gchar *in; + gsize inlen; + + in = lua_tolstring(L, -1, &inlen); + peer_key = rspamd_pubkey_from_base32(in, inlen, + RSPAMD_KEYPAIR_KEX, RSPAMD_CRYPTOBOX_MODE_25519); + } + + lua_pop(L, 1); + + lua_pushstring(L, "keypair"); + lua_gettable(L, 1); + + if (lua_type(L, -1) == LUA_TTABLE) { + ucl_object_t *kp_ucl = ucl_object_lua_import(L, -1); + + local_kp = rspamd_keypair_from_ucl(kp_ucl); + ucl_object_unref(kp_ucl); + } + + lua_pop(L, 1); + + lua_pushstring(L, "opaque_body"); + lua_gettable(L, 1); + + if (!!lua_toboolean(L, -1)) { + flags |= RSPAMD_LUA_HTTP_FLAG_TEXT; + } + + lua_pop(L, 1); + + lua_pushstring(L, "gzip"); + lua_gettable(L, 1); + + if (!!lua_toboolean(L, -1)) { + gzip = TRUE; + } + + lua_pop(L, 1); + + lua_pushstring(L, "no_ssl_verify"); + lua_gettable(L, 1); + + if (!!lua_toboolean(L, -1)) { + flags |= RSPAMD_LUA_HTTP_FLAG_NOVERIFY; + } + + lua_pop(L, 1); + + lua_pushstring(L, "keepalive"); + lua_gettable(L, 1); + + if (!!lua_toboolean(L, -1)) { + flags |= RSPAMD_LUA_HTTP_FLAG_KEEP_ALIVE; + } + + lua_pop(L, 1); + + lua_pushstring(L, "max_size"); + lua_gettable(L, 1); + + if (lua_type(L, -1) == LUA_TNUMBER) { + max_size = lua_tointeger(L, -1); + } + + lua_pop(L, 1); + + lua_pushstring(L, "method"); + lua_gettable(L, 1); + + if (lua_type(L, -1) == LUA_TSTRING) { + rspamd_http_message_set_method(msg, lua_tostring(L, -1)); + } + + lua_pop(L, 1); + + lua_pushstring(L, "upstream"); + lua_gettable(L, 1); + + if (lua_type(L, -1) == LUA_TUSERDATA) { + struct rspamd_lua_upstream *lup = lua_check_upstream(L, -1); + + if (lup) { + /* Preserve pointer in case if lup is destructed */ + up = lup->up; + } + } + + lua_pop(L, 1); + + lua_pushstring(L, "user"); + lua_gettable(L, 1); + + if (lua_type(L, -1) == LUA_TSTRING) { + const gchar *user = lua_tostring(L, -1); + + lua_pushstring(L, "password"); + lua_gettable(L, 1); + + if (lua_type(L, -1) == LUA_TSTRING) { + const gchar *password = lua_tostring(L, -1); + gchar *tmpbuf; + gsize tlen; + + tlen = strlen(user) + strlen(password) + 1; + tmpbuf = g_malloc(tlen + 1); + rspamd_snprintf(tmpbuf, tlen + 1, "%s:%s", user, password); + tlen *= 2; + tlen += sizeof("Basic ") - 1; + auth = g_malloc(tlen + 1); + rspamd_snprintf(auth, tlen + 1, "Basic %Bs", tmpbuf); + g_free(tmpbuf); + } + else { + msg_warn("HTTP user must have password, disabling auth"); + } + + lua_pop(L, 1); /* password */ + } + + lua_pop(L, 1); /* username */ + } + else { + msg_err("http request has bad params"); + lua_pushboolean(L, FALSE); + + return 1; + } + + if (session && rspamd_session_blocked(session)) { + lua_pushboolean(L, FALSE); + + g_free(auth); + rspamd_http_message_unref(msg); + if (body) { + rspamd_fstring_free(body); + } + if (local_kp) { + rspamd_keypair_unref(local_kp); + } + + return 1; + } + if (task == NULL && cfg == NULL) { + g_free(auth); + rspamd_http_message_unref(msg); + if (body) { + rspamd_fstring_free(body); + } + if (local_kp) { + rspamd_keypair_unref(local_kp); + } + + return luaL_error(L, + "Bad params to rspamd_http:request(): either task or config should be set"); + } + + if (ev_base == NULL) { + g_free(auth); + rspamd_http_message_unref(msg); + if (body) { + rspamd_fstring_free(body); + } + if (local_kp) { + rspamd_keypair_unref(local_kp); + } + + return luaL_error(L, + "Bad params to rspamd_http:request(): ev_base isn't passed"); + } + + cbd = g_malloc0(sizeof(*cbd)); + cbd->cbref = cbref; + cbd->msg = msg; + cbd->event_loop = ev_base; + cbd->mime_type = mime_type; + cbd->timeout = timeout; + cbd->fd = -1; + cbd->cfg = cfg; + cbd->peer_pk = peer_key; + cbd->local_kp = local_kp; + cbd->flags = flags; + cbd->max_size = max_size; + cbd->url = url; + cbd->auth = auth; + cbd->task = task; + + if (up) { + cbd->up = rspamd_upstream_ref(up); + } + + if (cbd->cbref == -1) { + cbd->thread = lua_thread_pool_get_running_entry(cfg->lua_thread_pool); + } + + REF_INIT_RETAIN(cbd, lua_http_cbd_dtor); + + if (task) { + cbd->item = rspamd_symcache_get_cur_item(task); + } + + + if (body) { + if (gzip) { + if (rspamd_fstring_gzip(&body)) { + rspamd_http_message_add_header(msg, "Content-Encoding", "gzip"); + } + } + + rspamd_http_message_set_body_from_fstring_steal(msg, body); + } + + if (session) { + cbd->session = session; + } + + bool numeric_ip = false; + + /* Check if we can skip resolving */ + + gsize hostlen = 0; + const gchar *host = rspamd_http_message_get_http_host(msg, &hostlen); + + if (host) { + cbd->host = g_malloc(hostlen + 1); + rspamd_strlcpy(cbd->host, host, hostlen + 1); + + /* Keep-alive entry is available */ + if (cbd->flags & RSPAMD_LUA_HTTP_FLAG_KEEP_ALIVE) { + const rspamd_inet_addr_t *ka_addr = rspamd_http_context_has_keepalive(NULL, + cbd->host, + msg->port, + msg->flags & RSPAMD_HTTP_FLAG_WANT_SSL); + + if (ka_addr) { + cbd->addr = rspamd_inet_address_copy(ka_addr, NULL); + numeric_ip = true; + } + } + + /* + * No keep-alive stuff, check if we have upstream or if we can parse host as + * a numeric address + */ + if (!cbd->addr) { + if (cbd->up) { + numeric_ip = true; + cbd->addr = rspamd_inet_address_copy(rspamd_upstream_addr_next(cbd->up), NULL); + } + else { + /* We use msg->host here, not cbd->host ! */ + if (rspamd_parse_inet_address(&cbd->addr, + msg->host->str, msg->host->len, + RSPAMD_INET_ADDRESS_PARSE_DEFAULT)) { + numeric_ip = true; + } + } + } + } + else { + if (cbd->up) { + numeric_ip = true; + cbd->addr = rspamd_inet_address_copy(rspamd_upstream_addr_next(cbd->up), NULL); + } + cbd->host = NULL; + } + + if (numeric_ip) { + /* Host is numeric IP, no need to resolve */ + gboolean ret; + + REF_RETAIN(cbd); + ret = lua_http_make_connection(cbd); + + if (!ret) { + if (cbd->up) { + rspamd_upstream_fail(cbd->up, true, "HTTP connection failed"); + } + if (cbd->ref.refcount > 1) { + /* Not released by make_connection */ + REF_RELEASE(cbd); + } + + REF_RELEASE(cbd); + lua_pushboolean(L, FALSE); + + return 1; + } + + REF_RELEASE(cbd); + } + else { + if (!cbd->host) { + REF_RELEASE(cbd); + + return luaL_error(L, "no host has been specified"); + } + if (task == NULL) { + + REF_RETAIN(cbd); + if (!rspamd_dns_resolver_request(resolver, session, NULL, lua_http_dns_handler, cbd, + RDNS_REQUEST_A, + cbd->host)) { + if (cbd->ref.refcount > 1) { + /* Not released by make_connection */ + REF_RELEASE(cbd); + } + + REF_RELEASE(cbd); + lua_pushboolean(L, FALSE); + + return 1; + } + + REF_RELEASE(cbd); + } + else { + REF_RETAIN(cbd); + + if (!rspamd_dns_resolver_request_task_forced(task, lua_http_dns_handler, cbd, + RDNS_REQUEST_A, cbd->host)) { + if (cbd->ref.refcount > 1) { + /* Not released by make_connection */ + REF_RELEASE(cbd); + } + + REF_RELEASE(cbd); + lua_pushboolean(L, FALSE); + + return 1; + } + else if (cbd->item) { + rspamd_symcache_item_async_inc(cbd->task, cbd->item, M); + } + + REF_RELEASE(cbd); + } + } + + if (cbd->cbref == -1) { + cbd->thread = lua_thread_pool_get_running_entry(cfg->lua_thread_pool); + cbd->flags |= RSPAMD_LUA_HTTP_FLAG_YIELDED; + + return lua_thread_yield(cbd->thread, 0); + } + else { + lua_pushboolean(L, TRUE); + } + + return 1; +} + +static gint +lua_load_http(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, httplib_m); + + return 1; +} + +void luaopen_http(lua_State *L) +{ + rspamd_lua_add_preload(L, "rspamd_http", lua_load_http); +} diff --git a/src/lua/lua_ip.c b/src/lua/lua_ip.c new file mode 100644 index 0000000..ac24dc5 --- /dev/null +++ b/src/lua/lua_ip.c @@ -0,0 +1,637 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" +#include "libserver/maps/map_helpers.h" + +/*** + * @module rspamd_ip + * `rspamd_ip` is a helper module to simplify IP addresses manipulations. + * @example +local print_octets = function(ip) + print('Normal order octets:') + for _,o in ipairs(ip:str_octets()) do + print(o) + end + print('Reversed order octets:') + for _,o in ipairs(ip:inversed_str_octets()) do + print(o) + end + print('Numeric octets:') + for _,o in ipairs(ip:to_table()) do + print(o) + end +end + +local rspamd_ip = require "rspamd_ip" +-- Create ipv4 +local ip4 = rspamd_ip.from_string('127.0.0.1') +-- Implicit conversion to string +print(ip4) +-- Numeric version +print(ip4:get_version()) +print_octets(ip4) + +-- Create a sample ipv6 address +local ip6 = rspamd_ip.from_string('2001:41d0:8:dd9a::100') +print(ip6) +print(ip6:get_version()) +print_octets(ip6) + */ + +/*** + * @method ip:to_string([pretty=false]) + * Converts valid IP address to string + * @param {bool} pretty print IP address with port and braces (for IPv6) + * @return {string or nil} string representation of IP or `nil` if IP is invalid + */ +LUA_FUNCTION_DEF(ip, to_string); +/*** + * @method ip:to_number() + * Converts valid IP address to number or list of numbers in case of IPv6 + * @return {integer(s) or nil} numeric representation of IP in *host* byte order or `nil` if IP is invalid + */ +LUA_FUNCTION_DEF(ip, to_number); + +/*** + * @method ip:to_table() + * Converts valid IP address to the table of numeric octets + * @return {table or nil} numeric octets of IP address or `nil` if IP is invalid + * @example +local ip = rspamd_ip.from_string('127.0.0.1') +for _,o in ipairs(ip:to_table()) do + print(o) +end +-- Output: +-- 127 +-- 0 +-- 0 +-- 1 + */ +LUA_FUNCTION_DEF(ip, to_table); +/*** + * @method ip:str_octets() + * Converts valid IP address to the table of string octets. The difference from + * @see ip:to_table() is that this method returns just hex strings for ipv6 + * addresses. + * @return {table or nil} string octets of IP address or `nil` if IP is invalid + * @example +local ip = rspamd_ip.from_string('fe80::11') +print(table.concat(ip:str_octets(), ".")) +-- Output: +-- f.e.8.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.1.1 + */ +LUA_FUNCTION_DEF(ip, str_octets); +/*** + * @method ip:inversed_str_octets() + * Converts valid IP address to the table of string octets in reversed order. The difference from + * @see ip:to_table() is that this method returns just hex strings for ipv6 + * addresses in reversed order. + * @return {table or nil} string octets of IP address or `nil` if IP is invalid + * @example +local ip = rspamd_ip.from_string('fe80::11') +print(table.concat(ip:inversed_str_octets(), ".")) +-- Output: +-- 1.1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.e.f + */ +LUA_FUNCTION_DEF(ip, inversed_str_octets); +/*** + * @function rspamd_ip.from_string(line) + * Create IP address from its string representation. + * @param {string} line valid IP address string (either ipv4 or ipv6) + * @return {ip} new ip object or `nil` if input is invalid + */ +LUA_FUNCTION_DEF(ip, from_string); +/*** + * @method ip:__gc() + * Automatically destroys IP object. + */ +LUA_FUNCTION_DEF(ip, destroy); +/*** + * @method ip:get_version() + * Gets numeric version of ip address + * @return {number} `4` for IPv4 and `6` for IPv6 + */ +LUA_FUNCTION_DEF(ip, get_version); +/*** + * @method ip:is_valid() + * Checks if an IP object is a valid IP address. + * @return {boolean} `true` if IP is valid and `false` otherwise + */ +LUA_FUNCTION_DEF(ip, is_valid); +/*** + * @method ip:apply_mask(mask) + * Applies mask to IP address, resetting up to `mask` least significant bits to zero. + * @param {integer} mask how many bits to reset + * @return {ip} new IP object with `mask` bits reset + */ +LUA_FUNCTION_DEF(ip, apply_mask); +/*** + * @method ip:__eq(other) + * Compares two IP addresses + * @param {ip} other IP to compare + * @return {boolean} `true` if two objects are the same + */ +LUA_FUNCTION_DEF(ip, equal); +/*** + * @method ip:copy() + * Performs deep copy of IP address. + * @return {ip} a fresh copy of IP address + */ +LUA_FUNCTION_DEF(ip, copy); + +/** + * @method ip:get_port() + * Returns associated port for this IP address + * @return {number} port number or nil + */ +LUA_FUNCTION_DEF(ip, get_port); +/*** + * @method ip:is_local() + * Returns true if address is local one + * @return {boolean} `true` if address is local + */ +LUA_FUNCTION_DEF(ip, is_local); + +/*** + * @method ip:less_than(other) + * Returns true if address is less than other + * @return {boolean} + */ +LUA_FUNCTION_DEF(ip, less_than); + +static const struct luaL_reg iplib_m[] = { + LUA_INTERFACE_DEF(ip, to_string), + LUA_INTERFACE_DEF(ip, to_table), + LUA_INTERFACE_DEF(ip, to_number), + LUA_INTERFACE_DEF(ip, str_octets), + LUA_INTERFACE_DEF(ip, inversed_str_octets), + LUA_INTERFACE_DEF(ip, get_version), + LUA_INTERFACE_DEF(ip, get_port), + LUA_INTERFACE_DEF(ip, is_valid), + LUA_INTERFACE_DEF(ip, apply_mask), + LUA_INTERFACE_DEF(ip, copy), + LUA_INTERFACE_DEF(ip, is_local), + {"tostring", lua_ip_to_string}, + {"totable", lua_ip_to_table}, + {"tonumber", lua_ip_to_number}, + {"__tostring", lua_ip_to_string}, + {"__eq", lua_ip_equal}, + {"__gc", lua_ip_destroy}, + {"__lt", lua_ip_less_than}, + {NULL, NULL}}; + +static const struct luaL_reg iplib_f[] = { + LUA_INTERFACE_DEF(ip, from_string), + {"fromstring", lua_ip_from_string}, + {"fromip", lua_ip_copy}, + {"from_ip", lua_ip_copy}, + {NULL, NULL}}; + +static struct rspamd_lua_ip * +lua_ip_new(lua_State *L, struct rspamd_lua_ip *old) +{ + struct rspamd_lua_ip *ip, **pip; + + ip = g_malloc0(sizeof(*ip)); + + if (old != NULL && old->addr != NULL) { + ip->addr = rspamd_inet_address_copy(old->addr, NULL); + } + + pip = lua_newuserdata(L, sizeof(struct rspamd_lua_ip *)); + rspamd_lua_setclass(L, "rspamd{ip}", -1); + *pip = ip; + + + return ip; +} + +struct rspamd_lua_ip * +lua_check_ip(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{ip}"); + + luaL_argcheck(L, ud != NULL, pos, "'ip' expected"); + return ud ? *((struct rspamd_lua_ip **) ud) : NULL; +} + +static gint +lua_ip_to_table(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_ip *ip = lua_check_ip(L, 1); + guint max, i; + guint8 *ptr; + + if (ip != NULL && ip->addr) { + ptr = rspamd_inet_address_get_hash_key(ip->addr, &max); + lua_createtable(L, max, 0); + + for (i = 1; i <= max; i++, ptr++) { + lua_pushinteger(L, *ptr); + lua_rawseti(L, -2, i); + } + } + else { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_ip_str_octets(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_ip *ip = lua_check_ip(L, 1); + guint max, i; + guint8 *ptr; + gint af; + char numbuf[8]; + + if (ip != NULL && ip->addr) { + af = rspamd_inet_address_get_af(ip->addr); + ptr = rspamd_inet_address_get_hash_key(ip->addr, &max); + lua_createtable(L, max * 2, 0); + + for (i = 1; i <= max; i++, ptr++) { + if (af == AF_INET) { + rspamd_snprintf(numbuf, sizeof(numbuf), "%d", *ptr); + lua_pushstring(L, numbuf); + lua_rawseti(L, -2, i); + } + else { + rspamd_snprintf(numbuf, + sizeof(numbuf), + "%xd", + (*ptr & 0xf0) >> 4); + lua_pushstring(L, numbuf); + lua_rawseti(L, -2, i * 2 - 1); + rspamd_snprintf(numbuf, sizeof(numbuf), "%xd", *ptr & 0x0f); + lua_pushstring(L, numbuf); + lua_rawseti(L, -2, i * 2); + } + } + } + else { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_ip_inversed_str_octets(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_ip *ip = lua_check_ip(L, 1); + guint max, i; + guint8 *ptr; + char numbuf[4]; + gint af; + + if (ip != NULL && ip->addr) { + ptr = rspamd_inet_address_get_hash_key(ip->addr, &max); + af = rspamd_inet_address_get_af(ip->addr); + lua_createtable(L, max * 2, 0); + + ptr += max - 1; + for (i = 1; i <= max; i++, ptr--) { + if (af == AF_INET) { + rspamd_snprintf(numbuf, sizeof(numbuf), "%d", *ptr); + lua_pushstring(L, numbuf); + lua_rawseti(L, -2, i); + } + else { + rspamd_snprintf(numbuf, sizeof(numbuf), "%xd", *ptr & 0x0f); + lua_pushstring(L, numbuf); + lua_rawseti(L, -2, i * 2 - 1); + rspamd_snprintf(numbuf, + sizeof(numbuf), + "%xd", + (*ptr & 0xf0) >> 4); + lua_pushstring(L, numbuf); + lua_rawseti(L, -2, i * 2); + } + } + } + else { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_ip_to_string(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_ip *ip = lua_check_ip(L, 1); + + if (ip != NULL && ip->addr) { + if (lua_isboolean(L, 2) && lua_toboolean(L, 2) == true) { + lua_pushstring(L, rspamd_inet_address_to_string_pretty(ip->addr)); + } + else { + lua_pushstring(L, rspamd_inet_address_to_string(ip->addr)); + } + } + else { + luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_ip_get_port(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_ip *ip = lua_check_ip(L, 1); + + if (ip != NULL && ip->addr) { + lua_pushinteger(L, rspamd_inet_address_get_port(ip->addr)); + } + else { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_ip_from_string(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_ip *ip; + const gchar *ip_str; + gsize len; + + ip_str = luaL_checklstring(L, 1, &len); + if (ip_str) { + ip = lua_ip_new(L, NULL); + + if (!rspamd_parse_inet_address(&ip->addr, + ip_str, len, RSPAMD_INET_ADDRESS_PARSE_DEFAULT)) { + msg_warn("cannot parse ip: %*s", (gint) len, ip_str); + ip->addr = NULL; + } + } + else { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_ip_to_number(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_ip *ip = lua_check_ip(L, 1); + guint32 c; + guint max, i; + guchar *ptr; + + if (ip != NULL && ip->addr) { + ptr = rspamd_inet_address_get_hash_key(ip->addr, &max); + + for (i = 0; i < max / sizeof(c); i++) { + memcpy(&c, ptr + i * sizeof(c), sizeof(c)); + lua_pushinteger(L, ntohl(c)); + } + + return max / sizeof(c); + } + else { + lua_pushnil(L); + } + + return 1; +} + + +static gint +lua_ip_destroy(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_ip *ip = lua_check_ip(L, 1); + + if (ip) { + if (ip->addr) { + rspamd_inet_address_free(ip->addr); + } + g_free(ip); + } + + return 0; +} + +static gint +lua_ip_get_version(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_ip *ip = lua_check_ip(L, 1); + + if (ip && ip->addr) { + lua_pushinteger(L, rspamd_inet_address_get_af(ip->addr) == AF_INET6 ? 6 : 4); + } + else { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_ip_is_valid(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_ip *ip = lua_check_ip(L, 1); + + if (ip) { + lua_pushboolean(L, ip->addr != NULL); + } + else { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_ip_apply_mask(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_ip *ip = lua_check_ip(L, 1), *nip; + gint mask; + + mask = lua_tonumber(L, 2); + if (mask > 0 && ip != NULL && ip->addr) { + nip = lua_ip_new(L, ip); + rspamd_inet_address_apply_mask(nip->addr, mask); + } + else { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_ip_equal(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_ip *ip1 = lua_check_ip(L, 1), + *ip2 = lua_check_ip(L, 2); + gboolean res = FALSE; + + if (ip1 && ip2 && ip1->addr && ip2->addr) { + res = rspamd_inet_address_compare(ip1->addr, ip2->addr, TRUE) == 0; + } + + lua_pushboolean(L, res); + + return 1; +} + +static gint +lua_ip_copy(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_ip *ip = lua_check_ip(L, 1); + + if (ip) { + lua_ip_new(L, ip); + } + else { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_ip_is_local(lua_State *L) +{ + struct rspamd_lua_ip *ip = lua_check_ip(L, 1); + gboolean check_laddrs = TRUE; + + if (ip && ip->addr) { + + if (lua_type(L, 2) == LUA_TBOOLEAN) { + check_laddrs = lua_toboolean(L, 2); + } + + if (rspamd_inet_address_is_local(ip->addr)) { + lua_pushboolean(L, true); + + return 1; + } + else if (check_laddrs) { + struct rspamd_radix_map_helper *local_addrs = + rspamd_inet_library_get_lib_ctx(); + if (local_addrs) { + if (rspamd_match_radix_map_addr(local_addrs, ip->addr) != NULL) { + lua_pushboolean(L, true); + + return 1; + } + } + } + + lua_pushboolean(L, false); + } + else { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_ip_less_than(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_ip *ip = lua_check_ip(L, 1), + *other = lua_check_ip(L, 2); + + if (ip && other) { + lua_pushboolean(L, + rspamd_inet_address_compare(ip->addr, other->addr, true) < 0); + } + else { + lua_pushnil(L); + } + + return 1; +} + +void rspamd_lua_ip_push(lua_State *L, rspamd_inet_addr_t *addr) +{ + struct rspamd_lua_ip *ip, **pip; + + if (addr) { + ip = g_malloc0(sizeof(struct rspamd_lua_ip)); + ip->addr = rspamd_inet_address_copy(addr, NULL); + pip = lua_newuserdata(L, sizeof(struct rspamd_lua_ip *)); + rspamd_lua_setclass(L, "rspamd{ip}", -1); + *pip = ip; + } + else { + lua_pushnil(L); + } +} + +void rspamd_lua_ip_push_fromstring(lua_State *L, const gchar *ip_str) +{ + struct rspamd_lua_ip *ip, **pip; + + if (ip_str == NULL) { + lua_pushnil(L); + } + else { + ip = g_malloc0(sizeof(struct rspamd_lua_ip)); + + if (rspamd_parse_inet_address(&ip->addr, + ip_str, strlen(ip_str), RSPAMD_INET_ADDRESS_PARSE_DEFAULT)) { + + pip = lua_newuserdata(L, sizeof(struct rspamd_lua_ip *)); + rspamd_lua_setclass(L, "rspamd{ip}", -1); + *pip = ip; + } + else { + g_free(ip); + lua_pushnil(L); + } + } +} + +static gint +lua_load_ip(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, iplib_f); + + return 1; +} + +void luaopen_ip(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{ip}", iplib_m); + lua_pop(L, 1); + rspamd_lua_add_preload(L, "rspamd_ip", lua_load_ip); +} diff --git a/src/lua/lua_kann.c b/src/lua/lua_kann.c new file mode 100644 index 0000000..e42fbfb --- /dev/null +++ b/src/lua/lua_kann.c @@ -0,0 +1,1361 @@ +/* + * Copyright 2023 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "lua_common.h" +#include "lua_tensor.h" +#include "contrib/kann/kann.h" + +/*** + * @module rspamd_kann + * `rspamd_kann` is a Lua interface to kann library + */ + +#define KANN_NODE_CLASS "rspamd{kann_node}" +#define KANN_NETWORK_CLASS "rspamd{kann}" + +/* Simple macros to define behaviour */ +#define KANN_LAYER_DEF(name) static int lua_kann_layer_##name(lua_State *L) +#define KANN_LAYER_INTERFACE(name) \ + { \ + #name, lua_kann_layer_##name \ + } + +#define KANN_TRANSFORM_DEF(name) static int lua_kann_transform_##name(lua_State *L) +#define KANN_TRANSFORM_INTERFACE(name) \ + { \ + #name, lua_kann_transform_##name \ + } + +#define KANN_LOSS_DEF(name) static int lua_kann_loss_##name(lua_State *L) +#define KANN_LOSS_INTERFACE(name) \ + { \ + #name, lua_kann_loss_##name \ + } + +#define KANN_NEW_DEF(name) static int lua_kann_new_##name(lua_State *L) +#define KANN_NEW_INTERFACE(name) \ + { \ + #name, lua_kann_new_##name \ + } + + +/* + * Forwarded declarations + */ +static kad_node_t *lua_check_kann_node(lua_State *L, int pos); + +/* Layers */ +KANN_LAYER_DEF(input); +KANN_LAYER_DEF(dense); +KANN_LAYER_DEF(layernorm); +KANN_LAYER_DEF(rnn); +KANN_LAYER_DEF(lstm); +KANN_LAYER_DEF(gru); +KANN_LAYER_DEF(conv2d); +KANN_LAYER_DEF(conv1d); +KANN_LAYER_DEF(cost); + +static luaL_reg rspamd_kann_layers_f[] = { + KANN_LAYER_INTERFACE(input), + KANN_LAYER_INTERFACE(dense), + KANN_LAYER_INTERFACE(layernorm), + KANN_LAYER_INTERFACE(rnn), + KANN_LAYER_INTERFACE(lstm), + KANN_LAYER_INTERFACE(gru), + KANN_LAYER_INTERFACE(conv2d), + KANN_LAYER_INTERFACE(conv1d), + KANN_LAYER_INTERFACE(cost), + {NULL, NULL}, +}; + +/* Transition and composition functions */ + +/* General transform */ +KANN_TRANSFORM_DEF(add); +KANN_TRANSFORM_DEF(sub); +KANN_TRANSFORM_DEF(mul); +KANN_TRANSFORM_DEF(cmul); +KANN_TRANSFORM_DEF(matmul); + +KANN_TRANSFORM_DEF(square); +KANN_TRANSFORM_DEF(sigm); +KANN_TRANSFORM_DEF(tanh); +KANN_TRANSFORM_DEF(relu); +KANN_TRANSFORM_DEF(softmax); +KANN_TRANSFORM_DEF(1minus); +KANN_TRANSFORM_DEF(exp); +KANN_TRANSFORM_DEF(log); +KANN_TRANSFORM_DEF(sin); +static luaL_reg rspamd_kann_transform_f[] = { + KANN_TRANSFORM_INTERFACE(add), + KANN_TRANSFORM_INTERFACE(sub), + KANN_TRANSFORM_INTERFACE(mul), + KANN_TRANSFORM_INTERFACE(cmul), + KANN_TRANSFORM_INTERFACE(matmul), + + KANN_TRANSFORM_INTERFACE(square), + KANN_TRANSFORM_INTERFACE(sigm), + KANN_TRANSFORM_INTERFACE(tanh), + KANN_TRANSFORM_INTERFACE(relu), + KANN_TRANSFORM_INTERFACE(softmax), + KANN_TRANSFORM_INTERFACE(1minus), + KANN_TRANSFORM_INTERFACE(exp), + KANN_TRANSFORM_INTERFACE(log), + KANN_TRANSFORM_INTERFACE(sin), + {NULL, NULL}, +}; + +/* Loss functions */ +KANN_LOSS_DEF(mse); +KANN_LOSS_DEF(ce_multi); +KANN_LOSS_DEF(ce_bin); +KANN_LOSS_DEF(ce_bin_neg); +KANN_LOSS_DEF(ce_multi_weighted); +static luaL_reg rspamd_kann_loss_f[] = { + KANN_LOSS_INTERFACE(mse), + KANN_LOSS_INTERFACE(ce_multi), + KANN_LOSS_INTERFACE(ce_bin), + KANN_LOSS_INTERFACE(ce_bin_neg), + KANN_LOSS_INTERFACE(ce_multi_weighted), + {NULL, NULL}, +}; + +/* Creation functions */ +KANN_NEW_DEF(leaf); +KANN_NEW_DEF(scalar); +KANN_NEW_DEF(weight); +KANN_NEW_DEF(bias); +KANN_NEW_DEF(weight_conv2d); +KANN_NEW_DEF(weight_conv1d); +KANN_NEW_DEF(kann); + +static luaL_reg rspamd_kann_new_f[] = { + KANN_NEW_INTERFACE(leaf), + KANN_NEW_INTERFACE(scalar), + KANN_NEW_INTERFACE(weight), + KANN_NEW_INTERFACE(bias), + KANN_NEW_INTERFACE(weight_conv2d), + KANN_NEW_INTERFACE(weight_conv1d), + KANN_NEW_INTERFACE(kann), + {NULL, NULL}, +}; + +LUA_FUNCTION_DEF(kann, load); +LUA_FUNCTION_DEF(kann, destroy); +LUA_FUNCTION_DEF(kann, save); +LUA_FUNCTION_DEF(kann, train1); +LUA_FUNCTION_DEF(kann, apply1); + +static luaL_reg rspamd_kann_m[] = { + LUA_INTERFACE_DEF(kann, save), + LUA_INTERFACE_DEF(kann, train1), + LUA_INTERFACE_DEF(kann, apply1), + {"__gc", lua_kann_destroy}, + {NULL, NULL}, +}; + +static int +rspamd_kann_table_to_flags(lua_State *L, int table_pos) +{ + int result = 0; + + lua_pushvalue(L, table_pos); + + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + int fl = lua_tointeger(L, -1); + + result |= fl; + } + + lua_pop(L, 1); + + return result; +} + +static gint +lua_load_kann(lua_State *L) +{ + lua_newtable(L); + + /* Flags */ + lua_pushstring(L, "flag"); + lua_newtable(L); + lua_pushinteger(L, KANN_F_IN); + lua_setfield(L, -2, "in"); + lua_pushinteger(L, KANN_F_COST); + lua_setfield(L, -2, "cost"); + lua_pushinteger(L, KANN_F_OUT); + lua_setfield(L, -2, "out"); + lua_pushinteger(L, KANN_F_TRUTH); + lua_setfield(L, -2, "truth"); + lua_settable(L, -3); + + /* Cost type */ + lua_pushstring(L, "cost"); + lua_newtable(L); + /* binary cross-entropy cost, used with sigmoid */ + lua_pushinteger(L, KANN_C_CEB); + lua_setfield(L, -2, "ceb"); + /* multi-class cross-entropy cost, used with softmax */ + lua_pushinteger(L, KANN_C_CEM); + lua_setfield(L, -2, "cem"); + /* binary cross-entropy-like cost, used with tanh */ + lua_pushinteger(L, KANN_C_CEB_NEG); + lua_setfield(L, -2, "ceb_neg"); + lua_pushinteger(L, KANN_C_MSE); + lua_setfield(L, -2, "mse"); + lua_settable(L, -3); + + /* RNN flag */ + lua_pushstring(L, "rnn"); + lua_newtable(L); + /* apply layer normalization */ + lua_pushinteger(L, KANN_RNN_NORM); + lua_setfield(L, -2, "norm"); + /* take the initial hidden values as variables */ + lua_pushinteger(L, KANN_RNN_VAR_H0); + lua_setfield(L, -2, "var_h0"); + lua_settable(L, -3); + + /* Layers */ + lua_pushstring(L, "layer"); + lua_newtable(L); + luaL_register(L, NULL, rspamd_kann_layers_f); + lua_settable(L, -3); + + /* Transforms */ + lua_pushstring(L, "transform"); + lua_newtable(L); + luaL_register(L, NULL, rspamd_kann_transform_f); + lua_settable(L, -3); + + /* Cost */ + lua_pushstring(L, "loss"); + lua_newtable(L); + luaL_register(L, NULL, rspamd_kann_loss_f); + lua_settable(L, -3); + + /* Create functions */ + lua_pushstring(L, "new"); + lua_newtable(L); + luaL_register(L, NULL, rspamd_kann_new_f); + lua_settable(L, -3); + + /* Load ann from memory or file */ + lua_pushstring(L, "load"); + lua_pushcfunction(L, lua_kann_load); + lua_settable(L, -3); + + return 1; +} + +static kad_node_t * +lua_check_kann_node(lua_State *L, int pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, KANN_NODE_CLASS); + luaL_argcheck(L, ud != NULL, pos, "'kann_node' expected"); + return ud ? *((kad_node_t **) ud) : NULL; +} + +static kann_t * +lua_check_kann(lua_State *L, int pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, KANN_NETWORK_CLASS); + luaL_argcheck(L, ud != NULL, pos, "'kann' expected"); + return ud ? *((kann_t **) ud) : NULL; +} + +void luaopen_kann(lua_State *L) +{ + /* Metatables */ + rspamd_lua_new_class(L, KANN_NODE_CLASS, NULL); /* TODO: add methods */ + lua_pop(L, 1); /* No need in metatable... */ + rspamd_lua_new_class(L, KANN_NETWORK_CLASS, rspamd_kann_m); + lua_pop(L, 1); /* No need in metatable... */ + rspamd_lua_add_preload(L, "rspamd_kann", lua_load_kann); + lua_settop(L, 0); +} + +/* Layers implementation */ +#define PUSH_KAD_NODE(n) \ + do { \ + kad_node_t **pt; \ + pt = lua_newuserdata(L, sizeof(kad_node_t *)); \ + *pt = (n); \ + rspamd_lua_setclass(L, KANN_NODE_CLASS, -1); \ + } while (0) + +#define PUSH_KAN_NETWORK(n) \ + do { \ + kann_t **pn; \ + pn = lua_newuserdata(L, sizeof(kann_t *)); \ + *pn = (n); \ + rspamd_lua_setclass(L, KANN_NETWORK_CLASS, -1); \ + } while (0) + +#define PROCESS_KAD_FLAGS(n, pos) \ + do { \ + int fl = 0; \ + if (lua_type(L, (pos)) == LUA_TTABLE) { fl = rspamd_kann_table_to_flags(L, (pos)); } \ + else if (lua_type(L, (pos)) == LUA_TNUMBER) { \ + fl = lua_tointeger(L, (pos)); \ + } \ + (n)->ext_flag |= fl; \ + } while (0) + +/*** + * @function kann.layer.input(ninputs[, flags]) + * Creates an input layer for ANN + * @param {int} ninputs number of inputs + * @param {table|int} flags optional flags + * @return {kann_node} kann node object (should be used to combine ANN) +*/ +static int +lua_kann_layer_input(lua_State *L) +{ + gint nnodes = luaL_checkinteger(L, 1); + + if (nnodes > 0) { + kad_node_t *t; + + t = kann_layer_input(nnodes); + + PROCESS_KAD_FLAGS(t, 2); + PUSH_KAD_NODE(t); + } + else { + return luaL_error(L, "invalid arguments, nnodes required"); + } + + return 1; +} + +/*** + * @function kann.layer.dense(in, ninputs[, flags]) + * Creates a dense layer (e.g. for hidden layer) + * @param {kann_node} in kann node + * @param {int} ninputs number of dense nodes + * @param {table|int} flags optional flags + * @return {kann_node} kann node object (should be used to combine ANN) +*/ +static int +lua_kann_layer_dense(lua_State *L) +{ + kad_node_t *in = lua_check_kann_node(L, 1); + gint nnodes = luaL_checkinteger(L, 2); + + if (in != NULL && nnodes > 0) { + kad_node_t *t; + + t = kann_layer_dense(in, nnodes); + + PROCESS_KAD_FLAGS(t, 3); + PUSH_KAD_NODE(t); + } + else { + return luaL_error(L, "invalid arguments, input + nnodes required"); + } + + return 1; +} + +/*** + * @function kann.layer.dropout(in, ratio[, flags]) + * Creates a dropout layer + * @param {kann_node} in kann node + * @param {float} ratio drop ratio + * @param {table|int} flags optional flags + * @return {kann_node} kann node object (should be used to combine ANN) +*/ +static int +lua_kann_layer_layerdropout(lua_State *L) +{ + kad_node_t *in = lua_check_kann_node(L, 1); + double r = luaL_checknumber(L, 2); + + if (in != NULL) { + kad_node_t *t; + + t = kann_layer_dropout(in, r); + + PROCESS_KAD_FLAGS(t, 3); + PUSH_KAD_NODE(t); + } + else { + return luaL_error(L, "invalid arguments, input + rate required"); + } + + return 1; +} + +/*** + * @function kann.layer.dropout(in [, flags]) + * Creates a normalisation layer + * @param {kann_node} in kann node + * @param {table|int} flags optional flags + * @return {kann_node} kann node object (should be used to combine ANN) +*/ +static int +lua_kann_layer_layernorm(lua_State *L) +{ + kad_node_t *in = lua_check_kann_node(L, 1); + + if (in != NULL) { + kad_node_t *t; + + t = kann_layer_layernorm(in); + + PROCESS_KAD_FLAGS(t, 2); + PUSH_KAD_NODE(t); + } + else { + return luaL_error(L, "invalid arguments, input required"); + } + + return 1; +} + +/*** + * @function kann.layer.rnn(in, nnodes[, rnn_flags, [, flags]]) + * Creates a recursive NN layer + * @param {kann_node} in kann node + * @param {int} nnodes number of cells + * @param {int} rnnflags rnn flags + * @param {table|int} flags optional flags + * @return {kann_node} kann node object (should be used to combine ANN) +*/ +static int +lua_kann_layer_rnn(lua_State *L) +{ + kad_node_t *in = lua_check_kann_node(L, 1); + gint nnodes = luaL_checkinteger(L, 2); + gint rnnflags = 0; + + if (in != NULL && nnodes > 0) { + kad_node_t *t; + + if (lua_type(L, 3) == LUA_TNUMBER) { + rnnflags = lua_tointeger(L, 3); + } + + t = kann_layer_rnn(in, nnodes, rnnflags); + + PROCESS_KAD_FLAGS(t, 4); + PUSH_KAD_NODE(t); + } + else { + return luaL_error(L, "invalid arguments, input + nnodes required"); + } + + return 1; +} + +/*** + * @function kann.layer.lstm(in, nnodes[, rnn_flags, [, flags]]) + * Creates a recursive NN layer using LSTM cells + * @param {kann_node} in kann node + * @param {int} nnodes number of cells + * @param {int} rnnflags rnn flags + * @param {table|int} flags optional flags + * @return {kann_node} kann node object (should be used to combine ANN) +*/ +static int +lua_kann_layer_lstm(lua_State *L) +{ + kad_node_t *in = lua_check_kann_node(L, 1); + gint nnodes = luaL_checkinteger(L, 2); + gint rnnflags = 0; + + if (in != NULL && nnodes > 0) { + kad_node_t *t; + + if (lua_type(L, 3) == LUA_TNUMBER) { + rnnflags = lua_tointeger(L, 3); + } + + t = kann_layer_lstm(in, nnodes, rnnflags); + + PROCESS_KAD_FLAGS(t, 4); + PUSH_KAD_NODE(t); + } + else { + return luaL_error(L, "invalid arguments, input + nnodes required"); + } + + return 1; +} + +/*** + * @function kann.layer.rnn(in, nnodes[, rnn_flags, [, flags]]) + * Creates a recursive NN layer using GRU cells + * @param {kann_node} in kann node + * @param {int} nnodes number of cells + * @param {int} rnnflags rnn flags + * @param {table|int} flags optional flags + * @return {kann_node} kann node object (should be used to combine ANN) +*/ +static int +lua_kann_layer_gru(lua_State *L) +{ + kad_node_t *in = lua_check_kann_node(L, 1); + gint nnodes = luaL_checkinteger(L, 2); + gint rnnflags = 0; + + if (in != NULL && nnodes > 0) { + kad_node_t *t; + + if (lua_type(L, 3) == LUA_TNUMBER) { + rnnflags = lua_tointeger(L, 3); + } + + t = kann_layer_gru(in, nnodes, rnnflags); + + PROCESS_KAD_FLAGS(t, 4); + PUSH_KAD_NODE(t); + } + else { + return luaL_error(L, "invalid arguments, input + nnodes required"); + } + + return 1; +} + +/*** + * @function kann.layer.conv2d(in, n_flt, k_rows, k_cols, stride_rows, stride_cols, pad_rows, pad_columns[, flags]) + * Creates a 2D convolution layer + * @param {kann_node} in kann node + * @param {int} n_flt number of filters + * @param {int} k_rows kernel rows + * @param {int} k_cols kernel columns + * @param {int} stride_rows stride rows + * @param {int} stride_cols stride columns + * @param {int} pad_rows padding rows + * @param {int} pad_columns padding columns + * @param {table|int} flags optional flags + * @return {kann_node} kann node object (should be used to combine ANN) +*/ +static int +lua_kann_layer_conv2d(lua_State *L) +{ + kad_node_t *in = lua_check_kann_node(L, 1); + int n_flt = luaL_checkinteger(L, 2); + int k_rows = luaL_checkinteger(L, 3); + int k_cols = luaL_checkinteger(L, 4); + int stride_r = luaL_checkinteger(L, 5); + int stride_c = luaL_checkinteger(L, 6); + int pad_r = luaL_checkinteger(L, 7); + int pad_c = luaL_checkinteger(L, 8); + + if (in != NULL) { + kad_node_t *t; + t = kann_layer_conv2d(in, n_flt, k_rows, k_cols, stride_r, stride_c, + pad_r, pad_c); + + PROCESS_KAD_FLAGS(t, 9); + PUSH_KAD_NODE(t); + } + else { + return luaL_error(L, "invalid arguments, input, nflt, kx, ky, stridex, stridey, padx, pady are required"); + } + + return 1; +} + +/*** + * @function kann.layer.conv1d(in, n_flt, kern_size, stride_size, pad_size[, flags]) + * Creates 1D convolution layer + * @param {kann_node} in kann node + * @param {int} n_flt number of filters + * @param {int} kern_size kernel rows + * @param {int} stride_size stride rows + * @param {int} pad_size padding rows + * @param {table|int} flags optional flags + * @return {kann_node} kann node object (should be used to combine ANN) +*/ +static int +lua_kann_layer_conv1d(lua_State *L) +{ + kad_node_t *in = lua_check_kann_node(L, 1); + int n_flt = luaL_checkinteger(L, 2); + int k_size = luaL_checkinteger(L, 3); + int stride = luaL_checkinteger(L, 4); + int pad = luaL_checkinteger(L, 5); + + if (in != NULL) { + kad_node_t *t; + t = kann_layer_conv1d(in, n_flt, k_size, stride, pad); + + PROCESS_KAD_FLAGS(t, 6); + PUSH_KAD_NODE(t); + } + else { + return luaL_error(L, "invalid arguments, input, nflt, k, stride, pad required"); + } + + return 1; +} + +/*** + * @function kann.layer.cost(in, nout, cost_type[, flags]) + * Creates 1D convolution layer + * @param {kann_node} in kann node + * @param {int} nout number of outputs + * @param {int} cost_type see kann.cost table + * @param {table|int} flags optional flags + * @return {kann_node} kann node object (should be used to combine ANN) +*/ +static int +lua_kann_layer_cost(lua_State *L) +{ + kad_node_t *in = lua_check_kann_node(L, 1); + int nout = luaL_checkinteger(L, 2); + int cost_type = luaL_checkinteger(L, 3); + + if (in != NULL && nout > 0) { + kad_node_t *t; + t = kann_layer_cost(in, nout, cost_type); + + PROCESS_KAD_FLAGS(t, 4); + PUSH_KAD_NODE(t); + } + else { + return luaL_error(L, "invalid arguments, input, nout and cost_type are required"); + } + + return 1; +} + +/* Generic helpers */ +static int +lua_kann_call_unary_function(lua_State *L, const char *name, + kad_node_t *(*func)(kad_node_t *) ) +{ + kad_node_t *in = lua_check_kann_node(L, 1); + + if (in != NULL) { + kad_node_t *t; + t = func(in); + + PUSH_KAD_NODE(t); + } + else { + return luaL_error(L, "invalid arguments for %s, input required", name); + } + + return 1; +} +static int +lua_kann_call_binary_function(lua_State *L, const char *name, + kad_node_t *(*func)(kad_node_t *, kad_node_t *) ) +{ + kad_node_t *x = lua_check_kann_node(L, 1); + kad_node_t *y = lua_check_kann_node(L, 2); + + if (x != NULL && y != NULL) { + kad_node_t *t; + t = func(x, y); + + PUSH_KAD_NODE(t); + } + else { + return luaL_error(L, "invalid arguments for %s, 2 inputs required", name); + } + + return 1; +} + +#define LUA_UNARY_TRANSFORM_FUNC_IMPL(name) \ + static int lua_kann_transform_##name(lua_State *L) \ + { \ + return lua_kann_call_unary_function(L, #name, kad_##name); \ + } + +#define LUA_BINARY_TRANSFORM_FUNC_IMPL(name) \ + static int lua_kann_transform_##name(lua_State *L) \ + { \ + return lua_kann_call_binary_function(L, #name, kad_##name); \ + } + +#define LUA_LOSS_FUNC_IMPL(name) \ + static int lua_kann_loss_##name(lua_State *L) \ + { \ + return lua_kann_call_binary_function(L, #name, kad_##name); \ + } + +/* Transform functions registered via macro helpers */ +LUA_BINARY_TRANSFORM_FUNC_IMPL(add) +LUA_BINARY_TRANSFORM_FUNC_IMPL(sub) +LUA_BINARY_TRANSFORM_FUNC_IMPL(mul) +LUA_BINARY_TRANSFORM_FUNC_IMPL(cmul) +LUA_BINARY_TRANSFORM_FUNC_IMPL(matmul) + +LUA_UNARY_TRANSFORM_FUNC_IMPL(square) +LUA_UNARY_TRANSFORM_FUNC_IMPL(sigm) +LUA_UNARY_TRANSFORM_FUNC_IMPL(tanh) +LUA_UNARY_TRANSFORM_FUNC_IMPL(relu) +LUA_UNARY_TRANSFORM_FUNC_IMPL(softmax) +LUA_UNARY_TRANSFORM_FUNC_IMPL(1minus) +LUA_UNARY_TRANSFORM_FUNC_IMPL(exp) +LUA_UNARY_TRANSFORM_FUNC_IMPL(log) +LUA_UNARY_TRANSFORM_FUNC_IMPL(sin) + +/* Generic cost functions */ +LUA_LOSS_FUNC_IMPL(mse) +LUA_LOSS_FUNC_IMPL(ce_multi) +LUA_LOSS_FUNC_IMPL(ce_bin) +LUA_LOSS_FUNC_IMPL(ce_bin_neg) + +/* The only case of ternary weight function */ +static int +lua_kann_loss_ce_multi_weighted(lua_State *L) +{ + kad_node_t *pred = lua_check_kann_node(L, 1); + kad_node_t *truth = lua_check_kann_node(L, 2); + kad_node_t *weight = lua_check_kann_node(L, 3); + + if (pred != NULL && truth != NULL && weight != NULL) { + kad_node_t *t; + t = kad_ce_multi_weighted(pred, truth, weight); + + PUSH_KAD_NODE(t); + } + else { + return luaL_error(L, "invalid arguments for ce_multi_weighted, 3 inputs required"); + } + + return 1; +} + +/* Creation functions */ +static int +lua_kann_new_scalar(lua_State *L) +{ + gint flag = luaL_checkinteger(L, 1); + double x = luaL_checknumber(L, 2); + kad_node_t *t; + + t = kann_new_scalar(flag, x); + + PROCESS_KAD_FLAGS(t, 3); + PUSH_KAD_NODE(t); + + return 1; +} + +static int +lua_kann_new_weight(lua_State *L) +{ + gint nrow = luaL_checkinteger(L, 1); + gint ncol = luaL_checkinteger(L, 2); + kad_node_t *t; + + t = kann_new_weight(nrow, ncol); + + PROCESS_KAD_FLAGS(t, 3); + PUSH_KAD_NODE(t); + + return 1; +} + +static int +lua_kann_new_bias(lua_State *L) +{ + gint n = luaL_checkinteger(L, 1); + kad_node_t *t; + + t = kann_new_bias(n); + + PROCESS_KAD_FLAGS(t, 2); + PUSH_KAD_NODE(t); + + return 1; +} + +static int +lua_kann_new_weight_conv2d(lua_State *L) +{ + gint nout = luaL_checkinteger(L, 1); + gint nin = luaL_checkinteger(L, 2); + gint krow = luaL_checkinteger(L, 3); + gint kcol = luaL_checkinteger(L, 4); + kad_node_t *t; + + t = kann_new_weight_conv2d(nout, nin, krow, kcol); + + PROCESS_KAD_FLAGS(t, 5); + PUSH_KAD_NODE(t); + + return 1; +} + +static int +lua_kann_new_weight_conv1d(lua_State *L) +{ + gint nout = luaL_checkinteger(L, 1); + gint nin = luaL_checkinteger(L, 2); + gint klen = luaL_checkinteger(L, 3); + kad_node_t *t; + + t = kann_new_weight_conv1d(nout, nin, klen); + + PROCESS_KAD_FLAGS(t, 4); + PUSH_KAD_NODE(t); + + return 1; +} + +static int +lua_kann_new_leaf(lua_State *L) +{ + int dim = luaL_checkinteger(L, 1), i, *ar; + kad_node_t *t; + + if (dim >= 1 && dim < KAD_MAX_DIM && lua_istable(L, 2)) { + ar = g_new0(int, KAD_MAX_DIM); + + for (i = 0; i < dim; i++) { + lua_rawgeti(L, 2, i + 1); + ar[i] = lua_tointeger(L, -1); + lua_pop(L, 1); + } + + t = kann_new_leaf_array(NULL, NULL, 0, 0.0, dim, ar); + + PROCESS_KAD_FLAGS(t, 3); + PUSH_KAD_NODE(t); + + g_free(ar); + } + else { + return luaL_error(L, "invalid arguments for new.leaf, " + "dim and vector of elements are required"); + } + + return 1; +} + +static int +lua_kann_new_kann(lua_State *L) +{ + kad_node_t *cost = lua_check_kann_node(L, 1); + kann_t *k; + + if (cost) { + k = kann_new(cost, 0); + + PUSH_KAN_NETWORK(k); + } + else { + return luaL_error(L, "invalid arguments for new.kann, " + "cost node is required"); + } + + return 1; +} + +static int +lua_kann_destroy(lua_State *L) +{ + kann_t *k = lua_check_kann(L, 1); + + kann_delete(k); + + return 0; +} + +static int +lua_kann_save(lua_State *L) +{ + kann_t *k = lua_check_kann(L, 1); + + if (k) { + if (lua_istable(L, 2)) { + lua_getfield(L, 2, "filename"); + + if (lua_isstring(L, -1)) { + const gchar *fname = lua_tostring(L, -1); + FILE *f; + + f = fopen(fname, "w"); + + if (!f) { + lua_pop(L, 1); + + return luaL_error(L, "cannot open %s for writing: %s", + fname, strerror(errno)); + } + + kann_save_fp(f, k); + fclose(f); + + lua_pushboolean(L, true); + } + else { + lua_pop(L, 1); + + return luaL_error(L, "invalid arguments: missing filename"); + } + + lua_pop(L, 1); + } + else { + /* Save to Rspamd text */ +#ifndef HAVE_OPENMEMSTREAM + return luaL_error(L, "no support of saving to memory on your system"); +#endif + FILE *f; + char *buf = NULL; + size_t buflen; + struct rspamd_lua_text *t; + + f = open_memstream(&buf, &buflen); + g_assert(f != NULL); + + kann_save_fp(f, k); + fclose(f); + + t = lua_newuserdata(L, sizeof(*t)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + t->flags = RSPAMD_TEXT_FLAG_OWN; + t->start = (const gchar *) buf; + t->len = buflen; + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static int +lua_kann_load(lua_State *L) +{ + kann_t *k; + FILE *f = NULL; + + if (lua_istable(L, 1)) { + lua_getfield(L, 2, "filename"); + + if (lua_isstring(L, -1)) { + const gchar *fname = lua_tostring(L, -1); + + f = fopen(fname, "rb"); + } + else { + lua_pop(L, 1); + + return luaL_error(L, "invalid arguments: missing filename"); + } + + lua_pop(L, 1); + } + else if (lua_isstring(L, 1)) { + gsize dlen; + const gchar *data; + + data = lua_tolstring(L, 1, &dlen); + +#ifndef HAVE_FMEMOPEN + return luaL_error(L, "no support of loading from memory on your system"); +#endif + f = fmemopen((void *) data, dlen, "rb"); + } + else if (lua_isuserdata(L, 1)) { + struct rspamd_lua_text *t; + + t = lua_check_text(L, 1); + + if (!t) { + return luaL_error(L, "invalid arguments"); + } + +#ifndef HAVE_FMEMOPEN + return luaL_error(L, "no support of loading from memory on your system"); +#endif + f = fmemopen((void *) t->start, t->len, "rb"); + } + + if (f == NULL) { + return luaL_error(L, "invalid arguments or cannot open file"); + } + + k = kann_load_fp(f); + fclose(f); + + if (k == NULL) { + lua_pushnil(L); + } + else { + PUSH_KAN_NETWORK(k); + } + + return 1; +} + +struct rspamd_kann_train_cbdata { + lua_State *L; + kann_t *k; + gint cbref; +}; + +static void +lua_kann_train_cb(int iter, float train_cost, float val_cost, void *ud) +{ + struct rspamd_kann_train_cbdata *cbd = (struct rspamd_kann_train_cbdata *) ud; + + if (cbd->cbref != -1) { + gint err_idx; + lua_State *L = cbd->L; + + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + + lua_rawgeti(L, LUA_REGISTRYINDEX, cbd->cbref); + lua_pushinteger(L, iter); + lua_pushnumber(L, train_cost); + lua_pushnumber(L, val_cost); + + if (lua_pcall(L, 3, 0, err_idx) != 0) { + msg_err("cannot run lua train callback: %s", + lua_tostring(L, -1)); + } + + lua_settop(L, err_idx - 1); + } +} + +#define FREE_VEC(a, n) \ + do { \ + for (int i = 0; i < (n); i++) g_free((a)[i]); \ + g_free(a); \ + } while (0) + +static int +lua_kann_train1(lua_State *L) +{ + kann_t *k = lua_check_kann(L, 1); + struct rspamd_lua_tensor *pca = NULL; + + /* Default train params */ + double lr = 0.001; + gint64 mini_size = 64; + gint64 max_epoch = 25; + gint64 max_drop_streak = 10; + double frac_val = 0.1; + gint cbref = -1; + + if (k && lua_istable(L, 2) && lua_istable(L, 3)) { + int n = rspamd_lua_table_size(L, 2); + int n_in = kann_dim_in(k); + int n_out = kann_dim_out(k); + + if (n_in <= 0) { + return luaL_error(L, "invalid inputs count: %d", n_in); + } + + if (n_out <= 0) { + return luaL_error(L, "invalid outputs count: %d", n_out); + } + + if (n != rspamd_lua_table_size(L, 3) || n == 0) { + return luaL_error(L, "invalid dimensions: outputs size must be " + "equal to inputs and non zero"); + } + + if (lua_istable(L, 4)) { + GError *err = NULL; + + if (!rspamd_lua_parse_table_arguments(L, 4, &err, + RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING, + "lr=N;mini_size=I;max_epoch=I;max_drop_streak=I;frac_val=N;cb=F;pca=u{tensor}", + &lr, &mini_size, &max_epoch, &max_drop_streak, &frac_val, &cbref, &pca)) { + n = luaL_error(L, "invalid params: %s", + err ? err->message : "unknown error"); + g_error_free(err); + + return n; + } + } + + if (pca) { + /* Check pca matrix validity */ + if (pca->ndims != 2) { + return luaL_error(L, "invalid pca tensor: matrix expected, got a row"); + } + + if (pca->dim[0] != n_in) { + return luaL_error(L, "invalid pca tensor: " + "matrix must have %d rows and it has %d rows instead", + n_in, pca->dim[0]); + } + } + + float **x, **y, *tmp_row = NULL; + + /* Fill vectors row by row */ + x = (float **) g_malloc0(sizeof(float *) * n); + y = (float **) g_malloc0(sizeof(float *) * n); + + if (pca) { + tmp_row = g_malloc(sizeof(float) * pca->dim[1]); + } + + for (int s = 0; s < n; s++) { + /* Inputs */ + lua_rawgeti(L, 2, s + 1); + x[s] = (float *) g_malloc(sizeof(float) * n_in); + + if (pca == NULL) { + if (rspamd_lua_table_size(L, -1) != n_in) { + FREE_VEC(x, n); + FREE_VEC(y, n); + + n = luaL_error(L, "invalid params at pos %d: " + "bad input dimension %d; %d expected", + s + 1, + (int) rspamd_lua_table_size(L, -1), + n_in); + lua_pop(L, 1); + + return n; + } + + for (int i = 0; i < n_in; i++) { + lua_rawgeti(L, -1, i + 1); + x[s][i] = lua_tonumber(L, -1); + + lua_pop(L, 1); + } + } + else { + if (rspamd_lua_table_size(L, -1) != pca->dim[1]) { + FREE_VEC(x, n); + FREE_VEC(y, n); + g_free(tmp_row); + + n = luaL_error(L, "(pca on) invalid params at pos %d: " + "bad input dimension %d; %d expected", + s + 1, + (int) rspamd_lua_table_size(L, -1), + pca->dim[1]); + lua_pop(L, 1); + + return n; + } + + + for (int i = 0; i < pca->dim[1]; i++) { + lua_rawgeti(L, -1, i + 1); + tmp_row[i] = lua_tonumber(L, -1); + + lua_pop(L, 1); + } + + kad_sgemm_simple(0, 1, 1, n_in, + pca->dim[1], tmp_row, pca->data, + x[s]); + } + + lua_pop(L, 1); + + /* Outputs */ + y[s] = (float *) g_malloc(sizeof(float) * n_out); + lua_rawgeti(L, 3, s + 1); + + if (rspamd_lua_table_size(L, -1) != n_out) { + FREE_VEC(x, n); + FREE_VEC(y, n); + g_free(tmp_row); + + n = luaL_error(L, "invalid params at pos %d: " + "bad output dimension %d; " + "%d expected", + s + 1, + (int) rspamd_lua_table_size(L, -1), + n_out); + lua_pop(L, 1); + + return n; + } + + for (int i = 0; i < n_out; i++) { + lua_rawgeti(L, -1, i + 1); + y[s][i] = lua_tonumber(L, -1); + + lua_pop(L, 1); + } + + lua_pop(L, 1); + } + + struct rspamd_kann_train_cbdata cbd; + + cbd.cbref = cbref; + cbd.k = k; + cbd.L = L; + + int niters = kann_train_fnn1(k, lr, + mini_size, max_epoch, max_drop_streak, + frac_val, n, x, y, lua_kann_train_cb, &cbd); + + lua_pushinteger(L, niters); + + FREE_VEC(x, n); + FREE_VEC(y, n); + g_free(tmp_row); + } + else { + return luaL_error(L, "invalid arguments: kann, inputs, outputs and" + " optional params are expected"); + } + + return 1; +} + +static int +lua_kann_apply1(lua_State *L) +{ + kann_t *k = lua_check_kann(L, 1); + struct rspamd_lua_tensor *pca = NULL; + + if (k) { + if (lua_istable(L, 2)) { + gsize vec_len = rspamd_lua_table_size(L, 2); + float *vec = (float *) g_malloc(sizeof(float) * vec_len), + *pca_out = NULL; + int i_out; + int n_in = kann_dim_in(k); + + if (n_in <= 0) { + g_free(vec); + return luaL_error(L, "invalid inputs count: %d", n_in); + } + + if (lua_isuserdata(L, 3)) { + pca = lua_check_tensor(L, 3); + + if (pca) { + if (pca->ndims != 2) { + g_free(vec); + return luaL_error(L, "invalid pca tensor: matrix expected, got a row"); + } + + if (pca->dim[0] != n_in) { + g_free(vec); + return luaL_error(L, "invalid pca tensor: " + "matrix must have %d rows and it has %d rows instead", + n_in, pca->dim[0]); + } + } + else { + g_free(vec); + return luaL_error(L, "invalid params: pca matrix expected"); + } + } + else { + if (n_in != vec_len) { + g_free(vec); + return luaL_error(L, "invalid params: bad input dimension %d; %d expected", + (int) vec_len, n_in); + } + } + + for (gsize i = 0; i < vec_len; i++) { + lua_rawgeti(L, 2, i + 1); + vec[i] = lua_tonumber(L, -1); + lua_pop(L, 1); + } + + i_out = kann_find(k, KANN_F_OUT, 0); + + if (i_out <= 0) { + g_free(vec); + return luaL_error(L, "invalid ANN: output layer is missing or is " + "at the input pos"); + } + + kann_set_batch_size(k, 1); + if (pca) { + pca_out = g_malloc(sizeof(float) * n_in); + + kad_sgemm_simple(0, 1, 1, n_in, + vec_len, vec, pca->data, + pca_out); + + kann_feed_bind(k, KANN_F_IN, 0, &pca_out); + } + else { + kann_feed_bind(k, KANN_F_IN, 0, &vec); + } + + kad_eval_at(k->n, k->v, i_out); + + gsize outlen = kad_len(k->v[i_out]); + lua_createtable(L, outlen, 0); + + for (gsize i = 0; i < outlen; i++) { + lua_pushnumber(L, k->v[i_out]->x[i]); + lua_rawseti(L, -2, i + 1); + } + + g_free(vec); + g_free(pca_out); + } + else if (lua_isuserdata(L, 2)) { + struct rspamd_lua_tensor *t = lua_check_tensor(L, 2); + + if (t && t->ndims == 1) { + int i_out; + int n_in = kann_dim_in(k); + + if (n_in != t->dim[0]) { + return luaL_error(L, "invalid params: bad input dimension %d; %d expected", + (int) t->dim[0], n_in); + } + + i_out = kann_find(k, KANN_F_OUT, 0); + + if (i_out <= 0) { + return luaL_error(L, "invalid ANN: output layer is missing or is " + "at the input pos"); + } + + kann_set_batch_size(k, 1); + kann_feed_bind(k, KANN_F_IN, 0, &t->data); + kad_eval_at(k->n, k->v, i_out); + + gint outlen = kad_len(k->v[i_out]); + struct rspamd_lua_tensor *out; + out = lua_newtensor(L, 1, &outlen, false, false); + /* Ensure that kann and tensor have the same understanding of floats */ + G_STATIC_ASSERT(sizeof(float) == sizeof(rspamd_tensor_num_t)); + memcpy(out->data, k->v[i_out]->x, outlen * sizeof(float)); + } + else { + return luaL_error(L, "invalid arguments: 1D rspamd{tensor} expected"); + } + } + else { + return luaL_error(L, "invalid arguments: 1D rspamd{tensor} expected"); + } + } + else { + return luaL_error(L, "invalid arguments: rspamd{kann} expected"); + } + + return 1; +}
\ No newline at end of file diff --git a/src/lua/lua_logger.c b/src/lua/lua_logger.c new file mode 100644 index 0000000..f4f8f3d --- /dev/null +++ b/src/lua/lua_logger.c @@ -0,0 +1,1068 @@ +/* + * Copyright 2023 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" +#include "libserver/maps/map.h" +#include "libserver/maps/map_private.h" + +/*** + * @module rspamd_logger + * Rspamd logger module is used to log messages from LUA API to the main rspamd logger. + * It supports legacy and modern interfaces allowing highly customized an convenient log functions. + * Here is an example of logger usage: + * @example +local rspamd_logger = require "rspamd_logger" + +local a = 'string' +local b = 1.5 +local c = 1 +local d = { + 'aa', + 1, + 'bb' +} +local e = { + key = 'value', + key2 = 1.0 +} + +-- New extended interface +-- %<number> means numeric arguments and %s means the next argument +-- for example %1, %2, %s: %s would mean the third argument + +rspamd_logger.info('a=%1, b=%2, c=%3, d=%4, e=%s', a, b, c, d, e) +-- Output: a=string, b=1.50000, c=1, d={[1] = aa, [2] = 1, [3] = bb} e={[key]=value, [key2]=1.0} + +-- Create string using logger API +local str = rspamd_logger.slog('a=%1, b=%2, c=%3, d=%4, e=%5', a, b, c, d, e) + +print(str) +-- Output: a=string, b=1.50000, c=1, d={[1] = aa, [2] = 1, [3] = bb} e={[key]=value, [key2]=1.0} + */ + +/* Logger methods */ +/*** + * @function logger.err(msg) + * Log message as an error + * @param {string} msg string to be logged + */ +LUA_FUNCTION_DEF(logger, err); +/*** + * @function logger.warn(msg) + * Log message as a warning + * @param {string} msg string to be logged + */ +LUA_FUNCTION_DEF(logger, warn); +/*** + * @function logger.info(msg) + * Log message as an informational message + * @param {string} msg string to be logged + */ +LUA_FUNCTION_DEF(logger, info); +/*** + * @function logger.message(msg) + * Log message as an notice message + * @param {string} msg string to be logged + */ +LUA_FUNCTION_DEF(logger, message); +/*** + * @function logger.debug(msg) + * Log message as a debug message + * @param {string} msg string to be logged + */ +LUA_FUNCTION_DEF(logger, debug); +/*** + * @function logger.errx(fmt[, args) + * Extended interface to make an error log message + * @param {string} fmt format string, arguments are encoded as %<number> + * @param {any} args list of arguments to be replaced in %<number> positions + */ +LUA_FUNCTION_DEF(logger, errx); +/*** + * @function logger.warn(fmt[, args) + * Extended interface to make a warning log message + * @param {string} fmt format string, arguments are encoded as %<number> + * @param {any} args list of arguments to be replaced in %<number> positions + */ +LUA_FUNCTION_DEF(logger, warnx); +/*** + * @function logger.infox(fmt[, args) + * Extended interface to make an informational log message + * @param {string} fmt format string, arguments are encoded as %<number> + * @param {any} args list of arguments to be replaced in %<number> positions + */ +LUA_FUNCTION_DEF(logger, infox); +/*** + * @function logger.infox(fmt[, args) + * Extended interface to make an informational log message + * @param {string} fmt format string, arguments are encoded as %<number> + * @param {any} args list of arguments to be replaced in %<number> positions + */ +LUA_FUNCTION_DEF(logger, messagex); +/*** + * @function logger.debugx(fmt[, args) + * Extended interface to make a debug log message + * @param {string} fmt format string, arguments are encoded as %<number> + * @param {any} args list of arguments to be replaced in %<number> positions + */ +LUA_FUNCTION_DEF(logger, debugx); + +/*** + * @function logger.debugm(module, id, fmt[, args) + * Extended interface to make a debug log message + * @param {string} module debug module + * @param {task|cfg|pool|string} id id to log + * @param {string} fmt format string, arguments are encoded as %<number> + * @param {any} args list of arguments to be replaced in %<number> positions + */ +LUA_FUNCTION_DEF(logger, debugm); +/*** + * @function logger.slog(fmt[, args) + * Create string replacing percent params with corresponding arguments + * @param {string} fmt format string, arguments are encoded as %<number> + * @param {any} args list of arguments to be replaced in %<number> positions + * @return {string} string with percent parameters substituted + */ +LUA_FUNCTION_DEF(logger, slog); + +/*** + * @function logger.logx(level, module, id, fmt[, args) + * Extended interface to make a generic log message on any level + * @param {number} log level as a number (see GLogLevelFlags enum for values) + * @param {task|cfg|pool|string} id id to log + * @param {string} fmt format string, arguments are encoded as %<number> + * @param {any} args list of arguments to be replaced in %<number> positions + */ +LUA_FUNCTION_DEF(logger, logx); + +/*** + * @function logger.log_level() + * Returns log level for a logger + * @return {string} current log level + */ +LUA_FUNCTION_DEF(logger, log_level); + +static const struct luaL_reg loggerlib_f[] = { + LUA_INTERFACE_DEF(logger, err), + LUA_INTERFACE_DEF(logger, warn), + LUA_INTERFACE_DEF(logger, message), + {"msg", lua_logger_message}, + LUA_INTERFACE_DEF(logger, info), + LUA_INTERFACE_DEF(logger, debug), + LUA_INTERFACE_DEF(logger, errx), + LUA_INTERFACE_DEF(logger, warnx), + LUA_INTERFACE_DEF(logger, infox), + LUA_INTERFACE_DEF(logger, messagex), + {"msgx", lua_logger_messagex}, + LUA_INTERFACE_DEF(logger, debugx), + LUA_INTERFACE_DEF(logger, debugm), + LUA_INTERFACE_DEF(logger, slog), + LUA_INTERFACE_DEF(logger, logx), + LUA_INTERFACE_DEF(logger, log_level), + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + +static void +lua_common_log_line(GLogLevelFlags level, + lua_State *L, + const gchar *msg, + const gchar *uid, + const gchar *module, + gint stack_level) +{ + lua_Debug d; + gchar func_buf[128], *p; + + if (lua_getstack(L, stack_level, &d) == 1) { + (void) lua_getinfo(L, "Sl", &d); + if ((p = strrchr(d.short_src, '/')) == NULL) { + p = d.short_src; + } + else { + p++; + } + + if (strlen(p) > 30) { + rspamd_snprintf(func_buf, sizeof(func_buf), "%27s...:%d", p, + d.currentline); + } + else { + rspamd_snprintf(func_buf, sizeof(func_buf), "%s:%d", p, + d.currentline); + } + + rspamd_common_log_function(NULL, + level, + module, + uid, + func_buf, + "%s", + msg); + } + else { + rspamd_common_log_function(NULL, + level, + module, + uid, + G_STRFUNC, + "%s", + msg); + } +} + +/*** Logger interface ***/ +static gint +lua_logger_err(lua_State *L) +{ + return lua_logger_errx(L); +} + +static gint +lua_logger_warn(lua_State *L) +{ + return lua_logger_warnx(L); +} + +static gint +lua_logger_info(lua_State *L) +{ + return lua_logger_infox(L); +} + +static gint +lua_logger_message(lua_State *L) +{ + return lua_logger_messagex(L); +} + +static gint +lua_logger_debug(lua_State *L) +{ + return lua_logger_debugx(L); +} + +static inline bool +lua_logger_char_safe(int t, unsigned int esc_type) +{ + if (t & 0x80) { + if (esc_type & LUA_ESCAPE_8BIT) { + return false; + } + + return true; + } + + if (esc_type & LUA_ESCAPE_UNPRINTABLE) { + if (!g_ascii_isprint(t) && !g_ascii_isspace(t)) { + return false; + } + } + + if (esc_type & LUA_ESCAPE_NEWLINES) { + if (t == '\r' || t == '\n') { + return false; + } + } + + return true; +} + +static gsize +lua_logger_out_str(lua_State *L, gint pos, + gchar *outbuf, gsize len, + struct lua_logger_trace *trace, + enum lua_logger_escape_type esc_type) +{ + gsize slen, flen; + const gchar *str = lua_tolstring(L, pos, &slen); + static const gchar hexdigests[16] = "0123456789abcdef"; + gsize r = 0, s; + + if (str) { + gboolean normal = TRUE; + flen = MIN(slen, len - 1); + + for (r = 0; r < flen; r++) { + if (!lua_logger_char_safe(str[r], esc_type)) { + normal = FALSE; + break; + } + } + + if (normal) { + r = rspamd_strlcpy(outbuf, str, flen + 1); + } + else { + /* Need to escape non-printed characters */ + r = 0; + s = 0; + + while (slen > 0 && len > 1) { + if (!lua_logger_char_safe(str[s], esc_type)) { + if (len >= 3) { + outbuf[r++] = '\\'; + outbuf[r++] = hexdigests[((str[s] >> 4) & 0xF)]; + outbuf[r++] = hexdigests[((str[s]) & 0xF)]; + + len -= 2; + } + else { + outbuf[r++] = '?'; + } + } + else { + outbuf[r++] = str[s]; + } + + s++; + slen--; + len--; + } + + outbuf[r] = '\0'; + } + } + + return r; +} + +static gsize +lua_logger_out_num(lua_State *L, gint pos, gchar *outbuf, gsize len, + struct lua_logger_trace *trace) +{ + gdouble num = lua_tonumber(L, pos); + glong inum; + gsize r = 0; + + if ((gdouble) (glong) num == num) { + inum = num; + r = rspamd_snprintf(outbuf, len + 1, "%l", inum); + } + else { + r = rspamd_snprintf(outbuf, len + 1, "%f", num); + } + + return r; +} + +static gsize +lua_logger_out_boolean(lua_State *L, gint pos, gchar *outbuf, gsize len, + struct lua_logger_trace *trace) +{ + gboolean val = lua_toboolean(L, pos); + gsize r = 0; + + r = rspamd_strlcpy(outbuf, val ? "true" : "false", len + 1); + + return r; +} + +static gsize +lua_logger_out_userdata(lua_State *L, gint pos, gchar *outbuf, gsize len, + struct lua_logger_trace *trace) +{ + gint r = 0, top; + const gchar *str = NULL; + gboolean converted_to_str = FALSE; + + top = lua_gettop(L); + + if (!lua_getmetatable(L, pos)) { + return 0; + } + + lua_pushstring(L, "__index"); + lua_gettable(L, -2); + + if (!lua_istable(L, -1)) { + + if (lua_isfunction(L, -1)) { + /* Functional metatable, try to get __tostring directly */ + lua_pushstring(L, "__tostring"); + lua_gettable(L, -3); + + if (lua_isfunction(L, -1)) { + lua_pushvalue(L, pos); + + if (lua_pcall(L, 1, 1, 0) != 0) { + lua_settop(L, top); + + return 0; + } + + str = lua_tostring(L, -1); + + if (str) { + r = rspamd_snprintf(outbuf, len, "%s", str); + } + + lua_settop(L, top); + + return r; + } + } + lua_settop(L, top); + + return 0; + } + + lua_pushstring(L, "__tostring"); + lua_gettable(L, -2); + + if (lua_isfunction(L, -1)) { + lua_pushvalue(L, pos); + + if (lua_pcall(L, 1, 1, 0) != 0) { + lua_settop(L, top); + + return 0; + } + + str = lua_tostring(L, -1); + + if (str) { + converted_to_str = TRUE; + } + } + else { + lua_pop(L, 1); + lua_pushstring(L, "class"); + lua_gettable(L, -2); + + if (lua_isstring(L, -1)) { + str = lua_tostring(L, -1); + converted_to_str = TRUE; + } + } + + if (converted_to_str) { + r = rspamd_snprintf(outbuf, len, "%s", str); + } + else { + /* Print raw pointer */ + r = rspamd_snprintf(outbuf, len, "%s(%p)", str, lua_touserdata(L, pos)); + } + + lua_settop(L, top); + + return r; +} + +#define MOVE_BUF(d, remain, r) \ + (d) += (r); \ + (remain) -= (r); \ + if ((remain) == 0) { \ + lua_settop(L, old_top); \ + break; \ + } + +static gsize +lua_logger_out_table(lua_State *L, gint pos, gchar *outbuf, gsize len, + struct lua_logger_trace *trace, + enum lua_logger_escape_type esc_type) +{ + gchar *d = outbuf; + gsize remain = len, r; + gboolean first = TRUE; + gconstpointer self = NULL; + gint i, tpos, last_seq = -1, old_top; + + if (!lua_istable(L, pos) || remain == 0) { + return 0; + } + + old_top = lua_gettop(L); + self = lua_topointer(L, pos); + + /* Check if we have seen this pointer */ + for (i = 0; i < TRACE_POINTS; i++) { + if (trace->traces[i] == self) { + r = rspamd_snprintf(d, remain + 1, "ref(%p)", self); + + d += r; + + return (d - outbuf); + } + } + + trace->traces[trace->cur_level % TRACE_POINTS] = self; + + lua_pushvalue(L, pos); + r = rspamd_snprintf(d, remain + 1, "{"); + remain -= r; + d += r; + + /* Get numeric keys (ipairs) */ + for (i = 1;; i++) { + lua_rawgeti(L, -1, i); + + if (lua_isnil(L, -1)) { + lua_pop(L, 1); + break; + } + + last_seq = i; + + if (!first) { + r = rspamd_snprintf(d, remain + 1, ", "); + MOVE_BUF(d, remain, r); + } + + r = rspamd_snprintf(d, remain + 1, "[%d] = ", i); + MOVE_BUF(d, remain, r); + tpos = lua_gettop(L); + + if (lua_topointer(L, tpos) == self) { + r = rspamd_snprintf(d, remain + 1, "__self"); + } + else { + r = lua_logger_out_type(L, tpos, d, remain, trace, esc_type); + } + MOVE_BUF(d, remain, r); + + first = FALSE; + lua_pop(L, 1); + } + + /* Get string keys (pairs) */ + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + /* 'key' is at index -2 and 'value' is at index -1 */ + if (lua_type(L, -2) == LUA_TNUMBER) { + if (last_seq > 0) { + lua_pushvalue(L, -2); + + if (lua_tonumber(L, -1) <= last_seq + 1) { + lua_pop(L, 1); + /* Already seen */ + continue; + } + + lua_pop(L, 1); + } + } + + if (!first) { + r = rspamd_snprintf(d, remain + 1, ", "); + MOVE_BUF(d, remain, r); + } + + /* Preserve key */ + lua_pushvalue(L, -2); + r = rspamd_snprintf(d, remain + 1, "[%s] = ", + lua_tostring(L, -1)); + lua_pop(L, 1); /* Remove key */ + MOVE_BUF(d, remain, r); + tpos = lua_gettop(L); + + if (lua_topointer(L, tpos) == self) { + r = rspamd_snprintf(d, remain + 1, "__self"); + } + else { + r = lua_logger_out_type(L, tpos, d, remain, trace, esc_type); + } + MOVE_BUF(d, remain, r); + + first = FALSE; + } + + lua_settop(L, old_top); + + r = rspamd_snprintf(d, remain + 1, "}"); + d += r; + + return (d - outbuf); +} + +#undef MOVE_BUF + +gsize lua_logger_out_type(lua_State *L, gint pos, + gchar *outbuf, gsize len, + struct lua_logger_trace *trace, + enum lua_logger_escape_type esc_type) +{ + gint type; + gsize r = 0; + + if (len == 0) { + return 0; + } + + type = lua_type(L, pos); + trace->cur_level++; + + switch (type) { + case LUA_TNUMBER: + r = lua_logger_out_num(L, pos, outbuf, len, trace); + break; + case LUA_TBOOLEAN: + r = lua_logger_out_boolean(L, pos, outbuf, len, trace); + break; + case LUA_TTABLE: + r = lua_logger_out_table(L, pos, outbuf, len, trace, esc_type); + break; + case LUA_TUSERDATA: + r = lua_logger_out_userdata(L, pos, outbuf, len, trace); + break; + case LUA_TFUNCTION: + r = rspamd_snprintf(outbuf, len + 1, "function"); + break; + case LUA_TLIGHTUSERDATA: + r = rspamd_snprintf(outbuf, len + 1, "0x%p", lua_topointer(L, pos)); + break; + case LUA_TNIL: + r = rspamd_snprintf(outbuf, len + 1, "nil"); + break; + case LUA_TNONE: + r = rspamd_snprintf(outbuf, len + 1, "no value"); + break; + default: + /* Try to push everything as string using tostring magic */ + r = lua_logger_out_str(L, pos, outbuf, len, trace, esc_type); + break; + } + + trace->cur_level--; + + return r; +} + +static const gchar * +lua_logger_get_id(lua_State *L, gint pos, GError **err) +{ + const gchar *uid = NULL, *clsname; + + if (lua_getmetatable(L, pos) != 0) { + uid = ""; + lua_pushstring(L, "__index"); + lua_gettable(L, -2); + + lua_pushstring(L, "class"); + lua_gettable(L, -2); + + clsname = lua_tostring(L, -1); + + if (strcmp(clsname, "rspamd{task}") == 0) { + struct rspamd_task *task = lua_check_task(L, pos); + + if (task) { + uid = task->task_pool->tag.uid; + } + else { + g_set_error(err, g_quark_from_static_string("lua_logger"), + EINVAL, "invalid rspamd{task}"); + } + } + else if (strcmp(clsname, "rspamd{mempool}") == 0) { + rspamd_mempool_t *pool; + + pool = rspamd_lua_check_mempool(L, pos); + + if (pool) { + uid = pool->tag.uid; + } + else { + g_set_error(err, g_quark_from_static_string("lua_logger"), + EINVAL, "invalid rspamd{mempool}"); + } + } + else if (strcmp(clsname, "rspamd{config}") == 0) { + struct rspamd_config *cfg; + + cfg = lua_check_config(L, pos); + + if (cfg) { + if (cfg->checksum) { + uid = cfg->checksum; + } + } + else { + g_set_error(err, g_quark_from_static_string("lua_logger"), + EINVAL, "invalid rspamd{config}"); + } + } + else if (strcmp(clsname, "rspamd{map}") == 0) { + struct rspamd_lua_map *map; + + map = lua_check_map(L, pos); + + if (map) { + if (map->map) { + uid = map->map->tag; + } + else { + uid = "embedded"; + } + } + else { + g_set_error(err, g_quark_from_static_string("lua_logger"), + EINVAL, "invalid rspamd{map}"); + } + } + else { + g_set_error(err, g_quark_from_static_string("lua_logger"), + EINVAL, "unknown class: %s", clsname); + } + + + /* Metatable, __index, classname */ + lua_pop(L, 3); + } + else { + g_set_error(err, g_quark_from_static_string("lua_logger"), + EINVAL, "no metatable found for userdata"); + } + + return uid; +} + +static gboolean +lua_logger_log_format(lua_State *L, gint fmt_pos, gboolean is_string, + gchar *logbuf, gsize remain) +{ + gchar *d; + const gchar *s, *c; + gsize r, cpylen = 0; + guint arg_num = 0, cur_arg; + bool num_arg = false; + struct lua_logger_trace tr; + enum { + copy_char = 0, + got_percent, + parse_arg_num + } state = copy_char; + + d = logbuf; + s = lua_tostring(L, fmt_pos); + c = s; + cur_arg = fmt_pos; + + if (s == NULL) { + return FALSE; + } + + while (remain > 0 && *s != '\0') { + switch (state) { + case copy_char: + if (*s == '%') { + state = got_percent; + s++; + if (cpylen > 0) { + memcpy(d, c, cpylen); + d += cpylen; + } + cpylen = 0; + } + else { + s++; + cpylen++; + remain--; + } + break; + case got_percent: + if (g_ascii_isdigit(*s) || *s == 's') { + state = parse_arg_num; + c = s; + } + else { + *d++ = *s++; + c = s; + state = copy_char; + } + break; + case parse_arg_num: + if (g_ascii_isdigit(*s)) { + s++; + num_arg = true; + } + else { + if (num_arg) { + arg_num = strtoul(c, NULL, 10); + arg_num += fmt_pos - 1; + /* Update the current argument */ + cur_arg = arg_num; + } + else { + /* We have non numeric argument, e.g. %s */ + arg_num = cur_arg++; + s++; + } + + if (arg_num < 1 || arg_num > (guint) lua_gettop(L) + 1) { + msg_err("wrong argument number: %ud", arg_num); + + return FALSE; + } + + memset(&tr, 0, sizeof(tr)); + r = lua_logger_out_type(L, arg_num + 1, d, remain, &tr, + is_string ? LUA_ESCAPE_UNPRINTABLE : LUA_ESCAPE_LOG); + g_assert(r <= remain); + remain -= r; + d += r; + state = copy_char; + c = s; + } + break; + } + } + + if (state == parse_arg_num) { + if (num_arg) { + arg_num = strtoul(c, NULL, 10); + arg_num += fmt_pos - 1; + } + else { + /* We have non numeric argument, e.g. %s */ + arg_num = cur_arg; + } + + if (arg_num < 1 || arg_num > (guint) lua_gettop(L) + 1) { + msg_err("wrong argument number: %ud", arg_num); + + return FALSE; + } + + memset(&tr, 0, sizeof(tr)); + r = lua_logger_out_type(L, arg_num + 1, d, remain, &tr, + is_string ? LUA_ESCAPE_UNPRINTABLE : LUA_ESCAPE_LOG); + g_assert(r <= remain); + remain -= r; + d += r; + } + else if (state == copy_char) { + if (cpylen > 0 && remain > 0) { + memcpy(d, c, cpylen); + d += cpylen; + } + } + + *d = '\0'; + + + return TRUE; +} + +static gint +lua_logger_do_log(lua_State *L, + GLogLevelFlags level, + gboolean is_string, + gint start_pos) +{ + gchar logbuf[RSPAMD_LOGBUF_SIZE - 128]; + const gchar *uid = NULL; + gint fmt_pos = start_pos; + gint ret; + GError *err = NULL; + + if (lua_type(L, start_pos) == LUA_TSTRING) { + fmt_pos = start_pos; + } + else if (lua_type(L, start_pos) == LUA_TUSERDATA) { + fmt_pos = start_pos + 1; + + uid = lua_logger_get_id(L, start_pos, &err); + + if (uid == NULL) { + ret = luaL_error(L, "bad userdata for logging: %s", + err ? err->message : "unknown error"); + + if (err) { + g_error_free(err); + } + + return ret; + } + } + else { + /* Bad argument type */ + return luaL_error(L, "bad format string type: %s", + lua_typename(L, lua_type(L, start_pos))); + } + + ret = lua_logger_log_format(L, fmt_pos, is_string, + logbuf, sizeof(logbuf) - 1); + + if (ret) { + if (is_string) { + lua_pushstring(L, logbuf); + return 1; + } + else { + lua_common_log_line(level, L, logbuf, uid, "lua", 1); + } + } + else { + if (is_string) { + lua_pushnil(L); + + return 1; + } + } + + return 0; +} + +static gint +lua_logger_errx(lua_State *L) +{ + LUA_TRACE_POINT; + return lua_logger_do_log(L, G_LOG_LEVEL_CRITICAL, FALSE, 1); +} + +static gint +lua_logger_warnx(lua_State *L) +{ + LUA_TRACE_POINT; + return lua_logger_do_log(L, G_LOG_LEVEL_WARNING, FALSE, 1); +} + +static gint +lua_logger_infox(lua_State *L) +{ + LUA_TRACE_POINT; + return lua_logger_do_log(L, G_LOG_LEVEL_INFO, FALSE, 1); +} + +static gint +lua_logger_messagex(lua_State *L) +{ + LUA_TRACE_POINT; + return lua_logger_do_log(L, G_LOG_LEVEL_MESSAGE, FALSE, 1); +} + +static gint +lua_logger_debugx(lua_State *L) +{ + LUA_TRACE_POINT; + return lua_logger_do_log(L, G_LOG_LEVEL_DEBUG, FALSE, 1); +} + +static gint +lua_logger_logx(lua_State *L) +{ + LUA_TRACE_POINT; + GLogLevelFlags flags = lua_tonumber(L, 1); + const gchar *modname = lua_tostring(L, 2), *uid = NULL; + gchar logbuf[RSPAMD_LOGBUF_SIZE - 128]; + gboolean ret; + gint stack_pos = 1; + + if (lua_type(L, 3) == LUA_TSTRING) { + uid = luaL_checkstring(L, 3); + } + else if (lua_type(L, 3) == LUA_TUSERDATA) { + uid = lua_logger_get_id(L, 3, NULL); + } + else { + uid = "???"; + } + + if (uid && modname) { + if (lua_type(L, 4) == LUA_TSTRING) { + ret = lua_logger_log_format(L, 4, FALSE, logbuf, sizeof(logbuf) - 1); + } + else if (lua_type(L, 4) == LUA_TNUMBER) { + stack_pos = lua_tonumber(L, 4); + ret = lua_logger_log_format(L, 5, FALSE, logbuf, sizeof(logbuf) - 1); + } + else { + return luaL_error(L, "invalid argument on pos 4"); + } + + if (ret) { + lua_common_log_line(flags, L, logbuf, uid, modname, stack_pos); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + + +static gint +lua_logger_debugm(lua_State *L) +{ + LUA_TRACE_POINT; + gchar logbuf[RSPAMD_LOGBUF_SIZE - 128]; + const gchar *uid = NULL, *module = NULL; + gint stack_pos = 1; + gboolean ret; + + module = luaL_checkstring(L, 1); + + if (lua_type(L, 2) == LUA_TSTRING) { + uid = luaL_checkstring(L, 2); + } + else { + uid = lua_logger_get_id(L, 2, NULL); + } + + if (uid && module) { + if (lua_type(L, 3) == LUA_TSTRING) { + ret = lua_logger_log_format(L, 3, FALSE, logbuf, sizeof(logbuf) - 1); + } + else if (lua_type(L, 3) == LUA_TNUMBER) { + stack_pos = lua_tonumber(L, 3); + ret = lua_logger_log_format(L, 4, FALSE, logbuf, sizeof(logbuf) - 1); + } + else { + return luaL_error(L, "invalid argument on pos 3"); + } + + if (ret) { + lua_common_log_line(G_LOG_LEVEL_DEBUG, L, logbuf, uid, module, stack_pos); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + + +static gint +lua_logger_slog(lua_State *L) +{ + return lua_logger_do_log(L, 0, TRUE, 1); +} + +static gint +lua_logger_log_level(lua_State *L) +{ + gint log_level = rspamd_log_get_log_level(NULL); + + lua_pushstring(L, rspamd_get_log_severity_string(log_level)); + + return 1; +} + +/*** Init functions ***/ + +static gint +lua_load_logger(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, loggerlib_f); + + return 1; +} + +void luaopen_logger(lua_State *L) +{ + rspamd_lua_add_preload(L, "rspamd_logger", lua_load_logger); +} diff --git a/src/lua/lua_map.c b/src/lua/lua_map.c new file mode 100644 index 0000000..54cfb4b --- /dev/null +++ b/src/lua/lua_map.c @@ -0,0 +1,1421 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "lua_common.h" +#include "libserver/maps/map.h" +#include "libserver/maps/map_helpers.h" +#include "libserver/maps/map_private.h" +#include "contrib/libucl/lua_ucl.h" + +/*** + * This module is used to manage rspamd maps and map like objects + * + * @module rspamd_map + * + * All maps could be obtained by function `rspamd_config:get_maps()` + * Also see [`lua_maps` module description](lua_maps.html). + * + * **Important notice** maps cannot be queried outside of the worker context. + * For example, you cannot add even a file map and query some keys from it during + * some module initialisation, you need to add the appropriate event loop context + * for a worker (e.g. you cannot use `get_key` outside of the symbols callbacks or + * a worker `on_load` scripts). + * +@example + +local hash_map = rspamd_config:add_map{ + type = "hash", + urls = ['file:///path/to/file'], + description = 'sample map' +} + +local function sample_symbol_cb(task) + -- Check whether hash map contains from address of message + if hash_map:get_key((task:get_from() or {})[1]) then + -- key found + end +end + +rspamd_config:register_symbol{ + name = 'SAMPLE_SYMBOL', + type = 'normal', + score = 1.0, + description = "A sample symbol", + callback = sample_symbol_cb, +} + */ + +/*** + * @method map:get_key(in) + * Variable method for different types of maps: + * + * - For hash maps it returns boolean and accepts string + * - For kv maps it returns string (or nil) and accepts string + * - For radix maps it returns boolean and accepts IP address (as object, string or number) + * + * @param {vary} in input to check + * @return {bool|string} if a value is found then this function returns string or `True` if not - then it returns `nil` or `False` + */ +LUA_FUNCTION_DEF(map, get_key); + + +/*** + * @method map:is_signed() + * Returns `True` if a map is signed + * @return {bool} signed value + */ +LUA_FUNCTION_DEF(map, is_signed); + +/*** + * @method map:get_proto() + * Returns protocol of map as string: + * + * - `http`: for HTTP map + * - `file`: for file map + * @return {string} string representation of the map protocol + */ +LUA_FUNCTION_DEF(map, get_proto); + +/*** + * @method map:get_sign_key() + * Returns pubkey used for signing as base32 string or nil + * @return {string} base32 encoded string or nil + */ +LUA_FUNCTION_DEF(map, get_sign_key); + +/*** + * @method map:set_sign_key(key) + * Set trusted key for signatures for this map + * @param {string} key base32 encoded string or nil + */ +LUA_FUNCTION_DEF(map, set_sign_key); + +/*** + * @method map:set_callback(cb) + * Set callback for a specified callback map. + * @param {function} cb map callback function + */ +LUA_FUNCTION_DEF(map, set_callback); + +/*** + * @method map:get_uri() + * Get uri for a specified map + * @return {string} map's URI + */ +LUA_FUNCTION_DEF(map, get_uri); + +/*** + * @method map:get_stats(reset) + * Get statistics for specific map. It returns table in form: + * [key] => [nhits] + * @param {boolean} reset reset stats if true + * @return {table} map's stat + */ +LUA_FUNCTION_DEF(map, get_stats); + +/*** + * @method map:foreach(callback, is_text) + * Iterate over map elements and call callback for each element. + * @param {function} callback callback function, that accepts two arguments: key and value, if it returns true then iteration is stopped + * @param {boolean} is_text if true then callback accepts rspamd_text instead of Lua strings + * @return {number} number of elements iterated + */ +LUA_FUNCTION_DEF(map, foreach); + +/*** + * @method map:on_load(callback) + * Sets a callback for a map that is called when map is loaded + * @param {function} callback callback function, that accepts no arguments (pass maps in a closure if needed) + */ +LUA_FUNCTION_DEF(map, on_load); + +/*** + * @method map:get_data_digest() + * Get data digest for specific map + * @return {string} 64 bit number represented as string (due to Lua limitations) + */ +LUA_FUNCTION_DEF(map, get_data_digest); + +/*** + * @method map:get_nelts() + * Get number of elements for specific map + * @return {number} number of elements in the map + */ +LUA_FUNCTION_DEF(map, get_nelts); + +static const struct luaL_reg maplib_m[] = { + LUA_INTERFACE_DEF(map, get_key), + LUA_INTERFACE_DEF(map, is_signed), + LUA_INTERFACE_DEF(map, get_proto), + LUA_INTERFACE_DEF(map, get_sign_key), + LUA_INTERFACE_DEF(map, set_sign_key), + LUA_INTERFACE_DEF(map, set_callback), + LUA_INTERFACE_DEF(map, get_uri), + LUA_INTERFACE_DEF(map, get_stats), + LUA_INTERFACE_DEF(map, foreach), + LUA_INTERFACE_DEF(map, on_load), + LUA_INTERFACE_DEF(map, get_data_digest), + LUA_INTERFACE_DEF(map, get_nelts), + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + +struct lua_map_callback_data { + lua_State *L; + gint ref; + gboolean opaque; + rspamd_fstring_t *data; + struct rspamd_lua_map *lua_map; +}; + +struct rspamd_lua_map * +lua_check_map(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{map}"); + luaL_argcheck(L, ud != NULL, pos, "'map' expected"); + return ud ? *((struct rspamd_lua_map **) ud) : NULL; +} + +gint lua_config_add_radix_map(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *map_line, *description; + struct rspamd_lua_map *map, **pmap; + struct rspamd_map *m; + + if (cfg) { + map_line = luaL_checkstring(L, 2); + description = lua_tostring(L, 3); + map = rspamd_mempool_alloc0(cfg->cfg_pool, sizeof(*map)); + map->data.radix = NULL; + map->type = RSPAMD_LUA_MAP_RADIX; + + if ((m = rspamd_map_add(cfg, map_line, description, + rspamd_radix_read, + rspamd_radix_fin, + rspamd_radix_dtor, + (void **) &map->data.radix, + NULL, RSPAMD_MAP_DEFAULT)) == NULL) { + msg_warn_config("invalid radix map %s", map_line); + lua_pushnil(L); + + return 1; + } + + map->map = m; + m->lua_map = map; + pmap = lua_newuserdata(L, sizeof(void *)); + *pmap = map; + rspamd_lua_setclass(L, "rspamd{map}", -1); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +gint lua_config_radix_from_config(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *mname, *optname; + const ucl_object_t *obj; + struct rspamd_lua_map *map, **pmap; + ucl_object_t *fake_obj; + struct rspamd_map *m; + + if (!cfg) { + return luaL_error(L, "invalid arguments"); + } + + mname = luaL_checkstring(L, 2); + optname = luaL_checkstring(L, 3); + + if (mname && optname) { + obj = rspamd_config_get_module_opt(cfg, mname, optname); + + if (obj) { + map = rspamd_mempool_alloc0(cfg->cfg_pool, sizeof(*map)); + map->data.radix = NULL; + map->type = RSPAMD_LUA_MAP_RADIX; + + fake_obj = ucl_object_typed_new(UCL_OBJECT); + ucl_object_insert_key(fake_obj, ucl_object_ref(obj), + "data", 0, false); + ucl_object_insert_key(fake_obj, ucl_object_fromstring("static"), + "url", 0, false); + + if ((m = rspamd_map_add_from_ucl(cfg, fake_obj, "static radix map", + rspamd_radix_read, + rspamd_radix_fin, + rspamd_radix_dtor, + (void **) &map->data.radix, + NULL, RSPAMD_MAP_DEFAULT)) == NULL) { + msg_err_config("invalid radix map static"); + lua_pushnil(L); + ucl_object_unref(fake_obj); + + return 1; + } + + ucl_object_unref(fake_obj); + pmap = lua_newuserdata(L, sizeof(void *)); + map->map = m; + m->lua_map = map; + *pmap = map; + rspamd_lua_setclass(L, "rspamd{map}", -1); + } + else { + msg_warn_config("Couldnt find config option [%s][%s]", mname, + optname); + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + + +gint lua_config_radix_from_ucl(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + ucl_object_t *obj; + struct rspamd_lua_map *map, **pmap; + ucl_object_t *fake_obj; + struct rspamd_map *m; + + if (!cfg) { + return luaL_error(L, "invalid arguments"); + } + + obj = ucl_object_lua_import(L, 2); + + if (obj) { + map = rspamd_mempool_alloc0(cfg->cfg_pool, sizeof(*map)); + map->data.radix = NULL; + map->type = RSPAMD_LUA_MAP_RADIX; + + fake_obj = ucl_object_typed_new(UCL_OBJECT); + ucl_object_insert_key(fake_obj, ucl_object_ref(obj), + "data", 0, false); + ucl_object_insert_key(fake_obj, ucl_object_fromstring("static"), + "url", 0, false); + + if ((m = rspamd_map_add_from_ucl(cfg, fake_obj, "static radix map", + rspamd_radix_read, + rspamd_radix_fin, + rspamd_radix_dtor, + (void **) &map->data.radix, + NULL, RSPAMD_MAP_DEFAULT)) == NULL) { + msg_err_config("invalid radix map static"); + lua_pushnil(L); + ucl_object_unref(fake_obj); + ucl_object_unref(obj); + + return 1; + } + + ucl_object_unref(fake_obj); + ucl_object_unref(obj); + pmap = lua_newuserdata(L, sizeof(void *)); + map->map = m; + m->lua_map = map; + *pmap = map; + rspamd_lua_setclass(L, "rspamd{map}", -1); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +gint lua_config_add_hash_map(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *map_line, *description; + struct rspamd_lua_map *map, **pmap; + struct rspamd_map *m; + + if (cfg) { + map_line = luaL_checkstring(L, 2); + description = lua_tostring(L, 3); + map = rspamd_mempool_alloc0(cfg->cfg_pool, sizeof(*map)); + map->data.hash = NULL; + map->type = RSPAMD_LUA_MAP_SET; + + if ((m = rspamd_map_add(cfg, map_line, description, + rspamd_kv_list_read, + rspamd_kv_list_fin, + rspamd_kv_list_dtor, + (void **) &map->data.hash, + NULL, RSPAMD_MAP_DEFAULT)) == NULL) { + msg_warn_config("invalid set map %s", map_line); + lua_pushnil(L); + return 1; + } + + map->map = m; + m->lua_map = map; + pmap = lua_newuserdata(L, sizeof(void *)); + *pmap = map; + rspamd_lua_setclass(L, "rspamd{map}", -1); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +gint lua_config_add_kv_map(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *map_line, *description; + struct rspamd_lua_map *map, **pmap; + struct rspamd_map *m; + + if (cfg) { + map_line = luaL_checkstring(L, 2); + description = lua_tostring(L, 3); + map = rspamd_mempool_alloc0(cfg->cfg_pool, sizeof(*map)); + map->data.hash = NULL; + map->type = RSPAMD_LUA_MAP_HASH; + + if ((m = rspamd_map_add(cfg, map_line, description, + rspamd_kv_list_read, + rspamd_kv_list_fin, + rspamd_kv_list_dtor, + (void **) &map->data.hash, + NULL, RSPAMD_MAP_DEFAULT)) == NULL) { + msg_warn_config("invalid hash map %s", map_line); + lua_pushnil(L); + + return 1; + } + + map->map = m; + m->lua_map = map; + pmap = lua_newuserdata(L, sizeof(void *)); + *pmap = map; + rspamd_lua_setclass(L, "rspamd{map}", -1); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + + +static gchar * +lua_map_read(gchar *chunk, gint len, + struct map_cb_data *data, + gboolean final) +{ + struct lua_map_callback_data *cbdata, *old; + + if (data->cur_data == NULL) { + old = (struct lua_map_callback_data *) data->prev_data; + cbdata = old; + cbdata->L = old->L; + cbdata->ref = old->ref; + cbdata->lua_map = old->lua_map; + data->cur_data = cbdata; + data->prev_data = NULL; + } + else { + cbdata = (struct lua_map_callback_data *) data->cur_data; + } + + if (cbdata->data == NULL) { + cbdata->data = rspamd_fstring_new_init(chunk, len); + } + else { + cbdata->data = rspamd_fstring_append(cbdata->data, chunk, len); + } + + return NULL; +} + +static void +lua_map_fin(struct map_cb_data *data, void **target) +{ + struct lua_map_callback_data *cbdata; + struct rspamd_lua_map **pmap; + struct rspamd_map *map; + + map = data->map; + + if (data->errored) { + if (data->cur_data) { + cbdata = (struct lua_map_callback_data *) data->cur_data; + if (cbdata->ref != -1) { + luaL_unref(cbdata->L, LUA_REGISTRYINDEX, cbdata->ref); + } + + if (cbdata->data) { + rspamd_fstring_free(cbdata->data); + } + + data->cur_data = NULL; + } + } + else { + if (data->cur_data) { + cbdata = (struct lua_map_callback_data *) data->cur_data; + } + else { + msg_err_map("no data read for map"); + return; + } + + if (cbdata->ref == -1) { + msg_err_map("map has no callback set"); + } + else if (cbdata->data != NULL && cbdata->data->len != 0) { + + lua_pushcfunction(cbdata->L, &rspamd_lua_traceback); + int err_idx = lua_gettop(cbdata->L); + + lua_rawgeti(cbdata->L, LUA_REGISTRYINDEX, cbdata->ref); + + if (!cbdata->opaque) { + lua_pushlstring(cbdata->L, cbdata->data->str, cbdata->data->len); + } + else { + struct rspamd_lua_text *t; + + t = lua_newuserdata(cbdata->L, sizeof(*t)); + rspamd_lua_setclass(cbdata->L, "rspamd{text}", -1); + t->flags = 0; + t->len = cbdata->data->len; + t->start = cbdata->data->str; + } + + pmap = lua_newuserdata(cbdata->L, sizeof(void *)); + *pmap = cbdata->lua_map; + rspamd_lua_setclass(cbdata->L, "rspamd{map}", -1); + + gint ret = lua_pcall(cbdata->L, 2, 0, err_idx); + + if (ret != 0) { + msg_info_map("call to %s failed (%d): %s", "map fin function", + ret, + lua_tostring(cbdata->L, -1)); + } + + lua_settop(cbdata->L, err_idx - 1); + } + + cbdata->data = rspamd_fstring_assign(cbdata->data, "", 0); + + if (target) { + *target = data->cur_data; + } + + if (data->prev_data) { + data->prev_data = NULL; + } + } +} + +static void +lua_map_dtor(struct map_cb_data *data) +{ + struct lua_map_callback_data *cbdata; + + if (data->cur_data) { + cbdata = (struct lua_map_callback_data *) data->cur_data; + if (cbdata->ref != -1) { + luaL_unref(cbdata->L, LUA_REGISTRYINDEX, cbdata->ref); + } + + if (cbdata->data) { + rspamd_fstring_free(cbdata->data); + } + } +} + +gint lua_config_add_map(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const char *description = NULL; + const gchar *type = NULL; + ucl_object_t *map_obj = NULL; + struct lua_map_callback_data *cbdata; + struct rspamd_lua_map *map, **pmap; + struct rspamd_map *m; + gboolean opaque_data = FALSE; + int cbidx = -1, ret; + GError *err = NULL; + + if (cfg) { + if (!rspamd_lua_parse_table_arguments(L, 2, &err, + RSPAMD_LUA_PARSE_ARGUMENTS_DEFAULT, + "*url=O;description=S;callback=F;type=S;opaque_data=B", + &map_obj, &description, &cbidx, &type, &opaque_data)) { + ret = luaL_error(L, "invalid table arguments: %s", err->message); + g_error_free(err); + if (map_obj) { + ucl_object_unref(map_obj); + } + + return ret; + } + + g_assert(map_obj != NULL); + + if (type == NULL && cbidx != -1) { + type = "callback"; + } + else if (type == NULL) { + return luaL_error(L, "invalid map type"); + } + + if (strcmp(type, "callback") == 0) { + map = rspamd_mempool_alloc0(cfg->cfg_pool, sizeof(*map)); + map->type = RSPAMD_LUA_MAP_CALLBACK; + map->data.cbdata = rspamd_mempool_alloc0(cfg->cfg_pool, + sizeof(*map->data.cbdata)); + cbdata = map->data.cbdata; + cbdata->L = L; + cbdata->data = NULL; + cbdata->lua_map = map; + cbdata->ref = cbidx; + cbdata->opaque = opaque_data; + + if ((m = rspamd_map_add_from_ucl(cfg, map_obj, description, + lua_map_read, + lua_map_fin, + lua_map_dtor, + (void **) &map->data.cbdata, + NULL, RSPAMD_MAP_DEFAULT)) == NULL) { + + if (cbidx != -1) { + luaL_unref(L, LUA_REGISTRYINDEX, cbidx); + } + + if (map_obj) { + ucl_object_unref(map_obj); + } + + lua_pushnil(L); + + return 1; + } + m->lua_map = map; + } + else if (strcmp(type, "set") == 0) { + map = rspamd_mempool_alloc0(cfg->cfg_pool, sizeof(*map)); + map->data.hash = NULL; + map->type = RSPAMD_LUA_MAP_SET; + + if ((m = rspamd_map_add_from_ucl(cfg, map_obj, description, + rspamd_kv_list_read, + rspamd_kv_list_fin, + rspamd_kv_list_dtor, + (void **) &map->data.hash, + NULL, RSPAMD_MAP_DEFAULT)) == NULL) { + lua_pushnil(L); + ucl_object_unref(map_obj); + + return 1; + } + m->lua_map = map; + } + else if (strcmp(type, "map") == 0 || strcmp(type, "hash") == 0) { + map = rspamd_mempool_alloc0(cfg->cfg_pool, sizeof(*map)); + map->data.hash = NULL; + map->type = RSPAMD_LUA_MAP_HASH; + + if ((m = rspamd_map_add_from_ucl(cfg, map_obj, description, + rspamd_kv_list_read, + rspamd_kv_list_fin, + rspamd_kv_list_dtor, + (void **) &map->data.hash, + NULL, RSPAMD_MAP_DEFAULT)) == NULL) { + lua_pushnil(L); + ucl_object_unref(map_obj); + + return 1; + } + m->lua_map = map; + } + else if (strcmp(type, "radix") == 0) { + map = rspamd_mempool_alloc0(cfg->cfg_pool, sizeof(*map)); + map->data.radix = NULL; + map->type = RSPAMD_LUA_MAP_RADIX; + + if ((m = rspamd_map_add_from_ucl(cfg, map_obj, description, + rspamd_radix_read, + rspamd_radix_fin, + rspamd_radix_dtor, + (void **) &map->data.radix, + NULL, RSPAMD_MAP_DEFAULT)) == NULL) { + lua_pushnil(L); + ucl_object_unref(map_obj); + + return 1; + } + m->lua_map = map; + } + else if (strcmp(type, "regexp") == 0) { + map = rspamd_mempool_alloc0(cfg->cfg_pool, sizeof(*map)); + map->data.re_map = NULL; + map->type = RSPAMD_LUA_MAP_REGEXP; + + if ((m = rspamd_map_add_from_ucl(cfg, map_obj, description, + rspamd_regexp_list_read_single, + rspamd_regexp_list_fin, + rspamd_regexp_list_dtor, + (void **) &map->data.re_map, + NULL, RSPAMD_MAP_DEFAULT)) == NULL) { + lua_pushnil(L); + ucl_object_unref(map_obj); + + return 1; + } + m->lua_map = map; + } + else if (strcmp(type, "regexp_multi") == 0) { + map = rspamd_mempool_alloc0(cfg->cfg_pool, sizeof(*map)); + map->data.re_map = NULL; + map->type = RSPAMD_LUA_MAP_REGEXP_MULTIPLE; + + if ((m = rspamd_map_add_from_ucl(cfg, map_obj, description, + rspamd_regexp_list_read_multiple, + rspamd_regexp_list_fin, + rspamd_regexp_list_dtor, + (void **) &map->data.re_map, + NULL, RSPAMD_MAP_DEFAULT)) == NULL) { + lua_pushnil(L); + ucl_object_unref(map_obj); + + return 1; + } + m->lua_map = map; + } + else if (strcmp(type, "glob") == 0) { + map = rspamd_mempool_alloc0(cfg->cfg_pool, sizeof(*map)); + map->data.re_map = NULL; + map->type = RSPAMD_LUA_MAP_REGEXP; + + if ((m = rspamd_map_add_from_ucl(cfg, map_obj, description, + rspamd_glob_list_read_single, + rspamd_regexp_list_fin, + rspamd_regexp_list_dtor, + (void **) &map->data.re_map, + NULL, RSPAMD_MAP_DEFAULT)) == NULL) { + lua_pushnil(L); + ucl_object_unref(map_obj); + + return 1; + } + m->lua_map = map; + } + else if (strcmp(type, "glob_multi") == 0) { + map = rspamd_mempool_alloc0(cfg->cfg_pool, sizeof(*map)); + map->data.re_map = NULL; + map->type = RSPAMD_LUA_MAP_REGEXP_MULTIPLE; + + if ((m = rspamd_map_add_from_ucl(cfg, map_obj, description, + rspamd_glob_list_read_multiple, + rspamd_regexp_list_fin, + rspamd_regexp_list_dtor, + (void **) &map->data.re_map, + NULL, RSPAMD_MAP_DEFAULT)) == NULL) { + lua_pushnil(L); + ucl_object_unref(map_obj); + + return 1; + } + m->lua_map = map; + } + else if (strcmp(type, "cdb") == 0) { + map = rspamd_mempool_alloc0(cfg->cfg_pool, sizeof(*map)); + map->data.cdb_map = NULL; + map->type = RSPAMD_LUA_MAP_CDB; + + if ((m = rspamd_map_add_from_ucl(cfg, map_obj, description, + rspamd_cdb_list_read, + rspamd_cdb_list_fin, + rspamd_cdb_list_dtor, + (void **) &map->data.cdb_map, + NULL, RSPAMD_MAP_FILE_ONLY | RSPAMD_MAP_FILE_NO_READ)) == NULL) { + lua_pushnil(L); + ucl_object_unref(map_obj); + + return 1; + } + m->lua_map = map; + } + else { + ret = luaL_error(L, "invalid arguments: unknown type '%s'", type); + ucl_object_unref(map_obj); + + return ret; + } + + map->map = m; + pmap = lua_newuserdata(L, sizeof(void *)); + *pmap = map; + rspamd_lua_setclass(L, "rspamd{map}", -1); + } + else { + return luaL_error(L, "invalid arguments"); + } + + ucl_object_unref(map_obj); + + return 1; +} + +gint lua_config_get_maps(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + struct rspamd_lua_map *map, **pmap; + struct rspamd_map *m; + gint i = 1; + GList *cur; + + if (cfg) { + lua_newtable(L); + cur = g_list_first(cfg->maps); + + while (cur) { + m = cur->data; + + if (m->lua_map) { + map = m->lua_map; + } + else { + /* Implement heuristic */ + map = rspamd_mempool_alloc0(cfg->cfg_pool, sizeof(*map)); + + if (m->read_callback == rspamd_radix_read) { + map->type = RSPAMD_LUA_MAP_RADIX; + map->data.radix = *m->user_data; + } + else if (m->read_callback == rspamd_kv_list_read) { + map->type = RSPAMD_LUA_MAP_HASH; + map->data.hash = *m->user_data; + } + else { + map->type = RSPAMD_LUA_MAP_UNKNOWN; + } + + map->map = m; + m->lua_map = map; + } + + pmap = lua_newuserdata(L, sizeof(*pmap)); + *pmap = map; + rspamd_lua_setclass(L, "rspamd{map}", -1); + lua_rawseti(L, -2, i); + + cur = g_list_next(cur); + i++; + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static const gchar * +lua_map_process_string_key(lua_State *L, gint pos, gsize *len) +{ + struct rspamd_lua_text *t; + + if (lua_type(L, pos) == LUA_TSTRING) { + return lua_tolstring(L, pos, len); + } + else if (lua_type(L, pos) == LUA_TUSERDATA) { + t = lua_check_text(L, pos); + + if (t) { + *len = t->len; + return t->start; + } + } + + return NULL; +} + +/* Radix and hash table functions */ +static gint +lua_map_get_key(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_map *map = lua_check_map(L, 1); + struct rspamd_radix_map_helper *radix; + struct rspamd_lua_ip *addr = NULL; + const gchar *key, *value = NULL; + gpointer ud; + gsize len; + guint32 key_num = 0; + gboolean ret = FALSE; + + if (map) { + if (map->type == RSPAMD_LUA_MAP_RADIX) { + radix = map->data.radix; + + if (lua_type(L, 2) == LUA_TSTRING) { + const gchar *addr_str; + + addr_str = luaL_checklstring(L, 2, &len); + addr = g_alloca(sizeof(*addr)); + addr->addr = g_alloca(rspamd_inet_address_storage_size()); + + if (!rspamd_parse_inet_address_ip(addr_str, len, addr->addr)) { + addr = NULL; + } + } + else if (lua_type(L, 2) == LUA_TUSERDATA) { + ud = rspamd_lua_check_udata(L, 2, "rspamd{ip}"); + if (ud != NULL) { + addr = *((struct rspamd_lua_ip **) ud); + + if (addr->addr == NULL) { + addr = NULL; + } + } + else { + msg_err("invalid userdata type provided, rspamd{ip} expected"); + } + } + else if (lua_type(L, 2) == LUA_TNUMBER) { + key_num = luaL_checkinteger(L, 2); + key_num = htonl(key_num); + } + + if (radix) { + gconstpointer p = NULL; + + if (addr != NULL) { + if ((p = rspamd_match_radix_map_addr(radix, addr->addr)) != NULL) { + ret = TRUE; + } + else { + p = 0; + } + } + else if (key_num != 0) { + if ((p = rspamd_match_radix_map(radix, + (guint8 *) &key_num, sizeof(key_num))) != NULL) { + ret = TRUE; + } + else { + p = 0; + } + } + + value = (const char *) p; + } + + if (ret) { + lua_pushstring(L, value); + return 1; + } + } + else if (map->type == RSPAMD_LUA_MAP_SET) { + key = lua_map_process_string_key(L, 2, &len); + + if (key && map->data.hash) { + ret = rspamd_match_hash_map(map->data.hash, key, len) != NULL; + } + } + else if (map->type == RSPAMD_LUA_MAP_REGEXP) { + key = lua_map_process_string_key(L, 2, &len); + + if (key && map->data.re_map) { + value = rspamd_match_regexp_map_single(map->data.re_map, key, + len); + + if (value) { + lua_pushstring(L, value); + return 1; + } + } + } + else if (map->type == RSPAMD_LUA_MAP_REGEXP_MULTIPLE) { + GPtrArray *ar; + guint i; + const gchar *val; + + key = lua_map_process_string_key(L, 2, &len); + + if (key && map->data.re_map) { + ar = rspamd_match_regexp_map_all(map->data.re_map, key, + len); + + if (ar) { + lua_createtable(L, ar->len, 0); + + PTR_ARRAY_FOREACH(ar, i, val) + { + lua_pushstring(L, val); + lua_rawseti(L, -2, i + 1); + } + + g_ptr_array_free(ar, TRUE); + + return 1; + } + } + } + else if (map->type == RSPAMD_LUA_MAP_HASH) { + /* key-value map */ + key = lua_map_process_string_key(L, 2, &len); + + if (key && map->data.hash) { + value = rspamd_match_hash_map(map->data.hash, key, len); + } + + if (value) { + lua_pushstring(L, value); + return 1; + } + } + else if (map->type == RSPAMD_LUA_MAP_CDB) { + /* cdb map */ + const rspamd_ftok_t *tok = NULL; + + key = lua_map_process_string_key(L, 2, &len); + + if (key && map->data.cdb_map) { + tok = rspamd_match_cdb_map(map->data.cdb_map, key, len); + } + + if (tok) { + lua_pushlstring(L, tok->begin, tok->len); + return 1; + } + } + else { + /* callback map or unknown type map */ + lua_pushnil(L); + return 1; + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + lua_pushboolean(L, ret); + return 1; +} + +static gboolean +lua_map_traverse_cb(gconstpointer key, + gconstpointer value, gsize hits, gpointer ud) +{ + lua_State *L = (lua_State *) ud; + + lua_pushstring(L, key); + lua_pushinteger(L, hits); + lua_settable(L, -3); + + return TRUE; +} + +static gint +lua_map_get_stats(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_map *map = lua_check_map(L, 1); + gboolean do_reset = FALSE; + + if (map != NULL) { + if (lua_isboolean(L, 2)) { + do_reset = lua_toboolean(L, 2); + } + + lua_createtable(L, 0, map->map->nelts); + + if (map->map->traverse_function) { + rspamd_map_traverse(map->map, lua_map_traverse_cb, L, do_reset); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +struct lua_map_traverse_cbdata { + lua_State *L; + gint cbref; + gboolean use_text; +}; + +static gboolean +lua_map_foreach_cb(gconstpointer key, gconstpointer value, gsize _hits, gpointer ud) +{ + struct lua_map_traverse_cbdata *cbdata = ud; + lua_State *L = cbdata->L; + + lua_pushvalue(L, cbdata->cbref); + + if (cbdata->use_text) { + lua_new_text(L, key, strlen(key), 0); + lua_new_text(L, value, strlen(value), 0); + } + else { + lua_pushstring(L, key); + lua_pushstring(L, value); + } + + if (lua_pcall(L, 2, 1, 0) != 0) { + msg_err("call to map foreach callback failed: %s", lua_tostring(L, -1)); + lua_pop(L, 1); /* error */ + + return FALSE; + } + else { + if (lua_isboolean(L, -1)) { + lua_pop(L, 2); + + return lua_toboolean(L, -1); + } + + lua_pop(L, 1); /* result */ + } + + return TRUE; +} + +static gint +lua_map_foreach(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_map *map = lua_check_map(L, 1); + gboolean use_text = FALSE; + + if (map != NULL && lua_isfunction(L, 2)) { + if (lua_isboolean(L, 3)) { + use_text = lua_toboolean(L, 3); + } + + struct lua_map_traverse_cbdata cbdata; + cbdata.L = L; + lua_pushvalue(L, 2); /* func */ + cbdata.cbref = lua_gettop(L); + cbdata.use_text = use_text; + + if (map->map->traverse_function) { + rspamd_map_traverse(map->map, lua_map_foreach_cb, &cbdata, FALSE); + } + + /* Remove callback */ + lua_pop(L, 1); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_map_get_data_digest(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_map *map = lua_check_map(L, 1); + gchar numbuf[64]; + + if (map != NULL) { + rspamd_snprintf(numbuf, sizeof(numbuf), "%uL", map->map->digest); + lua_pushstring(L, numbuf); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_map_get_nelts(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_map *map = lua_check_map(L, 1); + + if (map != NULL) { + lua_pushinteger(L, map->map->nelts); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static int +lua_map_is_signed(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_map *map = lua_check_map(L, 1); + gboolean ret = FALSE; + struct rspamd_map_backend *bk; + guint i; + + if (map != NULL) { + if (map->map) { + for (i = 0; i < map->map->backends->len; i++) { + bk = g_ptr_array_index(map->map->backends, i); + if (bk->is_signed && bk->protocol == MAP_PROTO_FILE) { + ret = TRUE; + break; + } + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + lua_pushboolean(L, ret); + return 1; +} + +static int +lua_map_get_proto(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_map *map = lua_check_map(L, 1); + const gchar *ret = "undefined"; + struct rspamd_map_backend *bk; + guint i; + + if (map != NULL) { + for (i = 0; i < map->map->backends->len; i++) { + bk = g_ptr_array_index(map->map->backends, i); + switch (bk->protocol) { + case MAP_PROTO_FILE: + ret = "file"; + break; + case MAP_PROTO_HTTP: + ret = "http"; + break; + case MAP_PROTO_HTTPS: + ret = "https"; + break; + case MAP_PROTO_STATIC: + ret = "static"; + break; + } + lua_pushstring(L, ret); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + + return map->map->backends->len; +} + +static int +lua_map_get_sign_key(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_map *map = lua_check_map(L, 1); + struct rspamd_map_backend *bk; + guint i; + GString *ret = NULL; + + if (map != NULL) { + for (i = 0; i < map->map->backends->len; i++) { + bk = g_ptr_array_index(map->map->backends, i); + + if (bk->trusted_pubkey) { + ret = rspamd_pubkey_print(bk->trusted_pubkey, + RSPAMD_KEYPAIR_PUBKEY | RSPAMD_KEYPAIR_BASE32); + } + else { + ret = NULL; + } + + if (ret) { + lua_pushlstring(L, ret->str, ret->len); + g_string_free(ret, TRUE); + } + else { + lua_pushnil(L); + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return map->map->backends->len; +} + +static int +lua_map_set_sign_key(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_map *map = lua_check_map(L, 1); + struct rspamd_map_backend *bk; + const gchar *pk_str; + struct rspamd_cryptobox_pubkey *pk; + gsize len; + guint i; + + pk_str = lua_tolstring(L, 2, &len); + + if (map && pk_str) { + pk = rspamd_pubkey_from_base32(pk_str, len, RSPAMD_KEYPAIR_SIGN, + RSPAMD_CRYPTOBOX_MODE_25519); + + if (!pk) { + return luaL_error(L, "invalid pubkey string"); + } + + for (i = 0; i < map->map->backends->len; i++) { + bk = g_ptr_array_index(map->map->backends, i); + if (bk->trusted_pubkey) { + /* Unref old pk */ + rspamd_pubkey_unref(bk->trusted_pubkey); + } + + bk->trusted_pubkey = rspamd_pubkey_ref(pk); + } + + rspamd_pubkey_unref(pk); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static int +lua_map_set_callback(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_map *map = lua_check_map(L, 1); + + if (!map || map->type != RSPAMD_LUA_MAP_CALLBACK || map->data.cbdata == NULL) { + return luaL_error(L, "invalid map"); + } + + if (lua_type(L, 2) != LUA_TFUNCTION) { + return luaL_error(L, "invalid callback"); + } + + lua_pushvalue(L, 2); + /* Get a reference */ + map->data.cbdata->ref = luaL_ref(L, LUA_REGISTRYINDEX); + + return 0; +} + +static int +lua_map_get_uri(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_map *map = lua_check_map(L, 1); + struct rspamd_map_backend *bk; + guint i; + + if (map != NULL) { + for (i = 0; i < map->map->backends->len; i++) { + bk = g_ptr_array_index(map->map->backends, i); + lua_pushstring(L, bk->uri); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return map->map->backends->len; +} + +struct lua_map_on_load_cbdata { + lua_State *L; + gint ref; +}; + +static void +lua_map_on_load_dtor(gpointer p) +{ + struct lua_map_on_load_cbdata *cbd = p; + + luaL_unref(cbd->L, LUA_REGISTRYINDEX, cbd->ref); + g_free(cbd); +} + +static void +lua_map_on_load_handler(struct rspamd_map *map, gpointer ud) +{ + struct lua_map_on_load_cbdata *cbd = ud; + lua_State *L = cbd->L; + + lua_rawgeti(L, LUA_REGISTRYINDEX, cbd->ref); + + if (lua_pcall(L, 0, 0, 0) != 0) { + msg_err_map("call to on_load function failed: %s", lua_tostring(L, -1)); + } +} + +static gint +lua_map_on_load(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_map *map = lua_check_map(L, 1); + + if (map == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (lua_type(L, 2) == LUA_TFUNCTION) { + lua_pushvalue(L, 2); + struct lua_map_on_load_cbdata *cbd = g_malloc(sizeof(struct lua_map_on_load_cbdata)); + cbd->L = L; + cbd->ref = luaL_ref(L, LUA_REGISTRYINDEX); + + rspamd_map_set_on_load_function(map->map, lua_map_on_load_handler, cbd, lua_map_on_load_dtor); + } + else { + return luaL_error(L, "invalid callback"); + } + + return 0; +} + +void luaopen_map(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{map}", maplib_m); + + lua_pop(L, 1); +} diff --git a/src/lua/lua_map.h b/src/lua/lua_map.h new file mode 100644 index 0000000..70677de --- /dev/null +++ b/src/lua/lua_map.h @@ -0,0 +1,38 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef SRC_LUA_LUA_MAP_H_ +#define SRC_LUA_LUA_MAP_H_ + +#include "lua_common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +LUA_PUBLIC_FUNCTION_DEF(config, add_radix_map); +LUA_PUBLIC_FUNCTION_DEF(config, radix_from_config); +LUA_PUBLIC_FUNCTION_DEF(config, radix_from_ucl); +LUA_PUBLIC_FUNCTION_DEF(config, add_map); +LUA_PUBLIC_FUNCTION_DEF(config, add_hash_map); +LUA_PUBLIC_FUNCTION_DEF(config, add_kv_map); +LUA_PUBLIC_FUNCTION_DEF(config, add_map); +LUA_PUBLIC_FUNCTION_DEF(config, get_maps); + +#ifdef __cplusplus +} +#endif + +#endif /* SRC_LUA_LUA_MAP_H_ */ diff --git a/src/lua/lua_mempool.c b/src/lua/lua_mempool.c new file mode 100644 index 0000000..4897d15 --- /dev/null +++ b/src/lua/lua_mempool.c @@ -0,0 +1,612 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" + +/*** + * @module rspamd_mempool + * Rspamd memory pool is used to allocate memory attached to specific objects, + * namely it was initially used for memory allocation for rspamd_task. + * + * All memory allocated by the pool is destroyed when the associated object is + * destroyed. This allows a sort of controlled garbage collection for memory + * allocated from the pool. Memory pools are extensively used by rspamd internal + * components and provide some powerful features, such as destructors or + * persistent variables. + * @example +local mempool = require "rspamd_mempool" +local pool = mempool.create() + +pool:set_variable('a', 'bcd', 1, 1.01, false) +local v1, v2, v3, v4 = pool:get_variable('a', 'string,double,double,bool') +pool:destroy() + */ + +/* Lua bindings */ +/*** + * @function mempool.create([size]) + * Creates a memory pool of a specified `size` or platform dependent optimal size (normally, a page size) + * @param {number} size size of a page inside pool + * @return {rspamd_mempool} new pool object (that should be removed by explicit call to `pool:destroy()`) + */ +LUA_FUNCTION_DEF(mempool, create); +/*** + * @method mempool:add_destructor(func) + * Adds new destructor function to the pool + * @param {function} func function to be called when the pool is destroyed + */ +LUA_FUNCTION_DEF(mempool, add_destructor); +/*** + * @method mempool:destroy() + * Destroys memory pool cleaning all variables and calling all destructors registered (both C and Lua ones) + */ +LUA_FUNCTION_DEF(mempool, delete); +LUA_FUNCTION_DEF(mempool, stat); +LUA_FUNCTION_DEF(mempool, suggest_size); +/*** + * @method mempool:set_variable(name, [value1[, value2 ...]]) + * Sets a variable that's valid during memory pool lifetime. This function allows + * to pack multiple values inside a single variable. Currently supported types are: + * + * - `string`: packed as null terminated C string (so no `\0` are allowed) + * - `number`: packed as C double + * - `boolean`: packed as bool + * @param {string} name variable's name to set + */ +LUA_FUNCTION_DEF(mempool, set_variable); +/*** + * @method mempool:set_bucket(name, num_values, [value1...valuen]|[table]) + * Stores a variable bucket of numbers where the first number is number of elements to pack + * and then there should be either n numeric values or a plain table of numeric values + * @param {string} name variable's name to set + * @param {number} num_values number of variables in the bucket + * @param {table|list} values values + */ +LUA_FUNCTION_DEF(mempool, set_bucket); +/*** + * @method mempool:get_variable(name[, type]) + * Unpacks mempool variable to lua If `type` is not specified, then a variable is + * assumed to be zero-terminated C string. Otherwise, `type` is a comma separated (spaces are ignored) + * list of types that should be unpacked from a variable's content. The following types + * are supported: + * + * - `string`: null terminated C string (so no `\0` are allowed) + * - `double`: returned as lua number + * - `int`: unpack a single integer + * - `int64`: unpack 64-bits integer + * - `boolean`: unpack boolean + * - `bucket`: bucket of numbers represented as a Lua table + * - `fstrings`: list of rspamd_fstring_t (GList) represented as a Lua table + * @param {string} name variable's name to get + * @param {string} type list of types to be extracted + * @return {variable list} list of variables extracted (but **not** a table) + */ +LUA_FUNCTION_DEF(mempool, get_variable); +/*** + * @method mempool:has_variable(name) + * Checks if the specified variable `name` exists in the memory pool + * @param {string} name variable's name to get + * @return {boolean} `true` if variable exists and `false` otherwise + */ +LUA_FUNCTION_DEF(mempool, has_variable); + +/*** + * @method mempool:delete_variable(name) + * Removes the specified variable `name` from the memory pool + * @param {string} name variable's name to remove + * @return {boolean} `true` if variable exists and has been removed + */ +LUA_FUNCTION_DEF(mempool, delete_variable); +/** + * @method mempool:topointer() + * + * Returns raw C pointer (lightuserdata) associated with mempool. This might be + * broken with luajit and GC64, use with caution. + */ +LUA_FUNCTION_DEF(mempool, topointer); + +static const struct luaL_reg mempoollib_m[] = { + LUA_INTERFACE_DEF(mempool, add_destructor), + LUA_INTERFACE_DEF(mempool, stat), + LUA_INTERFACE_DEF(mempool, suggest_size), + LUA_INTERFACE_DEF(mempool, set_variable), + LUA_INTERFACE_DEF(mempool, set_bucket), + LUA_INTERFACE_DEF(mempool, get_variable), + LUA_INTERFACE_DEF(mempool, has_variable), + LUA_INTERFACE_DEF(mempool, delete_variable), + LUA_INTERFACE_DEF(mempool, topointer), + LUA_INTERFACE_DEF(mempool, delete), + {"destroy", lua_mempool_delete}, + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + +static const struct luaL_reg mempoollib_f[] = { + LUA_INTERFACE_DEF(mempool, create), + {NULL, NULL}}; + +/* + * Struct for lua destructor + */ + +struct lua_mempool_udata { + lua_State *L; + gint cbref; + rspamd_mempool_t *mempool; +}; + +struct memory_pool_s * +rspamd_lua_check_mempool(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{mempool}"); + luaL_argcheck(L, ud != NULL, pos, "'mempool' expected"); + return ud ? *((struct memory_pool_s **) ud) : NULL; +} + + +static int +lua_mempool_create(lua_State *L) +{ + LUA_TRACE_POINT; + struct memory_pool_s *mempool = rspamd_mempool_new( + rspamd_mempool_suggest_size(), "lua", 0), + **pmempool; + + if (mempool) { + pmempool = lua_newuserdata(L, sizeof(struct memory_pool_s *)); + rspamd_lua_setclass(L, "rspamd{mempool}", -1); + *pmempool = mempool; + } + else { + lua_pushnil(L); + } + + return 1; +} + +static void +lua_mempool_destructor_func(gpointer p) +{ + struct lua_mempool_udata *ud = p; + + lua_rawgeti(ud->L, LUA_REGISTRYINDEX, ud->cbref); + if (lua_pcall(ud->L, 0, 0, 0) != 0) { + msg_info("call to destructor failed: %s", lua_tostring(ud->L, -1)); + lua_pop(ud->L, 1); + } + luaL_unref(ud->L, LUA_REGISTRYINDEX, ud->cbref); +} + +static int +lua_mempool_add_destructor(lua_State *L) +{ + LUA_TRACE_POINT; + struct memory_pool_s *mempool = rspamd_lua_check_mempool(L, 1); + struct lua_mempool_udata *ud; + + if (mempool) { + if (lua_isfunction(L, 2)) { + ud = rspamd_mempool_alloc(mempool, + sizeof(struct lua_mempool_udata)); + lua_pushvalue(L, 2); + /* Get a reference */ + ud->cbref = luaL_ref(L, LUA_REGISTRYINDEX); + ud->L = L; + ud->mempool = mempool; + rspamd_mempool_add_destructor(mempool, + lua_mempool_destructor_func, + ud); + } + else { + msg_err("trying to add destructor without function"); + } + } + else { + lua_pushnil(L); + } + + return 1; +} + +static int +lua_mempool_delete(lua_State *L) +{ + LUA_TRACE_POINT; + struct memory_pool_s *mempool = rspamd_lua_check_mempool(L, 1); + + if (mempool) { + rspamd_mempool_delete(mempool); + return 0; + } + else { + lua_pushnil(L); + } + + return 1; +} + +static int +lua_mempool_stat(lua_State *L) +{ + LUA_TRACE_POINT; + struct memory_pool_s *mempool = rspamd_lua_check_mempool(L, 1); + + if (mempool) { + } + else { + lua_pushnil(L); + } + + return 1; +} + +static int +lua_mempool_suggest_size(lua_State *L) +{ + LUA_TRACE_POINT; + struct memory_pool_s *mempool = rspamd_lua_check_mempool(L, 1); + + if (mempool) { + lua_pushinteger(L, rspamd_mempool_suggest_size()); + return 0; + } + else { + lua_pushnil(L); + } + + return 1; +} + +struct lua_numbers_bucket { + guint nelts; + gdouble elts[0]; +}; + +static int +lua_mempool_set_bucket(lua_State *L) +{ + LUA_TRACE_POINT; + struct memory_pool_s *mempool = rspamd_lua_check_mempool(L, 1); + const gchar *var = luaL_checkstring(L, 2); + struct lua_numbers_bucket *bucket; + gint nelts = luaL_checknumber(L, 3), i; + + if (var && nelts > 0) { + bucket = rspamd_mempool_alloc(mempool, + sizeof(*bucket) + sizeof(gdouble) * nelts); + bucket->nelts = nelts; + + if (lua_type(L, 4) == LUA_TTABLE) { + /* Table version */ + for (i = 1; i <= nelts; i++) { + lua_rawgeti(L, 4, i); + bucket->elts[i - 1] = lua_tonumber(L, -1); + lua_pop(L, 1); + } + } + else { + for (i = 0; i <= nelts; i++) { + bucket->elts[i] = lua_tonumber(L, 4 + i); + } + } + + rspamd_mempool_set_variable(mempool, var, bucket, NULL); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static int +lua_mempool_set_variable(lua_State *L) +{ + LUA_TRACE_POINT; + struct memory_pool_s *mempool = rspamd_lua_check_mempool(L, 1); + const gchar *var = luaL_checkstring(L, 2); + gpointer value; + struct lua_numbers_bucket *bucket; + gchar *vp; + union { + gdouble d; + const gchar *s; + gboolean b; + } val; + gsize slen; + gint i, j, len = 0, type; + + if (mempool && var) { + + for (i = 3; i <= lua_gettop(L); i++) { + type = lua_type(L, i); + + if (type == LUA_TNUMBER) { + /* We have some ambiguity here between integer and double */ + len += sizeof(gdouble); + } + else if (type == LUA_TBOOLEAN) { + len += sizeof(gboolean); + } + else if (type == LUA_TSTRING) { + (void) lua_tolstring(L, i, &slen); + len += slen + 1; + } + else if (type == LUA_TTABLE) { + /* We assume it as a bucket of numbers so far */ + slen = rspamd_lua_table_size(L, i); + len += sizeof(gdouble) * slen + sizeof(*bucket); + } + else { + msg_err("cannot handle lua type %s", lua_typename(L, type)); + } + } + + if (len == 0) { + msg_err("no values specified"); + } + else { + value = rspamd_mempool_alloc(mempool, len); + vp = value; + + for (i = 3; i <= lua_gettop(L); i++) { + type = lua_type(L, i); + + if (type == LUA_TNUMBER) { + val.d = lua_tonumber(L, i); + memcpy(vp, &val, sizeof(gdouble)); + vp += sizeof(gdouble); + } + else if (type == LUA_TBOOLEAN) { + val.b = lua_toboolean(L, i); + memcpy(vp, &val, sizeof(gboolean)); + vp += sizeof(gboolean); + } + else if (type == LUA_TSTRING) { + val.s = lua_tolstring(L, i, &slen); + memcpy(vp, val.s, slen + 1); + vp += slen + 1; + } + else if (type == LUA_TTABLE) { + slen = rspamd_lua_table_size(L, i); + /* XXX: Ret, ret, ret: alignment issues */ + bucket = (struct lua_numbers_bucket *) vp; + bucket->nelts = slen; + + for (j = 0; j < slen; j++) { + lua_rawgeti(L, i, j + 1); + bucket->elts[j] = lua_tonumber(L, -1); + lua_pop(L, 1); + } + + vp += sizeof(gdouble) * slen + sizeof(*bucket); + } + else { + msg_err("cannot handle lua type %s", lua_typename(L, type)); + } + } + + rspamd_mempool_set_variable(mempool, var, value, NULL); + } + + return 0; + } + else { + lua_pushnil(L); + } + + return 1; +} + + +static int +lua_mempool_get_variable(lua_State *L) +{ + LUA_TRACE_POINT; + struct memory_pool_s *mempool = rspamd_lua_check_mempool(L, 1); + const gchar *var = luaL_checkstring(L, 2); + const gchar *type = NULL, *pt; + struct lua_numbers_bucket bucket; + const gchar *value, *pv; + guint len, nvar, slen, i; + + if (mempool && var) { + value = rspamd_mempool_get_variable(mempool, var); + + if (lua_gettop(L) >= 3) { + type = luaL_checkstring(L, 3); + } + + if (value) { + + if (type) { + pt = type; + pv = value; + nvar = 0; + + while ((len = strcspn(pt, ", ")) > 0) { + if (len == sizeof("double") - 1 && + g_ascii_strncasecmp(pt, "double", len) == 0) { + gdouble num; + memcpy(&num, pv, sizeof(gdouble)); + lua_pushnumber(L, num); + pv += sizeof(gdouble); + } + else if (len == sizeof("int") - 1 && + g_ascii_strncasecmp(pt, "int", len) == 0) { + gint num; + memcpy(&num, pv, sizeof(gint)); + lua_pushinteger(L, num); + pv += sizeof(gint); + } + else if (len == sizeof("int64") - 1 && + g_ascii_strncasecmp(pt, "int64", len) == 0) { + gint64 num; + memcpy(&num, pv, sizeof(gint64)); + lua_pushinteger(L, num); + pv += sizeof(gint64); + } + else if (len == sizeof("bool") - 1 && + g_ascii_strncasecmp(pt, "bool", len) == 0) { + gboolean num; + memcpy(&num, pv, sizeof(gboolean)); + lua_pushboolean(L, num); + pv += sizeof(gboolean); + } + else if (len == sizeof("string") - 1 && + g_ascii_strncasecmp(pt, "string", len) == 0) { + slen = strlen((const gchar *) pv); + lua_pushlstring(L, (const gchar *) pv, slen); + pv += slen + 1; + } + else if (len == sizeof("gstring") - 1 && + g_ascii_strncasecmp(pt, "gstring", len) == 0) { + GString *st = (GString *) pv; + lua_pushlstring(L, st->str, st->len); + pv += sizeof(GString *); + } + else if (len == sizeof("bucket") - 1 && + g_ascii_strncasecmp(pt, "bucket", len) == 0) { + memcpy(&bucket, pv, sizeof(bucket)); + lua_createtable(L, bucket.nelts, 0); + pv += sizeof(struct lua_numbers_bucket); + + for (i = 0; i < bucket.nelts; i++) { + gdouble num; + memcpy(&num, pv, sizeof(num)); + lua_pushnumber(L, num); + lua_rawseti(L, -2, i + 1); + pv += sizeof(num); + } + } + else if (len == sizeof("fstrings") - 1 && + g_ascii_strncasecmp(pt, "fstrings", len) == 0) { + GList *cur; + rspamd_fstring_t *fstr; + + cur = (GList *) pv; + lua_newtable(L); + + i = 1; + while (cur != NULL) { + fstr = cur->data; + lua_pushlstring(L, fstr->str, fstr->len); + lua_rawseti(L, -2, i); + i++; + cur = g_list_next(cur); + } + + pv += sizeof(GList *); + } + else { + msg_err("unknown type for get_variable: %s", pt); + lua_pushnil(L); + } + + pt += len; + pt += strspn(pt, ", "); + + nvar++; + } + + return nvar; + } + else { + /* No type specified, return string */ + lua_pushstring(L, value); + } + } + else { + lua_pushnil(L); + } + } + else { + lua_pushnil(L); + } + + return 1; +} + +static int +lua_mempool_has_variable(lua_State *L) +{ + LUA_TRACE_POINT; + struct memory_pool_s *mempool = rspamd_lua_check_mempool(L, 1); + const gchar *var = luaL_checkstring(L, 2); + gboolean ret = FALSE; + + if (mempool && var) { + if (rspamd_mempool_get_variable(mempool, var) != NULL) { + ret = TRUE; + } + } + + lua_pushboolean(L, ret); + + return 1; +} + +static int +lua_mempool_delete_variable(lua_State *L) +{ + LUA_TRACE_POINT; + struct memory_pool_s *mempool = rspamd_lua_check_mempool(L, 1); + const gchar *var = luaL_checkstring(L, 2); + gboolean ret = FALSE; + + if (mempool && var) { + if (rspamd_mempool_get_variable(mempool, var) != NULL) { + ret = TRUE; + + rspamd_mempool_remove_variable(mempool, var); + } + } + + lua_pushboolean(L, ret); + + return 1; +} + +static gint +lua_mempool_topointer(lua_State *L) +{ + LUA_TRACE_POINT; + rspamd_mempool_t *pool = rspamd_lua_check_mempool(L, 1); + + if (pool) { + /* XXX: this might cause issues on arm64 and LuaJIT */ + lua_pushlightuserdata(L, pool); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_load_mempool(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, mempoollib_f); + + return 1; +} + +void luaopen_mempool(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{mempool}", mempoollib_m); + lua_pop(L, 1); + rspamd_lua_add_preload(L, "rspamd_mempool", lua_load_mempool); +} diff --git a/src/lua/lua_mimepart.c b/src/lua/lua_mimepart.c new file mode 100644 index 0000000..5d4b8b7 --- /dev/null +++ b/src/lua/lua_mimepart.c @@ -0,0 +1,2304 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "lua_common.h" +#include "lua_url.h" +#include "libmime/message.h" +#include "libmime/lang_detection.h" +#include "libstat/stat_api.h" +#include "libcryptobox/cryptobox.h" +#include "libutil/shingles.h" + +#include "contrib/uthash/utlist.h" + +/* Textpart methods */ +/*** + * @module rspamd_textpart + * This module provides different methods to manipulate text parts data. Text parts + * could be obtained from the `rspamd_task` by using of method `task:get_text_parts()` +@example +rspamd_config.R_EMPTY_IMAGE = function (task) + parts = task:get_text_parts() + if parts then + for _,part in ipairs(parts) do + if part:is_empty() then + texts = task:get_texts() + if texts then + return true + end + return false + end + end + end + return false +end + */ + +/*** + * @method text_part:is_utf() + * Return TRUE if part is a valid utf text + * @return {boolean} true if part is valid `UTF8` part + */ +LUA_FUNCTION_DEF(textpart, is_utf); + +/*** + * @method text_part:has_8bit_raw() + * Return TRUE if a part has raw 8bit characters + * @return {boolean} true if a part has raw 8bit characters + */ +LUA_FUNCTION_DEF(textpart, has_8bit_raw); + +/*** + * @method text_part:has_8bit() + * Return TRUE if a part has raw 8bit characters + * @return {boolean} true if a part has encoded 8bit characters + */ +LUA_FUNCTION_DEF(textpart, has_8bit); + +/*** + * @method text_part:get_content([type]) + * Get the text of the part (html tags stripped). Optional `type` defines type of content to get: + * - `content` (default): utf8 content with HTML tags stripped and newlines preserved + * - `content_oneline`: utf8 content with HTML tags and newlines stripped + * - `raw`: raw content, not mime decoded nor utf8 converted + * - `raw_parsed`: raw content, mime decoded, not utf8 converted + * - `raw_utf`: raw content, mime decoded, utf8 converted (but with HTML tags and newlines) + * @return {text} `UTF8` encoded content of the part (zero-copy if not converted to a lua string) + */ +LUA_FUNCTION_DEF(textpart, get_content); +/*** + * @method text_part:get_raw_content() + * Get the original text of the part + * @return {text} `UTF8` encoded content of the part (zero-copy if not converted to a lua string) + */ +LUA_FUNCTION_DEF(textpart, get_raw_content); +/*** + * @method text_part:get_content_oneline() + *Get the text of the part (html tags and newlines stripped) + * @return {text} `UTF8` encoded content of the part (zero-copy if not converted to a lua string) + */ +LUA_FUNCTION_DEF(textpart, get_content_oneline); +/*** + * @method text_part:get_length() + * Get length of the text of the part + * @return {integer} length of part in **bytes** + */ +LUA_FUNCTION_DEF(textpart, get_length); +/*** + * @method mime_part:get_raw_length() + * Get length of the **raw** content of the part (e.g. HTML with tags unstripped) + * @return {integer} length of part in **bytes** + */ +LUA_FUNCTION_DEF(textpart, get_raw_length); +/*** + * @method mime_part:get_urls_length() + * Get length of the urls within the part + * @return {integer} length of urls in **bytes** + */ +LUA_FUNCTION_DEF(textpart, get_urls_length); +/*** + * @method mime_part:get_lines_count() + * Get lines number in the part + * @return {integer} number of lines in the part + */ +LUA_FUNCTION_DEF(textpart, get_lines_count); +/*** + * @method mime_part:get_stats() + * Returns a table with the following data: + * - `lines`: number of lines + * - `spaces`: number of spaces + * - `double_spaces`: double spaces + * - `empty_lines`: number of empty lines + * - `non_ascii_characters`: number of non ascii characters + * - `ascii_characters`: number of ascii characters + * @return {table} table of stats + */ +LUA_FUNCTION_DEF(textpart, get_stats); +/*** + * @method mime_part:get_words_count() + * Get words number in the part + * @return {integer} number of words in the part + */ +LUA_FUNCTION_DEF(textpart, get_words_count); + +/*** + * @method mime_part:get_words([how]) + * Get words in the part. Optional `how` argument defines type of words returned: + * - `stem`: stemmed words (default) + * - `norm`: normalised words (utf normalised + lowercased) + * - `raw`: raw words in utf (if possible) + * - `full`: list of tables, each table has the following fields: + * - [1] - stemmed word + * - [2] - normalised word + * - [3] - raw word + * - [4] - flags (table of strings) + * @return {table/strings} words in the part + */ +LUA_FUNCTION_DEF(textpart, get_words); + +/*** + * @method mime_part:filter_words(regexp, [how][, max]]) + * Filter words using some regexp: + * - `stem`: stemmed words (default) + * - `norm`: normalised words (utf normalised + lowercased) + * - `raw`: raw words in utf (if possible) + * - `full`: list of tables, each table has the following fields: + * - [1] - stemmed word + * - [2] - normalised word + * - [3] - raw word + * - [4] - flags (table of strings) + * @param {rspamd_regexp} regexp regexp to match + * @param {string} how what words to extract + * @param {number} max maximum number of hits returned (all hits if <= 0 or nil) + * @return {table/strings} words matching regexp + */ +LUA_FUNCTION_DEF(textpart, filter_words); + +/*** + * @method text_part:is_empty() + * Returns `true` if the specified part is empty + * @return {bool} whether a part is empty + */ +LUA_FUNCTION_DEF(textpart, is_empty); +/*** + * @method text_part:is_html() + * Returns `true` if the specified part has HTML content + * @return {bool} whether a part is HTML part + */ +LUA_FUNCTION_DEF(textpart, is_html); +/*** + * @method text_part:get_html() + * Returns html content of the specified part + * @return {html} html content + */ +LUA_FUNCTION_DEF(textpart, get_html); +/*** + * @method text_part:get_language() + * Returns the code of the most used unicode script in the text part. Does not work with raw parts + * @return {string} short abbreviation (such as `ru`) for the script's language + */ +LUA_FUNCTION_DEF(textpart, get_language); + +/*** + * @method text_part:get_charset() + * Returns part real charset + * @return {string} charset of the part + */ +LUA_FUNCTION_DEF(textpart, get_charset); +/*** + * @method text_part:get_languages() + * Returns array of tables of all languages detected for a part: + * - 'code': language code (short string) + * - 'prob': logarithm of probability + * @return {array|tables} all languages detected for the part + */ +LUA_FUNCTION_DEF(textpart, get_languages); +/*** + * @method text_part:get_fuzzy_hashes(mempool) + * @param {rspamd_mempool} mempool - memory pool (usually task pool) + * Returns direct hash of textpart as a string and array [1..32] of shingles each represented as a following table: + * - [1] - 64 bit fuzzy hash represented as a string + * - [2..4] - strings used to generate this hash + * @return {string,array|tables} fuzzy hashes calculated + */ +LUA_FUNCTION_DEF(textpart, get_fuzzy_hashes); +/*** + * @method text_part:get_mimepart() + * Returns the mime part object corresponding to this text part + * @return {mimepart} mimepart object + */ +LUA_FUNCTION_DEF(textpart, get_mimepart); + +static const struct luaL_reg textpartlib_m[] = { + LUA_INTERFACE_DEF(textpart, is_utf), + LUA_INTERFACE_DEF(textpart, has_8bit_raw), + LUA_INTERFACE_DEF(textpart, has_8bit), + LUA_INTERFACE_DEF(textpart, get_content), + LUA_INTERFACE_DEF(textpart, get_raw_content), + LUA_INTERFACE_DEF(textpart, get_content_oneline), + LUA_INTERFACE_DEF(textpart, get_length), + LUA_INTERFACE_DEF(textpart, get_raw_length), + LUA_INTERFACE_DEF(textpart, get_urls_length), + LUA_INTERFACE_DEF(textpart, get_lines_count), + LUA_INTERFACE_DEF(textpart, get_words_count), + LUA_INTERFACE_DEF(textpart, get_words), + LUA_INTERFACE_DEF(textpart, filter_words), + LUA_INTERFACE_DEF(textpart, is_empty), + LUA_INTERFACE_DEF(textpart, is_html), + LUA_INTERFACE_DEF(textpart, get_html), + LUA_INTERFACE_DEF(textpart, get_language), + LUA_INTERFACE_DEF(textpart, get_charset), + LUA_INTERFACE_DEF(textpart, get_languages), + LUA_INTERFACE_DEF(textpart, get_mimepart), + LUA_INTERFACE_DEF(textpart, get_stats), + LUA_INTERFACE_DEF(textpart, get_fuzzy_hashes), + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + +/* Mimepart methods */ + +/*** + * @module rspamd_mimepart + * This module provides access to mime parts found in a message +@example +rspamd_config.MISSING_CONTENT_TYPE = function(task) + local parts = task:get_parts() + if parts and #parts > 1 then + -- We have more than one part + for _,p in ipairs(parts) do + local ct = p:get_header('Content-Type') + -- And some parts have no Content-Type header + if not ct then + return true + end + end + end + return false +end + */ + +/*** + * @method mime_part:get_header(name[, case_sensitive]) + * Get decoded value of a header specified with optional case_sensitive flag. + * By default headers are searched in caseless matter. + * @param {string} name name of header to get + * @param {boolean} case_sensitive case sensitiveness flag to search for a header + * @return {string} decoded value of a header + */ +LUA_FUNCTION_DEF(mimepart, get_header); +/*** + * @method mime_part:get_header_raw(name[, case_sensitive]) + * Get raw value of a header specified with optional case_sensitive flag. + * By default headers are searched in caseless matter. + * @param {string} name name of header to get + * @param {boolean} case_sensitive case sensitiveness flag to search for a header + * @return {string} raw value of a header + */ +LUA_FUNCTION_DEF(mimepart, get_header_raw); +/*** + * @method mime_part:get_header_full(name[, case_sensitive]) + * Get raw value of a header specified with optional case_sensitive flag. + * By default headers are searched in caseless matter. This method returns more + * information about the header as a list of tables with the following structure: + * + * - `name` - name of a header + * - `value` - raw value of a header + * - `decoded` - decoded value of a header + * - `tab_separated` - `true` if a header and a value are separated by `tab` character + * - `empty_separator` - `true` if there are no separator between a header and a value + * @param {string} name name of header to get + * @param {boolean} case_sensitive case sensitiveness flag to search for a header + * @return {list of tables} all values of a header as specified above +@example +function check_header_delimiter_tab(task, header_name) + for _,rh in ipairs(task:get_header_full(header_name)) do + if rh['tab_separated'] then return true end + end + return false +end + */ +LUA_FUNCTION_DEF(mimepart, get_header_full); +/*** + * @method mimepart:get_header_count(name[, case_sensitive]) + * Lightweight version if you need just a header's count + * * By default headers are searched in caseless matter. + * @param {string} name name of header to get + * @param {boolean} case_sensitive case sensitiveness flag to search for a header + * @return {number} number of header's occurrences or 0 if not found + */ +LUA_FUNCTION_DEF(mimepart, get_header_count); + +/*** + * @method mimepart:get_raw_headers() + * Get all undecoded headers of a mime part as a string + * @return {rspamd_text} all raw headers for a message as opaque text + */ +LUA_FUNCTION_DEF(mimepart, get_raw_headers); + +/*** + * @method mimepart:get_headers() + * Get all undecoded headers of a mime part as a string + * @return {rspamd_text} all raw headers for a message as opaque text + */ +LUA_FUNCTION_DEF(mimepart, get_headers); + +/*** + * @method mime_part:get_content() + * Get the parsed content of part + * @return {text} opaque text object (zero-copy if not casted to lua string) + */ +LUA_FUNCTION_DEF(mimepart, get_content); +/*** + * @method mime_part:get_raw_content() + * Get the raw content of part + * @return {text} opaque text object (zero-copy if not casted to lua string) + */ +LUA_FUNCTION_DEF(mimepart, get_raw_content); +/*** + * @method mime_part:get_length() + * Get length of the content of the part + * @return {integer} length of part in **bytes** + */ +LUA_FUNCTION_DEF(mimepart, get_length); +/*** + * @method mime_part:get_type() + * Extract content-type string of the mime part + * @return {string,string} content type in form 'type','subtype' + */ +LUA_FUNCTION_DEF(mimepart, get_type); + +/*** + * @method mime_part:get_type_full() + * Extract content-type string of the mime part with all attributes + * @return {string,string,table} content type in form 'type','subtype', {attrs} + */ +LUA_FUNCTION_DEF(mimepart, get_type_full); + +/*** + * @method mime_part:get_detected_type() + * Extract content-type string of the mime part. Use lua_magic detection + * @return {string,string} content type in form 'type','subtype' + */ +LUA_FUNCTION_DEF(mimepart, get_detected_type); + +/*** + * @method mime_part:get_detected_type_full() + * Extract content-type string of the mime part with all attributes. Use lua_magic detection + * @return {string,string,table} content type in form 'type','subtype', {attrs} + */ +LUA_FUNCTION_DEF(mimepart, get_detected_type_full); + +/*** + * @method mime_part:get_detected_ext() + * Returns a msdos extension name according to lua_magic detection + * @return {string} detected extension (see lua_magic.types) + */ +LUA_FUNCTION_DEF(mimepart, get_detected_ext); + +/*** + * @method mime_part:get_cte() + * Extract content-transfer-encoding for a part + * @return {string} content transfer encoding (e.g. `base64` or `7bit`) + */ +LUA_FUNCTION_DEF(mimepart, get_cte); + +/*** + * @method mime_part:get_filename() + * Extract filename associated with mime part if it is an attachment + * @return {string} filename or `nil` if no file is associated with this part + */ +LUA_FUNCTION_DEF(mimepart, get_filename); +/*** + * @method mime_part:is_image() + * Returns true if mime part is an image + * @return {bool} true if a part is an image + */ +LUA_FUNCTION_DEF(mimepart, is_image); +/*** + * @method mime_part:get_image() + * Returns rspamd_image structure associated with this part. This structure has + * the following methods: + * + * * `get_width` - return width of an image in pixels + * * `get_height` - return height of an image in pixels + * * `get_type` - return string representation of image's type (e.g. 'jpeg') + * * `get_filename` - return string with image's file name + * * `get_size` - return size in bytes + * @return {rspamd_image} image structure or nil if a part is not an image + */ +LUA_FUNCTION_DEF(mimepart, get_image); +/*** + * @method mime_part:is_archive() + * Returns true if mime part is an archive + * @return {bool} true if a part is an archive + */ +LUA_FUNCTION_DEF(mimepart, is_archive); +/*** + * @method mime_part:is_attachment() + * Returns true if mime part looks like an attachment + * @return {bool} true if a part looks like an attachment + */ +LUA_FUNCTION_DEF(mimepart, is_attachment); + +/*** + * @method mime_part:get_archive() + * Returns rspamd_archive structure associated with this part. This structure has + * the following methods: + * + * * `get_files` - return list of strings with filenames inside archive + * * `get_files_full` - return list of tables with all information about files + * * `is_encrypted` - return true if an archive is encrypted + * * `get_type` - return string representation of image's type (e.g. 'zip') + * * `get_filename` - return string with archive's file name + * * `get_size` - return size in bytes + * @return {rspamd_archive} archive structure or nil if a part is not an archive + */ +LUA_FUNCTION_DEF(mimepart, get_archive); +/*** + * @method mime_part:is_multipart() + * Returns true if mime part is a multipart part + * @return {bool} true if a part is is a multipart part + */ +LUA_FUNCTION_DEF(mimepart, is_multipart); +/*** + * @method mime_part:is_message() + * Returns true if mime part is a message part (message/rfc822) + * @return {bool} true if a part is is a message part + */ +LUA_FUNCTION_DEF(mimepart, is_message); +/*** + * @method mime_part:get_boundary() + * Returns boundary for a part (extracted from parent multipart for normal parts and + * from the part itself for multipart) + * @return {string} boundary value or nil + */ +LUA_FUNCTION_DEF(mimepart, get_boundary); + +/*** + * @method mime_part:get_enclosing_boundary() + * Returns an enclosing boundary for a part even for multiparts. For normal parts + * this method is identical to `get_boundary` + * @return {string} boundary value or nil + */ +LUA_FUNCTION_DEF(mimepart, get_enclosing_boundary); + +/*** + * @method mime_part:get_children() + * Returns rspamd_mimepart table of part's childer. Returns nil if mime part is not multipart + * or a message part. + * @return {rspamd_mimepart} table of children + */ +LUA_FUNCTION_DEF(mimepart, get_children); +/*** + * @method mime_part:is_text() + * Returns true if mime part is a text part + * @return {bool} true if a part is a text part + */ +LUA_FUNCTION_DEF(mimepart, is_text); +/*** + * @method mime_part:get_text() + * Returns rspamd_textpart structure associated with this part. + * @return {rspamd_textpart} textpart structure or nil if a part is not an text + */ +LUA_FUNCTION_DEF(mimepart, get_text); + +/*** + * @method mime_part:get_digest() + * Returns the unique digest for this mime part + * @return {string} 128 characters hex string with digest of the part + */ +LUA_FUNCTION_DEF(mimepart, get_digest); + +/*** + * @method mime_part:get_id() + * Returns the order of the part in parts list + * @return {number} index of the part (starting from 1 as it is Lua API) + */ +LUA_FUNCTION_DEF(mimepart, get_id); +/*** + * @method mime_part:is_broken() + * Returns true if mime part has incorrectly specified content type + * @return {bool} true if a part has bad content type + */ +LUA_FUNCTION_DEF(mimepart, is_broken); +/*** + * @method mime_part:headers_foreach(callback, [params]) + * This method calls `callback` for each header that satisfies some condition. + * By default, all headers are iterated unless `callback` returns `true`. Nil or + * false means continue of iterations. + * Params could be as following: + * + * - `full`: header value is full table of all attributes @see task:get_header_full for details + * - `regexp`: return headers that satisfies the specified regexp + * @param {function} callback function from header name and header value + * @param {table} params optional parameters + */ +LUA_FUNCTION_DEF(mimepart, headers_foreach); +/*** + * @method mime_part:get_parent() + * Returns parent part for this part + * @return {rspamd_mimepart} parent part or nil + */ +LUA_FUNCTION_DEF(mimepart, get_parent); + +/*** + * @method mime_part:get_specific() + * Returns specific lua content for this part + * @return {any} specific lua content + */ +LUA_FUNCTION_DEF(mimepart, get_specific); + +/*** + * @method mime_part:set_specific(<any>) + * Sets a specific content for this part + * @return {any} previous specific lua content (or nil) + */ +LUA_FUNCTION_DEF(mimepart, set_specific); + +/*** + * @method mime_part:is_specific(<any>) + * Returns true if part has specific lua content + * @return {boolean} flag + */ +LUA_FUNCTION_DEF(mimepart, is_specific); + +/*** + * @method mime_part:get_urls([need_emails|list_protos][, need_images]) + * Get all URLs found in a mime part. Telephone urls and emails are not included unless explicitly asked in `list_protos` + * @param {boolean} need_emails if `true` then return also email urls, this can be a comma separated string of protocols desired or a table (e.g. `mailto` or `telephone`) + * @param {boolean} need_images return urls from images (<img src=...>) as well + * @return {table rspamd_url} list of all urls found + */ +LUA_FUNCTION_DEF(mimepart, get_urls); + +static const struct luaL_reg mimepartlib_m[] = { + LUA_INTERFACE_DEF(mimepart, get_content), + LUA_INTERFACE_DEF(mimepart, get_raw_content), + LUA_INTERFACE_DEF(mimepart, get_length), + LUA_INTERFACE_DEF(mimepart, get_type), + LUA_INTERFACE_DEF(mimepart, get_type_full), + LUA_INTERFACE_DEF(mimepart, get_detected_type), + LUA_INTERFACE_DEF(mimepart, get_detected_ext), + LUA_INTERFACE_DEF(mimepart, get_detected_type_full), + LUA_INTERFACE_DEF(mimepart, get_cte), + LUA_INTERFACE_DEF(mimepart, get_filename), + LUA_INTERFACE_DEF(mimepart, get_boundary), + LUA_INTERFACE_DEF(mimepart, get_enclosing_boundary), + LUA_INTERFACE_DEF(mimepart, get_header), + LUA_INTERFACE_DEF(mimepart, get_header_raw), + LUA_INTERFACE_DEF(mimepart, get_header_full), + LUA_INTERFACE_DEF(mimepart, get_header_count), + LUA_INTERFACE_DEF(mimepart, get_raw_headers), + LUA_INTERFACE_DEF(mimepart, get_headers), + LUA_INTERFACE_DEF(mimepart, is_image), + LUA_INTERFACE_DEF(mimepart, get_image), + LUA_INTERFACE_DEF(mimepart, is_archive), + LUA_INTERFACE_DEF(mimepart, get_archive), + LUA_INTERFACE_DEF(mimepart, is_multipart), + LUA_INTERFACE_DEF(mimepart, is_message), + LUA_INTERFACE_DEF(mimepart, get_children), + LUA_INTERFACE_DEF(mimepart, get_parent), + LUA_INTERFACE_DEF(mimepart, get_urls), + LUA_INTERFACE_DEF(mimepart, is_text), + LUA_INTERFACE_DEF(mimepart, is_broken), + LUA_INTERFACE_DEF(mimepart, is_attachment), + LUA_INTERFACE_DEF(mimepart, get_text), + LUA_INTERFACE_DEF(mimepart, get_digest), + LUA_INTERFACE_DEF(mimepart, get_id), + LUA_INTERFACE_DEF(mimepart, headers_foreach), + LUA_INTERFACE_DEF(mimepart, get_specific), + LUA_INTERFACE_DEF(mimepart, set_specific), + LUA_INTERFACE_DEF(mimepart, is_specific), + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + + +static struct rspamd_mime_text_part * +lua_check_textpart(lua_State *L) +{ + void *ud = rspamd_lua_check_udata(L, 1, "rspamd{textpart}"); + luaL_argcheck(L, ud != NULL, 1, "'textpart' expected"); + return ud ? *((struct rspamd_mime_text_part **) ud) : NULL; +} + +static struct rspamd_mime_part * +lua_check_mimepart(lua_State *L) +{ + void *ud = rspamd_lua_check_udata(L, 1, "rspamd{mimepart}"); + luaL_argcheck(L, ud != NULL, 1, "'mimepart' expected"); + return ud ? *((struct rspamd_mime_part **) ud) : NULL; +} + + +static gint +lua_textpart_is_utf(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_text_part *part = lua_check_textpart(L); + + if (part == NULL || IS_TEXT_PART_EMPTY(part)) { + lua_pushboolean(L, FALSE); + return 1; + } + + lua_pushboolean(L, IS_TEXT_PART_UTF(part)); + + return 1; +} + + +static gint +lua_textpart_has_8bit_raw(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_text_part *part = lua_check_textpart(L); + + if (part) { + if (part->flags & RSPAMD_MIME_TEXT_PART_FLAG_8BIT_RAW) { + lua_pushboolean(L, TRUE); + } + else { + lua_pushboolean(L, FALSE); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_textpart_has_8bit(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_text_part *part = lua_check_textpart(L); + + if (part) { + if (part->flags & RSPAMD_MIME_TEXT_PART_FLAG_8BIT_ENCODED) { + lua_pushboolean(L, TRUE); + } + else { + lua_pushboolean(L, FALSE); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + + +static gint +lua_textpart_get_content(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_text_part *part = lua_check_textpart(L); + struct rspamd_lua_text *t; + gsize len; + const gchar *start, *type = NULL; + + if (part == NULL) { + lua_pushnil(L); + return 1; + } + + if (lua_type(L, 2) == LUA_TSTRING) { + type = lua_tostring(L, 2); + } + + if (!type) { + if (IS_TEXT_PART_EMPTY(part)) { + lua_pushnil(L); + return 1; + } + start = part->utf_content.begin; + len = part->utf_content.len; + } + else if (strcmp(type, "content") == 0) { + if (IS_TEXT_PART_EMPTY(part)) { + lua_pushnil(L); + return 1; + } + + start = part->utf_content.begin; + len = part->utf_content.len; + } + else if (strcmp(type, "content_oneline") == 0) { + if (IS_TEXT_PART_EMPTY(part)) { + lua_pushnil(L); + return 1; + } + + start = part->utf_stripped_content->data; + len = part->utf_stripped_content->len; + } + else if (strcmp(type, "raw_parsed") == 0) { + if (part->parsed.len == 0) { + lua_pushnil(L); + return 1; + } + + start = part->parsed.begin; + len = part->parsed.len; + } + else if (strcmp(type, "raw_utf") == 0) { + if (part->utf_raw_content == NULL || part->utf_raw_content->len == 0) { + lua_pushnil(L); + return 1; + } + + start = part->utf_raw_content->data; + len = part->utf_raw_content->len; + } + else if (strcmp(type, "raw") == 0) { + if (part->raw.len == 0) { + lua_pushnil(L); + return 1; + } + + start = part->raw.begin; + len = part->raw.len; + } + else { + return luaL_error(L, "invalid content type: %s", type); + } + + t = lua_newuserdata(L, sizeof(*t)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + + t->start = start; + t->len = len; + t->flags = 0; + + return 1; +} + +static gint +lua_textpart_get_raw_content(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_text_part *part = lua_check_textpart(L); + struct rspamd_lua_text *t; + + if (part == NULL || IS_TEXT_PART_EMPTY(part)) { + lua_pushnil(L); + return 1; + } + + t = lua_newuserdata(L, sizeof(*t)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + t->start = part->raw.begin; + t->len = part->raw.len; + t->flags = 0; + + return 1; +} + +static gint +lua_textpart_get_content_oneline(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_text_part *part = lua_check_textpart(L); + + if (part == NULL || IS_TEXT_PART_EMPTY(part)) { + lua_pushnil(L); + return 1; + } + + lua_new_text(L, part->utf_stripped_content->data, part->utf_stripped_content->len, FALSE); + + return 1; +} + +static gint +lua_textpart_get_length(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_text_part *part = lua_check_textpart(L); + + if (part == NULL) { + lua_pushnil(L); + return 1; + } + + if (IS_TEXT_PART_EMPTY(part) || part->utf_content.len == 0) { + lua_pushinteger(L, 0); + } + else { + lua_pushinteger(L, part->utf_content.len); + } + + return 1; +} + +static gint +lua_textpart_get_raw_length(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_text_part *part = lua_check_textpart(L); + + if (part == NULL) { + lua_pushnil(L); + return 1; + } + + lua_pushinteger(L, part->raw.len); + + return 1; +} + +static gint +lua_textpart_get_urls_length(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_text_part *part = lua_check_textpart(L); + GList *cur; + guint total = 0; + struct rspamd_process_exception *ex; + + if (part == NULL) { + lua_pushnil(L); + return 1; + } + + for (cur = part->exceptions; cur != NULL; cur = g_list_next(cur)) { + ex = cur->data; + + if (ex->type == RSPAMD_EXCEPTION_URL) { + total += ex->len; + } + } + + lua_pushinteger(L, total); + + return 1; +} + +static gint +lua_textpart_get_lines_count(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_text_part *part = lua_check_textpart(L); + + if (part == NULL) { + lua_pushnil(L); + return 1; + } + + if (IS_TEXT_PART_EMPTY(part)) { + lua_pushinteger(L, 0); + } + else { + lua_pushinteger(L, part->nlines); + } + + return 1; +} + +static gint +lua_textpart_get_words_count(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_text_part *part = lua_check_textpart(L); + + if (part == NULL) { + lua_pushnil(L); + return 1; + } + + if (IS_TEXT_PART_EMPTY(part) || part->utf_words == NULL) { + lua_pushinteger(L, 0); + } + else { + lua_pushinteger(L, part->nwords); + } + + return 1; +} + +static inline enum rspamd_lua_words_type +word_extract_type_from_string(const gchar *how_str) +{ + enum rspamd_lua_words_type how = RSPAMD_LUA_WORDS_MAX; + + if (strcmp(how_str, "stem") == 0) { + how = RSPAMD_LUA_WORDS_STEM; + } + else if (strcmp(how_str, "norm") == 0) { + how = RSPAMD_LUA_WORDS_NORM; + } + else if (strcmp(how_str, "raw") == 0) { + how = RSPAMD_LUA_WORDS_RAW; + } + else if (strcmp(how_str, "full") == 0) { + how = RSPAMD_LUA_WORDS_FULL; + } + + return how; +} + +static gint +lua_textpart_get_words(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_text_part *part = lua_check_textpart(L); + enum rspamd_lua_words_type how = RSPAMD_LUA_WORDS_STEM; + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (IS_TEXT_PART_EMPTY(part) || part->utf_words == NULL) { + lua_createtable(L, 0, 0); + } + else { + if (lua_type(L, 2) == LUA_TSTRING) { + const gchar *how_str = lua_tostring(L, 2); + + how = word_extract_type_from_string(how_str); + + if (how == RSPAMD_LUA_WORDS_MAX) { + return luaL_error(L, "invalid extraction type: %s", how_str); + } + } + + return rspamd_lua_push_words(L, part->utf_words, how); + } + + return 1; +} + +static gint +lua_textpart_filter_words(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_text_part *part = lua_check_textpart(L); + struct rspamd_lua_regexp *re = lua_check_regexp(L, 2); + gint lim = -1; + enum rspamd_lua_words_type how = RSPAMD_LUA_WORDS_STEM; + + if (part == NULL || re == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (IS_TEXT_PART_EMPTY(part) || part->utf_words == NULL) { + lua_createtable(L, 0, 0); + } + else { + if (lua_type(L, 3) == LUA_TSTRING) { + const gchar *how_str = lua_tostring(L, 3); + + how = word_extract_type_from_string(how_str); + + if (how == RSPAMD_LUA_WORDS_MAX) { + return luaL_error(L, "invalid extraction type: %s", how_str); + } + } + + if (lua_type(L, 4) == LUA_TNUMBER) { + lim = lua_tointeger(L, 4); + } + + guint cnt, i; + + lua_createtable(L, 8, 0); + + for (i = 0, cnt = 1; i < part->utf_words->len; i++) { + rspamd_stat_token_t *w = &g_array_index(part->utf_words, + rspamd_stat_token_t, i); + + switch (how) { + case RSPAMD_LUA_WORDS_STEM: + if (w->stemmed.len > 0) { + if (rspamd_regexp_match(re->re, w->stemmed.begin, + w->stemmed.len, FALSE)) { + lua_pushlstring(L, w->stemmed.begin, w->stemmed.len); + lua_rawseti(L, -2, cnt++); + } + } + break; + case RSPAMD_LUA_WORDS_NORM: + if (w->normalized.len > 0) { + if (rspamd_regexp_match(re->re, w->normalized.begin, + w->normalized.len, FALSE)) { + lua_pushlstring(L, w->normalized.begin, w->normalized.len); + lua_rawseti(L, -2, cnt++); + } + } + break; + case RSPAMD_LUA_WORDS_RAW: + if (w->original.len > 0) { + if (rspamd_regexp_match(re->re, w->original.begin, + w->original.len, TRUE)) { + lua_pushlstring(L, w->original.begin, w->original.len); + lua_rawseti(L, -2, cnt++); + } + } + break; + case RSPAMD_LUA_WORDS_FULL: + if (rspamd_regexp_match(re->re, w->normalized.begin, + w->normalized.len, FALSE)) { + rspamd_lua_push_full_word(L, w); + /* Push to the resulting vector */ + lua_rawseti(L, -2, cnt++); + } + break; + default: + break; + } + + if (lim > 0 && cnt >= lim) { + break; + } + } + } + + return 1; +} + +static gint +lua_textpart_is_empty(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_text_part *part = lua_check_textpart(L); + + if (part == NULL) { + lua_pushnil(L); + return 1; + } + + lua_pushboolean(L, IS_TEXT_PART_EMPTY(part)); + + return 1; +} + +static gint +lua_textpart_is_html(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_text_part *part = lua_check_textpart(L); + + if (part == NULL) { + lua_pushnil(L); + return 1; + } + + lua_pushboolean(L, IS_TEXT_PART_HTML(part)); + + return 1; +} + +static gint +lua_textpart_get_html(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_text_part *part = lua_check_textpart(L); + struct html_content **phc; + + if (part == NULL || part->html == NULL) { + lua_pushnil(L); + } + else { + phc = lua_newuserdata(L, sizeof(*phc)); + rspamd_lua_setclass(L, "rspamd{html}", -1); + *phc = part->html; + } + + return 1; +} + +static gint +lua_textpart_get_language(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_text_part *part = lua_check_textpart(L); + + if (part != NULL) { + if (part->language != NULL && part->language[0] != '\0') { + lua_pushstring(L, part->language); + return 1; + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_textpart_get_charset(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_text_part *part = lua_check_textpart(L); + + if (part != NULL) { + if (part->real_charset != NULL) { + lua_pushstring(L, part->real_charset); + return 1; + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_textpart_get_languages(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_text_part *part = lua_check_textpart(L); + guint i; + struct rspamd_lang_detector_res *cur; + + if (part != NULL) { + if (part->languages != NULL) { + lua_createtable(L, part->languages->len, 0); + + PTR_ARRAY_FOREACH(part->languages, i, cur) + { + lua_createtable(L, 0, 2); + lua_pushstring(L, "code"); + lua_pushstring(L, cur->lang); + lua_settable(L, -3); + lua_pushstring(L, "prob"); + lua_pushnumber(L, cur->prob); + lua_settable(L, -3); + + lua_rawseti(L, -2, i + 1); + } + } + else { + lua_newtable(L); + } + } + else { + luaL_error(L, "invalid arguments"); + } + + return 1; +} + +struct lua_shingle_data { + guint64 hash; + rspamd_ftok_t t1; + rspamd_ftok_t t2; + rspamd_ftok_t t3; +}; + +struct lua_shingle_filter_cbdata { + struct rspamd_mime_text_part *part; + rspamd_mempool_t *pool; +}; + +#define STORE_TOKEN(i, t) \ + do { \ + if ((i) < part->utf_words->len) { \ + word = &g_array_index(part->utf_words, rspamd_stat_token_t, (i)); \ + sd->t.begin = word->stemmed.begin; \ + sd->t.len = word->stemmed.len; \ + } \ + } while (0) + +static guint64 +lua_shingles_filter(guint64 *input, gsize count, + gint shno, const guchar *key, gpointer ud) +{ + guint64 minimal = G_MAXUINT64; + gsize i, min_idx = 0; + struct lua_shingle_data *sd; + rspamd_stat_token_t *word; + struct lua_shingle_filter_cbdata *cbd = (struct lua_shingle_filter_cbdata *) ud; + struct rspamd_mime_text_part *part; + + part = cbd->part; + + for (i = 0; i < count; i++) { + if (minimal > input[i]) { + minimal = input[i]; + min_idx = i; + } + } + + sd = rspamd_mempool_alloc0(cbd->pool, sizeof(*sd)); + sd->hash = minimal; + + + STORE_TOKEN(min_idx, t1); + STORE_TOKEN(min_idx + 1, t2); + STORE_TOKEN(min_idx + 2, t3); + + return GPOINTER_TO_SIZE(sd); +} + +#undef STORE_TOKEN + +static gint +lua_textpart_get_fuzzy_hashes(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_text_part *part = lua_check_textpart(L); + rspamd_mempool_t *pool = rspamd_lua_check_mempool(L, 2); + guchar key[rspamd_cryptobox_HASHBYTES], digest[rspamd_cryptobox_HASHBYTES], + hexdigest[rspamd_cryptobox_HASHBYTES * 2 + 1], numbuf[64]; + struct rspamd_shingle *sgl; + guint i; + struct lua_shingle_data *sd; + rspamd_cryptobox_hash_state_t st; + rspamd_stat_token_t *word; + struct lua_shingle_filter_cbdata cbd; + + + if (part == NULL || pool == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (IS_TEXT_PART_EMPTY(part) || part->utf_words == NULL) { + lua_pushnil(L); + lua_pushnil(L); + } + else { + /* TODO: add keys and algorithms support */ + rspamd_cryptobox_hash(key, "rspamd", strlen("rspamd"), NULL, 0); + + /* TODO: add short text support */ + + /* Calculate direct hash */ + rspamd_cryptobox_hash_init(&st, key, rspamd_cryptobox_HASHKEYBYTES); + + for (i = 0; i < part->utf_words->len; i++) { + word = &g_array_index(part->utf_words, rspamd_stat_token_t, i); + rspamd_cryptobox_hash_update(&st, + word->stemmed.begin, word->stemmed.len); + } + + rspamd_cryptobox_hash_final(&st, digest); + + rspamd_encode_hex_buf(digest, sizeof(digest), hexdigest, + sizeof(hexdigest)); + lua_pushlstring(L, hexdigest, sizeof(hexdigest) - 1); + + cbd.pool = pool; + cbd.part = part; + sgl = rspamd_shingles_from_text(part->utf_words, key, + pool, lua_shingles_filter, &cbd, RSPAMD_SHINGLES_MUMHASH); + + if (sgl == NULL) { + lua_pushnil(L); + } + else { + lua_createtable(L, G_N_ELEMENTS(sgl->hashes), 0); + + for (i = 0; i < G_N_ELEMENTS(sgl->hashes); i++) { + sd = GSIZE_TO_POINTER(sgl->hashes[i]); + + lua_createtable(L, 4, 0); + rspamd_snprintf(numbuf, sizeof(numbuf), "%uL", sd->hash); + lua_pushstring(L, numbuf); + lua_rawseti(L, -2, 1); + + /* Tokens */ + lua_pushlstring(L, sd->t1.begin, sd->t1.len); + lua_rawseti(L, -2, 2); + + lua_pushlstring(L, sd->t2.begin, sd->t2.len); + lua_rawseti(L, -2, 3); + + lua_pushlstring(L, sd->t3.begin, sd->t3.len); + lua_rawseti(L, -2, 4); + + lua_rawseti(L, -2, i + 1); /* Store table */ + } + } + } + + return 2; +} + +static gint +lua_textpart_get_mimepart(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_text_part *part = lua_check_textpart(L); + struct rspamd_mime_part **pmime; + + if (part != NULL) { + if (part->mime_part != NULL) { + pmime = lua_newuserdata(L, sizeof(struct rspamd_mime_part *)); + rspamd_lua_setclass(L, "rspamd{mimepart}", -1); + *pmime = part->mime_part; + + return 1; + } + } + + lua_pushnil(L); + return 1; +} + +/*** + * @method mime_part:get_stats() + * Returns a table with the following data: + * - + * - `lines`: number of lines + * - `spaces`: number of spaces + * - `double_spaces`: double spaces + * - `empty_lines`: number of empty lines + * - `non_ascii_characters`: number of non ascii characters + * - `ascii_characters`: number of ascii characters + * @return {table} table of stats + */ +static gint +lua_textpart_get_stats(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_text_part *part = lua_check_textpart(L); + + if (part != NULL) { + lua_createtable(L, 0, 9); + + lua_pushstring(L, "lines"); + lua_pushinteger(L, part->nlines); + lua_settable(L, -3); + lua_pushstring(L, "empty_lines"); + lua_pushinteger(L, part->empty_lines); + lua_settable(L, -3); + lua_pushstring(L, "spaces"); + lua_pushinteger(L, part->spaces); + lua_settable(L, -3); + lua_pushstring(L, "non_spaces"); + lua_pushinteger(L, part->non_spaces); + lua_settable(L, -3); + lua_pushstring(L, "double_spaces"); + lua_pushinteger(L, part->double_spaces); + lua_settable(L, -3); + lua_pushstring(L, "ascii_characters"); + lua_pushinteger(L, part->ascii_chars); + lua_settable(L, -3); + lua_pushstring(L, "non_ascii_characters"); + lua_pushinteger(L, part->non_ascii_chars); + lua_settable(L, -3); + lua_pushstring(L, "capital_letters"); + lua_pushinteger(L, part->capital_letters); + lua_settable(L, -3); + lua_pushstring(L, "numeric_characters"); + lua_pushinteger(L, part->numeric_characters); + lua_settable(L, -3); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/* Mimepart implementation */ + +static gint +lua_mimepart_get_content(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + struct rspamd_lua_text *t; + + if (part == NULL) { + lua_pushnil(L); + return 1; + } + + t = lua_newuserdata(L, sizeof(*t)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + t->start = part->parsed_data.begin; + t->len = part->parsed_data.len; + t->flags = 0; + + if (lua_is_text_binary(t)) { + t->flags |= RSPAMD_TEXT_FLAG_BINARY; + } + + return 1; +} + +static gint +lua_mimepart_get_raw_content(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + struct rspamd_lua_text *t; + + if (part == NULL) { + lua_pushnil(L); + return 1; + } + + t = lua_newuserdata(L, sizeof(*t)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + t->start = part->raw_data.begin; + t->len = part->raw_data.len; + t->flags = 0; + + return 1; +} + +static gint +lua_mimepart_get_length(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + + if (part == NULL) { + lua_pushnil(L); + return 1; + } + + lua_pushinteger(L, part->parsed_data.len); + + return 1; +} + +static gint +lua_mimepart_get_type_common(lua_State *L, struct rspamd_content_type *ct, + gboolean full) +{ + + GHashTableIter it; + gpointer k, v; + struct rspamd_content_type_param *param; + + if (ct == NULL) { + lua_pushnil(L); + lua_pushnil(L); + return 2; + } + + lua_pushlstring(L, ct->type.begin, ct->type.len); + lua_pushlstring(L, ct->subtype.begin, ct->subtype.len); + + if (!full) { + return 2; + } + + lua_createtable(L, 0, 2 + (ct->attrs ? g_hash_table_size(ct->attrs) : 0)); + + if (ct->charset.len > 0) { + lua_pushstring(L, "charset"); + lua_pushlstring(L, ct->charset.begin, ct->charset.len); + lua_settable(L, -3); + } + + if (ct->boundary.len > 0) { + lua_pushstring(L, "boundary"); + lua_pushlstring(L, ct->boundary.begin, ct->boundary.len); + lua_settable(L, -3); + } + + if (ct->attrs) { + g_hash_table_iter_init(&it, ct->attrs); + + while (g_hash_table_iter_next(&it, &k, &v)) { + param = v; + + if (param->name.len > 0 && param->value.len > 0) { + /* TODO: think about multiple values here */ + lua_pushlstring(L, param->name.begin, param->name.len); + lua_pushlstring(L, param->value.begin, param->value.len); + lua_settable(L, -3); + } + } + } + + return 3; +} + +static gint +lua_mimepart_get_type(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + return lua_mimepart_get_type_common(L, part->ct, FALSE); +} + +static gint +lua_mimepart_get_type_full(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + return lua_mimepart_get_type_common(L, part->ct, TRUE); +} + +static gint +lua_mimepart_get_detected_type(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + return lua_mimepart_get_type_common(L, part->detected_ct, FALSE); +} + +static gint +lua_mimepart_get_detected_type_full(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + return lua_mimepart_get_type_common(L, part->detected_ct, TRUE); +} + +static gint +lua_mimepart_get_detected_ext(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (part->detected_ext) { + lua_pushstring(L, part->detected_ext); + } + else { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_mimepart_get_cte(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + + if (part == NULL) { + lua_pushnil(L); + return 1; + } + + lua_pushstring(L, rspamd_cte_to_string(part->cte)); + + return 1; +} + +static gint +lua_mimepart_get_filename(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + + if (part == NULL || part->cd == NULL || part->cd->filename.len == 0) { + lua_pushnil(L); + return 1; + } + + lua_pushlstring(L, part->cd->filename.begin, part->cd->filename.len); + + return 1; +} + +static gint +lua_mimepart_get_boundary(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L), *parent; + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (IS_PART_MULTIPART(part)) { + lua_pushlstring(L, part->specific.mp->boundary.begin, + part->specific.mp->boundary.len); + } + else { + parent = part->parent_part; + + if (!parent || !IS_PART_MULTIPART(parent)) { + lua_pushnil(L); + } + else { + lua_pushlstring(L, parent->specific.mp->boundary.begin, + parent->specific.mp->boundary.len); + } + } + + return 1; +} + +static gint +lua_mimepart_get_enclosing_boundary(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L), *parent; + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + parent = part->parent_part; + + if (!parent || !IS_PART_MULTIPART(parent)) { + lua_pushnil(L); + } + else { + lua_pushlstring(L, parent->specific.mp->boundary.begin, + parent->specific.mp->boundary.len); + } + + return 1; +} + +static gint +lua_mimepart_get_header_common(lua_State *L, enum rspamd_lua_task_header_type how) +{ + struct rspamd_mime_part *part = lua_check_mimepart(L); + const gchar *name; + gboolean strong = FALSE; + + name = luaL_checkstring(L, 2); + + if (name && part) { + + if (lua_isboolean(L, 3)) { + strong = lua_toboolean(L, 3); + } + + return rspamd_lua_push_header_array(L, + name, + rspamd_message_get_header_from_hash(part->raw_headers, name, FALSE), + how, + strong); + } + + lua_pushnil(L); + + return 1; +} + +static gint +lua_mimepart_get_header_full(lua_State *L) +{ + LUA_TRACE_POINT; + return lua_mimepart_get_header_common(L, RSPAMD_TASK_HEADER_PUSH_FULL); +} + +static gint +lua_mimepart_get_header(lua_State *L) +{ + LUA_TRACE_POINT; + return lua_mimepart_get_header_common(L, RSPAMD_TASK_HEADER_PUSH_SIMPLE); +} + +static gint +lua_mimepart_get_header_raw(lua_State *L) +{ + LUA_TRACE_POINT; + return lua_mimepart_get_header_common(L, RSPAMD_TASK_HEADER_PUSH_RAW); +} + +static gint +lua_mimepart_get_header_count(lua_State *L) +{ + LUA_TRACE_POINT; + return lua_mimepart_get_header_common(L, RSPAMD_TASK_HEADER_PUSH_COUNT); +} + +static gint +lua_mimepart_get_raw_headers(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + struct rspamd_lua_text *t; + + if (part) { + t = lua_newuserdata(L, sizeof(*t)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + t->start = part->raw_headers_str; + t->len = part->raw_headers_len; + t->flags = 0; + } + else { + return luaL_error(L, "invalid arguments"); + } + + + return 1; +} + +static gint +lua_mimepart_get_headers(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + bool need_modified = lua_isnoneornil(L, 2) ? false : lua_toboolean(L, 2); + + if (part) { + struct rspamd_mime_header *cur; + int i = 1; + + lua_createtable(L, rspamd_mime_headers_count(part->raw_headers), 0); + LL_FOREACH2(part->headers_order, cur, ord_next) + { + if (need_modified && cur->modified_chain) { + struct rspamd_mime_header *cur_modified; + + LL_FOREACH(cur->modified_chain, cur_modified) + { + rspamd_lua_push_header(L, cur_modified, RSPAMD_TASK_HEADER_PUSH_FULL); + lua_rawseti(L, -2, i++); + } + } + else { + rspamd_lua_push_header(L, cur, RSPAMD_TASK_HEADER_PUSH_FULL); + lua_rawseti(L, -2, i++); + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + + return 1; +} + + +static gint +lua_mimepart_is_image(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + lua_pushboolean(L, part->part_type == RSPAMD_MIME_PART_IMAGE); + + return 1; +} + +static gint +lua_mimepart_is_archive(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + lua_pushboolean(L, part->part_type == RSPAMD_MIME_PART_ARCHIVE); + + return 1; +} + +static gint +lua_mimepart_is_multipart(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + lua_pushboolean(L, IS_PART_MULTIPART(part) ? true : false); + + return 1; +} + +static gint +lua_mimepart_is_message(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + lua_pushboolean(L, IS_PART_MESSAGE(part) ? true : false); + + return 1; +} + +static gint +lua_mimepart_is_attachment(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (part->cd && part->cd->type == RSPAMD_CT_ATTACHMENT) { + lua_pushboolean(L, true); + } + else { + /* if has_name and not (image and Content-ID_header_present) */ + if (part->cd && part->cd->filename.len > 0) { + if (part->part_type != RSPAMD_MIME_PART_IMAGE && + rspamd_message_get_header_from_hash(part->raw_headers, + "Content-Id", FALSE) == NULL) { + /* Filename is presented but no content id and not image */ + lua_pushboolean(L, true); + } + else { + /* Image or an embedded object */ + lua_pushboolean(L, false); + } + } + else { + /* No filename */ + lua_pushboolean(L, false); + } + } + + return 1; +} + +static gint +lua_mimepart_is_text(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + lua_pushboolean(L, part->part_type == RSPAMD_MIME_PART_TEXT); + + return 1; +} + +static gint +lua_mimepart_is_broken(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (part->ct) { + lua_pushboolean(L, (part->ct->flags & RSPAMD_CONTENT_TYPE_BROKEN) ? true : false); + } + else { + lua_pushboolean(L, false); + } + + return 1; +} + +static gint +lua_mimepart_get_image(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + struct rspamd_image **pimg; + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (part->part_type != RSPAMD_MIME_PART_IMAGE || part->specific.img == NULL) { + lua_pushnil(L); + } + else { + pimg = lua_newuserdata(L, sizeof(*pimg)); + *pimg = part->specific.img; + rspamd_lua_setclass(L, "rspamd{image}", -1); + } + + return 1; +} + +static gint +lua_mimepart_get_archive(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + struct rspamd_archive **parch; + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (part->part_type != RSPAMD_MIME_PART_ARCHIVE || part->specific.arch == NULL) { + lua_pushnil(L); + } + else { + parch = lua_newuserdata(L, sizeof(*parch)); + *parch = part->specific.arch; + rspamd_lua_setclass(L, "rspamd{archive}", -1); + } + + return 1; +} + +static gint +lua_mimepart_get_children(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + struct rspamd_mime_part **pcur, *cur; + guint i; + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (!IS_PART_MULTIPART(part) || part->specific.mp->children == NULL) { + lua_pushnil(L); + } + else { + lua_createtable(L, part->specific.mp->children->len, 0); + + PTR_ARRAY_FOREACH(part->specific.mp->children, i, cur) + { + pcur = lua_newuserdata(L, sizeof(*pcur)); + *pcur = cur; + rspamd_lua_setclass(L, "rspamd{mimepart}", -1); + lua_rawseti(L, -2, i + 1); + } + } + + return 1; +} + +static gint +lua_mimepart_get_parent(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + struct rspamd_mime_part **pparent; + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (part->parent_part) { + pparent = lua_newuserdata(L, sizeof(*pparent)); + *pparent = part->parent_part; + rspamd_lua_setclass(L, "rspamd{mimepart}", -1); + } + else { + lua_pushnil(L); + } + + return 1; +} + + +static gint +lua_mimepart_get_text(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + struct rspamd_mime_text_part **ppart; + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (part->part_type != RSPAMD_MIME_PART_TEXT || part->specific.txt == NULL) { + lua_pushnil(L); + } + else { + ppart = lua_newuserdata(L, sizeof(*ppart)); + *ppart = part->specific.txt; + rspamd_lua_setclass(L, "rspamd{textpart}", -1); + } + + return 1; +} + +static gint +lua_mimepart_get_digest(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + gchar digestbuf[rspamd_cryptobox_HASHBYTES * 2 + 1]; + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + memset(digestbuf, 0, sizeof(digestbuf)); + rspamd_encode_hex_buf(part->digest, sizeof(part->digest), + digestbuf, sizeof(digestbuf)); + lua_pushstring(L, digestbuf); + + return 1; +} + +static gint +lua_mimepart_get_id(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + lua_pushinteger(L, part->part_number); + + return 1; +} + +static gint +lua_mimepart_headers_foreach(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + enum rspamd_lua_task_header_type how = RSPAMD_TASK_HEADER_PUSH_SIMPLE; + struct rspamd_lua_regexp *re = NULL; + struct rspamd_mime_header *hdr, *cur; + gint old_top; + + if (part && lua_isfunction(L, 2)) { + if (lua_istable(L, 3)) { + lua_pushstring(L, "full"); + lua_gettable(L, 3); + + if (lua_isboolean(L, -1) && lua_toboolean(L, -1)) { + how = RSPAMD_TASK_HEADER_PUSH_FULL; + } + + lua_pop(L, 1); + + lua_pushstring(L, "raw"); + lua_gettable(L, 3); + + if (lua_isboolean(L, -1) && lua_toboolean(L, -1)) { + how = RSPAMD_TASK_HEADER_PUSH_RAW; + } + + lua_pop(L, 1); + + lua_pushstring(L, "regexp"); + lua_gettable(L, 3); + + if (lua_isuserdata(L, -1)) { + RSPAMD_LUA_CHECK_UDATA_PTR_OR_RETURN(L, -1, "rspamd{regexp}", + struct rspamd_lua_regexp, re); + } + + lua_pop(L, 1); + } + + if (part->headers_order) { + hdr = part->headers_order; + + LL_FOREACH2(hdr, cur, ord_next) + { + if (re && re->re) { + if (!rspamd_regexp_match(re->re, cur->name, + strlen(cur->name), FALSE)) { + continue; + } + } + + old_top = lua_gettop(L); + lua_pushvalue(L, 2); + lua_pushstring(L, cur->name); + rspamd_lua_push_header(L, cur, how); + + if (lua_pcall(L, 2, LUA_MULTRET, 0) != 0) { + msg_err("call to header_foreach failed: %s", + lua_tostring(L, -1)); + lua_settop(L, old_top); + break; + } + else { + if (lua_gettop(L) > old_top) { + if (lua_isboolean(L, old_top + 1)) { + if (lua_toboolean(L, old_top + 1)) { + lua_settop(L, old_top); + break; + } + } + } + } + + lua_settop(L, old_top); + } + } + } + + return 0; +} + +static gint +lua_mimepart_get_specific(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (part->part_type != RSPAMD_MIME_PART_CUSTOM_LUA) { + lua_pushnil(L); + } + else { + lua_rawgeti(L, LUA_REGISTRYINDEX, part->specific.lua_specific.cbref); + } + + return 1; +} + +static gint +lua_mimepart_get_urls(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + struct lua_tree_cb_data cb; + struct rspamd_url *u; + static const gint default_protocols_mask = PROTOCOL_HTTP | PROTOCOL_HTTPS | + PROTOCOL_FILE | PROTOCOL_FTP; + gsize sz, max_urls = 0, i; + + if (part->urls == NULL) { + lua_newtable(L); + + return 1; + } + + if (!lua_url_cbdata_fill(L, 2, &cb, default_protocols_mask, + ~(0), max_urls)) { + return luaL_error(L, "invalid arguments"); + } + + sz = part->urls->len; + + lua_createtable(L, sz, 0); + + PTR_ARRAY_FOREACH(part->urls, i, u) + { + lua_tree_url_callback(u, u, &cb); + } + + lua_url_cbdata_dtor(&cb); + + return 1; +} + +static gint +lua_mimepart_is_specific(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + + if (part == NULL) { + return luaL_error(L, "invalid arguments"); + } + + lua_pushboolean(L, part->part_type == RSPAMD_MIME_PART_CUSTOM_LUA); + + return 1; +} + +static gint +lua_mimepart_set_specific(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_mime_part *part = lua_check_mimepart(L); + + if (part == NULL || lua_isnil(L, 2)) { + return luaL_error(L, "invalid arguments"); + } + + if (part->part_type != RSPAMD_MIME_PART_UNDEFINED && + part->part_type != RSPAMD_MIME_PART_CUSTOM_LUA) { + return luaL_error(L, + "internal error: trying to set specific lua content on part of type %d", + part->part_type); + } + + if (part->part_type == RSPAMD_MIME_PART_CUSTOM_LUA) { + /* Push old specific data */ + lua_rawgeti(L, LUA_REGISTRYINDEX, part->specific.lua_specific.cbref); + luaL_unref(L, LUA_REGISTRYINDEX, part->specific.lua_specific.cbref); + } + else { + part->part_type = RSPAMD_MIME_PART_CUSTOM_LUA; + lua_pushnil(L); + } + + /* Now, we push argument on the position 2 and save its reference */ + lua_pushvalue(L, 2); + part->specific.lua_specific.cbref = luaL_ref(L, LUA_REGISTRYINDEX); + /* Now stack has just a return value as luaL_ref removes value from stack */ + + gint ltype = lua_type(L, 2); + + switch (ltype) { + case LUA_TTABLE: + part->specific.lua_specific.type = RSPAMD_LUA_PART_TABLE; + break; + case LUA_TSTRING: + part->specific.lua_specific.type = RSPAMD_LUA_PART_STRING; + break; + case LUA_TUSERDATA: + if (rspamd_lua_check_udata_maybe(L, 2, "rspamd{text}")) { + part->specific.lua_specific.type = RSPAMD_LUA_PART_TEXT; + } + else { + part->specific.lua_specific.type = RSPAMD_LUA_PART_UNKNOWN; + } + break; + case LUA_TFUNCTION: + part->specific.lua_specific.type = RSPAMD_LUA_PART_FUNCTION; + break; + default: + part->specific.lua_specific.type = RSPAMD_LUA_PART_UNKNOWN; + break; + } + + return 1; +} + +void luaopen_textpart(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{textpart}", textpartlib_m); + lua_pop(L, 1); +} + +void luaopen_mimepart(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{mimepart}", mimepartlib_m); + lua_pop(L, 1); +} diff --git a/src/lua/lua_parsers.c b/src/lua/lua_parsers.c new file mode 100644 index 0000000..1fc71db --- /dev/null +++ b/src/lua/lua_parsers.c @@ -0,0 +1,410 @@ +/*- + * Copyright 2020 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "lua_common.h" +#include "tokenizers/tokenizers.h" +#include "contrib/uthash/utlist.h" +#include "libserver/html/html.h" +#include "libmime/email_addr.h" +#include "libmime/content_type.h" +#include "libmime/mime_headers.h" +#include "libmime/smtp_parsers.h" +#include "lua_parsers.h" + +/*** + * @module rspamd_parsers + * This module contains Lua-C interfaces to Rspamd parsers of different kind. + */ + +/*** + * @function parsers.tokenize_text(input[, exceptions]) + * Create tokens from a text using optional exceptions list + * @param {text/string} input input data + * @param {table} exceptions, a table of pairs containing <start_pos,length> of exceptions in the input + * @return {table/strings} list of strings representing words in the text + */ + + +/*** + * @function parsers.parse_html(input) + * Parses HTML and returns the according text + * @param {string|text} in input HTML + * @return {rspamd_text} processed text with no HTML tags + */ + +/*** + * @function parsers.parse_mail_address(str, [pool]) + * Parses email address and returns a table of tables in the following format: + * + * - `raw` - the original value without any processing + * - `name` - name of internet address in UTF8, e.g. for `Vsevolod Stakhov <blah@foo.com>` it returns `Vsevolod Stakhov` + * - `addr` - address part of the address + * - `user` - user part (if present) of the address, e.g. `blah` + * - `domain` - domain part (if present), e.g. `foo.com` + * - `flags` - table with following keys set to true if given condition fulfilled: + * - [valid] - valid SMTP address in conformity with https://tools.ietf.org/html/rfc5321#section-4.1. + * - [ip] - domain is IPv4/IPv6 address + * - [braced] - angled `<blah@foo.com>` address + * - [quoted] - quoted user part + * - [empty] - empty address + * - [backslash] - user part contains backslash + * - [8bit] - contains 8bit characters + * + * @param {string} str input string + * @param {rspamd_mempool} pool memory pool to use + * @return {table/tables} parsed list of mail addresses + */ + +/*** + * @function parsers.parse_content_type(ct_string, mempool) + * Parses content-type string to a table: + * - `type` + * - `subtype` + * - `charset` + * - `boundary` + * - other attributes + * + * @param {string} ct_string content type as string + * @param {rspamd_mempool} mempool needed to store temporary data (e.g. task pool) + * @return table or nil if cannot parse content type + */ + +/*** + * @function parsers.parse_smtp_date(str[, local_tz]) + * Converts an SMTP date string to unix timestamp + * @param {string} str input string + * @param {boolean} local_tz convert to local tz if `true` + * @return {number} time as unix timestamp (converted to float) + */ + +static const struct luaL_reg parserslib_f[] = { + LUA_INTERFACE_DEF(parsers, tokenize_text), + LUA_INTERFACE_DEF(parsers, parse_html), + LUA_INTERFACE_DEF(parsers, parse_mail_address), + LUA_INTERFACE_DEF(parsers, parse_content_type), + LUA_INTERFACE_DEF(parsers, parse_smtp_date), + + {NULL, NULL}}; + +gint lua_parsers_tokenize_text(lua_State *L) +{ + LUA_TRACE_POINT; + const gchar *in = NULL; + gsize len = 0, pos, ex_len, i; + GList *exceptions = NULL, *cur; + struct rspamd_lua_text *t; + struct rspamd_process_exception *ex; + UText utxt = UTEXT_INITIALIZER; + GArray *res; + rspamd_stat_token_t *w; + + if (lua_type(L, 1) == LUA_TSTRING) { + in = luaL_checklstring(L, 1, &len); + } + else if (lua_type(L, 1) == LUA_TUSERDATA) { + t = lua_check_text(L, 1); + + if (t) { + in = t->start; + len = t->len; + } + } + + if (in == NULL) { + lua_pushnil(L); + return 1; + } + + if (lua_gettop(L) > 1 && lua_type(L, 2) == LUA_TTABLE) { + lua_pushvalue(L, 2); + lua_pushnil(L); + + while (lua_next(L, -2) != 0) { + if (lua_type(L, -1) == LUA_TTABLE) { + lua_rawgeti(L, -1, 1); + pos = luaL_checknumber(L, -1); + lua_pop(L, 1); + lua_rawgeti(L, -1, 2); + ex_len = luaL_checknumber(L, -1); + lua_pop(L, 1); + + if (ex_len > 0) { + ex = g_malloc0(sizeof(*ex)); + ex->pos = pos; + ex->len = ex_len; + ex->type = RSPAMD_EXCEPTION_GENERIC; + exceptions = g_list_prepend(exceptions, ex); + } + } + lua_pop(L, 1); + } + + lua_pop(L, 1); + } + + if (exceptions) { + exceptions = g_list_reverse(exceptions); + } + + UErrorCode uc_err = U_ZERO_ERROR; + utext_openUTF8(&utxt, + in, + len, + &uc_err); + + res = rspamd_tokenize_text((gchar *) in, len, + &utxt, + RSPAMD_TOKENIZE_UTF, NULL, + exceptions, + NULL, NULL, NULL); + + if (res == NULL) { + lua_pushnil(L); + } + else { + lua_createtable(L, res->len, 0); + + for (i = 0; i < res->len; i++) { + w = &g_array_index(res, rspamd_stat_token_t, i); + lua_pushlstring(L, w->original.begin, w->original.len); + lua_rawseti(L, -2, i + 1); + } + } + + cur = exceptions; + while (cur) { + ex = cur->data; + g_free(ex); + cur = g_list_next(cur); + } + + g_list_free(exceptions); + utext_close(&utxt); + + return 1; +} + +gint lua_parsers_parse_html(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t; + const gchar *start = NULL; + gsize len; + GByteArray *in; + rspamd_mempool_t *pool; + void *hc; + + if (lua_type(L, 1) == LUA_TUSERDATA) { + t = lua_check_text(L, 1); + + if (t != NULL) { + start = t->start; + len = t->len; + } + } + else if (lua_type(L, 1) == LUA_TSTRING) { + start = luaL_checklstring(L, 1, &len); + } + + if (start != NULL) { + pool = rspamd_mempool_new(rspamd_mempool_suggest_size(), NULL, 0); + in = g_byte_array_sized_new(len); + g_byte_array_append(in, start, len); + + hc = rspamd_html_process_part(pool, in); + + rspamd_ftok_t res; + rspamd_html_get_parsed_content(hc, &res); + lua_new_text(L, res.begin, res.len, TRUE); + + g_byte_array_free(in, TRUE); + rspamd_mempool_delete(pool); + } + else { + lua_pushnil(L); + } + + return 1; +} + +gint lua_parsers_parse_mail_address(lua_State *L) +{ + LUA_TRACE_POINT; + GPtrArray *addrs; + gsize len; + const gchar *str = luaL_checklstring(L, 1, &len); + gint max_addrs = luaL_optinteger(L, 3, 10240); + rspamd_mempool_t *pool; + gboolean own_pool = FALSE; + + if (str) { + + if (lua_type(L, 2) == LUA_TUSERDATA) { + pool = rspamd_lua_check_mempool(L, 2); + + if (pool == NULL) { + return luaL_error(L, "invalid arguments"); + } + } + else { + pool = rspamd_mempool_new(rspamd_mempool_suggest_size(), + "lua parsers", 0); + own_pool = TRUE; + } + + addrs = rspamd_email_address_from_mime(pool, str, len, NULL, max_addrs); + + if (addrs == NULL) { + lua_pushnil(L); + } + else { + lua_push_emails_address_list(L, addrs, 0); + } + + if (own_pool) { + rspamd_mempool_delete(pool); + } + } + else { + lua_pushnil(L); + } + + return 1; +} + +gint lua_parsers_parse_content_type(lua_State *L) +{ + LUA_TRACE_POINT; + gsize len; + const gchar *ct_str = luaL_checklstring(L, 1, &len); + rspamd_mempool_t *pool = rspamd_lua_check_mempool(L, 2); + struct rspamd_content_type *ct; + + if (!ct_str || !pool) { + return luaL_error(L, "invalid arguments"); + } + + ct = rspamd_content_type_parse(ct_str, len, pool); + + if (ct == NULL) { + lua_pushnil(L); + } + else { + GHashTableIter it; + gpointer k, v; + + lua_createtable(L, 0, 4 + (ct->attrs ? g_hash_table_size(ct->attrs) : 0)); + + if (ct->type.len > 0) { + lua_pushstring(L, "type"); + lua_pushlstring(L, ct->type.begin, ct->type.len); + lua_settable(L, -3); + } + + if (ct->subtype.len > 0) { + lua_pushstring(L, "subtype"); + lua_pushlstring(L, ct->subtype.begin, ct->subtype.len); + lua_settable(L, -3); + } + + if (ct->charset.len > 0) { + lua_pushstring(L, "charset"); + lua_pushlstring(L, ct->charset.begin, ct->charset.len); + lua_settable(L, -3); + } + + if (ct->orig_boundary.len > 0) { + lua_pushstring(L, "boundary"); + lua_pushlstring(L, ct->orig_boundary.begin, ct->orig_boundary.len); + lua_settable(L, -3); + } + + if (ct->attrs) { + g_hash_table_iter_init(&it, ct->attrs); + + while (g_hash_table_iter_next(&it, &k, &v)) { + struct rspamd_content_type_param *param = + (struct rspamd_content_type_param *) v, + *cur; + guint i = 1; + + lua_pushlstring(L, param->name.begin, param->name.len); + lua_createtable(L, 1, 0); + + DL_FOREACH(param, cur) + { + lua_pushlstring(L, cur->value.begin, cur->value.len); + lua_rawseti(L, -2, i++); + } + + lua_settable(L, -3); + } + } + } + + return 1; +} + +int lua_parsers_parse_smtp_date(lua_State *L) +{ + gsize slen; + const gchar *str = lua_tolstring(L, 1, &slen); + GError *err = NULL; + + if (str == NULL) { + return luaL_argerror(L, 1, "invalid argument"); + } + + time_t tt = rspamd_parse_smtp_date(str, slen, &err); + + if (err == NULL) { + if (lua_isboolean(L, 2) && !!lua_toboolean(L, 2)) { + struct tm t; + + rspamd_localtime(tt, &t); +#if !defined(__sun) + t.tm_gmtoff = 0; +#endif + t.tm_isdst = 0; + tt = mktime(&t); + } + + lua_pushnumber(L, tt); + } + else { + lua_pushnil(L); + lua_pushstring(L, err->message); + g_error_free(err); + + return 2; + } + + return 1; +} + +static gint +lua_load_parsers(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, parserslib_f); + + return 1; +} + +void luaopen_parsers(lua_State *L) +{ + rspamd_lua_add_preload(L, "rspamd_parsers", lua_load_parsers); +}
\ No newline at end of file diff --git a/src/lua/lua_parsers.h b/src/lua/lua_parsers.h new file mode 100644 index 0000000..2466938 --- /dev/null +++ b/src/lua/lua_parsers.h @@ -0,0 +1,88 @@ +/*- + * Copyright 2020 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef RSPAMD_LUA_PARSERS_H +#define RSPAMD_LUA_PARSERS_H + +#include "lua_common.h" + +/*** + * @function parsers.tokenize_text(input[, exceptions]) + * Create tokens from a text using optional exceptions list + * @param {text/string} input input data + * @param {table} exceptions, a table of pairs containing <start_pos,length> of exceptions in the input + * @return {table/strings} list of strings representing words in the text + */ +LUA_PUBLIC_FUNCTION_DEF(parsers, tokenize_text); + +/*** + * @function parsers.parse_html(input) + * Parses HTML and returns the according text + * @param {string|text} in input HTML + * @return {rspamd_text} processed text with no HTML tags + */ +LUA_PUBLIC_FUNCTION_DEF(parsers, parse_html); + +/*** + * @function parsers.parse_mail_address(str, [pool]) + * Parses email address and returns a table of tables in the following format: + * + * - `raw` - the original value without any processing + * - `name` - name of internet address in UTF8, e.g. for `Vsevolod Stakhov <blah@foo.com>` it returns `Vsevolod Stakhov` + * - `addr` - address part of the address + * - `user` - user part (if present) of the address, e.g. `blah` + * - `domain` - domain part (if present), e.g. `foo.com` + * - `flags` - table with following keys set to true if given condition fulfilled: + * - [valid] - valid SMTP address in conformity with https://tools.ietf.org/html/rfc5321#section-4.1. + * - [ip] - domain is IPv4/IPv6 address + * - [braced] - angled `<blah@foo.com>` address + * - [quoted] - quoted user part + * - [empty] - empty address + * - [backslash] - user part contains backslash + * - [8bit] - contains 8bit characters + * + * @param {string} str input string + * @param {rspamd_mempool} pool memory pool to use + * @return {table/tables} parsed list of mail addresses + */ +LUA_PUBLIC_FUNCTION_DEF(parsers, parse_mail_address); + +/*** + * @function parsers.parse_content_type(ct_string, mempool) + * Parses content-type string to a table: + * - `type` + * - `subtype` + * - `charset` + * - `boundary` + * - other attributes + * + * @param {string} ct_string content type as string + * @param {rspamd_mempool} mempool needed to store temporary data (e.g. task pool) + * @return table or nil if cannot parse content type + */ +LUA_PUBLIC_FUNCTION_DEF(parsers, parse_content_type); + +/*** + * @function parsers.parse_smtp_date(str[, local_tz]) + * Converts an SMTP date string to unix timestamp + * @param {string} str input string + * @param {boolean} local_tz convert to local tz if `true` + * @return {number} time as unix timestamp (converted to float) + */ +LUA_PUBLIC_FUNCTION_DEF(parsers, parse_smtp_date); + + +#endif//RSPAMD_LUA_PARSERS_H diff --git a/src/lua/lua_redis.c b/src/lua/lua_redis.c new file mode 100644 index 0000000..1ad3b3d --- /dev/null +++ b/src/lua/lua_redis.c @@ -0,0 +1,1662 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" +#include "lua_thread_pool.h" +#include "utlist.h" + +#include "contrib/hiredis/hiredis.h" +#include "contrib/hiredis/async.h" + +#define REDIS_DEFAULT_TIMEOUT 1.0 + +static const gchar *M = "rspamd lua redis"; +static void *redis_null; + +/*** + * @module rspamd_redis + * This module implements redis asynchronous client for rspamd LUA API. + * Here is an example of using of this module: + * @example +local rspamd_redis = require "rspamd_redis" +local rspamd_logger = require "rspamd_logger" + +local function symbol_callback(task) + local redis_key = 'some_key' + local function redis_cb(err, data) + if not err then + rspamd_logger.infox('redis returned %1=%2', redis_key, data) + end + end + + rspamd_redis.make_request(task, "127.0.0.1:6379", redis_cb, + 'GET', {redis_key}) + -- or in table form: + -- rspamd_redis.make_request({task=task, host="127.0.0.1:6379, + -- callback=redis_cb, timeout=2.0, cmd='GET', args={redis_key}}) +end + */ + +LUA_FUNCTION_DEF(redis, make_request); +LUA_FUNCTION_DEF(redis, make_request_sync); +LUA_FUNCTION_DEF(redis, connect); +LUA_FUNCTION_DEF(redis, connect_sync); +LUA_FUNCTION_DEF(redis, add_cmd); +LUA_FUNCTION_DEF(redis, exec); +LUA_FUNCTION_DEF(redis, gc); + +static const struct luaL_reg redislib_f[] = { + LUA_INTERFACE_DEF(redis, make_request), + LUA_INTERFACE_DEF(redis, make_request_sync), + LUA_INTERFACE_DEF(redis, connect), + LUA_INTERFACE_DEF(redis, connect_sync), + {NULL, NULL}}; + +static const struct luaL_reg redislib_m[] = { + LUA_INTERFACE_DEF(redis, add_cmd), + LUA_INTERFACE_DEF(redis, exec), + {"__gc", lua_redis_gc}, + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + +#undef REDIS_DEBUG_REFS +#ifdef REDIS_DEBUG_REFS +#define REDIS_RETAIN(x) \ + do { \ + msg_err("retain ref %p, refcount: %d", (x), (x)->ref.refcount); \ + REF_RETAIN(x); \ + } while (0) + +#define REDIS_RELEASE(x) \ + do { \ + msg_err("release ref %p, refcount: %d", (x), (x)->ref.refcount); \ + REF_RELEASE(x); \ + } while (0) +#else +#define REDIS_RETAIN REF_RETAIN +#define REDIS_RELEASE REF_RELEASE +#endif + +struct lua_redis_request_specific_userdata; +/** + * Struct for userdata representation + */ +struct lua_redis_userdata { + redisAsyncContext *ctx; + struct rspamd_task *task; + struct rspamd_symcache_dynamic_item *item; + struct rspamd_async_session *s; + struct ev_loop *event_loop; + struct rspamd_config *cfg; + struct rspamd_redis_pool *pool; + gchar *server; + gchar log_tag[RSPAMD_LOG_ID_LEN + 1]; + struct lua_redis_request_specific_userdata *specific; + gdouble timeout; + guint16 port; + guint16 terminated; +}; + +#define msg_debug_lua_redis(...) rspamd_conditional_debug_fast(NULL, NULL, \ + rspamd_lua_redis_log_id, "lua_redis", ud->log_tag, \ + G_STRFUNC, \ + __VA_ARGS__) +INIT_LOG_MODULE(lua_redis) + +#define LUA_REDIS_SPECIFIC_REPLIED (1 << 0) +/* session was finished */ +#define LUA_REDIS_SPECIFIC_FINISHED (1 << 1) +#define LUA_REDIS_ASYNC (1 << 0) +#define LUA_REDIS_TEXTDATA (1 << 1) +#define LUA_REDIS_TERMINATED (1 << 2) +#define LUA_REDIS_NO_POOL (1 << 3) +#define LUA_REDIS_SUBSCRIBED (1 << 4) +#define IS_ASYNC(ctx) ((ctx)->flags & LUA_REDIS_ASYNC) + +struct lua_redis_request_specific_userdata { + gint cbref; + guint nargs; + gchar **args; + gsize *arglens; + struct lua_redis_userdata *c; + struct lua_redis_ctx *ctx; + struct lua_redis_request_specific_userdata *next; + ev_timer timeout_ev; + guint flags; +}; + +struct lua_redis_ctx { + guint flags; + struct lua_redis_userdata async; + guint cmds_pending; + ref_entry_t ref; + GQueue *replies; /* for sync connection only */ + GQueue *events_cleanup; /* for sync connection only */ + struct thread_entry *thread; /* for sync mode, set only if there was yield */ +}; + +struct lua_redis_result { + gboolean is_error; + gint result_ref; + struct rspamd_symcache_dynamic_item *item; + struct rspamd_async_session *s; + struct rspamd_task *task; + struct lua_redis_request_specific_userdata *sp_ud; +}; + +static struct lua_redis_ctx * +lua_check_redis(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{redis}"); + luaL_argcheck(L, ud != NULL, pos, "'redis' expected"); + return ud ? *((struct lua_redis_ctx **) ud) : NULL; +} + +static void +lua_redis_free_args(char **args, gsize *arglens, guint nargs) +{ + guint i; + + if (args) { + for (i = 0; i < nargs; i++) { + g_free(args[i]); + } + + g_free(args); + g_free(arglens); + } +} + +static void +lua_redis_dtor(struct lua_redis_ctx *ctx) +{ + struct lua_redis_userdata *ud; + struct lua_redis_request_specific_userdata *cur, *tmp; + gboolean is_successful = TRUE; + struct redisAsyncContext *ac; + + ud = &ctx->async; + msg_debug_lua_redis("destructing %p", ctx); + + if (ud->ctx) { + + LL_FOREACH_SAFE(ud->specific, cur, tmp) + { + ev_timer_stop(ud->event_loop, &cur->timeout_ev); + + if (!(cur->flags & LUA_REDIS_SPECIFIC_REPLIED)) { + is_successful = FALSE; + } + + cur->flags |= LUA_REDIS_SPECIFIC_FINISHED; + } + + ctx->flags |= LUA_REDIS_TERMINATED; + + ud->terminated = 1; + ac = ud->ctx; + ud->ctx = NULL; + + if (!is_successful) { + rspamd_redis_pool_release_connection(ud->pool, ac, + RSPAMD_REDIS_RELEASE_FATAL); + } + else { + rspamd_redis_pool_release_connection(ud->pool, ac, + (ctx->flags & LUA_REDIS_NO_POOL) ? RSPAMD_REDIS_RELEASE_ENFORCE : RSPAMD_REDIS_RELEASE_DEFAULT); + } + } + + LL_FOREACH_SAFE(ud->specific, cur, tmp) + { + lua_redis_free_args(cur->args, cur->arglens, cur->nargs); + + if (cur->cbref != -1) { + luaL_unref(ud->cfg->lua_state, LUA_REGISTRYINDEX, cur->cbref); + } + + g_free(cur); + } + + if (ctx->events_cleanup) { + g_queue_free(ctx->events_cleanup); + ctx->events_cleanup = NULL; + } + if (ctx->replies) { + g_queue_free(ctx->replies); + ctx->replies = NULL; + } + + g_free(ctx); +} + +static gint +lua_redis_gc(lua_State *L) +{ + struct lua_redis_ctx *ctx = lua_check_redis(L, 1); + + if (ctx) { + REDIS_RELEASE(ctx); + } + + return 0; +} + +static void +lua_redis_fin(void *arg) +{ + struct lua_redis_request_specific_userdata *sp_ud = arg; + struct lua_redis_userdata *ud; + struct lua_redis_ctx *ctx; + + ctx = sp_ud->ctx; + ud = sp_ud->c; + + if (ev_can_stop(&sp_ud->timeout_ev)) { + ev_timer_stop(sp_ud->ctx->async.event_loop, &sp_ud->timeout_ev); + } + + msg_debug_lua_redis("finished redis query %p from session %p; refcount=%d", + sp_ud, ctx, ctx->ref.refcount); + sp_ud->flags |= LUA_REDIS_SPECIFIC_FINISHED; + + REDIS_RELEASE(ctx); +} + +/** + * Push error of redis request to lua callback + * @param code + * @param ud + */ +static void +lua_redis_push_error(const gchar *err, + struct lua_redis_ctx *ctx, + struct lua_redis_request_specific_userdata *sp_ud, + gboolean connected) +{ + struct lua_redis_userdata *ud = sp_ud->c; + struct lua_callback_state cbs; + lua_State *L; + + if (!(sp_ud->flags & (LUA_REDIS_SPECIFIC_REPLIED | LUA_REDIS_SPECIFIC_FINISHED))) { + if (sp_ud->cbref != -1) { + + lua_thread_pool_prepare_callback(ud->cfg->lua_thread_pool, &cbs); + L = cbs.L; + + lua_pushcfunction(L, &rspamd_lua_traceback); + int err_idx = lua_gettop(L); + /* Push error */ + lua_rawgeti(cbs.L, LUA_REGISTRYINDEX, sp_ud->cbref); + + /* String of error */ + lua_pushstring(cbs.L, err); + /* Data is nil */ + lua_pushnil(cbs.L); + + if (ud->item) { + rspamd_symcache_set_cur_item(ud->task, ud->item); + } + + if (lua_pcall(cbs.L, 2, 0, err_idx) != 0) { + msg_info("call to callback failed: %s", lua_tostring(cbs.L, -1)); + } + + lua_settop(L, err_idx - 1); + lua_thread_pool_restore_callback(&cbs); + } + + sp_ud->flags |= LUA_REDIS_SPECIFIC_REPLIED; + + if (connected && ud->s) { + if (ud->item) { + rspamd_symcache_item_async_dec_check(ud->task, ud->item, M); + } + + rspamd_session_remove_event(ud->s, lua_redis_fin, sp_ud); + } + else { + lua_redis_fin(sp_ud); + } + } +} + +static void +lua_redis_push_reply(lua_State *L, const redisReply *r, gboolean text_data) +{ + guint i; + struct rspamd_lua_text *t; + + switch (r->type) { + case REDIS_REPLY_INTEGER: + lua_pushinteger(L, r->integer); + break; + case REDIS_REPLY_NIL: + lua_getfield(L, LUA_REGISTRYINDEX, "redis.null"); + break; + case REDIS_REPLY_STRING: + case REDIS_REPLY_STATUS: + if (text_data) { + t = lua_newuserdata(L, sizeof(*t)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + t->flags = 0; + t->start = r->str; + t->len = r->len; + } + else { + lua_pushlstring(L, r->str, r->len); + } + break; + case REDIS_REPLY_ARRAY: + lua_createtable(L, r->elements, 0); + for (i = 0; i < r->elements; ++i) { + lua_redis_push_reply(L, r->element[i], text_data); + lua_rawseti(L, -2, i + 1); /* Store sub-reply */ + } + break; + default: /* should not happen */ + msg_info("unknown reply type: %d", r->type); + break; + } +} + +/** + * Push data of redis request to lua callback + * @param r redis reply data + * @param ud + */ +static void +lua_redis_push_data(const redisReply *r, struct lua_redis_ctx *ctx, + struct lua_redis_request_specific_userdata *sp_ud) +{ + struct lua_redis_userdata *ud = sp_ud->c; + struct lua_callback_state cbs; + lua_State *L; + + if (!(sp_ud->flags & (LUA_REDIS_SPECIFIC_REPLIED | LUA_REDIS_SPECIFIC_FINISHED)) || + (sp_ud->flags & LUA_REDIS_SUBSCRIBED)) { + if (sp_ud->cbref != -1) { + lua_thread_pool_prepare_callback(ud->cfg->lua_thread_pool, &cbs); + L = cbs.L; + + lua_pushcfunction(L, &rspamd_lua_traceback); + int err_idx = lua_gettop(L); + /* Push error */ + lua_rawgeti(cbs.L, LUA_REGISTRYINDEX, sp_ud->cbref); + /* Error is nil */ + lua_pushnil(cbs.L); + /* Data */ + lua_redis_push_reply(cbs.L, r, ctx->flags & LUA_REDIS_TEXTDATA); + + if (ud->item) { + rspamd_symcache_set_cur_item(ud->task, ud->item); + } + + gint ret = lua_pcall(cbs.L, 2, 0, err_idx); + + if (ret != 0) { + msg_info("call to lua_redis callback failed (%d): %s", + ret, lua_tostring(cbs.L, -1)); + } + + lua_settop(L, err_idx - 1); + lua_thread_pool_restore_callback(&cbs); + } + + if (sp_ud->flags & LUA_REDIS_SUBSCRIBED) { + if (!(sp_ud->flags & LUA_REDIS_SPECIFIC_REPLIED)) { + if (ev_can_stop(&sp_ud->timeout_ev)) { + ev_timer_stop(sp_ud->ctx->async.event_loop, + &sp_ud->timeout_ev); + } + } + } + + sp_ud->flags |= LUA_REDIS_SPECIFIC_REPLIED; + + if (!(sp_ud->flags & LUA_REDIS_SUBSCRIBED)) { + if (ud->s) { + if (ud->item) { + rspamd_symcache_item_async_dec_check(ud->task, + ud->item, M); + } + + rspamd_session_remove_event(ud->s, lua_redis_fin, sp_ud); + } + else { + lua_redis_fin(sp_ud); + } + } + } +} + +/** + * Callback for redis replies + * @param c context of redis connection + * @param r redis reply + * @param priv userdata + */ +static void +lua_redis_callback(redisAsyncContext *c, gpointer r, gpointer priv) +{ + redisReply *reply = r; + struct lua_redis_request_specific_userdata *sp_ud = priv; + struct lua_redis_ctx *ctx; + struct lua_redis_userdata *ud; + redisAsyncContext *ac; + + ctx = sp_ud->ctx; + ud = sp_ud->c; + + if (ud->terminated || !rspamd_lua_is_initialised()) { + /* We are already at the termination stage, just go out */ + return; + } + + msg_debug_lua_redis("got reply from redis %p for query %p", sp_ud->c->ctx, + sp_ud); + + REDIS_RETAIN(ctx); + + /* If session is finished, we cannot call lua callbacks */ + if (!(sp_ud->flags & LUA_REDIS_SPECIFIC_FINISHED) || + (sp_ud->flags & LUA_REDIS_SUBSCRIBED)) { + if (c->err == 0) { + if (r != NULL) { + if (reply->type != REDIS_REPLY_ERROR) { + lua_redis_push_data(reply, ctx, sp_ud); + } + else { + lua_redis_push_error(reply->str, ctx, sp_ud, TRUE); + } + } + else { + lua_redis_push_error("received no data from server", ctx, sp_ud, TRUE); + } + } + else { + if (c->err == REDIS_ERR_IO) { + lua_redis_push_error(strerror(errno), ctx, sp_ud, TRUE); + } + else { + lua_redis_push_error(c->errstr, ctx, sp_ud, TRUE); + } + } + } + + if (!(sp_ud->flags & LUA_REDIS_SUBSCRIBED)) { + ctx->cmds_pending--; + + if (ctx->cmds_pending == 0 && !ud->terminated) { + /* Disconnect redis early as we don't need it anymore */ + ud->terminated = 1; + ac = ud->ctx; + ud->ctx = NULL; + + if (ac) { + msg_debug_lua_redis("release redis connection ud=%p; ctx=%p; refcount=%d", + ud, ctx, ctx->ref.refcount); + rspamd_redis_pool_release_connection(ud->pool, ac, + (ctx->flags & LUA_REDIS_NO_POOL) ? RSPAMD_REDIS_RELEASE_ENFORCE : RSPAMD_REDIS_RELEASE_DEFAULT); + } + } + } + + REDIS_RELEASE(ctx); +} + +static gint +lua_redis_push_results(struct lua_redis_ctx *ctx, lua_State *L) +{ + gint results = g_queue_get_length(ctx->replies); + gint i; + gboolean can_use_lua = TRUE; + + results = g_queue_get_length(ctx->replies); + + if (!lua_checkstack(L, (results * 2) + 1)) { + luaL_error(L, "cannot resize stack to fit %d commands", + ctx->cmds_pending); + + can_use_lua = FALSE; + } + + for (i = 0; i < results; i++) { + struct lua_redis_result *result = g_queue_pop_head(ctx->replies); + + if (can_use_lua) { + lua_pushboolean(L, !result->is_error); + lua_rawgeti(L, LUA_REGISTRYINDEX, result->result_ref); + } + + luaL_unref(L, LUA_REGISTRYINDEX, result->result_ref); + + g_queue_push_tail(ctx->events_cleanup, result); + } + + return can_use_lua ? results * 2 : 0; +} + +static void +lua_redis_cleanup_events(struct lua_redis_ctx *ctx) +{ + REDIS_RETAIN(ctx); /* To avoid preliminary destruction */ + + while (!g_queue_is_empty(ctx->events_cleanup)) { + struct lua_redis_result *result = g_queue_pop_head(ctx->events_cleanup); + + if (result->item) { + rspamd_symcache_item_async_dec_check(result->task, result->item, M); + } + + if (result->s) { + rspamd_session_remove_event(result->s, lua_redis_fin, result->sp_ud); + } + else { + lua_redis_fin(result->sp_ud); + } + + g_free(result); + } + + REDIS_RELEASE(ctx); +} + +/** + * Callback for redis replies + * @param c context of redis connection + * @param r redis reply + * @param priv userdata + */ +static void +lua_redis_callback_sync(redisAsyncContext *ac, gpointer r, gpointer priv) +{ + redisReply *reply = r; + + struct lua_redis_request_specific_userdata *sp_ud = priv; + struct lua_redis_ctx *ctx; + struct lua_redis_userdata *ud; + struct thread_entry *thread; + gint results; + + ctx = sp_ud->ctx; + ud = sp_ud->c; + lua_State *L = ctx->async.cfg->lua_state; + + sp_ud->flags |= LUA_REDIS_SPECIFIC_REPLIED; + + if (ud->terminated) { + /* We are already at the termination stage, just go out */ + /* TODO: + if somebody is waiting for us (ctx->thread), return result, + otherwise, indeed, ignore + */ + return; + } + + if (ev_can_stop(&sp_ud->timeout_ev)) { + ev_timer_stop(ud->event_loop, &sp_ud->timeout_ev); + } + + if (!(sp_ud->flags & LUA_REDIS_SPECIFIC_FINISHED)) { + msg_debug_lua_redis("got reply from redis: %p for query %p", ac, sp_ud); + + struct lua_redis_result *result = g_malloc0(sizeof *result); + + if (ac->err == 0) { + if (r != NULL) { + if (reply->type != REDIS_REPLY_ERROR) { + result->is_error = FALSE; + lua_redis_push_reply(L, reply, ctx->flags & LUA_REDIS_TEXTDATA); + } + else { + result->is_error = TRUE; + lua_pushstring(L, reply->str); + } + } + else { + result->is_error = TRUE; + lua_pushliteral(L, "received no data from server"); + } + } + else { + result->is_error = TRUE; + if (ac->err == REDIS_ERR_IO) { + lua_pushstring(L, strerror(errno)); + } + else { + lua_pushstring(L, ac->errstr); + } + } + + /* if error happened, we should terminate the connection, + and release it */ + + if (result->is_error && sp_ud->c->ctx) { + ac = sp_ud->c->ctx; + /* Set to NULL to avoid double free in dtor */ + sp_ud->c->ctx = NULL; + ctx->flags |= LUA_REDIS_TERMINATED; + + /* + * This will call all callbacks pending so the entire context + * will be destructed + */ + rspamd_redis_pool_release_connection(sp_ud->c->pool, ac, + RSPAMD_REDIS_RELEASE_FATAL); + } + + result->result_ref = luaL_ref(L, LUA_REGISTRYINDEX); + result->s = ud->s; + result->item = ud->item; + result->task = ud->task; + result->sp_ud = sp_ud; + + g_queue_push_tail(ctx->replies, result); + } + + ctx->cmds_pending--; + + if (ctx->cmds_pending == 0) { + if (ctx->thread) { + if (!(sp_ud->flags & LUA_REDIS_SPECIFIC_FINISHED)) { + /* somebody yielded and waits for results */ + thread = ctx->thread; + ctx->thread = NULL; + + results = lua_redis_push_results(ctx, thread->lua_state); + + if (ud->item) { + rspamd_symcache_set_cur_item(ud->task, ud->item); + } + + lua_thread_resume(thread, results); + lua_redis_cleanup_events(ctx); + } + else { + /* We cannot resume the thread as the associated task has gone */ + lua_thread_pool_terminate_entry_full(ud->cfg->lua_thread_pool, + ctx->thread, G_STRLOC, true); + ctx->thread = NULL; + } + } + } +} + +static void +lua_redis_timeout_sync(EV_P_ ev_timer *w, int revents) +{ + struct lua_redis_request_specific_userdata *sp_ud = + (struct lua_redis_request_specific_userdata *) w->data; + struct lua_redis_ctx *ctx; + struct lua_redis_userdata *ud; + redisAsyncContext *ac; + + if (sp_ud->flags & LUA_REDIS_SPECIFIC_FINISHED) { + return; + } + + ud = sp_ud->c; + ctx = sp_ud->ctx; + msg_debug_lua_redis("timeout while querying redis server: %p, redis: %p", sp_ud, + sp_ud->c->ctx); + + if (sp_ud->c->ctx) { + ac = sp_ud->c->ctx; + + /* Set to NULL to avoid double free in dtor */ + sp_ud->c->ctx = NULL; + ac->err = REDIS_ERR_IO; + errno = ETIMEDOUT; + ctx->flags |= LUA_REDIS_TERMINATED; + + /* + * This will call all callbacks pending so the entire context + * will be destructed + */ + rspamd_redis_pool_release_connection(sp_ud->c->pool, ac, + RSPAMD_REDIS_RELEASE_FATAL); + } +} + +static void +lua_redis_timeout(EV_P_ ev_timer *w, int revents) +{ + struct lua_redis_request_specific_userdata *sp_ud = + (struct lua_redis_request_specific_userdata *) w->data; + struct lua_redis_userdata *ud; + struct lua_redis_ctx *ctx; + redisAsyncContext *ac; + + if (sp_ud->flags & LUA_REDIS_SPECIFIC_FINISHED) { + return; + } + + ctx = sp_ud->ctx; + ud = sp_ud->c; + + REDIS_RETAIN(ctx); + msg_debug_lua_redis("timeout while querying redis server: %p, redis: %p", sp_ud, + sp_ud->c->ctx); + lua_redis_push_error("timeout while connecting the server", ctx, sp_ud, TRUE); + + if (sp_ud->c->ctx) { + ac = sp_ud->c->ctx; + /* Set to NULL to avoid double free in dtor */ + sp_ud->c->ctx = NULL; + ac->err = REDIS_ERR_IO; + errno = ETIMEDOUT; + /* + * This will call all callbacks pending so the entire context + * will be destructed + */ + rspamd_redis_pool_release_connection(sp_ud->c->pool, ac, + RSPAMD_REDIS_RELEASE_FATAL); + } + + REDIS_RELEASE(ctx); +} + + +static void +lua_redis_parse_args(lua_State *L, gint idx, const gchar *cmd, + gchar ***pargs, gsize **parglens, guint *nargs) +{ + gchar **args = NULL; + gsize *arglens; + gint top; + + if (idx != 0 && lua_type(L, idx) == LUA_TTABLE) { + /* Get all arguments */ + lua_pushvalue(L, idx); + lua_pushnil(L); + top = 0; + + while (lua_next(L, -2) != 0) { + gint type = lua_type(L, -1); + + if (type == LUA_TNUMBER || type == LUA_TSTRING || + type == LUA_TUSERDATA) { + top++; + } + lua_pop(L, 1); + } + + args = g_malloc((top + 1) * sizeof(gchar *)); + arglens = g_malloc((top + 1) * sizeof(gsize)); + arglens[0] = strlen(cmd); + args[0] = g_malloc(arglens[0]); + memcpy(args[0], cmd, arglens[0]); + top = 1; + lua_pushnil(L); + + while (lua_next(L, -2) != 0) { + gint type = lua_type(L, -1); + + if (type == LUA_TSTRING) { + const gchar *s; + + s = lua_tolstring(L, -1, &arglens[top]); + args[top] = g_malloc(arglens[top]); + memcpy(args[top], s, arglens[top]); + top++; + } + else if (type == LUA_TUSERDATA) { + struct rspamd_lua_text *t; + + t = lua_check_text(L, -1); + + if (t && t->start) { + arglens[top] = t->len; + args[top] = g_malloc(arglens[top]); + memcpy(args[top], t->start, arglens[top]); + top++; + } + } + else if (type == LUA_TNUMBER) { + gdouble val = lua_tonumber(L, -1); + gint r; + gchar numbuf[64]; + + if (val == (gdouble) ((gint64) val)) { + r = rspamd_snprintf(numbuf, sizeof(numbuf), "%L", + (gint64) val); + } + else { + r = rspamd_snprintf(numbuf, sizeof(numbuf), "%f", + val); + } + + arglens[top] = r; + args[top] = g_malloc(arglens[top]); + memcpy(args[top], numbuf, arglens[top]); + top++; + } + + lua_pop(L, 1); + } + + lua_pop(L, 1); + } + else { + /* Use merely cmd */ + + args = g_malloc(sizeof(gchar *)); + arglens = g_malloc(sizeof(gsize)); + arglens[0] = strlen(cmd); + args[0] = g_malloc(arglens[0]); + memcpy(args[0], cmd, arglens[0]); + top = 1; + } + + *pargs = args; + *parglens = arglens; + *nargs = top; +} + +static struct lua_redis_ctx * +rspamd_lua_redis_prepare_connection(lua_State *L, gint *pcbref, gboolean is_async) +{ + struct lua_redis_ctx *ctx = NULL; + rspamd_inet_addr_t *ip = NULL; + struct lua_redis_userdata *ud = NULL; + struct rspamd_lua_ip *addr = NULL; + struct rspamd_task *task = NULL; + const gchar *host = NULL; + const gchar *username = NULL, *password = NULL, *dbname = NULL, *log_tag = NULL; + gint cbref = -1; + struct rspamd_config *cfg = NULL; + struct rspamd_async_session *session = NULL; + struct ev_loop *ev_base = NULL; + gboolean ret = FALSE; + guint flags = 0; + + if (lua_istable(L, 1)) { + /* Table version */ + lua_pushvalue(L, 1); + lua_pushstring(L, "task"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TUSERDATA) { + task = lua_check_task_maybe(L, -1); + } + lua_pop(L, 1); + + if (!task) { + /* We need to get ev_base, config and session separately */ + lua_pushstring(L, "config"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TUSERDATA) { + cfg = lua_check_config(L, -1); + } + lua_pop(L, 1); + + lua_pushstring(L, "session"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TUSERDATA) { + session = lua_check_session(L, -1); + } + lua_pop(L, 1); + + lua_pushstring(L, "ev_base"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TUSERDATA) { + ev_base = lua_check_ev_base(L, -1); + } + lua_pop(L, 1); + + if (cfg && ev_base) { + ret = TRUE; + } + else if (!cfg) { + msg_err_task_check("config is not passed"); + } + else { + msg_err_task_check("ev_base is not set"); + } + } + else { + cfg = task->cfg; + session = task->s; + ev_base = task->event_loop; + log_tag = task->task_pool->tag.uid; + ret = TRUE; + } + + if (pcbref) { + lua_pushstring(L, "callback"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TFUNCTION) { + /* This also pops function from the stack */ + cbref = luaL_ref(L, LUA_REGISTRYINDEX); + *pcbref = cbref; + } + else { + *pcbref = -1; + lua_pop(L, 1); + } + } + + lua_pushstring(L, "host"); + lua_gettable(L, -2); + + if (lua_type(L, -1) == LUA_TUSERDATA) { + addr = lua_check_ip(L, -1); + host = rspamd_inet_address_to_string_pretty(addr->addr); + } + else if (lua_type(L, -1) == LUA_TSTRING) { + host = lua_tostring(L, -1); + + if (rspamd_parse_inet_address(&ip, + host, strlen(host), RSPAMD_INET_ADDRESS_PARSE_DEFAULT)) { + addr = g_alloca(sizeof(*addr)); + addr->addr = ip; + + if (rspamd_inet_address_get_port(ip) == 0) { + rspamd_inet_address_set_port(ip, 6379); + } + } + } + lua_pop(L, 1); + + lua_pushstring(L, "username"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TSTRING) { + username = lua_tostring(L, -1); + } + lua_pop(L, 1); + + lua_pushstring(L, "password"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TSTRING) { + password = lua_tostring(L, -1); + } + lua_pop(L, 1); + + lua_pushstring(L, "dbname"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TSTRING) { + dbname = lua_tostring(L, -1); + } + lua_pop(L, 1); + + lua_pushstring(L, "opaque_data"); + lua_gettable(L, -2); + if (!!lua_toboolean(L, -1)) { + flags |= LUA_REDIS_TEXTDATA; + } + lua_pop(L, 1); + + lua_pushstring(L, "no_pool"); + lua_gettable(L, -2); + if (!!lua_toboolean(L, -1)) { + flags |= LUA_REDIS_NO_POOL; + } + lua_pop(L, 1); + + lua_pop(L, 1); /* table */ + + if (session && rspamd_session_blocked(session)) { + msg_err_task_check("Session is being destroying"); + ret = FALSE; + } + + if (ret && addr != NULL) { + ctx = g_malloc0(sizeof(struct lua_redis_ctx)); + REF_INIT_RETAIN(ctx, lua_redis_dtor); + if (is_async) { + ctx->flags |= flags | LUA_REDIS_ASYNC; + ud = &ctx->async; + } + else { + ud = &ctx->async; + ctx->replies = g_queue_new(); + ctx->events_cleanup = g_queue_new(); + } + + ud->s = session; + ud->cfg = cfg; + ud->pool = cfg->redis_pool; + ud->event_loop = ev_base; + ud->task = task; + + if (log_tag) { + rspamd_strlcpy(ud->log_tag, log_tag, sizeof(ud->log_tag)); + } + else { + /* Use pointer itself as a tag */ + rspamd_snprintf(ud->log_tag, sizeof(ud->log_tag), + "%ud", + (int) rspamd_cryptobox_fast_hash(&ud, sizeof(ud), 0)); + } + + if (task) { + ud->item = rspamd_symcache_get_cur_item(task); + } + + ret = TRUE; + } + else { + if (cbref != -1) { + luaL_unref(L, LUA_REGISTRYINDEX, cbref); + } + + msg_err_task_check("incorrect function invocation"); + ret = FALSE; + } + } + + if (ret) { + ud->terminated = 0; + ud->ctx = rspamd_redis_pool_connect(ud->pool, + dbname, username, password, + rspamd_inet_address_to_string(addr->addr), + rspamd_inet_address_get_port(addr->addr)); + + if (ip) { + rspamd_inet_address_free(ip); + } + + if (ud->ctx == NULL || ud->ctx->err) { + if (ud->ctx) { + msg_err_task_check("cannot connect to redis: %s", + ud->ctx->errstr); + rspamd_redis_pool_release_connection(ud->pool, ud->ctx, + RSPAMD_REDIS_RELEASE_FATAL); + ud->ctx = NULL; + } + else { + msg_err_task_check("cannot connect to redis (OS error): %s", + strerror(errno)); + } + + REDIS_RELEASE(ctx); + + return NULL; + } + + msg_debug_lua_redis("opened redis connection host=%s; ctx=%p; ud=%p", + host, ctx, ud); + + return ctx; + } + + if (ip) { + rspamd_inet_address_free(ip); + } + + return NULL; +} + +/*** + * @function rspamd_redis.make_request({params}) + * Make request to redis server, params is a table of key=value arguments in any order + * @param {task} task worker task object + * @param {ip|string} host server address + * @param {function} callback callback to be called in form `function (task, err, data)` + * @param {string} cmd command to be sent to redis + * @param {table} args numeric array of strings used as redis arguments + * @param {number} timeout timeout in seconds for request (1.0 by default) + * @return {boolean} `true` if a request has been scheduled + */ +static int +lua_redis_make_request(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_redis_request_specific_userdata *sp_ud; + struct lua_redis_userdata *ud; + struct lua_redis_ctx *ctx, **pctx; + const gchar *cmd = NULL; + gdouble timeout = REDIS_DEFAULT_TIMEOUT; + gint cbref = -1; + gboolean ret = FALSE; + + ctx = rspamd_lua_redis_prepare_connection(L, &cbref, TRUE); + + if (ctx) { + ud = &ctx->async; + sp_ud = g_malloc0(sizeof(*sp_ud)); + sp_ud->cbref = cbref; + sp_ud->c = ud; + sp_ud->ctx = ctx; + + lua_pushstring(L, "cmd"); + lua_gettable(L, -2); + cmd = lua_tostring(L, -1); + lua_pop(L, 1); + + lua_pushstring(L, "timeout"); + lua_gettable(L, 1); + if (lua_type(L, -1) == LUA_TNUMBER) { + timeout = lua_tonumber(L, -1); + } + lua_pop(L, 1); + ud->timeout = timeout; + + + lua_pushstring(L, "args"); + lua_gettable(L, 1); + lua_redis_parse_args(L, -1, cmd, &sp_ud->args, &sp_ud->arglens, + &sp_ud->nargs); + lua_pop(L, 1); + LL_PREPEND(ud->specific, sp_ud); + + ret = redisAsyncCommandArgv(ud->ctx, + lua_redis_callback, + sp_ud, + sp_ud->nargs, + (const gchar **) sp_ud->args, + sp_ud->arglens); + + if (ret == REDIS_OK) { + if (ud->s) { + rspamd_session_add_event(ud->s, + lua_redis_fin, sp_ud, + M); + + if (ud->item) { + rspamd_symcache_item_async_inc(ud->task, ud->item, M); + } + } + + REDIS_RETAIN(ctx); /* Cleared by fin event */ + ctx->cmds_pending++; + + if (ud->ctx->c.flags & REDIS_SUBSCRIBED) { + msg_debug_lua_redis("subscribe command, never unref/timeout"); + sp_ud->flags |= LUA_REDIS_SUBSCRIBED; + } + + sp_ud->timeout_ev.data = sp_ud; + ev_now_update_if_cheap((struct ev_loop *) ud->event_loop); + ev_timer_init(&sp_ud->timeout_ev, lua_redis_timeout, timeout, 0.0); + ev_timer_start(ud->event_loop, &sp_ud->timeout_ev); + + ret = TRUE; + } + else { + msg_info("call to redis failed: %s", ud->ctx->errstr); + rspamd_redis_pool_release_connection(ud->pool, ud->ctx, + RSPAMD_REDIS_RELEASE_FATAL); + ud->ctx = NULL; + REDIS_RELEASE(ctx); + ret = FALSE; + } + } + else { + lua_pushboolean(L, FALSE); + lua_pushnil(L); + + return 2; + } + + lua_pushboolean(L, ret); + + if (ret) { + pctx = lua_newuserdata(L, sizeof(ctx)); + *pctx = ctx; + rspamd_lua_setclass(L, "rspamd{redis}", -1); + } + else { + lua_pushnil(L); + } + + return 2; +} + +/*** + * @function rspamd_redis.make_request_sync({params}) + * Make blocking request to redis server, params is a table of key=value arguments in any order + * @param {ip|string} host server address + * @param {string} cmd command to be sent to redis + * @param {table} args numeric array of strings used as redis arguments + * @param {number} timeout timeout in seconds for request (1.0 by default) + * @return {boolean + result} `true` and a result if a request has been successful + */ +static int +lua_redis_make_request_sync(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_ip *addr = NULL; + rspamd_inet_addr_t *ip = NULL; + const gchar *cmd = NULL, *host; + struct timeval tv; + gboolean ret = FALSE; + gdouble timeout = REDIS_DEFAULT_TIMEOUT; + gchar **args = NULL; + gsize *arglens = NULL; + guint nargs = 0, flags = 0; + redisContext *ctx; + redisReply *r; + + if (lua_istable(L, 1)) { + lua_pushvalue(L, 1); + + lua_pushstring(L, "cmd"); + lua_gettable(L, -2); + cmd = lua_tostring(L, -1); + lua_pop(L, 1); + + lua_pushstring(L, "host"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TUSERDATA) { + addr = lua_check_ip(L, -1); + } + else if (lua_type(L, -1) == LUA_TSTRING) { + host = lua_tostring(L, -1); + if (rspamd_parse_inet_address(&ip, + host, strlen(host), RSPAMD_INET_ADDRESS_PARSE_DEFAULT)) { + addr = g_alloca(sizeof(*addr)); + addr->addr = ip; + + if (rspamd_inet_address_get_port(ip) == 0) { + rspamd_inet_address_set_port(ip, 6379); + } + } + } + lua_pop(L, 1); + + lua_pushstring(L, "timeout"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TNUMBER) { + timeout = lua_tonumber(L, -1); + } + lua_pop(L, 1); + + lua_pushstring(L, "opaque_data"); + lua_gettable(L, -2); + if (!!lua_toboolean(L, -1)) { + flags |= LUA_REDIS_TEXTDATA; + } + lua_pop(L, 1); + + + if (cmd) { + lua_pushstring(L, "args"); + lua_gettable(L, -2); + lua_redis_parse_args(L, -1, cmd, &args, &arglens, &nargs); + lua_pop(L, 1); + } + + lua_pop(L, 1); + + if (addr && cmd) { + ret = TRUE; + } + } + + if (ret) { + double_to_tv(timeout, &tv); + + if (rspamd_inet_address_get_af(addr->addr) == AF_UNIX) { + ctx = redisConnectUnixWithTimeout( + rspamd_inet_address_to_string(addr->addr), tv); + } + else { + ctx = redisConnectWithTimeout( + rspamd_inet_address_to_string(addr->addr), + rspamd_inet_address_get_port(addr->addr), tv); + } + + if (ip) { + rspamd_inet_address_free(ip); + } + + if (ctx == NULL || ctx->err) { + redisFree(ctx); + lua_redis_free_args(args, arglens, nargs); + lua_pushboolean(L, FALSE); + + return 1; + } + + r = redisCommandArgv(ctx, + nargs, + (const gchar **) args, + arglens); + + if (r != NULL) { + if (r->type != REDIS_REPLY_ERROR) { + lua_pushboolean(L, TRUE); + lua_redis_push_reply(L, r, flags & LUA_REDIS_TEXTDATA); + } + else { + lua_pushboolean(L, FALSE); + lua_pushstring(L, r->str); + } + + freeReplyObject(r); + redisFree(ctx); + lua_redis_free_args(args, arglens, nargs); + + return 2; + } + else { + msg_info("call to redis failed: %s", ctx->errstr); + redisFree(ctx); + lua_redis_free_args(args, arglens, nargs); + lua_pushboolean(L, FALSE); + } + } + else { + if (ip) { + rspamd_inet_address_free(ip); + } + msg_err("bad arguments for redis request"); + lua_redis_free_args(args, arglens, nargs); + + lua_pushboolean(L, FALSE); + } + + return 1; +} + +/*** + * @function rspamd_redis.connect({params}) + * Make request to redis server, params is a table of key=value arguments in any order + * @param {task} task worker task object + * @param {ip|string} host server address + * @param {number} timeout timeout in seconds for request (1.0 by default) + * @return {boolean,redis} new connection object or nil if connection failed + */ +static int +lua_redis_connect(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_redis_userdata *ud; + struct lua_redis_ctx *ctx, **pctx; + gdouble timeout = REDIS_DEFAULT_TIMEOUT; + + ctx = rspamd_lua_redis_prepare_connection(L, NULL, TRUE); + + if (ctx) { + ud = &ctx->async; + + lua_pushstring(L, "timeout"); + lua_gettable(L, 1); + if (lua_type(L, -1) == LUA_TNUMBER) { + timeout = lua_tonumber(L, -1); + } + + lua_pop(L, 1); + ud->timeout = timeout; + } + else { + lua_pushboolean(L, FALSE); + lua_pushnil(L); + + return 2; + } + + lua_pushboolean(L, TRUE); + pctx = lua_newuserdata(L, sizeof(ctx)); + *pctx = ctx; + rspamd_lua_setclass(L, "rspamd{redis}", -1); + + return 2; +} + +/*** + * @function rspamd_redis.connect_sync({params}) + * Make blocking request to redis server, params is a table of key=value arguments in any order + * @param {ip|string} host server address + * @param {number} timeout timeout in seconds for request (1.0 by default) + * @return {redis} redis object if a request has been successful + */ +static int +lua_redis_connect_sync(lua_State *L) +{ + LUA_TRACE_POINT; + gdouble timeout = REDIS_DEFAULT_TIMEOUT; + struct lua_redis_ctx *ctx, **pctx; + + ctx = rspamd_lua_redis_prepare_connection(L, NULL, FALSE); + + if (ctx) { + if (lua_istable(L, 1)) { + lua_pushstring(L, "timeout"); + lua_gettable(L, 1); + if (lua_type(L, -1) == LUA_TNUMBER) { + timeout = lua_tonumber(L, -1); + } + lua_pop(L, 1); + } + + ctx->async.timeout = timeout; + + lua_pushboolean(L, TRUE); + pctx = lua_newuserdata(L, sizeof(ctx)); + *pctx = ctx; + rspamd_lua_setclass(L, "rspamd{redis}", -1); + } + else { + lua_pushboolean(L, FALSE); + lua_pushstring(L, "bad arguments for redis request"); + return 2; + } + + return 2; +} + +/*** + * @method rspamd_redis:add_cmd(cmd, {args}) + * Append new cmd to redis pipeline + * @param {string} cmd command to be sent to redis + * @param {table} args array of strings used as redis arguments + * @return {boolean} `true` if a request has been successful + */ +static int +lua_redis_add_cmd(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_redis_ctx *ctx = lua_check_redis(L, 1); + struct lua_redis_request_specific_userdata *sp_ud; + struct lua_redis_userdata *ud; + const gchar *cmd = NULL; + gint args_pos = 2; + gint cbref = -1, ret; + + if (ctx) { + if (ctx->flags & LUA_REDIS_TERMINATED) { + lua_pushboolean(L, FALSE); + lua_pushstring(L, "Connection is terminated"); + + return 2; + } + + /* Async version */ + if (lua_type(L, 2) == LUA_TSTRING) { + /* No callback version */ + cmd = lua_tostring(L, 2); + args_pos = 3; + } + else if (lua_type(L, 2) == LUA_TFUNCTION) { + lua_pushvalue(L, 2); + cbref = luaL_ref(L, LUA_REGISTRYINDEX); + cmd = lua_tostring(L, 3); + args_pos = 4; + } + else { + return luaL_error(L, "invalid arguments"); + } + + sp_ud = g_malloc0(sizeof(*sp_ud)); + if (IS_ASYNC(ctx)) { + sp_ud->c = &ctx->async; + ud = &ctx->async; + sp_ud->cbref = cbref; + } + else { + sp_ud->c = &ctx->async; + ud = &ctx->async; + } + sp_ud->ctx = ctx; + + lua_redis_parse_args(L, args_pos, cmd, &sp_ud->args, + &sp_ud->arglens, &sp_ud->nargs); + + LL_PREPEND(sp_ud->c->specific, sp_ud); + + if (ud->s && rspamd_session_blocked(ud->s)) { + lua_pushboolean(L, 0); + lua_pushstring(L, "session is terminating"); + + return 2; + } + + if (IS_ASYNC(ctx)) { + ret = redisAsyncCommandArgv(sp_ud->c->ctx, + lua_redis_callback, + sp_ud, + sp_ud->nargs, + (const gchar **) sp_ud->args, + sp_ud->arglens); + } + else { + ret = redisAsyncCommandArgv(sp_ud->c->ctx, + lua_redis_callback_sync, + sp_ud, + sp_ud->nargs, + (const gchar **) sp_ud->args, + sp_ud->arglens); + } + + if (ret == REDIS_OK) { + if (ud->s) { + rspamd_session_add_event(ud->s, + lua_redis_fin, + sp_ud, + M); + + if (ud->item) { + rspamd_symcache_item_async_inc(ud->task, ud->item, M); + } + } + + sp_ud->timeout_ev.data = sp_ud; + + if (IS_ASYNC(ctx)) { + ev_timer_init(&sp_ud->timeout_ev, lua_redis_timeout, + sp_ud->c->timeout, 0.0); + } + else { + ev_timer_init(&sp_ud->timeout_ev, lua_redis_timeout_sync, + sp_ud->c->timeout, 0.0); + } + + ev_timer_start(ud->event_loop, &sp_ud->timeout_ev); + REDIS_RETAIN(ctx); + ctx->cmds_pending++; + } + else { + msg_info("call to redis failed: %s", + sp_ud->c->ctx->errstr); + lua_pushboolean(L, 0); + lua_pushstring(L, sp_ud->c->ctx->errstr); + + return 2; + } + } + + lua_pushboolean(L, true); + + return 1; +} + +/*** + * @method rspamd_redis:exec() + * Executes pending commands (suitable for blocking IO only for now) + * @return {boolean}, {table}, ...: pairs in format [bool, result] for each request pending + */ +static int +lua_redis_exec(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_redis_ctx *ctx = lua_check_redis(L, 1); + + if (ctx == NULL) { + lua_error(L); + + return 1; + } + + if (IS_ASYNC(ctx)) { + lua_pushstring(L, "Async redis pipelining is not implemented"); + lua_error(L); + return 0; + } + else { + if (ctx->cmds_pending == 0 && g_queue_get_length(ctx->replies) == 0) { + lua_pushstring(L, "No pending commands to execute"); + lua_error(L); + } + if (ctx->cmds_pending == 0 && g_queue_get_length(ctx->replies) > 0) { + gint results = lua_redis_push_results(ctx, L); + return results; + } + else { + ctx->thread = lua_thread_pool_get_running_entry(ctx->async.cfg->lua_thread_pool); + return lua_thread_yield(ctx->thread, 0); + } + } +} + +static gint +lua_load_redis(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, redislib_f); + + return 1; +} + +static gint +lua_redis_null_idx(lua_State *L) +{ + lua_pushnil(L); + + return 1; +} + +static void +lua_redis_null_mt(lua_State *L) +{ + luaL_newmetatable(L, "redis{null}"); + + lua_pushcfunction(L, lua_redis_null_idx); + lua_setfield(L, -2, "__index"); + lua_pushcfunction(L, lua_redis_null_idx); + lua_setfield(L, -2, "__tostring"); + + lua_pop(L, 1); +} + +/** + * Open redis library + * @param L lua stack + * @return + */ +void luaopen_redis(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{redis}", redislib_m); + lua_pop(L, 1); + rspamd_lua_add_preload(L, "rspamd_redis", lua_load_redis); + + /* Set null element */ + lua_redis_null_mt(L); + redis_null = lua_newuserdata(L, 0); + luaL_getmetatable(L, "redis{null}"); + lua_setmetatable(L, -2); + lua_setfield(L, LUA_REGISTRYINDEX, "redis.null"); +} diff --git a/src/lua/lua_regexp.c b/src/lua/lua_regexp.c new file mode 100644 index 0000000..7e638ca --- /dev/null +++ b/src/lua/lua_regexp.c @@ -0,0 +1,858 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" + +/*** + * @module rspamd_regexp + * Rspamd regexp is an utility module that handles rspamd perl compatible + * regular expressions + * @example + * local rspamd_regexp = require "rspamd_regexp" + * + * local re = rspamd_regexp.create_cached('/^\\s*some_string\\s*$/i') + * re:match('some_string') + * local re = rspamd_regexp.create_cached('/\\s+/i') + * re:split('word word word') -- returns ['word', 'word', 'word'] + */ + +LUA_FUNCTION_DEF(regexp, create); +LUA_FUNCTION_DEF(regexp, import_glob); +LUA_FUNCTION_DEF(regexp, import_plain); +LUA_FUNCTION_DEF(regexp, create_cached); +LUA_FUNCTION_DEF(regexp, get_cached); +LUA_FUNCTION_DEF(regexp, get_pattern); +LUA_FUNCTION_DEF(regexp, set_limit); +LUA_FUNCTION_DEF(regexp, set_max_hits); +LUA_FUNCTION_DEF(regexp, get_max_hits); +LUA_FUNCTION_DEF(regexp, search); +LUA_FUNCTION_DEF(regexp, match); +LUA_FUNCTION_DEF(regexp, matchn); +LUA_FUNCTION_DEF(regexp, split); +LUA_FUNCTION_DEF(regexp, destroy); +LUA_FUNCTION_DEF(regexp, gc); + +static const struct luaL_reg regexplib_m[] = { + LUA_INTERFACE_DEF(regexp, get_pattern), + LUA_INTERFACE_DEF(regexp, set_limit), + LUA_INTERFACE_DEF(regexp, set_max_hits), + LUA_INTERFACE_DEF(regexp, get_max_hits), + LUA_INTERFACE_DEF(regexp, match), + LUA_INTERFACE_DEF(regexp, matchn), + LUA_INTERFACE_DEF(regexp, search), + LUA_INTERFACE_DEF(regexp, split), + LUA_INTERFACE_DEF(regexp, destroy), + {"__tostring", lua_regexp_get_pattern}, + {"__gc", lua_regexp_gc}, + {NULL, NULL}}; +static const struct luaL_reg regexplib_f[] = { + LUA_INTERFACE_DEF(regexp, create), + LUA_INTERFACE_DEF(regexp, import_glob), + LUA_INTERFACE_DEF(regexp, import_plain), + LUA_INTERFACE_DEF(regexp, get_cached), + LUA_INTERFACE_DEF(regexp, create_cached), + {NULL, NULL}}; + +#define LUA_REGEXP_FLAG_DESTROYED (1 << 0) +#define IS_DESTROYED(re) ((re)->re_flags & LUA_REGEXP_FLAG_DESTROYED) + +rspamd_mempool_t *regexp_static_pool = NULL; + +struct rspamd_lua_regexp * +lua_check_regexp(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{regexp}"); + + luaL_argcheck(L, ud != NULL, pos, "'regexp' expected"); + return ud ? *((struct rspamd_lua_regexp **) ud) : NULL; +} + +/*** + * @function rspamd_regexp.create(pattern[, flags]) + * Creates new rspamd_regexp + * @param {string} pattern pattern to build regexp. If this pattern is enclosed in `//` then it is possible to specify flags after it + * @param {string} flags optional flags to create regular expression + * @return {regexp} regexp argument that is *not* automatically destroyed + * @example + * local regexp = require "rspamd_regexp" + * + * local re = regexp.create('/^test.*[0-9]\\s*$/i') + */ +static int +lua_regexp_create(lua_State *L) +{ + LUA_TRACE_POINT; + rspamd_regexp_t *re; + struct rspamd_lua_regexp *new, **pnew; + const gchar *string, *flags_str = NULL; + GError *err = NULL; + + string = luaL_checkstring(L, 1); + if (lua_gettop(L) == 2) { + flags_str = luaL_checkstring(L, 2); + } + + if (string) { + re = rspamd_regexp_new(string, flags_str, &err); + if (re == NULL) { + lua_pushnil(L); + msg_info("cannot parse regexp: %s, error: %s", + string, + err == NULL ? "undefined" : err->message); + g_error_free(err); + } + else { + new = g_malloc0(sizeof(struct rspamd_lua_regexp)); + new->re = re; + new->re_pattern = g_strdup(string); + new->module = rspamd_lua_get_module_name(L); + pnew = lua_newuserdata(L, sizeof(struct rspamd_lua_regexp *)); + rspamd_lua_setclass(L, "rspamd{regexp}", -1); + *pnew = new; + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @function rspamd_regexp.import_glob(glob_pattern[, flags]) + * Creates new rspamd_regexp from glob + * @param {string} pattern pattern to build regexp. + * @param {string} flags optional flags to create regular expression + * @return {regexp} regexp argument that is *not* automatically destroyed + * @example + * local regexp = require "rspamd_regexp" + * + * local re = regexp.import_glob('ab*', 'i') + */ +static int +lua_regexp_import_glob(lua_State *L) +{ + LUA_TRACE_POINT; + rspamd_regexp_t *re; + struct rspamd_lua_regexp *new, **pnew; + const gchar *string, *flags_str = NULL; + gchar *escaped; + gsize pat_len; + GError *err = NULL; + + string = luaL_checklstring(L, 1, &pat_len); + + if (lua_gettop(L) == 2) { + flags_str = luaL_checkstring(L, 2); + } + + if (string) { + escaped = rspamd_str_regexp_escape(string, pat_len, NULL, + RSPAMD_REGEXP_ESCAPE_GLOB | RSPAMD_REGEXP_ESCAPE_UTF); + + re = rspamd_regexp_new(escaped, flags_str, &err); + + if (re == NULL) { + lua_pushnil(L); + msg_info("cannot parse regexp: %s, error: %s", + string, + err == NULL ? "undefined" : err->message); + g_error_free(err); + g_free(escaped); + } + else { + new = g_malloc0(sizeof(struct rspamd_lua_regexp)); + new->re = re; + new->re_pattern = escaped; + new->module = rspamd_lua_get_module_name(L); + pnew = lua_newuserdata(L, sizeof(struct rspamd_lua_regexp *)); + rspamd_lua_setclass(L, "rspamd{regexp}", -1); + *pnew = new; + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @function rspamd_regexp.import_plain(plain_string[, flags]) + * Creates new rspamd_regexp from plain string (escaping specials) + * @param {string} pattern pattern to build regexp. + * @param {string} flags optional flags to create regular expression + * @return {regexp} regexp argument that is *not* automatically destroyed + * @example + * local regexp = require "rspamd_regexp" + * + * local re = regexp.import_plain('exact_string_with*', 'i') + */ +static int +lua_regexp_import_plain(lua_State *L) +{ + LUA_TRACE_POINT; + rspamd_regexp_t *re; + struct rspamd_lua_regexp *new, **pnew; + const gchar *string, *flags_str = NULL; + gchar *escaped; + gsize pat_len; + GError *err = NULL; + + string = luaL_checklstring(L, 1, &pat_len); + + if (lua_gettop(L) == 2) { + flags_str = luaL_checkstring(L, 2); + } + + if (string) { + escaped = rspamd_str_regexp_escape(string, pat_len, NULL, + RSPAMD_REGEXP_ESCAPE_ASCII); + + re = rspamd_regexp_new(escaped, flags_str, &err); + + if (re == NULL) { + lua_pushnil(L); + msg_info("cannot parse regexp: %s, error: %s", + string, + err == NULL ? "undefined" : err->message); + g_error_free(err); + g_free(escaped); + } + else { + new = g_malloc0(sizeof(struct rspamd_lua_regexp)); + new->re = re; + new->re_pattern = escaped; + new->module = rspamd_lua_get_module_name(L); + pnew = lua_newuserdata(L, sizeof(struct rspamd_lua_regexp *)); + rspamd_lua_setclass(L, "rspamd{regexp}", -1); + *pnew = new; + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @function rspamd_regexp.get_cached(pattern) + * This function gets cached and pre-compiled regexp created by either `create` + * or `create_cached` methods. If no cached regexp is found then `nil` is returned. + * + * @param {string} pattern regexp pattern + * @return {regexp} cached regexp structure or `nil` + */ +static int +lua_regexp_get_cached(lua_State *L) +{ + LUA_TRACE_POINT; + rspamd_regexp_t *re; + struct rspamd_lua_regexp *new, **pnew; + const gchar *string, *flags_str = NULL; + + string = luaL_checkstring(L, 1); + if (lua_gettop(L) == 2) { + flags_str = luaL_checkstring(L, 2); + } + + if (string) { + re = rspamd_regexp_cache_query(NULL, string, flags_str); + + if (re) { + new = g_malloc0(sizeof(struct rspamd_lua_regexp)); + new->re = rspamd_regexp_ref(re); + new->re_pattern = g_strdup(string); + new->module = rspamd_lua_get_module_name(L); + pnew = lua_newuserdata(L, sizeof(struct rspamd_lua_regexp *)); + rspamd_lua_setclass(L, "rspamd{regexp}", -1); + *pnew = new; + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @function rspamd_regexp.create_cached(pattern[, flags]) + * This function is similar to `create` but it tries to search for regexp in the + * cache first. + * @param {string} pattern pattern to build regexp. If this pattern is enclosed in `//` then it is possible to specify flags after it + * @param {string} flags optional flags to create regular expression + * @return {regexp} regexp argument that is *not* automatically destroyed + * @example + * local regexp = require "rspamd_regexp" + * + * local re = regexp.create_cached('/^test.*[0-9]\\s*$/i') + * ... + * -- This doesn't create new regexp object + * local other_re = regexp.create_cached('/^test.*[0-9]\\s*$/i') + */ +static int +lua_regexp_create_cached(lua_State *L) +{ + LUA_TRACE_POINT; + rspamd_regexp_t *re; + struct rspamd_lua_regexp *new, **pnew; + const gchar *string, *flags_str = NULL; + GError *err = NULL; + + string = luaL_checkstring(L, 1); + if (lua_gettop(L) == 2) { + flags_str = luaL_checkstring(L, 2); + } + + if (string) { + re = rspamd_regexp_cache_query(NULL, string, flags_str); + + if (re) { + new = g_malloc0(sizeof(struct rspamd_lua_regexp)); + new->re = rspamd_regexp_ref(re); + new->re_pattern = g_strdup(string); + new->module = rspamd_lua_get_module_name(L); + pnew = lua_newuserdata(L, sizeof(struct rspamd_lua_regexp *)); + + rspamd_lua_setclass(L, "rspamd{regexp}", -1); + *pnew = new; + } + else { + re = rspamd_regexp_cache_create(NULL, string, flags_str, &err); + if (re == NULL) { + lua_pushnil(L); + msg_info("cannot parse regexp: %s, error: %s", + string, + err == NULL ? "undefined" : err->message); + g_error_free(err); + } + else { + new = g_malloc0(sizeof(struct rspamd_lua_regexp)); + new->re = rspamd_regexp_ref(re); + new->re_pattern = g_strdup(string); + new->module = rspamd_lua_get_module_name(L); + pnew = lua_newuserdata(L, sizeof(struct rspamd_lua_regexp *)); + rspamd_lua_setclass(L, "rspamd{regexp}", -1); + *pnew = new; + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @method re:get_pattern() + * Get a pattern for specified regexp object + * @return {string} pattern line + */ +static int +lua_regexp_get_pattern(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_regexp *re = lua_check_regexp(L, 1); + + if (re && re->re && !IS_DESTROYED(re)) { + lua_pushstring(L, rspamd_regexp_get_pattern(re->re)); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** + * @method re:set_limit(lim) + * Set maximum size of text length to be matched with this regexp (if `lim` is + * less or equal to zero then all texts are checked) + * @param {number} lim limit in bytes + */ +static int +lua_regexp_set_limit(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_regexp *re = lua_check_regexp(L, 1); + gint64 lim; + + lim = lua_tointeger(L, 2); + + if (re && re->re && !IS_DESTROYED(re)) { + if (lim > 0) { + rspamd_regexp_set_match_limit(re->re, lim); + } + else { + rspamd_regexp_set_match_limit(re->re, 0); + } + } + + return 0; +} + +/*** + * @method re:set_max_hits(lim) + * Set maximum number of hits returned by a regexp + * @param {number} lim limit in hits count + * @return {number} old number of max hits + */ +static int +lua_regexp_set_max_hits(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_regexp *re = lua_check_regexp(L, 1); + guint lim; + + lim = luaL_checkinteger(L, 2); + + if (re && re->re && !IS_DESTROYED(re)) { + lua_pushinteger(L, rspamd_regexp_set_maxhits(re->re, lim)); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** + * @method re:get_max_hits(lim) + * Get maximum number of hits returned by a regexp + * @return {number} number of max hits + */ +static int +lua_regexp_get_max_hits(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_regexp *re = lua_check_regexp(L, 1); + + if (re && re->re && !IS_DESTROYED(re)) { + lua_pushinteger(L, rspamd_regexp_get_maxhits(re->re)); + } + else { + lua_pushinteger(L, 1); + } + + return 1; +} + +/*** + * @method re:search(line[, raw[, capture]]) + * Search line in regular expression object. If line matches then this + * function returns the table of captured strings. Otherwise, nil is returned. + * If `raw` is specified, then input is treated as raw data not encoded in `utf-8`. + * If `capture` is true, then this function saves all captures to the table of + * values, so the first element is the whole matched string and the + * subsequent elements are ordered captures defined within pattern. + * + * @param {string} line match the specified line against regexp object + * @param {bool} match raw regexp instead of utf8 one + * @param {bool} capture perform subpatterns capturing + * @return {table or nil} table of strings or tables (if `capture` is true) or nil if not matched + * @example + * local re = regexp.create_cached('/^\s*([0-9]+)\s*$/') + * -- returns nil + * local m1 = re:search('blah') + * local m2 = re:search(' 190 ') + * -- prints ' 190 ' + * print(m2[1]) + * + * local m3 = re:search(' 100500 ') + * -- prints ' 100500 ' + * print(m3[1][1]) + * -- prints '100500' capture + * print(m3[1][2]) + */ +static int +lua_regexp_search(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_regexp *re = lua_check_regexp(L, 1); + const gchar *data = NULL; + struct rspamd_lua_text *t; + const gchar *start = NULL, *end = NULL; + gint i; + gsize len = 0, capn; + gboolean matched = FALSE, capture = FALSE, raw = FALSE; + GArray *captures = NULL; + struct rspamd_re_capture *cap; + + if (re && !IS_DESTROYED(re)) { + if (lua_type(L, 2) == LUA_TSTRING) { + data = luaL_checklstring(L, 2, &len); + } + else if (lua_type(L, 2) == LUA_TUSERDATA) { + t = lua_check_text(L, 2); + if (t != NULL) { + data = t->start; + len = t->len; + } + } + + if (lua_gettop(L) >= 3) { + raw = lua_toboolean(L, 3); + } + + if (data && len > 0) { + if (lua_gettop(L) >= 4 && lua_toboolean(L, 4)) { + capture = TRUE; + captures = g_array_new(FALSE, TRUE, + sizeof(struct rspamd_re_capture)); + } + + lua_newtable(L); + i = 0; + + while (rspamd_regexp_search(re->re, data, len, &start, &end, raw, + captures)) { + + if (capture) { + lua_createtable(L, captures->len, 0); + + for (capn = 0; capn < captures->len; capn++) { + cap = &g_array_index(captures, struct rspamd_re_capture, + capn); + lua_pushlstring(L, cap->p, cap->len); + lua_rawseti(L, -2, capn + 1); + } + + lua_rawseti(L, -2, ++i); + } + else { + lua_pushlstring(L, start, end - start); + lua_rawseti(L, -2, ++i); + } + + matched = TRUE; + } + + if (!matched) { + lua_pop(L, 1); + lua_pushnil(L); + } + + if (capture) { + g_array_free(captures, TRUE); + } + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @method re:match(line[, raw_match]) + * Matches line against the regular expression and return true if line matches + * (partially or completely) + * + * @param {string} line match the specified line against regexp object + * @param {bool} match raw regexp instead of utf8 one + * @return {bool} true if `line` matches + */ +static int +lua_regexp_match(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_regexp *re = lua_check_regexp(L, 1); + struct rspamd_lua_text *t; + const gchar *data = NULL; + gsize len = 0; + gboolean raw = FALSE; + + if (re && !IS_DESTROYED(re)) { + if (lua_type(L, 2) == LUA_TSTRING) { + data = luaL_checklstring(L, 2, &len); + } + else if (lua_type(L, 2) == LUA_TUSERDATA) { + t = lua_check_text(L, 2); + if (t != NULL) { + data = t->start; + len = t->len; + } + } + + if (lua_gettop(L) == 3) { + raw = lua_toboolean(L, 3); + } + + if (data && len > 0) { + if (rspamd_regexp_search(re->re, data, len, NULL, NULL, raw, NULL)) { + lua_pushboolean(L, TRUE); + } + else { + lua_pushboolean(L, FALSE); + } + } + else { + lua_pushboolean(L, FALSE); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @method re:matchn(line, max_matches, [, raw_match]) + * Matches line against the regular expression and return number of matches if line matches + * (partially or completely). This process stop when `max_matches` is reached. + * If `max_matches` is zero, then only a single match is counted which is equal to + * @see re:match If `max_matches` is negative, then all matches are considered. + * + * @param {string} line match the specified line against regexp object + * @param {number} max_matches maximum number of matches + * @param {bool} match raw regexp instead of utf8 one + * @return {number} number of matches found in the `line` argument + */ +static int +lua_regexp_matchn(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_regexp *re = lua_check_regexp(L, 1); + struct rspamd_lua_text *t; + const gchar *data = NULL, *start = NULL, *end = NULL; + gint max_matches, matches; + gsize len = 0; + gboolean raw = FALSE; + + if (re && !IS_DESTROYED(re)) { + if (lua_type(L, 2) == LUA_TSTRING) { + data = luaL_checklstring(L, 2, &len); + } + else if (lua_type(L, 2) == LUA_TUSERDATA) { + t = lua_check_text(L, 2); + if (t != NULL) { + data = t->start; + len = t->len; + } + } + + max_matches = lua_tointeger(L, 3); + matches = 0; + + if (lua_gettop(L) == 4) { + raw = lua_toboolean(L, 4); + } + + if (data && len > 0) { + for (;;) { + if (rspamd_regexp_search(re->re, data, len, &start, &end, raw, + NULL)) { + matches++; + } + else { + break; + } + + if (max_matches >= 0 && matches >= max_matches) { + break; + } + } + } + + lua_pushinteger(L, matches); + } + else { + return luaL_error(L, "invalid arguments"); + } + + + return 1; +} + +/*** + * @method re:split(line) + * Split line using the specified regular expression. + * Breaks the string on the pattern, and returns an array of the tokens. + * If the pattern contains capturing parentheses, then the text for each + * of the substrings will also be returned. If the pattern does not match + * anywhere in the string, then the whole string is returned as the first + * token. + * @param {string/text} line line to split + * @return {table} table of split line portions (if text was the input, then text is used for return parts) + */ +static int +lua_regexp_split(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_regexp *re = lua_check_regexp(L, 1); + const gchar *data = NULL; + struct rspamd_lua_text *t; + gboolean matched = FALSE, is_text = FALSE; + gsize len = 0; + const gchar *start = NULL, *end = NULL, *old_start; + gint i; + + if (re && !IS_DESTROYED(re)) { + if (lua_type(L, 2) == LUA_TSTRING) { + data = luaL_checklstring(L, 2, &len); + } + else if (lua_type(L, 2) == LUA_TUSERDATA) { + t = lua_check_text(L, 2); + + if (t == NULL) { + lua_error(L); + return 0; + } + + data = t->start; + len = t->len; + is_text = TRUE; + } + + if (data && len > 0) { + lua_newtable(L); + i = 0; + old_start = data; + + while (rspamd_regexp_search(re->re, data, len, &start, &end, FALSE, + NULL)) { + if (start - old_start > 0) { + if (!is_text) { + lua_pushlstring(L, old_start, start - old_start); + } + else { + t = lua_newuserdata(L, sizeof(*t)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + t->start = old_start; + t->len = start - old_start; + t->flags = 0; + } + + lua_rawseti(L, -2, ++i); + matched = TRUE; + } + else if (start == end) { + break; + } + old_start = end; + } + + if (len > 0 && (end == NULL || end < data + len)) { + if (end == NULL) { + end = data; + } + + if (!is_text) { + lua_pushlstring(L, end, (data + len) - end); + } + else { + t = lua_newuserdata(L, sizeof(*t)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + t->start = end; + t->len = (data + len) - end; + t->flags = 0; + } + + lua_rawseti(L, -2, ++i); + matched = TRUE; + } + + if (!matched) { + lua_pop(L, 1); + lua_pushnil(L); + } + return 1; + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + lua_pushnil(L); + return 1; +} + +/*** + * @method re:destroy() + * Destroy regexp from caches if needed (the pointer is removed by garbage collector) + */ +static gint +lua_regexp_destroy(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_regexp *to_del = lua_check_regexp(L, 1); + + if (to_del) { + rspamd_regexp_cache_remove(NULL, to_del->re); + rspamd_regexp_unref(to_del->re); + to_del->re = NULL; + to_del->re_flags |= LUA_REGEXP_FLAG_DESTROYED; + } + + return 0; +} + +static gint +lua_regexp_gc(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_regexp *to_del = lua_check_regexp(L, 1); + + if (to_del) { + if (!IS_DESTROYED(to_del)) { + rspamd_regexp_unref(to_del->re); + } + + g_free(to_del->re_pattern); + g_free(to_del->module); + g_free(to_del); + } + + return 0; +} + +static gint +lua_load_regexp(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, regexplib_f); + + return 1; +} + +void luaopen_regexp(lua_State *L) +{ + if (!regexp_static_pool) { + regexp_static_pool = rspamd_mempool_new(rspamd_mempool_suggest_size(), + "regexp_lua_pool", 0); + } + + rspamd_lua_new_class(L, "rspamd{regexp}", regexplib_m); + lua_pop(L, 1); + rspamd_lua_add_preload(L, "rspamd_regexp", lua_load_regexp); +} + +RSPAMD_DESTRUCTOR(lua_re_static_pool_dtor) +{ + if (regexp_static_pool) { + rspamd_mempool_delete(regexp_static_pool); + } +}
\ No newline at end of file diff --git a/src/lua/lua_rsa.c b/src/lua/lua_rsa.c new file mode 100644 index 0000000..ae5acc8 --- /dev/null +++ b/src/lua/lua_rsa.c @@ -0,0 +1,867 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * @file lua_rsa.c + * This module exports routines to load rsa keys, check inline or external + * rsa signatures. It assumes sha256 based signatures. + */ + +#include "lua_common.h" +#include "unix-std.h" +#include <openssl/err.h> +#include <openssl/pem.h> +#include <openssl/sha.h> +#include <openssl/rsa.h> + +LUA_FUNCTION_DEF(rsa_pubkey, load); +LUA_FUNCTION_DEF(rsa_pubkey, create); +LUA_FUNCTION_DEF(rsa_pubkey, gc); +LUA_FUNCTION_DEF(rsa_pubkey, tostring); + +LUA_FUNCTION_DEF(rsa_privkey, load_file); +LUA_FUNCTION_DEF(rsa_privkey, load_pem); +LUA_FUNCTION_DEF(rsa_privkey, load_raw); +LUA_FUNCTION_DEF(rsa_privkey, load_base64); +LUA_FUNCTION_DEF(rsa_privkey, create); +LUA_FUNCTION_DEF(rsa_privkey, gc); +LUA_FUNCTION_DEF(rsa_privkey, save); + +LUA_FUNCTION_DEF(rsa_signature, create); +LUA_FUNCTION_DEF(rsa_signature, load); +LUA_FUNCTION_DEF(rsa_signature, save); +LUA_FUNCTION_DEF(rsa_signature, base64); +LUA_FUNCTION_DEF(rsa_signature, gc); + +LUA_FUNCTION_DEF(rsa, verify_memory); +LUA_FUNCTION_DEF(rsa, sign_memory); +LUA_FUNCTION_DEF(rsa, keypair); + +static const struct luaL_reg rsalib_f[] = { + LUA_INTERFACE_DEF(rsa, verify_memory), + LUA_INTERFACE_DEF(rsa, sign_memory), + LUA_INTERFACE_DEF(rsa, keypair), + {NULL, NULL}}; + +static const struct luaL_reg rsapubkeylib_f[] = { + LUA_INTERFACE_DEF(rsa_pubkey, load), + LUA_INTERFACE_DEF(rsa_pubkey, create), + {NULL, NULL}}; + +static const struct luaL_reg rsapubkeylib_m[] = { + {"__tostring", lua_rsa_pubkey_tostring}, + {"__gc", lua_rsa_pubkey_gc}, + {NULL, NULL}}; + +static const struct luaL_reg rsaprivkeylib_f[] = { + LUA_INTERFACE_DEF(rsa_privkey, load_file), + LUA_INTERFACE_DEF(rsa_privkey, load_pem), + LUA_INTERFACE_DEF(rsa_privkey, load_raw), + LUA_INTERFACE_DEF(rsa_privkey, load_base64), + LUA_INTERFACE_DEF(rsa_privkey, create), + {NULL, NULL}}; + +static const struct luaL_reg rsaprivkeylib_m[] = { + {"__tostring", rspamd_lua_class_tostring}, + {"__gc", lua_rsa_privkey_gc}, + LUA_INTERFACE_DEF(rsa_privkey, save), + {NULL, NULL}}; + +static const struct luaL_reg rsasignlib_f[] = { + LUA_INTERFACE_DEF(rsa_signature, load), + LUA_INTERFACE_DEF(rsa_signature, create), + {NULL, NULL}}; + +static const struct luaL_reg rsasignlib_m[] = { + LUA_INTERFACE_DEF(rsa_signature, save), + LUA_INTERFACE_DEF(rsa_signature, base64), + {"__tostring", rspamd_lua_class_tostring}, + {"__gc", lua_rsa_signature_gc}, + {NULL, NULL}}; + +static RSA * +lua_check_rsa_pubkey(lua_State *L, int pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{rsa_pubkey}"); + + luaL_argcheck(L, ud != NULL, 1, "'rsa_pubkey' expected"); + return ud ? *((RSA **) ud) : NULL; +} + +static RSA * +lua_check_rsa_privkey(lua_State *L, int pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{rsa_privkey}"); + + luaL_argcheck(L, ud != NULL, 1, "'rsa_privkey' expected"); + return ud ? *((RSA **) ud) : NULL; +} + +static rspamd_fstring_t * +lua_check_rsa_sign(lua_State *L, int pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{rsa_signature}"); + + luaL_argcheck(L, ud != NULL, 1, "'rsa_signature' expected"); + return ud ? *((rspamd_fstring_t **) ud) : NULL; +} + +static gint +lua_rsa_pubkey_load(lua_State *L) +{ + RSA *rsa = NULL, **prsa; + const gchar *filename; + FILE *f; + + filename = luaL_checkstring(L, 1); + if (filename != NULL) { + f = fopen(filename, "r"); + if (f == NULL) { + msg_err("cannot open pubkey from file: %s, %s", + filename, + strerror(errno)); + lua_pushnil(L); + } + else { + if (!PEM_read_RSA_PUBKEY(f, &rsa, NULL, NULL)) { + msg_err("cannot open pubkey from file: %s, %s", filename, + ERR_error_string(ERR_get_error(), NULL)); + lua_pushnil(L); + } + else { + prsa = lua_newuserdata(L, sizeof(RSA *)); + rspamd_lua_setclass(L, "rspamd{rsa_pubkey}", -1); + *prsa = rsa; + } + fclose(f); + } + } + else { + lua_pushnil(L); + } + return 1; +} + +static gint +lua_rsa_privkey_save(lua_State *L) +{ + const gchar *filename; + const gchar *type = "pem"; + FILE *f; + int ret; + + RSA *rsa = lua_check_rsa_privkey(L, 1); + + filename = luaL_checkstring(L, 2); + if (lua_gettop(L) > 2) { + type = luaL_checkstring(L, 3); + } + + if (rsa != NULL && filename != NULL) { + if (strcmp(filename, "-") == 0) { + f = stdout; + } + else { + f = fopen(filename, "wb"); + } + if (f == NULL) { + msg_err("cannot save privkey to file: %s, %s", + filename, + strerror(errno)); + lua_pushboolean(L, FALSE); + } + else { + if (f != stdout) { + /* Set secure permissions for the private key file */ + chmod(filename, S_IRUSR | S_IWUSR); + } + + if (strcmp(type, "der") == 0) { + ret = i2d_RSAPrivateKey_fp(f, rsa); + } + else { + ret = PEM_write_RSAPrivateKey(f, rsa, NULL, NULL, 0, NULL, NULL); + } + + if (!ret) { + msg_err("cannot save privkey to file: %s, %s", filename, + ERR_error_string(ERR_get_error(), NULL)); + lua_pushboolean(L, FALSE); + } + else { + lua_pushboolean(L, TRUE); + } + + if (f != stdout) { + fclose(f); + } + else { + fflush(f); + } + } + } + else { + lua_pushboolean(L, FALSE); + } + + return 1; +} + + +static gint +lua_rsa_pubkey_create(lua_State *L) +{ + RSA *rsa = NULL, **prsa; + const gchar *buf; + BIO *bp; + + buf = luaL_checkstring(L, 1); + if (buf != NULL) { + bp = BIO_new_mem_buf((void *) buf, -1); + + if (!PEM_read_bio_RSA_PUBKEY(bp, &rsa, NULL, NULL)) { + msg_err("cannot parse pubkey: %s", + ERR_error_string(ERR_get_error(), NULL)); + lua_pushnil(L); + } + else { + prsa = lua_newuserdata(L, sizeof(RSA *)); + rspamd_lua_setclass(L, "rspamd{rsa_pubkey}", -1); + *prsa = rsa; + } + BIO_free(bp); + } + else { + lua_pushnil(L); + } + return 1; +} + +static gint +lua_rsa_pubkey_gc(lua_State *L) +{ + RSA *rsa = lua_check_rsa_pubkey(L, 1); + + if (rsa != NULL) { + RSA_free(rsa); + } + + return 0; +} + +static gint +lua_rsa_pubkey_tostring(lua_State *L) +{ + RSA *rsa = lua_check_rsa_pubkey(L, 1); + + if (rsa != NULL) { + BIO *pubout = BIO_new(BIO_s_mem()); + const gchar *pubdata; + gsize publen; + int rc = i2d_RSA_PUBKEY_bio(pubout, rsa); + + if (rc != 1) { + BIO_free(pubout); + + return luaL_error(L, "i2d_RSA_PUBKEY_bio failed"); + } + + publen = BIO_get_mem_data(pubout, &pubdata); + lua_pushlstring(L, pubdata, publen); + BIO_free(pubout); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_rsa_privkey_load_file(lua_State *L) +{ + RSA *rsa = NULL, **prsa; + const gchar *filename; + FILE *f; + + filename = luaL_checkstring(L, 1); + if (filename != NULL) { + f = fopen(filename, "r"); + if (f == NULL) { + msg_err("cannot open private key from file: %s, %s", + filename, + strerror(errno)); + lua_pushnil(L); + } + else { + if (!PEM_read_RSAPrivateKey(f, &rsa, NULL, NULL)) { + msg_err("cannot open private key from file: %s, %s", filename, + ERR_error_string(ERR_get_error(), NULL)); + lua_pushnil(L); + } + else { + prsa = lua_newuserdata(L, sizeof(RSA *)); + rspamd_lua_setclass(L, "rspamd{rsa_privkey}", -1); + *prsa = rsa; + } + fclose(f); + } + } + else { + lua_pushnil(L); + } + return 1; +} + +static gint +lua_rsa_privkey_load_pem(lua_State *L) +{ + RSA *rsa = NULL, **prsa; + BIO *b; + struct rspamd_lua_text *t; + const gchar *data; + gsize len; + + if (lua_isuserdata(L, 1)) { + t = lua_check_text(L, 1); + + if (!t) { + return luaL_error(L, "invalid arguments"); + } + + data = t->start; + len = t->len; + } + else { + data = luaL_checklstring(L, 1, &len); + } + + if (data != NULL) { + b = BIO_new_mem_buf(data, len); + + if (!PEM_read_bio_RSAPrivateKey(b, &rsa, NULL, NULL)) { + msg_err("cannot open private key from data, %s", + ERR_error_string(ERR_get_error(), NULL)); + lua_pushnil(L); + } + else { + prsa = lua_newuserdata(L, sizeof(RSA *)); + rspamd_lua_setclass(L, "rspamd{rsa_privkey}", -1); + *prsa = rsa; + } + + BIO_free(b); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_rsa_privkey_load_raw(lua_State *L) +{ + RSA *rsa = NULL, **prsa; + BIO *b; + struct rspamd_lua_text *t; + const gchar *data; + gsize len; + + if (lua_isuserdata(L, 1)) { + t = lua_check_text(L, 1); + + if (!t) { + return luaL_error(L, "invalid arguments"); + } + + data = t->start; + len = t->len; + } + else { + data = luaL_checklstring(L, 1, &len); + } + + if (data != NULL) { + b = BIO_new_mem_buf(data, len); + rsa = d2i_RSAPrivateKey_bio(b, NULL); + + if (rsa == NULL) { + msg_err("cannot open private key from data, %s", + ERR_error_string(ERR_get_error(), NULL)); + lua_pushnil(L); + } + else { + prsa = lua_newuserdata(L, sizeof(RSA *)); + rspamd_lua_setclass(L, "rspamd{rsa_privkey}", -1); + *prsa = rsa; + } + + BIO_free(b); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_rsa_privkey_load_base64(lua_State *L) +{ + RSA *rsa = NULL, **prsa; + BIO *b; + EVP_PKEY *evp = NULL; + struct rspamd_lua_text *t; + const gchar *data; + guchar *decoded; + gsize len, dec_len; + + if (lua_isuserdata(L, 1)) { + t = lua_check_text(L, 1); + + if (!t) { + return luaL_error(L, "invalid arguments"); + } + + data = t->start; + len = t->len; + } + else { + data = luaL_checklstring(L, 1, &len); + } + + if (data != NULL) { + decoded = g_malloc(len); + + if (!rspamd_cryptobox_base64_decode(data, len, decoded, &dec_len)) { + g_free(decoded); + + return luaL_error(L, "invalid base64 encoding"); + } + + b = BIO_new_mem_buf(decoded, dec_len); + + if (d2i_PrivateKey_bio(b, &evp) != NULL) { + rsa = EVP_PKEY_get1_RSA(evp); + + if (rsa == NULL) { + msg_err("cannot open RSA private key from data, %s", + ERR_error_string(ERR_get_error(), NULL)); + lua_pushnil(L); + } + else { + prsa = lua_newuserdata(L, sizeof(RSA *)); + rspamd_lua_setclass(L, "rspamd{rsa_privkey}", -1); + *prsa = rsa; + } + + EVP_PKEY_free(evp); + } + else { + msg_err("cannot open EVP private key from data, %s", + ERR_error_string(ERR_get_error(), NULL)); + lua_pushnil(L); + } + + BIO_free(b); + g_free(decoded); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_rsa_privkey_create(lua_State *L) +{ + RSA *rsa = NULL, **prsa; + const gchar *buf; + BIO *bp; + + buf = luaL_checkstring(L, 1); + if (buf != NULL) { + bp = BIO_new_mem_buf((void *) buf, -1); + + if (!PEM_read_bio_RSAPrivateKey(bp, &rsa, NULL, NULL)) { + msg_err("cannot parse private key: %s", + ERR_error_string(ERR_get_error(), NULL)); + lua_pushnil(L); + } + else { + prsa = lua_newuserdata(L, sizeof(RSA *)); + rspamd_lua_setclass(L, "rspamd{rsa_privkey}", -1); + *prsa = rsa; + } + BIO_free(bp); + } + else { + lua_pushnil(L); + } + return 1; +} + +static gint +lua_rsa_privkey_gc(lua_State *L) +{ + RSA *rsa = lua_check_rsa_privkey(L, 1); + + if (rsa != NULL) { + RSA_free(rsa); + } + + return 0; +} + +static gint +lua_rsa_signature_load(lua_State *L) +{ + rspamd_fstring_t *sig, **psig; + const gchar *filename; + gpointer data; + int fd; + struct stat st; + + filename = luaL_checkstring(L, 1); + if (filename != NULL) { + fd = open(filename, O_RDONLY); + if (fd == -1) { + msg_err("cannot open signature file: %s, %s", filename, + strerror(errno)); + lua_pushnil(L); + } + else { + if (fstat(fd, &st) == -1 || + (data = + mmap(NULL, st.st_size, PROT_READ, MAP_SHARED, fd, 0)) == MAP_FAILED) { + msg_err("cannot mmap file %s: %s", filename, strerror(errno)); + lua_pushnil(L); + } + else { + sig = rspamd_fstring_new_init(data, st.st_size); + psig = lua_newuserdata(L, sizeof(rspamd_fstring_t *)); + rspamd_lua_setclass(L, "rspamd{rsa_signature}", -1); + *psig = sig; + munmap(data, st.st_size); + } + close(fd); + } + } + else { + lua_pushnil(L); + } + return 1; +} + +static gint +lua_rsa_signature_save(lua_State *L) +{ + rspamd_fstring_t *sig; + gint fd, flags; + const gchar *filename; + gboolean forced = FALSE, res = TRUE; + + sig = lua_check_rsa_sign(L, 1); + filename = luaL_checkstring(L, 2); + if (lua_gettop(L) > 2) { + forced = lua_toboolean(L, 3); + } + + if (sig != NULL && filename != NULL) { + flags = O_WRONLY | O_CREAT; + if (forced) { + flags |= O_TRUNC; + } + else { + flags |= O_EXCL; + } + fd = open(filename, flags, 00644); + if (fd == -1) { + msg_err("cannot create a signature file: %s, %s", + filename, + strerror(errno)); + lua_pushboolean(L, FALSE); + } + else { + while (write(fd, sig->str, sig->len) == -1) { + if (errno == EINTR) { + continue; + } + msg_err("cannot write to a signature file: %s, %s", + filename, + strerror(errno)); + res = FALSE; + break; + } + lua_pushboolean(L, res); + close(fd); + } + } + else { + lua_pushboolean(L, FALSE); + } + + return 1; +} + +static gint +lua_rsa_signature_create(lua_State *L) +{ + rspamd_fstring_t *sig, **psig; + const gchar *data; + gsize dlen; + + data = luaL_checklstring(L, 1, &dlen); + if (data != NULL) { + sig = rspamd_fstring_new_init(data, dlen); + psig = lua_newuserdata(L, sizeof(rspamd_fstring_t *)); + rspamd_lua_setclass(L, "rspamd{rsa_signature}", -1); + *psig = sig; + } + + return 1; +} + +static gint +lua_rsa_signature_gc(lua_State *L) +{ + rspamd_fstring_t *sig = lua_check_rsa_sign(L, 1); + + rspamd_fstring_free(sig); + + return 0; +} + +static gint +lua_rsa_signature_base64(lua_State *L) +{ + rspamd_fstring_t *sig = lua_check_rsa_sign(L, 1); + guint boundary = 0; + gchar *b64; + gsize outlen; + enum rspamd_newlines_type how = RSPAMD_TASK_NEWLINES_CRLF; + + if (lua_isnumber(L, 2)) { + boundary = lua_tonumber(L, 2); + } + + if (lua_isstring(L, 3)) { + const gchar *how_str = lua_tostring(L, 3); + + if (strcmp(how_str, "cr") == 0) { + how = RSPAMD_TASK_NEWLINES_CR; + } + else if (strcmp(how_str, "lf") == 0) { + how = RSPAMD_TASK_NEWLINES_LF; + } + else { + how = RSPAMD_TASK_NEWLINES_CRLF; + } + } + + b64 = rspamd_encode_base64_fold(sig->str, sig->len, boundary, &outlen, how); + + if (b64) { + lua_pushlstring(L, b64, outlen); + g_free(b64); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/** + * Check memory using specified rsa key and signature + * + * arguments: + * (rsa_pubkey, rsa_signature, string) + * + * returns: + * true - if string match rsa signature + * false - otherwise + */ +static gint +lua_rsa_verify_memory(lua_State *L) +{ + RSA *rsa; + rspamd_fstring_t *signature; + const gchar *data; + gsize sz; + gint ret; + + rsa = lua_check_rsa_pubkey(L, 1); + signature = lua_check_rsa_sign(L, 2); + data = luaL_checklstring(L, 3, &sz); + + if (rsa != NULL && signature != NULL && data != NULL) { + ret = RSA_verify(NID_sha256, data, sz, + signature->str, signature->len, rsa); + + if (ret == 0) { + lua_pushboolean(L, FALSE); + lua_pushstring(L, ERR_error_string(ERR_get_error(), NULL)); + + return 2; + } + else { + lua_pushboolean(L, TRUE); + } + } + else { + lua_pushnil(L); + } + + return 1; +} + +/** + * Sign memory using specified rsa key and signature + * + * arguments: + * (rsa_privkey, string) + * + * returns: + * rspamd_signature object + * nil - otherwise + */ +static gint +lua_rsa_sign_memory(lua_State *L) +{ + RSA *rsa; + rspamd_fstring_t *signature, **psig; + const gchar *data; + gsize sz; + gint ret; + + rsa = lua_check_rsa_privkey(L, 1); + data = luaL_checklstring(L, 2, &sz); + + if (rsa != NULL && data != NULL) { + signature = rspamd_fstring_sized_new(RSA_size(rsa)); + + guint siglen = signature->len; + ret = RSA_sign(NID_sha256, data, sz, + signature->str, &siglen, rsa); + + if (ret != 1) { + rspamd_fstring_free(signature); + + return luaL_error(L, "cannot sign: %s", + ERR_error_string(ERR_get_error(), NULL)); + } + else { + signature->len = siglen; + psig = lua_newuserdata(L, sizeof(rspamd_fstring_t *)); + rspamd_lua_setclass(L, "rspamd{rsa_signature}", -1); + *psig = signature; + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_rsa_keypair(lua_State *L) +{ + BIGNUM *e; + RSA *rsa, *pub_rsa, *priv_rsa, **prsa; + gint bits = lua_gettop(L) > 0 ? lua_tointeger(L, 1) : 1024; + + if (bits > 4096 || bits < 512) { + return luaL_error(L, "invalid bits count"); + } + + e = BN_new(); + rsa = RSA_new(); + g_assert(BN_set_word(e, RSA_F4) == 1); + g_assert(RSA_generate_key_ex(rsa, bits, e, NULL) == 1); + + priv_rsa = RSAPrivateKey_dup(rsa); + prsa = lua_newuserdata(L, sizeof(RSA *)); + rspamd_lua_setclass(L, "rspamd{rsa_privkey}", -1); + *prsa = priv_rsa; + + pub_rsa = RSAPublicKey_dup(rsa); + prsa = lua_newuserdata(L, sizeof(RSA *)); + rspamd_lua_setclass(L, "rspamd{rsa_pubkey}", -1); + *prsa = pub_rsa; + + RSA_free(rsa); + BN_free(e); + + return 2; +} + +static gint +lua_load_pubkey(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, rsapubkeylib_f); + + return 1; +} + +static gint +lua_load_privkey(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, rsaprivkeylib_f); + + return 1; +} + +static gint +lua_load_signature(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, rsasignlib_f); + + return 1; +} + +static gint +lua_load_rsa(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, rsalib_f); + + return 1; +} + +void luaopen_rsa(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{rsa_pubkey}", rsapubkeylib_m); + lua_pop(L, 1); + rspamd_lua_add_preload(L, "rspamd_rsa_pubkey", lua_load_pubkey); + + rspamd_lua_new_class(L, "rspamd{rsa_privkey}", rsaprivkeylib_m); + lua_pop(L, 1); + rspamd_lua_add_preload(L, "rspamd_rsa_privkey", lua_load_privkey); + + rspamd_lua_new_class(L, "rspamd{rsa_signature}", rsasignlib_m); + lua_pop(L, 1); + rspamd_lua_add_preload(L, "rspamd_rsa_signature", lua_load_signature); + + rspamd_lua_add_preload(L, "rspamd_rsa", lua_load_rsa); + + lua_settop(L, 0); +} diff --git a/src/lua/lua_spf.c b/src/lua/lua_spf.c new file mode 100644 index 0000000..a67a267 --- /dev/null +++ b/src/lua/lua_spf.c @@ -0,0 +1,620 @@ +/*- + * Copyright 2019 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * @file lua_spf.c + * This module exports spf functions to Lua + */ + +#include "lua_common.h" +#include "libserver/spf.h" +#include "libutil/ref.h" + +#define SPF_RECORD_CLASS "rspamd{spf_record}" + +LUA_FUNCTION_DEF(spf, resolve); +LUA_FUNCTION_DEF(spf, config); + +LUA_FUNCTION_DEF(spf_record, check_ip); +LUA_FUNCTION_DEF(spf_record, dtor); +LUA_FUNCTION_DEF(spf_record, get_domain); +LUA_FUNCTION_DEF(spf_record, get_elts); +LUA_FUNCTION_DEF(spf_record, get_ttl); +LUA_FUNCTION_DEF(spf_record, get_timestamp); +LUA_FUNCTION_DEF(spf_record, get_digest); + +static luaL_reg rspamd_spf_f[] = { + LUA_INTERFACE_DEF(spf, resolve), + LUA_INTERFACE_DEF(spf, config), + {NULL, NULL}, +}; + +static luaL_reg rspamd_spf_record_m[] = { + LUA_INTERFACE_DEF(spf_record, check_ip), + LUA_INTERFACE_DEF(spf_record, get_domain), + LUA_INTERFACE_DEF(spf_record, get_ttl), + LUA_INTERFACE_DEF(spf_record, get_digest), + LUA_INTERFACE_DEF(spf_record, get_elts), + LUA_INTERFACE_DEF(spf_record, get_timestamp), + {"__gc", lua_spf_record_dtor}, + {NULL, NULL}, +}; + +struct rspamd_lua_spf_cbdata { + struct rspamd_task *task; + lua_State *L; + struct rspamd_symcache_dynamic_item *item; + gint cbref; + ref_entry_t ref; +}; + +static gint +lua_load_spf(lua_State *L) +{ + lua_newtable(L); + + /* Create integer arguments to check SPF results */ + lua_newtable(L); + lua_pushinteger(L, SPF_FAIL); + lua_setfield(L, -2, "fail"); + lua_pushinteger(L, SPF_PASS); + lua_setfield(L, -2, "pass"); + lua_pushinteger(L, SPF_NEUTRAL); + lua_setfield(L, -2, "neutral"); + lua_pushinteger(L, SPF_SOFT_FAIL); + lua_setfield(L, -2, "soft_fail"); + + lua_setfield(L, -2, "policy"); + + /* Flags stuff */ + lua_newtable(L); + + lua_pushinteger(L, RSPAMD_SPF_RESOLVED_TEMP_FAILED); + lua_setfield(L, -2, "temp_fail"); + lua_pushinteger(L, RSPAMD_SPF_RESOLVED_NA); + lua_setfield(L, -2, "na"); + lua_pushinteger(L, RSPAMD_SPF_RESOLVED_PERM_FAILED); + lua_setfield(L, -2, "perm_fail"); + lua_pushinteger(L, RSPAMD_SPF_FLAG_CACHED); + lua_setfield(L, -2, "cached"); + + lua_setfield(L, -2, "flags"); + + luaL_register(L, NULL, rspamd_spf_f); + + return 1; +} + +void luaopen_spf(lua_State *L) +{ + rspamd_lua_new_class(L, SPF_RECORD_CLASS, rspamd_spf_record_m); + lua_pop(L, 1); /* No need in metatable... */ + + rspamd_lua_add_preload(L, "rspamd_spf", lua_load_spf); + lua_settop(L, 0); +} + +static void +lua_spf_push_result(struct rspamd_lua_spf_cbdata *cbd, gint code_flags, + struct spf_resolved *resolved, const gchar *err) +{ + g_assert(cbd != NULL); + REF_RETAIN(cbd); + + lua_pushcfunction(cbd->L, &rspamd_lua_traceback); + gint err_idx = lua_gettop(cbd->L); + + lua_rawgeti(cbd->L, LUA_REGISTRYINDEX, cbd->cbref); + + if (resolved) { + struct spf_resolved **presolved; + + presolved = lua_newuserdata(cbd->L, sizeof(*presolved)); + rspamd_lua_setclass(cbd->L, SPF_RECORD_CLASS, -1); + *presolved = spf_record_ref(resolved); + } + else { + lua_pushnil(cbd->L); + } + + lua_pushinteger(cbd->L, code_flags); + + if (err) { + lua_pushstring(cbd->L, err); + } + else { + lua_pushnil(cbd->L); + } + + if (lua_pcall(cbd->L, 3, 0, err_idx) != 0) { + struct rspamd_task *task = cbd->task; + + msg_err_task("cannot call callback function for spf: %s", + lua_tostring(cbd->L, -1)); + } + + lua_settop(cbd->L, err_idx - 1); + + REF_RELEASE(cbd); +} + +static void +lua_spf_dtor(struct rspamd_lua_spf_cbdata *cbd) +{ + if (cbd) { + luaL_unref(cbd->L, LUA_REGISTRYINDEX, cbd->cbref); + if (cbd->item) { + rspamd_symcache_item_async_dec_check(cbd->task, cbd->item, + "lua_spf"); + } + } +} + +static void +spf_lua_lib_callback(struct spf_resolved *record, struct rspamd_task *task, + gpointer ud) +{ + struct rspamd_lua_spf_cbdata *cbd = (struct rspamd_lua_spf_cbdata *) ud; + + if (record) { + if ((record->flags & RSPAMD_SPF_RESOLVED_NA)) { + lua_spf_push_result(cbd, RSPAMD_SPF_RESOLVED_NA, NULL, + "no SPF record"); + } + else if (record->elts->len == 0) { + if (record->flags & RSPAMD_SPF_RESOLVED_PERM_FAILED) { + lua_spf_push_result(cbd, RSPAMD_SPF_RESOLVED_PERM_FAILED, NULL, + "bad SPF record"); + } + else if ((record->flags & RSPAMD_SPF_RESOLVED_TEMP_FAILED)) { + lua_spf_push_result(cbd, RSPAMD_SPF_RESOLVED_TEMP_FAILED, NULL, + "temporary DNS error"); + } + else { + lua_spf_push_result(cbd, RSPAMD_SPF_RESOLVED_PERM_FAILED, NULL, + "empty SPF record"); + } + } + else if (record->domain) { + spf_record_ref(record); + lua_spf_push_result(cbd, record->flags, record, NULL); + spf_record_unref(record); + } + else { + lua_spf_push_result(cbd, RSPAMD_SPF_RESOLVED_PERM_FAILED, NULL, + "internal error: non empty record for no domain"); + } + } + else { + lua_spf_push_result(cbd, RSPAMD_SPF_RESOLVED_PERM_FAILED, NULL, + "internal error: no record"); + } + + REF_RELEASE(cbd); +} + +/*** + * @function rspamd_spf.resolve(task, callback) + * Resolves SPF credentials for a task + * @param {rspamd_task} task task + * @param {function} callback callback that is called on spf resolution +*/ +gint lua_spf_resolve(lua_State *L) +{ + struct rspamd_task *task = lua_check_task(L, 1); + + if (task && lua_isfunction(L, 2)) { + struct rspamd_lua_spf_cbdata *cbd = rspamd_mempool_alloc0(task->task_pool, + sizeof(*cbd)); + struct rspamd_spf_cred *spf_cred; + + cbd->task = task; + cbd->L = L; + lua_pushvalue(L, 2); + cbd->cbref = luaL_ref(L, LUA_REGISTRYINDEX); + /* TODO: make it as an optional parameter */ + spf_cred = rspamd_spf_get_cred(task); + cbd->item = rspamd_symcache_get_cur_item(task); + + if (cbd->item) { + rspamd_symcache_item_async_inc(task, cbd->item, "lua_spf"); + } + REF_INIT_RETAIN(cbd, lua_spf_dtor); + + if (!rspamd_spf_resolve(task, spf_lua_lib_callback, cbd, spf_cred)) { + msg_info_task("cannot make spf request for %s", + spf_cred ? spf_cred->domain : "empty domain"); + if (spf_cred) { + lua_spf_push_result(cbd, RSPAMD_SPF_RESOLVED_TEMP_FAILED, + NULL, "DNS failed"); + } + else { + lua_spf_push_result(cbd, RSPAMD_SPF_RESOLVED_NA, + NULL, "No domain"); + } + REF_RELEASE(cbd); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static gint +lua_spf_record_dtor(lua_State *L) +{ + struct spf_resolved *record; + + RSPAMD_LUA_CHECK_UDATA_PTR_OR_RETURN(L, 1, SPF_RECORD_CLASS, + struct spf_resolved, + record); + + if (record) { + spf_record_unref(record); + } + + return 0; +} + +static void +lua_spf_push_spf_addr(lua_State *L, struct spf_addr *addr) +{ + gchar *addr_mask; + + lua_createtable(L, 0, 4); + + lua_pushinteger(L, addr->mech); + lua_setfield(L, -2, "result"); + lua_pushinteger(L, addr->flags); + lua_setfield(L, -2, "flags"); + + if (addr->spf_string) { + lua_pushstring(L, addr->spf_string); + lua_setfield(L, -2, "str"); + } + + addr_mask = spf_addr_mask_to_string(addr); + + if (addr_mask) { + lua_pushstring(L, addr_mask); + lua_setfield(L, -2, "addr"); + g_free(addr_mask); + } +} + +static gint +spf_check_element(lua_State *L, struct spf_resolved *rec, struct spf_addr *addr, + struct rspamd_lua_ip *ip) +{ + gboolean res = FALSE; + const guint8 *s, *d; + guint af, mask, bmask, addrlen; + + + if (addr->flags & RSPAMD_SPF_FLAG_TEMPFAIL) { + /* Ignore failed addresses */ + + return -1; + } + + af = rspamd_inet_address_get_af(ip->addr); + /* Basic comparing algorithm */ + if (((addr->flags & RSPAMD_SPF_FLAG_IPV6) && af == AF_INET6) || + ((addr->flags & RSPAMD_SPF_FLAG_IPV4) && af == AF_INET)) { + d = rspamd_inet_address_get_hash_key(ip->addr, &addrlen); + + if (af == AF_INET6) { + s = (const guint8 *) addr->addr6; + mask = addr->m.dual.mask_v6; + } + else { + s = (const guint8 *) addr->addr4; + mask = addr->m.dual.mask_v4; + } + + /* Compare the first bytes */ + bmask = mask / CHAR_BIT; + if (mask > addrlen * CHAR_BIT) { + /* XXX: add logging */ + } + else if (memcmp(s, d, bmask) == 0) { + if (bmask * CHAR_BIT < mask) { + /* Compare the remaining bits */ + s += bmask; + d += bmask; + mask = (0xff << (CHAR_BIT - (mask - bmask * 8))) & 0xff; + + if ((*s & mask) == (*d & mask)) { + res = TRUE; + } + } + else { + res = TRUE; + } + } + } + else { + if (addr->flags & RSPAMD_SPF_FLAG_ANY) { + res = TRUE; + } + else { + res = FALSE; + } + } + + if (res) { + if (addr->flags & RSPAMD_SPF_FLAG_ANY) { + if (rec->flags & RSPAMD_SPF_RESOLVED_PERM_FAILED) { + lua_pushboolean(L, false); + lua_pushinteger(L, RSPAMD_SPF_RESOLVED_PERM_FAILED); + lua_pushfstring(L, "%cany", spf_mech_char(addr->mech)); + } + else if (rec->flags & RSPAMD_SPF_RESOLVED_TEMP_FAILED) { + lua_pushboolean(L, false); + lua_pushinteger(L, RSPAMD_SPF_RESOLVED_TEMP_FAILED); + lua_pushfstring(L, "%cany", spf_mech_char(addr->mech)); + } + else { + lua_pushboolean(L, true); + lua_pushinteger(L, addr->mech); + lua_spf_push_spf_addr(L, addr); + } + } + else { + lua_pushboolean(L, true); + lua_pushinteger(L, addr->mech); + lua_spf_push_spf_addr(L, addr); + } + + return 3; + } + + return -1; +} + +/*** + * @method rspamd_spf_record:check_ip(ip) + * Checks the processed record versus a specific IP address. This function + * returns 3 values normally: + * 1. Boolean check result + * 2. If result is `false` then the second value is the error flag (e.g. rspamd_spf.flags.temp_fail), otherwise it will be an SPF method + * 3. If result is `false` then this will be an error string, otherwise - an SPF string (e.g. `mx` or `ip4:x.y.z.1`) + * @param {rspamd_ip|string} ip address + * @return {result,flag_or_policy,error_or_addr} - triplet +*/ +static gint +lua_spf_record_check_ip(lua_State *L) +{ + struct spf_resolved *record; + RSPAMD_LUA_CHECK_UDATA_PTR_OR_RETURN(L, 1, SPF_RECORD_CLASS, + struct spf_resolved, + record); + struct rspamd_lua_ip *ip = NULL; + gint nres = 0; + gboolean need_free_ip = FALSE; + + if (lua_type(L, 2) == LUA_TUSERDATA) { + ip = lua_check_ip(L, 2); + } + else if (lua_type(L, 2) == LUA_TSTRING) { + const gchar *ip_str; + gsize iplen; + + ip = g_malloc0(sizeof(struct rspamd_lua_ip)); + ip_str = lua_tolstring(L, 2, &iplen); + + if (!rspamd_parse_inet_address(&ip->addr, + ip_str, iplen, RSPAMD_INET_ADDRESS_PARSE_DEFAULT)) { + g_free(ip); + ip = NULL; + } + else { + need_free_ip = TRUE; + } + } + + if (record && ip && ip->addr) { + for (guint i = 0; i < record->elts->len; i++) { + struct spf_addr *addr = &g_array_index(record->elts, struct spf_addr, i); + if ((nres = spf_check_element(L, record, addr, ip)) > 0) { + if (need_free_ip) { + g_free(ip); + } + + return nres; + } + } + } + else { + if (need_free_ip) { + g_free(ip); + } + + return luaL_error(L, "invalid arguments"); + } + + if (need_free_ip) { + g_free(ip); + } + + /* If we are here it means that there is no ALL record */ + /* + * According to https://tools.ietf.org/html/rfc7208#section-4.7 it means + * SPF neutral + */ + struct spf_addr fake_all; + + fake_all.mech = SPF_NEUTRAL; + fake_all.flags = RSPAMD_SPF_FLAG_ANY; + fake_all.spf_string = "all"; + + lua_pushboolean(L, true); + lua_pushinteger(L, SPF_NEUTRAL); + lua_spf_push_spf_addr(L, &fake_all); + + return 3; +} + +/*** + * @method rspamd_spf_record:get_domain() + * Returns domain for the specific spf record +*/ +static gint +lua_spf_record_get_domain(lua_State *L) +{ + struct spf_resolved *record; + RSPAMD_LUA_CHECK_UDATA_PTR_OR_RETURN(L, 1, SPF_RECORD_CLASS, + struct spf_resolved, + record); + + if (record) { + lua_pushstring(L, record->domain); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @method rspamd_spf_record:get_ttl() + * Returns ttl for the specific spf record +*/ +static gint +lua_spf_record_get_ttl(lua_State *L) +{ + struct spf_resolved *record; + RSPAMD_LUA_CHECK_UDATA_PTR_OR_RETURN(L, 1, SPF_RECORD_CLASS, + struct spf_resolved, + record); + + if (record) { + lua_pushinteger(L, record->ttl); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @method rspamd_spf_record:get_timestamp() + * Returns ttl for the specific spf record +*/ +static gint +lua_spf_record_get_timestamp(lua_State *L) +{ + struct spf_resolved *record; + RSPAMD_LUA_CHECK_UDATA_PTR_OR_RETURN(L, 1, SPF_RECORD_CLASS, + struct spf_resolved, + record); + + if (record) { + lua_pushnumber(L, record->timestamp); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @method rspamd_spf_record:get_digest() + * Returns string hex representation of the record digest (fast hash function) +*/ +static gint +lua_spf_record_get_digest(lua_State *L) +{ + struct spf_resolved *record; + RSPAMD_LUA_CHECK_UDATA_PTR_OR_RETURN(L, 1, SPF_RECORD_CLASS, + struct spf_resolved, + record); + + if (record) { + gchar hexbuf[64]; + + rspamd_snprintf(hexbuf, sizeof(hexbuf), "%xuL", record->digest); + lua_pushstring(L, hexbuf); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @method rspamd_spf_record:get_elts() + * Returns a list of all elements in an SPF record. Each element is a table with the + * following fields: + * + * - result - mech flag from rspamd_spf.results + * - flags - all flags + * - addr - address and mask as a string + * - str - string representation (if available) +*/ +static gint +lua_spf_record_get_elts(lua_State *L) +{ + struct spf_resolved *record; + RSPAMD_LUA_CHECK_UDATA_PTR_OR_RETURN(L, 1, SPF_RECORD_CLASS, + struct spf_resolved, + record); + + if (record) { + guint i; + struct spf_addr *addr; + + lua_createtable(L, record->elts->len, 0); + + for (i = 0; i < record->elts->len; i++) { + addr = (struct spf_addr *) &g_array_index(record->elts, + struct spf_addr, i); + lua_spf_push_spf_addr(L, addr); + + lua_rawseti(L, -2, i + 1); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @function rspamd_spf.config(object) + * Configures SPF library according to the UCL config + * @param {table} object configuration object +*/ +gint lua_spf_config(lua_State *L) +{ + ucl_object_t *config_obj = ucl_object_lua_import(L, 1); + + if (config_obj) { + spf_library_config(config_obj); + ucl_object_unref(config_obj); /* As we copy data all the time */ + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +}
\ No newline at end of file diff --git a/src/lua/lua_sqlite3.c b/src/lua/lua_sqlite3.c new file mode 100644 index 0000000..be7a9ae --- /dev/null +++ b/src/lua/lua_sqlite3.c @@ -0,0 +1,379 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" +#include "sqlite_utils.h" + +/*** + * @module rspamd_sqlite3 + * This module provides routines to query sqlite3 databases +@example +local sqlite3 = require "rspamd_sqlite3" + +local db = sqlite3.open("/tmp/db.sqlite") + +if db then + db:exec([[ CREATE TABLE x (id INT, value TEXT); ]]) + + db:exec([[ INSERT INTO x VALUES (?1, ?2); ]], 1, 'test') + + for row in db:rows([[ SELECT * FROM x ]]) do + print(string.format('%d -> %s', row.id, row.value)) + end +end + */ + +LUA_FUNCTION_DEF(sqlite3, open); +LUA_FUNCTION_DEF(sqlite3, sql); +LUA_FUNCTION_DEF(sqlite3, rows); +LUA_FUNCTION_DEF(sqlite3, close); +LUA_FUNCTION_DEF(sqlite3_stmt, close); + +static const struct luaL_reg sqlitelib_f[] = { + LUA_INTERFACE_DEF(sqlite3, open), + {NULL, NULL}}; + +static const struct luaL_reg sqlitelib_m[] = { + LUA_INTERFACE_DEF(sqlite3, sql), + {"query", lua_sqlite3_sql}, + {"exec", lua_sqlite3_sql}, + LUA_INTERFACE_DEF(sqlite3, rows), + {"__tostring", rspamd_lua_class_tostring}, + {"__gc", lua_sqlite3_close}, + {NULL, NULL}}; + +static const struct luaL_reg sqlitestmtlib_m[] = { + {"__tostring", rspamd_lua_class_tostring}, + {"__gc", lua_sqlite3_stmt_close}, + {NULL, NULL}}; + +static void lua_sqlite3_push_row(lua_State *L, sqlite3_stmt *stmt); + +static sqlite3 * +lua_check_sqlite3(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{sqlite3}"); + luaL_argcheck(L, ud != NULL, pos, "'sqlite3' expected"); + return ud ? *((sqlite3 **) ud) : NULL; +} + +static sqlite3_stmt * +lua_check_sqlite3_stmt(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{sqlite3_stmt}"); + luaL_argcheck(L, ud != NULL, pos, "'sqlite3_stmt' expected"); + return ud ? *((sqlite3_stmt **) ud) : NULL; +} + + +/*** + * @function rspamd_sqlite3.open(path) + * Opens sqlite3 database at the specified path. DB is created if not exists. + * @param {string} path path to db + * @return {sqlite3} sqlite3 handle + */ +static gint +lua_sqlite3_open(lua_State *L) +{ + const gchar *path = luaL_checkstring(L, 1); + sqlite3 *db, **pdb; + GError *err = NULL; + + if (path == NULL) { + lua_pushnil(L); + return 1; + } + + db = rspamd_sqlite3_open_or_create(NULL, path, NULL, 0, &err); + + if (db == NULL) { + if (err) { + msg_err("cannot open db: %e", err); + g_error_free(err); + } + lua_pushnil(L); + + return 1; + } + + pdb = lua_newuserdata(L, sizeof(db)); + *pdb = db; + rspamd_lua_setclass(L, "rspamd{sqlite3}", -1); + + return 1; +} + +static void +lua_sqlite3_bind_statements(lua_State *L, gint start, gint end, + sqlite3_stmt *stmt) +{ + gint i, type, num = 1; + const gchar *str; + gsize slen; + gdouble n; + + g_assert(start <= end && start > 0 && end > 0); + + for (i = start; i <= end; i++) { + type = lua_type(L, i); + + switch (type) { + case LUA_TNUMBER: + n = lua_tonumber(L, i); + + if (n == (gdouble) ((gint64) n)) { + sqlite3_bind_int64(stmt, num, n); + } + else { + sqlite3_bind_double(stmt, num, n); + } + num++; + break; + case LUA_TSTRING: + str = lua_tolstring(L, i, &slen); + sqlite3_bind_text(stmt, num, str, slen, SQLITE_TRANSIENT); + num++; + break; + default: + msg_err("invalid type at position %d: %s", i, lua_typename(L, type)); + break; + } + } +} + +/*** + * @function rspamd_sqlite3:sql(query[, args..]) + * Performs sqlite3 query replacing '?1', '?2' and so on with the subsequent args + * of the function + * + * @param {string} query SQL query + * @param {string|number} args... variable number of arguments + * @return {boolean} `true` if a statement has been successfully executed + */ +static gint +lua_sqlite3_sql(lua_State *L) +{ + LUA_TRACE_POINT; + sqlite3 *db = lua_check_sqlite3(L, 1); + const gchar *query = luaL_checkstring(L, 2); + sqlite3_stmt *stmt; + gboolean ret = FALSE; + gint top = 1, rc; + + if (db && query) { + if (sqlite3_prepare_v2(db, query, -1, &stmt, NULL) != SQLITE_OK) { + msg_err("cannot prepare query %s: %s", query, sqlite3_errmsg(db)); + return luaL_error(L, sqlite3_errmsg(db)); + } + else { + top = lua_gettop(L); + + if (top > 2) { + /* Push additional arguments to sqlite3 */ + lua_sqlite3_bind_statements(L, 3, top, stmt); + } + + rc = sqlite3_step(stmt); + top = 1; + + if (rc == SQLITE_ROW || rc == SQLITE_OK || rc == SQLITE_DONE) { + ret = TRUE; + + if (rc == SQLITE_ROW) { + lua_sqlite3_push_row(L, stmt); + top = 2; + } + } + else { + msg_warn("sqlite3 error: %s", sqlite3_errmsg(db)); + } + + sqlite3_finalize(stmt); + } + } + + lua_pushboolean(L, ret); + + return top; +} + +static void +lua_sqlite3_push_row(lua_State *L, sqlite3_stmt *stmt) +{ + const gchar *str; + gsize slen; + gint64 num; + gchar numbuf[32]; + gint nresults, i, type; + + nresults = sqlite3_column_count(stmt); + lua_createtable(L, 0, nresults); + + for (i = 0; i < nresults; i++) { + lua_pushstring(L, sqlite3_column_name(stmt, i)); + type = sqlite3_column_type(stmt, i); + + switch (type) { + case SQLITE_INTEGER: + /* + * XXX: we represent int64 as strings, as we can nothing else to do + * about it portably + */ + num = sqlite3_column_int64(stmt, i); + rspamd_snprintf(numbuf, sizeof(numbuf), "%uL", num); + lua_pushstring(L, numbuf); + break; + case SQLITE_FLOAT: + lua_pushnumber(L, sqlite3_column_double(stmt, i)); + break; + case SQLITE_TEXT: + slen = sqlite3_column_bytes(stmt, i); + str = sqlite3_column_text(stmt, i); + lua_pushlstring(L, str, slen); + break; + case SQLITE_BLOB: + slen = sqlite3_column_bytes(stmt, i); + str = sqlite3_column_blob(stmt, i); + lua_pushlstring(L, str, slen); + break; + default: + lua_pushboolean(L, 0); + break; + } + + lua_settable(L, -3); + } +} + +static gint +lua_sqlite3_next_row(lua_State *L) +{ + LUA_TRACE_POINT; + sqlite3_stmt *stmt = *(sqlite3_stmt **) lua_touserdata(L, lua_upvalueindex(1)); + gint rc; + + if (stmt != NULL) { + rc = sqlite3_step(stmt); + + if (rc == SQLITE_ROW) { + lua_sqlite3_push_row(L, stmt); + return 1; + } + } + + lua_pushnil(L); + + return 1; +} + +/*** + * @function rspamd_sqlite3:rows(query[, args..]) + * Performs sqlite3 query replacing '?1', '?2' and so on with the subsequent args + * of the function. This function returns iterator suitable for loop construction: + * + * @param {string} query SQL query + * @param {string|number} args... variable number of arguments + * @return {function} iterator to get all rows +@example +for row in db:rows([[ SELECT * FROM x ]]) do + print(string.format('%d -> %s', row.id, row.value)) +end + */ +static gint +lua_sqlite3_rows(lua_State *L) +{ + LUA_TRACE_POINT; + sqlite3 *db = lua_check_sqlite3(L, 1); + const gchar *query = luaL_checkstring(L, 2); + sqlite3_stmt *stmt, **pstmt; + gint top; + + if (db && query) { + if (sqlite3_prepare_v2(db, query, -1, &stmt, NULL) != SQLITE_OK) { + msg_err("cannot prepare query %s: %s", query, sqlite3_errmsg(db)); + lua_pushstring(L, sqlite3_errmsg(db)); + return lua_error(L); + } + else { + top = lua_gettop(L); + + if (top > 2) { + /* Push additional arguments to sqlite3 */ + lua_sqlite3_bind_statements(L, 3, top, stmt); + } + + /* Create C closure */ + pstmt = lua_newuserdata(L, sizeof(stmt)); + *pstmt = stmt; + rspamd_lua_setclass(L, "rspamd{sqlite3_stmt}", -1); + + lua_pushcclosure(L, lua_sqlite3_next_row, 1); + } + } + else { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_sqlite3_close(lua_State *L) +{ + LUA_TRACE_POINT; + sqlite3 *db = lua_check_sqlite3(L, 1); + + if (db) { + sqlite3_close(db); + } + + return 0; +} + +static gint +lua_sqlite3_stmt_close(lua_State *L) +{ + sqlite3_stmt *stmt = lua_check_sqlite3_stmt(L, 1); + + if (stmt) { + sqlite3_finalize(stmt); + } + + return 0; +} + +static gint +lua_load_sqlite3(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, sqlitelib_f); + + return 1; +} +/** + * Open redis library + * @param L lua stack + * @return + */ +void luaopen_sqlite3(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{sqlite3}", sqlitelib_m); + lua_pop(L, 1); + + rspamd_lua_new_class(L, "rspamd{sqlite3_stmt}", sqlitestmtlib_m); + lua_pop(L, 1); + + rspamd_lua_add_preload(L, "rspamd_sqlite3", lua_load_sqlite3); +} diff --git a/src/lua/lua_task.c b/src/lua/lua_task.c new file mode 100644 index 0000000..7278602 --- /dev/null +++ b/src/lua/lua_task.c @@ -0,0 +1,7295 @@ +/* + * Copyright 2023 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" +#include "lua_url.h" + +#include "message.h" +#include "images.h" +#include "archives.h" +#include "utlist.h" +#include "unix-std.h" +#include "libmime/smtp_parsers.h" +#include "libserver/mempool_vars_internal.h" +#include "libserver/dkim.h" +#include "libserver/task.h" +#include "libserver/cfg_file_private.h" +#include "libmime/scan_result_private.h" +#include "libstat/stat_api.h" +#include "libserver/maps/map_helpers.h" + +#include <math.h> +#include "libmime/received.h" + +/*** + * @module rspamd_task + * This module provides routines for tasks manipulation in rspamd. Tasks usually + * represent messages being scanned, and this API provides access to such elements + * as headers, symbols, metrics and so on and so forth. Normally, task objects + * are passed to the lua callbacks allowing to check specific properties of messages + * and add the corresponding symbols to the scan's results. +@example +rspamd_config.DATE_IN_PAST = function(task) + local dm = task:get_date{format = 'message', gmt = true} + local dt = task:get_date{format = 'connect', gmt = true} + -- A day + if dt - dm > 86400 then + return true + end + + return false +end + */ + +/* Task methods */ + +/*** + * @function rspamd_task.create([cfg]) + * Create a new empty task + * @return {rspamd_task} new task + */ +LUA_FUNCTION_DEF(task, create); +/*** + * @function rspamd_task.load_from_file(filename[, cfg]) + * Loads a message from specific file + * @return {boolean,rspamd_task|error} status + new task or error message + */ +LUA_FUNCTION_DEF(task, load_from_file); +/*** + * @function rspamd_task.load_from_string(message[, cfg]) + * Loads a message from specific file + * @return {boolean,rspamd_task|error} status + new task or error message + */ +LUA_FUNCTION_DEF(task, load_from_string); +/*** + * @method task:get_message() + * Returns task raw message content as opaque text + * @return {rspamd_text} task raw content + */ +LUA_FUNCTION_DEF(task, get_message); +/*** + * @method task:set_message(msg) + * Updates task message with another message; It also parses a message to + * fill the internal structures. + * Input might be a string, a lua_text or a table of the former stuff. + * @param {string/text/table} msg new message to set + * @return {boolean,number} if a message has been set + its raw size + */ +LUA_FUNCTION_DEF(task, set_message); +/*** + * @method task:process_message() + * Parses message + */ +LUA_FUNCTION_DEF(task, process_message); +/*** + * @method task:get_cfg() + * Get configuration object for a task. + * @return {rspamd_config} (config.md)[configuration object] for the task + */ +LUA_FUNCTION_DEF(task, get_cfg); +LUA_FUNCTION_DEF(task, set_cfg); +LUA_FUNCTION_DEF(task, destroy); +/*** + * @method task:get_mempool() + * Returns memory pool valid for a lifetime of task. It is used internally by + * many rspamd routines. + * @return {rspamd_mempool} memory pool object + */ +LUA_FUNCTION_DEF(task, get_mempool); +/*** + * @method task:get_session() + * Returns asynchronous session object that is used by many rspamd asynchronous + * utilities internally. + * @return {rspamd_session} session object + */ +LUA_FUNCTION_DEF(task, get_session); +/*** + * @method task:set_session(session) + * Sets new async session for a task + */ +LUA_FUNCTION_DEF(task, set_session); +/*** + * @method task:get_ev_base() + * Return asynchronous event base for using in callbacks and resolver. + * @return {rspamd_ev_base} event base + */ +LUA_FUNCTION_DEF(task, get_ev_base); +/*** + * @method task:get_worker() + * Returns a worker object associated with the task + * @return {rspamd_worker} worker object + */ +LUA_FUNCTION_DEF(task, get_worker); +/*** + * @method task:insert_result([enforce_symbol,]symbol, weight[, option1, ...]) + * Insert specific symbol to the tasks scanning results assigning the initial + * weight to it. + * @param {boolean} enforce_symbol if represented and true, then insert symbol even if it is not registered in the metric + * @param {string} symbol symbol to insert + * @param {number} weight initial weight (this weight is multiplied by the metric weight) + * @param {string} options list of optional options attached to a symbol inserted +@example +local function cb(task) + if task:get_header('Some header') then + task:insert_result('SOME_HEADER', 1.0, 'Got some header') + end +end + */ +LUA_FUNCTION_DEF(task, insert_result); +/*** + * @method task:insert_result_named(shadow_result, [enforce_symbol,]symbol, weight[, option1, ...]) + * Insert specific symbol to the tasks scanning results assigning the initial + * weight to it. + * @param {string} shadow_result name of shadow result + * @param {boolean} enforce_symbol if represented and true, then insert symbol even if it is not registered in the metric + * @param {string} symbol symbol to insert + * @param {number} weight initial weight (this weight is multiplied by the metric weight) + * @param {string} options list of optional options attached to a symbol inserted + */ +LUA_FUNCTION_DEF(task, insert_result_named); + +/*** + * @method task:adjust_result(symbol, score[, option1, ...]) + * Alters the existing symbol's score to a new score. It is not affected by + * metric score or grow factor. You can also add new options + * using this method. Symbol must be already inserted into metric or an error + * will be emitted. + * @param {string} symbol symbol to adjust + * @param {number} score this value is NOT multiplied by the metric score + * @param {string/table} options list of optional options attached to a symbol adjusted + */ +LUA_FUNCTION_DEF(task, adjust_result); + +/*** + * @method task:remove_result(symbol[, shadow_result]) + * Removes the symbol from a named or unnamed/default result + * @param {string} symbol symbol to remove + * @param {string} shadow_result name of shadow result + * @return {boolean} true if a symbol has been removed + */ +LUA_FUNCTION_DEF(task, remove_result); +/*** + * @method task:set_pre_result(action, [message, [module], [score], [priority], [flags]) + * Sets pre-result for a task. It is used in pre-filters to specify early result + * of the task scanned. If a pre-filter sets some result, then further processing + * may be skipped. For selecting action it is possible to use global table + * `rspamd_actions` or a string value: + * + * - `reject`: reject message permanently + * - `add header`: add spam header + * - `rewrite subject`: rewrite subject to spam subject + * - `greylist`: greylist message + * - `accept` or `no action`: whitelist message + * + * This function also accepts a table from Rspamd 2.6 with the following keys: + * - action: string required + * - message: string + * - module: string + * - score: number + * - priority: integer + * - flags: flags string + * - result: named result if needed + * + * @param {rspamd_action or string} action a numeric or string action value + * @param {string} message action message + * @param {string} module optional module name + * @param {number/nil} score optional explicit score + * @param {number/nil} priority optional explicit priority + * @param {string/nil} flags optional flags (e.g. 'least' for least action) +@example +local function cb(task) + local gr = task:get_header('Greylist') + if gr and gr == 'greylist' then + task:set_pre_result('soft reject', 'Greylisting required') + end +end + */ +LUA_FUNCTION_DEF(task, set_pre_result); + +/*** + * @method task:has_pre_result() + * Returns true if task has some pre-result being set. + * If result has been set this function also returns pre result action, + * message and module as strings in this order. + * + * @return {boolean,[string,string,string]} true if task has some pre-result being set + */ +LUA_FUNCTION_DEF(task, has_pre_result); +/*** + * @method task:append_message(message, [category]) + * Adds a message to scanning output. + * @param {string} message + * @param {category} message category +@example +local function cb(task) + task:append_message('Example message') +end + */ +LUA_FUNCTION_DEF(task, append_message); +/*** + * @method task:get_urls([need_emails|list_protos][, need_images]) + * Get all URLs found in a message. Telephone urls and emails are not included unless explicitly asked in `list_protos` + * @param {boolean} need_emails if `true` then return also email urls, this can be a comma separated string of protocols desired or a table (e.g. `mailto` or `telephone`) + * @param {boolean} need_images return urls from images (<img src=...>) as well + * @return {table rspamd_url} list of all urls found +@example +local function phishing_cb(task) + local urls = task:get_urls({'https', 'http'}); + + if urls then + for _,url in ipairs(urls) do + if url:is_phished() then + return true + end + end + end + return false +end + */ +LUA_FUNCTION_DEF(task, get_urls); +/*** + * @method task:get_urls_filtered([{flags_include}, [{flags_exclude}]], [{protocols_mask}]) + * Get urls managed by either exclude or include flags list + * - If flags include are nil then all but excluded urls are returned + * - If flags exclude are nil then only included explicitly urls are returned + * - If both parameters are nil then all urls are included + * @param {table} flags_include included flags + * @param {table} flags_exclude excluded flags + * @param {table} protocols_mask include only specific protocols + * @return {table rspamd_url} list of urls matching conditions + */ +LUA_FUNCTION_DEF(task, get_urls_filtered); +/*** + * @method task:has_urls([need_emails]) + * Returns 'true' if a task has urls listed + * @param {boolean} need_emails if `true` then return also email urls + * @return {boolean} true if a task has urls (urls or emails if `need_emails` is true) + */ +LUA_FUNCTION_DEF(task, has_urls); +/*** + * @method task:inject_url(url) + * Inserts an url into a task (useful for redirected urls) + * @param {lua_url} url url to inject + */ +LUA_FUNCTION_DEF(task, inject_url); +/*** + * @method task:get_content() + * Get raw content for the specified task + * @return {text} the data contained in the task + */ +LUA_FUNCTION_DEF(task, get_content); + +/*** + * @method task:get_filename() + * Returns filename for a specific task + * @return {string|nil} filename or nil if unknown + */ +LUA_FUNCTION_DEF(task, get_filename); + +/*** + * @method task:get_rawbody() + * Get raw body for the specified task + * @return {text} the data contained in the task + */ +LUA_FUNCTION_DEF(task, get_rawbody); + +/*** + * @method task:get_emails() + * Get all email addresses found in a message. + * @return {table rspamd_url} list of all email addresses found + */ +LUA_FUNCTION_DEF(task, get_emails); +/*** + * @method task:get_text_parts() + * Get all text (and HTML) parts found in a message + * @return {table rspamd_text_part} list of text parts + */ +LUA_FUNCTION_DEF(task, get_text_parts); +/*** + * @method task:get_parts() + * Get all mime parts found in a message + * @return {table rspamd_mime_part} list of mime parts + */ +LUA_FUNCTION_DEF(task, get_parts); +/*** + * @method task:get_meta_words([how='stem']) + * Get meta words from task (subject and displayed names) + * - `stem`: stemmed words (default) + * - `norm`: normalised words (utf normalised + lowercased) + * - `raw`: raw words in utf (if possible) + * - `full`: list of tables, each table has the following fields: + * - [1] - stemmed word + * - [2] - normalised word + * - [3] - raw word + * - [4] - flags (table of strings) + */ +LUA_FUNCTION_DEF(task, get_meta_words); +/*** + * @method task:get_request_header(name) + * Get value of a HTTP request header. + * @param {string} name name of header to get + * @return {rspamd_text} value of an HTTP header + */ +LUA_FUNCTION_DEF(task, get_request_header); +/*** + * @method task:set_request_header(name, value) + * Set value of a HTTP request header. If value is omitted, then a header is removed + * @param {string} name name of header to get + * @param {rspamd_text/string} value new header's value + */ +LUA_FUNCTION_DEF(task, set_request_header); +/*** + * @method task:get_subject() + * Returns task subject (either from the protocol override or from a header) + * @return {string} value of a subject (decoded) + */ +LUA_FUNCTION_DEF(task, get_subject); +/*** + * @method task:get_header(name[, case_sensitive]) + * Get decoded value of a header specified with optional case_sensitive flag. + * By default headers are searched in caseless matter. + * @param {string} name name of header to get + * @param {boolean} case_sensitive case sensitiveness flag to search for a header + * @return {string} decoded value of a header + */ +LUA_FUNCTION_DEF(task, get_header); +/*** + * @method task:has_header(name[, case_sensitive]) + * Get decoded value of a header specified with optional case_sensitive flag. + * By default headers are searched in the case insensitive matter. + * @param {string} name name of header to get + * @param {boolean} case_sensitive case sensitiveness flag to search for a header + * @return {boolean} true if header exists + */ +LUA_FUNCTION_DEF(task, has_header); +/*** + * @method task:get_header_raw(name[, case_sensitive]) + * Get raw value of a header specified with optional case_sensitive flag. + * By default headers are searched in caseless matter. + * @param {string} name name of header to get + * @param {boolean} case_sensitive case sensitiveness flag to search for a header + * @return {string} raw value of a header + */ +LUA_FUNCTION_DEF(task, get_header_raw); +/*** + * @method task:get_header_full(name[, case_sensitive[, need_modified]]) + * Get raw value of a header specified with optional case_sensitive flag. + * By default headers are searched in caseless matter. This method returns more + * information about the header as a list of tables with the following structure: + * + * - `name` - name of a header + * - `value` - raw value of a header + * - `decoded` - decoded value of a header + * - `tab_separated` - `true` if a header and a value are separated by `tab` character + * - `empty_separator` - `true` if there are no separator between a header and a value + * @param {string} name name of header to get + * @param {boolean} case_sensitive case sensitiveness flag to search for a header + * @param {boolean} need_modified return a modified value of a header if presented + * @return {list of tables} all values of a header as specified above +@example +function check_header_delimiter_tab(task, header_name) + for _,rh in ipairs(task:get_header_full(header_name)) do + if rh['tab_separated'] then return true end + end + return false +end + */ +LUA_FUNCTION_DEF(task, get_header_full); +/*** + * @method task:get_header_count(name[, case_sensitive]) + * Lightweight version if you need just a header's count + * * By default headers are searched in caseless matter. + * @param {string} name name of header to get + * @param {boolean} case_sensitive case sensitiveness flag to search for a header + * @return {number} number of header's occurrences or 0 if not found + */ +LUA_FUNCTION_DEF(task, get_header_count); +/*** + * @method task:get_raw_headers() + * Get all undecoded headers of a message as a string + * @return {rspamd_text} all raw headers for a message as opaque text + */ +LUA_FUNCTION_DEF(task, get_raw_headers); + +/*** + * @method task:get_headers([need_modified=false]) + * Get all headers of a message in the same format as get_header_full + * @return {table of headers data} all headers for a message + */ +LUA_FUNCTION_DEF(task, get_headers); + +/*** + * @method task:modify_header(name, mods) + * Modify an existing or non-existing header with the name `name` + * Mods is a table with the following structure: + * { + * "add" = { {order, value}, {order, value} }, + * "remove" = {order, order, order} + * } + * Modifications are evaluated in order: remove, add, so headers are first + * removed (if any) and then added + * Order in remove starts from 1, where 0 means 'remove all', and negative value means + * remove from the end + * Order in addition means addition from the top: 0 means the most top header, 1 one after, etc + * negative order means addition to the end, e.g. -1 is appending header. + * @return {bool} true if header could be modified (always true unless we don't have an unparsed message) + */ +LUA_FUNCTION_DEF(task, modify_header); + +/*** + * @method task:get_received_headers() + * Returns a list of tables of parsed received headers. A tables returned have + * the following structure: + * + * - `from_hostname` - string that represents hostname provided by a peer + * - `from_ip` - string representation of sending IP address + * - `real_hostname` - hostname as resolved by MTA + * - `real_ip` - rspamd_ip object representing sending IP address + * - `by_hostname` - MTA hostname + * - `proto` - protocol, e.g. ESMTP or ESMTPS + * - `timestamp` - received timestamp + * - `for` - for value (unparsed mailbox) + * + * Please note that in some situations rspamd cannot parse all the fields of received headers. + * In that case you should check all strings for validity. + * @return {table of tables} list of received headers described above + */ +LUA_FUNCTION_DEF(task, get_received_headers); +/*** + * @method task:get_queue_id() + * Returns queue ID of the message being processed. + */ +LUA_FUNCTION_DEF(task, get_queue_id); +/*** + * @method task:get_uid() + * Returns ID of the task being processed. + */ +LUA_FUNCTION_DEF(task, get_uid); +/*** + * @method task:get_resolver() + * Returns ready to use rspamd_resolver object suitable for making asynchronous DNS requests. + * @return {rspamd_resolver} resolver object associated with the task's session + * @example +local logger = require "rspamd_logger" + +local function task_cb(task) + local function dns_cb(resolver, to_resolve, results, err) + -- task object is available due to closure + task:inc_dns_req() + if results then + logger.info(string.format('<%s> [%s] resolved for symbol: %s', + task:get_message_id(), to_resolve, 'EXAMPLE_SYMBOL')) + task:insert_result('EXAMPLE_SYMBOL', 1) + end + end + local r = task:get_resolver() + r:resolve_a(task:get_session(), task:get_mempool(), 'example.com', dns_cb) +end + */ +LUA_FUNCTION_DEF(task, get_resolver); +/*** + * @method task:set_resolver(resolver) + * Sets rspamd_resolver for a specified task. + */ +LUA_FUNCTION_DEF(task, set_resolver); +/*** +* @method task:inc_dns_req() +* Increment number of DNS requests for the task. Is used just for logging purposes. +*/ +LUA_FUNCTION_DEF(task, inc_dns_req); +/*** + * @method task:get_dns_req() + * Get number of dns requests being sent in the task + * @return {number} number of DNS requests + */ +LUA_FUNCTION_DEF(task, get_dns_req); + +/*** + * @method task:has_recipients([type]) + * Return true if there are SMTP or MIME recipients for a task. + * @param {integer|string} type if specified has the following meaning: `0` or `any` means try SMTP recipients and fallback to MIME if failed, `1` or `smtp` means checking merely SMTP recipients and `2` or `mime` means MIME recipients only + * @return {bool,integer} `true` if there are recipients of the following type and a number of such a recipients excluding artificial ones + */ +LUA_FUNCTION_DEF(task, has_recipients); + +/*** + * @method task:get_recipients([type]) + * Return SMTP or MIME recipients for a task. This function returns list of internet addresses each one is a table with the following structure: + * + * - `name` - name of internet address in UTF8, e.g. for `Vsevolod Stakhov <blah@foo.com>` it returns `Vsevolod Stakhov` + * - `addr` - address part of the address + * - `user` - user part (if present) of the address, e.g. `blah` + * - `domain` - domain part (if present), e.g. `foo.com` + * @param {integer|string} type if specified has the following meaning: `0` or `any` means try SMTP recipients and fallback to MIME if failed, `1` or `smtp` means checking merely SMTP recipients and `2` or `mime` means MIME recipients only + * @return {list of addresses} list of recipients or `nil` + */ +LUA_FUNCTION_DEF(task, get_recipients); + +/*** + * @method task:get_principal_recipient() + * Returns a single string with so called `principal recipient` for a message. The order + * of check is the following: + * + * - deliver-to request header + * - the first recipient (envelope) + * - the first recipient (mime) + * @return {string} principal recipient + */ +LUA_FUNCTION_DEF(task, get_principal_recipient); +/*** + * @method task:get_reply_sender() + * Returns a single string with address that should be used to reply on a message + * + * - reply-to header + * - from header + * - smtp from as a last resort + * @return {address} email address + */ +LUA_FUNCTION_DEF(task, get_reply_sender); + +/*** + * @method task:set_recipients([type], {rcpt1, rcpt2...}, [how='add']) + * Sets recipients for a task. This function accepts table that will be converted to the address. + * If some fields are missing they are subsequently reconstructed by this function. E.g. if you + * specify 'user' and 'domain', then address and raw string will be reconstructed + * + * - `name` - name of internet address in UTF8, e.g. for `Vsevolod Stakhov <blah@foo.com>` it returns `Vsevolod Stakhov` + * - `addr` - address part of the address + * - `user` - user part (if present) of the address, e.g. `blah` + * - `domain` - domain part (if present), e.g. `foo.com` + * @param {integer|string} type if specified has the following meaning: `0` or `any` means try SMTP recipients and fallback to MIME if failed, `1` or `smtp` means checking merely SMTP recipients and `2` or `mime` means MIME recipients only + * @param {list of tables} recipients recipients to set + * @param {string} how define how to set recipients: `rewrite` - rewrite existing recipients, `alias` - treat existing recipients as aliased recipients, `add` - add new recipients + * @return {boolean} result of the operation + */ +LUA_FUNCTION_DEF(task, set_recipients); + +/*** + * @method task:has_from([type]) + * Return true if there is SMTP or MIME sender for a task. + * @param {integer|string} type if specified has the following meaning: `0` or `any` means try SMTP recipients and fallback to MIME if failed, `1` or `smtp` means checking merely SMTP recipients and `2` or `mime` means MIME recipients only + * @return {bool} `true` if there is sender of the following type + */ +LUA_FUNCTION_DEF(task, has_from); + +/*** + * @method task:get_from([type]) + * Return SMTP or MIME sender for a task. This function returns an internet address which one is a table with the following structure: + * + * - `raw` - the original value without any processing + * - `name` - name of internet address in UTF8, e.g. for `Vsevolod Stakhov <blah@foo.com>` it returns `Vsevolod Stakhov` + * - `addr` - address part of the address + * - `user` - user part (if present) of the address, e.g. `blah` + * - `domain` - domain part (if present), e.g. `foo.com` + * - `flags` - table with following keys set to true if given condition fulfilled: + * - [valid] - valid SMTP address in conformity with https://tools.ietf.org/html/rfc5321#section-4.1. + * - [ip] - domain is IPv4/IPv6 address + * - [braced] - angled `<blah@foo.com>` address + * - [quoted] - quoted user part + * - [empty] - empty address + * - [backslash] - user part contains backslash + * - [8bit] - contains 8bit characters + * @param {integer|string} type if specified has the following meaning: `0` or `any` means try SMTP sender and fallback to MIME if failed, `1` or `smtp` means checking merely SMTP sender and `2` or `mime` means MIME `From:` only + * @return {address} sender or `nil` + */ +LUA_FUNCTION_DEF(task, get_from); + +/*** + * @method task:set_from(type, addr) + * Sets sender for a task. This function accepts table that will be converted to the address. + * If some fields are missing they are subsequently reconstructed by this function. E.g. if you + * specify 'user' and 'domain', then address and raw string will be reconstructed + * + * - `name` - name of internet address in UTF8, e.g. for `Vsevolod Stakhov <blah@foo.com>` it returns `Vsevolod Stakhov` + * - `addr` - address part of the address + * - `user` - user part (if present) of the address, e.g. `blah` + * - `domain` - domain part (if present), e.g. `foo.com` + * @param {integer|string} type if specified has the following meaning: `0` or `any` means try SMTP sender and fallback to MIME if failed, `1` or `smtp` means checking merely SMTP sender and `2` or `mime` means MIME `From:` only + * @param {table + * @return {boolean} success or not + */ +LUA_FUNCTION_DEF(task, set_from); +/*** + * @method task:get_user() + * Returns authenticated user name for this task if specified by an MTA. + * @return {string} username or nil + */ +LUA_FUNCTION_DEF(task, get_user); +/*** + * @method task:set_user([username]) + * Sets or resets (if username is not specified) authenticated user name for this task. + * @return {string} the previously set username or nil + */ +LUA_FUNCTION_DEF(task, set_user); +/*** + * @method task:get_from_ip() + * Returns [ip_addr](ip.md) object of a sender that is provided by MTA + * @return {rspamd_ip} ip address object + */ +LUA_FUNCTION_DEF(task, get_from_ip); +/*** + * @method task:set_from_ip(str) + * Set tasks's IP address based on the passed string + * @param {string} str string representation of ip + */ +LUA_FUNCTION_DEF(task, set_from_ip); +LUA_FUNCTION_DEF(task, get_from_ip_num); +/*** + * @method task:get_client_ip() + * Returns [ip_addr](ip.md) object of a client connected to rspamd (normally, it is an IP address of MTA) + * @return {rspamd_ip} ip address object + */ +LUA_FUNCTION_DEF(task, get_client_ip); +/*** + * @method task:get_helo() + * Returns the value of SMTP helo provided by MTA. + * @return {string} HELO value + */ +LUA_FUNCTION_DEF(task, get_helo); +LUA_FUNCTION_DEF(task, set_helo); +/*** + * @method task:get_hostname() + * Returns the value of sender's hostname provided by MTA + * @return {string} hostname value + */ +LUA_FUNCTION_DEF(task, get_hostname); +LUA_FUNCTION_DEF(task, set_hostname); +/*** + * @method task:get_images() + * Returns list of all images found in a task as a table of `rspamd_image`. + * Each image has the following methods: + * + * * `get_width` - return width of an image in pixels + * * `get_height` - return height of an image in pixels + * * `get_type` - return string representation of image's type (e.g. 'jpeg') + * * `get_filename` - return string with image's file name + * * `get_size` - return size in bytes + * @return {list of rspamd_image} images found in a message + */ +LUA_FUNCTION_DEF(task, get_images); +/*** + * @method task:get_archives() + * Returns list of all archives found in a task as a table of `rspamd_archive`. + * Each archive has the following methods available: + * + * * `get_files` - return list of strings with filenames inside archive + * * `get_files_full` - return list of tables with all information about files + * * `is_encrypted` - return true if an archive is encrypted + * * `get_type` - return string representation of image's type (e.g. 'zip') + * * `get_filename` - return string with archive's file name + * * `get_size` - return size in bytes + * @return {list of rspamd_archive} archives found in a message + */ +LUA_FUNCTION_DEF(task, get_archives); +/*** + * @method task:get_dkim_results() + * Returns list of all dkim check results as table of maps. Callee must ensure that + * dkim checks have been completed by adding dependency on `DKIM_TRACE` symbol. + * Fields in map: + * + * * `result` - string result of check: + * - `reject` + * - `allow` + * - `tempfail` + * - `permfail` + * - `not found` + * - `bad record` + * * `domain` - dkim domain + * * `selector` - dkim selector + * * `bhash` - short version of b tag (8 base64 symbols) + * * `fail_reason` - reason of failure (if applicable) + * @return {list of maps} dkim check results + */ +LUA_FUNCTION_DEF(task, get_dkim_results); +/*** + * @method task:get_symbol(name, [shadow_result_name]) + * Searches for a symbol `name` in all metrics results and returns a list of tables + * one per metric that describes the symbol inserted. + * Please note, that for using this function you need to ensure that the symbol + * being queried is already checked. This is guaranteed if there is a dependency + * between the caller symbol and the checked symbol (either virtual or real). + * Please check `rspamd_config:register_dependency` method for details. + * The symbols are returned as the list of the following tables: + * + * - `metric` - name of metric + * - `score` - score of a symbol in that metric + * - `options` - a table of strings representing options of a symbol + * - `group` - a group of symbol (or 'ungrouped') + * @param {string} name symbol's name + * @return {list of tables} list of tables or nil if symbol was not found + */ +LUA_FUNCTION_DEF(task, get_symbol); +/*** + * @method task:get_symbols_all() + * Returns array of symbols matched in default metric with all metadata + * @return {table} table of tables formatted as in `task:get_symbol()` except that `metric` is absent and `name` is added + */ +LUA_FUNCTION_DEF(task, get_symbols_all); +/*** + * @method task:get_symbols([shadow_result_name]) + * Returns array of all symbols matched for this task + * @return {table, table} table of strings with symbols names + table of theirs scores + */ +LUA_FUNCTION_DEF(task, get_symbols); + +/*** + * @method task:get_groups([need_private]) + * Returns a map [group -> group_score] for matched group. If `need_private` is + * unspecified, then the global option `public_groups_only` is used for default. + * @return {table, number} a map [group -> group_score] + */ +LUA_FUNCTION_DEF(task, get_groups); + +/*** + * @method task:get_symbols_numeric() + * Returns array of all symbols matched for this task + * @return {table|number, table|number} table of numbers with symbols ids + table of theirs scores + */ +LUA_FUNCTION_DEF(task, get_symbols_numeric); + +/*** + * @method task:get_symbols_tokens() + * Returns array of all symbols as statistical tokens + * @return {table|number} table of numbers + */ +LUA_FUNCTION_DEF(task, get_symbols_tokens); + +/*** + * @method task:process_ann_tokens(symbols, ann_tokens, offset, [min]) + * Processes ann tokens + * @param {table|string} symbols list of symbols in this profile + * @param {table|number} ann_tokens list of tokens (including metatokens) + * @param {integer} offset offset for symbols token (#metatokens) + * @param {number} min minimum value for symbols found (e.g. for 0 score symbols) + * @return nothing + */ +LUA_FUNCTION_DEF(task, process_ann_tokens); + +/*** + * @method task:has_symbol(name, [shadow_result_name]) + * Fast path to check if a specified symbol is in the task's results. + * Please note, that for using this function you need to ensure that the symbol + * being queried is already checked. This is guaranteed if there is a dependency + * between the caller symbol and the checked symbol (either virtual or real). + * Please check `rspamd_config:register_dependency` method for details. + * @param {string} name symbol's name + * @return {boolean} `true` if symbol has been found + */ +LUA_FUNCTION_DEF(task, has_symbol); +/*** + * @method task:enable_symbol(name) + * Enable specified symbol for this particular task + * @param {string} name symbol's name + * @return {boolean} `true` if symbol has been found + */ +LUA_FUNCTION_DEF(task, enable_symbol); +/*** + * @method task:disable_symbol(name) + * Disable specified symbol for this particular task + * @param {string} name symbol's name + * @return {boolean} `true` if symbol has been found + */ +LUA_FUNCTION_DEF(task, disable_symbol); +/*** + * @method task:get_date(type[, gmt]) + * Returns timestamp for a connection or for a MIME message. This function can be called with a + * single table arguments with the following fields: + * + * * `format` - a format of date returned: + * - `message` - returns a mime date as integer (unix timestamp) + * - `connect` - returns a unix timestamp of a connection to rspamd + * * `gmt` - returns date in `GMT` timezone (normal for unix timestamps) + * + * By default this function returns connection time in numeric format. + * @param {string} type date format as described above + * @param {boolean} gmt gmt flag as described above + * @return {string/number} date representation according to format + * @example +rspamd_config.DATE_IN_PAST = function(task) + local dm = task:get_date{format = 'message', gmt = true} + local dt = task:get_date{format = 'connect', gmt = true} + -- A day + if dt - dm > 86400 then + return true + end + + return false +end + */ +LUA_FUNCTION_DEF(task, get_date); +/*** + * @method task:get_message_id() + * Returns message identifier from the `Message-ID` header. Angle brackets (`<>`) + * are stripped off if present. If a Message-ID header is missing `undef` is + * returned. + * @return {string} ID of the message + */ +LUA_FUNCTION_DEF(task, get_message_id); +/*** + * @method task:get_timeval([raw]) + * Returns the timestamp for a task start processing time. + * @param {boolean} raw if true then two float numbers are returned: task start timestamp and timeout event timestamp + * @return {table} table with fields as described in `struct timeval` in C + */ +LUA_FUNCTION_DEF(task, get_timeval); +/*** + * @method task:get_scan_time([set]) + * Returns 2 floating point numbers: scan real time and scan virtual time. + * If `set` is `true`, then the finishing time is also set (enabled by default). + * This function should be normally called on idempotent phase. + * @return {number,number} real and virtual times in seconds with floating point + */ +LUA_FUNCTION_DEF(task, get_scan_time); +/*** + * @method task:get_metric_result() + * Get full result of a metric as a table: + * - `score`: current score + * - `action`: current action as a string + * - `nnegative`: number of negative rules matched + * - `npositive`: number of positive rules matched + * - `positive_score`: total score for positive rules + * - `negative_score`: total score for negative rules + * - `passthrough`: set to true if message has a passthrough result + * @return {table} metric result + */ +LUA_FUNCTION_DEF(task, get_metric_result); +/*** + * @method task:get_metric_score(name) + * Get the current score of metric `name` (must be nil or 'default') . Should be used in idempotent filters only. + * @param {string} name name of a metric + * @return {number,number} 2 numbers containing the current score and required score of the metric + */ +LUA_FUNCTION_DEF(task, get_metric_score); +/*** + * @method task:get_metric_action(name) + * Get the current action of metric `name` (must be nil or 'default'). Should be used in idempotent filters only. + * @param {string} name name of a metric + * @return {string} the current action of the metric as a string + */ +LUA_FUNCTION_DEF(task, get_metric_action); +/*** + * @method task:set_metric_score(name, score) + * Set the current score of metric `name`. Should be used in post-filters only. + * @param {string} name name of a metric + * @param {number} score the current score of the metric + */ +LUA_FUNCTION_DEF(task, set_metric_score); +/*** + * @method task:set_metric_subject(subject) + * Set the subject in the default metric + * @param {string} subject subject to set + */ +LUA_FUNCTION_DEF(task, set_metric_subject); + +/*** + * @method task:learn(is_spam[, classifier) + * Learn classifier `classifier` with the task. If `is_spam` is true then message + * is learnt as spam. Otherwise HAM is learnt. By default, this function learns + * `bayes` classifier. + * @param {boolean} is_spam learn spam or ham + * @param {string} classifier classifier's name + * @return {boolean} `true` if classifier has been learnt successfully + */ +LUA_FUNCTION_DEF(task, learn); +/*** + * @method task:set_settings(obj) + * Set users settings object for a task. The format of this object is described + * [here](https://rspamd.com/doc/configuration/settings.html). + * @param {any} obj any lua object that corresponds to the settings format + */ +LUA_FUNCTION_DEF(task, set_settings); + +/*** + * @method task:set_settings_id(id) + * Set users settings id for a task (must be registered previously) + * @available 2.0+ + * @param {number} id numeric id + * @return {boolean} true if settings id has been replaced from the existing one + */ +LUA_FUNCTION_DEF(task, set_settings_id); + +/*** + * @method task:get_settings() + * Gets users settings object for a task. The format of this object is described + * [here](https://rspamd.com/doc/configuration/settings.html). + * @return {lua object} lua object generated from UCL + */ +LUA_FUNCTION_DEF(task, get_settings); + +/*** + * @method task:lookup_settings(key) + * Gets users settings object with the specified key for a task. + * @param {string} key key to lookup + * @return {lua object} lua object generated from UCL + */ +LUA_FUNCTION_DEF(task, lookup_settings); + +/*** + * @method task:get_settings_id() + * Get numeric hash of settings id if specified for this task. 0 is returned otherwise. + * @return {number} settings-id hash + */ +LUA_FUNCTION_DEF(task, get_settings_id); + +/*** + * @method task:set_milter_reply(obj) + * Set special reply for milter + * @param {any} obj any lua object that corresponds to the settings format + * @example +task:set_milter_reply({ + add_headers = {['X-Lua'] = 'test'}, + -- 1 is the position of header to remove + remove_headers = {['DKIM-Signature'] = 1}, +}) + */ +LUA_FUNCTION_DEF(task, set_milter_reply); + +/*** + * @method task:process_re(params) + * Processes the specified regexp and returns number of captures (cached or new) + * Params is the table with the following fields (mandatory fields are marked with `*`): + * - `re`* : regular expression object + * - `type`*: type of regular expression: + * + `mime`: mime regexp + * + `header`: header regexp + * + `rawheader`: raw header expression + * + `rawmime`: raw mime regexp + * + `body`: raw body regexp + * + `url`: url regexp + * - `header`: for header and rawheader regexp means the name of header + * - `strong`: case sensitive match for headers + * @return {number} number of regexp occurrences in the task (limited by 255 so far) + */ +LUA_FUNCTION_DEF(task, process_regexp); + +/*** + * @method task:cache_set(key, value) + * Store some value to the task cache + * @param {string} key key to use + * @param {any} value any value (including functions and tables) + */ +LUA_FUNCTION_DEF(task, cache_set); +/*** + * @method task:cache_get(key) + * Returns cached value or nil if nothing is cached + * @param {string} key key to use + * @return {any} cached value + */ +LUA_FUNCTION_DEF(task, cache_get); + +/*** + * @method task:get_size() + * Returns size of the task in bytes (that includes headers + parts size) + * @return {number} size in bytes + */ +LUA_FUNCTION_DEF(task, get_size); + +/*** + * @method task:set_flag(flag_name[, set]) + * Set specific flag for task: + * + * - `no_log`: do not log task summary + * - `no_stat`: do not include task into scanned stats + * - `pass_all`: check all filters for task + * - `extended_urls`: output extended info about urls + * - `skip`: skip task processing + * - `learn_spam`: learn message as spam + * - `learn_ham`: learn message as ham + * - `broken_headers`: header data is broken for a message + * @param {string} flag to set + * @param {boolean} set set or clear flag (default is set) +@example +--[[ +For messages with undefined queue ID (scanned with rspamc or WebUI) +do not include results into statistics and do not log task summary +(it will not appear in the WebUI history as well). +]]-- + +-- Callback function to set flags +local function no_log_stat_cb(task) + if not task:get_queue_id() then + task:set_flag('no_log') + task:set_flag('no_stat') + end +end + +rspamd_config:register_symbol({ + name = 'LOCAL_NO_LOG_STAT', + type = 'postfilter', + callback = no_log_stat_cb +}) + */ +LUA_FUNCTION_DEF(task, set_flag); + + +/*** + * @method task:has_flag(flag_name) + * Checks for a specific flag in task: + * + * - `no_log`: do not log task summary + * - `no_stat`: do not include task into scanned stats + * - `pass_all`: check all filters for task + * - `extended_urls`: output extended info about urls + * - `skip`: skip task processing + * - `learn_spam`: learn message as spam + * - `learn_ham`: learn message as ham + * - `broken_headers`: header data is broken for a message + * @param {string} flag to check + * @return {boolean} true if flags is set + */ +LUA_FUNCTION_DEF(task, has_flag); + + +/*** + * @method task:get_flags() + * Get list of flags for task: + * + * - `no_log`: do not log task summary + * - `no_stat`: do not include task into scanned stats + * - `pass_all`: check all filters for task + * - `extended_urls`: output extended info about urls + * - `skip`: skip task processing + * - `learn_spam`: learn message as spam + * - `learn_ham`: learn message as ham + * - `broken_headers`: header data is broken for a message + * - `milter`: task is initiated by milter connection + * @return {array of strings} table with all flags as strings + */ +LUA_FUNCTION_DEF(task, get_flags); + +/*** + * @method task:get_digest() + * Returns message's unique digest (32 hex symbols) + * @return {string} hex digest + */ +LUA_FUNCTION_DEF(task, get_digest); + +/*** + * @method task:store_in_file([mode|table]) + * If task was loaded using file scan, then this method just returns its name, + * otherwise, a fresh temporary file is created and its name is returned. Default + * mode is 0600. To convert lua number to the octal mode you can use the following + * trick: `tonumber("0644", 8)`. The file is automatically removed when task is + * destroyed. + * + * If table argument is specified, the following extra fields are allowed: + * + * - `mode`: same as mode argument + * - `force_new`: always create a new file + * - `filename`: use specific filename instead of a temporary one + * - `tmpmask`: use specific tempmask, e.g. '/tmp/file-XXXXX', where XXXX will be replaced by some random letters + * - `keep`: do not remove file after task is processed + * + * @param {number} mode mode for new file + * @return {string} file name with task content + */ +LUA_FUNCTION_DEF(task, store_in_file); + +/*** + * @method task:get_protocol_reply([flags]) + * This method being called from a **postfilter** will return reply for a message + * as it is returned to a client. This method returns the Lua table corresponding + * to the UCL object. Flags is a table that specify which information should be + * there in a reply: + * + * - `basic`: basic info, such as message-id + * - `metrics`: metrics and symbols + * - `messages`: messages + * - `dkim`: dkim signature + * - `milter`: milter control block + * - `extra`: extra data, such as profiling + * - `urls`: list of all urls in a message + * + * @param {table} flags table of flags (default is all flags but `urls`) + * @return {table} ucl object corresponding to the reply + */ +LUA_FUNCTION_DEF(task, get_protocol_reply); + +/*** + * @method task:headers_foreach(callback, [params]) + * This method calls `callback` for each header that satisfies some condition. + * By default, all headers are iterated unless `callback` returns `true`. Nil or + * false means continue of iterations. + * Params could be as following: + * + * - `full`: header value is full table of all attributes @see task:get_header_full for details + * - `regexp`: return headers that satisfies the specified regexp + * @param {function} callback function from header name and header value + * @param {table} params optional parameters + */ +LUA_FUNCTION_DEF(task, headers_foreach); + +/*** + * @method task:disable_action(action) + * Disables some action for this task (e.g. 'greylist') + * + * @param {string} action action to disable + * @return {boolean} true if an action was enabled and is disabled after the method call + */ +LUA_FUNCTION_DEF(task, disable_action); + +/*** + * @method task:get_newlines_type() + * Returns the most frequent newlines type met in a task + * + * @return {string} "cr" for \r, "lf" for \n, "crlf" for \r\n + */ +LUA_FUNCTION_DEF(task, get_newlines_type); + +/*** + * @method task:get_stat_tokens() + * Returns list of tables the statistical tokens: + * - `data`: 64 bit number encoded as a string + * - `t1`: the first token (if any) + * - `t2`: the second token (if any) + * - `win`: window index + * - `flag`: table of strings: + * - `text`: text token + * - `meta`: meta token + * - `lua`: lua meta token + * - `exception`: exception + * - `subject`: subject token + * - `unigram`: unigram token + * + * @return {table of tables} + */ +LUA_FUNCTION_DEF(task, get_stat_tokens); + +/*** + * @method task:lookup_words(map, function({o, n, s, f}) ... end) + * Matches words in a task (including meta words) against some map (set, regexp and so on) + * and call the specified function with a table containing 4 values: + * - [1] - stemmed word + * - [2] - normalised word + * - [3] - raw word + * - [4] - flags (table of strings) + */ +LUA_FUNCTION_DEF(task, lookup_words); + +/** + * @method task:topointer() + * + * Returns raw C pointer (lightuserdata) associated with task. This might be + * broken with luajit and GC64, use with caution. + */ +LUA_FUNCTION_DEF(task, topointer); + +/** + * @method task:add_named_result(name, symbol_control_function) + * + * Adds a new named result for this task. symbol_control_function is a callback + * called with 3 parameters: + * `function(task, symbol, result_name)` and it should return boolean that + * specifies if this symbol should be added to this named result. + * @param {string} name for this result + * @param {function} symbol_control_function predicate for symbols + */ +LUA_FUNCTION_DEF(task, add_named_result); + +/** + * @method task:get_all_named_results() + * + * Returns all named results registered for the task as a table of strings + * @return {table|string} all named results starting from `default` + */ +LUA_FUNCTION_DEF(task, get_all_named_results); + +/*** + * @method task:get_dns_req() + * Get number of dns requests being sent in the task + * @return {number} number of DNS requests + */ +LUA_FUNCTION_DEF(task, get_dns_req); + +static const struct luaL_reg tasklib_f[] = { + LUA_INTERFACE_DEF(task, create), + LUA_INTERFACE_DEF(task, load_from_file), + LUA_INTERFACE_DEF(task, load_from_string), + {NULL, NULL}}; + +static const struct luaL_reg tasklib_m[] = { + LUA_INTERFACE_DEF(task, get_message), + LUA_INTERFACE_DEF(task, set_message), + LUA_INTERFACE_DEF(task, destroy), + LUA_INTERFACE_DEF(task, process_message), + LUA_INTERFACE_DEF(task, set_cfg), + LUA_INTERFACE_DEF(task, get_cfg), + LUA_INTERFACE_DEF(task, get_mempool), + LUA_INTERFACE_DEF(task, get_session), + LUA_INTERFACE_DEF(task, set_session), + LUA_INTERFACE_DEF(task, get_ev_base), + LUA_INTERFACE_DEF(task, get_worker), + LUA_INTERFACE_DEF(task, insert_result), + LUA_INTERFACE_DEF(task, insert_result_named), + LUA_INTERFACE_DEF(task, adjust_result), + LUA_INTERFACE_DEF(task, remove_result), + LUA_INTERFACE_DEF(task, set_pre_result), + LUA_INTERFACE_DEF(task, has_pre_result), + LUA_INTERFACE_DEF(task, append_message), + LUA_INTERFACE_DEF(task, has_urls), + LUA_INTERFACE_DEF(task, get_urls), + LUA_INTERFACE_DEF(task, get_urls_filtered), + LUA_INTERFACE_DEF(task, inject_url), + LUA_INTERFACE_DEF(task, get_content), + LUA_INTERFACE_DEF(task, get_filename), + LUA_INTERFACE_DEF(task, get_rawbody), + LUA_INTERFACE_DEF(task, get_emails), + LUA_INTERFACE_DEF(task, get_text_parts), + LUA_INTERFACE_DEF(task, get_parts), + LUA_INTERFACE_DEF(task, get_request_header), + LUA_INTERFACE_DEF(task, set_request_header), + LUA_INTERFACE_DEF(task, get_header), + LUA_INTERFACE_DEF(task, has_header), + LUA_INTERFACE_DEF(task, get_header_raw), + LUA_INTERFACE_DEF(task, get_header_full), + LUA_INTERFACE_DEF(task, get_header_count), + LUA_INTERFACE_DEF(task, get_raw_headers), + LUA_INTERFACE_DEF(task, get_headers), + LUA_INTERFACE_DEF(task, modify_header), + LUA_INTERFACE_DEF(task, get_received_headers), + LUA_INTERFACE_DEF(task, get_queue_id), + LUA_INTERFACE_DEF(task, get_uid), + LUA_INTERFACE_DEF(task, get_resolver), + LUA_INTERFACE_DEF(task, set_resolver), + LUA_INTERFACE_DEF(task, inc_dns_req), + LUA_INTERFACE_DEF(task, get_dns_req), + LUA_INTERFACE_DEF(task, has_recipients), + LUA_INTERFACE_DEF(task, get_recipients), + LUA_INTERFACE_DEF(task, set_recipients), + LUA_INTERFACE_DEF(task, get_principal_recipient), + LUA_INTERFACE_DEF(task, get_reply_sender), + LUA_INTERFACE_DEF(task, has_from), + LUA_INTERFACE_DEF(task, get_from), + LUA_INTERFACE_DEF(task, set_from), + LUA_INTERFACE_DEF(task, get_user), + LUA_INTERFACE_DEF(task, set_user), + {"get_addr", lua_task_get_from_ip}, + {"get_ip", lua_task_get_from_ip}, + {"get_from_addr", lua_task_get_from_ip}, + LUA_INTERFACE_DEF(task, get_from_ip), + LUA_INTERFACE_DEF(task, set_from_ip), + LUA_INTERFACE_DEF(task, get_from_ip_num), + LUA_INTERFACE_DEF(task, get_client_ip), + LUA_INTERFACE_DEF(task, get_subject), + LUA_INTERFACE_DEF(task, get_helo), + LUA_INTERFACE_DEF(task, set_helo), + LUA_INTERFACE_DEF(task, get_hostname), + LUA_INTERFACE_DEF(task, set_hostname), + LUA_INTERFACE_DEF(task, get_images), + LUA_INTERFACE_DEF(task, get_archives), + LUA_INTERFACE_DEF(task, get_dkim_results), + LUA_INTERFACE_DEF(task, get_symbol), + LUA_INTERFACE_DEF(task, get_symbols), + LUA_INTERFACE_DEF(task, get_symbols_all), + LUA_INTERFACE_DEF(task, get_symbols_numeric), + LUA_INTERFACE_DEF(task, get_symbols_tokens), + LUA_INTERFACE_DEF(task, get_groups), + LUA_INTERFACE_DEF(task, process_ann_tokens), + LUA_INTERFACE_DEF(task, has_symbol), + LUA_INTERFACE_DEF(task, enable_symbol), + LUA_INTERFACE_DEF(task, disable_symbol), + LUA_INTERFACE_DEF(task, get_date), + LUA_INTERFACE_DEF(task, get_message_id), + LUA_INTERFACE_DEF(task, get_timeval), + LUA_INTERFACE_DEF(task, get_scan_time), + LUA_INTERFACE_DEF(task, get_metric_result), + LUA_INTERFACE_DEF(task, get_metric_score), + LUA_INTERFACE_DEF(task, get_metric_action), + LUA_INTERFACE_DEF(task, set_metric_score), + LUA_INTERFACE_DEF(task, set_metric_subject), + LUA_INTERFACE_DEF(task, learn), + LUA_INTERFACE_DEF(task, set_settings), + LUA_INTERFACE_DEF(task, get_settings), + LUA_INTERFACE_DEF(task, lookup_settings), + LUA_INTERFACE_DEF(task, get_settings_id), + LUA_INTERFACE_DEF(task, set_settings_id), + LUA_INTERFACE_DEF(task, cache_get), + LUA_INTERFACE_DEF(task, cache_set), + LUA_INTERFACE_DEF(task, process_regexp), + LUA_INTERFACE_DEF(task, get_size), + LUA_INTERFACE_DEF(task, set_flag), + LUA_INTERFACE_DEF(task, get_flags), + LUA_INTERFACE_DEF(task, has_flag), + {"set_rmilter_reply", lua_task_set_milter_reply}, + LUA_INTERFACE_DEF(task, set_milter_reply), + LUA_INTERFACE_DEF(task, get_digest), + LUA_INTERFACE_DEF(task, store_in_file), + LUA_INTERFACE_DEF(task, get_protocol_reply), + LUA_INTERFACE_DEF(task, headers_foreach), + LUA_INTERFACE_DEF(task, disable_action), + LUA_INTERFACE_DEF(task, get_newlines_type), + LUA_INTERFACE_DEF(task, get_stat_tokens), + LUA_INTERFACE_DEF(task, get_meta_words), + LUA_INTERFACE_DEF(task, lookup_words), + LUA_INTERFACE_DEF(task, add_named_result), + LUA_INTERFACE_DEF(task, get_all_named_results), + LUA_INTERFACE_DEF(task, topointer), + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + +/* Image methods */ +LUA_FUNCTION_DEF(image, get_width); +LUA_FUNCTION_DEF(image, get_height); +LUA_FUNCTION_DEF(image, get_type); +LUA_FUNCTION_DEF(image, get_filename); +LUA_FUNCTION_DEF(image, get_size); + +static const struct luaL_reg imagelib_m[] = { + LUA_INTERFACE_DEF(image, get_width), + LUA_INTERFACE_DEF(image, get_height), + LUA_INTERFACE_DEF(image, get_type), + LUA_INTERFACE_DEF(image, get_filename), + LUA_INTERFACE_DEF(image, get_size), + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + +/* Archive methods */ +LUA_FUNCTION_DEF(archive, get_type); +LUA_FUNCTION_DEF(archive, get_files); +LUA_FUNCTION_DEF(archive, get_files_full); +/* TODO: Export archive flags as integers to use bitops for that */ +LUA_FUNCTION_DEF(archive, is_encrypted); +LUA_FUNCTION_DEF(archive, is_obfuscated); +LUA_FUNCTION_DEF(archive, is_unreadable); +LUA_FUNCTION_DEF(archive, get_filename); +LUA_FUNCTION_DEF(archive, get_size); + +static const struct luaL_reg archivelib_m[] = { + LUA_INTERFACE_DEF(archive, get_type), + LUA_INTERFACE_DEF(archive, get_files), + LUA_INTERFACE_DEF(archive, get_files_full), + LUA_INTERFACE_DEF(archive, is_encrypted), + LUA_INTERFACE_DEF(archive, is_obfuscated), + LUA_INTERFACE_DEF(archive, is_unreadable), + LUA_INTERFACE_DEF(archive, get_filename), + LUA_INTERFACE_DEF(archive, get_size), + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + +/* Utility functions */ +struct rspamd_task * +lua_check_task(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{task}"); + luaL_argcheck(L, ud != NULL, pos, "'task' expected"); + return ud ? *((struct rspamd_task **) ud) : NULL; +} + +struct rspamd_task * +lua_check_task_maybe(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata_maybe(L, pos, "rspamd{task}"); + + return ud ? *((struct rspamd_task **) ud) : NULL; +} + +static struct rspamd_image * +lua_check_image(lua_State *L) +{ + void *ud = rspamd_lua_check_udata(L, 1, "rspamd{image}"); + luaL_argcheck(L, ud != NULL, 1, "'image' expected"); + return ud ? *((struct rspamd_image **) ud) : NULL; +} + +static struct rspamd_archive * +lua_check_archive(lua_State *L) +{ + void *ud = rspamd_lua_check_udata(L, 1, "rspamd{archive}"); + luaL_argcheck(L, ud != NULL, 1, "'archive' expected"); + return ud ? *((struct rspamd_archive **) ud) : NULL; +} + +static void +lua_task_set_cached(lua_State *L, struct rspamd_task *task, const gchar *key, + gint pos) +{ + LUA_TRACE_POINT; + khiter_t k; + + lua_pushvalue(L, pos); + + k = kh_get(rspamd_task_lua_cache, &task->lua_cache, (char *) key); + + if (G_UNLIKELY(k != kh_end(&task->lua_cache))) { + /* Unref previous value */ + luaL_unref(L, LUA_REGISTRYINDEX, kh_value(&task->lua_cache, k).ref); + } + else { + int r; + + k = kh_put(rspamd_task_lua_cache, &task->lua_cache, rspamd_mempool_strdup(task->task_pool, key), &r); + } + + kh_value(&task->lua_cache, k).ref = luaL_ref(L, LUA_REGISTRYINDEX); + kh_value(&task->lua_cache, k).id = GPOINTER_TO_UINT(task->message); +} + + +static gboolean +lua_task_get_cached(lua_State *L, struct rspamd_task *task, const gchar *key) +{ + LUA_TRACE_POINT; + khiter_t k; + + k = kh_get(rspamd_task_lua_cache, &task->lua_cache, (char *) key); + + if (k != kh_end(&task->lua_cache) && (kh_value(&task->lua_cache, k).id == GPOINTER_TO_UINT(task->message))) { + lua_rawgeti(L, LUA_REGISTRYINDEX, kh_value(&task->lua_cache, k).ref); + + return TRUE; + } + + return FALSE; +} + +/* Task methods */ +static int +lua_task_process_message(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + gboolean enforce = FALSE; + + if (task != NULL) { + if (task->msg.len > 0) { + if (lua_isboolean(L, 2)) { + enforce = lua_toboolean(L, 2); + } + + if (rspamd_message_parse(task)) { + if (enforce || + (!(task->flags & RSPAMD_TASK_FLAG_SKIP_PROCESS) && + !(task->processed_stages & RSPAMD_TASK_STAGE_PROCESS_MESSAGE))) { + + lua_pushboolean(L, TRUE); + rspamd_message_process(task); + task->processed_stages |= RSPAMD_TASK_STAGE_PROCESS_MESSAGE; + } + else { + lua_pushboolean(L, FALSE); + } + } + else { + lua_pushboolean(L, FALSE); + } + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static int +lua_task_get_cfg(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + struct rspamd_config **pcfg; + + if (task) { + pcfg = lua_newuserdata(L, sizeof(gpointer)); + rspamd_lua_setclass(L, "rspamd{config}", -1); + *pcfg = task->cfg; + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static int +lua_task_set_cfg(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + void *ud = rspamd_lua_check_udata(L, 2, "rspamd{config}"); + + if (task) { + luaL_argcheck(L, ud != NULL, 1, "'config' expected"); + task->cfg = ud ? *((struct rspamd_config **) ud) : NULL; + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static int +lua_task_destroy(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + + if (task != NULL) { + rspamd_task_free(task); + } + + return 0; +} + +static int +lua_task_get_message(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t; + struct rspamd_task *task = lua_check_task(L, 1); + + if (task) { + t = lua_newuserdata(L, sizeof(*t)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + t->flags = 0; + t->start = task->msg.begin; + t->len = task->msg.len; + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static int +lua_task_set_message(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t; + struct rspamd_task *task = lua_check_task(L, 1); + gboolean message_set = FALSE; + + if (task) { + gsize final_len = 0; + gchar *buf = NULL; + + if (lua_type(L, 2) == LUA_TTABLE) { + /* Piecewise construct */ + guint vec_len = rspamd_lua_table_size(L, 2); + + + for (guint i = 0; i < vec_len; i++) { + lua_rawgeti(L, 2, i + 1); + + if (lua_type(L, -1) == LUA_TSTRING) { + gsize l; + + (void) lua_tolstring(L, -1, &l); + final_len += l; + } + else { + t = lua_check_text(L, -1); + + if (t) { + final_len += t->len; + } + } + + lua_pop(L, 1); + } + + if (final_len > 0) { + gchar *pos; + + buf = rspamd_mempool_alloc(task->task_pool, final_len); + pos = buf; + + for (guint i = 0; i < vec_len; i++) { + lua_rawgeti(L, 2, i + 1); + + if (lua_type(L, -1) == LUA_TSTRING) { + gsize l; + const gchar *s; + + s = lua_tolstring(L, -1, &l); + memcpy(pos, s, l); + pos += l; + } + else { + t = lua_check_text(L, -1); + + if (t) { + memcpy(pos, t->start, t->len); + pos += t->len; + } + } + + lua_pop(L, 1); + } + + task->flags |= RSPAMD_TASK_FLAG_MESSAGE_REWRITE; + task->msg.begin = buf; + task->msg.len = final_len; + message_set = TRUE; + } + } + else { + if (lua_type(L, 2) == LUA_TSTRING) { + const gchar *s; + + s = lua_tolstring(L, -1, &final_len); + buf = rspamd_mempool_alloc(task->task_pool, final_len); + memcpy(buf, s, final_len); + } + else { + t = lua_check_text(L, -1); + + if (t) { + final_len = t->len; + buf = rspamd_mempool_alloc(task->task_pool, final_len); + memcpy(buf, t->start, final_len); + } + } + + if (buf) { + task->msg.begin = buf; + task->msg.len = final_len; + task->flags |= RSPAMD_TASK_FLAG_MESSAGE_REWRITE; + message_set = TRUE; + } + } + + if (message_set) { + if (rspamd_message_parse(task)) { + rspamd_message_process(task); + lua_pushboolean(L, TRUE); + lua_pushinteger(L, final_len); + + return 2; + } + else { + lua_pushboolean(L, FALSE); + } + } + else { + lua_pushboolean(L, FALSE); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static void +lua_task_unmap_dtor(gpointer p) +{ + struct rspamd_task *task = (struct rspamd_task *) p; + + if (task->msg.begin) { + munmap((gpointer) task->msg.begin, task->msg.len); + } +} + +static void +lua_task_free_dtor(gpointer p) +{ + g_free(p); +} + +static gint +lua_task_load_from_file(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = NULL, **ptask; + const gchar *fname = luaL_checkstring(L, 1), *err = NULL; + struct rspamd_config *cfg = NULL; + gboolean res = FALSE; + gpointer map; + gsize sz; + + if (fname) { + + if (lua_type(L, 2) == LUA_TUSERDATA) { + gpointer p; + p = rspamd_lua_check_udata_maybe(L, 2, "rspamd{config}"); + + if (p) { + cfg = *(struct rspamd_config **) p; + } + } + + if (strcmp(fname, "-") == 0) { + /* Read from stdin */ + gint fd = STDIN_FILENO; + GString *data = g_string_sized_new(BUFSIZ); + gchar buf[BUFSIZ]; + gssize r; + + for (;;) { + r = read(fd, buf, sizeof(buf)); + + if (r == -1) { + err = strerror(errno); + break; + } + else if (r == 0) { + break; + } + else { + g_string_append_len(data, buf, r); + } + } + + task = rspamd_task_new(NULL, cfg, NULL, NULL, NULL, FALSE); + task->msg.begin = data->str; + task->msg.len = data->len; + rspamd_mempool_add_destructor(task->task_pool, + lua_task_free_dtor, data->str); + res = TRUE; + g_string_free(data, FALSE); /* Buffer is still valid */ + } + else { + map = rspamd_file_xmap(fname, PROT_READ, &sz, TRUE); + + if (!map) { + err = strerror(errno); + } + else { + task = rspamd_task_new(NULL, cfg, NULL, NULL, NULL, FALSE); + task->msg.begin = map; + task->msg.len = sz; + rspamd_mempool_add_destructor(task->task_pool, + lua_task_unmap_dtor, task); + res = TRUE; + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + lua_pushboolean(L, res); + + if (res) { + ptask = lua_newuserdata(L, sizeof(*ptask)); + *ptask = task; + rspamd_lua_setclass(L, "rspamd{task}", -1); + } + else { + if (err) { + lua_pushstring(L, err); + } + else { + lua_pushnil(L); + } + } + + return 2; +} + +static gint +lua_task_load_from_string(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = NULL, **ptask; + const gchar *str_message; + gsize message_len; + struct rspamd_config *cfg = NULL; + + str_message = luaL_checklstring(L, 1, &message_len); + + if (str_message) { + + if (lua_type(L, 2) == LUA_TUSERDATA) { + gpointer p; + p = rspamd_lua_check_udata_maybe(L, 2, "rspamd{config}"); + + if (p) { + cfg = *(struct rspamd_config **) p; + } + } + + task = rspamd_task_new(NULL, cfg, NULL, NULL, NULL, FALSE); + task->msg.begin = g_malloc(message_len); + memcpy((gchar *) task->msg.begin, str_message, message_len); + task->msg.len = message_len; + rspamd_mempool_add_destructor(task->task_pool, lua_task_free_dtor, + (gpointer) task->msg.begin); + } + else { + return luaL_error(L, "invalid arguments"); + } + + lua_pushboolean(L, true); + + ptask = lua_newuserdata(L, sizeof(*ptask)); + *ptask = task; + rspamd_lua_setclass(L, "rspamd{task}", -1); + + return 2; +} + +static gint +lua_task_create(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = NULL, **ptask; + struct rspamd_config *cfg = NULL; + struct ev_loop *ev_base = NULL; + + if (lua_type(L, 1) == LUA_TUSERDATA) { + gpointer p; + p = rspamd_lua_check_udata_maybe(L, 1, "rspamd{config}"); + + if (p) { + cfg = *(struct rspamd_config **) p; + } + } + + if (lua_type(L, 2) == LUA_TUSERDATA) { + gpointer p; + p = rspamd_lua_check_udata_maybe(L, 2, "rspamd{ev_base}"); + + if (p) { + ev_base = *(struct ev_loop **) p; + } + } + + task = rspamd_task_new(NULL, cfg, NULL, NULL, ev_base, FALSE); + task->flags |= RSPAMD_TASK_FLAG_EMPTY; + + ptask = lua_newuserdata(L, sizeof(*ptask)); + *ptask = task; + rspamd_lua_setclass(L, "rspamd{task}", -1); + + return 1; +} + +static int +lua_task_get_mempool(lua_State *L) +{ + LUA_TRACE_POINT; + rspamd_mempool_t **ppool; + struct rspamd_task *task = lua_check_task(L, 1); + + if (task != NULL) { + ppool = lua_newuserdata(L, sizeof(rspamd_mempool_t *)); + rspamd_lua_setclass(L, "rspamd{mempool}", -1); + *ppool = task->task_pool; + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static int +lua_task_get_session(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_async_session **psession; + struct rspamd_task *task = lua_check_task(L, 1); + + if (task != NULL) { + psession = lua_newuserdata(L, sizeof(void *)); + rspamd_lua_setclass(L, "rspamd{session}", -1); + *psession = task->s; + } + else { + return luaL_error(L, "invalid arguments"); + } + return 1; +} + +static int +lua_task_set_session(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_async_session *session = lua_check_session(L, 2); + struct rspamd_task *task = lua_check_task(L, 1); + + if (task != NULL && session != NULL) { + task->s = session; + } + else { + return luaL_error(L, "invalid arguments"); + } + return 1; +} + +static int +lua_task_get_ev_base(lua_State *L) +{ + LUA_TRACE_POINT; + struct ev_loop **pbase; + struct rspamd_task *task = lua_check_task(L, 1); + + if (task != NULL) { + pbase = lua_newuserdata(L, sizeof(struct ev_loop *)); + rspamd_lua_setclass(L, "rspamd{ev_base}", -1); + *pbase = task->event_loop; + } + else { + return luaL_error(L, "invalid arguments"); + } + return 1; +} + +static int +lua_task_get_worker(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_worker **pworker; + struct rspamd_task *task = lua_check_task(L, 1); + + if (task != NULL) { + if (task->worker) { + pworker = lua_newuserdata(L, sizeof(struct rspamd_worker *)); + rspamd_lua_setclass(L, "rspamd{worker}", -1); + *pworker = task->worker; + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + return 1; +} + + +static gint +lua_task_insert_result_common(lua_State *L, struct rspamd_scan_result *result, + gint common_args_pos) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *symbol_name; + double weight; + struct rspamd_symbol_result *s; + enum rspamd_symbol_insert_flags flags = RSPAMD_SYMBOL_INSERT_DEFAULT; + gint i, top, args_start; + + if (task != NULL) { + if (lua_isboolean(L, common_args_pos)) { + args_start = common_args_pos + 1; + + if (lua_toboolean(L, common_args_pos)) { + flags |= RSPAMD_SYMBOL_INSERT_ENFORCE; + } + } + else { + args_start = common_args_pos; + } + + symbol_name = rspamd_mempool_strdup(task->task_pool, + luaL_checkstring(L, args_start)); + weight = luaL_checknumber(L, args_start + 1); + top = lua_gettop(L); + s = rspamd_task_insert_result_full(task, symbol_name, weight, + NULL, flags, result); + + /* Get additional options */ + if (s) { + if (s->sym == NULL) { + /* Unknown symbol, print traceback */ + lua_pushfstring(L, "unknown symbol %s", symbol_name); + rspamd_lua_traceback(L); + + msg_info_task("symbol insertion issue: %s", lua_tostring(L, -1)); + + lua_pop(L, 1); /* Traceback string */ + } + for (i = args_start + 2; i <= top; i++) { + gint ltype = lua_type(L, i); + + if (ltype == LUA_TSTRING) { + gsize optlen; + const char *opt = lua_tolstring(L, i, &optlen); + + rspamd_task_add_result_option(task, s, opt, optlen); + } + else if (ltype == LUA_TUSERDATA) { + struct rspamd_lua_text *t = lua_check_text(L, i); + + if (t) { + rspamd_task_add_result_option(task, s, t->start, + t->len); + } + } + else if (ltype == LUA_TTABLE) { + gsize objlen = rspamd_lua_table_size(L, i); + + for (guint j = 1; j <= objlen; j++) { + lua_rawgeti(L, i, j); + + if (lua_type(L, -1) == LUA_TSTRING) { + gsize optlen; + const char *opt = lua_tolstring(L, -1, &optlen); + + rspamd_task_add_result_option(task, s, opt, optlen); + } + else if (lua_type(L, -1) == LUA_TUSERDATA) { + struct rspamd_lua_text *t = lua_check_text(L, -1); + + if (t) { + rspamd_task_add_result_option(task, s, t->start, + t->len); + } + else { + return luaL_error(L, "not rspamd_text option in a table " + "when adding symbol %s: %s type", + s->name); + } + } + else { + const gchar *tname = lua_typename(L, lua_type(L, -1)); + lua_pop(L, 2); + + return luaL_error(L, "not a string option in a table " + "when adding symbol %s: %s type", + s->name, tname); + } + + lua_pop(L, 1); + } + } + else if (ltype == LUA_TNIL) { + /* We have received a NULL option, it is not good but not a fatal error */ + msg_info_task("nil option when adding symbol %s at pos %d", + s->name, i); + continue; + } + else { + const gchar *tname = lua_typename(L, ltype); + + return luaL_error(L, "not a string/table option " + "when adding symbol %s: %s type", + s->name, tname); + } + } + } + else if (task->settings == NULL && task->settings_elt == NULL) { + lua_pushfstring(L, "insertion failed for %s", symbol_name); + rspamd_lua_traceback(L); + + msg_info_task("symbol insertion issue: %s", lua_tostring(L, -1)); + + lua_pop(L, 2); /* Traceback string + error string */ + } + else { + /* Usually denied by settings */ + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static gint +lua_task_insert_result(lua_State *L) +{ + return lua_task_insert_result_common(L, NULL, 2); +} + +static gint +lua_task_insert_result_named(lua_State *L) +{ + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *named_result = luaL_checkstring(L, 2); + struct rspamd_scan_result *res; + + if (task && named_result) { + res = rspamd_find_metric_result(task, named_result); + + if (res == NULL) { + return luaL_error(L, "invalid arguments: bad named result: %s", + named_result); + } + + return lua_task_insert_result_common(L, res, 3); + } + + return luaL_error(L, "invalid arguments"); +} + +static gint +lua_task_adjust_result(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *symbol_name; + struct rspamd_scan_result *metric_res; + struct rspamd_symbol_result *s = NULL; + double weight; + gint i, top; + + if (task != NULL) { + + symbol_name = luaL_checkstring(L, 2); + weight = luaL_checknumber(L, 3); + top = lua_gettop(L); + metric_res = task->result; + + if (metric_res) { + s = rspamd_task_find_symbol_result(task, symbol_name, NULL); + } + else { + return luaL_error(L, "no metric result"); + } + + if (s) { + if (!isnan(weight)) { + metric_res->score -= s->score; + s->score = weight; + metric_res->score += s->score; + } + } + else { + return luaL_error(L, "symbol not found: %s", symbol_name); + } + + /* Get additional options */ + if (s) { + for (i = 4; i <= top; i++) { + if (lua_type(L, i) == LUA_TSTRING) { + gsize optlen; + const char *opt = lua_tolstring(L, i, &optlen); + + rspamd_task_add_result_option(task, s, opt, optlen); + } + else if (lua_type(L, i) == LUA_TUSERDATA) { + struct rspamd_lua_text *t = lua_check_text(L, i); + + if (t) { + rspamd_task_add_result_option(task, s, t->start, + t->len); + } + } + else if (lua_type(L, i) == LUA_TTABLE) { + gsize objlen = rspamd_lua_table_size(L, i); + + for (guint j = 1; j <= objlen; j++) { + lua_rawgeti(L, i, j); + + if (lua_type(L, -1) == LUA_TSTRING) { + gsize optlen; + const char *opt = lua_tolstring(L, -1, &optlen); + + rspamd_task_add_result_option(task, s, opt, optlen); + } + else if (lua_type(L, -1) == LUA_TUSERDATA) { + struct rspamd_lua_text *t = lua_check_text(L, -1); + + if (t) { + rspamd_task_add_result_option(task, s, t->start, + t->len); + } + } + + lua_pop(L, 1); + } + } + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static gint +lua_task_remove_result(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *symbol_name = luaL_checkstring(L, 2); + struct rspamd_scan_result *metric_res; + const gchar *named_result = luaL_optstring(L, 3, NULL); + + if (task != NULL) { + metric_res = rspamd_find_metric_result(task, named_result); + + if (metric_res == NULL) { + return luaL_error(L, "invalid arguments: bad named result: %s", + named_result); + } + + lua_pushboolean(L, (rspamd_task_remove_symbol_result(task, symbol_name, + metric_res)) != NULL); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_set_pre_result(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *message = NULL, *module = NULL, *fl_str = NULL, *act_str = NULL, + *res_name = NULL; + gdouble score = NAN; + struct rspamd_action *action; + guint priority = RSPAMD_PASSTHROUGH_NORMAL, flags = 0; + + if (task != NULL) { + + if (RSPAMD_TASK_IS_SKIPPED(task)) { + /* Do not set pre-result for a skipped task */ + return 0; + } + + if (lua_type(L, 2) == LUA_TTABLE) { + GError *err = NULL; + + if (!rspamd_lua_parse_table_arguments(L, 2, &err, + RSPAMD_LUA_PARSE_ARGUMENTS_DEFAULT, + "*action=S;message=S;module=S;score=D;priority=i;flags=S;result=S", + &act_str, &message, &module, &score, &priority, &fl_str, &res_name)) { + gint ret = luaL_error(L, "invalid arguments: %s", err->message); + g_error_free(err); + + return ret; + } + } + else { + if (lua_type(L, 2) == LUA_TSTRING) { + act_str = lua_tostring(L, 2); + } + else { + return luaL_error(L, "invalid arguments"); + } + + if (lua_type(L, 3) == LUA_TSTRING) { + message = lua_tostring(L, 3); + } + + if (lua_type(L, 4) == LUA_TSTRING) { + module = lua_tostring(L, 4); + } + + if (lua_type(L, 5) == LUA_TNUMBER) { + score = lua_tonumber(L, 5); + } + + if (lua_type(L, 6) == LUA_TNUMBER) { + priority = lua_tointeger(L, 6); + } + + if (lua_type(L, 7) == LUA_TSTRING) { + fl_str = lua_tostring(L, 7); + } + } + + enum rspamd_action_type internal_type; + + if (strcmp(act_str, "accept") == 0) { + /* Compatibility! */ + act_str = "no action"; + } + else if (rspamd_action_from_str(act_str, &internal_type)) { + /* Compatibility! */ + act_str = rspamd_action_to_str(internal_type); + } + + action = rspamd_config_get_action(task->cfg, act_str); + + if (action == NULL) { + return luaL_error(L, "unknown action %s", act_str); + } + + if (module == NULL) { + module = "Unknown lua"; + } + + if (message == NULL) { + message = "unknown reason"; + flags |= RSPAMD_PASSTHROUGH_NO_SMTP_MESSAGE; + } + + if (fl_str != NULL) { + /* + * TODO: convert to a set of string and split by `,` + add table support + * once this legacy code is migrated to C++ + */ + if (strstr(fl_str, "least") != NULL) { + flags |= RSPAMD_PASSTHROUGH_LEAST; + } + else if (strstr(fl_str, "no_smtp_message") != NULL) { + flags |= RSPAMD_PASSTHROUGH_NO_SMTP_MESSAGE; + } + else if (strstr(fl_str, "process_all") != NULL) { + flags |= RSPAMD_PASSTHROUGH_PROCESS_ALL; + } + } + + + rspamd_add_passthrough_result(task, + action, + priority, + score, + rspamd_mempool_strdup(task->task_pool, message), + rspamd_mempool_strdup(task->task_pool, module), + flags, + rspamd_find_metric_result(task, res_name)); + + /* Don't classify or filter message if pre-filter sets results */ + + if (res_name == NULL && !(flags & (RSPAMD_PASSTHROUGH_LEAST | RSPAMD_PASSTHROUGH_PROCESS_ALL))) { + task->processed_stages |= (RSPAMD_TASK_STAGE_CLASSIFIERS | + RSPAMD_TASK_STAGE_CLASSIFIERS_PRE | + RSPAMD_TASK_STAGE_CLASSIFIERS_POST); + rspamd_symcache_disable_all_symbols(task, task->cfg->cache, + SYMBOL_TYPE_IDEMPOTENT | SYMBOL_TYPE_IGNORE_PASSTHROUGH); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static gint +lua_task_has_pre_result(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + gint nret = 1; + + if (task) { + if (task->result->passthrough_result) { + struct rspamd_passthrough_result *pr = task->result->passthrough_result; + + lua_pushboolean(L, true); + nret = 4; + /* bool, action, message, module */ + + if (pr->action) { + lua_pushstring(L, rspamd_action_to_str(pr->action->action_type)); + } + else { + lua_pushnil(L); + } + + if (pr->message) { + lua_pushstring(L, pr->message); + } + else { + lua_pushnil(L); + } + if (pr->module) { + lua_pushstring(L, pr->module); + } + else { + lua_pushnil(L); + } + } + else { + lua_pushboolean(L, false); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return nret; +} + +static gint +lua_task_append_message(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *category; + + if (task != NULL) { + if (lua_type(L, 3) == LUA_TSTRING) { + category = luaL_checkstring(L, 3); + } + else { + category = "unknown"; + } + + ucl_object_insert_key(task->messages, + ucl_object_lua_import(L, 2), + category, 0, + true); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + + +static gint +lua_task_get_urls(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + struct lua_tree_cb_data cb; + struct rspamd_url *u; + static const gint default_protocols_mask = PROTOCOL_HTTP | PROTOCOL_HTTPS | + PROTOCOL_FILE | PROTOCOL_FTP; + gsize sz, max_urls = 0; + + if (task) { + if (task->cfg) { + max_urls = task->cfg->max_lua_urls; + } + + if (task->message == NULL) { + lua_newtable(L); + + return 1; + } + + /* Exclude RSPAMD_URL_FLAG_CONTENT to preserve backward compatibility */ + if (!lua_url_cbdata_fill(L, 2, &cb, default_protocols_mask, + ~(RSPAMD_URL_FLAG_CONTENT | RSPAMD_URL_FLAG_IMAGE), + max_urls)) { + return luaL_error(L, "invalid arguments"); + } + + sz = kh_size(MESSAGE_FIELD(task, urls)); + sz = lua_url_adjust_skip_prob(task->task_timestamp, + MESSAGE_FIELD(task, digest), &cb, sz); + + lua_createtable(L, sz, 0); + + if (cb.sort) { + struct rspamd_url **urls_sorted; + gint i = 0; + + urls_sorted = g_new0(struct rspamd_url *, sz); + + kh_foreach_key(MESSAGE_FIELD(task, urls), u, { + if (i < sz) { + urls_sorted[i] = u; + i++; + } + }); + + qsort(urls_sorted, i, sizeof(struct rspamd_url *), rspamd_url_cmp_qsort); + + for (int j = 0; j < i; j++) { + lua_tree_url_callback(urls_sorted[j], urls_sorted[j], &cb); + } + + g_free(urls_sorted); + } + else { + kh_foreach_key(MESSAGE_FIELD(task, urls), u, { + lua_tree_url_callback(u, u, &cb); + }); + } + + lua_url_cbdata_dtor(&cb); + } + else { + return luaL_error(L, "invalid arguments, no task"); + } + + return 1; +} + +static gint +lua_task_get_urls_filtered(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + struct lua_tree_cb_data cb; + struct rspamd_url *u; + static const gint default_protocols_mask = PROTOCOL_HTTP | PROTOCOL_HTTPS | + PROTOCOL_FILE | PROTOCOL_FTP; + gsize sz, max_urls = 0; + + if (task) { + if (task->cfg) { + max_urls = task->cfg->max_lua_urls; + } + + if (task->message == NULL) { + lua_newtable(L); + + return 1; + } + + if (!lua_url_cbdata_fill_exclude_include(L, 2, &cb, default_protocols_mask, max_urls)) { + return luaL_error(L, "invalid arguments"); + } + + sz = kh_size(MESSAGE_FIELD(task, urls)); + sz = lua_url_adjust_skip_prob(task->task_timestamp, + MESSAGE_FIELD(task, digest), &cb, sz); + + lua_createtable(L, sz, 0); + + if (cb.sort) { + struct rspamd_url **urls_sorted; + gint i = 0; + + urls_sorted = g_new0(struct rspamd_url *, sz); + + kh_foreach_key(MESSAGE_FIELD(task, urls), u, { + if (i < sz) { + urls_sorted[i] = u; + i++; + } + }); + + qsort(urls_sorted, i, sizeof(struct rspamd_url *), rspamd_url_cmp_qsort); + + for (int j = 0; j < i; j++) { + lua_tree_url_callback(urls_sorted[j], urls_sorted[j], &cb); + } + + g_free(urls_sorted); + } + else { + kh_foreach_key(MESSAGE_FIELD(task, urls), u, { + lua_tree_url_callback(u, u, &cb); + }); + } + + lua_url_cbdata_dtor(&cb); + } + else { + return luaL_error(L, "invalid arguments, no task"); + } + + return 1; +} + +static gint +lua_task_has_urls(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + bool need_emails = false; + gboolean ret = FALSE; + gsize sz = 0; + + if (task) { + if (task->message) { + if (lua_gettop(L) >= 2) { + need_emails = lua_toboolean(L, 2); + } + + if (need_emails) { + /* Simplified check */ + if (kh_size(MESSAGE_FIELD(task, urls)) > 0) { + sz += kh_size(MESSAGE_FIELD(task, urls)); + ret = TRUE; + } + } + else { + /* Linear scan */ + struct rspamd_url *u; + kh_foreach_key(MESSAGE_FIELD(task, urls), u, { + if (u->protocol != PROTOCOL_MAILTO) { + sz++; + ret = TRUE; + } + }); + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + lua_pushboolean(L, ret); + lua_pushinteger(L, sz); + + return 2; +} + +static gint +lua_task_inject_url(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + struct rspamd_lua_url *url = lua_check_url(L, 2); + struct rspamd_mime_part *mpart = NULL; + + if (lua_isuserdata(L, 3)) { + /* We also have a mime part there */ + mpart = *((struct rspamd_mime_part **) rspamd_lua_check_udata_maybe(L, + 3, "rspamd{mimepart}")); + } + + if (task && task->message && url && url->url) { + if (rspamd_url_set_add_or_increase(MESSAGE_FIELD(task, urls), url->url, false)) { + if (mpart && mpart->urls) { + /* Also add url to the mime part */ + g_ptr_array_add(mpart->urls, url->url); + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static gint +lua_task_get_content(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + struct rspamd_lua_text *t; + + if (task) { + t = lua_newuserdata(L, sizeof(*t)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + t->len = task->msg.len; + t->start = task->msg.begin; + t->flags = 0; + + if (lua_is_text_binary(t)) { + t->flags |= RSPAMD_TEXT_FLAG_BINARY; + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_filename(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + + if (task) { + if (task->msg.fpath) { + lua_pushstring(L, task->msg.fpath); + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_rawbody(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + struct rspamd_lua_text *t; + + if (task) { + if (task->message != NULL) { + + + if (MESSAGE_FIELD(task, raw_headers_content).len > 0) { + g_assert(MESSAGE_FIELD(task, raw_headers_content).len <= task->msg.len); + t = lua_new_text_task(L, task, task->msg.begin + MESSAGE_FIELD(task, raw_headers_content).len, + task->msg.len - MESSAGE_FIELD(task, raw_headers_content).len, FALSE); + } + else { + t = lua_new_text_task(L, task, task->msg.begin, + task->msg.len, FALSE); + } + + t->flags = 0; + } + else { + /* Push body it it is there */ + if (task->msg.len > 0 && task->msg.begin != NULL) { + lua_new_text_task(L, task, task->msg.begin, task->msg.len, FALSE); + } + else { + lua_pushnil(L); + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_emails(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + struct lua_tree_cb_data cb; + struct rspamd_url *u; + gsize max_urls = 0, sz; + + if (task) { + if (task->message) { + if (task->cfg) { + max_urls = task->cfg->max_lua_urls; + } + + if (!lua_url_cbdata_fill(L, 2, &cb, PROTOCOL_MAILTO, + ~(RSPAMD_URL_FLAG_CONTENT | RSPAMD_URL_FLAG_IMAGE), + max_urls)) { + return luaL_error(L, "invalid arguments"); + } + + sz = kh_size(MESSAGE_FIELD(task, urls)); + sz = lua_url_adjust_skip_prob(task->task_timestamp, + MESSAGE_FIELD(task, digest), &cb, sz); + + lua_createtable(L, sz, 0); + + kh_foreach_key(MESSAGE_FIELD(task, urls), u, { + lua_tree_url_callback(u, u, &cb); + }); + + lua_url_cbdata_dtor(&cb); + } + else { + lua_newtable(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_text_parts(lua_State *L) +{ + LUA_TRACE_POINT; + guint i; + struct rspamd_task *task = lua_check_task(L, 1); + struct rspamd_mime_text_part *part, **ppart; + + if (task != NULL) { + + if (task->message) { + if (!lua_task_get_cached(L, task, "text_parts")) { + lua_createtable(L, MESSAGE_FIELD(task, text_parts)->len, 0); + + PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, part) + { + ppart = lua_newuserdata(L, sizeof(struct rspamd_mime_text_part *)); + *ppart = part; + rspamd_lua_setclass(L, "rspamd{textpart}", -1); + /* Make it array */ + lua_rawseti(L, -2, i + 1); + } + + lua_task_set_cached(L, task, "text_parts", -1); + } + } + else { + lua_newtable(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_parts(lua_State *L) +{ + LUA_TRACE_POINT; + guint i; + struct rspamd_task *task = lua_check_task(L, 1); + struct rspamd_mime_part *part, **ppart; + + if (task != NULL) { + if (task->message) { + lua_createtable(L, MESSAGE_FIELD(task, parts)->len, 0); + + PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, parts), i, part) + { + ppart = lua_newuserdata(L, sizeof(struct rspamd_mime_part *)); + *ppart = part; + rspamd_lua_setclass(L, "rspamd{mimepart}", -1); + /* Make it array */ + lua_rawseti(L, -2, i + 1); + } + } + else { + lua_newtable(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_request_header(lua_State *L) +{ + LUA_TRACE_POINT; + rspamd_ftok_t *hdr; + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *s; + struct rspamd_lua_text *t; + + s = luaL_checkstring(L, 2); + + if (s && task) { + hdr = rspamd_task_get_request_header(task, s); + + if (hdr) { + t = lua_newuserdata(L, sizeof(*t)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + t->start = hdr->begin; + t->len = hdr->len; + t->flags = 0; + + return 1; + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_set_request_header(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *s, *v = NULL; + rspamd_fstring_t *buf; + struct rspamd_lua_text *t; + rspamd_ftok_t *hdr, *new_name; + gsize len, vlen = 0; + + s = luaL_checklstring(L, 2, &len); + + if (s && task) { + if (lua_type(L, 3) == LUA_TSTRING) { + v = luaL_checklstring(L, 3, &vlen); + } + else if (lua_type(L, 3) == LUA_TUSERDATA) { + t = lua_check_text(L, 3); + + if (t != NULL) { + v = t->start; + vlen = t->len; + } + } + + if (v != NULL) { + buf = rspamd_fstring_new_init(v, vlen); + hdr = rspamd_ftok_map(buf); + buf = rspamd_fstring_new_init(s, len); + new_name = rspamd_ftok_map(buf); + + rspamd_task_add_request_header(task, new_name, hdr); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + + return 0; +} + + +gint rspamd_lua_push_header(lua_State *L, struct rspamd_mime_header *rh, + enum rspamd_lua_task_header_type how) +{ + LUA_TRACE_POINT; + switch (how) { + case RSPAMD_TASK_HEADER_PUSH_FULL: + /* Create new associated table for a header */ + lua_createtable(L, 0, 7); + rspamd_lua_table_set(L, "name", rh->name); + + if (rh->value) { + rspamd_lua_table_set(L, "value", rh->value); + } + + if (rh->raw_len > 0) { + lua_pushstring(L, "raw"); + lua_pushlstring(L, rh->raw_value, rh->raw_len); + lua_settable(L, -3); + } + + if (rh->decoded) { + rspamd_lua_table_set(L, "decoded", rh->decoded); + } + + lua_pushstring(L, "tab_separated"); + lua_pushboolean(L, rh->flags & RSPAMD_HEADER_TAB_SEPARATED); + lua_settable(L, -3); + lua_pushstring(L, "empty_separator"); + lua_pushboolean(L, rh->flags & RSPAMD_HEADER_EMPTY_SEPARATOR); + lua_settable(L, -3); + rspamd_lua_table_set(L, "separator", rh->separator); + lua_pushstring(L, "order"); + lua_pushinteger(L, rh->order); + lua_settable(L, -3); + break; + case RSPAMD_TASK_HEADER_PUSH_RAW: + if (rh->value) { + lua_pushstring(L, rh->value); + } + else { + lua_pushnil(L); + } + break; + case RSPAMD_TASK_HEADER_PUSH_SIMPLE: + if (rh->decoded) { + lua_pushstring(L, rh->decoded); + } + else { + lua_pushnil(L); + } + break; + case RSPAMD_TASK_HEADER_PUSH_COUNT: + default: + g_assert_not_reached(); + break; + } + + return 1; +} + +gint rspamd_lua_push_header_array(lua_State *L, + const gchar *name, + struct rspamd_mime_header *rh, + enum rspamd_lua_task_header_type how, + gboolean strong) +{ + LUA_TRACE_POINT; + struct rspamd_mime_header *cur; + guint i; + gint nret = 1; + + if (rh == NULL) { + if (how == RSPAMD_TASK_HEADER_PUSH_HAS) { + lua_pushboolean(L, false); + nret = 1; + } + else if (how == RSPAMD_TASK_HEADER_PUSH_COUNT) { + lua_pushnumber(L, 0); + } + else { + lua_pushnil(L); + } + + return nret; + } + + if (how == RSPAMD_TASK_HEADER_PUSH_FULL) { + lua_createtable(L, 0, 0); + i = 0; + + DL_FOREACH(rh, cur) + { + if (!strong || strcmp(name, cur->name) == 0) { + rspamd_lua_push_header(L, cur, how); + lua_rawseti(L, -2, ++i); + } + } + } + else if (how == RSPAMD_TASK_HEADER_PUSH_COUNT) { + i = 0; + + DL_FOREACH(rh, cur) + { + if (!strong || strcmp(name, cur->name) == 0) { + i++; + } + } + + lua_pushinteger(L, i); + } + else if (how == RSPAMD_TASK_HEADER_PUSH_HAS) { + nret = 1; + bool found = false; + + if (strong) { + /* We still have to check all headers in the chain */ + DL_FOREACH(rh, cur) + { + if (strcmp(name, cur->name) == 0) { + found = true; + break; + } + } + } + else { + found = true; + } + + lua_pushboolean(L, found); + } + else { + DL_FOREACH(rh, cur) + { + if (!strong || strcmp(name, cur->name) == 0) { + return rspamd_lua_push_header(L, cur, how); + } + } + + /* Not found with this case */ + lua_pushnil(L); + } + + return nret; +} + +static gint +lua_task_get_header_common(lua_State *L, enum rspamd_lua_task_header_type how) +{ + LUA_TRACE_POINT; + gboolean strong = FALSE, need_modified = FALSE; + struct rspamd_task *task = lua_check_task(L, 1); + struct rspamd_mime_header *rh; + const gchar *name; + + name = luaL_checkstring(L, 2); + + if (name && task) { + if (lua_gettop(L) >= 3) { + strong = lua_toboolean(L, 3); + if (lua_isboolean(L, 4)) { + need_modified = lua_toboolean(L, 4); + } + } + + + rh = rspamd_message_get_header_array(task, name, need_modified); + + return rspamd_lua_push_header_array(L, name, rh, how, strong); + } + else { + return luaL_error(L, "invalid arguments"); + } +} + +static gint +lua_task_get_header_full(lua_State *L) +{ + return lua_task_get_header_common(L, RSPAMD_TASK_HEADER_PUSH_FULL); +} + +static gint +lua_task_get_header(lua_State *L) +{ + return lua_task_get_header_common(L, RSPAMD_TASK_HEADER_PUSH_SIMPLE); +} + +static gint +lua_task_get_header_raw(lua_State *L) +{ + return lua_task_get_header_common(L, RSPAMD_TASK_HEADER_PUSH_RAW); +} + +static gint +lua_task_get_header_count(lua_State *L) +{ + return lua_task_get_header_common(L, RSPAMD_TASK_HEADER_PUSH_COUNT); +} + +static gint +lua_task_has_header(lua_State *L) +{ + return lua_task_get_header_common(L, RSPAMD_TASK_HEADER_PUSH_HAS); +} + +static gint +lua_task_get_headers(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + bool need_modified = lua_isnoneornil(L, 2) ? false : lua_toboolean(L, 2); + + if (task && task->message) { + struct rspamd_mime_header *cur; + int i = 1; + + lua_createtable(L, rspamd_mime_headers_count(MESSAGE_FIELD(task, raw_headers)), 0); + LL_FOREACH2(MESSAGE_FIELD(task, headers_order), cur, ord_next) + { + if (need_modified && cur->modified_chain) { + struct rspamd_mime_header *cur_modified; + + LL_FOREACH(cur->modified_chain, cur_modified) + { + rspamd_lua_push_header(L, cur_modified, RSPAMD_TASK_HEADER_PUSH_FULL); + lua_rawseti(L, -2, i++); + } + } + else { + rspamd_lua_push_header(L, cur, RSPAMD_TASK_HEADER_PUSH_FULL); + lua_rawseti(L, -2, i++); + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + + return 1; +} + +static gint +lua_task_get_raw_headers(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + struct rspamd_lua_text *t; + + if (task && task->message) { + t = lua_newuserdata(L, sizeof(*t)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + t->start = MESSAGE_FIELD(task, raw_headers_content).begin; + t->len = MESSAGE_FIELD(task, raw_headers_content).len; + t->flags = 0; + } + else { + return luaL_error(L, "invalid arguments"); + } + + + return 1; +} + +static gint +lua_task_get_received_headers(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + + if (task) { + if (!task->message) { + /* No message - no received */ + lua_newtable(L); + return 1; + } + + if (!lua_task_get_cached(L, task, "received")) { + + if (rspamd_received_export_to_lua(task, L)) { + lua_task_set_cached(L, task, "received", -1); + } + else { + /* no received, preserve compatibility */ + lua_newtable(L); + return 1; + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_queue_id(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + + if (task) { + if (task->queue_id != NULL && strcmp(task->queue_id, "undef") != 0) { + lua_pushstring(L, task->queue_id); + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_uid(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + + if (task) { + lua_pushstring(L, task->task_pool->tag.uid); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_resolver(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + struct rspamd_dns_resolver **presolver; + + if (task != NULL && task->resolver != NULL) { + presolver = lua_newuserdata(L, sizeof(void *)); + rspamd_lua_setclass(L, "rspamd{resolver}", -1); + *presolver = task->resolver; + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_set_resolver(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + struct rspamd_dns_resolver *resolver = lua_check_dns_resolver(L, 2); + + if (task != NULL && resolver != NULL) { + task->resolver = resolver; + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static gint +lua_task_inc_dns_req(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + static guint warning_shown = 0; + + if (warning_shown < 100) { + warning_shown++; + msg_warn_task_check("task:inc_dns_req is deprecated and should not be used"); + } + + if (task != NULL) { + /* Deprecation: already done in rspamd_dns_resolver_request */ + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static gint +lua_task_get_dns_req(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + + if (task != NULL) { + lua_pushinteger(L, task->dns_requests); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +enum lua_email_address_type { + LUA_ADDRESS_ANY = 0u, + LUA_ADDRESS_SMTP = 1, + LUA_ADDRESS_MIME = 2, + LUA_ADDRESS_MASK = 0x3FF, + LUA_ADDRESS_RAW = (1u << 10), + LUA_ADDRESS_ORIGINAL = (1u << 11), + LUA_ADDRESS_MAX = LUA_ADDRESS_MASK, +}; + +/* + * Convert element at the specified position to the type + * for get_from/get_recipients + */ +static enum lua_email_address_type +lua_task_str_to_get_type(lua_State *L, struct rspamd_task *task, gint pos, gint last_pos) +{ + const gchar *type = NULL; + gint ret = LUA_ADDRESS_ANY; + guint64 h; + gsize sz; + + /* Get what value */ + do { + if (lua_type(L, pos) == LUA_TNUMBER) { + ret = lua_tonumber(L, pos); + + if (ret >= LUA_ADDRESS_ANY && ret < LUA_ADDRESS_MAX) { + return ret; + } + + return LUA_ADDRESS_ANY; + } + else if (lua_type(L, pos) == LUA_TSTRING) { + type = lua_tolstring(L, pos, &sz); + + if (type && sz > 0) { + h = rspamd_cryptobox_fast_hash_specific(RSPAMD_CRYPTOBOX_XXHASH64, + type, sz, 0xdeadbabe); + + switch (h) { + case 0xDA081341FB600389ULL: /* mime */ + ret = LUA_ADDRESS_MIME; + break; + case 0xEEC8A7832F8C43ACULL: /* any */ + ret = LUA_ADDRESS_ANY; + break; + case 0x472274D5193B2A80ULL: /* smtp */ + case 0xEFE0F586CC9F14A9ULL: /* envelope */ + ret = LUA_ADDRESS_SMTP; + break; + default: + msg_err_task("invalid email type: %*s", (gint) sz, type); + break; + } + } + } + else if (lua_type(L, pos) == LUA_TTABLE) { + for (lua_pushnil(L); lua_next(L, pos); lua_pop(L, 1)) { + type = lua_tolstring(L, -1, &sz); + + if (type && sz > 0) { + h = rspamd_cryptobox_fast_hash_specific(RSPAMD_CRYPTOBOX_XXHASH64, + type, sz, 0xdeadbabe); + + switch (h) { + case 0xDA081341FB600389ULL: /* mime */ + ret |= LUA_ADDRESS_MIME; + break; + case 0xEEC8A7832F8C43ACULL: /* any */ + ret |= LUA_ADDRESS_ANY; + break; + case 0x472274D5193B2A80ULL: /* smtp */ + case 0xEFE0F586CC9F14A9ULL: /* envelope */ + ret |= LUA_ADDRESS_SMTP; + break; + case 0xAF4DE083D9AD0132: /* raw */ + ret |= LUA_ADDRESS_RAW; + break; + case 0xC7AB6C7B7B0F5A8A: /* orig */ + case 0x1778AE905589E431: /* original */ + ret |= LUA_ADDRESS_ORIGINAL; + break; + default: + msg_err_task("invalid email type: %*s", (gint) sz, type); + break; + } + } + } + } + pos++; + } while (pos <= last_pos); + + return ret; +} + +#define EMAIL_CHECK_FLAG(fl, str) \ + do { \ + if (addr->flags & (fl)) { \ + lua_pushstring(L, (str)); \ + lua_pushboolean(L, true); \ + lua_settable(L, -3); \ + } \ + } while (0) + +static void +lua_push_email_address(lua_State *L, struct rspamd_email_address *addr) +{ + if (addr) { + lua_createtable(L, 0, 5); + + if (addr->raw_len > 0) { + lua_pushstring(L, "raw"); + lua_pushlstring(L, addr->raw, addr->raw_len); + lua_settable(L, -3); + } + else { + lua_pushstring(L, "raw"); + lua_pushstring(L, ""); + lua_settable(L, -3); + } + if (addr->addr_len > 0) { + lua_pushstring(L, "addr"); + lua_pushlstring(L, addr->addr, addr->addr_len); + lua_settable(L, -3); + } + else { + lua_pushstring(L, "addr"); + lua_pushstring(L, ""); + lua_settable(L, -3); + } + if (addr->domain_len > 0) { + lua_pushstring(L, "domain"); + lua_pushlstring(L, addr->domain, addr->domain_len); + lua_settable(L, -3); + } + else { + lua_pushstring(L, "domain"); + lua_pushstring(L, ""); + lua_settable(L, -3); + } + if (addr->user_len > 0) { + lua_pushstring(L, "user"); + lua_pushlstring(L, addr->user, addr->user_len); + lua_settable(L, -3); + } + else { + lua_pushstring(L, "user"); + lua_pushstring(L, ""); + lua_settable(L, -3); + } + + if (addr->name) { + lua_pushstring(L, "name"); + lua_pushstring(L, addr->name); + lua_settable(L, -3); + } + else { + lua_pushstring(L, "name"); + lua_pushstring(L, ""); + lua_settable(L, -3); + } + + lua_pushstring(L, "flags"); + lua_createtable(L, 0, 7); + + EMAIL_CHECK_FLAG(RSPAMD_EMAIL_ADDR_VALID, "valid"); + EMAIL_CHECK_FLAG(RSPAMD_EMAIL_ADDR_IP, "ip"); + EMAIL_CHECK_FLAG(RSPAMD_EMAIL_ADDR_BRACED, "braced"); + EMAIL_CHECK_FLAG(RSPAMD_EMAIL_ADDR_QUOTED, "quoted"); + EMAIL_CHECK_FLAG(RSPAMD_EMAIL_ADDR_EMPTY, "empty"); + EMAIL_CHECK_FLAG(RSPAMD_EMAIL_ADDR_HAS_BACKSLASH, "backslash"); + EMAIL_CHECK_FLAG(RSPAMD_EMAIL_ADDR_HAS_8BIT, "8bit"); + + lua_settable(L, -3); + } +} + +void lua_push_emails_address_list(lua_State *L, GPtrArray *addrs, int flags) +{ + struct rspamd_email_address *addr; + guint i, pos = 1; + + lua_createtable(L, addrs->len, 0); + + for (i = 0; i < addrs->len; i++) { + addr = g_ptr_array_index(addrs, i); + + if (addr->flags & RSPAMD_EMAIL_ADDR_ORIGINAL) { + if (flags & LUA_ADDRESS_ORIGINAL) { + lua_push_email_address(L, addr); + lua_rawseti(L, -2, pos); + pos++; + } + } + else { + lua_push_email_address(L, addr); + lua_rawseti(L, -2, pos); + pos++; + } + } +} + +static gboolean +lua_import_email_address(lua_State *L, struct rspamd_task *task, + gint pos, + struct rspamd_email_address **paddr) +{ + struct rspamd_email_address *addr; + const gchar *p; + gchar *dst; + gsize len; + + g_assert(paddr != NULL); + + if (!lua_istable(L, pos)) { + return FALSE; + } + + addr = g_malloc0(sizeof(*addr)); + + lua_pushstring(L, "name"); + lua_gettable(L, pos); + + if (lua_type(L, -1) == LUA_TSTRING) { + p = lua_tolstring(L, -1, &len); + dst = rspamd_mempool_alloc(task->task_pool, len + 1); + rspamd_strlcpy(dst, p, len + 1); + addr->name = dst; + } + + lua_pop(L, 1); + + lua_pushstring(L, "user"); + lua_gettable(L, pos); + + if (lua_type(L, -1) == LUA_TSTRING) { + p = lua_tolstring(L, -1, &len); + addr->user = (const gchar *) rspamd_mempool_alloc(task->task_pool, len); + memcpy((gchar *) addr->user, p, len); + addr->user_len = len; + } + + lua_pop(L, 1); + + lua_pushstring(L, "domain"); + lua_gettable(L, pos); + + if (lua_type(L, -1) == LUA_TSTRING) { + p = lua_tolstring(L, -1, &len); + addr->domain = (const gchar *) rspamd_mempool_alloc(task->task_pool, len); + memcpy((gchar *) addr->domain, p, len); + addr->domain_len = len; + } + + lua_pop(L, 1); + + lua_pushstring(L, "addr"); + lua_gettable(L, pos); + + if (lua_type(L, -1) == LUA_TSTRING) { + p = lua_tolstring(L, -1, &len); + addr->addr = (const gchar *) rspamd_mempool_alloc(task->task_pool, len); + memcpy((gchar *) addr->addr, p, len); + addr->addr_len = len; + } + else { + /* Construct addr */ + len = addr->domain_len + addr->user_len + 1; + addr->addr = (const gchar *) rspamd_mempool_alloc(task->task_pool, len); + addr->addr_len = rspamd_snprintf((gchar *) addr->addr, len, "%*s@%*s", + (int) addr->user_len, addr->user, + (int) addr->domain_len, addr->domain); + } + + lua_pop(L, 1); + + lua_pushstring(L, "raw"); + lua_gettable(L, pos); + + if (lua_type(L, -1) == LUA_TSTRING) { + gchar *cpy; + p = lua_tolstring(L, -1, &len); + cpy = rspamd_mempool_alloc(task->task_pool, len + 1); + memcpy(cpy, p, len); + cpy[len] = '\0'; + addr->raw_len = len; + addr->raw = cpy; + } + else { + /* Construct raw addr */ + len = addr->addr_len + 3; + + if (addr->name) { + len += strlen(addr->name) + 1; + dst = rspamd_mempool_alloc(task->task_pool, len + 1); + + addr->raw_len = rspamd_snprintf(dst, len, "%s <%*s>", + addr->name, + (int) addr->addr_len, addr->addr); + } + else { + dst = rspamd_mempool_alloc(task->task_pool, len + 1); + + addr->raw_len = rspamd_snprintf(dst, len, "<%*s@%*s>", + (int) addr->user_len, addr->user, + (int) addr->domain_len, addr->domain); + } + + addr->raw = dst; + } + + lua_pop(L, 1); + addr->flags = RSPAMD_EMAIL_ADDR_VALID; + + *paddr = addr; + + return TRUE; +} + +static gint +lua_task_get_recipients(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + GPtrArray *ptrs = NULL; + gint what = 0; + + if (task) { + if (lua_gettop(L) == 2) { + /* Get what value */ + what = lua_task_str_to_get_type(L, task, 2, lua_gettop(L)); + } + + switch (what & LUA_ADDRESS_MASK) { + case LUA_ADDRESS_SMTP: + /* Here we check merely envelope rcpt */ + ptrs = task->rcpt_envelope; + break; + case LUA_ADDRESS_MIME: + /* Here we check merely mime rcpt */ + ptrs = MESSAGE_FIELD_CHECK(task, rcpt_mime); + break; + case LUA_ADDRESS_ANY: + default: + if (task->rcpt_envelope) { + ptrs = task->rcpt_envelope; + } + else { + ptrs = MESSAGE_FIELD_CHECK(task, rcpt_mime); + } + break; + } + if (ptrs) { + lua_push_emails_address_list(L, ptrs, what & ~LUA_ADDRESS_MASK); + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_set_recipients(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + GPtrArray *ptrs = NULL; + struct rspamd_email_address *addr = NULL; + gint what = 0, pos = 3; + const gchar *how = "add"; + gboolean need_update_digest = FALSE; + + if (task && lua_gettop(L) >= 3) { + + /* Get what value */ + what = lua_task_str_to_get_type(L, task, 2, -1); + + if (lua_isstring(L, 4)) { + how = lua_tostring(L, 4); + } + + switch (what) { + case LUA_ADDRESS_SMTP: + /* Here we check merely envelope rcpt */ + if (task->rcpt_envelope) { + ptrs = task->rcpt_envelope; + } + else { + ptrs = g_ptr_array_new(); + task->rcpt_envelope = ptrs; + } + break; + case LUA_ADDRESS_MIME: + /* Here we check merely mime rcpt */ + ptrs = MESSAGE_FIELD_CHECK(task, rcpt_mime); + need_update_digest = TRUE; + break; + case LUA_ADDRESS_ANY: + default: + if (task->rcpt_envelope) { + if (task->rcpt_envelope) { + ptrs = task->rcpt_envelope; + } + else { + ptrs = g_ptr_array_new(); + task->rcpt_envelope = ptrs; + } + } + else { + ptrs = MESSAGE_FIELD_CHECK(task, rcpt_mime); + need_update_digest = TRUE; + } + break; + } + if (ptrs) { + guint i, flags_existing = RSPAMD_EMAIL_ADDR_ORIGINAL, flags_add = 0; + struct rspamd_email_address *tmp; + + if (strcmp(how, "alias") == 0) { + flags_add |= RSPAMD_EMAIL_ADDR_ALIASED; + } + else if (strcmp(how, "rewrite") == 0) { + /* Clear old addresses */ + PTR_ARRAY_FOREACH(ptrs, i, tmp) + { + rspamd_email_address_free(addr); + } + + g_ptr_array_set_size(ptrs, 0); + } + + PTR_ARRAY_FOREACH(ptrs, i, tmp) + { + tmp->flags |= flags_existing; + } + + lua_pushvalue(L, pos); + + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + if (lua_import_email_address(L, task, lua_gettop(L), &addr)) { + + if (need_update_digest) { + rspamd_message_update_digest(task->message, + addr->addr, addr->addr_len); + } + + addr->flags |= flags_add; + g_ptr_array_add(ptrs, addr); + } + } + + lua_pop(L, 1); + lua_pushboolean(L, true); + } + else { + lua_pushboolean(L, false); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + + +#define CHECK_EMAIL_ADDR(addr) \ + do { \ + if (addr == NULL) { \ + ret = 0; \ + } \ + else { \ + ret = addr->flags & RSPAMD_EMAIL_ADDR_VALID; \ + } \ + } while (0) + +#define CHECK_EMAIL_ADDR_LIST(addr) \ + do { \ + if (addr == NULL) { \ + ret = 0; \ + } \ + else { \ + ret = addr->len > 0; \ + nrcpt = addr->len; \ + } \ + } while (0) + +static gint +lua_task_has_from(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + gint what = 0, nrcpt = 0; + gboolean ret = FALSE; + + if (task) { + if (lua_gettop(L) == 2) { + /* Get what value */ + what = lua_task_str_to_get_type(L, task, 2, lua_gettop(L)); + } + + switch (what & LUA_ADDRESS_MASK) { + case LUA_ADDRESS_SMTP: + /* Here we check merely envelope rcpt */ + CHECK_EMAIL_ADDR(task->from_envelope); + break; + case LUA_ADDRESS_MIME: + /* Here we check merely mime rcpt */ + CHECK_EMAIL_ADDR_LIST(MESSAGE_FIELD_CHECK(task, from_mime)); + break; + case LUA_ADDRESS_ANY: + default: + CHECK_EMAIL_ADDR(task->from_envelope); + + if (!ret) { + CHECK_EMAIL_ADDR_LIST(MESSAGE_FIELD_CHECK(task, from_mime)); + } + break; + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + lua_pushboolean(L, ret); + (void) nrcpt; /* Silence warning */ + + return 1; +} + +static inline int +rspamd_check_real_recipients_array_size(GPtrArray *ar) +{ + gint ret = 0, i; + struct rspamd_email_address *addr; + + PTR_ARRAY_FOREACH(ar, i, addr) + { + if (!(addr->flags & RSPAMD_EMAIL_ADDR_ORIGINAL)) { + ret++; + } + } + + return ret; +} + +static gint +lua_task_has_recipients(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + gint what = 0, nrcpt = 0; + gboolean ret = FALSE; + + if (task) { + if (lua_gettop(L) == 2) { + /* Get what value */ + what = lua_task_str_to_get_type(L, task, 2, lua_gettop(L)); + } + + switch (what & LUA_ADDRESS_MASK) { + case LUA_ADDRESS_SMTP: + /* Here we check merely envelope rcpt */ + nrcpt = rspamd_check_real_recipients_array_size(task->rcpt_envelope); + ret = nrcpt > 0; + break; + case LUA_ADDRESS_MIME: + /* Here we check merely mime rcpt */ + nrcpt = rspamd_check_real_recipients_array_size(MESSAGE_FIELD_CHECK(task, rcpt_mime)); + ret = nrcpt > 0; + break; + case LUA_ADDRESS_ANY: + default: + nrcpt = rspamd_check_real_recipients_array_size(task->rcpt_envelope); + ret = nrcpt > 0; + + if (!ret) { + nrcpt = rspamd_check_real_recipients_array_size(MESSAGE_FIELD_CHECK(task, rcpt_mime)); + ret = nrcpt > 0; + } + break; + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + lua_pushboolean(L, ret); + lua_pushinteger(L, nrcpt); + + return 2; +} + +static gint +lua_task_get_from(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + GPtrArray *addrs = NULL; + struct rspamd_email_address *addr = NULL; + gint what = 0; + + if (task) { + if (lua_gettop(L) == 2) { + /* Get what value */ + what = lua_task_str_to_get_type(L, task, 2, lua_gettop(L)); + } + + switch (what & LUA_ADDRESS_MASK) { + case LUA_ADDRESS_SMTP: + /* Here we check merely envelope rcpt */ + addr = task->from_envelope; + break; + case LUA_ADDRESS_MIME: + /* Here we check merely mime rcpt */ + addrs = MESSAGE_FIELD_CHECK(task, from_mime); + break; + case LUA_ADDRESS_ANY: + default: + if (task->from_envelope) { + addr = task->from_envelope; + } + else { + addrs = MESSAGE_FIELD_CHECK(task, from_mime); + } + break; + } + + if (addrs && addrs->len > 0) { + lua_push_emails_address_list(L, addrs, what & ~LUA_ADDRESS_MASK); + } + else if (addr) { + /* Create table to preserve compatibility */ + if (addr->addr) { + lua_createtable(L, 1, 0); + if (what & LUA_ADDRESS_ORIGINAL) { + if (task->from_envelope_orig) { + lua_push_email_address(L, task->from_envelope_orig); + } + else { + lua_push_email_address(L, addr); + } + } + else { + lua_push_email_address(L, addr); + } + + lua_rawseti(L, -2, 1); + } + else { + lua_pushnil(L); + } + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_set_from(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *how = "rewrite"; + GPtrArray *addrs = NULL; + struct rspamd_email_address **paddr = NULL, *addr; + gboolean need_update_digest = FALSE; + gint what = 0; + + if (task && lua_gettop(L) >= 3) { + what = lua_task_str_to_get_type(L, task, 2, -1); + + if (lua_isstring(L, 4)) { + how = lua_tostring(L, 4); + } + + switch (what & LUA_ADDRESS_MASK) { + case LUA_ADDRESS_SMTP: + /* Here we check merely envelope rcpt */ + paddr = &task->from_envelope; + break; + case LUA_ADDRESS_MIME: + /* Here we check merely mime rcpt */ + addrs = MESSAGE_FIELD_CHECK(task, from_mime); + need_update_digest = TRUE; + break; + case LUA_ADDRESS_ANY: + default: + if (task->from_envelope) { + paddr = &task->from_envelope; + } + else { + addrs = MESSAGE_FIELD_CHECK(task, from_mime); + need_update_digest = TRUE; + } + break; + } + + if (addrs) { + if (lua_import_email_address(L, task, 3, &addr)) { + guint i, flags_add = RSPAMD_EMAIL_ADDR_ORIGINAL; + struct rspamd_email_address *tmp; + + if (strcmp(how, "alias") == 0) { + flags_add |= RSPAMD_EMAIL_ADDR_ALIASED; + } + + PTR_ARRAY_FOREACH(addrs, i, tmp) + { + tmp->flags |= flags_add; + } + + if (need_update_digest) { + rspamd_message_update_digest(task->message, + addr->addr, addr->addr_len); + } + + g_ptr_array_add(addrs, addr); + lua_pushboolean(L, true); + } + else { + lua_pushboolean(L, false); + } + } + else if (paddr) { + /* SMTP from case */ + if (lua_import_email_address(L, task, 3, &addr)) { + task->from_envelope_orig = *paddr; + task->from_envelope = addr; + lua_pushboolean(L, true); + } + else { + lua_pushboolean(L, false); + } + } + else { + lua_pushboolean(L, false); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_principal_recipient(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *r; + + if (task) { + r = rspamd_task_get_principal_recipient(task); + if (r != NULL) { + lua_pushstring(L, r); + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_reply_sender(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + struct rspamd_mime_header *rh; + + if (task) { + + rh = rspamd_message_get_header_array(task, "Reply-To", FALSE); + + if (rh) { + GPtrArray *addrs; + + addrs = rspamd_email_address_from_mime(task->task_pool, rh->decoded, + strlen(rh->decoded), NULL, -1); + + if (addrs == NULL || addrs->len == 0) { + lua_pushnil(L); + } + else { + struct rspamd_email_address *addr; + + addr = (struct rspamd_email_address *) g_ptr_array_index(addrs, 0); + lua_pushlstring(L, addr->addr, addr->addr_len); + } + } + else if (MESSAGE_FIELD_CHECK(task, from_mime) && + MESSAGE_FIELD(task, from_mime)->len >= 1) { + struct rspamd_email_address *addr; + + addr = (struct rspamd_email_address *) g_ptr_array_index( + MESSAGE_FIELD(task, from_mime), 0); + + lua_pushlstring(L, addr->addr, addr->addr_len); + } + else if (task->from_envelope) { + lua_pushlstring(L, task->from_envelope->addr, + task->from_envelope->addr_len); + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_user(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + + if (task) { + if (task->auth_user != NULL) { + lua_pushstring(L, task->auth_user); + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_set_user(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *new_user; + + if (task) { + + if (lua_type(L, 2) == LUA_TSTRING) { + new_user = lua_tostring(L, 2); + + if (task->auth_user) { + /* Push old user */ + lua_pushstring(L, task->auth_user); + } + else { + lua_pushnil(L); + } + + task->auth_user = rspamd_mempool_strdup(task->task_pool, new_user); + } + else { + /* Reset user */ + if (task->auth_user) { + /* Push old user */ + lua_pushstring(L, task->auth_user); + } + else { + lua_pushnil(L); + } + + task->auth_user = NULL; + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_from_ip(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + + if (task) { + if (task->from_addr) { + rspamd_lua_ip_push(L, task->from_addr); + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_set_from_ip(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + rspamd_inet_addr_t *addr = NULL; + + if (!task) { + return luaL_error(L, "no task"); + } + else { + if (lua_type(L, 2) == LUA_TSTRING) { + gsize len; + const gchar *ip_str = lua_tolstring(L, 2, &len); + + if (!rspamd_parse_inet_address(&addr, + ip_str, + len, + RSPAMD_INET_ADDRESS_PARSE_DEFAULT)) { + return luaL_error(L, "invalid IP string: %s", ip_str); + } + else { + if (task->from_addr) { + rspamd_inet_address_free(task->from_addr); + } + + task->from_addr = addr; + } + } + else if (lua_type(L, 2) == LUA_TUSERDATA) { + struct rspamd_lua_ip *ip = lua_check_ip(L, 2); + + if (ip && ip->addr) { + if (task->from_addr) { + rspamd_inet_address_free(task->from_addr); + } + + task->from_addr = rspamd_inet_address_copy(ip->addr, NULL); + } + else { + return luaL_error(L, "invalid IP object"); + } + } + else { + return luaL_error(L, "invalid IP argument type: %s", lua_typename(L, lua_type(L, 2))); + } + } + + return 0; +} + +static gint +lua_task_get_from_ip_num(lua_State *L) +{ + LUA_TRACE_POINT; + msg_err("this function is deprecated and should no longer be used"); + lua_pushnil(L); + return 1; +} + +static gint +lua_task_get_client_ip(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + + if (task) { + if (task->client_addr) { + rspamd_lua_ip_push(L, task->client_addr); + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_helo(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + + if (task) { + if (task->helo != NULL) { + lua_pushstring(L, task->helo); + return 1; + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_subject(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + + if (task) { + if (MESSAGE_FIELD_CHECK(task, subject) != NULL) { + lua_pushstring(L, MESSAGE_FIELD(task, subject)); + return 1; + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_set_helo(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *new_helo; + + if (task) { + new_helo = luaL_checkstring(L, 2); + if (new_helo) { + task->helo = rspamd_mempool_strdup(task->task_pool, new_helo); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static gint +lua_task_get_hostname(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + + if (task) { + if (task->hostname != NULL) { + /* Check whether it looks like an IP address */ + if (*task->hostname == '[') { + /* + * From the milter documentation: + * If the reverse lookup fails or if none of the IP + * addresses of the resolved host name matches the + * original IP address, hostname will contain the + * message sender's IP address enclosed in square + * brackets (e.g. `[a.b.c.d]') + */ + lua_pushnil(L); + } + else { + lua_pushstring(L, task->hostname); + } + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_set_hostname(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *new_hostname; + + if (task) { + new_hostname = luaL_checkstring(L, 2); + if (new_hostname) { + task->hostname = rspamd_mempool_strdup(task->task_pool, + new_hostname); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static gint +lua_task_get_images(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + guint nelt = 0, i; + struct rspamd_mime_part *part; + struct rspamd_image **pimg; + + if (task) { + if (task->message) { + if (!lua_task_get_cached(L, task, "images")) { + lua_createtable(L, MESSAGE_FIELD(task, parts)->len, 0); + + PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, parts), i, part) + { + if (part->part_type == RSPAMD_MIME_PART_IMAGE) { + pimg = lua_newuserdata(L, sizeof(struct rspamd_image *)); + rspamd_lua_setclass(L, "rspamd{image}", -1); + *pimg = part->specific.img; + lua_rawseti(L, -2, ++nelt); + } + } + + lua_task_set_cached(L, task, "images", -1); + } + } + else { + lua_newtable(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_archives(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + guint nelt = 0, i; + struct rspamd_mime_part *part; + struct rspamd_archive **parch; + + if (task) { + if (task->message) { + if (!lua_task_get_cached(L, task, "archives")) { + lua_createtable(L, MESSAGE_FIELD(task, parts)->len, 0); + + PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, parts), i, part) + { + if (part->part_type == RSPAMD_MIME_PART_ARCHIVE) { + parch = lua_newuserdata(L, sizeof(struct rspamd_archive *)); + rspamd_lua_setclass(L, "rspamd{archive}", -1); + *parch = part->specific.arch; + lua_rawseti(L, -2, ++nelt); + } + } + + lua_task_set_cached(L, task, "archives", -1); + } + } + else { + lua_newtable(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_dkim_results(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + guint nelt = 0, i; + struct rspamd_dkim_check_result **pres, **cur; + + if (task) { + if (!lua_task_get_cached(L, task, "dkim_results")) { + pres = rspamd_mempool_get_variable(task->task_pool, + RSPAMD_MEMPOOL_DKIM_CHECK_RESULTS); + + if (pres == NULL) { + lua_newtable(L); + } + else { + for (cur = pres; *cur != NULL; cur++) { + nelt++; + } + + lua_createtable(L, nelt, 0); + + for (i = 0; i < nelt; i++) { + struct rspamd_dkim_check_result *res = pres[i]; + const gchar *result_str = "unknown"; + + lua_createtable(L, 0, 4); + + switch (res->rcode) { + case DKIM_CONTINUE: + result_str = "allow"; + break; + case DKIM_REJECT: + result_str = "reject"; + break; + case DKIM_TRYAGAIN: + result_str = "tempfail"; + break; + case DKIM_NOTFOUND: + result_str = "not found"; + break; + case DKIM_RECORD_ERROR: + result_str = "bad record"; + break; + case DKIM_PERM_ERROR: + result_str = "permanent error"; + break; + default: + break; + } + + rspamd_lua_table_set(L, "result", result_str); + + if (res->domain) { + rspamd_lua_table_set(L, "domain", res->domain); + } + + if (res->selector) { + rspamd_lua_table_set(L, "selector", res->selector); + } + + if (res->short_b) { + rspamd_lua_table_set(L, "bhash", res->short_b); + } + + if (res->fail_reason) { + rspamd_lua_table_set(L, "fail_reason", res->fail_reason); + } + + lua_rawseti(L, -2, i + 1); + } + } + + lua_task_set_cached(L, task, "dkim_results", -1); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static inline gboolean +lua_push_symbol_result(lua_State *L, + struct rspamd_task *task, + const gchar *symbol, + struct rspamd_symbol_result *symbol_result, + struct rspamd_scan_result *metric_res, + gboolean add_metric, + gboolean add_name) +{ + + struct rspamd_symbol_result *s = NULL; + struct rspamd_symbol_option *opt; + struct rspamd_symbols_group *sym_group; + guint i; + gint j = 1, table_fields_cnt = 4; + + if (!metric_res) { + metric_res = task->result; + } + + if (!symbol_result) { + s = rspamd_task_find_symbol_result(task, symbol, metric_res); + } + else { + s = symbol_result; + } + + if (s && !(s->flags & RSPAMD_SYMBOL_RESULT_IGNORED)) { + if (add_metric) { + table_fields_cnt++; + } + if (add_name) { + table_fields_cnt++; + } + + lua_createtable(L, 0, table_fields_cnt); + + if (add_name) { + lua_pushstring(L, "name"); + lua_pushstring(L, symbol); + lua_settable(L, -3); + } + lua_pushstring(L, "score"); + lua_pushnumber(L, s->score); + lua_settable(L, -3); + + if (s->sym && s->sym->gr) { + lua_pushstring(L, "group"); + lua_pushstring(L, s->sym->gr->name); + lua_settable(L, -3); + + lua_pushstring(L, "groups"); + lua_createtable(L, s->sym->groups->len, 0); + + PTR_ARRAY_FOREACH(s->sym->groups, i, sym_group) + { + lua_pushstring(L, sym_group->name); + lua_rawseti(L, -2, i + 1); + } + + lua_settable(L, -3); + } + else { + lua_pushstring(L, "group"); + lua_pushstring(L, "ungrouped"); + lua_settable(L, -3); + } + + if (s->options) { + lua_pushstring(L, "options"); + lua_createtable(L, kh_size(s->options), 0); + + DL_FOREACH(s->opts_head, opt) + { + lua_pushlstring(L, opt->option, opt->optlen); + lua_rawseti(L, -2, j++); + } + + lua_settable(L, -3); + } + + return TRUE; + } + + return FALSE; +} + +static gint +lua_task_get_symbol(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *symbol; + gboolean found = FALSE; + + symbol = luaL_checkstring(L, 2); + + if (task && symbol) { + struct rspamd_scan_result *sres = NULL; + + if (lua_isstring(L, 3)) { + sres = rspamd_find_metric_result(task, lua_tostring(L, 3)); + + if (sres == NULL) { + return luaL_error(L, "invalid scan result: %s", + lua_tostring(L, 3)); + } + } + + /* Always push as a table for compatibility :( */ + lua_createtable(L, 1, 0); + + if ((found = lua_push_symbol_result(L, task, symbol, + NULL, sres, TRUE, FALSE))) { + lua_rawseti(L, -2, 1); + } + else { + /* Pop table */ + lua_pop(L, 1); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + if (!found) { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_task_has_symbol(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + struct rspamd_symbol_result *s; + const gchar *symbol; + gboolean found = FALSE; + + symbol = luaL_checkstring(L, 2); + + if (task && symbol) { + if (lua_isstring(L, 3)) { + s = rspamd_task_find_symbol_result(task, symbol, + rspamd_find_metric_result(task, lua_tostring(L, 3))); + + if (s && !(s->flags & RSPAMD_SYMBOL_RESULT_IGNORED)) { + found = TRUE; + } + } + else { + s = rspamd_task_find_symbol_result(task, symbol, NULL); + + if (s && !(s->flags & RSPAMD_SYMBOL_RESULT_IGNORED)) { + found = TRUE; + } + } + lua_pushboolean(L, found); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_enable_symbol(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *symbol; + gboolean found = FALSE; + + symbol = luaL_checkstring(L, 2); + + if (task && symbol) { + found = rspamd_symcache_enable_symbol(task, task->cfg->cache, symbol); + lua_pushboolean(L, found); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_disable_symbol(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *symbol; + gboolean found = FALSE; + + symbol = luaL_checkstring(L, 2); + + if (task && symbol) { + found = rspamd_symcache_disable_symbol(task, task->cfg->cache, symbol); + lua_pushboolean(L, found); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_symbols(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + struct rspamd_scan_result *mres; + gint i = 1; + struct rspamd_symbol_result *s; + + if (task) { + mres = task->result; + + if (lua_isstring(L, 2)) { + mres = rspamd_find_metric_result(task, lua_tostring(L, 2)); + } + + if (mres) { + lua_createtable(L, kh_size(mres->symbols), 0); + lua_createtable(L, kh_size(mres->symbols), 0); + + kh_foreach_value(mres->symbols, s, { + if (!(s->flags & RSPAMD_SYMBOL_RESULT_IGNORED)) { + lua_pushstring(L, s->name); + lua_rawseti(L, -3, i); + lua_pushnumber(L, s->score); + lua_rawseti(L, -2, i); + i++; + } + }); + } + else { + lua_createtable(L, 0, 0); + lua_createtable(L, 0, 0); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 2; +} + +static gint +lua_task_get_symbols_all(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + struct rspamd_scan_result *mres; + struct rspamd_symbol_result *s; + gboolean found = FALSE; + gint i = 1; + + if (task) { + mres = task->result; + + if (lua_isstring(L, 2)) { + mres = rspamd_find_metric_result(task, lua_tostring(L, 2)); + } + + if (mres) { + found = TRUE; + lua_createtable(L, kh_size(mres->symbols), 0); + + kh_foreach_value(mres->symbols, s, { + if (!(s->flags & RSPAMD_SYMBOL_RESULT_IGNORED)) { + lua_push_symbol_result(L, task, s->name, s, mres, FALSE, TRUE); + lua_rawseti(L, -2, i++); + } + }); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + if (!found) { + lua_pushnil(L); + } + + return 1; +} + + +static gint +lua_task_get_symbols_numeric(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + struct rspamd_scan_result *mres; + gint i = 1, id; + struct rspamd_symbol_result *s; + + if (task) { + mres = task->result; + + if (lua_isstring(L, 2)) { + mres = rspamd_find_metric_result(task, lua_tostring(L, 2)); + } + + if (mres) { + lua_createtable(L, kh_size(mres->symbols), 0); + lua_createtable(L, kh_size(mres->symbols), 0); + + lua_createtable(L, kh_size(mres->symbols), 0); + + kh_foreach_value(mres->symbols, s, { + if (!(s->flags & RSPAMD_SYMBOL_RESULT_IGNORED)) { + id = rspamd_symcache_find_symbol(task->cfg->cache, + s->name); + lua_pushinteger(L, id); + lua_rawseti(L, -3, i); + lua_pushnumber(L, s->score); + lua_rawseti(L, -2, i); + i++; + } + }); + } + else { + lua_createtable(L, 0, 0); + lua_createtable(L, 0, 0); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 2; +} + +static gint +lua_task_get_groups(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + gboolean need_private; + struct rspamd_scan_result *mres; + struct rspamd_symbols_group *gr; + gdouble gr_score; + + if (task) { + mres = task->result; + + if (lua_isboolean(L, 2)) { + need_private = lua_toboolean(L, 2); + } + else { + need_private = !(task->cfg->public_groups_only); + } + + if (lua_isstring(L, 3)) { + mres = rspamd_find_metric_result(task, lua_tostring(L, 3)); + } + + if (mres == NULL) { + lua_pushnil(L); + + return 1; + } + + lua_createtable(L, 0, kh_size(mres->sym_groups)); + + kh_foreach(mres->sym_groups, gr, gr_score, { + if (!(gr->flags & RSPAMD_SYMBOL_GROUP_PUBLIC)) { + if (!need_private) { + continue; + } + } + + lua_pushnumber(L, gr_score); + lua_setfield(L, -2, gr->name); + }); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +struct tokens_foreach_cbdata { + struct rspamd_task *task; + lua_State *L; + gint idx; + gboolean normalize; +}; + +static void +tokens_foreach_cb(struct rspamd_symcache_item *item, gpointer ud) +{ + struct tokens_foreach_cbdata *cbd = ud; + struct rspamd_symbol_result *s; + gint flags; + const gchar *sym; + + sym = rspamd_symcache_item_name(item); + flags = rspamd_symcache_item_flags(item); + + if (flags & SYMBOL_TYPE_NOSTAT) { + return; + } + + if ((s = rspamd_task_find_symbol_result(cbd->task, sym, NULL)) != NULL) { + if (s->flags & RSPAMD_SYMBOL_RESULT_IGNORED) { + lua_pushnumber(cbd->L, 0.0); + } + else { + if (cbd->normalize) { + lua_pushnumber(cbd->L, tanh(s->score)); + } + else { + lua_pushnumber(cbd->L, s->score); + } + } + } + else { + lua_pushnumber(cbd->L, 0.0); + } + + lua_rawseti(cbd->L, -2, cbd->idx++); +} + +static gint +lua_task_get_symbols_tokens(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + struct tokens_foreach_cbdata cbd; + + if (task) { + cbd.task = task; + cbd.L = L; + cbd.idx = 1; + cbd.normalize = TRUE; + + if (lua_type(L, 2) == LUA_TBOOLEAN) { + cbd.normalize = lua_toboolean(L, 2); + } + else { + cbd.normalize = TRUE; + } + + lua_createtable(L, + rspamd_symcache_stats_symbols_count(task->cfg->cache), 0); + rspamd_symcache_foreach(task->cfg->cache, tokens_foreach_cb, &cbd); + } + else { + return luaL_error(L, "invalid arguments"); + } + + /* Return type is table created */ + return 1; +} + +static gint +lua_task_process_ann_tokens(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + gint offset = luaL_checkinteger(L, 4); + gdouble min_score = 0.0; + + if (task && lua_istable(L, 2) && lua_istable(L, 3)) { + guint symlen = rspamd_lua_table_size(L, 2); + if (lua_isnumber(L, 5)) { + min_score = lua_tonumber(L, 5); + } + + for (guint i = 1; i <= symlen; i++, offset++) { + const gchar *sym; + struct rspamd_symbol_result *sres; + + lua_rawgeti(L, 2, i); + sym = lua_tostring(L, -1); + + /* + * TODO: this cycle involves one hash lookup per symbol in a profile + * Basically, in a common case that would be a table of all symbols + * So we need to do N_symbols hash lookups which is not optimal + * The optimal solution is to convert [sym1, sym2, ... symn] profile + * to a set {sym1 = true, sym2 = true, ...} and then for each + * resulting symbol check this table. + * + * That would lead to N_results lookups which is usually MUCH smaller + */ + sres = rspamd_task_find_symbol_result(task, sym, NULL); + + if (sres && !(sres->flags & RSPAMD_SYMBOL_RESULT_IGNORED)) { + + if (!isnan(sres->score) && !isinf(sres->score) && + (!sres->sym || + !(rspamd_symcache_item_flags(sres->sym->cache_item) & SYMBOL_TYPE_NOSTAT))) { + + gdouble norm_score; + + if (sres->sym && !isnan(sres->sym->score)) { + if (sres->sym->score == 0) { + + if (sres->score == 0) { + /* Binary symbol */ + norm_score = 1.0; + } + else { + norm_score = fabs(tanh(sres->score)); + } + } + else { + /* Get dynamic weight */ + norm_score = fabs(sres->score / sres->sym->score); + + if (norm_score > 1.0) { + /* Multiple hits, we assume them as a single one */ + norm_score = 1.0; + } + } + } + else { + norm_score = fabs(tanh(sres->score)); + } + + lua_pushnumber(L, MAX(min_score, norm_score)); + lua_rawseti(L, 3, offset + 1); + } + } + + lua_pop(L, 1); /* Symbol name */ + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +enum lua_date_type { + DATE_CONNECT = 0, + DATE_MESSAGE, + DATE_INVALID +}; + +static enum lua_date_type +lua_task_detect_date_type(struct rspamd_task *task, + lua_State *L, gint idx, gboolean *gmt) +{ + enum lua_date_type type = DATE_CONNECT; + + if (lua_type(L, idx) == LUA_TNUMBER) { + gint num = lua_tonumber(L, idx); + if (num >= DATE_CONNECT && num < DATE_INVALID) { + return num; + } + } + else if (lua_type(L, idx) == LUA_TTABLE) { + const gchar *str; + + lua_pushvalue(L, idx); + lua_pushstring(L, "format"); + lua_gettable(L, -2); + + str = lua_tostring(L, -1); + + if (str) { + if (g_ascii_strcasecmp(str, "message") == 0) { + type = DATE_MESSAGE; + } + } + else { + msg_warn_task("date format has not been specified"); + } + + lua_pop(L, 1); + + lua_pushstring(L, "gmt"); + lua_gettable(L, -2); + + if (lua_type(L, -1) == LUA_TBOOLEAN) { + *gmt = lua_toboolean(L, -1); + } + + /* Value and table */ + lua_pop(L, 2); + } + + return type; +} + +static gint +lua_task_get_date(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + struct rspamd_mime_header *h; + gdouble tim; + enum lua_date_type type = DATE_CONNECT; + gboolean gmt = TRUE; + + if (task != NULL) { + if (lua_gettop(L) > 1) { + type = lua_task_detect_date_type(task, L, 2, &gmt); + } + /* Get GMT date and store it to time_t */ + if (type == DATE_CONNECT) { + tim = task->task_timestamp; + + if (!gmt) { + struct tm t; + time_t tt; + + tt = tim; + rspamd_localtime(tt, &t); +#if !defined(__sun) + t.tm_gmtoff = 0; +#endif + t.tm_isdst = 0; + /* Preserve fractional part as Lua is aware of it */ + tim = mktime(&t) + (tim - tt); + } + } + else { + h = rspamd_message_get_header_array(task, "Date", FALSE); + + if (h) { + time_t tt; + struct tm t; + GError *err = NULL; + + tt = rspamd_parse_smtp_date(h->decoded, strlen(h->decoded), + &err); + + if (err == NULL) { + if (!gmt) { + rspamd_localtime(tt, &t); +#if !defined(__sun) + t.tm_gmtoff = 0; +#endif + t.tm_isdst = 0; + tim = mktime(&t); + } + else { + tim = tt; + } + } + else { + g_error_free(err); + tim = 0.0; + } + } + else { + tim = 0.0; + } + } + + lua_pushnumber(L, tim); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_message_id(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + + if (task != NULL) { + if (MESSAGE_FIELD_CHECK(task, message_id) != NULL) { + lua_pushstring(L, MESSAGE_FIELD(task, message_id)); + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_timeval(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + struct timeval tv; + + if (task != NULL) { + if (lua_isboolean(L, 2) && !!lua_toboolean(L, 2)) { + lua_pushnumber(L, task->task_timestamp); + } + else { + double_to_tv(task->task_timestamp, &tv); + lua_createtable(L, 0, 2); + lua_pushstring(L, "tv_sec"); + lua_pushinteger(L, (lua_Integer) tv.tv_sec); + lua_settable(L, -3); + lua_pushstring(L, "tv_usec"); + lua_pushinteger(L, (lua_Integer) tv.tv_usec); + lua_settable(L, -3); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_scan_time(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + gboolean set = TRUE; + + if (task != NULL) { + if (lua_isboolean(L, 2)) { + set = lua_toboolean(L, 2); + } + + rspamd_task_set_finish_time(task); + gdouble diff = task->time_real_finish - task->task_timestamp; + lua_pushnumber(L, diff); + lua_pushnumber(L, diff); + + if (!set) { + /* Reset to nan to allow further calcs in rspamd_task_set_finish_time */ + task->time_real_finish = NAN; + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 2; +} + +static gint +lua_task_get_size(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + + if (task != NULL) { + lua_pushinteger(L, task->msg.len); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/** +* - `no_log`: do not log task summary +* - `no_stat`: do not include task into scanned stats +* - `pass_all`: check all filters for task +* - `extended_urls`: output extended info about urls +* - `skip`: skip task processing +*/ + +#define LUA_TASK_FLAG_WRITE(flag, set) \ + do { \ + task->flags = (set) ? (task->flags | (flag)) : (task->flags & ~(flag)); \ + } while (0) + +#define LUA_TASK_SET_FLAG(flag, strname, macro, set) \ + do { \ + if (!found && strcmp((flag), strname) == 0) { \ + LUA_TASK_FLAG_WRITE((macro), set); \ + found = TRUE; \ + } \ + } while (0) + +#define LUA_TASK_FLAG_READ(flag) \ + do { \ + lua_pushboolean(L, !!(task->flags & (flag))); \ + } while (0) + +#define LUA_TASK_GET_FLAG(flag, strname, macro) \ + do { \ + if (!found && strcmp((flag), strname) == 0) { \ + LUA_TASK_FLAG_READ((macro)); \ + found = TRUE; \ + } \ + } while (0) + +#define LUA_TASK_PROTOCOL_FLAG_READ(flag) \ + do { \ + lua_pushboolean(L, !!(task->protocol_flags & (flag))); \ + } while (0) + +#define LUA_TASK_GET_PROTOCOL_FLAG(flag, strname, macro) \ + do { \ + if (!found && strcmp((flag), strname) == 0) { \ + LUA_TASK_PROTOCOL_FLAG_READ((macro)); \ + found = TRUE; \ + } \ + } while (0) + +static gint +lua_task_set_flag(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *flag = luaL_checkstring(L, 2); + gboolean set = TRUE, found = FALSE; + + if (lua_gettop(L) >= 3) { + set = lua_toboolean(L, 3); + } + + if (task != NULL && flag != NULL) { + LUA_TASK_SET_FLAG(flag, "pass_all", RSPAMD_TASK_FLAG_PASS_ALL, set); + LUA_TASK_SET_FLAG(flag, "no_log", RSPAMD_TASK_FLAG_NO_LOG, set); + LUA_TASK_SET_FLAG(flag, "no_stat", RSPAMD_TASK_FLAG_NO_STAT, set); + LUA_TASK_SET_FLAG(flag, "skip", RSPAMD_TASK_FLAG_SKIP, set); + LUA_TASK_SET_FLAG(flag, "learn_spam", RSPAMD_TASK_FLAG_LEARN_SPAM, set); + LUA_TASK_SET_FLAG(flag, "learn_ham", RSPAMD_TASK_FLAG_LEARN_HAM, set); + LUA_TASK_SET_FLAG(flag, "broken_headers", + RSPAMD_TASK_FLAG_BROKEN_HEADERS, set); + LUA_TASK_SET_FLAG(flag, "greylisted", RSPAMD_TASK_FLAG_GREYLISTED, set); + LUA_TASK_SET_FLAG(flag, "skip_process", RSPAMD_TASK_FLAG_SKIP_PROCESS, set); + LUA_TASK_SET_FLAG(flag, "message_rewrite", RSPAMD_TASK_FLAG_MESSAGE_REWRITE, set); + + if (!found) { + msg_warn_task("unknown flag requested: %s", flag); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static gint +lua_task_has_flag(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *flag = luaL_checkstring(L, 2); + gboolean found = FALSE; + + if (task != NULL && flag != NULL) { + LUA_TASK_GET_FLAG(flag, "pass_all", RSPAMD_TASK_FLAG_PASS_ALL); + LUA_TASK_GET_FLAG(flag, "no_log", RSPAMD_TASK_FLAG_NO_LOG); + LUA_TASK_GET_FLAG(flag, "no_stat", RSPAMD_TASK_FLAG_NO_STAT); + LUA_TASK_GET_FLAG(flag, "skip", RSPAMD_TASK_FLAG_SKIP); + LUA_TASK_GET_FLAG(flag, "learn_spam", RSPAMD_TASK_FLAG_LEARN_SPAM); + LUA_TASK_GET_FLAG(flag, "learn_ham", RSPAMD_TASK_FLAG_LEARN_HAM); + LUA_TASK_GET_FLAG(flag, "greylisted", RSPAMD_TASK_FLAG_GREYLISTED); + LUA_TASK_GET_FLAG(flag, "broken_headers", + RSPAMD_TASK_FLAG_BROKEN_HEADERS); + LUA_TASK_GET_FLAG(flag, "skip_process", + RSPAMD_TASK_FLAG_SKIP_PROCESS); + LUA_TASK_GET_FLAG(flag, "bad_unicode", + RSPAMD_TASK_FLAG_BAD_UNICODE); + LUA_TASK_GET_FLAG(flag, "mime", + RSPAMD_TASK_FLAG_MIME); + LUA_TASK_GET_FLAG(flag, "message_rewrite", + RSPAMD_TASK_FLAG_MESSAGE_REWRITE); + LUA_TASK_GET_PROTOCOL_FLAG(flag, "milter", + RSPAMD_TASK_PROTOCOL_FLAG_MILTER); + + if (!found) { + msg_warn_task("unknown flag requested: %s", flag); + lua_pushboolean(L, 0); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_flags(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + gint idx = 1; + guint flags, bit, i; + + if (task) { + lua_createtable(L, 8, 0); + + flags = task->flags; + + for (i = 0; i <= RSPAMD_TASK_FLAG_MAX_SHIFT; i++) { + bit = (1U << i); + + if (flags & bit) { + switch (bit) { + case RSPAMD_TASK_FLAG_PASS_ALL: + lua_pushstring(L, "pass_all"); + lua_rawseti(L, -2, idx++); + break; + case RSPAMD_TASK_FLAG_NO_LOG: + lua_pushstring(L, "no_log"); + lua_rawseti(L, -2, idx++); + break; + case RSPAMD_TASK_FLAG_NO_STAT: + lua_pushstring(L, "no_stat"); + lua_rawseti(L, -2, idx++); + break; + case RSPAMD_TASK_FLAG_SKIP: + lua_pushstring(L, "skip"); + lua_rawseti(L, -2, idx++); + break; + case RSPAMD_TASK_FLAG_BROKEN_HEADERS: + lua_pushstring(L, "broken_headers"); + lua_rawseti(L, -2, idx++); + break; + case RSPAMD_TASK_FLAG_LEARN_SPAM: + lua_pushstring(L, "learn_spam"); + lua_rawseti(L, -2, idx++); + break; + case RSPAMD_TASK_FLAG_LEARN_HAM: + lua_pushstring(L, "learn_ham"); + lua_rawseti(L, -2, idx++); + break; + case RSPAMD_TASK_FLAG_GREYLISTED: + lua_pushstring(L, "greylisted"); + lua_rawseti(L, -2, idx++); + break; + case RSPAMD_TASK_FLAG_SKIP_PROCESS: + lua_pushstring(L, "skip_process"); + lua_rawseti(L, -2, idx++); + break; + case RSPAMD_TASK_FLAG_MESSAGE_REWRITE: + lua_pushstring(L, "message_rewrite"); + lua_rawseti(L, -2, idx++); + break; + default: + break; + } + } + } + + if (task->protocol_flags & RSPAMD_TASK_PROTOCOL_FLAG_MILTER) { + lua_pushstring(L, "milter"); + lua_rawseti(L, -2, idx++); + } + if (task->protocol_flags & RSPAMD_TASK_PROTOCOL_FLAG_BODY_BLOCK) { + lua_pushstring(L, "body_block"); + lua_rawseti(L, -2, idx++); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_digest(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + gchar hexbuf[sizeof(MESSAGE_FIELD(task, digest)) * 2 + 1]; + gint r; + + if (task) { + if (task->message) { + r = rspamd_encode_hex_buf(MESSAGE_FIELD(task, digest), + sizeof(MESSAGE_FIELD(task, digest)), + hexbuf, sizeof(hexbuf) - 1); + + if (r > 0) { + hexbuf[r] = '\0'; + lua_pushstring(L, hexbuf); + } + else { + lua_pushnil(L); + } + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_learn(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + gboolean is_spam = FALSE; + const gchar *clname = NULL; + GError *err = NULL; + int ret = 1; + + if (task == NULL) { + return luaL_error(L, "invalid arguments"); + } + + is_spam = lua_toboolean(L, 2); + if (lua_gettop(L) > 2) { + clname = luaL_checkstring(L, 3); + } + + if (!rspamd_learn_task_spam(task, is_spam, clname, &err)) { + lua_pushboolean(L, FALSE); + if (err != NULL) { + lua_pushstring(L, err->message); + ret = 2; + } + } + else { + lua_pushboolean(L, TRUE); + } + + return ret; +} + +static gint +lua_task_set_settings(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + ucl_object_t *settings; + const ucl_object_t *act, *metric_elt, *vars, *cur; + ucl_object_iter_t it = NULL; + struct rspamd_scan_result *mres; + guint i; + + settings = ucl_object_lua_import(L, 2); + + if (settings != NULL && task != NULL) { + + if (task->settings) { + /* Do not allow to set settings on top of the existing ones */ + ucl_object_unref(settings); + + return luaL_error(L, "invalid invocation: settings has been already set"); + } + + metric_elt = ucl_object_lookup(settings, DEFAULT_METRIC); + + if (metric_elt) { + task->settings = ucl_object_ref(metric_elt); + ucl_object_unref(settings); + } + else { + task->settings = settings; + } + + act = ucl_object_lookup(task->settings, "actions"); + + if (act && ucl_object_type(act) == UCL_OBJECT) { + /* Adjust desired actions */ + mres = task->result; + + it = NULL; + + while ((cur = ucl_object_iterate(act, &it, true)) != NULL) { + const gchar *act_name = ucl_object_key(cur); + struct rspamd_action_config *action_config = NULL; + double act_score; + enum rspamd_action_type act_type; + + if (!rspamd_action_from_str(act_name, &act_type)) { + act_type = -1; + } + + for (i = 0; i < mres->nactions; i++) { + struct rspamd_action_config *cur_act = &mres->actions_config[i]; + + if (cur_act->action->action_type == METRIC_ACTION_CUSTOM && + act_type == -1) { + /* Compare by name */ + if (g_ascii_strcasecmp(act_name, cur_act->action->name) == 0) { + action_config = cur_act; + break; + } + } + else { + if (cur_act->action->action_type == act_type) { + action_config = cur_act; + break; + } + } + } + + if (!action_config) { + act_score = ucl_object_todouble(cur); + if (!isnan(act_score)) { + struct rspamd_action *new_act; + + new_act = rspamd_config_get_action(task->cfg, act_name); + + if (new_act == NULL) { + /* New action! */ + msg_info_task("added new action %s with threshold %.2f " + "due to settings", + act_name, + act_score); + new_act = rspamd_mempool_alloc0(task->task_pool, + sizeof(*new_act)); + new_act->name = rspamd_mempool_strdup(task->task_pool, act_name); + new_act->action_type = METRIC_ACTION_CUSTOM; + new_act->threshold = act_score; + } + else { + /* A disabled action that is enabled */ + msg_info_task("enabled disabled action %s with threshold %.2f " + "due to settings", + act_name, + act_score); + } + + /* Insert it to the mres structure */ + gsize new_actions_cnt = mres->nactions + 1; + struct rspamd_action_config *old_actions = mres->actions_config; + + mres->actions_config = rspamd_mempool_alloc(task->task_pool, + sizeof(struct rspamd_action_config) * new_actions_cnt); + memcpy(mres->actions_config, old_actions, + sizeof(struct rspamd_action_config) * mres->nactions); + mres->actions_config[mres->nactions].action = new_act; + mres->actions_config[mres->nactions].cur_limit = act_score; + mres->nactions++; + } + /* Disabled/missing action is disabled one more time, not an error */ + } + else { + /* Found the existing configured action */ + if (ucl_object_type(cur) == UCL_NULL) { + /* Disable action completely */ + action_config->flags |= RSPAMD_ACTION_RESULT_DISABLED; + msg_info_task("disabled action %s due to settings", + action_config->action->name); + } + else { + act_score = ucl_object_todouble(cur); + if (isnan(act_score)) { + msg_info_task("disabled action %s threshold (was %.2f) due to settings", + action_config->action->name, + action_config->cur_limit); + action_config->flags |= RSPAMD_ACTION_RESULT_NO_THRESHOLD; + } + else { + action_config->cur_limit = act_score; + msg_debug_task("adjusted action %s: %.2f -> %.2f", + act_name, + action_config->cur_limit, + act_score); + } + } + } + } + } + + vars = ucl_object_lookup(task->settings, "variables"); + if (vars && ucl_object_type(vars) == UCL_OBJECT) { + /* Set memory pool variables */ + it = NULL; + + while ((cur = ucl_object_iterate(vars, &it, true)) != NULL) { + if (ucl_object_type(cur) == UCL_STRING) { + rspamd_mempool_set_variable(task->task_pool, + ucl_object_key(cur), rspamd_mempool_strdup(task->task_pool, ucl_object_tostring(cur)), NULL); + } + } + } + + rspamd_symcache_process_settings(task, task->cfg->cache); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static gint +lua_task_set_milter_reply(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + ucl_object_t *reply, *prev; + + reply = ucl_object_lua_import(L, 2); + + if (reply != NULL && task != NULL) { + prev = rspamd_mempool_get_variable(task->task_pool, + RSPAMD_MEMPOOL_MILTER_REPLY); + + if (prev) { + /* + * We need to be very special about the add_headers part + * If we want to insert some existing object, such as + * add_headers = { + * hdr = {value = val1, order = 1}, + * } + * + * and new header has something similar: + * add_headers = { + * hdr = {value = val2, order = 1}, + * } + * + * then we need to convert it to an array... + * + * add_headers = { + * hdr = [{value = val1, order = 1}, {value = val2, order = 1}], + * } + * + * UCL itself cannot do it directly. So the trick is to extract the + * original object, pack it into an array and then insert it back. + * + * I wish there was a simpler way to do it... + */ + const ucl_object_t *add_hdrs = ucl_object_lookup(prev, "add_headers"); + const ucl_object_t *nadd_hdrs = ucl_object_lookup(reply, "add_headers"); + + if (add_hdrs && nadd_hdrs) { + ucl_object_iter_t it = NULL; + const ucl_object_t *cur; + + while ((cur = ucl_object_iterate(nadd_hdrs, &it, true)) != NULL) { + gsize klen; + const gchar *key = ucl_object_keyl(cur, &klen); + const ucl_object_t *existing; + + existing = ucl_object_lookup_len(add_hdrs, key, klen); + + if (existing && ucl_object_type(existing) != UCL_ARRAY) { + ucl_object_t *ar = ucl_object_typed_new(UCL_ARRAY); + + ucl_array_append(ar, ucl_object_ref(existing)); + /* Avoid double refcount */ + key = ucl_object_keyl(existing, &klen); + ucl_object_delete_keyl((ucl_object_t *) add_hdrs, key, klen); + ucl_object_insert_key((ucl_object_t *) add_hdrs, + ar, key, klen, false); + } + } + } + + if (!ucl_object_merge(prev, reply, false)) { + msg_err_task("internal error: cannot merge two objects when setting milter reply!"); + } + ucl_object_unref(reply); + } + else { + rspamd_mempool_set_variable(task->task_pool, + RSPAMD_MEMPOOL_MILTER_REPLY, + reply, + (rspamd_mempool_destruct_t) ucl_object_unref); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static gint +lua_task_get_settings(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + + if (task != NULL) { + + if (task->settings) { + return ucl_object_push_lua(L, task->settings, true); + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_lookup_settings(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *key = NULL; + const ucl_object_t *elt; + + if (task != NULL) { + + if (lua_isstring(L, 2)) { + key = lua_tostring(L, 2); + } + + if (task->settings) { + if (key == NULL) { + return ucl_object_push_lua(L, task->settings, true); + } + else { + elt = ucl_object_lookup(task->settings, key); + + if (elt) { + return ucl_object_push_lua(L, elt, true); + } + else { + lua_pushnil(L); + } + } + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_settings_id(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + + if (task != NULL) { + + if (task->settings_elt) { + lua_pushinteger(L, task->settings_elt->id); + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_set_settings_id(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + guint32 id = lua_tointeger(L, 2); + + if (task != NULL && id != 0) { + + struct rspamd_config_settings_elt *selt = + rspamd_config_find_settings_id_ref(task->cfg, id); + + if (selt == NULL) { + return luaL_error(L, "settings id %f is unknown", (lua_Number) id); + } + if (task->settings_elt) { + /* Overwrite existing settings from Lua */ + REF_RELEASE(task->settings_elt); + lua_pushboolean(L, true); + } + else { + lua_pushboolean(L, false); + } + + task->settings_elt = selt; + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_cache_get(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *key = luaL_checkstring(L, 2); + + if (task && key) { + if (!lua_task_get_cached(L, task, key)) { + lua_pushnil(L); + } + } + else { + luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_cache_set(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *key = luaL_checkstring(L, 2); + + if (task && key && lua_gettop(L) >= 3) { + lua_task_set_cached(L, task, key, 3); + } + else { + luaL_error(L, "invalid arguments"); + } + + return 0; +} + +struct lua_file_cbdata { + gchar *fname; + gint fd; + gboolean keep; +}; + +static void +lua_tmp_file_dtor(gpointer p) +{ + struct lua_file_cbdata *cbdata = p; + + if (!cbdata->keep) { + unlink(cbdata->fname); + } + + close(cbdata->fd); +} + +static gint +lua_task_store_in_file(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + gboolean force_new = FALSE, keep = FALSE; + gchar fpath[PATH_MAX]; + const gchar *tmpmask = NULL, *fname = NULL; + guint mode = 00600; + gint fd; + struct lua_file_cbdata *cbdata; + GError *err = NULL; + + if (task) { + if (lua_istable(L, 2)) { + if (!rspamd_lua_parse_table_arguments(L, 2, &err, + RSPAMD_LUA_PARSE_ARGUMENTS_DEFAULT, + "filename=S;tmpmask=S;mode=I;force_new=B;keep=B", + &fname, &tmpmask, &mode, &force_new, &keep)) { + msg_err_task("cannot get parameters list: %e", err); + + if (err) { + g_error_free(err); + } + + return luaL_error(L, "invalid arguments"); + } + } + else if (lua_isnumber(L, 2)) { + mode = lua_tointeger(L, 2); + } + + if (!force_new && (task->flags & RSPAMD_TASK_FLAG_FILE) && + task->msg.fpath) { + lua_pushstring(L, task->msg.fpath); + } + else { + if (fname == NULL) { + if (tmpmask == NULL) { + rspamd_snprintf(fpath, sizeof(fpath), "%s%c%s", + task->cfg->temp_dir, + G_DIR_SEPARATOR, "rmsg-XXXXXXXXXX"); + } + else { + rspamd_snprintf(fpath, sizeof(fpath), "%s", tmpmask); + } + + fd = g_mkstemp_full(fpath, O_WRONLY | O_CREAT | O_EXCL, mode); + fname = fpath; + + if (fd != -1) { + fchmod(fd, mode); + } + } + else { + fd = rspamd_file_xopen(fname, O_WRONLY | O_CREAT | O_EXCL, + (guint) mode, FALSE); + } + + if (fd == -1) { + msg_err_task("cannot save file: %s", strerror(errno)); + lua_pushnil(L); + } + else { + if (write(fd, task->msg.begin, task->msg.len) == -1) { + msg_err_task("cannot write file %s: %s", fpath, + strerror(errno)); + unlink(fname); + close(fd); + lua_pushnil(L); + + return 1; + } + + cbdata = rspamd_mempool_alloc(task->task_pool, sizeof(*cbdata)); + cbdata->fd = fd; + cbdata->fname = rspamd_mempool_strdup(task->task_pool, fname); + cbdata->keep = keep; + lua_pushstring(L, cbdata->fname); + rspamd_mempool_add_destructor(task->task_pool, + lua_tmp_file_dtor, cbdata); + } + } + } + else { + luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_process_regexp(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + struct rspamd_lua_regexp *re = NULL; + gboolean strong = FALSE; + const gchar *type_str = NULL, *header_str = NULL; + gsize header_len = 0; + GError *err = NULL; + gint ret = 0; + enum rspamd_re_type type = RSPAMD_RE_BODY; + + /* + * - `re`* : regular expression object + * - `type`*: type of regular expression: + * + `mime`: mime regexp + * + `rawmime`: raw mime regexp + * + `header`: header regexp + * + `rawheader`: raw header expression + * + `body`: raw body regexp + * + `url`: url regexp + * - `header`: for header and rawheader regexp means the name of header + * - `strong`: case sensitive match for headers + */ + if (task != NULL) { + if (!rspamd_lua_parse_table_arguments(L, 2, &err, + RSPAMD_LUA_PARSE_ARGUMENTS_DEFAULT, + "*re=U{regexp};*type=S;header=V;strong=B", + &re, &type_str, &header_len, &header_str, + &strong)) { + msg_err_task("cannot get parameters list: %e", err); + + if (err) { + g_error_free(err); + } + + return luaL_error(L, "invalid arguments"); + } + else { + type = rspamd_re_cache_type_from_string(type_str); + + if ((type == RSPAMD_RE_HEADER || type == RSPAMD_RE_RAWHEADER) && header_str == NULL) { + msg_err_task( + "header argument is mandatory for header/rawheader regexps"); + } + else { + ret = rspamd_re_cache_process(task, re->re, type, + (gpointer) header_str, header_len, strong); + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + lua_pushinteger(L, ret); + + return 1; +} + +static gint +lua_task_get_metric_result(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + struct rspamd_scan_result *metric_res; + struct rspamd_action *action; + + if (task) { + metric_res = task->result; + + if (lua_isstring(L, 2)) { + metric_res = rspamd_find_metric_result(task, lua_tostring(L, 2)); + + if (metric_res == NULL) { + lua_pushnil(L); + + return 1; + } + } + + /* Fields added: + * - `score`: current score + * - `action`: current action as a string + * - `nnegative`: number of negative rules matched + * - `npositive`: number of positive rules matched + * - `positive_score`: total score for positive rules + * - `negative_score`: total score for negative rules + * - `passthrough`: set to true if message has a passthrough result + */ + lua_createtable(L, 0, 7); + + lua_pushstring(L, "score"); + lua_pushnumber(L, metric_res->score); + lua_settable(L, -3); + + action = rspamd_check_action_metric(task, NULL, metric_res); + + if (action) { + lua_pushstring(L, "action"); + lua_pushstring(L, action->name); + lua_settable(L, -3); + } + + lua_pushstring(L, "nnegative"); + lua_pushnumber(L, metric_res->nnegative); + lua_settable(L, -3); + + lua_pushstring(L, "npositive"); + lua_pushnumber(L, metric_res->npositive); + lua_settable(L, -3); + + lua_pushstring(L, "positive_score"); + lua_pushnumber(L, metric_res->positive_score); + lua_settable(L, -3); + + lua_pushstring(L, "negative_score"); + lua_pushnumber(L, metric_res->negative_score); + lua_settable(L, -3); + + lua_pushstring(L, "passthrough"); + lua_pushboolean(L, !!(metric_res->passthrough_result != NULL)); + lua_settable(L, -3); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_metric_score(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + gdouble rs; + struct rspamd_scan_result *metric_res; + + if (task) { + metric_res = task->result; + + if (lua_isstring(L, 2)) { + metric_res = rspamd_find_metric_result(task, lua_tostring(L, 2)); + } + + if (metric_res != NULL) { + lua_createtable(L, 2, 0); + lua_pushnumber(L, isnan(metric_res->score) ? 0.0 : metric_res->score); + rs = rspamd_task_get_required_score(task, metric_res); + lua_rawseti(L, -2, 1); + lua_pushnumber(L, rs); + lua_rawseti(L, -2, 2); + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_metric_action(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + struct rspamd_action *action; + + if (task) { + struct rspamd_scan_result *mres = task->result; + + if (lua_isstring(L, 2)) { + mres = rspamd_find_metric_result(task, lua_tostring(L, 2)); + } + + if (mres == NULL) { + lua_pushnil(L); + + return 1; + } + + action = rspamd_check_action_metric(task, NULL, mres); + lua_pushstring(L, action->name); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_set_metric_score(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + struct rspamd_scan_result *metric_res; + gdouble nscore; + + if (lua_isnumber(L, 2)) { + nscore = luaL_checknumber(L, 2); + } + else { + nscore = luaL_checknumber(L, 3); + } + + if (task) { + metric_res = task->result; + + if (lua_isstring(L, 4)) { + metric_res = rspamd_find_metric_result(task, lua_tostring(L, 4)); + } + + if (metric_res != NULL) { + msg_debug_task("set metric score from %.2f to %.2f", + metric_res->score, nscore); + metric_res->score = nscore; + lua_pushboolean(L, true); + } + else { + lua_pushboolean(L, false); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_disable_action(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *action_name; + struct rspamd_action_config *action_res; + + action_name = luaL_checkstring(L, 2); + + if (task && action_name) { + + for (guint i = 0; i < task->result->nactions; i++) { + action_res = &task->result->actions_config[i]; + + if (strcmp(action_name, action_res->action->name) == 0) { + if (isnan(action_res->cur_limit)) { + lua_pushboolean(L, false); + } + else { + action_res->cur_limit = NAN; + lua_pushboolean(L, true); + } + + break; + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_newlines_type(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + + if (task) { + if (task->message) { + switch (MESSAGE_FIELD(task, nlines_type)) { + case RSPAMD_TASK_NEWLINES_CR: + lua_pushstring(L, "cr"); + break; + case RSPAMD_TASK_NEWLINES_LF: + lua_pushstring(L, "lf"); + break; + case RSPAMD_TASK_NEWLINES_CRLF: + default: + lua_pushstring(L, "crlf"); + break; + } + } + else { + lua_pushstring(L, "crlf"); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static void +lua_push_stat_token(lua_State *L, rspamd_token_t *tok) +{ + gchar numbuf[64]; + + /* Table values + * - `data`: 64 bit number encoded as a string + * - `t1`: the first token (if any) + * - `t2`: the second token (if any) + * - `win`: window index + * - `flag`: table of strings: + * - `text`: text token + * - `meta`: meta token + * - `lua`: lua meta token + * - `exception`: exception + * - `subject`: subject token + * - `unigram`: unigram token + */ + lua_createtable(L, 0, 5); + + rspamd_snprintf(numbuf, sizeof(numbuf), "%uL", tok->data); + lua_pushstring(L, "data"); + lua_pushstring(L, numbuf); + lua_settable(L, -3); + + if (tok->t1) { + lua_pushstring(L, "t1"); + lua_pushlstring(L, tok->t1->stemmed.begin, tok->t1->stemmed.len); + lua_settable(L, -3); + } + + if (tok->t2) { + lua_pushstring(L, "t2"); + lua_pushlstring(L, tok->t2->stemmed.begin, tok->t2->stemmed.len); + lua_settable(L, -3); + } + + lua_pushstring(L, "win"); + lua_pushinteger(L, tok->window_idx); + lua_settable(L, -3); + + lua_pushstring(L, "flags"); + lua_createtable(L, 0, 5); + + /* Flags */ + { + if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_TEXT) { + lua_pushstring(L, "text"); + lua_pushboolean(L, true); + lua_settable(L, -3); + } + if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_META) { + lua_pushstring(L, "meta"); + lua_pushboolean(L, true); + lua_settable(L, -3); + } + if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_LUA_META) { + lua_pushstring(L, "lua"); + lua_pushboolean(L, true); + lua_settable(L, -3); + } + if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_EXCEPTION) { + lua_pushstring(L, "exception"); + lua_pushboolean(L, true); + lua_settable(L, -3); + } + if (tok->flags & RSPAMD_STAT_TOKEN_FLAG_HEADER) { + lua_pushstring(L, "header"); + lua_pushboolean(L, true); + lua_settable(L, -3); + } + } + lua_settable(L, -3); +} + +static gint +lua_task_get_stat_tokens(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + guint i; + rspamd_token_t *tok; + + if (task) { + if (!task->tokens) { + rspamd_stat_process_tokenize(NULL, task); + } + + if (!task->tokens) { + lua_pushnil(L); + } + else { + lua_createtable(L, task->tokens->len, 0); + + PTR_ARRAY_FOREACH(task->tokens, i, tok) + { + lua_push_stat_token(L, tok); + lua_rawseti(L, -2, i + 1); + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_set_metric_subject(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *subject; + + subject = luaL_checkstring(L, 2); + + if (task && subject) { + rspamd_mempool_set_variable(task->task_pool, "metric_subject", + rspamd_mempool_strdup(task->task_pool, subject), NULL); + lua_pushboolean(L, true); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_protocol_reply(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + guint flags = 0; + ucl_object_t *obj; + + if (!task) { + return luaL_error(L, "invalid arguments"); + } + + if (!(task->processed_stages & (RSPAMD_TASK_STAGE_POST_FILTERS >> 1))) { + return luaL_error(L, "must not be called before post-filters"); + } + + if (lua_istable(L, 2)) { + for (lua_pushnil(L); lua_next(L, 2); lua_pop(L, 1)) { + if (lua_isstring(L, -1)) { + const gchar *str = lua_tostring(L, -1); + + if (strcmp(str, "default") == 0) { + flags |= RSPAMD_PROTOCOL_DEFAULT; + } + else if (strcmp(str, "basic") == 0) { + flags |= RSPAMD_PROTOCOL_BASIC; + } + else if (strcmp(str, "metrics") == 0) { + flags |= RSPAMD_PROTOCOL_METRICS; + } + else if (strcmp(str, "messages") == 0) { + flags |= RSPAMD_PROTOCOL_MESSAGES; + } + else if (strcmp(str, "rmilter") == 0) { + flags |= RSPAMD_PROTOCOL_RMILTER; + } + else if (strcmp(str, "dkim") == 0) { + flags |= RSPAMD_PROTOCOL_DKIM; + } + else if (strcmp(str, "extra") == 0) { + flags |= RSPAMD_PROTOCOL_EXTRA; + } + else { + msg_err_task("invalid protocol flag: %s", str); + } + } + } + } + else { + flags = RSPAMD_PROTOCOL_DEFAULT; + } + + obj = rspamd_protocol_write_ucl(task, flags); + + if (obj) { + ucl_object_push_lua(L, obj, true); + } + else { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_task_headers_foreach(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + enum rspamd_lua_task_header_type how = RSPAMD_TASK_HEADER_PUSH_SIMPLE; + struct rspamd_lua_regexp *re = NULL; + struct rspamd_mime_header *hdr, *cur; + gint old_top; + + if (task && lua_isfunction(L, 2)) { + if (task->message) { + if (lua_istable(L, 3)) { + lua_pushstring(L, "full"); + lua_gettable(L, 3); + + if (lua_isboolean(L, -1) && lua_toboolean(L, -1)) { + how = RSPAMD_TASK_HEADER_PUSH_FULL; + } + + lua_pop(L, 1); + + lua_pushstring(L, "raw"); + lua_gettable(L, 3); + + if (lua_isboolean(L, -1) && lua_toboolean(L, -1)) { + how = RSPAMD_TASK_HEADER_PUSH_RAW; + } + + lua_pop(L, 1); + + lua_pushstring(L, "regexp"); + lua_gettable(L, 3); + + if (lua_isuserdata(L, -1)) { + RSPAMD_LUA_CHECK_UDATA_PTR_OR_RETURN(L, -1, "rspamd{regexp}", + struct rspamd_lua_regexp, re); + } + + lua_pop(L, 1); + } + + if (MESSAGE_FIELD(task, headers_order)) { + hdr = MESSAGE_FIELD(task, headers_order); + + LL_FOREACH2(hdr, cur, ord_next) + { + if (re && re->re) { + if (!rspamd_regexp_match(re->re, cur->name, + strlen(cur->name), FALSE)) { + continue; + } + } + + old_top = lua_gettop(L); + lua_pushvalue(L, 2); + lua_pushstring(L, cur->name); + rspamd_lua_push_header(L, cur, how); + + if (lua_pcall(L, 2, LUA_MULTRET, 0) != 0) { + msg_err("call to header_foreach failed: %s", + lua_tostring(L, -1)); + lua_settop(L, old_top); + break; + } + else { + if (lua_gettop(L) > old_top) { + if (lua_isboolean(L, old_top + 1)) { + if (lua_toboolean(L, old_top + 1)) { + lua_settop(L, old_top); + break; + } + } + } + } + + lua_settop(L, old_top); + } + } + } /* if (task->message) */ + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static gint +lua_task_modify_header(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *hname = luaL_checkstring(L, 2); + + if (hname && task && lua_type(L, 3) == LUA_TTABLE) { + if (task->message) { + ucl_object_t *mods = ucl_object_lua_import(L, 3); + + rspamd_message_set_modified_header(task, + MESSAGE_FIELD(task, raw_headers), + hname, + mods, + &(MESSAGE_FIELD(task, headers_order))); + ucl_object_unref(mods); + + lua_pushboolean(L, true); + } + else { + lua_pushboolean(L, false); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_get_meta_words(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + enum rspamd_lua_words_type how = RSPAMD_LUA_WORDS_STEM; + + if (task == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (task->meta_words == NULL) { + lua_createtable(L, 0, 0); + } + else { + if (lua_type(L, 2) == LUA_TSTRING) { + const gchar *how_str = lua_tostring(L, 2); + + if (strcmp(how_str, "stem") == 0) { + how = RSPAMD_LUA_WORDS_STEM; + } + else if (strcmp(how_str, "norm") == 0) { + how = RSPAMD_LUA_WORDS_NORM; + } + else if (strcmp(how_str, "raw") == 0) { + how = RSPAMD_LUA_WORDS_RAW; + } + else if (strcmp(how_str, "full") == 0) { + how = RSPAMD_LUA_WORDS_FULL; + } + else { + return luaL_error(L, "unknown words type: %s", how_str); + } + } + + return rspamd_lua_push_words(L, task->meta_words, how); + } + + return 1; +} + +static guint +lua_lookup_words_array(lua_State *L, + gint cbpos, + struct rspamd_task *task, + struct rspamd_lua_map *map, + GArray *words) +{ + rspamd_stat_token_t *tok; + guint i, nmatched = 0; + gint err_idx; + gboolean matched; + const gchar *key; + gsize keylen; + + for (i = 0; i < words->len; i++) { + tok = &g_array_index(words, rspamd_stat_token_t, i); + + matched = FALSE; + + if (tok->normalized.len == 0) { + continue; + } + + key = tok->normalized.begin; + keylen = tok->normalized.len; + + switch (map->type) { + case RSPAMD_LUA_MAP_SET: + case RSPAMD_LUA_MAP_HASH: + /* We know that tok->normalized is zero terminated in fact */ + if (rspamd_match_hash_map(map->data.hash, key, keylen)) { + matched = TRUE; + } + break; + case RSPAMD_LUA_MAP_REGEXP: + case RSPAMD_LUA_MAP_REGEXP_MULTIPLE: + if (rspamd_match_regexp_map_single(map->data.re_map, key, + keylen)) { + matched = TRUE; + } + break; + default: + g_assert_not_reached(); + break; + } + + if (matched) { + nmatched++; + + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + lua_pushvalue(L, cbpos); /* Function */ + rspamd_lua_push_full_word(L, tok); + + if (lua_pcall(L, 1, 0, err_idx) != 0) { + msg_err_task("cannot call callback function for lookup words: %s", + lua_tostring(L, -1)); + } + + lua_settop(L, err_idx - 1); + } + } + + return nmatched; +} + +static gint +lua_task_lookup_words(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + struct rspamd_lua_map *map = lua_check_map(L, 2); + struct rspamd_mime_text_part *tp; + + guint i, matches = 0; + + if (task == NULL || map == NULL || task->message == NULL || lua_type(L, 3) != LUA_TFUNCTION) { + return luaL_error(L, "invalid arguments"); + } + + if (map->type != RSPAMD_LUA_MAP_SET && + map->type != RSPAMD_LUA_MAP_REGEXP && + map->type != RSPAMD_LUA_MAP_HASH && + map->type != RSPAMD_LUA_MAP_REGEXP_MULTIPLE) { + return luaL_error(L, "invalid map type"); + } + + PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, tp) + { + if (tp->utf_words) { + matches += lua_lookup_words_array(L, 3, task, map, tp->utf_words); + } + } + + if (task->meta_words) { + matches += lua_lookup_words_array(L, 3, task, map, task->meta_words); + } + + lua_pushinteger(L, matches); + + return 1; +} + +static gint +lua_task_topointer(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + + if (task) { + /* XXX: this might cause issues on arm64 and LuaJIT */ + lua_pushlightuserdata(L, task); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_task_add_named_result(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + const gchar *name = luaL_checkstring(L, 2); + gint cbref; + + if (task && name && lua_isfunction(L, 3)) { + lua_pushvalue(L, 3); + cbref = luaL_ref(L, LUA_REGISTRYINDEX); + rspamd_create_metric_result(task, name, cbref); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static gint +lua_task_get_all_named_results(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_task *task = lua_check_task(L, 1); + + if (task) { + gint n = 0; + struct rspamd_scan_result *res; + + DL_COUNT(task->result, res, n); + lua_createtable(L, n, 0); + n = 1; + + DL_FOREACH(task->result, res) + { + if (res->name != NULL) { + lua_pushstring(L, res->name); + } + else { + lua_pushstring(L, DEFAULT_METRIC); + } + + lua_rawseti(L, -2, n++); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + + +/* Image functions */ +static gint +lua_image_get_width(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_image *img = lua_check_image(L); + + if (img != NULL) { + lua_pushinteger(L, img->width); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_image_get_height(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_image *img = lua_check_image(L); + + if (img != NULL) { + lua_pushinteger(L, img->height); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_image_get_type(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_image *img = lua_check_image(L); + + if (img != NULL) { + lua_pushstring(L, rspamd_image_type_str(img->type)); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_image_get_size(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_image *img = lua_check_image(L); + + if (img != NULL) { + lua_pushinteger(L, img->data->len); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_image_get_filename(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_image *img = lua_check_image(L); + + if (img != NULL) { + if (img->filename != NULL) { + lua_pushlstring(L, img->filename->begin, img->filename->len); + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/* Archive methods */ +static gint +lua_archive_get_type(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_archive *arch = lua_check_archive(L); + + if (arch != NULL) { + lua_pushstring(L, rspamd_archive_type_str(arch->type)); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_archive_get_files(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_archive *arch = lua_check_archive(L); + guint i, max_files = 0; + struct rspamd_archive_file *f; + + if (arch != NULL) { + if (lua_isnumber(L, 2)) { + max_files = lua_tointeger(L, 2); + max_files = MIN(arch->files->len, max_files); + } + else { + max_files = arch->files->len; + } + + lua_createtable(L, max_files, 0); + + for (i = 0; i < max_files; i++) { + f = g_ptr_array_index(arch->files, i); + + lua_pushlstring(L, f->fname->str, f->fname->len); + lua_rawseti(L, -2, i + 1); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_archive_get_files_full(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_archive *arch = lua_check_archive(L); + guint i, max_files = 0; + struct rspamd_archive_file *f; + + if (arch != NULL) { + if (lua_isnumber(L, 2)) { + max_files = lua_tointeger(L, 2); + max_files = MIN(arch->files->len, max_files); + } + else { + max_files = arch->files->len; + } + + lua_createtable(L, max_files, 0); + + for (i = 0; i < max_files; i++) { + f = g_ptr_array_index(arch->files, i); + + lua_createtable(L, 0, 4); + + lua_pushstring(L, "name"); + lua_pushlstring(L, f->fname->str, f->fname->len); + lua_settable(L, -3); + + lua_pushstring(L, "compressed_size"); + lua_pushinteger(L, f->compressed_size); + lua_settable(L, -3); + + lua_pushstring(L, "uncompressed_size"); + lua_pushinteger(L, f->uncompressed_size); + lua_settable(L, -3); + + lua_pushstring(L, "encrypted"); + lua_pushboolean(L, (f->flags & RSPAMD_ARCHIVE_FILE_ENCRYPTED) ? true : false); + lua_settable(L, -3); + + lua_rawseti(L, -2, i + 1); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_archive_is_encrypted(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_archive *arch = lua_check_archive(L); + + if (arch != NULL) { + lua_pushboolean(L, (arch->flags & RSPAMD_ARCHIVE_ENCRYPTED) ? true : false); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_archive_is_obfuscated(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_archive *arch = lua_check_archive(L); + + if (arch != NULL) { + lua_pushboolean(L, + (arch->flags & RSPAMD_ARCHIVE_HAS_OBFUSCATED_FILES) ? true : false); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_archive_is_unreadable(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_archive *arch = lua_check_archive(L); + + if (arch != NULL) { + lua_pushboolean(L, (arch->flags & RSPAMD_ARCHIVE_CANNOT_READ) ? true : false); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_archive_get_size(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_archive *arch = lua_check_archive(L); + + if (arch != NULL) { + lua_pushinteger(L, arch->size); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_archive_get_filename(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_archive *arch = lua_check_archive(L); + + if (arch != NULL) { + lua_pushlstring(L, arch->archive_name->begin, arch->archive_name->len); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/* Init part */ + +static gint +lua_load_task(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, tasklib_f); + + return 1; +} + +static void +luaopen_archive(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{archive}", archivelib_m); + lua_pop(L, 1); +} + +void luaopen_task(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{task}", tasklib_m); + lua_pop(L, 1); + + rspamd_lua_add_preload(L, "rspamd_task", lua_load_task); + + luaopen_archive(L); +} + +void luaopen_image(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{image}", imagelib_m); + lua_pop(L, 1); +} + +void rspamd_lua_task_push(lua_State *L, struct rspamd_task *task) +{ + struct rspamd_task **ptask; + + ptask = lua_newuserdata(L, sizeof(gpointer)); + rspamd_lua_setclass(L, "rspamd{task}", -1); + *ptask = task; +} diff --git a/src/lua/lua_tcp.c b/src/lua/lua_tcp.c new file mode 100644 index 0000000..45faa79 --- /dev/null +++ b/src/lua/lua_tcp.c @@ -0,0 +1,2566 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" +#include "lua_thread_pool.h" +#include "libserver/ssl_util.h" +#include "utlist.h" +#include "unix-std.h" +#include <math.h> + +static const gchar *M = "rspamd lua tcp"; + +/*** + * @module rspamd_tcp + * Rspamd TCP module represents generic TCP asynchronous client available from LUA code. + * This module hides all complexity: DNS resolving, sessions management, zero-copy + * text transfers and so on under the hood. It can work in partial or complete modes: + * + * - partial mode is used when you need to call a continuation routine each time data is available for read + * - complete mode calls for continuation merely when all data is read from socket (e.g. when a server sends reply and closes a connection) + * @example +local logger = require "rspamd_logger" +local tcp = require "rspamd_tcp" + +rspamd_config.SYM = function(task) + + local function cb(err, data) + logger.infox('err: %1, data: %2', err, tostring(data)) + end + + tcp.request({ + task = task, + host = "google.com", + port = 80, + data = {"GET / HTTP/1.0\r\n", "Host: google.com\r\n", "\r\n"}, + callback = cb}) +end + +-- New TCP syntax test +rspamd_config:register_symbol({ + name = 'TCP_TEST', + type = "normal", + callback = function(task) + local logger = require "rspamd_logger" + local function rcpt_done_cb(err, data, conn) + logger.errx(task, 'RCPT: got reply: %s, error: %s', data, err) + conn:close() + end + local function rcpt_cb(err, conn) + logger.errx(task, 'written rcpt, error: %s', err) + conn:add_read(rcpt_done_cb, '\r\n') + end + local function from_done_cb(err, data, conn) + logger.errx(task, 'FROM: got reply: %s, error: %s', data, err) + conn:add_write(rcpt_cb, 'RCPT TO: <test@yandex.ru>\r\n') + end + local function from_cb(err, conn) + logger.errx(task, 'written from, error: %s', err) + conn:add_read(from_done_cb, '\r\n') + end + local function hello_done_cb(err, data, conn) + logger.errx(task, 'HELO: got reply: %s, error: %s', data, err) + conn:add_write(from_cb, 'MAIL FROM: <>\r\n') + end + local function hello_cb(err, conn) + logger.errx(task, 'written hello, error: %s', err) + conn:add_read(hello_done_cb, '\r\n') + end + local function init_cb(err, data, conn) + logger.errx(task, 'got reply: %s, error: %s', data, err) + conn:add_write(hello_cb, 'HELO example.com\r\n') + end + tcp.request{ + task = task, + callback = init_cb, + stop_pattern = '\r\n', + host = 'mx.yandex.ru', + port = 25 + } + end, + priority = 10, +}) + */ + +LUA_FUNCTION_DEF(tcp, request); + +/*** + * @function rspamd_tcp.connect_sync() + * + * Creates pseudo-synchronous TCP connection. + * Each method of the connection requiring IO, becomes a yielding point, + * i.e. current thread Lua thread is get suspended and resumes as soon as IO is done + * + * This class represents low-level API, using of "lua_tcp_sync" module is recommended. + * + * @example + +local rspamd_tcp = require "rspamd_tcp" +local logger = require "rspamd_logger" + +local function http_simple_tcp_symbol(task) + + local err + local is_ok, connection = rspamd_tcp.connect_sync { + task = task, + host = '127.0.0.1', + timeout = 20, + port = 18080, + ssl = false, -- If SSL connection is needed + ssl_verify = true, -- set to false if verify is not needed + } + + is_ok, err = connection:write('GET /request_sync HTTP/1.1\r\nConnection: keep-alive\r\n\r\n') + + logger.errx(task, 'write %1, %2', is_ok, err) + if not is_ok then + logger.errx(task, 'write error: %1', err) + end + + local data + is_ok, data = connection:read_once(); + + logger.errx(task, 'read_once: is_ok: %1, data: %2', is_ok, data) + + is_ok, err = connection:write("POST /request2 HTTP/1.1\r\n\r\n") + logger.errx(task, 'write[2] %1, %2', is_ok, err) + + is_ok, data = connection:read_once(); + logger.errx(task, 'read_once[2]: is_ok %1, data: %2', is_ok, data) + + connection:close() +end + +rspamd_config:register_symbol({ + name = 'SIMPLE_TCP_TEST', + score = 1.0, + callback = http_simple_tcp_symbol, + no_squeeze = true +}) + * + */ +LUA_FUNCTION_DEF(tcp, connect_sync); + +/*** + * @method tcp:close() + * + * Closes TCP connection + */ +LUA_FUNCTION_DEF(tcp, close); + +/*** + * @method tcp:add_read(callback, [pattern]) + * + * Adds new read event to the tcp connection + * @param {function} callback to be called when data is read + * @param {string} pattern optional stop pattern + */ +LUA_FUNCTION_DEF(tcp, add_read); + +/*** + * @method tcp:add_write(callback, data) + * + * Adds new write event to the tcp connection + * @param {function} optional callback to be called when data is completely written + * @param {table/string/text} data to send to a remote server + */ +LUA_FUNCTION_DEF(tcp, add_write); + +/*** + * @method tcp:shift_callback() + * + * Shifts the current callback and go to the next one (if any) + */ +LUA_FUNCTION_DEF(tcp, shift_callback); + +/*** + * @method tcp:starttls([no_verify]) + * + * Starts tls connection + * @param {boolean} no_verify used to skip ssl verification + */ +LUA_FUNCTION_DEF(tcp, starttls); + +static const struct luaL_reg tcp_libf[] = { + LUA_INTERFACE_DEF(tcp, request), + {"new", lua_tcp_request}, + {"connect", lua_tcp_request}, + {"connect_sync", lua_tcp_connect_sync}, + {NULL, NULL}}; + +static const struct luaL_reg tcp_libm[] = { + LUA_INTERFACE_DEF(tcp, close), + LUA_INTERFACE_DEF(tcp, add_read), + LUA_INTERFACE_DEF(tcp, add_write), + LUA_INTERFACE_DEF(tcp, shift_callback), + LUA_INTERFACE_DEF(tcp, starttls), + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + +/*** + * @method tcp:close() + * + * Closes TCP connection + */ +LUA_FUNCTION_DEF(tcp_sync, close); + +/*** + * @method read_once() + * + * Performs one read operation. If syscall returned with EAGAIN/EINT, + * restarts the operation, so it always returns either data or error. + */ +LUA_FUNCTION_DEF(tcp_sync, read_once); + +/*** + * @method eof() + * + * True if last IO operation ended with EOF, i.e. endpoint closed connection + */ +LUA_FUNCTION_DEF(tcp_sync, eof); + +/*** + * @method shutdown() + * + * Half-shutdown TCP connection + */ +LUA_FUNCTION_DEF(tcp_sync, shutdown); + +/*** + * @method write() + * + * Writes data into the stream. If syscall returned with EAGAIN/EINT + * restarts the operation. If performs write() until all the passed + * data is written completely. + */ +LUA_FUNCTION_DEF(tcp_sync, write); + +LUA_FUNCTION_DEF(tcp_sync, gc); + +static void lua_tcp_sync_session_dtor(gpointer ud); + +static const struct luaL_reg tcp_sync_libm[] = { + LUA_INTERFACE_DEF(tcp_sync, close), + LUA_INTERFACE_DEF(tcp_sync, read_once), + LUA_INTERFACE_DEF(tcp_sync, write), + LUA_INTERFACE_DEF(tcp_sync, eof), + LUA_INTERFACE_DEF(tcp_sync, shutdown), + {"__gc", lua_tcp_sync_gc}, + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + +struct lua_tcp_read_handler { + gchar *stop_pattern; + guint plen; + gint cbref; +}; + +struct lua_tcp_write_handler { + struct iovec *iov; + guint iovlen; + gint cbref; + gsize pos; + gsize total_bytes; +}; + +enum lua_tcp_handler_type { + LUA_WANT_WRITE = 0, + LUA_WANT_READ, + LUA_WANT_CONNECT +}; + +struct lua_tcp_handler { + union { + struct lua_tcp_read_handler r; + struct lua_tcp_write_handler w; + } h; + enum lua_tcp_handler_type type; +}; + +struct lua_tcp_dtor { + rspamd_mempool_destruct_t dtor; + void *data; + struct lua_tcp_dtor *next; +}; + +#define LUA_TCP_FLAG_PARTIAL (1u << 0u) +#define LUA_TCP_FLAG_SHUTDOWN (1u << 2u) +#define LUA_TCP_FLAG_CONNECTED (1u << 3u) +#define LUA_TCP_FLAG_FINISHED (1u << 4u) +#define LUA_TCP_FLAG_SYNC (1u << 5u) +#define LUA_TCP_FLAG_RESOLVED (1u << 6u) +#define LUA_TCP_FLAG_SSL (1u << 7u) +#define LUA_TCP_FLAG_SSL_NOVERIFY (1u << 8u) + +#undef TCP_DEBUG_REFS +#ifdef TCP_DEBUG_REFS +#define TCP_RETAIN(x) \ + do { \ + msg_info("retain ref %p, refcount: %d", (x), (x)->ref.refcount); \ + REF_RETAIN(x); \ + } while (0) + +#define TCP_RELEASE(x) \ + do { \ + msg_info("release ref %p, refcount: %d", (x), (x)->ref.refcount); \ + REF_RELEASE(x); \ + } while (0) +#else +#define TCP_RETAIN(x) REF_RETAIN(x) +#define TCP_RELEASE(x) REF_RELEASE(x) +#endif + +struct lua_tcp_cbdata { + struct rspamd_async_session *session; + struct rspamd_async_event *async_ev; + struct ev_loop *event_loop; + rspamd_inet_addr_t *addr; + GByteArray *in; + GQueue *handlers; + gint fd; + gint connect_cb; + guint port; + guint flags; + gchar tag[7]; + struct rspamd_io_ev ev; + struct lua_tcp_dtor *dtors; + ref_entry_t ref; + struct rspamd_task *task; + struct rspamd_symcache_dynamic_item *item; + struct thread_entry *thread; + struct rspamd_config *cfg; + struct rspamd_ssl_connection *ssl_conn; + gchar *hostname; + struct upstream *up; + gboolean eof; +}; + +#define IS_SYNC(c) (((c)->flags & LUA_TCP_FLAG_SYNC) != 0) + +#define msg_debug_tcp(...) rspamd_conditional_debug_fast(NULL, cbd->addr, \ + rspamd_lua_tcp_log_id, "lua_tcp", cbd->tag, \ + G_STRFUNC, \ + __VA_ARGS__) + +INIT_LOG_MODULE(lua_tcp) + +static void lua_tcp_handler(int fd, short what, gpointer ud); +static void lua_tcp_plan_handler_event(struct lua_tcp_cbdata *cbd, + gboolean can_read, gboolean can_write); +static void lua_tcp_unregister_event(struct lua_tcp_cbdata *cbd); + +static void +lua_tcp_void_finalyser(gpointer arg) +{ +} + +static const gdouble default_tcp_timeout = 5.0; + +static struct rspamd_dns_resolver * +lua_tcp_global_resolver(struct ev_loop *ev_base, + struct rspamd_config *cfg) +{ + static struct rspamd_dns_resolver *global_resolver; + + if (cfg && cfg->dns_resolver) { + return cfg->dns_resolver; + } + + if (global_resolver == NULL) { + global_resolver = rspamd_dns_resolver_init(NULL, ev_base, cfg); + } + + return global_resolver; +} + +static gboolean +lua_tcp_shift_handler(struct lua_tcp_cbdata *cbd) +{ + struct lua_tcp_handler *hdl; + + hdl = g_queue_pop_head(cbd->handlers); + + if (hdl == NULL) { + /* We are done */ + return FALSE; + } + + if (hdl->type == LUA_WANT_READ) { + msg_debug_tcp("switch from read handler %d", hdl->h.r.cbref); + if (hdl->h.r.cbref && hdl->h.r.cbref != -1) { + luaL_unref(cbd->cfg->lua_state, LUA_REGISTRYINDEX, hdl->h.r.cbref); + } + + if (hdl->h.r.stop_pattern) { + g_free(hdl->h.r.stop_pattern); + } + } + else if (hdl->type == LUA_WANT_WRITE) { + msg_debug_tcp("switch from write handler %d", hdl->h.r.cbref); + if (hdl->h.w.cbref && hdl->h.w.cbref != -1) { + luaL_unref(cbd->cfg->lua_state, LUA_REGISTRYINDEX, hdl->h.w.cbref); + } + + if (hdl->h.w.iov) { + g_free(hdl->h.w.iov); + } + } + else { + msg_debug_tcp("removing connect handler"); + /* LUA_WANT_CONNECT: it doesn't allocate anything, nothing to do here */ + } + + g_free(hdl); + + return TRUE; +} + +static void +lua_tcp_fin(gpointer arg) +{ + struct lua_tcp_cbdata *cbd = (struct lua_tcp_cbdata *) arg; + struct lua_tcp_dtor *dtor, *dttmp; + + if (IS_SYNC(cbd) && cbd->task) { + /* + pointer is now becoming invalid, we should remove registered destructor, + all the necessary steps are done here + */ + rspamd_mempool_replace_destructor(cbd->task->task_pool, + lua_tcp_sync_session_dtor, cbd, NULL); + } + + msg_debug_tcp("finishing TCP %s connection", IS_SYNC(cbd) ? "sync" : "async"); + + if (cbd->connect_cb != -1) { + luaL_unref(cbd->cfg->lua_state, LUA_REGISTRYINDEX, cbd->connect_cb); + } + + if (cbd->ssl_conn) { + /* TODO: postpone close in case ssl is used ! */ + rspamd_ssl_connection_free(cbd->ssl_conn); + } + + if (cbd->fd != -1) { + rspamd_ev_watcher_stop(cbd->event_loop, &cbd->ev); + close(cbd->fd); + cbd->fd = -1; + } + + if (cbd->addr) { + rspamd_inet_address_free(cbd->addr); + } + + if (cbd->up) { + rspamd_upstream_unref(cbd->up); + } + + while (lua_tcp_shift_handler(cbd)) {} + g_queue_free(cbd->handlers); + + LL_FOREACH_SAFE(cbd->dtors, dtor, dttmp) + { + dtor->dtor(dtor->data); + g_free(dtor); + } + + g_byte_array_unref(cbd->in); + g_free(cbd->hostname); + g_free(cbd); +} + +static struct lua_tcp_cbdata * +lua_check_tcp(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{tcp}"); + luaL_argcheck(L, ud != NULL, pos, "'tcp' expected"); + return ud ? *((struct lua_tcp_cbdata **) ud) : NULL; +} + +static void +lua_tcp_maybe_free(struct lua_tcp_cbdata *cbd) +{ + if (IS_SYNC(cbd)) { + /* + * in this mode, we don't remove object, we only remove the event + * Object is owned by lua and will be destroyed on __gc() + */ + + if (cbd->item) { + rspamd_symcache_item_async_dec_check(cbd->task, cbd->item, M); + cbd->item = NULL; + } + + if (cbd->async_ev) { + rspamd_session_remove_event(cbd->session, lua_tcp_void_finalyser, cbd); + } + + cbd->async_ev = NULL; + } + else { + if (cbd->item) { + rspamd_symcache_item_async_dec_check(cbd->task, cbd->item, M); + cbd->item = NULL; + } + + if (cbd->async_ev) { + rspamd_session_remove_event(cbd->session, lua_tcp_fin, cbd); + } + else { + lua_tcp_fin(cbd); + } + } +} + +#ifdef __GNUC__ +static void +lua_tcp_push_error(struct lua_tcp_cbdata *cbd, gboolean is_fatal, + const char *err, ...) __attribute__((format(printf, 3, 4))); +#endif + +static void lua_tcp_resume_thread_error_argp(struct lua_tcp_cbdata *cbd, const gchar *error, va_list argp); + +static void +lua_tcp_push_error(struct lua_tcp_cbdata *cbd, gboolean is_fatal, + const char *err, ...) +{ + va_list ap, ap_copy; + struct lua_tcp_cbdata **pcbd; + struct lua_tcp_handler *hdl; + gint cbref, top; + struct lua_callback_state cbs; + lua_State *L; + gboolean callback_called = FALSE; + + if (is_fatal && cbd->up) { + rspamd_upstream_fail(cbd->up, false, err); + } + + if (cbd->thread) { + va_start(ap, err); + lua_tcp_resume_thread_error_argp(cbd, err, ap); + va_end(ap); + + return; + } + + lua_thread_pool_prepare_callback(cbd->cfg->lua_thread_pool, &cbs); + L = cbs.L; + + va_start(ap, err); + + for (;;) { + hdl = g_queue_peek_head(cbd->handlers); + + if (hdl == NULL) { + break; + } + + if (hdl->type == LUA_WANT_READ) { + cbref = hdl->h.r.cbref; + } + else { + cbref = hdl->h.w.cbref; + } + + if (cbref != -1) { + top = lua_gettop(L); + lua_rawgeti(L, LUA_REGISTRYINDEX, cbref); + + /* Error message */ + va_copy(ap_copy, ap); + lua_pushvfstring(L, err, ap_copy); + va_end(ap_copy); + + /* Body */ + lua_pushnil(L); + /* Connection */ + pcbd = lua_newuserdata(L, sizeof(*pcbd)); + *pcbd = cbd; + rspamd_lua_setclass(L, "rspamd{tcp}", -1); + TCP_RETAIN(cbd); + + if (cbd->item) { + rspamd_symcache_set_cur_item(cbd->task, cbd->item); + } + + if (lua_pcall(L, 3, 0, 0) != 0) { + msg_info("callback call failed: %s", lua_tostring(L, -1)); + } + + lua_settop(L, top); + + TCP_RELEASE(cbd); + + if ((cbd->flags & (LUA_TCP_FLAG_FINISHED | LUA_TCP_FLAG_CONNECTED)) == + (LUA_TCP_FLAG_FINISHED | LUA_TCP_FLAG_CONNECTED)) { + /* A callback has called `close` method, so we need to release a refcount */ + TCP_RELEASE(cbd); + } + + callback_called = TRUE; + } + + if (!is_fatal) { + if (callback_called) { + /* Stop on the first callback found */ + break; + } + else { + /* Shift to another callback to inform about non fatal error */ + msg_debug_tcp("non fatal error find matching callback"); + lua_tcp_shift_handler(cbd); + continue; + } + } + else { + msg_debug_tcp("fatal error rollback all handlers"); + lua_tcp_shift_handler(cbd); + } + } + + va_end(ap); + + lua_thread_pool_restore_callback(&cbs); +} + +static void lua_tcp_resume_thread(struct lua_tcp_cbdata *cbd, const guint8 *str, gsize len); + +static void +lua_tcp_push_data(struct lua_tcp_cbdata *cbd, const guint8 *str, gsize len) +{ + struct rspamd_lua_text *t; + struct lua_tcp_cbdata **pcbd; + struct lua_tcp_handler *hdl; + gint cbref, arg_cnt, top; + struct lua_callback_state cbs; + lua_State *L; + + if (cbd->thread) { + lua_tcp_resume_thread(cbd, str, len); + return; + } + + lua_thread_pool_prepare_callback(cbd->cfg->lua_thread_pool, &cbs); + L = cbs.L; + + hdl = g_queue_peek_head(cbd->handlers); + + g_assert(hdl != NULL); + + if (hdl->type == LUA_WANT_READ) { + cbref = hdl->h.r.cbref; + } + else { + cbref = hdl->h.w.cbref; + } + + if (cbref != -1) { + top = lua_gettop(L); + lua_rawgeti(L, LUA_REGISTRYINDEX, cbref); + /* Error */ + lua_pushnil(L); + /* Body */ + + if (hdl->type == LUA_WANT_READ) { + t = lua_newuserdata(L, sizeof(*t)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + t->start = (const gchar *) str; + t->len = len; + t->flags = 0; + arg_cnt = 3; + } + else { + arg_cnt = 2; + } + /* Connection */ + pcbd = lua_newuserdata(L, sizeof(*pcbd)); + *pcbd = cbd; + rspamd_lua_setclass(L, "rspamd{tcp}", -1); + + TCP_RETAIN(cbd); + + if (cbd->item) { + rspamd_symcache_set_cur_item(cbd->task, cbd->item); + } + + if (lua_pcall(L, arg_cnt, 0, 0) != 0) { + msg_info("callback call failed: %s", lua_tostring(L, -1)); + } + + lua_settop(L, top); + TCP_RELEASE(cbd); + + if ((cbd->flags & (LUA_TCP_FLAG_FINISHED | LUA_TCP_FLAG_CONNECTED)) == + (LUA_TCP_FLAG_FINISHED | LUA_TCP_FLAG_CONNECTED)) { + /* A callback has called `close` method, so we need to release a refcount */ + TCP_RELEASE(cbd); + } + } + + lua_thread_pool_restore_callback(&cbs); +} + +static void +lua_tcp_resume_thread_error_argp(struct lua_tcp_cbdata *cbd, const gchar *error, va_list argp) +{ + struct thread_entry *thread = cbd->thread; + lua_State *L = thread->lua_state; + + lua_pushboolean(L, FALSE); + lua_pushvfstring(L, error, argp); + + lua_tcp_shift_handler(cbd); + // lua_tcp_unregister_event (cbd); + lua_thread_pool_set_running_entry(cbd->cfg->lua_thread_pool, cbd->thread); + lua_thread_resume(thread, 2); + TCP_RELEASE(cbd); +} + +static void +lua_tcp_resume_thread(struct lua_tcp_cbdata *cbd, const guint8 *str, gsize len) +{ + /* + * typical call returns: + * + * read: + * error: + * (nil, error message) + * got data: + * (true, data) + * write/connect: + * error: + * (nil, error message) + * wrote + * (true) + */ + + lua_State *L = cbd->thread->lua_state; + struct lua_tcp_handler *hdl; + + hdl = g_queue_peek_head(cbd->handlers); + + lua_pushboolean(L, TRUE); + if (hdl->type == LUA_WANT_READ) { + lua_pushlstring(L, str, len); + } + else { + lua_pushnil(L); + } + + lua_tcp_shift_handler(cbd); + lua_thread_pool_set_running_entry(cbd->cfg->lua_thread_pool, + cbd->thread); + + if (cbd->item) { + rspamd_symcache_set_cur_item(cbd->task, cbd->item); + } + + lua_thread_resume(cbd->thread, 2); + + TCP_RELEASE(cbd); +} + +static void +lua_tcp_plan_read(struct lua_tcp_cbdata *cbd) +{ + rspamd_ev_watcher_reschedule(cbd->event_loop, &cbd->ev, EV_READ); +} + +static void +lua_tcp_connect_helper(struct lua_tcp_cbdata *cbd) +{ + /* This is used for sync mode only */ + lua_State *L = cbd->thread->lua_state; + + struct lua_tcp_cbdata **pcbd; + + lua_pushboolean(L, TRUE); + + lua_thread_pool_set_running_entry(cbd->cfg->lua_thread_pool, cbd->thread); + pcbd = lua_newuserdata(L, sizeof(*pcbd)); + *pcbd = cbd; + rspamd_lua_setclass(L, "rspamd{tcp_sync}", -1); + msg_debug_tcp("tcp connected"); + + lua_tcp_shift_handler(cbd); + + // lua_tcp_unregister_event (cbd); + lua_thread_resume(cbd->thread, 2); + TCP_RELEASE(cbd); +} + +static void +lua_tcp_write_helper(struct lua_tcp_cbdata *cbd) +{ + struct iovec *start; + guint niov, i; + gint flags = 0; + bool allocated_iov = false; + gsize remain; + gssize r; + struct iovec *cur_iov; + struct lua_tcp_handler *hdl; + struct lua_tcp_write_handler *wh; + struct msghdr msg; + + hdl = g_queue_peek_head(cbd->handlers); + + g_assert(hdl != NULL && hdl->type == LUA_WANT_WRITE); + wh = &hdl->h.w; + + if (wh->pos == wh->total_bytes) { + goto call_finish_handler; + } + + start = &wh->iov[0]; + niov = wh->iovlen; + remain = wh->pos; + /* We know that niov is small enough for that */ + + if (niov < 1024) { + cur_iov = g_alloca(niov * sizeof(struct iovec)); + } + else { + cur_iov = g_malloc0(niov * sizeof(struct iovec)); + allocated_iov = true; + } + + memcpy(cur_iov, wh->iov, niov * sizeof(struct iovec)); + + for (i = 0; i < wh->iovlen && remain > 0; i++) { + /* Find out the first iov required */ + start = &cur_iov[i]; + if (start->iov_len <= remain) { + remain -= start->iov_len; + start = &cur_iov[i + 1]; + niov--; + } + else { + start->iov_base = (void *) ((char *) start->iov_base + remain); + start->iov_len -= remain; + remain = 0; + } + } + + memset(&msg, 0, sizeof(msg)); + msg.msg_iov = start; + msg.msg_iovlen = MIN(IOV_MAX, niov); + g_assert(niov > 0); +#ifdef MSG_NOSIGNAL + flags = MSG_NOSIGNAL; +#endif + + msg_debug_tcp("want write %d io vectors of %d", (int) msg.msg_iovlen, + (int) niov); + + if (cbd->ssl_conn) { + r = rspamd_ssl_writev(cbd->ssl_conn, msg.msg_iov, msg.msg_iovlen); + } + else { + r = sendmsg(cbd->fd, &msg, flags); + } + + if (allocated_iov) { + g_free(cur_iov); + } + + if (r == -1) { + if (!(cbd->ssl_conn)) { + if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { + msg_debug_tcp("got temporary failure, retry write"); + lua_tcp_plan_handler_event(cbd, TRUE, TRUE); + return; + } + else { + lua_tcp_push_error(cbd, TRUE, + "IO write error while trying to write %d bytes: %s", + (gint) remain, strerror(errno)); + + msg_debug_tcp("write error, terminate connection"); + TCP_RELEASE(cbd); + } + } + + return; + } + else { + wh->pos += r; + } + + msg_debug_tcp("written %z bytes: %z/%z", r, + wh->pos, wh->total_bytes); + + if (wh->pos >= wh->total_bytes) { + goto call_finish_handler; + } + else { + /* Want to write more */ + if (r > 0) { + /* XXX: special case: we know that we want to write more data + * than it is available in iov function. + * + * Hence, we need to check if we can write more at some point... + */ + lua_tcp_write_helper(cbd); + } + } + + return; + +call_finish_handler: + + msg_debug_tcp("finishing TCP write, calling TCP handler"); + + if ((cbd->flags & LUA_TCP_FLAG_SHUTDOWN)) { + /* Half close the connection */ + shutdown(cbd->fd, SHUT_WR); + cbd->flags &= ~LUA_TCP_FLAG_SHUTDOWN; + } + + if (cbd->up) { + rspamd_upstream_ok(cbd->up); + } + + lua_tcp_push_data(cbd, NULL, 0); + if (!IS_SYNC(cbd)) { + lua_tcp_shift_handler(cbd); + lua_tcp_plan_handler_event(cbd, TRUE, TRUE); + } +} + +static gboolean +lua_tcp_process_read_handler(struct lua_tcp_cbdata *cbd, + struct lua_tcp_read_handler *rh, gboolean eof) +{ + guint slen; + goffset pos; + + if (rh->stop_pattern) { + slen = rh->plen; + + if (cbd->in->len >= slen) { + if ((pos = rspamd_substring_search(cbd->in->data, cbd->in->len, + rh->stop_pattern, slen)) != -1) { + msg_debug_tcp("found TCP stop pattern"); + lua_tcp_push_data(cbd, cbd->in->data, pos); + + if (!IS_SYNC(cbd)) { + lua_tcp_shift_handler(cbd); + } + if (pos + slen < cbd->in->len) { + /* We have a leftover */ + memmove(cbd->in->data, cbd->in->data + pos + slen, + cbd->in->len - (pos + slen)); + cbd->in->len = cbd->in->len - (pos + slen); + } + else { + cbd->in->len = 0; + } + + return TRUE; + } + else { + /* Plan new read */ + msg_debug_tcp("NOT found TCP stop pattern"); + + if (!cbd->eof) { + lua_tcp_plan_read(cbd); + } + else { + /* Got session finished but no stop pattern */ + lua_tcp_push_error(cbd, TRUE, + "IO read error: connection terminated"); + } + } + } + } + else { + msg_debug_tcp("read TCP partial data %d bytes", cbd->in->len); + slen = cbd->in->len; + + /* we have eaten all the data, handler should not know that there is something */ + cbd->in->len = 0; + lua_tcp_push_data(cbd, cbd->in->data, slen); + if (!IS_SYNC(cbd)) { + lua_tcp_shift_handler(cbd); + } + + return TRUE; + } + + return FALSE; +} + +static void +lua_tcp_process_read(struct lua_tcp_cbdata *cbd, + guchar *in, gssize r) +{ + struct lua_tcp_handler *hdl; + struct lua_tcp_read_handler *rh; + + hdl = g_queue_peek_head(cbd->handlers); + + g_assert(hdl != NULL && hdl->type == LUA_WANT_READ); + rh = &hdl->h.r; + + if (r > 0) { + if (cbd->flags & LUA_TCP_FLAG_PARTIAL) { + lua_tcp_push_data(cbd, in, r); + /* Plan next event */ + lua_tcp_plan_read(cbd); + } + else { + g_byte_array_append(cbd->in, in, r); + + if (!lua_tcp_process_read_handler(cbd, rh, FALSE)) { + /* Plan more read */ + lua_tcp_plan_read(cbd); + } + else { + /* Go towards the next handler */ + if (!IS_SYNC(cbd)) { + lua_tcp_plan_handler_event(cbd, TRUE, TRUE); + } + } + } + } + else if (r == 0) { + /* EOF */ + cbd->eof = TRUE; + if (cbd->in->len > 0) { + /* We have some data to process */ + lua_tcp_process_read_handler(cbd, rh, TRUE); + } + else { + lua_tcp_push_error(cbd, TRUE, "IO read error: connection terminated"); + + if ((cbd->flags & LUA_TCP_FLAG_FINISHED)) { + /* A callback has called `close` method, so we need to release a refcount */ + TCP_RELEASE(cbd); + } + } + + lua_tcp_plan_handler_event(cbd, FALSE, FALSE); + } + else { + /* An error occurred */ + if (errno == EAGAIN || errno == EINTR) { + /* Restart call */ + lua_tcp_plan_read(cbd); + + return; + } + + /* Fatal error */ + cbd->eof = TRUE; + if (cbd->in->len > 0) { + /* We have some data to process */ + lua_tcp_process_read_handler(cbd, rh, TRUE); + } + else { + lua_tcp_push_error(cbd, TRUE, + "IO read error while trying to read data: %s", + strerror(errno)); + + if ((cbd->flags & LUA_TCP_FLAG_FINISHED)) { + /* A callback has called `close` method, so we need to release a refcount */ + TCP_RELEASE(cbd); + } + } + + lua_tcp_plan_handler_event(cbd, FALSE, FALSE); + } +} + +static void +lua_tcp_handler(int fd, short what, gpointer ud) +{ + struct lua_tcp_cbdata *cbd = ud; + guchar inbuf[8192]; + gssize r; + gint so_error = 0; + socklen_t so_len = sizeof(so_error); + struct lua_callback_state cbs; + lua_State *L; + enum lua_tcp_handler_type event_type; + TCP_RETAIN(cbd); + + msg_debug_tcp("processed TCP event: %d", what); + + struct lua_tcp_handler *rh = g_queue_peek_head(cbd->handlers); + event_type = rh->type; + + rspamd_ev_watcher_stop(cbd->event_loop, &cbd->ev); + + if (what == EV_READ) { + if (cbd->ssl_conn) { + r = rspamd_ssl_read(cbd->ssl_conn, inbuf, sizeof(inbuf)); + } + else { + r = read(cbd->fd, inbuf, sizeof(inbuf)); + } + + lua_tcp_process_read(cbd, inbuf, r); + } + else if (what == EV_WRITE) { + + if (!(cbd->flags & LUA_TCP_FLAG_CONNECTED)) { + if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &so_error, &so_len) == -1) { + lua_tcp_push_error(cbd, TRUE, "Cannot get socket error: %s", + strerror(errno)); + TCP_RELEASE(cbd); + goto out; + } + else if (so_error != 0) { + lua_tcp_push_error(cbd, TRUE, "Socket error detected: %s", + strerror(so_error)); + TCP_RELEASE(cbd); + goto out; + } + else { + cbd->flags |= LUA_TCP_FLAG_CONNECTED; + + if (cbd->connect_cb != -1) { + struct lua_tcp_cbdata **pcbd; + gint top; + + lua_thread_pool_prepare_callback(cbd->cfg->lua_thread_pool, &cbs); + L = cbs.L; + + top = lua_gettop(L); + lua_rawgeti(L, LUA_REGISTRYINDEX, cbd->connect_cb); + pcbd = lua_newuserdata(L, sizeof(*pcbd)); + *pcbd = cbd; + TCP_RETAIN(cbd); + rspamd_lua_setclass(L, "rspamd{tcp}", -1); + + if (cbd->item) { + rspamd_symcache_set_cur_item(cbd->task, cbd->item); + } + + if (lua_pcall(L, 1, 0, 0) != 0) { + msg_info("callback call failed: %s", lua_tostring(L, -1)); + } + + lua_settop(L, top); + TCP_RELEASE(cbd); + lua_thread_pool_restore_callback(&cbs); + + if ((cbd->flags & (LUA_TCP_FLAG_FINISHED | LUA_TCP_FLAG_CONNECTED)) == + (LUA_TCP_FLAG_FINISHED | LUA_TCP_FLAG_CONNECTED)) { + /* A callback has called `close` method, so we need to release a refcount */ + TCP_RELEASE(cbd); + } + } + } + } + + if (event_type == LUA_WANT_WRITE) { + lua_tcp_write_helper(cbd); + } + else if (event_type == LUA_WANT_CONNECT) { + lua_tcp_connect_helper(cbd); + } + else { + g_assert_not_reached(); + } + + if ((cbd->flags & (LUA_TCP_FLAG_FINISHED | LUA_TCP_FLAG_CONNECTED)) == + (LUA_TCP_FLAG_FINISHED | LUA_TCP_FLAG_CONNECTED)) { + /* A callback has called `close` method, so we need to release a refcount */ + TCP_RELEASE(cbd); + } + } + else { + lua_tcp_push_error(cbd, TRUE, "IO timeout"); + TCP_RELEASE(cbd); + } + +out: + TCP_RELEASE(cbd); +} + +static void +lua_tcp_plan_handler_event(struct lua_tcp_cbdata *cbd, gboolean can_read, + gboolean can_write) +{ + struct lua_tcp_handler *hdl; + + hdl = g_queue_peek_head(cbd->handlers); + + if (hdl == NULL) { + if (!(cbd->flags & LUA_TCP_FLAG_FINISHED)) { + /* We are finished with a connection */ + msg_debug_tcp("no handlers left, finish session"); + cbd->flags |= LUA_TCP_FLAG_FINISHED; + TCP_RELEASE(cbd); + } + } + else { + if (hdl->type == LUA_WANT_READ) { + + /* We need to check if we have some leftover in the buffer */ + if (cbd->in->len > 0) { + msg_debug_tcp("process read buffer leftover"); + if (lua_tcp_process_read_handler(cbd, &hdl->h.r, FALSE)) { + if (!IS_SYNC(cbd)) { + /* We can go to the next handler */ + lua_tcp_plan_handler_event(cbd, can_read, can_write); + } + } + } + else { + if (can_read) { + /* We need to plan a new event */ + msg_debug_tcp("plan new read"); + rspamd_ev_watcher_reschedule(cbd->event_loop, &cbd->ev, + EV_READ); + } + else { + /* Cannot read more */ + msg_debug_tcp("cannot read more"); + lua_tcp_push_error(cbd, FALSE, "EOF, cannot read more data"); + if (!IS_SYNC(cbd)) { + lua_tcp_shift_handler(cbd); + lua_tcp_plan_handler_event(cbd, can_read, can_write); + } + } + } + } + else if (hdl->type == LUA_WANT_WRITE) { + /* + * We need to plan write event if there is something in the + * write request + */ + + if (hdl->h.w.pos < hdl->h.w.total_bytes) { + msg_debug_tcp("plan new write"); + if (can_write) { + rspamd_ev_watcher_reschedule(cbd->event_loop, &cbd->ev, + EV_WRITE); + } + else { + /* Cannot write more */ + lua_tcp_push_error(cbd, FALSE, "EOF, cannot write more data"); + if (!IS_SYNC(cbd)) { + lua_tcp_shift_handler(cbd); + lua_tcp_plan_handler_event(cbd, can_read, can_write); + } + } + } + else { + /* We shouldn't have empty write handlers */ + g_assert_not_reached(); + } + } + else { /* LUA_WANT_CONNECT */ + msg_debug_tcp("plan new connect"); + rspamd_ev_watcher_reschedule(cbd->event_loop, &cbd->ev, + EV_WRITE); + } + } +} + +static gboolean +lua_tcp_register_event(struct lua_tcp_cbdata *cbd) +{ + if (cbd->session) { + event_finalizer_t fin = IS_SYNC(cbd) ? lua_tcp_void_finalyser : lua_tcp_fin; + + if (cbd->item) { + cbd->async_ev = rspamd_session_add_event_full(cbd->session, fin, cbd, M, + rspamd_symcache_dyn_item_name(cbd->task, cbd->item)); + } + else { + cbd->async_ev = rspamd_session_add_event(cbd->session, fin, cbd, M); + } + + if (!cbd->async_ev) { + return FALSE; + } + } + + return TRUE; +} + +static void +lua_tcp_register_watcher(struct lua_tcp_cbdata *cbd) +{ + if (cbd->item && cbd->task) { + rspamd_symcache_item_async_inc(cbd->task, cbd->item, M); + } +} + +static void +lua_tcp_ssl_on_error(gpointer ud, GError *err) +{ + struct lua_tcp_cbdata *cbd = (struct lua_tcp_cbdata *) ud; + + if (err) { + lua_tcp_push_error(cbd, TRUE, "ssl error: %s", err->message); + } + else { + lua_tcp_push_error(cbd, TRUE, "ssl error: unknown error"); + } + + TCP_RELEASE(cbd); +} + +static gboolean +lua_tcp_make_connection(struct lua_tcp_cbdata *cbd) +{ + int fd; + + rspamd_inet_address_set_port(cbd->addr, cbd->port); + fd = rspamd_inet_address_connect(cbd->addr, SOCK_STREAM, TRUE); + + if (fd == -1) { + if (cbd->session) { + rspamd_mempool_t *pool = rspamd_session_mempool(cbd->session); + msg_info_pool("cannot connect to %s (%s): %s", + rspamd_inet_address_to_string(cbd->addr), + cbd->hostname, + strerror(errno)); + } + else { + msg_info("cannot connect to %s (%s): %s", + rspamd_inet_address_to_string(cbd->addr), + cbd->hostname, + strerror(errno)); + } + + return FALSE; + } + + cbd->fd = fd; + +#if 0 + if (!(cbd->flags & LUA_TCP_FLAG_RESOLVED)) { + /* We come here without resolving, so we need to add a watcher */ + lua_tcp_register_watcher (cbd); + } + else { + cbd->flags |= LUA_TCP_FLAG_RESOLVED; + } +#endif + + if (cbd->flags & LUA_TCP_FLAG_SSL) { + gpointer ssl_ctx; + gboolean verify_peer; + + if (cbd->flags & LUA_TCP_FLAG_SSL_NOVERIFY) { + ssl_ctx = cbd->cfg->libs_ctx->ssl_ctx_noverify; + verify_peer = FALSE; + } + else { + ssl_ctx = cbd->cfg->libs_ctx->ssl_ctx; + verify_peer = TRUE; + } + + cbd->ssl_conn = rspamd_ssl_connection_new(ssl_ctx, + cbd->event_loop, + verify_peer, + cbd->tag); + + if (!rspamd_ssl_connect_fd(cbd->ssl_conn, fd, cbd->hostname, &cbd->ev, + cbd->ev.timeout, lua_tcp_handler, lua_tcp_ssl_on_error, cbd)) { + lua_tcp_push_error(cbd, TRUE, "ssl connection failed: %s", + strerror(errno)); + + return FALSE; + } + else { + lua_tcp_register_event(cbd); + } + } + else { + rspamd_ev_watcher_init(&cbd->ev, cbd->fd, EV_WRITE, + lua_tcp_handler, cbd); + lua_tcp_register_event(cbd); + lua_tcp_plan_handler_event(cbd, TRUE, TRUE); + } + + + return TRUE; +} + +static void +lua_tcp_dns_handler(struct rdns_reply *reply, gpointer ud) +{ + struct lua_tcp_cbdata *cbd = (struct lua_tcp_cbdata *) ud; + const struct rdns_request_name *rn; + + if (reply->code != RDNS_RC_NOERROR) { + rn = rdns_request_get_name(reply->request, NULL); + lua_tcp_push_error(cbd, TRUE, "unable to resolve host: %s", + rn->name); + TCP_RELEASE(cbd); + } + else { + /* + * We set this flag as it means that we have already registered the watcher + * when started DNS query + */ + struct rdns_reply_entry *entry; + + DL_FOREACH(reply->entries, entry) + { + if (entry->type == RDNS_REQUEST_A) { + cbd->addr = rspamd_inet_address_new(AF_INET, + &entry->content.a.addr); + break; + } + else if (entry->type == RDNS_REQUEST_AAAA) { + cbd->addr = rspamd_inet_address_new(AF_INET6, + &entry->content.aaa.addr); + break; + } + } + + if (cbd->addr == NULL) { + rn = rdns_request_get_name(reply->request, NULL); + lua_tcp_push_error(cbd, TRUE, "unable to resolve host: %s; no records with this name", + rn->name); + TCP_RELEASE(cbd); + return; + } + + cbd->flags |= LUA_TCP_FLAG_RESOLVED; + rspamd_inet_address_set_port(cbd->addr, cbd->port); + + if (!lua_tcp_make_connection(cbd)) { + lua_tcp_push_error(cbd, TRUE, "unable to make connection to the host %s", + rspamd_inet_address_to_string(cbd->addr)); + TCP_RELEASE(cbd); + } + } +} + +static gboolean +lua_tcp_arg_toiovec(lua_State *L, gint pos, struct lua_tcp_cbdata *cbd, + struct iovec *vec) +{ + struct rspamd_lua_text *t; + gsize len; + const gchar *str; + struct lua_tcp_dtor *dtor; + + if (lua_type(L, pos) == LUA_TUSERDATA) { + t = lua_check_text(L, pos); + + if (t) { + vec->iov_base = (void *) t->start; + vec->iov_len = t->len; + + if (t->flags & RSPAMD_TEXT_FLAG_OWN) { + /* Steal ownership */ + t->flags = 0; + dtor = g_malloc0(sizeof(*dtor)); + dtor->dtor = g_free; + dtor->data = (void *) t->start; + LL_PREPEND(cbd->dtors, dtor); + } + } + else { + msg_err("bad userdata argument at position %d", pos); + return FALSE; + } + } + else if (lua_type(L, pos) == LUA_TSTRING) { + str = luaL_checklstring(L, pos, &len); + vec->iov_base = g_malloc(len); + dtor = g_malloc0(sizeof(*dtor)); + dtor->dtor = g_free; + dtor->data = vec->iov_base; + LL_PREPEND(cbd->dtors, dtor); + memcpy(vec->iov_base, str, len); + vec->iov_len = len; + } + else { + msg_err("bad argument at position %d", pos); + return FALSE; + } + + return TRUE; +} + +/*** + * @function rspamd_tcp.request({params}) + * This function creates and sends TCP request to the specified host and port, + * resolves hostname (if needed) and invokes continuation callback upon data received + * from the remote peer. This function accepts table of arguments with the following + * attributes + * + * - `task`: rspamd task objects (implies `pool`, `session`, `ev_base` and `resolver` arguments) + * - `ev_base`: event base (if no task specified) + * - `resolver`: DNS resolver (no task) + * - `session`: events session (no task) + * - `host`: IP or name of the peer (required) + * - `port`: remote port to use + * - `data`: a table of strings or `rspamd_text` objects that contains data pieces + * - `callback`: continuation function (required) + * - `on_connect`: callback called on connection success + * - `timeout`: floating point value that specifies timeout for IO operations in **seconds** + * - `partial`: boolean flag that specifies that callback should be called on any data portion received + * - `stop_pattern`: stop reading on finding a certain pattern (e.g. \r\n.\r\n for smtp) + * - `shutdown`: half-close socket after writing (boolean: default false) + * - `read`: read response after sending request (boolean: default true) + * - `upstream`: optional upstream object that would be used to get an address + * @return {boolean} true if request has been sent + */ +static gint +lua_tcp_request(lua_State *L) +{ + LUA_TRACE_POINT; + const gchar *host; + gchar *stop_pattern = NULL; + guint port; + gint cbref, tp, conn_cbref = -1; + gsize plen = 0; + struct ev_loop *event_loop = NULL; + struct lua_tcp_cbdata *cbd; + struct rspamd_dns_resolver *resolver = NULL; + struct rspamd_async_session *session = NULL; + struct rspamd_task *task = NULL; + struct rspamd_config *cfg = NULL; + struct iovec *iov = NULL; + struct upstream *up = NULL; + guint niov = 0, total_out; + guint64 h; + gdouble timeout = default_tcp_timeout; + gboolean partial = FALSE, do_shutdown = FALSE, do_read = TRUE, + ssl = FALSE, ssl_noverify = FALSE; + + if (lua_type(L, 1) == LUA_TTABLE) { + lua_pushstring(L, "host"); + lua_gettable(L, -2); + host = luaL_checkstring(L, -1); + lua_pop(L, 1); + + lua_pushstring(L, "port"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TNUMBER) { + port = lua_tointeger(L, -1); + } + else { + /* We assume that it is a unix socket */ + port = 0; + } + + lua_pop(L, 1); + + lua_pushstring(L, "callback"); + lua_gettable(L, -2); + if (host == NULL || lua_type(L, -1) != LUA_TFUNCTION) { + lua_pop(L, 1); + msg_err("tcp request has bad params"); + lua_pushboolean(L, FALSE); + return 1; + } + cbref = luaL_ref(L, LUA_REGISTRYINDEX); + + cbd = g_malloc0(sizeof(*cbd)); + + lua_pushstring(L, "task"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TUSERDATA) { + task = lua_check_task(L, -1); + event_loop = task->event_loop; + resolver = task->resolver; + session = task->s; + cfg = task->cfg; + } + lua_pop(L, 1); + + if (task == NULL) { + lua_pushstring(L, "ev_base"); + lua_gettable(L, -2); + if (rspamd_lua_check_udata_maybe(L, -1, "rspamd{ev_base}")) { + event_loop = *(struct ev_loop **) lua_touserdata(L, -1); + } + else { + g_free(cbd); + + return luaL_error(L, "event loop is required"); + } + lua_pop(L, 1); + + lua_pushstring(L, "session"); + lua_gettable(L, -2); + if (rspamd_lua_check_udata_maybe(L, -1, "rspamd{session}")) { + session = *(struct rspamd_async_session **) lua_touserdata(L, -1); + } + else { + session = NULL; + } + lua_pop(L, 1); + + lua_pushstring(L, "config"); + lua_gettable(L, -2); + if (rspamd_lua_check_udata_maybe(L, -1, "rspamd{config}")) { + cfg = *(struct rspamd_config **) lua_touserdata(L, -1); + } + else { + cfg = NULL; + } + lua_pop(L, 1); + + lua_pushstring(L, "resolver"); + lua_gettable(L, -2); + if (rspamd_lua_check_udata_maybe(L, -1, "rspamd{resolver}")) { + resolver = *(struct rspamd_dns_resolver **) lua_touserdata(L, -1); + } + else { + resolver = lua_tcp_global_resolver(event_loop, cfg); + } + lua_pop(L, 1); + } + + lua_pushstring(L, "timeout"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TNUMBER) { + timeout = lua_tonumber(L, -1); + } + lua_pop(L, 1); + + lua_pushstring(L, "stop_pattern"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TSTRING) { + const gchar *p; + + p = lua_tolstring(L, -1, &plen); + + if (p && plen > 0) { + stop_pattern = g_malloc(plen); + memcpy(stop_pattern, p, plen); + } + } + lua_pop(L, 1); + + lua_pushstring(L, "partial"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TBOOLEAN) { + partial = lua_toboolean(L, -1); + } + lua_pop(L, 1); + + lua_pushstring(L, "shutdown"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TBOOLEAN) { + do_shutdown = lua_toboolean(L, -1); + } + lua_pop(L, 1); + + lua_pushstring(L, "read"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TBOOLEAN) { + do_read = lua_toboolean(L, -1); + } + lua_pop(L, 1); + + lua_pushstring(L, "ssl"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TBOOLEAN) { + ssl = lua_toboolean(L, -1); + } + lua_pop(L, 1); + + lua_pushstring(L, "ssl_noverify"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TBOOLEAN) { + ssl_noverify = lua_toboolean(L, -1); + lua_pop(L, 1); + } + else { + lua_pop(L, 1); /* Previous nil... */ + /* Similar to lua http, meh... */ + lua_pushstring(L, "no_ssl_verify"); + lua_gettable(L, -2); + + if (lua_type(L, -1) == LUA_TBOOLEAN) { + ssl_noverify = lua_toboolean(L, -1); + } + + lua_pop(L, 1); + } + + lua_pushstring(L, "on_connect"); + lua_gettable(L, -2); + + if (lua_type(L, -1) == LUA_TFUNCTION) { + conn_cbref = luaL_ref(L, LUA_REGISTRYINDEX); + } + else { + lua_pop(L, 1); + } + + lua_pushstring(L, "upstream"); + lua_gettable(L, 1); + + if (lua_type(L, -1) == LUA_TUSERDATA) { + struct rspamd_lua_upstream *lup = lua_check_upstream(L, -1); + + if (lup) { + /* Preserve pointer in case if lup is destructed */ + up = lup->up; + } + } + + lua_pop(L, 1); + + lua_pushstring(L, "data"); + lua_gettable(L, -2); + total_out = 0; + + tp = lua_type(L, -1); + if (tp == LUA_TSTRING || tp == LUA_TUSERDATA) { + iov = g_malloc(sizeof(*iov)); + niov = 1; + + if (!lua_tcp_arg_toiovec(L, -1, cbd, iov)) { + lua_pop(L, 1); + msg_err("tcp request has bad data argument"); + lua_pushboolean(L, FALSE); + g_free(iov); + g_free(cbd); + + return 1; + } + + total_out = iov[0].iov_len; + } + else if (tp == LUA_TTABLE) { + /* Count parts */ + lua_pushnil(L); + while (lua_next(L, -2) != 0) { + niov++; + lua_pop(L, 1); + } + + iov = g_malloc(sizeof(*iov) * niov); + lua_pushnil(L); + niov = 0; + + while (lua_next(L, -2) != 0) { + if (!lua_tcp_arg_toiovec(L, -1, cbd, &iov[niov])) { + lua_pop(L, 2); + msg_err("tcp request has bad data argument at pos %d", niov); + lua_pushboolean(L, FALSE); + g_free(iov); + g_free(cbd); + + return 1; + } + + total_out += iov[niov].iov_len; + niov++; + + lua_pop(L, 1); + } + } + + lua_pop(L, 1); + } + else { + return luaL_error(L, "tcp request has bad params"); + } + + if (resolver == NULL && cfg == NULL && task == NULL) { + g_free(cbd); + g_free(iov); + + return luaL_error(L, "tcp request has bad params: one of " + "{resolver,task,config} should be set"); + } + + cbd->task = task; + + if (task) { + cbd->item = rspamd_symcache_get_cur_item(task); + } + + cbd->cfg = cfg; + h = rspamd_random_uint64_fast(); + rspamd_snprintf(cbd->tag, sizeof(cbd->tag), "%uxL", h); + cbd->handlers = g_queue_new(); + cbd->hostname = g_strdup(host); + + if (total_out > 0) { + struct lua_tcp_handler *wh; + + wh = g_malloc0(sizeof(*wh)); + wh->type = LUA_WANT_WRITE; + wh->h.w.iov = iov; + wh->h.w.iovlen = niov; + wh->h.w.total_bytes = total_out; + wh->h.w.pos = 0; + /* Cannot set write handler here */ + wh->h.w.cbref = -1; + + if (cbref != -1 && !do_read) { + /* We have write only callback */ + wh->h.w.cbref = cbref; + } + else { + /* We have simple client callback */ + wh->h.w.cbref = -1; + } + + g_queue_push_tail(cbd->handlers, wh); + } + + cbd->event_loop = event_loop; + cbd->fd = -1; + cbd->port = port; + cbd->ev.timeout = timeout; + + if (ssl) { + cbd->flags |= LUA_TCP_FLAG_SSL; + + if (ssl_noverify) { + cbd->flags |= LUA_TCP_FLAG_SSL_NOVERIFY; + } + } + + if (do_read) { + cbd->in = g_byte_array_sized_new(8192); + } + else { + /* Save some space... */ + cbd->in = g_byte_array_new(); + } + + if (partial) { + cbd->flags |= LUA_TCP_FLAG_PARTIAL; + } + + if (do_shutdown) { + cbd->flags |= LUA_TCP_FLAG_SHUTDOWN; + } + + if (do_read) { + struct lua_tcp_handler *rh; + + rh = g_malloc0(sizeof(*rh)); + rh->type = LUA_WANT_READ; + rh->h.r.cbref = cbref; + rh->h.r.stop_pattern = stop_pattern; + rh->h.r.plen = plen; + g_queue_push_tail(cbd->handlers, rh); + } + + cbd->connect_cb = conn_cbref; + REF_INIT_RETAIN(cbd, lua_tcp_maybe_free); + + if (up) { + cbd->up = rspamd_upstream_ref(up); + } + + if (session) { + cbd->session = session; + + if (rspamd_session_blocked(session)) { + lua_tcp_push_error(cbd, TRUE, "async session is the blocked state"); + TCP_RELEASE(cbd); + cbd->item = NULL; /* To avoid decrease with no watcher */ + lua_pushboolean(L, FALSE); + + return 1; + } + } + + if (cbd->up) { + /* Use upstream to get addr */ + cbd->addr = rspamd_inet_address_copy(rspamd_upstream_addr_next(cbd->up), NULL); + + /* Host is numeric IP, no need to resolve */ + lua_tcp_register_watcher(cbd); + + if (!lua_tcp_make_connection(cbd)) { + lua_tcp_push_error(cbd, TRUE, "cannot connect to the host: %s", host); + lua_pushboolean(L, FALSE); + + rspamd_upstream_fail(cbd->up, true, "failed to connect"); + + /* No reset of the item as watcher has been registered */ + TCP_RELEASE(cbd); + + return 1; + } + } + else if (rspamd_parse_inet_address(&cbd->addr, + host, strlen(host), RSPAMD_INET_ADDRESS_PARSE_DEFAULT)) { + rspamd_inet_address_set_port(cbd->addr, port); + /* Host is numeric IP, no need to resolve */ + lua_tcp_register_watcher(cbd); + + if (!lua_tcp_make_connection(cbd)) { + lua_tcp_push_error(cbd, TRUE, "cannot connect to the host: %s", host); + lua_pushboolean(L, FALSE); + + /* No reset of the item as watcher has been registered */ + TCP_RELEASE(cbd); + + return 1; + } + } + else { + if (task == NULL) { + if (!rspamd_dns_resolver_request(resolver, session, NULL, lua_tcp_dns_handler, cbd, + RDNS_REQUEST_A, host)) { + lua_tcp_push_error(cbd, TRUE, "cannot resolve host: %s", host); + lua_pushboolean(L, FALSE); + cbd->item = NULL; /* To avoid decrease with no watcher */ + TCP_RELEASE(cbd); + + return 1; + } + else { + lua_tcp_register_watcher(cbd); + } + } + else { + if (!rspamd_dns_resolver_request_task(task, lua_tcp_dns_handler, cbd, + RDNS_REQUEST_A, host)) { + lua_tcp_push_error(cbd, TRUE, "cannot resolve host: %s", host); + lua_pushboolean(L, FALSE); + cbd->item = NULL; /* To avoid decrease with no watcher */ + + TCP_RELEASE(cbd); + + return 1; + } + else { + lua_tcp_register_watcher(cbd); + } + } + } + + lua_pushboolean(L, TRUE); + return 1; +} + +/*** + * @function rspamd_tcp.connect_sync({params}) + * Creates new pseudo-synchronous connection to the specific address:port + * + * - `task`: rspamd task objects (implies `pool`, `session`, `ev_base` and `resolver` arguments) + * - `ev_base`: event base (if no task specified) + * - `resolver`: DNS resolver (no task) + * - `session`: events session (no task) + * - `config`: config (no task) + * - `host`: IP or name of the peer (required) + * - `port`: remote port to use + * - `timeout`: floating point value that specifies timeout for IO operations in **seconds** + * @return {boolean} true if request has been sent + */ +static gint +lua_tcp_connect_sync(lua_State *L) +{ + LUA_TRACE_POINT; + GError *err = NULL; + + gint64 port = -1; + gdouble timeout = default_tcp_timeout; + const gchar *host = NULL; + gint ret; + guint64 h; + + struct rspamd_task *task = NULL; + struct rspamd_async_session *session = NULL; + struct rspamd_dns_resolver *resolver = NULL; + struct rspamd_config *cfg = NULL; + struct ev_loop *ev_base = NULL; + struct lua_tcp_cbdata *cbd; + + + int arguments_validated = rspamd_lua_parse_table_arguments(L, 1, &err, + RSPAMD_LUA_PARSE_ARGUMENTS_DEFAULT, + "task=U{task};session=U{session};resolver=U{resolver};ev_base=U{ev_base};" + "*host=S;*port=I;timeout=D;config=U{config}", + &task, &session, &resolver, &ev_base, + &host, &port, &timeout, &cfg); + + if (!arguments_validated) { + if (err) { + ret = luaL_error(L, "invalid arguments: %s", err->message); + g_error_free(err); + + return ret; + } + + return luaL_error(L, "invalid arguments"); + } + + if (0 > port || port > 65535) { + return luaL_error(L, "invalid port given (correct values: 1..65535)"); + } + + if (task == NULL && (cfg == NULL || ev_base == NULL || session == NULL)) { + return luaL_error(L, "invalid arguments: either task or config+ev_base+session should be set"); + } + + if (isnan(timeout)) { + /* rspamd_lua_parse_table_arguments() sets missing N field to zero */ + timeout = default_tcp_timeout; + } + + cbd = g_new0(struct lua_tcp_cbdata, 1); + + if (task) { + static const gchar hexdigests[16] = "0123456789abcdef"; + + cfg = task->cfg; + ev_base = task->event_loop; + session = task->s; + /* Make a readable tag */ + memcpy(cbd->tag, task->task_pool->tag.uid, sizeof(cbd->tag) - 2); + cbd->tag[sizeof(cbd->tag) - 2] = hexdigests[GPOINTER_TO_INT(cbd) & 0xf]; + cbd->tag[sizeof(cbd->tag) - 1] = 0; + } + else { + h = rspamd_random_uint64_fast(); + rspamd_snprintf(cbd->tag, sizeof(cbd->tag), "%uxL", h); + } + + if (resolver == NULL) { + if (task) { + resolver = task->resolver; + } + else { + resolver = lua_tcp_global_resolver(ev_base, cfg); + } + } + + cbd->task = task; + cbd->cfg = cfg; + cbd->thread = lua_thread_pool_get_running_entry(cfg->lua_thread_pool); + + + cbd->handlers = g_queue_new(); + + cbd->event_loop = ev_base; + cbd->flags |= LUA_TCP_FLAG_SYNC; + cbd->fd = -1; + cbd->port = (guint16) port; + + cbd->in = g_byte_array_new(); + + cbd->connect_cb = -1; + + REF_INIT_RETAIN(cbd, lua_tcp_maybe_free); + + if (task) { + rspamd_mempool_add_destructor(task->task_pool, lua_tcp_sync_session_dtor, cbd); + } + + struct lua_tcp_handler *wh; + + wh = g_malloc0(sizeof(*wh)); + wh->type = LUA_WANT_CONNECT; + + g_queue_push_tail(cbd->handlers, wh); + + if (session) { + cbd->session = session; + + if (rspamd_session_blocked(session)) { + TCP_RELEASE(cbd); + lua_pushboolean(L, FALSE); + lua_pushliteral(L, "Session is being destroyed, requests are not allowed"); + + return 2; + } + } + + if (rspamd_parse_inet_address(&cbd->addr, + host, strlen(host), RSPAMD_INET_ADDRESS_PARSE_DEFAULT)) { + rspamd_inet_address_set_port(cbd->addr, (guint16) port); + /* Host is numeric IP, no need to resolve */ + if (!lua_tcp_make_connection(cbd)) { + lua_pushboolean(L, FALSE); + lua_pushliteral(L, "Failed to initiate connection"); + + TCP_RELEASE(cbd); + + return 2; + } + } + else { + if (task == NULL) { + if (!rspamd_dns_resolver_request(resolver, session, NULL, lua_tcp_dns_handler, cbd, + RDNS_REQUEST_A, host)) { + lua_pushboolean(L, FALSE); + lua_pushliteral(L, "Failed to initiate dns request"); + + TCP_RELEASE(cbd); + + return 2; + } + else { + lua_tcp_register_watcher(cbd); + } + } + else { + cbd->item = rspamd_symcache_get_cur_item(task); + + if (!rspamd_dns_resolver_request_task(task, lua_tcp_dns_handler, cbd, + RDNS_REQUEST_A, host)) { + cbd->item = NULL; /* We have not registered watcher */ + lua_pushboolean(L, FALSE); + lua_pushliteral(L, "Failed to initiate dns request"); + TCP_RELEASE(cbd); + + return 2; + } + else { + lua_tcp_register_watcher(cbd); + } + } + } + + return lua_thread_yield(cbd->thread, 0); +} + +static gint +lua_tcp_close(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_tcp(L, 1); + + if (cbd == NULL) { + return luaL_error(L, "invalid arguments"); + } + + cbd->flags |= LUA_TCP_FLAG_FINISHED; + + if (cbd->ssl_conn) { + /* TODO: postpone close in case ssl is used ! */ + rspamd_ssl_connection_free(cbd->ssl_conn); + cbd->ssl_conn = NULL; + } + + if (cbd->fd != -1) { + rspamd_ev_watcher_stop(cbd->event_loop, &cbd->ev); + close(cbd->fd); + cbd->fd = -1; + } + + if (cbd->addr) { + rspamd_inet_address_free(cbd->addr); + cbd->addr = NULL; + } + + if (cbd->up) { + rspamd_upstream_unref(cbd->up); + cbd->up = NULL; + } + /* Do not release refcount as it will be handled elsewhere */ + /* TCP_RELEASE (cbd); */ + + return 0; +} + +static gint +lua_tcp_add_read(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_tcp(L, 1); + struct lua_tcp_handler *rh; + gchar *stop_pattern = NULL; + const gchar *p; + gsize plen = 0; + gint cbref = -1; + + if (cbd == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (lua_type(L, 2) == LUA_TFUNCTION) { + lua_pushvalue(L, 2); + cbref = luaL_ref(L, LUA_REGISTRYINDEX); + } + + if (lua_type(L, 3) == LUA_TSTRING) { + p = lua_tolstring(L, 3, &plen); + + if (p && plen > 0) { + stop_pattern = g_malloc(plen); + memcpy(stop_pattern, p, plen); + } + } + + rh = g_malloc0(sizeof(*rh)); + rh->type = LUA_WANT_READ; + rh->h.r.cbref = cbref; + rh->h.r.stop_pattern = stop_pattern; + rh->h.r.plen = plen; + msg_debug_tcp("added read event, cbref: %d", cbref); + + g_queue_push_tail(cbd->handlers, rh); + + return 0; +} + +static gint +lua_tcp_add_write(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_tcp(L, 1); + struct lua_tcp_handler *wh; + gint cbref = -1, tp; + struct iovec *iov = NULL; + guint niov = 0, total_out = 0; + + if (cbd == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (lua_type(L, 2) == LUA_TFUNCTION) { + lua_pushvalue(L, 2); + cbref = luaL_ref(L, LUA_REGISTRYINDEX); + } + + tp = lua_type(L, 3); + if (tp == LUA_TSTRING || tp == LUA_TUSERDATA) { + iov = g_malloc(sizeof(*iov)); + niov = 1; + + if (!lua_tcp_arg_toiovec(L, 3, cbd, iov)) { + msg_err("tcp request has bad data argument"); + lua_pushboolean(L, FALSE); + g_free(iov); + + return 1; + } + + total_out = iov[0].iov_len; + } + else if (tp == LUA_TTABLE) { + /* Count parts */ + lua_pushvalue(L, 3); + + lua_pushnil(L); + while (lua_next(L, -2) != 0) { + niov++; + lua_pop(L, 1); + } + + iov = g_malloc(sizeof(*iov) * niov); + lua_pushnil(L); + niov = 0; + + while (lua_next(L, -2) != 0) { + if (!lua_tcp_arg_toiovec(L, -1, cbd, &iov[niov])) { + lua_pop(L, 2); + msg_err("tcp request has bad data argument at pos %d", niov); + lua_pushboolean(L, FALSE); + g_free(iov); + g_free(cbd); + + return 1; + } + + total_out += iov[niov].iov_len; + niov++; + + lua_pop(L, 1); + } + + lua_pop(L, 1); + } + + wh = g_malloc0(sizeof(*wh)); + wh->type = LUA_WANT_WRITE; + wh->h.w.iov = iov; + wh->h.w.iovlen = niov; + wh->h.w.total_bytes = total_out; + wh->h.w.pos = 0; + /* Cannot set write handler here */ + wh->h.w.cbref = cbref; + msg_debug_tcp("added write event, cbref: %d", cbref); + + g_queue_push_tail(cbd->handlers, wh); + lua_pushboolean(L, TRUE); + + return 1; +} + +static gint +lua_tcp_shift_callback(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_tcp(L, 1); + + if (cbd == NULL) { + return luaL_error(L, "invalid arguments"); + } + + lua_tcp_shift_handler(cbd); + lua_tcp_plan_handler_event(cbd, TRUE, TRUE); + + return 0; +} + +static struct lua_tcp_cbdata * +lua_check_sync_tcp(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{tcp_sync}"); + luaL_argcheck(L, ud != NULL, pos, "'tcp' expected"); + return ud ? *((struct lua_tcp_cbdata **) ud) : NULL; +} + +static int +lua_tcp_sync_close(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_sync_tcp(L, 1); + + if (cbd == NULL) { + return luaL_error(L, "invalid arguments [self is not rspamd{tcp_sync}]"); + } + cbd->flags |= LUA_TCP_FLAG_FINISHED; + + if (cbd->fd != -1) { + rspamd_ev_watcher_stop(cbd->event_loop, &cbd->ev); + close(cbd->fd); + cbd->fd = -1; + } + + return 0; +} + +static void +lua_tcp_sync_session_dtor(gpointer ud) +{ + struct lua_tcp_cbdata *cbd = ud; + cbd->flags |= LUA_TCP_FLAG_FINISHED; + + if (cbd->fd != -1) { + msg_debug("closing sync TCP connection"); + rspamd_ev_watcher_stop(cbd->event_loop, &cbd->ev); + close(cbd->fd); + cbd->fd = -1; + } + + /* Task is gone, we should not try use it anymore */ + cbd->task = NULL; + + /* All events are removed when task is done, we should not refer them */ + cbd->async_ev = NULL; +} + +static int +lua_tcp_sync_read_once(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_sync_tcp(L, 1); + struct lua_tcp_handler *rh; + + if (cbd == NULL) { + return luaL_error(L, "invalid arguments [self is not rspamd{tcp_sync}]"); + } + + struct thread_entry *thread = lua_thread_pool_get_running_entry(cbd->cfg->lua_thread_pool); + + rh = g_malloc0(sizeof(*rh)); + rh->type = LUA_WANT_READ; + rh->h.r.cbref = -1; + + msg_debug_tcp("added read sync event, thread: %p", thread); + + g_queue_push_tail(cbd->handlers, rh); + lua_tcp_plan_handler_event(cbd, TRUE, TRUE); + + TCP_RETAIN(cbd); + + return lua_thread_yield(thread, 0); +} + +static int +lua_tcp_sync_write(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_sync_tcp(L, 1); + struct lua_tcp_handler *wh; + gint tp; + struct iovec *iov = NULL; + guint niov = 0; + gsize total_out = 0; + + if (cbd == NULL) { + return luaL_error(L, "invalid arguments [self is not rspamd{tcp_sync}]"); + } + + struct thread_entry *thread = lua_thread_pool_get_running_entry(cbd->cfg->lua_thread_pool); + + tp = lua_type(L, 2); + if (tp == LUA_TSTRING || tp == LUA_TUSERDATA) { + iov = g_malloc(sizeof(*iov)); + niov = 1; + + if (!lua_tcp_arg_toiovec(L, 2, cbd, iov)) { + msg_err("tcp request has bad data argument"); + g_free(iov); + g_free(cbd); + + return luaL_error(L, "invalid arguments second parameter (data) is expected to be either string or rspamd{text}"); + } + + total_out = iov[0].iov_len; + } + else if (tp == LUA_TTABLE) { + /* Count parts */ + lua_pushvalue(L, 3); + + lua_pushnil(L); + while (lua_next(L, -2) != 0) { + niov++; + lua_pop(L, 1); + } + + iov = g_malloc(sizeof(*iov) * niov); + lua_pushnil(L); + niov = 0; + + while (lua_next(L, -2) != 0) { + if (!lua_tcp_arg_toiovec(L, -1, cbd, &iov[niov])) { + msg_err("tcp request has bad data argument at pos %d", niov); + g_free(iov); + g_free(cbd); + + return luaL_error(L, "invalid arguments second parameter (data) is expected to be either string or rspamd{text}"); + } + + total_out += iov[niov].iov_len; + niov++; + + lua_pop(L, 1); + } + + lua_pop(L, 1); + } + + wh = g_malloc0(sizeof(*wh)); + wh->type = LUA_WANT_WRITE; + wh->h.w.iov = iov; + wh->h.w.iovlen = niov; + wh->h.w.total_bytes = total_out; + wh->h.w.pos = 0; + wh->h.w.cbref = -1; + msg_debug_tcp("added sync write event, thread: %p", thread); + + g_queue_push_tail(cbd->handlers, wh); + lua_tcp_plan_handler_event(cbd, TRUE, TRUE); + + TCP_RETAIN(cbd); + + return lua_thread_yield(thread, 0); +} + +static gint +lua_tcp_sync_eof(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_sync_tcp(L, 1); + if (cbd == NULL) { + return luaL_error(L, "invalid arguments [self is not rspamd{tcp_sync}]"); + } + + lua_pushboolean(L, cbd->eof); + + return 1; +} + +static gint +lua_tcp_sync_shutdown(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_sync_tcp(L, 1); + if (cbd == NULL) { + return luaL_error(L, "invalid arguments [self is not rspamd{tcp_sync}]"); + } + + shutdown(cbd->fd, SHUT_WR); + + return 0; +} + +static gint +lua_tcp_starttls(lua_State *L) +{ + LUA_TRACE_POINT; + struct lua_tcp_cbdata *cbd = lua_check_tcp(L, 1); + gpointer ssl_ctx; + gboolean verify_peer; + + if (cbd == NULL || cbd->ssl_conn != NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (cbd->flags & LUA_TCP_FLAG_SSL_NOVERIFY) { + ssl_ctx = cbd->cfg->libs_ctx->ssl_ctx_noverify; + verify_peer = FALSE; + } + else { + ssl_ctx = cbd->cfg->libs_ctx->ssl_ctx; + verify_peer = TRUE; + } + + cbd->ssl_conn = rspamd_ssl_connection_new(ssl_ctx, + cbd->event_loop, + verify_peer, + cbd->tag); + + if (!rspamd_ssl_connect_fd(cbd->ssl_conn, cbd->fd, cbd->hostname, &cbd->ev, + cbd->ev.timeout, lua_tcp_handler, lua_tcp_ssl_on_error, cbd)) { + lua_tcp_push_error(cbd, TRUE, "ssl connection failed: %s", + strerror(errno)); + } + + return 0; +} + +static gint +lua_tcp_sync_gc(lua_State *L) +{ + struct lua_tcp_cbdata *cbd = lua_check_sync_tcp(L, 1); + if (!cbd) { + return luaL_error(L, "invalid arguments [self is not rspamd{tcp_sync}]"); + } + + lua_tcp_maybe_free(cbd); + lua_tcp_fin(cbd); + + return 0; +} + +static gint +lua_load_tcp(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, tcp_libf); + + return 1; +} + +void luaopen_tcp(lua_State *L) +{ + rspamd_lua_add_preload(L, "rspamd_tcp", lua_load_tcp); + rspamd_lua_new_class(L, "rspamd{tcp}", tcp_libm); + rspamd_lua_new_class(L, "rspamd{tcp_sync}", tcp_sync_libm); + lua_pop(L, 1); +} diff --git a/src/lua/lua_tensor.c b/src/lua/lua_tensor.c new file mode 100644 index 0000000..75e6139 --- /dev/null +++ b/src/lua/lua_tensor.c @@ -0,0 +1,817 @@ +/*- + * Copyright 2020 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "lua_common.h" +#include "lua_tensor.h" +#include "contrib/kann/kautodiff.h" +#include "blas-config.h" + +/*** + * @module rspamd_tensor + * `rspamd_tensor` is a simple Lua library to abstract matrices and vectors + * Internally, they are represented as arrays of float variables + * So far, merely 1D and 2D tensors are supported + */ + +LUA_FUNCTION_DEF(tensor, load); +LUA_FUNCTION_DEF(tensor, save); +LUA_FUNCTION_DEF(tensor, new); +LUA_FUNCTION_DEF(tensor, fromtable); +LUA_FUNCTION_DEF(tensor, destroy); +LUA_FUNCTION_DEF(tensor, mul); +LUA_FUNCTION_DEF(tensor, tostring); +LUA_FUNCTION_DEF(tensor, index); +LUA_FUNCTION_DEF(tensor, newindex); +LUA_FUNCTION_DEF(tensor, len); +LUA_FUNCTION_DEF(tensor, eigen); +LUA_FUNCTION_DEF(tensor, mean); +LUA_FUNCTION_DEF(tensor, transpose); +LUA_FUNCTION_DEF(tensor, has_blas); +LUA_FUNCTION_DEF(tensor, scatter_matrix); + +static luaL_reg rspamd_tensor_f[] = { + LUA_INTERFACE_DEF(tensor, load), + LUA_INTERFACE_DEF(tensor, new), + LUA_INTERFACE_DEF(tensor, fromtable), + LUA_INTERFACE_DEF(tensor, has_blas), + LUA_INTERFACE_DEF(tensor, scatter_matrix), + {NULL, NULL}, +}; + +static luaL_reg rspamd_tensor_m[] = { + LUA_INTERFACE_DEF(tensor, save), + {"__gc", lua_tensor_destroy}, + {"__mul", lua_tensor_mul}, + {"mul", lua_tensor_mul}, + {"tostring", lua_tensor_tostring}, + {"__tostring", lua_tensor_tostring}, + {"__index", lua_tensor_index}, + {"__newindex", lua_tensor_newindex}, + {"__len", lua_tensor_len}, + LUA_INTERFACE_DEF(tensor, eigen), + LUA_INTERFACE_DEF(tensor, mean), + LUA_INTERFACE_DEF(tensor, transpose), + {NULL, NULL}, +}; + +struct rspamd_lua_tensor * +lua_newtensor(lua_State *L, int ndims, const int *dim, bool zero_fill, bool own) +{ + struct rspamd_lua_tensor *res; + + res = lua_newuserdata(L, sizeof(struct rspamd_lua_tensor)); + memset(res, 0, sizeof(*res)); + + res->ndims = ndims; + res->size = 1; + + for (guint i = 0; i < ndims; i++) { + res->size *= dim[i]; + res->dim[i] = dim[i]; + } + + /* To avoid allocating large stuff in Lua */ + if (own) { + res->data = g_malloc(sizeof(rspamd_tensor_num_t) * res->size); + + if (zero_fill) { + memset(res->data, 0, sizeof(rspamd_tensor_num_t) * res->size); + } + } + else { + /* Mark size negative to distinguish */ + res->size = -(res->size); + } + + rspamd_lua_setclass(L, TENSOR_CLASS, -1); + + return res; +} + +/*** + * @function tensor.new(ndims, [dim1, ... dimN]) + * Creates a new zero filled tensor with the specific number of dimensions + * @return + */ +static gint +lua_tensor_new(lua_State *L) +{ + gint ndims = luaL_checkinteger(L, 1); + + if (ndims > 0 && ndims <= 2) { + gint *dims = g_alloca(sizeof(gint) * ndims); + + for (guint i = 0; i < ndims; i++) { + dims[i] = lua_tointeger(L, i + 2); + } + + (void) lua_newtensor(L, ndims, dims, true, true); + } + else { + return luaL_error(L, "incorrect dimensions number: %d", ndims); + } + + return 1; +} + +/*** + * @function tensor.fromtable(tbl) + * Creates a new zero filled tensor with the specific number of dimensions + * @return + */ +static gint +lua_tensor_fromtable(lua_State *L) +{ + if (lua_istable(L, 1)) { + lua_rawgeti(L, 1, 1); + + if (lua_isnumber(L, -1)) { + lua_pop(L, 1); + /* Input vector */ + gint dims[2]; + dims[0] = 1; + dims[1] = rspamd_lua_table_size(L, 1); + + struct rspamd_lua_tensor *res = lua_newtensor(L, 2, + dims, false, true); + + for (guint i = 0; i < dims[1]; i++) { + lua_rawgeti(L, 1, i + 1); + res->data[i] = lua_tonumber(L, -1); + lua_pop(L, 1); + } + } + else if (lua_istable(L, -1)) { + /* Input matrix */ + lua_pop(L, 1); + + /* Calculate the overall size */ + gint nrows = rspamd_lua_table_size(L, 1), ncols = 0; + gint err; + + for (gint i = 0; i < nrows; i++) { + lua_rawgeti(L, 1, i + 1); + + if (ncols == 0) { + ncols = rspamd_lua_table_size(L, -1); + + if (ncols == 0) { + lua_pop(L, 1); + err = luaL_error(L, "invalid params at pos %d: " + "bad input dimension %d", + i, + (int) ncols); + + return err; + } + } + else { + if (ncols != rspamd_lua_table_size(L, -1)) { + gint t = rspamd_lua_table_size(L, -1); + + lua_pop(L, 1); + err = luaL_error(L, "invalid params at pos %d: " + "bad input dimension %d; %d expected", + i, + t, + ncols); + + return err; + } + } + + lua_pop(L, 1); + } + + gint dims[2]; + dims[0] = nrows; + dims[1] = ncols; + + struct rspamd_lua_tensor *res = lua_newtensor(L, 2, + dims, false, true); + + for (gint i = 0; i < nrows; i++) { + lua_rawgeti(L, 1, i + 1); + + for (gint j = 0; j < ncols; j++) { + lua_rawgeti(L, -1, j + 1); + + res->data[i * ncols + j] = lua_tonumber(L, -1); + + lua_pop(L, 1); + } + + lua_pop(L, 1); + } + } + else { + lua_pop(L, 1); + return luaL_error(L, "incorrect table"); + } + } + else { + return luaL_error(L, "incorrect input"); + } + + return 1; +} + + +/*** + * @method tensor:destroy() + * Tensor destructor + * @return + */ +static gint +lua_tensor_destroy(lua_State *L) +{ + struct rspamd_lua_tensor *t = lua_check_tensor(L, 1); + + if (t) { + if (t->size > 0) { + g_free(t->data); + } + } + + return 0; +} + +/*** + * @method tensor:save() + * Tensor serialisation function + * @return + */ +static gint +lua_tensor_save(lua_State *L) +{ + struct rspamd_lua_tensor *t = lua_check_tensor(L, 1); + gint size; + + if (t) { + if (t->size > 0) { + size = t->size; + } + else { + size = -(t->size); + } + + gsize sz = sizeof(gint) * 4 + size * sizeof(rspamd_tensor_num_t); + guchar *data; + + struct rspamd_lua_text *out = lua_new_text(L, NULL, 0, TRUE); + + data = g_malloc(sz); + memcpy(data, &t->ndims, sizeof(int)); + memcpy(data + sizeof(int), &size, sizeof(int)); + memcpy(data + 2 * sizeof(int), t->dim, sizeof(int) * 2); + memcpy(data + 4 * sizeof(int), t->data, + size * sizeof(rspamd_tensor_num_t)); + + out->start = (const gchar *) data; + out->len = sz; + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_tensor_tostring(lua_State *L) +{ + struct rspamd_lua_tensor *t = lua_check_tensor(L, 1); + + if (t) { + GString *out = g_string_sized_new(128); + + if (t->ndims == 1) { + /* Print as a vector */ + for (gint i = 0; i < t->dim[0]; i++) { + rspamd_printf_gstring(out, "%.4f ", t->data[i]); + } + /* Trim last space */ + out->len--; + } + else { + for (gint i = 0; i < t->dim[0]; i++) { + for (gint j = 0; j < t->dim[1]; j++) { + rspamd_printf_gstring(out, "%.4f ", + t->data[i * t->dim[1] + j]); + } + /* Trim last space */ + out->len--; + rspamd_printf_gstring(out, "\n"); + } + /* Trim last ; */ + out->len--; + } + + lua_pushlstring(L, out->str, out->len); + + g_string_free(out, TRUE); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_tensor_index(lua_State *L) +{ + struct rspamd_lua_tensor *t = lua_check_tensor(L, 1); + gint idx; + + if (t) { + if (lua_isnumber(L, 2)) { + idx = lua_tointeger(L, 2); + + if (t->ndims == 1) { + /* Individual element */ + if (idx <= t->dim[0]) { + lua_pushnumber(L, t->data[idx - 1]); + } + else { + lua_pushnil(L); + } + } + else { + /* Push row */ + gint dim = t->dim[1]; + + + if (idx <= t->dim[0]) { + /* Non-owning tensor */ + struct rspamd_lua_tensor *res = + lua_newtensor(L, 1, &dim, false, false); + res->data = &t->data[(idx - 1) * t->dim[1]]; + } + else { + lua_pushnil(L); + } + } + } + else if (lua_isstring(L, 2)) { + /* Access to methods */ + lua_getmetatable(L, 1); + lua_pushvalue(L, 2); + lua_rawget(L, -2); + } + } + + return 1; +} +static gint +lua_tensor_newindex(lua_State *L) +{ + struct rspamd_lua_tensor *t = lua_check_tensor(L, 1); + gint idx; + + if (t) { + if (lua_isnumber(L, 2)) { + idx = lua_tointeger(L, 2); + + if (t->ndims == 1) { + /* Individual element */ + if (idx <= t->dim[0] && idx > 0) { + rspamd_tensor_num_t value = lua_tonumber(L, 3), old; + + old = t->data[idx - 1]; + t->data[idx - 1] = value; + lua_pushnumber(L, old); + } + else { + return luaL_error(L, "invalid index: %d", idx); + } + } + else { + if (lua_isnumber(L, 3)) { + return luaL_error(L, "cannot assign number to a row"); + } + else if (lua_isuserdata(L, 3)) { + /* Tensor assignment */ + struct rspamd_lua_tensor *row = lua_check_tensor(L, 3); + + if (row) { + if (row->ndims == 1) { + if (row->dim[0] == t->dim[1]) { + if (idx > 0 && idx <= t->dim[0]) { + idx--; /* Zero based index */ + memcpy(&t->data[idx * t->dim[1]], + row->data, + t->dim[1] * sizeof(rspamd_tensor_num_t)); + + return 0; + } + else { + return luaL_error(L, "invalid index: %d", idx); + } + } + } + else { + return luaL_error(L, "cannot assign matrix to row"); + } + } + else { + return luaL_error(L, "cannot assign row, invalid tensor"); + } + } + else { + /* TODO: add table assignment */ + return luaL_error(L, "cannot assign row, not a tensor"); + } + } + } + else { + /* Access to methods? NYI */ + return luaL_error(L, "cannot assign method of a tensor"); + } + } + + return 1; +} + +/*** + * @method tensor:mul(other, [transA, [transB]]) + * Multiply two tensors (optionally transposed) and return a new tensor + * @return + */ +static gint +lua_tensor_mul(lua_State *L) +{ + struct rspamd_lua_tensor *t1 = lua_check_tensor(L, 1), + *t2 = lua_check_tensor(L, 2), *res; + int transA = 0, transB = 0; + + if (lua_isboolean(L, 3)) { + transA = lua_toboolean(L, 3); + } + + if (lua_isboolean(L, 4)) { + transB = lua_toboolean(L, 4); + } + + if (t1 && t2) { + gint dims[2], shadow_dims[2]; + dims[0] = abs(transA ? t1->dim[1] : t1->dim[0]); + shadow_dims[0] = abs(transB ? t2->dim[1] : t2->dim[0]); + dims[1] = abs(transB ? t2->dim[0] : t2->dim[1]); + shadow_dims[1] = abs(transA ? t1->dim[0] : t1->dim[1]); + + if (shadow_dims[0] != shadow_dims[1]) { + return luaL_error(L, "incompatible dimensions %d x %d * %d x %d", + dims[0], shadow_dims[1], shadow_dims[0], dims[1]); + } + else if (shadow_dims[0] == 0) { + /* Row * Column -> matrix */ + shadow_dims[0] = 1; + shadow_dims[1] = 1; + } + + if (dims[0] == 0) { + /* Column */ + dims[0] = 1; + + if (dims[1] == 0) { + /* Column * row -> number */ + dims[1] = 1; + } + res = lua_newtensor(L, 2, dims, true, true); + } + else if (dims[1] == 0) { + /* Row */ + res = lua_newtensor(L, 1, dims, true, true); + dims[1] = 1; + } + else { + res = lua_newtensor(L, 2, dims, true, true); + } + + kad_sgemm_simple(transA, transB, dims[0], dims[1], shadow_dims[0], + t1->data, t2->data, res->data); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @function tensor.load(rspamd_text) + * Deserialize tensor + * @return + */ +static gint +lua_tensor_load(lua_State *L) +{ + const guchar *data; + gsize sz; + + if (lua_type(L, 1) == LUA_TUSERDATA) { + struct rspamd_lua_text *t = lua_check_text(L, 1); + + if (!t) { + return luaL_error(L, "invalid argument"); + } + + data = (const guchar *) t->start; + sz = t->len; + } + else { + data = (const guchar *) lua_tolstring(L, 1, &sz); + } + + if (sz >= sizeof(gint) * 4) { + int ndims, nelts, dims[2]; + + memcpy(&ndims, data, sizeof(int)); + memcpy(&nelts, data + sizeof(int), sizeof(int)); + memcpy(dims, data + sizeof(int) * 2, sizeof(int) * 2); + + if (sz == nelts * sizeof(rspamd_tensor_num_t) + sizeof(int) * 4) { + if (ndims == 1) { + if (nelts == dims[0]) { + struct rspamd_lua_tensor *t = lua_newtensor(L, ndims, dims, false, true); + memcpy(t->data, data + sizeof(int) * 4, nelts * sizeof(rspamd_tensor_num_t)); + } + else { + return luaL_error(L, "invalid argument: bad dims: %d x %d != %d", + dims[0], 1, nelts); + } + } + else if (ndims == 2) { + if (nelts == dims[0] * dims[1]) { + struct rspamd_lua_tensor *t = lua_newtensor(L, ndims, dims, false, true); + memcpy(t->data, data + sizeof(int) * 4, nelts * sizeof(rspamd_tensor_num_t)); + } + else { + return luaL_error(L, "invalid argument: bad dims: %d x %d != %d", + dims[0], dims[1], nelts); + } + } + else { + return luaL_error(L, "invalid argument: bad ndims: %d", ndims); + } + } + else { + return luaL_error(L, "invalid size: %d, %d required, %d elts", (int) sz, + (int) (nelts * sizeof(rspamd_tensor_num_t) + sizeof(int) * 4), + nelts); + } + } + else { + return luaL_error(L, "invalid arguments; sz = %d", (int) sz); + } + + return 1; +} + +static gint +lua_tensor_len(lua_State *L) +{ + struct rspamd_lua_tensor *t = lua_check_tensor(L, 1); + gint nret = 1; + + if (t) { + /* Return the main dimension first */ + if (t->ndims == 1) { + lua_pushinteger(L, t->dim[0]); + } + else { + lua_pushinteger(L, t->dim[0]); + lua_pushinteger(L, t->dim[1]); + nret = 2; + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return nret; +} + +static gint +lua_tensor_eigen(lua_State *L) +{ + struct rspamd_lua_tensor *t = lua_check_tensor(L, 1), *eigen; + + if (t) { + if (t->ndims != 2 || t->dim[0] != t->dim[1]) { + return luaL_error(L, "expected square matrix NxN but got %dx%d", + t->dim[0], t->dim[1]); + } + + eigen = lua_newtensor(L, 1, &t->dim[0], true, true); + + if (!kad_ssyev_simple(t->dim[0], t->data, eigen->data)) { + lua_pop(L, 1); + return luaL_error(L, "kad_ssyev_simple failed (no blas?)"); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static inline rspamd_tensor_num_t +mean_vec(rspamd_tensor_num_t *x, gsize n) +{ + float sum = rspamd_sum_floats(x, &n); + return sum / (rspamd_tensor_num_t) n; +} + +static gint +lua_tensor_mean(lua_State *L) +{ + struct rspamd_lua_tensor *t = lua_check_tensor(L, 1); + + if (t) { + if (t->ndims == 1) { + /* Mean of all elements in a vector */ + lua_pushnumber(L, mean_vec(t->data, t->dim[0])); + } + else { + /* Row-wise mean vector output */ + struct rspamd_lua_tensor *res; + + res = lua_newtensor(L, 1, &t->dim[0], false, true); + + for (int i = 0; i < t->dim[0]; i++) { + res->data[i] = mean_vec(&t->data[i * t->dim[1]], t->dim[1]); + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_tensor_transpose(lua_State *L) +{ + struct rspamd_lua_tensor *t = lua_check_tensor(L, 1), *res; + int dims[2]; + + if (t) { + if (t->ndims == 1) { + /* Row to column */ + dims[0] = 1; + dims[1] = t->dim[0]; + res = lua_newtensor(L, 2, dims, false, true); + memcpy(res->data, t->data, t->dim[0] * sizeof(rspamd_tensor_num_t)); + } + else { + /* Cache friendly algorithm */ + struct rspamd_lua_tensor *res; + + dims[0] = t->dim[1]; + dims[1] = t->dim[0]; + res = lua_newtensor(L, 2, dims, false, true); + + static const int block = 32; + + for (int i = 0; i < t->dim[0]; i += block) { + for (int j = 0; j < t->dim[1]; ++j) { + for (int boff = 0; boff < block && i + boff < t->dim[0]; ++boff) { + res->data[j * t->dim[0] + i + boff] = + t->data[(i + boff) * t->dim[1] + j]; + } + } + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_tensor_has_blas(lua_State *L) +{ +#ifdef HAVE_CBLAS + lua_pushboolean(L, true); +#else + lua_pushboolean(L, false); +#endif + + return 1; +} + +static gint +lua_tensor_scatter_matrix(lua_State *L) +{ + struct rspamd_lua_tensor *t = lua_check_tensor(L, 1), *res; + int dims[2]; + + if (t) { + if (t->ndims != 2) { + return luaL_error(L, "matrix required"); + } + + /* X * X square matrix */ + dims[0] = t->dim[1]; + dims[1] = t->dim[1]; + res = lua_newtensor(L, 2, dims, true, true); + + /* Auxiliary vars */ + rspamd_tensor_num_t *means, /* means vector */ + *tmp_row, /* temp row for Kahan's algorithm */ + *tmp_square /* temp matrix for multiplications */; + means = g_malloc0(sizeof(rspamd_tensor_num_t) * t->dim[1]); + tmp_row = g_malloc0(sizeof(rspamd_tensor_num_t) * t->dim[1]); + tmp_square = g_malloc(sizeof(rspamd_tensor_num_t) * t->dim[1] * t->dim[1]); + + /* + * Column based means + * means will have s, tmp_row will have c + */ + for (int i = 0; i < t->dim[0]; i++) { + /* Cycle by rows */ + for (int j = 0; j < t->dim[1]; j++) { + rspamd_tensor_num_t v = t->data[i * t->dim[1] + j]; + rspamd_tensor_num_t y = v - tmp_row[j]; + rspamd_tensor_num_t st = means[j] + y; + tmp_row[j] = (st - means[j]) - y; + means[j] = st; + } + } + + for (int j = 0; j < t->dim[1]; j++) { + means[j] /= t->dim[0]; + } + + for (int i = 0; i < t->dim[0]; i++) { + /* Update for each sample */ + for (int j = 0; j < t->dim[1]; j++) { + tmp_row[j] = t->data[i * t->dim[1] + j] - means[j]; + } + + memset(tmp_square, 0, t->dim[1] * t->dim[1] * sizeof(rspamd_tensor_num_t)); + kad_sgemm_simple(1, 0, t->dim[1], t->dim[1], 1, + tmp_row, tmp_row, tmp_square); + + for (int j = 0; j < t->dim[1]; j++) { + kad_saxpy(t->dim[1], 1.0, &tmp_square[j * t->dim[1]], + &res->data[j * t->dim[1]]); + } + } + + g_free(tmp_row); + g_free(means); + g_free(tmp_square); + } + else { + return luaL_error(L, "tensor required"); + } + + return 1; +} + +static gint +lua_load_tensor(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, rspamd_tensor_f); + + return 1; +} + + +void luaopen_tensor(lua_State *L) +{ + /* Metatables */ + rspamd_lua_new_class(L, TENSOR_CLASS, rspamd_tensor_m); + lua_pop(L, 1); /* No need in metatable... */ + rspamd_lua_add_preload(L, "rspamd_tensor", lua_load_tensor); + lua_settop(L, 0); +} + +struct rspamd_lua_tensor * +lua_check_tensor(lua_State *L, int pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, TENSOR_CLASS); + luaL_argcheck(L, ud != NULL, pos, "'tensor' expected"); + return ud ? ((struct rspamd_lua_tensor *) ud) : NULL; +} diff --git a/src/lua/lua_tensor.h b/src/lua/lua_tensor.h new file mode 100644 index 0000000..2103868 --- /dev/null +++ b/src/lua/lua_tensor.h @@ -0,0 +1,34 @@ +/*- + * Copyright 2020 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef RSPAMD_LUA_TENSOR_H +#define RSPAMD_LUA_TENSOR_H + +#define TENSOR_CLASS "rspamd{tensor}" + +typedef float rspamd_tensor_num_t; + +struct rspamd_lua_tensor { + int ndims; + int size; /* overall size (product of dims) */ + int dim[2]; + rspamd_tensor_num_t *data; +}; + +struct rspamd_lua_tensor *lua_check_tensor(lua_State *L, int pos); +struct rspamd_lua_tensor *lua_newtensor(lua_State *L, int ndims, + const int *dim, bool zero_fill, bool own); + +#endif diff --git a/src/lua/lua_text.c b/src/lua/lua_text.c new file mode 100644 index 0000000..26a5c08 --- /dev/null +++ b/src/lua/lua_text.c @@ -0,0 +1,1789 @@ +/*- + * Copyright 2019 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "lua_common.h" +#include "libcryptobox/cryptobox.h" +#include "contrib/fastutf8/fastutf8.h" +#include "unix-std.h" + +/*** + * @module rspamd_text + * This module provides access to opaque text structures used widely to prevent + * copying between Lua and C for various concerns: performance, security etc... + * + * You can convert rspamd_text into string but it will copy data. + */ + +/*** + * @function rspamd_text.fromstring(str) + * Creates rspamd_text from Lua string (copied to the text) + * @param {string} str string to use + * @return {rspamd_text} resulting text + */ +LUA_FUNCTION_DEF(text, fromstring); + +/*** + * @function rspamd_text.null() + * Creates rspamd_text with NULL pointer for testing purposes + * @param {string} str string to use + * @return {rspamd_text} resulting text + */ +LUA_FUNCTION_DEF(text, null); +/*** + * @function rspamd_text.randombytes(nbytes) + * Creates rspamd_text with random bytes inside (raw bytes) + * @param {number} nbytes number of random bytes generated + * @return {rspamd_text} random bytes text + */ +LUA_FUNCTION_DEF(text, randombytes); + +/*** + * @function rspamd_text.fromtable(tbl[, delim]) + * Same as `table.concat` but generates rspamd_text instead of the Lua string + * @param {table} tbl table to use + * @param {string} delim optional delimiter + * @return {rspamd_text} resulting text + */ +LUA_FUNCTION_DEF(text, fromtable); +/*** + * @method rspamd_text:byte(pos[, pos2]) + * Returns a byte at the position `pos` or bytes from `pos` to `pos2` if specified + * @param {integer} pos index + * @param {integer} pos2 index + * @return {integer} byte at the position `pos` or varargs of bytes + */ +LUA_FUNCTION_DEF(text, byte); +/*** + * @method rspamd_text:len() + * Returns length of a string + * @return {number} length of string in **bytes** + */ +LUA_FUNCTION_DEF(text, len); +/*** + * @method rspamd_text:str() + * Converts text to string by copying its content + * @return {string} copy of text as Lua string + */ +LUA_FUNCTION_DEF(text, str); +/*** + * @method rspamd_text:ptr() + * Converts text to lightuserdata + * @return {lightuserdata} pointer value of rspamd_text + */ +LUA_FUNCTION_DEF(text, ptr); +/*** + * @method rspamd_text:save_in_file(fname[, mode]) + * Saves text in file + * @return {boolean} true if save has been completed + */ +LUA_FUNCTION_DEF(text, save_in_file); +/*** + * @method rspamd_text:span(start[, len]) + * Returns a span for lua_text starting at pos [start] (1 indexed) and with + * length `len` (or to the end of the text) + * @param {integer} start start index + * @param {integer} len length of span + * @return {rspamd_text} new rspamd_text with span (must be careful when using with owned texts...) + */ +LUA_FUNCTION_DEF(text, span); +/*** + * @method rspamd_text:sub(start[, len]) + * Returns a substring for lua_text similar to string.sub from Lua + * @return {rspamd_text} new rspamd_text with span (must be careful when using with owned texts...) + */ +LUA_FUNCTION_DEF(text, sub); +/*** + * @method rspamd_text:lines([stringify]) + * Returns an iter over all lines as rspamd_text objects or as strings if `stringify` is true + * @param {boolean} stringify stringify lines + * @return {iterator} iterator triplet + */ +LUA_FUNCTION_DEF(text, lines); +/*** + * @method rspamd_text:split(regexp, [stringify]) + * Returns an iter over all encounters of the specific regexp as rspamd_text objects or as strings if `stringify` is true + * @param {rspamd_regexp} regexp regexp (pcre syntax) used for splitting + * @param {boolean} stringify stringify lines + * @return {iterator} iterator triplet + */ +LUA_FUNCTION_DEF(text, split); +/*** + * @method rspamd_text:at(pos) + * Returns a byte at the position `pos` + * @param {integer} pos index + * @return {integer} byte at the position `pos` or nil if pos out of bound + */ +LUA_FUNCTION_DEF(text, at); +/*** + * @method rspamd_text:memchr(chr, [reverse]) + * Returns the first or the last position of the character `chr` in the text or + * -1 in case if a character has not been found. Indexes start from `1` + * @param {string/number} chr character or a character code to find + * @param {boolean} reverse last character if `true` + * @return {integer} position of the character or `-1` + */ +LUA_FUNCTION_DEF(text, memchr); +/*** + * @method rspamd_text:bytes() + * Converts text to an array of bytes + * @return {table|integer} bytes in the array (as unsigned char) + */ +LUA_FUNCTION_DEF(text, bytes); +/*** + * @method rspamd_text:lower([is_utf, [inplace]]) + * Return a new text with lowercased characters, if is_utf is true then Rspamd applies utf8 lowercase + * @param {boolean} is_utf apply utf8 lowercase + * @param {boolean} inplace lowercase the original text + * @return {rspamd_text} new rspamd_text (or the original text if inplace) with lowercased letters + */ +LUA_FUNCTION_DEF(text, lower); +LUA_FUNCTION_DEF(text, take_ownership); +/*** + * @method rspamd_text:exclude_chars(set_to_exclude, [always_copy]) + * Returns a text (if owned, then the original text is modified, if not, then it is copied and owned) + * where all chars from `set_to_exclude` are removed + * Patterns supported: + * + * - %s - all space characters + * - %n - all newline characters + * - %c - all control characters (it includes 8bit characters and spaces) + * - %8 - all 8 bit characters + * - %% - just a percent character + * + * @param {string} set_to_exclude characters to exclude + * @param {boolean} always_copy always copy the source text + * @return {rspamd_text} modified or copied text + */ +LUA_FUNCTION_DEF(text, exclude_chars); +/*** + * @method rspamd_text:oneline([always_copy]) + * Returns a text (if owned, then the original text is modified, if not, then it is copied and owned) + * where the following transformations are made: + * - All spaces sequences are replaced with a single space + * - All newlines sequences are replaced with a single space + * - Trailing and leading spaces are removed + * - Control characters are excluded + * - UTF8 sequences are normalised + * + * @param {boolean} always_copy always copy the source text + * @return {rspamd_text} modified or copied text + */ +LUA_FUNCTION_DEF(text, oneline); +/*** + * @method rspamd_text:base32([b32type]) + * Returns a text encoded in base32 (new rspamd_text is allocated) + * + * @param {string} b32type base32 type (default, bleach, rfc) + * @return {rspamd_text} new text encoded in base32 + */ +LUA_FUNCTION_DEF(text, base32); +/*** + * @method rspamd_text:base64([line_length, [nline, [fold]]]) + * Returns a text encoded in base64 (new rspamd_text is allocated) + * + * @param {number} line_length return text split with newlines up to this attribute + * @param {string} nline newline type: `cr`, `lf`, `crlf` + * @param {boolean} fold use folding when splitting into lines (false by default) + * @return {rspamd_text} new text encoded in base64 + */ +LUA_FUNCTION_DEF(text, base64); +/*** + * @method rspamd_text:hex() + * Returns a text encoded in hex (new rspamd_text is allocated) + * + * @return {rspamd_text} new text encoded in hex + */ +LUA_FUNCTION_DEF(text, hex); +/*** + * @method rspamd_text:find(pattern [, init]) + * Looks for the first match of pattern in the string s. + * If it finds a match, then find returns the indices of s where this occurrence + * starts and ends; otherwise, it returns nil. A third, + * optional numerical argument init specifies where to start the search; + * its default value is 1 and can be negative. + * This method currently supports merely a plain search, no patterns. + * + * @param {string} pattern pattern to find + * @param {number} init specifies where to start the search (1 default) + * @return {number,number/nil} If it finds a match, then find returns the indices of s where this occurrence starts and ends; otherwise, it returns nil + */ +LUA_FUNCTION_DEF(text, find); +LUA_FUNCTION_DEF(text, gc); +LUA_FUNCTION_DEF(text, eq); +LUA_FUNCTION_DEF(text, lt); +LUA_FUNCTION_DEF(text, concat); +LUA_FUNCTION_DEF(text, strtoul); + +static const struct luaL_reg textlib_f[] = { + LUA_INTERFACE_DEF(text, fromstring), + {"from_string", lua_text_fromstring}, + LUA_INTERFACE_DEF(text, fromtable), + {"from_table", lua_text_fromtable}, + LUA_INTERFACE_DEF(text, null), + LUA_INTERFACE_DEF(text, randombytes), + {NULL, NULL}}; + +static const struct luaL_reg textlib_m[] = { + LUA_INTERFACE_DEF(text, len), + LUA_INTERFACE_DEF(text, str), + LUA_INTERFACE_DEF(text, ptr), + LUA_INTERFACE_DEF(text, take_ownership), + LUA_INTERFACE_DEF(text, save_in_file), + LUA_INTERFACE_DEF(text, span), + LUA_INTERFACE_DEF(text, sub), + LUA_INTERFACE_DEF(text, lines), + LUA_INTERFACE_DEF(text, split), + LUA_INTERFACE_DEF(text, at), + LUA_INTERFACE_DEF(text, memchr), + LUA_INTERFACE_DEF(text, byte), + LUA_INTERFACE_DEF(text, bytes), + LUA_INTERFACE_DEF(text, lower), + LUA_INTERFACE_DEF(text, exclude_chars), + LUA_INTERFACE_DEF(text, oneline), + LUA_INTERFACE_DEF(text, base32), + LUA_INTERFACE_DEF(text, base64), + LUA_INTERFACE_DEF(text, hex), + LUA_INTERFACE_DEF(text, find), + LUA_INTERFACE_DEF(text, strtoul), + {"write", lua_text_save_in_file}, + {"__len", lua_text_len}, + {"__tostring", lua_text_str}, + {"__gc", lua_text_gc}, + {"__eq", lua_text_eq}, + {"__lt", lua_text_lt}, + {"__concat", lua_text_concat}, + {NULL, NULL}}; + +struct rspamd_lua_text * +lua_check_text(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{text}"); + luaL_argcheck(L, ud != NULL, pos, "'text' expected"); + return ud ? (struct rspamd_lua_text *) ud : NULL; +} + +struct rspamd_lua_text * +lua_check_text_or_string(lua_State *L, gint pos) +{ + gint pos_type = lua_type(L, pos); + + if (pos_type == LUA_TUSERDATA) { + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{text}"); + luaL_argcheck(L, ud != NULL, pos, "'text' expected"); + return ud ? (struct rspamd_lua_text *) ud : NULL; + } + else if (pos_type == LUA_TSTRING) { + /* + * Fake static lua_text, we allow to use this function multiple times + * by having a small array of static structures. + */ + static unsigned cur_txt_idx = 0; + static struct rspamd_lua_text fake_text[4]; + gsize len; + int sel_idx; + + sel_idx = cur_txt_idx++ % G_N_ELEMENTS(fake_text); + fake_text[sel_idx].start = lua_tolstring(L, pos, &len); + + if (len >= G_MAXUINT) { + return NULL; + } + + fake_text[sel_idx].len = len; + fake_text[sel_idx].flags = RSPAMD_TEXT_FLAG_FAKE; + + return &fake_text[sel_idx]; + } + + return NULL; +} + +struct rspamd_lua_text * +lua_new_text(lua_State *L, const gchar *start, gsize len, gboolean own) +{ + struct rspamd_lua_text *t; + + t = lua_newuserdata(L, sizeof(*t)); + t->flags = 0; + + if (own) { + gchar *storage; + + if (len > 0) { + storage = g_malloc(len); + + if (start != NULL) { + memcpy(storage, start, len); + } + + t->start = storage; + t->flags = RSPAMD_TEXT_FLAG_OWN; + } + else { + t->start = ""; + } + } + else { + t->start = start; + } + + t->len = len; + rspamd_lua_setclass(L, "rspamd{text}", -1); + + return t; +} + +struct rspamd_lua_text * +lua_new_text_task(lua_State *L, struct rspamd_task *task, + const gchar *start, gsize len, gboolean own) +{ + struct rspamd_lua_text *t; + + t = lua_newuserdata(L, sizeof(*t)); + t->flags = 0; + + if (own) { + gchar *storage; + + if (len > 0) { + storage = rspamd_mempool_alloc(task->task_pool, len); + + if (start != NULL) { + memcpy(storage, start, len); + } + + t->start = storage; + } + else { + t->start = ""; + } + } + else { + t->start = start; + } + + t->len = len; + rspamd_lua_setclass(L, "rspamd{text}", -1); + + return t; +} + +bool lua_is_text_binary(struct rspamd_lua_text *t) +{ + if (t == NULL || t->len == 0) { + return false; + } + + if (rspamd_str_has_8bit(t->start, t->len)) { + if (rspamd_fast_utf8_validate(t->start, t->len) == 0) { + return false; + } + return true; + } + + return false; +} + + +static gint +lua_text_fromstring(lua_State *L) +{ + LUA_TRACE_POINT; + const gchar *str; + gsize l = 0; + gboolean transparent = FALSE; + + str = luaL_checklstring(L, 1, &l); + + if (str) { + if (lua_isboolean(L, 2)) { + transparent = lua_toboolean(L, 2); + } + + lua_new_text(L, str, l, !transparent); + } + else { + return luaL_error(L, "invalid arguments"); + } + + + return 1; +} + +static gint +lua_text_null(lua_State *L) +{ + LUA_TRACE_POINT; + + lua_new_text(L, NULL, 0, false); + + return 1; +} + +static gint +lua_text_randombytes(lua_State *L) +{ + LUA_TRACE_POINT; + guint nbytes = luaL_checkinteger(L, 1); + struct rspamd_lua_text *out; + + out = lua_new_text(L, NULL, nbytes, TRUE); + randombytes_buf((char *) out->start, nbytes); + out->len = nbytes; + + return 1; +} + +#define MAX_REC 10 + +static void +lua_text_tbl_length(lua_State *L, gsize dlen, gsize *dest, guint rec) +{ + gsize tblen, stlen; + struct rspamd_lua_text *elt; + + if (rec > MAX_REC) { + luaL_error(L, "lua_text_tbl_length: recursion limit exceeded"); + + return; + } + + tblen = rspamd_lua_table_size(L, -1); + + for (gsize i = 0; i < tblen; i++) { + lua_rawgeti(L, -1, i + 1); + + if (lua_type(L, -1) == LUA_TSTRING) { +#if LUA_VERSION_NUM >= 502 + stlen = lua_rawlen(L, -1); +#else + stlen = lua_objlen(L, -1); +#endif + (*dest) += stlen; + } + else if (lua_type(L, -1) == LUA_TUSERDATA) { + elt = (struct rspamd_lua_text *) lua_touserdata(L, -1); + + if (elt) { + (*dest) += elt->len; + } + } + else if (lua_type(L, -1) == LUA_TTABLE) { + lua_text_tbl_length(L, dlen, dest, rec + 1); + } + + if (i != tblen - 1) { + (*dest) += dlen; + } + + lua_pop(L, 1); + } +} + +static void +lua_text_tbl_append(lua_State *L, + const gchar *delim, + gsize dlen, + gchar **dest, + guint rec) +{ + const gchar *st; + gsize tblen, stlen; + struct rspamd_lua_text *elt; + + if (rec > MAX_REC) { + luaL_error(L, "lua_text_tbl_length: recursion limit exceeded"); + + return; + } + + tblen = rspamd_lua_table_size(L, -1); + + for (guint i = 0; i < tblen; i++) { + lua_rawgeti(L, -1, i + 1); + + if (lua_type(L, -1) == LUA_TSTRING) { + st = lua_tolstring(L, -1, &stlen); + memcpy((*dest), st, stlen); + (*dest) += stlen; + } + else if (lua_type(L, -1) == LUA_TUSERDATA) { + elt = (struct rspamd_lua_text *) lua_touserdata(L, -1); + + if (elt) { + memcpy((*dest), elt->start, elt->len); + (*dest) += elt->len; + } + } + else if (lua_type(L, -1) == LUA_TTABLE) { + lua_text_tbl_append(L, delim, dlen, dest, rec + 1); + } + + if (dlen && i != tblen - 1) { + memcpy((*dest), delim, dlen); + (*dest) += dlen; + } + + lua_pop(L, 1); + } +} + +static gint +lua_text_fromtable(lua_State *L) +{ + LUA_TRACE_POINT; + const gchar *delim = ""; + struct rspamd_lua_text *t; + gsize textlen = 0, dlen, oldtop = lua_gettop(L); + gchar *dest; + + if (!lua_istable(L, 1)) { + return luaL_error(L, "invalid arguments"); + } + + if (lua_type(L, 2) == LUA_TSTRING) { + delim = lua_tolstring(L, 2, &dlen); + } + else { + dlen = 0; + } + + /* Calculate length needed */ + lua_pushvalue(L, 1); + lua_text_tbl_length(L, dlen, &textlen, 0); + lua_pop(L, 1); + + /* Allocate new text */ + t = lua_newuserdata(L, sizeof(*t)); + dest = g_malloc(textlen); + t->start = dest; + t->len = textlen; + t->flags = RSPAMD_TEXT_FLAG_OWN; + rspamd_lua_setclass(L, "rspamd{text}", -1); + + lua_pushvalue(L, 1); + lua_text_tbl_append(L, delim, dlen, &dest, 0); + lua_pop(L, 1); /* Table arg */ + + gint newtop = lua_gettop(L); + g_assert(newtop == oldtop + 1); + + return 1; +} + +static gint +lua_text_len(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text(L, 1); + gsize l = 0; + + if (t != NULL) { + l = t->len; + } + else { + return luaL_error(L, "invalid arguments"); + } + + lua_pushinteger(L, l); + + return 1; +} + +static gint +lua_text_str(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text(L, 1); + + if (t != NULL) { + lua_pushlstring(L, t->start, t->len); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_text_ptr(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text(L, 1); + + if (t != NULL) { + lua_pushlightuserdata(L, (gpointer) t->start); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_text_take_ownership(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text(L, 1); + gchar *dest; + + if (t != NULL) { + if (t->flags & RSPAMD_TEXT_FLAG_OWN) { + /* We already own it */ + lua_pushboolean(L, true); + } + else { + dest = g_malloc(t->len); + memcpy(dest, t->start, t->len); + t->start = dest; + t->flags |= RSPAMD_TEXT_FLAG_OWN; + lua_pushboolean(L, true); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_text_span(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text(L, 1); + gint64 start = lua_tointeger(L, 2), len = -1; + + if (t && start >= 1 && start <= t->len) { + if (lua_isnumber(L, 3)) { + len = lua_tonumber(L, 3); + } + + if (len == -1) { + len = t->len - (start - 1); + } + + if (len < 0 || (len > (t->len - (start - 1)))) { + return luaL_error(L, "invalid length"); + } + + lua_new_text(L, t->start + (start - 1), len, FALSE); + } + else { + if (!t) { + return luaL_error(L, "invalid arguments, text required"); + } + else { + return luaL_error(L, "invalid arguments: start offset %d " + "is larger than text len %d", + (int) start, (int) t->len); + } + } + + return 1; +} + +/* Helpers to behave exactly as Lua does */ +static inline gsize +relative_pos_start(gint pos, gsize len) +{ + if (pos > 0) { + return pos; + } + else if (pos == 0) { + return 1; + } + else if (pos < -((gint) len)) { + return 1; + } + + /* Negative pos inside str */ + return len + ((gsize) pos) + 1; +} + +static inline gsize +relative_pos_end(gint pos, gsize len) +{ + if (pos > (gint) len) { + return len; + } + else if (pos >= 0) { + return (size_t) pos; + } + else if (pos < -((gint) len)) { + return 0; + } + + return len + ((gsize) pos) + 1; +} + +static gint +lua_text_sub(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text(L, 1); + + if (t) { + size_t start = relative_pos_start(luaL_checkinteger(L, 2), + t->len); + size_t end = relative_pos_end(luaL_optinteger(L, 3, -1), + t->len); + + + if (start <= end) { + lua_new_text(L, t->start + (start - 1), + (end - start) + 1, FALSE); + } + else { + lua_new_text(L, "", 0, TRUE); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint64 +rspamd_lua_text_push_line(lua_State *L, + struct rspamd_lua_text *t, + gint64 start_offset, + const gchar *sep_pos, + gboolean stringify) +{ + const gchar *start; + gsize len; + gint64 ret; + + start = t->start + start_offset; + len = sep_pos ? (sep_pos - start) : (t->len - start_offset); + ret = start_offset + len; + + /* Trim line */ + while (len > 0) { + if (start[len - 1] == '\r' || start[len - 1] == '\n') { + len--; + } + else { + break; + } + } + + if (stringify) { + lua_pushlstring(L, start, len); + } + else { + struct rspamd_lua_text *ntext; + + ntext = lua_newuserdata(L, sizeof(*ntext)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + ntext->start = start; + ntext->len = len; + ntext->flags = 0; /* Not own as it must be owned by a top object */ + } + + return ret; +} + +static gint +rspamd_lua_text_readline(lua_State *L) +{ + struct rspamd_lua_text *t = lua_touserdata(L, lua_upvalueindex(1)); + gboolean stringify = lua_toboolean(L, lua_upvalueindex(2)); + gint64 pos = lua_tointeger(L, lua_upvalueindex(3)); + + if (pos < 0) { + return luaL_error(L, "invalid pos: %d", (gint) pos); + } + + if (pos >= t->len) { + /* We are done */ + return 0; + } + + const gchar *sep_pos; + + /* We look just for `\n` ignoring `\r` as it is very rare nowadays */ + sep_pos = memchr(t->start + pos, '\n', t->len - pos); + + if (sep_pos == NULL) { + /* Either last `\n` or `\r` separated text */ + sep_pos = memchr(t->start + pos, '\r', t->len - pos); + } + + pos = rspamd_lua_text_push_line(L, t, pos, sep_pos, stringify); + + /* Skip separators */ + while (pos < t->len) { + if (t->start[pos] == '\n' || t->start[pos] == '\r') { + pos++; + } + else { + break; + } + } + + /* Update pos */ + lua_pushinteger(L, pos); + lua_replace(L, lua_upvalueindex(3)); + + return 1; +} + +static gint +lua_text_lines(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text(L, 1); + gboolean stringify = FALSE; + + if (t) { + if (lua_isboolean(L, 2)) { + stringify = lua_toboolean(L, 2); + } + + lua_pushvalue(L, 1); + lua_pushboolean(L, stringify); + lua_pushinteger(L, 0); /* Current pos */ + lua_pushcclosure(L, rspamd_lua_text_readline, 3); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +rspamd_lua_text_regexp_split(lua_State *L) +{ + struct rspamd_lua_text *t = lua_touserdata(L, lua_upvalueindex(1)), + *new_t; + struct rspamd_lua_regexp *re = *(struct rspamd_lua_regexp **) + lua_touserdata(L, lua_upvalueindex(2)); + gboolean stringify = lua_toboolean(L, lua_upvalueindex(3)); + gint64 pos = lua_tointeger(L, lua_upvalueindex(4)); + gboolean matched; + + if (pos < 0) { + return luaL_error(L, "invalid pos: %d", (gint) pos); + } + + if (pos >= t->len) { + /* We are done */ + return 0; + } + + const gchar *start, *end, *old_start; + + end = t->start + pos; + + for (;;) { + old_start = end; + + matched = rspamd_regexp_search(re->re, t->start, t->len, &start, &end, FALSE, + NULL); + + if (matched) { + if (start - old_start > 0) { + if (stringify) { + lua_pushlstring(L, old_start, start - old_start); + } + else { + new_t = lua_newuserdata(L, sizeof(*t)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + new_t->start = old_start; + new_t->len = start - old_start; + new_t->flags = 0; + } + + break; + } + else { + if (start == end) { + matched = FALSE; + break; + } + /* + * All match separators (e.g. starting separator, + * we need to skip it). Continue iterations. + */ + } + } + else { + /* No match, stop */ + break; + } + } + + if (!matched && (t->len > 0 && (end == NULL || end < t->start + t->len))) { + /* No more matches, but we might need to push the last element */ + if (end == NULL) { + end = t->start; + } + /* No separators, need to push the whole remaining part */ + if (stringify) { + lua_pushlstring(L, end, (t->start + t->len) - end); + } + else { + new_t = lua_newuserdata(L, sizeof(*t)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + new_t->start = end; + new_t->len = (t->start + t->len) - end; + new_t->flags = 0; + } + + pos = t->len; + } + else { + + pos = end - t->start; + } + + /* Update pos */ + lua_pushinteger(L, pos); + lua_replace(L, lua_upvalueindex(4)); + + return 1; +} + +static gint +lua_text_split(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text(L, 1); + struct rspamd_lua_regexp *re; + gboolean stringify = FALSE, own_re = FALSE; + + if (t == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (lua_type(L, 2) == LUA_TUSERDATA) { + re = lua_check_regexp(L, 2); + } + else { + rspamd_regexp_t *c_re; + GError *err = NULL; + + c_re = rspamd_regexp_new(lua_tostring(L, 2), NULL, &err); + if (c_re == NULL) { + + gint ret = luaL_error(L, "cannot parse regexp: %s, error: %s", + lua_tostring(L, 2), + err == NULL ? "undefined" : err->message); + if (err) { + g_error_free(err); + } + + return ret; + } + + re = g_malloc0(sizeof(struct rspamd_lua_regexp)); + re->re = c_re; + re->re_pattern = g_strdup(lua_tostring(L, 2)); + re->module = rspamd_lua_get_module_name(L); + own_re = TRUE; + } + + if (re) { + if (lua_isboolean(L, 3)) { + stringify = lua_toboolean(L, 3); + } + + /* Upvalues */ + lua_pushvalue(L, 1); /* text */ + + if (own_re) { + struct rspamd_lua_regexp **pre; + pre = lua_newuserdata(L, sizeof(struct rspamd_lua_regexp *)); + rspamd_lua_setclass(L, "rspamd{regexp}", -1); + *pre = re; + } + else { + lua_pushvalue(L, 2); /* regexp */ + } + + lua_pushboolean(L, stringify); + lua_pushinteger(L, 0); /* Current pos */ + lua_pushcclosure(L, rspamd_lua_text_regexp_split, 4); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + + +static gint +lua_text_at(lua_State *L) +{ + return lua_text_byte(L); +} + +static gint +lua_text_byte(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text(L, 1); + if (!t) { + return luaL_error(L, "invalid arguments"); + } + + gsize start = relative_pos_start(luaL_optinteger(L, 2, 1), t->len); + gsize end = relative_pos_end(luaL_optinteger(L, 3, start), t->len); + start--; + + if (start >= end) { + return 0; + } + + for (gsize i = start; i < end; i++) { + lua_pushinteger(L, t->start[i]); + } + return end - start; +} + +static gint +lua_text_memchr(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text(L, 1); + int c; + bool reverse = false; + + if (lua_isnumber(L, 2)) { + c = lua_tonumber(L, 2); + } + else { + gsize l; + const gchar *str = lua_tolstring(L, 2, &l); + + if (str) { + c = str[0]; + + if (l != 1) { + return luaL_error(L, "need exactly one character to search"); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + } + + if (t) { + void *f; + + if (lua_isboolean(L, 3)) { + reverse = lua_toboolean(L, 3); + } + + if (reverse) { + f = rspamd_memrchr(t->start, c, t->len); + } + else { + f = memchr(t->start, c, t->len); + } + + if (f) { + lua_pushinteger(L, ((const char *) f) - t->start + 1); + } + else { + lua_pushinteger(L, -1); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_text_bytes(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text(L, 1); + + if (t) { + lua_createtable(L, t->len, 0); + + for (gsize i = 0; i < t->len; i++) { + lua_pushinteger(L, (guchar) t->start[i]); + lua_rawseti(L, -2, i + 1); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_text_save_in_file(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text(L, 1); + const gchar *fname = NULL; + guint mode = 00644; + gint fd = -1; + gboolean need_close = FALSE; + + if (t != NULL) { + if (lua_type(L, 2) == LUA_TSTRING) { + fname = luaL_checkstring(L, 2); + + if (lua_type(L, 3) == LUA_TNUMBER) { + mode = lua_tointeger(L, 3); + } + } + else if (lua_type(L, 2) == LUA_TNUMBER) { + /* Created fd */ + fd = lua_tointeger(L, 2); + } + + if (fd == -1) { + if (fname) { + fd = rspamd_file_xopen(fname, O_CREAT | O_WRONLY | O_EXCL, mode, 0); + + if (fd == -1) { + lua_pushboolean(L, false); + lua_pushstring(L, strerror(errno)); + + return 2; + } + need_close = TRUE; + } + else { + fd = STDOUT_FILENO; + } + } + + if (write(fd, t->start, t->len) == -1) { + if (fd != STDOUT_FILENO) { + close(fd); + } + + lua_pushboolean(L, false); + lua_pushstring(L, strerror(errno)); + + return 2; + } + + if (need_close) { + close(fd); + } + + lua_pushboolean(L, true); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_text_gc(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text(L, 1); + + if (t != NULL) { + g_assert(!(t->flags & RSPAMD_TEXT_FLAG_FAKE)); + + if (t->flags & RSPAMD_TEXT_FLAG_OWN) { + if (t->flags & RSPAMD_TEXT_FLAG_WIPE) { + rspamd_explicit_memzero((guchar *) t->start, t->len); + } + + if (t->flags & RSPAMD_TEXT_FLAG_MMAPED) { + munmap((gpointer) t->start, t->len); + } + else { + if (t->flags & RSPAMD_TEXT_FLAG_SYSMALLOC) { + free((gpointer) t->start); + } + else { + g_free((gpointer) t->start); + } + } + } + } + + return 0; +} + +static gint +lua_text_eq(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t1 = lua_check_text_or_string(L, 1), + *t2 = lua_check_text_or_string(L, 2); + + if (t1->len == t2->len) { + lua_pushboolean(L, memcmp(t1->start, t2->start, t1->len) == 0); + } + else { + lua_pushboolean(L, false); + } + + return 1; +} + +static gint +lua_text_lt(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t1 = lua_check_text_or_string(L, 1), + *t2 = lua_check_text_or_string(L, 2); + + if (t1 && t2) { + if (t1->len == t2->len) { + lua_pushboolean(L, memcmp(t1->start, t2->start, t1->len) < 0); + } + else { + lua_pushboolean(L, t1->len < t2->len); + } + } + + return 1; +} + +static gint +lua_text_concat(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t1 = lua_check_text_or_string(L, 1), + *t2 = lua_check_text_or_string(L, 2); + + if (t1 && t2) { + struct rspamd_lua_text *final; + + final = lua_new_text(L, NULL, t1->len + t2->len, TRUE); + memcpy((void *) final->start, t1->start, t1->len); + memcpy((void *) (final->start + t1->len), t2->start, t2->len); + } + + return 1; +} + +static gint +lua_text_wipe(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text(L, 1); + + if (t != NULL) { + if (t->flags & RSPAMD_TEXT_FLAG_OWN) { + rspamd_explicit_memzero((guchar *) t->start, t->len); + } + else { + return luaL_error(L, "cannot wipe not owned text"); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static gint +lua_text_base32(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text(L, 1), *out; + enum rspamd_base32_type btype = RSPAMD_BASE32_DEFAULT; + + if (t != NULL) { + if (lua_type(L, 2) == LUA_TSTRING) { + btype = rspamd_base32_decode_type_from_str(lua_tostring(L, 2)); + + if (btype == RSPAMD_BASE32_INVALID) { + return luaL_error(L, "invalid b32 type: %s", lua_tostring(L, 2)); + } + } + + out = lua_new_text(L, NULL, t->len * 8 / 5 + 2, TRUE); + out->len = rspamd_encode_base32_buf(t->start, t->len, (gchar *) out->start, + out->len, btype); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_text_base64(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text(L, 1), *out; + gsize line_len = 0; + gboolean fold = FALSE; + + if (t != NULL) { + if (lua_type(L, 2) == LUA_TNUMBER) { + line_len = lua_tointeger(L, 2); + + if (line_len <= 8) { + return luaL_error(L, "too small line length (at least 8 is required)"); + } + } + + enum rspamd_newlines_type how = RSPAMD_TASK_NEWLINES_CRLF; + + if (lua_type(L, 3) == LUA_TSTRING) { + const gchar *how_str = lua_tostring(L, 3); + + if (g_ascii_strcasecmp(how_str, "cr") == 0) { + how = RSPAMD_TASK_NEWLINES_CR; + } + else if (g_ascii_strcasecmp(how_str, "lf") == 0) { + how = RSPAMD_TASK_NEWLINES_LF; + } + else if (g_ascii_strcasecmp(how_str, "crlf") != 0) { + return luaL_error(L, "invalid newline style: %s", how_str); + } + } + + if (lua_type(L, 4) == LUA_TBOOLEAN) { + fold = lua_toboolean(L, 4); + } + + gsize sz_len; + + out = lua_newuserdata(L, sizeof(*t)); + out->flags = RSPAMD_TEXT_FLAG_OWN; + out->start = rspamd_encode_base64_common(t->start, t->len, + line_len, &sz_len, fold, how); + out->len = sz_len; + rspamd_lua_setclass(L, "rspamd{text}", -1); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_text_hex(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text(L, 1), *out; + + if (t != NULL) { + + out = lua_new_text(L, NULL, t->len * 2, TRUE); + out->len = rspamd_encode_hex_buf(t->start, t->len, (gchar *) out->start, + out->len); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_text_find(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text(L, 1); + gsize patlen, init = 1; + const gchar *pat = luaL_checklstring(L, 2, &patlen); + + if (t != NULL && pat != NULL) { + + if (lua_isnumber(L, 3)) { + init = relative_pos_start(lua_tointeger(L, 3), t->len); + } + + init--; + + if (init > t->len) { + return luaL_error(L, "invalid arguments to find: init too large"); + } + + goffset pos = rspamd_substring_search(t->start + init, + t->len - init, + pat, patlen); + + if (pos == -1) { + lua_pushnil(L); + + return 1; + } + + lua_pushinteger(L, pos + 1); + lua_pushinteger(L, pos + patlen); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 2; +} + +#define BITOP(a, b, op) \ + ((a)[(guint64) (b) / (8u * sizeof *(a))] op(guint64) 1 << ((guint64) (b) % (8u * sizeof *(a)))) + +static gint +lua_text_exclude_chars(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text(L, 1); + gssize patlen; + const gchar *pat = lua_tolstring(L, 2, &patlen), *p, *end; + gchar *dest, *d; + guint64 byteset[32 / sizeof(guint64)]; /* Bitset for ascii */ + gboolean copy = TRUE; + guint *plen; + + if (t != NULL && pat && patlen > 0) { + if (lua_isboolean(L, 3)) { + copy = lua_toboolean(L, 3); + } + else if (t->flags & RSPAMD_TEXT_FLAG_OWN) { + copy = FALSE; + } + + if (!copy) { + dest = (gchar *) t->start; + plen = &t->len; + lua_pushvalue(L, 1); /* Push text as a result */ + } + else { + /* We need to copy read only text */ + struct rspamd_lua_text *nt; + + dest = g_malloc(t->len); + nt = lua_newuserdata(L, sizeof(*nt)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + nt->len = t->len; + nt->flags = RSPAMD_TEXT_FLAG_OWN; + memcpy(dest, t->start, t->len); + nt->start = dest; + plen = &nt->len; + } + + /* Fill pattern bitset */ + memset(byteset, 0, sizeof byteset); + + while (patlen > 0) { + if (*pat == '%') { + pat++; + patlen--; + + if (patlen > 0) { + /* + * This stuff assumes little endian, but GUINT64_FROM_LE should + * deal with proper conversion + */ + switch (*pat) { + case '%': + BITOP(byteset, *(guchar *) pat, |=); + break; + case 's': + /* "\r\n\t\f " */ + byteset[0] |= GUINT64_FROM_LE(0x100003600LLU); + break; + case 'n': + /* newlines: "\r\n" */ + byteset[0] |= GUINT64_FROM_LE(0x2400LLU); + break; + case '8': + /* 8 bit characters */ + byteset[2] |= GUINT64_FROM_LE(0xffffffffffffffffLLU); + byteset[3] |= GUINT64_FROM_LE(0xffffffffffffffffLLU); + break; + case 'c': + /* Non printable (control) characters */ + byteset[0] |= GUINT64_FROM_LE(0xffffffffLLU); + /* Del character */ + byteset[1] |= GUINT64_FROM_LE(0x8000000000000000LLU); + break; + } + } + else { + /* Last '%' */ + BITOP(byteset, (guchar) '%', |=); + } + } + else { + BITOP(byteset, *(guchar *) pat, |=); + } + + pat++; + patlen--; + } + for (; patlen > 0 && BITOP(byteset, *(guchar *) pat, |=); pat++, patlen--) + ; + + p = t->start; + end = t->start + t->len; + d = dest; + + while (p < end) { + if (!BITOP(byteset, *(guchar *) p, &)) { + *d++ = *p; + } + + p++; + } + + *(plen) = d - dest; + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_text_oneline(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text(L, 1); + const gchar *p, *end; + gchar *dest, *d; + guint64 byteset[32 / sizeof(guint64)]; /* Bitset for ascii */ + gboolean copy = TRUE, seen_8bit = FALSE; + guint *plen; + + if (t != NULL) { + if (lua_isboolean(L, 2)) { + copy = lua_toboolean(L, 2); + } + else if (t->flags & RSPAMD_TEXT_FLAG_OWN) { + copy = FALSE; + } + + if (!copy) { + dest = (gchar *) t->start; + plen = &t->len; + lua_pushvalue(L, 1); /* Push text as a result */ + } + else { + /* We need to copy read only text */ + struct rspamd_lua_text *nt; + + dest = g_malloc(t->len); + nt = lua_newuserdata(L, sizeof(*nt)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + nt->len = t->len; + nt->flags = RSPAMD_TEXT_FLAG_OWN; + memcpy(dest, t->start, t->len); + nt->start = dest; + plen = &nt->len; + } + + /* Fill pattern bitset */ + memset(byteset, 0, sizeof byteset); + /* All spaces */ + byteset[0] |= GUINT64_FROM_LE(0x100003600LLU); + /* Control characters */ + byteset[0] |= GUINT64_FROM_LE(0xffffffffLLU); + /* Del character */ + byteset[1] |= GUINT64_FROM_LE(0x8000000000000000LLU); + /* 8 bit characters */ + byteset[2] |= GUINT64_FROM_LE(0xffffffffffffffffLLU); + byteset[3] |= GUINT64_FROM_LE(0xffffffffffffffffLLU); + + p = t->start; + end = t->start + t->len; + d = dest; + + while (p < end) { + if (!BITOP(byteset, *(guchar *) p, &)) { + *d++ = *p; + } + else { + if ((*(guchar *) p) & 0x80) { + seen_8bit = TRUE; + *d++ = *p; + } + else { + if (*p == ' ') { + if (d != dest) { + *d++ = *p++; + } + + while (p < end && g_ascii_isspace(*p)) { + p++; + } + + continue; /* To avoid p++ */ + } + else if (*p == '\r' || *p == '\n') { + if (d != dest) { + *d++ = ' '; + p++; + } + + while (p < end && g_ascii_isspace(*p)) { + p++; + } + + continue; /* To avoid p++ */ + } + } + } + + p++; + } + + while (d > dest && g_ascii_isspace(*(d - 1))) { + d--; + } + + if (seen_8bit) { + if (rspamd_fast_utf8_validate(dest, d - dest) != 0) { + /* Need to make it valid :( */ + UChar32 uc; + goffset err_offset; + gsize remain = d - dest; + gchar *nd = dest; + + while (remain > 0 && (err_offset = rspamd_fast_utf8_validate(nd, remain)) > 0) { + gint i = 0; + + err_offset--; /* As it returns it 1 indexed */ + nd += err_offset; + remain -= err_offset; + + /* Each invalid character of input requires 3 bytes of output (+2 bytes) */ + while (i < remain) { + gint old_pos = i; + U8_NEXT(nd, i, remain, uc); + + if (uc < 0) { + nd[old_pos] = '?'; + } + else { + break; + } + } + + nd += i; + remain -= i; + } + } + } + + *(plen) = d - dest; + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_text_lower(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text(L, 1), *nt; + gboolean is_utf8 = FALSE, is_inplace = FALSE; + + if (t != NULL) { + if (lua_isboolean(L, 2)) { + is_utf8 = lua_toboolean(L, 2); + } + if (lua_isboolean(L, 3)) { + is_inplace = lua_toboolean(L, 3); + } + + if (is_inplace) { + nt = t; + lua_pushvalue(L, 1); + } + else { + nt = lua_new_text(L, t->start, t->len, TRUE); + } + + if (!is_utf8) { + rspamd_str_lc((gchar *) nt->start, nt->len); + } + else { + rspamd_str_lc_utf8((gchar *) nt->start, nt->len); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_text_strtoul(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text(L, 1); + + if (t) { + unsigned long ll; + + if (rspamd_strtoul(t->start, t->len, &ll)) { + lua_pushinteger(L, ll); + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/* Used to distinguish lua text metatable */ +static const guint rspamd_lua_text_cookie = 0x2b21ef6fU; + +static gint +lua_load_text(lua_State *L) +{ + lua_newtable(L); + lua_pushstring(L, "cookie"); + lua_pushnumber(L, rspamd_lua_text_cookie); + lua_settable(L, -3); + luaL_register(L, NULL, textlib_f); + + return 1; +} + +void luaopen_text(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{text}", textlib_m); + lua_pushstring(L, "cookie"); + lua_pushnumber(L, rspamd_lua_text_cookie); + lua_settable(L, -3); + lua_pop(L, 1); + + rspamd_lua_add_preload(L, "rspamd_text", lua_load_text); +} diff --git a/src/lua/lua_thread_pool.cxx b/src/lua/lua_thread_pool.cxx new file mode 100644 index 0000000..295f33d --- /dev/null +++ b/src/lua/lua_thread_pool.cxx @@ -0,0 +1,369 @@ +/*- + * Copyright 2018 Mikhail Galanin + * Copyright 2019 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "config.h" + +#include "lua_common.h" +#include "lua_thread_pool.h" + +#include <vector> + +#define msg_debug_lua_threads(...) rspamd_conditional_debug_fast(NULL, NULL, \ + rspamd_lua_threads_log_id, "lua_threads", NULL, \ + RSPAMD_LOG_FUNC, \ + __VA_ARGS__) + +INIT_LOG_MODULE(lua_threads) + +static struct thread_entry *thread_entry_new(lua_State *L); +static void thread_entry_free(lua_State *L, struct thread_entry *ent); + +#define CFG_POOL_GET(cfg) (reinterpret_cast<lua_thread_pool *>((cfg)->lua_thread_pool)) + +struct lua_thread_pool { + std::vector<struct thread_entry *> available_items; + lua_State *L; + gint max_items; + struct thread_entry *running_entry; + static const int default_max_items = 100; + + lua_thread_pool(lua_State *L, gint max_items = default_max_items) + : L(L), max_items(max_items) + { + running_entry = nullptr; + available_items.reserve(max_items); + + for (auto i = 0; i < MAX(2, max_items / 10); i++) { + auto *ent = thread_entry_new(L); + available_items.push_back(ent); + } + } + + ~lua_thread_pool() + { + for (auto *ent: available_items) { + thread_entry_free(L, ent); + } + } + + auto get_thread() -> struct thread_entry * + { + struct thread_entry *ent; + + if (!available_items.empty()) { + ent = available_items.back(); + available_items.pop_back(); + } + else { + ent = thread_entry_new(L); + } + + running_entry = ent; + + return ent; + } + + auto return_thread(struct thread_entry *thread_entry, const gchar *loc) -> void + { + /* we can't return a running/yielded thread into the pool */ + g_assert(lua_status(thread_entry->lua_state) == 0); + + if (running_entry == thread_entry) { + running_entry = NULL; + } + + if (available_items.size() <= max_items) { + thread_entry->cd = NULL; + thread_entry->finish_callback = NULL; + thread_entry->error_callback = NULL; + thread_entry->task = NULL; + thread_entry->cfg = NULL; + + msg_debug_lua_threads("%s: returned thread to the threads pool %ud items", + loc, + available_items.size()); + + available_items.push_back(thread_entry); + } + else { + msg_debug_lua_threads("%s: removed thread as thread pool has %ud items", + loc, + available_items.size()); + thread_entry_free(L, thread_entry); + } + } + + auto terminate_thread(struct thread_entry *thread_entry, + const gchar *loc, + bool enforce) -> void + { + struct thread_entry *ent = NULL; + + if (!enforce) { + /* we should only terminate failed threads */ + g_assert(lua_status(thread_entry->lua_state) != 0 && + lua_status(thread_entry->lua_state) != LUA_YIELD); + } + + if (running_entry == thread_entry) { + running_entry = NULL; + } + + msg_debug_lua_threads("%s: terminated thread entry", loc); + thread_entry_free(L, thread_entry); + + if (available_items.size() <= max_items) { + ent = thread_entry_new(L); + available_items.push_back(ent); + } + } + + auto get_running_entry(void) -> struct thread_entry * + { + return running_entry; + }; + + auto set_running_entry(struct thread_entry *ent) -> struct thread_entry * + { + auto *old_entry = running_entry; + running_entry = ent; + return old_entry; + }; +}; + +static struct thread_entry * +thread_entry_new(lua_State *L) +{ + struct thread_entry *ent; + ent = g_new0(struct thread_entry, 1); + ent->lua_state = lua_newthread(L); + ent->thread_index = luaL_ref(L, LUA_REGISTRYINDEX); + + return ent; +} + +static void +thread_entry_free(lua_State *L, struct thread_entry *ent) +{ + luaL_unref(L, LUA_REGISTRYINDEX, ent->thread_index); + g_free(ent); +} + +struct lua_thread_pool * +lua_thread_pool_new(lua_State *L) +{ + auto *pool = new lua_thread_pool(L); + return pool; +} + +void lua_thread_pool_free(struct lua_thread_pool *pool) +{ + delete pool; +} + + +struct thread_entry * +lua_thread_pool_get_for_task(struct rspamd_task *task) +{ + struct thread_entry *ent = CFG_POOL_GET(task->cfg)->get_thread(); + + ent->task = task; + + return ent; +} + +struct thread_entry * +lua_thread_pool_get_for_config(struct rspamd_config *cfg) +{ + struct thread_entry *ent = CFG_POOL_GET(cfg)->get_thread(); + + ent->cfg = cfg; + + return ent; +} + +void lua_thread_pool_return_full(struct lua_thread_pool *pool, + struct thread_entry *thread_entry, const gchar *loc) +{ + pool->return_thread(thread_entry, loc); +} + +void lua_thread_pool_terminate_entry_full(struct lua_thread_pool *pool, + struct thread_entry *thread_entry, const gchar *loc, + bool enforce) +{ + pool->terminate_thread(thread_entry, loc, enforce); +} + +struct thread_entry * +lua_thread_pool_get_running_entry_full(struct lua_thread_pool *pool, + const gchar *loc) +{ + msg_debug_lua_threads("%s: lua_thread_pool_get_running_entry_full", loc); + return pool->get_running_entry(); +} + +void lua_thread_pool_set_running_entry_full(struct lua_thread_pool *pool, + struct thread_entry *thread_entry, + const gchar *loc) +{ + msg_debug_lua_threads("%s: lua_thread_pool_set_running_entry_full", loc); + pool->set_running_entry(thread_entry); +} + +static void +lua_thread_pool_set_running_entry_for_thread(struct thread_entry *thread_entry, + const gchar *loc) +{ + struct lua_thread_pool *pool; + + if (thread_entry->task) { + pool = CFG_POOL_GET(thread_entry->task->cfg); + } + else { + pool = CFG_POOL_GET(thread_entry->cfg); + } + + lua_thread_pool_set_running_entry_full(pool, thread_entry, loc); +} + +void lua_thread_pool_prepare_callback_full(struct lua_thread_pool *pool, + struct lua_callback_state *cbs, + const gchar *loc) +{ + msg_debug_lua_threads("%s: lua_thread_pool_prepare_callback_full", loc); + cbs->thread_pool = pool; + cbs->previous_thread = lua_thread_pool_get_running_entry_full(pool, loc); + cbs->my_thread = pool->get_thread(); + cbs->L = cbs->my_thread->lua_state; +} + +void lua_thread_pool_restore_callback_full(struct lua_callback_state *cbs, + const gchar *loc) +{ + lua_thread_pool_return_full(cbs->thread_pool, cbs->my_thread, loc); + lua_thread_pool_set_running_entry_full(cbs->thread_pool, + cbs->previous_thread, loc); +} + +static gint +lua_do_resume_full(lua_State *L, gint narg, const gchar *loc) +{ +#if LUA_VERSION_NUM >= 504 + int nres; +#endif + msg_debug_lua_threads("%s: lua_do_resume_full", loc); +#if LUA_VERSION_NUM < 502 + return lua_resume(L, narg); +#else +#if LUA_VERSION_NUM >= 504 + return lua_resume(L, NULL, narg, &nres); +#else + return lua_resume(L, NULL, narg); +#endif +#endif +} + +static void +lua_resume_thread_internal_full(struct thread_entry *thread_entry, + gint narg, const gchar *loc) +{ + gint ret; + struct lua_thread_pool *pool; + struct rspamd_task *task; + + msg_debug_lua_threads("%s: lua_resume_thread_internal_full", loc); + ret = lua_do_resume_full(thread_entry->lua_state, narg, loc); + + if (ret != LUA_YIELD) { + /* + LUA_YIELD state should not be handled here. + It should only happen when the thread initiated a asynchronous event and it will be restored as soon + the event is finished + */ + + if (thread_entry->task) { + pool = CFG_POOL_GET(thread_entry->task->cfg); + } + else { + pool = CFG_POOL_GET(thread_entry->cfg); + } + + if (ret == 0) { + if (thread_entry->finish_callback) { + thread_entry->finish_callback(thread_entry, ret); + } + + pool->return_thread(thread_entry, loc); + } + else { + rspamd_lua_traceback(thread_entry->lua_state); + if (thread_entry->error_callback) { + thread_entry->error_callback(thread_entry, ret, + lua_tostring(thread_entry->lua_state, -1)); + } + else if (thread_entry->task) { + task = thread_entry->task; + msg_err_task("lua call failed (%d): %s", ret, + lua_tostring(thread_entry->lua_state, -1)); + } + else { + msg_err("lua call failed (%d): %s", ret, + lua_tostring(thread_entry->lua_state, -1)); + } + + /* + * Maybe there is a way to recover here. + * For now, just remove faulty thread + */ + pool->terminate_thread(thread_entry, loc, false); + } + } +} + +void lua_thread_resume_full(struct thread_entry *thread_entry, gint narg, + const gchar *loc) +{ + /* + * The only state where we can resume from is LUA_YIELD + * Another acceptable status is OK (0) but in that case we should push function on stack + * to start the thread from, which is happening in lua_thread_call(), not in this function. + */ + g_assert(lua_status(thread_entry->lua_state) == LUA_YIELD); + msg_debug_lua_threads("%s: lua_thread_resume_full", loc); + lua_thread_pool_set_running_entry_for_thread(thread_entry, loc); + lua_resume_thread_internal_full(thread_entry, narg, loc); +} + +void lua_thread_call_full(struct thread_entry *thread_entry, + int narg, const gchar *loc) +{ + g_assert(lua_status(thread_entry->lua_state) == 0); /* we can't call running/yielded thread */ + g_assert(thread_entry->task != NULL || thread_entry->cfg != NULL); /* we can't call without pool */ + + lua_resume_thread_internal_full(thread_entry, narg, loc); +} + +gint lua_thread_yield_full(struct thread_entry *thread_entry, + gint nresults, + const gchar *loc) +{ + g_assert(lua_status(thread_entry->lua_state) == 0); + + msg_debug_lua_threads("%s: lua_thread_yield_full", loc); + return lua_yield(thread_entry->lua_state, nresults); +} diff --git a/src/lua/lua_thread_pool.h b/src/lua/lua_thread_pool.h new file mode 100644 index 0000000..b612ac3 --- /dev/null +++ b/src/lua/lua_thread_pool.h @@ -0,0 +1,194 @@ +#ifndef LUA_THREAD_POOL_H_ +#define LUA_THREAD_POOL_H_ + +#include <lua.h> + +#ifdef __cplusplus +extern "C" { +#endif + +struct thread_entry; +struct lua_thread_pool; + +typedef void (*lua_thread_finish_t)(struct thread_entry *thread, int ret); + +typedef void (*lua_thread_error_t)(struct thread_entry *thread, int ret, const char *msg); + +struct thread_entry { + lua_State *lua_state; + gint thread_index; + gpointer cd; + + /* function to handle result of called method, can be NULL */ + lua_thread_finish_t finish_callback; + + /* function to log result, i.e. if you want to modify error logging message or somehow process this state, can be NUL */ + lua_thread_error_t error_callback; + struct rspamd_task *task; + struct rspamd_config *cfg; +}; + +struct lua_callback_state { + lua_State *L; + struct thread_entry *my_thread; + struct thread_entry *previous_thread; + struct lua_thread_pool *thread_pool; +}; + +/** + * Allocates new thread pool on state L. Pre-creates number of lua-threads to use later on + * + * @param L + * @return + */ +struct lua_thread_pool * +lua_thread_pool_new(lua_State *L); + +/** + * Destroys the pool + * @param pool + */ +void lua_thread_pool_free(struct lua_thread_pool *pool); + +/** + * Extracts a thread from the list of available ones. + * It immediately becomes the running one and should be used to run a Lua script/function straight away. + * as soon as the code is finished, it should be either returned into list of available threads by + * calling lua_thread_pool_return() or terminated by calling lua_thread_pool_terminate_entry() + * if the code finished with error. + * + * If the code performed YIELD, the thread is still running and it's live should be controlled by the callee + * + * @param task + * @return + */ +struct thread_entry * +lua_thread_pool_get_for_task(struct rspamd_task *task); + +/** + * The same, but used when task is not available + * + * @param cfg + * @return + */ +struct thread_entry * +lua_thread_pool_get_for_config(struct rspamd_config *cfg); + +/** + * Return thread into the list of available ones. It can't be done with yielded or dead threads. + * + * @param pool + * @param thread_entry + */ +void lua_thread_pool_return_full(struct lua_thread_pool *pool, + struct thread_entry *thread_entry, + const gchar *loc); + +#define lua_thread_pool_return(pool, thread_entry) \ + lua_thread_pool_return_full(pool, thread_entry, G_STRLOC) + +/** + * Currently running thread. Typically needed in yielding point - to fill-up continuation. + * + * @param pool + * @return + */ +struct thread_entry * +lua_thread_pool_get_running_entry_full(struct lua_thread_pool *pool, + const gchar *loc); + +#define lua_thread_pool_get_running_entry(pool) \ + lua_thread_pool_get_running_entry_full(pool, G_STRLOC) + +/** + * Updates currently running thread + * + * @param pool + * @param thread_entry + */ +void lua_thread_pool_set_running_entry_full(struct lua_thread_pool *pool, + struct thread_entry *thread_entry, + const gchar *loc); + +#define lua_thread_pool_set_running_entry(pool, thread_entry) \ + lua_thread_pool_set_running_entry_full(pool, thread_entry, G_STRLOC) + +/** + * Prevents yielded thread to be used for callback execution. lua_thread_pool_restore_callback() should be called afterwards. + * + * @param pool + * @param cbs + */ +void lua_thread_pool_prepare_callback_full(struct lua_thread_pool *pool, + struct lua_callback_state *cbs, const gchar *loc); + +#define lua_thread_pool_prepare_callback(pool, cbs) \ + lua_thread_pool_prepare_callback_full(pool, cbs, G_STRLOC) + +/** + * Restores state after lua_thread_pool_prepare_callback () usage + * + * @param cbs + */ +void lua_thread_pool_restore_callback_full(struct lua_callback_state *cbs, + const gchar *loc); + +#define lua_thread_pool_restore_callback(cbs) \ + lua_thread_pool_restore_callback_full(cbs, G_STRLOC) + +/** + * Acts like lua_call but the tread is able to suspend execution. + * As soon as the call is over, call either thread_entry::finish_callback or thread_entry::error_callback. + * + * @param thread_entry + * @param narg + */ +void lua_thread_call_full(struct thread_entry *thread_entry, + int narg, + const gchar *loc); + +#define lua_thread_call(thread_entry, narg) \ + lua_thread_call_full(thread_entry, narg, G_STRLOC) + +/** + * Yields thread. should be only called in return statement + * @param thread_entry + * @param nresults + * @return + */ +int lua_thread_yield_full(struct thread_entry *thread_entry, int nresults, + const gchar *loc); + +#define lua_thread_yield(thread_entry, narg) \ + lua_thread_yield_full(thread_entry, narg, G_STRLOC) + +/** + * Resumes suspended by lua_yield_thread () thread + * @param task + * @param thread_entry + * @param narg + */ +void lua_thread_resume_full(struct thread_entry *thread_entry, + int narg, + const gchar *loc); + +#define lua_thread_resume(thread_entry, narg) \ + lua_thread_resume_full(thread_entry, narg, G_STRLOC) + +/** + * Terminates thread pool entry and fill the pool with another thread entry if needed + * @param pool + * @param thread_entry + * @param loc + */ +void lua_thread_pool_terminate_entry_full(struct lua_thread_pool *pool, + struct thread_entry *thread_entry, + const gchar *loc, bool enforce); +#define lua_thread_pool_terminate_entry(pool, thread_entry) \ + lua_thread_pool_terminate_entry_full(pool, thread_entry, G_STRLOC, false) + +#ifdef __cplusplus +} +#endif + +#endif /* LUA_THREAD_POOL_H_ */ diff --git a/src/lua/lua_trie.c b/src/lua/lua_trie.c new file mode 100644 index 0000000..3b0b55e --- /dev/null +++ b/src/lua/lua_trie.c @@ -0,0 +1,500 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" +#include "message.h" +#include "libutil/multipattern.h" + +/*** + * @module rspamd_trie + * Rspamd trie module provides the data structure suitable for searching of many + * patterns in arbitrary texts (or binary chunks). The algorithmic complexity of + * this algorithm is at most O(n + m + z), where `n` is the length of text, `m` is a length of pattern and `z` is a number of patterns in the text. + * + * Here is a typical example of trie usage: + * @example +local rspamd_trie = require "rspamd_trie" +local patterns = {'aab', 'ab', 'bcd\0ef'} + +local trie = rspamd_trie.create(patterns) + +local function trie_callback(number, pos) + print('Matched pattern number ' .. tostring(number) .. ' at pos: ' .. tostring(pos)) +end + +trie:match('some big text', trie_callback) + */ + +/* Suffix trie */ +LUA_FUNCTION_DEF(trie, create); +LUA_FUNCTION_DEF(trie, has_hyperscan); +LUA_FUNCTION_DEF(trie, match); +LUA_FUNCTION_DEF(trie, search_mime); +LUA_FUNCTION_DEF(trie, search_rawmsg); +LUA_FUNCTION_DEF(trie, search_rawbody); +LUA_FUNCTION_DEF(trie, destroy); + +static const struct luaL_reg trielib_m[] = { + LUA_INTERFACE_DEF(trie, match), + LUA_INTERFACE_DEF(trie, search_mime), + LUA_INTERFACE_DEF(trie, search_rawmsg), + LUA_INTERFACE_DEF(trie, search_rawbody), + {"__tostring", rspamd_lua_class_tostring}, + {"__gc", lua_trie_destroy}, + {NULL, NULL}}; +static const struct luaL_reg trielib_f[] = { + LUA_INTERFACE_DEF(trie, create), + LUA_INTERFACE_DEF(trie, has_hyperscan), + {NULL, NULL}}; + +static struct rspamd_multipattern * +lua_check_trie(lua_State *L, gint idx) +{ + void *ud = rspamd_lua_check_udata(L, 1, "rspamd{trie}"); + + luaL_argcheck(L, ud != NULL, 1, "'trie' expected"); + return ud ? *((struct rspamd_multipattern **) ud) : NULL; +} + +static gint +lua_trie_destroy(lua_State *L) +{ + struct rspamd_multipattern *trie = lua_check_trie(L, 1); + + if (trie) { + rspamd_multipattern_destroy(trie); + } + + return 0; +} + +/*** + * function trie.has_hyperscan() + * Checks for hyperscan support + * + * @return {bool} true if hyperscan is supported + */ +static gint +lua_trie_has_hyperscan(lua_State *L) +{ + lua_pushboolean(L, rspamd_multipattern_has_hyperscan()); + return 1; +} + +/*** + * function trie.create(patterns, [flags]) + * Creates new trie data structure + * @param {table} array of string patterns + * @return {trie} new trie object + */ +static gint +lua_trie_create(lua_State *L) +{ + struct rspamd_multipattern *trie, **ptrie; + gint npat = 0, flags = RSPAMD_MULTIPATTERN_ICASE | RSPAMD_MULTIPATTERN_GLOB; + GError *err = NULL; + + if (lua_isnumber(L, 2)) { + flags = lua_tointeger(L, 2); + } + + if (!lua_istable(L, 1)) { + return luaL_error(L, "lua trie expects array of patterns for now"); + } + else { + lua_pushvalue(L, 1); + lua_pushnil(L); + + while (lua_next(L, -2) != 0) { + if (lua_isstring(L, -1)) { + npat++; + } + + lua_pop(L, 1); + } + + trie = rspamd_multipattern_create_sized(npat, flags); + lua_pushnil(L); + + while (lua_next(L, -2) != 0) { + if (lua_isstring(L, -1)) { + const gchar *pat; + gsize patlen; + + pat = lua_tolstring(L, -1, &patlen); + rspamd_multipattern_add_pattern_len(trie, pat, patlen, flags); + } + + lua_pop(L, 1); + } + + lua_pop(L, 1); /* table */ + + if (!rspamd_multipattern_compile(trie, &err)) { + msg_err("cannot compile multipattern: %e", err); + g_error_free(err); + rspamd_multipattern_destroy(trie); + lua_pushnil(L); + } + else { + ptrie = lua_newuserdata(L, sizeof(void *)); + rspamd_lua_setclass(L, "rspamd{trie}", -1); + *ptrie = trie; + } + } + + return 1; +} + +#define PUSH_TRIE_MATCH(L, start, end, report_start) \ + do { \ + if (report_start) { \ + lua_createtable(L, 2, 0); \ + lua_pushinteger(L, (start)); \ + lua_rawseti(L, -2, 1); \ + lua_pushinteger(L, (end)); \ + lua_rawseti(L, -2, 2); \ + } \ + else { \ + lua_pushinteger(L, (end)); \ + } \ + } while (0) + +/* Normal callback type */ +static gint +lua_trie_lua_cb_callback(struct rspamd_multipattern *mp, + guint strnum, + gint match_start, + gint textpos, + const gchar *text, + gsize len, + void *context) +{ + lua_State *L = context; + gint ret; + + gboolean report_start = lua_toboolean(L, -1); + + /* Function */ + lua_pushvalue(L, 3); + lua_pushinteger(L, strnum + 1); + + PUSH_TRIE_MATCH(L, match_start, textpos, report_start); + + if (lua_pcall(L, 2, 1, 0) != 0) { + msg_info("call to trie callback has failed: %s", + lua_tostring(L, -1)); + lua_pop(L, 1); + + return 1; + } + + ret = lua_tonumber(L, -1); + lua_pop(L, 1); + + return ret; +} + +/* Table like callback, expect result table on top of the stack */ +static gint +lua_trie_table_callback(struct rspamd_multipattern *mp, + guint strnum, + gint match_start, + gint textpos, + const gchar *text, + gsize len, + void *context) +{ + lua_State *L = context; + + gint report_start = lua_toboolean(L, -2); + /* Set table, indexed by pattern number */ + lua_rawgeti(L, -1, strnum + 1); + + if (lua_istable(L, -1)) { + /* Already have table, add offset */ + gsize last = rspamd_lua_table_size(L, -1); + PUSH_TRIE_MATCH(L, match_start, textpos, report_start); + lua_rawseti(L, -2, last + 1); + /* Remove table from the stack */ + lua_pop(L, 1); + } + else { + /* Pop none */ + lua_pop(L, 1); + /* New table */ + lua_newtable(L); + PUSH_TRIE_MATCH(L, match_start, textpos, report_start); + lua_rawseti(L, -2, 1); + lua_rawseti(L, -2, strnum + 1); + } + + return 0; +} + +/* + * We assume that callback argument is at pos 3 and icase is in position 4 + */ +static gint +lua_trie_search_str(lua_State *L, struct rspamd_multipattern *trie, + const gchar *str, gsize len, rspamd_multipattern_cb_t cb) +{ + gint ret; + guint nfound = 0; + + if ((ret = rspamd_multipattern_lookup(trie, str, len, + cb, L, &nfound)) == 0) { + return nfound; + } + + return ret; +} + +/*** + * @method trie:match(input, [cb][, report_start]) + * Search for patterns in `input` invoking `cb` optionally ignoring case + * @param {table or string} input one or several (if `input` is an array) strings of input text + * @param {function} cb callback called on each pattern match in form `function (idx, pos)` where `idx` is a numeric index of pattern (starting from 1) and `pos` is a numeric offset where the pattern ends + * @param {boolean} report_start report both start and end offset when matching patterns + * @return {boolean} `true` if any pattern has been found (`cb` might be called multiple times however). If `cb` is not defined then it returns a table of match positions indexed by pattern number + */ +static gint +lua_trie_match(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_multipattern *trie = lua_check_trie(L, 1); + const gchar *text; + gsize len; + gboolean found = FALSE, report_start = FALSE; + struct rspamd_lua_text *t; + rspamd_multipattern_cb_t cb = lua_trie_lua_cb_callback; + + gint old_top = lua_gettop(L); + + if (trie) { + if (lua_type(L, 3) != LUA_TFUNCTION) { + if (lua_isboolean(L, 3)) { + report_start = lua_toboolean(L, 3); + } + + lua_pushboolean(L, report_start); + /* Table like match */ + lua_newtable(L); + cb = lua_trie_table_callback; + } + else { + if (lua_isboolean(L, 4)) { + report_start = lua_toboolean(L, 4); + } + lua_pushboolean(L, report_start); + } + + if (lua_type(L, 2) == LUA_TTABLE) { + lua_pushvalue(L, 2); + lua_pushnil(L); + + while (lua_next(L, -2) != 0) { + if (lua_isstring(L, -1)) { + text = lua_tolstring(L, -1, &len); + + if (lua_trie_search_str(L, trie, text, len, cb)) { + found = TRUE; + } + } + else if (lua_isuserdata(L, -1)) { + t = lua_check_text(L, -1); + + if (t) { + if (lua_trie_search_str(L, trie, t->start, t->len, cb)) { + found = TRUE; + } + } + } + lua_pop(L, 1); + } + } + else if (lua_type(L, 2) == LUA_TSTRING) { + text = lua_tolstring(L, 2, &len); + + if (lua_trie_search_str(L, trie, text, len, cb)) { + found = TRUE; + } + } + else if (lua_type(L, 2) == LUA_TUSERDATA) { + t = lua_check_text(L, 2); + + if (t && lua_trie_search_str(L, trie, t->start, t->len, cb)) { + found = TRUE; + } + } + } + + if (lua_type(L, 3) == LUA_TFUNCTION) { + lua_settop(L, old_top); + lua_pushboolean(L, found); + } + else { + lua_remove(L, -2); + } + + return 1; +} + +/*** + * @method trie:search_mime(task, cb) + * This is a helper mehthod to search pattern within text parts of a message in rspamd task + * @param {task} task object + * @param {function} cb callback called on each pattern match @see trie:match + * @param {boolean} caseless if `true` then match ignores symbols case (ASCII only) + * @return {boolean} `true` if any pattern has been found (`cb` might be called multiple times however) + */ +static gint +lua_trie_search_mime(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_multipattern *trie = lua_check_trie(L, 1); + struct rspamd_task *task = lua_check_task(L, 2); + struct rspamd_mime_text_part *part; + const gchar *text; + gsize len, i; + gboolean found = FALSE; + rspamd_multipattern_cb_t cb = lua_trie_lua_cb_callback; + + if (trie && task) { + PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, part) + { + if (!IS_TEXT_PART_EMPTY(part) && part->utf_content.len > 0) { + text = part->utf_content.begin; + len = part->utf_content.len; + + if (lua_trie_search_str(L, trie, text, len, cb) != 0) { + found = TRUE; + } + } + } + } + + lua_pushboolean(L, found); + return 1; +} + +/*** + * @method trie:search_rawmsg(task, cb[, caseless]) + * This is a helper mehthod to search pattern within the whole undecoded content of rspamd task + * @param {task} task object + * @param {function} cb callback called on each pattern match @see trie:match + * @param {boolean} caseless if `true` then match ignores symbols case (ASCII only) + * @return {boolean} `true` if any pattern has been found (`cb` might be called multiple times however) + */ +static gint +lua_trie_search_rawmsg(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_multipattern *trie = lua_check_trie(L, 1); + struct rspamd_task *task = lua_check_task(L, 2); + const gchar *text; + gsize len; + gboolean found = FALSE; + + if (trie && task) { + text = task->msg.begin; + len = task->msg.len; + + if (lua_trie_search_str(L, trie, text, len, lua_trie_lua_cb_callback) != 0) { + found = TRUE; + } + } + + lua_pushboolean(L, found); + return 1; +} + +/*** + * @method trie:search_rawbody(task, cb[, caseless]) + * This is a helper mehthod to search pattern within the whole undecoded content of task's body (not including headers) + * @param {task} task object + * @param {function} cb callback called on each pattern match @see trie:match + * @param {boolean} caseless if `true` then match ignores symbols case (ASCII only) + * @return {boolean} `true` if any pattern has been found (`cb` might be called multiple times however) + */ +static gint +lua_trie_search_rawbody(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_multipattern *trie = lua_check_trie(L, 1); + struct rspamd_task *task = lua_check_task(L, 2); + const gchar *text; + gsize len; + gboolean found = FALSE; + + if (trie && task) { + if (MESSAGE_FIELD(task, raw_headers_content).len > 0) { + text = task->msg.begin + MESSAGE_FIELD(task, raw_headers_content).len; + len = task->msg.len - MESSAGE_FIELD(task, raw_headers_content).len; + } + else { + /* Treat as raw message */ + text = task->msg.begin; + len = task->msg.len; + } + + if (lua_trie_search_str(L, trie, text, len, lua_trie_lua_cb_callback) != 0) { + found = TRUE; + } + } + + lua_pushboolean(L, found); + return 1; +} + +static gint +lua_load_trie(lua_State *L) +{ + lua_newtable(L); + + /* Flags */ + lua_pushstring(L, "flags"); + lua_newtable(L); + + lua_pushinteger(L, RSPAMD_MULTIPATTERN_GLOB); + lua_setfield(L, -2, "glob"); + lua_pushinteger(L, RSPAMD_MULTIPATTERN_RE); + lua_setfield(L, -2, "re"); + lua_pushinteger(L, RSPAMD_MULTIPATTERN_ICASE); + lua_setfield(L, -2, "icase"); + lua_pushinteger(L, RSPAMD_MULTIPATTERN_UTF8); + lua_setfield(L, -2, "utf8"); + lua_pushinteger(L, RSPAMD_MULTIPATTERN_TLD); + lua_setfield(L, -2, "tld"); + lua_pushinteger(L, RSPAMD_MULTIPATTERN_DOTALL); + lua_setfield(L, -2, "dot_all"); + lua_pushinteger(L, RSPAMD_MULTIPATTERN_SINGLEMATCH); + lua_setfield(L, -2, "single_match"); + lua_pushinteger(L, RSPAMD_MULTIPATTERN_NO_START); + lua_setfield(L, -2, "no_start"); + lua_settable(L, -3); + + /* Main content */ + luaL_register(L, NULL, trielib_f); + + return 1; +} + +void luaopen_trie(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{trie}", trielib_m); + lua_pop(L, 1); + rspamd_lua_add_preload(L, "rspamd_trie", lua_load_trie); +} diff --git a/src/lua/lua_udp.c b/src/lua/lua_udp.c new file mode 100644 index 0000000..c79e35a --- /dev/null +++ b/src/lua/lua_udp.c @@ -0,0 +1,594 @@ +/*- + * Copyright 2019 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" +#include "lua_thread_pool.h" +#include "utlist.h" +#include "unix-std.h" +#include <math.h> +#include <src/libutil/libev_helper.h> + +static const gchar *M = "rspamd lua udp"; + +/*** + * @module rspamd_udp + * Rspamd UDP module is available from the version 1.9.0 and represents a generic + * UDP asynchronous client available from the LUA code. + * This module is quite simple: it can either send requests to some address or + * it can send requests and wait for replies, potentially handling retransmits. + * @example +local logger = require "rspamd_logger" +local udp = require "rspamd_udp" + +rspamd_config.SYM = function(task) + udp.sento{ + host = addr, -- must be ip address object (e.g. received by upstream module) + port = 500, + data = {'str1', 'str2'}, -- can be table, string or rspamd_text + timeout = 0.5, -- default = 1s + task = task, -- if has task + session = session, -- optional + ev_base = ev_base, -- if no task available + -- You can include callback and then Rspamd will try to read replies + callback = function(success, data) + -- success is bool, data is either data or an error (string) + end, + retransmits = 0, -- Or more if retransmitting is necessary + } +end + */ + +static const double default_udp_timeout = 1.0; + +LUA_FUNCTION_DEF(udp, sendto); + +static const struct luaL_reg udp_libf[] = { + LUA_INTERFACE_DEF(udp, sendto), + {NULL, NULL}}; + +struct lua_udp_cbdata { + struct ev_loop *event_loop; + struct rspamd_io_ev ev; + struct rspamd_async_event *async_ev; + struct rspamd_task *task; + rspamd_mempool_t *pool; + rspamd_inet_addr_t *addr; + struct rspamd_symcache_dynamic_item *item; + struct rspamd_async_session *s; + struct iovec *iov; + lua_State *L; + guint retransmits; + guint iovlen; + gint sock; + gint cbref; + gboolean sent; +}; + +#define msg_debug_udp(...) rspamd_conditional_debug_fast(NULL, cbd->addr, \ + rspamd_lua_udp_log_id, "lua_udp", cbd->pool->tag.uid, \ + G_STRFUNC, \ + __VA_ARGS__) + +INIT_LOG_MODULE(lua_udp) + +static inline void +lua_fill_iov(lua_State *L, rspamd_mempool_t *pool, + struct iovec *iov, gint pos) +{ + if (lua_type(L, pos) == LUA_TUSERDATA) { + struct rspamd_lua_text *t = lua_check_text(L, pos); + + if (t) { + iov->iov_base = rspamd_mempool_alloc(pool, t->len); + iov->iov_len = t->len; + memcpy(iov->iov_base, t->start, t->len); + } + } + else { + const gchar *s; + gsize len; + + s = lua_tolstring(L, pos, &len); + + iov->iov_base = rspamd_mempool_alloc(pool, len); + iov->iov_len = len; + memcpy(iov->iov_base, s, len); + } +} + +static void +lua_udp_cbd_fin(gpointer p) +{ + struct lua_udp_cbdata *cbd = (struct lua_udp_cbdata *) p; + + if (cbd->sock != -1) { + rspamd_ev_watcher_stop(cbd->event_loop, &cbd->ev); + close(cbd->sock); + } + + if (cbd->addr) { + rspamd_inet_address_free(cbd->addr); + } + + if (cbd->cbref) { + luaL_unref(cbd->L, LUA_REGISTRYINDEX, cbd->cbref); + } +} + +static void +lua_udp_maybe_free(struct lua_udp_cbdata *cbd) +{ + if (cbd->item) { + rspamd_symcache_item_async_dec_check(cbd->task, cbd->item, M); + cbd->item = NULL; + } + + if (cbd->async_ev) { + rspamd_session_remove_event(cbd->s, lua_udp_cbd_fin, cbd); + } + else { + lua_udp_cbd_fin(cbd); + } +} + + +enum rspamd_udp_send_result { + RSPAMD_SENT_OK, + RSPAMD_SENT_RETRY, + RSPAMD_SENT_FAILURE +}; + +static enum rspamd_udp_send_result +lua_try_send_request(struct lua_udp_cbdata *cbd) +{ + struct msghdr msg; + gint r; + + memset(&msg, 0, sizeof(msg)); + msg.msg_iov = cbd->iov; + msg.msg_iovlen = cbd->iovlen; + msg.msg_name = rspamd_inet_address_get_sa(cbd->addr, &msg.msg_namelen); + + r = sendmsg(cbd->sock, &msg, 0); + + if (r != -1) { + return RSPAMD_SENT_OK; + } + + if (errno == EAGAIN || errno == EINTR) { + return RSPAMD_SENT_RETRY; + } + + return RSPAMD_SENT_FAILURE; +} + +static void +lua_udp_maybe_push_error(struct lua_udp_cbdata *cbd, const gchar *err) +{ + if (cbd->cbref != -1) { + gint top; + lua_State *L = cbd->L; + + top = lua_gettop(L); + lua_rawgeti(L, LUA_REGISTRYINDEX, cbd->cbref); + + /* Error message */ + lua_pushboolean(L, false); + lua_pushstring(L, err); + + if (cbd->item) { + rspamd_symcache_set_cur_item(cbd->task, cbd->item); + } + + if (lua_pcall(L, 2, 0, 0) != 0) { + msg_info("callback call failed: %s", lua_tostring(L, -1)); + } + + lua_settop(L, top); + } + + lua_udp_maybe_free(cbd); +} + +static void +lua_udp_push_data(struct lua_udp_cbdata *cbd, const gchar *data, + gssize len) +{ + if (cbd->cbref != -1) { + gint top; + lua_State *L = cbd->L; + + top = lua_gettop(L); + lua_rawgeti(L, LUA_REGISTRYINDEX, cbd->cbref); + + /* Error message */ + lua_pushboolean(L, true); + lua_pushlstring(L, data, len); + + if (cbd->item) { + rspamd_symcache_set_cur_item(cbd->task, cbd->item); + } + + if (lua_pcall(L, 2, 0, 0) != 0) { + msg_info("callback call failed: %s", lua_tostring(L, -1)); + } + + lua_settop(L, top); + } + + lua_udp_maybe_free(cbd); +} + +static gboolean +lua_udp_maybe_register_event(struct lua_udp_cbdata *cbd) +{ + if (cbd->s && !cbd->async_ev) { + if (cbd->item) { + cbd->async_ev = rspamd_session_add_event_full(cbd->s, lua_udp_cbd_fin, + cbd, M, + rspamd_symcache_dyn_item_name(cbd->task, cbd->item)); + } + else { + cbd->async_ev = rspamd_session_add_event(cbd->s, lua_udp_cbd_fin, + cbd, M); + } + + if (!cbd->async_ev) { + return FALSE; + } + } + + if (cbd->task && !cbd->item) { + cbd->item = rspamd_symcache_get_cur_item(cbd->task); + rspamd_symcache_item_async_inc(cbd->task, cbd->item, M); + } + + return TRUE; +} + +static void +lua_udp_io_handler(gint fd, short what, gpointer p) +{ + struct lua_udp_cbdata *cbd = (struct lua_udp_cbdata *) p; + gssize r; + + if (what == EV_TIMEOUT) { + if (cbd->sent && cbd->retransmits > 0) { + r = lua_try_send_request(cbd); + + if (r == RSPAMD_SENT_OK) { + rspamd_ev_watcher_reschedule(cbd->event_loop, &cbd->ev, EV_READ); + lua_udp_maybe_register_event(cbd); + cbd->retransmits--; + } + else if (r == RSPAMD_SENT_FAILURE) { + lua_udp_maybe_push_error(cbd, "write error"); + } + else { + cbd->retransmits--; + rspamd_ev_watcher_reschedule(cbd->event_loop, &cbd->ev, EV_WRITE); + } + } + else { + if (!cbd->sent) { + lua_udp_maybe_push_error(cbd, "sent timeout"); + } + else { + lua_udp_maybe_push_error(cbd, "read timeout"); + } + } + } + else if (what == EV_WRITE) { + r = lua_try_send_request(cbd); + + if (r == RSPAMD_SENT_OK) { + if (cbd->cbref != -1) { + rspamd_ev_watcher_reschedule(cbd->event_loop, &cbd->ev, EV_READ); + cbd->sent = TRUE; + } + else { + lua_udp_maybe_free(cbd); + } + } + else if (r == RSPAMD_SENT_FAILURE) { + lua_udp_maybe_push_error(cbd, "write error"); + } + else { + cbd->retransmits--; + rspamd_ev_watcher_reschedule(cbd->event_loop, &cbd->ev, EV_WRITE); + } + } + else if (what == EV_READ) { + guchar udpbuf[4096]; + socklen_t slen; + struct sockaddr *sa; + + sa = rspamd_inet_address_get_sa(cbd->addr, &slen); + + r = recvfrom(cbd->sock, udpbuf, sizeof(udpbuf), 0, sa, &slen); + + if (r == -1) { + lua_udp_maybe_push_error(cbd, strerror(errno)); + } + else { + lua_udp_push_data(cbd, udpbuf, r); + } + } +} + +/*** + * @function rspamd_udp.sendto({params}) + * This function simply sends data to an external UDP service + * + * - `task`: rspamd task objects (implies `pool`, `session` and `ev_base` arguments) + * - `ev_base`: event base (if no task specified) + * - `session`: events session (no task, optional) + * - `pool`: memory pool (if no task specified) + * - `host`: IP or name of the peer (required) + * - `port`: remote port to use (if `host` has no port part this is required) + * - `data`: a table of strings or `rspamd_text` objects that contains data pieces + * - `retransmits`: number of retransmits if needed + * - `callback`: optional callback if reply should be read + * @return {boolean} true if request has been sent (additional string if it has not) + */ +static gint +lua_udp_sendto(lua_State *L) +{ + LUA_TRACE_POINT; + const gchar *host; + guint port; + struct ev_loop *ev_base = NULL; + struct lua_udp_cbdata *cbd; + struct rspamd_async_session *session = NULL; + struct rspamd_task *task = NULL; + rspamd_inet_addr_t *addr; + rspamd_mempool_t *pool = NULL; + gdouble timeout = default_udp_timeout; + + if (lua_type(L, 1) == LUA_TTABLE) { + lua_pushstring(L, "port"); + lua_gettable(L, -2); + + if (lua_type(L, -1) == LUA_TNUMBER) { + port = lua_tointeger(L, -1); + } + else { + /* We assume that it is a unix socket */ + port = 0; + } + + lua_pop(L, 1); + + lua_pushstring(L, "host"); + lua_gettable(L, -2); + + if (lua_type(L, -1) == LUA_TSTRING) { + host = luaL_checkstring(L, -1); + + if (rspamd_parse_inet_address(&addr, + host, strlen(host), RSPAMD_INET_ADDRESS_PARSE_DEFAULT)) { + if (port != 0) { + rspamd_inet_address_set_port(addr, port); + } + } + else { + lua_pop(L, 1); + return luaL_error(L, "invalid host: %s", host); + } + } + else if (lua_type(L, -1) == LUA_TUSERDATA) { + struct rspamd_lua_ip *lip; + + lip = lua_check_ip(L, -1); + + if (lip == NULL || lip->addr == NULL) { + lua_pop(L, 1); + return luaL_error(L, "invalid host class"); + } + + addr = rspamd_inet_address_copy(lip->addr, NULL); + + if (port != 0) { + rspamd_inet_address_set_port(addr, port); + } + } + else { + lua_pop(L, 1); + return luaL_error(L, "invalid host"); + } + + lua_pop(L, 1); + + lua_pushstring(L, "task"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TUSERDATA) { + task = lua_check_task(L, -1); + ev_base = task->event_loop; + session = task->s; + pool = task->task_pool; + } + lua_pop(L, 1); + + if (task == NULL) { + lua_pushstring(L, "ev_base"); + lua_gettable(L, -2); + if (rspamd_lua_check_udata_maybe(L, -1, "rspamd{ev_base}")) { + ev_base = *(struct ev_loop **) lua_touserdata(L, -1); + } + else { + ev_base = NULL; + } + lua_pop(L, 1); + + lua_pushstring(L, "session"); + lua_gettable(L, -2); + if (rspamd_lua_check_udata_maybe(L, -1, "rspamd{session}")) { + session = *(struct rspamd_async_session **) lua_touserdata(L, -1); + } + else { + session = NULL; + } + lua_pop(L, 1); + + lua_pushstring(L, "pool"); + lua_gettable(L, -2); + if (rspamd_lua_check_udata_maybe(L, -1, "rspamd{mempool}")) { + pool = *(rspamd_mempool_t **) lua_touserdata(L, -1); + } + else { + pool = NULL; + } + lua_pop(L, 1); + } + + lua_pushstring(L, "timeout"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TNUMBER) { + timeout = lua_tonumber(L, -1); + } + lua_pop(L, 1); + + if (!ev_base || !pool) { + rspamd_inet_address_free(addr); + + return luaL_error(L, "invalid arguments"); + } + + + cbd = rspamd_mempool_alloc0(pool, sizeof(*cbd)); + cbd->event_loop = ev_base; + cbd->pool = pool; + cbd->s = session; + cbd->addr = addr; + cbd->sock = rspamd_socket_create(rspamd_inet_address_get_af(addr), + SOCK_DGRAM, 0, TRUE); + cbd->cbref = -1; + cbd->ev.timeout = timeout; + + if (cbd->sock == -1) { + rspamd_inet_address_free(addr); + + return luaL_error(L, "cannot open socket: %s", strerror(errno)); + } + + cbd->L = L; + + gsize data_len; + + lua_pushstring(L, "data"); + lua_gettable(L, -2); + + if (lua_type(L, -1) == LUA_TTABLE) { + data_len = rspamd_lua_table_size(L, -1); + cbd->iov = rspamd_mempool_alloc(pool, + sizeof(*cbd->iov) * data_len); + + for (int i = 0; i < data_len; i++) { + lua_rawgeti(L, -1, i + 1); + lua_fill_iov(L, pool, &cbd->iov[i], -1); + lua_pop(L, 1); + } + + cbd->iovlen = data_len; + } + else { + cbd->iov = rspamd_mempool_alloc(pool, sizeof(*cbd->iov)); + cbd->iovlen = 1; + lua_fill_iov(L, pool, cbd->iov, -1); + } + + lua_pop(L, 1); + + lua_pushstring(L, "callback"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TFUNCTION) { + cbd->cbref = luaL_ref(L, LUA_REGISTRYINDEX); + } + else { + lua_pop(L, 1); + } + + lua_pushstring(L, "retransmits"); + lua_gettable(L, -2); + if (lua_type(L, -1) == LUA_TNUMBER) { + cbd->retransmits = lua_tonumber(L, -1); + } + lua_pop(L, 1); + + enum rspamd_udp_send_result r; + + r = lua_try_send_request(cbd); + if (r == RSPAMD_SENT_OK) { + if (cbd->cbref == -1) { + lua_udp_maybe_free(cbd); + } + else { + if (!lua_udp_maybe_register_event(cbd)) { + lua_pushboolean(L, false); + lua_pushstring(L, "session error"); + lua_udp_maybe_free(cbd); + + return 2; + } + + rspamd_ev_watcher_init(&cbd->ev, cbd->sock, EV_READ, + lua_udp_io_handler, cbd); + rspamd_ev_watcher_start(cbd->event_loop, &cbd->ev, timeout); + cbd->sent = TRUE; + } + + lua_pushboolean(L, true); + } + else if (r == RSPAMD_SENT_FAILURE) { + lua_pushboolean(L, false); + lua_pushstring(L, strerror(errno)); + lua_udp_maybe_free(cbd); + + return 2; + } + else { + rspamd_ev_watcher_init(&cbd->ev, cbd->sock, EV_WRITE, + lua_udp_io_handler, cbd); + rspamd_ev_watcher_start(cbd->event_loop, &cbd->ev, timeout); + + if (!lua_udp_maybe_register_event(cbd)) { + lua_pushboolean(L, false); + lua_pushstring(L, "session error"); + lua_udp_maybe_free(cbd); + + return 2; + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_load_udp(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, udp_libf); + + return 1; +} + +void luaopen_udp(lua_State *L) +{ + rspamd_lua_add_preload(L, "rspamd_udp", lua_load_udp); +} diff --git a/src/lua/lua_upstream.c b/src/lua/lua_upstream.c new file mode 100644 index 0000000..583ee6a --- /dev/null +++ b/src/lua/lua_upstream.c @@ -0,0 +1,672 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "config.h" +#include "lua_common.h" + + +/*** + * @module rspamd_upstream_list + * This module implements upstreams manipulation from LUA API. This functionality + * can be used for load balancing using different strategies including: + * + * - round-robin: balance upstreams one by one selecting accordingly to their weight + * - hash: use stable hashing algorithm to distribute values according to some static strings + * - master-slave: always prefer upstream with higher priority unless it is not available + * + * Here is an example of upstreams manipulations: + * @example +local rspamd_logger = require "rspamd_logger" +local rspamd_redis = require "rspamd_redis" +local upstream_list = require "rspamd_upstream_list" +local upstreams = upstream_list.create('127.0.0.1,10.0.0.1,10.0.0.2', 6379) + +local function sym_callback(task) + local upstream = upstreams:get_upstream_by_hash(task:get_from()[1]['domain']) + + local function cb(task, err, data) + if err then + upstream:fail() + else + upstream:ok() + end + end + + local addr = upstream:get_addr() + rspamd_redis.make_request(task, addr, cb, + 'PUSH', {'key', 'value'}) +end + */ +/* Upstream list functions */ +LUA_FUNCTION_DEF(upstream_list, create); +LUA_FUNCTION_DEF(upstream_list, destroy); +LUA_FUNCTION_DEF(upstream_list, all_upstreams); +LUA_FUNCTION_DEF(upstream_list, get_upstream_by_hash); +LUA_FUNCTION_DEF(upstream_list, get_upstream_round_robin); +LUA_FUNCTION_DEF(upstream_list, get_upstream_master_slave); +LUA_FUNCTION_DEF(upstream_list, add_watcher); + +static const struct luaL_reg upstream_list_m[] = { + + LUA_INTERFACE_DEF(upstream_list, get_upstream_by_hash), + LUA_INTERFACE_DEF(upstream_list, get_upstream_round_robin), + LUA_INTERFACE_DEF(upstream_list, get_upstream_master_slave), + LUA_INTERFACE_DEF(upstream_list, all_upstreams), + LUA_INTERFACE_DEF(upstream_list, add_watcher), + {"__tostring", rspamd_lua_class_tostring}, + {"__gc", lua_upstream_list_destroy}, + {NULL, NULL}}; +static const struct luaL_reg upstream_list_f[] = { + LUA_INTERFACE_DEF(upstream_list, create), + {NULL, NULL}}; + +/* Upstream functions */ +LUA_FUNCTION_DEF(upstream, ok); +LUA_FUNCTION_DEF(upstream, fail); +LUA_FUNCTION_DEF(upstream, get_addr); +LUA_FUNCTION_DEF(upstream, get_name); +LUA_FUNCTION_DEF(upstream, get_port); +LUA_FUNCTION_DEF(upstream, destroy); + +static const struct luaL_reg upstream_m[] = { + LUA_INTERFACE_DEF(upstream, ok), + LUA_INTERFACE_DEF(upstream, fail), + LUA_INTERFACE_DEF(upstream, get_addr), + LUA_INTERFACE_DEF(upstream, get_port), + LUA_INTERFACE_DEF(upstream, get_name), + {"__tostring", rspamd_lua_class_tostring}, + {"__gc", lua_upstream_destroy}, + {NULL, NULL}}; + +/* Upstream class */ + +struct rspamd_lua_upstream * +lua_check_upstream(lua_State *L, int pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{upstream}"); + + luaL_argcheck(L, ud != NULL, 1, "'upstream' expected"); + return ud ? (struct rspamd_lua_upstream *) ud : NULL; +} + +/*** + * @method upstream:get_addr() + * Get ip of upstream + * @return {ip} ip address object + */ +static gint +lua_upstream_get_addr(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_upstream *up = lua_check_upstream(L, 1); + + if (up) { + rspamd_lua_ip_push(L, rspamd_upstream_addr_next(up->up)); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** + * @method upstream:get_name() + * Get name of upstream + * @return {string} name of the upstream + */ +static gint +lua_upstream_get_name(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_upstream *up = lua_check_upstream(L, 1); + + if (up) { + lua_pushstring(L, rspamd_upstream_name(up->up)); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** + * @method upstream:get_port() + * Get port of upstream + * @return {int} port of the upstream + */ +static gint +lua_upstream_get_port(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_upstream *up = lua_check_upstream(L, 1); + + if (up) { + lua_pushinteger(L, rspamd_upstream_port(up->up)); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** + * @method upstream:fail() + * Indicate upstream failure. After certain amount of failures during specified time frame, an upstream is marked as down and does not participate in rotations. + */ +static gint +lua_upstream_fail(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_upstream *up = lua_check_upstream(L, 1); + gboolean fail_addr = FALSE; + const gchar *reason = "unknown"; + + if (up) { + + if (lua_isboolean(L, 2)) { + fail_addr = lua_toboolean(L, 2); + + if (lua_isstring(L, 3)) { + reason = lua_tostring(L, 3); + } + } + else if (lua_isstring(L, 2)) { + reason = lua_tostring(L, 2); + } + + rspamd_upstream_fail(up->up, fail_addr, reason); + } + + return 0; +} + +/*** + * @method upstream:ok() + * Indicates upstream success. Resets errors count for an upstream. + */ +static gint +lua_upstream_ok(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_upstream *up = lua_check_upstream(L, 1); + + if (up) { + rspamd_upstream_ok(up->up); + } + + return 0; +} + +static gint +lua_upstream_destroy(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_upstream *up = lua_check_upstream(L, 1); + + if (up) { + /* Remove reference to the parent */ + luaL_unref(L, LUA_REGISTRYINDEX, up->upref); + /* Upstream belongs to the upstream list, so no free here */ + } + + return 0; +} + +/* Upstream list class */ + +static struct upstream_list * +lua_check_upstream_list(lua_State *L) +{ + void *ud = rspamd_lua_check_udata(L, 1, "rspamd{upstream_list}"); + + luaL_argcheck(L, ud != NULL, 1, "'upstream_list' expected"); + return ud ? *((struct upstream_list **) ud) : NULL; +} + +static struct rspamd_lua_upstream * +lua_push_upstream(lua_State *L, gint up_idx, struct upstream *up) +{ + struct rspamd_lua_upstream *lua_ups; + + if (up_idx < 0) { + up_idx = lua_gettop(L) + up_idx + 1; + } + + lua_ups = lua_newuserdata(L, sizeof(*lua_ups)); + lua_ups->up = up; + rspamd_lua_setclass(L, "rspamd{upstream}", -1); + /* Store parent in the upstream to prevent gc */ + lua_pushvalue(L, up_idx); + lua_ups->upref = luaL_ref(L, LUA_REGISTRYINDEX); + + return lua_ups; +} + +/*** + * @function upstream_list.create(cfg, def, [default_port]) + * Create new upstream list from its string definition in form `<upstream>,<upstream>;<upstream>` + * @param {rspamd_config} cfg configuration reference + * @param {string} def upstream list definition + * @param {number} default_port default port for upstreams + * @return {upstream_list} upstream list structure + */ +static gint +lua_upstream_list_create(lua_State *L) +{ + LUA_TRACE_POINT; + struct upstream_list *new = NULL, **pnew; + struct rspamd_config *cfg = NULL; + const gchar *def; + guint default_port = 0; + gint top; + + + if (lua_type(L, 1) == LUA_TUSERDATA) { + cfg = lua_check_config(L, 1); + top = 2; + } + else { + top = 1; + } + + if (lua_gettop(L) >= top + 1) { + default_port = luaL_checknumber(L, top + 1); + } + + if (lua_type(L, top) == LUA_TSTRING) { + def = luaL_checkstring(L, top); + + new = rspamd_upstreams_create(cfg ? cfg->ups_ctx : NULL); + + if (rspamd_upstreams_parse_line(new, def, default_port, NULL)) { + pnew = lua_newuserdata(L, sizeof(struct upstream_list *)); + rspamd_lua_setclass(L, "rspamd{upstream_list}", -1); + *pnew = new; + } + else { + rspamd_upstreams_destroy(new); + lua_pushnil(L); + } + } + else if (lua_type(L, top) == LUA_TTABLE) { + new = rspamd_upstreams_create(cfg ? cfg->ups_ctx : NULL); + pnew = lua_newuserdata(L, sizeof(struct upstream_list *)); + rspamd_lua_setclass(L, "rspamd{upstream_list}", -1); + *pnew = new; + + lua_pushvalue(L, top); + + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + def = lua_tostring(L, -1); + + if (!def || !rspamd_upstreams_parse_line(new, def, default_port, NULL)) { + msg_warn("cannot parse upstream %s", def); + } + } + + lua_pop(L, 1); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/** + * Destroy a single upstream list object + * @param L + * @return + */ +static gint +lua_upstream_list_destroy(lua_State *L) +{ + LUA_TRACE_POINT; + struct upstream_list *upl = lua_check_upstream_list(L); + + rspamd_upstreams_destroy(upl); + + return 0; +} + +/*** + * @method upstream_list:get_upstream_by_hash(key) + * Get upstream by hash from key + * @param {string} key a string used as input for stable hash algorithm + * @return {upstream} upstream from a list corresponding to the given key + */ +static gint +lua_upstream_list_get_upstream_by_hash(lua_State *L) +{ + LUA_TRACE_POINT; + struct upstream_list *upl; + struct upstream *selected; + const gchar *key; + gsize keyl; + + upl = lua_check_upstream_list(L); + if (upl) { + key = luaL_checklstring(L, 2, &keyl); + if (key) { + selected = rspamd_upstream_get(upl, RSPAMD_UPSTREAM_HASHED, key, + (guint) keyl); + + if (selected) { + lua_push_upstream(L, 1, selected); + } + else { + lua_pushnil(L); + } + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @method upstream_list:get_upstream_round_robin() + * Get upstream round robin (by current weight) + * @return {upstream} upstream from a list in round-robin matter + */ +static gint +lua_upstream_list_get_upstream_round_robin(lua_State *L) +{ + LUA_TRACE_POINT; + struct upstream_list *upl; + struct upstream *selected; + + upl = lua_check_upstream_list(L); + if (upl) { + + selected = rspamd_upstream_get(upl, RSPAMD_UPSTREAM_ROUND_ROBIN, NULL, 0); + if (selected) { + lua_push_upstream(L, 1, selected); + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +/*** + * @method upstream_list:get_upstream_master_slave() + * Get upstream master slave order (by static priority) + * @return {upstream} upstream from a list in master-slave order + */ +static gint +lua_upstream_list_get_upstream_master_slave(lua_State *L) +{ + LUA_TRACE_POINT; + struct upstream_list *upl; + struct upstream *selected; + + upl = lua_check_upstream_list(L); + if (upl) { + + selected = rspamd_upstream_get(upl, RSPAMD_UPSTREAM_MASTER_SLAVE, + NULL, + 0); + if (selected) { + lua_push_upstream(L, 1, selected); + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +struct upstream_foreach_cbdata { + lua_State *L; + gint ups_pos; +}; + +static void lua_upstream_inserter(struct upstream *up, guint idx, void *ud) +{ + struct upstream_foreach_cbdata *cbd = (struct upstream_foreach_cbdata *) ud; + + lua_push_upstream(cbd->L, cbd->ups_pos, up); + lua_rawseti(cbd->L, -2, idx + 1); +} +/*** + * @method upstream_list:all_upstreams() + * Returns all upstreams for this list + * @return {table|upstream} all upstreams defined + */ +static gint +lua_upstream_list_all_upstreams(lua_State *L) +{ + LUA_TRACE_POINT; + struct upstream_list *upl; + struct upstream_foreach_cbdata cbd; + + upl = lua_check_upstream_list(L); + if (upl) { + cbd.L = L; + cbd.ups_pos = 1; + + lua_createtable(L, rspamd_upstreams_count(upl), 0); + rspamd_upstreams_foreach(upl, lua_upstream_inserter, &cbd); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static inline enum rspamd_upstreams_watch_event +lua_str_to_upstream_flag(const gchar *str) +{ + enum rspamd_upstreams_watch_event fl = 0; + + if (strcmp(str, "success") == 0) { + fl = RSPAMD_UPSTREAM_WATCH_SUCCESS; + } + else if (strcmp(str, "failure") == 0) { + fl = RSPAMD_UPSTREAM_WATCH_FAILURE; + } + else if (strcmp(str, "online") == 0) { + fl = RSPAMD_UPSTREAM_WATCH_ONLINE; + } + else if (strcmp(str, "offline") == 0) { + fl = RSPAMD_UPSTREAM_WATCH_OFFLINE; + } + else { + msg_err("invalid flag: %s", str); + } + + return fl; +} + +static inline const gchar * +lua_upstream_flag_to_str(enum rspamd_upstreams_watch_event fl) +{ + const gchar *res = "unknown"; + + /* Works with single flags, not combinations */ + if (fl & RSPAMD_UPSTREAM_WATCH_SUCCESS) { + res = "success"; + } + else if (fl & RSPAMD_UPSTREAM_WATCH_FAILURE) { + res = "failure"; + } + else if (fl & RSPAMD_UPSTREAM_WATCH_ONLINE) { + res = "online"; + } + else if (fl & RSPAMD_UPSTREAM_WATCH_OFFLINE) { + res = "offline"; + } + else { + msg_err("invalid flag: %d", fl); + } + + return res; +} + +struct rspamd_lua_upstream_watcher_cbdata { + lua_State *L; + gint cbref; + gint parent_cbref; /* Reference to the upstream list */ + struct upstream_list *upl; +}; + +static void +lua_upstream_watch_func(struct upstream *up, + enum rspamd_upstreams_watch_event event, + guint cur_errors, + void *ud) +{ + struct rspamd_lua_upstream_watcher_cbdata *cdata = + (struct rspamd_lua_upstream_watcher_cbdata *) ud; + lua_State *L; + const gchar *what; + gint err_idx; + + L = cdata->L; + what = lua_upstream_flag_to_str(event); + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + + lua_rawgeti(L, LUA_REGISTRYINDEX, cdata->cbref); + lua_pushstring(L, what); + + struct rspamd_lua_upstream *lua_ups = lua_newuserdata(L, sizeof(*lua_ups)); + lua_ups->up = up; + rspamd_lua_setclass(L, "rspamd{upstream}", -1); + /* Store parent in the upstream to prevent gc */ + lua_rawgeti(L, LUA_REGISTRYINDEX, cdata->parent_cbref); + lua_ups->upref = luaL_ref(L, LUA_REGISTRYINDEX); + + lua_pushinteger(L, cur_errors); + + if (lua_pcall(L, 3, 0, err_idx) != 0) { + msg_err("cannot call watch function for upstream: %s", lua_tostring(L, -1)); + lua_settop(L, 0); + + return; + } + + lua_settop(L, 0); +} + +static void +lua_upstream_watch_dtor(gpointer ud) +{ + struct rspamd_lua_upstream_watcher_cbdata *cdata = + (struct rspamd_lua_upstream_watcher_cbdata *) ud; + + luaL_unref(cdata->L, LUA_REGISTRYINDEX, cdata->cbref); + luaL_unref(cdata->L, LUA_REGISTRYINDEX, cdata->parent_cbref); + g_free(cdata); +} + +/*** + * @method upstream_list:add_watcher(what, cb) + * Add new watcher to the upstream lists events (table or a string): + * - `success` - called whenever upstream successfully used + * - `failure` - called on upstream error + * - `online` - called when upstream is being taken online from offline + * - `offline` - called when upstream is being taken offline from online + * Callback is a function: function(what, upstream, cur_errors) ... end + * @example +ups:add_watcher('success', function(what, up, cur_errors) ... end) +ups:add_watcher({'online', 'offline'}, function(what, up, cur_errors) ... end) + * @return nothing + */ +static gint +lua_upstream_list_add_watcher(lua_State *L) +{ + LUA_TRACE_POINT; + struct upstream_list *upl; + + upl = lua_check_upstream_list(L); + if (upl && + (lua_type(L, 2) == LUA_TTABLE || lua_type(L, 2) == LUA_TSTRING) && + lua_type(L, 3) == LUA_TFUNCTION) { + + enum rspamd_upstreams_watch_event flags = 0; + struct rspamd_lua_upstream_watcher_cbdata *cdata; + + if (lua_type(L, 2) == LUA_TSTRING) { + flags = lua_str_to_upstream_flag(lua_tostring(L, 2)); + } + else { + for (lua_pushnil(L); lua_next(L, -2); lua_pop(L, 1)) { + if (lua_isstring(L, -1)) { + flags |= lua_str_to_upstream_flag(lua_tostring(L, -1)); + } + else { + lua_pop(L, 1); + + return luaL_error(L, "invalid arguments"); + } + } + } + + cdata = g_malloc0(sizeof(*cdata)); + lua_pushvalue(L, 3); /* callback */ + cdata->cbref = luaL_ref(L, LUA_REGISTRYINDEX); + cdata->L = L; + cdata->upl = upl; + lua_pushvalue(L, 1); /* upstream list itself */ + cdata->parent_cbref = luaL_ref(L, LUA_REGISTRYINDEX); + + rspamd_upstreams_add_watch_callback(upl, flags, + lua_upstream_watch_func, lua_upstream_watch_dtor, cdata); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 0; +} + +static gint +lua_load_upstream_list(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, upstream_list_f); + + return 1; +} + +void luaopen_upstream(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{upstream_list}", upstream_list_m); + lua_pop(L, 1); + rspamd_lua_add_preload(L, "rspamd_upstream_list", lua_load_upstream_list); + + rspamd_lua_new_class(L, "rspamd{upstream}", upstream_m); + lua_pop(L, 1); +} diff --git a/src/lua/lua_url.c b/src/lua/lua_url.c new file mode 100644 index 0000000..913469f --- /dev/null +++ b/src/lua/lua_url.c @@ -0,0 +1,1481 @@ +/* + * Copyright 2023 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" +#include "lua_url.h" + + +/*** + * @module rspamd_url + * This module provides routines to handle URL's and extract URL's from the text. + * Objects of this class are returned, for example, by `task:get_urls()` or `task:get_emails()`. + * You can also create `rspamd_url` from any text. + * @example +local url = require "rspamd_url" +local mpool = require "rspamd_mempool" + +url.init("/usr/share/rspamd/effective_tld_names.dat") +local pool = mpool.create() +local res = url.create(pool, 'Look at: http://user@test.example.com/test?query") +local t = res:to_table() +-- Content of t: +-- url = ['http://test.example.com/test?query'] +-- host = ['test.example.com'] +-- user = ['user'] +-- path = ['test'] +-- tld = ['example.com'] + +pool:destroy() -- res is destroyed here, so you should not use it afterwards + +local mistake = res:to_table() -- INVALID! as pool is destroyed + */ + +/* URL methods */ +LUA_FUNCTION_DEF(url, get_length); +LUA_FUNCTION_DEF(url, get_host); +LUA_FUNCTION_DEF(url, get_port); +LUA_FUNCTION_DEF(url, get_user); +LUA_FUNCTION_DEF(url, get_path); +LUA_FUNCTION_DEF(url, get_query); +LUA_FUNCTION_DEF(url, get_fragment); +LUA_FUNCTION_DEF(url, get_text); +LUA_FUNCTION_DEF(url, tostring); +LUA_FUNCTION_DEF(url, get_raw); +LUA_FUNCTION_DEF(url, get_tld); +LUA_FUNCTION_DEF(url, get_flags); +LUA_FUNCTION_DEF(url, get_flags_num); +LUA_FUNCTION_DEF(url, get_protocol); +LUA_FUNCTION_DEF(url, to_table); +LUA_FUNCTION_DEF(url, is_phished); +LUA_FUNCTION_DEF(url, is_redirected); +LUA_FUNCTION_DEF(url, is_obscured); +LUA_FUNCTION_DEF(url, is_html_displayed); +LUA_FUNCTION_DEF(url, is_subject); +LUA_FUNCTION_DEF(url, get_phished); +LUA_FUNCTION_DEF(url, set_redirected); +LUA_FUNCTION_DEF(url, get_count); +LUA_FUNCTION_DEF(url, get_visible); +LUA_FUNCTION_DEF(url, create); +LUA_FUNCTION_DEF(url, init); +LUA_FUNCTION_DEF(url, all); +LUA_FUNCTION_DEF(url, lt); +LUA_FUNCTION_DEF(url, eq); +LUA_FUNCTION_DEF(url, get_order); +LUA_FUNCTION_DEF(url, get_part_order); + +static const struct luaL_reg urllib_m[] = { + LUA_INTERFACE_DEF(url, get_length), + LUA_INTERFACE_DEF(url, get_host), + LUA_INTERFACE_DEF(url, get_port), + LUA_INTERFACE_DEF(url, get_user), + LUA_INTERFACE_DEF(url, get_path), + LUA_INTERFACE_DEF(url, get_query), + LUA_INTERFACE_DEF(url, get_fragment), + LUA_INTERFACE_DEF(url, get_text), + LUA_INTERFACE_DEF(url, get_tld), + LUA_INTERFACE_DEF(url, get_raw), + LUA_INTERFACE_DEF(url, get_protocol), + LUA_INTERFACE_DEF(url, to_table), + LUA_INTERFACE_DEF(url, is_phished), + LUA_INTERFACE_DEF(url, is_redirected), + LUA_INTERFACE_DEF(url, is_obscured), + LUA_INTERFACE_DEF(url, is_html_displayed), + LUA_INTERFACE_DEF(url, is_subject), + LUA_INTERFACE_DEF(url, get_phished), + + LUA_INTERFACE_DEF(url, get_visible), + LUA_INTERFACE_DEF(url, get_count), + LUA_INTERFACE_DEF(url, get_flags), + LUA_INTERFACE_DEF(url, get_flags_num), + LUA_INTERFACE_DEF(url, get_order), + LUA_INTERFACE_DEF(url, get_part_order), + {"get_redirected", lua_url_get_phished}, + LUA_INTERFACE_DEF(url, set_redirected), + {"__tostring", lua_url_tostring}, + {"__eq", lua_url_eq}, + {"__lt", lua_url_lt}, + {NULL, NULL}}; + +static const struct luaL_reg urllib_f[] = { + LUA_INTERFACE_DEF(url, init), + LUA_INTERFACE_DEF(url, create), + LUA_INTERFACE_DEF(url, all), + {NULL, NULL}}; + +struct rspamd_lua_url * +lua_check_url(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{url}"); + luaL_argcheck(L, ud != NULL, pos, "'url' expected"); + return ud ? ((struct rspamd_lua_url *) ud) : NULL; +} + +static gboolean +lua_url_single_inserter(struct rspamd_url *url, gsize start_offset, + gsize end_offset, gpointer ud) +{ + lua_State *L = ud; + struct rspamd_lua_url *lua_url; + + lua_url = lua_newuserdata(L, sizeof(struct rspamd_lua_url)); + rspamd_lua_setclass(L, "rspamd{url}", -1); + lua_url->url = url; + + return TRUE; +} + +/*** + * @method url:get_length() + * Get length of the url + * @return {number} length of url in bytes + */ +static gint +lua_url_get_length(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1); + + if (url != NULL) { + lua_pushinteger(L, url->url->urllen); + } + else { + lua_pushnil(L); + } + return 1; +} + +/*** + * @method url:get_host() + * Get domain part of the url + * @return {string} domain part of URL + */ +static gint +lua_url_get_host(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1); + + if (url != NULL && url->url && url->url->hostlen > 0) { + lua_pushlstring(L, rspamd_url_host(url->url), url->url->hostlen); + } + else { + lua_pushnil(L); + } + return 1; +} + +/*** + * @method url:get_port() + * Get port of the url + * @return {number} url port + */ +static gint +lua_url_get_port(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1); + + if (url != NULL) { + if (rspamd_url_get_port_if_special(url->url) == 0) { + lua_pushnil(L); + } + else { + lua_pushinteger(L, rspamd_url_get_port_if_special(url->url)); + } + } + else { + lua_pushnil(L); + } + return 1; +} + +/*** + * @method url:get_user() + * Get user part of the url (e.g. username in email) + * @return {string} user part of URL + */ +static gint +lua_url_get_user(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1); + + if (url != NULL && rspamd_url_user(url->url) != NULL) { + lua_pushlstring(L, rspamd_url_user(url->url), url->url->userlen); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** + * @method url:get_path() + * Get path of the url + * @return {string} path part of URL + */ +static gint +lua_url_get_path(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1); + + if (url != NULL && url->url->datalen > 0) { + lua_pushlstring(L, rspamd_url_data_unsafe(url->url), url->url->datalen); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** + * @method url:get_query() + * Get query of the url + * @return {string} query part of URL + */ +static gint +lua_url_get_query(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1); + + if (url != NULL && url->url->querylen > 0) { + lua_pushlstring(L, rspamd_url_query_unsafe(url->url), url->url->querylen); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** + * @method url:get_fragment() + * Get fragment of the url + * @return {string} fragment part of URL + */ +static gint +lua_url_get_fragment(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1); + + if (url != NULL && url->url->fragmentlen > 0) { + lua_pushlstring(L, rspamd_url_fragment_unsafe(url->url), url->url->fragmentlen); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** + * @method url:get_text() + * Get full content of the url + * @return {string} url string + */ +static gint +lua_url_get_text(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1); + + if (url != NULL) { + lua_pushlstring(L, url->url->string, url->url->urllen); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** + * @method url:tostring() + * Get full content of the url or user@domain in case of email + * @return {string} url as a string + */ +static gint +lua_url_tostring(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1); + + if (url != NULL && url->url != NULL) { + if (url->url->protocol == PROTOCOL_MAILTO) { + gchar *tmp = g_malloc(url->url->userlen + 1 + + url->url->hostlen); + if (url->url->userlen) { + memcpy(tmp, url->url->string + url->url->usershift, url->url->userlen); + } + + tmp[url->url->userlen] = '@'; + memcpy(tmp + url->url->userlen + 1, rspamd_url_host_unsafe(url->url), + url->url->hostlen); + + lua_pushlstring(L, tmp, url->url->userlen + 1 + url->url->hostlen); + g_free(tmp); + } + else { + lua_pushlstring(L, url->url->string, url->url->urllen); + } + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** + * @method url:get_raw() + * Get full content of the url as it was parsed (e.g. with urldecode) + * @return {string} url string + */ +static gint +lua_url_get_raw(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1); + + if (url != NULL) { + lua_pushlstring(L, url->url->raw, url->url->rawlen); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** + * @method url:is_phished() + * Check whether URL is treated as phished + * @return {boolean} `true` if URL is phished + */ +static gint +lua_url_is_phished(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1); + + if (url != NULL) { + lua_pushboolean(L, url->url->flags & RSPAMD_URL_FLAG_PHISHED); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** + * @method url:is_redirected() + * Check whether URL was redirected + * @return {boolean} `true` if URL is redirected + */ +static gint +lua_url_is_redirected(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1); + + if (url != NULL) { + lua_pushboolean(L, url->url->flags & RSPAMD_URL_FLAG_REDIRECTED); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** + * @method url:is_obscured() + * Check whether URL is treated as obscured or obfuscated (e.g. numbers in IP address or other hacks) + * @return {boolean} `true` if URL is obscured + */ +static gint +lua_url_is_obscured(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1); + + if (url != NULL) { + lua_pushboolean(L, url->url->flags & RSPAMD_URL_FLAG_OBSCURED); + } + else { + lua_pushnil(L); + } + + return 1; +} + + +/*** + * @method url:is_html_displayed() + * Check whether URL is just displayed in HTML (e.g. NOT a real href) + * @return {boolean} `true` if URL is displayed only + */ +static gint +lua_url_is_html_displayed(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1); + + if (url != NULL) { + lua_pushboolean(L, url->url->flags & RSPAMD_URL_FLAG_HTML_DISPLAYED); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** + * @method url:is_subject() + * Check whether URL is found in subject + * @return {boolean} `true` if URL is found in subject + */ +static gint +lua_url_is_subject(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1); + + if (url != NULL) { + lua_pushboolean(L, url->url->flags & RSPAMD_URL_FLAG_SUBJECT); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** + * @method url:get_phished() + * Get another URL that pretends to be this URL (e.g. used in phishing) + * @return {url} phished URL + */ +static gint +lua_url_get_phished(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *purl, *url = lua_check_url(L, 1); + + if (url) { + if (url->url->ext && url->url->ext->linked_url != NULL) { + /* XXX: in fact, this is the only possible combination of flags, so this check is redundant */ + if (url->url->flags & + (RSPAMD_URL_FLAG_PHISHED | RSPAMD_URL_FLAG_REDIRECTED)) { + purl = lua_newuserdata(L, sizeof(struct rspamd_lua_url)); + rspamd_lua_setclass(L, "rspamd{url}", -1); + purl->url = url->url->ext->linked_url; + + return 1; + } + } + } + + lua_pushnil(L); + return 1; +} + +/*** + * @method url:set_redirected(url, pool) + * Set url as redirected to another url + * @param {string|url} url new url that is redirecting an old one + * @param {pool} pool memory pool to allocate memory if needed + * @return {url} parsed redirected url (if needed) + */ +static gint +lua_url_set_redirected(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1), *redir; + rspamd_mempool_t *pool = NULL; + + if (url == NULL) { + return luaL_error(L, "url is required as the first argument"); + } + + if (lua_type(L, 2) == LUA_TSTRING) { + /* Parse url */ + if (lua_type(L, 3) != LUA_TUSERDATA) { + return luaL_error(L, "mempool is required as the third argument"); + } + + pool = rspamd_lua_check_mempool(L, 3); + + if (pool == NULL) { + return luaL_error(L, "mempool is required as the third argument"); + } + + gsize len; + const gchar *urlstr = lua_tolstring(L, 2, &len); + + rspamd_url_find_single(pool, urlstr, len, RSPAMD_URL_FIND_ALL, + lua_url_single_inserter, L); + + if (lua_type(L, -1) != LUA_TUSERDATA) { + /* URL is actually not found */ + lua_pushnil(L); + } + else { + redir = lua_check_url(L, -1); + + url->url->flags |= RSPAMD_URL_FLAG_REDIRECTED; + + if (url->url->ext == NULL) { + url->url->ext = rspamd_mempool_alloc0_type(pool, struct rspamd_url_ext); + } + url->url->ext->linked_url = redir->url; + } + } + else { + redir = lua_check_url(L, 2); + + if (redir == NULL) { + return luaL_error(L, "url is required as the second argument"); + } + + pool = rspamd_lua_check_mempool(L, 3); + + if (pool == NULL) { + return luaL_error(L, "mempool is required as the third argument"); + } + + url->url->flags |= RSPAMD_URL_FLAG_REDIRECTED; + if (url->url->ext == NULL) { + url->url->ext = rspamd_mempool_alloc0_type(pool, struct rspamd_url_ext); + } + url->url->ext->linked_url = redir->url; + + /* Push back on stack */ + lua_pushvalue(L, 2); + } + + return 1; +} + +/*** + * @method url:get_tld() + * Get effective second level domain part (eSLD) of the url host + * @return {string} effective second level domain part (eSLD) of the url host + */ +static gint +lua_url_get_tld(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1); + + if (url != NULL && url->url->tldlen > 0) { + lua_pushlstring(L, rspamd_url_tld_unsafe(url->url), url->url->tldlen); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** + * @method url:get_protocol() + * Get protocol name + * @return {string} protocol as a string + */ +static gint +lua_url_get_protocol(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1); + + if (url != NULL && url->url->protocol != PROTOCOL_UNKNOWN) { + lua_pushstring(L, rspamd_url_protocol_name(url->url->protocol)); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** + * @method url:get_count() + * Return number of occurrences for this particular URL + * @return {number} number of occurrences + */ +static gint +lua_url_get_count(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1); + + if (url != NULL && url->url != NULL) { + lua_pushinteger(L, url->url->count); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** +* @method url:get_visible() +* Get visible part of the url with html tags stripped +* @return {string} url string +*/ +static gint +lua_url_get_visible(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1); + + if (url != NULL && url->url->ext && url->url->ext->visible_part) { + lua_pushstring(L, url->url->ext->visible_part); + } + else { + lua_pushnil(L); + } + + return 1; +} + +/*** + * @method url:to_table() + * Return url as a table with the following fields: + * + * - `url`: full content + * - `host`: hostname part + * - `user`: user part + * - `path`: path part + * - `tld`: top level domain + * - `protocol`: url protocol + * @return {table} URL as a table + */ +static gint +lua_url_to_table(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1); + struct rspamd_url *u; + + if (url != NULL) { + u = url->url; + lua_createtable(L, 0, 12); + lua_pushstring(L, "url"); + lua_pushlstring(L, u->string, u->urllen); + lua_settable(L, -3); + + if (u->hostlen > 0) { + lua_pushstring(L, "host"); + lua_pushlstring(L, rspamd_url_host_unsafe(u), u->hostlen); + lua_settable(L, -3); + } + + if (rspamd_url_get_port_if_special(u) != 0) { + lua_pushstring(L, "port"); + lua_pushinteger(L, rspamd_url_get_port_if_special(u)); + lua_settable(L, -3); + } + + if (u->tldlen > 0) { + lua_pushstring(L, "tld"); + lua_pushlstring(L, rspamd_url_tld_unsafe(u), u->tldlen); + lua_settable(L, -3); + } + + if (u->userlen > 0) { + lua_pushstring(L, "user"); + lua_pushlstring(L, rspamd_url_user(u), u->userlen); + lua_settable(L, -3); + } + + if (u->datalen > 0) { + lua_pushstring(L, "path"); + lua_pushlstring(L, rspamd_url_data_unsafe(u), u->datalen); + lua_settable(L, -3); + } + + if (u->querylen > 0) { + lua_pushstring(L, "query"); + lua_pushlstring(L, rspamd_url_query_unsafe(u), u->querylen); + lua_settable(L, -3); + } + + if (u->fragmentlen > 0) { + lua_pushstring(L, "fragment"); + lua_pushlstring(L, rspamd_url_fragment_unsafe(u), u->fragmentlen); + lua_settable(L, -3); + } + + + lua_pushstring(L, "protocol"); + lua_pushstring(L, rspamd_url_protocol_name(u->protocol)); + lua_settable(L, -3); + } + else { + lua_pushnil(L); + } + + return 1; +} + +static rspamd_mempool_t *static_lua_url_pool; + +RSPAMD_CONSTRUCTOR(rspamd_urls_static_pool_ctor) +{ + static_lua_url_pool = rspamd_mempool_new(rspamd_mempool_suggest_size(), + "static_lua_url", 0); +} + +RSPAMD_DESTRUCTOR(rspamd_urls_static_pool_dtor) +{ + rspamd_mempool_delete(static_lua_url_pool); +} + +/*** + * @function url.create([mempool,] str, [{flags_table}]) + * @param {rspamd_mempool} memory pool for URL, e.g. `task:get_mempool()` + * @param {string} text that contains URL (can also contain other stuff) + * @return {url} new url object that exists as long as the corresponding mempool exists + */ +static gint +lua_url_create(lua_State *L) +{ + LUA_TRACE_POINT; + rspamd_mempool_t *pool; + struct rspamd_lua_text *t; + struct rspamd_lua_url *u; + + if (lua_type(L, 1) == LUA_TUSERDATA) { + pool = rspamd_lua_check_mempool(L, 1); + t = lua_check_text_or_string(L, 2); + } + else { + pool = static_lua_url_pool; + t = lua_check_text_or_string(L, 2); + } + + if (pool == NULL || t == NULL) { + return luaL_error(L, "invalid arguments"); + } + else { + rspamd_url_find_single(pool, t->start, t->len, RSPAMD_URL_FIND_ALL, + lua_url_single_inserter, L); + + if (lua_type(L, -1) != LUA_TUSERDATA) { + /* URL is actually not found */ + lua_pushnil(L); + + return 1; + } + + u = (struct rspamd_lua_url *) lua_touserdata(L, -1); + + if (lua_type(L, 3) == LUA_TTABLE) { + /* Add flags */ + for (lua_pushnil(L); lua_next(L, 3); lua_pop(L, 1)) { + int nmask = 0; + const gchar *fname = lua_tostring(L, -1); + + if (rspamd_url_flag_from_string(fname, &nmask)) { + u->url->flags |= nmask; + } + else { + lua_pop(L, 1); + return luaL_error(L, "invalid flag: %s", fname); + } + } + } + } + + return 1; +} + +/*** + * @function url.init(tld_file) + * Initialize url library if not initialized yet by Rspamd + * @param {string} tld_file path to effective_tld_names.dat file (public suffix list) + * @return nothing + */ +static gint +lua_url_init(lua_State *L) +{ + const gchar *tld_path; + + tld_path = luaL_checkstring(L, 1); + + rspamd_url_init(tld_path); + + return 0; +} + +static gboolean +lua_url_table_inserter(struct rspamd_url *url, gsize start_offset, + gsize end_offset, gpointer ud) +{ + lua_State *L = ud; + struct rspamd_lua_url *lua_url; + gint n; + + n = rspamd_lua_table_size(L, -1); + lua_url = lua_newuserdata(L, sizeof(struct rspamd_lua_url)); + rspamd_lua_setclass(L, "rspamd{url}", -1); + lua_url->url = url; + lua_rawseti(L, -2, n + 1); + + return TRUE; +} + + +static gint +lua_url_all(lua_State *L) +{ + LUA_TRACE_POINT; + rspamd_mempool_t *pool = rspamd_lua_check_mempool(L, 1); + const gchar *text; + size_t length; + + if (pool == NULL) { + lua_pushnil(L); + } + else { + text = luaL_checklstring(L, 2, &length); + + if (text != NULL) { + lua_newtable(L); + rspamd_url_find_multiple(pool, text, length, + RSPAMD_URL_FIND_ALL, NULL, + lua_url_table_inserter, L); + } + else { + lua_pushnil(L); + } + } + + return 1; +} + +/*** + * @method url:get_flags() + * Return flags for a specified URL as map 'flag'->true for all flags set, + * possible flags are: + * + * - `phished`: URL is likely phished + * - `numeric`: URL is numeric (e.g. IP address) + * - `obscured`: URL was obscured + * - `redirected`: URL comes from redirector + * - `html_displayed`: URL is used just for displaying purposes + * - `text`: URL comes from the text + * - `subject`: URL comes from the subject + * - `host_encoded`: URL host part is encoded + * - `schema_encoded`: URL schema part is encoded + * - `query_encoded`: URL query part is encoded + * - `missing_slashes`: URL has some slashes missing + * - `idn`: URL has international characters + * - `has_port`: URL has port + * - `has_user`: URL has user part + * - `schemaless`: URL has no schema + * - `unnormalised`: URL has some unicode unnormalities + * - `zw_spaces`: URL has some zero width spaces + * - `url_displayed`: URL has some other url-like string in visible part + * - `image`: URL is from src attribute of img HTML tag + * @return {table} URL flags + */ +#define PUSH_FLAG(fl) \ + do { \ + if (flags & (fl)) { \ + lua_pushstring(L, rspamd_url_flag_to_string(fl)); \ + lua_pushboolean(L, true); \ + lua_settable(L, -3); \ + } \ + } while (0) + +static gint +lua_url_get_flags(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1); + enum rspamd_url_flags flags; + + if (url != NULL) { + flags = url->url->flags; + + lua_createtable(L, 0, 4); + + for (gint i = 0; i < RSPAMD_URL_MAX_FLAG_SHIFT; i++) { + PUSH_FLAG(1u << i); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +#undef PUSH_FLAG + +static gint +lua_url_get_flags_num(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1); + + if (url) { + lua_pushinteger(L, url->url->flags); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_url_get_order(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1); + + if (url) { + if (url->url->order != (uint16_t) -1) { + lua_pushinteger(L, url->url->order); + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_url_get_part_order(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *url = lua_check_url(L, 1); + + if (url) { + if (url->url->part_order != (uint16_t) -1) { + lua_pushinteger(L, url->url->part_order); + } + else { + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +void lua_tree_url_callback(gpointer key, gpointer value, gpointer ud) +{ + struct rspamd_lua_url *lua_url; + struct rspamd_url *url = (struct rspamd_url *) value; + struct lua_tree_cb_data *cb = ud; + + if ((url->protocol & cb->protocols_mask) == url->protocol) { + + /* Handle different flags application logic */ + switch (cb->flags_mode) { + case url_flags_mode_include_any: + if (url->flags != (url->flags & cb->flags_mask)) { + return; + } + break; + case url_flags_mode_include_explicit: + if ((url->flags & cb->flags_mask) != cb->flags_mask) { + return; + } + break; + case url_flags_mode_exclude_include: + if ((url->flags & cb->flags_exclude_mask) != 0) { + return; + } + if ((url->flags & cb->flags_mask) == 0) { + return; + } + break; + } + + if (cb->skip_prob > 0) { + gdouble coin = rspamd_random_double_fast_seed(&cb->random_seed); + + if (coin < cb->skip_prob) { + return; + } + } + + lua_url = lua_newuserdata(cb->L, sizeof(struct rspamd_lua_url)); + lua_pushvalue(cb->L, cb->metatable_pos); + lua_setmetatable(cb->L, -2); + lua_url->url = url; + lua_rawseti(cb->L, -2, cb->i++); + } +} + +gboolean +lua_url_cbdata_fill(lua_State *L, + gint pos, + struct lua_tree_cb_data *cbd, + guint default_protocols, + guint default_flags, + gsize max_urls) +{ + gint protocols_mask = 0; + + gint pos_arg_type = lua_type(L, pos); + guint flags_mask = default_flags; + gboolean seen_flags = FALSE, seen_protocols = FALSE; + + memset(cbd, 0, sizeof(*cbd)); + cbd->flags_mode = url_flags_mode_include_any; + + if (pos_arg_type == LUA_TBOOLEAN) { + protocols_mask = default_protocols; + if (lua_toboolean(L, 2)) { + protocols_mask |= PROTOCOL_MAILTO; + } + } + else if (pos_arg_type == LUA_TTABLE) { + if (rspamd_lua_geti(L, 1, pos) == LUA_TNIL) { + /* New method: indexed table */ + + lua_getfield(L, pos, "flags"); + if (lua_istable(L, -1)) { + gint top = lua_gettop(L); + + lua_getfield(L, pos, "flags_mode"); + if (lua_isstring(L, -1)) { + const gchar *mode_str = lua_tostring(L, -1); + + if (strcmp(mode_str, "explicit") == 0) { + cbd->flags_mode = url_flags_mode_include_explicit; + /* + * Ignore default flags in this mode and include + * merely flags specified by a caller + */ + flags_mask = 0; + } + } + lua_pop(L, 1); + + for (lua_pushnil(L); lua_next(L, top); lua_pop(L, 1)) { + int nmask = 0; + + + if (lua_type(L, -1) == LUA_TSTRING) { + const gchar *fname = lua_tostring(L, -1); + + + if (rspamd_url_flag_from_string(fname, &nmask)) { + flags_mask |= nmask; + } + else { + msg_info("bad url flag: %s", fname); + return FALSE; + } + } + else { + flags_mask |= lua_tointeger(L, -1); + } + } + + seen_flags = TRUE; + } + else { + flags_mask |= default_flags; + } + lua_pop(L, 1); + + lua_getfield(L, pos, "protocols"); + if (lua_istable(L, -1)) { + gint top = lua_gettop(L); + + for (lua_pushnil(L); lua_next(L, top); lua_pop(L, 1)) { + int nmask; + const gchar *pname = lua_tostring(L, -1); + + nmask = rspamd_url_protocol_from_string(pname); + + if (nmask != PROTOCOL_UNKNOWN) { + protocols_mask |= nmask; + } + else { + msg_info("bad url protocol: %s", pname); + return FALSE; + } + } + seen_protocols = TRUE; + } + else { + protocols_mask = default_protocols; + } + lua_pop(L, 1); + + if (!seen_protocols) { + lua_getfield(L, pos, "emails"); + if (lua_isboolean(L, -1)) { + if (lua_toboolean(L, -1)) { + protocols_mask |= PROTOCOL_MAILTO; + } + } + lua_pop(L, 1); + } + + if (!seen_flags) { + lua_getfield(L, pos, "images"); + if (lua_isboolean(L, -1)) { + if (lua_toboolean(L, -1)) { + flags_mask |= RSPAMD_URL_FLAG_IMAGE; + } + else { + flags_mask &= ~RSPAMD_URL_FLAG_IMAGE; + } + } + else { + flags_mask &= ~RSPAMD_URL_FLAG_IMAGE; + } + lua_pop(L, 1); + } + + if (!seen_flags) { + lua_getfield(L, pos, "content"); + if (lua_isboolean(L, -1)) { + if (lua_toboolean(L, -1)) { + flags_mask |= RSPAMD_URL_FLAG_CONTENT; + } + else { + flags_mask &= ~RSPAMD_URL_FLAG_CONTENT; + } + } + else { + flags_mask &= ~RSPAMD_URL_FLAG_CONTENT; + } + lua_pop(L, 1); + } + + lua_getfield(L, pos, "max_urls"); + if (lua_isnumber(L, -1)) { + max_urls = lua_tonumber(L, -1); + } + lua_pop(L, 1); + + lua_getfield(L, pos, "sort"); + if (lua_isboolean(L, -1)) { + cbd->sort = TRUE; + } + lua_pop(L, 1); + } + else { + /* Plain table of the protocols */ + for (lua_pushnil(L); lua_next(L, pos); lua_pop(L, 1)) { + int nmask; + const gchar *pname = lua_tostring(L, -1); + + nmask = rspamd_url_protocol_from_string(pname); + + if (nmask != PROTOCOL_UNKNOWN) { + protocols_mask |= nmask; + } + else { + msg_info("bad url protocol: %s", pname); + return FALSE; + } + } + } + + lua_pop(L, 1); /* After rspamd_lua_geti */ + } + else if (pos_arg_type == LUA_TSTRING) { + const gchar *plist = lua_tostring(L, pos); + gchar **strvec; + gchar *const *cvec; + + strvec = g_strsplit_set(plist, ",;", -1); + cvec = strvec; + + while (*cvec) { + int nmask; + + nmask = rspamd_url_protocol_from_string(*cvec); + + if (nmask != PROTOCOL_UNKNOWN) { + protocols_mask |= nmask; + } + else { + msg_info("bad url protocol: %s", *cvec); + g_strfreev(strvec); + + return FALSE; + } + + cvec++; + } + + g_strfreev(strvec); + } + else if (pos_arg_type == LUA_TNONE || pos_arg_type == LUA_TNIL) { + protocols_mask = default_protocols; + flags_mask = default_flags; + } + else { + return FALSE; + } + + if (lua_type(L, pos + 1) == LUA_TBOOLEAN) { + if (lua_toboolean(L, pos + 1)) { + flags_mask |= RSPAMD_URL_FLAG_IMAGE; + } + else { + flags_mask &= ~RSPAMD_URL_FLAG_IMAGE; + } + } + + cbd->i = 1; + cbd->L = L; + cbd->max_urls = max_urls; + cbd->protocols_mask = protocols_mask; + cbd->flags_mask = flags_mask; + + /* This needs to be removed from the stack */ + rspamd_lua_class_metatable(L, "rspamd{url}"); + cbd->metatable_pos = lua_gettop(L); + (void) lua_checkstack(L, cbd->metatable_pos + 4); + + return TRUE; +} + +gboolean +lua_url_cbdata_fill_exclude_include(lua_State *L, + gint pos, + struct lua_tree_cb_data *cbd, + guint default_protocols, + gsize max_urls) +{ + guint protocols_mask = default_protocols; + guint include_flags_mask, exclude_flags_mask; + + gint pos_arg_type = lua_type(L, pos); + + memset(cbd, 0, sizeof(*cbd)); + cbd->flags_mode = url_flags_mode_exclude_include; + + /* Include flags */ + if (pos_arg_type == LUA_TTABLE) { + include_flags_mask = 0; /* Reset to no flags */ + + for (lua_pushnil(L); lua_next(L, pos); lua_pop(L, 1)) { + int nmask = 0; + + if (lua_type(L, -1) == LUA_TSTRING) { + const gchar *fname = lua_tostring(L, -1); + + if (rspamd_url_flag_from_string(fname, &nmask)) { + include_flags_mask |= nmask; + } + else { + msg_info("bad url include flag: %s", fname); + return FALSE; + } + } + else { + include_flags_mask |= lua_tointeger(L, -1); + } + } + } + else if (pos_arg_type == LUA_TNIL || pos_arg_type == LUA_TNONE) { + /* Include all flags */ + include_flags_mask = ~0U; + } + else { + msg_info("bad arguments: wrong include mask"); + return FALSE; + } + + /* Exclude flags */ + pos_arg_type = lua_type(L, pos + 1); + if (pos_arg_type == LUA_TTABLE) { + exclude_flags_mask = 0; /* Reset to no flags */ + + for (lua_pushnil(L); lua_next(L, pos + 1); lua_pop(L, 1)) { + int nmask = 0; + + if (lua_type(L, -1) == LUA_TSTRING) { + const gchar *fname = lua_tostring(L, -1); + + if (rspamd_url_flag_from_string(fname, &nmask)) { + exclude_flags_mask |= nmask; + } + else { + msg_info("bad url exclude flag: %s", fname); + return FALSE; + } + } + else { + exclude_flags_mask |= lua_tointeger(L, -1); + } + } + } + else if (pos_arg_type == LUA_TNIL || pos_arg_type == LUA_TNONE) { + /* Empty all exclude flags */ + exclude_flags_mask = 0U; + } + else { + msg_info("bad arguments: wrong exclude mask"); + return FALSE; + } + + if (lua_type(L, pos + 2) == LUA_TTABLE) { + protocols_mask = 0U; /* Reset all protocols */ + + for (lua_pushnil(L); lua_next(L, pos + 2); lua_pop(L, 1)) { + int nmask; + const gchar *pname = lua_tostring(L, -1); + + nmask = rspamd_url_protocol_from_string(pname); + + if (nmask != PROTOCOL_UNKNOWN) { + protocols_mask |= nmask; + } + else { + msg_info("bad url protocol: %s", pname); + return FALSE; + } + } + } + else { + protocols_mask = default_protocols; + } + + cbd->i = 1; + cbd->L = L; + cbd->max_urls = max_urls; + cbd->protocols_mask = protocols_mask; + cbd->flags_mask = include_flags_mask; + cbd->flags_exclude_mask = exclude_flags_mask; + + /* This needs to be removed from the stack */ + rspamd_lua_class_metatable(L, "rspamd{url}"); + cbd->metatable_pos = lua_gettop(L); + (void) lua_checkstack(L, cbd->metatable_pos + 4); + + return TRUE; +} + + +void lua_url_cbdata_dtor(struct lua_tree_cb_data *cbd) +{ + if (cbd->metatable_pos != -1) { + lua_remove(cbd->L, cbd->metatable_pos); + } +} + +gsize lua_url_adjust_skip_prob(float timestamp, + guchar digest[16], + struct lua_tree_cb_data *cb, + gsize sz) +{ + if (cb->max_urls > 0 && sz > cb->max_urls) { + cb->skip_prob = 1.0 - ((gdouble) cb->max_urls) / (gdouble) sz; + /* + * Use task dependent probabilistic seed to ensure that + * consequent task:get_urls return the same list of urls + * We use both digest and timestamp here to avoid attack surface + * based just on digest. + */ + memcpy(&cb->random_seed, digest, 4); + memcpy(((unsigned char *) &cb->random_seed) + 4, ×tamp, 4); + sz = cb->max_urls; + } + + return sz; +} + +static gint +lua_url_eq(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *u1 = lua_check_url(L, 1), + *u2 = lua_check_url(L, 2); + + if (u1 && u2) { + lua_pushboolean(L, (rspamd_url_cmp(u1->url, u2->url) == 0)); + } + else { + lua_pushboolean(L, false); + } + + return 1; +} + +static gint +lua_url_lt(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_url *u1 = lua_check_url(L, 1), + *u2 = lua_check_url(L, 2); + + if (u1 && u2) { + lua_pushinteger(L, rspamd_url_cmp(u1->url, u2->url)); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_load_url(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, urllib_f); + + /* Push flags */ + lua_createtable(L, 0, RSPAMD_URL_MAX_FLAG_SHIFT); + for (int i = 0; i < RSPAMD_URL_MAX_FLAG_SHIFT; i++) { + guint flag = 1u << i; + + lua_pushinteger(L, flag); + lua_setfield(L, -2, rspamd_url_flag_to_string(flag)); + } + + lua_setfield(L, -2, "flags"); + + return 1; +} + +void luaopen_url(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{url}", urllib_m); + lua_pop(L, 1); + + rspamd_lua_add_preload(L, "rspamd_url", lua_load_url); +} diff --git a/src/lua/lua_url.h b/src/lua/lua_url.h new file mode 100644 index 0000000..a78dffa --- /dev/null +++ b/src/lua/lua_url.h @@ -0,0 +1,87 @@ +/*- + * Copyright 2020 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef RSPAMD_LUA_URL_H +#define RSPAMD_LUA_URL_H + +#include "lua_common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +struct lua_tree_cb_data { + lua_State *L; + int i; + int metatable_pos; + guint flags_mask; + guint flags_exclude_mask; + guint protocols_mask; + enum { + url_flags_mode_include_any, + url_flags_mode_include_explicit, + url_flags_mode_exclude_include, + } flags_mode; + gboolean sort; + gsize max_urls; + gdouble skip_prob; + guint64 random_seed; +}; + +void lua_tree_url_callback(gpointer key, gpointer value, gpointer ud); + +/** + * Fills a cbdata table based on the parameter at position pos + * @param L + * @param pos + * @param cbd + * @return + */ +gboolean lua_url_cbdata_fill(lua_State *L, gint pos, + struct lua_tree_cb_data *cbd, + guint default_protocols, + guint default_flags, + gsize max_urls); + +gboolean lua_url_cbdata_fill_exclude_include(lua_State *L, gint pos, + struct lua_tree_cb_data *cbd, + guint default_protocols, + gsize max_urls); + +/** + * Cleanup url cbdata + * @param cbd + */ +void lua_url_cbdata_dtor(struct lua_tree_cb_data *cbd); + +/** + * Adjust probabilistic skip of the urls + * @param timestamp + * @param digest + * @param cb + * @param sz + * @param max_urls + * @return + */ +gsize lua_url_adjust_skip_prob(float timestamp, + guchar digest[16], + struct lua_tree_cb_data *cb, + gsize sz); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/src/lua/lua_util.c b/src/lua/lua_util.c new file mode 100644 index 0000000..152c02d --- /dev/null +++ b/src/lua/lua_util.c @@ -0,0 +1,3585 @@ +/* + * Copyright 2023 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" +#include "unix-std.h" +#include "lua_compress.h" +#include "libmime/email_addr.h" +#include "libmime/content_type.h" +#include "libmime/mime_headers.h" +#include "libutil/hash.h" + +#include "lua_parsers.h" + +#ifdef WITH_LUA_REPL +#include "replxx.h" +#endif + +#include <math.h> +#include <glob.h> + +#include "unicode/uspoof.h" +#include "unicode/uscript.h" +#include "contrib/fastutf8/fastutf8.h" + +/*** + * @module rspamd_util + * This module contains some generic purpose utilities that could be useful for + * testing and production rules. + */ + +/*** + * @function util.create_event_base() + * Creates new event base for processing asynchronous events + * @return {ev_base} new event processing base + */ +LUA_FUNCTION_DEF(util, create_event_base); +/*** + * @function util.load_rspamd_config(filename) + * Load rspamd config from the specified file + * @return {confg} new configuration object suitable for access + */ +LUA_FUNCTION_DEF(util, load_rspamd_config); +/*** + * @function util.config_from_ucl(any, string) + * Load rspamd config from ucl represented by any lua table + * @return {confg} new configuration object suitable for access + */ +LUA_FUNCTION_DEF(util, config_from_ucl); +/*** + * @function util.encode_base64(input[, str_len, [newlines_type]]) + * Encodes data in base64 breaking lines if needed + * @param {text or string} input input data + * @param {number} str_len optional size of lines or 0 if split is not needed + * @return {rspamd_text} encoded data chunk + */ +LUA_FUNCTION_DEF(util, encode_base64); +/*** + * @function util.encode_qp(input[, str_len, [newlines_type]]) + * Encodes data in quoted printable breaking lines if needed + * @param {text or string} input input data + * @param {number} str_len optional size of lines or 0 if split is not needed + * @return {rspamd_text} encoded data chunk + */ +LUA_FUNCTION_DEF(util, encode_qp); + +/*** + * @function util.decode_qp(input) + * Decodes data from quoted printable + * @param {text or string} input input data + * @return {rspamd_text} decoded data chunk + */ +LUA_FUNCTION_DEF(util, decode_qp); + +/*** + * @function util.decode_base64(input) + * Decodes data from base64 ignoring whitespace characters + * @param {text or string} input data to decode; if `rspamd{text}` is used then the string is modified **in-place** + * @return {rspamd_text} decoded data chunk + */ +LUA_FUNCTION_DEF(util, decode_base64); + +/*** + * @function util.encode_base32(input, [b32type = 'default']) + * Encodes data in base32 breaking lines if needed + * @param {text or string} input input data + * @param {string} b32type base32 type (default, bleach, rfc) + * @return {rspamd_text} encoded data chunk + */ +LUA_FUNCTION_DEF(util, encode_base32); +/*** + * @function util.decode_base32(input, [b32type = 'default']) + * Decodes data from base32 ignoring whitespace characters + * @param {text or string} input data to decode + * @param {string} b32type base32 type (default, bleach, rfc) + * @return {rspamd_text} decoded data chunk + */ +LUA_FUNCTION_DEF(util, decode_base32); + +/*** + * @function util.decode_url(input) + * Decodes data from url encoding + * @param {text or string} input data to decode + * @return {rspamd_text} decoded data chunk + */ +LUA_FUNCTION_DEF(util, decode_url); + +/*** + * @function util.tokenize_text(input[, exceptions]) + * Create tokens from a text using optional exceptions list + * @param {text/string} input input data + * @param {table} exceptions, a table of pairs containing <start_pos,length> of exceptions in the input + * @return {table/strings} list of strings representing words in the text + */ +LUA_FUNCTION_DEF(util, tokenize_text); +LUA_FUNCTION_DEF(util, process_message); +/*** + * @function util.tanh(num) + * Calculates hyperbolic tangent of the specified floating point value + * @param {number} num input number + * @return {number} hyperbolic tangent of the variable + */ +LUA_FUNCTION_DEF(util, tanh); + +/*** + * @function util.parse_html(input) + * Parses HTML and returns the according text + * @param {string|text} in input HTML + * @return {rspamd_text} processed text with no HTML tags + */ +LUA_FUNCTION_DEF(util, parse_html); + +/*** + * @function util.levenshtein_distance(s1, s2) + * Returns levenstein distance between two strings + * @param {string} s1 the first string + * @param {string} s2 the second string + * @return {number} number of differences in two strings + */ +LUA_FUNCTION_DEF(util, levenshtein_distance); + +/*** + * @function util.fold_header(name, value, [how, [stop_chars]]) + * Fold rfc822 header according to the folding rules + * + * @param {string} name name of the header + * @param {string} value value of the header + * @param {string} how "cr" for \r, "lf" for \n and "crlf" for \r\n (default) + * @param {string} stop_chars also fold header when the + * @return {string} Folded value of the header + */ +LUA_FUNCTION_DEF(util, fold_header); + +/*** + * @function util.is_uppercase(str) + * Returns true if a string is all uppercase + * + * @param {string} str input string + * @return {bool} true if a string is all uppercase + */ +LUA_FUNCTION_DEF(util, is_uppercase); + +/*** + * @function util.humanize_number(num) + * Returns humanized representation of given number (like 1k instead of 1000) + * + * @param {number} num number to humanize + * @return {string} humanized representation of a number + */ +LUA_FUNCTION_DEF(util, humanize_number); + +/*** + * @function util.get_tld(host) + * Returns effective second level domain part (eSLD) for the specified host + * + * @param {string} host hostname + * @return {string} eSLD part of the hostname or the full hostname if eSLD was not found + */ +LUA_FUNCTION_DEF(util, get_tld); + +/*** + * @function util.glob(pattern) + * Returns results for the glob match for the specified pattern + * + * @param {string} pattern glob pattern to match ('?' and '*' are supported) + * @return {table/string} list of matched files + */ +LUA_FUNCTION_DEF(util, glob); + +/*** + * @function util.parse_mail_address(str, [pool]) + * Parses email address and returns a table of tables in the following format: + * + * - `raw` - the original value without any processing + * - `name` - name of internet address in UTF8, e.g. for `Vsevolod Stakhov <blah@foo.com>` it returns `Vsevolod Stakhov` + * - `addr` - address part of the address + * - `user` - user part (if present) of the address, e.g. `blah` + * - `domain` - domain part (if present), e.g. `foo.com` + * - `flags` - table with following keys set to true if given condition fulfilled: + * - [valid] - valid SMTP address in conformity with https://tools.ietf.org/html/rfc5321#section-4.1. + * - [ip] - domain is IPv4/IPv6 address + * - [braced] - angled `<blah@foo.com>` address + * - [quoted] - quoted user part + * - [empty] - empty address + * - [backslash] - user part contains backslash + * - [8bit] - contains 8bit characters + * + * @param {string} str input string + * @param {rspamd_mempool} pool memory pool to use + * @return {table/tables} parsed list of mail addresses + */ +LUA_FUNCTION_DEF(util, parse_mail_address); + +/*** + * @function util.strlen_utf8(str) + * Returns length of string encoded in utf-8 in characters. + * If invalid characters are found, then this function returns number of bytes. + * @param {string} str utf8 encoded string + * @return {number} number of characters in string + */ +LUA_FUNCTION_DEF(util, strlen_utf8); + +/*** + * @function util.lower_utf8(str) + * Converts utf8 string to lower case + * @param {string} str utf8 encoded string + * @return {string} lowercased utf8 string + */ +LUA_FUNCTION_DEF(util, lower_utf8); + +/*** + * @function util.normalize_utf8(str) + * Gets a string in UTF8 and normalises it to NFKC_Casefold form + * @param {string} str utf8 encoded string + * @return {string,integer} lowercased utf8 string + result of the normalisation (use bit.band to check): + * RSPAMD_UNICODE_NORM_NORMAL = 0, + * RSPAMD_UNICODE_NORM_UNNORMAL = (1 << 0), + * RSPAMD_UNICODE_NORM_ZERO_SPACES = (1 << 1), + * RSPAMD_UNICODE_NORM_ERROR = (1 << 2), + * RSPAMD_UNICODE_NORM_OVERFLOW = (1 << 3) + */ +LUA_FUNCTION_DEF(util, normalize_utf8); + + +/*** + * @function util.transliterate(str) + * Converts utf8 encoded string to latin transliteration + * @param {string/text} str utf8 encoded string + * @return {text} transliterated string + */ +LUA_FUNCTION_DEF(util, transliterate); + +/*** + * @function util.strequal_caseless(str1, str2) + * Compares two strings regardless of their case using ascii comparison. + * Returns `true` if `str1` is equal to `str2` + * @param {string} str1 utf8 encoded string + * @param {string} str2 utf8 encoded string + * @return {bool} result of comparison + */ +LUA_FUNCTION_DEF(util, strequal_caseless); + + +/*** + * @function util.strequal_caseless_utf8(str1, str2) + * Compares two utf8 strings regardless of their case using utf8 collation rules. + * Returns `true` if `str1` is equal to `str2` + * @param {string} str1 utf8 encoded string + * @param {string} str2 utf8 encoded string + * @return {bool} result of comparison + */ +LUA_FUNCTION_DEF(util, strequal_caseless_utf8); + + +/*** + * @function util.get_ticks() + * Returns current number of ticks as floating point number + * @return {number} number of current clock ticks (monotonically increasing) + */ +LUA_FUNCTION_DEF(util, get_ticks); + +/*** + * @function util.get_time() + * Returns current time as unix time in floating point representation + * @return {number} number of seconds since 01.01.1970 + */ +LUA_FUNCTION_DEF(util, get_time); + +/*** + * @function util.time_to_string(seconds) + * Converts time from Unix time to HTTP date format + * @param {number} seconds unix timestamp + * @return {string} date as HTTP date + */ +LUA_FUNCTION_DEF(util, time_to_string); + +/*** + * @function util.stat(fname) + * Performs stat(2) on a specified filepath and returns table of values + * + * - `size`: size of file in bytes + * - `type`: type of filepath: `regular`, `directory`, `special` + * - `mtime`: modification time as unix time + * + * @return {string,table} string is returned when error is occurred + * @example + * + * local err,st = util.stat('/etc/password') + * + * if err then + * -- handle error + * else + * print(st['size']) + * end + */ +LUA_FUNCTION_DEF(util, stat); + +/*** + * @function util.unlink(fname) + * Removes the specified file from the filesystem + * + * @param {string} fname filename to remove + * @return {boolean,[string]} true if file has been deleted or false,'error string' + */ +LUA_FUNCTION_DEF(util, unlink); + +/*** + * @function util.lock_file(fname, [fd]) + * Lock the specified file. This function returns {number} which must be passed to `util.unlock_file` after usage + * or you'll have a resource leak + * + * @param {string} fname filename to lock + * @param {number} fd use the specified fd instead of opening one + * @return {number|nil,string} number if locking was successful or nil + error otherwise + */ +LUA_FUNCTION_DEF(util, lock_file); + +/*** + * @function util.unlock_file(fd, [close_fd]) + * Unlock the specified file closing the file descriptor associated. + * + * @param {number} fd descriptor to unlock + * @param {boolean} close_fd close descriptor on unlocking (default: TRUE) + * @return {boolean[,string]} true if a file was unlocked + */ +LUA_FUNCTION_DEF(util, unlock_file); + +/*** + * @function util.create_file(fname, [mode]) + * Creates the specified file with the default mode 0644 + * + * @param {string} fname filename to create + * @param {number} mode open mode (you should use octal number here) + * @return {number|nil,string} file descriptor or pair nil + error string + */ +LUA_FUNCTION_DEF(util, create_file); + +/*** + * @function util.close_file(fd) + * Closes descriptor fd + * + * @param {number} fd descriptor to close + * @return {boolean[,string]} true if a file was closed + */ +LUA_FUNCTION_DEF(util, close_file); + +/*** + * @function util.random_hex(size) + * Returns random hex string of the specified size + * + * @param {number} len length of desired string in bytes + * @return {string} string with random hex digests + */ +LUA_FUNCTION_DEF(util, random_hex); + +/*** + * @function util.zstd_compress(data, [level=1]) + * Compresses input using zstd compression + * + * @param {string/rspamd_text} data input data + * @return {rspamd_text} compressed data + */ +LUA_FUNCTION_DEF(util, zstd_compress); + +/*** + * @function util.zstd_decompress(data) + * Decompresses input using zstd algorithm + * + * @param {string/rspamd_text} data compressed data + * @return {error,rspamd_text} pair of error + decompressed text + */ +LUA_FUNCTION_DEF(util, zstd_decompress); + +/*** + * @function util.gzip_decompress(data, [size_limit]) + * Decompresses input using gzip algorithm + * + * @param {string/rspamd_text} data compressed data + * @param {integer} size_limit optional size limit + * @return {rspamd_text} decompressed text + */ +LUA_FUNCTION_DEF(util, gzip_decompress); + +/*** + * @function util.inflate(data, [size_limit]) + * Decompresses input using inflate algorithm + * + * @param {string/rspamd_text} data compressed data + * @param {integer} size_limit optional size limit + * @return {rspamd_text} decompressed text + */ +LUA_FUNCTION_DEF(util, inflate); + +/*** + * @function util.gzip_compress(data, [level=1]) + * Compresses input using gzip compression + * + * @param {string/rspamd_text} data input data + * @return {rspamd_text} compressed data + */ +LUA_FUNCTION_DEF(util, gzip_compress); + +/*** + * @function util.normalize_prob(prob, [bias = 0.5]) + * Normalize probabilities using polynom + * + * @param {number} prob probability param + * @param {number} bias number to subtract for making the final solution + * @return {number} normalized number + */ +LUA_FUNCTION_DEF(util, normalize_prob); +/*** + * @function util.is_utf_spoofed(str, [str2]) + * Returns true if a string is spoofed (possibly with another string `str2`) + * @return {boolean} true if a string is spoofed + */ +LUA_FUNCTION_DEF(util, is_utf_spoofed); + +/** +* @function util.is_utf_mixed_script(str) +* Returns true if a string contains mixed unicode scripts +* @param {string} String to check +* @return {boolean} true if a string contains chars with mixed unicode script +*/ +LUA_FUNCTION_DEF(util, is_utf_mixed_script); + +/** +* @function util.is_utf_outside_range(str, range_start, range_end) +* Returns true if a string contains chars outside range +* @param {string} String to check +* @param {number} start of character range similar to uset_addRange +* @param {number} end of character range similar to uset_addRange +* @return {boolean} true if a string contains chars outside selected utf range +*/ +LUA_FUNCTION_DEF(util, is_utf_outside_range); + +/*** +* @function util.get_string_stats(str) +* Returns table with number of letters and digits in string +* @return {table} with string stats keys are "digits" and "letters" +*/ +LUA_FUNCTION_DEF(util, get_string_stats); + +/*** + * @function util.is_valid_utf8(str) + * Returns true if a string is valid UTF8 string + * @return {boolean} true if a string is spoofed + */ +LUA_FUNCTION_DEF(util, is_valid_utf8); + +/*** + * @function util.has_obscured_unicode(str) + * Returns true if a string has obscure UTF symbols (zero width spaces, order marks), ignores invalid utf characters + * @return {boolean} true if a has obscured unicode characters (+ character and offset if found) + */ +LUA_FUNCTION_DEF(util, has_obscured_unicode); + +/*** + * @function util.readline([prompt]) + * Returns string read from stdin with history and editing support + * @return {string} string read from the input (with line endings stripped) + */ +LUA_FUNCTION_DEF(util, readline); + +/*** + * @function util.readpassphrase([prompt]) + * Returns string read from stdin disabling echo + * @return {string} string read from the input (with line endings stripped) + */ +LUA_FUNCTION_DEF(util, readpassphrase); + +/*** + * @function util.file_exists(file) + * Checks if a specified file exists and is available for reading + * @return {boolean,string} true if file exists + string error if not + */ +LUA_FUNCTION_DEF(util, file_exists); + +/*** + * @function util.mkdir(dir[, recursive]) + * Creates a specified directory + * @return {boolean[,error]} true if directory has been created + */ +LUA_FUNCTION_DEF(util, mkdir); + +/*** + * @function util.umask(mask) + * Sets new umask. Accepts either numeric octal string, e.g. '022' or a plain + * number, e.g. 0x12 (since Lua does not support octal integrals) + * @return {number} old umask + */ +LUA_FUNCTION_DEF(util, umask); + +/*** + * @function util.isatty() + * Returns if stdout is a tty + * @return {boolean} true in case of output being tty + */ +LUA_FUNCTION_DEF(util, isatty); + +/*** + * @function util.pack(fmt, ...) + * + * Backport of Lua 5.3 `string.pack` function: + * Returns a binary string containing the values v1, v2, etc. packed (that is, + * serialized in binary form) according to the format string `fmt` + * A format string is a sequence of conversion options. The conversion + * options are as follows: + * + * * <: sets little endian + * * >: sets big endian + * * =: sets native endian + * * ![n]: sets maximum alignment to n (default is native alignment) + * * b: a signed byte (char) + * * B: an unsigned byte (char) + * * h: a signed short (native size) + * * H: an unsigned short (native size) + * * l: a signed long (native size) + * * L: an unsigned long (native size) + * * j: a lua_Integer + * * J: a lua_Unsigned + * * T: a size_t (native size) + * * i[n]: a signed int with n bytes (default is native size) + * * I[n]: an unsigned int with n bytes (default is native size) + * * f: a float (native size) + * * d: a double (native size) + * * n: a lua_Number + * * cn: a fixed-sized string with n bytes + * * z: a zero-terminated string + * * s[n]: a string preceded by its length coded as an unsigned integer with + * * n bytes (default is a size_t) + * * x: one byte of padding + * * Xop: an empty item that aligns according to option op (which is otherwise ignored) + * * ' ': (empty space) ignored + * + * (A "[n]" means an optional integral numeral.) Except for padding, spaces, + * and configurations (options "xX <=>!"), each option corresponds to an + * argument (in string.pack) or a result (in string.unpack). + * + * For options "!n", "sn", "in", and "In", n can be any integer between 1 and + * All integral options check overflows; string.pack checks whether the given + * value fits in the given size; string.unpack checks whether the read value + * fits in a Lua integer. + * + * Any format string starts as if prefixed by "!1=", that is, with maximum + * alignment of 1 (no alignment) and native endianness. + * + * Alignment works as follows: For each option, the format gets extra padding + * until the data starts at an offset that is a multiple of the minimum + * between the option size and the maximum alignment; this minimum must be a + * power of 2. Options "c" and "z" are not aligned; option "s" follows the + * alignment of its starting integer. + * + * All padding is filled with zeros by string.pack (and ignored by unpack). + */ +LUA_FUNCTION_DEF(util, pack); + +/*** + * @function util.packsize(fmt) + * + * Returns size of the packed binary string returned for the same `fmt` argument + * by @see util.pack + */ +LUA_FUNCTION_DEF(util, packsize); + +/*** + * @function util.unpack(fmt, s [, pos]) + * Unpacks string `s` according to the format string `fmt` as described in + * @see util.pack + * + * @returns {multiple} list of unpacked values according to `fmt` + */ +LUA_FUNCTION_DEF(util, unpack); + +/*** + * @function util.caseless_hash(str[, seed]) + * Calculates caseless non-crypto hash from a string or rspamd text + * @param str string or lua_text + * @param seed mandatory seed (0xdeadbabe by default) + * @return {int64} boxed int64_t + */ +LUA_FUNCTION_DEF(util, caseless_hash); + +/*** + * @function util.caseless_hash_fast(str[, seed]) + * Calculates caseless non-crypto hash from a string or rspamd text + * @param str string or lua_text + * @param seed mandatory seed (0xdeadbabe by default) + * @return {number} number from int64_t + */ +LUA_FUNCTION_DEF(util, caseless_hash_fast); + +/*** + * @function util.get_hostname() + * Returns hostname for this machine + * @return {string} hostname + */ +LUA_FUNCTION_DEF(util, get_hostname); + +/*** + * @function util.parse_content_type(ct_string, mempool) + * Parses content-type string to a table: + * - `type` + * - `subtype` + * - `charset` + * - `boundary` + * - other attributes + * + * @param {string} ct_string content type as string + * @param {rspamd_mempool} mempool needed to store temporary data (e.g. task pool) + * @return table or nil if cannot parse content type + */ +LUA_FUNCTION_DEF(util, parse_content_type); + +/*** + * @function util.mime_header_encode(hdr) + * Encodes header if needed + * @param {string} hdr input header + * @return encoded header + */ +LUA_FUNCTION_DEF(util, mime_header_encode); + +/*** + * @function util.btc_polymod(input_values) + * Performs bitcoin polymod function + * @param {table|numbers} input_values + * @return {boolean} true if polymod has been successful + */ +LUA_FUNCTION_DEF(util, btc_polymod); + +/*** + * @function util.parse_smtp_date(str[, local_tz]) + * Converts an SMTP date string to unix timestamp + * @param {string} str input string + * @param {boolean} local_tz convert to local tz if `true` + * @return {number} time as unix timestamp (converted to float) + */ +LUA_FUNCTION_DEF(util, parse_smtp_date); + + +static const struct luaL_reg utillib_f[] = { + LUA_INTERFACE_DEF(util, create_event_base), + LUA_INTERFACE_DEF(util, load_rspamd_config), + LUA_INTERFACE_DEF(util, config_from_ucl), + LUA_INTERFACE_DEF(util, process_message), + LUA_INTERFACE_DEF(util, encode_base64), + LUA_INTERFACE_DEF(util, encode_qp), + LUA_INTERFACE_DEF(util, decode_qp), + LUA_INTERFACE_DEF(util, decode_base64), + LUA_INTERFACE_DEF(util, encode_base32), + LUA_INTERFACE_DEF(util, decode_base32), + LUA_INTERFACE_DEF(util, decode_url), + LUA_INTERFACE_DEF(util, tokenize_text), + LUA_INTERFACE_DEF(util, tanh), + LUA_INTERFACE_DEF(util, parse_html), + LUA_INTERFACE_DEF(util, levenshtein_distance), + LUA_INTERFACE_DEF(util, fold_header), + LUA_INTERFACE_DEF(util, is_uppercase), + LUA_INTERFACE_DEF(util, humanize_number), + LUA_INTERFACE_DEF(util, get_tld), + LUA_INTERFACE_DEF(util, glob), + {"parse_addr", lua_util_parse_mail_address}, + LUA_INTERFACE_DEF(util, parse_mail_address), + LUA_INTERFACE_DEF(util, strlen_utf8), + LUA_INTERFACE_DEF(util, lower_utf8), + LUA_INTERFACE_DEF(util, normalize_utf8), + LUA_INTERFACE_DEF(util, transliterate), + LUA_INTERFACE_DEF(util, strequal_caseless), + LUA_INTERFACE_DEF(util, strequal_caseless_utf8), + LUA_INTERFACE_DEF(util, get_ticks), + LUA_INTERFACE_DEF(util, get_time), + LUA_INTERFACE_DEF(util, time_to_string), + LUA_INTERFACE_DEF(util, stat), + LUA_INTERFACE_DEF(util, unlink), + LUA_INTERFACE_DEF(util, lock_file), + LUA_INTERFACE_DEF(util, unlock_file), + LUA_INTERFACE_DEF(util, create_file), + LUA_INTERFACE_DEF(util, close_file), + LUA_INTERFACE_DEF(util, random_hex), + LUA_INTERFACE_DEF(util, zstd_compress), + LUA_INTERFACE_DEF(util, zstd_decompress), + LUA_INTERFACE_DEF(util, gzip_compress), + LUA_INTERFACE_DEF(util, gzip_decompress), + LUA_INTERFACE_DEF(util, inflate), + LUA_INTERFACE_DEF(util, normalize_prob), + LUA_INTERFACE_DEF(util, caseless_hash), + LUA_INTERFACE_DEF(util, caseless_hash_fast), + LUA_INTERFACE_DEF(util, is_utf_spoofed), + LUA_INTERFACE_DEF(util, is_utf_mixed_script), + LUA_INTERFACE_DEF(util, is_utf_outside_range), + LUA_INTERFACE_DEF(util, get_string_stats), + LUA_INTERFACE_DEF(util, is_valid_utf8), + LUA_INTERFACE_DEF(util, has_obscured_unicode), + LUA_INTERFACE_DEF(util, readline), + LUA_INTERFACE_DEF(util, readpassphrase), + LUA_INTERFACE_DEF(util, file_exists), + LUA_INTERFACE_DEF(util, mkdir), + LUA_INTERFACE_DEF(util, umask), + LUA_INTERFACE_DEF(util, isatty), + LUA_INTERFACE_DEF(util, get_hostname), + LUA_INTERFACE_DEF(util, parse_content_type), + LUA_INTERFACE_DEF(util, mime_header_encode), + LUA_INTERFACE_DEF(util, pack), + LUA_INTERFACE_DEF(util, unpack), + LUA_INTERFACE_DEF(util, packsize), + LUA_INTERFACE_DEF(util, btc_polymod), + LUA_INTERFACE_DEF(util, parse_smtp_date), + {NULL, NULL}}; + +LUA_FUNCTION_DEF(int64, tostring); +LUA_FUNCTION_DEF(int64, fromstring); +LUA_FUNCTION_DEF(int64, tonumber); +LUA_FUNCTION_DEF(int64, hex); + +static const struct luaL_reg int64lib_f[] = { + LUA_INTERFACE_DEF(int64, fromstring), + {NULL, NULL}}; +static const struct luaL_reg int64lib_m[] = { + LUA_INTERFACE_DEF(int64, tostring), + LUA_INTERFACE_DEF(int64, tonumber), + LUA_INTERFACE_DEF(int64, hex), + {"__tostring", lua_int64_tostring}, + {NULL, NULL}}; + +LUA_FUNCTION_DEF(ev_base, loop); + +static const struct luaL_reg ev_baselib_m[] = { + LUA_INTERFACE_DEF(ev_base, loop), + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + +static gint64 +lua_check_int64(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{int64}"); + luaL_argcheck(L, ud != NULL, pos, "'int64' expected"); + return ud ? *((gint64 *) ud) : 0LL; +} + + +static gint +lua_util_create_event_base(lua_State *L) +{ + LUA_TRACE_POINT; + struct ev_loop **pev_base; + + pev_base = lua_newuserdata(L, sizeof(struct ev_loop *)); + rspamd_lua_setclass(L, "rspamd{ev_base}", -1); + *pev_base = ev_loop_new(EVFLAG_SIGNALFD | EVBACKEND_ALL); + + return 1; +} + +static gint +lua_util_load_rspamd_config(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg, **pcfg; + const gchar *cfg_name; + + cfg_name = luaL_checkstring(L, 1); + + if (cfg_name) { + cfg = rspamd_config_new(RSPAMD_CONFIG_INIT_SKIP_LUA); + cfg->lua_state = L; + + if (rspamd_config_read(cfg, cfg_name, NULL, NULL, NULL, FALSE, NULL)) { + msg_err_config("cannot load config from %s", cfg_name); + lua_pushnil(L); + } + else { + rspamd_config_post_load(cfg, 0); + pcfg = lua_newuserdata(L, sizeof(struct rspamd_config *)); + rspamd_lua_setclass(L, "rspamd{config}", -1); + *pcfg = cfg; + } + } + + return 1; +} + +static gint +parse_config_options(const char *str_options) +{ + gint ret = 0; + gchar **vec; + const gchar *str; + guint i, l; + + vec = g_strsplit_set(str_options, ",;", -1); + if (vec) { + l = g_strv_length(vec); + for (i = 0; i < l; i++) { + str = vec[i]; + + if (g_ascii_strcasecmp(str, "INIT_URL") == 0) { + ret |= RSPAMD_CONFIG_INIT_URL; + } + else if (g_ascii_strcasecmp(str, "INIT_LIBS") == 0) { + ret |= RSPAMD_CONFIG_INIT_LIBS; + } + else if (g_ascii_strcasecmp(str, "INIT_SYMCACHE") == 0) { + ret |= RSPAMD_CONFIG_INIT_SYMCACHE; + } + else if (g_ascii_strcasecmp(str, "INIT_VALIDATE") == 0) { + ret |= RSPAMD_CONFIG_INIT_VALIDATE; + } + else if (g_ascii_strcasecmp(str, "INIT_NO_TLD") == 0) { + ret |= RSPAMD_CONFIG_INIT_NO_TLD; + } + else if (g_ascii_strcasecmp(str, "INIT_PRELOAD_MAPS") == 0) { + ret |= RSPAMD_CONFIG_INIT_PRELOAD_MAPS; + } + else { + msg_warn("bad type: %s", str); + } + } + + g_strfreev(vec); + } + + return ret; +} + +static gint +lua_util_config_from_ucl(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = NULL, **pcfg; + struct rspamd_rcl_sections_map *top; + GError *err = NULL; + ucl_object_t *obj; + const char *str_options = NULL; + gint int_options = 0; + + + obj = ucl_object_lua_import(L, 1); + if (lua_gettop(L) == 2) { + if (lua_type(L, 2) == LUA_TSTRING) { + str_options = lua_tostring(L, 2); + int_options = parse_config_options(str_options); + } + else { + msg_err("config_from_ucl: second parameter is expected to be string"); + ucl_object_unref(obj); + lua_pushnil(L); + } + } + + if (obj) { + cfg = rspamd_config_new(RSPAMD_CONFIG_INIT_SKIP_LUA); + cfg->lua_state = L; + + cfg->cfg_ucl_obj = obj; + top = rspamd_rcl_config_init(cfg, NULL); + + if (!rspamd_rcl_parse(top, cfg, cfg, cfg->cfg_pool, cfg->cfg_ucl_obj, &err)) { + msg_err("rcl parse error: %s", err->message); + ucl_object_unref(obj); + lua_pushnil(L); + } + else { + + if (int_options & RSPAMD_CONFIG_INIT_LIBS) { + cfg->libs_ctx = rspamd_init_libs(); + } + + rspamd_config_post_load(cfg, int_options); + pcfg = lua_newuserdata(L, sizeof(struct rspamd_config *)); + rspamd_lua_setclass(L, "rspamd{config}", -1); + *pcfg = cfg; + } + + rspamd_rcl_sections_free(top); + } + + return 1; +} + +static gboolean +lua_util_task_fin(struct rspamd_task *task, void *ud) +{ + ucl_object_t **target = ud; + + *target = rspamd_protocol_write_ucl(task, RSPAMD_PROTOCOL_DEFAULT); + rdns_resolver_release(task->resolver->r); + + return TRUE; +} + +static gint +lua_util_process_message(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_config *cfg = lua_check_config(L, 1); + const gchar *message; + gsize mlen; + struct rspamd_task *task; + struct ev_loop *base; + ucl_object_t *res = NULL; + + message = luaL_checklstring(L, 2, &mlen); + + if (cfg != NULL && message != NULL) { + base = ev_loop_new(EVFLAG_SIGNALFD | EVBACKEND_ALL); + rspamd_init_filters(cfg, false, false); + task = rspamd_task_new(NULL, cfg, NULL, NULL, base, FALSE); + task->msg.begin = rspamd_mempool_alloc(task->task_pool, mlen); + rspamd_strlcpy((gpointer) task->msg.begin, message, mlen); + task->msg.len = mlen; + task->fin_callback = lua_util_task_fin; + task->fin_arg = &res; + task->resolver = rspamd_dns_resolver_init(NULL, base, cfg); + task->s = rspamd_session_create(task->task_pool, rspamd_task_fin, + NULL, (event_finalizer_t) rspamd_task_free, task); + + if (!rspamd_task_load_message(task, NULL, message, mlen)) { + lua_pushnil(L); + } + else { + if (rspamd_task_process(task, RSPAMD_TASK_PROCESS_ALL)) { + ev_loop(base, 0); + + if (res != NULL) { + ucl_object_push_lua(L, res, true); + + ucl_object_unref(res); + } + else { + ucl_object_push_lua(L, + rspamd_protocol_write_ucl(task, RSPAMD_PROTOCOL_DEFAULT), + true); + rdns_resolver_release(task->resolver->r); + rspamd_session_destroy(task->s); + } + } + else { + lua_pushnil(L); + } + } + + ev_loop_destroy(base); + } + else { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_util_encode_base64(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t; + gchar *out; + gsize outlen; + long str_lim = 0; + gboolean fold = FALSE; + + t = lua_check_text_or_string(L, 1); + + if (lua_gettop(L) > 1) { + str_lim = luaL_checkinteger(L, 2); + fold = str_lim > 0; + } + + if (t == NULL) { + return luaL_error(L, "invalid arguments"); + } + else { + + if (fold) { + out = rspamd_encode_base64(t->start, t->len, str_lim, &outlen); + } + else { + enum rspamd_newlines_type how = RSPAMD_TASK_NEWLINES_CRLF; + + if (lua_type(L, 3) == LUA_TSTRING) { + const gchar *how_str = lua_tostring(L, 3); + + if (g_ascii_strcasecmp(how_str, "cr") == 0) { + how = RSPAMD_TASK_NEWLINES_CR; + } + else if (g_ascii_strcasecmp(how_str, "lf") == 0) { + how = RSPAMD_TASK_NEWLINES_LF; + } + else if (g_ascii_strcasecmp(how_str, "crlf") != 0) { + return luaL_error(L, "invalid newline style: %s", how_str); + } + } + + out = rspamd_encode_base64_fold(t->start, t->len, str_lim, &outlen, how); + } + + if (out != NULL) { + lua_new_text(L, out, outlen, TRUE); + } + else { + lua_pushnil(L); + } + } + + return 1; +} + +static gint +lua_util_encode_qp(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t; + const gchar *s = NULL; + gchar *out; + gsize inlen, outlen; + guint str_lim = 0; + + if (lua_type(L, 1) == LUA_TSTRING) { + s = luaL_checklstring(L, 1, &inlen); + } + else if (lua_type(L, 1) == LUA_TUSERDATA) { + t = lua_check_text(L, 1); + + if (t != NULL) { + s = t->start; + inlen = t->len; + } + } + + if (lua_gettop(L) > 1) { + str_lim = luaL_checknumber(L, 2); + } + + if (s == NULL) { + lua_pushnil(L); + } + else { + enum rspamd_newlines_type how = RSPAMD_TASK_NEWLINES_CRLF; + + if (lua_type(L, 3) == LUA_TSTRING) { + const gchar *how_str = lua_tostring(L, 3); + + if (g_ascii_strcasecmp(how_str, "cr") == 0) { + how = RSPAMD_TASK_NEWLINES_CR; + } + else if (g_ascii_strcasecmp(how_str, "lf") == 0) { + how = RSPAMD_TASK_NEWLINES_LF; + } + else if (g_ascii_strcasecmp(how_str, "crlf") != 0) { + return luaL_error(L, "invalid newline style: %s", how_str); + } + } + + out = rspamd_encode_qp_fold(s, inlen, str_lim, &outlen, how); + + if (out != NULL) { + t = lua_newuserdata(L, sizeof(*t)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + t->start = out; + t->len = outlen; + /* Need destruction */ + t->flags = RSPAMD_TEXT_FLAG_OWN; + } + else { + lua_pushnil(L); + } + } + + return 1; +} + +static gint +lua_util_decode_qp(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t, *out; + const gchar *s = NULL; + gsize inlen = 0; + gssize outlen; + + if (lua_type(L, 1) == LUA_TSTRING) { + s = luaL_checklstring(L, 1, &inlen); + } + else if (lua_type(L, 1) == LUA_TUSERDATA) { + t = lua_check_text(L, 1); + + if (t != NULL) { + s = t->start; + inlen = t->len; + } + } + + if (s == NULL) { + lua_pushnil(L); + } + else { + out = lua_newuserdata(L, sizeof(*t)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + out->start = g_malloc(inlen + 1); + out->flags = RSPAMD_TEXT_FLAG_OWN; + outlen = rspamd_decode_qp_buf(s, inlen, (gchar *) out->start, inlen + 1); + + if (outlen > 0) { + out->len = outlen; + } + else { + /* + * It removes out and frees memory on gc due to RSPAMD_TEXT_FLAG_OWN + */ + lua_pop(L, 1); + lua_pushnil(L); + } + } + + return 1; +} + +static gint +lua_util_decode_base64(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t; + const gchar *s = NULL; + gsize inlen = 0, outlen; + + if (lua_type(L, 1) == LUA_TSTRING) { + s = luaL_checklstring(L, 1, &inlen); + } + else if (lua_type(L, 1) == LUA_TUSERDATA) { + t = lua_check_text(L, 1); + + if (t != NULL) { + s = t->start; + inlen = t->len; + } + } + + if (s != NULL) { + t = lua_newuserdata(L, sizeof(*t)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + t->len = (inlen / 4) * 3 + 3; + t->start = g_malloc(t->len); + + rspamd_cryptobox_base64_decode(s, inlen, (guchar *) t->start, + &outlen); + t->len = outlen; + t->flags = RSPAMD_TEXT_FLAG_OWN; + } + else { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_util_encode_base32(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t; + const gchar *s = NULL; + gchar *out; + enum rspamd_base32_type btype = RSPAMD_BASE32_DEFAULT; + gsize inlen, outlen; + + if (lua_type(L, 1) == LUA_TSTRING) { + s = luaL_checklstring(L, 1, &inlen); + } + else if (lua_type(L, 1) == LUA_TUSERDATA) { + t = lua_check_text(L, 1); + + if (t != NULL) { + s = t->start; + inlen = t->len; + } + } + + if (lua_type(L, 2) == LUA_TSTRING) { + btype = rspamd_base32_decode_type_from_str(lua_tostring(L, 2)); + + if (btype == RSPAMD_BASE32_INVALID) { + return luaL_error(L, "invalid b32 type: %s", lua_tostring(L, 2)); + } + } + + if (s == NULL) { + return luaL_error(L, "invalid arguments"); + } + else { + out = rspamd_encode_base32(s, inlen, btype); + + if (out != NULL) { + t = lua_newuserdata(L, sizeof(*t)); + outlen = strlen(out); + rspamd_lua_setclass(L, "rspamd{text}", -1); + t->start = out; + t->len = outlen; + /* Need destruction */ + t->flags = RSPAMD_TEXT_FLAG_OWN; + } + else { + lua_pushnil(L); + } + } + + return 1; +} + +static gint +lua_util_decode_base32(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t; + const gchar *s = NULL; + gsize inlen, outlen; + enum rspamd_base32_type btype = RSPAMD_BASE32_DEFAULT; + + if (lua_type(L, 1) == LUA_TSTRING) { + s = luaL_checklstring(L, 1, &inlen); + } + else if (lua_type(L, 1) == LUA_TUSERDATA) { + t = lua_check_text(L, 1); + + if (t != NULL) { + s = t->start; + inlen = t->len; + } + } + + if (lua_type(L, 2) == LUA_TSTRING) { + btype = rspamd_base32_decode_type_from_str(lua_tostring(L, 2)); + + if (btype == RSPAMD_BASE32_INVALID) { + return luaL_error(L, "invalid b32 type: %s", lua_tostring(L, 2)); + } + } + + if (s != NULL) { + guchar *decoded; + + decoded = rspamd_decode_base32(s, inlen, &outlen, btype); + + if (decoded) { + t = lua_newuserdata(L, sizeof(*t)); + rspamd_lua_setclass(L, "rspamd{text}", -1); + t->start = (const gchar *) decoded; + t->len = outlen; + t->flags = RSPAMD_TEXT_FLAG_OWN; + } + else { + lua_pushnil(L); + } + } + else { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_util_decode_url(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t; + + t = lua_check_text_or_string(L, 1); + + if (t != NULL) { + struct rspamd_lua_text *out = lua_new_text(L, NULL, t->len, TRUE); + + out->len = rspamd_url_decode((char *) out->start, t->start, t->len); + } + else { + lua_pushnil(L); + } + + return 1; +} + + +static gint +lua_util_tokenize_text(lua_State *L) +{ + return lua_parsers_tokenize_text(L); +} + +static gint +lua_util_tanh(lua_State *L) +{ + LUA_TRACE_POINT; + gdouble in = luaL_checknumber(L, 1); + + lua_pushnumber(L, tanh(in)); + + return 1; +} + +static gint +lua_util_parse_html(lua_State *L) +{ + return lua_parsers_parse_html(L); +} + +static gint +lua_util_levenshtein_distance(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t1, *t2; + gint dist = 0; + guint replace_cost = 1; + + t1 = lua_check_text_or_string(L, 1); + t2 = lua_check_text_or_string(L, 2); + if (lua_isnumber(L, 3)) { + replace_cost = lua_tointeger(L, 3); + } + + if (t1 && t2) { + dist = rspamd_strings_levenshtein_distance(t1->start, t1->len, t2->start, t2->len, + replace_cost); + } + else { + return luaL_error(L, "invalid arguments"); + } + + lua_pushinteger(L, dist); + + return 1; +} + +static gint +lua_util_fold_header(lua_State *L) +{ + LUA_TRACE_POINT; + const gchar *how, *stop_chars = NULL; + struct rspamd_lua_text *name, *value; + GString *folded; + + name = lua_check_text_or_string(L, 1); + value = lua_check_text_or_string(L, 2); + + if (name && value) { + + if (lua_isstring(L, 3)) { + + how = lua_tostring(L, 3); + + if (lua_isstring(L, 4)) { + stop_chars = lua_tostring(L, 4); + } + + if (strcmp(how, "cr") == 0) { + folded = rspamd_header_value_fold(name->start, name->len, + value->start, value->len, + 0, + RSPAMD_TASK_NEWLINES_CR, stop_chars); + } + else if (strcmp(how, "lf") == 0) { + folded = rspamd_header_value_fold(name->start, name->len, + value->start, value->len, 0, + RSPAMD_TASK_NEWLINES_LF, stop_chars); + } + else { + folded = rspamd_header_value_fold(name->start, name->len, + value->start, value->len, 0, + RSPAMD_TASK_NEWLINES_CRLF, stop_chars); + } + } + else { + folded = rspamd_header_value_fold(name->start, name->len, + value->start, value->len, 0, + RSPAMD_TASK_NEWLINES_CRLF, stop_chars); + } + + if (folded) { + lua_pushlstring(L, folded->str, folded->len); + g_string_free(folded, TRUE); + + return 1; + } + } + + lua_pushnil(L); + return 1; +} + +static gint +lua_util_is_uppercase(lua_State *L) +{ + LUA_TRACE_POINT; + gint32 i = 0; + UChar32 uc; + guint nlc = 0, nuc = 0; + + struct rspamd_lua_text *t = lua_check_text_or_string(L, 1); + if (t) { + while (i >= 0 && i < t->len) { + U8_NEXT(t->start, i, t->len, uc); + + if (uc < 0) { + break; + } + + if (u_isupper(uc)) { + nuc++; + } + else if (u_islower(uc)) { + nlc++; + } + } + } + + if (nuc > 0 && nlc == 0) { + lua_pushboolean(L, TRUE); + } + else { + lua_pushboolean(L, FALSE); + } + + return 1; +} + +static gint +lua_util_humanize_number(lua_State *L) +{ + LUA_TRACE_POINT; + gint64 number = luaL_checkinteger(L, 1); + gchar numbuf[32]; + + + rspamd_snprintf(numbuf, sizeof(numbuf), "%hL", number); + lua_pushstring(L, numbuf); + + return 1; +} + +static gint +lua_util_get_tld(lua_State *L) +{ + LUA_TRACE_POINT; + const gchar *host; + gsize hostlen; + rspamd_ftok_t tld; + + host = luaL_checklstring(L, 1, &hostlen); + + if (host) { + if (!rspamd_url_find_tld(host, hostlen, &tld)) { + lua_pushlstring(L, host, hostlen); + } + else { + lua_pushlstring(L, tld.begin, tld.len); + } + } + else { + lua_pushnil(L); + } + + return 1; +} + + +static gint +lua_util_glob(lua_State *L) +{ + LUA_TRACE_POINT; + const gchar *pattern; + glob_t gl; + gint top, i, flags = 0; + + top = lua_gettop(L); + memset(&gl, 0, sizeof(gl)); + + for (i = 1; i <= top; i++, flags |= GLOB_APPEND) { + pattern = luaL_checkstring(L, i); + + if (pattern) { + if (glob(pattern, flags, NULL, &gl) != 0) { + /* There is no way to return error here, so just create an table */ + lua_createtable(L, 0, 0); + globfree(&gl); + + return 1; + } + } + } + + lua_createtable(L, gl.gl_pathc, 0); + /* Push results */ + for (i = 0; i < (gint) gl.gl_pathc; i++) { + lua_pushstring(L, gl.gl_pathv[i]); + lua_rawseti(L, -2, i + 1); + } + + globfree(&gl); + + return 1; +} + +static gint +lua_util_parse_mail_address(lua_State *L) +{ + return lua_parsers_parse_mail_address(L); +} + +static gint +lua_util_strlen_utf8(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t; + + t = lua_check_text_or_string(L, 1); + + if (t) { + gint32 i = 0, nchars = 0; + UChar32 uc; + + while (i < t->len) { + U8_NEXT((guint8 *) t->start, i, t->len, uc); + nchars++; + } + + lua_pushinteger(L, nchars); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_util_lower_utf8(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t; + + gchar *dst; + UChar32 uc; + UBool err = 0; + gint32 i = 0, j = 0; + + t = lua_check_text_or_string(L, 1); + + if (t) { + dst = g_malloc(t->len); + + while (i < t->len && err == 0) { + U8_NEXT((guint8 *) t->start, i, t->len, uc); + uc = u_tolower(uc); + U8_APPEND(dst, j, t->len, uc, err); + } + + if (lua_isstring(L, 1)) { + lua_pushlstring(L, dst, j); + g_free(dst); + } + else { + t = lua_new_text(L, dst, j, FALSE); + /* We have actually allocated text data before */ + t->flags |= RSPAMD_TEXT_FLAG_OWN; + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_util_normalize_utf8(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t; + bool is_text = lua_type(L, 1) == LUA_TUSERDATA; + + t = lua_check_text_or_string(L, 1); + + if (!t) { + return luaL_error(L, "invalid arguments"); + } + + char *cpy = g_malloc(t->len + 1); + memcpy(cpy, t->start, t->len); + cpy[t->len] = '\0'; + gsize len = t->len; + enum rspamd_utf8_normalise_result res = rspamd_normalise_unicode_inplace(cpy, &len); + + if (is_text) { + struct rspamd_lua_text *out = lua_new_text(L, cpy, len, FALSE); + out->flags |= RSPAMD_TEXT_FLAG_OWN; + } + else { + lua_pushlstring(L, cpy, len); + g_free(cpy); + } + + lua_pushinteger(L, res); + + return 2; +} + +static gint +lua_util_transliterate(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t; + t = lua_check_text_or_string(L, 1); + + if (!t) { + return luaL_error(L, "invalid arguments"); + } + + gsize outlen; + char *transliterated = rspamd_utf8_transliterate(t->start, t->len, &outlen); + lua_new_text(L, transliterated, outlen, TRUE); + + return 1; +} + +static gint +lua_util_strequal_caseless(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t1, *t2; + gint ret = -1; + + t1 = lua_check_text_or_string(L, 1); + t2 = lua_check_text_or_string(L, 2); + + if (t1 && t2) { + + if (t1->len == t2->len) { + ret = rspamd_lc_cmp(t1->start, t2->start, t1->len); + } + else { + ret = t1->len - t2->len; + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + lua_pushboolean(L, (ret == 0) ? true : false); + return 1; +} + +static gint +lua_util_strequal_caseless_utf8(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t1, *t2; + gint ret = -1; + + t1 = lua_check_text_or_string(L, 1); + t2 = lua_check_text_or_string(L, 2); + + if (t1 && t2) { + ret = rspamd_utf8_strcmp_sizes(t1->start, t1->len, t2->start, t2->len); + } + else { + return luaL_error(L, "invalid arguments"); + } + + lua_pushboolean(L, (ret == 0) ? true : false); + + return 1; +} + +static gint +lua_util_get_ticks(lua_State *L) +{ + LUA_TRACE_POINT; + gdouble ticks; + gboolean rdtsc = FALSE; + + if (lua_isboolean(L, 1)) { + rdtsc = lua_toboolean(L, 1); + } + + ticks = rspamd_get_ticks(rdtsc); + lua_pushnumber(L, ticks); + + return 1; +} + +static gint +lua_util_get_time(lua_State *L) +{ + LUA_TRACE_POINT; + + lua_pushnumber(L, ev_time()); + + return 1; +} + +static gint +lua_util_time_to_string(lua_State *L) +{ + LUA_TRACE_POINT; + gdouble seconds; + char timebuf[128]; + + if (lua_isnumber(L, 1)) { + seconds = lua_tonumber(L, 1); + } + else { + seconds = ev_time(); + } + + rspamd_http_date_format(timebuf, sizeof(timebuf), seconds); + lua_pushstring(L, timebuf); + + return 1; +} + +static gint +lua_util_stat(lua_State *L) +{ + LUA_TRACE_POINT; + const gchar *fpath; + struct stat st; + + fpath = luaL_checkstring(L, 1); + + if (fpath) { + if (stat(fpath, &st) == -1) { + lua_pushstring(L, strerror(errno)); + lua_pushnil(L); + } + else { + lua_pushnil(L); + lua_createtable(L, 0, 3); + + lua_pushstring(L, "size"); + lua_pushinteger(L, st.st_size); + lua_settable(L, -3); + + lua_pushstring(L, "mtime"); + lua_pushinteger(L, st.st_mtime); + lua_settable(L, -3); + + lua_pushstring(L, "type"); + if (S_ISREG(st.st_mode)) { + lua_pushstring(L, "regular"); + } + else if (S_ISDIR(st.st_mode)) { + lua_pushstring(L, "directory"); + } + else { + lua_pushstring(L, "special"); + } + lua_settable(L, -3); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 2; +} + +static gint +lua_util_unlink(lua_State *L) +{ + LUA_TRACE_POINT; + const gchar *fpath; + gint ret; + + fpath = luaL_checkstring(L, 1); + + if (fpath) { + ret = unlink(fpath); + + if (ret == -1) { + lua_pushboolean(L, false); + lua_pushstring(L, strerror(errno)); + + return 2; + } + + lua_pushboolean(L, true); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_util_lock_file(lua_State *L) +{ + LUA_TRACE_POINT; + const gchar *fpath; + gint fd = -1; + gboolean own = FALSE; + +#if !HAVE_FLOCK + struct flock fl = { + .l_type = F_WRLCK, + .l_whence = SEEK_SET, + .l_start = 0, + .l_len = 0}; +#endif + + fpath = luaL_checkstring(L, 1); + + if (fpath) { + if (lua_isnumber(L, 2)) { + fd = lua_tointeger(L, 2); + } + else { + fd = open(fpath, O_RDONLY); + own = TRUE; + } + + if (fd == -1) { + lua_pushnil(L); + lua_pushstring(L, strerror(errno)); + + return 2; + } + +#if HAVE_FLOCK + if (flock(fd, LOCK_EX) == -1) { +#else + if (fcntl(fd, F_SETLKW, &fl) == -1) { +#endif + lua_pushnil(L); + lua_pushstring(L, strerror(errno)); + + if (own) { + close(fd); + } + + return 2; + } + + lua_pushinteger(L, fd); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_util_unlock_file(lua_State *L) +{ + LUA_TRACE_POINT; + gint fd = -1, ret, serrno; + gboolean do_close = TRUE; + +#if !HAVE_FLOCK + struct flock fl = { + .l_type = F_UNLCK, + .l_whence = SEEK_SET, + .l_start = 0, + .l_len = 0}; +#endif + + if (lua_isnumber(L, 1)) { + fd = lua_tointeger(L, 1); + + if (lua_isboolean(L, 2)) { + do_close = lua_toboolean(L, 2); + } + +#if HAVE_FLOCK + ret = flock(fd, LOCK_UN); +#else + ret = fcntl(fd, F_SETLKW, &fl); +#endif + + if (do_close) { + serrno = errno; + close(fd); + errno = serrno; + } + + if (ret == -1) { + lua_pushboolean(L, false); + lua_pushstring(L, strerror(errno)); + + return 2; + } + + lua_pushboolean(L, true); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_util_create_file(lua_State *L) +{ + LUA_TRACE_POINT; + gint fd, mode = 00644; + const gchar *fpath; + + fpath = luaL_checkstring(L, 1); + + if (fpath) { + if (lua_isnumber(L, 2)) { + mode = lua_tointeger(L, 2); + } + + fd = rspamd_file_xopen(fpath, O_RDWR | O_CREAT | O_TRUNC, mode, 0); + + if (fd == -1) { + lua_pushnil(L); + lua_pushstring(L, strerror(errno)); + + return 2; + } + + lua_pushinteger(L, fd); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_util_close_file(lua_State *L) +{ + LUA_TRACE_POINT; + gint fd = -1; + + if (lua_isnumber(L, 1)) { + fd = lua_tointeger(L, 1); + + if (close(fd) == -1) { + lua_pushboolean(L, false); + lua_pushstring(L, strerror(errno)); + + return 2; + } + + lua_pushboolean(L, true); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_util_random_hex(lua_State *L) +{ + LUA_TRACE_POINT; + gchar *buf; + gint buflen; + + buflen = lua_tointeger(L, 1); + + if (buflen <= 0) { + return luaL_error(L, "invalid arguments"); + } + + buf = g_malloc(buflen); + rspamd_random_hex(buf, buflen); + lua_pushlstring(L, buf, buflen); + g_free(buf); + + return 1; +} + +static gint +lua_util_zstd_compress(lua_State *L) +{ + return lua_compress_zstd_compress(L); +} + +static gint +lua_util_zstd_decompress(lua_State *L) +{ + return lua_compress_zstd_decompress(L); +} + +static gint +lua_util_gzip_compress(lua_State *L) +{ + return lua_compress_zlib_compress(L); +} + +static gint +lua_util_gzip_decompress(lua_State *L) +{ + return lua_compress_zlib_decompress(L, true); +} + +static gint +lua_util_inflate(lua_State *L) +{ + return lua_compress_zlib_decompress(L, false); +} + +static gint +lua_util_normalize_prob(lua_State *L) +{ + LUA_TRACE_POINT; + gdouble x, bias = 0.5; + + x = lua_tonumber(L, 1); + + if (lua_type(L, 2) == LUA_TNUMBER) { + bias = lua_tonumber(L, 2); + } + + lua_pushnumber(L, rspamd_normalize_probability(x, bias)); + + return 1; +} + +static gint +lua_util_caseless_hash(lua_State *L) +{ + LUA_TRACE_POINT; + guint64 seed = 0xdeadbabe, h; + struct rspamd_lua_text *t = NULL; + gint64 *r; + + t = lua_check_text_or_string(L, 1); + + if (t == NULL || t->start == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (lua_type(L, 2) == LUA_TNUMBER) { + seed = lua_tointeger(L, 2); + } + else if (lua_type(L, 2) == LUA_TUSERDATA) { + seed = lua_check_int64(L, 2); + } + + h = rspamd_icase_hash(t->start, t->len, seed); + r = lua_newuserdata(L, sizeof(*r)); + *r = h; + rspamd_lua_setclass(L, "rspamd{int64}", -1); + + return 1; +} + +static gint +lua_util_caseless_hash_fast(lua_State *L) +{ + LUA_TRACE_POINT; + guint64 seed = 0xdeadbabe, h; + struct rspamd_lua_text *t = NULL; + union { + guint64 i; + double d; + } u; + + t = lua_check_text_or_string(L, 1); + + if (t == NULL || t->start == NULL) { + return luaL_error(L, "invalid arguments"); + } + + if (lua_type(L, 2) == LUA_TNUMBER) { + seed = lua_tointeger(L, 2); + } + else if (lua_type(L, 2) == LUA_TUSERDATA) { + seed = lua_check_int64(L, 2); + } + + /* + * Here, we loose entropy from 64 bits to 52 bits roughly, however, + * it is still fine for practical applications + */ + + h = rspamd_icase_hash(t->start, t->len, seed); + u.i = G_GUINT64_CONSTANT(0x3FF) << 52 | h >> 12; + lua_pushnumber(L, u.d - 1.0); + + return 1; +} + +static gint +lua_util_is_utf_spoofed(lua_State *L) +{ + LUA_TRACE_POINT; + gsize l1, l2; + gint ret, nres = 2; + const gchar *s1 = lua_tolstring(L, 1, &l1), + *s2 = lua_tolstring(L, 2, &l2); + static USpoofChecker *spc, *spc_sgl; + UErrorCode uc_err = U_ZERO_ERROR; + + if (s1 && s2) { + if (spc == NULL) { + spc = uspoof_open(&uc_err); + + if (uc_err != U_ZERO_ERROR) { + msg_err("cannot init spoof checker: %s", u_errorName(uc_err)); + lua_pushboolean(L, false); + + return 1; + } + } + + ret = uspoof_areConfusableUTF8(spc, s1, l1, s2, l2, &uc_err); + } + else if (s1) { + /* We have just s1, not s2 */ + if (spc_sgl == NULL) { + spc_sgl = uspoof_open(&uc_err); + + if (uc_err != U_ZERO_ERROR) { + msg_err("cannot init spoof checker: %s", u_errorName(uc_err)); + lua_pushboolean(L, false); + + return 1; + } + + uspoof_setChecks(spc_sgl, + USPOOF_INVISIBLE | USPOOF_MIXED_SCRIPT_CONFUSABLE | USPOOF_ANY_CASE, + &uc_err); + if (uc_err != U_ZERO_ERROR) { + msg_err("Cannot set proper checks for uspoof: %s", u_errorName(uc_err)); + lua_pushboolean(L, false); + uspoof_close(spc); + return 1; + } + } + + ret = uspoof_checkUTF8(spc_sgl, s1, l1, NULL, &uc_err); + } + else { + return luaL_error(L, "invalid arguments"); + } + + lua_pushboolean(L, !!(ret != 0)); + + switch (ret) { + case 0: + nres = 1; + break; + case USPOOF_SINGLE_SCRIPT_CONFUSABLE: + lua_pushstring(L, "single"); + break; + case USPOOF_MIXED_SCRIPT_CONFUSABLE: + lua_pushstring(L, "multiple"); + break; + case USPOOF_WHOLE_SCRIPT_CONFUSABLE: + lua_pushstring(L, "whole"); + break; + default: + lua_pushstring(L, "unknown"); + break; + } + + return nres; +} + +static gint +lua_util_is_utf_mixed_script(lua_State *L) +{ + LUA_TRACE_POINT; + gsize len_of_string; + const guchar *string_to_check = lua_tolstring(L, 1, &len_of_string); + UScriptCode last_script_code = USCRIPT_INVALID_CODE; + UErrorCode uc_err = U_ZERO_ERROR; + + if (string_to_check) { + uint index = 0; + UChar32 char_to_check = 0; + + while (index < len_of_string) { + U8_NEXT(string_to_check, index, len_of_string, char_to_check); + + if (char_to_check < 0) { + return luaL_error(L, "passed string is not valid utf"); + } + + UScriptCode current_script_code = uscript_getScript(char_to_check, &uc_err); + + if (uc_err != U_ZERO_ERROR) { + msg_err("cannot get unicode script for character, error: %s", + u_errorName(uc_err)); + lua_pushboolean(L, false); + + return 1; + } + + if (current_script_code != USCRIPT_COMMON && + current_script_code != USCRIPT_INHERITED) { + + if (last_script_code == USCRIPT_INVALID_CODE) { + last_script_code = current_script_code; + } + else { + if (last_script_code != current_script_code) { + lua_pushboolean(L, true); + + return 1; + } + } + } + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + lua_pushboolean(L, false); + + return 1; +} + +static gint +lua_util_get_string_stats(lua_State *L) +{ + LUA_TRACE_POINT; + gint num_of_digits = 0, num_of_letters = 0; + struct rspamd_lua_text *t; + + t = lua_check_text_or_string(L, 1); + + if (t) { + const gchar *p = t->start, *end = t->start + t->len; + while (p < end) { + if (g_ascii_isdigit(*p)) { + num_of_digits++; + } + else if (g_ascii_isalpha(*p)) { + num_of_letters++; + } + p++; + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + lua_createtable(L, 0, 2); + lua_pushstring(L, "digits"); + lua_pushinteger(L, num_of_digits); + lua_settable(L, -3); + lua_pushstring(L, "letters"); + lua_pushinteger(L, num_of_letters); + lua_settable(L, -3); + + return 1; +} + + +static gint +lua_util_is_utf_outside_range(lua_State *L) +{ + LUA_TRACE_POINT; + gint ret; + struct rspamd_lua_text *t = lua_check_text_or_string(L, 1); + guint32 range_start = lua_tointeger(L, 2); + guint32 range_end = lua_tointeger(L, 3); + + static rspamd_lru_hash_t *validators; + + if (validators == NULL) { + validators = rspamd_lru_hash_new_full(16, g_free, (GDestroyNotify) uspoof_close, g_int64_hash, g_int64_equal); + } + + if (t) { + guint64 hash_key = (guint64) range_end << 32 || range_start; + + USpoofChecker *validator = rspamd_lru_hash_lookup(validators, &hash_key, 0); + + UErrorCode uc_err = U_ZERO_ERROR; + + if (validator == NULL) { + USet *allowed_chars; + guint64 *creation_hash_key = g_malloc(sizeof(guint64)); + *creation_hash_key = hash_key; + + validator = uspoof_open(&uc_err); + if (uc_err != U_ZERO_ERROR) { + msg_err("cannot init spoof checker: %s", u_errorName(uc_err)); + lua_pushboolean(L, false); + uspoof_close(validator); + g_free(creation_hash_key); + return 1; + } + + allowed_chars = uset_openEmpty(); + uset_addRange(allowed_chars, range_start, range_end); + uspoof_setAllowedChars(validator, allowed_chars, &uc_err); + + uspoof_setChecks(validator, + USPOOF_CHAR_LIMIT | USPOOF_ANY_CASE, &uc_err); + + uset_close(allowed_chars); + + if (uc_err != U_ZERO_ERROR) { + msg_err("Cannot configure uspoof: %s", u_errorName(uc_err)); + lua_pushboolean(L, false); + uspoof_close(validator); + g_free(creation_hash_key); + return 1; + } + + rspamd_lru_hash_insert(validators, creation_hash_key, validator, + 0, 0); + } + + gint32 pos = 0; + ret = uspoof_checkUTF8(validator, t->start, t->len, &pos, + &uc_err); + } + else { + return luaL_error(L, "invalid arguments"); + } + + lua_pushboolean(L, !!(ret != 0)); + + return 1; +} + + +static gint +lua_util_get_hostname(lua_State *L) +{ + LUA_TRACE_POINT; + gchar *hostbuf; + gsize hostlen; + + hostlen = sysconf(_SC_HOST_NAME_MAX); + + if (hostlen <= 0) { + hostlen = 256; + } + else { + hostlen++; + } + + hostbuf = g_alloca(hostlen); + memset(hostbuf, 0, hostlen); + gethostname(hostbuf, hostlen - 1); + + lua_pushstring(L, hostbuf); + + return 1; +} + +static gint +lua_util_parse_content_type(lua_State *L) +{ + return lua_parsers_parse_content_type(L); +} + + +static gint +lua_util_mime_header_encode(lua_State *L) +{ + LUA_TRACE_POINT; + gsize len; + const gchar *hdr = luaL_checklstring(L, 1, &len); + gchar *encoded; + + if (!hdr) { + return luaL_error(L, "invalid arguments"); + } + + encoded = rspamd_mime_header_encode(hdr, len); + lua_pushstring(L, encoded); + g_free(encoded); + + return 1; +} + +static gint +lua_util_is_valid_utf8(lua_State *L) +{ + LUA_TRACE_POINT; + struct rspamd_lua_text *t = lua_check_text_or_string(L, 1); + + if (t) { + goffset error_offset = rspamd_fast_utf8_validate(t->start, t->len); + + if (error_offset == 0) { + lua_pushboolean(L, true); + } + else { + lua_pushboolean(L, false); + lua_pushinteger(L, error_offset); + + return 2; + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_util_has_obscured_unicode(lua_State *L) +{ + LUA_TRACE_POINT; + gint32 i = 0, prev_i; + UChar32 uc; + + struct rspamd_lua_text *t = lua_check_text_or_string(L, 1); + + while (i < t->len) { + prev_i = i; + U8_NEXT(t->start, i, t->len, uc); + + if (uc > 0) { + if (IS_OBSCURED_CHAR(uc)) { + lua_pushboolean(L, true); + lua_pushinteger(L, uc); /* Character */ + lua_pushinteger(L, prev_i); /* Offset */ + + return 3; + } + } + } + + lua_pushboolean(L, false); + + return 1; +} + +static gint +lua_util_readline(lua_State *L) +{ + LUA_TRACE_POINT; + const gchar *prompt = ""; + gchar *input; + + if (lua_type(L, 1) == LUA_TSTRING) { + prompt = lua_tostring(L, 1); + } +#ifdef WITH_LUA_REPL + static Replxx *rx_instance = NULL; + + if (rx_instance == NULL) { + rx_instance = replxx_init(); + /* See https://github.com/AmokHuginnsson/replxx/issues/137 */ + replxx_history_add(rx_instance, ""); + } + + input = (gchar *) replxx_input(rx_instance, prompt); + + if (input) { + lua_pushstring(L, input); + } + else { + lua_pushnil(L); + } +#else + size_t linecap = 0; + ssize_t linelen; + + fprintf(stdout, "%s ", prompt); + + linelen = getline(&input, &linecap, stdin); + + if (linelen > 0) { + if (input[linelen - 1] == '\n') { + linelen--; + } + + lua_pushlstring(L, input, linelen); + free(input); + } + else { + lua_pushnil(L); + } +#endif + + return 1; +} + +static gint +lua_util_readpassphrase(lua_State *L) +{ + LUA_TRACE_POINT; + gchar test_password[8192]; + gsize r; + + r = rspamd_read_passphrase(test_password, sizeof(test_password), 0, NULL); + + if (r > 0) { + lua_pushlstring(L, test_password, r); + } + else { + lua_pushnil(L); + } + + /* In fact, we still pass it to Lua which is not very safe */ + rspamd_explicit_memzero(test_password, sizeof(test_password)); + + return 1; +} + +static gint +lua_util_file_exists(lua_State *L) +{ + LUA_TRACE_POINT; + const gchar *fname = luaL_checkstring(L, 1); + gint serrno; + + if (fname) { + if (access(fname, R_OK) == -1) { + serrno = errno; + lua_pushboolean(L, false); + lua_pushstring(L, strerror(serrno)); + } + else { + lua_pushboolean(L, true); + lua_pushnil(L); + } + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 2; +} + +static gint +lua_util_mkdir(lua_State *L) +{ + LUA_TRACE_POINT; + const gchar *dname = luaL_checkstring(L, 1); + gboolean recursive = FALSE; + gint r = -1; + + if (dname) { + if (lua_isboolean(L, 2)) { + recursive = lua_toboolean(L, 2); + } + + if (recursive) { + char path[PATH_MAX]; + gsize len, i; + + len = rspamd_strlcpy(path, dname, sizeof(path)); + + /* Strip last / */ + if (path[len - 1] == '/') { + path[len - 1] = '\0'; + len--; + } + + for (i = 1; i < len; i++) { + if (path[i] == '/') { + path[i] = '\0'; + + errno = 0; + r = mkdir(path, S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH); + + if (r == -1 && errno != EEXIST) { + break; + } + + path[i] = '/'; + } + } + + /* Final path component */ + r = mkdir(path, S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH); + } + else { + r = mkdir(dname, S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH); + } + + if (r == -1 && errno != EEXIST) { + lua_pushboolean(L, false); + lua_pushstring(L, strerror(errno)); + + return 2; + } + + lua_pushboolean(L, true); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + + +static gint +lua_util_umask(lua_State *L) +{ + LUA_TRACE_POINT; + mode_t mask = 0, old; + + if (lua_type(L, 1) == LUA_TSTRING) { + const gchar *str = lua_tostring(L, 1); + + if (str[0] == '0') { + /* e.g. '022' */ + mask = strtol(str, NULL, 8); + } + else { + /* XXX: implement modestring parsing at some point */ + return luaL_error(L, "invalid arguments"); + } + } + else if (lua_type(L, 1) == LUA_TNUMBER) { + mask = lua_tointeger(L, 1); + } + else { + return luaL_error(L, "invalid arguments"); + } + + old = umask(mask); + + lua_pushinteger(L, old); + + return 1; +} + +static gint +lua_util_isatty(lua_State *L) +{ + LUA_TRACE_POINT; + if (isatty(STDOUT_FILENO)) { + lua_pushboolean(L, true); + } + else { + lua_pushboolean(L, false); + } + + return 1; +} + +/* Backport from Lua 5.3 */ + +/****************************************************************************** +* Copyright (C) 1994-2016 Lua.org, PUC-Rio. +* +* Permission is hereby granted, free of charge, to any person obtaining +* a copy of this software and associated documentation files (the +* "Software"), to deal in the Software without restriction, including +* without limitation the rights to use, copy, modify, merge, publish, +* distribute, sublicense, and/or sell copies of the Software, and to +* permit persons to whom the Software is furnished to do so, subject to +* the following conditions: +* +* The above copyright notice and this permission notice shall be +* included in all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +******************************************************************************/ + +/* +** {====================================================== +** PACK/UNPACK +** ======================================================= +*/ + + +/* value used for padding */ +#if !defined(LUA_PACKPADBYTE) +#define LUA_PACKPADBYTE 0x00 +#endif + +/* maximum size for the binary representation of an integer */ +#define MAXINTSIZE 16 + +/* number of bits in a character */ +#define NB CHAR_BIT + +/* mask for one character (NB 1's) */ +#define MC ((1 << NB) - 1) + +/* size of a lua_Integer */ +#define SZINT ((int) sizeof(lua_Integer)) + +#define MAX_SIZET ((size_t) (~(size_t) 0)) + +#define MAXSIZE \ + (sizeof(size_t) < sizeof(int) ? MAX_SIZET : (size_t) (INT_MAX)) + + +/* dummy union to get native endianness */ +static const union { + int dummy; + char little; /* true if machine is little endian */ +} nativeendian = {1}; + + +/* dummy structure to get native alignment requirements */ +struct cD { + char c; + union { + double d; + void *p; + lua_Integer i; + lua_Number n; + } u; +}; + +#define MAXALIGN (offsetof(struct cD, u)) + +/* +** Union for serializing floats +*/ +typedef union Ftypes { + float f; + double d; + lua_Number n; + char buff[5 * sizeof(lua_Number)]; /* enough for any float type */ +} Ftypes; + + +/* +** information to pack/unpack stuff +*/ +typedef struct Header { + lua_State *L; + int islittle; + int maxalign; +} Header; + +/* +** options for pack/unpack +*/ +typedef enum KOption { + Kint, /* signed integers */ + Kuint, /* unsigned integers */ + Kfloat, /* floating-point numbers */ + Kchar, /* fixed-length strings */ + Kstring, /* strings with prefixed length */ + Kzstr, /* zero-terminated strings */ + Kpadding, /* padding */ + Kpaddalign, /* padding for alignment */ + Knop /* no-op (configuration or spaces) */ +} KOption; + +#if LUA_VERSION_NUM <= 502 +#define lua_Unsigned size_t +#endif + +#if LUA_VERSION_NUM < 502 + +#define lua_Unsigned size_t + +typedef struct luaL_Buffer_53 { + luaL_Buffer b; /* make incorrect code crash! */ + char *ptr; + size_t nelems; + size_t capacity; + lua_State *L2; +} luaL_Buffer_53; + +#define luaL_Buffer luaL_Buffer_53 +#define COMPAT53_PREFIX lua +#undef COMPAT53_API + +#if defined(__GNUC__) || defined(__clang__) +#define COMPAT53_API __attribute__((__unused__)) static +#else +#define COMPAT53_API static +#endif + +#define COMPAT53_CONCAT_HELPER(a, b) a##b +#define COMPAT53_CONCAT(a, b) COMPAT53_CONCAT_HELPER(a, b) + +#define luaL_buffinit COMPAT53_CONCAT(COMPAT53_PREFIX, _buffinit_53) +COMPAT53_API void luaL_buffinit(lua_State *L, luaL_Buffer_53 *B); +#define luaL_prepbuffsize COMPAT53_CONCAT(COMPAT53_PREFIX, _prepbufsize_53) +COMPAT53_API char *luaL_prepbuffsize(luaL_Buffer_53 *B, size_t s); +#define luaL_addlstring COMPAT53_CONCAT(COMPAT53_PREFIX, _addlstring_53) +COMPAT53_API void luaL_addlstring(luaL_Buffer_53 *B, const char *s, size_t l); +#define luaL_addvalue COMPAT53_CONCAT(COMPAT53_PREFIX, _addvalue_53) +COMPAT53_API void luaL_addvalue(luaL_Buffer_53 *B); +#define luaL_pushresult COMPAT53_CONCAT(COMPAT53_PREFIX, _pushresult_53) +COMPAT53_API void luaL_pushresult(luaL_Buffer_53 *B); +#undef luaL_buffinitsize +#define luaL_buffinitsize(L, B, s) \ + (luaL_buffinit(L, B), luaL_prepbuffsize(B, s)) + +#undef luaL_prepbuffer +#define luaL_prepbuffer(B) \ + luaL_prepbuffsize(B, LUAL_BUFFERSIZE) + +#undef luaL_addchar +#define luaL_addchar(B, c) \ + ((void) ((B)->nelems < (B)->capacity || luaL_prepbuffsize(B, 1)), \ + ((B)->ptr[(B)->nelems++] = (c))) + +#undef luaL_addsize +#define luaL_addsize(B, s) \ + ((B)->nelems += (s)) + +#undef luaL_addstring +#define luaL_addstring(B, s) \ + luaL_addlstring(B, s, strlen(s)) + +#undef luaL_pushresultsize +#define luaL_pushresultsize(B, s) \ + (luaL_addsize(B, s), luaL_pushresult(B)) + +COMPAT53_API void +luaL_buffinit(lua_State *L, luaL_Buffer_53 *B) +{ + /* make it crash if used via pointer to a 5.1-style luaL_Buffer */ + B->b.p = NULL; + B->b.L = NULL; + B->b.lvl = 0; + /* reuse the buffer from the 5.1-style luaL_Buffer though! */ + B->ptr = B->b.buffer; + B->nelems = 0; + B->capacity = LUAL_BUFFERSIZE; + B->L2 = L; +} + + +COMPAT53_API char * +luaL_prepbuffsize(luaL_Buffer_53 *B, size_t s) +{ + if (B->capacity - B->nelems < s) { /* needs to grow */ + char *newptr = NULL; + size_t newcap = B->capacity * 2; + if (newcap - B->nelems < s) + newcap = B->nelems + s; + if (newcap < B->capacity) /* overflow */ + luaL_error(B->L2, "buffer too large"); + newptr = (char *) lua_newuserdata(B->L2, newcap); + memcpy(newptr, B->ptr, B->nelems); + if (B->ptr != B->b.buffer) { + lua_replace(B->L2, -2); /* remove old buffer */ + } + B->ptr = newptr; + B->capacity = newcap; + } + return B->ptr + B->nelems; +} + + +COMPAT53_API void +luaL_addlstring(luaL_Buffer_53 *B, const char *s, size_t l) +{ + memcpy(luaL_prepbuffsize(B, l), s, l); + luaL_addsize(B, l); +} + + +COMPAT53_API void +luaL_addvalue(luaL_Buffer_53 *B) +{ + size_t len = 0; + const char *s = lua_tolstring(B->L2, -1, &len); + if (!s) + luaL_error(B->L2, "cannot convert value to string"); + if (B->ptr != B->b.buffer) { + lua_insert(B->L2, -2); /* userdata buffer must be at stack top */ + } + luaL_addlstring(B, s, len); + lua_remove(B->L2, B->ptr != B->b.buffer ? -2 : -1); +} + + +COMPAT53_API void +luaL_pushresult(luaL_Buffer_53 *B) +{ + lua_pushlstring(B->L2, B->ptr, B->nelems); + if (B->ptr != B->b.buffer) { + lua_replace(B->L2, -2); /* remove userdata buffer */ + } +} + +#endif + +/* +** Read an integer numeral from string 'fmt' or return 'df' if +** there is no numeral +*/ +static int +digit(int c) +{ + return '0' <= c && c <= '9'; +} + +static int +getnum(const char **fmt, int df) +{ + if (!digit(**fmt)) /* no number? */ + return df; /* return default value */ + else { + int a = 0; + do { + a = a * 10 + (*((*fmt)++) - '0'); + } while (digit(**fmt) && a <= ((int) MAXSIZE - 9) / 10); + return a; + } +} + + +/* +** Read an integer numeral and raises an error if it is larger +** than the maximum size for integers. +*/ +static int +getnumlimit(Header *h, const char **fmt, int df) +{ + int sz = getnum(fmt, df); + if (sz > MAXINTSIZE || sz <= 0) + luaL_error(h->L, "integral size (%d) out of limits [1,%d]", + sz, MAXINTSIZE); + return sz; +} + + +/* +** Initialize Header +*/ +static void +initheader(lua_State *L, Header *h) +{ + h->L = L; + h->islittle = nativeendian.little; + h->maxalign = 1; +} + + +/* +** Read and classify next option. 'size' is filled with option's size. +*/ +static KOption +getoption(Header *h, const char **fmt, int *size) +{ + int opt = *((*fmt)++); + *size = 0; /* default */ + switch (opt) { + case 'b': + *size = sizeof(char); + return Kint; + case 'B': + *size = sizeof(char); + return Kuint; + case 'h': + *size = sizeof(short); + return Kint; + case 'H': + *size = sizeof(short); + return Kuint; + case 'l': + *size = sizeof(long); + return Kint; + case 'L': + *size = sizeof(long); + return Kuint; + case 'j': + *size = sizeof(lua_Integer); + return Kint; + case 'J': + *size = sizeof(lua_Integer); + return Kuint; + case 'T': + *size = sizeof(size_t); + return Kuint; + case 'f': + *size = sizeof(float); + return Kfloat; + case 'd': + *size = sizeof(double); + return Kfloat; + case 'n': + *size = sizeof(lua_Number); + return Kfloat; + case 'i': + *size = getnumlimit(h, fmt, sizeof(int)); + return Kint; + case 'I': + *size = getnumlimit(h, fmt, sizeof(int)); + return Kuint; + case 's': + *size = getnumlimit(h, fmt, sizeof(size_t)); + return Kstring; + case 'c': + *size = getnum(fmt, -1); + if (*size == -1) + luaL_error(h->L, "missing size for format option 'c'"); + return Kchar; + case 'z': + return Kzstr; + case 'x': + *size = 1; + return Kpadding; + case 'X': + return Kpaddalign; + case ' ': + break; + case '<': + h->islittle = 1; + break; + case '>': + h->islittle = 0; + break; + case '=': + h->islittle = nativeendian.little; + break; + case '!': + h->maxalign = getnumlimit(h, fmt, MAXALIGN); + break; + default: + luaL_error(h->L, "invalid format option '%c'", opt); + } + return Knop; +} + + +/* +** Read, classify, and fill other details about the next option. +** 'psize' is filled with option's size, 'notoalign' with its +** alignment requirements. +** Local variable 'size' gets the size to be aligned. (Kpadal option +** always gets its full alignment, other options are limited by +** the maximum alignment ('maxalign'). Kchar option needs no alignment +** despite its size. +*/ +static KOption +getdetails(Header *h, size_t totalsize, + const char **fmt, int *psize, int *ntoalign) +{ + KOption opt = getoption(h, fmt, psize); + int align = *psize; /* usually, alignment follows size */ + if (opt == Kpaddalign) { /* 'X' gets alignment from following option */ + if (**fmt == '\0' || getoption(h, fmt, &align) == Kchar || align == 0) + luaL_argerror(h->L, 1, "invalid next option for option 'X'"); + } + if (align <= 1 || opt == Kchar) /* need no alignment? */ + *ntoalign = 0; + else { + if (align > h->maxalign) /* enforce maximum alignment */ + align = h->maxalign; + if ((align & (align - 1)) != 0) /* is 'align' not a power of 2? */ + luaL_argerror(h->L, 1, "format asks for alignment not power of 2"); + *ntoalign = (align - (int) (totalsize & (align - 1))) & (align - 1); + } + return opt; +} + + +/* +** Pack integer 'n' with 'size' bytes and 'islittle' endianness. +** The final 'if' handles the case when 'size' is larger than +** the size of a Lua integer, correcting the extra sign-extension +** bytes if necessary (by default they would be zeros). +*/ +static void +packint(luaL_Buffer *b, lua_Unsigned n, + int islittle, int size, int neg) +{ + char *buff = luaL_prepbuffsize(b, size); + int i; + buff[islittle ? 0 : size - 1] = (char) (n & MC); /* first byte */ + for (i = 1; i < size; i++) { + n >>= NB; + buff[islittle ? i : size - 1 - i] = (char) (n & MC); + } + if (neg && size > SZINT) { /* negative number need sign extension? */ + for (i = SZINT; i < size; i++) /* correct extra bytes */ + buff[islittle ? i : size - 1 - i] = (char) MC; + } + luaL_addsize(b, size); /* add result to buffer */ +} + + +/* +** Copy 'size' bytes from 'src' to 'dest', correcting endianness if +** given 'islittle' is different from native endianness. +*/ +static void +copywithendian(volatile char *dest, volatile const char *src, + int size, int islittle) +{ + if (islittle == nativeendian.little) { + while (size-- != 0) + *(dest++) = *(src++); + } + else { + dest += size - 1; + while (size-- != 0) + *(dest--) = *(src++); + } +} + + +static int +lua_util_pack(lua_State *L) +{ + luaL_Buffer b; + Header h; + const char *fmt = luaL_checkstring(L, 1); /* format string */ + int arg = 1; /* current argument to pack */ + size_t totalsize = 0; /* accumulate total size of result */ + initheader(L, &h); + lua_pushnil(L); /* mark to separate arguments from string buffer */ + luaL_buffinit(L, &b); + + while (*fmt != '\0') { + int size, ntoalign; + KOption opt = getdetails(&h, totalsize, &fmt, &size, &ntoalign); + totalsize += ntoalign + size; + while (ntoalign-- > 0) + luaL_addchar(&b, LUA_PACKPADBYTE); /* fill alignment */ + arg++; + switch (opt) { + case Kint: { /* signed integers */ + lua_Integer n = luaL_checkinteger(L, arg); + if (size < SZINT) { /* need overflow check? */ + lua_Integer lim = (lua_Integer) 1 << ((size * NB) - 1); + luaL_argcheck(L, -lim <= n && n < lim, arg, "integer overflow"); + } + packint(&b, (lua_Unsigned) n, h.islittle, size, (n < 0)); + break; + } + case Kuint: { /* unsigned integers */ + lua_Integer n = luaL_checkinteger(L, arg); + if (size < SZINT) /* need overflow check? */ + luaL_argcheck(L, + (lua_Unsigned) n < ((lua_Unsigned) 1 << (size * NB)), + arg, + "unsigned overflow"); + packint(&b, (lua_Unsigned) n, h.islittle, size, 0); + break; + } + case Kfloat: { /* floating-point options */ + volatile Ftypes u; + char *buff = luaL_prepbuffsize(&b, size); + lua_Number n = luaL_checknumber(L, arg); /* get argument */ + if (size == sizeof(u.f)) + u.f = (float) n; /* copy it into 'u' */ + else if (size == sizeof(u.d)) + u.d = (double) n; + else + u.n = n; + /* move 'u' to final result, correcting endianness if needed */ + copywithendian(buff, u.buff, size, h.islittle); + luaL_addsize(&b, size); + break; + } + case Kchar: { /* fixed-size string */ + size_t len; + const char *s = luaL_checklstring(L, arg, &len); + if ((size_t) size <= + len) /* string larger than (or equal to) needed? */ + luaL_addlstring(&b, + s, + size); /* truncate string to asked size */ + else { /* string smaller than needed */ + luaL_addlstring(&b, s, len); /* add it all */ + while (len++ < (size_t) size) /* pad extra space */ + luaL_addchar(&b, LUA_PACKPADBYTE); + } + break; + } + case Kstring: { /* strings with length count */ + size_t len; + const char *s = luaL_checklstring(L, arg, &len); + luaL_argcheck(L, size >= (int) sizeof(size_t) || len < ((size_t) 1 << (size * NB)), + arg, "string length does not fit in given size"); + packint(&b, + (lua_Unsigned) len, + h.islittle, + size, + 0); /* pack length */ + luaL_addlstring(&b, s, len); + totalsize += len; + break; + } + case Kzstr: { /* zero-terminated string */ + size_t len; + const char *s = luaL_checklstring(L, arg, &len); + luaL_argcheck(L, strlen(s) == len, arg, "string contains zeros"); + luaL_addlstring(&b, s, len); + luaL_addchar(&b, '\0'); /* add zero at the end */ + totalsize += len + 1; + break; + } + case Kpadding: + luaL_addchar(&b, LUA_PACKPADBYTE); /* FALLTHROUGH */ + case Kpaddalign: + case Knop: + arg--; /* undo increment */ + break; + } + } + luaL_pushresult(&b); + return 1; +} + + +static int +lua_util_packsize(lua_State *L) +{ + Header h; + const char *fmt = luaL_checkstring(L, 1); /* format string */ + size_t totalsize = 0; /* accumulate total size of result */ + initheader(L, &h); + while (*fmt != '\0') { + int size, ntoalign; + KOption opt = getdetails(&h, totalsize, &fmt, &size, &ntoalign); + size += ntoalign; /* total space used by option */ + luaL_argcheck(L, totalsize <= MAXSIZE - size, 1, + "format result too large"); + totalsize += size; + switch (opt) { + case Kstring: /* strings with length count */ + case Kzstr: /* zero-terminated string */ + luaL_argerror(L, 1, "variable-length format"); + /* call never return, but to avoid warnings: */ /* FALLTHROUGH */ + default: + break; + } + } + lua_pushinteger(L, (lua_Integer) totalsize); + return 1; +} + + +/* +** Unpack an integer with 'size' bytes and 'islittle' endianness. +** If size is smaller than the size of a Lua integer and integer +** is signed, must do sign extension (propagating the sign to the +** higher bits); if size is larger than the size of a Lua integer, +** it must check the unread bytes to see whether they do not cause an +** overflow. +*/ +static lua_Integer +unpackint(lua_State *L, const char *str, + int islittle, int size, int issigned) +{ + lua_Unsigned res = 0; + int i; + int limit = (size <= SZINT) ? size : SZINT; + for (i = limit - 1; i >= 0; i--) { + res <<= NB; + res |= (lua_Unsigned) (unsigned char) str[islittle ? i : size - 1 - i]; + } + if (size < SZINT) { /* real size smaller than lua_Integer? */ + if (issigned) { /* needs sign extension? */ + lua_Unsigned mask = (lua_Unsigned) 1 << (size * NB - 1); + res = ((res ^ mask) - mask); /* do sign extension */ + } + } + else if (size > SZINT) { /* must check unread bytes */ + int mask = (!issigned || (lua_Integer) res >= 0) ? 0 : MC; + for (i = limit; i < size; i++) { + if ((unsigned char) str[islittle ? i : size - 1 - i] != mask) + luaL_error(L, + "%d-byte integer does not fit into Lua Integer", + size); + } + } + return (lua_Integer) res; +} + +static lua_Integer +posrelat(lua_Integer pos, size_t len) +{ + if (pos >= 0) + return pos; + else if (0u - (size_t) pos > len) + return 0; + else + return (lua_Integer) len + pos + 1; +} + +static int +lua_util_unpack(lua_State *L) +{ + Header h; + const char *fmt = luaL_checkstring(L, 1); + size_t ld; + const char *data; + int n = 0; /* number of results */ + + if (lua_type(L, 2) == LUA_TUSERDATA) { + struct rspamd_lua_text *t = lua_check_text(L, 2); + + if (!t) { + return luaL_error(L, "invalid arguments"); + } + + data = t->start; + ld = t->len; + } + else { + data = luaL_checklstring(L, 2, &ld); + } + + size_t pos = (size_t) posrelat(luaL_optinteger(L, 3, 1), ld) - 1; + luaL_argcheck(L, pos <= ld, 3, "initial position out of string"); + + initheader(L, &h); + + while (*fmt != '\0') { + int size, ntoalign; + KOption opt = getdetails(&h, pos, &fmt, &size, &ntoalign); + if ((size_t) ntoalign + size > ~pos || pos + ntoalign + size > ld) + luaL_argerror(L, 2, "data string too short"); + pos += ntoalign; /* skip alignment */ + /* stack space for item + next position */ + luaL_checkstack(L, 2, "too many results"); + n++; + switch (opt) { + case Kint: + case Kuint: { + lua_Integer res = unpackint(L, data + pos, h.islittle, size, + (opt == Kint)); + lua_pushinteger(L, res); + break; + } + case Kfloat: { + volatile Ftypes u; + lua_Number num; + copywithendian(u.buff, data + pos, size, h.islittle); + if (size == sizeof(u.f)) + num = (lua_Number) u.f; + else if (size == sizeof(u.d)) + num = (lua_Number) u.d; + else + num = u.n; + lua_pushnumber(L, num); + break; + } + case Kchar: { + lua_pushlstring(L, data + pos, size); + break; + } + case Kstring: { + size_t len = (size_t) unpackint(L, + data + pos, + h.islittle, + size, + 0); + luaL_argcheck(L, + pos + len + size <= ld, + 2, + "data string too short"); + lua_pushlstring(L, data + pos + size, len); + pos += len; /* skip string */ + break; + } + case Kzstr: { + size_t len = (int) strlen(data + pos); + lua_pushlstring(L, data + pos, len); + pos += len + 1; /* skip string plus final '\0' */ + break; + } + case Kpaddalign: + case Kpadding: + case Knop: + n--; /* undo increment */ + break; + } + pos += size; + } + lua_pushinteger(L, pos + 1); /* next position */ + return n + 1; +} + +static int +lua_util_btc_polymod(lua_State *L) +{ + guint64 c = 1; + + if (lua_type(L, 1) != LUA_TTABLE) { + return luaL_error(L, "invalid arguments"); + } + + for (lua_pushnil(L); lua_next(L, 1); lua_pop(L, 1)) { + guint8 c0 = c >> 35; + guint64 d = lua_tointeger(L, -1); + + c = ((c & 0x07ffffffff) << 5) ^ d; + + if (c0 & 0x01) c ^= 0x98f2bc8e61; + if (c0 & 0x02) c ^= 0x79b76d99e2; + if (c0 & 0x04) c ^= 0xf33e5fb3c4; + if (c0 & 0x08) c ^= 0xae2eabe2a8; + if (c0 & 0x10) c ^= 0x1e4f43e470; + } + + if ((c ^ 1) == 0) { + lua_pushboolean(L, true); + } + else { + lua_pushboolean(L, false); + } + + return 1; +} + +static int +lua_util_parse_smtp_date(lua_State *L) +{ + return lua_parsers_parse_smtp_date(L); +} + + +static gint +lua_load_util(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, utillib_f); + + return 1; +} + +static gint +lua_load_int64(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, int64lib_f); + + return 1; +} + + +void luaopen_util(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{ev_base}", ev_baselib_m); + lua_pop(L, 1); + rspamd_lua_new_class(L, "rspamd{int64}", int64lib_m); + lua_pop(L, 1); + rspamd_lua_add_preload(L, "rspamd_util", lua_load_util); + rspamd_lua_add_preload(L, "rspamd_int64", lua_load_int64); +} + +static int +lua_int64_tostring(lua_State *L) +{ + gint64 n = lua_check_int64(L, 1); + gchar buf[32]; + bool is_signed = false; + + if (lua_isboolean(L, 2)) { + is_signed = lua_toboolean(L, 2); + } + + if (is_signed) { + rspamd_snprintf(buf, sizeof(buf), "%L", n); + } + else { + rspamd_snprintf(buf, sizeof(buf), "%uL", n); + } + lua_pushstring(L, buf); + + return 1; +} + +static int +lua_int64_fromstring(lua_State *L) +{ + struct rspamd_lua_text *t = lua_check_text_or_string(L, 1); + + if (t && t->len > 0) { + guint64 u64; + const char *p = t->start; + gsize len = t->len; + bool neg = false; + + /* + * We use complicated negation to allow both signed and unsinged values to + * fit into result. + * So we read int64 as unsigned and copy it to signed number. + * If we wanted u64 this allows to have the same memory representation of + * signed and unsigned. + * If we wanted signed i64 we still can use -1000500 and it will be parsed + * properly + */ + if (*p == '-') { + neg = true; + p++; + len--; + } + if (!rspamd_strtou64(p, len, &u64)) { + lua_pushnil(L); + lua_pushstring(L, "invalid number"); + return 2; + } + + gint64 *i64_p = lua_newuserdata(L, sizeof(gint64)); + rspamd_lua_setclass(L, "rspamd{int64}", -1); + memcpy(i64_p, &u64, sizeof(u64)); + + if (neg) { + *i64_p = -(*i64_p); + } + } + else { + } + + return 1; +} + +static int +lua_int64_tonumber(lua_State *L) +{ + gint64 n = lua_check_int64(L, 1); + gdouble d; + + d = n; + lua_pushinteger(L, d); + + return 1; +} + +static int +lua_int64_hex(lua_State *L) +{ + gint64 n = lua_check_int64(L, 1); + gchar buf[32]; + + rspamd_snprintf(buf, sizeof(buf), "%XL", n); + lua_pushstring(L, buf); + + return 1; +} + +static int +lua_ev_base_loop(lua_State *L) +{ + int flags = 0; + struct ev_loop *ev_base; + + ev_base = lua_check_ev_base(L, 1); + if (lua_isnumber(L, 2)) { + flags = lua_tointeger(L, 2); + } + + int ret = ev_run(ev_base, flags); + lua_pushinteger(L, ret); + + return 1; +} diff --git a/src/lua/lua_worker.c b/src/lua/lua_worker.c new file mode 100644 index 0000000..025b97b --- /dev/null +++ b/src/lua/lua_worker.c @@ -0,0 +1,883 @@ +/* + * Copyright 2023 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "lua_common.h" +#include "unix-std.h" +#include "worker_util.h" +#include "rspamd_control.h" +#include "ottery.h" + +#ifdef WITH_JEMALLOC +#include <jemalloc/jemalloc.h> +#endif + +#include <sys/wait.h> +#include <src/libserver/rspamd_control.h> + +/*** + * @module rspamd_worker + * This module provides methods to access worker related functions in various + * places, such as periodic events or on_load events. + */ + + +LUA_FUNCTION_DEF(worker, get_name); +LUA_FUNCTION_DEF(worker, get_stat); +LUA_FUNCTION_DEF(worker, get_index); +LUA_FUNCTION_DEF(worker, get_count); +LUA_FUNCTION_DEF(worker, get_pid); +LUA_FUNCTION_DEF(worker, is_scanner); +LUA_FUNCTION_DEF(worker, is_primary_controller); +LUA_FUNCTION_DEF(worker, spawn_process); +LUA_FUNCTION_DEF(worker, get_mem_stats); +LUA_FUNCTION_DEF(worker, add_control_handler); + +const luaL_reg worker_reg[] = { + LUA_INTERFACE_DEF(worker, get_name), + {"get_type", lua_worker_get_name}, + LUA_INTERFACE_DEF(worker, get_stat), + LUA_INTERFACE_DEF(worker, get_index), + LUA_INTERFACE_DEF(worker, get_count), + LUA_INTERFACE_DEF(worker, get_pid), + LUA_INTERFACE_DEF(worker, spawn_process), + LUA_INTERFACE_DEF(worker, is_scanner), + LUA_INTERFACE_DEF(worker, is_primary_controller), + LUA_INTERFACE_DEF(worker, get_mem_stats), + LUA_INTERFACE_DEF(worker, add_control_handler), + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + +static struct rspamd_worker * +lua_check_worker(lua_State *L, gint pos) +{ + void *ud = rspamd_lua_check_udata(L, pos, "rspamd{worker}"); + luaL_argcheck(L, ud != NULL, pos, "'worker' expected"); + return ud ? *((struct rspamd_worker **) ud) : NULL; +} + +static gint +lua_worker_get_stat(lua_State *L) +{ + struct rspamd_worker *w = lua_check_worker(L, 1); + + if (w) { + rspamd_mempool_stat_t mem_st; + struct rspamd_stat *stat, stat_copy; + ucl_object_t *top, *sub; + gint i; + guint64 spam = 0, ham = 0; + + memset(&mem_st, 0, sizeof(mem_st)); + rspamd_mempool_stat(&mem_st); + memcpy(&stat_copy, w->srv->stat, sizeof(stat_copy)); + stat = &stat_copy; + top = ucl_object_typed_new(UCL_OBJECT); + ucl_object_insert_key(top, ucl_object_fromint(stat->messages_scanned), "scanned", 0, false); + ucl_object_insert_key(top, ucl_object_fromint(stat->messages_learned), "learned", 0, false); + if (stat->messages_scanned > 0) { + sub = ucl_object_typed_new(UCL_OBJECT); + for (i = METRIC_ACTION_REJECT; i <= METRIC_ACTION_NOACTION; i++) { + ucl_object_insert_key(sub, + ucl_object_fromint(stat->actions_stat[i]), + rspamd_action_to_str(i), 0, false); + if (i < METRIC_ACTION_GREYLIST) { + spam += stat->actions_stat[i]; + } + else { + ham += stat->actions_stat[i]; + } + } + ucl_object_insert_key(top, sub, "actions", 0, false); + } + else { + sub = ucl_object_typed_new(UCL_OBJECT); + for (i = METRIC_ACTION_REJECT; i <= METRIC_ACTION_NOACTION; i++) { + ucl_object_insert_key(sub, + 0, + rspamd_action_to_str(i), 0, false); + } + ucl_object_insert_key(top, sub, "actions", 0, false); + } + ucl_object_insert_key(top, ucl_object_fromint(spam), "spam_count", 0, false); + ucl_object_insert_key(top, ucl_object_fromint(ham), "ham_count", 0, false); + ucl_object_insert_key(top, + ucl_object_fromint(stat->connections_count), "connections", 0, false); + ucl_object_insert_key(top, + ucl_object_fromint(stat->control_connections_count), + "control_connections", 0, false); + ucl_object_insert_key(top, + ucl_object_fromint(mem_st.pools_allocated), "pools_allocated", 0, + false); + ucl_object_insert_key(top, + ucl_object_fromint(mem_st.pools_freed), "pools_freed", 0, false); + ucl_object_insert_key(top, + ucl_object_fromint(mem_st.bytes_allocated), "bytes_allocated", 0, + false); + ucl_object_insert_key(top, + ucl_object_fromint( + mem_st.chunks_allocated), + "chunks_allocated", 0, false); + ucl_object_insert_key(top, + ucl_object_fromint(mem_st.shared_chunks_allocated), + "shared_chunks_allocated", 0, false); + ucl_object_insert_key(top, + ucl_object_fromint(mem_st.chunks_freed), "chunks_freed", 0, false); + ucl_object_insert_key(top, + ucl_object_fromint( + mem_st.oversized_chunks), + "chunks_oversized", 0, false); + + ucl_object_push_lua(L, top, true); + ucl_object_unref(top); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_worker_get_name(lua_State *L) +{ + struct rspamd_worker *w = lua_check_worker(L, 1); + + if (w) { + lua_pushstring(L, g_quark_to_string(w->type)); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_worker_get_index(lua_State *L) +{ + struct rspamd_worker *w = lua_check_worker(L, 1); + + if (w) { + lua_pushinteger(L, w->index); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_worker_get_count(lua_State *L) +{ + struct rspamd_worker *w = lua_check_worker(L, 1); + + if (w) { + lua_pushinteger(L, w->cf->count); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_worker_get_pid(lua_State *L) +{ + struct rspamd_worker *w = lua_check_worker(L, 1); + + if (w) { + lua_pushinteger(L, w->pid); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + + +static gint +lua_worker_is_scanner(lua_State *L) +{ + struct rspamd_worker *w = lua_check_worker(L, 1); + + if (w) { + lua_pushboolean(L, rspamd_worker_is_scanner(w)); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +static gint +lua_worker_is_primary_controller(lua_State *L) +{ + struct rspamd_worker *w = lua_check_worker(L, 1); + + if (w) { + lua_pushboolean(L, rspamd_worker_is_primary_controller(w)); + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +struct rspamd_control_cbdata { + lua_State *L; + rspamd_mempool_t *pool; + struct rspamd_worker *w; + struct rspamd_config *cfg; + struct ev_loop *event_loop; + struct rspamd_async_session *session; + enum rspamd_control_type cmd; + gint cbref; + gint fd; +}; + +static gboolean +lua_worker_control_fin_session(void *ud) +{ + struct rspamd_control_reply rep; + struct rspamd_control_cbdata *cbd = (struct rspamd_control_cbdata *) ud; + rspamd_mempool_t *pool; + + pool = cbd->pool; + + memset(&rep, 0, sizeof(rep)); + rep.type = cbd->cmd; + + if (write(cbd->fd, &rep, sizeof(rep)) != sizeof(rep)) { + msg_err_pool("cannot write reply to the control socket: %s", + strerror(errno)); + } + + return TRUE; +} + +static void +lua_worker_control_session_dtor(void *ud) +{ + struct rspamd_control_cbdata *cbd = (struct rspamd_control_cbdata *) ud; + + rspamd_mempool_delete(cbd->pool); +} + +static gboolean +lua_worker_control_handler(struct rspamd_main *rspamd_main, + struct rspamd_worker *worker, + gint fd, + gint attached_fd, + struct rspamd_control_command *cmd, + gpointer ud) +{ + struct rspamd_async_session *session, **psession; + struct rspamd_control_cbdata *cbd = (struct rspamd_control_cbdata *) ud; + rspamd_mempool_t *pool; + lua_State *L; + gint err_idx, status; + + L = cbd->L; + pool = cbd->pool; + session = rspamd_session_create(cbd->pool, + lua_worker_control_fin_session, + NULL, + lua_worker_control_session_dtor, + cbd); + cbd->session = session; + cbd->fd = fd; + + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + lua_rawgeti(L, LUA_REGISTRYINDEX, cbd->cbref); + psession = lua_newuserdata(L, sizeof(*psession)); + rspamd_lua_setclass(L, "rspamd{session}", -1); + *psession = session; + + /* Command name */ + lua_pushstring(L, rspamd_control_command_to_string(cmd->type)); + + /* Command's extras */ + lua_newtable(L); + + switch (cmd->type) { + case RSPAMD_CONTROL_CHILD_CHANGE: + lua_pushinteger(L, cmd->cmd.child_change.pid); + lua_setfield(L, -2, "pid"); + switch (cmd->cmd.child_change.what) { + case rspamd_child_offline: + lua_pushstring(L, "offline"); + lua_setfield(L, -2, "what"); + break; + case rspamd_child_online: + lua_pushstring(L, "online"); + lua_setfield(L, -2, "what"); + break; + case rspamd_child_terminated: + lua_pushstring(L, "terminated"); + lua_setfield(L, -2, "what"); + status = cmd->cmd.child_change.additional; + + if (WIFEXITED(status)) { + lua_pushinteger(L, WEXITSTATUS(status)); + lua_setfield(L, -2, "exit_code"); + } + + if (WIFSIGNALED(status)) { + lua_pushinteger(L, WTERMSIG(status)); + lua_setfield(L, -2, "signal"); +#ifdef WCOREDUMP + lua_pushboolean(L, WCOREDUMP(status)); + lua_setfield(L, -2, "core"); +#endif + } + break; + } + break; + case RSPAMD_CONTROL_MONITORED_CHANGE: + lua_pushinteger(L, cmd->cmd.monitored_change.sender); + lua_setfield(L, -2, "sender"); + lua_pushboolean(L, cmd->cmd.monitored_change.alive); + lua_setfield(L, -2, "alive"); + lua_pushlstring(L, cmd->cmd.monitored_change.tag, + sizeof(cmd->cmd.monitored_change.tag)); + lua_setfield(L, -2, "tag"); + break; + case RSPAMD_CONTROL_HYPERSCAN_LOADED: + lua_pushstring(L, cmd->cmd.hs_loaded.cache_dir); + lua_setfield(L, -2, "cache_dir"); + lua_pushboolean(L, cmd->cmd.hs_loaded.forced); + lua_setfield(L, -2, "forced"); + break; + case RSPAMD_CONTROL_STAT: + case RSPAMD_CONTROL_RELOAD: + case RSPAMD_CONTROL_RERESOLVE: + case RSPAMD_CONTROL_RECOMPILE: + case RSPAMD_CONTROL_LOG_PIPE: + case RSPAMD_CONTROL_FUZZY_STAT: + case RSPAMD_CONTROL_FUZZY_SYNC: + default: + break; + } + + if (lua_pcall(L, 3, 0, err_idx) != 0) { + msg_err_pool("cannot init lua parser script: %s", lua_tostring(L, -1)); + lua_settop(L, err_idx - 1); + + struct rspamd_control_reply rep; + + memset(&rep, 0, sizeof(rep)); + rep.type = cbd->cmd; + rep.reply.monitored_change.status = -1; + + if (write(fd, &rep, sizeof(rep)) != sizeof(rep)) { + msg_err_pool("cannot write reply to the control socket: %s", + strerror(errno)); + } + + rspamd_session_destroy(session); + } + else { + lua_settop(L, err_idx - 1); + rspamd_session_pending(session); + } + + return TRUE; +} + +static gint +lua_worker_add_control_handler(lua_State *L) +{ + struct rspamd_worker *w = lua_check_worker(L, 1); + struct rspamd_config *cfg = lua_check_config(L, 2); + struct ev_loop *event_loop = lua_check_ev_base(L, 3); + const gchar *cmd_name = luaL_checkstring(L, 4); + enum rspamd_control_type cmd; + struct rspamd_control_cbdata *cbd; + + if (w && cfg && event_loop && cmd_name && lua_isfunction(L, 5)) { + cmd = rspamd_control_command_from_string(cmd_name); + + if (cmd == RSPAMD_CONTROL_MAX) { + return luaL_error(L, "invalid command type: %s", cmd_name); + } + + rspamd_mempool_t *pool = rspamd_mempool_new( + rspamd_mempool_suggest_size(), "lua_control", 0); + cbd = rspamd_mempool_alloc0(pool, sizeof(*cbd)); + cbd->pool = pool; + cbd->event_loop = event_loop; + cbd->w = w; + cbd->cfg = cfg; + cbd->cmd = cmd; + cbd->L = L; + /* Refcount callback */ + lua_pushvalue(L, 5); + cbd->cbref = luaL_ref(L, LUA_REGISTRYINDEX); + + rspamd_control_worker_add_cmd_handler(w, cmd, lua_worker_control_handler, + cbd); + } + else { + return luaL_error(L, "invalid arguments, need worker, cfg, " + "ev_loop, cmd_name and callback function"); + } + + return 0; +} + +#ifdef WITH_JEMALLOC +static void +lua_worker_jemalloc_stats_cb(void *ud, const char *msg) +{ + lua_State *L = (lua_State *) ud; + + lua_pushstring(L, msg); +} +#endif + +static gint +lua_worker_get_mem_stats(lua_State *L) +{ + struct rspamd_worker *w = lua_check_worker(L, 1); + + if (w) { +#ifdef WITH_JEMALLOC + malloc_stats_print(lua_worker_jemalloc_stats_cb, (void *) L, NULL); +#else + lua_pushstring(L, "no stats, jemalloc support is required"); +#endif + } + else { + return luaL_error(L, "invalid arguments"); + } + + return 1; +} + +struct rspamd_lua_process_cbdata { + gint sp[2]; + gint func_cbref; + gint cb_cbref; + gboolean replied; + gboolean is_error; + pid_t cpid; + lua_State *L; + guint64 sz; + GString *io_buf; + GString *out_buf; + goffset out_pos; + struct rspamd_worker *wrk; + struct ev_loop *event_loop; + ev_io ev; +}; + +static void +rspamd_lua_execute_lua_subprocess(lua_State *L, + struct rspamd_lua_process_cbdata *cbdata) +{ + gint err_idx, r; + guint64 wlen = 0; + + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + + lua_rawgeti(L, LUA_REGISTRYINDEX, cbdata->func_cbref); + + if (lua_pcall(L, 0, 1, err_idx) != 0) { + const gchar *s = lua_tostring(L, -1); + gsize slen = strlen(s); + + msg_err("call to subprocess failed: %s", s); + /* Indicate error */ + wlen = (1ULL << 63u) + slen; + + r = write(cbdata->sp[1], &wlen, sizeof(wlen)); + if (r == -1) { + msg_err("write failed: %s", strerror(errno)); + } + + r = write(cbdata->sp[1], s, slen); + if (r == -1) { + msg_err("write failed: %s", strerror(errno)); + } + } + else { + struct rspamd_lua_text *t = lua_check_text_or_string(L, -1); + + if (t) { + wlen = t->len; + r = write(cbdata->sp[1], &wlen, sizeof(wlen)); + + if (r == -1) { + msg_err("write failed: %s", strerror(errno)); + } + + r = write(cbdata->sp[1], t->start, wlen); + + if (r == -1) { + msg_err("write failed: %s", strerror(errno)); + } + } + else { + msg_err("subprocess: invalid return value: %s", + lua_typename(L, lua_type(L, -1))); + } + } + + lua_settop(L, err_idx - 1); +} + +static void +rspamd_lua_call_on_complete(lua_State *L, + struct rspamd_lua_process_cbdata *cbdata, + const gchar *err_msg, + const gchar *data, gsize datalen) +{ + gint err_idx; + + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + + lua_rawgeti(L, LUA_REGISTRYINDEX, cbdata->cb_cbref); + + if (err_msg) { + lua_pushstring(L, err_msg); + } + else { + lua_pushnil(L); + } + + if (data) { + lua_pushlstring(L, data, datalen); + } + else { + lua_pushnil(L); + } + + if (lua_pcall(L, 2, 0, err_idx) != 0) { + msg_err("call to on_complete script failed: %s", + lua_tostring(L, -1)); + } + + lua_settop(L, err_idx - 1); +} + +static gboolean +rspamd_lua_cld_handler(struct rspamd_worker_signal_handler *sigh, void *ud) +{ + struct rspamd_lua_process_cbdata *cbdata = ud; + struct rspamd_srv_command srv_cmd; + lua_State *L; + pid_t died; + gint res = 0; + + /* Are we called by a correct children ? */ + died = waitpid(cbdata->cpid, &res, WNOHANG); + + if (died <= 0) { + /* Wait more */ + return TRUE; + } + + L = cbdata->L; + msg_info("handled SIGCHLD from %P", cbdata->cpid); + + if (!cbdata->replied) { + /* We still need to call on_complete callback */ + ev_io_stop(cbdata->event_loop, &cbdata->ev); + rspamd_lua_call_on_complete(cbdata->L, cbdata, + "Worker has died without reply", NULL, 0); + } + + /* Free structures */ + close(cbdata->sp[0]); + luaL_unref(L, LUA_REGISTRYINDEX, cbdata->func_cbref); + luaL_unref(L, LUA_REGISTRYINDEX, cbdata->cb_cbref); + g_string_free(cbdata->io_buf, TRUE); + + if (cbdata->out_buf) { + g_string_free(cbdata->out_buf, TRUE); + } + + /* Notify main */ + memset(&srv_cmd, 0, sizeof(srv_cmd)); + srv_cmd.type = RSPAMD_SRV_ON_FORK; + srv_cmd.cmd.on_fork.state = child_dead; + srv_cmd.cmd.on_fork.cpid = cbdata->cpid; + srv_cmd.cmd.on_fork.ppid = getpid(); + rspamd_srv_send_command(cbdata->wrk, cbdata->event_loop, &srv_cmd, -1, + NULL, NULL); + g_free(cbdata); + + /* We are done with this SIGCHLD */ + return FALSE; +} + +static void +rspamd_lua_subprocess_io(EV_P_ ev_io *w, int revents) +{ + struct rspamd_lua_process_cbdata *cbdata = + (struct rspamd_lua_process_cbdata *) w->data; + gssize r; + + if (cbdata->sz == (guint64) -1) { + guint64 sz; + + /* We read size of reply + flags first */ + r = read(cbdata->sp[0], cbdata->io_buf->str + cbdata->io_buf->len, + sizeof(guint64) - cbdata->io_buf->len); + + if (r == 0) { + ev_io_stop(cbdata->event_loop, &cbdata->ev); + rspamd_lua_call_on_complete(cbdata->L, cbdata, + "Unexpected EOF", NULL, 0); + cbdata->replied = TRUE; + kill(cbdata->cpid, SIGTERM); + + return; + } + else if (r == -1) { + if (errno == EAGAIN || errno == EINTR) { + return; + } + else { + ev_io_stop(cbdata->event_loop, &cbdata->ev); + rspamd_lua_call_on_complete(cbdata->L, cbdata, + strerror(errno), NULL, 0); + cbdata->replied = TRUE; + kill(cbdata->cpid, SIGTERM); + + return; + } + } + + cbdata->io_buf->len += r; + + if (cbdata->io_buf->len == sizeof(guint64)) { + memcpy((guchar *) &sz, cbdata->io_buf->str, sizeof(sz)); + + if (sz & (1ULL << 63)) { + cbdata->is_error = TRUE; + sz &= ~(1ULL << 63); + } + + cbdata->io_buf->len = 0; + cbdata->sz = sz; + g_string_set_size(cbdata->io_buf, sz + 1); + cbdata->io_buf->len = 0; + } + } + else { + /* Read data */ + r = read(cbdata->sp[0], cbdata->io_buf->str + cbdata->io_buf->len, + cbdata->sz - cbdata->io_buf->len); + + if (r == 0) { + ev_io_stop(cbdata->event_loop, &cbdata->ev); + rspamd_lua_call_on_complete(cbdata->L, cbdata, + "Unexpected EOF", NULL, 0); + cbdata->replied = TRUE; + kill(cbdata->cpid, SIGTERM); + + return; + } + else if (r == -1) { + if (errno == EAGAIN || errno == EINTR) { + return; + } + else { + ev_io_stop(cbdata->event_loop, &cbdata->ev); + rspamd_lua_call_on_complete(cbdata->L, cbdata, + strerror(errno), NULL, 0); + cbdata->replied = TRUE; + kill(cbdata->cpid, SIGTERM); + + return; + } + } + + cbdata->io_buf->len += r; + + if (cbdata->io_buf->len == cbdata->sz) { + gchar rep[4]; + + ev_io_stop(cbdata->event_loop, &cbdata->ev); + /* Finished reading data */ + if (cbdata->is_error) { + cbdata->io_buf->str[cbdata->io_buf->len] = '\0'; + rspamd_lua_call_on_complete(cbdata->L, cbdata, + cbdata->io_buf->str, NULL, 0); + } + else { + rspamd_lua_call_on_complete(cbdata->L, cbdata, + NULL, cbdata->io_buf->str, cbdata->io_buf->len); + } + + cbdata->replied = TRUE; + + /* Write reply to the child */ + rspamd_socket_blocking(cbdata->sp[0]); + memset(rep, 0, sizeof(rep)); + (void) !write(cbdata->sp[0], rep, sizeof(rep)); + } + } +} + +static gint +lua_worker_spawn_process(lua_State *L) +{ + struct rspamd_worker *w = lua_check_worker(L, 1); + struct rspamd_lua_process_cbdata *cbdata; + struct rspamd_abstract_worker_ctx *actx; + struct rspamd_srv_command srv_cmd; + const gchar *cmdline = NULL, *input = NULL, *proctitle = NULL; + gsize inputlen = 0; + pid_t pid; + GError *err = NULL; + gint func_cbref, cb_cbref; + + if (!rspamd_lua_parse_table_arguments(L, 2, &err, + RSPAMD_LUA_PARSE_ARGUMENTS_DEFAULT, + "func=F;exec=S;stdin=V;*on_complete=F;proctitle=S", &func_cbref, + &cmdline, &inputlen, &input, &cb_cbref, &proctitle)) { + msg_err("cannot get parameters list: %e", err); + + if (err) { + g_error_free(err); + } + + return 0; + } + + cbdata = g_malloc0(sizeof(*cbdata)); + cbdata->cb_cbref = cb_cbref; + cbdata->func_cbref = func_cbref; + + if (input) { + cbdata->out_buf = g_string_new_len(input, inputlen); + cbdata->out_pos = 0; + } + + if (rspamd_socketpair(cbdata->sp, SOCK_STREAM) == -1) { + msg_err("cannot spawn socketpair: %s", strerror(errno)); + luaL_unref(L, LUA_REGISTRYINDEX, cbdata->func_cbref); + luaL_unref(L, LUA_REGISTRYINDEX, cbdata->cb_cbref); + g_free(cbdata); + + return 0; + } + + actx = w->ctx; + cbdata->wrk = w; + cbdata->L = L; + cbdata->event_loop = actx->event_loop; + cbdata->sz = (guint64) -1; + + pid = fork(); + + if (pid == -1) { + msg_err("cannot spawn process: %s", strerror(errno)); + close(cbdata->sp[0]); + close(cbdata->sp[1]); + luaL_unref(L, LUA_REGISTRYINDEX, cbdata->func_cbref); + luaL_unref(L, LUA_REGISTRYINDEX, cbdata->cb_cbref); + g_free(cbdata); + + return 0; + } + else if (pid == 0) { + /* Child */ + gint rc; + gchar inbuf[4]; + + rspamd_log_on_fork(w->cf->type, w->srv->cfg, w->srv->logger); + rc = ottery_init(w->srv->cfg->libs_ctx->ottery_cfg); + + if (rc != OTTERY_ERR_NONE) { + msg_err("cannot initialize PRNG: %d", rc); + abort(); + } + rspamd_random_seed_fast(); +#ifdef HAVE_EVUTIL_RNG_INIT + evutil_secure_rng_init(); +#endif + + close(cbdata->sp[0]); + /* Here we assume that we can block on writing results */ + rspamd_socket_blocking(cbdata->sp[1]); + g_hash_table_remove_all(w->signal_events); + ev_loop_destroy(cbdata->event_loop); + + if (proctitle) { + rspamd_setproctitle("lua process: %s", proctitle); + } + else { + rspamd_setproctitle("lua process: unnamed"); + } + + cbdata->event_loop = ev_loop_new(EVFLAG_SIGNALFD); + rspamd_worker_unblock_signals(); + rspamd_lua_execute_lua_subprocess(L, cbdata); + + /* Wait for parent to reply and exit */ + rc = read(cbdata->sp[1], inbuf, sizeof(inbuf)); + + if (rc >= sizeof(inbuf) && + memcmp(inbuf, "\0\0\0\0", sizeof(inbuf)) == 0) { + exit(EXIT_SUCCESS); + } + else { + msg_err("got invalid reply from parent"); + + exit(EXIT_FAILURE); + } + } + + cbdata->cpid = pid; + cbdata->io_buf = g_string_sized_new(8); + /* Notify main */ + memset(&srv_cmd, 0, sizeof(srv_cmd)); + srv_cmd.type = RSPAMD_SRV_ON_FORK; + srv_cmd.cmd.on_fork.state = child_create; + srv_cmd.cmd.on_fork.cpid = pid; + srv_cmd.cmd.on_fork.ppid = getpid(); + rspamd_srv_send_command(w, cbdata->event_loop, &srv_cmd, -1, NULL, NULL); + + close(cbdata->sp[1]); + rspamd_socket_nonblocking(cbdata->sp[0]); + /* Parent */ + rspamd_worker_set_signal_handler(SIGCHLD, w, cbdata->event_loop, + rspamd_lua_cld_handler, + cbdata); + + /* Add result pipe waiting */ + ev_io_init(&cbdata->ev, rspamd_lua_subprocess_io, cbdata->sp[0], EV_READ); + cbdata->ev.data = cbdata; + ev_io_start(cbdata->event_loop, &cbdata->ev); + + return 0; +} + +void luaopen_worker(lua_State *L) +{ + rspamd_lua_new_class(L, "rspamd{worker}", worker_reg); +} diff --git a/src/lua/lua_xmlrpc.c b/src/lua/lua_xmlrpc.c new file mode 100644 index 0000000..efb2b22 --- /dev/null +++ b/src/lua/lua_xmlrpc.c @@ -0,0 +1,796 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "lua_common.h" + + +LUA_FUNCTION_DEF(xmlrpc, parse_reply); +LUA_FUNCTION_DEF(xmlrpc, make_request); + +static const struct luaL_reg xmlrpclib_m[] = { + LUA_INTERFACE_DEF(xmlrpc, parse_reply), + LUA_INTERFACE_DEF(xmlrpc, make_request), + {"__tostring", rspamd_lua_class_tostring}, + {NULL, NULL}}; + +#define msg_debug_xmlrpc(...) rspamd_conditional_debug_fast(NULL, NULL, \ + rspamd_xmlrpc_log_id, "xmlrpc", "", \ + RSPAMD_LOG_FUNC, \ + __VA_ARGS__) + +INIT_LOG_MODULE(xmlrpc) + +enum lua_xmlrpc_state { + read_method_response = 0, + read_params = 1, + read_param = 2, + read_param_value = 3, + read_param_element = 4, + read_struct = 5, + read_struct_member_name = 6, + read_struct_member_value = 7, + read_struct_element = 8, + read_string = 9, + read_int = 10, + read_double = 11, + read_array = 12, + read_array_value = 13, + read_array_element = 14, + error_state = 99, + success_state = 100, +}; + +enum lua_xmlrpc_stack { + st_array = 1, + st_struct = 2, +}; + +struct lua_xmlrpc_ud { + enum lua_xmlrpc_state parser_state; + GQueue *st; + gint param_count; + gboolean got_text; + lua_State *L; +}; + +static void xmlrpc_start_element(GMarkupParseContext *context, + const gchar *name, + const gchar **attribute_names, + const gchar **attribute_values, + gpointer user_data, + GError **error); +static void xmlrpc_end_element(GMarkupParseContext *context, + const gchar *element_name, + gpointer user_data, + GError **error); +static void xmlrpc_error(GMarkupParseContext *context, + GError *error, + gpointer user_data); +static void xmlrpc_text(GMarkupParseContext *context, + const gchar *text, + gsize text_len, + gpointer user_data, + GError **error); + +static GMarkupParser xmlrpc_parser = { + .start_element = xmlrpc_start_element, + .end_element = xmlrpc_end_element, + .passthrough = NULL, + .text = xmlrpc_text, + .error = xmlrpc_error, +}; + +static GQuark +xmlrpc_error_quark(void) +{ + return g_quark_from_static_string("xmlrpc-error-quark"); +} + +static void +xmlrpc_start_element(GMarkupParseContext *context, + const gchar *name, + const gchar **attribute_names, + const gchar **attribute_values, + gpointer user_data, + GError **error) +{ + struct lua_xmlrpc_ud *ud = user_data; + enum lua_xmlrpc_state last_state; + + last_state = ud->parser_state; + + msg_debug_xmlrpc("got start element %s on state %d", name, last_state); + + switch (ud->parser_state) { + case read_method_response: + /* Expect tag methodResponse */ + if (g_ascii_strcasecmp(name, "methodResponse") == 0) { + ud->parser_state = read_params; + } + else { + /* Error state */ + ud->parser_state = error_state; + } + break; + case read_params: + /* Expect tag params */ + if (g_ascii_strcasecmp(name, "params") == 0) { + ud->parser_state = read_param; + /* result -> table of params indexed by int */ + lua_newtable(ud->L); + } + else { + /* Error state */ + ud->parser_state = error_state; + } + break; + case read_param: + /* Expect tag param */ + if (g_ascii_strcasecmp(name, "param") == 0) { + ud->parser_state = read_param_value; + /* Create new param */ + } + else { + /* Error state */ + ud->parser_state = error_state; + } + break; + case read_param_value: + /* Expect tag value */ + if (g_ascii_strcasecmp(name, "value") == 0) { + ud->parser_state = read_param_element; + } + else { + /* Error state */ + ud->parser_state = error_state; + } + break; + case read_param_element: + /* Expect tag struct */ + if (g_ascii_strcasecmp(name, "struct") == 0) { + ud->parser_state = read_struct; + /* Create new param of table type */ + lua_newtable(ud->L); + g_queue_push_head(ud->st, GINT_TO_POINTER(st_struct)); + msg_debug_xmlrpc("push struct"); + } + else if (g_ascii_strcasecmp(name, "array") == 0) { + ud->parser_state = read_array; + /* Create new param of table type */ + lua_newtable(ud->L); + g_queue_push_head(ud->st, GINT_TO_POINTER(st_array)); + msg_debug_xmlrpc("push array"); + } + else if (g_ascii_strcasecmp(name, "string") == 0) { + ud->parser_state = read_string; + ud->got_text = FALSE; + } + else if (g_ascii_strcasecmp(name, "int") == 0) { + ud->parser_state = read_int; + ud->got_text = FALSE; + } + else if (g_ascii_strcasecmp(name, "double") == 0) { + ud->parser_state = read_double; + ud->got_text = FALSE; + } + else { + /* Error state */ + ud->parser_state = error_state; + } + break; + case read_struct: + /* Parse structure */ + /* Expect tag member */ + if (g_ascii_strcasecmp(name, "member") == 0) { + ud->parser_state = read_struct_member_name; + } + else { + /* Error state */ + ud->parser_state = error_state; + } + break; + case read_struct_member_name: + /* Expect tag name */ + if (g_ascii_strcasecmp(name, "name") == 0) { + ud->parser_state = read_struct_member_value; + } + else { + /* Error state */ + ud->parser_state = error_state; + } + break; + case read_struct_member_value: + /* Accept value */ + if (g_ascii_strcasecmp(name, "value") == 0) { + ud->parser_state = read_struct_element; + } + else { + /* Error state */ + ud->parser_state = error_state; + } + break; + case read_struct_element: + /* Parse any values */ + /* Primitives */ + if (g_ascii_strcasecmp(name, "string") == 0) { + ud->parser_state = read_string; + ud->got_text = FALSE; + } + else if (g_ascii_strcasecmp(name, "int") == 0) { + ud->parser_state = read_int; + ud->got_text = FALSE; + } + else if (g_ascii_strcasecmp(name, "double") == 0) { + ud->parser_state = read_double; + ud->got_text = FALSE; + } + /* Structure */ + else if (g_ascii_strcasecmp(name, "struct") == 0) { + ud->parser_state = read_struct; + /* Create new param of table type */ + lua_newtable(ud->L); + g_queue_push_head(ud->st, GINT_TO_POINTER(st_struct)); + msg_debug_xmlrpc("push struct"); + } + else if (g_ascii_strcasecmp(name, "array") == 0) { + ud->parser_state = read_array; + /* Create new param of table type */ + lua_newtable(ud->L); + g_queue_push_head(ud->st, GINT_TO_POINTER(st_array)); + msg_debug_xmlrpc("push array"); + } + else { + /* Error state */ + ud->parser_state = error_state; + } + break; + case read_array: + /* Parse array */ + /* Expect data */ + if (g_ascii_strcasecmp(name, "data") == 0) { + ud->parser_state = read_array_value; + } + else { + /* Error state */ + ud->parser_state = error_state; + } + break; + case read_array_value: + /* Accept array value */ + if (g_ascii_strcasecmp(name, "value") == 0) { + ud->parser_state = read_array_element; + } + else { + /* Error state */ + ud->parser_state = error_state; + } + break; + case read_array_element: + /* Parse any values */ + /* Primitives */ + if (g_ascii_strcasecmp(name, "string") == 0) { + ud->parser_state = read_string; + ud->got_text = FALSE; + } + else if (g_ascii_strcasecmp(name, "int") == 0) { + ud->parser_state = read_int; + ud->got_text = FALSE; + } + else if (g_ascii_strcasecmp(name, "double") == 0) { + ud->parser_state = read_double; + ud->got_text = FALSE; + } + /* Structure */ + else if (g_ascii_strcasecmp(name, "struct") == 0) { + ud->parser_state = read_struct; + /* Create new param of table type */ + lua_newtable(ud->L); + g_queue_push_head(ud->st, GINT_TO_POINTER(st_struct)); + msg_debug_xmlrpc("push struct"); + } + else if (g_ascii_strcasecmp(name, "array") == 0) { + ud->parser_state = read_array; + /* Create new param of table type */ + lua_newtable(ud->L); + g_queue_push_head(ud->st, GINT_TO_POINTER(st_array)); + msg_debug_xmlrpc("push array"); + } + else { + /* Error state */ + ud->parser_state = error_state; + } + break; + default: + break; + } + + msg_debug_xmlrpc("switched state on start tag %d->%d", last_state, + ud->parser_state); + + if (ud->parser_state == error_state) { + g_set_error(error, + xmlrpc_error_quark(), 1, "xmlrpc parse error on state: %d, while parsing start tag: %s", + last_state, name); + } +} + +static void +xmlrpc_end_element(GMarkupParseContext *context, + const gchar *name, + gpointer user_data, + GError **error) +{ + struct lua_xmlrpc_ud *ud = user_data; + enum lua_xmlrpc_state last_state; + int last_queued; + + last_state = ud->parser_state; + + msg_debug_xmlrpc("got end element %s on state %d", name, last_state); + + switch (ud->parser_state) { + case read_method_response: + ud->parser_state = error_state; + break; + case read_params: + /* Got methodResponse */ + if (g_ascii_strcasecmp(name, "methodResponse") == 0) { + /* End processing */ + ud->parser_state = success_state; + } + else { + /* Error state */ + ud->parser_state = error_state; + } + break; + case read_param: + /* Got tag params */ + if (g_ascii_strcasecmp(name, "params") == 0) { + ud->parser_state = read_params; + } + else { + /* Error state */ + ud->parser_state = error_state; + } + break; + case read_param_value: + /* Got tag param */ + if (g_ascii_strcasecmp(name, "param") == 0) { + ud->parser_state = read_param; + lua_rawseti(ud->L, -2, ++ud->param_count); + msg_debug_xmlrpc("set param element idx: %d", ud->param_count); + } + else { + /* Error state */ + ud->parser_state = error_state; + } + break; + case read_param_element: + /* Got tag value */ + if (g_ascii_strcasecmp(name, "value") == 0) { + if (g_queue_get_length(ud->st) == 0) { + ud->parser_state = read_param_value; + } + else { + if (GPOINTER_TO_INT(g_queue_peek_head(ud->st)) == st_struct) { + ud->parser_state = read_struct_member_name; + } + else { + ud->parser_state = read_array_value; + } + } + } + else { + /* Error state */ + ud->parser_state = error_state; + } + break; + case read_struct: + /* Got tag struct */ + if (g_ascii_strcasecmp(name, "struct") == 0) { + g_assert(GPOINTER_TO_INT(g_queue_pop_head(ud->st)) == st_struct); + + if (g_queue_get_length(ud->st) == 0) { + ud->parser_state = read_param_element; + } + else { + last_queued = GPOINTER_TO_INT(g_queue_peek_head(ud->st)); + if (last_queued == st_struct) { + ud->parser_state = read_struct_element; + } + else { + ud->parser_state = read_array_element; + } + } + + msg_debug_xmlrpc("pop struct"); + } + else { + /* Error state */ + ud->parser_state = error_state; + } + break; + case read_struct_member_name: + /* Got tag member */ + if (g_ascii_strcasecmp(name, "member") == 0) { + ud->parser_state = read_struct; + /* Set table */ + msg_debug_xmlrpc("set struct element idx: %s", + lua_tostring(ud->L, -2)); + lua_settable(ud->L, -3); + } + else { + /* Error state */ + ud->parser_state = error_state; + } + break; + case read_struct_member_value: + /* Got tag name */ + if (g_ascii_strcasecmp(name, "name") == 0) { + ud->parser_state = read_struct_member_value; + } + else { + /* Error state */ + ud->parser_state = error_state; + } + break; + case read_struct_element: + /* Got tag value */ + if (g_ascii_strcasecmp(name, "value") == 0) { + ud->parser_state = read_struct_member_name; + } + else { + /* Error state */ + ud->parser_state = error_state; + } + break; + case read_string: + case read_int: + case read_double: + /* Parse any values */ + /* Handle empty tags */ + if (!ud->got_text) { + lua_pushnil(ud->L); + } + else { + ud->got_text = FALSE; + } + /* Primitives */ + if (g_ascii_strcasecmp(name, "string") == 0 || + g_ascii_strcasecmp(name, "int") == 0 || + g_ascii_strcasecmp(name, "double") == 0) { + if (GPOINTER_TO_INT(g_queue_peek_head(ud->st)) == st_struct) { + ud->parser_state = read_struct_element; + } + else { + ud->parser_state = read_array_element; + } + } + else { + /* Error state */ + ud->parser_state = error_state; + } + break; + case read_array: + /* Got tag array */ + if (g_ascii_strcasecmp(name, "array") == 0) { + g_assert(GPOINTER_TO_INT(g_queue_pop_head(ud->st)) == st_array); + + if (g_queue_get_length(ud->st) == 0) { + ud->parser_state = read_param_element; + } + else { + last_queued = GPOINTER_TO_INT(g_queue_peek_head(ud->st)); + if (last_queued == st_struct) { + ud->parser_state = read_struct_element; + } + else { + ud->parser_state = read_array_element; + } + } + + msg_debug_xmlrpc("pop array"); + } + else { + /* Error state */ + ud->parser_state = error_state; + } + break; + case read_array_value: + /* Got tag data */ + if (g_ascii_strcasecmp(name, "data") == 0) { + ud->parser_state = read_array; + } + else { + /* Error state */ + ud->parser_state = error_state; + } + break; + case read_array_element: + /* Got tag value */ + if (g_ascii_strcasecmp(name, "value") == 0) { + guint tbl_len = rspamd_lua_table_size(ud->L, -2); + lua_rawseti(ud->L, -2, tbl_len + 1); + msg_debug_xmlrpc("set array element idx: %d", tbl_len + 1); + ud->parser_state = read_array_value; + } + else { + /* Error state */ + ud->parser_state = error_state; + } + break; + default: + break; + } + + msg_debug_xmlrpc("switched state on end tag %d->%d", + last_state, ud->parser_state); + + if (ud->parser_state == error_state) { + g_set_error(error, + xmlrpc_error_quark(), 1, "xmlrpc parse error on state: %d, while parsing end tag: %s", + last_state, name); + } +} + +static void +xmlrpc_text(GMarkupParseContext *context, + const gchar *text, + gsize text_len, + gpointer user_data, + GError **error) +{ + struct lua_xmlrpc_ud *ud = user_data; + gulong num; + gdouble dnum; + + /* Strip line */ + while (text_len > 0 && g_ascii_isspace(*text)) { + text++; + text_len--; + } + while (text_len > 0 && g_ascii_isspace(text[text_len - 1])) { + text_len--; + } + + if (text_len > 0) { + msg_debug_xmlrpc("got data on state %d", ud->parser_state); + switch (ud->parser_state) { + case read_struct_member_value: + /* Push key */ + lua_pushlstring(ud->L, text, text_len); + break; + case read_string: + /* Push string value */ + lua_pushlstring(ud->L, text, text_len); + break; + case read_int: + /* Push integer value */ + rspamd_strtoul(text, text_len, &num); + lua_pushinteger(ud->L, num); + break; + case read_double: + /* Push integer value */ + dnum = strtod(text, NULL); + lua_pushnumber(ud->L, dnum); + break; + default: + break; + } + ud->got_text = TRUE; + } +} + +static void +xmlrpc_error(GMarkupParseContext *context, GError *error, gpointer user_data) +{ + msg_err("xmlrpc parser error: %s", error->message); +} + +static gint +lua_xmlrpc_parse_reply(lua_State *L) +{ + LUA_TRACE_POINT; + const gchar *data; + GMarkupParseContext *ctx; + GError *err = NULL; + struct lua_xmlrpc_ud ud; + gsize s; + gboolean res; + + data = luaL_checklstring(L, 1, &s); + + if (data != NULL) { + ud.L = L; + ud.parser_state = read_method_response; + ud.param_count = 0; + ud.st = g_queue_new(); + + ctx = g_markup_parse_context_new(&xmlrpc_parser, + G_MARKUP_TREAT_CDATA_AS_TEXT, &ud, NULL); + res = g_markup_parse_context_parse(ctx, data, s, &err); + + g_markup_parse_context_free(ctx); + if (!res) { + lua_pushnil(L); + } + } + else { + lua_pushnil(L); + } + + /* Return table or nil */ + return 1; +} + +static gint +lua_xmlrpc_parse_table(lua_State *L, + gint pos, + gchar *databuf, + gint pr, + gsize size) +{ + gint r = pr, num; + double dnum; + + r += rspamd_snprintf(databuf + r, size - r, "<struct>"); + lua_pushnil(L); /* first key */ + while (lua_next(L, pos) != 0) { + /* uses 'key' (at index -2) and 'value' (at index -1) */ + if (lua_type(L, -2) != LUA_TSTRING) { + /* Ignore non sting keys */ + lua_pop(L, 1); + continue; + } + r += rspamd_snprintf(databuf + r, + size - r, + "<member><name>%s</name><value>", + lua_tostring(L, -2)); + switch (lua_type(L, -1)) { + case LUA_TNUMBER: + num = lua_tointeger(L, -1); + dnum = lua_tonumber(L, -1); + + /* Try to avoid conversion errors */ + if (dnum != (double) num) { + r += rspamd_snprintf(databuf + r, + sizeof(databuf) - r, + "<double>%f</double>", + dnum); + } + else { + r += rspamd_snprintf(databuf + r, + sizeof(databuf) - r, + "<int>%d</int>", + num); + } + break; + case LUA_TBOOLEAN: + r += rspamd_snprintf(databuf + r, + size - r, + "<boolean>%d</boolean>", + lua_toboolean(L, -1) ? 1 : 0); + break; + case LUA_TSTRING: + r += rspamd_snprintf(databuf + r, size - r, "<string>%s</string>", + lua_tostring(L, -1)); + break; + case LUA_TTABLE: + /* Recursive call */ + r += lua_xmlrpc_parse_table(L, -1, databuf + r, r, size); + break; + } + r += rspamd_snprintf(databuf + r, size - r, "</value></member>"); + /* removes 'value'; keeps 'key' for next iteration */ + lua_pop(L, 1); + } + r += rspamd_snprintf(databuf + r, size - r, "</struct>"); + + return r - pr; +} + +/* + * Internal limitation: xmlrpc request must NOT be more than + * BUFSIZ * 2 (16384 bytes) + */ +static gint +lua_xmlrpc_make_request(lua_State *L) +{ + LUA_TRACE_POINT; + gchar databuf[BUFSIZ * 2]; + const gchar *func; + gint r, top, i, num; + double dnum; + + func = luaL_checkstring(L, 1); + + if (func) { + r = rspamd_snprintf(databuf, sizeof(databuf), + "<?xml version=\"1.0\" encoding=\"UTF-8\"?>" + "<methodCall><methodName>%s</methodName><params>", + func); + /* Extract arguments */ + top = lua_gettop(L); + /* Get additional options */ + for (i = 2; i <= top; i++) { + r += rspamd_snprintf(databuf + r, + sizeof(databuf) - r, + "<param><value>"); + switch (lua_type(L, i)) { + case LUA_TNUMBER: + num = lua_tointeger(L, i); + dnum = lua_tonumber(L, i); + + /* Try to avoid conversion errors */ + if (dnum != (double) num) { + r += rspamd_snprintf(databuf + r, + sizeof(databuf) - r, + "<double>%f</double>", + dnum); + } + else { + r += rspamd_snprintf(databuf + r, + sizeof(databuf) - r, + "<int>%d</int>", + num); + } + break; + case LUA_TBOOLEAN: + r += rspamd_snprintf(databuf + r, + sizeof(databuf) - r, + "<boolean>%d</boolean>", + lua_toboolean(L, i) ? 1 : 0); + break; + case LUA_TSTRING: + r += rspamd_snprintf(databuf + r, + sizeof(databuf) - r, + "<string>%s</string>", + lua_tostring(L, i)); + break; + case LUA_TTABLE: + r += + lua_xmlrpc_parse_table(L, i, databuf, r, sizeof(databuf)); + break; + } + r += rspamd_snprintf(databuf + r, + sizeof(databuf) - r, + "</value></param>"); + } + + r += rspamd_snprintf(databuf + r, + sizeof(databuf) - r, + "</params></methodCall>"); + lua_pushlstring(L, databuf, r); + } + else { + lua_pushnil(L); + } + + return 1; +} + +static gint +lua_load_xmlrpc(lua_State *L) +{ + lua_newtable(L); + luaL_register(L, NULL, xmlrpclib_m); + + return 1; +} + +void luaopen_xmlrpc(lua_State *L) +{ + rspamd_lua_add_preload(L, "rspamd_xmlrpc", lua_load_xmlrpc); +} diff --git a/src/lua/rspamd.luadoc b/src/lua/rspamd.luadoc new file mode 100644 index 0000000..7f2c5cc --- /dev/null +++ b/src/lua/rspamd.luadoc @@ -0,0 +1,124 @@ +--- Rspamd interaction package +-- contains several subclasses: +-- config - for parsing config files +-- metric - for handling metrics callbacks +-- task - for interaction with task object +-- message - gate to GMime functions +-- textpart - a single textual part of message +module Rspamd + +--- Each lua module has global rspamd_config that can be used for getting config +-- options and registering callbacks (via metric interface) + +------------------------------------- CONFIG ----------------------------------------- +-- +--- Get module option from config +-- @param mname module name +-- @param option option +-- @return string with value +function config:get_module_opt (mname, option) + +--- Get all module options as a table like ['param' => 'value'] +-- @param mname module name +-- @return table with options +function config:get_all_opt (mname) + +--- Get specified metric +-- @param name metric name +-- @return metric object +function config:get_metric (name) + +------------------------------------- METRIC ----------------------------------------- + +--- Register symbol in metric +-- @param symbol name of symbol +-- @param weight weight of symbol +-- @param callback function that would be called as callback for symbol +function metric:register_symbol (symbol, weight, callback) + +------------------------------------- TASK ------------------------------------------- + +--- Get message object from task +-- @return message object +function task:get_message () + +--- Insert result to specified metric with specified weight (obsoleted) +-- @param metric metric name +-- @param symbol symbol name +-- @param weight weight of symbol +function task:insert_result (metric, symbol, weight) + +--- Get all urls as array +-- @return array of urls in textual form +function task:get_urls () + +--- Get all text parts +-- @return array of textpart objects +function task:get_text_parts () + +--- Get raw headers +-- @return string that contains raw headers +function task:get_raw_headers () + +--- Get array of received headers +-- @return array of received headers that are tables itself +function task:get_received_headers () + +--- Resolve A record using rspamd async resolver +-- @param host host to resolve +-- @param callback name of callback function +function task:resolve_dns_a (host, callback) + +--- Resolve PTR record using rspamd async resolver +-- @param host host to resolve +-- @param callback name of callback function +function task:resolve_dns_ptr (host, callback) + +--- Callback function for dns resolving +-- @param task task object +-- @param to_resolve ptr or a record that was resolved +-- @param results results of dns query (array or nil) +-- @param err resolver error or nil +function dns_cb(task, to_resolve, results, err) + +------------------------------------- TEXTPART --------------------------------------- + +--- Get part's content +-- @return string that contains part's content +function textpart:get_content () + +--- Check if part is empty +-- @return boolean value +function textpart:is_empty () + +--- Check if part is html +-- @return boolean value +function textpart:is_html () + +--- Get part's fuzzy +-- @return string that contains part's fuzzy +function textpart:get_fuzzy () + +------------------------------------- MESSAGE ---------------------------------------- + +--- Get message subject +-- @return message subject +function message:get_subject () + +--- Get message id +-- @return message id +function message:get_message_id () + +--- Get sender of message +-- @return sender's credits +function message:get_sender () + +--- Get reply-to field +-- @return value of reply-to header +function message:get_reply_to () + +--- Get header +-- @param header_name name of header +-- @return array of headers with specified name +function message:get_header (header_name) + |