summaryrefslogtreecommitdiffstats
path: root/src/plugins/lua/neural.lua
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-10 21:30:40 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-10 21:30:40 +0000
commit133a45c109da5310add55824db21af5239951f93 (patch)
treeba6ac4c0a950a0dda56451944315d66409923918 /src/plugins/lua/neural.lua
parentInitial commit. (diff)
downloadrspamd-133a45c109da5310add55824db21af5239951f93.tar.xz
rspamd-133a45c109da5310add55824db21af5239951f93.zip
Adding upstream version 3.8.1.upstream/3.8.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'src/plugins/lua/neural.lua')
-rw-r--r--src/plugins/lua/neural.lua1000
1 files changed, 1000 insertions, 0 deletions
diff --git a/src/plugins/lua/neural.lua b/src/plugins/lua/neural.lua
new file mode 100644
index 0000000..f3b26f1
--- /dev/null
+++ b/src/plugins/lua/neural.lua
@@ -0,0 +1,1000 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+
+if confighelp then
+ return
+end
+
+local fun = require "fun"
+local lua_redis = require "lua_redis"
+local lua_util = require "lua_util"
+local lua_verdict = require "lua_verdict"
+local neural_common = require "plugins/neural"
+local rspamd_kann = require "rspamd_kann"
+local rspamd_logger = require "rspamd_logger"
+local rspamd_tensor = require "rspamd_tensor"
+local rspamd_text = require "rspamd_text"
+local rspamd_util = require "rspamd_util"
+local ts = require("tableshape").types
+
+local N = "neural"
+
+local settings = neural_common.settings
+
+local redis_profile_schema = ts.shape {
+ digest = ts.string,
+ symbols = ts.array_of(ts.string),
+ version = ts.number,
+ redis_key = ts.string,
+ distance = ts.number:is_optional(),
+}
+
+local has_blas = rspamd_tensor.has_blas()
+local text_cookie = rspamd_text.cookie
+
+-- Creates and stores ANN profile in Redis
+local function new_ann_profile(task, rule, set, version)
+ local ann_key = neural_common.new_ann_key(rule, set, version, settings)
+
+ local profile = {
+ symbols = set.symbols,
+ redis_key = ann_key,
+ version = version,
+ digest = set.digest,
+ distance = 0 -- Since we are using our own profile
+ }
+
+ local ucl = require "ucl"
+ local profile_serialized = ucl.to_format(profile, 'json-compact', true)
+
+ local function add_cb(err, _)
+ if err then
+ rspamd_logger.errx(task, 'cannot store ANN profile for %s:%s at %s : %s',
+ rule.prefix, set.name, profile.redis_key, err)
+ else
+ rspamd_logger.infox(task, 'created new ANN profile for %s:%s, data stored at prefix %s',
+ rule.prefix, set.name, profile.redis_key)
+ end
+ end
+
+ lua_redis.redis_make_request(task,
+ rule.redis,
+ nil,
+ true, -- is write
+ add_cb, --callback
+ 'ZADD', -- command
+ { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
+ )
+
+ return profile
+end
+
+
+-- ANN filter function, used to insert scores based on the existing symbols
+local function ann_scores_filter(task)
+
+ for _, rule in pairs(settings.rules) do
+ local sid = task:get_settings_id() or -1
+ local ann
+ local profile
+
+ local set = neural_common.get_rule_settings(task, rule)
+ if set then
+ if set.ann then
+ ann = set.ann.ann
+ profile = set.ann
+ else
+ lua_util.debugm(N, task, 'no ann loaded for %s:%s',
+ rule.prefix, set.name)
+ end
+ else
+ lua_util.debugm(N, task, 'no ann defined in %s for settings id %s',
+ rule.prefix, sid)
+ end
+
+ if ann then
+ local vec = neural_common.result_to_vector(task, profile)
+
+ local score
+ local out = ann:apply1(vec, set.ann.pca)
+ score = out[1]
+
+ local symscore = string.format('%.3f', score)
+ task:cache_set(rule.prefix .. '_neural_score', score)
+ lua_util.debugm(N, task, '%s:%s:%s ann score: %s',
+ rule.prefix, set.name, set.ann.version, symscore)
+
+ if score > 0 then
+ local result = score
+
+ -- If spam_score_threshold is defined, override all other thresholds.
+ local spam_threshold = 0
+ if rule.spam_score_threshold then
+ spam_threshold = rule.spam_score_threshold
+ elseif rule.roc_enabled and not set.ann.roc_thresholds then
+ spam_threshold = set.ann.roc_thresholds[1]
+ end
+
+ if result >= spam_threshold then
+ if rule.flat_threshold_curve then
+ task:insert_result(rule.symbol_spam, 1.0, symscore)
+ else
+ task:insert_result(rule.symbol_spam, result, symscore)
+ end
+ else
+ lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (spam threshold)',
+ rule.prefix, set.name, set.ann.version, symscore,
+ spam_threshold)
+ end
+ else
+ local result = -(score)
+
+ -- If ham_score_threshold is defined, override all other thresholds.
+ local ham_threshold = 0
+ if rule.ham_score_threshold then
+ ham_threshold = rule.ham_score_threshold
+ elseif rule.roc_enabled and not set.ann.roc_thresholds then
+ ham_threshold = set.ann.roc_thresholds[2]
+ end
+
+ if result >= ham_threshold then
+ if rule.flat_threshold_curve then
+ task:insert_result(rule.symbol_ham, 1.0, symscore)
+ else
+ task:insert_result(rule.symbol_ham, result, symscore)
+ end
+ else
+ lua_util.debugm(N, task, '%s:%s:%s ann score: %s < %s (ham threshold)',
+ rule.prefix, set.name, set.ann.version, result,
+ ham_threshold)
+ end
+ end
+ end
+ end
+end
+
+local function ann_push_task_result(rule, task, verdict, score, set)
+ local train_opts = rule.train
+ local learn_spam, learn_ham
+ local skip_reason = 'unknown'
+
+ if not train_opts.store_pool_only and train_opts.autotrain then
+ if train_opts.spam_score then
+ learn_spam = score >= train_opts.spam_score
+
+ if not learn_spam then
+ skip_reason = string.format('score < spam_score: %f < %f',
+ score, train_opts.spam_score)
+ end
+ else
+ learn_spam = verdict == 'spam' or verdict == 'junk'
+
+ if not learn_spam then
+ skip_reason = string.format('verdict: %s',
+ verdict)
+ end
+ end
+
+ if train_opts.ham_score then
+ learn_ham = score <= train_opts.ham_score
+ if not learn_ham then
+ skip_reason = string.format('score > ham_score: %f > %f',
+ score, train_opts.ham_score)
+ end
+ else
+ learn_ham = verdict == 'ham'
+
+ if not learn_ham then
+ skip_reason = string.format('verdict: %s',
+ verdict)
+ end
+ end
+ else
+ -- Train by request header
+ local hdr = task:get_request_header('ANN-Train')
+
+ if hdr then
+ if hdr:lower() == 'spam' then
+ learn_spam = true
+ elseif hdr:lower() == 'ham' then
+ learn_ham = true
+ else
+ skip_reason = 'no explicit header'
+ end
+ elseif train_opts.store_pool_only then
+ local ucl = require "ucl"
+ learn_ham = false
+ learn_spam = false
+
+ -- Explicitly store tokens in cache
+ local vec = neural_common.result_to_vector(task, set)
+ task:cache_set(rule.prefix .. '_neural_vec_mpack', ucl.to_format(vec, 'msgpack'))
+ task:cache_set(rule.prefix .. '_neural_profile_digest', set.digest)
+ skip_reason = 'store_pool_only has been set'
+ end
+ end
+
+ if learn_spam or learn_ham then
+ local learn_type
+ if learn_spam then
+ learn_type = 'spam'
+ else
+ learn_type = 'ham'
+ end
+
+ local function vectors_len_cb(err, data)
+ if not err and type(data) == 'table' then
+ local nspam, nham = data[1], data[2]
+
+ if neural_common.can_push_train_vector(rule, task, learn_type, nspam, nham) then
+ local vec = neural_common.result_to_vector(task, set)
+
+ local str = rspamd_util.zstd_compress(table.concat(vec, ';'))
+ local target_key = set.ann.redis_key .. '_' .. learn_type .. '_set'
+
+ local function learn_vec_cb(redis_err)
+ if redis_err then
+ rspamd_logger.errx(task, 'cannot store train vector for %s:%s: %s',
+ rule.prefix, set.name, redis_err)
+ else
+ lua_util.debugm(N, task,
+ "add train data for ANN rule " ..
+ "%s:%s, save %s vector of %s elts in %s key; %s bytes compressed",
+ rule.prefix, set.name, learn_type, #vec, target_key, #str)
+ end
+ end
+
+ lua_redis.redis_make_request(task,
+ rule.redis,
+ nil,
+ true, -- is write
+ learn_vec_cb, --callback
+ 'SADD', -- command
+ { target_key, str } -- arguments
+ )
+ else
+ lua_util.debugm(N, task,
+ "do not add %s train data for ANN rule " ..
+ "%s:%s",
+ learn_type, rule.prefix, set.name)
+ end
+ else
+ if err then
+ rspamd_logger.errx(task, 'cannot check if we can train %s:%s : %s',
+ rule.prefix, set.name, err)
+ elseif type(data) == 'string' then
+ -- nil return value
+ rspamd_logger.infox(task, "cannot learn %s ANN %s:%s; redis_key: %s: locked for learning: %s",
+ learn_type, rule.prefix, set.name, set.ann.redis_key, data)
+ else
+ rspamd_logger.errx(task, 'cannot check if we can train %s:%s : type of Redis key %s is %s, expected table' ..
+ 'please remove this key from Redis manually if you perform upgrade from the previous version',
+ rule.prefix, set.name, set.ann.redis_key, type(data))
+ end
+ end
+ end
+
+ -- Check if we can learn
+ if set.can_store_vectors then
+ if not set.ann then
+ -- Need to create or load a profile corresponding to the current configuration
+ set.ann = new_ann_profile(task, rule, set, 0)
+ lua_util.debugm(N, task,
+ 'requested new profile for %s, set.ann is missing',
+ set.name)
+ end
+
+ lua_redis.exec_redis_script(neural_common.redis_script_id.vectors_len,
+ { task = task, is_write = false },
+ vectors_len_cb,
+ {
+ set.ann.redis_key,
+ })
+ else
+ lua_util.debugm(N, task,
+ 'do not push data: train condition not satisfied; reason: not checked existing ANNs')
+ end
+ else
+ lua_util.debugm(N, task,
+ 'do not push data to key %s: train condition not satisfied; reason: %s',
+ (set.ann or {}).redis_key,
+ skip_reason)
+ end
+end
+
+--- Offline training logic
+
+-- Utility to extract and split saved training vectors to a table of tables
+local function process_training_vectors(data)
+ return fun.totable(fun.map(function(tok)
+ local _, str = rspamd_util.zstd_decompress(tok)
+ return fun.totable(fun.map(tonumber, lua_util.str_split(tostring(str), ';')))
+ end, data))
+end
+
+-- This function does the following:
+-- * Tries to lock ANN
+-- * Loads spam and ham vectors
+-- * Spawn learning process
+local function do_train_ann(worker, ev_base, rule, set, ann_key)
+ local spam_elts = {}
+ local ham_elts = {}
+
+ local function redis_ham_cb(err, data)
+ if err or type(data) ~= 'table' then
+ rspamd_logger.errx(rspamd_config, 'cannot get ham tokens for ANN %s from redis: %s',
+ ann_key, err)
+ -- Unlock on error
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ neural_common.gen_unlock_cb(rule, set, ann_key), --callback
+ 'HDEL', -- command
+ { ann_key, 'lock' }
+ )
+ else
+ -- Decompress and convert to numbers each training vector
+ ham_elts = process_training_vectors(data)
+ neural_common.spawn_train({ worker = worker, ev_base = ev_base,
+ rule = rule, set = set, ann_key = ann_key, ham_vec = ham_elts,
+ spam_vec = spam_elts })
+ end
+ end
+
+ -- Spam vectors received
+ local function redis_spam_cb(err, data)
+ if err or type(data) ~= 'table' then
+ rspamd_logger.errx(rspamd_config, 'cannot get spam tokens for ANN %s from redis: %s',
+ ann_key, err)
+ -- Unlock ANN on error
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ neural_common.gen_unlock_cb(rule, set, ann_key), --callback
+ 'HDEL', -- command
+ { ann_key, 'lock' }
+ )
+ else
+ -- Decompress and convert to numbers each training vector
+ spam_elts = process_training_vectors(data)
+ -- Now get ham vectors...
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ false, -- is write
+ redis_ham_cb, --callback
+ 'SMEMBERS', -- command
+ { ann_key .. '_ham_set' }
+ )
+ end
+ end
+
+ local function redis_lock_cb(err, data)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot call lock script for ANN %s from redis: %s',
+ ann_key, err)
+ elseif type(data) == 'number' and data == 1 then
+ -- ANN is locked, so we can extract SPAM and HAM vectors and spawn learning
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ false, -- is write
+ redis_spam_cb, --callback
+ 'SMEMBERS', -- command
+ { ann_key .. '_spam_set' }
+ )
+
+ rspamd_logger.infox(rspamd_config, 'lock ANN %s:%s (key name %s) for learning',
+ rule.prefix, set.name, ann_key)
+ else
+ local lock_tm = tonumber(data[1])
+ rspamd_logger.infox(rspamd_config, 'do not learn ANN %s:%s (key name %s), ' ..
+ 'locked by another host %s at %s', rule.prefix, set.name, ann_key,
+ data[2], os.date('%c', lock_tm))
+ end
+ end
+
+ -- Check if we are already learning this network
+ if set.learning_spawned then
+ rspamd_logger.infox(rspamd_config, 'do not learn ANN %s, already learning another ANN',
+ ann_key)
+ return
+ end
+
+ -- Call Redis script that tries to acquire a lock
+ -- This script returns either a boolean or a pair {'lock_time', 'hostname'} when
+ -- ANN is locked by another host (or a process, meh)
+ lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_lock,
+ { ev_base = ev_base, is_write = true },
+ redis_lock_cb,
+ {
+ ann_key,
+ tostring(os.time()),
+ tostring(math.max(10.0, rule.watch_interval * 2)),
+ rspamd_util.get_hostname()
+ })
+end
+
+-- This function loads new ann from Redis
+-- This is based on `profile` attribute.
+-- ANN is loaded from `profile.redis_key`
+-- Rank of `profile` key is also increased, unfortunately, it means that we need to
+-- serialize profile one more time and set its rank to the current time
+-- set.ann fields are set according to Redis data received
+local function load_new_ann(rule, ev_base, set, profile, min_diff)
+ local ann_key = profile.redis_key
+
+ local function data_cb(err, data)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot get ANN data from key: %s; %s',
+ ann_key, err)
+ else
+ if type(data) == 'table' then
+ if type(data[1]) == 'userdata' and data[1].cookie == text_cookie then
+ local _err, ann_data = rspamd_util.zstd_decompress(data[1])
+ local ann
+
+ if _err or not ann_data then
+ rspamd_logger.errx(rspamd_config, 'cannot decompress ANN for %s from Redis key %s: %s',
+ rule.prefix .. ':' .. set.name, ann_key, _err)
+ return
+ else
+ ann = rspamd_kann.load(ann_data)
+
+ if ann then
+ set.ann = {
+ digest = profile.digest,
+ version = profile.version,
+ symbols = profile.symbols,
+ distance = min_diff,
+ redis_key = profile.redis_key
+ }
+
+ local ucl = require "ucl"
+ local profile_serialized = ucl.to_format(profile, 'json-compact', true)
+ set.ann.ann = ann -- To avoid serialization
+
+ local function rank_cb(_, _)
+ -- TODO: maybe add some logging
+ end
+ -- Also update rank for the loaded ANN to avoid removal
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ rank_cb, --callback
+ 'ZADD', -- command
+ { set.prefix, tostring(rspamd_util.get_time()), profile_serialized }
+ )
+ rspamd_logger.infox(rspamd_config,
+ 'loaded ANN for %s:%s from %s; %s bytes compressed; version=%s',
+ rule.prefix, set.name, ann_key, #data[1], profile.version)
+ else
+ rspamd_logger.errx(rspamd_config,
+ 'cannot unpack/deserialise ANN for %s:%s from Redis key %s',
+ rule.prefix, set.name, ann_key)
+ end
+ end
+ else
+ lua_util.debugm(N, rspamd_config, 'missing ANN for %s:%s in Redis key %s',
+ rule.prefix, set.name, ann_key)
+ end
+
+ if set.ann and set.ann.ann and type(data[2]) == 'userdata' and data[2].cookie == text_cookie then
+ if rule.roc_enabled then
+ local ucl = require "ucl"
+ local parser = ucl.parser()
+ local ok, parse_err = parser:parse_text(data[2])
+ assert(ok, parse_err)
+ local roc_thresholds = parser:get_object()
+ set.ann.roc_thresholds = roc_thresholds
+ rspamd_logger.infox(rspamd_config,
+ 'loaded ROC thresholds for %s:%s; version=%s',
+ rule.prefix, set.name, profile.version)
+ rspamd_logger.debugx("ROC thresholds: %s", roc_thresholds)
+ end
+ end
+
+ if set.ann and set.ann.ann and type(data[3]) == 'userdata' and data[3].cookie == text_cookie then
+ -- PCA table
+ local _err, pca_data = rspamd_util.zstd_decompress(data[3])
+ if pca_data then
+ if rule.max_inputs then
+ -- We can use PCA
+ set.ann.pca = rspamd_tensor.load(pca_data)
+ rspamd_logger.infox(rspamd_config,
+ 'loaded PCA for ANN for %s:%s from %s; %s bytes compressed; version=%s',
+ rule.prefix, set.name, ann_key, #data[3], profile.version)
+ else
+ -- no need in pca, why is it there?
+ rspamd_logger.warnx(rspamd_config,
+ 'extra PCA for ANN for %s:%s from Redis key %s: no max inputs defined',
+ rule.prefix, set.name, ann_key)
+ end
+ else
+ -- pca can be missing merely if we have no max_inputs
+ if rule.max_inputs then
+ rspamd_logger.errx(rspamd_config, 'cannot unpack/deserialise ANN for %s:%s from Redis key %s: no PCA: %s',
+ rule.prefix, set.name, ann_key, _err)
+ set.ann.ann = nil
+ else
+ -- It is okay
+ set.ann.pca = nil
+ end
+ end
+ end
+
+ else
+ lua_util.debugm(N, rspamd_config, 'no ANN key for %s:%s in Redis key %s',
+ rule.prefix, set.name, ann_key)
+ end
+ end
+ end
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ false, -- is write
+ data_cb, --callback
+ 'HMGET', -- command
+ { ann_key, 'ann', 'roc_thresholds', 'pca' }, -- arguments
+ { opaque_data = true }
+ )
+end
+
+-- Used to check an element in Redis serialized as JSON
+-- for some specific rule + some specific setting
+-- This function tries to load more fresh or more specific ANNs in lieu of
+-- the existing ones.
+-- Use this function to load ANNs as `callback` parameter for `check_anns` function
+local function process_existing_ann(_, ev_base, rule, set, profiles)
+ local my_symbols = set.symbols
+ local min_diff = math.huge
+ local sel_elt
+
+ for _, elt in fun.iter(profiles) do
+ if elt and elt.symbols then
+ local dist = lua_util.distance_sorted(elt.symbols, my_symbols)
+ -- Check distance
+ if dist < #my_symbols * .3 then
+ if dist < min_diff then
+ min_diff = dist
+ sel_elt = elt
+ end
+ end
+ end
+ end
+
+ if sel_elt then
+ -- We can load element from ANN
+ if set.ann then
+ -- We have an existing ANN, probably the same...
+ if set.ann.digest == sel_elt.digest then
+ -- Same ANN, check version
+ if set.ann.version < sel_elt.version then
+ -- Load new ann
+ rspamd_logger.infox(rspamd_config, 'ann %s is changed, ' ..
+ 'our version = %s, remote version = %s',
+ rule.prefix .. ':' .. set.name,
+ set.ann.version,
+ sel_elt.version)
+ load_new_ann(rule, ev_base, set, sel_elt, min_diff)
+ else
+ lua_util.debugm(N, rspamd_config, 'ann %s is not changed, ' ..
+ 'our version = %s, remote version = %s',
+ rule.prefix .. ':' .. set.name,
+ set.ann.version,
+ sel_elt.version)
+ end
+ else
+ -- We have some different ANN, so we need to compare distance
+ if set.ann.distance > min_diff then
+ -- Load more specific ANN
+ rspamd_logger.infox(rspamd_config, 'more specific ann is available for %s, ' ..
+ 'our distance = %s, remote distance = %s',
+ rule.prefix .. ':' .. set.name,
+ set.ann.distance,
+ min_diff)
+ load_new_ann(rule, ev_base, set, sel_elt, min_diff)
+ else
+ lua_util.debugm(N, rspamd_config, 'ann %s is not changed or less specific, ' ..
+ 'our distance = %s, remote distance = %s',
+ rule.prefix .. ':' .. set.name,
+ set.ann.distance,
+ min_diff)
+ end
+ end
+ else
+ -- We have no ANN, load new one
+ load_new_ann(rule, ev_base, set, sel_elt, min_diff)
+ end
+ end
+end
+
+
+-- This function checks all profiles and selects if we can train our
+-- ANN. By our we mean that it has exactly the same symbols in profile.
+-- Use this function to train ANN as `callback` parameter for `check_anns` function
+local function maybe_train_existing_ann(worker, ev_base, rule, set, profiles)
+ local my_symbols = set.symbols
+ local sel_elt
+ local lens = {
+ spam = 0,
+ ham = 0,
+ }
+
+ for _, elt in fun.iter(profiles) do
+ if elt and elt.symbols then
+ local dist = lua_util.distance_sorted(elt.symbols, my_symbols)
+ -- Check distance
+ if dist == 0 then
+ sel_elt = elt
+ break
+ end
+ end
+ end
+
+ if sel_elt then
+ -- We have our ANN and that's train vectors, check if we can learn
+ local ann_key = sel_elt.redis_key
+
+ lua_util.debugm(N, rspamd_config, "check if ANN %s needs to be trained",
+ ann_key)
+
+ -- Create continuation closure
+ local redis_len_cb_gen = function(cont_cb, what, is_final)
+ return function(err, data)
+ if err then
+ rspamd_logger.errx(rspamd_config,
+ 'cannot get ANN %s trains %s from redis: %s', what, ann_key, err)
+ elseif data and type(data) == 'number' or type(data) == 'string' then
+ local ntrains = tonumber(data) or 0
+ lens[what] = ntrains
+ if is_final then
+ -- Ensure that we have the following:
+ -- one class has reached max_trains
+ -- other class(es) are at least as full as classes_bias
+ -- e.g. if classes_bias = 0.25 and we have 10 max_trains then
+ -- one class must have 10 or more trains whilst another should have
+ -- at least (10 * (1 - 0.25)) = 8 trains
+
+ local max_len = math.max(lua_util.unpack(lua_util.values(lens)))
+ local min_len = math.min(lua_util.unpack(lua_util.values(lens)))
+
+ if rule.train.learn_type == 'balanced' then
+ local len_bias_check_pred = function(_, l)
+ return l >= rule.train.max_trains * (1.0 - rule.train.classes_bias)
+ end
+ if max_len >= rule.train.max_trains and fun.all(len_bias_check_pred, lens) then
+ rspamd_logger.debugm(N, rspamd_config,
+ 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
+ ann_key, lens, rule.train.max_trains, what)
+ cont_cb()
+ else
+ rspamd_logger.debugm(N, rspamd_config,
+ 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
+ ann_key, what, lens, rule.train.max_trains)
+ end
+ else
+ -- Probabilistic mode, just ensure that at least one vector is okay
+ if min_len > 0 and max_len >= rule.train.max_trains then
+ rspamd_logger.debugm(N, rspamd_config,
+ 'can start ANN %s learn as it has %s learn vectors; %s required, after checking %s vectors',
+ ann_key, lens, rule.train.max_trains, what)
+ cont_cb()
+ else
+ rspamd_logger.debugm(N, rspamd_config,
+ 'cannot learn ANN %s now: there are not enough %s learn vectors (has %s vectors; %s required)',
+ ann_key, what, lens, rule.train.max_trains)
+ end
+ end
+
+ else
+ rspamd_logger.debugm(N, rspamd_config,
+ 'checked %s vectors in ANN %s: %s vectors; %s required, need to check other class vectors',
+ what, ann_key, ntrains, rule.train.max_trains)
+ cont_cb()
+ end
+ end
+ end
+
+ end
+
+ local function initiate_train()
+ rspamd_logger.infox(rspamd_config,
+ 'need to learn ANN %s after %s required learn vectors',
+ ann_key, lens)
+ do_train_ann(worker, ev_base, rule, set, ann_key)
+ end
+
+ -- Spam vector is OK, check ham vector length
+ local function check_ham_len()
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ false, -- is write
+ redis_len_cb_gen(initiate_train, 'ham', true), --callback
+ 'SCARD', -- command
+ { ann_key .. '_ham_set' }
+ )
+ end
+
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ false, -- is write
+ redis_len_cb_gen(check_ham_len, 'spam', false), --callback
+ 'SCARD', -- command
+ { ann_key .. '_spam_set' }
+ )
+ end
+end
+
+-- Used to deserialise ANN element from a list
+local function load_ann_profile(element)
+ local ucl = require "ucl"
+
+ local parser = ucl.parser()
+ local res, ucl_err = parser:parse_string(element)
+ if not res then
+ rspamd_logger.warnx(rspamd_config, 'cannot parse ANN from redis: %s',
+ ucl_err)
+ return nil
+ else
+ local profile = parser:get_object()
+ local checked, schema_err = redis_profile_schema:transform(profile)
+ if not checked then
+ rspamd_logger.errx(rspamd_config, "cannot parse profile schema: %s", schema_err)
+
+ return nil
+ end
+ return checked
+ end
+end
+
+-- Function to check or load ANNs from Redis
+local function check_anns(worker, cfg, ev_base, rule, process_callback, what)
+ for _, set in pairs(rule.settings) do
+ local function members_cb(err, data)
+ if err then
+ rspamd_logger.errx(cfg, 'cannot get ANNs list from redis: %s',
+ err)
+ set.can_store_vectors = true
+ elseif type(data) == 'table' then
+ lua_util.debugm(N, cfg, '%s: process element %s:%s',
+ what, rule.prefix, set.name)
+ process_callback(worker, ev_base, rule, set, fun.map(load_ann_profile, data))
+ set.can_store_vectors = true
+ end
+ end
+
+ if type(set) == 'table' then
+ -- Extract all profiles for some specific settings id
+ -- Get the last `max_profiles` recently used
+ -- Select the most appropriate to our profile but it should not differ by more
+ -- than 30% of symbols
+ lua_redis.redis_make_request_taskless(ev_base,
+ cfg,
+ rule.redis,
+ nil,
+ false, -- is write
+ members_cb, --callback
+ 'ZREVRANGE', -- command
+ { set.prefix, '0', tostring(settings.max_profiles) } -- arguments
+ )
+ end
+ end -- Cycle over all settings
+
+ return rule.watch_interval
+end
+
+-- Function to clean up old ANNs
+local function cleanup_anns(rule, cfg, ev_base)
+ for _, set in pairs(rule.settings) do
+ local function invalidate_cb(err, data)
+ if err then
+ rspamd_logger.errx(cfg, 'cannot exec invalidate script in redis: %s',
+ err)
+ elseif type(data) == 'table' then
+ for _, expired in ipairs(data) do
+ local profile = load_ann_profile(expired)
+ rspamd_logger.infox(cfg, 'invalidated ANN for %s; redis key: %s; version=%s',
+ rule.prefix .. ':' .. set.name,
+ profile.redis_key,
+ profile.version)
+ end
+ end
+ end
+
+ if type(set) == 'table' then
+ lua_redis.exec_redis_script(neural_common.redis_script_id.maybe_invalidate,
+ { ev_base = ev_base, is_write = true },
+ invalidate_cb,
+ { set.prefix, tostring(settings.max_profiles) })
+ end
+ end
+end
+
+local function ann_push_vector(task)
+ if task:has_flag('skip') then
+ lua_util.debugm(N, task, 'do not push data for skipped task')
+ return
+ end
+ if not settings.allow_local and lua_util.is_rspamc_or_controller(task) then
+ lua_util.debugm(N, task, 'do not push data for manual scan')
+ return
+ end
+
+ local verdict, score = lua_verdict.get_specific_verdict(N, task)
+
+ if verdict == 'passthrough' then
+ lua_util.debugm(N, task, 'ignore task as its verdict is %s(%s)',
+ verdict, score)
+
+ return
+ end
+
+ if score ~= score then
+ lua_util.debugm(N, task, 'ignore task as its score is nan (%s verdict)',
+ verdict)
+
+ return
+ end
+
+ for _, rule in pairs(settings.rules) do
+ local set = neural_common.get_rule_settings(task, rule)
+
+ if set then
+ ann_push_task_result(rule, task, verdict, score, set)
+ else
+ lua_util.debugm(N, task, 'settings not found in rule %s', rule.prefix)
+ end
+
+ end
+end
+
+
+-- Initialization part
+if not (neural_common.module_config and type(neural_common.module_config) == 'table')
+ or not neural_common.redis_params then
+ rspamd_logger.infox(rspamd_config, 'Module is unconfigured')
+ lua_util.disable_module(N, "redis")
+ return
+end
+
+local rules = neural_common.module_config['rules']
+
+if not rules then
+ -- Use legacy configuration
+ rules = {}
+ rules['default'] = neural_common.module_config
+end
+
+local id = rspamd_config:register_symbol({
+ name = 'NEURAL_CHECK',
+ type = 'postfilter,callback',
+ flags = 'nostat',
+ priority = lua_util.symbols_priorities.medium,
+ callback = ann_scores_filter
+})
+
+neural_common.settings.rules = {} -- Reset unless validated further in the cycle
+
+if settings.blacklisted_symbols and settings.blacklisted_symbols[1] then
+ -- Transform to hash for simplicity
+ settings.blacklisted_symbols = lua_util.list_to_hash(settings.blacklisted_symbols)
+end
+
+-- Check all rules
+for k, r in pairs(rules) do
+ local rule_elt = lua_util.override_defaults(neural_common.default_options, r)
+ rule_elt['redis'] = neural_common.redis_params
+ rule_elt['anns'] = {} -- Store ANNs here
+
+ if not rule_elt.prefix then
+ rule_elt.prefix = k
+ end
+ if not rule_elt.name then
+ rule_elt.name = k
+ end
+ if rule_elt.train.max_train and not rule_elt.train.max_trains then
+ rule_elt.train.max_trains = rule_elt.train.max_train
+ end
+
+ if not rule_elt.profile then
+ rule_elt.profile = {}
+ end
+
+ if rule_elt.max_inputs and not has_blas then
+ rspamd_logger.errx('cannot set max inputs to %s as BLAS is not compiled in',
+ rule_elt.name, rule_elt.max_inputs)
+ rule_elt.max_inputs = nil
+ end
+
+ rspamd_logger.infox(rspamd_config, "register ann rule %s", k)
+ settings.rules[k] = rule_elt
+ rspamd_config:set_metric_symbol({
+ name = rule_elt.symbol_spam,
+ score = 0.0,
+ description = 'Neural network SPAM',
+ group = 'neural'
+ })
+ rspamd_config:register_symbol({
+ name = rule_elt.symbol_spam,
+ type = 'virtual',
+ flags = 'nostat',
+ parent = id
+ })
+
+ rspamd_config:set_metric_symbol({
+ name = rule_elt.symbol_ham,
+ score = -0.0,
+ description = 'Neural network HAM',
+ group = 'neural'
+ })
+ rspamd_config:register_symbol({
+ name = rule_elt.symbol_ham,
+ type = 'virtual',
+ flags = 'nostat',
+ parent = id
+ })
+end
+
+rspamd_config:register_symbol({
+ name = 'NEURAL_LEARN',
+ type = 'idempotent,callback',
+ flags = 'nostat,explicit_disable,ignore_passthrough',
+ callback = ann_push_vector
+})
+
+-- We also need to deal with settings
+rspamd_config:add_post_init(neural_common.process_rules_settings)
+
+-- Add training scripts
+for _, rule in pairs(settings.rules) do
+ neural_common.load_scripts(rule.redis)
+ -- This function will check ANNs in Redis when a worker is loaded
+ rspamd_config:add_on_load(function(cfg, ev_base, worker)
+ if worker:is_scanner() then
+ rspamd_config:add_periodic(ev_base, 0.0,
+ function(_, _)
+ return check_anns(worker, cfg, ev_base, rule, process_existing_ann,
+ 'try_load_ann')
+ end)
+ end
+
+ if worker:is_primary_controller() then
+ -- We also want to train neural nets when they have enough data
+ rspamd_config:add_periodic(ev_base, 0.0,
+ function(_, _)
+ -- Clean old ANNs
+ cleanup_anns(rule, cfg, ev_base)
+ return check_anns(worker, cfg, ev_base, rule, maybe_train_existing_ann,
+ 'try_train_ann')
+ end)
+ end
+ end)
+end