diff options
Diffstat (limited to 'src/libstat/stat_process.c')
-rw-r--r-- | src/libstat/stat_process.c | 1250 |
1 files changed, 1250 insertions, 0 deletions
diff --git a/src/libstat/stat_process.c b/src/libstat/stat_process.c new file mode 100644 index 0000000..8c1d8ff --- /dev/null +++ b/src/libstat/stat_process.c @@ -0,0 +1,1250 @@ +/*- + * Copyright 2016 Vsevolod Stakhov + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "config.h" +#include "stat_api.h" +#include "rspamd.h" +#include "stat_internal.h" +#include "libmime/message.h" +#include "libmime/images.h" +#include "libserver/html/html.h" +#include "lua/lua_common.h" +#include "libserver/mempool_vars_internal.h" +#include "utlist.h" +#include <math.h> + +#define RSPAMD_CLASSIFY_OP 0 +#define RSPAMD_LEARN_OP 1 +#define RSPAMD_UNLEARN_OP 2 + +static const gdouble similarity_threshold = 80.0; + +static void +rspamd_stat_tokenize_parts_metadata(struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task) +{ + GArray *ar; + rspamd_stat_token_t elt; + guint i; + lua_State *L = task->cfg->lua_state; + + ar = g_array_sized_new(FALSE, FALSE, sizeof(elt), 16); + memset(&elt, 0, sizeof(elt)); + elt.flags = RSPAMD_STAT_TOKEN_FLAG_META; + + if (st_ctx->lua_stat_tokens_ref != -1) { + gint err_idx, ret; + struct rspamd_task **ptask; + + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + lua_rawgeti(L, LUA_REGISTRYINDEX, st_ctx->lua_stat_tokens_ref); + + ptask = lua_newuserdata(L, sizeof(*ptask)); + *ptask = task; + rspamd_lua_setclass(L, "rspamd{task}", -1); + + if ((ret = lua_pcall(L, 1, 1, err_idx)) != 0) { + msg_err_task("call to stat_tokens lua " + "script failed (%d): %s", + ret, lua_tostring(L, -1)); + } + else { + if (lua_type(L, -1) != LUA_TTABLE) { + msg_err_task("stat_tokens invocation must return " + "table and not %s", + lua_typename(L, lua_type(L, -1))); + } + else { + guint vlen; + rspamd_ftok_t tok; + + vlen = rspamd_lua_table_size(L, -1); + + for (i = 0; i < vlen; i++) { + lua_rawgeti(L, -1, i + 1); + tok.begin = lua_tolstring(L, -1, &tok.len); + + if (tok.begin && tok.len > 0) { + elt.original.begin = + rspamd_mempool_ftokdup(task->task_pool, &tok); + elt.original.len = tok.len; + elt.stemmed.begin = elt.original.begin; + elt.stemmed.len = elt.original.len; + elt.normalized.begin = elt.original.begin; + elt.normalized.len = elt.original.len; + + g_array_append_val(ar, elt); + } + + lua_pop(L, 1); + } + } + } + + lua_settop(L, 0); + } + + + if (ar->len > 0) { + st_ctx->tokenizer->tokenize_func(st_ctx, + task, + ar, + TRUE, + "M", + task->tokens); + } + + rspamd_mempool_add_destructor(task->task_pool, + rspamd_array_free_hard, ar); +} + +/* + * Tokenize task using the tokenizer specified + */ +void rspamd_stat_process_tokenize(struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task) +{ + struct rspamd_mime_text_part *part; + rspamd_cryptobox_hash_state_t hst; + rspamd_token_t *st_tok; + guint i, reserved_len = 0; + gdouble *pdiff; + guchar hout[rspamd_cryptobox_HASHBYTES]; + gchar *b32_hout; + + if (st_ctx == NULL) { + st_ctx = rspamd_stat_get_ctx(); + } + + g_assert(st_ctx != NULL); + + PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, part) + { + if (!IS_TEXT_PART_EMPTY(part) && part->utf_words != NULL) { + reserved_len += part->utf_words->len; + } + /* XXX: normal window size */ + reserved_len += 5; + } + + task->tokens = g_ptr_array_sized_new(reserved_len); + rspamd_mempool_add_destructor(task->task_pool, + rspamd_ptr_array_free_hard, task->tokens); + rspamd_mempool_notify_alloc(task->task_pool, reserved_len * sizeof(gpointer)); + pdiff = rspamd_mempool_get_variable(task->task_pool, "parts_distance"); + + PTR_ARRAY_FOREACH(MESSAGE_FIELD(task, text_parts), i, part) + { + if (!IS_TEXT_PART_EMPTY(part) && part->utf_words != NULL) { + st_ctx->tokenizer->tokenize_func(st_ctx, task, + part->utf_words, IS_TEXT_PART_UTF(part), + NULL, task->tokens); + } + + + if (pdiff != NULL && (1.0 - *pdiff) * 100.0 > similarity_threshold) { + msg_debug_bayes("message has two common parts (%.2f), so skip the last one", + *pdiff); + break; + } + } + + if (task->meta_words != NULL) { + st_ctx->tokenizer->tokenize_func(st_ctx, + task, + task->meta_words, + TRUE, + "SUBJECT", + task->tokens); + } + + rspamd_stat_tokenize_parts_metadata(st_ctx, task); + + /* Produce signature */ + rspamd_cryptobox_hash_init(&hst, NULL, 0); + + PTR_ARRAY_FOREACH(task->tokens, i, st_tok) + { + rspamd_cryptobox_hash_update(&hst, (guchar *) &st_tok->data, + sizeof(st_tok->data)); + } + + rspamd_cryptobox_hash_final(&hst, hout); + b32_hout = rspamd_encode_base32(hout, sizeof(hout), RSPAMD_BASE32_DEFAULT); + /* + * We need to strip it to 32 characters providing ~160 bits of + * hash distribution + */ + b32_hout[32] = '\0'; + rspamd_mempool_set_variable(task->task_pool, RSPAMD_MEMPOOL_STAT_SIGNATURE, + b32_hout, g_free); +} + +static gboolean +rspamd_stat_classifier_is_skipped(struct rspamd_task *task, + struct rspamd_classifier *cl, gboolean is_learn, gboolean is_spam) +{ + GList *cur = is_learn ? cl->cfg->learn_conditions : cl->cfg->classify_conditions; + lua_State *L = task->cfg->lua_state; + gboolean ret = FALSE; + + while (cur) { + gint cb_ref = GPOINTER_TO_INT(cur->data); + gint old_top = lua_gettop(L); + gint nargs; + + lua_rawgeti(L, LUA_REGISTRYINDEX, cb_ref); + /* Push task and two booleans: is_spam and is_unlearn */ + struct rspamd_task **ptask = lua_newuserdata(L, sizeof(*ptask)); + *ptask = task; + rspamd_lua_setclass(L, "rspamd{task}", -1); + + if (is_learn) { + lua_pushboolean(L, is_spam); + lua_pushboolean(L, + task->flags & RSPAMD_TASK_FLAG_UNLEARN ? true : false); + nargs = 3; + } + else { + nargs = 1; + } + + if (lua_pcall(L, nargs, LUA_MULTRET, 0) != 0) { + msg_err_task("call to %s failed: %s", + "condition callback", + lua_tostring(L, -1)); + } + else { + if (lua_isboolean(L, 1)) { + if (!lua_toboolean(L, 1)) { + ret = TRUE; + } + } + + if (lua_isstring(L, 2)) { + if (ret) { + msg_notice_task("%s condition for classifier %s returned: %s; skip classifier", + is_learn ? "learn" : "classify", cl->cfg->name, + lua_tostring(L, 2)); + } + else { + msg_info_task("%s condition for classifier %s returned: %s", + is_learn ? "learn" : "classify", cl->cfg->name, + lua_tostring(L, 2)); + } + } + else if (ret) { + msg_notice_task("%s condition for classifier %s returned false; skip classifier", + is_learn ? "learn" : "classify", cl->cfg->name); + } + + if (ret) { + lua_settop(L, old_top); + break; + } + } + + lua_settop(L, old_top); + cur = g_list_next(cur); + } + + return ret; +} + +static void +rspamd_stat_preprocess(struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task, gboolean is_learn, gboolean is_spam) +{ + guint i; + struct rspamd_statfile *st; + gpointer bk_run; + + if (task->tokens == NULL) { + rspamd_stat_process_tokenize(st_ctx, task); + } + + task->stat_runtimes = g_ptr_array_sized_new(st_ctx->statfiles->len); + g_ptr_array_set_size(task->stat_runtimes, st_ctx->statfiles->len); + rspamd_mempool_add_destructor(task->task_pool, + rspamd_ptr_array_free_hard, task->stat_runtimes); + + /* Temporary set all stat_runtimes to some max size to distinguish from NULL */ + for (i = 0; i < st_ctx->statfiles->len; i++) { + g_ptr_array_index(task->stat_runtimes, i) = GSIZE_TO_POINTER(G_MAXSIZE); + } + + for (i = 0; i < st_ctx->classifiers->len; i++) { + struct rspamd_classifier *cl = g_ptr_array_index(st_ctx->classifiers, i); + gboolean skip_classifier = FALSE; + + if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) { + skip_classifier = TRUE; + } + else { + if (rspamd_stat_classifier_is_skipped(task, cl, is_learn, is_spam)) { + skip_classifier = TRUE; + } + } + + if (skip_classifier) { + /* Set NULL for all statfiles indexed by id */ + for (int j = 0; j < cl->statfiles_ids->len; j++) { + int id = g_array_index(cl->statfiles_ids, gint, j); + g_ptr_array_index(task->stat_runtimes, id) = NULL; + } + } + } + + for (i = 0; i < st_ctx->statfiles->len; i++) { + st = g_ptr_array_index(st_ctx->statfiles, i); + g_assert(st != NULL); + + if (g_ptr_array_index(task->stat_runtimes, i) == NULL) { + /* The whole classifier is skipped */ + continue; + } + + if (is_learn && st->backend->read_only) { + /* Read only backend, skip it */ + g_ptr_array_index(task->stat_runtimes, i) = NULL; + continue; + } + + if (!is_learn && !rspamd_symcache_is_symbol_enabled(task, task->cfg->cache, + st->stcf->symbol)) { + g_ptr_array_index(task->stat_runtimes, i) = NULL; + msg_debug_bayes("symbol %s is disabled, skip classification", + st->stcf->symbol); + continue; + } + + bk_run = st->backend->runtime(task, st->stcf, is_learn, st->bkcf, i); + + if (bk_run == NULL) { + msg_err_task("cannot init backend %s for statfile %s", + st->backend->name, st->stcf->symbol); + } + + g_ptr_array_index(task->stat_runtimes, i) = bk_run; + } +} + +static void +rspamd_stat_backends_process(struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task) +{ + guint i; + struct rspamd_statfile *st; + gpointer bk_run; + + g_assert(task->stat_runtimes != NULL); + + for (i = 0; i < st_ctx->statfiles->len; i++) { + st = g_ptr_array_index(st_ctx->statfiles, i); + bk_run = g_ptr_array_index(task->stat_runtimes, i); + + if (bk_run != NULL) { + st->backend->process_tokens(task, task->tokens, i, bk_run); + } + } +} + +static void +rspamd_stat_classifiers_process(struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task) +{ + guint i, j, id; + struct rspamd_classifier *cl; + struct rspamd_statfile *st; + gpointer bk_run; + gboolean skip; + + if (st_ctx->classifiers->len == 0) { + return; + } + + /* + * Do not classify a message if some class is missing + */ + if (!(task->flags & RSPAMD_TASK_FLAG_HAS_SPAM_TOKENS)) { + msg_info_task("skip statistics as SPAM class is missing"); + + return; + } + if (!(task->flags & RSPAMD_TASK_FLAG_HAS_HAM_TOKENS)) { + msg_info_task("skip statistics as HAM class is missing"); + + return; + } + + for (i = 0; i < st_ctx->classifiers->len; i++) { + cl = g_ptr_array_index(st_ctx->classifiers, i); + cl->spam_learns = 0; + cl->ham_learns = 0; + } + + g_assert(task->stat_runtimes != NULL); + + for (i = 0; i < st_ctx->statfiles->len; i++) { + st = g_ptr_array_index(st_ctx->statfiles, i); + cl = st->classifier; + + bk_run = g_ptr_array_index(task->stat_runtimes, i); + g_assert(st != NULL); + + if (bk_run != NULL) { + if (st->stcf->is_spam) { + cl->spam_learns += st->backend->total_learns(task, + bk_run, + st_ctx); + } + else { + cl->ham_learns += st->backend->total_learns(task, + bk_run, + st_ctx); + } + } + } + + for (i = 0; i < st_ctx->classifiers->len; i++) { + cl = g_ptr_array_index(st_ctx->classifiers, i); + + g_assert(cl != NULL); + + skip = FALSE; + + /* Do not process classifiers on backend failures */ + for (j = 0; j < cl->statfiles_ids->len; j++) { + id = g_array_index(cl->statfiles_ids, gint, j); + bk_run = g_ptr_array_index(task->stat_runtimes, id); + st = g_ptr_array_index(st_ctx->statfiles, id); + + if (bk_run != NULL) { + if (!st->backend->finalize_process(task, bk_run, st_ctx)) { + skip = TRUE; + break; + } + } + } + + /* Ensure that all symbols enabled */ + if (!skip && !(cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND)) { + for (j = 0; j < cl->statfiles_ids->len; j++) { + id = g_array_index(cl->statfiles_ids, gint, j); + bk_run = g_ptr_array_index(task->stat_runtimes, id); + st = g_ptr_array_index(st_ctx->statfiles, id); + + if (bk_run == NULL) { + skip = TRUE; + msg_debug_bayes("disable classifier %s as statfile symbol %s is disabled", + cl->cfg->name, st->stcf->symbol); + break; + } + } + } + + if (!skip) { + if (cl->cfg->min_tokens > 0 && task->tokens->len < cl->cfg->min_tokens) { + msg_debug_bayes( + "contains less tokens than required for %s classifier: " + "%ud < %ud", + cl->cfg->name, + task->tokens->len, + cl->cfg->min_tokens); + continue; + } + else if (cl->cfg->max_tokens > 0 && task->tokens->len > cl->cfg->max_tokens) { + msg_debug_bayes( + "contains more tokens than allowed for %s classifier: " + "%ud > %ud", + cl->cfg->name, + task->tokens->len, + cl->cfg->max_tokens); + continue; + } + + cl->subrs->classify_func(cl, task->tokens, task); + } + } +} + +rspamd_stat_result_t +rspamd_stat_classify(struct rspamd_task *task, lua_State *L, guint stage, + GError **err) +{ + struct rspamd_stat_ctx *st_ctx; + rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_OK; + + st_ctx = rspamd_stat_get_ctx(); + g_assert(st_ctx != NULL); + + if (st_ctx->classifiers->len == 0) { + task->processed_stages |= stage; + return ret; + } + + if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS_PRE) { + /* Preprocess tokens */ + rspamd_stat_preprocess(st_ctx, task, FALSE, FALSE); + } + else if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS) { + /* Process backends */ + rspamd_stat_backends_process(st_ctx, task); + } + else if (stage == RSPAMD_TASK_STAGE_CLASSIFIERS_POST) { + /* Process classifiers */ + rspamd_stat_classifiers_process(st_ctx, task); + } + + task->processed_stages |= stage; + + return ret; +} + +static gboolean +rspamd_stat_cache_check(struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task, + const gchar *classifier, + gboolean spam, + GError **err) +{ + rspamd_learn_t learn_res = RSPAMD_LEARN_OK; + struct rspamd_classifier *cl, *sel = NULL; + gpointer rt; + guint i; + + /* Check whether we have learned that file */ + for (i = 0; i < st_ctx->classifiers->len; i++) { + cl = g_ptr_array_index(st_ctx->classifiers, i); + + /* Skip other classifiers if they are not needed */ + if (classifier != NULL && (cl->cfg->name == NULL || + g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) { + continue; + } + + sel = cl; + + if (sel->cache && sel->cachecf) { + rt = cl->cache->runtime(task, sel->cachecf, FALSE); + learn_res = cl->cache->check(task, spam, rt); + } + + if (learn_res == RSPAMD_LEARN_IGNORE) { + /* Do not learn twice */ + g_set_error(err, rspamd_stat_quark(), 404, "<%s> has been already " + "learned as %s, ignore it", + MESSAGE_FIELD(task, message_id), + spam ? "spam" : "ham"); + task->flags |= RSPAMD_TASK_FLAG_ALREADY_LEARNED; + + return FALSE; + } + else if (learn_res == RSPAMD_LEARN_UNLEARN) { + task->flags |= RSPAMD_TASK_FLAG_UNLEARN; + break; + } + } + + if (sel == NULL) { + if (classifier) { + g_set_error(err, rspamd_stat_quark(), 404, "cannot find classifier " + "with name %s", + classifier); + } + else { + g_set_error(err, rspamd_stat_quark(), 404, "no classifiers defined"); + } + + return FALSE; + } + + return TRUE; +} + +static gboolean +rspamd_stat_classifiers_learn(struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task, + const gchar *classifier, + gboolean spam, + GError **err) +{ + struct rspamd_classifier *cl, *sel = NULL; + guint i; + gboolean learned = FALSE, too_small = FALSE, too_large = FALSE; + + if ((task->flags & RSPAMD_TASK_FLAG_ALREADY_LEARNED) && err != NULL && + *err == NULL) { + /* Do not learn twice */ + g_set_error(err, rspamd_stat_quark(), 208, "<%s> has been already " + "learned as %s, ignore it", + MESSAGE_FIELD(task, message_id), + spam ? "spam" : "ham"); + + return FALSE; + } + + /* Check whether we have learned that file */ + for (i = 0; i < st_ctx->classifiers->len; i++) { + cl = g_ptr_array_index(st_ctx->classifiers, i); + + /* Skip other classifiers if they are not needed */ + if (classifier != NULL && (cl->cfg->name == NULL || + g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) { + continue; + } + + sel = cl; + + /* Now check max and min tokens */ + if (cl->cfg->min_tokens > 0 && task->tokens->len < cl->cfg->min_tokens) { + msg_info_task( + "<%s> contains less tokens than required for %s classifier: " + "%ud < %ud", + MESSAGE_FIELD(task, message_id), + cl->cfg->name, + task->tokens->len, + cl->cfg->min_tokens); + too_small = TRUE; + continue; + } + else if (cl->cfg->max_tokens > 0 && task->tokens->len > cl->cfg->max_tokens) { + msg_info_task( + "<%s> contains more tokens than allowed for %s classifier: " + "%ud > %ud", + MESSAGE_FIELD(task, message_id), + cl->cfg->name, + task->tokens->len, + cl->cfg->max_tokens); + too_large = TRUE; + continue; + } + + if (cl->subrs->learn_spam_func(cl, task->tokens, task, spam, + task->flags & RSPAMD_TASK_FLAG_UNLEARN, err)) { + learned = TRUE; + } + } + + if (sel == NULL) { + if (classifier) { + g_set_error(err, rspamd_stat_quark(), 404, "cannot find classifier " + "with name %s", + classifier); + } + else { + g_set_error(err, rspamd_stat_quark(), 404, "no classifiers defined"); + } + + return FALSE; + } + + if (!learned && err && *err == NULL) { + if (too_large) { + g_set_error(err, rspamd_stat_quark(), 204, + "<%s> contains more tokens than allowed for %s classifier: " + "%d > %d", + MESSAGE_FIELD(task, message_id), + sel->cfg->name, + task->tokens->len, + sel->cfg->max_tokens); + } + else if (too_small) { + g_set_error(err, rspamd_stat_quark(), 204, + "<%s> contains less tokens than required for %s classifier: " + "%d < %d", + MESSAGE_FIELD(task, message_id), + sel->cfg->name, + task->tokens->len, + sel->cfg->min_tokens); + } + } + + return learned; +} + +static gboolean +rspamd_stat_backends_learn(struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task, + const gchar *classifier, + gboolean spam, + GError **err) +{ + struct rspamd_classifier *cl, *sel = NULL; + struct rspamd_statfile *st; + gpointer bk_run; + guint i, j; + gint id; + gboolean res = FALSE, backend_found = FALSE; + + for (i = 0; i < st_ctx->classifiers->len; i++) { + cl = g_ptr_array_index(st_ctx->classifiers, i); + + /* Skip other classifiers if they are not needed */ + if (classifier != NULL && (cl->cfg->name == NULL || + g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) { + continue; + } + + if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) { + res = TRUE; + continue; + } + + sel = cl; + + for (j = 0; j < cl->statfiles_ids->len; j++) { + id = g_array_index(cl->statfiles_ids, gint, j); + st = g_ptr_array_index(st_ctx->statfiles, id); + bk_run = g_ptr_array_index(task->stat_runtimes, id); + + g_assert(st != NULL); + + if (bk_run == NULL) { + /* XXX: must be error */ + if (task->result->passthrough_result) { + /* Passthrough email, cannot learn */ + g_set_error(err, rspamd_stat_quark(), 204, + "Cannot learn statistics when passthrough " + "result has been set; not classified"); + + res = FALSE; + goto end; + } + + msg_debug_task("no runtime for backend %s; classifier %s; symbol %s", + st->backend->name, cl->cfg->name, st->stcf->symbol); + continue; + } + + /* We set sel merely when we have runtime */ + backend_found = TRUE; + + if (!(task->flags & RSPAMD_TASK_FLAG_UNLEARN)) { + if (!!spam != !!st->stcf->is_spam) { + /* If we are not unlearning, then do not touch another class */ + continue; + } + } + + if (!st->backend->learn_tokens(task, task->tokens, id, bk_run)) { + g_set_error(err, rspamd_stat_quark(), 500, + "Cannot push " + "learned results to the backend"); + + res = FALSE; + goto end; + } + else { + if (!!spam == !!st->stcf->is_spam) { + st->backend->inc_learns(task, bk_run, st_ctx); + } + else if (task->flags & RSPAMD_TASK_FLAG_UNLEARN) { + st->backend->dec_learns(task, bk_run, st_ctx); + } + + res = TRUE; + } + } + } + +end: + + if (!res) { + if (err && *err) { + /* Error has been set already */ + return res; + } + + if (sel == NULL) { + if (classifier) { + g_set_error(err, rspamd_stat_quark(), 404, "cannot find classifier " + "with name %s", + classifier); + } + else { + g_set_error(err, rspamd_stat_quark(), 404, "no classifiers defined"); + } + + return FALSE; + } + else if (!backend_found) { + g_set_error(err, rspamd_stat_quark(), 204, "all learn conditions " + "denied learning %s in %s", + spam ? "spam" : "ham", + classifier ? classifier : "default classifier"); + } + else { + g_set_error(err, rspamd_stat_quark(), 404, "cannot find statfile " + "backend to learn %s in %s", + spam ? "spam" : "ham", + classifier ? classifier : "default classifier"); + } + } + + return res; +} + +static gboolean +rspamd_stat_backends_post_learn(struct rspamd_stat_ctx *st_ctx, + struct rspamd_task *task, + const gchar *classifier, + gboolean spam, + GError **err) +{ + struct rspamd_classifier *cl; + struct rspamd_statfile *st; + gpointer bk_run, cache_run; + guint i, j; + gint id; + gboolean res = TRUE; + + for (i = 0; i < st_ctx->classifiers->len; i++) { + cl = g_ptr_array_index(st_ctx->classifiers, i); + + /* Skip other classifiers if they are not needed */ + if (classifier != NULL && (cl->cfg->name == NULL || + g_ascii_strcasecmp(classifier, cl->cfg->name) != 0)) { + continue; + } + + if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) { + res = TRUE; + continue; + } + + for (j = 0; j < cl->statfiles_ids->len; j++) { + id = g_array_index(cl->statfiles_ids, gint, j); + st = g_ptr_array_index(st_ctx->statfiles, id); + bk_run = g_ptr_array_index(task->stat_runtimes, id); + + g_assert(st != NULL); + + if (bk_run == NULL) { + /* XXX: must be error */ + continue; + } + + if (!st->backend->finalize_learn(task, bk_run, st_ctx, err)) { + return RSPAMD_STAT_PROCESS_ERROR; + } + } + + if (cl->cache) { + cache_run = cl->cache->runtime(task, cl->cachecf, TRUE); + cl->cache->learn(task, spam, cache_run); + } + } + + g_atomic_int_add(&task->worker->srv->stat->messages_learned, 1); + + return res; +} + +rspamd_stat_result_t +rspamd_stat_learn(struct rspamd_task *task, + gboolean spam, lua_State *L, const gchar *classifier, guint stage, + GError **err) +{ + struct rspamd_stat_ctx *st_ctx; + rspamd_stat_result_t ret = RSPAMD_STAT_PROCESS_OK; + + /* + * We assume now that a task has been already classified before + * coming to learn + */ + g_assert(RSPAMD_TASK_IS_CLASSIFIED(task)); + + st_ctx = rspamd_stat_get_ctx(); + g_assert(st_ctx != NULL); + + if (st_ctx->classifiers->len == 0) { + task->processed_stages |= stage; + return ret; + } + + if (stage == RSPAMD_TASK_STAGE_LEARN_PRE) { + /* Process classifiers */ + rspamd_stat_preprocess(st_ctx, task, TRUE, spam); + + if (!rspamd_stat_cache_check(st_ctx, task, classifier, spam, err)) { + return RSPAMD_STAT_PROCESS_ERROR; + } + } + else if (stage == RSPAMD_TASK_STAGE_LEARN) { + /* Process classifiers */ + if (!rspamd_stat_classifiers_learn(st_ctx, task, classifier, + spam, err)) { + if (err && *err == NULL) { + g_set_error(err, rspamd_stat_quark(), 500, + "Unknown statistics error, found when learning classifiers;" + " classifier: %s", + task->classifier); + } + return RSPAMD_STAT_PROCESS_ERROR; + } + + /* Process backends */ + if (!rspamd_stat_backends_learn(st_ctx, task, classifier, spam, err)) { + if (err && *err == NULL) { + g_set_error(err, rspamd_stat_quark(), 500, + "Unknown statistics error, found when storing data on backend;" + " classifier: %s", + task->classifier); + } + return RSPAMD_STAT_PROCESS_ERROR; + } + } + else if (stage == RSPAMD_TASK_STAGE_LEARN_POST) { + if (!rspamd_stat_backends_post_learn(st_ctx, task, classifier, spam, err)) { + return RSPAMD_STAT_PROCESS_ERROR; + } + } + + task->processed_stages |= stage; + + return ret; +} + +static gboolean +rspamd_stat_has_classifier_symbols(struct rspamd_task *task, + struct rspamd_scan_result *mres, + struct rspamd_classifier *cl) +{ + guint i; + gint id; + struct rspamd_statfile *st; + struct rspamd_stat_ctx *st_ctx; + gboolean is_spam; + + if (mres == NULL) { + return FALSE; + } + + st_ctx = rspamd_stat_get_ctx(); + is_spam = !!(task->flags & RSPAMD_TASK_FLAG_LEARN_SPAM); + + for (i = 0; i < cl->statfiles_ids->len; i++) { + id = g_array_index(cl->statfiles_ids, gint, i); + st = g_ptr_array_index(st_ctx->statfiles, id); + + if (rspamd_task_find_symbol_result(task, st->stcf->symbol, NULL)) { + if (is_spam == !!st->stcf->is_spam) { + msg_debug_bayes("do not autolearn %s as symbol %s is already " + "added", + is_spam ? "spam" : "ham", st->stcf->symbol); + + return TRUE; + } + } + } + + return FALSE; +} + +gboolean +rspamd_stat_check_autolearn(struct rspamd_task *task) +{ + struct rspamd_stat_ctx *st_ctx; + struct rspamd_classifier *cl; + const ucl_object_t *obj, *elt1, *elt2; + struct rspamd_scan_result *mres = NULL; + struct rspamd_task **ptask; + lua_State *L; + guint i; + gint err_idx; + gboolean ret = FALSE; + gdouble ham_score, spam_score; + const gchar *lua_script, *lua_ret; + + g_assert(RSPAMD_TASK_IS_CLASSIFIED(task)); + st_ctx = rspamd_stat_get_ctx(); + g_assert(st_ctx != NULL); + + L = task->cfg->lua_state; + + for (i = 0; i < st_ctx->classifiers->len; i++) { + cl = g_ptr_array_index(st_ctx->classifiers, i); + ret = FALSE; + + if (cl->cfg->opts) { + obj = ucl_object_lookup(cl->cfg->opts, "autolearn"); + + if (ucl_object_type(obj) == UCL_BOOLEAN) { + /* Legacy true/false */ + if (ucl_object_toboolean(obj)) { + /* + * Default learning algorithm: + * + * - We learn spam if action is ACTION_REJECT + * - We learn ham if score is less than zero + */ + mres = task->result; + + if (mres) { + if (mres->score > rspamd_task_get_required_score(task, mres)) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; + + ret = TRUE; + } + else if (mres->score < 0) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; + ret = TRUE; + } + } + } + } + else if (ucl_object_type(obj) == UCL_ARRAY && obj->len == 2) { + /* Legacy thresholds */ + /* + * We have an array of 2 elements, treat it as a + * ham_score, spam_score + */ + elt1 = ucl_array_find_index(obj, 0); + elt2 = ucl_array_find_index(obj, 1); + + if ((ucl_object_type(elt1) == UCL_FLOAT || + ucl_object_type(elt1) == UCL_INT) && + (ucl_object_type(elt2) == UCL_FLOAT || + ucl_object_type(elt2) == UCL_INT)) { + ham_score = ucl_object_todouble(elt1); + spam_score = ucl_object_todouble(elt2); + + if (ham_score > spam_score) { + gdouble t; + + t = ham_score; + ham_score = spam_score; + spam_score = t; + } + + mres = task->result; + + if (mres) { + if (mres->score >= spam_score) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; + + ret = TRUE; + } + else if (mres->score <= ham_score) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; + ret = TRUE; + } + } + } + } + else if (ucl_object_type(obj) == UCL_STRING) { + /* Legacy script */ + lua_script = ucl_object_tostring(obj); + + if (luaL_dostring(L, lua_script) != 0) { + msg_err_task("cannot execute lua script for autolearn " + "extraction: %s", + lua_tostring(L, -1)); + } + else { + if (lua_type(L, -1) == LUA_TFUNCTION) { + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + lua_pushvalue(L, -2); /* Function itself */ + + ptask = lua_newuserdata(L, sizeof(struct rspamd_task *)); + *ptask = task; + rspamd_lua_setclass(L, "rspamd{task}", -1); + + if (lua_pcall(L, 1, 1, err_idx) != 0) { + msg_err_task("call to autolearn script failed: " + "%s", + lua_tostring(L, -1)); + } + else { + lua_ret = lua_tostring(L, -1); + + /* We can have immediate results */ + if (lua_ret) { + if (strcmp(lua_ret, "ham") == 0) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; + ret = TRUE; + } + else if (strcmp(lua_ret, "spam") == 0) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; + ret = TRUE; + } + } + } + + /* Result + error function + original function */ + lua_pop(L, 3); + } + else { + msg_err_task("lua script must return " + "function(task) and not %s", + lua_typename(L, lua_type( + L, -1))); + } + } + } + else if (ucl_object_type(obj) == UCL_OBJECT) { + /* Try to find autolearn callback */ + if (cl->autolearn_cbref == 0) { + /* We don't have preprocessed cb id, so try to get it */ + if (!rspamd_lua_require_function(L, "lua_bayes_learn", + "autolearn")) { + msg_err_task("cannot get autolearn library from " + "`lua_bayes_learn`"); + } + else { + cl->autolearn_cbref = luaL_ref(L, LUA_REGISTRYINDEX); + } + } + + if (cl->autolearn_cbref != -1) { + lua_pushcfunction(L, &rspamd_lua_traceback); + err_idx = lua_gettop(L); + lua_rawgeti(L, LUA_REGISTRYINDEX, cl->autolearn_cbref); + + ptask = lua_newuserdata(L, sizeof(struct rspamd_task *)); + *ptask = task; + rspamd_lua_setclass(L, "rspamd{task}", -1); + /* Push the whole object as well */ + ucl_object_push_lua(L, obj, true); + + if (lua_pcall(L, 2, 1, err_idx) != 0) { + msg_err_task("call to autolearn script failed: " + "%s", + lua_tostring(L, -1)); + } + else { + lua_ret = lua_tostring(L, -1); + + if (lua_ret) { + if (strcmp(lua_ret, "ham") == 0) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_HAM; + ret = TRUE; + } + else if (strcmp(lua_ret, "spam") == 0) { + task->flags |= RSPAMD_TASK_FLAG_LEARN_SPAM; + ret = TRUE; + } + } + } + + lua_settop(L, err_idx - 1); + } + } + + if (ret) { + /* Do not autolearn if we have this symbol already */ + if (rspamd_stat_has_classifier_symbols(task, mres, cl)) { + ret = FALSE; + task->flags &= ~(RSPAMD_TASK_FLAG_LEARN_HAM | + RSPAMD_TASK_FLAG_LEARN_SPAM); + } + else if (mres != NULL) { + if (task->flags & RSPAMD_TASK_FLAG_LEARN_HAM) { + msg_info_task("<%s>: autolearn ham for classifier " + "'%s' as message's " + "score is negative: %.2f", + MESSAGE_FIELD(task, message_id), cl->cfg->name, + mres->score); + } + else { + msg_info_task("<%s>: autolearn spam for classifier " + "'%s' as message's " + "action is reject, score: %.2f", + MESSAGE_FIELD(task, message_id), cl->cfg->name, + mres->score); + } + + task->classifier = cl->cfg->name; + break; + } + } + } + } + + return ret; +} + +/** + * Get the overall statistics for all statfile backends + * @param cfg configuration + * @param total_learns the total number of learns is stored here + * @return array of statistical information + */ +rspamd_stat_result_t +rspamd_stat_statistics(struct rspamd_task *task, + struct rspamd_config *cfg, + guint64 *total_learns, + ucl_object_t **target) +{ + struct rspamd_stat_ctx *st_ctx; + struct rspamd_classifier *cl; + struct rspamd_statfile *st; + gpointer backend_runtime; + ucl_object_t *res = NULL, *elt; + guint64 learns = 0; + guint i, j; + gint id; + + st_ctx = rspamd_stat_get_ctx(); + g_assert(st_ctx != NULL); + + res = ucl_object_typed_new(UCL_ARRAY); + + for (i = 0; i < st_ctx->classifiers->len; i++) { + cl = g_ptr_array_index(st_ctx->classifiers, i); + + if (cl->cfg->flags & RSPAMD_FLAG_CLASSIFIER_NO_BACKEND) { + continue; + } + + for (j = 0; j < cl->statfiles_ids->len; j++) { + id = g_array_index(cl->statfiles_ids, gint, j); + st = g_ptr_array_index(st_ctx->statfiles, id); + backend_runtime = st->backend->runtime(task, st->stcf, FALSE, + st->bkcf, id); + elt = st->backend->get_stat(backend_runtime, st->bkcf); + + if (elt && ucl_object_type(elt) == UCL_OBJECT) { + const ucl_object_t *rev = ucl_object_lookup(elt, "revision"); + + learns += ucl_object_toint(rev); + } + else { + learns += st->backend->total_learns(task, backend_runtime, + st->bkcf); + } + + if (elt != NULL) { + ucl_array_append(res, elt); + } + } + } + + if (total_learns != NULL) { + *total_learns = learns; + } + + if (target) { + *target = res; + } + else { + ucl_object_unref(res); + } + + return RSPAMD_STAT_PROCESS_OK; +} |