diff options
Diffstat (limited to '')
-rw-r--r-- | src/plugins/lua/bayes_expiry.lua | 503 |
1 files changed, 503 insertions, 0 deletions
diff --git a/src/plugins/lua/bayes_expiry.lua b/src/plugins/lua/bayes_expiry.lua new file mode 100644 index 0000000..44ff9da --- /dev/null +++ b/src/plugins/lua/bayes_expiry.lua @@ -0,0 +1,503 @@ +--[[ +Copyright (c) 2017, Andrew Lewis <nerf@judo.za.org> +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]] -- + +if confighelp then + return +end + +local N = 'bayes_expiry' +local E = {} +local logger = require "rspamd_logger" +local rspamd_util = require "rspamd_util" +local lutil = require "lua_util" +local lredis = require "lua_redis" + +local settings = { + interval = 60, -- one iteration step per minute + count = 1000, -- check up to 1000 keys on each iteration + epsilon_common = 0.01, -- eliminate common if spam to ham rate is equal to this epsilon + common_ttl = 10 * 86400, -- TTL of discriminated common elements + significant_factor = 3.0 / 4.0, -- which tokens should we update + classifiers = {}, + cluster_nodes = 0, +} + +local template = {} + +local function check_redis_classifier(cls, cfg) + -- Skip old classifiers + if cls.new_schema then + local symbol_spam, symbol_ham + local expiry = (cls.expiry or cls.expire) + if type(expiry) == 'table' then + expiry = expiry[1] + end + + -- Load symbols from statfiles + + local function check_statfile_table(tbl, def_sym) + local symbol = tbl.symbol or def_sym + + local spam + if tbl.spam then + spam = tbl.spam + else + if string.match(symbol:upper(), 'SPAM') then + spam = true + else + spam = false + end + end + + if spam then + symbol_spam = symbol + else + symbol_ham = symbol + end + end + + local statfiles = cls.statfile + if statfiles[1] then + for _, stf in ipairs(statfiles) do + if not stf.symbol then + for k, v in pairs(stf) do + check_statfile_table(v, k) + end + else + check_statfile_table(stf, 'undefined') + end + end + else + for stn, stf in pairs(statfiles) do + check_statfile_table(stf, stn) + end + end + + if not symbol_spam or not symbol_ham or type(expiry) ~= 'number' then + logger.debugm(N, rspamd_config, + 'disable expiry for classifier %s: no expiry %s', + symbol_spam, cls) + return + end + -- Now try to load redis_params if needed + + local redis_params + redis_params = lredis.try_load_redis_servers(cls, rspamd_config, false, 'bayes') + if not redis_params then + redis_params = lredis.try_load_redis_servers(cfg[N] or E, rspamd_config, false, 'bayes') + if not redis_params then + redis_params = lredis.try_load_redis_servers(cfg[N] or E, rspamd_config, true) + if not redis_params then + return false + end + end + end + + if redis_params['read_only'] then + logger.infox(rspamd_config, 'disable expiry for classifier %s: read only redis configuration', + symbol_spam) + return + end + + logger.debugm(N, rspamd_config, "enabled expiry for %s/%s -> %s expiry", + symbol_spam, symbol_ham, expiry) + + table.insert(settings.classifiers, { + symbol_spam = symbol_spam, + symbol_ham = symbol_ham, + redis_params = redis_params, + expiry = expiry + }) + end +end + +-- Check classifiers and try find the appropriate ones +local obj = rspamd_config:get_ucl() + +local classifier = obj.classifier + +if classifier then + if classifier[1] then + for _, cls in ipairs(classifier) do + if cls.bayes then + cls = cls.bayes + end + if cls.backend and cls.backend == 'redis' then + check_redis_classifier(cls, obj) + end + end + else + if classifier.bayes then + + classifier = classifier.bayes + if classifier[1] then + for _, cls in ipairs(classifier) do + if cls.backend and cls.backend == 'redis' then + check_redis_classifier(cls, obj) + end + end + else + if classifier.backend and classifier.backend == 'redis' then + check_redis_classifier(classifier, obj) + end + end + end + end +end + +local opts = rspamd_config:get_all_opt(N) + +if opts then + for k, v in pairs(opts) do + settings[k] = v + end +end + +-- In clustered setup, we need to increase interval of expiration +-- according to number of nodes in a cluster +if settings.cluster_nodes == 0 then + local neighbours = obj.neighbours or {} + local n_neighbours = 0 + for _, _ in pairs(neighbours) do + n_neighbours = n_neighbours + 1 + end + settings.cluster_nodes = n_neighbours +end + +-- Fill template +template.count = settings.count +template.threshold = settings.threshold +template.common_ttl = settings.common_ttl +template.epsilon_common = settings.epsilon_common +template.significant_factor = settings.significant_factor +template.expire_step = settings.interval +template.hostname = rspamd_util.get_hostname() + +for k, v in pairs(template) do + template[k] = tostring(v) +end + +-- Arguments: +-- [1] = symbol pattern +-- [2] = expire value +-- [3] = cursor +-- returns {cursor for the next step, step number, step statistic counters, cycle statistic counters, tokens occurrences distribution} +local expiry_script = [[ + local unpack_function = table.unpack or unpack + + local hash2list = function (hash) + local res = {} + for k, v in pairs(hash) do + table.insert(res, k) + table.insert(res, v) + end + return res + end + + local function merge_list(table, list) + local k + for i, v in ipairs(list) do + if i % 2 == 1 then + k = v + else + table[k] = v + end + end + end + + local expire = math.floor(KEYS[2]) + local pattern_sha1 = redis.sha1hex(KEYS[1]) + + local lock_key = pattern_sha1 .. '_lock' -- Check locking + local lock = redis.call('GET', lock_key) + + if lock then + if lock ~= '${hostname}' then + return 'locked by ' .. lock + end + end + + redis.replicate_commands() + redis.call('SETEX', lock_key, ${expire_step}, '${hostname}') + + local cursor_key = pattern_sha1 .. '_cursor' + local cursor = tonumber(redis.call('GET', cursor_key) or 0) + + local step = 1 + local step_key = pattern_sha1 .. '_step' + if cursor > 0 then + step = redis.call('GET', step_key) + step = step and (tonumber(step) + 1) or 1 + end + + local ret = redis.call('SCAN', cursor, 'MATCH', KEYS[1], 'COUNT', '${count}') + local next_cursor = ret[1] + local keys = ret[2] + local tokens = {} + + -- Tokens occurrences distribution counters + local occur = { + ham = {}, + spam = {}, + total = {} + } + + -- Expiry step statistics counters + local nelts, extended, discriminated, sum, sum_squares, common, significant, + infrequent, infrequent_ttls_set, insignificant, insignificant_ttls_set = + 0,0,0,0,0,0,0,0,0,0,0 + + for _,key in ipairs(keys) do + local t = redis.call('TYPE', key)["ok"] + if t == 'hash' then + local values = redis.call('HMGET', key, 'H', 'S') + local ham = tonumber(values[1]) or 0 + local spam = tonumber(values[2]) or 0 + local ttl = redis.call('TTL', key) + tokens[key] = { + ham, + spam, + ttl + } + local total = spam + ham + sum = sum + total + sum_squares = sum_squares + total * total + nelts = nelts + 1 + + for k,v in pairs({['ham']=ham, ['spam']=spam, ['total']=total}) do + if tonumber(v) > 19 then v = 20 end + occur[k][v] = occur[k][v] and occur[k][v] + 1 or 1 + end + end + end + + local mean, stddev = 0, 0 + + if nelts > 0 then + mean = sum / nelts + stddev = math.sqrt(sum_squares / nelts - mean * mean) + end + + for key,token in pairs(tokens) do + local ham, spam, ttl = token[1], token[2], tonumber(token[3]) + local threshold = mean + local total = spam + ham + + local function set_ttl() + if expire < 0 then + if ttl ~= -1 then + redis.call('PERSIST', key) + return 1 + end + elseif ttl == -1 or ttl > expire then + redis.call('EXPIRE', key, expire) + return 1 + end + return 0 + end + + if total == 0 or math.abs(ham - spam) <= total * ${epsilon_common} then + common = common + 1 + if ttl > ${common_ttl} then + discriminated = discriminated + 1 + redis.call('EXPIRE', key, ${common_ttl}) + end + elseif total >= threshold and total > 0 then + if ham / total > ${significant_factor} or spam / total > ${significant_factor} then + significant = significant + 1 + if ttl ~= -1 then + redis.call('PERSIST', key) + extended = extended + 1 + end + else + insignificant = insignificant + 1 + insignificant_ttls_set = insignificant_ttls_set + set_ttl() + end + else + infrequent = infrequent + 1 + infrequent_ttls_set = infrequent_ttls_set + set_ttl() + end + end + + -- Expiry cycle statistics counters + local c = {nelts = 0, extended = 0, discriminated = 0, sum = 0, sum_squares = 0, + common = 0, significant = 0, infrequent = 0, infrequent_ttls_set = 0, insignificant = 0, insignificant_ttls_set = 0} + + local counters_key = pattern_sha1 .. '_counters' + + if cursor ~= 0 then + merge_list(c, redis.call('HGETALL', counters_key)) + end + + c.nelts = c.nelts + nelts + c.extended = c.extended + extended + c.discriminated = c.discriminated + discriminated + c.sum = c.sum + sum + c.sum_squares = c.sum_squares + sum_squares + c.common = c.common + common + c.significant = c.significant + significant + c.infrequent = c.infrequent + infrequent + c.infrequent_ttls_set = c.infrequent_ttls_set + infrequent_ttls_set + c.insignificant = c.insignificant + insignificant + c.insignificant_ttls_set = c.insignificant_ttls_set + insignificant_ttls_set + + redis.call('HMSET', counters_key, unpack_function(hash2list(c))) + redis.call('SET', cursor_key, tostring(next_cursor)) + redis.call('SET', step_key, tostring(step)) + redis.call('DEL', lock_key) + + local occ_distr = {} + for _,cl in pairs({'ham', 'spam', 'total'}) do + local occur_key = pattern_sha1 .. '_occurrence_' .. cl + + if cursor ~= 0 then + local n + for i,v in ipairs(redis.call('HGETALL', occur_key)) do + if i % 2 == 1 then + n = tonumber(v) + else + occur[cl][n] = occur[cl][n] and occur[cl][n] + v or v + end + end + + local str = '' + if occur[cl][0] ~= nil then + str = '0:' .. occur[cl][0] .. ',' + end + for k,v in ipairs(occur[cl]) do + if k == 20 then k = '>19' end + str = str .. k .. ':' .. v .. ',' + end + table.insert(occ_distr, str) + else + redis.call('DEL', occur_key) + end + + if next(occur[cl]) ~= nil then + redis.call('HMSET', occur_key, unpack_function(hash2list(occur[cl]))) + end + end + + return { + next_cursor, step, + {nelts, extended, discriminated, mean, stddev, common, significant, infrequent, + infrequent_ttls_set, insignificant, insignificant_ttls_set}, + {c.nelts, c.extended, c.discriminated, c.sum, c.sum_squares, c.common, + c.significant, c.infrequent, c.infrequent_ttls_set, c.insignificant, c.insignificant_ttls_set}, + occ_distr + } +]] + +local function expire_step(cls, ev_base, worker) + local function redis_step_cb(err, args) + if err then + logger.errx(rspamd_config, 'cannot perform expiry step: %s', err) + elseif type(args) == 'table' then + local cur = tonumber(args[1]) + local step = args[2] + local data = args[3] + local c_data = args[4] + local occ_distr = args[5] + + local function log_stat(cycle) + local infrequent_action = (cls.expiry < 0) and 'made persistent' or 'ttls set' + + local c_mean, c_stddev = 0, 0 + if cycle and c_data[1] ~= 0 then + c_mean = c_data[4] / c_data[1] + c_stddev = math.floor(.5 + math.sqrt(c_data[5] / c_data[1] - c_mean * c_mean)) + c_mean = math.floor(.5 + c_mean) + end + + local d = cycle and { + 'cycle in ' .. step .. ' steps', c_data[1], + c_data[7], c_data[2], 'made persistent', + c_data[10], c_data[11], infrequent_action, + c_data[6], c_data[3], + c_data[8], c_data[9], infrequent_action, + c_mean, + c_stddev + } or { + 'step ' .. step, data[1], + data[7], data[2], 'made persistent', + data[10], data[11], infrequent_action, + data[6], data[3], + data[8], data[9], infrequent_action, + data[4], + data[5] + } + logger.infox(rspamd_config, + 'finished expiry %s: %s items checked, %s significant (%s %s), ' .. + '%s insignificant (%s %s), %s common (%s discriminated), ' .. + '%s infrequent (%s %s), %s mean, %s std', + lutil.unpack(d)) + if cycle then + for i, cl in ipairs({ 'in ham', 'in spam', 'total' }) do + logger.infox(rspamd_config, 'tokens occurrences, %s: {%s}', cl, occ_distr[i]) + end + end + end + log_stat(false) + if cur == 0 then + log_stat(true) + end + elseif type(args) == 'string' then + logger.infox(rspamd_config, 'skip expiry step: %s', args) + end + end + lredis.exec_redis_script(cls.script, + { ev_base = ev_base, is_write = true }, + redis_step_cb, + { 'RS*_*', cls.expiry } + ) +end + +rspamd_config:add_on_load(function(_, ev_base, worker) + -- Exit unless we're the first 'controller' worker + if not worker:is_primary_controller() then + return + end + + local unique_redis_params = {} + -- Push redis script to all unique redis servers + for _, cls in ipairs(settings.classifiers) do + if not unique_redis_params[cls.redis_params.hash] then + unique_redis_params[cls.redis_params.hash] = cls.redis_params + end + end + + for h, rp in pairs(unique_redis_params) do + local script_id = lredis.add_redis_script(lutil.template(expiry_script, + template), rp) + + for _, cls in ipairs(settings.classifiers) do + if cls.redis_params.hash == h then + cls.script = script_id + end + end + end + + -- Expire tokens at regular intervals + for _, cls in ipairs(settings.classifiers) do + rspamd_config:add_periodic(ev_base, + settings['interval'], + function() + expire_step(cls, ev_base, worker) + return true + end, true) + end +end) |