diff options
Diffstat (limited to 'src/libstat/backends/redis_backend.cxx')
-rw-r--r-- | src/libstat/backends/redis_backend.cxx | 1132 |
1 files changed, 1132 insertions, 0 deletions
diff --git a/src/libstat/backends/redis_backend.cxx b/src/libstat/backends/redis_backend.cxx new file mode 100644 index 0000000..cd0c379 --- /dev/null +++ b/src/libstat/backends/redis_backend.cxx @@ -0,0 +1,1132 @@ +/* + * Copyright 2024 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "config.h" +#include "lua/lua_common.h" +#include "rspamd.h" +#include "stat_internal.h" +#include "upstream.h" +#include "libserver/mempool_vars_internal.h" +#include "fmt/core.h" + +#include "libutil/cxx/error.hxx" + +#include <string> +#include <cstdint> +#include <vector> +#include <optional> + +#define msg_debug_stat_redis(...) rspamd_conditional_debug_fast(nullptr, nullptr, \ + rspamd_stat_redis_log_id, "stat_redis", task->task_pool->tag.uid, \ + RSPAMD_LOG_FUNC, \ + __VA_ARGS__) + +INIT_LOG_MODULE(stat_redis) + +#define REDIS_CTX(p) (reinterpret_cast<struct redis_stat_ctx *>(p)) +#define REDIS_RUNTIME(p) (reinterpret_cast<struct redis_stat_runtime<float> *>(p)) +#define REDIS_DEFAULT_OBJECT "%s%l" +#define REDIS_DEFAULT_USERS_OBJECT "%s%l%r" +#define REDIS_DEFAULT_TIMEOUT 0.5 +#define REDIS_STAT_TIMEOUT 30 +#define REDIS_MAX_USERS 1000 + +struct redis_stat_ctx { + lua_State *L; + struct rspamd_statfile_config *stcf; + const char *redis_object = REDIS_DEFAULT_OBJECT; + bool enable_users = false; + bool store_tokens = false; + bool enable_signatures = false; + int cbref_user = -1; + + int cbref_classify = -1; + int cbref_learn = -1; + + ucl_object_t *cur_stat = nullptr; + + explicit redis_stat_ctx(lua_State *_L) + : L(_L) + { + } + + ~redis_stat_ctx() + { + if (cbref_user != -1) { + luaL_unref(L, LUA_REGISTRYINDEX, cbref_user); + } + + if (cbref_classify != -1) { + luaL_unref(L, LUA_REGISTRYINDEX, cbref_classify); + } + + if (cbref_learn != -1) { + luaL_unref(L, LUA_REGISTRYINDEX, cbref_learn); + } + } +}; + + +template<class T, std::enable_if_t<std::is_convertible_v<T, float>, bool> = true> +struct redis_stat_runtime { + struct redis_stat_ctx *ctx; + struct rspamd_task *task; + struct rspamd_statfile_config *stcf; + GPtrArray *tokens = nullptr; + const char *redis_object_expanded; + std::uint64_t learned = 0; + int id; + std::vector<std::pair<int, T>> *results = nullptr; + bool need_redis_call = true; + std::optional<rspamd::util::error> err; + + using result_type = std::vector<std::pair<int, T>>; + +private: + /* Called on connection termination */ + static void rt_dtor(gpointer data) + { + auto *rt = REDIS_RUNTIME(data); + + delete rt; + } + + /* Avoid occasional deletion */ + ~redis_stat_runtime() + { + if (tokens) { + g_ptr_array_unref(tokens); + } + + delete results; + } + +public: + explicit redis_stat_runtime(struct redis_stat_ctx *_ctx, struct rspamd_task *_task, const char *_redis_object_expanded) + : ctx(_ctx), task(_task), stcf(_ctx->stcf), redis_object_expanded(_redis_object_expanded) + { + rspamd_mempool_add_destructor(task->task_pool, redis_stat_runtime<T>::rt_dtor, this); + } + + static auto maybe_recover_from_mempool(struct rspamd_task *task, const char *redis_object_expanded, + bool is_spam) -> std::optional<redis_stat_runtime<T> *> + { + auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H"); + auto *res = rspamd_mempool_get_variable(task->task_pool, var_name.c_str()); + + if (res) { + msg_debug_bayes("recovered runtime from mempool at %s", var_name.c_str()); + return reinterpret_cast<redis_stat_runtime<T> *>(res); + } + else { + msg_debug_bayes("no runtime at %s", var_name.c_str()); + return std::nullopt; + } + } + + void set_results(std::vector<std::pair<int, T>> *results) + { + this->results = results; + } + + /* Propagate results from internal representation to the tokens array */ + auto process_tokens(GPtrArray *tokens) const -> bool + { + rspamd_token_t *tok; + + if (!results) { + return false; + } + + for (auto [idx, val]: *results) { + tok = (rspamd_token_t *) g_ptr_array_index(tokens, idx - 1); + tok->values[id] = val; + } + + return true; + } + + auto save_in_mempool(bool is_spam) const + { + auto var_name = fmt::format("{}_{}", redis_object_expanded, is_spam ? "S" : "H"); + /* We do not set destructor for the variable, as it should be already added on creation */ + rspamd_mempool_set_variable(task->task_pool, var_name.c_str(), (gpointer) this, nullptr); + msg_debug_bayes("saved runtime in mempool at %s", var_name.c_str()); + } +}; + +#define GET_TASK_ELT(task, elt) (task == nullptr ? nullptr : (task)->elt) + +static const gchar *M = "redis statistics"; + +static GQuark +rspamd_redis_stat_quark(void) +{ + return g_quark_from_static_string(M); +} + +/* + * Non-static for lua unit testing + */ +gsize rspamd_redis_expand_object(const gchar *pattern, + struct redis_stat_ctx *ctx, + struct rspamd_task *task, + gchar **target) +{ + gsize tlen = 0; + const gchar *p = pattern, *elt; + gchar *d, *end; + enum { + just_char, + percent_char, + mod_char + } state = just_char; + struct rspamd_statfile_config *stcf; + lua_State *L = nullptr; + struct rspamd_task **ptask; + const gchar *rcpt = nullptr; + gint err_idx; + + g_assert(ctx != nullptr); + g_assert(task != nullptr); + stcf = ctx->stcf; + + L = RSPAMD_LUA_CFG_STATE(task->cfg); + g_assert(L != nullptr); + + if (ctx->enable_users) { + if (ctx->cbref_user == -1) { + rcpt = rspamd_task_get_principal_recipient(task); + } + else { + /* Execute lua function to get userdata */ + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + + lua_rawgeti(L, LUA_REGISTRYINDEX, ctx->cbref_user); + ptask = (struct rspamd_task **) lua_newuserdata(L, sizeof(struct rspamd_task *)); + *ptask = task; + rspamd_lua_setclass(L, "rspamd{task}", -1); + + if (lua_pcall(L, 1, 1, err_idx) != 0) { + msg_err_task("call to user extraction script failed: %s", + lua_tostring(L, -1)); + } + else { + rcpt = rspamd_mempool_strdup(task->task_pool, lua_tostring(L, -1)); + } + + /* Result + error function */ + lua_settop(L, err_idx - 1); + } + + if (rcpt) { + rspamd_mempool_set_variable(task->task_pool, "stat_user", + (gpointer) rcpt, nullptr); + } + } + + /* Length calculation */ + while (*p) { + switch (state) { + case just_char: + if (*p == '%') { + state = percent_char; + } + else { + tlen++; + } + p++; + break; + case percent_char: + switch (*p) { + case '%': + tlen++; + state = just_char; + break; + case 'u': + elt = GET_TASK_ELT(task, auth_user); + if (elt) { + tlen += strlen(elt); + } + break; + case 'r': + + if (rcpt == nullptr) { + elt = rspamd_task_get_principal_recipient(task); + } + else { + elt = rcpt; + } + + if (elt) { + tlen += strlen(elt); + } + break; + case 'l': + if (stcf->label) { + tlen += strlen(stcf->label); + } + /* Label miss is OK */ + break; + case 's': + tlen += sizeof("RS") - 1; + break; + default: + state = just_char; + tlen++; + break; + } + + if (state == percent_char) { + state = mod_char; + } + p++; + break; + + case mod_char: + switch (*p) { + case 'd': + p++; + state = just_char; + break; + default: + state = just_char; + break; + } + break; + } + } + + + if (target == nullptr) { + return -1; + } + + *target = (gchar *) rspamd_mempool_alloc(task->task_pool, tlen + 1); + d = *target; + end = d + tlen + 1; + d[tlen] = '\0'; + p = pattern; + state = just_char; + + /* Expand string */ + while (*p && d < end) { + switch (state) { + case just_char: + if (*p == '%') { + state = percent_char; + } + else { + *d++ = *p; + } + p++; + break; + case percent_char: + switch (*p) { + case '%': + *d++ = *p; + state = just_char; + break; + case 'u': + elt = GET_TASK_ELT(task, auth_user); + if (elt) { + d += rspamd_strlcpy(d, elt, end - d); + } + break; + case 'r': + if (rcpt == nullptr) { + elt = rspamd_task_get_principal_recipient(task); + } + else { + elt = rcpt; + } + + if (elt) { + d += rspamd_strlcpy(d, elt, end - d); + } + break; + case 'l': + if (stcf->label) { + d += rspamd_strlcpy(d, stcf->label, end - d); + } + break; + case 's': + d += rspamd_strlcpy(d, "RS", end - d); + break; + default: + state = just_char; + *d++ = *p; + break; + } + + if (state == percent_char) { + state = mod_char; + } + p++; + break; + + case mod_char: + switch (*p) { + case 'd': + /* TODO: not supported yet */ + p++; + state = just_char; + break; + default: + state = just_char; + break; + } + break; + } + } + + return tlen; +} + +static int +rspamd_redis_stat_cb(lua_State *L) +{ + const auto *cookie = lua_tostring(L, lua_upvalueindex(1)); + auto *cfg = lua_check_config(L, 1); + auto *backend = REDIS_CTX(rspamd_mempool_get_variable(cfg->cfg_pool, cookie)); + + if (backend == nullptr) { + msg_err("internal error: cookie %s is not found", cookie); + + return 0; + } + + auto *cur_obj = ucl_object_lua_import(L, 2); + msg_debug_bayes_cfg("got stat object for %s", backend->stcf->symbol); + /* Enrich with some default values that are meaningless for redis */ + ucl_object_insert_key(cur_obj, + ucl_object_typed_new(UCL_INT), "used", 0, false); + ucl_object_insert_key(cur_obj, + ucl_object_typed_new(UCL_INT), "total", 0, false); + ucl_object_insert_key(cur_obj, + ucl_object_typed_new(UCL_INT), "size", 0, false); + ucl_object_insert_key(cur_obj, + ucl_object_fromstring(backend->stcf->symbol), + "symbol", 0, false); + ucl_object_insert_key(cur_obj, ucl_object_fromstring("redis"), + "type", 0, false); + ucl_object_insert_key(cur_obj, ucl_object_fromint(0), + "languages", 0, false); + + if (backend->cur_stat) { + ucl_object_unref(backend->cur_stat); + } + + backend->cur_stat = cur_obj; + + return 0; +} + +static void +rspamd_redis_parse_classifier_opts(struct redis_stat_ctx *backend, + const ucl_object_t *statfile_obj, + const ucl_object_t *classifier_obj, + struct rspamd_config *cfg) +{ + const gchar *lua_script; + const ucl_object_t *elt, *users_enabled; + auto *L = RSPAMD_LUA_CFG_STATE(cfg); + + users_enabled = ucl_object_lookup_any(classifier_obj, "per_user", + "users_enabled", nullptr); + + if (users_enabled != nullptr) { + if (ucl_object_type(users_enabled) == UCL_BOOLEAN) { + backend->enable_users = ucl_object_toboolean(users_enabled); + backend->cbref_user = -1; + } + else if (ucl_object_type(users_enabled) == UCL_STRING) { + lua_script = ucl_object_tostring(users_enabled); + + if (luaL_dostring(L, lua_script) != 0) { + msg_err_config("cannot execute lua script for users " + "extraction: %s", + lua_tostring(L, -1)); + } + else { + if (lua_type(L, -1) == LUA_TFUNCTION) { + backend->enable_users = TRUE; + backend->cbref_user = luaL_ref(L, + LUA_REGISTRYINDEX); + } + else { + msg_err_config("lua script must return " + "function(task) and not %s", + lua_typename(L, lua_type(L, -1))); + } + } + } + } + else { + backend->enable_users = FALSE; + backend->cbref_user = -1; + } + + elt = ucl_object_lookup(classifier_obj, "prefix"); + if (elt == nullptr || ucl_object_type(elt) != UCL_STRING) { + /* Default non-users statistics */ + if (backend->enable_users || backend->cbref_user != -1) { + backend->redis_object = REDIS_DEFAULT_USERS_OBJECT; + } + else { + backend->redis_object = REDIS_DEFAULT_OBJECT; + } + } + else { + /* XXX: sanity check */ + backend->redis_object = ucl_object_tostring(elt); + } + + elt = ucl_object_lookup(classifier_obj, "store_tokens"); + if (elt) { + backend->store_tokens = ucl_object_toboolean(elt); + } + else { + backend->store_tokens = FALSE; + } + + elt = ucl_object_lookup(classifier_obj, "signatures"); + if (elt) { + backend->enable_signatures = ucl_object_toboolean(elt); + } + else { + backend->enable_signatures = FALSE; + } +} + +gpointer +rspamd_redis_init(struct rspamd_stat_ctx *ctx, + struct rspamd_config *cfg, struct rspamd_statfile *st) +{ + auto *L = RSPAMD_LUA_CFG_STATE(cfg); + + auto backend = std::make_unique<struct redis_stat_ctx>(L); + lua_settop(L, 0); + + rspamd_redis_parse_classifier_opts(backend.get(), st->stcf->opts, st->classifier->cfg->opts, cfg); + + st->stcf->clcf->flags |= RSPAMD_FLAG_CLASSIFIER_INCREMENTING_BACKEND; + backend->stcf = st->stcf; + + lua_pushcfunction(L, &rspamd_lua_traceback); + auto err_idx = lua_gettop(L); + + /* Obtain function */ + if (!rspamd_lua_require_function(L, "lua_bayes_redis", "lua_bayes_init_statfile")) { + msg_err_config("cannot require lua_bayes_redis.lua_bayes_init_statfile"); + lua_settop(L, err_idx - 1); + + return nullptr; + } + + /* Push arguments */ + ucl_object_push_lua(L, st->classifier->cfg->opts, false); + ucl_object_push_lua(L, st->stcf->opts, false); + lua_pushstring(L, backend->stcf->symbol); + lua_pushboolean(L, backend->stcf->is_spam); + auto **pev_base = (struct ev_loop **) lua_newuserdata(L, sizeof(struct ev_loop *)); + *pev_base = ctx->event_loop; + rspamd_lua_setclass(L, "rspamd{ev_base}", -1); + + /* Store backend in random cookie */ + char *cookie = (char *) rspamd_mempool_alloc(cfg->cfg_pool, 16); + rspamd_random_hex(cookie, 16); + cookie[15] = '\0'; + rspamd_mempool_set_variable(cfg->cfg_pool, cookie, backend.get(), nullptr); + /* Callback + 1 upvalue */ + lua_pushstring(L, cookie); + lua_pushcclosure(L, &rspamd_redis_stat_cb, 1); + + if (lua_pcall(L, 6, 2, err_idx) != 0) { + msg_err("call to lua_bayes_init_classifier " + "script failed: %s", + lua_tostring(L, -1)); + lua_settop(L, err_idx - 1); + + return nullptr; + } + + /* Results are in the stack: + * top - 1 - classifier function (idx = -2) + * top - learn function (idx = -1) + */ + + lua_pushvalue(L, -2); + backend->cbref_classify = luaL_ref(L, LUA_REGISTRYINDEX); + + lua_pushvalue(L, -1); + backend->cbref_learn = luaL_ref(L, LUA_REGISTRYINDEX); + + lua_settop(L, err_idx - 1); + + return backend.release(); +} + +gpointer +rspamd_redis_runtime(struct rspamd_task *task, + struct rspamd_statfile_config *stcf, + gboolean learn, gpointer c, gint _id) +{ + struct redis_stat_ctx *ctx = REDIS_CTX(c); + char *object_expanded = nullptr; + + g_assert(ctx != nullptr); + g_assert(stcf != nullptr); + + if (rspamd_redis_expand_object(ctx->redis_object, ctx, task, + &object_expanded) == 0) { + msg_err_task("expansion for %s failed for symbol %s " + "(maybe learning per user classifier with no user or recipient)", + learn ? "learning" : "classifying", + stcf->symbol); + return nullptr; + } + + /* Look for the cached results */ + if (!learn) { + auto maybe_existing = redis_stat_runtime<float>::maybe_recover_from_mempool(task, + object_expanded, stcf->is_spam); + + if (maybe_existing) { + auto *rt = maybe_existing.value(); + /* Update stcf and ctx to correspond to what we have been asked */ + rt->stcf = stcf; + rt->ctx = ctx; + return rt; + } + } + + /* No cached result (or learn), create new one */ + auto *rt = new redis_stat_runtime<float>(ctx, task, object_expanded); + + if (!learn) { + /* + * For check, we also need to create the opposite class runtime to avoid + * double call for Redis scripts. + * This runtime will be filled later. + */ + auto maybe_opposite_rt = redis_stat_runtime<float>::maybe_recover_from_mempool(task, + object_expanded, + !stcf->is_spam); + + if (!maybe_opposite_rt) { + auto *opposite_rt = new redis_stat_runtime<float>(ctx, task, object_expanded); + opposite_rt->save_in_mempool(!stcf->is_spam); + opposite_rt->need_redis_call = false; + } + } + + rt->save_in_mempool(stcf->is_spam); + + return rt; +} + +void rspamd_redis_close(gpointer p) +{ + struct redis_stat_ctx *ctx = REDIS_CTX(p); + delete ctx; +} + +static constexpr auto +msgpack_emit_str(const std::string_view st, char *out) -> std::size_t +{ + auto len = st.size(); + constexpr const unsigned char fix_mask = 0xA0, l8_ch = 0xd9, l16_ch = 0xda, l32_ch = 0xdb; + auto blen = 0; + if (len <= 0x1F) { + blen = 1; + out[0] = (len | fix_mask) & 0xff; + } + else if (len <= 0xff) { + blen = 2; + out[0] = l8_ch; + out[1] = len & 0xff; + } + else if (len <= 0xffff) { + uint16_t bl = GUINT16_TO_BE(len); + + blen = 3; + out[0] = l16_ch; + memcpy(&out[1], &bl, sizeof(bl)); + } + else { + uint32_t bl = GUINT32_TO_BE(len); + + blen = 5; + out[0] = l32_ch; + memcpy(&out[1], &bl, sizeof(bl)); + } + + memcpy(&out[blen], st.data(), st.size()); + + return blen + len; +} + +static constexpr auto +msgpack_str_len(std::size_t len) -> std::size_t +{ + if (len <= 0x1F) { + return 1 + len; + } + else if (len <= 0xff) { + return 2 + len; + } + else if (len <= 0xffff) { + return 3 + len; + } + else { + return 4 + len; + } +} + +/* + * Serialise stat tokens to message pack + */ +static char * +rspamd_redis_serialize_tokens(struct rspamd_task *task, const gchar *prefix, GPtrArray *tokens, gsize *ser_len) +{ + /* Each token is int64_t that requires 10 bytes (2 int32_t) + 4 bytes array len + 1 byte array magic */ + char max_int64_str[] = "18446744073709551615"; + auto prefix_len = strlen(prefix); + std::size_t req_len = 5; + rspamd_token_t *tok; + + /* Calculate required length */ + req_len += tokens->len * (msgpack_str_len(sizeof(max_int64_str) + prefix_len) + 1); + + auto *buf = (gchar *) rspamd_mempool_alloc(task->task_pool, req_len); + auto *p = buf; + + /* Array */ + *p++ = (gchar) 0xdd; + /* Length in big-endian (4 bytes) */ + *p++ = (gchar) ((tokens->len >> 24) & 0xff); + *p++ = (gchar) ((tokens->len >> 16) & 0xff); + *p++ = (gchar) ((tokens->len >> 8) & 0xff); + *p++ = (gchar) (tokens->len & 0xff); + + + int i; + auto numbuf_len = sizeof(max_int64_str) + prefix_len + 1; + auto *numbuf = (char *) g_alloca(numbuf_len); + + PTR_ARRAY_FOREACH(tokens, i, tok) + { + std::size_t r = rspamd_snprintf(numbuf, numbuf_len, "%s_%uL", prefix, tok->data); + auto shift = msgpack_emit_str({numbuf, r}, p); + p += shift; + } + + *ser_len = p - buf; + + return buf; +} + +static char * +rspamd_redis_serialize_text_tokens(struct rspamd_task *task, GPtrArray *tokens, gsize *ser_len) +{ + rspamd_token_t *tok; + auto req_len = 5; /* Messagepack array prefix */ + int i; + + /* + * First we need to determine the requested length + */ + PTR_ARRAY_FOREACH(tokens, i, tok) + { + if (tok->t1 && tok->t2) { + /* Two tokens */ + req_len += msgpack_str_len(tok->t1->stemmed.len) + msgpack_str_len(tok->t2->stemmed.len); + } + else if (tok->t1) { + req_len += msgpack_str_len(tok->t1->stemmed.len); + req_len += 1; /* null */ + } + else { + req_len += 2; /* 2 nulls */ + } + } + + auto *buf = (gchar *) rspamd_mempool_alloc(task->task_pool, req_len); + auto *p = buf; + + /* Array */ + std::uint32_t nlen = tokens->len * 2; + nlen = GUINT32_TO_BE(nlen); + *p++ = (gchar) 0xdd; + /* Length in big-endian (4 bytes) */ + memcpy(p, &nlen, sizeof(nlen)); + p += sizeof(nlen); + + PTR_ARRAY_FOREACH(tokens, i, tok) + { + if (tok->t1 && tok->t2) { + auto step = msgpack_emit_str({tok->t1->stemmed.begin, tok->t1->stemmed.len}, p); + p += step; + step = msgpack_emit_str({tok->t2->stemmed.begin, tok->t2->stemmed.len}, p); + p += step; + } + else if (tok->t1) { + auto step = msgpack_emit_str({tok->t1->stemmed.begin, tok->t1->stemmed.len}, p); + p += step; + *p++ = 0xc0; + } + else { + *p++ = 0xc0; + *p++ = 0xc0; + } + } + + *ser_len = p - buf; + + return buf; +} + +static gint +rspamd_redis_classified(lua_State *L) +{ + const auto *cookie = lua_tostring(L, lua_upvalueindex(1)); + auto *task = lua_check_task(L, 1); + auto *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie)); + + if (rt == nullptr) { + msg_err_task("internal error: cannot find runtime for cookie %s", cookie); + + return 0; + } + + bool result = lua_toboolean(L, 2); + + if (result) { + /* Indexes: + * 3 - learned_ham (int) + * 4 - learned_spam (int) + * 5 - ham_tokens (pair<int, int>) + * 6 - spam_tokens (pair<int, int>) + */ + + /* + * We need to fill our runtime AND the opposite runtime + */ + auto filler_func = [](redis_stat_runtime<float> *rt, lua_State *L, unsigned learned, int tokens_pos) { + rt->learned = learned; + redis_stat_runtime<float>::result_type *res; + + res = new redis_stat_runtime<float>::result_type(); + + for (lua_pushnil(L); lua_next(L, tokens_pos); lua_pop(L, 1)) { + lua_rawgeti(L, -1, 1); + auto idx = lua_tointeger(L, -1); + lua_pop(L, 1); + + lua_rawgeti(L, -1, 2); + auto value = lua_tonumber(L, -1); + lua_pop(L, 1); + + res->emplace_back(idx, value); + } + + rt->set_results(res); + }; + + auto opposite_rt_maybe = redis_stat_runtime<float>::maybe_recover_from_mempool(task, + rt->redis_object_expanded, + !rt->stcf->is_spam); + + if (!opposite_rt_maybe) { + msg_err_task("internal error: cannot find opposite runtime for cookie %s", cookie); + + return 0; + } + + if (rt->stcf->is_spam) { + filler_func(rt, L, lua_tointeger(L, 4), 6); + filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 3), 5); + } + else { + filler_func(rt, L, lua_tointeger(L, 3), 5); + filler_func(opposite_rt_maybe.value(), L, lua_tointeger(L, 4), 6); + } + + /* Mark task as being processed */ + task->flags |= RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS | RSPAMD_TASK_FLAG_HAS_HAM_TOKENS; + + /* Process all tokens */ + g_assert(rt->tokens != nullptr); + rt->process_tokens(rt->tokens); + opposite_rt_maybe.value()->process_tokens(rt->tokens); + } + else { + /* Error message is on index 3 */ + const auto *err_msg = lua_tostring(L, 3); + rt->err = rspamd::util::error(err_msg, 500); + msg_err_task("cannot classify task: %s", + err_msg); + } + + return 0; +} + +gboolean +rspamd_redis_process_tokens(struct rspamd_task *task, + GPtrArray *tokens, + gint id, gpointer p) +{ + auto *rt = REDIS_RUNTIME(p); + auto *L = rt->ctx->L; + + if (rspamd_session_blocked(task->s)) { + return FALSE; + } + + if (tokens == nullptr || tokens->len == 0) { + return FALSE; + } + + if (!rt->need_redis_call) { + /* No need to do anything, as it is already done in the opposite class processing */ + /* However, we need to store id as it is needed for further tokens processing */ + rt->id = id; + rt->tokens = g_ptr_array_ref(tokens); + + return TRUE; + } + + gsize tokens_len; + gchar *tokens_buf = rspamd_redis_serialize_tokens(task, rt->redis_object_expanded, tokens, &tokens_len); + rt->id = id; + + lua_pushcfunction(L, &rspamd_lua_traceback); + gint err_idx = lua_gettop(L); + + /* Function arguments */ + lua_rawgeti(L, LUA_REGISTRYINDEX, rt->ctx->cbref_classify); + rspamd_lua_task_push(L, task); + lua_pushstring(L, rt->redis_object_expanded); + lua_pushinteger(L, id); + lua_pushboolean(L, rt->stcf->is_spam); + lua_new_text(L, tokens_buf, tokens_len, false); + + /* Store rt in random cookie */ + char *cookie = (char *) rspamd_mempool_alloc(task->task_pool, 16); + rspamd_random_hex(cookie, 16); + cookie[15] = '\0'; + rspamd_mempool_set_variable(task->task_pool, cookie, rt, nullptr); + /* Callback */ + lua_pushstring(L, cookie); + lua_pushcclosure(L, &rspamd_redis_classified, 1); + + if (lua_pcall(L, 6, 0, err_idx) != 0) { + msg_err_task("call to redis failed: %s", lua_tostring(L, -1)); + lua_settop(L, err_idx - 1); + return FALSE; + } + + rt->tokens = g_ptr_array_ref(tokens); + + lua_settop(L, err_idx - 1); + return TRUE; +} + +gboolean +rspamd_redis_finalize_process(struct rspamd_task *task, gpointer runtime, + gpointer ctx) +{ + auto *rt = REDIS_RUNTIME(runtime); + + return !rt->err.has_value(); +} + + +static gint +rspamd_redis_learned(lua_State *L) +{ + const auto *cookie = lua_tostring(L, lua_upvalueindex(1)); + auto *task = lua_check_task(L, 1); + auto *rt = REDIS_RUNTIME(rspamd_mempool_get_variable(task->task_pool, cookie)); + + if (rt == nullptr) { + msg_err_task("internal error: cannot find runtime for cookie %s", cookie); + + return 0; + } + + bool result = lua_toboolean(L, 2); + + if (result) { + /* TODO: write it */ + } + else { + /* Error message is on index 3 */ + const auto *err_msg = lua_tostring(L, 3); + rt->err = rspamd::util::error(err_msg, 500); + msg_err_task("cannot learn task: %s", err_msg); + } + + return 0; +} + +gboolean +rspamd_redis_learn_tokens(struct rspamd_task *task, + GPtrArray *tokens, + gint id, gpointer p) +{ + auto *rt = REDIS_RUNTIME(p); + auto *L = rt->ctx->L; + + if (rspamd_session_blocked(task->s)) { + return FALSE; + } + + if (tokens == nullptr || tokens->len == 0) { + return FALSE; + } + + gsize tokens_len; + gchar *tokens_buf = rspamd_redis_serialize_tokens(task, rt->redis_object_expanded, tokens, &tokens_len); + + rt->id = id; + + gsize text_tokens_len = 0; + gchar *text_tokens_buf = nullptr; + + if (rt->ctx->store_tokens) { + text_tokens_buf = rspamd_redis_serialize_text_tokens(task, tokens, &text_tokens_len); + } + + lua_pushcfunction(L, &rspamd_lua_traceback); + gint err_idx = lua_gettop(L); + auto nargs = 8; + + /* Function arguments */ + lua_rawgeti(L, LUA_REGISTRYINDEX, rt->ctx->cbref_learn); + rspamd_lua_task_push(L, task); + lua_pushstring(L, rt->redis_object_expanded); + lua_pushinteger(L, id); + lua_pushboolean(L, rt->stcf->is_spam); + lua_pushstring(L, rt->stcf->symbol); + + /* Detect unlearn */ + auto *tok = (rspamd_token_t *) g_ptr_array_index(task->tokens, 0); + + if (tok->values[id] > 0) { + lua_pushboolean(L, FALSE);// Learn + } + else { + lua_pushboolean(L, TRUE);// Unlearn + } + lua_new_text(L, tokens_buf, tokens_len, false); + + /* Store rt in random cookie */ + char *cookie = (char *) rspamd_mempool_alloc(task->task_pool, 16); + rspamd_random_hex(cookie, 16); + cookie[15] = '\0'; + rspamd_mempool_set_variable(task->task_pool, cookie, rt, nullptr); + /* Callback */ + lua_pushstring(L, cookie); + lua_pushcclosure(L, &rspamd_redis_learned, 1); + + if (text_tokens_len) { + nargs = 9; + lua_new_text(L, text_tokens_buf, text_tokens_len, false); + } + + if (lua_pcall(L, nargs, 0, err_idx) != 0) { + msg_err_task("call to script failed: %s", lua_tostring(L, -1)); + lua_settop(L, err_idx - 1); + return FALSE; + } + + rt->tokens = g_ptr_array_ref(tokens); + + lua_settop(L, err_idx - 1); + return TRUE; +} + + +gboolean +rspamd_redis_finalize_learn(struct rspamd_task *task, gpointer runtime, + gpointer ctx, GError **err) +{ + auto *rt = REDIS_RUNTIME(runtime); + + if (rt->err.has_value()) { + rt->err->into_g_error_set(rspamd_redis_stat_quark(), err); + + return FALSE; + } + + return TRUE; +} + +gulong +rspamd_redis_total_learns(struct rspamd_task *task, gpointer runtime, + gpointer ctx) +{ + auto *rt = REDIS_RUNTIME(runtime); + + return rt->learned; +} + +gulong +rspamd_redis_inc_learns(struct rspamd_task *task, gpointer runtime, + gpointer ctx) +{ + auto *rt = REDIS_RUNTIME(runtime); + + /* XXX: may cause races */ + return rt->learned + 1; +} + +gulong +rspamd_redis_dec_learns(struct rspamd_task *task, gpointer runtime, + gpointer ctx) +{ + auto *rt = REDIS_RUNTIME(runtime); + + /* XXX: may cause races */ + return rt->learned + 1; +} + +gulong +rspamd_redis_learns(struct rspamd_task *task, gpointer runtime, + gpointer ctx) +{ + auto *rt = REDIS_RUNTIME(runtime); + + return rt->learned; +} + +ucl_object_t * +rspamd_redis_get_stat(gpointer runtime, + gpointer ctx) +{ + auto *rt = REDIS_RUNTIME(runtime); + + return ucl_object_ref(rt->ctx->cur_stat); +} + +gpointer +rspamd_redis_load_tokenizer_config(gpointer runtime, + gsize *len) +{ + return nullptr; +} |