summaryrefslogtreecommitdiffstats
path: root/src/plugins/lua/reputation.lua
diff options
context:
space:
mode:
Diffstat (limited to 'src/plugins/lua/reputation.lua')
-rw-r--r--src/plugins/lua/reputation.lua1390
1 files changed, 1390 insertions, 0 deletions
diff --git a/src/plugins/lua/reputation.lua b/src/plugins/lua/reputation.lua
new file mode 100644
index 0000000..a3af26c
--- /dev/null
+++ b/src/plugins/lua/reputation.lua
@@ -0,0 +1,1390 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+if confighelp then
+ return
+end
+
+-- A generic plugin for reputation handling
+
+local E = {}
+local N = 'reputation'
+
+local rspamd_logger = require "rspamd_logger"
+local rspamd_util = require "rspamd_util"
+local lua_util = require "lua_util"
+local lua_maps = require "lua_maps"
+local lua_maps_exprs = require "lua_maps_expressions"
+local hash = require 'rspamd_cryptobox_hash'
+local lua_redis = require "lua_redis"
+local fun = require "fun"
+local lua_selectors = require "lua_selectors"
+local ts = require("tableshape").types
+
+local redis_params = nil
+local default_expiry = 864000 -- 10 day by default
+local default_prefix = 'RR:' -- Rspamd Reputation
+
+local tanh = math.tanh or rspamd_util.tanh
+
+-- Get reputation from ham/spam/probable hits
+local function generic_reputation_calc(token, rule, mult, task)
+ local cfg = rule.selector.config or E
+ local reject_threshold = task:get_metric_score()[2] or 10.0
+
+ if cfg.score_calc_func then
+ return cfg.score_calc_func(rule, token, mult)
+ end
+
+ if tonumber(token[1]) < cfg.lower_bound then
+ lua_util.debugm(N, task, "not enough matches %s < %s for rule %s",
+ token[1], cfg.lower_bound, rule.symbol)
+ return 0
+ end
+
+ -- Get average score
+ local avg_score = fun.foldl(function(acc, v)
+ return acc + v
+ end, 0.0, fun.map(tonumber, token[2])) / #token[2]
+
+ -- Apply function tanh(x / reject_score * atanh(0.95) - atanh(0.5))
+ -- 1.83178 0.5493
+ local score = tanh(avg_score / reject_threshold * 1.83178 - 0.5493) * mult
+ lua_util.debugm(N, task, "got generic average score %s (reject threshold=%s, mult=%s) -> %s for rule %s",
+ avg_score, reject_threshold, mult, score, rule.symbol)
+ return score
+end
+
+local function add_symbol_score(task, rule, mult, params)
+ if not params then
+ params = { tostring(mult) }
+ end
+
+ if rule.selector.config.split_symbols then
+ local sym_spam = rule.symbol .. '_SPAM'
+ local sym_ham = rule.symbol .. '_HAM'
+ if not rule.static_symbols then
+ rule.static_symbols = {}
+ rule.static_symbols.ham = rspamd_config:get_symbol(sym_ham)
+ rule.static_symbols.spam = rspamd_config:get_symbol(sym_spam)
+ end
+ if mult >= 0 then
+ task:insert_result(sym_spam, mult, params)
+ else
+ -- Avoid multiplication of negative the `mult` by negative static score of the
+ -- ham symbol
+ if rule.static_symbols.ham and rule.static_symbols.ham.score then
+ if rule.static_symbols.ham.score < 0 then
+ mult = math.abs(mult)
+ end
+ end
+ task:insert_result(sym_ham, mult, params)
+ end
+ else
+ task:insert_result(rule.symbol, mult, params)
+ end
+end
+
+local function sub_symbol_score(task, rule, score)
+ local function sym_score(sym)
+ local s = task:get_symbol(sym)[1]
+ return s.score
+ end
+ if rule.selector.config.split_symbols then
+ local spam_sym = rule.symbol .. '_SPAM'
+ local ham_sym = rule.symbol .. '_HAM'
+
+ if task:has_symbol(spam_sym) then
+ score = score - sym_score(spam_sym)
+ elseif task:has_symbol(ham_sym) then
+ score = score - sym_score(ham_sym)
+ end
+ else
+ if task:has_symbol(rule.symbol) then
+ score = score - sym_score(rule.symbol)
+ end
+ end
+
+ return score
+end
+
+-- Extracts task score and subtracts score of the rule itself
+local function extract_task_score(task, rule)
+ local lua_verdict = require "lua_verdict"
+ local verdict, score = lua_verdict.get_specific_verdict(N, task)
+
+ if not score or verdict == 'passthrough' then
+ return nil
+ end
+
+ return sub_symbol_score(task, rule, score)
+end
+
+-- DKIM Selector functions
+local gr
+local function gen_dkim_queries(task, rule)
+ local dkim_trace = (task:get_symbol('DKIM_TRACE') or E)[1]
+ local lpeg = require 'lpeg'
+ local ret = {}
+
+ if not gr then
+ local semicolon = lpeg.P(':')
+ local domain = lpeg.C((1 - semicolon) ^ 1)
+ local res = lpeg.S '+-?~'
+
+ local function res_to_label(ch)
+ if ch == '+' then
+ return 'a'
+ elseif ch == '-' then
+ return 'r'
+ end
+
+ return 'u'
+ end
+
+ gr = domain * semicolon * (lpeg.C(res ^ 1) / res_to_label)
+ end
+
+ if dkim_trace and dkim_trace.options then
+ for _, opt in ipairs(dkim_trace.options) do
+ local dom, res = lpeg.match(gr, opt)
+
+ if dom and res then
+ local tld = rspamd_util.get_tld(dom)
+ ret[tld] = res
+ end
+ end
+ end
+
+ return ret
+end
+
+local function dkim_reputation_filter(task, rule)
+ local requests = gen_dkim_queries(task, rule)
+ local results = {}
+ local dkim_tlds = lua_util.keys(requests)
+ local requests_left = #dkim_tlds
+ local rep_accepted = 0.0
+
+ lua_util.debugm(N, task, 'dkim reputation tokens: %s', requests)
+
+ local function tokens_cb(err, token, values)
+ requests_left = requests_left - 1
+
+ if values then
+ results[token] = values
+ end
+
+ if requests_left == 0 then
+ for k, v in pairs(results) do
+ -- `k` in results is a prefixed and suffixed tld, so we need to look through
+ -- all requests to find any request with the matching tld
+ local sel_tld
+ for _, tld in ipairs(dkim_tlds) do
+ if k:find(tld, 1, true) then
+ sel_tld = tld
+ break
+ end
+ end
+
+ if sel_tld and requests[sel_tld] then
+ if requests[sel_tld] == 'a' then
+ rep_accepted = rep_accepted + generic_reputation_calc(v, rule, 1.0, task)
+ end
+ else
+ rspamd_logger.warnx(task, "cannot find the requested tld for a request: %s (%s tlds noticed)",
+ k, dkim_tlds)
+ end
+ end
+
+ -- Set local reputation symbol
+ local rep_accepted_abs = math.abs(rep_accepted or 0)
+ lua_util.debugm(N, task, "dkim reputation accepted: %s",
+ rep_accepted_abs)
+ if rep_accepted_abs then
+ local final_rep = rep_accepted
+ if rep_accepted > 1.0 then
+ final_rep = 1.0
+ end
+ if rep_accepted < -1.0 then
+ final_rep = -1.0
+ end
+ add_symbol_score(task, rule, final_rep)
+
+ -- Store results for future DKIM results adjustments
+ task:get_mempool():set_variable("dkim_reputation_accept", tostring(rep_accepted))
+ end
+ end
+ end
+
+ for dom, res in pairs(requests) do
+ -- tld + "." + check_result, e.g. example.com.+ - reputation for valid sigs
+ local query = string.format('%s.%s', dom, res)
+ rule.backend.get_token(task, rule, nil, query, tokens_cb, 'string')
+ end
+end
+
+local function dkim_reputation_idempotent(task, rule)
+ local requests = gen_dkim_queries(task, rule)
+ local sc = extract_task_score(task, rule)
+
+ if sc then
+ for dom, res in pairs(requests) do
+ -- tld + "." + check_result, e.g. example.com.+ - reputation for valid sigs
+ local query = string.format('%s.%s', dom, res)
+ rule.backend.set_token(task, rule, nil, query, sc)
+ end
+ end
+end
+
+local function dkim_reputation_postfilter(task, rule)
+ local sym_accepted = (task:get_symbol('R_DKIM_ALLOW') or E)[1]
+ local accept_adjustment = task:get_mempool():get_variable("dkim_reputation_accept")
+ local cfg = rule.selector.config or E
+
+ if sym_accepted and sym_accepted.score and
+ accept_adjustment and type(cfg.max_accept_adjustment) == 'number' then
+ local final_adjustment = cfg.max_accept_adjustment *
+ rspamd_util.tanh(tonumber(accept_adjustment) or 0)
+ lua_util.debugm(N, task, "adjust DKIM_ALLOW: " ..
+ "cfg.max_accept_adjustment=%s accept_adjustment=%s final_adjustment=%s sym_accepted.score=%s",
+ cfg.max_accept_adjustment, accept_adjustment, final_adjustment,
+ sym_accepted.score)
+
+ task:adjust_result('R_DKIM_ALLOW', sym_accepted.score + final_adjustment)
+ end
+end
+
+local dkim_selector = {
+ config = {
+ symbol = 'DKIM_SCORE', -- symbol to be inserted
+ lower_bound = 10, -- minimum number of messages to be scored
+ min_score = nil,
+ max_score = nil,
+ outbound = true,
+ inbound = true,
+ max_accept_adjustment = 2.0, -- How to adjust accepted DKIM score
+ },
+ dependencies = { "DKIM_TRACE" },
+ filter = dkim_reputation_filter, -- used to get scores
+ postfilter = dkim_reputation_postfilter, -- used to adjust DKIM scores
+ idempotent = dkim_reputation_idempotent, -- used to set scores
+}
+
+-- URL Selector functions
+
+local function gen_url_queries(task, rule)
+ local domains = {}
+
+ fun.each(function(u)
+ if u:is_redirected() then
+ local redir = u:get_redirected() -- get the original url
+ local redir_tld = redir:get_tld()
+ if domains[redir_tld] then
+ domains[redir_tld] = domains[redir_tld] - 1
+ end
+ end
+ local dom = u:get_tld()
+ if not domains[dom] then
+ domains[dom] = 1
+ else
+ domains[dom] = domains[dom] + 1
+ end
+ end, fun.filter(function(u)
+ return not u:is_html_displayed()
+ end,
+ task:get_urls(true)))
+
+ local results = {}
+ for k, v in lua_util.spairs(domains,
+ function(t, a, b)
+ return t[a] > t[b]
+ end, rule.selector.config.max_urls) do
+ if v > 0 then
+ table.insert(results, { k, v })
+ end
+ end
+
+ return results
+end
+
+local function url_reputation_filter(task, rule)
+ local requests = gen_url_queries(task, rule)
+ local url_keys = lua_util.keys(requests)
+ local requests_left = #url_keys
+ local results = {}
+
+ local function indexed_tokens_cb(err, index, values)
+ requests_left = requests_left - 1
+
+ if values then
+ results[index] = values
+ end
+
+ if requests_left == 0 then
+ -- Check the url with maximum hits
+ local mhits = 0
+
+ for i, res in ipairs(results) do
+ local req = requests[i]
+ if req then
+ local hits = tonumber(res[1])
+ if hits > mhits then
+ mhits = hits
+ end
+ else
+ rspamd_logger.warnx(task, "cannot find the requested response for a request: %s (%s requests noticed)",
+ i, #requests)
+ end
+ end
+
+ if mhits > 0 then
+ local score = 0
+ for i, res in pairs(results) do
+ local req = requests[i]
+ if req then
+ local url_score = generic_reputation_calc(res, rule,
+ req[2] / mhits, task)
+ lua_util.debugm(N, task, "score for url %s is %s, score=%s", req[1], url_score, score)
+ score = score + url_score
+ end
+ end
+
+ if math.abs(score) > 1e-3 then
+ -- TODO: add description
+ add_symbol_score(task, rule, score)
+ end
+ end
+ end
+ end
+
+ for i, req in ipairs(requests) do
+ local function tokens_cb(err, token, values)
+ indexed_tokens_cb(err, i, values)
+ end
+
+ rule.backend.get_token(task, rule, nil, req[1], tokens_cb, 'string')
+ end
+end
+
+local function url_reputation_idempotent(task, rule)
+ local requests = gen_url_queries(task, rule)
+ local sc = extract_task_score(task, rule)
+
+ if sc then
+ for _, tld in ipairs(requests) do
+ rule.backend.set_token(task, rule, nil, tld[1], sc)
+ end
+ end
+end
+
+local url_selector = {
+ config = {
+ symbol = 'URL_SCORE', -- symbol to be inserted
+ lower_bound = 10, -- minimum number of messages to be scored
+ min_score = nil,
+ max_score = nil,
+ max_urls = 10,
+ check_from = true,
+ outbound = true,
+ inbound = true,
+ },
+ filter = url_reputation_filter, -- used to get scores
+ idempotent = url_reputation_idempotent -- used to set scores
+}
+-- IP Selector functions
+
+local function ip_reputation_init(rule)
+ local cfg = rule.selector.config
+
+ if cfg.asn_cc_whitelist then
+ cfg.asn_cc_whitelist = lua_maps.map_add('reputation',
+ 'asn_cc_whitelist',
+ 'map',
+ 'IP score whitelisted ASNs/countries')
+ end
+
+ return true
+end
+
+local function ip_reputation_filter(task, rule)
+
+ local ip = task:get_from_ip()
+
+ if not ip or not ip:is_valid() then
+ return
+ end
+ if lua_util.is_rspamc_or_controller(task) then
+ return
+ end
+
+ local cfg = rule.selector.config
+
+ if ip:get_version() == 4 and cfg.ipv4_mask then
+ ip = ip:apply_mask(cfg.ipv4_mask)
+ elseif cfg.ipv6_mask then
+ ip = ip:apply_mask(cfg.ipv6_mask)
+ end
+
+ local pool = task:get_mempool()
+ local asn = pool:get_variable("asn")
+ local country = pool:get_variable("country")
+
+ if country and cfg.asn_cc_whitelist then
+ if cfg.asn_cc_whitelist:get_key(country) then
+ return
+ end
+ if asn and cfg.asn_cc_whitelist:get_key(asn) then
+ return
+ end
+ end
+
+ -- These variables are used to define if we have some specific token
+ local has_asn = not asn
+ local has_country = not country
+ local has_ip = false
+
+ local asn_stats, country_stats, ip_stats
+
+ local function ipstats_check()
+ local score = 0.0
+ local description_t = {}
+
+ if asn_stats then
+ local asn_score = generic_reputation_calc(asn_stats, rule, cfg.scores.asn, task)
+ score = score + asn_score
+ table.insert(description_t, string.format('asn: %s(%.2f)',
+ asn, asn_score))
+ end
+ if country_stats then
+ local country_score = generic_reputation_calc(country_stats, rule,
+ cfg.scores.country, task)
+ score = score + country_score
+ table.insert(description_t, string.format('country: %s(%.2f)',
+ country, country_score))
+ end
+ if ip_stats then
+ local ip_score = generic_reputation_calc(ip_stats, rule, cfg.scores.ip,
+ task)
+ score = score + ip_score
+ table.insert(description_t, string.format('ip: %s(%.2f)',
+ tostring(ip), ip_score))
+ end
+
+ if math.abs(score) > 0.001 then
+ add_symbol_score(task, rule, score, table.concat(description_t, ', '))
+ end
+ end
+
+ local function gen_token_callback(what)
+ return function(err, _, values)
+ if not err and values then
+ if what == 'asn' then
+ has_asn = true
+ asn_stats = values
+ elseif what == 'country' then
+ has_country = true
+ country_stats = values
+ elseif what == 'ip' then
+ has_ip = true
+ ip_stats = values
+ end
+ else
+ if what == 'asn' then
+ has_asn = true
+ elseif what == 'country' then
+ has_country = true
+ elseif what == 'ip' then
+ has_ip = true
+ end
+ end
+
+ if has_asn and has_country and has_ip then
+ -- Check reputation
+ ipstats_check()
+ end
+ end
+ end
+
+ if asn then
+ rule.backend.get_token(task, rule, cfg.asn_prefix, asn,
+ gen_token_callback('asn'), 'string')
+ end
+ if country then
+ rule.backend.get_token(task, rule, cfg.country_prefix, country,
+ gen_token_callback('country'), 'string')
+ end
+
+ rule.backend.get_token(task, rule, cfg.ip_prefix, ip,
+ gen_token_callback('ip'), 'ip')
+end
+
+-- Used to set scores
+local function ip_reputation_idempotent(task, rule)
+ if not rule.backend.set_token then
+ return
+ end -- Read only backend
+ local ip = task:get_from_ip()
+ local cfg = rule.selector.config
+
+ if not ip or not ip:is_valid() then
+ return
+ end
+
+ if lua_util.is_rspamc_or_controller(task) then
+ return
+ end
+
+ if ip:get_version() == 4 and cfg.ipv4_mask then
+ ip = ip:apply_mask(cfg.ipv4_mask)
+ elseif cfg.ipv6_mask then
+ ip = ip:apply_mask(cfg.ipv6_mask)
+ end
+
+ local pool = task:get_mempool()
+ local asn = pool:get_variable("asn")
+ local country = pool:get_variable("country")
+
+ if country and cfg.asn_cc_whitelist then
+ if cfg.asn_cc_whitelist:get_key(country) then
+ return
+ end
+ if asn and cfg.asn_cc_whitelist:get_key(asn) then
+ return
+ end
+ end
+ local sc = extract_task_score(task, rule)
+ if sc then
+ if asn then
+ rule.backend.set_token(task, rule, cfg.asn_prefix, asn, sc, nil, 'string')
+ end
+ if country then
+ rule.backend.set_token(task, rule, cfg.country_prefix, country, sc, nil, 'string')
+ end
+
+ rule.backend.set_token(task, rule, cfg.ip_prefix, ip, sc, nil, 'ip')
+ end
+end
+
+-- Selectors are used to extract reputation tokens
+local ip_selector = {
+ config = {
+ scores = { -- how each component is evaluated
+ ['asn'] = 0.4,
+ ['country'] = 0.01,
+ ['ip'] = 1.0
+ },
+ symbol = 'SENDER_REP', -- symbol to be inserted
+ split_symbols = true,
+ asn_prefix = 'a:', -- prefix for ASN hashes
+ country_prefix = 'c:', -- prefix for country hashes
+ ip_prefix = 'i:',
+ lower_bound = 10, -- minimum number of messages to be scored
+ min_score = nil,
+ max_score = nil,
+ score_divisor = 1,
+ outbound = false,
+ inbound = true,
+ ipv4_mask = 32, -- Mask bits for ipv4
+ ipv6_mask = 64, -- Mask bits for ipv6
+ },
+ --dependencies = {"ASN"}, -- ASN is a prefilter now...
+ init = ip_reputation_init,
+ filter = ip_reputation_filter, -- used to get scores
+ idempotent = ip_reputation_idempotent, -- used to set scores
+}
+
+-- SPF Selector functions
+
+local function spf_reputation_filter(task, rule)
+ local spf_record = task:get_mempool():get_variable('spf_record')
+ local spf_allow = task:has_symbol('R_SPF_ALLOW')
+
+ -- Don't care about bad/missing spf
+ if not spf_record or not spf_allow then
+ return
+ end
+
+ local cr = require "rspamd_cryptobox_hash"
+ local hkey = cr.create(spf_record):base32():sub(1, 32)
+
+ lua_util.debugm(N, task, 'check spf record %s -> %s', spf_record, hkey)
+
+ local function tokens_cb(err, token, values)
+ if values then
+ local score = generic_reputation_calc(values, rule, 1.0, task)
+
+ if math.abs(score) > 1e-3 then
+ -- TODO: add description
+ add_symbol_score(task, rule, score)
+ end
+ end
+ end
+
+ rule.backend.get_token(task, rule, nil, hkey, tokens_cb, 'string')
+end
+
+local function spf_reputation_idempotent(task, rule)
+ local sc = extract_task_score(task, rule)
+ local spf_record = task:get_mempool():get_variable('spf_record')
+ local spf_allow = task:has_symbol('R_SPF_ALLOW')
+
+ if not spf_record or not spf_allow or not sc then
+ return
+ end
+
+ local cr = require "rspamd_cryptobox_hash"
+ local hkey = cr.create(spf_record):base32():sub(1, 32)
+
+ lua_util.debugm(N, task, 'set spf record %s -> %s = %s',
+ spf_record, hkey, sc)
+ rule.backend.set_token(task, rule, nil, hkey, sc)
+end
+
+local spf_selector = {
+ config = {
+ symbol = 'SPF_REP', -- symbol to be inserted
+ split_symbols = true,
+ lower_bound = 10, -- minimum number of messages to be scored
+ min_score = nil,
+ max_score = nil,
+ outbound = true,
+ inbound = true,
+ },
+ dependencies = { "R_SPF_ALLOW" },
+ filter = spf_reputation_filter, -- used to get scores
+ idempotent = spf_reputation_idempotent, -- used to set scores
+}
+
+-- Generic selector based on lua_selectors framework
+
+local function generic_reputation_init(rule)
+ local cfg = rule.selector.config
+
+ if not cfg.selector then
+ rspamd_logger.errx(rspamd_config, 'cannot configure generic rule: no selector specified')
+ return false
+ end
+
+ local selector = lua_selectors.create_selector_closure(rspamd_config,
+ cfg.selector, cfg.delimiter)
+
+ if not selector then
+ rspamd_logger.errx(rspamd_config, 'cannot configure generic rule: bad selector: %s',
+ cfg.selector)
+ return false
+ end
+
+ cfg.selector = selector -- Replace with closure
+
+ if cfg.whitelist then
+ cfg.whitelist = lua_maps.map_add('reputation',
+ 'generic_whitelist',
+ 'map',
+ 'Whitelisted selectors')
+ end
+
+ return true
+end
+
+local function generic_reputation_filter(task, rule)
+ local cfg = rule.selector.config
+ local selector_res = cfg.selector(task)
+
+ local function tokens_cb(err, token, values)
+ if values then
+ local score = generic_reputation_calc(values, rule, 1.0, task)
+
+ if math.abs(score) > 1e-3 then
+ -- TODO: add description
+ add_symbol_score(task, rule, score)
+ end
+ end
+ end
+
+ if selector_res then
+ if type(selector_res) == 'table' then
+ fun.each(function(e)
+ lua_util.debugm(N, task, 'check generic reputation (%s) %s',
+ rule['symbol'], e)
+ rule.backend.get_token(task, rule, nil, e, tokens_cb, 'string')
+ end, selector_res)
+ else
+ lua_util.debugm(N, task, 'check generic reputation (%s) %s',
+ rule['symbol'], selector_res)
+ rule.backend.get_token(task, rule, nil, selector_res, tokens_cb, 'string')
+ end
+ end
+end
+
+local function generic_reputation_idempotent(task, rule)
+ local sc = extract_task_score(task, rule)
+ local cfg = rule.selector.config
+
+ local selector_res = cfg.selector(task)
+ if not selector_res then
+ return
+ end
+
+ if sc then
+ if type(selector_res) == 'table' then
+ fun.each(function(e)
+ lua_util.debugm(N, task, 'set generic selector (%s) %s = %s',
+ rule['symbol'], e, sc)
+ rule.backend.set_token(task, rule, nil, e, sc)
+ end, selector_res)
+ else
+ lua_util.debugm(N, task, 'set generic selector (%s) %s = %s',
+ rule['symbol'], selector_res, sc)
+ rule.backend.set_token(task, rule, nil, selector_res, sc)
+ end
+ end
+end
+
+local generic_selector = {
+ schema = ts.shape {
+ lower_bound = ts.number + ts.string / tonumber,
+ max_score = ts.number:is_optional(),
+ min_score = ts.number:is_optional(),
+ outbound = ts.boolean,
+ inbound = ts.boolean,
+ selector = ts.string,
+ delimiter = ts.string,
+ whitelist = ts.one_of(lua_maps.map_schema, lua_maps_exprs.schema):is_optional(),
+ },
+ config = {
+ lower_bound = 10, -- minimum number of messages to be scored
+ min_score = nil,
+ max_score = nil,
+ outbound = true,
+ inbound = true,
+ selector = nil,
+ delimiter = ':',
+ whitelist = nil
+ },
+ init = generic_reputation_init,
+ filter = generic_reputation_filter, -- used to get scores
+ idempotent = generic_reputation_idempotent -- used to set scores
+}
+
+local selectors = {
+ ip = ip_selector,
+ sender = ip_selector, -- Better name
+ url = url_selector,
+ dkim = dkim_selector,
+ spf = spf_selector,
+ generic = generic_selector
+}
+
+local function reputation_dns_init(rule, _, _, _)
+ if not rule.backend.config.list then
+ rspamd_logger.errx(rspamd_config, "rule %s with DNS backend has no `list` parameter defined",
+ rule.symbol)
+ return false
+ end
+
+ return true
+end
+
+local function gen_token_key(prefix, token, rule)
+ if prefix then
+ token = prefix .. token
+ end
+ local res = token
+ if rule.backend.config.hashed then
+ local hash_alg = rule.backend.config.hash_alg or "blake2"
+ local encoding = "base32"
+
+ if rule.backend.config.hash_encoding then
+ encoding = rule.backend.config.hash_encoding
+ end
+
+ local h = hash.create_specific(hash_alg, res)
+ if encoding == 'hex' then
+ res = h:hex()
+ elseif encoding == 'base64' then
+ res = h:base64()
+ else
+ res = h:base32()
+ end
+ end
+
+ if rule.backend.config.hashlen then
+ res = string.sub(res, 1, rule.backend.config.hashlen)
+ end
+
+ if rule.backend.config.prefix then
+ res = rule.backend.config.prefix .. res
+ end
+
+ return res
+end
+
+--[[
+-- Generic interface for get and set tokens functions:
+-- get_token(task, rule, prefix, token, continuation, token_type), where `continuation` is the following function:
+--
+-- function(err, token, values) ... end
+-- `err`: string value for error (similar to redis or DNS callbacks)
+-- `token`: string value of a token
+-- `values`: table of key=number, parsed from backend. It is selector's duty
+-- to deal with missing, invalid or other values
+--
+-- set_token(task, rule, token, values, continuation_cb)
+-- This function takes values, encodes them using whatever suitable format
+-- and calls for continuation:
+--
+-- function(err, token) ... end
+-- `err`: string value for error (similar to redis or DNS callbacks)
+-- `token`: string value of a token
+--
+-- example of tokens: {'s': 0, 'h': 0, 'p': 1}
+--]]
+
+local function reputation_dns_get_token(task, rule, prefix, token, continuation_cb, token_type)
+ -- local r = task:get_resolver()
+ -- In DNS we never ever use prefix as prefix, we use if as a suffix!
+ if token_type == 'ip' then
+ token = table.concat(token:inversed_str_octets(), '.')
+ end
+
+ local key = gen_token_key(nil, token, rule)
+ local dns_name = key .. '.' .. rule.backend.config.list
+
+ if prefix then
+ dns_name = string.format('%s.%s.%s', key, prefix,
+ rule.backend.config.list)
+ else
+ dns_name = string.format('%s.%s', key, rule.backend.config.list)
+ end
+
+ local function dns_cb(_, _, results, err)
+ if err and (err ~= 'requested record is not found' and
+ err ~= 'no records with this name') then
+ rspamd_logger.warnx(task, 'error looking up %s: %s', dns_name, err)
+ end
+
+ lua_util.debugm(N, task, 'DNS RESPONSE: label=%1 results=%2 err=%3 list=%4',
+ dns_name, results, err, rule.backend.config.list)
+
+ -- Now split tokens to list of values
+ if results and results[1] then
+ -- Format: num_messages;sc1;sc2...scn
+ local dns_tokens = lua_util.rspamd_str_split(results[1], ";")
+ -- Convert all to numbers excluding any possible non-numbers
+ dns_tokens = fun.totable(fun.filter(function(e)
+ return type(e) == 'number'
+ end,
+ fun.map(function(e)
+ local n = tonumber(e)
+ if n then
+ return n
+ end
+ return "BAD"
+ end, dns_tokens)))
+
+ if #dns_tokens < 2 then
+ rspamd_logger.warnx(task, 'cannot parse response for reputation token %s: %s',
+ dns_name, results[1])
+ continuation_cb(results, dns_name, nil)
+ else
+ local cnt = table.remove(dns_tokens, 1)
+ continuation_cb(nil, dns_name, { cnt, dns_tokens })
+ end
+ else
+ rspamd_logger.messagex(task, 'invalid response for reputation token %s: %s',
+ dns_name, results[1])
+ continuation_cb(results, dns_name, nil)
+ end
+ end
+
+ task:get_resolver():resolve_a({
+ task = task,
+ name = dns_name,
+ callback = dns_cb,
+ forced = true,
+ })
+end
+
+local function reputation_redis_init(rule, cfg, ev_base, worker)
+ local our_redis_params = {}
+
+ our_redis_params = lua_redis.try_load_redis_servers(rule.backend.config, rspamd_config,
+ true)
+ if not our_redis_params then
+ our_redis_params = redis_params
+ end
+ if not our_redis_params then
+ rspamd_logger.errx(rspamd_config, 'cannot init redis for reputation rule: %s',
+ rule)
+ return false
+ end
+ -- Init scripts for buckets
+ -- Redis script to extract data from Redis buckets
+ -- KEYS[1] - key to extract
+ -- Value returned - table of scores as a strings vector + number of scores
+ local redis_get_script_tpl = [[
+ local cnt = redis.call('HGET', KEYS[1], 'n')
+ local results = {}
+ if cnt then
+ {% for w in windows %}
+ local sc = tonumber(redis.call('HGET', KEYS[1], 'v' .. '{= w.name =}'))
+ table.insert(results, tostring(sc * {= w.mult =}))
+ {% endfor %}
+ else
+ {% for w in windows %}
+ table.insert(results, '0')
+ {% endfor %}
+ end
+
+ return {cnt or 0, results}
+ ]]
+
+ local get_script = lua_util.jinja_template(redis_get_script_tpl,
+ { windows = rule.backend.config.buckets })
+ rspamd_logger.debugm(N, rspamd_config, 'added extraction script %s', get_script)
+ rule.backend.script_get = lua_redis.add_redis_script(get_script, our_redis_params)
+
+ -- Redis script to update Redis buckets
+ -- KEYS[1] - key to update
+ -- KEYS[2] - current time in milliseconds
+ -- KEYS[3] - message score
+ -- KEYS[4] - expire for a bucket
+ -- Value returned - table of scores as a strings vector
+ local redis_adaptive_emea_script_tpl = [[
+ local last = redis.call('HGET', KEYS[1], 'l')
+ local score = tonumber(KEYS[3])
+ local now = tonumber(KEYS[2])
+ local scores = {}
+
+ if last then
+ {% for w in windows %}
+ local last_value = tonumber(redis.call('HGET', KEYS[1], 'v' .. '{= w.name =}'))
+ local window = {= w.time =}
+ -- Adjust alpha
+ local time_diff = now - last
+ if time_diff < 0 then
+ time_diff = 0
+ end
+ local alpha = 1.0 - math.exp((-time_diff) / (1000 * window))
+ local nscore = alpha * score + (1.0 - alpha) * last_value
+ table.insert(scores, tostring(nscore * {= w.mult =}))
+ {% endfor %}
+ else
+ {% for w in windows %}
+ table.insert(scores, tostring(score * {= w.mult =}))
+ {% endfor %}
+ end
+
+ local i = 1
+ {% for w in windows %}
+ redis.call('HSET', KEYS[1], 'v' .. '{= w.name =}', scores[i])
+ i = i + 1
+ {% endfor %}
+ redis.call('HSET', KEYS[1], 'l', now)
+ redis.call('HINCRBY', KEYS[1], 'n', 1)
+ redis.call('EXPIRE', KEYS[1], tonumber(KEYS[4]))
+
+ return scores
+]]
+
+ local set_script = lua_util.jinja_template(redis_adaptive_emea_script_tpl,
+ { windows = rule.backend.config.buckets })
+ rspamd_logger.debugm(N, rspamd_config, 'added emea update script %s', set_script)
+ rule.backend.script_set = lua_redis.add_redis_script(set_script, our_redis_params)
+
+ return true
+end
+
+local function reputation_redis_get_token(task, rule, prefix, token, continuation_cb, token_type)
+ if token_type and token_type == 'ip' then
+ token = tostring(token)
+ end
+ local key = gen_token_key(prefix, token, rule)
+
+ local function redis_get_cb(err, data)
+ if data then
+ if type(data) == 'table' then
+ lua_util.debugm(N, task, 'rule %s - got values for key %s -> %s',
+ rule['symbol'], key, data)
+ continuation_cb(nil, key, data)
+ else
+ rspamd_logger.errx(task, 'rule %s - invalid type while getting reputation keys %s: %s',
+ rule['symbol'], key, type(data))
+ continuation_cb("invalid type", key, nil)
+ end
+
+ elseif err then
+ rspamd_logger.errx(task, 'rule %s - got error while getting reputation keys %s: %s',
+ rule['symbol'], key, err)
+ continuation_cb(err, key, nil)
+ else
+ rspamd_logger.errx(task, 'rule %s - got error while getting reputation keys %s: %s',
+ rule['symbol'], key, "unknown error")
+ continuation_cb("unknown error", key, nil)
+ end
+ end
+
+ local ret = lua_redis.exec_redis_script(rule.backend.script_get,
+ { task = task, is_write = false },
+ redis_get_cb,
+ { key })
+ if not ret then
+ rspamd_logger.errx(task, 'cannot make redis request to check results')
+ end
+end
+
+local function reputation_redis_set_token(task, rule, prefix, token, sc, continuation_cb, token_type)
+ if token_type and token_type == 'ip' then
+ token = tostring(token)
+ end
+ local key = gen_token_key(prefix, token, rule)
+
+ local function redis_set_cb(err, data)
+ if err then
+ rspamd_logger.errx(task, 'rule %s - got error while setting reputation keys %s: %s',
+ rule['symbol'], key, err)
+ if continuation_cb then
+ continuation_cb(err, key)
+ end
+ else
+ if continuation_cb then
+ continuation_cb(nil, key)
+ end
+ end
+ end
+
+ lua_util.debugm(N, task, 'rule %s - set values for key %s -> %s',
+ rule['symbol'], key, sc)
+ local ret = lua_redis.exec_redis_script(rule.backend.script_set,
+ { task = task, is_write = true },
+ redis_set_cb,
+ { key, tostring(os.time() * 1000),
+ tostring(sc),
+ tostring(rule.backend.config.expiry) })
+ if not ret then
+ rspamd_logger.errx(task, 'got error while connecting to redis')
+ end
+end
+
+--[[ Backends are responsible for getting reputation tokens
+ -- Common config options:
+ -- `hashed`: if `true` then apply hash function to the key
+ -- `hash_alg`: use specific hash type (`blake2` by default)
+ -- `hash_len`: strip hash to this amount of bytes (no strip by default)
+ -- `hash_encoding`: use specific hash encoding (base32 by default)
+--]]
+local backends = {
+ redis = {
+ schema = lua_redis.enrich_schema({
+ prefix = ts.string:is_optional(),
+ expiry = (ts.number + ts.string / lua_util.parse_time_interval):is_optional(),
+ buckets = ts.array_of(ts.shape {
+ time = ts.number + ts.string / lua_util.parse_time_interval,
+ name = ts.string,
+ mult = ts.number + ts.string / tonumber
+ }) :is_optional(),
+ }),
+ config = {
+ expiry = default_expiry,
+ prefix = default_prefix,
+ buckets = {
+ {
+ time = 60 * 60 * 24 * 30,
+ name = '1m',
+ mult = 1.0,
+ }
+ }, -- What buckets should be used, default 1h and 1month
+ },
+ init = reputation_redis_init,
+ get_token = reputation_redis_get_token,
+ set_token = reputation_redis_set_token,
+ },
+ dns = {
+ schema = ts.shape {
+ list = ts.string,
+ },
+ config = {
+ -- list = rep.example.com
+ },
+ get_token = reputation_dns_get_token,
+ -- No set token for DNS
+ init = reputation_dns_init,
+ }
+}
+
+local function is_rule_applicable(task, rule)
+ local ip = task:get_from_ip()
+ if not (rule.selector.config.outbound and rule.selector.config.inbound) then
+ if rule.selector.config.outbound then
+ if not (task:get_user() or (ip and ip:is_local())) then
+ return false
+ end
+ elseif rule.selector.config.inbound then
+ if task:get_user() or (ip and ip:is_local()) then
+ return false
+ end
+ end
+ end
+
+ if rule.config.whitelist_map then
+ if rule.config.whitelist_map:process(task) then
+ return false
+ end
+ end
+
+ return true
+end
+
+local function reputation_filter_cb(task, rule)
+ if (is_rule_applicable(task, rule)) then
+ rule.selector.filter(task, rule, rule.backend)
+ end
+end
+
+local function reputation_postfilter_cb(task, rule)
+ if (is_rule_applicable(task, rule)) then
+ rule.selector.postfilter(task, rule, rule.backend)
+ end
+end
+
+local function reputation_idempotent_cb(task, rule)
+ if (is_rule_applicable(task, rule)) then
+ rule.selector.idempotent(task, rule, rule.backend)
+ end
+end
+
+local function callback_gen(cb, rule)
+ return function(task)
+ if rule.enabled then
+ cb(task, rule)
+ end
+ end
+end
+
+local function parse_rule(name, tbl)
+ local sel_type, sel_conf = fun.head(tbl.selector)
+ local selector = selectors[sel_type]
+
+ if not selector then
+ rspamd_logger.errx(rspamd_config, "unknown selector defined for rule %s: %s", name,
+ sel_type)
+ return false
+ end
+
+ local bk_type, bk_conf = fun.head(tbl.backend)
+
+ local backend = backends[bk_type]
+ if not backend then
+ rspamd_logger.errx(rspamd_config, "unknown backend defined for rule %s: %s", name,
+ tbl.backend.type)
+ return false
+ end
+ -- Allow config override
+ local rule = {
+ selector = lua_util.shallowcopy(selector),
+ backend = lua_util.shallowcopy(backend),
+ config = {}
+ }
+
+ -- Override default config params
+ rule.backend.config = lua_util.override_defaults(rule.backend.config, bk_conf)
+ if backend.schema then
+ local checked, schema_err = backend.schema:transform(rule.backend.config)
+ if not checked then
+ rspamd_logger.errx(rspamd_config, "cannot parse backend config for %s: %s",
+ sel_type, schema_err)
+
+ return false
+ end
+
+ rule.backend.config = checked
+ end
+
+ rule.selector.config = lua_util.override_defaults(rule.selector.config, sel_conf)
+ if selector.schema then
+ local checked, schema_err = selector.schema:transform(rule.selector.config)
+
+ if not checked then
+ rspamd_logger.errx(rspamd_config, "cannot parse selector config for %s: %s (%s)",
+ sel_type,
+ schema_err, sel_conf)
+ return
+ end
+
+ rule.selector.config = checked
+ end
+ -- Generic options
+ tbl.selector = nil
+ tbl.backend = nil
+ rule.config = lua_util.override_defaults(rule.config, tbl)
+
+ if rule.config.whitelist then
+ if lua_maps_exprs.schema(rule.config.whitelist) then
+ rule.config.whitelist_map = lua_maps_exprs.create(rspamd_config,
+ rule.config.whitelist, N)
+ elseif lua_maps.map_schema(rule.config.whitelist) then
+ local map = lua_maps.map_add_from_ucl(rule.config.whitelist,
+ 'radix',
+ sel_type .. ' reputation whitelist')
+
+ if not map then
+ rspamd_logger.errx(rspamd_config, "cannot parse whitelist map config for %s: (%s)",
+ sel_type,
+ rule.config.whitelist)
+ return
+ end
+
+ rule.config.whitelist_map = {
+ process = function(_, task)
+ -- Hack: we assume that it is an ip whitelist :(
+ local ip = task:get_from_ip()
+
+ if ip and map:get_key(ip) then
+ return true
+ end
+ return false
+ end
+ }
+ else
+ rspamd_logger.errx(rspamd_config, "cannot parse whitelist map config for %s: (%s)",
+ sel_type,
+ rule.config.whitelist)
+ return false
+ end
+ end
+
+ local symbol = rule.selector.config.symbol or name
+ if tbl.symbol then
+ symbol = tbl.symbol
+ end
+
+ rule.symbol = symbol
+ rule.enabled = true
+ if rule.selector.init then
+ rule.enabled = false
+ end
+ if rule.backend.init then
+ rule.enabled = false
+ end
+ -- Perform additional initialization if needed
+ rspamd_config:add_on_load(function(cfg, ev_base, worker)
+ if rule.selector.init then
+ if not rule.selector.init(rule, cfg, ev_base, worker) then
+ rule.enabled = false
+ rspamd_logger.errx(rspamd_config, 'Cannot init selector %s (backend %s) for symbol %s',
+ sel_type, bk_type, rule.symbol)
+ else
+ rule.enabled = true
+ end
+ end
+ if rule.backend.init then
+ if not rule.backend.init(rule, cfg, ev_base, worker) then
+ rule.enabled = false
+ rspamd_logger.errx(rspamd_config, 'Cannot init backend (%s) for rule %s for symbol %s',
+ bk_type, sel_type, rule.symbol)
+ else
+ rule.enabled = true
+ end
+ end
+
+ if rule.enabled then
+ rspamd_logger.infox(rspamd_config, 'Enable %s (%s backend) rule for symbol %s (split symbols: %s)',
+ sel_type, bk_type, rule.symbol,
+ rule.selector.config.split_symbols)
+ end
+ end)
+
+ -- We now generate symbol for checking
+ local rule_type = 'normal'
+ if rule.selector.config.split_symbols then
+ rule_type = 'callback'
+ end
+
+ local id = rspamd_config:register_symbol {
+ name = rule.symbol,
+ type = rule_type,
+ callback = callback_gen(reputation_filter_cb, rule),
+ augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) },
+ }
+
+ if rule.selector.config.split_symbols then
+ rspamd_config:register_symbol {
+ name = rule.symbol .. '_HAM',
+ type = 'virtual',
+ parent = id,
+ }
+ rspamd_config:register_symbol {
+ name = rule.symbol .. '_SPAM',
+ type = 'virtual',
+ parent = id,
+ }
+ end
+
+ if rule.selector.dependencies then
+ fun.each(function(d)
+ rspamd_config:register_dependency(symbol, d)
+ end, rule.selector.dependencies)
+ end
+
+ if rule.selector.postfilter then
+ -- Also register a postfilter
+ rspamd_config:register_symbol {
+ name = rule.symbol .. '_POST',
+ type = 'postfilter',
+ flags = 'nostat,explicit_disable,ignore_passthrough',
+ callback = callback_gen(reputation_postfilter_cb, rule),
+ augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) },
+ }
+ end
+
+ if rule.selector.idempotent then
+ -- Has also idempotent component (e.g. saving data to the backend)
+ rspamd_config:register_symbol {
+ name = rule.symbol .. '_IDEMPOTENT',
+ type = 'idempotent',
+ flags = 'explicit_disable,ignore_passthrough',
+ callback = callback_gen(reputation_idempotent_cb, rule),
+ augmentations = { string.format("timeout=%f", redis_params.timeout or 0.0) },
+ }
+ end
+
+ return true
+end
+
+redis_params = lua_redis.parse_redis_server('reputation')
+local opts = rspamd_config:get_all_opt("reputation")
+
+-- Initialization part
+if not (opts and type(opts) == 'table') then
+ rspamd_logger.infox(rspamd_config, 'Module is not configured, disabling it')
+ return
+end
+
+if opts['rules'] then
+ for k, v in pairs(opts['rules']) do
+ if not ((v or E).selector) then
+ rspamd_logger.errx(rspamd_config, "no selector defined for rule %s", k)
+ lua_util.config_utils.push_config_error(N, "no selector defined for rule: " .. k)
+ else
+ if not parse_rule(k, v) then
+ lua_util.config_utils.push_config_error(N, "reputation rule is misconfigured: " .. k)
+ end
+ end
+ end
+else
+ lua_util.disable_module(N, "config")
+end