diff options
Diffstat (limited to 'lualib/plugins')
-rw-r--r-- | lualib/plugins/dmarc.lua | 359 | ||||
-rw-r--r-- | lualib/plugins/neural.lua | 892 | ||||
-rw-r--r-- | lualib/plugins/rbl.lua | 232 |
3 files changed, 1483 insertions, 0 deletions
diff --git a/lualib/plugins/dmarc.lua b/lualib/plugins/dmarc.lua new file mode 100644 index 0000000..7791f4e --- /dev/null +++ b/lualib/plugins/dmarc.lua @@ -0,0 +1,359 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> +Copyright (c) 2015-2016, Andrew Lewis <nerf@judo.za.org> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +-- Common dmarc stuff +local rspamd_logger = require "rspamd_logger" +local lua_util = require "lua_util" +local N = "dmarc" + +local exports = {} + +exports.default_settings = { + auth_and_local_conf = false, + symbols = { + spf_allow_symbol = 'R_SPF_ALLOW', + spf_deny_symbol = 'R_SPF_FAIL', + spf_softfail_symbol = 'R_SPF_SOFTFAIL', + spf_neutral_symbol = 'R_SPF_NEUTRAL', + spf_tempfail_symbol = 'R_SPF_DNSFAIL', + spf_permfail_symbol = 'R_SPF_PERMFAIL', + spf_na_symbol = 'R_SPF_NA', + + dkim_allow_symbol = 'R_DKIM_ALLOW', + dkim_deny_symbol = 'R_DKIM_REJECT', + dkim_tempfail_symbol = 'R_DKIM_TEMPFAIL', + dkim_na_symbol = 'R_DKIM_NA', + dkim_permfail_symbol = 'R_DKIM_PERMFAIL', + + -- DMARC symbols + allow = 'DMARC_POLICY_ALLOW', + badpolicy = 'DMARC_BAD_POLICY', + dnsfail = 'DMARC_DNSFAIL', + na = 'DMARC_NA', + reject = 'DMARC_POLICY_REJECT', + softfail = 'DMARC_POLICY_SOFTFAIL', + quarantine = 'DMARC_POLICY_QUARANTINE', + }, + no_sampling_domains = nil, + no_reporting_domains = nil, + reporting = { + report_local_controller = false, -- Store reports for local/controller scans (for testing only) + redis_keys = { + index_prefix = 'dmarc_idx', + report_prefix = 'dmarc_rpt', + join_char = ';', + }, + helo = 'rspamd.localhost', + smtp = '127.0.0.1', + smtp_port = 25, + retries = 2, + from_name = 'Rspamd', + msgid_from = 'rspamd', + enabled = false, + max_entries = 1000, + keys_expire = 172800, + only_domains = nil, + }, + actions = {}, +} + + +-- Returns a key used to be inserted into dmarc report sample +exports.dmarc_report = function(task, settings, data) + local rspamd_lua_utils = require "lua_util" + local E = {} + + local ip = task:get_from_ip() + if not ip or not ip:is_valid() then + rspamd_logger.infox(task, 'cannot store dmarc report for %s: no valid source IP', + data.domain) + return nil + end + + ip = ip:to_string() + + if rspamd_lua_utils.is_rspamc_or_controller(task) and not settings.reporting.report_local_controller then + rspamd_logger.infox(task, 'cannot store dmarc report for %s from IP %s: has come from controller/rspamc', + data.domain, ip) + return + end + + local dkim_pass = table.concat(data.dkim_results.pass or E, '|') + local dkim_fail = table.concat(data.dkim_results.fail or E, '|') + local dkim_temperror = table.concat(data.dkim_results.temperror or E, '|') + local dkim_permerror = table.concat(data.dkim_results.permerror or E, '|') + local disposition_to_return = data.disposition + local res = table.concat({ + ip, data.spf_ok, data.dkim_ok, + disposition_to_return, (data.sampled_out and 'sampled_out' or ''), data.domain, + dkim_pass, dkim_fail, dkim_temperror, dkim_permerror, data.spf_domain, data.spf_result }, ',') + + return res +end + +exports.gen_munging_callback = function(munging_opts, settings) + local rspamd_util = require "rspamd_util" + local lua_mime = require "lua_mime" + return function(task) + if munging_opts.mitigate_allow_only then + if not task:has_symbol(settings.symbols.allow) then + lua_util.debugm(N, task, 'skip munging, no %s symbol', + settings.symbols.allow) + -- Excepted + return + end + else + local has_dmarc = task:has_symbol(settings.symbols.allow) or + task:has_symbol(settings.symbols.quarantine) or + task:has_symbol(settings.symbols.reject) or + task:has_symbol(settings.symbols.softfail) + + if not has_dmarc then + lua_util.debugm(N, task, 'skip munging, no %s symbol', + settings.symbols.allow) + -- Excepted + return + end + end + if munging_opts.mitigate_strict_only then + local s = task:get_symbol(settings.symbols.allow) or { [1] = {} } + local sopts = s[1].options or {} + + local seen_strict + for _, o in ipairs(sopts) do + if o == 'reject' or o == 'quarantine' then + seen_strict = true + break + end + end + + if not seen_strict then + lua_util.debugm(N, task, 'skip munging, no strict policy found in %s', + settings.symbols.allow) + -- Excepted + return + end + end + if munging_opts.munge_map_condition then + local accepted, trace = munging_opts.munge_map_condition:process(task) + if not accepted then + lua_util.debugm(N, task, 'skip munging, maps condition not satisfied: (%s)', + trace) + -- Excepted + return + end + end + -- Now, look for domain for munging + local mr = task:get_recipients({ 'mime', 'orig' }) + local rcpt_found + if mr then + for _, r in ipairs(mr) do + if r.domain and munging_opts.list_map:get_key(r.addr) then + rcpt_found = r + break + end + end + end + + if not rcpt_found then + lua_util.debugm(N, task, 'skip munging, recipients are not in list_map') + -- Excepted + return + end + + local from = task:get_from({ 'mime', 'orig' }) + + if not from or not from[1] then + lua_util.debugm(N, task, 'skip munging, from is bad') + -- Excepted + return + end + + from = from[1] + local via_user = rcpt_found.user + local via_addr = rcpt_found.addr + local via_name + + if from.name == "" then + via_name = string.format('%s via %s', from.user or 'unknown', via_user) + else + via_name = string.format('%s via %s', from.name, via_user) + end + + local hdr_encoded = rspamd_util.fold_header('From', + rspamd_util.mime_header_encode(string.format('%s <%s>', + via_name, via_addr)), task:get_newlines_type()) + local orig_from_encoded = rspamd_util.fold_header('X-Original-From', + rspamd_util.mime_header_encode(string.format('%s <%s>', + from.name or '', from.addr)), task:get_newlines_type()) + local add_hdrs = { + ['From'] = { order = 1, value = hdr_encoded }, + } + local remove_hdrs = { ['From'] = 0 } + + local nreply = from.addr + if munging_opts.reply_goes_to_list then + -- Reply-to goes to the list + nreply = via_addr + end + + if task:has_header('Reply-To') then + -- If we have reply-to header, then we need to insert an additional + -- address there + local orig_reply = task:get_header_full('Reply-To')[1] + if orig_reply.value then + nreply = string.format('%s, %s', orig_reply.value, nreply) + end + remove_hdrs['Reply-To'] = 1 + end + + add_hdrs['Reply-To'] = { order = 0, value = nreply } + + add_hdrs['X-Original-From'] = { order = 0, value = orig_from_encoded } + lua_mime.modify_headers(task, { + remove = remove_hdrs, + add = add_hdrs + }) + lua_util.debugm(N, task, 'munged DMARC header for %s: %s -> %s', + from.domain, hdr_encoded, from.addr) + rspamd_logger.infox(task, 'munged DMARC header for %s', from.addr) + task:insert_result('DMARC_MUNGED', 1.0, from.addr) + end +end + +local function gen_dmarc_grammar() + local lpeg = require "lpeg" + lpeg.locale(lpeg) + local space = lpeg.space ^ 0 + local name = lpeg.C(lpeg.alpha ^ 1) * space + local sep = space * (lpeg.S("\\;") * space) + (lpeg.space ^ 1) + local value = lpeg.C(lpeg.P(lpeg.graph - sep) ^ 1) + local pair = lpeg.Cg(name * "=" * space * value) * sep ^ -1 + local list = lpeg.Cf(lpeg.Ct("") * pair ^ 0, rawset) + local version = lpeg.P("v") * space * lpeg.P("=") * space * lpeg.P("DMARC1") + local record = version * sep * list + + return record +end + +local dmarc_grammar = gen_dmarc_grammar() + +local function dmarc_key_value_case(elts) + if type(elts) ~= "table" then + return elts + end + local result = {} + for k, v in pairs(elts) do + k = k:lower() + if k ~= "v" then + v = v:lower() + end + + result[k] = v + end + + return result +end + +--[[ +-- Used to check dmarc record, check elements and produce dmarc policy processed +-- result. +-- Returns: +-- false,false - record is garbage +-- false,error_message - record is invalid +-- true,policy_table - record is valid and parsed +]] +local function dmarc_check_record(log_obj, record, is_tld) + local failed_policy + local result = { + dmarc_policy = 'none' + } + + local elts = dmarc_grammar:match(record) + lua_util.debugm(N, log_obj, "got DMARC record: %s, tld_flag=%s, processed=%s", + record, is_tld, elts) + + if elts then + elts = dmarc_key_value_case(elts) + + local dkim_pol = elts['adkim'] + if dkim_pol then + if dkim_pol == 's' then + result.strict_dkim = true + elseif dkim_pol ~= 'r' then + failed_policy = 'adkim tag has invalid value: ' .. dkim_pol + return false, failed_policy + end + end + + local spf_pol = elts['aspf'] + if spf_pol then + if spf_pol == 's' then + result.strict_spf = true + elseif spf_pol ~= 'r' then + failed_policy = 'aspf tag has invalid value: ' .. spf_pol + return false, failed_policy + end + end + + local policy = elts['p'] + if policy then + if (policy == 'reject') then + result.dmarc_policy = 'reject' + elseif (policy == 'quarantine') then + result.dmarc_policy = 'quarantine' + elseif (policy ~= 'none') then + failed_policy = 'p tag has invalid value: ' .. policy + return false, failed_policy + end + end + + -- Adjust policy if we are in tld mode + local subdomain_policy = elts['sp'] + if elts['sp'] and is_tld then + result.subdomain_policy = elts['sp'] + + if (subdomain_policy == 'reject') then + result.dmarc_policy = 'reject' + elseif (subdomain_policy == 'quarantine') then + result.dmarc_policy = 'quarantine' + elseif (subdomain_policy == 'none') then + result.dmarc_policy = 'none' + elseif (subdomain_policy ~= 'none') then + failed_policy = 'sp tag has invalid value: ' .. subdomain_policy + return false, failed_policy + end + end + result.pct = elts['pct'] + if result.pct then + result.pct = tonumber(result.pct) + end + + if elts.rua then + result.rua = elts['rua'] + end + result.raw_elts = elts + else + return false, false -- Ignore garbage + end + + return true, result +end + +exports.dmarc_check_record = dmarc_check_record + +return exports diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua new file mode 100644 index 0000000..6e88ef2 --- /dev/null +++ b/lualib/plugins/neural.lua @@ -0,0 +1,892 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local fun = require "fun" +local lua_redis = require "lua_redis" +local lua_settings = require "lua_settings" +local lua_util = require "lua_util" +local meta_functions = require "lua_meta" +local rspamd_kann = require "rspamd_kann" +local rspamd_logger = require "rspamd_logger" +local rspamd_tensor = require "rspamd_tensor" +local rspamd_util = require "rspamd_util" +local ucl = require "ucl" + +local N = 'neural' + +-- Used in prefix to avoid wrong ANN to be loaded +local plugin_ver = '2' + +-- Module vars +local default_options = { + train = { + max_trains = 1000, + max_epoch = 1000, + max_usages = 10, + max_iterations = 25, -- Torch style + mse = 0.001, + autotrain = true, + train_prob = 1.0, + learn_threads = 1, + learn_mode = 'balanced', -- Possible values: balanced, proportional + learning_rate = 0.01, + classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias) + spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1) + ham_skip_prob = 0.0, -- proportional mode: ham skip probability + store_pool_only = false, -- store tokens in cache only (disables autotrain); + -- neural_vec_mpack stores vector of training data in messagepack neural_profile_digest stores profile digest + }, + watch_interval = 60.0, + lock_expire = 600, + learning_spawned = false, + ann_expire = 60 * 60 * 24 * 2, -- 2 days + hidden_layer_mult = 1.5, -- number of neurons in the hidden layer + roc_enabled = false, -- Use ROC to find the best possible thresholds for ham and spam. If spam_score_threshold or ham_score_threshold is defined, it takes precedence over ROC thresholds. + roc_misclassification_cost = 0.5, -- Cost of misclassifying a spam message (must be 0..1). + spam_score_threshold = nil, -- neural score threshold for spam (must be 0..1 or nil to disable) + ham_score_threshold = nil, -- neural score threshold for ham (must be 0..1 or nil to disable) + flat_threshold_curve = false, -- use binary classification 0/1 when threshold is reached + symbol_spam = 'NEURAL_SPAM', + symbol_ham = 'NEURAL_HAM', + max_inputs = nil, -- when PCA is used + blacklisted_symbols = {}, -- list of symbols skipped in neural processing +} + +-- Rule structure: +-- * static config fields (see `default_options`) +-- * prefix - name or defined prefix +-- * settings - table of settings indexed by settings id, -1 is used when no settings defined + +-- Rule settings element defines elements for specific settings id: +-- * symbols - static symbols profile (defined by config or extracted from symcache) +-- * name - name of settings id +-- * digest - digest of all symbols +-- * ann - dynamic ANN configuration loaded from Redis +-- * train - train data for ANN (e.g. the currently trained ANN) + +-- Settings ANN table is loaded from Redis and represents dynamic profile for ANN +-- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically +-- * version - version of ANN loaded from redis +-- * redis_key - name of ANN key in Redis +-- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols) +-- * distance - distance between set.symbols and set.ann.symbols +-- * ann - kann object + +local settings = { + rules = {}, + prefix = 'rn', -- Neural network default prefix + max_profiles = 3, -- Maximum number of NN profiles stored +} + +-- Get module & Redis configuration +local module_config = rspamd_config:get_all_opt(N) +settings = lua_util.override_defaults(settings, module_config) +local redis_params = lua_redis.parse_redis_server('neural') + +local redis_lua_script_vectors_len = "neural_train_size.lua" +local redis_lua_script_maybe_invalidate = "neural_maybe_invalidate.lua" +local redis_lua_script_maybe_lock = "neural_maybe_lock.lua" +local redis_lua_script_save_unlock = "neural_save_unlock.lua" + +local redis_script_id = {} + +local function load_scripts() + redis_script_id.vectors_len = lua_redis.load_redis_script_from_file(redis_lua_script_vectors_len, + redis_params) + redis_script_id.maybe_invalidate = lua_redis.load_redis_script_from_file(redis_lua_script_maybe_invalidate, + redis_params) + redis_script_id.maybe_lock = lua_redis.load_redis_script_from_file(redis_lua_script_maybe_lock, + redis_params) + redis_script_id.save_unlock = lua_redis.load_redis_script_from_file(redis_lua_script_save_unlock, + redis_params) +end + +local function create_ann(n, nlayers, rule) + -- We ignore number of layers so far when using kann + local nhidden = math.floor(n * (rule.hidden_layer_mult or 1.0) + 1.0) + local t = rspamd_kann.layer.input(n) + t = rspamd_kann.transform.relu(t) + t = rspamd_kann.layer.dense(t, nhidden); + t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.ceb_neg) + return rspamd_kann.new.kann(t) +end + +-- Fills ANN data for a specific settings element +local function fill_set_ann(set, ann_key) + if not set.ann then + set.ann = { + symbols = set.symbols, + distance = 0, + digest = set.digest, + redis_key = ann_key, + version = 0, + } + end +end + +-- This function takes all inputs, applies PCA transformation and returns the final +-- PCA matrix as rspamd_tensor +local function learn_pca(inputs, max_inputs) + local scatter_matrix = rspamd_tensor.scatter_matrix(rspamd_tensor.fromtable(inputs)) + local eigenvals = scatter_matrix:eigen() + -- scatter matrix is not filled with eigenvectors + lua_util.debugm(N, 'eigenvalues: %s', eigenvals) + local w = rspamd_tensor.new(2, max_inputs, #scatter_matrix[1]) + for i = 1, max_inputs do + w[i] = scatter_matrix[#scatter_matrix - i + 1] + end + + lua_util.debugm(N, 'pca matrix: %s', w) + + return w +end + +-- This function computes optimal threshold using ROC for the given set of inputs. +-- Returns a threshold that minimizes: +-- alpha * (false_positive_rate) + beta * (false_negative_rate) +-- Where alpha is cost of false positive result +-- beta is cost of false negative result +local function get_roc_thresholds(ann, inputs, outputs, alpha, beta) + + -- Sorts list x and list y based on the values in list x. + local sort_relative = function(x, y) + + local r = {} + + assert(#x == #y) + local n = #x + + local a = {} + local b = {} + for i = 1, n do + r[i] = i + end + + local cmp = function(p, q) + return p < q + end + + table.sort(r, function(p, q) + return cmp(x[p], x[q]) + end) + + for i = 1, n do + a[i] = x[r[i]] + b[i] = y[r[i]] + end + + return a, b + end + + local function get_scores(nn, input_vectors) + local scores = {} + for i = 1, #inputs do + local score = nn:apply1(input_vectors[i], nn.pca)[1] + scores[#scores + 1] = score + end + + return scores + end + + local fpr = {} + local fnr = {} + local scores = get_scores(ann, inputs) + + scores, outputs = sort_relative(scores, outputs) + + local n_samples = #outputs + local n_spam = 0 + local n_ham = 0 + local ham_count_ahead = {} + local spam_count_ahead = {} + local ham_count_behind = {} + local spam_count_behind = {} + + ham_count_ahead[n_samples + 1] = 0 + spam_count_ahead[n_samples + 1] = 0 + + for i = n_samples, 1, -1 do + + if outputs[i][1] == 0 then + n_ham = n_ham + 1 + ham_count_ahead[i] = 1 + spam_count_ahead[i] = 0 + else + n_spam = n_spam + 1 + ham_count_ahead[i] = 0 + spam_count_ahead[i] = 1 + end + + ham_count_ahead[i] = ham_count_ahead[i] + ham_count_ahead[i + 1] + spam_count_ahead[i] = spam_count_ahead[i] + spam_count_ahead[i + 1] + end + + for i = 1, n_samples do + if outputs[i][1] == 0 then + ham_count_behind[i] = 1 + spam_count_behind[i] = 0 + else + ham_count_behind[i] = 0 + spam_count_behind[i] = 1 + end + + if i ~= 1 then + ham_count_behind[i] = ham_count_behind[i] + ham_count_behind[i - 1] + spam_count_behind[i] = spam_count_behind[i] + spam_count_behind[i - 1] + end + end + + for i = 1, n_samples do + fpr[i] = 0 + fnr[i] = 0 + + if (ham_count_ahead[i + 1] + ham_count_behind[i]) ~= 0 then + fpr[i] = ham_count_ahead[i + 1] / (ham_count_ahead[i + 1] + ham_count_behind[i]) + end + + if (spam_count_behind[i] + spam_count_ahead[i + 1]) ~= 0 then + fnr[i] = spam_count_behind[i] / (spam_count_behind[i] + spam_count_ahead[i + 1]) + end + end + + local p = n_spam / (n_spam + n_ham) + + local cost = {} + local min_cost_idx = 0 + local min_cost = math.huge + for i = 1, n_samples do + cost[i] = ((1 - p) * alpha * fpr[i]) + (p * beta * fnr[i]) + if min_cost >= cost[i] then + min_cost = cost[i] + min_cost_idx = i + end + end + + return scores[min_cost_idx] +end + +-- This function is intended to extend lock for ANN during training +-- It registers periodic that increases locked key each 30 seconds unless +-- `set.learning_spawned` is set to `true` +local function register_lock_extender(rule, set, ev_base, ann_key) + rspamd_config:add_periodic(ev_base, 30.0, + function() + local function redis_lock_extend_cb(err, _) + if err then + rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s', + ann_key, err) + else + rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds', + ann_key) + end + end + + if set.learning_spawned then + lua_redis.redis_make_request_taskless(ev_base, + rspamd_config, + rule.redis, + nil, + true, -- is write + redis_lock_extend_cb, --callback + 'HINCRBY', -- command + { ann_key, 'lock', '30' } + ) + else + lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false") + return false -- do not plan any more updates + end + + return true + end + ) +end + +local function can_push_train_vector(rule, task, learn_type, nspam, nham) + local train_opts = rule.train + local coin = math.random() + + if train_opts.train_prob and coin < 1.0 - train_opts.train_prob then + rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin) + return false + end + + if train_opts.learn_mode == 'balanced' then + -- Keep balanced training set based on number of spam and ham samples + if learn_type == 'spam' then + if nspam <= train_opts.max_trains then + if nspam > nham then + -- Apply sampling + local skip_rate = 1.0 - nham / (nspam + 1) + if coin < skip_rate - train_opts.classes_bias then + rspamd_logger.infox(task, + 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored', + learn_type, + skip_rate - train_opts.classes_bias, + nspam, nham) + return false + end + end + return true + else + -- Enough learns + rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s', + learn_type, + nspam) + end + else + if nham <= train_opts.max_trains then + if nham > nspam then + -- Apply sampling + local skip_rate = 1.0 - nspam / (nham + 1) + if coin < skip_rate - train_opts.classes_bias then + rspamd_logger.infox(task, + 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored', + learn_type, + skip_rate - train_opts.classes_bias, + nspam, nham) + return false + end + end + return true + else + rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many ham samples: %s', learn_type, + nham) + end + end + else + -- Probabilistic learn mode, we just skip learn if we already have enough samples or + -- if our coin drop is less than desired probability + if learn_type == 'spam' then + if nspam <= train_opts.max_trains then + if train_opts.spam_skip_prob then + if coin <= train_opts.spam_skip_prob then + rspamd_logger.infox(task, 'skip %s sample probabilistically; probability %s (%s skip chance)', learn_type, + coin, train_opts.spam_skip_prob) + return false + end + + return true + end + else + rspamd_logger.infox(task, 'skip %s sample; too many spam samples: %s (%s limit)', learn_type, + nspam, train_opts.max_trains) + end + else + if nham <= train_opts.max_trains then + if train_opts.ham_skip_prob then + if coin <= train_opts.ham_skip_prob then + rspamd_logger.infox(task, 'skip %s sample probabilistically; probability %s (%s skip chance)', learn_type, + coin, train_opts.ham_skip_prob) + return false + end + + return true + end + else + rspamd_logger.infox(task, 'skip %s sample; too many ham samples: %s (%s limit)', learn_type, + nham, train_opts.max_trains) + end + end + end + + return false +end + +-- Closure generator for unlock function +local function gen_unlock_cb(rule, set, ann_key) + return function(err) + if err then + rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s', + rule.prefix, set.name, ann_key, err) + else + lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s', + rule.prefix, set.name, ann_key) + end + end +end + +-- Used to generate new ANN key for specific profile +local function new_ann_key(rule, set, version) + local ann_key = string.format('%s_%s_%s_%s_%s', settings.prefix, + rule.prefix, set.name, set.digest:sub(1, 8), tostring(version)) + + return ann_key +end + +local function redis_ann_prefix(rule, settings_name) + -- We also need to count metatokens: + local n = meta_functions.version + return string.format('%s%d_%s_%d_%s', + settings.prefix, plugin_ver, rule.prefix, n, settings_name) +end + +-- This function receives training vectors, checks them, spawn learning and saves ANN in Redis +local function spawn_train(params) + -- Check training data sanity + -- Now we need to join inputs and create the appropriate test vectors + local n = #params.set.symbols + + meta_functions.rspamd_count_metatokens() + + -- Now we can train ann + local train_ann = create_ann(params.rule.max_inputs or n, 3, params.rule) + + if #params.ham_vec + #params.spam_vec < params.rule.train.max_trains / 2 then + -- Invalidate ANN as it is definitely invalid + -- TODO: add invalidation + assert(false) + else + local inputs, outputs = {}, {} + + -- Used to show parsed vectors in a convenient format (for debugging only) + local function debug_vec(t) + local ret = {} + for i, v in ipairs(t) do + if v ~= 0 then + ret[#ret + 1] = string.format('%d=%.2f', i, v) + end + end + + return ret + end + + -- Make training set by joining vectors + -- KANN automatically shuffles those samples + -- 1.0 is used for spam and -1.0 is used for ham + -- It implies that output layer can express that (e.g. tanh output) + for _, e in ipairs(params.spam_vec) do + inputs[#inputs + 1] = e + outputs[#outputs + 1] = { 1.0 } + --rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e)) + end + for _, e in ipairs(params.ham_vec) do + inputs[#inputs + 1] = e + outputs[#outputs + 1] = { -1.0 } + --rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e)) + end + + -- Called in child process + local function train() + local log_thresh = params.rule.train.max_iterations / 10 + local seen_nan = false + + local function train_cb(iter, train_cost, value_cost) + if (iter * (params.rule.train.max_iterations / log_thresh)) % (params.rule.train.max_iterations) == 0 then + if train_cost ~= train_cost and not seen_nan then + -- We have nan :( try to log lot's of stuff to dig into a problem + seen_nan = true + rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s', + params.rule.prefix, params.set.name, + value_cost) + for i, e in ipairs(inputs) do + lua_util.debugm(N, rspamd_config, 'train vector %s -> %s', + debug_vec(e), outputs[i][1]) + end + end + + rspamd_logger.infox(rspamd_config, + "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s", + params.rule.prefix, params.set.name, + params.ann_key, + iter, + train_cost, + value_cost) + end + end + + lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started", + params.rule.prefix, params.set.name) + + local pca + if params.rule.max_inputs then + -- Train PCA in the main process, presumably it is not that long + lua_util.debugm(N, rspamd_config, "start PCA train for ANN %s:%s", + params.rule.prefix, params.set.name) + pca = learn_pca(inputs, params.rule.max_inputs) + end + + lua_util.debugm(N, rspamd_config, "start neural train for ANN %s:%s", + params.rule.prefix, params.set.name) + local ret, err = pcall(train_ann.train1, train_ann, + inputs, outputs, { + lr = params.rule.train.learning_rate, + max_epoch = params.rule.train.max_iterations, + cb = train_cb, + pca = pca + }) + + if not ret then + rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s", + params.rule.prefix, params.set.name, err) + + return nil + else + lua_util.debugm(N, rspamd_config, "finished neural train for ANN %s:%s", + params.rule.prefix, params.set.name) + end + + local roc_thresholds = {} + if params.rule.roc_enabled then + local spam_threshold = get_roc_thresholds(train_ann, + inputs, + outputs, + 1 - params.rule.roc_misclassification_cost, + params.rule.roc_misclassification_cost) + local ham_threshold = get_roc_thresholds(train_ann, + inputs, + outputs, + params.rule.roc_misclassification_cost, + 1 - params.rule.roc_misclassification_cost) + roc_thresholds = { spam_threshold, ham_threshold } + + rspamd_logger.messagex("ROC thresholds: (spam_threshold: %s, ham_threshold: %s)", + roc_thresholds[1], roc_thresholds[2]) + end + + if not seen_nan then + -- Convert to strings as ucl cannot rspamd_text properly + local pca_data + if pca then + pca_data = tostring(pca:save()) + end + local out = { + ann_data = tostring(train_ann:save()), + pca_data = pca_data, + roc_thresholds = roc_thresholds, + } + + local final_data = ucl.to_format(out, 'msgpack') + lua_util.debugm(N, rspamd_config, "subprocess for ANN %s:%s returned %s bytes", + params.rule.prefix, params.set.name, #final_data) + return final_data + else + return nil + end + end + + params.set.learning_spawned = true + + local function redis_save_cb(err) + if err then + rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s', + params.rule.prefix, params.set.name, params.ann_key, err) + lua_redis.redis_make_request_taskless(params.ev_base, + rspamd_config, + params.rule.redis, + nil, + false, -- is write + gen_unlock_cb(params.rule, params.set, params.ann_key), --callback + 'HDEL', -- command + { params.ann_key, 'lock' } + ) + else + rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s', + params.rule.prefix, params.set.name, params.set.ann.redis_key) + end + end + + local function ann_trained(err, data) + params.set.learning_spawned = false + if err then + rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s', + params.rule.prefix, params.set.name, err) + lua_redis.redis_make_request_taskless(params.ev_base, + rspamd_config, + params.rule.redis, + nil, + true, -- is write + gen_unlock_cb(params.rule, params.set, params.ann_key), --callback + 'HDEL', -- command + { params.ann_key, 'lock' } + ) + else + local parser = ucl.parser() + local ok, parse_err = parser:parse_text(data, 'msgpack') + assert(ok, parse_err) + local parsed = parser:get_object() + local ann_data = rspamd_util.zstd_compress(parsed.ann_data) + local pca_data = parsed.pca_data + local roc_thresholds = parsed.roc_thresholds + + fill_set_ann(params.set, params.ann_key) + if pca_data then + params.set.ann.pca = rspamd_tensor.load(pca_data) + pca_data = rspamd_util.zstd_compress(pca_data) + end + + if roc_thresholds then + params.set.ann.roc_thresholds = roc_thresholds + end + + + -- Deserialise ANN from the child process + ann_trained = rspamd_kann.load(parsed.ann_data) + local version = (params.set.ann.version or 0) + 1 + params.set.ann.version = version + params.set.ann.ann = ann_trained + params.set.ann.symbols = params.set.symbols + params.set.ann.redis_key = new_ann_key(params.rule, params.set, version) + + local profile = { + symbols = params.set.symbols, + digest = params.set.digest, + redis_key = params.set.ann.redis_key, + version = version + } + + local profile_serialized = ucl.to_format(profile, 'json-compact', true) + local roc_thresholds_serialized = ucl.to_format(roc_thresholds, 'json-compact', true) + + rspamd_logger.infox(rspamd_config, + 'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)', + params.rule.prefix, params.set.name, + #data, #ann_data, + #(params.set.ann.pca or {}), #(pca_data or {}), + params.set.ann.redis_key, params.ann_key) + + lua_redis.exec_redis_script(redis_script_id.save_unlock, + { ev_base = params.ev_base, is_write = true }, + redis_save_cb, + { profile.redis_key, + redis_ann_prefix(params.rule, params.set.name), + ann_data, + profile_serialized, + tostring(params.rule.ann_expire), + tostring(os.time()), + params.ann_key, -- old key to unlock... + roc_thresholds_serialized, + pca_data, + }) + end + end + + if params.rule.max_inputs then + fill_set_ann(params.set, params.ann_key) + end + + params.worker:spawn_process { + func = train, + on_complete = ann_trained, + proctitle = string.format("ANN train for %s/%s", params.rule.prefix, params.set.name), + } + -- Spawn learn and register lock extension + params.set.learning_spawned = true + register_lock_extender(params.rule, params.set, params.ev_base, params.ann_key) + return + + end +end + +-- This function is used to adjust profiles and allowed setting ids for each rule +-- It must be called when all settings are already registered (e.g. at post-init for config) +local function process_rules_settings() + local function process_settings_elt(rule, selt) + local profile = rule.profile[selt.name] + if profile then + -- Use static user defined profile + -- Ensure that we have an array... + lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s", + rule.prefix, selt.name, profile) + if not profile[1] then + profile = lua_util.keys(profile) + end + selt.symbols = profile + else + lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)", + rule.prefix, selt.name) + end + + local function filter_symbols_predicate(sname) + if settings.blacklisted_symbols and settings.blacklisted_symbols[sname] then + return false + end + local fl = rspamd_config:get_symbol_flags(sname) + if fl then + fl = lua_util.list_to_hash(fl) + + return not (fl.nostat or fl.idempotent or fl.skip or fl.composite) + end + + return false + end + + -- Generic stuff + if not profile then + -- Do filtering merely if we are using a dynamic profile + selt.symbols = fun.totable(fun.filter(filter_symbols_predicate, selt.symbols)) + end + + table.sort(selt.symbols) + + selt.digest = lua_util.table_digest(selt.symbols) + selt.prefix = redis_ann_prefix(rule, selt.name) + + rspamd_logger.messagex(rspamd_config, + 'use NN prefix for rule %s; settings id "%s"; symbols digest: "%s"', + selt.prefix, selt.name, selt.digest) + + lua_redis.register_prefix(selt.prefix, N, + string.format('NN prefix for rule "%s"; settings id "%s"', + selt.prefix, selt.name), { + persistent = true, + type = 'zlist', + }) + -- Versions + lua_redis.register_prefix(selt.prefix .. '_\\d+', N, + string.format('NN storage for rule "%s"; settings id "%s"', + selt.prefix, selt.name), { + persistent = true, + type = 'hash', + }) + lua_redis.register_prefix(selt.prefix .. '_\\d+_spam_set', N, + string.format('NN learning set (spam) for rule "%s"; settings id "%s"', + selt.prefix, selt.name), { + persistent = true, + type = 'set', + }) + lua_redis.register_prefix(selt.prefix .. '_\\d+_ham_set', N, + string.format('NN learning set (spam) for rule "%s"; settings id "%s"', + rule.prefix, selt.name), { + persistent = true, + type = 'set', + }) + end + + for k, rule in pairs(settings.rules) do + if not rule.allowed_settings then + rule.allowed_settings = {} + elseif rule.allowed_settings == 'all' then + -- Extract all settings ids + rule.allowed_settings = lua_util.keys(lua_settings.all_settings()) + end + + -- Convert to a map <setting_id> -> true + rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings) + + -- Check if we can work without settings + if k == 'default' or type(rule.default) ~= 'boolean' then + rule.default = true + end + + rule.settings = {} + + if rule.default then + local default_settings = { + symbols = lua_settings.default_symbols(), + name = 'default' + } + + process_settings_elt(rule, default_settings) + rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32 + end + + -- Now, for each allowed settings, we store sorted symbols + digest + -- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest } + for s, _ in pairs(rule.allowed_settings) do + -- Here, we have a name, set of symbols and + local settings_id = s + if type(settings_id) ~= 'number' then + settings_id = lua_settings.numeric_settings_id(s) + end + local selt = lua_settings.settings_by_id(settings_id) + + local nelt = { + symbols = selt.symbols, -- Already sorted + name = selt.name + } + + process_settings_elt(rule, nelt) + for id, ex in pairs(rule.settings) do + if type(ex) == 'table' then + if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then + -- Equal symbols, add reference + lua_util.debugm(N, rspamd_config, + 'added reference from settings id %s to %s; same symbols', + nelt.name, ex.name) + rule.settings[settings_id] = id + nelt = nil + end + end + end + + if nelt then + rule.settings[settings_id] = nelt + lua_util.debugm(N, rspamd_config, 'added new settings id %s(%s) to %s', + nelt.name, settings_id, rule.prefix) + end + end + end +end + +-- Extract settings element for a specific settings id +local function get_rule_settings(task, rule) + local sid = task:get_settings_id() or -1 + local set = rule.settings[sid] + + if not set then + return nil + end + + while type(set) == 'number' do + -- Reference to another settings! + set = rule.settings[set] + end + + return set +end + +local function result_to_vector(task, profile) + if not profile.zeros then + -- Fill zeros vector + local zeros = {} + for i = 1, meta_functions.count_metatokens() do + zeros[i] = 0.0 + end + for _, _ in ipairs(profile.symbols) do + zeros[#zeros + 1] = 0.0 + end + profile.zeros = zeros + end + + local vec = lua_util.shallowcopy(profile.zeros) + local mt = meta_functions.rspamd_gen_metatokens(task) + + for i, v in ipairs(mt) do + vec[i] = v + end + + task:process_ann_tokens(profile.symbols, vec, #mt, 0.1) + + return vec +end + +return { + can_push_train_vector = can_push_train_vector, + create_ann = create_ann, + default_options = default_options, + gen_unlock_cb = gen_unlock_cb, + get_rule_settings = get_rule_settings, + load_scripts = load_scripts, + module_config = module_config, + new_ann_key = new_ann_key, + plugin_ver = plugin_ver, + process_rules_settings = process_rules_settings, + redis_ann_prefix = redis_ann_prefix, + redis_params = redis_params, + redis_script_id = redis_script_id, + result_to_vector = result_to_vector, + settings = settings, + spawn_train = spawn_train, +} diff --git a/lualib/plugins/rbl.lua b/lualib/plugins/rbl.lua new file mode 100644 index 0000000..af5d6bd --- /dev/null +++ b/lualib/plugins/rbl.lua @@ -0,0 +1,232 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local ts = require("tableshape").types +local lua_maps = require "lua_maps" +local lua_util = require "lua_util" + +-- Common RBL plugin definitions + +local check_types = { + from = { + connfilter = true, + }, + received = {}, + helo = { + connfilter = true, + }, + urls = {}, + content_urls = {}, + numeric_urls = {}, + emails = {}, + replyto = {}, + dkim = {}, + rdns = { + connfilter = true, + }, + selector = { + require_argument = true, + }, +} + +local default_options = { + ['default_enabled'] = true, + ['default_ipv4'] = true, + ['default_ipv6'] = true, + ['default_unknown'] = false, + ['default_dkim_domainonly'] = true, + ['default_emails_domainonly'] = false, + ['default_exclude_users'] = false, + ['default_exclude_local'] = true, + ['default_no_ip'] = false, + ['default_dkim_match_from'] = false, + ['default_selector_flatten'] = true, +} + +local return_codes_schema = ts.map_of( + ts.string / string.upper, -- Symbol name + ( + ts.array_of(ts.string) + + (ts.string / function(s) + return { s } + end) -- List of IP patterns + ) +) +local return_bits_schema = ts.map_of( + ts.string / string.upper, -- Symbol name + ( + ts.array_of(ts.number + ts.string / tonumber) + + (ts.string / function(s) + return { tonumber(s) } + end) + + (ts.number / function(s) + return { s } + end) + ) +) + +local rule_schema_tbl = { + content_urls = ts.boolean:is_optional(), + disable_monitoring = ts.boolean:is_optional(), + disabled = ts.boolean:is_optional(), + dkim = ts.boolean:is_optional(), + dkim_domainonly = ts.boolean:is_optional(), + dkim_match_from = ts.boolean:is_optional(), + emails = ts.boolean:is_optional(), + emails_delimiter = ts.string:is_optional(), + emails_domainonly = ts.boolean:is_optional(), + enabled = ts.boolean:is_optional(), + exclude_local = ts.boolean:is_optional(), + exclude_users = ts.boolean:is_optional(), + from = ts.boolean:is_optional(), + hash = ts.one_of { "sha1", "sha256", "sha384", "sha512", "md5", "blake2" }:is_optional(), + hash_format = ts.one_of { "hex", "base32", "base64" }:is_optional(), + hash_len = (ts.integer + ts.string / tonumber):is_optional(), + helo = ts.boolean:is_optional(), + ignore_default = ts.boolean:is_optional(), -- alias + ignore_defaults = ts.boolean:is_optional(), + ignore_url_whitelist = ts.boolean:is_optional(), + ignore_whitelist = ts.boolean:is_optional(), + ignore_whitelists = ts.boolean:is_optional(), -- alias + images = ts.boolean:is_optional(), + ipv4 = ts.boolean:is_optional(), + ipv6 = ts.boolean:is_optional(), + is_whitelist = ts.boolean:is_optional(), + local_exclude_ip_map = ts.string:is_optional(), + monitored_address = ts.string:is_optional(), + no_ip = ts.boolean:is_optional(), + process_script = ts.string:is_optional(), + random_monitored = ts.boolean:is_optional(), + rbl = ts.string, + rdns = ts.boolean:is_optional(), + received = ts.boolean:is_optional(), + received_flags = ts.array_of(ts.string):is_optional(), + received_max_pos = ts.number:is_optional(), + received_min_pos = ts.number:is_optional(), + received_nflags = ts.array_of(ts.string):is_optional(), + replyto = ts.boolean:is_optional(), + requests_limit = (ts.integer + ts.string / tonumber):is_optional(), + require_symbols = ( + ts.array_of(ts.string) + (ts.string / function(s) + return { s } + end) + ):is_optional(), + resolve_ip = ts.boolean:is_optional(), + return_bits = return_bits_schema:is_optional(), + return_codes = return_codes_schema:is_optional(), + returnbits = return_bits_schema:is_optional(), + returncodes = return_codes_schema:is_optional(), + returncodes_matcher = ts.one_of { "equality", "glob", "luapattern", "radix", "regexp" }:is_optional(), + selector = ts.one_of { ts.string, ts.table }:is_optional(), + selector_flatten = ts.boolean:is_optional(), + symbol = ts.string:is_optional(), + symbols_prefixes = ts.map_of(ts.string, ts.string):is_optional(), + unknown = ts.boolean:is_optional(), + url_compose_map = lua_maps.map_schema:is_optional(), + url_full_hostname = ts.boolean:is_optional(), + url_whitelist = lua_maps.map_schema:is_optional(), + urls = ts.boolean:is_optional(), + whitelist = lua_maps.map_schema:is_optional(), + whitelist_exception = ( + ts.array_of(ts.string) + (ts.string / function(s) + return { s } + end) + ):is_optional(), + checks = ts.array_of(ts.one_of(lua_util.keys(check_types))):is_optional(), + exclude_checks = ts.array_of(ts.one_of(lua_util.keys(check_types))):is_optional(), +} + +local function convert_checks(rule, name) + local rspamd_logger = require "rspamd_logger" + if rule.checks then + local all_connfilter = true + local exclude_checks = lua_util.list_to_hash(rule.exclude_checks or {}) + for _, check in ipairs(rule.checks) do + if not exclude_checks[check] then + local check_type = check_types[check] + if check_type.require_argument then + if not rule[check] then + rspamd_logger.errx(rspamd_config, 'rbl rule %s has check %s which requires an argument', + name, check) + return nil + end + end + + rule[check] = check_type + + if not check_type.connfilter then + all_connfilter = false + end + + if not check_type then + rspamd_logger.errx(rspamd_config, 'rbl rule %s has invalid check type: %s', + name, check) + return nil + end + else + rspamd_logger.infox(rspamd_config, 'disable check %s in %s: excluded explicitly', + check, name) + end + end + rule.connfilter = all_connfilter + end + + -- Now check if we have any check enabled at all + local check_found = false + for k, _ in pairs(check_types) do + if type(rule[k]) ~= 'nil' then + check_found = true + break + end + end + + if not check_found then + -- Enable implicit `from` check to allow upgrade + rspamd_logger.warnx(rspamd_config, 'rbl rule %s has no check enabled, enable default `from` check', + name) + rule.from = true + end + + if rule.returncodes and not rule.returncodes_matcher then + for _, v in pairs(rule.returncodes) do + for _, e in ipairs(v) do + if e:find('[%%%[]') then + rspamd_logger.warn(rspamd_config, 'implicitly enabling luapattern returncodes_matcher for rule %s', name) + rule.returncodes_matcher = 'luapattern' + break + end + end + if rule.returncodes_matcher then + break + end + end + end + + return rule +end + + +-- Add default boolean flags to the schema +for def_k, _ in pairs(default_options) do + rule_schema_tbl[def_k:sub(#('default_') + 1)] = ts.boolean:is_optional() +end + +return { + check_types = check_types, + rule_schema = ts.shape(rule_schema_tbl), + default_options = default_options, + convert_checks = convert_checks, +} |