summaryrefslogtreecommitdiffstats
path: root/lualib/plugins
diff options
context:
space:
mode:
Diffstat (limited to 'lualib/plugins')
-rw-r--r--lualib/plugins/dmarc.lua359
-rw-r--r--lualib/plugins/neural.lua892
-rw-r--r--lualib/plugins/rbl.lua232
3 files changed, 1483 insertions, 0 deletions
diff --git a/lualib/plugins/dmarc.lua b/lualib/plugins/dmarc.lua
new file mode 100644
index 0000000..7791f4e
--- /dev/null
+++ b/lualib/plugins/dmarc.lua
@@ -0,0 +1,359 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+Copyright (c) 2015-2016, Andrew Lewis <nerf@judo.za.org>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+-- Common dmarc stuff
+local rspamd_logger = require "rspamd_logger"
+local lua_util = require "lua_util"
+local N = "dmarc"
+
+local exports = {}
+
+exports.default_settings = {
+ auth_and_local_conf = false,
+ symbols = {
+ spf_allow_symbol = 'R_SPF_ALLOW',
+ spf_deny_symbol = 'R_SPF_FAIL',
+ spf_softfail_symbol = 'R_SPF_SOFTFAIL',
+ spf_neutral_symbol = 'R_SPF_NEUTRAL',
+ spf_tempfail_symbol = 'R_SPF_DNSFAIL',
+ spf_permfail_symbol = 'R_SPF_PERMFAIL',
+ spf_na_symbol = 'R_SPF_NA',
+
+ dkim_allow_symbol = 'R_DKIM_ALLOW',
+ dkim_deny_symbol = 'R_DKIM_REJECT',
+ dkim_tempfail_symbol = 'R_DKIM_TEMPFAIL',
+ dkim_na_symbol = 'R_DKIM_NA',
+ dkim_permfail_symbol = 'R_DKIM_PERMFAIL',
+
+ -- DMARC symbols
+ allow = 'DMARC_POLICY_ALLOW',
+ badpolicy = 'DMARC_BAD_POLICY',
+ dnsfail = 'DMARC_DNSFAIL',
+ na = 'DMARC_NA',
+ reject = 'DMARC_POLICY_REJECT',
+ softfail = 'DMARC_POLICY_SOFTFAIL',
+ quarantine = 'DMARC_POLICY_QUARANTINE',
+ },
+ no_sampling_domains = nil,
+ no_reporting_domains = nil,
+ reporting = {
+ report_local_controller = false, -- Store reports for local/controller scans (for testing only)
+ redis_keys = {
+ index_prefix = 'dmarc_idx',
+ report_prefix = 'dmarc_rpt',
+ join_char = ';',
+ },
+ helo = 'rspamd.localhost',
+ smtp = '127.0.0.1',
+ smtp_port = 25,
+ retries = 2,
+ from_name = 'Rspamd',
+ msgid_from = 'rspamd',
+ enabled = false,
+ max_entries = 1000,
+ keys_expire = 172800,
+ only_domains = nil,
+ },
+ actions = {},
+}
+
+
+-- Returns a key used to be inserted into dmarc report sample
+exports.dmarc_report = function(task, settings, data)
+ local rspamd_lua_utils = require "lua_util"
+ local E = {}
+
+ local ip = task:get_from_ip()
+ if not ip or not ip:is_valid() then
+ rspamd_logger.infox(task, 'cannot store dmarc report for %s: no valid source IP',
+ data.domain)
+ return nil
+ end
+
+ ip = ip:to_string()
+
+ if rspamd_lua_utils.is_rspamc_or_controller(task) and not settings.reporting.report_local_controller then
+ rspamd_logger.infox(task, 'cannot store dmarc report for %s from IP %s: has come from controller/rspamc',
+ data.domain, ip)
+ return
+ end
+
+ local dkim_pass = table.concat(data.dkim_results.pass or E, '|')
+ local dkim_fail = table.concat(data.dkim_results.fail or E, '|')
+ local dkim_temperror = table.concat(data.dkim_results.temperror or E, '|')
+ local dkim_permerror = table.concat(data.dkim_results.permerror or E, '|')
+ local disposition_to_return = data.disposition
+ local res = table.concat({
+ ip, data.spf_ok, data.dkim_ok,
+ disposition_to_return, (data.sampled_out and 'sampled_out' or ''), data.domain,
+ dkim_pass, dkim_fail, dkim_temperror, dkim_permerror, data.spf_domain, data.spf_result }, ',')
+
+ return res
+end
+
+exports.gen_munging_callback = function(munging_opts, settings)
+ local rspamd_util = require "rspamd_util"
+ local lua_mime = require "lua_mime"
+ return function(task)
+ if munging_opts.mitigate_allow_only then
+ if not task:has_symbol(settings.symbols.allow) then
+ lua_util.debugm(N, task, 'skip munging, no %s symbol',
+ settings.symbols.allow)
+ -- Excepted
+ return
+ end
+ else
+ local has_dmarc = task:has_symbol(settings.symbols.allow) or
+ task:has_symbol(settings.symbols.quarantine) or
+ task:has_symbol(settings.symbols.reject) or
+ task:has_symbol(settings.symbols.softfail)
+
+ if not has_dmarc then
+ lua_util.debugm(N, task, 'skip munging, no %s symbol',
+ settings.symbols.allow)
+ -- Excepted
+ return
+ end
+ end
+ if munging_opts.mitigate_strict_only then
+ local s = task:get_symbol(settings.symbols.allow) or { [1] = {} }
+ local sopts = s[1].options or {}
+
+ local seen_strict
+ for _, o in ipairs(sopts) do
+ if o == 'reject' or o == 'quarantine' then
+ seen_strict = true
+ break
+ end
+ end
+
+ if not seen_strict then
+ lua_util.debugm(N, task, 'skip munging, no strict policy found in %s',
+ settings.symbols.allow)
+ -- Excepted
+ return
+ end
+ end
+ if munging_opts.munge_map_condition then
+ local accepted, trace = munging_opts.munge_map_condition:process(task)
+ if not accepted then
+ lua_util.debugm(N, task, 'skip munging, maps condition not satisfied: (%s)',
+ trace)
+ -- Excepted
+ return
+ end
+ end
+ -- Now, look for domain for munging
+ local mr = task:get_recipients({ 'mime', 'orig' })
+ local rcpt_found
+ if mr then
+ for _, r in ipairs(mr) do
+ if r.domain and munging_opts.list_map:get_key(r.addr) then
+ rcpt_found = r
+ break
+ end
+ end
+ end
+
+ if not rcpt_found then
+ lua_util.debugm(N, task, 'skip munging, recipients are not in list_map')
+ -- Excepted
+ return
+ end
+
+ local from = task:get_from({ 'mime', 'orig' })
+
+ if not from or not from[1] then
+ lua_util.debugm(N, task, 'skip munging, from is bad')
+ -- Excepted
+ return
+ end
+
+ from = from[1]
+ local via_user = rcpt_found.user
+ local via_addr = rcpt_found.addr
+ local via_name
+
+ if from.name == "" then
+ via_name = string.format('%s via %s', from.user or 'unknown', via_user)
+ else
+ via_name = string.format('%s via %s', from.name, via_user)
+ end
+
+ local hdr_encoded = rspamd_util.fold_header('From',
+ rspamd_util.mime_header_encode(string.format('%s <%s>',
+ via_name, via_addr)), task:get_newlines_type())
+ local orig_from_encoded = rspamd_util.fold_header('X-Original-From',
+ rspamd_util.mime_header_encode(string.format('%s <%s>',
+ from.name or '', from.addr)), task:get_newlines_type())
+ local add_hdrs = {
+ ['From'] = { order = 1, value = hdr_encoded },
+ }
+ local remove_hdrs = { ['From'] = 0 }
+
+ local nreply = from.addr
+ if munging_opts.reply_goes_to_list then
+ -- Reply-to goes to the list
+ nreply = via_addr
+ end
+
+ if task:has_header('Reply-To') then
+ -- If we have reply-to header, then we need to insert an additional
+ -- address there
+ local orig_reply = task:get_header_full('Reply-To')[1]
+ if orig_reply.value then
+ nreply = string.format('%s, %s', orig_reply.value, nreply)
+ end
+ remove_hdrs['Reply-To'] = 1
+ end
+
+ add_hdrs['Reply-To'] = { order = 0, value = nreply }
+
+ add_hdrs['X-Original-From'] = { order = 0, value = orig_from_encoded }
+ lua_mime.modify_headers(task, {
+ remove = remove_hdrs,
+ add = add_hdrs
+ })
+ lua_util.debugm(N, task, 'munged DMARC header for %s: %s -> %s',
+ from.domain, hdr_encoded, from.addr)
+ rspamd_logger.infox(task, 'munged DMARC header for %s', from.addr)
+ task:insert_result('DMARC_MUNGED', 1.0, from.addr)
+ end
+end
+
+local function gen_dmarc_grammar()
+ local lpeg = require "lpeg"
+ lpeg.locale(lpeg)
+ local space = lpeg.space ^ 0
+ local name = lpeg.C(lpeg.alpha ^ 1) * space
+ local sep = space * (lpeg.S("\\;") * space) + (lpeg.space ^ 1)
+ local value = lpeg.C(lpeg.P(lpeg.graph - sep) ^ 1)
+ local pair = lpeg.Cg(name * "=" * space * value) * sep ^ -1
+ local list = lpeg.Cf(lpeg.Ct("") * pair ^ 0, rawset)
+ local version = lpeg.P("v") * space * lpeg.P("=") * space * lpeg.P("DMARC1")
+ local record = version * sep * list
+
+ return record
+end
+
+local dmarc_grammar = gen_dmarc_grammar()
+
+local function dmarc_key_value_case(elts)
+ if type(elts) ~= "table" then
+ return elts
+ end
+ local result = {}
+ for k, v in pairs(elts) do
+ k = k:lower()
+ if k ~= "v" then
+ v = v:lower()
+ end
+
+ result[k] = v
+ end
+
+ return result
+end
+
+--[[
+-- Used to check dmarc record, check elements and produce dmarc policy processed
+-- result.
+-- Returns:
+-- false,false - record is garbage
+-- false,error_message - record is invalid
+-- true,policy_table - record is valid and parsed
+]]
+local function dmarc_check_record(log_obj, record, is_tld)
+ local failed_policy
+ local result = {
+ dmarc_policy = 'none'
+ }
+
+ local elts = dmarc_grammar:match(record)
+ lua_util.debugm(N, log_obj, "got DMARC record: %s, tld_flag=%s, processed=%s",
+ record, is_tld, elts)
+
+ if elts then
+ elts = dmarc_key_value_case(elts)
+
+ local dkim_pol = elts['adkim']
+ if dkim_pol then
+ if dkim_pol == 's' then
+ result.strict_dkim = true
+ elseif dkim_pol ~= 'r' then
+ failed_policy = 'adkim tag has invalid value: ' .. dkim_pol
+ return false, failed_policy
+ end
+ end
+
+ local spf_pol = elts['aspf']
+ if spf_pol then
+ if spf_pol == 's' then
+ result.strict_spf = true
+ elseif spf_pol ~= 'r' then
+ failed_policy = 'aspf tag has invalid value: ' .. spf_pol
+ return false, failed_policy
+ end
+ end
+
+ local policy = elts['p']
+ if policy then
+ if (policy == 'reject') then
+ result.dmarc_policy = 'reject'
+ elseif (policy == 'quarantine') then
+ result.dmarc_policy = 'quarantine'
+ elseif (policy ~= 'none') then
+ failed_policy = 'p tag has invalid value: ' .. policy
+ return false, failed_policy
+ end
+ end
+
+ -- Adjust policy if we are in tld mode
+ local subdomain_policy = elts['sp']
+ if elts['sp'] and is_tld then
+ result.subdomain_policy = elts['sp']
+
+ if (subdomain_policy == 'reject') then
+ result.dmarc_policy = 'reject'
+ elseif (subdomain_policy == 'quarantine') then
+ result.dmarc_policy = 'quarantine'
+ elseif (subdomain_policy == 'none') then
+ result.dmarc_policy = 'none'
+ elseif (subdomain_policy ~= 'none') then
+ failed_policy = 'sp tag has invalid value: ' .. subdomain_policy
+ return false, failed_policy
+ end
+ end
+ result.pct = elts['pct']
+ if result.pct then
+ result.pct = tonumber(result.pct)
+ end
+
+ if elts.rua then
+ result.rua = elts['rua']
+ end
+ result.raw_elts = elts
+ else
+ return false, false -- Ignore garbage
+ end
+
+ return true, result
+end
+
+exports.dmarc_check_record = dmarc_check_record
+
+return exports
diff --git a/lualib/plugins/neural.lua b/lualib/plugins/neural.lua
new file mode 100644
index 0000000..6e88ef2
--- /dev/null
+++ b/lualib/plugins/neural.lua
@@ -0,0 +1,892 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local fun = require "fun"
+local lua_redis = require "lua_redis"
+local lua_settings = require "lua_settings"
+local lua_util = require "lua_util"
+local meta_functions = require "lua_meta"
+local rspamd_kann = require "rspamd_kann"
+local rspamd_logger = require "rspamd_logger"
+local rspamd_tensor = require "rspamd_tensor"
+local rspamd_util = require "rspamd_util"
+local ucl = require "ucl"
+
+local N = 'neural'
+
+-- Used in prefix to avoid wrong ANN to be loaded
+local plugin_ver = '2'
+
+-- Module vars
+local default_options = {
+ train = {
+ max_trains = 1000,
+ max_epoch = 1000,
+ max_usages = 10,
+ max_iterations = 25, -- Torch style
+ mse = 0.001,
+ autotrain = true,
+ train_prob = 1.0,
+ learn_threads = 1,
+ learn_mode = 'balanced', -- Possible values: balanced, proportional
+ learning_rate = 0.01,
+ classes_bias = 0.0, -- balanced mode: what difference is allowed between classes (1:1 proportion means 0 bias)
+ spam_skip_prob = 0.0, -- proportional mode: spam skip probability (0-1)
+ ham_skip_prob = 0.0, -- proportional mode: ham skip probability
+ store_pool_only = false, -- store tokens in cache only (disables autotrain);
+ -- neural_vec_mpack stores vector of training data in messagepack neural_profile_digest stores profile digest
+ },
+ watch_interval = 60.0,
+ lock_expire = 600,
+ learning_spawned = false,
+ ann_expire = 60 * 60 * 24 * 2, -- 2 days
+ hidden_layer_mult = 1.5, -- number of neurons in the hidden layer
+ roc_enabled = false, -- Use ROC to find the best possible thresholds for ham and spam. If spam_score_threshold or ham_score_threshold is defined, it takes precedence over ROC thresholds.
+ roc_misclassification_cost = 0.5, -- Cost of misclassifying a spam message (must be 0..1).
+ spam_score_threshold = nil, -- neural score threshold for spam (must be 0..1 or nil to disable)
+ ham_score_threshold = nil, -- neural score threshold for ham (must be 0..1 or nil to disable)
+ flat_threshold_curve = false, -- use binary classification 0/1 when threshold is reached
+ symbol_spam = 'NEURAL_SPAM',
+ symbol_ham = 'NEURAL_HAM',
+ max_inputs = nil, -- when PCA is used
+ blacklisted_symbols = {}, -- list of symbols skipped in neural processing
+}
+
+-- Rule structure:
+-- * static config fields (see `default_options`)
+-- * prefix - name or defined prefix
+-- * settings - table of settings indexed by settings id, -1 is used when no settings defined
+
+-- Rule settings element defines elements for specific settings id:
+-- * symbols - static symbols profile (defined by config or extracted from symcache)
+-- * name - name of settings id
+-- * digest - digest of all symbols
+-- * ann - dynamic ANN configuration loaded from Redis
+-- * train - train data for ANN (e.g. the currently trained ANN)
+
+-- Settings ANN table is loaded from Redis and represents dynamic profile for ANN
+-- Some elements are directly stored in Redis, ANN is, in turn loaded dynamically
+-- * version - version of ANN loaded from redis
+-- * redis_key - name of ANN key in Redis
+-- * symbols - symbols in THIS PARTICULAR ANN (might be different from set.symbols)
+-- * distance - distance between set.symbols and set.ann.symbols
+-- * ann - kann object
+
+local settings = {
+ rules = {},
+ prefix = 'rn', -- Neural network default prefix
+ max_profiles = 3, -- Maximum number of NN profiles stored
+}
+
+-- Get module & Redis configuration
+local module_config = rspamd_config:get_all_opt(N)
+settings = lua_util.override_defaults(settings, module_config)
+local redis_params = lua_redis.parse_redis_server('neural')
+
+local redis_lua_script_vectors_len = "neural_train_size.lua"
+local redis_lua_script_maybe_invalidate = "neural_maybe_invalidate.lua"
+local redis_lua_script_maybe_lock = "neural_maybe_lock.lua"
+local redis_lua_script_save_unlock = "neural_save_unlock.lua"
+
+local redis_script_id = {}
+
+local function load_scripts()
+ redis_script_id.vectors_len = lua_redis.load_redis_script_from_file(redis_lua_script_vectors_len,
+ redis_params)
+ redis_script_id.maybe_invalidate = lua_redis.load_redis_script_from_file(redis_lua_script_maybe_invalidate,
+ redis_params)
+ redis_script_id.maybe_lock = lua_redis.load_redis_script_from_file(redis_lua_script_maybe_lock,
+ redis_params)
+ redis_script_id.save_unlock = lua_redis.load_redis_script_from_file(redis_lua_script_save_unlock,
+ redis_params)
+end
+
+local function create_ann(n, nlayers, rule)
+ -- We ignore number of layers so far when using kann
+ local nhidden = math.floor(n * (rule.hidden_layer_mult or 1.0) + 1.0)
+ local t = rspamd_kann.layer.input(n)
+ t = rspamd_kann.transform.relu(t)
+ t = rspamd_kann.layer.dense(t, nhidden);
+ t = rspamd_kann.layer.cost(t, 1, rspamd_kann.cost.ceb_neg)
+ return rspamd_kann.new.kann(t)
+end
+
+-- Fills ANN data for a specific settings element
+local function fill_set_ann(set, ann_key)
+ if not set.ann then
+ set.ann = {
+ symbols = set.symbols,
+ distance = 0,
+ digest = set.digest,
+ redis_key = ann_key,
+ version = 0,
+ }
+ end
+end
+
+-- This function takes all inputs, applies PCA transformation and returns the final
+-- PCA matrix as rspamd_tensor
+local function learn_pca(inputs, max_inputs)
+ local scatter_matrix = rspamd_tensor.scatter_matrix(rspamd_tensor.fromtable(inputs))
+ local eigenvals = scatter_matrix:eigen()
+ -- scatter matrix is not filled with eigenvectors
+ lua_util.debugm(N, 'eigenvalues: %s', eigenvals)
+ local w = rspamd_tensor.new(2, max_inputs, #scatter_matrix[1])
+ for i = 1, max_inputs do
+ w[i] = scatter_matrix[#scatter_matrix - i + 1]
+ end
+
+ lua_util.debugm(N, 'pca matrix: %s', w)
+
+ return w
+end
+
+-- This function computes optimal threshold using ROC for the given set of inputs.
+-- Returns a threshold that minimizes:
+-- alpha * (false_positive_rate) + beta * (false_negative_rate)
+-- Where alpha is cost of false positive result
+-- beta is cost of false negative result
+local function get_roc_thresholds(ann, inputs, outputs, alpha, beta)
+
+ -- Sorts list x and list y based on the values in list x.
+ local sort_relative = function(x, y)
+
+ local r = {}
+
+ assert(#x == #y)
+ local n = #x
+
+ local a = {}
+ local b = {}
+ for i = 1, n do
+ r[i] = i
+ end
+
+ local cmp = function(p, q)
+ return p < q
+ end
+
+ table.sort(r, function(p, q)
+ return cmp(x[p], x[q])
+ end)
+
+ for i = 1, n do
+ a[i] = x[r[i]]
+ b[i] = y[r[i]]
+ end
+
+ return a, b
+ end
+
+ local function get_scores(nn, input_vectors)
+ local scores = {}
+ for i = 1, #inputs do
+ local score = nn:apply1(input_vectors[i], nn.pca)[1]
+ scores[#scores + 1] = score
+ end
+
+ return scores
+ end
+
+ local fpr = {}
+ local fnr = {}
+ local scores = get_scores(ann, inputs)
+
+ scores, outputs = sort_relative(scores, outputs)
+
+ local n_samples = #outputs
+ local n_spam = 0
+ local n_ham = 0
+ local ham_count_ahead = {}
+ local spam_count_ahead = {}
+ local ham_count_behind = {}
+ local spam_count_behind = {}
+
+ ham_count_ahead[n_samples + 1] = 0
+ spam_count_ahead[n_samples + 1] = 0
+
+ for i = n_samples, 1, -1 do
+
+ if outputs[i][1] == 0 then
+ n_ham = n_ham + 1
+ ham_count_ahead[i] = 1
+ spam_count_ahead[i] = 0
+ else
+ n_spam = n_spam + 1
+ ham_count_ahead[i] = 0
+ spam_count_ahead[i] = 1
+ end
+
+ ham_count_ahead[i] = ham_count_ahead[i] + ham_count_ahead[i + 1]
+ spam_count_ahead[i] = spam_count_ahead[i] + spam_count_ahead[i + 1]
+ end
+
+ for i = 1, n_samples do
+ if outputs[i][1] == 0 then
+ ham_count_behind[i] = 1
+ spam_count_behind[i] = 0
+ else
+ ham_count_behind[i] = 0
+ spam_count_behind[i] = 1
+ end
+
+ if i ~= 1 then
+ ham_count_behind[i] = ham_count_behind[i] + ham_count_behind[i - 1]
+ spam_count_behind[i] = spam_count_behind[i] + spam_count_behind[i - 1]
+ end
+ end
+
+ for i = 1, n_samples do
+ fpr[i] = 0
+ fnr[i] = 0
+
+ if (ham_count_ahead[i + 1] + ham_count_behind[i]) ~= 0 then
+ fpr[i] = ham_count_ahead[i + 1] / (ham_count_ahead[i + 1] + ham_count_behind[i])
+ end
+
+ if (spam_count_behind[i] + spam_count_ahead[i + 1]) ~= 0 then
+ fnr[i] = spam_count_behind[i] / (spam_count_behind[i] + spam_count_ahead[i + 1])
+ end
+ end
+
+ local p = n_spam / (n_spam + n_ham)
+
+ local cost = {}
+ local min_cost_idx = 0
+ local min_cost = math.huge
+ for i = 1, n_samples do
+ cost[i] = ((1 - p) * alpha * fpr[i]) + (p * beta * fnr[i])
+ if min_cost >= cost[i] then
+ min_cost = cost[i]
+ min_cost_idx = i
+ end
+ end
+
+ return scores[min_cost_idx]
+end
+
+-- This function is intended to extend lock for ANN during training
+-- It registers periodic that increases locked key each 30 seconds unless
+-- `set.learning_spawned` is set to `true`
+local function register_lock_extender(rule, set, ev_base, ann_key)
+ rspamd_config:add_periodic(ev_base, 30.0,
+ function()
+ local function redis_lock_extend_cb(err, _)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot lock ANN %s from redis: %s',
+ ann_key, err)
+ else
+ rspamd_logger.infox(rspamd_config, 'extend lock for ANN %s for 30 seconds',
+ ann_key)
+ end
+ end
+
+ if set.learning_spawned then
+ lua_redis.redis_make_request_taskless(ev_base,
+ rspamd_config,
+ rule.redis,
+ nil,
+ true, -- is write
+ redis_lock_extend_cb, --callback
+ 'HINCRBY', -- command
+ { ann_key, 'lock', '30' }
+ )
+ else
+ lua_util.debugm(N, rspamd_config, "stop lock extension as learning_spawned is false")
+ return false -- do not plan any more updates
+ end
+
+ return true
+ end
+ )
+end
+
+local function can_push_train_vector(rule, task, learn_type, nspam, nham)
+ local train_opts = rule.train
+ local coin = math.random()
+
+ if train_opts.train_prob and coin < 1.0 - train_opts.train_prob then
+ rspamd_logger.infox(task, 'probabilistically skip sample: %s', coin)
+ return false
+ end
+
+ if train_opts.learn_mode == 'balanced' then
+ -- Keep balanced training set based on number of spam and ham samples
+ if learn_type == 'spam' then
+ if nspam <= train_opts.max_trains then
+ if nspam > nham then
+ -- Apply sampling
+ local skip_rate = 1.0 - nham / (nspam + 1)
+ if coin < skip_rate - train_opts.classes_bias then
+ rspamd_logger.infox(task,
+ 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
+ learn_type,
+ skip_rate - train_opts.classes_bias,
+ nspam, nham)
+ return false
+ end
+ end
+ return true
+ else
+ -- Enough learns
+ rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many spam samples: %s',
+ learn_type,
+ nspam)
+ end
+ else
+ if nham <= train_opts.max_trains then
+ if nham > nspam then
+ -- Apply sampling
+ local skip_rate = 1.0 - nspam / (nham + 1)
+ if coin < skip_rate - train_opts.classes_bias then
+ rspamd_logger.infox(task,
+ 'skip %s sample to keep spam/ham balance; probability %s; %s spam and %s ham vectors stored',
+ learn_type,
+ skip_rate - train_opts.classes_bias,
+ nspam, nham)
+ return false
+ end
+ end
+ return true
+ else
+ rspamd_logger.infox(task, 'skip %s sample to keep spam/ham balance; too many ham samples: %s', learn_type,
+ nham)
+ end
+ end
+ else
+ -- Probabilistic learn mode, we just skip learn if we already have enough samples or
+ -- if our coin drop is less than desired probability
+ if learn_type == 'spam' then
+ if nspam <= train_opts.max_trains then
+ if train_opts.spam_skip_prob then
+ if coin <= train_opts.spam_skip_prob then
+ rspamd_logger.infox(task, 'skip %s sample probabilistically; probability %s (%s skip chance)', learn_type,
+ coin, train_opts.spam_skip_prob)
+ return false
+ end
+
+ return true
+ end
+ else
+ rspamd_logger.infox(task, 'skip %s sample; too many spam samples: %s (%s limit)', learn_type,
+ nspam, train_opts.max_trains)
+ end
+ else
+ if nham <= train_opts.max_trains then
+ if train_opts.ham_skip_prob then
+ if coin <= train_opts.ham_skip_prob then
+ rspamd_logger.infox(task, 'skip %s sample probabilistically; probability %s (%s skip chance)', learn_type,
+ coin, train_opts.ham_skip_prob)
+ return false
+ end
+
+ return true
+ end
+ else
+ rspamd_logger.infox(task, 'skip %s sample; too many ham samples: %s (%s limit)', learn_type,
+ nham, train_opts.max_trains)
+ end
+ end
+ end
+
+ return false
+end
+
+-- Closure generator for unlock function
+local function gen_unlock_cb(rule, set, ann_key)
+ return function(err)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot unlock ANN %s:%s at %s from redis: %s',
+ rule.prefix, set.name, ann_key, err)
+ else
+ lua_util.debugm(N, rspamd_config, 'unlocked ANN %s:%s at %s',
+ rule.prefix, set.name, ann_key)
+ end
+ end
+end
+
+-- Used to generate new ANN key for specific profile
+local function new_ann_key(rule, set, version)
+ local ann_key = string.format('%s_%s_%s_%s_%s', settings.prefix,
+ rule.prefix, set.name, set.digest:sub(1, 8), tostring(version))
+
+ return ann_key
+end
+
+local function redis_ann_prefix(rule, settings_name)
+ -- We also need to count metatokens:
+ local n = meta_functions.version
+ return string.format('%s%d_%s_%d_%s',
+ settings.prefix, plugin_ver, rule.prefix, n, settings_name)
+end
+
+-- This function receives training vectors, checks them, spawn learning and saves ANN in Redis
+local function spawn_train(params)
+ -- Check training data sanity
+ -- Now we need to join inputs and create the appropriate test vectors
+ local n = #params.set.symbols +
+ meta_functions.rspamd_count_metatokens()
+
+ -- Now we can train ann
+ local train_ann = create_ann(params.rule.max_inputs or n, 3, params.rule)
+
+ if #params.ham_vec + #params.spam_vec < params.rule.train.max_trains / 2 then
+ -- Invalidate ANN as it is definitely invalid
+ -- TODO: add invalidation
+ assert(false)
+ else
+ local inputs, outputs = {}, {}
+
+ -- Used to show parsed vectors in a convenient format (for debugging only)
+ local function debug_vec(t)
+ local ret = {}
+ for i, v in ipairs(t) do
+ if v ~= 0 then
+ ret[#ret + 1] = string.format('%d=%.2f', i, v)
+ end
+ end
+
+ return ret
+ end
+
+ -- Make training set by joining vectors
+ -- KANN automatically shuffles those samples
+ -- 1.0 is used for spam and -1.0 is used for ham
+ -- It implies that output layer can express that (e.g. tanh output)
+ for _, e in ipairs(params.spam_vec) do
+ inputs[#inputs + 1] = e
+ outputs[#outputs + 1] = { 1.0 }
+ --rspamd_logger.debugm(N, rspamd_config, 'spam vector: %s', debug_vec(e))
+ end
+ for _, e in ipairs(params.ham_vec) do
+ inputs[#inputs + 1] = e
+ outputs[#outputs + 1] = { -1.0 }
+ --rspamd_logger.debugm(N, rspamd_config, 'ham vector: %s', debug_vec(e))
+ end
+
+ -- Called in child process
+ local function train()
+ local log_thresh = params.rule.train.max_iterations / 10
+ local seen_nan = false
+
+ local function train_cb(iter, train_cost, value_cost)
+ if (iter * (params.rule.train.max_iterations / log_thresh)) % (params.rule.train.max_iterations) == 0 then
+ if train_cost ~= train_cost and not seen_nan then
+ -- We have nan :( try to log lot's of stuff to dig into a problem
+ seen_nan = true
+ rspamd_logger.errx(rspamd_config, 'ANN %s:%s: train error: observed nan in error cost!; value cost = %s',
+ params.rule.prefix, params.set.name,
+ value_cost)
+ for i, e in ipairs(inputs) do
+ lua_util.debugm(N, rspamd_config, 'train vector %s -> %s',
+ debug_vec(e), outputs[i][1])
+ end
+ end
+
+ rspamd_logger.infox(rspamd_config,
+ "ANN %s:%s: learned from %s redis key in %s iterations, error: %s, value cost: %s",
+ params.rule.prefix, params.set.name,
+ params.ann_key,
+ iter,
+ train_cost,
+ value_cost)
+ end
+ end
+
+ lua_util.debugm(N, rspamd_config, "subprocess to learn ANN %s:%s has been started",
+ params.rule.prefix, params.set.name)
+
+ local pca
+ if params.rule.max_inputs then
+ -- Train PCA in the main process, presumably it is not that long
+ lua_util.debugm(N, rspamd_config, "start PCA train for ANN %s:%s",
+ params.rule.prefix, params.set.name)
+ pca = learn_pca(inputs, params.rule.max_inputs)
+ end
+
+ lua_util.debugm(N, rspamd_config, "start neural train for ANN %s:%s",
+ params.rule.prefix, params.set.name)
+ local ret, err = pcall(train_ann.train1, train_ann,
+ inputs, outputs, {
+ lr = params.rule.train.learning_rate,
+ max_epoch = params.rule.train.max_iterations,
+ cb = train_cb,
+ pca = pca
+ })
+
+ if not ret then
+ rspamd_logger.errx(rspamd_config, "cannot train ann %s:%s: %s",
+ params.rule.prefix, params.set.name, err)
+
+ return nil
+ else
+ lua_util.debugm(N, rspamd_config, "finished neural train for ANN %s:%s",
+ params.rule.prefix, params.set.name)
+ end
+
+ local roc_thresholds = {}
+ if params.rule.roc_enabled then
+ local spam_threshold = get_roc_thresholds(train_ann,
+ inputs,
+ outputs,
+ 1 - params.rule.roc_misclassification_cost,
+ params.rule.roc_misclassification_cost)
+ local ham_threshold = get_roc_thresholds(train_ann,
+ inputs,
+ outputs,
+ params.rule.roc_misclassification_cost,
+ 1 - params.rule.roc_misclassification_cost)
+ roc_thresholds = { spam_threshold, ham_threshold }
+
+ rspamd_logger.messagex("ROC thresholds: (spam_threshold: %s, ham_threshold: %s)",
+ roc_thresholds[1], roc_thresholds[2])
+ end
+
+ if not seen_nan then
+ -- Convert to strings as ucl cannot rspamd_text properly
+ local pca_data
+ if pca then
+ pca_data = tostring(pca:save())
+ end
+ local out = {
+ ann_data = tostring(train_ann:save()),
+ pca_data = pca_data,
+ roc_thresholds = roc_thresholds,
+ }
+
+ local final_data = ucl.to_format(out, 'msgpack')
+ lua_util.debugm(N, rspamd_config, "subprocess for ANN %s:%s returned %s bytes",
+ params.rule.prefix, params.set.name, #final_data)
+ return final_data
+ else
+ return nil
+ end
+ end
+
+ params.set.learning_spawned = true
+
+ local function redis_save_cb(err)
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot save ANN %s:%s to redis key %s: %s',
+ params.rule.prefix, params.set.name, params.ann_key, err)
+ lua_redis.redis_make_request_taskless(params.ev_base,
+ rspamd_config,
+ params.rule.redis,
+ nil,
+ false, -- is write
+ gen_unlock_cb(params.rule, params.set, params.ann_key), --callback
+ 'HDEL', -- command
+ { params.ann_key, 'lock' }
+ )
+ else
+ rspamd_logger.infox(rspamd_config, 'saved ANN %s:%s to redis: %s',
+ params.rule.prefix, params.set.name, params.set.ann.redis_key)
+ end
+ end
+
+ local function ann_trained(err, data)
+ params.set.learning_spawned = false
+ if err then
+ rspamd_logger.errx(rspamd_config, 'cannot train ANN %s:%s : %s',
+ params.rule.prefix, params.set.name, err)
+ lua_redis.redis_make_request_taskless(params.ev_base,
+ rspamd_config,
+ params.rule.redis,
+ nil,
+ true, -- is write
+ gen_unlock_cb(params.rule, params.set, params.ann_key), --callback
+ 'HDEL', -- command
+ { params.ann_key, 'lock' }
+ )
+ else
+ local parser = ucl.parser()
+ local ok, parse_err = parser:parse_text(data, 'msgpack')
+ assert(ok, parse_err)
+ local parsed = parser:get_object()
+ local ann_data = rspamd_util.zstd_compress(parsed.ann_data)
+ local pca_data = parsed.pca_data
+ local roc_thresholds = parsed.roc_thresholds
+
+ fill_set_ann(params.set, params.ann_key)
+ if pca_data then
+ params.set.ann.pca = rspamd_tensor.load(pca_data)
+ pca_data = rspamd_util.zstd_compress(pca_data)
+ end
+
+ if roc_thresholds then
+ params.set.ann.roc_thresholds = roc_thresholds
+ end
+
+
+ -- Deserialise ANN from the child process
+ ann_trained = rspamd_kann.load(parsed.ann_data)
+ local version = (params.set.ann.version or 0) + 1
+ params.set.ann.version = version
+ params.set.ann.ann = ann_trained
+ params.set.ann.symbols = params.set.symbols
+ params.set.ann.redis_key = new_ann_key(params.rule, params.set, version)
+
+ local profile = {
+ symbols = params.set.symbols,
+ digest = params.set.digest,
+ redis_key = params.set.ann.redis_key,
+ version = version
+ }
+
+ local profile_serialized = ucl.to_format(profile, 'json-compact', true)
+ local roc_thresholds_serialized = ucl.to_format(roc_thresholds, 'json-compact', true)
+
+ rspamd_logger.infox(rspamd_config,
+ 'trained ANN %s:%s, %s bytes (%s compressed); %s rows in pca (%sb compressed); redis key: %s (old key %s)',
+ params.rule.prefix, params.set.name,
+ #data, #ann_data,
+ #(params.set.ann.pca or {}), #(pca_data or {}),
+ params.set.ann.redis_key, params.ann_key)
+
+ lua_redis.exec_redis_script(redis_script_id.save_unlock,
+ { ev_base = params.ev_base, is_write = true },
+ redis_save_cb,
+ { profile.redis_key,
+ redis_ann_prefix(params.rule, params.set.name),
+ ann_data,
+ profile_serialized,
+ tostring(params.rule.ann_expire),
+ tostring(os.time()),
+ params.ann_key, -- old key to unlock...
+ roc_thresholds_serialized,
+ pca_data,
+ })
+ end
+ end
+
+ if params.rule.max_inputs then
+ fill_set_ann(params.set, params.ann_key)
+ end
+
+ params.worker:spawn_process {
+ func = train,
+ on_complete = ann_trained,
+ proctitle = string.format("ANN train for %s/%s", params.rule.prefix, params.set.name),
+ }
+ -- Spawn learn and register lock extension
+ params.set.learning_spawned = true
+ register_lock_extender(params.rule, params.set, params.ev_base, params.ann_key)
+ return
+
+ end
+end
+
+-- This function is used to adjust profiles and allowed setting ids for each rule
+-- It must be called when all settings are already registered (e.g. at post-init for config)
+local function process_rules_settings()
+ local function process_settings_elt(rule, selt)
+ local profile = rule.profile[selt.name]
+ if profile then
+ -- Use static user defined profile
+ -- Ensure that we have an array...
+ lua_util.debugm(N, rspamd_config, "use static profile for %s (%s): %s",
+ rule.prefix, selt.name, profile)
+ if not profile[1] then
+ profile = lua_util.keys(profile)
+ end
+ selt.symbols = profile
+ else
+ lua_util.debugm(N, rspamd_config, "use dynamic cfg based profile for %s (%s)",
+ rule.prefix, selt.name)
+ end
+
+ local function filter_symbols_predicate(sname)
+ if settings.blacklisted_symbols and settings.blacklisted_symbols[sname] then
+ return false
+ end
+ local fl = rspamd_config:get_symbol_flags(sname)
+ if fl then
+ fl = lua_util.list_to_hash(fl)
+
+ return not (fl.nostat or fl.idempotent or fl.skip or fl.composite)
+ end
+
+ return false
+ end
+
+ -- Generic stuff
+ if not profile then
+ -- Do filtering merely if we are using a dynamic profile
+ selt.symbols = fun.totable(fun.filter(filter_symbols_predicate, selt.symbols))
+ end
+
+ table.sort(selt.symbols)
+
+ selt.digest = lua_util.table_digest(selt.symbols)
+ selt.prefix = redis_ann_prefix(rule, selt.name)
+
+ rspamd_logger.messagex(rspamd_config,
+ 'use NN prefix for rule %s; settings id "%s"; symbols digest: "%s"',
+ selt.prefix, selt.name, selt.digest)
+
+ lua_redis.register_prefix(selt.prefix, N,
+ string.format('NN prefix for rule "%s"; settings id "%s"',
+ selt.prefix, selt.name), {
+ persistent = true,
+ type = 'zlist',
+ })
+ -- Versions
+ lua_redis.register_prefix(selt.prefix .. '_\\d+', N,
+ string.format('NN storage for rule "%s"; settings id "%s"',
+ selt.prefix, selt.name), {
+ persistent = true,
+ type = 'hash',
+ })
+ lua_redis.register_prefix(selt.prefix .. '_\\d+_spam_set', N,
+ string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
+ selt.prefix, selt.name), {
+ persistent = true,
+ type = 'set',
+ })
+ lua_redis.register_prefix(selt.prefix .. '_\\d+_ham_set', N,
+ string.format('NN learning set (spam) for rule "%s"; settings id "%s"',
+ rule.prefix, selt.name), {
+ persistent = true,
+ type = 'set',
+ })
+ end
+
+ for k, rule in pairs(settings.rules) do
+ if not rule.allowed_settings then
+ rule.allowed_settings = {}
+ elseif rule.allowed_settings == 'all' then
+ -- Extract all settings ids
+ rule.allowed_settings = lua_util.keys(lua_settings.all_settings())
+ end
+
+ -- Convert to a map <setting_id> -> true
+ rule.allowed_settings = lua_util.list_to_hash(rule.allowed_settings)
+
+ -- Check if we can work without settings
+ if k == 'default' or type(rule.default) ~= 'boolean' then
+ rule.default = true
+ end
+
+ rule.settings = {}
+
+ if rule.default then
+ local default_settings = {
+ symbols = lua_settings.default_symbols(),
+ name = 'default'
+ }
+
+ process_settings_elt(rule, default_settings)
+ rule.settings[-1] = default_settings -- Magic constant, but OK as settings are positive int32
+ end
+
+ -- Now, for each allowed settings, we store sorted symbols + digest
+ -- We set table rule.settings[id] -> { name = name, symbols = symbols, digest = digest }
+ for s, _ in pairs(rule.allowed_settings) do
+ -- Here, we have a name, set of symbols and
+ local settings_id = s
+ if type(settings_id) ~= 'number' then
+ settings_id = lua_settings.numeric_settings_id(s)
+ end
+ local selt = lua_settings.settings_by_id(settings_id)
+
+ local nelt = {
+ symbols = selt.symbols, -- Already sorted
+ name = selt.name
+ }
+
+ process_settings_elt(rule, nelt)
+ for id, ex in pairs(rule.settings) do
+ if type(ex) == 'table' then
+ if nelt and lua_util.distance_sorted(ex.symbols, nelt.symbols) == 0 then
+ -- Equal symbols, add reference
+ lua_util.debugm(N, rspamd_config,
+ 'added reference from settings id %s to %s; same symbols',
+ nelt.name, ex.name)
+ rule.settings[settings_id] = id
+ nelt = nil
+ end
+ end
+ end
+
+ if nelt then
+ rule.settings[settings_id] = nelt
+ lua_util.debugm(N, rspamd_config, 'added new settings id %s(%s) to %s',
+ nelt.name, settings_id, rule.prefix)
+ end
+ end
+ end
+end
+
+-- Extract settings element for a specific settings id
+local function get_rule_settings(task, rule)
+ local sid = task:get_settings_id() or -1
+ local set = rule.settings[sid]
+
+ if not set then
+ return nil
+ end
+
+ while type(set) == 'number' do
+ -- Reference to another settings!
+ set = rule.settings[set]
+ end
+
+ return set
+end
+
+local function result_to_vector(task, profile)
+ if not profile.zeros then
+ -- Fill zeros vector
+ local zeros = {}
+ for i = 1, meta_functions.count_metatokens() do
+ zeros[i] = 0.0
+ end
+ for _, _ in ipairs(profile.symbols) do
+ zeros[#zeros + 1] = 0.0
+ end
+ profile.zeros = zeros
+ end
+
+ local vec = lua_util.shallowcopy(profile.zeros)
+ local mt = meta_functions.rspamd_gen_metatokens(task)
+
+ for i, v in ipairs(mt) do
+ vec[i] = v
+ end
+
+ task:process_ann_tokens(profile.symbols, vec, #mt, 0.1)
+
+ return vec
+end
+
+return {
+ can_push_train_vector = can_push_train_vector,
+ create_ann = create_ann,
+ default_options = default_options,
+ gen_unlock_cb = gen_unlock_cb,
+ get_rule_settings = get_rule_settings,
+ load_scripts = load_scripts,
+ module_config = module_config,
+ new_ann_key = new_ann_key,
+ plugin_ver = plugin_ver,
+ process_rules_settings = process_rules_settings,
+ redis_ann_prefix = redis_ann_prefix,
+ redis_params = redis_params,
+ redis_script_id = redis_script_id,
+ result_to_vector = result_to_vector,
+ settings = settings,
+ spawn_train = spawn_train,
+}
diff --git a/lualib/plugins/rbl.lua b/lualib/plugins/rbl.lua
new file mode 100644
index 0000000..af5d6bd
--- /dev/null
+++ b/lualib/plugins/rbl.lua
@@ -0,0 +1,232 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+local ts = require("tableshape").types
+local lua_maps = require "lua_maps"
+local lua_util = require "lua_util"
+
+-- Common RBL plugin definitions
+
+local check_types = {
+ from = {
+ connfilter = true,
+ },
+ received = {},
+ helo = {
+ connfilter = true,
+ },
+ urls = {},
+ content_urls = {},
+ numeric_urls = {},
+ emails = {},
+ replyto = {},
+ dkim = {},
+ rdns = {
+ connfilter = true,
+ },
+ selector = {
+ require_argument = true,
+ },
+}
+
+local default_options = {
+ ['default_enabled'] = true,
+ ['default_ipv4'] = true,
+ ['default_ipv6'] = true,
+ ['default_unknown'] = false,
+ ['default_dkim_domainonly'] = true,
+ ['default_emails_domainonly'] = false,
+ ['default_exclude_users'] = false,
+ ['default_exclude_local'] = true,
+ ['default_no_ip'] = false,
+ ['default_dkim_match_from'] = false,
+ ['default_selector_flatten'] = true,
+}
+
+local return_codes_schema = ts.map_of(
+ ts.string / string.upper, -- Symbol name
+ (
+ ts.array_of(ts.string) +
+ (ts.string / function(s)
+ return { s }
+ end) -- List of IP patterns
+ )
+)
+local return_bits_schema = ts.map_of(
+ ts.string / string.upper, -- Symbol name
+ (
+ ts.array_of(ts.number + ts.string / tonumber) +
+ (ts.string / function(s)
+ return { tonumber(s) }
+ end) +
+ (ts.number / function(s)
+ return { s }
+ end)
+ )
+)
+
+local rule_schema_tbl = {
+ content_urls = ts.boolean:is_optional(),
+ disable_monitoring = ts.boolean:is_optional(),
+ disabled = ts.boolean:is_optional(),
+ dkim = ts.boolean:is_optional(),
+ dkim_domainonly = ts.boolean:is_optional(),
+ dkim_match_from = ts.boolean:is_optional(),
+ emails = ts.boolean:is_optional(),
+ emails_delimiter = ts.string:is_optional(),
+ emails_domainonly = ts.boolean:is_optional(),
+ enabled = ts.boolean:is_optional(),
+ exclude_local = ts.boolean:is_optional(),
+ exclude_users = ts.boolean:is_optional(),
+ from = ts.boolean:is_optional(),
+ hash = ts.one_of { "sha1", "sha256", "sha384", "sha512", "md5", "blake2" }:is_optional(),
+ hash_format = ts.one_of { "hex", "base32", "base64" }:is_optional(),
+ hash_len = (ts.integer + ts.string / tonumber):is_optional(),
+ helo = ts.boolean:is_optional(),
+ ignore_default = ts.boolean:is_optional(), -- alias
+ ignore_defaults = ts.boolean:is_optional(),
+ ignore_url_whitelist = ts.boolean:is_optional(),
+ ignore_whitelist = ts.boolean:is_optional(),
+ ignore_whitelists = ts.boolean:is_optional(), -- alias
+ images = ts.boolean:is_optional(),
+ ipv4 = ts.boolean:is_optional(),
+ ipv6 = ts.boolean:is_optional(),
+ is_whitelist = ts.boolean:is_optional(),
+ local_exclude_ip_map = ts.string:is_optional(),
+ monitored_address = ts.string:is_optional(),
+ no_ip = ts.boolean:is_optional(),
+ process_script = ts.string:is_optional(),
+ random_monitored = ts.boolean:is_optional(),
+ rbl = ts.string,
+ rdns = ts.boolean:is_optional(),
+ received = ts.boolean:is_optional(),
+ received_flags = ts.array_of(ts.string):is_optional(),
+ received_max_pos = ts.number:is_optional(),
+ received_min_pos = ts.number:is_optional(),
+ received_nflags = ts.array_of(ts.string):is_optional(),
+ replyto = ts.boolean:is_optional(),
+ requests_limit = (ts.integer + ts.string / tonumber):is_optional(),
+ require_symbols = (
+ ts.array_of(ts.string) + (ts.string / function(s)
+ return { s }
+ end)
+ ):is_optional(),
+ resolve_ip = ts.boolean:is_optional(),
+ return_bits = return_bits_schema:is_optional(),
+ return_codes = return_codes_schema:is_optional(),
+ returnbits = return_bits_schema:is_optional(),
+ returncodes = return_codes_schema:is_optional(),
+ returncodes_matcher = ts.one_of { "equality", "glob", "luapattern", "radix", "regexp" }:is_optional(),
+ selector = ts.one_of { ts.string, ts.table }:is_optional(),
+ selector_flatten = ts.boolean:is_optional(),
+ symbol = ts.string:is_optional(),
+ symbols_prefixes = ts.map_of(ts.string, ts.string):is_optional(),
+ unknown = ts.boolean:is_optional(),
+ url_compose_map = lua_maps.map_schema:is_optional(),
+ url_full_hostname = ts.boolean:is_optional(),
+ url_whitelist = lua_maps.map_schema:is_optional(),
+ urls = ts.boolean:is_optional(),
+ whitelist = lua_maps.map_schema:is_optional(),
+ whitelist_exception = (
+ ts.array_of(ts.string) + (ts.string / function(s)
+ return { s }
+ end)
+ ):is_optional(),
+ checks = ts.array_of(ts.one_of(lua_util.keys(check_types))):is_optional(),
+ exclude_checks = ts.array_of(ts.one_of(lua_util.keys(check_types))):is_optional(),
+}
+
+local function convert_checks(rule, name)
+ local rspamd_logger = require "rspamd_logger"
+ if rule.checks then
+ local all_connfilter = true
+ local exclude_checks = lua_util.list_to_hash(rule.exclude_checks or {})
+ for _, check in ipairs(rule.checks) do
+ if not exclude_checks[check] then
+ local check_type = check_types[check]
+ if check_type.require_argument then
+ if not rule[check] then
+ rspamd_logger.errx(rspamd_config, 'rbl rule %s has check %s which requires an argument',
+ name, check)
+ return nil
+ end
+ end
+
+ rule[check] = check_type
+
+ if not check_type.connfilter then
+ all_connfilter = false
+ end
+
+ if not check_type then
+ rspamd_logger.errx(rspamd_config, 'rbl rule %s has invalid check type: %s',
+ name, check)
+ return nil
+ end
+ else
+ rspamd_logger.infox(rspamd_config, 'disable check %s in %s: excluded explicitly',
+ check, name)
+ end
+ end
+ rule.connfilter = all_connfilter
+ end
+
+ -- Now check if we have any check enabled at all
+ local check_found = false
+ for k, _ in pairs(check_types) do
+ if type(rule[k]) ~= 'nil' then
+ check_found = true
+ break
+ end
+ end
+
+ if not check_found then
+ -- Enable implicit `from` check to allow upgrade
+ rspamd_logger.warnx(rspamd_config, 'rbl rule %s has no check enabled, enable default `from` check',
+ name)
+ rule.from = true
+ end
+
+ if rule.returncodes and not rule.returncodes_matcher then
+ for _, v in pairs(rule.returncodes) do
+ for _, e in ipairs(v) do
+ if e:find('[%%%[]') then
+ rspamd_logger.warn(rspamd_config, 'implicitly enabling luapattern returncodes_matcher for rule %s', name)
+ rule.returncodes_matcher = 'luapattern'
+ break
+ end
+ end
+ if rule.returncodes_matcher then
+ break
+ end
+ end
+ end
+
+ return rule
+end
+
+
+-- Add default boolean flags to the schema
+for def_k, _ in pairs(default_options) do
+ rule_schema_tbl[def_k:sub(#('default_') + 1)] = ts.boolean:is_optional()
+end
+
+return {
+ check_types = check_types,
+ rule_schema = ts.shape(rule_schema_tbl),
+ default_options = default_options,
+ convert_checks = convert_checks,
+}