diff options
Diffstat (limited to 'src/plugins/lua/reputation.lua')
-rw-r--r-- | src/plugins/lua/reputation.lua | 1390 |
1 files changed, 1390 insertions, 0 deletions
diff --git a/src/plugins/lua/reputation.lua b/src/plugins/lua/reputation.lua new file mode 100644 index 0000000..a3af26c --- /dev/null +++ b/src/plugins/lua/reputation.lua @@ -0,0 +1,1390 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +if confighelp then + return +end + +-- A generic plugin for reputation handling + +local E = {} +local N = 'reputation' + +local rspamd_logger = require "rspamd_logger" +local rspamd_util = require "rspamd_util" +local lua_util = require "lua_util" +local lua_maps = require "lua_maps" +local lua_maps_exprs = require "lua_maps_expressions" +local hash = require 'rspamd_cryptobox_hash' +local lua_redis = require "lua_redis" +local fun = require "fun" +local lua_selectors = require "lua_selectors" +local ts = require("tableshape").types + +local redis_params = nil +local default_expiry = 864000 -- 10 day by default +local default_prefix = 'RR:' -- Rspamd Reputation + +local tanh = math.tanh or rspamd_util.tanh + +-- Get reputation from ham/spam/probable hits +local function generic_reputation_calc(token, rule, mult, task) + local cfg = rule.selector.config or E + local reject_threshold = task:get_metric_score()[2] or 10.0 + + if cfg.score_calc_func then + return cfg.score_calc_func(rule, token, mult) + end + + if tonumber(token[1]) < cfg.lower_bound then + lua_util.debugm(N, task, "not enough matches %s < %s for rule %s", + token[1], cfg.lower_bound, rule.symbol) + return 0 + end + + -- Get average score + local avg_score = fun.foldl(function(acc, v) + return acc + v + end, 0.0, fun.map(tonumber, token[2])) / #token[2] + + -- Apply function tanh(x / reject_score * atanh(0.95) - atanh(0.5)) + -- 1.83178 0.5493 + local score = tanh(avg_score / reject_threshold * 1.83178 - 0.5493) * mult + lua_util.debugm(N, task, "got generic average score %s (reject threshold=%s, mult=%s) -> %s for rule %s", + avg_score, reject_threshold, mult, score, rule.symbol) + return score +end + +local function add_symbol_score(task, rule, mult, params) + if not params then + params = { tostring(mult) } + end + + if rule.selector.config.split_symbols then + local sym_spam = rule.symbol .. '_SPAM' + local sym_ham = rule.symbol .. '_HAM' + if not rule.static_symbols then + rule.static_symbols = {} + rule.static_symbols.ham = rspamd_config:get_symbol(sym_ham) + rule.static_symbols.spam = rspamd_config:get_symbol(sym_spam) + end + if mult >= 0 then + task:insert_result(sym_spam, mult, params) + else + -- Avoid multiplication of negative the `mult` by negative static score of the + -- ham symbol + if rule.static_symbols.ham and rule.static_symbols.ham.score then + if rule.static_symbols.ham.score < 0 then + mult = math.abs(mult) + end + end + task:insert_result(sym_ham, mult, params) + end + else + task:insert_result(rule.symbol, mult, params) + end +end + +local function sub_symbol_score(task, rule, score) + local function sym_score(sym) + local s = task:get_symbol(sym)[1] + return s.score + end + if rule.selector.config.split_symbols then + local spam_sym = rule.symbol .. '_SPAM' + local ham_sym = rule.symbol .. '_HAM' + + if task:has_symbol(spam_sym) then + score = score - sym_score(spam_sym) + elseif task:has_symbol(ham_sym) then + score = score - sym_score(ham_sym) + end + else + if task:has_symbol(rule.symbol) then + score = score - sym_score(rule.symbol) + end + end + + return score +end + +-- Extracts task score and subtracts score of the rule itself +local function extract_task_score(task, rule) + local lua_verdict = require "lua_verdict" + local verdict, score = lua_verdict.get_specific_verdict(N, task) + + if not score or verdict == 'passthrough' then + return nil + end + + return sub_symbol_score(task, rule, score) +end + +-- DKIM Selector functions +local gr +local function gen_dkim_queries(task, rule) + local dkim_trace = (task:get_symbol('DKIM_TRACE') or E)[1] + local lpeg = require 'lpeg' + local ret = {} + + if not gr then + local semicolon = lpeg.P(':') + local domain = lpeg.C((1 - semicolon) ^ 1) + local res = lpeg.S '+-?~' + + local function res_to_label(ch) + if ch == '+' then + return 'a' + elseif ch == '-' then + return 'r' + end + + return 'u' + end + + gr = domain * semicolon * (lpeg.C(res ^ 1) / res_to_label) + end + + if dkim_trace and dkim_trace.options then + for _, opt in ipairs(dkim_trace.options) do + local dom, res = lpeg.match(gr, opt) + + if dom and res then + local tld = rspamd_util.get_tld(dom) + ret[tld] = res + end + end + end + + return ret +end + +local function dkim_reputation_filter(task, rule) + local requests = gen_dkim_queries(task, rule) + local results = {} + local dkim_tlds = lua_util.keys(requests) + local requests_left = #dkim_tlds + local rep_accepted = 0.0 + + lua_util.debugm(N, task, 'dkim reputation tokens: %s', requests) + + local function tokens_cb(err, token, values) + requests_left = requests_left - 1 + + if values then + results[token] = values + end + + if requests_left == 0 then + for k, v in pairs(results) do + -- `k` in results is a prefixed and suffixed tld, so we need to look through + -- all requests to find any request with the matching tld + local sel_tld + for _, tld in ipairs(dkim_tlds) do + if k:find(tld, 1, true) then + sel_tld = tld + break + end + end + + if sel_tld and requests[sel_tld] then + if requests[sel_tld] == 'a' then + rep_accepted = rep_accepted + generic_reputation_calc(v, rule, 1.0, task) + end + else + rspamd_logger.warnx(task, "cannot find the requested tld for a request: %s (%s tlds noticed)", + k, dkim_tlds) + end + end + + -- Set local reputation symbol + local rep_accepted_abs = math.abs(rep_accepted or 0) + lua_util.debugm(N, task, "dkim reputation accepted: %s", + rep_accepted_abs) + if rep_accepted_abs then + local final_rep = rep_accepted + if rep_accepted > 1.0 then + final_rep = 1.0 + end + if rep_accepted < -1.0 then + final_rep = -1.0 + end + add_symbol_score(task, rule, final_rep) + + -- Store results for future DKIM results adjustments + task:get_mempool():set_variable("dkim_reputation_accept", tostring(rep_accepted)) + end + end + end + + for dom, res in pairs(requests) do + -- tld + "." + check_result, e.g. example.com.+ - reputation for valid sigs + local query = string.format('%s.%s', dom, res) + rule.backend.get_token(task, rule, nil, query, tokens_cb, 'string') + end +end + +local function dkim_reputation_idempotent(task, rule) + local requests = gen_dkim_queries(task, rule) + local sc = extract_task_score(task, rule) + + if sc then + for dom, res in pairs(requests) do + -- tld + "." + check_result, e.g. example.com.+ - reputation for valid sigs + local query = string.format('%s.%s', dom, res) + rule.backend.set_token(task, rule, nil, query, sc) + end + end +end + +local function dkim_reputation_postfilter(task, rule) + local sym_accepted = (task:get_symbol('R_DKIM_ALLOW') or E)[1] + local accept_adjustment = task:get_mempool():get_variable("dkim_reputation_accept") + local cfg = rule.selector.config or E + + if sym_accepted and sym_accepted.score and + accept_adjustment and type(cfg.max_accept_adjustment) == 'number' then + local final_adjustment = cfg.max_accept_adjustment * + rspamd_util.tanh(tonumber(accept_adjustment) or 0) + lua_util.debugm(N, task, "adjust DKIM_ALLOW: " .. + "cfg.max_accept_adjustment=%s accept_adjustment=%s final_adjustment=%s sym_accepted.score=%s", + cfg.max_accept_adjustment, accept_adjustment, final_adjustment, + sym_accepted.score) + + task:adjust_result('R_DKIM_ALLOW', sym_accepted.score + final_adjustment) + end +end + +local dkim_selector = { + config = { + symbol = 'DKIM_SCORE', -- symbol to be inserted + lower_bound = 10, -- minimum number of messages to be scored + min_score = nil, + max_score = nil, + outbound = true, + inbound = true, + max_accept_adjustment = 2.0, -- How to adjust accepted DKIM score + }, + dependencies = { "DKIM_TRACE" }, + filter = dkim_reputation_filter, -- used to get scores + postfilter = dkim_reputation_postfilter, -- used to adjust DKIM scores + idempotent = dkim_reputation_idempotent, -- used to set scores +} + +-- URL Selector functions + +local function gen_url_queries(task, rule) + local domains = {} + + fun.each(function(u) + if u:is_redirected() then + local redir = u:get_redirected() -- get the original url + local redir_tld = redir:get_tld() + if domains[redir_tld] then + domains[redir_tld] = domains[redir_tld] - 1 + end + end + local dom = u:get_tld() + if not domains[dom] then + domains[dom] = 1 + else + domains[dom] = domains[dom] + 1 + end + end, fun.filter(function(u) + return not u:is_html_displayed() + end, + task:get_urls(true))) + + local results = {} + for k, v in lua_util.spairs(domains, + function(t, a, b) + return t[a] > t[b] + end, rule.selector.config.max_urls) do + if v > 0 then + table.insert(results, { k, v }) + end + end + + return results +end + +local function url_reputation_filter(task, rule) + local requests = gen_url_queries(task, rule) + local url_keys = lua_util.keys(requests) + local requests_left = #url_keys + local results = {} + + local function indexed_tokens_cb(err, index, values) + requests_left = requests_left - 1 + + if values then + results[index] = values + end + + if requests_left == 0 then + -- Check the url with maximum hits + local mhits = 0 + + for i, res in ipairs(results) do + local req = requests[i] + if req then + local hits = tonumber(res[1]) + if hits > mhits then + mhits = hits + end + else + rspamd_logger.warnx(task, "cannot find the requested response for a request: %s (%s requests noticed)", + i, #requests) + end + end + + if mhits > 0 then + local score = 0 + for i, res in pairs(results) do + local req = requests[i] + if req then + local url_score = generic_reputation_calc(res, rule, + req[2] / mhits, task) + lua_util.debugm(N, task, "score for url %s is %s, score=%s", req[1], url_score, score) + score = score + url_score + end + end + + if math.abs(score) > 1e-3 then + -- TODO: add description + add_symbol_score(task, rule, score) + end + end + end + end + + for i, req in ipairs(requests) do + local function tokens_cb(err, token, values) + indexed_tokens_cb(err, i, values) + end + + rule.backend.get_token(task, rule, nil, req[1], tokens_cb, 'string') + end +end + +local function url_reputation_idempotent(task, rule) + local requests = gen_url_queries(task, rule) + local sc = extract_task_score(task, rule) + + if sc then + for _, tld in ipairs(requests) do + rule.backend.set_token(task, rule, nil, tld[1], sc) + end + end +end + +local url_selector = { + config = { + symbol = 'URL_SCORE', -- symbol to be inserted + lower_bound = 10, -- minimum number of messages to be scored + min_score = nil, + max_score = nil, + max_urls = 10, + check_from = true, + outbound = true, + inbound = true, + }, + filter = url_reputation_filter, -- used to get scores + idempotent = url_reputation_idempotent -- used to set scores +} +-- IP Selector functions + +local function ip_reputation_init(rule) + local cfg = rule.selector.config + + if cfg.asn_cc_whitelist then + cfg.asn_cc_whitelist = lua_maps.map_add('reputation', + 'asn_cc_whitelist', + 'map', + 'IP score whitelisted ASNs/countries') + end + + return true +end + +local function ip_reputation_filter(task, rule) + + local ip = task:get_from_ip() + + if not ip or not ip:is_valid() then + return + end + if lua_util.is_rspamc_or_controller(task) then + return + end + + local cfg = rule.selector.config + + if ip:get_version() == 4 and cfg.ipv4_mask then + ip = ip:apply_mask(cfg.ipv4_mask) + elseif cfg.ipv6_mask then + ip = ip:apply_mask(cfg.ipv6_mask) + end + + local pool = task:get_mempool() + local asn = pool:get_variable("asn") + local country = pool:get_variable("country") + + if country and cfg.asn_cc_whitelist then + if cfg.asn_cc_whitelist:get_key(country) then + return + end + if asn and cfg.asn_cc_whitelist:get_key(asn) then + return + end + end + + -- These variables are used to define if we have some specific token + local has_asn = not asn + local has_country = not country + local has_ip = false + + local asn_stats, country_stats, ip_stats + + local function ipstats_check() + local score = 0.0 + local description_t = {} + + if asn_stats then + local asn_score = generic_reputation_calc(asn_stats, rule, cfg.scores.asn, task) + score = score + asn_score + table.insert(description_t, string.format('asn: %s(%.2f)', + asn, asn_score)) + end + if country_stats then + local country_score = generic_reputation_calc(country_stats, rule, + cfg.scores.country, task) + score = score + country_score + table.insert(description_t, string.format('country: %s(%.2f)', + country, country_score)) + end + if ip_stats then + local ip_score = generic_reputation_calc(ip_stats, rule, cfg.scores.ip, + task) + score = score + ip_score + table.insert(description_t, string.format('ip: %s(%.2f)', + tostring(ip), ip_score)) + end + + if math.abs(score) > 0.001 then + add_symbol_score(task, rule, score, table.concat(description_t, ', ')) + end + end + + local function gen_token_callback(what) + return function(err, _, values) + if not err and values then + if what == 'asn' then + has_asn = true + asn_stats = values + elseif what == 'country' then + has_country = true + country_stats = values + elseif what == 'ip' then + has_ip = true + ip_stats = values + end + else + if what == 'asn' then + has_asn = true + elseif what == 'country' then + has_country = true + elseif what == 'ip' then + has_ip = true + end + end + + if has_asn and has_country and has_ip then + -- Check reputation + ipstats_check() + end + end + end + + if asn then + rule.backend.get_token(task, rule, cfg.asn_prefix, asn, + gen_token_callback('asn'), 'string') + end + if country then + rule.backend.get_token(task, rule, cfg.country_prefix, country, + gen_token_callback('country'), 'string') + end + + rule.backend.get_token(task, rule, cfg.ip_prefix, ip, + gen_token_callback('ip'), 'ip') +end + +-- Used to set scores +local function ip_reputation_idempotent(task, rule) + if not rule.backend.set_token then + return + end -- Read only backend + local ip = task:get_from_ip() + local cfg = rule.selector.config + + if not ip or not ip:is_valid() then + return + end + + if lua_util.is_rspamc_or_controller(task) then + return + end + + if ip:get_version() == 4 and cfg.ipv4_mask then + ip = ip:apply_mask(cfg.ipv4_mask) + elseif cfg.ipv6_mask then + ip = ip:apply_mask(cfg.ipv6_mask) + end + + local pool = task:get_mempool() + local asn = pool:get_variable("asn") + local country = pool:get_variable("country") + + if country and cfg.asn_cc_whitelist then + if cfg.asn_cc_whitelist:get_key(country) then + return + end + if asn and cfg.asn_cc_whitelist:get_key(asn) then + return + end + end + local sc = extract_task_score(task, rule) + if sc then + if asn then + rule.backend.set_token(task, rule, cfg.asn_prefix, asn, sc, nil, 'string') + end + if country then + rule.backend.set_token(task, rule, cfg.country_prefix, country, sc, nil, 'string') + end + + rule.backend.set_token(task, rule, cfg.ip_prefix, ip, sc, nil, 'ip') + end +end + +-- Selectors are used to extract reputation tokens +local ip_selector = { + config = { + scores = { -- how each component is evaluated + ['asn'] = 0.4, + ['country'] = 0.01, + ['ip'] = 1.0 + }, + symbol = 'SENDER_REP', -- symbol to be inserted + split_symbols = true, + asn_prefix = 'a:', -- prefix for ASN hashes + country_prefix = 'c:', -- prefix for country hashes + ip_prefix = 'i:', + lower_bound = 10, -- minimum number of messages to be scored + min_score = nil, + max_score = nil, + score_divisor = 1, + outbound = false, + inbound = true, + ipv4_mask = 32, -- Mask bits for ipv4 + ipv6_mask = 64, -- Mask bits for ipv6 + }, + --dependencies = {"ASN"}, -- ASN is a prefilter now... + init = ip_reputation_init, + filter = ip_reputation_filter, -- used to get scores + idempotent = ip_reputation_idempotent, -- used to set scores +} + +-- SPF Selector functions + +local function spf_reputation_filter(task, rule) + local spf_record = task:get_mempool():get_variable('spf_record') + local spf_allow = task:has_symbol('R_SPF_ALLOW') + + -- Don't care about bad/missing spf + if not spf_record or not spf_allow then + return + end + + local cr = require "rspamd_cryptobox_hash" + local hkey = cr.create(spf_record):base32():sub(1, 32) + + lua_util.debugm(N, task, 'check spf record %s -> %s', spf_record, hkey) + + local function tokens_cb(err, token, values) + if values then + local score = generic_reputation_calc(values, rule, 1.0, task) + + if math.abs(score) > 1e-3 then + -- TODO: add description + add_symbol_score(task, rule, score) + end + end + end + + rule.backend.get_token(task, rule, nil, hkey, tokens_cb, 'string') +end + +local function spf_reputation_idempotent(task, rule) + local sc = extract_task_score(task, rule) + local spf_record = task:get_mempool():get_variable('spf_record') + local spf_allow = task:has_symbol('R_SPF_ALLOW') + + if not spf_record or not spf_allow or not sc then + return + end + + local cr = require "rspamd_cryptobox_hash" + local hkey = cr.create(spf_record):base32():sub(1, 32) + + lua_util.debugm(N, task, 'set spf record %s -> %s = %s', + spf_record, hkey, sc) + rule.backend.set_token(task, rule, nil, hkey, sc) +end + +local spf_selector = { + config = { + symbol = 'SPF_REP', -- symbol to be inserted + split_symbols = true, + lower_bound = 10, -- minimum number of messages to be scored + min_score = nil, + max_score = nil, + outbound = true, + inbound = true, + }, + dependencies = { "R_SPF_ALLOW" }, + filter = spf_reputation_filter, -- used to get scores + idempotent = spf_reputation_idempotent, -- used to set scores +} + +-- Generic selector based on lua_selectors framework + +local function generic_reputation_init(rule) + local cfg = rule.selector.config + + if not cfg.selector then + rspamd_logger.errx(rspamd_config, 'cannot configure generic rule: no selector specified') + return false + end + + local selector = lua_selectors.create_selector_closure(rspamd_config, + cfg.selector, cfg.delimiter) + + if not selector then + rspamd_logger.errx(rspamd_config, 'cannot configure generic rule: bad selector: %s', + cfg.selector) + return false + end + + cfg.selector = selector -- Replace with closure + + if cfg.whitelist then + cfg.whitelist = lua_maps.map_add('reputation', + 'generic_whitelist', + 'map', + 'Whitelisted selectors') + end + + return true +end + +local function generic_reputation_filter(task, rule) + local cfg = rule.selector.config + local selector_res = cfg.selector(task) + + local function tokens_cb(err, token, values) + if values then + local score = generic_reputation_calc(values, rule, 1.0, task) + + if math.abs(score) > 1e-3 then + -- TODO: add description + add_symbol_score(task, rule, score) + end + end + end + + if selector_res then + if type(selector_res) == 'table' then + fun.each(function(e) + lua_util.debugm(N, task, 'check generic reputation (%s) %s', + rule['symbol'], e) + rule.backend.get_token(task, rule, nil, e, tokens_cb, 'string') + end, selector_res) + else + lua_util.debugm(N, task, 'check generic reputation (%s) %s', + rule['symbol'], selector_res) + rule.backend.get_token(task, rule, nil, selector_res, tokens_cb, 'string') + end + end +end + +local function generic_reputation_idempotent(task, rule) + local sc = extract_task_score(task, rule) + local cfg = rule.selector.config + + local selector_res = cfg.selector(task) + if not selector_res then + return + end + + if sc then + if type(selector_res) == 'table' then + fun.each(function(e) + lua_util.debugm(N, task, 'set generic selector (%s) %s = %s', + rule['symbol'], e, sc) + rule.backend.set_token(task, rule, nil, e, sc) + end, selector_res) + else + lua_util.debugm(N, task, 'set generic selector (%s) %s = %s', + rule['symbol'], selector_res, sc) + rule.backend.set_token(task, rule, nil, selector_res, sc) + end + end +end + +local generic_selector = { + schema = ts.shape { + lower_bound = ts.number + ts.string / tonumber, + max_score = ts.number:is_optional(), + min_score = ts.number:is_optional(), + outbound = ts.boolean, + inbound = ts.boolean, + selector = ts.string, + delimiter = ts.string, + whitelist = ts.one_of(lua_maps.map_schema, lua_maps_exprs.schema):is_optional(), + }, + config = { + lower_bound = 10, -- minimum number of messages to be scored + min_score = nil, + max_score = nil, + outbound = true, + inbound = true, + selector = nil, + delimiter = ':', + whitelist = nil + }, + init = generic_reputation_init, + filter = generic_reputation_filter, -- used to get scores + idempotent = generic_reputation_idempotent -- used to set scores +} + +local selectors = { + ip = ip_selector, + sender = ip_selector, -- Better name + url = url_selector, + dkim = dkim_selector, + spf = spf_selector, + generic = generic_selector +} + +local function reputation_dns_init(rule, _, _, _) + if not rule.backend.config.list then + rspamd_logger.errx(rspamd_config, "rule %s with DNS backend has no `list` parameter defined", + rule.symbol) + return false + end + + return true +end + +local function gen_token_key(prefix, token, rule) + if prefix then + token = prefix .. token + end + local res = token + if rule.backend.config.hashed then + local hash_alg = rule.backend.config.hash_alg or "blake2" + local encoding = "base32" + + if rule.backend.config.hash_encoding then + encoding = rule.backend.config.hash_encoding + end + + local h = hash.create_specific(hash_alg, res) + if encoding == 'hex' then + res = h:hex() + elseif encoding == 'base64' then + res = h:base64() + else + res = h:base32() + end + end + + if rule.backend.config.hashlen then + res = string.sub(res, 1, rule.backend.config.hashlen) + end + + if rule.backend.config.prefix then + res = rule.backend.config.prefix .. res + end + + return res +end + +--[[ +-- Generic interface for get and set tokens functions: +-- get_token(task, rule, prefix, token, continuation, token_type), where `continuation` is the following function: +-- +-- function(err, token, values) ... end +-- `err`: string value for error (similar to redis or DNS callbacks) +-- `token`: string value of a token +-- `values`: table of key=number, parsed from backend. It is selector's duty +-- to deal with missing, invalid or other values +-- +-- set_token(task, rule, token, values, continuation_cb) +-- This function takes values, encodes them using whatever suitable format +-- and calls for continuation: +-- +-- function(err, token) ... end +-- `err`: string value for error (similar to redis or DNS callbacks) +-- `token`: string value of a token +-- +-- example of tokens: {'s': 0, 'h': 0, 'p': 1} +--]] + +local function reputation_dns_get_token(task, rule, prefix, token, continuation_cb, token_type) + -- local r = task:get_resolver() + -- In DNS we never ever use prefix as prefix, we use if as a suffix! + if token_type == 'ip' then + token = table.concat(token:inversed_str_octets(), '.') + end + + local key = gen_token_key(nil, token, rule) + local dns_name = key .. '.' .. rule.backend.config.list + + if prefix then + dns_name = string.format('%s.%s.%s', key, prefix, + rule.backend.config.list) + else + dns_name = string.format('%s.%s', key, rule.backend.config.list) + end + + local function dns_cb(_, _, results, err) + if err and (err ~= 'requested record is not found' and + err ~= 'no records with this name') then + rspamd_logger.warnx(task, 'error looking up %s: %s', dns_name, err) + end + + lua_util.debugm(N, task, 'DNS RESPONSE: label=%1 results=%2 err=%3 list=%4', + dns_name, results, err, rule.backend.config.list) + + -- Now split tokens to list of values + if results and results[1] then + -- Format: num_messages;sc1;sc2...scn + local dns_tokens = lua_util.rspamd_str_split(results[1], ";") + -- Convert all to numbers excluding any possible non-numbers + dns_tokens = fun.totable(fun.filter(function(e) + return type(e) == 'number' + end, + fun.map(function(e) + local n = tonumber(e) + if n then + return n + end + return "BAD" + end, dns_tokens))) + + if #dns_tokens < 2 then + rspamd_logger.warnx(task, 'cannot parse response for reputation token %s: %s', + dns_name, results[1]) + continuation_cb(results, dns_name, nil) + else + local cnt = table.remove(dns_tokens, 1) + continuation_cb(nil, dns_name, { cnt, dns_tokens }) + end + else + rspamd_logger.messagex(task, 'invalid response for reputation token %s: %s', + dns_name, results[1]) + continuation_cb(results, dns_name, nil) + end + end + + task:get_resolver():resolve_a({ + task = task, + name = dns_name, + callback = dns_cb, + forced = true, + }) +end + +local function reputation_redis_init(rule, cfg, ev_base, worker) + local our_redis_params = {} + + our_redis_params = lua_redis.try_load_redis_servers(rule.backend.config, rspamd_config, + true) + if not our_redis_params then + our_redis_params = redis_params + end + if not our_redis_params then + rspamd_logger.errx(rspamd_config, 'cannot init redis for reputation rule: %s', + rule) + return false + end + -- Init scripts for buckets + -- Redis script to extract data from Redis buckets + -- KEYS[1] - key to extract + -- Value returned - table of scores as a strings vector + number of scores + local redis_get_script_tpl = [[ + local cnt = redis.call('HGET', KEYS[1], 'n') + local results = {} + if cnt then + {% for w in windows %} + local sc = tonumber(redis.call('HGET', KEYS[1], 'v' .. '{= w.name =}')) + table.insert(results, tostring(sc * {= w.mult =})) + {% endfor %} + else + {% for w in windows %} + table.insert(results, '0') + {% endfor %} + end + + return {cnt or 0, results} + ]] + + local get_script = lua_util.jinja_template(redis_get_script_tpl, + { windows = rule.backend.config.buckets }) + rspamd_logger.debugm(N, rspamd_config, 'added extraction script %s', get_script) + rule.backend.script_get = lua_redis.add_redis_script(get_script, our_redis_params) + + -- Redis script to update Redis buckets + -- KEYS[1] - key to update + -- KEYS[2] - current time in milliseconds + -- KEYS[3] - message score + -- KEYS[4] - expire for a bucket + -- Value returned - table of scores as a strings vector + local redis_adaptive_emea_script_tpl = [[ + local last = redis.call('HGET', KEYS[1], 'l') + local score = tonumber(KEYS[3]) + local now = tonumber(KEYS[2]) + local scores = {} + + if last then + {% for w in windows %} + local last_value = tonumber(redis.call('HGET', KEYS[1], 'v' .. '{= w.name =}')) + local window = {= w.time =} + -- Adjust alpha + local time_diff = now - last + if time_diff < 0 then + time_diff = 0 + end + local alpha = 1.0 - math.exp((-time_diff) / (1000 * window)) + local nscore = alpha * score + (1.0 - alpha) * last_value + table.insert(scores, tostring(nscore * {= w.mult =})) + {% endfor %} + else + {% for w in windows %} + table.insert(scores, tostring(score * {= w.mult =})) + {% endfor %} + end + + local i = 1 + {% for w in windows %} + redis.call('HSET', KEYS[1], 'v' .. '{= w.name =}', scores[i]) + i = i + 1 + {% endfor %} + redis.call('HSET', KEYS[1], 'l', now) + redis.call('HINCRBY', KEYS[1], 'n', 1) + redis.call('EXPIRE', KEYS[1], tonumber(KEYS[4])) + + return scores +]] + + local set_script = lua_util.jinja_template(redis_adaptive_emea_script_tpl, + { windows = rule.backend.config.buckets }) + rspamd_logger.debugm(N, rspamd_config, 'added emea update script %s', set_script) + rule.backend.script_set = lua_redis.add_redis_script(set_script, our_redis_params) + + return true +end + +local function reputation_redis_get_token(task, rule, prefix, token, continuation_cb, token_type) + if token_type and token_type == 'ip' then + token = tostring(token) + end + local key = gen_token_key(prefix, token, rule) + + local function redis_get_cb(err, data) + if data then + if type(data) == 'table' then + lua_util.debugm(N, task, 'rule %s - got values for key %s -> %s', + rule['symbol'], key, data) + continuation_cb(nil, key, data) + else + rspamd_logger.errx(task, 'rule %s - invalid type while getting reputation keys %s: %s', + rule['symbol'], key, type(data)) + continuation_cb("invalid type", key, nil) + end + + elseif err then + rspamd_logger.errx(task, 'rule %s - got error while getting reputation keys %s: %s', + rule['symbol'], key, err) + continuation_cb(err, key, nil) + else + rspamd_logger.errx(task, 'rule %s - got error while getting reputation keys %s: %s', + rule['symbol'], key, "unknown error") + continuation_cb("unknown error", key, nil) + end + end + + local ret = lua_redis.exec_redis_script(rule.backend.script_get, + { task = task, is_write = false }, + redis_get_cb, + { key }) + if not ret then + rspamd_logger.errx(task, 'cannot make redis request to check results') + end +end + +local function reputation_redis_set_token(task, rule, prefix, token, sc, continuation_cb, token_type) + if token_type and token_type == 'ip' then + token = tostring(token) + end + local key = gen_token_key(prefix, token, rule) + + local function redis_set_cb(err, data) + if err then + rspamd_logger.errx(task, 'rule %s - got error while setting reputation keys %s: %s', + rule['symbol'], key, err) + if continuation_cb then + continuation_cb(err, key) + end + else + if continuation_cb then + continuation_cb(nil, key) + end + end + end + + lua_util.debugm(N, task, 'rule %s - set values for key %s -> %s', + rule['symbol'], key, sc) + local ret = lua_redis.exec_redis_script(rule.backend.script_set, + { task = task, is_write = true }, + redis_set_cb, + { key, tostring(os.time() * 1000), + tostring(sc), + tostring(rule.backend.config.expiry) }) + if not ret then + rspamd_logger.errx(task, 'got error while connecting to redis') + end +end + +--[[ Backends are responsible for getting reputation tokens + -- Common config options: + -- `hashed`: if `true` then apply hash function to the key + -- `hash_alg`: use specific hash type (`blake2` by default) + -- `hash_len`: strip hash to this amount of bytes (no strip by default) + -- `hash_encoding`: use specific hash encoding (base32 by default) +--]] +local backends = { + redis = { + schema = lua_redis.enrich_schema({ + prefix = ts.string:is_optional(), + expiry = (ts.number + ts.string / lua_util.parse_time_interval):is_optional(), + buckets = ts.array_of(ts.shape { + time = ts.number + ts.string / lua_util.parse_time_interval, + name = ts.string, + mult = ts.number + ts.string / tonumber + }) :is_optional(), + }), + config = { + expiry = default_expiry, + prefix = default_prefix, + buckets = { + { + time = 60 * 60 * 24 * 30, + name = '1m', + mult = 1.0, + } + }, -- What buckets should be used, default 1h and 1month + }, + init = reputation_redis_init, + get_token = reputation_redis_get_token, + set_token = reputation_redis_set_token, + }, + dns = { + schema = ts.shape { + list = ts.string, + }, + config = { + -- list = rep.example.com + }, + get_token = reputation_dns_get_token, + -- No set token for DNS + init = reputation_dns_init, + } +} + +local function is_rule_applicable(task, rule) + local ip = task:get_from_ip() + if not (rule.selector.config.outbound and rule.selector.config.inbound) then + if rule.selector.config.outbound then + if not (task:get_user() or (ip and ip:is_local())) then + return false + end + elseif rule.selector.config.inbound then + if task:get_user() or (ip and ip:is_local()) then + return false + end + end + end + + if rule.config.whitelist_map then + if rule.config.whitelist_map:process(task) then + return false + end + end + + return true +end + +local function reputation_filter_cb(task, rule) + if (is_rule_applicable(task, rule)) then + rule.selector.filter(task, rule, rule.backend) + end +end + +local function reputation_postfilter_cb(task, rule) + if (is_rule_applicable(task, rule)) then + rule.selector.postfilter(task, rule, rule.backend) + end +end + +local function reputation_idempotent_cb(task, rule) + if (is_rule_applicable(task, rule)) then + rule.selector.idempotent(task, rule, rule.backend) + end +end + +local function callback_gen(cb, rule) + return function(task) + if rule.enabled then + cb(task, rule) + end + end +end + +local function parse_rule(name, tbl) + local sel_type, sel_conf = fun.head(tbl.selector) + local selector = selectors[sel_type] + + if not selector then + rspamd_logger.errx(rspamd_config, "unknown selector defined for rule %s: %s", name, + sel_type) + return false + end + + local bk_type, bk_conf = fun.head(tbl.backend) + + local backend = backends[bk_type] + if not backend then + rspamd_logger.errx(rspamd_config, "unknown backend defined for rule %s: %s", name, + tbl.backend.type) + return false + end + -- Allow config override + local rule = { + selector = lua_util.shallowcopy(selector), + backend = lua_util.shallowcopy(backend), + config = {} + } + + -- Override default config params + rule.backend.config = lua_util.override_defaults(rule.backend.config, bk_conf) + if backend.schema then + local checked, schema_err = backend.schema:transform(rule.backend.config) + if not checked then + rspamd_logger.errx(rspamd_config, "cannot parse backend config for %s: %s", + sel_type, schema_err) + + return false + end + + rule.backend.config = checked + end + + rule.selector.config = lua_util.override_defaults(rule.selector.config, sel_conf) + if selector.schema then + local checked, schema_err = selector.schema:transform(rule.selector.config) + + if not checked then + rspamd_logger.errx(rspamd_config, "cannot parse selector config for %s: %s (%s)", + sel_type, + schema_err, sel_conf) + return + end + + rule.selector.config = checked + end + -- Generic options + tbl.selector = nil + tbl.backend = nil + rule.config = lua_util.override_defaults(rule.config, tbl) + + if rule.config.whitelist then + if lua_maps_exprs.schema(rule.config.whitelist) then + rule.config.whitelist_map = lua_maps_exprs.create(rspamd_config, + rule.config.whitelist, N) + elseif lua_maps.map_schema(rule.config.whitelist) then + local map = lua_maps.map_add_from_ucl(rule.config.whitelist, + 'radix', + sel_type .. ' reputation whitelist') + + if not map then + rspamd_logger.errx(rspamd_config, "cannot parse whitelist map config for %s: (%s)", + sel_type, + rule.config.whitelist) + return + end + + rule.config.whitelist_map = { + process = function(_, task) + -- Hack: we assume that it is an ip whitelist :( + local ip = task:get_from_ip() + + if ip and map:get_key(ip) then + return true + end + return false + end + } + else + rspamd_logger.errx(rspamd_config, "cannot parse whitelist map config for %s: (%s)", + sel_type, + rule.config.whitelist) + return false + end + end + + local symbol = rule.selector.config.symbol or name + if tbl.symbol then + symbol = tbl.symbol + end + + rule.symbol = symbol + rule.enabled = true + if rule.selector.init then + rule.enabled = false + end + if rule.backend.init then + rule.enabled = false + end + -- Perform additional initialization if needed + rspamd_config:add_on_load(function(cfg, ev_base, worker) + if rule.selector.init then + if not rule.selector.init(rule, cfg, ev_base, worker) then + rule.enabled = false + rspamd_logger.errx(rspamd_config, 'Cannot init selector %s (backend %s) for symbol %s', + sel_type, bk_type, rule.symbol) + else + rule.enabled = true + end + end + if rule.backend.init then + if not rule.backend.init(rule, cfg, ev_base, worker) then + rule.enabled = false + rspamd_logger.errx(rspamd_config, 'Cannot init backend (%s) for rule %s for symbol %s', + bk_type, sel_type, rule.symbol) + else + rule.enabled = true + end + end + + if rule.enabled then + rspamd_logger.infox(rspamd_config, 'Enable %s (%s backend) rule for symbol %s (split symbols: %s)', + sel_type, bk_type, rule.symbol, + rule.selector.config.split_symbols) + end + end) + + -- We now generate symbol for checking + local rule_type = 'normal' + if rule.selector.config.split_symbols then + rule_type = 'callback' + end + + local id = rspamd_config:register_symbol { + name = rule.symbol, + type = rule_type, + callback = callback_gen(reputation_filter_cb, rule), + augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) }, + } + + if rule.selector.config.split_symbols then + rspamd_config:register_symbol { + name = rule.symbol .. '_HAM', + type = 'virtual', + parent = id, + } + rspamd_config:register_symbol { + name = rule.symbol .. '_SPAM', + type = 'virtual', + parent = id, + } + end + + if rule.selector.dependencies then + fun.each(function(d) + rspamd_config:register_dependency(symbol, d) + end, rule.selector.dependencies) + end + + if rule.selector.postfilter then + -- Also register a postfilter + rspamd_config:register_symbol { + name = rule.symbol .. '_POST', + type = 'postfilter', + flags = 'nostat,explicit_disable,ignore_passthrough', + callback = callback_gen(reputation_postfilter_cb, rule), + augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) }, + } + end + + if rule.selector.idempotent then + -- Has also idempotent component (e.g. saving data to the backend) + rspamd_config:register_symbol { + name = rule.symbol .. '_IDEMPOTENT', + type = 'idempotent', + flags = 'explicit_disable,ignore_passthrough', + callback = callback_gen(reputation_idempotent_cb, rule), + augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) }, + } + end + + return true +end + +redis_params = lua_redis.parse_redis_server('reputation') +local opts = rspamd_config:get_all_opt("reputation") + +-- Initialization part +if not (opts and type(opts) == 'table') then + rspamd_logger.infox(rspamd_config, 'Module is not configured, disabling it') + return +end + +if opts['rules'] then + for k, v in pairs(opts['rules']) do + if not ((v or E).selector) then + rspamd_logger.errx(rspamd_config, "no selector defined for rule %s", k) + lua_util.config_utils.push_config_error(N, "no selector defined for rule: " .. k) + else + if not parse_rule(k, v) then + lua_util.config_utils.push_config_error(N, "reputation rule is misconfigured: " .. k) + end + end + end +else + lua_util.disable_module(N, "config") +end |