diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-10 21:30:40 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-10 21:30:40 +0000 |
commit | 133a45c109da5310add55824db21af5239951f93 (patch) | |
tree | ba6ac4c0a950a0dda56451944315d66409923918 /src/plugins/lua/clustering.lua | |
parent | Initial commit. (diff) | |
download | rspamd-133a45c109da5310add55824db21af5239951f93.tar.xz rspamd-133a45c109da5310add55824db21af5239951f93.zip |
Adding upstream version 3.8.1.upstream/3.8.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'src/plugins/lua/clustering.lua')
-rw-r--r-- | src/plugins/lua/clustering.lua | 322 |
1 files changed, 322 insertions, 0 deletions
diff --git a/src/plugins/lua/clustering.lua b/src/plugins/lua/clustering.lua new file mode 100644 index 0000000..d97bdb9 --- /dev/null +++ b/src/plugins/lua/clustering.lua @@ -0,0 +1,322 @@ +--[[ +Copyright (c) 2018, Vsevolod Stakhov + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if confighelp then + return +end + +-- Plugin for finding patterns in email flows + +local N = 'clustering' + +local rspamd_logger = require "rspamd_logger" +local lua_util = require "lua_util" +local lua_verdict = require "lua_verdict" +local lua_redis = require "lua_redis" +local lua_selectors = require "lua_selectors" +local ts = require("tableshape").types + +local redis_params + +local rules = {} -- Rules placement + +local default_rule = { + max_elts = 100, -- Maximum elements in a cluster + expire = 3600, -- Expire for a bucket when limit is not reached + expire_overflow = 36000, -- Expire for a bucket when limit is reached + spam_mult = 1.0, -- Increase on spam hit + junk_mult = 0.5, -- Increase on junk + ham_mult = -0.1, -- Increase on ham + size_mult = 0.01, -- Reaches 1.0 on `max_elts` + score_mult = 0.1, +} + +local rule_schema = ts.shape { + max_elts = ts.number + ts.string / tonumber, + expire = ts.number + ts.string / lua_util.parse_time_interval, + expire_overflow = ts.number + ts.string / lua_util.parse_time_interval, + spam_mult = ts.number, + junk_mult = ts.number, + ham_mult = ts.number, + size_mult = ts.number, + score_mult = ts.number, + source_selector = ts.string, + cluster_selector = ts.string, + symbol = ts.string:is_optional(), + prefix = ts.string:is_optional(), +} + +-- Redis scripts + +-- Queries for a cluster's data +-- Arguments: +-- 1. Source selector (string) +-- 2. Cluster selector (string) +-- Returns: {cur_elts, total_score, element_score} +local query_cluster_script = [[ +local sz = redis.call('HLEN', KEYS[1]) + +if not sz or not tonumber(sz) then + -- New bucket, will update on idempotent phase + return {0, '0', '0'} +end + +local total_score = redis.call('HGET', KEYS[1], '__s') +total_score = tonumber(total_score) or 0 +local score = redis.call('HGET', KEYS[1], KEYS[2]) +if not score or not tonumber(score) then + return {sz, tostring(total_score), '0'} +end +return {sz, tostring(total_score), tostring(score)} +]] +local query_cluster_id + +-- Updates cluster's data +-- Arguments: +-- 1. Source selector (string) +-- 2. Cluster selector (string) +-- 3. Score (number) +-- 4. Max buckets (number) +-- 5. Expire (number) +-- 6. Expire overflow (number) +-- Returns: nothing +local update_cluster_script = [[ +local sz = redis.call('HLEN', KEYS[1]) + +if not sz or not tonumber(sz) then + -- Create bucket + redis.call('HSET', KEYS[1], KEYS[2], math.abs(KEYS[3])) + redis.call('HSET', KEYS[1], '__s', KEYS[3]) + redis.call('EXPIRE', KEYS[1], KEYS[5]) + + return +end + +sz = tonumber(sz) +local lim = tonumber(KEYS[4]) + +if sz > lim then + + if k then + -- Existing key + redis.call('HINCRBYFLOAT', KEYS[1], KEYS[2], math.abs(KEYS[3])) + end +else + redis.call('HINCRBYFLOAT', KEYS[1], KEYS[2], math.abs(KEYS[3])) + redis.call('EXPIRE', KEYS[1], KEYS[6]) +end + +redis.call('HINCRBYFLOAT', KEYS[1], '__s', KEYS[3]) +redis.call('EXPIRE', KEYS[1], KEYS[5]) +]] +local update_cluster_id + +-- Callbacks and logic + +local function clusterting_filter_cb(task, rule) + local source_selector = rule.source_selector(task) + local cluster_selector + + if source_selector then + cluster_selector = rule.cluster_selector(task) + end + + if not cluster_selector or not source_selector then + rspamd_logger.debugm(N, task, 'skip rule %s, selectors: source="%s", cluster="%s"', + rule.name, source_selector, cluster_selector) + return + end + + local function combine_scores(cur_elts, total_score, element_score) + local final_score + + local size_score = cur_elts * rule.size_mult + local cluster_score = total_score * rule.score_mult + + if element_score > 0 then + -- We have seen this element mostly in junk/spam + final_score = math.min(1.0, size_score + cluster_score) + else + -- We have seen this element in ham mostly, so subtract average it from the size score + final_score = math.min(1.0, size_score - cluster_score / cur_elts) + end + rspamd_logger.debugm(N, task, + 'processed rule %s, selectors: source="%s", cluster="%s"; data: %s elts, %s score, %s elt score', + rule.name, source_selector, cluster_selector, cur_elts, total_score, element_score) + if final_score > 0.1 then + task:insert_result(rule.symbol, final_score, { source_selector, + tostring(size_score), + tostring(cluster_score) }) + end + end + + local function redis_get_cb(err, data) + if data then + if type(data) == 'table' then + combine_scores(tonumber(data[1]), tonumber(data[2]), tonumber(data[3])) + else + rspamd_logger.errx(task, 'invalid type while getting clustering keys %s: %s', + source_selector, type(data)) + end + + elseif err then + rspamd_logger.errx(task, 'got error while getting clustering keys %s: %s', + source_selector, err) + else + rspamd_logger.errx(task, 'got error while getting clustering keys %s: %s', + source_selector, "unknown error") + end + end + + lua_redis.exec_redis_script(query_cluster_id, + { task = task, is_write = false, key = source_selector }, + redis_get_cb, + { source_selector, cluster_selector }) +end + +local function clusterting_idempotent_cb(task, rule) + if task:has_flag('skip') then + return + end + if not rule.allow_local and lua_util.is_rspamc_or_controller(task) then + return + end + + local verdict = lua_verdict.get_specific_verdict(N, task) + local score + + if verdict == 'ham' then + score = rule.ham_mult + elseif verdict == 'spam' then + score = rule.spam_mult + elseif verdict == 'junk' then + score = rule.junk_mult + else + rspamd_logger.debugm(N, task, 'skip rule %s, verdict=%s', + rule.name, verdict) + return + end + + local source_selector = rule.source_selector(task) + local cluster_selector + + if source_selector then + cluster_selector = rule.cluster_selector(task) + end + + if not cluster_selector or not source_selector then + rspamd_logger.debugm(N, task, 'skip rule %s, selectors: source="%s", cluster="%s"', + rule.name, source_selector, cluster_selector) + return + end + + local function redis_set_cb(err, data) + if err then + rspamd_logger.errx(task, 'got error while getting clustering keys %s: %s', + source_selector, err) + else + rspamd_logger.debugm(N, task, 'set clustering key for %s: %s{%s} = %s', + source_selector, "unknown error") + end + end + + lua_redis.exec_redis_script(update_cluster_id, + { task = task, is_write = true, key = source_selector }, + redis_set_cb, + { + source_selector, + cluster_selector, + tostring(score), + tostring(rule.max_elts), + tostring(rule.expire), + tostring(rule.expire_overflow) + } + ) +end +-- Init part +redis_params = lua_redis.parse_redis_server('clustering') +local opts = rspamd_config:get_all_opt("clustering") + +-- Initialization part +if not (opts and type(opts) == 'table') then + lua_util.disable_module(N, "config") + return +end + +if not redis_params then + lua_util.disable_module(N, "redis") + return +end + +if opts['rules'] then + for k, v in pairs(opts['rules']) do + local raw_rule = lua_util.override_defaults(default_rule, v) + + local rule, err = rule_schema:transform(raw_rule) + + if not rule then + rspamd_logger.errx(rspamd_config, 'invalid clustering rule %s: %s', + k, err) + else + + if not rule.symbol then + rule.symbol = k + end + if not rule.prefix then + rule.prefix = k .. "_" + end + + rule.source_selector = lua_selectors.create_selector_closure(rspamd_config, + rule.source_selector, '') + rule.cluster_selector = lua_selectors.create_selector_closure(rspamd_config, + rule.cluster_selector, '') + if rule.source_selector and rule.cluster_selector then + rule.name = k + table.insert(rules, rule) + end + end + end + + if #rules > 0 then + + query_cluster_id = lua_redis.add_redis_script(query_cluster_script, redis_params) + update_cluster_id = lua_redis.add_redis_script(update_cluster_script, redis_params) + local function callback_gen(f, rule) + return function(task) + return f(task, rule) + end + end + + for _, rule in ipairs(rules) do + rspamd_config:register_symbol { + name = rule.symbol, + type = 'normal', + callback = callback_gen(clusterting_filter_cb, rule), + } + rspamd_config:register_symbol { + name = rule.symbol .. '_STORE', + type = 'idempotent', + flags = 'empty,explicit_disable,ignore_passthrough', + callback = callback_gen(clusterting_idempotent_cb, rule), + augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) } + } + end + else + lua_util.disable_module(N, "config") + end +else + lua_util.disable_module(N, "config") +end |