diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-10 21:30:40 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-10 21:30:40 +0000 |
commit | 133a45c109da5310add55824db21af5239951f93 (patch) | |
tree | ba6ac4c0a950a0dda56451944315d66409923918 /src/plugins/lua/neural.lua | |
parent | Initial commit. (diff) | |
download | rspamd-133a45c109da5310add55824db21af5239951f93.tar.xz rspamd-133a45c109da5310add55824db21af5239951f93.zip |
Adding upstream version 3.8.1.upstream/3.8.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'src/plugins/lua/neural.lua')
-rw-r--r-- | src/plugins/lua/neural.lua | 1000 |
1 files changed, 1000 insertions, 0 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua new file mode 100644 index 0000000..f3b26f1 --- /dev/null +++ b/src/plugins/lua/neural.lua @@ -0,0 +1,1000 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + + +if confighelp then + return +end + +local fun = require "fun" +local lua_redis = require "lua_redis" +local lua_util = require "lua_util" +local lua_verdict = require "lua_verdict" +local neural_common = require "plugins/neural" +local rspamd_kann = require "rspamd_kann" +local rspamd_logger = require "rspamd_logger" +local rspamd_tensor = require "rspamd_tensor" +local rspamd_text = require "rspamd_text" +local rspamd_util = require "rspamd_util" +local ts = require("tableshape").types + +local N = "neural" + +local settings = neural_common.settings + +local redis_profile_schema = ts.shape { + digest = ts.string, + symbols = ts.array_of(ts.string), + version = ts.number, + redis_key = ts.string, + distance = ts.number:is_optional(), +} + +local has_blas = rspamd_tensor.has_blas() +local text_cookie = rspamd_text.cookie + +-- Creates and stores ANN profile in Redis +local function new_ann_profile(task, rule, set, version) + local ann_key = neural_common.new_ann_key(rule, set, version, settings) + + local profile = { + symbols = set.symbols, + redis_key = ann_key, + version = version, + digest = set.digest, + distance = 0 -- Since we are using our own profile + } + + local ucl = require "ucl" + local profile_serialized = ucl.to_format(profile, 'json-compact', true) + + local function add_cb(err, _) + if err then + rspamd_logger.errx(task, 'cannot store ANN profile for %s:%s at %s : %s', + rule.prefix, set.name, profile.redis_key, err) + else + rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s', + rule.prefix, set.name, profile.redis_key) + end + end + + lua_redis.redis_make_request(task, + rule.redis, + nil, + true, -- is write + add_cb, --callback + 'ZADD', -- command + { set.prefix, tostring(rspamd_util.get_time()), profile_serialized } + ) + + return profile +end + + +-- ANN filter function, used to insert scores based on the existing symbols +local function ann_scores_filter(task) + + for _, rule in pairs(settings.rules) do + local sid = task:get_settings_id() or -1 + local ann + local profile + + local set = neural_common.get_rule_settings(task, rule) + if set then + if set.ann then + ann = set.ann.ann + profile = set.ann + else + lua_util.debugm(N, task, 'no ann loaded for %s:%s', + rule.prefix, set.name) + end + else + lua_util.debugm(N, task, 'no ann defined in %s for settings id %s', + rule.prefix, sid) + end + + if ann then + local vec = neural_common.result_to_vector(task, profile) + + local score + local out = ann:apply1(vec, set.ann.pca) + score = out[1] + + local symscore = string.format('%.3f', score) + task:cache_set(rule.prefix .. '_neural_score', score) + lua_util.debugm(N, task, '%s:%s:%s ann score: %s', + rule.prefix, set.name, set.ann.version, symscore) + + if score > 0 then + local result = score + + -- If spam_score_threshold is defined, override all other thresholds. + local spam_threshold = 0 + if rule.spam_score_threshold then + spam_threshold = rule.spam_score_threshold + elseif rule.roc_enabled and not set.ann.roc_thresholds then + spam_threshold = set.ann.roc_thresholds[1] + end + + if result >= spam_threshold then + if rule.flat_threshold_curve then + task:insert_result(rule.symbol_spam, 1.0, symscore) + else + task:insert_result(rule.symbol_spam, result, symscore) + end + else + lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam threshold)', + rule.prefix, set.name, set.ann.version, symscore, + spam_threshold) + end + else + local result = -(score) + + -- If ham_score_threshold is defined, override all other thresholds. + local ham_threshold = 0 + if rule.ham_score_threshold then + ham_threshold = rule.ham_score_threshold + elseif rule.roc_enabled and not set.ann.roc_thresholds then + ham_threshold = set.ann.roc_thresholds[2] + end + + if result >= ham_threshold then + if rule.flat_threshold_curve then + task:insert_result(rule.symbol_ham, 1.0, symscore) + else + task:insert_result(rule.symbol_ham, result, symscore) + end + else + lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham threshold)', + rule.prefix, set.name, set.ann.version, result, + ham_threshold) + end + end + end + end +end + +local function ann_push_task_result(rule, task, verdict, score, set) + local train_opts = rule.train + local learn_spam, learn_ham + local skip_reason = 'unknown' + + if not train_opts.store_pool_only and train_opts.autotrain then + if train_opts.spam_score then + learn_spam = score >= train_opts.spam_score + + if not learn_spam then + skip_reason = string.format('score < spam_score: %f < %f', + score, train_opts.spam_score) + end + else + learn_spam = verdict == 'spam' or verdict == 'junk' + + if not learn_spam then + skip_reason = string.format('verdict: %s', + verdict) + end + end + + if train_opts.ham_score then + learn_ham = score <= train_opts.ham_score + if not learn_ham then + skip_reason = string.format('score > ham_score: %f > %f', + score, train_opts.ham_score) + end + else + learn_ham = verdict == 'ham' + + if not learn_ham then + skip_reason = string.format('verdict: %s', + verdict) + end + end + else + -- Train by request header + local hdr = task:get_request_header('ANN-Train') + + if hdr then + if hdr:lower() == 'spam' then + learn_spam = true + elseif hdr:lower() == 'ham' then + learn_ham = true + else + skip_reason = 'no explicit header' + end + elseif train_opts.store_pool_only then + local ucl = require "ucl" + learn_ham = false + learn_spam = false + + -- Explicitly store tokens in cache + local vec = neural_common.result_to_vector(task, set) + task:cache_set(rule.prefix .. '_neural_vec_mpack', ucl.to_format(vec, 'msgpack')) + task:cache_set(rule.prefix .. '_neural_profile_digest', set.digest) + skip_reason = 'store_pool_only has been set' + end + end + + if learn_spam or learn_ham then + local learn_type + if learn_spam then + learn_type = 'spam' + else + learn_type = 'ham' + end + + local function vectors_len_cb(err, data) + if not err and type(data) == 'table' then + local nspam, nham = data[1], data[2] + + if neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then + local vec = neural_common.result_to_vector(task, set) + + local str = rspamd_util.zstd_compress(table.concat(vec, ';')) + local target_key = set.ann.redis_key .. '_' .. learn_type .. '_set' + + local function learn_vec_cb(redis_err) + if redis_err then + rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s', + rule.prefix, set.name, redis_err) + else + lua_util.debugm(N, task, + "add train data for ANN rule " .. + "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed", + rule.prefix, set.name, learn_type, #vec, target_key, #str) + end + end + + lua_redis.redis_make_request(task, + rule.redis, + nil, + true, -- is write + learn_vec_cb, --callback + 'SADD', -- command + { target_key, str } -- arguments + ) + else + lua_util.debugm(N, task, + "do not add %s train data for ANN rule " .. + "%s:%s", + learn_type, rule.prefix, set.name) + end + else + if err then + rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s', + rule.prefix, set.name, err) + elseif type(data) == 'string' then + -- nil return value + rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: locked for learning: %s", + learn_type, rule.prefix, set.name, set.ann.redis_key, data) + else + rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' .. + 'please remove this key from Redis manually if you perform upgrade from the previous version', + rule.prefix, set.name, set.ann.redis_key, type(data)) + end + end + end + + -- Check if we can learn + if set.can_store_vectors then + if not set.ann then + -- Need to create or load a profile corresponding to the current configuration + set.ann = new_ann_profile(task, rule, set, 0) + lua_util.debugm(N, task, + 'requested new profile for %s, set.ann is missing', + set.name) + end + + lua_redis.exec_redis_script(neural_common.redis_script_id.vectors_len, + { task = task, is_write = false }, + vectors_len_cb, + { + set.ann.redis_key, + }) + else + lua_util.debugm(N, task, + 'do not push data: train condition not satisfied; reason: not checked existing ANNs') + end + else + lua_util.debugm(N, task, + 'do not push data to key %s: train condition not satisfied; reason: %s', + (set.ann or {}).redis_key, + skip_reason) + end +end + +--- Offline training logic + +-- Utility to extract and split saved training vectors to a table of tables +local function process_training_vectors(data) + return fun.totable(fun.map(function(tok) + local _, str = rspamd_util.zstd_decompress(tok) + return fun.totable(fun.map(tonumber, lua_util.str_split(tostring(str), ';'))) + end, data)) +end + +-- This function does the following: +-- * Tries to lock ANN +-- * Loads spam and ham vectors +-- * Spawn learning process +local function do_train_ann(worker, ev_base, rule, set, ann_key) + local spam_elts = {} + local ham_elts = {} + + local function redis_ham_cb(err, data) + if err or type(data) ~= 'table' then + rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s', + ann_key, err) + -- Unlock on error + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + true, -- is write + neural_common.gen_unlock_cb(rule, set, ann_key), --callback + 'HDEL', -- command + { ann_key, 'lock' } + ) + else + -- Decompress and convert to numbers each training vector + ham_elts = process_training_vectors(data) + neural_common.spawn_train({ worker = worker, ev_base = ev_base, + rule = rule, set = set, ann_key = ann_key, ham_vec = ham_elts, + spam_vec = spam_elts }) + end + end + + -- Spam vectors received + local function redis_spam_cb(err, data) + if err or type(data) ~= 'table' then + rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s', + ann_key, err) + -- Unlock ANN on error + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + true, -- is write + neural_common.gen_unlock_cb(rule, set, ann_key), --callback + 'HDEL', -- command + { ann_key, 'lock' } + ) + else + -- Decompress and convert to numbers each training vector + spam_elts = process_training_vectors(data) + -- Now get ham vectors... + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_ham_cb, --callback + 'SMEMBERS', -- command + { ann_key .. '_ham_set' } + ) + end + end + + local function redis_lock_cb(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s', + ann_key, err) + elseif type(data) == 'number' and data == 1 then + -- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_spam_cb, --callback + 'SMEMBERS', -- command + { ann_key .. '_spam_set' } + ) + + rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s) for learning', + rule.prefix, set.name, ann_key) + else + local lock_tm = tonumber(data[1]) + rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' .. + 'locked by another host %s at %s', rule.prefix, set.name, ann_key, + data[2], os.date('%c', lock_tm)) + end + end + + -- Check if we are already learning this network + if set.learning_spawned then + rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN', + ann_key) + return + end + + -- Call Redis script that tries to acquire a lock + -- This script returns either a boolean or a pair {'lock_time', 'hostname'} when + -- ANN is locked by another host (or a process, meh) + lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_lock, + { ev_base = ev_base, is_write = true }, + redis_lock_cb, + { + ann_key, + tostring(os.time()), + tostring(math.max(10.0, rule.watch_interval * 2)), + rspamd_util.get_hostname() + }) +end + +-- This function loads new ann from Redis +-- This is based on `profile` attribute. +-- ANN is loaded from `profile.redis_key` +-- Rank of `profile` key is also increased, unfortunately, it means that we need to +-- serialize profile one more time and set its rank to the current time +-- set.ann fields are set according to Redis data received +local function load_new_ann(rule, ev_base, set, profile, min_diff) + local ann_key = profile.redis_key + + local function data_cb(err, data) + if err then + rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s', + ann_key, err) + else + if type(data) == 'table' then + if type(data[1]) == 'userdata' and data[1].cookie == text_cookie then + local _err, ann_data = rspamd_util.zstd_decompress(data[1]) + local ann + + if _err or not ann_data then + rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s', + rule.prefix .. ':' .. set.name, ann_key, _err) + return + else + ann = rspamd_kann.load(ann_data) + + if ann then + set.ann = { + digest = profile.digest, + version = profile.version, + symbols = profile.symbols, + distance = min_diff, + redis_key = profile.redis_key + } + + local ucl = require "ucl" + local profile_serialized = ucl.to_format(profile, 'json-compact', true) + set.ann.ann = ann -- To avoid serialization + + local function rank_cb(_, _) + -- TODO: maybe add some logging + end + -- Also update rank for the loaded ANN to avoid removal + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + true, -- is write + rank_cb, --callback + 'ZADD', -- command + { set.prefix, tostring(rspamd_util.get_time()), profile_serialized } + ) + rspamd_logger.infox(rspamd_config, + 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s', + rule.prefix, set.name, ann_key, #data[1], profile.version) + else + rspamd_logger.errx(rspamd_config, + 'cannot unpack/deserialise ANN for %s:%s from Redis key %s', + rule.prefix, set.name, ann_key) + end + end + else + lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s', + rule.prefix, set.name, ann_key) + end + + if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then + if rule.roc_enabled then + local ucl = require "ucl" + local parser = ucl.parser() + local ok, parse_err = parser:parse_text(data[2]) + assert(ok, parse_err) + local roc_thresholds = parser:get_object() + set.ann.roc_thresholds = roc_thresholds + rspamd_logger.infox(rspamd_config, + 'loaded ROC thresholds for %s:%s; version=%s', + rule.prefix, set.name, profile.version) + rspamd_logger.debugx("ROC thresholds: %s", roc_thresholds) + end + end + + if set.ann and set.ann.ann and type(data[3]) == 'userdata' and data[3].cookie == text_cookie then + -- PCA table + local _err, pca_data = rspamd_util.zstd_decompress(data[3]) + if pca_data then + if rule.max_inputs then + -- We can use PCA + set.ann.pca = rspamd_tensor.load(pca_data) + rspamd_logger.infox(rspamd_config, + 'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s', + rule.prefix, set.name, ann_key, #data[3], profile.version) + else + -- no need in pca, why is it there? + rspamd_logger.warnx(rspamd_config, + 'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined', + rule.prefix, set.name, ann_key) + end + else + -- pca can be missing merely if we have no max_inputs + if rule.max_inputs then + rspamd_logger.errx(rspamd_config, 'cannot unpack/deserialise ANN for %s:%s from Redis key %s: no PCA: %s', + rule.prefix, set.name, ann_key, _err) + set.ann.ann = nil + else + -- It is okay + set.ann.pca = nil + end + end + end + + else + lua_util.debugm(N, rspamd_config, 'no ANN key for %s:%s in Redis key %s', + rule.prefix, set.name, ann_key) + end + end + end + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + false, -- is write + data_cb, --callback + 'HMGET', -- command + { ann_key, 'ann', 'roc_thresholds', 'pca' }, -- arguments + { opaque_data = true } + ) +end + +-- Used to check an element in Redis serialized as JSON +-- for some specific rule + some specific setting +-- This function tries to load more fresh or more specific ANNs in lieu of +-- the existing ones. +-- Use this function to load ANNs as `callback` parameter for `check_anns` function +local function process_existing_ann(_, ev_base, rule, set, profiles) + local my_symbols = set.symbols + local min_diff = math.huge + local sel_elt + + for _, elt in fun.iter(profiles) do + if elt and elt.symbols then + local dist = lua_util.distance_sorted(elt.symbols, my_symbols) + -- Check distance + if dist < #my_symbols * .3 then + if dist < min_diff then + min_diff = dist + sel_elt = elt + end + end + end + end + + if sel_elt then + -- We can load element from ANN + if set.ann then + -- We have an existing ANN, probably the same... + if set.ann.digest == sel_elt.digest then + -- Same ANN, check version + if set.ann.version < sel_elt.version then + -- Load new ann + rspamd_logger.infox(rspamd_config, 'ann %s is changed, ' .. + 'our version = %s, remote version = %s', + rule.prefix .. ':' .. set.name, + set.ann.version, + sel_elt.version) + load_new_ann(rule, ev_base, set, sel_elt, min_diff) + else + lua_util.debugm(N, rspamd_config, 'ann %s is not changed, ' .. + 'our version = %s, remote version = %s', + rule.prefix .. ':' .. set.name, + set.ann.version, + sel_elt.version) + end + else + -- We have some different ANN, so we need to compare distance + if set.ann.distance > min_diff then + -- Load more specific ANN + rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s, ' .. + 'our distance = %s, remote distance = %s', + rule.prefix .. ':' .. set.name, + set.ann.distance, + min_diff) + load_new_ann(rule, ev_base, set, sel_elt, min_diff) + else + lua_util.debugm(N, rspamd_config, 'ann %s is not changed or less specific, ' .. + 'our distance = %s, remote distance = %s', + rule.prefix .. ':' .. set.name, + set.ann.distance, + min_diff) + end + end + else + -- We have no ANN, load new one + load_new_ann(rule, ev_base, set, sel_elt, min_diff) + end + end +end + + +-- This function checks all profiles and selects if we can train our +-- ANN. By our we mean that it has exactly the same symbols in profile. +-- Use this function to train ANN as `callback` parameter for `check_anns` function +local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles) + local my_symbols = set.symbols + local sel_elt + local lens = { + spam = 0, + ham = 0, + } + + for _, elt in fun.iter(profiles) do + if elt and elt.symbols then + local dist = lua_util.distance_sorted(elt.symbols, my_symbols) + -- Check distance + if dist == 0 then + sel_elt = elt + break + end + end + end + + if sel_elt then + -- We have our ANN and that's train vectors, check if we can learn + local ann_key = sel_elt.redis_key + + lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained", + ann_key) + + -- Create continuation closure + local redis_len_cb_gen = function(cont_cb, what, is_final) + return function(err, data) + if err then + rspamd_logger.errx(rspamd_config, + 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err) + elseif data and type(data) == 'number' or type(data) == 'string' then + local ntrains = tonumber(data) or 0 + lens[what] = ntrains + if is_final then + -- Ensure that we have the following: + -- one class has reached max_trains + -- other class(es) are at least as full as classes_bias + -- e.g. if classes_bias = 0.25 and we have 10 max_trains then + -- one class must have 10 or more trains whilst another should have + -- at least (10 * (1 - 0.25)) = 8 trains + + local max_len = math.max(lua_util.unpack(lua_util.values(lens))) + local min_len = math.min(lua_util.unpack(lua_util.values(lens))) + + if rule.train.learn_type == 'balanced' then + local len_bias_check_pred = function(_, l) + return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias) + end + if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then + rspamd_logger.debugm(N, rspamd_config, + 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors', + ann_key, lens, rule.train.max_trains, what) + cont_cb() + else + rspamd_logger.debugm(N, rspamd_config, + 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)', + ann_key, what, lens, rule.train.max_trains) + end + else + -- Probabilistic mode, just ensure that at least one vector is okay + if min_len > 0 and max_len >= rule.train.max_trains then + rspamd_logger.debugm(N, rspamd_config, + 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors', + ann_key, lens, rule.train.max_trains, what) + cont_cb() + else + rspamd_logger.debugm(N, rspamd_config, + 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)', + ann_key, what, lens, rule.train.max_trains) + end + end + + else + rspamd_logger.debugm(N, rspamd_config, + 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors', + what, ann_key, ntrains, rule.train.max_trains) + cont_cb() + end + end + end + + end + + local function initiate_train() + rspamd_logger.infox(rspamd_config, + 'need to learn ANN %s after %s required learn vectors', + ann_key, lens) + do_train_ann(worker, ev_base, rule, set, ann_key) + end + + -- Spam vector is OK, check ham vector length + local function check_ham_len() + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_len_cb_gen(initiate_train, 'ham', true), --callback + 'SCARD', -- command + { ann_key .. '_ham_set' } + ) + end + + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + false, -- is write + redis_len_cb_gen(check_ham_len, 'spam', false), --callback + 'SCARD', -- command + { ann_key .. '_spam_set' } + ) + end +end + +-- Used to deserialise ANN element from a list +local function load_ann_profile(element) + local ucl = require "ucl" + + local parser = ucl.parser() + local res, ucl_err = parser:parse_string(element) + if not res then + rspamd_logger.warnx(rspamd_config, 'cannot parse ANN from redis: %s', + ucl_err) + return nil + else + local profile = parser:get_object() + local checked, schema_err = redis_profile_schema:transform(profile) + if not checked then + rspamd_logger.errx(rspamd_config, "cannot parse profile schema: %s", schema_err) + + return nil + end + return checked + end +end + +-- Function to check or load ANNs from Redis +local function check_anns(worker, cfg, ev_base, rule, process_callback, what) + for _, set in pairs(rule.settings) do + local function members_cb(err, data) + if err then + rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s', + err) + set.can_store_vectors = true + elseif type(data) == 'table' then + lua_util.debugm(N, cfg, '%s: process element %s:%s', + what, rule.prefix, set.name) + process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data)) + set.can_store_vectors = true + end + end + + if type(set) == 'table' then + -- Extract all profiles for some specific settings id + -- Get the last `max_profiles` recently used + -- Select the most appropriate to our profile but it should not differ by more + -- than 30% of symbols + lua_redis.redis_make_request_taskless(ev_base, + cfg, + rule.redis, + nil, + false, -- is write + members_cb, --callback + 'ZREVRANGE', -- command + { set.prefix, '0', tostring(settings.max_profiles) } -- arguments + ) + end + end -- Cycle over all settings + + return rule.watch_interval +end + +-- Function to clean up old ANNs +local function cleanup_anns(rule, cfg, ev_base) + for _, set in pairs(rule.settings) do + local function invalidate_cb(err, data) + if err then + rspamd_logger.errx(cfg, 'cannot exec invalidate script in redis: %s', + err) + elseif type(data) == 'table' then + for _, expired in ipairs(data) do + local profile = load_ann_profile(expired) + rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s', + rule.prefix .. ':' .. set.name, + profile.redis_key, + profile.version) + end + end + end + + if type(set) == 'table' then + lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_invalidate, + { ev_base = ev_base, is_write = true }, + invalidate_cb, + { set.prefix, tostring(settings.max_profiles) }) + end + end +end + +local function ann_push_vector(task) + if task:has_flag('skip') then + lua_util.debugm(N, task, 'do not push data for skipped task') + return + end + if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then + lua_util.debugm(N, task, 'do not push data for manual scan') + return + end + + local verdict, score = lua_verdict.get_specific_verdict(N, task) + + if verdict == 'passthrough' then + lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)', + verdict, score) + + return + end + + if score ~= score then + lua_util.debugm(N, task, 'ignore task as its score is nan (%s verdict)', + verdict) + + return + end + + for _, rule in pairs(settings.rules) do + local set = neural_common.get_rule_settings(task, rule) + + if set then + ann_push_task_result(rule, task, verdict, score, set) + else + lua_util.debugm(N, task, 'settings not found in rule %s', rule.prefix) + end + + end +end + + +-- Initialization part +if not (neural_common.module_config and type(neural_common.module_config) == 'table') + or not neural_common.redis_params then + rspamd_logger.infox(rspamd_config, 'Module is unconfigured') + lua_util.disable_module(N, "redis") + return +end + +local rules = neural_common.module_config['rules'] + +if not rules then + -- Use legacy configuration + rules = {} + rules['default'] = neural_common.module_config +end + +local id = rspamd_config:register_symbol({ + name = 'NEURAL_CHECK', + type = 'postfilter,callback', + flags = 'nostat', + priority = lua_util.symbols_priorities.medium, + callback = ann_scores_filter +}) + +neural_common.settings.rules = {} -- Reset unless validated further in the cycle + +if settings.blacklisted_symbols and settings.blacklisted_symbols[1] then + -- Transform to hash for simplicity + settings.blacklisted_symbols = lua_util.list_to_hash(settings.blacklisted_symbols) +end + +-- Check all rules +for k, r in pairs(rules) do + local rule_elt = lua_util.override_defaults(neural_common.default_options, r) + rule_elt['redis'] = neural_common.redis_params + rule_elt['anns'] = {} -- Store ANNs here + + if not rule_elt.prefix then + rule_elt.prefix = k + end + if not rule_elt.name then + rule_elt.name = k + end + if rule_elt.train.max_train and not rule_elt.train.max_trains then + rule_elt.train.max_trains = rule_elt.train.max_train + end + + if not rule_elt.profile then + rule_elt.profile = {} + end + + if rule_elt.max_inputs and not has_blas then + rspamd_logger.errx('cannot set max inputs to %s as BLAS is not compiled in', + rule_elt.name, rule_elt.max_inputs) + rule_elt.max_inputs = nil + end + + rspamd_logger.infox(rspamd_config, "register ann rule %s", k) + settings.rules[k] = rule_elt + rspamd_config:set_metric_symbol({ + name = rule_elt.symbol_spam, + score = 0.0, + description = 'Neural network SPAM', + group = 'neural' + }) + rspamd_config:register_symbol({ + name = rule_elt.symbol_spam, + type = 'virtual', + flags = 'nostat', + parent = id + }) + + rspamd_config:set_metric_symbol({ + name = rule_elt.symbol_ham, + score = -0.0, + description = 'Neural network HAM', + group = 'neural' + }) + rspamd_config:register_symbol({ + name = rule_elt.symbol_ham, + type = 'virtual', + flags = 'nostat', + parent = id + }) +end + +rspamd_config:register_symbol({ + name = 'NEURAL_LEARN', + type = 'idempotent,callback', + flags = 'nostat,explicit_disable,ignore_passthrough', + callback = ann_push_vector +}) + +-- We also need to deal with settings +rspamd_config:add_post_init(neural_common.process_rules_settings) + +-- Add training scripts +for _, rule in pairs(settings.rules) do + neural_common.load_scripts(rule.redis) + -- This function will check ANNs in Redis when a worker is loaded + rspamd_config:add_on_load(function(cfg, ev_base, worker) + if worker:is_scanner() then + rspamd_config:add_periodic(ev_base, 0.0, + function(_, _) + return check_anns(worker, cfg, ev_base, rule, process_existing_ann, + 'try_load_ann') + end) + end + + if worker:is_primary_controller() then + -- We also want to train neural nets when they have enough data + rspamd_config:add_periodic(ev_base, 0.0, + function(_, _) + -- Clean old ANNs + cleanup_anns(rule, cfg, ev_base) + return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann, + 'try_train_ann') + end) + end + end) +end |